{-# LANGUAGE CPP #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Network.QueryResponse.TCP where #ifdef THREAD_DEBUG import Control.Concurrent.Lifted.Instrument #else import Control.Concurrent.Lifted import GHC.Conc (labelThread,forkIO) import ForkLabeled #endif import Control.Arrow import Control.Concurrent.STM import Control.Concurrent.STM.TMVar import Control.Monad import Data.ByteString (ByteString,hPut) import Data.Function import Data.Hashable import Data.Maybe import Data.Ord import Data.Time.Clock.POSIX import Data.Word import Data.String (IsString(..)) import Network.BSD import Network.Socket as Socket import System.Timeout import System.IO import System.IO.Error import DebugTag import DebugUtil import DPut import Connection.Tcp (socketFamily) import qualified Data.MinMaxPSQ as MM import Network.QueryResponse data TCPSession st = PendingTCPSession | TCPSession { tcpHandle :: Handle , tcpState :: st , tcpThread :: ThreadId } newtype TCPAddress = TCPAddress SockAddr deriving (Eq,Ord,Show) instance Hashable TCPAddress where hashWithSalt salt (TCPAddress x) = case x of SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr) SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d) _ -> 0 data TCPCache st = TCPCache { lru :: TVar (MM.MinMaxPSQ' TCPAddress (Down POSIXTime) (TCPSession st)) , tcpMax :: Int } -- This is a suitable /st/ parameter to 'TCPCache' data SessionProtocol x y = SessionProtocol { streamGoodbye :: IO () -- ^ "Goodbye" protocol upon termination. , streamDecode :: IO (Maybe x) -- ^ Parse inbound messages. , streamEncode :: y -> IO () -- ^ Serialize outbound messages. } data StreamHandshake addr x y = StreamHandshake { streamHello :: addr -> Handle -> IO (SessionProtocol x y) -- ^ "Hello" protocol upon fresh connection. , streamAddr :: addr -> SockAddr } killSession :: TCPSession st -> IO () killSession PendingTCPSession = return () killSession TCPSession{tcpThread=t} = killThread t showStat :: IsString p => TCPSession st -> p showStat r = case r of PendingTCPSession -> "pending." TCPSession {} -> "established." tcp_timeout :: Int tcp_timeout = 10000000 acquireConnection :: TMVar (Arrival a addr x) -> TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> addr -> Bool -> IO (Maybe (y -> IO ())) acquireConnection mvar tcpcache stream addr bDoCon = do now <- getPOSIXTime -- dput XTCP $ "acquireConnection 0 " ++ show (streamAddr stream addr) entry <- atomically $ do c <- readTVar (lru tcpcache) let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c case v of Nothing | bDoCon -> writeTVar (lru tcpcache) $ MM.insert' (TCPAddress $ streamAddr stream addr) PendingTCPSession (Down now) c | otherwise -> return () Just (tm, v) -> writeTVar (lru tcpcache) $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down now) c return v -- dput XTCP $ "acquireConnection 1 " ++ show (streamAddr stream addr, fmap (second showStat) entry) case entry of Nothing -> fmap join $ forM (guard bDoCon) $ \() -> do proto <- getProtocolNumber "tcp" sock <- socket (socketFamily $ streamAddr stream addr) Stream proto mh <- catchIOError (do h <- timeout tcp_timeout $ do connect sock (streamAddr stream addr) `catchIOError` (\e -> close sock) h <- socketToHandle sock ReadWriteMode hSetBuffering h NoBuffering return h return h) $ \e -> return Nothing when (isNothing mh) $ do atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) Socket.close sock ret <- fmap join $ forM mh $ \h -> do mst <- catchIOError (Just <$> streamHello stream addr h) (\e -> return Nothing) case mst of Nothing -> do atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) return Nothing Just st -> do dput XTCP $ "TCP Connected! " ++ show (streamAddr stream addr) signal <- newTVarIO False let showAddr a = show (streamAddr stream a) rthread <- forkLabeled ("tcp:"++showAddr addr) $ do atomically (readTVar signal >>= check) fix $ \loop -> do x <- streamDecode st dput XTCP $ "TCP streamDecode " ++ show (streamAddr stream addr) ++ " --> " ++ maybe "Nothing" (const "got") x case x of Just u -> do m <- timeout tcp_timeout $ atomically (putTMVar mvar $ Arrival addr u) when (isNothing m) $ do dput XTCP $ "TCP "++show (streamAddr stream addr) ++ " dropped packet." atomically $ tryTakeTMVar mvar return () loop Nothing -> do dput XTCP $ "TCP disconnected: " ++ show (streamAddr stream addr) do atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) c <- atomically $ readTVar (lru tcpcache) now <- getPOSIXTime forM_ (zip [1..] $ MM.toList c) $ \(i,MM.Binding (TCPAddress addr) r (Down tm)) -> do dput XTCP $ unwords [show i ++ ".", "Still connected:", show addr, show (now - tm), showStat r] mreport <- timeout tcp_timeout $ threadReport False -- XXX: Paranoid timeout case mreport of Just treport -> dput XTCP treport Nothing -> dput XTCP "TCP ERROR: threadReport timed out." hClose h `catchIOError` \e -> return () let v = TCPSession { tcpHandle = h , tcpState = st , tcpThread = rthread } t <- getPOSIXTime retires <- atomically $ do c <- readTVar (lru tcpcache) let (rs,c') = MM.takeView (tcpMax tcpcache) $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) c writeTVar (lru tcpcache) c' writeTVar signal True return rs forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkLabeled ("tcp-close:"++show k) $ do dput XTCP $ "TCP dropped: " ++ show k killSession r case r of TCPSession {tcpState=st,tcpHandle=h} -> do streamGoodbye st hClose h `catchIOError` \e -> return () _ -> return () return $ Just $ streamEncode st when (isNothing ret) $ do atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) return ret Just (tm, PendingTCPSession) | not bDoCon -> return Nothing | otherwise -> fmap join $ timeout tcp_timeout $ atomically $ do c <- readTVar (lru tcpcache) let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c case v of Just (_,TCPSession{tcpState=st}) -> return $ Just $ streamEncode st Nothing -> return Nothing _ -> retry Just (tm, v@TCPSession {tcpState=st}) -> return $ Just $ streamEncode st closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO () closeAll tcpcache stream = do dput XTCP "TCP.closeAll called." cache <- atomically $ swapTVar (lru tcpcache) MM.empty forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do killSession r case r of TCPSession{tcpState=st,tcpHandle=h} -> catchIOError (streamGoodbye st >> hClose h) (\e -> return ()) _ -> return () -- Use a cache of TCP client connections for sending (and receiving) packets. -- The boolean value prepended to the message allows the sender to specify -- whether or not a new connection will be initiated if neccessary. If 'False' -- is passed, then the packet will be sent only if there already exists a -- connection. tcpTransport :: Int -- ^ maximum number of TCP links to maintain. -> StreamHandshake addr x y -> IO (TCPCache (SessionProtocol x y), TransportA err addr x (Bool,y)) tcpTransport maxcon stream = do msgvar <- atomically newEmptyTMVar tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty) return $ (,) tcpcache Transport { awaitMessage = \f -> takeTMVar msgvar >>= \x -> return $ do f x `catchIOError` (\e -> dput XTCP ("TCP transport stopped. " ++ show e) >> f Terminated) , sendMessage = \addr (bDoCon,y) -> do void . forkLabeled "tcp-send" $ do msock <- acquireConnection msgvar tcpcache stream addr bDoCon mapM_ ($ y) msock `catchIOError` \e -> dput XTCP $ "TCP-send: " ++ show e , setActive = \case False -> closeAll tcpcache stream >> atomically (putTMVar msgvar Terminated) True -> return () }