From 1e16567904a147b842070c1d98c83dc3e9c00c98 Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Tue, 28 Jan 2020 17:23:02 -0500 Subject: Kill tcp session on exception. --- server/src/Network/QueryResponse/TCP.hs | 79 ++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 35 deletions(-) (limited to 'server') diff --git a/server/src/Network/QueryResponse/TCP.hs b/server/src/Network/QueryResponse/TCP.hs index 24aacd98..639212cb 100644 --- a/server/src/Network/QueryResponse/TCP.hs +++ b/server/src/Network/QueryResponse/TCP.hs @@ -71,9 +71,12 @@ data StreamHandshake addr x y = StreamHandshake , streamAddr :: addr -> SockAddr } -killSession :: TCPSession st -> IO () -killSession PendingTCPSession = return () -killSession TCPSession{tcpThread=t} = killThread t +killSession :: TCPSession (SessionProtocol x y) -> IO () +killSession TCPSession{tcpState=st,tcpHandle=h,tcpThread=t} = do + catchIOError (streamGoodbye st >> hClose h) + (\e -> return ()) + killThread t +killSession _ = return () showStat :: IsString p => TCPSession st -> p showStat r = case r of PendingTCPSession -> "pending." @@ -82,6 +85,30 @@ showStat r = case r of PendingTCPSession -> "pending." tcp_timeout :: Int tcp_timeout = 10000000 + +removeOnFail :: TCPCache (SessionProtocol x y) -> Handle -> TCPAddress -> IO () -> IO () +removeOnFail tcpcache h addr action = action `catchIOError` \e -> do + join $ atomically $ do + c <- readTVar (lru tcpcache) + case MM.lookup' addr c of + Just (tm, v@TCPSession {tcpHandle=stored}) | h == stored -> do + modifyTVar' (lru tcpcache) $ MM.delete addr + return $ killSession v + _ -> return $ return () + dput XTCP $ "TCP-send " ++ show addr ++ " " ++ show e + +lookupSession :: TCPCache st -> POSIXTime -> TCPAddress -> Bool -> STM (Maybe (Down POSIXTime, TCPSession st)) +lookupSession tcpcache now saddr bDoCon = do + c <- readTVar (lru tcpcache) + let v = MM.lookup' saddr c + case v of + Nothing | bDoCon -> writeTVar (lru tcpcache) + $ MM.insert' saddr PendingTCPSession (Down now) c + | otherwise -> return () + Just (tm, v) -> writeTVar (lru tcpcache) + $ MM.insert' saddr v (Down now) c + return v + acquireConnection :: TMVar (Arrival a addr x) -> TCPCache (SessionProtocol x y) -> StreamHandshake addr x y @@ -91,16 +118,8 @@ acquireConnection :: TMVar (Arrival a addr x) acquireConnection mvar tcpcache stream addr bDoCon = do now <- getPOSIXTime -- dput XTCP $ "acquireConnection 0 " ++ show (streamAddr stream addr) - entry <- atomically $ do - c <- readTVar (lru tcpcache) - let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c - case v of - Nothing | bDoCon -> writeTVar (lru tcpcache) - $ MM.insert' (TCPAddress $ streamAddr stream addr) PendingTCPSession (Down now) c - | otherwise -> return () - Just (tm, v) -> writeTVar (lru tcpcache) - $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down now) c - return v + let saddr = TCPAddress $ streamAddr stream addr + entry <- atomically $ lookupSession tcpcache now saddr bDoCon -- dput XTCP $ "acquireConnection 1 " ++ show (streamAddr stream addr, fmap (second showStat) entry) case entry of Nothing -> fmap join $ forM (guard bDoCon) $ \() -> do @@ -114,15 +133,14 @@ acquireConnection mvar tcpcache stream addr bDoCon = do return h) $ \e -> return Nothing when (isNothing mh) $ do - atomically $ modifyTVar' (lru tcpcache) - $ MM.delete (TCPAddress $ streamAddr stream addr) + atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr Socket.close sock ret <- fmap join $ forM mh $ \h -> do mst <- catchIOError (Just <$> streamHello stream addr h) (\e -> return Nothing) case mst of Nothing -> do - atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) + atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr return Nothing Just st -> do dput XTCP $ "TCP Connected! " ++ show (streamAddr stream addr) @@ -143,8 +161,7 @@ acquireConnection mvar tcpcache stream addr bDoCon = do loop Nothing -> do dput XTCP $ "TCP disconnected: " ++ show (streamAddr stream addr) - do atomically $ modifyTVar' (lru tcpcache) - $ MM.delete (TCPAddress $ streamAddr stream addr) + do atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr c <- atomically $ readTVar (lru tcpcache) now <- getPOSIXTime forM_ (zip [1..] $ MM.toList c) $ \(i,MM.Binding (TCPAddress addr) r (Down tm)) -> do @@ -163,33 +180,28 @@ acquireConnection mvar tcpcache stream addr bDoCon = do retires <- atomically $ do c <- readTVar (lru tcpcache) let (rs,c') = MM.takeView (tcpMax tcpcache) - $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) c + $ MM.insert' saddr v (Down t) c writeTVar (lru tcpcache) c' writeTVar signal True return rs forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkLabeled ("tcp-close:"++show k) $ do dput XTCP $ "TCP dropped: " ++ show k killSession r - case r of TCPSession {tcpState=st,tcpHandle=h} -> do - streamGoodbye st - hClose h - `catchIOError` \e -> return () - _ -> return () - - return $ Just $ streamEncode st + return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y when (isNothing ret) $ do - atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) + atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr return ret Just (tm, PendingTCPSession) | not bDoCon -> return Nothing | otherwise -> fmap join $ timeout tcp_timeout $ atomically $ do c <- readTVar (lru tcpcache) - let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c + let v = MM.lookup' saddr c case v of - Just (_,TCPSession{tcpState=st}) -> return $ Just $ streamEncode st - Nothing -> return Nothing - _ -> retry - Just (tm, v@TCPSession {tcpState=st}) -> return $ Just $ streamEncode st + Just (_,TCPSession{tcpHandle=h,tcpState=st}) + -> return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y + Nothing -> return Nothing + _ -> retry + Just (tm, v@TCPSession {tcpHandle=h,tcpState=st}) -> return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO () closeAll tcpcache stream = do @@ -197,9 +209,6 @@ closeAll tcpcache stream = do cache <- atomically $ swapTVar (lru tcpcache) MM.empty forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do killSession r - case r of TCPSession{tcpState=st,tcpHandle=h} -> catchIOError (streamGoodbye st >> hClose h) - (\e -> return ()) - _ -> return () -- Use a cache of TCP client connections for sending (and receiving) packets. -- The boolean value prepended to the message allows the sender to specify -- cgit v1.2.3