diff options
author | Joe Crayne <joe@jerkface.net> | 2018-12-02 00:07:39 -0500 |
---|---|---|
committer | Joe Crayne <joe@jerkface.net> | 2018-12-16 14:08:26 -0500 |
commit | e15b8709e5091808a50630372f278fcbd844d638 (patch) | |
tree | 7a1d0ce7198ee3c35da67ef5fd31d4acfd56deff /src/Network | |
parent | 7a20395e8fe10625a239337aba24c3480e9b5e45 (diff) |
TCP-based Network.QueryResponse.Transport.
Diffstat (limited to 'src/Network')
-rw-r--r-- | src/Network/QueryResponse/TCP.hs | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/src/Network/QueryResponse/TCP.hs b/src/Network/QueryResponse/TCP.hs new file mode 100644 index 00000000..83ae367f --- /dev/null +++ b/src/Network/QueryResponse/TCP.hs | |||
@@ -0,0 +1,128 @@ | |||
1 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
2 | {-# LANGUAGE CPP #-} | ||
3 | module Network.QueryResponse.TCP where | ||
4 | |||
5 | #ifdef THREAD_DEBUG | ||
6 | import Control.Concurrent.Lifted.Instrument | ||
7 | #else | ||
8 | import Control.Concurrent.Lifted | ||
9 | import GHC.Conc (labelThread) | ||
10 | #endif | ||
11 | |||
12 | import Control.Concurrent.STM | ||
13 | import Control.Monad | ||
14 | import Data.ByteString (ByteString,hPut) | ||
15 | import Data.Function | ||
16 | import Data.Hashable | ||
17 | import Data.Time.Clock.POSIX | ||
18 | import Data.Word | ||
19 | import Network.BSD | ||
20 | import Network.Socket | ||
21 | import System.IO | ||
22 | import System.IO.Error | ||
23 | |||
24 | import Connection.Tcp (socketFamily) | ||
25 | import qualified Data.MinMaxPSQ as MM | ||
26 | import Network.QueryResponse | ||
27 | |||
28 | data TCPSession st = TCPSession | ||
29 | { tcpHandle :: MVar Handle | ||
30 | , tcpState :: st | ||
31 | , tcpThread :: ThreadId | ||
32 | } | ||
33 | |||
34 | newtype TCPAddress = TCPAddress SockAddr | ||
35 | deriving (Eq,Ord) | ||
36 | |||
37 | instance Hashable TCPAddress where | ||
38 | hashWithSalt salt (TCPAddress x) = case x of | ||
39 | SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr) | ||
40 | SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d) | ||
41 | _ -> 0 | ||
42 | |||
43 | data TCPCache st = TCPCache | ||
44 | { lru :: TVar (MM.MinMaxPSQ' TCPAddress POSIXTime (TCPSession st)) | ||
45 | , tcpMax :: Int | ||
46 | } | ||
47 | |||
48 | data StreamTransform st x y = StreamTransform | ||
49 | { streamHello :: Handle -> IO st -- ^ "Hello" protocol upon fresh connection. | ||
50 | , streamGoodbye :: st -> Handle -> IO () -- ^ "Goodbye" protocol upon termination. | ||
51 | , streamDecode :: st -> Handle -> IO (Maybe x) -- ^ Parse inbound messages. | ||
52 | , streamEncode :: st -> y -> IO ByteString -- ^ Serialize outbound messages. | ||
53 | } | ||
54 | |||
55 | acquireConnection :: MVar (Maybe (Either a (x, SockAddr))) | ||
56 | -> TCPCache st | ||
57 | -> StreamTransform st x y | ||
58 | -> SockAddr | ||
59 | -> IO (Maybe (y -> IO ())) | ||
60 | acquireConnection mvar tcpcache stream addr = do | ||
61 | cache <- atomically $ readTVar (lru tcpcache) | ||
62 | case MM.lookup' (TCPAddress addr) cache of | ||
63 | Nothing -> do | ||
64 | proto <- getProtocolNumber "tcp" | ||
65 | sock <- socket (socketFamily addr) Stream proto | ||
66 | connect sock addr `catchIOError` (\e -> close sock) | ||
67 | h <- socketToHandle sock ReadWriteMode | ||
68 | st <- streamHello stream h | ||
69 | t <- getPOSIXTime | ||
70 | mh <- newMVar h | ||
71 | rthread <- forkIO $ fix $ \loop -> do | ||
72 | x <- streamDecode stream st h | ||
73 | putMVar mvar $ fmap (\u -> Right (u, addr)) x | ||
74 | case x of | ||
75 | Just _ -> loop | ||
76 | Nothing -> do | ||
77 | atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress addr) | ||
78 | hClose h | ||
79 | labelThread rthread ("tcp:"++show addr) | ||
80 | let v = TCPSession | ||
81 | { tcpHandle = mh | ||
82 | , tcpState = st | ||
83 | , tcpThread = rthread | ||
84 | } | ||
85 | let (retires,cache') = MM.takeView (tcpMax tcpcache) $ MM.insert' (TCPAddress addr) v t cache | ||
86 | forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do | ||
87 | myThreadId >>= flip labelThread ("tcp-close:"++show k) | ||
88 | killThread (tcpThread r) | ||
89 | h <- takeMVar (tcpHandle r) | ||
90 | streamGoodbye stream st h | ||
91 | hClose h | ||
92 | atomically $ writeTVar (lru tcpcache) cache' | ||
93 | |||
94 | return $ Just $ \y -> do | ||
95 | bs <- streamEncode stream st y | ||
96 | withMVar mh (`hPut` bs) | ||
97 | Just (tm,v) -> do | ||
98 | t <- getPOSIXTime | ||
99 | let TCPSession { tcpHandle = mh, tcpState = st } = v | ||
100 | cache' = MM.insert' (TCPAddress addr) v t cache | ||
101 | atomically $ writeTVar (lru tcpcache) cache' | ||
102 | return $ Just $ \y -> do | ||
103 | bs <- streamEncode stream st y | ||
104 | withMVar mh (`hPut` bs) | ||
105 | |||
106 | closeAll :: TCPCache st -> StreamTransform st x y -> IO () | ||
107 | closeAll tcpcache stream = do | ||
108 | cache <- atomically $ readTVar (lru tcpcache) | ||
109 | forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do | ||
110 | let st = tcpState r | ||
111 | killThread (tcpThread r) | ||
112 | h <- takeMVar $ tcpHandle r | ||
113 | streamGoodbye stream st h | ||
114 | hClose h | ||
115 | |||
116 | tcpTransport :: Int -- ^ maximum number of TCP links to maintain. | ||
117 | -> StreamTransform st x y | ||
118 | -> IO (TransportA err SockAddr x y) | ||
119 | tcpTransport maxcon stream = do | ||
120 | msgvar <- newEmptyMVar | ||
121 | tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty) | ||
122 | return Transport | ||
123 | { awaitMessage = (takeMVar msgvar >>=) | ||
124 | , sendMessage = \addr y -> do | ||
125 | msock <- acquireConnection msgvar tcpcache stream addr | ||
126 | mapM_ ($ y) msock | ||
127 | , closeTransport = closeAll tcpcache stream | ||
128 | } | ||