From 1e16567904a147b842070c1d98c83dc3e9c00c98 Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Tue, 28 Jan 2020 17:23:02 -0500 Subject: Kill tcp session on exception. --- dht/src/Network/Tox/TCP.hs | 16 +++---- server/src/Network/QueryResponse/TCP.hs | 79 ++++++++++++++++++--------------- 2 files changed, 50 insertions(+), 45 deletions(-) diff --git a/dht/src/Network/Tox/TCP.hs b/dht/src/Network/Tox/TCP.hs index 36289e19..e3780675 100644 --- a/dht/src/Network/Tox/TCP.hs +++ b/dht/src/Network/Tox/TCP.hs @@ -146,16 +146,12 @@ tcpStream crypto mkst = StreamHandshake dput XTCP $ "TCP exception: " ++ show e return Nothing , streamEncode = \y -> do - -- dput XTCP $ "TCP(acquire nonce):" ++ show addr ++ " <-- " ++ show y - n24 <- takeMVar nsend - -- dput XTCP $ "TCP(got nonce):" ++ show addr ++ " <-- " ++ show y - let bs = encode $ encrypt (noncef' n24) $ encodePlain y - ($ h) -- bracket (takeMVar hvar) (putMVar hvar) - $ \h -> hPut h (encode (fromIntegral $ Data.ByteString.length bs :: Word16) <> bs) - `catchIOError` \e -> dput XTCP $ "TCP write exception: " ++ show e - -- dput XTCP $ "TCP(incrementing nonce): " ++ show addr ++ " <-- " ++ show y - putMVar nsend (incrementNonce24 n24) - dput XTCP $ "TCP: " ++ show addr ++ " <-- " ++ show y + -- We need this to throw so the tcp session state can be cleaned up elsewhere. + bracket (takeMVar nsend) (putMVar nsend . incrementNonce24) + $ \n24 -> do + let bs = encode $ encrypt (noncef' n24) $ encodePlain y + hPut h (encode (fromIntegral $ Data.ByteString.length bs :: Word16) <> bs) + dput XTCP $ "TCP: " ++ show addr ++ " <-- " ++ show y } , streamAddr = nodeAddr } diff --git a/server/src/Network/QueryResponse/TCP.hs b/server/src/Network/QueryResponse/TCP.hs index 24aacd98..639212cb 100644 --- a/server/src/Network/QueryResponse/TCP.hs +++ b/server/src/Network/QueryResponse/TCP.hs @@ -71,9 +71,12 @@ data StreamHandshake addr x y = StreamHandshake , streamAddr :: addr -> SockAddr } -killSession :: TCPSession st -> IO () -killSession PendingTCPSession = return () -killSession TCPSession{tcpThread=t} = killThread t +killSession :: TCPSession (SessionProtocol x y) -> IO () +killSession TCPSession{tcpState=st,tcpHandle=h,tcpThread=t} = do + catchIOError (streamGoodbye st >> hClose h) + (\e -> return ()) + killThread t +killSession _ = return () showStat :: IsString p => TCPSession st -> p showStat r = case r of PendingTCPSession -> "pending." @@ -82,6 +85,30 @@ showStat r = case r of PendingTCPSession -> "pending." tcp_timeout :: Int tcp_timeout = 10000000 + +removeOnFail :: TCPCache (SessionProtocol x y) -> Handle -> TCPAddress -> IO () -> IO () +removeOnFail tcpcache h addr action = action `catchIOError` \e -> do + join $ atomically $ do + c <- readTVar (lru tcpcache) + case MM.lookup' addr c of + Just (tm, v@TCPSession {tcpHandle=stored}) | h == stored -> do + modifyTVar' (lru tcpcache) $ MM.delete addr + return $ killSession v + _ -> return $ return () + dput XTCP $ "TCP-send " ++ show addr ++ " " ++ show e + +lookupSession :: TCPCache st -> POSIXTime -> TCPAddress -> Bool -> STM (Maybe (Down POSIXTime, TCPSession st)) +lookupSession tcpcache now saddr bDoCon = do + c <- readTVar (lru tcpcache) + let v = MM.lookup' saddr c + case v of + Nothing | bDoCon -> writeTVar (lru tcpcache) + $ MM.insert' saddr PendingTCPSession (Down now) c + | otherwise -> return () + Just (tm, v) -> writeTVar (lru tcpcache) + $ MM.insert' saddr v (Down now) c + return v + acquireConnection :: TMVar (Arrival a addr x) -> TCPCache (SessionProtocol x y) -> StreamHandshake addr x y @@ -91,16 +118,8 @@ acquireConnection :: TMVar (Arrival a addr x) acquireConnection mvar tcpcache stream addr bDoCon = do now <- getPOSIXTime -- dput XTCP $ "acquireConnection 0 " ++ show (streamAddr stream addr) - entry <- atomically $ do - c <- readTVar (lru tcpcache) - let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c - case v of - Nothing | bDoCon -> writeTVar (lru tcpcache) - $ MM.insert' (TCPAddress $ streamAddr stream addr) PendingTCPSession (Down now) c - | otherwise -> return () - Just (tm, v) -> writeTVar (lru tcpcache) - $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down now) c - return v + let saddr = TCPAddress $ streamAddr stream addr + entry <- atomically $ lookupSession tcpcache now saddr bDoCon -- dput XTCP $ "acquireConnection 1 " ++ show (streamAddr stream addr, fmap (second showStat) entry) case entry of Nothing -> fmap join $ forM (guard bDoCon) $ \() -> do @@ -114,15 +133,14 @@ acquireConnection mvar tcpcache stream addr bDoCon = do return h) $ \e -> return Nothing when (isNothing mh) $ do - atomically $ modifyTVar' (lru tcpcache) - $ MM.delete (TCPAddress $ streamAddr stream addr) + atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr Socket.close sock ret <- fmap join $ forM mh $ \h -> do mst <- catchIOError (Just <$> streamHello stream addr h) (\e -> return Nothing) case mst of Nothing -> do - atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) + atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr return Nothing Just st -> do dput XTCP $ "TCP Connected! " ++ show (streamAddr stream addr) @@ -143,8 +161,7 @@ acquireConnection mvar tcpcache stream addr bDoCon = do loop Nothing -> do dput XTCP $ "TCP disconnected: " ++ show (streamAddr stream addr) - do atomically $ modifyTVar' (lru tcpcache) - $ MM.delete (TCPAddress $ streamAddr stream addr) + do atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr c <- atomically $ readTVar (lru tcpcache) now <- getPOSIXTime forM_ (zip [1..] $ MM.toList c) $ \(i,MM.Binding (TCPAddress addr) r (Down tm)) -> do @@ -163,33 +180,28 @@ acquireConnection mvar tcpcache stream addr bDoCon = do retires <- atomically $ do c <- readTVar (lru tcpcache) let (rs,c') = MM.takeView (tcpMax tcpcache) - $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) c + $ MM.insert' saddr v (Down t) c writeTVar (lru tcpcache) c' writeTVar signal True return rs forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkLabeled ("tcp-close:"++show k) $ do dput XTCP $ "TCP dropped: " ++ show k killSession r - case r of TCPSession {tcpState=st,tcpHandle=h} -> do - streamGoodbye st - hClose h - `catchIOError` \e -> return () - _ -> return () - - return $ Just $ streamEncode st + return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y when (isNothing ret) $ do - atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) + atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr return ret Just (tm, PendingTCPSession) | not bDoCon -> return Nothing | otherwise -> fmap join $ timeout tcp_timeout $ atomically $ do c <- readTVar (lru tcpcache) - let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c + let v = MM.lookup' saddr c case v of - Just (_,TCPSession{tcpState=st}) -> return $ Just $ streamEncode st - Nothing -> return Nothing - _ -> retry - Just (tm, v@TCPSession {tcpState=st}) -> return $ Just $ streamEncode st + Just (_,TCPSession{tcpHandle=h,tcpState=st}) + -> return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y + Nothing -> return Nothing + _ -> retry + Just (tm, v@TCPSession {tcpHandle=h,tcpState=st}) -> return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO () closeAll tcpcache stream = do @@ -197,9 +209,6 @@ closeAll tcpcache stream = do cache <- atomically $ swapTVar (lru tcpcache) MM.empty forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do killSession r - case r of TCPSession{tcpState=st,tcpHandle=h} -> catchIOError (streamGoodbye st >> hClose h) - (\e -> return ()) - _ -> return () -- Use a cache of TCP client connections for sending (and receiving) packets. -- The boolean value prepended to the message allows the sender to specify -- cgit v1.2.3