From cb28281a2acabf87e91582ce5ace562544ae2730 Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Mon, 13 Jan 2020 06:58:48 -0500 Subject: Fixed race condition in thread instrumentation. --- .../src/Control/Concurrent/Lifted/Instrument.hs | 75 +++++++++++----------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/lifted-concurrent/src/Control/Concurrent/Lifted/Instrument.hs b/lifted-concurrent/src/Control/Concurrent/Lifted/Instrument.hs index a0bb7dc5..bd6ee4b8 100644 --- a/lifted-concurrent/src/Control/Concurrent/Lifted/Instrument.hs +++ b/lifted-concurrent/src/Control/Concurrent/Lifted/Instrument.hs @@ -1,4 +1,5 @@ {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RankNTypes #-} module Control.Concurrent.Lifted.Instrument ( module Control.Concurrent.Lifted , forkLabeled @@ -11,6 +12,7 @@ module Control.Concurrent.Lifted.Instrument , PerThread(..) ) where +import Control.Concurrent as Raw (forkOSWithUnmask) import qualified Control.Concurrent.Lifted as Raw import Control.Concurrent.Lifted hiding (fork,forkOS) import Control.Exception (fromException) @@ -47,28 +49,29 @@ globals = unsafePerformIO $ newMVar $ GlobalState forkLabeled :: HasCallStack => String -> IO () -> IO ThreadId forkLabeled lbl action = do - t <- instrumented GHC.forkIO action + t <- instrumented GHC.forkIOWithUnmask action labelThread t lbl return t {-# INLINE forkLabeled #-} forkOSLabeled :: HasCallStack => String -> IO () -> IO ThreadId forkOSLabeled lbl action = do - t <- instrumented Raw.forkOS action + t <- instrumented Raw.forkOSWithUnmask action labelThread t lbl return t {-# INLINE forkOSLabeled #-} forkIO :: HasCallStack => IO () -> IO ThreadId -forkIO = instrumented GHC.forkIO +forkIO = instrumented GHC.forkIOWithUnmask {-# INLINE forkIO #-} -forkOS :: ( HasCallStack, MonadBaseControl IO m ) => m () -> m ThreadId -forkOS = instrumented Raw.forkOS +forkOS :: -- ( HasCallStack, MonadBaseControl IO m ) => m () -> m ThreadId + HasCallStack => IO () -> IO ThreadId +forkOS = instrumented Raw.forkOSWithUnmask {-# INLINE forkOS #-} fork :: ( HasCallStack, MonadBaseControl IO m ) => m () -> m ThreadId -fork = instrumented Raw.fork +fork = instrumented Raw.forkWithUnmask {-# INLINE fork #-} shortCallStack :: [([Char], SrcLoc)] -> String @@ -80,47 +83,43 @@ defaultLabel stack = case getCallStack stack of _ : sites -> shortCallStack sites sites -> shortCallStack sites + instrumented :: ( HasCallStack, MonadBaseControl IO m ) => - (m () -> m ThreadId) -> m () -> m ThreadId + (((forall a. m a -> m a) -> m ()) -> m ThreadId) -> m () -> m ThreadId instrumented rawFork action = do mvar <- newEmptyMVar - t <- rawFork $ do - tid <- myThreadId - tm <- liftBase getCurrentTime - bracket_ (modifyThreads $! \ts -> Map.union ts (Map.singleton tid (PerThread (defaultLabel callStack) tm))) - (return ()) - $ do catch action $ \e -> case fromException e of - Just ThreadKilled -> return () + tm <- liftBase getCurrentTime + t <- mask_ $ rawFork $ \unmask -> do + tid <- myThreadId + let scrapIt = do takeMVar mvar + modifyThreads $! Map.delete tid + io <- catch (unmask action >> return scrapIt) $ \e -> case fromException e of + Just ThreadKilled -> return scrapIt Nothing -> liftBase $ do - bracket (takeMVar globals) - (\g -> do - let l = concat [ show e - , " (" - , maybe "" lbl $ Map.lookup tid (threads g) - , ")" - ] - foldr seq (return ()) l - putMVar globals $! g { threads = Map.insert tid (PerThread l tm) $ threads g } - throwIO e) - (\g -> do - let l = concat [ show e - , " (" - , maybe "" lbl $ Map.lookup tid (threads g) - , ")" - ] - reportException g l) - -- Remove the thread only if it terminated normally or was killed. - takeMVar mvar - modifyThreads $! Map.delete tid - liftBase $ labelThread_ t (defaultLabel callStack) + g <- takeMVar globals + let l = concat [ show e + , " (" + , maybe "" lbl $ Map.lookup tid (threads g) + , ")" + ] + reportException g l + let l = concat [ show e + , " (" + , maybe "" lbl $ Map.lookup tid (threads g) + , ")" + ] + foldr seq (return ()) l + putMVar globals $! g { threads = Map.insert tid (PerThread l tm) $ threads g } + return $ return () -- Remove the thread only if it terminated normally or was killed. + io -- scrap record on normal termination + liftBase $ labelThread_ t (defaultLabel callStack) tm putMVar mvar () return t -labelThread_ :: ThreadId -> String -> IO () -labelThread_ tid s = do +labelThread_ :: ThreadId -> String -> UTCTime -> IO () +labelThread_ tid s tm = do foldr seq (return ()) s GHC.labelThread tid s - tm <- liftBase getCurrentTime let updateIt (Just pt) = Just $ pt { lbl = s } updateIt Nothing = Just $ PerThread s tm modifyThreads $! Map.alter updateIt tid -- cgit v1.2.3