{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE CPP #-} module Network.QueryResponse.TCP where #ifdef THREAD_DEBUG import Control.Concurrent.Lifted.Instrument #else import Control.Concurrent.Lifted import GHC.Conc (labelThread) #endif import Control.Arrow import Control.Concurrent.STM import Control.Monad import Data.ByteString (ByteString,hPut) import Data.Function import Data.Hashable import Data.Maybe import Data.Ord import Data.Time.Clock.POSIX import Data.Word import Network.BSD import Network.Socket import System.Timeout import System.IO import System.IO.Error import DebugTag import DPut import Connection.Tcp (socketFamily) import qualified Data.MinMaxPSQ as MM import Network.QueryResponse data TCPSession st = PendingTCPSession | TCPSession { tcpHandle :: Handle , tcpState :: st , tcpThread :: ThreadId } newtype TCPAddress = TCPAddress SockAddr deriving (Eq,Ord,Show) instance Hashable TCPAddress where hashWithSalt salt (TCPAddress x) = case x of SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr) SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d) _ -> 0 data TCPCache st = TCPCache { lru :: TVar (MM.MinMaxPSQ' TCPAddress (Down POSIXTime) (TCPSession st)) , tcpMax :: Int } data SessionProtocol x y = SessionProtocol { streamGoodbye :: IO () -- ^ "Goodbye" protocol upon termination. , streamDecode :: IO (Maybe x) -- ^ Parse inbound messages. , streamEncode :: y -> IO () -- ^ Serialize outbound messages. } data StreamHandshake addr x y = StreamHandshake { streamHello :: addr -> Handle -> IO (SessionProtocol x y) -- ^ "Hello" protocol upon fresh connection. , streamAddr :: addr -> SockAddr } killSession :: TCPSession st -> IO () killSession PendingTCPSession = return () killSession TCPSession{tcpThread=t} = killThread t showStat r = case r of PendingTCPSession -> "pending." TCPSession {} -> "established." acquireConnection :: MVar (Maybe (Either a (x, addr))) -> TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> addr -> Bool -> IO (Maybe (y -> IO ())) acquireConnection mvar tcpcache stream addr bDoCon = do now <- getPOSIXTime -- dput XTCP $ "acquireConnection 0 " ++ show (streamAddr stream addr) entry <- atomically $ do c <- readTVar (lru tcpcache) let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c case v of Nothing | bDoCon -> writeTVar (lru tcpcache) $ MM.insert' (TCPAddress $ streamAddr stream addr) PendingTCPSession (Down now) c | otherwise -> return () Just (tm, v) -> modifyTVar' (lru tcpcache) $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down now) return v -- dput XTCP $ "acquireConnection 1 " ++ show (streamAddr stream addr, fmap (second showStat) entry) case entry of Nothing -> fmap join $ forM (guard bDoCon) $ \() -> do proto <- getProtocolNumber "tcp" mh <- catchIOError (do h <- timeout 10000000 $ do sock <- socket (socketFamily $ streamAddr stream addr) Stream proto connect sock (streamAddr stream addr) `catchIOError` (\e -> close sock) h <- socketToHandle sock ReadWriteMode hSetBuffering h NoBuffering return h return h) $ \e -> return Nothing ret <- fmap join $ forM mh $ \h -> do st <- streamHello stream addr h dput XTCP $ "TCP Connected! " ++ show (streamAddr stream addr) signal <- newTVarIO False rthread <- forkIO $ do atomically (readTVar signal >>= check) fix $ \loop -> do x <- streamDecode st dput XTCP $ "TCP streamDecode " ++ show (streamAddr stream addr) ++ " --> " ++ maybe "Nothing" (const "got") x case x of Just u -> do m <- timeout (1000000) $ putMVar mvar $ Just $ Right (u, addr) when (isNothing m) $ do dput XTCP $ "TCP "++show (streamAddr stream addr) ++ " dropped packet." tryTakeMVar mvar return () loop Nothing -> do dput XTCP $ "TCP disconnected: " ++ show (streamAddr stream addr) do atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) c <- atomically $ readTVar (lru tcpcache) now <- getPOSIXTime forM_ (zip [1..] $ MM.toList c) $ \(i,MM.Binding (TCPAddress addr) r (Down tm)) -> do dput XTCP $ unwords [show i ++ ".", "Still connected:", show addr, show (now - tm), showStat r] hClose h let showAddr a = show (streamAddr stream a) labelThread rthread ("tcp:"++showAddr addr) let v = TCPSession { tcpHandle = h , tcpState = st , tcpThread = rthread } t <- getPOSIXTime retires <- atomically $ do c <- readTVar (lru tcpcache) let (rs,c') = MM.takeView (tcpMax tcpcache) $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) c writeTVar (lru tcpcache) c' writeTVar signal True return rs forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do myThreadId >>= flip labelThread ("tcp-close:"++show k) dput XTCP $ "TCP dropped: " ++ show k killSession r case r of TCPSession {tcpState=st,tcpHandle=h} -> do streamGoodbye st hClose h _ -> return () return $ Just $ streamEncode st when (isNothing ret) $ do atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) return ret Just (tm, PendingTCPSession) | not bDoCon -> return Nothing | otherwise -> fmap join $ timeout 10000000 $ atomically $ do c <- readTVar (lru tcpcache) let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c case v of Just (_,TCPSession{tcpState=st}) -> return $ Just $ streamEncode st Nothing -> return Nothing _ -> retry Just (tm, v@TCPSession {tcpState=st}) -> return $ Just $ streamEncode st closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO () closeAll tcpcache stream = do cache <- atomically $ swapTVar (lru tcpcache) MM.empty forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do killSession r case r of TCPSession{tcpState=st,tcpHandle=h} -> streamGoodbye st >> hClose h _ -> return () tcpTransport :: Int -- ^ maximum number of TCP links to maintain. -> StreamHandshake addr x y -> IO (TCPCache (SessionProtocol x y), TransportA err addr x (Bool,y)) tcpTransport maxcon stream = do msgvar <- newEmptyMVar tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty) return $ (,) tcpcache Transport { awaitMessage = \f -> takeMVar msgvar >>= \x -> f x `catchIOError` (\e -> dput XTCP ("TCP transport stopped. " ++ show e) >> f Nothing) , sendMessage = \addr (bDoCon,y) -> do t <- forkIO $ do msock <- acquireConnection msgvar tcpcache stream addr bDoCon mapM_ ($ y) msock `catchIOError` \e -> dput XTCP $ "TCP-send: " ++ show e labelThread t "tcp-send" , closeTransport = closeAll tcpcache stream }