summaryrefslogtreecommitdiff
path: root/src/Data/PacketQueue.hs
blob: b5d8a75637e123706ad3238c4893de811c95a58e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
-- | This module is useful for implementing a lossess protocol on top of a
-- lossy datagram style protocol.  It implements a buffer in which packets may
-- be stored out of order, but from which they are extracted in the proper
-- sequence.
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.PacketQueue
    ( PacketQueue
    , getCapacity
    , getLastDequeuedPlus1
    , new
    , dequeue
    , getMissing
    , dequeueOrGetMissing
    , markButNotDequeue
    , enqueue
    , observeOutOfBand
    , PacketOutQueue
    , packetQueueViewList
    , newOutGoing
    , readyOutGoing
    , getRequested
    , peekPacket
    , tryAppendQueueOutgoing
    , dequeueOutgoing
    , getHighestHandledPacketPlus1
    , mapOutGoing
    , OutGoingResult(..)
    ) where

import Control.Concurrent.STM
import Control.Concurrent.STM.TArray
import Control.Monad
import Control.Applicative
import Data.Word
import Data.Array.MArray
import Data.Maybe
import DPut

data PacketQueue a = PacketQueue
    { pktq    :: TArray Word32 (Maybe a)
    , seqno   :: TVar Word32 -- (buffer_start)
    , qsize   :: Word32
    , buffend :: TVar Word32 -- on incoming, next packet they'll send + 1
    }

packetQueueViewList :: PacketQueue a -> STM [(Word32,a)]
packetQueueViewList p = do
    let f (n,Nothing) = Nothing
        f (n,Just x)  = Just (n,x)
    catMaybes . map f <$> getAssocs (pktq p)

getLastDequeuedPlus1 :: PacketQueue a -> STM Word32
getLastDequeuedPlus1 PacketQueue {seqno} = readTVar seqno

getCapacity :: Applicative m => PacketQueue t -> m Word32
getCapacity (PacketQueue { qsize }) = pure qsize

-- | Create a new PacketQueue.
new :: Word32 -- ^ Capacity of queue.
    -> Word32 -- ^ Initial sequence number.
    -> STM (PacketQueue a)
new capacity seqstart = do
    let cap = if capacity `mod` 2 == 0 then capacity else capacity + 1
    q <- newArray (0,cap - 1) Nothing
    seqv <- newTVar seqstart
    bufe <- newTVar 0
    return PacketQueue
        { pktq    = q
        , seqno   = seqv
        , qsize   = cap
        , buffend = bufe
        }

-- | Update the packet queue given:
--      * packet queue
--      * the number of next lossless packet they intend to send you
observeOutOfBand :: PacketQueue a -> Word32-> STM ()
observeOutOfBand PacketQueue { seqno, qsize, buffend } numberOfNextLosslessPacketThatTheyWillSend = do
    low <- readTVar seqno
    let proj = numberOfNextLosslessPacketThatTheyWillSend - low
    -- Ignore packet if out of range.
    when ( proj < qsize) $ do
    modifyTVar' buffend (\be -> if be - low <= proj then numberOfNextLosslessPacketThatTheyWillSend + 1 else be)

-- | If seqno < buffend then return expected packet numbers for all
--   the Nothings in the array between them.
--   Otherwise, return empty list.
getMissing :: PacketQueue a -> STM [Word32]
getMissing PacketQueue { pktq, seqno, qsize, buffend } = do
    seqno0 <- readTVar seqno
    buffend0 <- readTVar buffend
    -- note relying on fact that [ b .. a ] is null when a < b
    maybes <- mapM (readArray pktq) (take (fromIntegral qsize) $ map (`mod` qsize) [ seqno0 .. buffend0 ])
    let nums = map fst . filter (isNothing . snd) $ zip [buffend0 ..]  maybes
    return nums

-- | If seqno < buffend then return expected packet numbers for all
--   the Nothings in the array between them.
--   Otherwise, behave as 'dequeue' would.
--   TODO: Do we need this function? Delete it if not.
dequeueOrGetMissing :: PacketQueue a -> STM (Either [Word32] a)
dequeueOrGetMissing PacketQueue { pktq, seqno, qsize, buffend } = do
    seqno0 <- readTVar seqno
    buffend0 <- readTVar buffend
    if seqno0 < buffend0
      then do
        maybes <- mapM (readArray pktq) (take (fromIntegral qsize) $ map (`mod` qsize) [ seqno0 .. buffend0 ])
        let nums = map fst . filter (isNothing . snd) $ zip [buffend0 ..]  maybes
        return (Left nums)
      else do
        let i = seqno0 `mod` qsize
        x <- maybe retry return =<< readArray pktq i
        writeArray pktq i Nothing
        modifyTVar' seqno   succ
        return (Right x)

-- | Retry until the next expected packet is enqueued.  Then return it.
dequeue :: PacketQueue a -> STM a
dequeue PacketQueue { pktq, seqno, qsize } = do
    i0 <- readTVar seqno
    let i = i0 `mod` qsize
    x <- maybe retry return =<< readArray pktq i
    writeArray pktq i Nothing
    modifyTVar' seqno   succ
    return x

-- | Like dequeue, but marks as viewed rather than removing
markButNotDequeue :: PacketQueue (Bool,a) -> STM a
markButNotDequeue PacketQueue { pktq, seqno, qsize } = do
    i0 <- readTVar seqno
    let i = i0 `mod` qsize
    (b,x) <- maybe retry return =<< readArray pktq i
    writeArray pktq i (Just (True,x))
    modifyTVar' seqno   succ
    return x

-- | Enqueue a packet.  Packets need not be enqueued in order as long as there
-- is spare capacity in the queue.  If there is not, the packet will be
-- silently discarded without blocking. (Unless this is an Overwrite-queue,
-- in which case, the packets will simply wrap around overwriting the old ones.)
enqueue :: PacketQueue a -- ^ The packet queue.
        -> Word32        -- ^ Sequence number of the packet.
        -> a             -- ^ The packet.
        -> STM (Word32,Word32)
enqueue PacketQueue{ pktq, seqno, qsize, buffend} no x = do
    low <- readTVar seqno
    let proj = no - low
    -- Ignore packet if out of range.
    when ( proj < qsize) $ do
        let i = no `mod` qsize
        writeArray pktq i (Just x)
        modifyTVar' buffend (\be -> if be - low <= proj then no + 1 else be)
    return (proj `divMod` qsize)

-- lookup :: PacketQueue a -> Word32 -> STM (Maybe a)
-- lookup PacketQueue{ pktq, seqno, qsize } no = _todo

-----------------------------------------------------
-- * PacketOutQueue
--

data PacketOutQueue extra msg toWire fromWire = PacketOutQueue
    { pktoInPQ   :: PacketQueue fromWire -- ^ reference to the incoming 'PacketQueue'
    , pktoOutPQ  :: PacketQueue (Word32,toWire)
    , pktoPacketNo :: TVar Word32
    , pktoToWireIO :: IO (STM extra)
    , pktoToWire :: STM extra
                 -> Word32{-packet number we expect to recieve-}
                 -> Word32{- buffer_end -}
                 -> Word32{- packet number -}
                 -> msg
                 -> STM (Maybe (toWire,Word32{-next packet no-}))
    }

mapOutGoing :: ((Word32,towire) -> Maybe (Word32,towire)) -> PacketOutQueue extra msg towire fromwire -> STM ()
mapOutGoing f q@(PacketOutQueue { pktoOutPQ=PacketQueue{ pktq } }) = do
    (z,n) <- getBounds pktq
    let ff i = do
        e <- readArray pktq i
        writeArray pktq i (e>>=f)
    mapM_ ff [z .. n]

newOutGoing :: PacketQueue fromwire
                      -- ^ Incoming queue
            -> (STM io -> Word32 {-packet number we expect to recieve-} -> Word32{-buffer_end-} -> Word32{-packet number-} -> msg -> STM (Maybe (wire,Word32{-next packet no-})))
                      -- ^ toWire callback
            -> IO (STM io)
                      -- ^ io action to get extra parameter
            -> Word32 -- ^ packet number of first outgoing packet
            -> Word32 -- ^ Capacity of queue.
            -> Word32 -- ^ Initial sequence number.
            -> STM (PacketOutQueue io msg wire fromwire)
newOutGoing inq towire toWireIO num capacity seqstart = do
    outq <- new capacity seqstart
    numVar <- newTVar num
    return $ PacketOutQueue
                { pktoInPQ = inq
                , pktoOutPQ = outq
                , pktoPacketNo = numVar
                , pktoToWireIO = toWireIO
                , pktoToWire = towire
                }

data OutGoingResult = OGSuccess | OGFull | OGEncodeFail
    deriving (Eq,Show)

-- | do something in IO before appending to the queue
readyOutGoing :: PacketOutQueue extra msg wire fromwire -> IO (STM extra)
readyOutGoing (PacketOutQueue {pktoToWireIO }) = pktoToWireIO

getRequested :: STM extra -> PacketOutQueue extra msg wire fromwire -> Word32 -> [Word8] -> STM [Maybe (Word32,wire)]
getRequested _        _     _    [] = return []
getRequested getExtra pktoq snum ns = do
    let pnums = toPNums snum ns
        indices = map toIndex pnums
    forM indices $ \i -> readArray (pktq $ pktoOutPQ pktoq) i
  where
    toIndex :: Word32 -> Word32
    toIndex = (`mod` qsize (pktoOutPQ pktoq))

    toPNums :: Word32 -> [Word8] -> [Word32]
    toPNums snum ns = reverse . snd $ foldl doOne (snum,[]) ns
        where
            doOne :: (Word32,[Word32]) -> Word8 -> (Word32,[Word32])
            doOne (addend,as) 0 = (addend+255,as)
            doOne (addend,as) x = let y = fromIntegral x + addend
                                      in (y,y:as)

peekPacket :: STM extra -> PacketOutQueue extra msg wire fromwire -> msg -> STM (Maybe (wire,Word32))
peekPacket getExtra q@(PacketOutQueue { pktoInPQ, pktoOutPQ, pktoPacketNo, pktoToWireIO, pktoToWire }) msg
 = do
    be <- readTVar (buffend pktoOutPQ)
    let i = be `mod` (qsize pktoOutPQ)
    let arrayEmpty :: MArray a e m => a Word32 e -> m Bool
        arrayEmpty ar = do (lowB,highB) <- getBounds ar
                           let result= lowB > highB
                           tput XNetCrypto
                                        ("arrayEmpty result=" ++ show result
                                        ++ " lowB="    ++ show lowB
                                        ++ " highB = " ++ show highB
                                        ++ " i = " ++ show i)
                           return result
    mbPkt <- do emp <- arrayEmpty (pktq pktoOutPQ)
                if emp then tput XNetCrypto "(peekPacket empty)" >> return Nothing
                       else do  tput XNetCrypto "(peekPacket nonempty)"
                                result <- readArray (pktq pktoOutPQ) i
                                tput XNetCrypto ("readArray (isJust result)=="  ++ show (isJust result))
                                return result
    pktno <- readTVar pktoPacketNo
    nextno <- readTVar (seqno pktoInPQ)
    pktoToWire getExtra nextno be pktno msg

-- | Convert a message to packet format and append it to the front of a queue
--   used for outgoing messages. (Note that ‘front‛ usually means the higher
--   index in this implementation.)
tryAppendQueueOutgoing :: STM extra -> PacketOutQueue extra msg wire fromwire -> msg -> STM OutGoingResult
tryAppendQueueOutgoing getExtra q@(PacketOutQueue { pktoInPQ, pktoOutPQ, pktoPacketNo, pktoToWireIO, pktoToWire }) msg
  = dtrace XNetCrypto "(tryAppendQueueOutgoing)" $ do
    be <- readTVar (buffend pktoOutPQ)
    let i = be `mod` (qsize pktoOutPQ)
    let arrayEmpty :: MArray a e m => a Word32 e -> m Bool
        arrayEmpty ar = do (lowB,highB) <- getBounds ar
                           let result= lowB > highB
                           tput XNetCrypto
                                        ("arrayEmpty result=" ++ show result
                                        ++ " lowB="    ++ show lowB
                                        ++ " highB = " ++ show highB
                                        ++ " i = " ++ show i)
                           return result
    mbPkt <- do emp <- arrayEmpty (pktq pktoOutPQ)
                if emp then tput XNetCrypto "(tryAppendQueueOutgoing empty)" >> return Nothing
                       else do  tput XNetCrypto "(tryAppendQueueOutgoing nonempty)"
                                result <- readArray (pktq pktoOutPQ) i
                                tput XNetCrypto ("readArray (isJust result)=="  ++ show (isJust result))
                                return result
    pktno <- readTVar pktoPacketNo
    nextno <- readTVar (seqno pktoInPQ)
    mbWire <- pktoToWire getExtra nextno be pktno msg
    -- TODO all the above lines ^^ can be replaced with call to peekPacket
    case dtrace XNetCrypto "(tryAppendQueueOutgoing mbWire)" mbWire of
        Just (pkt,pktno')
            -> dtrace XNetCrypto "(tryAppendQueueOutgoing A)"
             $ case mbPkt of
                    -- slot is free, insert element
                    Nothing -> dtrace XNetCrypto "(tryAppendQueueOutgoing Nothing case)" $ do
                        modifyTVar' (buffend pktoOutPQ) (+1)
                        writeTVar pktoPacketNo $! pktno'
                        writeArray (pktq pktoOutPQ) i (Just (pktno,pkt))
                        return OGSuccess
                    -- queue is full
                    Just (n,_) -> dtrace XNetCrypto "tryAppendQueueOutgoing Just case)" $ do
                        nn <- getHighestHandledPacketPlus1 q
                        if (n < nn)
                          -- but we can overwrite an old packet
                          then do
                            modifyTVar' (buffend pktoOutPQ) (+1)
                            writeTVar pktoPacketNo $! pktno'
                            writeArray (pktq pktoOutPQ) i (Just (pktno,pkt))
                            return OGSuccess
                          --  uh oh this packet is still needed...
                          else return OGFull
        -- don't know how to send this message
        Nothing -> return OGEncodeFail

dequeueOutgoing :: PacketOutQueue extra msg wire fromwire -> STM (Word32,wire)
dequeueOutgoing (PacketOutQueue {pktoOutPQ=PacketQueue { pktq, seqno, qsize }}) = do
    i0 <- readTVar seqno
    let i = i0 `mod` qsize
    x <- maybe retry return =<< readArray pktq i
    -- writeArray pktq i Nothing -- not cleaning
    modifyTVar' seqno   succ
    return x

getHighestHandledPacketPlus1 :: PacketOutQueue extra msg wire fromwire -> STM Word32
getHighestHandledPacketPlus1 (PacketOutQueue { pktoInPQ }) = readTVar (buffend pktoInPQ)