{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE CPP #-} module Network.QueryResponse.TCP where #ifdef THREAD_DEBUG import Control.Concurrent.Lifted.Instrument #else import Control.Concurrent.Lifted import GHC.Conc (labelThread) #endif import Control.Concurrent.STM import Control.Monad import Data.ByteString (ByteString,hPut) import Data.Function import Data.Hashable import Data.Ord import Data.Time.Clock.POSIX import Data.Word import Network.BSD import Network.Socket import System.IO import System.IO.Error import Connection.Tcp (socketFamily) import qualified Data.MinMaxPSQ as MM import Network.QueryResponse data TCPSession st = TCPSession { tcpHandle :: Handle , tcpState :: st , tcpThread :: ThreadId } newtype TCPAddress = TCPAddress SockAddr deriving (Eq,Ord) instance Hashable TCPAddress where hashWithSalt salt (TCPAddress x) = case x of SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr) SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d) _ -> 0 data TCPCache st = TCPCache { lru :: TVar (MM.MinMaxPSQ' TCPAddress (Down POSIXTime) (TCPSession st)) , tcpMax :: Int } data SessionProtocol x y = SessionProtocol { streamGoodbye :: IO () -- ^ "Goodbye" protocol upon termination. , streamDecode :: IO (Maybe x) -- ^ Parse inbound messages. , streamEncode :: y -> IO () -- ^ Serialize outbound messages. } data StreamHandshake addr x y = StreamHandshake { streamHello :: addr -> Handle -> IO (SessionProtocol x y) -- ^ "Hello" protocol upon fresh connection. , streamAddr :: addr -> SockAddr } acquireConnection :: MVar (Maybe (Either a (x, addr))) -> TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> addr -> IO (Maybe (y -> IO ())) acquireConnection mvar tcpcache stream addr = do cache <- atomically $ readTVar (lru tcpcache) case MM.lookup' (TCPAddress $ streamAddr stream addr) cache of Nothing -> do proto <- getProtocolNumber "tcp" mh <- catchIOError (do sock <- socket (socketFamily $ streamAddr stream addr) Stream proto connect sock (streamAddr stream addr) `catchIOError` (\e -> close sock) h <- socketToHandle sock ReadWriteMode return $ Just h) $ \e -> return Nothing fmap join $ forM mh $ \h -> do st <- streamHello stream addr h t <- getPOSIXTime rthread <- forkIO $ fix $ \loop -> do x <- streamDecode st putMVar mvar $ fmap (\u -> Right (u, addr)) x case x of Just _ -> loop Nothing -> do atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr) hClose h let showAddr a = show (streamAddr stream a) labelThread rthread ("tcp:"++showAddr addr) let v = TCPSession { tcpHandle = h , tcpState = st , tcpThread = rthread } let (retires,cache') = MM.takeView (tcpMax tcpcache) $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) cache forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do myThreadId >>= flip labelThread ("tcp-close:"++show k) killThread (tcpThread r) streamGoodbye st hClose (tcpHandle r) atomically $ writeTVar (lru tcpcache) cache' return $ Just $ streamEncode st Just (tm,v) -> do t <- getPOSIXTime let TCPSession { tcpHandle = h, tcpState = st } = v cache' = MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) cache atomically $ writeTVar (lru tcpcache) cache' return $ Just $ streamEncode st closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO () closeAll tcpcache stream = do cache <- atomically $ swapTVar (lru tcpcache) MM.empty forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do let st = tcpState r killThread (tcpThread r) streamGoodbye st hClose (tcpHandle r) tcpTransport :: Int -- ^ maximum number of TCP links to maintain. -> StreamHandshake addr x y -> IO (TransportA err addr x y) tcpTransport maxcon stream = do msgvar <- newEmptyMVar tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty) return Transport { awaitMessage = (takeMVar msgvar >>=) , sendMessage = \addr y -> do msock <- acquireConnection msgvar tcpcache stream addr mapM_ ($ y) msock , closeTransport = closeAll tcpcache stream }