1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
module Network.QueryResponse.TCP where
#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent.Lifted
import GHC.Conc (labelThread)
#endif
import Control.Concurrent.STM
import Control.Monad
import Data.ByteString (ByteString,hPut)
import Data.Function
import Data.Hashable
import Data.Time.Clock.POSIX
import Data.Word
import Network.BSD
import Network.Socket
import System.IO
import System.IO.Error
import Connection.Tcp (socketFamily)
import qualified Data.MinMaxPSQ as MM
import Network.QueryResponse
data TCPSession st = TCPSession
{ tcpHandle :: MVar Handle
, tcpState :: st
, tcpThread :: ThreadId
}
newtype TCPAddress = TCPAddress SockAddr
deriving (Eq,Ord)
instance Hashable TCPAddress where
hashWithSalt salt (TCPAddress x) = case x of
SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr)
SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d)
_ -> 0
data TCPCache st = TCPCache
{ lru :: TVar (MM.MinMaxPSQ' TCPAddress POSIXTime (TCPSession st))
, tcpMax :: Int
}
data StreamTransform st x y = StreamTransform
{ streamHello :: Handle -> IO st -- ^ "Hello" protocol upon fresh connection.
, streamGoodbye :: st -> Handle -> IO () -- ^ "Goodbye" protocol upon termination.
, streamDecode :: st -> Handle -> IO (Maybe x) -- ^ Parse inbound messages.
, streamEncode :: st -> y -> IO ByteString -- ^ Serialize outbound messages.
}
acquireConnection :: MVar (Maybe (Either a (x, SockAddr)))
-> TCPCache st
-> StreamTransform st x y
-> SockAddr
-> IO (Maybe (y -> IO ()))
acquireConnection mvar tcpcache stream addr = do
cache <- atomically $ readTVar (lru tcpcache)
case MM.lookup' (TCPAddress addr) cache of
Nothing -> do
proto <- getProtocolNumber "tcp"
sock <- socket (socketFamily addr) Stream proto
connect sock addr `catchIOError` (\e -> close sock)
h <- socketToHandle sock ReadWriteMode
st <- streamHello stream h
t <- getPOSIXTime
mh <- newMVar h
rthread <- forkIO $ fix $ \loop -> do
x <- streamDecode stream st h
putMVar mvar $ fmap (\u -> Right (u, addr)) x
case x of
Just _ -> loop
Nothing -> do
atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress addr)
hClose h
labelThread rthread ("tcp:"++show addr)
let v = TCPSession
{ tcpHandle = mh
, tcpState = st
, tcpThread = rthread
}
let (retires,cache') = MM.takeView (tcpMax tcpcache) $ MM.insert' (TCPAddress addr) v t cache
forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do
myThreadId >>= flip labelThread ("tcp-close:"++show k)
killThread (tcpThread r)
h <- takeMVar (tcpHandle r)
streamGoodbye stream st h
hClose h
atomically $ writeTVar (lru tcpcache) cache'
return $ Just $ \y -> do
bs <- streamEncode stream st y
withMVar mh (`hPut` bs)
Just (tm,v) -> do
t <- getPOSIXTime
let TCPSession { tcpHandle = mh, tcpState = st } = v
cache' = MM.insert' (TCPAddress addr) v t cache
atomically $ writeTVar (lru tcpcache) cache'
return $ Just $ \y -> do
bs <- streamEncode stream st y
withMVar mh (`hPut` bs)
closeAll :: TCPCache st -> StreamTransform st x y -> IO ()
closeAll tcpcache stream = do
cache <- atomically $ readTVar (lru tcpcache)
forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do
let st = tcpState r
killThread (tcpThread r)
h <- takeMVar $ tcpHandle r
streamGoodbye stream st h
hClose h
tcpTransport :: Int -- ^ maximum number of TCP links to maintain.
-> StreamTransform st x y
-> IO (TransportA err SockAddr x y)
tcpTransport maxcon stream = do
msgvar <- newEmptyMVar
tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty)
return Transport
{ awaitMessage = (takeMVar msgvar >>=)
, sendMessage = \addr y -> do
msock <- acquireConnection msgvar tcpcache stream addr
mapM_ ($ y) msock
, closeTransport = closeAll tcpcache stream
}
|