summaryrefslogtreecommitdiff
path: root/src/Network/QueryResponse/TCP.hs
blob: 83ae367f597b76456f2cf232a9abc2ce409651aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
module Network.QueryResponse.TCP where

#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent.Lifted
import GHC.Conc                  (labelThread)
#endif

import Control.Concurrent.STM
import Control.Monad
import Data.ByteString (ByteString,hPut)
import Data.Function
import Data.Hashable
import Data.Time.Clock.POSIX
import Data.Word
import Network.BSD
import Network.Socket
import System.IO
import System.IO.Error

import Connection.Tcp (socketFamily)
import qualified Data.MinMaxPSQ as MM
import Network.QueryResponse

data TCPSession st = TCPSession
    { tcpHandle :: MVar Handle
    , tcpState  :: st
    , tcpThread :: ThreadId
    }

newtype TCPAddress = TCPAddress SockAddr
 deriving (Eq,Ord)

instance Hashable TCPAddress where
    hashWithSalt salt (TCPAddress x) = case x of
        SockAddrInet port addr   -> hashWithSalt salt (fromIntegral port :: Word16,addr)
        SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d)
        _                        -> 0

data TCPCache st = TCPCache
    { lru    :: TVar (MM.MinMaxPSQ' TCPAddress POSIXTime (TCPSession st))
    , tcpMax :: Int
    }

data StreamTransform st x y = StreamTransform
    { streamHello :: Handle -> IO st               -- ^ "Hello" protocol upon fresh connection.
    , streamGoodbye :: st -> Handle -> IO ()       -- ^ "Goodbye" protocol upon termination.
    , streamDecode :: st -> Handle -> IO (Maybe x) -- ^ Parse inbound messages.
    , streamEncode :: st -> y -> IO ByteString     -- ^ Serialize outbound messages.
    }

acquireConnection :: MVar (Maybe (Either a (x, SockAddr)))
                           -> TCPCache st
                           -> StreamTransform st x y
                           -> SockAddr
                           -> IO (Maybe (y -> IO ()))
acquireConnection mvar tcpcache stream addr = do
    cache <- atomically $ readTVar (lru tcpcache)
    case MM.lookup' (TCPAddress addr) cache of
        Nothing -> do
            proto <- getProtocolNumber "tcp"
            sock <- socket (socketFamily addr) Stream proto
            connect sock addr `catchIOError` (\e -> close sock)
            h <- socketToHandle sock ReadWriteMode
            st <- streamHello stream h
            t <- getPOSIXTime
            mh <- newMVar h
            rthread <- forkIO $ fix $ \loop -> do
                x <- streamDecode stream st h
                putMVar mvar $ fmap (\u -> Right (u, addr)) x
                case x of
                    Just _ -> loop
                    Nothing -> do
                        atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress addr)
                        hClose h
            labelThread rthread ("tcp:"++show addr)
            let v = TCPSession
                    { tcpHandle = mh
                    , tcpState  = st
                    , tcpThread = rthread
                    }
            let (retires,cache') = MM.takeView (tcpMax tcpcache) $ MM.insert' (TCPAddress addr) v t cache
            forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do
                myThreadId >>= flip labelThread ("tcp-close:"++show k)
                killThread (tcpThread r)
                h <- takeMVar (tcpHandle r)
                streamGoodbye stream st h
                hClose h
            atomically $ writeTVar (lru tcpcache) cache'

            return $ Just $ \y -> do
                                bs <- streamEncode stream st y
                                withMVar mh (`hPut` bs)
        Just (tm,v) -> do
            t <- getPOSIXTime
            let TCPSession { tcpHandle = mh, tcpState = st } = v
                cache' = MM.insert' (TCPAddress addr) v t cache
            atomically $ writeTVar (lru tcpcache) cache'
            return $ Just $ \y -> do
                                bs <- streamEncode stream st y
                                withMVar mh (`hPut` bs)

closeAll :: TCPCache st -> StreamTransform st x y -> IO ()
closeAll tcpcache stream = do
    cache <- atomically $ readTVar (lru tcpcache)
    forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do
        let st = tcpState r
        killThread (tcpThread r)
        h <- takeMVar $ tcpHandle r
        streamGoodbye stream st h
        hClose h

tcpTransport :: Int -- ^ maximum number of TCP links to maintain.
                -> StreamTransform st x y
                -> IO (TransportA err SockAddr x y)
tcpTransport maxcon stream = do
    msgvar <- newEmptyMVar
    tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty)
    return Transport
        { awaitMessage   = (takeMVar msgvar >>=)
        , sendMessage    = \addr y -> do
                                msock <- acquireConnection msgvar tcpcache stream addr
                                mapM_ ($ y) msock
        , closeTransport = closeAll tcpcache stream
        }