summaryrefslogtreecommitdiff
path: root/src/Network/SessionTransports.hs
blob: e9daf6c1fc65289b4865eb80e8a5a55ef2fc7c03 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
{-# LANGUAGE NamedFieldPuns #-}
module Network.SessionTransports
    ( Sessions
    , initSessions
    , newSession
    , sessionHandler
    ) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad
import qualified Data.IntMap.Strict as IntMap
         ;import Data.IntMap.Strict (IntMap)
import qualified Data.Map.Strict    as Map
         ;import Data.Map.Strict    (Map)

import Network.Address (SockAddr,either4or6)
import Network.QueryResponse
import qualified Data.IntervalSet as S
         ;import Data.IntervalSet (IntSet)

data Sessions x = Sessions
    { sessionsByAddr  :: TVar (Map SockAddr (IntMap (x -> IO Bool)))
    , sessionsById    :: TVar (IntMap SockAddr)
    , sessionIds      :: TVar IntSet
    , sessionsSendRaw :: SockAddr -> x -> IO ()
    }

initSessions :: (SockAddr -> x -> IO ()) -> IO (Sessions x)
initSessions send = atomically $ do
    byaddr <- newTVar Map.empty
    byid   <- newTVar IntMap.empty
    idset  <- newTVar S.empty
    return Sessions { sessionsByAddr  = byaddr
                    , sessionsById    = byid
                    , sessionIds      = idset
                    , sessionsSendRaw = send
                    }



rmSession :: Int -> (Maybe (IntMap x)) -> (Maybe (IntMap x))
rmSession sid Nothing = Nothing
rmSession sid (Just m) = case IntMap.delete sid m of
    m' | IntMap.null m' -> Nothing
       | otherwise      -> Just m'

newSession :: Sessions raw
                    -> (addr -> y -> IO raw)
                    -> (SockAddr -> raw -> IO (Maybe (x, addr)))
                    -> SockAddr
                    -> IO (Maybe (Int,TransportA err addr x y))
newSession Sessions{sessionsByAddr,sessionsById,sessionIds,sessionsSendRaw} unwrap wrap addr0 = do
    mvar <- newEmptyMVar
    let saddr = -- Canonical in case of 6-mapped-4 addresses.
               either id id $ either4or6 addr0
        handlePacket x = do
            m <- wrap saddr x
            case m of
                Nothing -> return False
                Just x' -> do putMVar mvar $! Just $! x'
                              return True
    msid <- atomically $ do
        msid <- S.nearestOutsider 0 <$> readTVar sessionIds
        forM msid $ \sid -> do
            modifyTVar' sessionIds $ S.insert sid
            modifyTVar' sessionsById $ IntMap.insert sid saddr
            modifyTVar' sessionsByAddr $ Map.insertWith IntMap.union saddr
                                       $ IntMap.singleton sid handlePacket
            return sid
    forM msid $ \sid -> do
    let tr = Transport
            { awaitMessage = \kont -> do
                x <- takeMVar mvar
                kont $! Right <$> x
            , sendMessage = \addr x -> do
                x' <- unwrap addr x
                sessionsSendRaw saddr x'
            , closeTransport = do
                tryTakeMVar mvar
                putMVar mvar Nothing
                atomically $ do
                    modifyTVar' sessionIds $ S.delete sid
                    modifyTVar' sessionsById $ IntMap.delete sid
                    modifyTVar' sessionsByAddr $ Map.alter (rmSession sid) saddr
            }
    return (sid,tr)

sessionHandler :: Sessions x -> (SockAddr -> x -> IO (Maybe (x -> x)))
sessionHandler Sessions{sessionsByAddr} = \addr0 x -> do
    let addr = -- Canonical in case of 6-mapped-4 addresses.
               either id id $ either4or6 addr0
        dispatch []     = return ()
        dispatch (f:fs) = do b <- f x
                             when (not b) $ dispatch fs
    fs <- atomically $ Map.lookup addr <$> readTVar sessionsByAddr
    mapM_ (dispatch . IntMap.elems) fs
    return Nothing -- consume all packets.