summaryrefslogtreecommitdiff
path: root/src/Network/QueryResponse/TCP.hs
blob: efeab305585b9101562b1f0707517be725398199 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
module Network.QueryResponse.TCP where

#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent.Lifted
import GHC.Conc                  (labelThread)
#endif

import Control.Concurrent.STM
import Control.Monad
import Data.ByteString (ByteString,hPut)
import Data.Function
import Data.Hashable
import Data.Ord
import Data.Time.Clock.POSIX
import Data.Word
import Network.BSD
import Network.Socket
import System.Timeout
import System.IO
import System.IO.Error

import DebugTag
import DPut
import Connection.Tcp (socketFamily)
import qualified Data.MinMaxPSQ as MM
import Network.QueryResponse

data TCPSession st = TCPSession
    { tcpHandle :: Handle
    , tcpState  :: st
    , tcpThread :: ThreadId
    }

newtype TCPAddress = TCPAddress SockAddr
 deriving (Eq,Ord)

instance Hashable TCPAddress where
    hashWithSalt salt (TCPAddress x) = case x of
        SockAddrInet port addr   -> hashWithSalt salt (fromIntegral port :: Word16,addr)
        SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d)
        _                        -> 0

data TCPCache st = TCPCache
    { lru    :: TVar (MM.MinMaxPSQ' TCPAddress (Down POSIXTime) (TCPSession st))
    , tcpMax :: Int
    }

data SessionProtocol x y = SessionProtocol
    { streamGoodbye :: IO ()        -- ^ "Goodbye" protocol upon termination.
    , streamDecode  :: IO (Maybe x) -- ^ Parse inbound messages.
    , streamEncode  :: y -> IO ()   -- ^ Serialize outbound messages.
    }

data StreamHandshake addr x y = StreamHandshake
    { streamHello :: addr -> Handle -> IO (SessionProtocol x y) -- ^ "Hello" protocol upon fresh connection.
    , streamAddr  :: addr -> SockAddr
    }

acquireConnection :: MVar (Maybe (Either a (x, addr)))
                           -> TCPCache (SessionProtocol x y)
                           -> StreamHandshake addr x y
                           -> addr
                           -> Bool
                           -> IO (Maybe (y -> IO ()))
acquireConnection mvar tcpcache stream addr bDoCon = do
    cache <- atomically $ readTVar (lru tcpcache)
    case MM.lookup' (TCPAddress $ streamAddr stream addr) cache of
        Nothing -> fmap join $ forM (guard bDoCon) $ \() -> do
            proto <- getProtocolNumber "tcp"
            mh <- catchIOError (do sock <- socket (socketFamily $ streamAddr stream addr) Stream proto
                                   connect sock (streamAddr stream addr) `catchIOError` (\e -> close sock)
                                   h <- socketToHandle sock ReadWriteMode
                                   return $ Just h)
                               $ \e -> return Nothing
            fmap join $ forM mh $ \h -> do
            st <- streamHello stream addr h
            t <- getPOSIXTime
            rthread <- forkIO $ fix $ \loop -> do
                x <- streamDecode st
                case x of
                    Just u -> do
                        timeout (1000000) $ putMVar mvar $ Just $ Right (u, addr)
                        loop
                    Nothing -> do
                        dput XTCP $ "TCP disconnected: " ++ show (streamAddr stream addr)
                        atomically $ modifyTVar' (lru tcpcache)
                                   $ MM.delete (TCPAddress $ streamAddr stream addr)
                        hClose h
            let showAddr a = show (streamAddr stream a)
            labelThread rthread ("tcp:"++showAddr addr)
            let v = TCPSession
                    { tcpHandle = h
                    , tcpState  = st
                    , tcpThread = rthread
                    }
            retires <- atomically $ do
                c <- readTVar (lru tcpcache)
                let (rs,c') = MM.takeView (tcpMax tcpcache)
                            $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) c
                writeTVar (lru tcpcache) c'
                return rs
            forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do
                myThreadId >>= flip labelThread ("tcp-close:"++show k)
                dput XTCP $ "TCP dropped: " ++ show k
                killThread (tcpThread r)
                streamGoodbye st
                hClose (tcpHandle r)

            return $ Just $ streamEncode st
        Just (tm,v) -> do
            t <- getPOSIXTime
            let TCPSession { tcpHandle = h, tcpState = st } = v
                cache' = MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) cache
            atomically $ writeTVar (lru tcpcache) cache'
            return $ Just $ streamEncode st

closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO ()
closeAll tcpcache stream = do
    cache <- atomically $ swapTVar (lru tcpcache) MM.empty
    forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do
        let st = tcpState r
        killThread (tcpThread r)
        streamGoodbye st
        hClose (tcpHandle r)

tcpTransport :: Int -- ^ maximum number of TCP links to maintain.
                -> StreamHandshake addr x y
                -> IO (TransportA err addr x (Bool,y))
tcpTransport maxcon stream = do
    msgvar <- newEmptyMVar
    tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty)
    return Transport
        { awaitMessage   = \f -> takeMVar msgvar >>= \x -> f x `catchIOError` (\e -> dput XTCP ("TCP transport stopped. " ++ show e) >> f Nothing)
        , sendMessage    = \addr (bDoCon,y) -> do
                                t <- forkIO $ do
                                    msock <- acquireConnection msgvar tcpcache stream addr bDoCon
                                    mapM_ ($ y) msock
                                  `catchIOError` \e -> return ()
                                labelThread t "tcp-send"
        , closeTransport = closeAll tcpcache stream
        }