summaryrefslogtreecommitdiff
path: root/dht/src/Network/SessionTransports.hs
blob: 68233cd4d52912937f0be1a67b99c3580e5527f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
{-# LANGUAGE LambdaCase     #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TupleSections  #-}
module Network.SessionTransports
    ( Sessions
    , initSessions
    , newSession
    , sessionHandler
    ) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Concurrent.STM.TMVar
import Control.Monad
import qualified Data.IntMap.Strict as IntMap
         ;import Data.IntMap.Strict (IntMap)
import qualified Data.Map.Strict    as Map
         ;import Data.Map.Strict    (Map)

import qualified Data.Tox.DHT.Multi as Multi
import Network.Address (SockAddr,either4or6)
import Network.QueryResponse
import qualified Data.IntervalSet as S
         ;import Data.IntervalSet (IntSet)

data Sessions x = Sessions
    { sessionsByAddr  :: TVar (Map Multi.SessionAddress (IntMap (x -> IO Bool)))
    , sessionsById    :: TVar (IntMap Multi.SessionAddress)
    , sessionIds      :: TVar IntSet
    , sessionsSendRaw :: Multi.SessionAddress -> x -> IO ()
    }

initSessions :: (Multi.SessionAddress -> x -> IO ()) -> IO (Sessions x)
initSessions send = atomically $ do
    byaddr <- newTVar Map.empty
    byid   <- newTVar IntMap.empty
    idset  <- newTVar S.empty
    return Sessions { sessionsByAddr  = byaddr
                    , sessionsById    = byid
                    , sessionIds      = idset
                    , sessionsSendRaw = send
                    }



rmSession :: Int -> (Maybe (IntMap x)) -> (Maybe (IntMap x))
rmSession sid Nothing = Nothing
rmSession sid (Just m) = case IntMap.delete sid m of
    m' | IntMap.null m' -> Nothing
       | otherwise      -> Just m'

newSession :: Sessions raw
                    -> (addr -> y -> IO raw)
                    -> (Multi.SessionAddress -> raw -> IO (Maybe (x, addr)))
                    -> Multi.SessionAddress
                    -> IO (Maybe (Int,TransportA err addr x y))
newSession Sessions{sessionsByAddr,sessionsById,sessionIds,sessionsSendRaw} unwrap wrap addr0 = do
    mvar <- atomically newEmptyTMVar
    let saddr = -- Canonical in case of 6-mapped-4 addresses.
                Multi.canonize addr0
        handlePacket x = do
            m <- wrap saddr x
            case m of
                Nothing -> return False
                Just x' -> do atomically $ putTMVar mvar $! Just $! x'
                              return True
    msid <- atomically $ do
        msid <- S.nearestOutsider 0 <$> readTVar sessionIds
        forM msid $ \sid -> do
            modifyTVar' sessionIds $ S.insert sid
            modifyTVar' sessionsById $ IntMap.insert sid saddr
            modifyTVar' sessionsByAddr $ Map.insertWith IntMap.union saddr
                                       $ IntMap.singleton sid handlePacket
            return sid
    forM msid $ \sid -> do
      let tr = Transport
            { awaitMessage = do
                x <- takeTMVar mvar
                return $ (, return ()) $ maybe Terminated (uncurry $ flip Arrival) x
            , sendMessage = \addr x -> do
                x' <- unwrap addr x
                sessionsSendRaw saddr x'
            , setActive = \case
                False -> do
                    atomically $ do
                        tryTakeTMVar mvar
                        putTMVar mvar Nothing
                    atomically $ do
                        modifyTVar' sessionIds $ S.delete sid
                        modifyTVar' sessionsById $ IntMap.delete sid
                        modifyTVar' sessionsByAddr $ Map.alter (rmSession sid) saddr
                True -> return ()
            }
      return (sid,tr)

sessionHandler :: Sessions x -> Arrival err Multi.SessionAddress x
                             -> STM (Arrival err Multi.SessionAddress x, IO ())
sessionHandler Sessions{sessionsByAddr} (Arrival addr0 x) = return $ (,) Discarded $ do
    let addr = -- Canonical in case of 6-mapped-4 addresses.
               Multi.canonize addr0
        dispatch []     = return ()
        dispatch (f:fs) = do b <- f x
                             when (not b) $ dispatch fs
    fs <- atomically $ Map.lookup addr <$> readTVar sessionsByAddr
    mapM_ (dispatch . IntMap.elems) fs
sessionHandler _ m = return (m, return ())