summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Network/QueryResponse/TCP.hs128
1 files changed, 128 insertions, 0 deletions
diff --git a/src/Network/QueryResponse/TCP.hs b/src/Network/QueryResponse/TCP.hs
new file mode 100644
index 00000000..83ae367f
--- /dev/null
+++ b/src/Network/QueryResponse/TCP.hs
@@ -0,0 +1,128 @@
1{-# LANGUAGE GeneralizedNewtypeDeriving #-}
2{-# LANGUAGE CPP #-}
3module Network.QueryResponse.TCP where
4
5#ifdef THREAD_DEBUG
6import Control.Concurrent.Lifted.Instrument
7#else
8import Control.Concurrent.Lifted
9import GHC.Conc (labelThread)
10#endif
11
12import Control.Concurrent.STM
13import Control.Monad
14import Data.ByteString (ByteString,hPut)
15import Data.Function
16import Data.Hashable
17import Data.Time.Clock.POSIX
18import Data.Word
19import Network.BSD
20import Network.Socket
21import System.IO
22import System.IO.Error
23
24import Connection.Tcp (socketFamily)
25import qualified Data.MinMaxPSQ as MM
26import Network.QueryResponse
27
28data TCPSession st = TCPSession
29 { tcpHandle :: MVar Handle
30 , tcpState :: st
31 , tcpThread :: ThreadId
32 }
33
34newtype TCPAddress = TCPAddress SockAddr
35 deriving (Eq,Ord)
36
37instance Hashable TCPAddress where
38 hashWithSalt salt (TCPAddress x) = case x of
39 SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr)
40 SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d)
41 _ -> 0
42
43data TCPCache st = TCPCache
44 { lru :: TVar (MM.MinMaxPSQ' TCPAddress POSIXTime (TCPSession st))
45 , tcpMax :: Int
46 }
47
48data StreamTransform st x y = StreamTransform
49 { streamHello :: Handle -> IO st -- ^ "Hello" protocol upon fresh connection.
50 , streamGoodbye :: st -> Handle -> IO () -- ^ "Goodbye" protocol upon termination.
51 , streamDecode :: st -> Handle -> IO (Maybe x) -- ^ Parse inbound messages.
52 , streamEncode :: st -> y -> IO ByteString -- ^ Serialize outbound messages.
53 }
54
55acquireConnection :: MVar (Maybe (Either a (x, SockAddr)))
56 -> TCPCache st
57 -> StreamTransform st x y
58 -> SockAddr
59 -> IO (Maybe (y -> IO ()))
60acquireConnection mvar tcpcache stream addr = do
61 cache <- atomically $ readTVar (lru tcpcache)
62 case MM.lookup' (TCPAddress addr) cache of
63 Nothing -> do
64 proto <- getProtocolNumber "tcp"
65 sock <- socket (socketFamily addr) Stream proto
66 connect sock addr `catchIOError` (\e -> close sock)
67 h <- socketToHandle sock ReadWriteMode
68 st <- streamHello stream h
69 t <- getPOSIXTime
70 mh <- newMVar h
71 rthread <- forkIO $ fix $ \loop -> do
72 x <- streamDecode stream st h
73 putMVar mvar $ fmap (\u -> Right (u, addr)) x
74 case x of
75 Just _ -> loop
76 Nothing -> do
77 atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress addr)
78 hClose h
79 labelThread rthread ("tcp:"++show addr)
80 let v = TCPSession
81 { tcpHandle = mh
82 , tcpState = st
83 , tcpThread = rthread
84 }
85 let (retires,cache') = MM.takeView (tcpMax tcpcache) $ MM.insert' (TCPAddress addr) v t cache
86 forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do
87 myThreadId >>= flip labelThread ("tcp-close:"++show k)
88 killThread (tcpThread r)
89 h <- takeMVar (tcpHandle r)
90 streamGoodbye stream st h
91 hClose h
92 atomically $ writeTVar (lru tcpcache) cache'
93
94 return $ Just $ \y -> do
95 bs <- streamEncode stream st y
96 withMVar mh (`hPut` bs)
97 Just (tm,v) -> do
98 t <- getPOSIXTime
99 let TCPSession { tcpHandle = mh, tcpState = st } = v
100 cache' = MM.insert' (TCPAddress addr) v t cache
101 atomically $ writeTVar (lru tcpcache) cache'
102 return $ Just $ \y -> do
103 bs <- streamEncode stream st y
104 withMVar mh (`hPut` bs)
105
106closeAll :: TCPCache st -> StreamTransform st x y -> IO ()
107closeAll tcpcache stream = do
108 cache <- atomically $ readTVar (lru tcpcache)
109 forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do
110 let st = tcpState r
111 killThread (tcpThread r)
112 h <- takeMVar $ tcpHandle r
113 streamGoodbye stream st h
114 hClose h
115
116tcpTransport :: Int -- ^ maximum number of TCP links to maintain.
117 -> StreamTransform st x y
118 -> IO (TransportA err SockAddr x y)
119tcpTransport maxcon stream = do
120 msgvar <- newEmptyMVar
121 tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty)
122 return Transport
123 { awaitMessage = (takeMVar msgvar >>=)
124 , sendMessage = \addr y -> do
125 msock <- acquireConnection msgvar tcpcache stream addr
126 mapM_ ($ y) msock
127 , closeTransport = closeAll tcpcache stream
128 }