summaryrefslogtreecommitdiff
path: root/lifted-concurrent/src/Control/Concurrent/Lifted/Instrument.hs
blob: eeda4de83a37bd4718e696f5cf0c48fab1bd7eb3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
{-# LANGUAGE FlexibleContexts #-}
module Control.Concurrent.Lifted.Instrument
    ( module Control.Concurrent.Lifted
    , forkLabeled
    , forkIO
    , forkOS
    , fork
    , labelThread
    , threadsInformation
    , PerThread(..)
    ) where

import qualified Control.Concurrent.Lifted as Raw
import Control.Concurrent.Lifted hiding (fork,forkOS)
import Control.Exception (fromException)
import Control.Monad.Trans.Control
import System.IO.Unsafe
import qualified Data.Map.Strict  as Map
import Control.Exception.Lifted
import Control.Monad.Base
import qualified GHC.Conc as GHC
import Data.Time()
import Data.Time.Clock
import DPut
import DebugTag


data PerThread = PerThread
 { lbl :: String
 , startTime :: UTCTime
 }
 deriving (Eq,Ord,Show)

data GlobalState = GlobalState
 { threads :: !(Map.Map ThreadId PerThread)
 , reportException :: String -> IO ()
 }

globals :: MVar GlobalState
globals = unsafePerformIO $ newMVar $ GlobalState
    { threads = Map.empty
    , reportException = dput XMisc
    }
{-# NOINLINE globals #-}

forkLabeled :: String -> IO () -> IO ThreadId
forkLabeled lbl action = do
    t <- forkIO action
    labelThread t lbl
    return t
{-# INLINE forkLabeled #-}

forkIO :: IO () -> IO ThreadId
forkIO = instrumented GHC.forkIO
{-# INLINE forkIO #-}

forkOS :: MonadBaseControl IO m => m () -> m ThreadId
forkOS = instrumented Raw.forkOS
{-# INLINE forkOS #-}

fork :: MonadBaseControl IO m => m () -> m ThreadId
fork = instrumented Raw.fork
{-# INLINE fork #-}

instrumented :: MonadBaseControl IO m =>
                (m () -> m ThreadId) -> m () -> m ThreadId
instrumented rawFork action = do
    t <- rawFork $ do
        tid <- myThreadId
        tm <- liftBase getCurrentTime
        bracket_ (modifyThreads $! Map.insert tid (PerThread "" tm))
                 (return ())
            $ do catch action $ \e -> case fromException e of
                    Just ThreadKilled -> return ()
                    Nothing           -> liftBase $ do
                        g <- takeMVar globals
                        let l = concat [ show e
                                       , " ("
                                       , maybe "" lbl $ Map.lookup tid (threads g)
                                       , ")"
                                       ]
                        reportException g l
                        putMVar globals $! g { threads = Map.insert tid (PerThread l tm) $ threads g }
                        throwIO e
                 -- Remove the thread only if it terminated normally or was killed.
                 modifyThreads $! Map.delete tid
    return t

labelThread :: ThreadId -> String -> IO ()
labelThread tid s = do
    GHC.labelThread tid s
    modifyThreads $! Map.adjust (\pt -> pt { lbl = s }) tid
{-# INLINE labelThread #-}

threadsInformation :: IO [(ThreadId,PerThread)]
threadsInformation = do
    m <- threads <$> readMVar globals
    return $ Map.toList m


modifyThreads :: MonadBase IO m => (Map.Map ThreadId PerThread -> Map.Map ThreadId PerThread) -> m ()
modifyThreads f = do
    g <- takeMVar globals
    let f' st = st { threads = f (threads st) }
    putMVar globals $! f' g