summaryrefslogtreecommitdiff
path: root/lifted-concurrent/src/Control/Concurrent/Lifted/Instrument.hs
blob: 0b6082136bc70c468cfc75d11a75782b2bdee4e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MagicHash        #-}
{-# LANGUAGE RankNTypes       #-}
{-# LANGUAGE UnliftedFFITypes #-}
module Control.Concurrent.Lifted.Instrument
    ( module Control.Concurrent.Lifted
    , forkLabeled
    , forkIO
    , forkOS
    , forkOSLabeled
    , fork
    , labelThread
    , threadsInformation
    , PerThread(..)
    ) where

import Control.Concurrent as Raw (forkOSWithUnmask)
import qualified Control.Concurrent.Lifted as Raw
import Control.Concurrent.Lifted hiding (fork,forkOS)
import Control.Exception (fromException)
import Control.Monad.Trans.Control
import Foreign.C.Types
import GHC.Exts (ThreadId#)
import GHC.Conc (ThreadId(..))
import GHC.Stack
import System.IO.Unsafe
import qualified Data.Map.Strict  as Map
import Control.Exception.Lifted
import Control.Monad.Base
import qualified GHC.Conc as GHC
import Data.Time()
import Data.Time.Clock
import DPut
import DebugTag
import qualified Data.Vector as V
import Data.Vector (Vector)
import Data.Char

data PerThread = PerThread
 { lbl :: String
 , startTime :: UTCTime
 }
 deriving (Eq,Ord,Show)

{-# NOINLINE globalMVarArray #-}
globalMVarArray :: Vector (MVar (Map.Map ThreadId PerThread))
globalMVarArray = unsafePerformIO (sequence (V.replicate 128 (newMVar Map.empty)))

data GlobalState = GlobalState
 { reportException :: String -> IO ()
 }

foreign import ccall unsafe "rts_getThreadId" rts_getThreadId :: ThreadId# -> CInt

hashThreadId :: ThreadId -> Int
hashThreadId (ThreadId t) = fromIntegral (rts_getThreadId t) `mod` V.length globalMVarArray


globals :: MVar GlobalState
globals = unsafePerformIO $ newMVar $ GlobalState
    { reportException = dput XMisc
    }
{-# NOINLINE globals #-}

forkLabeled :: HasCallStack => String -> IO () -> IO ThreadId
forkLabeled lbl action = do
    t <- instrumented GHC.forkIOWithUnmask action
    labelThread t lbl
    return t
{-# INLINE forkLabeled #-}

forkOSLabeled :: HasCallStack => String -> IO () -> IO ThreadId
forkOSLabeled lbl action = do
    t <- instrumented Raw.forkOSWithUnmask action
    labelThread t lbl
    return t
{-# INLINE forkOSLabeled #-}

forkIO :: HasCallStack => IO () -> IO ThreadId
forkIO = instrumented GHC.forkIOWithUnmask
{-# INLINE forkIO #-}

forkOS :: -- ( HasCallStack, MonadBaseControl IO m ) => m () -> m ThreadId
          HasCallStack => IO () -> IO ThreadId
forkOS = instrumented Raw.forkOSWithUnmask
{-# INLINE forkOS #-}

fork :: ( HasCallStack, MonadBaseControl IO m ) => m () -> m ThreadId
fork = instrumented Raw.forkWithUnmask
{-# INLINE fork #-}

shortCallStack :: [([Char], SrcLoc)] -> String
shortCallStack []          = ""
shortCallStack ((_,loc):_) = (srcLocFile loc) ++ ":" ++ show (srcLocStartLine loc)

defaultLabel :: CallStack -> String
defaultLabel stack = case getCallStack stack of
    _ : sites  -> shortCallStack sites
    sites -> shortCallStack sites


instrumented :: ( HasCallStack, MonadBaseControl IO m ) =>
                (((forall a. m a -> m a) -> m ()) -> m ThreadId) -> m () -> m ThreadId
instrumented rawFork action = do
    mvar <- newEmptyMVar
    tm <- liftBase getCurrentTime
    t <- mask_ $ rawFork $ \unmask -> do
            tid <- myThreadId
            let scrapIt = do takeMVar mvar
                             modifyThreads tid $! Map.delete tid
            io <- catch (unmask action >> return scrapIt) $ \e -> case fromException e of
                    Just ThreadKilled -> return scrapIt
                    Nothing           -> liftBase $ do
                        g <- readMVar globals
                        mp <- readMVar (globalMVarArray V.! hashThreadId tid)
                        let l = concat [ show e
                                       , " ("
                                       , maybe "" lbl $ Map.lookup tid mp
                                       , ")"
                                       ]
                        reportException g l
                        let l = concat [ show e
                                       , " ("
                                       , maybe "" lbl $ Map.lookup tid mp
                                       , ")"
                                       ]
                        foldr seq (return ()) l
                        modifyThreads tid $! Map.insert tid (PerThread l tm)
                        return $ return () -- Remove the thread only if it terminated normally or was killed.
            io -- scrap record on normal termination
    liftBase $ labelThread_ t (defaultLabel callStack) tm
    putMVar mvar ()
    return t

labelThread_ :: ThreadId -> String -> UTCTime -> IO ()
labelThread_ tid s tm = do
    foldr seq (return ()) s
    GHC.labelThread tid s
    let updateIt (Just pt) = Just $ pt { lbl = s }
        updateIt Nothing   = Just $ PerThread s tm
    modifyThreads tid $! Map.alter updateIt tid

labelThread :: ThreadId -> String -> IO ()
labelThread tid s = do
    foldr seq (return ()) s
    GHC.labelThread tid s
    modifyThreads tid $! Map.adjust (\pt -> pt { lbl = s }) tid
{-# INLINE labelThread #-}

threadsInformation :: IO [(ThreadId,PerThread)]
threadsInformation = do
    ms <- mapM readMVar (V.toList globalMVarArray)
    return $ Prelude.concatMap Map.toList ms


modifyThreads :: MonadBaseControl IO m => ThreadId ->
                    (Map.Map ThreadId PerThread -> Map.Map ThreadId PerThread) -> m ()
modifyThreads tid f = do
    let mvar = globalMVarArray V.! hashThreadId tid
    bracket (takeMVar mvar)
            (\m -> putMVar mvar $! f m)
            (\m -> return ())