summaryrefslogtreecommitdiff
path: root/src/Network/QueryResponse/TCP.hs
blob: 154e9145a9491048872856158bcd9a8e14a88b22 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
module Network.QueryResponse.TCP where

#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent.Lifted
import GHC.Conc                  (labelThread)
#endif

import Control.Concurrent.STM
import Control.Monad
import Data.ByteString (ByteString,hPut)
import Data.Function
import Data.Hashable
import Data.Ord
import Data.Time.Clock.POSIX
import Data.Word
import Network.BSD
import Network.Socket
import System.IO
import System.IO.Error

import Connection.Tcp (socketFamily)
import qualified Data.MinMaxPSQ as MM
import Network.QueryResponse

data TCPSession st = TCPSession
    { tcpHandle :: Handle
    , tcpState  :: st
    , tcpThread :: ThreadId
    }

newtype TCPAddress = TCPAddress SockAddr
 deriving (Eq,Ord)

instance Hashable TCPAddress where
    hashWithSalt salt (TCPAddress x) = case x of
        SockAddrInet port addr   -> hashWithSalt salt (fromIntegral port :: Word16,addr)
        SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d)
        _                        -> 0

data TCPCache st = TCPCache
    { lru    :: TVar (MM.MinMaxPSQ' TCPAddress (Down POSIXTime) (TCPSession st))
    , tcpMax :: Int
    }

data SessionProtocol x y = SessionProtocol
    { streamGoodbye :: IO ()        -- ^ "Goodbye" protocol upon termination.
    , streamDecode  :: IO (Maybe x) -- ^ Parse inbound messages.
    , streamEncode  :: y -> IO ()   -- ^ Serialize outbound messages.
    }

data StreamHandshake addr x y = StreamHandshake
    { streamHello :: addr -> Handle -> IO (SessionProtocol x y) -- ^ "Hello" protocol upon fresh connection.
    , streamAddr  :: addr -> SockAddr
    }

acquireConnection :: MVar (Maybe (Either a (x, addr)))
                           -> TCPCache (SessionProtocol x y)
                           -> StreamHandshake addr x y
                           -> addr
                           -> IO (Maybe (y -> IO ()))
acquireConnection mvar tcpcache stream addr = do
    cache <- atomically $ readTVar (lru tcpcache)
    case MM.lookup' (TCPAddress $ streamAddr stream addr) cache of
        Nothing -> do
            proto <- getProtocolNumber "tcp"
            mh <- catchIOError (do sock <- socket (socketFamily $ streamAddr stream addr) Stream proto
                                   connect sock (streamAddr stream addr) `catchIOError` (\e -> close sock)
                                   h <- socketToHandle sock ReadWriteMode
                                   return $ Just h)
                               $ \e -> return Nothing
            fmap join $ forM mh $ \h -> do
            st <- streamHello stream addr h
            t <- getPOSIXTime
            rthread <- forkIO $ fix $ \loop -> do
                x <- streamDecode st
                putMVar mvar $ fmap (\u -> Right (u, addr)) x
                case x of
                    Just _ -> loop
                    Nothing -> do
                        atomically $ modifyTVar' (lru tcpcache)
                                   $ MM.delete (TCPAddress $ streamAddr stream addr)
                        hClose h
            let showAddr a = show (streamAddr stream a)
            labelThread rthread ("tcp:"++showAddr addr)
            let v = TCPSession
                    { tcpHandle = h
                    , tcpState  = st
                    , tcpThread = rthread
                    }
            let (retires,cache') = MM.takeView (tcpMax tcpcache)
                                    $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) cache
            forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkIO $ do
                myThreadId >>= flip labelThread ("tcp-close:"++show k)
                killThread (tcpThread r)
                streamGoodbye st
                hClose (tcpHandle r)
            atomically $ writeTVar (lru tcpcache) cache'

            return $ Just $ streamEncode st
        Just (tm,v) -> do
            t <- getPOSIXTime
            let TCPSession { tcpHandle = h, tcpState = st } = v
                cache' = MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) cache
            atomically $ writeTVar (lru tcpcache) cache'
            return $ Just $ streamEncode st

closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO ()
closeAll tcpcache stream = do
    cache <- atomically $ swapTVar (lru tcpcache) MM.empty
    forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do
        let st = tcpState r
        killThread (tcpThread r)
        streamGoodbye st
        hClose (tcpHandle r)

tcpTransport :: Int -- ^ maximum number of TCP links to maintain.
                -> StreamHandshake addr x y
                -> IO (TransportA err addr x y)
tcpTransport maxcon stream = do
    msgvar <- newEmptyMVar
    tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty)
    return Transport
        { awaitMessage   = (takeMVar msgvar >>=)
        , sendMessage    = \addr y -> do
                                msock <- acquireConnection msgvar tcpcache stream addr
                                mapM_ ($ y) msock
        , closeTransport = closeAll tcpcache stream
        }