1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
module Network.QueryResponse.TCP where
#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent.Lifted
import GHC.Conc (labelThread)
#endif
import Control.Concurrent.STM
import Control.Monad
import Data.ByteString (ByteString,hPut)
import Data.Function
import Data.Hashable
import Data.Ord
import Data.Time.Clock.POSIX
import Data.Word
import Network.BSD
import Network.Socket
import System.IO
import System.IO.Error
import Connection.Tcp (socketFamily)
import qualified Data.MinMaxPSQ as MM
import Network.QueryResponse
data TCPSession st = TCPSession
{ tcpHandle :: Handle
, tcpState :: st
, tcpThread :: ThreadId
}
newtype TCPAddress = TCPAddress SockAddr
deriving (Eq,Ord)
instance Hashable TCPAddress where
hashWithSalt salt (TCPAddress x) = case x of
SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr)
SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d)
_ -> 0
data TCPCache st = TCPCache
{ lru :: TVar (MM.MinMaxPSQ' TCPAddress (Down POSIXTime) (TCPSession st))
, tcpMax :: Int
}
data SessionProtocol x y = SessionProtocol
{ streamGoodbye :: IO () -- ^ "Goodbye" protocol upon termination.
, streamDecode :: IO (Maybe x) -- ^ Parse inbound messages.
, streamEncode :: y -> IO () -- ^ Serialize outbound messages.
}
data StreamHandshake addr x y = StreamHandshake
{ streamHello :: addr -> Handle -> IO (SessionProtocol x y) -- ^ "Hello" protocol upon fresh connection.
, streamAddr :: addr -> SockAddr
}
acquireConnection :: MVar (Maybe (Either a (x, addr)))
-> TCPCache (SessionProtocol x y)
-> StreamHandshake addr x y
-> addr
-> IO (Maybe (y -> IO ()))
acquireConnection mvar tcpcache stream addr = do
cache <- atomically $ readTVar (lru tcpcache)
case MM.lookup' (TCPAddress $ streamAddr stream addr) cache of
Nothing -> do
proto <- getProtocolNumber "tcp"
sock <- socket (socketFamily $ streamAddr stream addr) Stream proto
connect sock (streamAddr stream addr) `catchIOError` (\e -> close sock)
h <- socketToHandle sock ReadWriteMode
st <- streamHello stream addr h
t <- getPOSIXTime
rthread <- forkIO $ fix $ \loop -> do
x <- streamDecode st
putMVar mvar $ fmap (\u -> Right (u, addr)) x
case x of
Just _ -> loop
Nothing -> do
atomically $ modifyTVar' (lru tcpcache)
$ MM.delete (TCPAddress $ streamAddr stream addr)
hClose h
let showAddr a = show (streamAddr stream a)
labelThread rthread ("tcp:"++showAddr addr)
let v = TCPSession
{ tcpHandle = h
, tcpState = st
, tcpThread = rthread
}
let (retires,cache') = MM.takeView (tcpMax tcpcache)
$ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) cache
forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do
myThreadId >>= flip labelThread ("tcp-close:"++show k)
killThread (tcpThread r)
streamGoodbye st
hClose (tcpHandle r)
atomically $ writeTVar (lru tcpcache) cache'
return $ Just $ streamEncode st
Just (tm,v) -> do
t <- getPOSIXTime
let TCPSession { tcpHandle = h, tcpState = st } = v
cache' = MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) cache
atomically $ writeTVar (lru tcpcache) cache'
return $ Just $ streamEncode st
closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO ()
closeAll tcpcache stream = do
cache <- atomically $ swapTVar (lru tcpcache) MM.empty
forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do
let st = tcpState r
killThread (tcpThread r)
streamGoodbye st
hClose (tcpHandle r)
tcpTransport :: Int -- ^ maximum number of TCP links to maintain.
-> StreamHandshake addr x y
-> IO (TransportA err addr x y)
tcpTransport maxcon stream = do
msgvar <- newEmptyMVar
tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty)
return Transport
{ awaitMessage = (takeMVar msgvar >>=)
, sendMessage = \addr y -> do
msock <- acquireConnection msgvar tcpcache stream addr
mapM_ ($ y) msock
, closeTransport = closeAll tcpcache stream
}
|