{-# LANGUAGE CPP #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Network.QueryResponse.TCP where #ifdef THREAD_DEBUG import Control.Concurrent.Lifted.Instrument #else import Control.Concurrent.Lifted import GHC.Conc (labelThread,forkIO) import ForkLabeled #endif import Control.Arrow import Control.Concurrent.STM import Control.Concurrent.STM.TMVar import Control.Monad import Data.ByteString (ByteString,hPut) import Data.Function import Data.Hashable import Data.Maybe import Data.Ord import Data.Time.Clock.POSIX import Data.Word import Data.String (IsString(..)) import Network.BSD import Network.Socket as Socket import System.Timeout import System.IO import System.IO.Error import DebugTag import DebugUtil import DPut import Connection.Tcp (socketFamily) import qualified Data.MinMaxPSQ as MM import Network.QueryResponse data TCPSession st = PendingTCPSession | TCPSession { tcpHandle :: Handle , tcpState :: st , tcpThread :: ThreadId } newtype TCPAddress = TCPAddress SockAddr deriving (Eq,Ord,Show) instance Hashable TCPAddress where hashWithSalt salt (TCPAddress x) = case x of SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr) SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d) _ -> 0 data TCPCache st = TCPCache { lru :: TVar (MM.MinMaxPSQ' TCPAddress (Down POSIXTime) (TCPSession st)) , tcpMax :: Int } -- This is a suitable /st/ parameter to 'TCPCache' data SessionProtocol x y = SessionProtocol { streamGoodbye :: IO () -- ^ "Goodbye" protocol upon termination. , streamDecode :: IO (Maybe x) -- ^ Parse inbound messages. , streamEncode :: y -> IO () -- ^ Serialize outbound messages. } data StreamHandshake addr x y = StreamHandshake { streamHello :: addr -> Handle -> IO (SessionProtocol x y) -- ^ "Hello" protocol upon fresh connection. , streamAddr :: addr -> SockAddr } killSession :: TCPSession (SessionProtocol x y) -> IO () killSession TCPSession{tcpState=st,tcpHandle=h,tcpThread=t} = do catchIOError (streamGoodbye st >> hClose h) (\e -> return ()) killThread t killSession _ = return () showStat :: IsString p => TCPSession st -> p showStat r = case r of PendingTCPSession -> "pending." TCPSession {} -> "established." tcp_timeout :: Int tcp_timeout = 10000000 removeOnFail :: TCPCache (SessionProtocol x y) -> Handle -> TCPAddress -> IO () -> IO () removeOnFail tcpcache h addr action = action `catchIOError` \e -> do join $ atomically $ do c <- readTVar (lru tcpcache) case MM.lookup' addr c of Just (tm, v@TCPSession {tcpHandle=stored}) | h == stored -> do modifyTVar' (lru tcpcache) $ MM.delete addr return $ killSession v _ -> return $ return () dput XTCP $ "TCP-send " ++ show addr ++ " " ++ show e lookupSession :: TCPCache st -> POSIXTime -> TCPAddress -> Bool -> STM (Maybe (Down POSIXTime, TCPSession st)) lookupSession tcpcache now saddr bDoCon = do c <- readTVar (lru tcpcache) let v = MM.lookup' saddr c case v of Nothing | bDoCon -> writeTVar (lru tcpcache) $ MM.insert' saddr PendingTCPSession (Down now) c | otherwise -> return () Just (tm, v) -> writeTVar (lru tcpcache) $ MM.insert' saddr v (Down now) c return v acquireConnection :: TMVar (Arrival a addr x) -> TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> addr -> Bool -> IO (Maybe (y -> IO ())) acquireConnection mvar tcpcache stream addr bDoCon = do now <- getPOSIXTime -- dput XTCP $ "acquireConnection 0 " ++ show (streamAddr stream addr) let saddr = TCPAddress $ streamAddr stream addr entry <- atomically $ lookupSession tcpcache now saddr bDoCon -- dput XTCP $ "acquireConnection 1 " ++ show (streamAddr stream addr, fmap (second showStat) entry) case entry of Nothing -> fmap join $ forM (guard bDoCon) $ \() -> do proto <- getProtocolNumber "tcp" sock <- socket (socketFamily $ streamAddr stream addr) Stream proto mh <- catchIOError (do h <- timeout tcp_timeout $ do connect sock (streamAddr stream addr) `catchIOError` (\e -> close sock) h <- socketToHandle sock ReadWriteMode hSetBuffering h NoBuffering return h return h) $ \e -> return Nothing when (isNothing mh) $ do atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr Socket.close sock ret <- fmap join $ forM mh $ \h -> do mst <- catchIOError (Just <$> streamHello stream addr h) (\e -> return Nothing) case mst of Nothing -> do atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr return Nothing Just st -> do dput XTCP $ "TCP Connected! " ++ show (streamAddr stream addr) signal <- newTVarIO False let showAddr a = show (streamAddr stream a) rthread <- forkLabeled ("tcp:"++showAddr addr) $ do atomically (readTVar signal >>= check) fix $ \loop -> do x <- streamDecode st dput XTCP $ "TCP streamDecode " ++ show (streamAddr stream addr) ++ " --> " ++ maybe "Nothing" (const "got") x case x of Just u -> do m <- timeout tcp_timeout $ atomically (putTMVar mvar $ Arrival addr u) when (isNothing m) $ do dput XTCP $ "TCP "++show (streamAddr stream addr) ++ " dropped packet." atomically $ tryTakeTMVar mvar return () loop Nothing -> do dput XTCP $ "TCP disconnected: " ++ show (streamAddr stream addr) do atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr c <- atomically $ readTVar (lru tcpcache) now <- getPOSIXTime forM_ (zip [1..] $ MM.toList c) $ \(i,MM.Binding (TCPAddress addr) r (Down tm)) -> do dput XTCP $ unwords [show i ++ ".", "Still connected:", show addr, show (now - tm), showStat r] mreport <- timeout tcp_timeout $ threadReport False -- XXX: Paranoid timeout case mreport of Just treport -> dput XTCP treport Nothing -> dput XTCP "TCP ERROR: threadReport timed out." hClose h `catchIOError` \e -> return () let v = TCPSession { tcpHandle = h , tcpState = st , tcpThread = rthread } t <- getPOSIXTime retires <- atomically $ do c <- readTVar (lru tcpcache) let (rs,c') = MM.takeView (tcpMax tcpcache) $ MM.insert' saddr v (Down t) c writeTVar (lru tcpcache) c' writeTVar signal True return rs forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkLabeled ("tcp-close:"++show k) $ do dput XTCP $ "TCP dropped: " ++ show k killSession r return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y when (isNothing ret) $ do atomically $ modifyTVar' (lru tcpcache) $ MM.delete saddr return ret Just (tm, PendingTCPSession) | not bDoCon -> return Nothing | otherwise -> fmap join $ timeout tcp_timeout $ atomically $ do c <- readTVar (lru tcpcache) let v = MM.lookup' saddr c case v of Just (_,TCPSession{tcpHandle=h,tcpState=st}) -> return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y Nothing -> return Nothing _ -> retry Just (tm, v@TCPSession {tcpHandle=h,tcpState=st}) -> return $ Just $ \y -> removeOnFail tcpcache h saddr $ streamEncode st y closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO () closeAll tcpcache stream = do dput XTCP "TCP.closeAll called." cache <- atomically $ swapTVar (lru tcpcache) MM.empty forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do killSession r -- Use a cache of TCP client connections for sending (and receiving) packets. -- The boolean value prepended to the message allows the sender to specify -- whether or not a new connection will be initiated if neccessary. If 'False' -- is passed, then the packet will be sent only if there already exists a -- connection. tcpTransport :: Int -- ^ maximum number of TCP links to maintain. -> StreamHandshake addr x y -> IO (TCPCache (SessionProtocol x y), TransportA err addr x (Bool,y)) tcpTransport maxcon stream = do msgvar <- atomically newEmptyTMVar tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty) return $ (,) tcpcache Transport { awaitMessage = do x <- takeTMVar msgvar return (x, return ()) , sendMessage = \addr (bDoCon,y) -> do void . forkLabeled "tcp-send" $ do msock <- acquireConnection msgvar tcpcache stream addr bDoCon mapM_ ($ y) msock `catchIOError` \e -> dput XTCP $ "TCP-send: " ++ show e , setActive = \case False -> closeAll tcpcache stream >> atomically (putTMVar msgvar Terminated) True -> return () }