summaryrefslogtreecommitdiff
path: root/server/src/Network/QueryResponse.hs
blob: 0c2009164bd5d660c56dd7ac1ff5494050bfe5b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
-- | This module can implement any query\/response protocol.  It was written
-- with Kademlia implementations in mind.

{-# LANGUAGE CPP                   #-}
{-# LANGUAGE DeriveFoldable        #-}
{-# LANGUAGE DeriveFunctor         #-}
{-# LANGUAGE DeriveTraversable     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TupleSections         #-}
module Network.QueryResponse where

#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent
import GHC.Conc           (labelThread)
#endif
import Control.Arrow
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import qualified Data.ByteString    as B
         ;import Data.ByteString    (ByteString)
import Data.Dependent.Map as DMap
import Data.Dependent.Sum
import Data.Function
import Data.Functor.Contravariant
import Data.Functor.Identity
import Data.GADT.Show
import qualified Data.IntMap.Strict as IntMap
         ;import Data.IntMap.Strict (IntMap)
import qualified Data.Map.Strict    as Map
         ;import Data.Map.Strict    (Map)
import Data.Time.Clock.POSIX
import Data.Traversable             (Traversable)
import qualified Data.Word64Map     as W64Map
         ;import Data.Word64Map     (Word64Map)
import Data.Word
import Data.Maybe
import GHC.Conc (closeFdWith)
import GHC.Event
import Network.Socket
import Network.Socket.ByteString    as B
import System.Endian
import System.IO
import System.IO.Error
import System.Timeout

import DPut
import DebugTag
import Data.TableMethods

-- | The reply to a query to a remote server or the result of some other IO
-- process that can timeout or be canceled.
data Result a = Success a | TimedOut | Canceled
 deriving (Functor, Foldable, Traversable, Eq, Ord, Show)

resultToMaybe :: Result a -> Maybe a
resultToMaybe (Success a) = Just a
resultToMaybe _           = Nothing

-- | An inbound packet or condition raised while monitoring a connection.
data Arrival err addr x
    = Terminated       -- ^ Virtual message that signals EOF.
    | Discarded        -- ^ Message dropped or passed to another thread.
    | ParseError !err  -- ^ A badly-formed message was received.
    | Arrival { arrivedFrom :: !addr , arrivedMsg  :: !x } -- ^ Inbound message.

-- | Three methods are required to implement a datagram based query\/response protocol.
data TransportA err addr x y = Transport
    { -- | Blocks until an inbound packet is available. Then calls the provided
      -- continuation with the packet and origin address or an error condition.
      awaitMessage :: STM (Arrival err addr x, IO ())
      -- | Send an /y/ packet to the given destination /addr/.
    , sendMessage :: addr -> y -> IO ()
      -- | Shutdown and clean up any state related to this 'Transport'.
    , setActive :: Bool -> IO ()
    }

type Transport err addr x = TransportA err addr x x

nullTransport :: TransportA err addr x y
nullTransport = Transport
    { awaitMessage = retry
    , sendMessage  = \_ _ -> return ()
    , setActive    = \_ -> return ()
    }

closeTransport :: TransportA err addr x y -> IO ()
closeTransport tr = setActive tr False

-- | This function modifies a 'Transport' to use higher-level addresses and
-- packet representations.  It could be used to change UDP 'ByteString's into
-- bencoded syntax trees or to add an encryption layer in which addresses have
-- associated public keys.
layerTransportM ::
        (x -> addr -> STM (Either err (x', addr')))
        -- ^ Function that attempts to transform a low-level address/packet
        -- pair into a higher level representation.
        -> (y' -> addr' -> IO (y, addr))
        -- ^ Function to encode a high-level address/packet into a lower level
        -- representation.
        -> TransportA err addr x y
        -- ^ The low-level transport to be transformed.
        -> TransportA err addr' x' y'
layerTransportM parse encode tr =
    tr { awaitMessage = do
            (m,io) <- awaitMessage tr
            case m of
                Terminated     -> return $ (,) Terminated io
                Discarded      -> return $ (,) Discarded  io
                ParseError e   -> return $ (ParseError e,io)
                Arrival addr x -> parse x addr >>= \case
                    Left e           -> return (ParseError e, io)
                    Right (x',addr') -> return (Arrival addr' x', io)
       , sendMessage = \addr' msg' -> do
                (msg,addr) <- encode msg' addr'
                sendMessage tr addr msg
       }


-- | This function modifies a 'Transport' to use higher-level addresses and
-- packet representations.  It could be used to change UDP 'ByteString's into
-- bencoded syntax trees or to add an encryption layer in which addresses have
-- associated public keys.
layerTransport ::
        (x -> addr -> Either err (x', addr'))
        -- ^ Function that attempts to transform a low-level address/packet
        -- pair into a higher level representation.
        -> (y' -> addr' -> (y, addr))
        -- ^ Function to encode a high-level address/packet into a lower level
        -- representation.
        -> TransportA err addr x y
        -- ^ The low-level transport to be transformed.
        -> TransportA err addr' x' y'
layerTransport parse encode tr =
    layerTransportM (\x addr -> return $ parse x addr)
                    (\x' addr' -> return $ encode x' addr')
                    tr

-- | Paritions a 'Transport' into two higher-level transports.  Note: A 'TChan'
-- is used to share the same underlying socket, so be sure to fork a thread for
-- both returned 'Transport's to avoid hanging.
partitionTransportM :: ((b,a) -> STM (Either (x,xaddr) (b,a)))
                      -> ((y,xaddr) -> IO (Maybe (c,a)))
                      -> TransportA err a b c
                      -> IO (TransportA err xaddr x y, TransportA err a b c)
partitionTransportM parse encodex tr = do
    tchan <- atomically newTChan
    let ytr = tr { awaitMessage = do
                    (m,io) <- awaitMessage tr
                    case m of
                        Arrival adr msg  -> parse (msg,adr) >>= \case
                            Left x          -> return (Discarded, io >> atomically (writeTChan tchan (Just x)))
                            Right (y,yaddr) -> return (Arrival yaddr y, io)
                        Terminated       -> return (Terminated, io >> atomically (writeTChan tchan Nothing))
                        _                -> return (m,io)
                 , sendMessage = sendMessage tr
                 }
        xtr = Transport
                { awaitMessage = readTChan tchan >>= \case
                    Nothing        -> return (Terminated, return ())
                    Just (x,xaddr) -> return (Arrival xaddr x, return ())
                , sendMessage = \addr' msg' -> do
                    msg_addr <- encodex (msg',addr')
                    mapM_ (uncurry . flip $ sendMessage tr) msg_addr
                , setActive = const $ return ()
                }
    return (xtr, ytr)

-- | Paritions a 'Transport' into two higher-level transports.  Note: An 'TChan'
-- is used to share the same underlying socket, so be sure to fork a thread for
-- both returned 'Transport's to avoid hanging.
partitionTransport :: ((b,a) -> Either (x,xaddr) (b,a))
                      -> ((y,xaddr) -> Maybe (c,a))
                      -> TransportA err a b c
                      -> IO (TransportA err xaddr x y, TransportA err a b c)
partitionTransport parse encodex tr =
    partitionTransportM (return . parse) (return . encodex) tr

addHandler :: (Arrival err addr x -> STM (Arrival err addr x, IO ())) -> TransportA err addr x y -> TransportA err addr x y
addHandler f tr = tr
    { awaitMessage = do
        (m,io1) <- awaitMessage tr
        (m', io2) <- f m
        return (m', io1 >> io2)
    }

forArrival :: Applicative m => (addr -> x -> IO ()) -> Arrival err addr x -> m (Arrival err addr x, IO ())
forArrival f (Arrival addr x) = pure (Arrival addr x, f addr x)
forArrival _ m                = pure (m, return ())

-- | Modify a 'Transport' to invoke an action upon every received packet.
onInbound :: (addr -> x -> IO ()) -> Transport err addr x -> Transport err addr x
onInbound f tr = addHandler (forArrival f) tr

-- * Using a query\/response client.

-- | Fork a thread that handles inbound packets.  The returned action may be used
-- to terminate the thread and clean up any related state.
--
--  Example usage:
--
--  > -- Start client.
--  > quitServer <- forkListener "listener" (\_ -> return()) (clientNet client)
--  > -- Send a query q, recieve a response r.
--  > r <- sendQuery client method q
--  > -- Quit client.
--  > quitServer
forkListener :: String -> (err -> IO ()) -> Transport err addr x -> IO (IO ())
forkListener name onParseError client = do
    setActive client True
    thread_id <- forkIO $ do
        myThreadId >>= flip labelThread ("listener."++name)
        fix $ \loop -> do
            (m,io) <- atomically $ awaitMessage client
            io
            case m of
                Terminated   -> return ()
                ParseError e -> onParseError e >> loop
                _            -> loop
        dput XMisc $ "Listener died: " ++ name
    return $ do
        setActive client False
        -- killThread thread_id

-- * Implementing a query\/response 'Client'.

-- | These methods indicate what should be done upon various conditions.  Write
-- to a log file, make debug prints, or simply ignore them.
--
--   [ /addr/ ]  Address of remote peer.
--
--   [ /x/ ]     Incoming or outgoing packet.
--
--   [ /meth/ ]  Method id of incoming or outgoing request.
--
--   [ /tid/ ]   Transaction id for outgoing packet.
--
--   [ /err/ ]   Error information, typically a 'String'.
data ErrorReporter addr x meth tid err = ErrorReporter
    { -- | Incoming: failed to parse packet.
      reportParseError :: err -> IO ()
      -- | Incoming: no handler for request.
    , reportMissingHandler :: meth -> addr -> x -> IO ()
      -- | Incoming: unable to identify request.
    , reportUnknown :: addr -> x -> err -> IO ()
    }

ignoreErrors :: ErrorReporter addr x meth tid err
ignoreErrors = ErrorReporter
    { reportParseError = \_ -> return ()
    , reportMissingHandler = \_ _ _ -> return ()
    , reportUnknown = \_ _ _ -> return ()
    }

logErrors :: ( Show addr
               , Show meth
               ) => ErrorReporter addr x meth tid String
logErrors = ErrorReporter
    { reportParseError = \err -> dput XMisc err
    , reportMissingHandler = \meth addr x -> dput XMisc $ show addr ++ " --> Missing handler ("++show meth++")"
    , reportUnknown = \addr x err -> dput XMisc $ show addr ++ " --> " ++ err
    }

printErrors :: ( Show addr
               , Show meth
               ) => Handle -> ErrorReporter addr x meth tid String
printErrors h = ErrorReporter
    { reportParseError = \err -> hPutStrLn h err
    , reportMissingHandler = \meth addr x -> hPutStrLn h $ show addr ++ " --> Missing handler ("++show meth++")"
    , reportUnknown = \addr x err -> hPutStrLn h $ show addr ++ " --> " ++ err
    }

-- Change the /err/ type for an 'ErrorReporter'.
instance Contravariant (ErrorReporter addr x meth tid) where
    -- contramap :: (t5 -> t4) -> ErrorReporter t3 t2 t1 t t4 -> ErrorReporter t3 t2 t1 t t5
    contramap f (ErrorReporter pe mh unk)
        = ErrorReporter (\e -> pe (f e))
                        mh
                        (\addr x e -> unk addr x (f e))

-- | An incoming message can be classified into three cases.
data MessageClass err meth tid addr x
    = IsQuery meth tid                                    -- ^ An unsolicited query is handled based on it's /meth/ value.  Any response
                                                          -- should include the provided /tid/ value.
    | IsResponse tid                                      -- ^ A response to a outgoing query we associated with a /tid/ value.
    | IsUnsolicited (addr -> addr -> IO (Maybe (x -> x))) -- ^ Transactionless informative packet.  The io action will be invoked
                                                          -- with the source and destination address of a message.  If it handles the
                                                          -- message, it should return Nothing. Otherwise, it should return a transform
                                                          -- (usually /id/) to apply before the next handler examines it.
    | IsUnknown err                                       -- ^ None of the above.

-- | Handler for an inbound query of type /x/ from an address of type _addr_.
type MethodHandler err tid addr x = MethodHandlerA  err tid addr x x

-- | Handler for an inbound query of type /x/ with outbound response of type
-- /y/ to an address of type /addr/.
data MethodHandlerA err tid addr x y = forall a b. MethodHandler
    { -- | Parse the query into a more specific type for this method.
      methodParse :: x -> Either err a
      -- | Serialize the response for transmission, given a context /ctx/ and the origin
      -- and destination addresses.
    , methodSerialize :: tid -> addr -> addr -> b -> y
      -- | Fully typed action to perform upon the query.  The remote origin
      -- address of the query is provided to the handler.
      --
      -- TODO: Allow queries to be ignored?
    , methodAction :: addr -> a -> IO b
    }
      -- | See also 'IsUnsolicited' which likely makes this constructor unnecessary.
    | forall a. NoReply
    { -- | Parse the query into a more specific type for this method.
      methodParse :: x -> Either err a
      -- | Fully typed action to perform upon the query.  The remote origin
      -- address of the query is provided to the handler.
    , noreplyAction :: addr -> a -> IO ()
    }


-- | To dispatch responses to our outbound queries, we require three
-- primitives.  See the 'transactionMethods' function to create these
-- primitives out of a lookup table and a generator for transaction ids.
--
-- The type variable /d/ is used to represent the current state of the
-- transaction generator and the table of pending transactions.
data TransactionMethods d qid addr x = TransactionMethods
    {
      -- | Before a query is sent, this function stores an 'MVar' to which the
      -- response will be written too.  The returned /qid/ is a transaction id
      -- that can be used to forget the 'MVar' if the remote peer is not
      -- responding.
      dispatchRegister :: POSIXTime -- time of expiry
                          -> (qid -> Result x -> IO ()) -- callback upon response (or timeout)
                          -> addr
                          -> d
                          -> STM (qid, d)
      -- | This method is invoked when an incoming packet /x/ indicates it is
      -- a response to the transaction with id /qid/.  The returned IO action
      -- will write the packet to the correct 'MVar' thus completing the
      -- dispatch.
    , dispatchResponse :: qid -> Result x -> d -> STM (d, IO ())
    }

-- | A set of methods necessary for dispatching incoming packets.
type DispatchMethods tbl err meth tid addr x = DispatchMethodsA tbl err meth tid addr x x

-- | A set of methods necessary for dispatching incoming packets.
data DispatchMethodsA tbl err meth tid addr x y = DispatchMethods
    { -- | Classify an inbound packet as a query or response.
      classifyInbound :: x -> MessageClass err meth tid addr x
      -- | Lookup the handler for a inbound query.
    , lookupHandler :: meth -> Maybe (MethodHandlerA err tid addr x y)
      -- | Methods for handling incoming responses.
    , tableMethods :: TransactionMethods tbl tid addr x
    }

-- | All inputs required to implement a query\/response client.
type Client err meth tid addr x = ClientA err meth tid addr x x

-- | All inputs required to implement a query\/response client.
data ClientA err meth tid addr x y = forall tbl. Client
    { -- | The 'Transport' used to dispatch and receive packets.
      clientNet :: TransportA err addr x y
      -- | Methods for handling inbound packets.
    , clientDispatcher :: DispatchMethodsA tbl err meth tid addr x y
      -- | Methods for reporting various conditions.
    , clientErrorReporter :: ErrorReporter addr x meth tid err
      -- | State necessary for routing inbound responses and assigning unique
      -- /tid/ values for outgoing queries.
    , clientPending :: TVar tbl
      -- | An action yielding this client\'s own address.  It is invoked once
      -- on each outbound and inbound packet.  It is valid for this to always
      -- return the same value.
      --
      -- The argument, if supplied, is the remote address for the transaction.
      -- This can be used to maintain consistent aliases for specific peers.
    , clientAddress :: Maybe addr -> IO addr
      -- | Transform a query /tid/ value to an appropriate response /tid/
      -- value.  Normally, this would be the identity transformation, but if
      -- /tid/ includes a unique cryptographic nonce, then it should be
      -- generated here.
    , clientResponseId :: tid -> IO tid
    }

-- | These four parameters are required to implement an outgoing query.  A
-- peer-to-peer algorithm will define a 'MethodSerializer' for every 'MethodHandler' that
-- might be returned by 'lookupHandler'.
data MethodSerializerA tid addr x y meth a b = MethodSerializer
    { -- | Returns the microseconds to wait for a response to this query being
      -- sent to the given address.  The /addr/ may also be modified to add
      -- routing information.
      methodTimeout :: addr -> STM (addr,Int)
      -- | A method identifier used for error reporting.  This needn't be the
      -- same as the /meth/ argument to 'MethodHandler', but it is suggested.
    , method :: meth
      -- | Serialize the outgoing query /a/ into a transmittable packet /x/.
      -- The /addr/ arguments are, respectively, our own origin address and the
      -- destination of the request.  The /tid/ argument is useful for attaching
      -- auxiliary notations on all outgoing packets.
    , wrapQuery :: tid -> addr -> addr -> a -> x
      -- | Parse an inbound packet /x/ into a response /b/ for this query.
    , unwrapResponse :: y -> b
    }

type MethodSerializer tid addr x meth a b = MethodSerializerA tid addr x x meth a b

microsecondsDiff :: Int -> POSIXTime
microsecondsDiff us = fromIntegral us / 1000000

asyncQuery :: Show meth => Client err meth qid addr x
              -> MethodSerializer qid addr x meth a b
              -> a
              -> addr
              -> (qid -> Result b -> IO ())
              -> IO qid
asyncQuery c@(Client net d err pending whoami _) meth q addr0 withResponse = do
    tm <- getSystemTimerManager
    now <- getPOSIXTime
    keyvar <- newEmptyMVar
    (qid,addr,expiry) <- atomically $ do
        tbl <- readTVar pending
        (addr,expiry) <- methodTimeout meth addr0
        (qid, tbl') <- dispatchRegister (tableMethods d)
                (now + microsecondsDiff expiry)
                (\qid result -> do
                    tm_key <- swapMVar keyvar Nothing
                    mapM_ (unregisterTimeout tm) tm_key `catch` (\(SomeException _) -> return ())
                    withResponse qid $ fmap (unwrapResponse meth) result)
                addr
                tbl
        writeTVar pending tbl'
        return (qid,addr,expiry)
    tm_key <- registerTimeout tm expiry $ do
        atomically $ do
            tbl <- readTVar pending
            -- Below, we discard the returned IO action since we will call
            -- withResponse directly later.
            (v,_) <- dispatchResponse (tableMethods d) qid TimedOut tbl
            writeTVar pending v
        m <- takeMVar keyvar
        forM_ m $ \_ -> withResponse qid TimedOut
    putMVar keyvar (Just tm_key)
    self <- whoami (Just addr)
    afterward <- newTVarIO $ return () -- Will be overridden by cancelation handler
    do sendMessage net addr (wrapQuery meth qid self addr q)
       return ()
     `catchIOError` \e -> atomically $ cancelQuery c (writeTVar afterward) qid
    join $ atomically $ readTVar afterward
    return qid

cancelQuery :: ClientA err meth qid addr x y -> (IO () -> STM ()) -> qid -> STM ()
cancelQuery c@(Client net d err pending whoami _) runIO qid = do
    tbl <- readTVar pending
    (tbl', io) <- dispatchResponse (tableMethods d) qid Canceled tbl
    writeTVar pending tbl'
    runIO io

-- | Send a query to a remote peer.  Note that this function will always time
-- out if 'forkListener' was never invoked to spawn a thread to receive and
-- dispatch the response.
sendQuery :: Show meth =>
        Client err meth tid addr x              -- ^ A query/response implementation.
        -> MethodSerializer tid addr x meth a b -- ^ Information for marshaling the query.
        -> a                                    -- ^ The outbound query.
        -> addr                                 -- ^ Destination address of query.
        -> IO (Result b)                        -- ^ The response or failure condition.
sendQuery c meth q addr0 = do
    got <- newEmptyMVar
    tid <- asyncQuery c meth q addr0 $ \qid r -> putMVar got r
    takeMVar got

contramapAddr :: (a -> b) -> MethodHandler err tid b x -> MethodHandler err tid a x
contramapAddr f (MethodHandler p s a)
    = MethodHandler
        p
        (\tid src dst result -> s tid (f src) (f dst) result)
        (\addr arg -> a (f addr) arg)
contramapAddr f (NoReply p a)
    = NoReply p (\addr arg -> a (f addr) arg)

-- | Query handlers can throw this to ignore a query instead of responding to
-- it.
data DropQuery = DropQuery
 deriving Show

instance Exception DropQuery

-- | Attempt to invoke a 'MethodHandler' upon a given inbound query.  If the
-- parse is successful, the returned IO action will construct our reply if
-- there is one.  Otherwise, a parse err is returned.
dispatchQuery :: MethodHandlerA err tid addr x y -- ^ Handler to invoke.
                 -> tid                          -- ^ The transaction id for this query\/response session.
                 -> addr                         -- ^ Our own address, to which the query was sent.
                 -> x                            -- ^ The query packet.
                 -> addr                         -- ^ The origin address of the query.
                 -> Either err (IO (Maybe y))
dispatchQuery (MethodHandler unwrapQ wrapR f) tid self x addr =
    fmap (\a -> catch (Just . wrapR tid self addr <$> f addr a)
                      (\DropQuery -> return Nothing))
        $ unwrapQ x
dispatchQuery (NoReply unwrapQ f) tid self x addr =
    fmap (\a -> f addr a >> return Nothing) $ unwrapQ x

-- | Like 'transactionMethods' but allows extra information to be stored in the
-- table of pending transactions.  This also enables multiple 'Client's to
-- share a single transaction table.
transactionMethods' ::
    ((qid -> Result x -> IO ()) -> a)    -- ^ store MVar into table entry
    -> (a -> qid -> Result x -> IO void) -- ^ load MVar from table entry
    -> TableMethods t qid                -- ^ Table methods to lookup values by /tid/.
    -> (g -> (qid,g))                    -- ^ Generate a new unique /tid/ value and update the generator state /g/.
    -> TransactionMethods (g,t a) qid addr x
transactionMethods' store load (TableMethods insert delete lookup) generate = TransactionMethods
    { dispatchRegister = \nowPlusExpiry v a (g,t) -> do
        let (tid,g') = generate g
        let t' = insert tid (store v) nowPlusExpiry t -- (now + microsecondsDiff expiry) t
        return ( tid, (g',t') )
    , dispatchResponse = \tid x (g,t) ->
        case lookup tid t of
            Just v -> let t' = delete tid t
                      in return ((g,t'),void $ load v tid x)
            Nothing -> return ((g,t), return ())
    }

-- | Construct 'TransactionMethods' methods out of 3 lookup table primitives and a
-- function for generating unique transaction ids.
transactionMethods ::
    TableMethods t qid  -- ^ Table methods to lookup values by /tid/.
    -> (g -> (qid,g))   -- ^ Generate a new unique /tid/ value and update the generator state /g/.
    -> TransactionMethods (g,t (qid -> Result x -> IO ())) qid addr x
transactionMethods methods generate = transactionMethods' id id methods generate

-- | Handle a single inbound packet and then invoke the given continuation.
-- The 'forkListener' function is implemented by passing this function to 'fix'
-- in a forked thread that loops until 'awaitMessage' returns 'Nothing' or
-- throws an exception.
handleMessage ::
    ClientA err meth tid addr x y
    -> Arrival err addr x
    -> STM (Arrival err addr x, IO ())
handleMessage (Client net d err pending whoami responseID) msg@(Arrival addr plain) = do
    -- Just (Left e)              -> do reportParseError err e
    --                                  return $! Just id
    -- Just (Right (plain, addr)) -> do
        case classifyInbound d plain of
            IsQuery meth tid -> case lookupHandler d meth of
                                Nothing -> return (msg, reportMissingHandler err meth addr plain)
                                Just m  -> return $ (,) Discarded $ do
                                  self <- whoami (Just addr)
                                  tid' <- responseID tid
                                  either (\e -> reportParseError err e)
                                         (\iom -> iom >>= mapM_ (sendMessage net addr))
                                         (dispatchQuery m tid' self plain addr)
            IsUnsolicited action -> return $ (,) Discarded $ do
                self <- whoami (Just addr)
                _ <- action self addr
                return ()
            IsResponse tid -> return $ (,) Discarded $ do
                action <- atomically $ do
                    ts0 <- readTVar pending
                    (ts, action) <- dispatchResponse (tableMethods d) tid (Success plain) ts0
                    writeTVar pending ts
                    return action
                action
            IsUnknown e -> return (msg, reportUnknown err addr plain e)
    -- Nothing -> return $! id
handleMessage _ msg = return (msg, return ())

-- * UDP Datagrams.

-- | Access the address family of a given 'SockAddr'.  This convenient accessor
-- is missing from 'Network.Socket', so I implemented it here.
sockAddrFamily :: SockAddr -> Family
sockAddrFamily (SockAddrInet  _ _    ) = AF_INET
sockAddrFamily (SockAddrInet6 _ _ _ _) = AF_INET6
sockAddrFamily (SockAddrUnix  _      ) = AF_UNIX
#if !MIN_VERSION_network(3,0,0)
sockAddrFamily _                       = AF_CAN -- SockAddrCan constructor deprecated
#endif

-- | Packets with an empty payload may trigger EOF exception.
-- 'udpTransport' uses this function to avoid throwing in that
-- case.
ignoreEOF :: Socket -> MVar () -> Arrival e a x -> IOError -> IO (Arrival e a x)
ignoreEOF sock isClosed def e = do
    done <- tryReadMVar isClosed
    case done of
        Just () -> do close sock
                      dput XMisc "Closing UDP socket."
                      pure Terminated
        _ -> if isEOFError e then pure def
                             else throwIO e

-- | Hard-coded maximum packet size for incoming UDP Packets received via
-- 'udpTransport'.
udpBufferSize :: Int
udpBufferSize = 65536

-- | Wrapper around 'B.sendTo' that silently ignores DoesNotExistError.
saferSendTo :: Socket -> ByteString -> SockAddr -> IO ()
saferSendTo sock bs saddr = void (B.sendTo sock bs saddr)
    `catch` \e ->
        -- sendTo: does not exist (Network is unreachable)
        --       Occurs when IPv6 or IPv4 network is not available.
        --       Currently, we require -threaded to prevent a forever-hang in this case.
        if isDoesNotExistError e
            then return ()
            else throw e

-- | Like 'udpTransport' except also returns the raw socket (for broadcast use).
--
-- Note: Throws an exception if unable to bind.
udpTransport' :: Show err => SockAddr -> IO (Transport err SockAddr ByteString, Socket)
udpTransport' bind_address = do
  let family = sockAddrFamily bind_address
  sock <- socket family Datagram defaultProtocol
  when (family == AF_INET6) $ do
    setSocketOption sock IPv6Only 0
  setSocketOption sock Broadcast 1
  bind sock bind_address
  isClosed <- newEmptyMVar
  udpTChan <- atomically newTChan
  let tr = Transport {
      awaitMessage = fmap (,return()) $ readTChan udpTChan
    , sendMessage = case family of
        AF_INET6 -> \case
            (SockAddrInet port addr) -> \bs ->
                -- Change IPv4 to 4mapped6 address.
                saferSendTo sock bs $ SockAddrInet6 port 0 (0,0,0x0000ffff,fromBE32 addr) 0
            addr6 -> \bs -> saferSendTo sock bs addr6
        AF_INET  -> \case
            (SockAddrInet6 port 0 (0,0,0x0000ffff,raw4) 0) -> \bs -> do
                let host4 = toBE32 raw4
                -- Change 4mapped6 to ordinary IPv4.
                -- dput XMisc $ "4mapped6 -> "++show (SockAddrInet port host4)
                saferSendTo sock bs (SockAddrInet port host4)
            addr@(SockAddrInet6 {}) -> \bs -> dput XMisc ("Discarding packet to "++show addr)
            addr4 -> \bs -> saferSendTo sock bs addr4
        _ -> \addr bs -> saferSendTo sock bs addr
    , setActive = \case
      False -> do
        dput XMisc $ "closeTransport for udpTransport' called. " ++ show bind_address
        tryPutMVar isClosed () -- signal awaitMessage that the transport is closed.
#if MIN_VERSION_network (3,1,0)
#elif MIN_VERSION_network(3,0,0)
        let withFdSocket sock f = fdSocket sock >>= f >>= seq sock . return
#else
        let withFdSocket sock f = f (fdSocket sock) >>= seq sock . return
#endif
        withFdSocket sock $ \fd -> do
            let sorryGHCButIAmNotFuckingClosingTheSocketYet fd = return ()
            -- This call is necessary to interrupt the blocking recvFrom call in awaitMessage.
            closeFdWith sorryGHCButIAmNotFuckingClosingTheSocketYet (fromIntegral fd)
      True -> do
            udpThread <- forkIO $ fix $ \again -> do
                r <- handle (ignoreEOF sock isClosed $ Arrival (SockAddrInet 0 0) B.empty) $ do
                        uncurry (flip Arrival) <$!> B.recvFrom sock udpBufferSize
                atomically $ writeTChan udpTChan r
                case r of Terminated -> return ()
                          _          -> again
            labelThread udpThread ("udp.io."++show bind_address)
    }
  return (tr, sock)

-- | A 'udpTransport' uses a UDP socket to send and receive 'ByteString's.  The
-- argument is the listen-address for incoming packets.  This is a useful
-- low-level 'Transport' that can be transformed for higher-level protocols
-- using 'layerTransport'.
udpTransport :: Show err => SockAddr -> IO (Transport err SockAddr ByteString)
udpTransport bind_address = fst <$> udpTransport' bind_address

chanTransport :: (addr -> TChan (x, addr)) -> addr -> TChan (x, addr) -> TVar Bool -> Transport err addr x
chanTransport chanFromAddr self achan aclosed = Transport
    { awaitMessage = fmap (, return ()) $
        orElse (uncurry (flip Arrival) <$> readTChan achan)
               (readTVar aclosed >>= check >> return Terminated)
    , sendMessage = \them bs -> do
        atomically $ writeTChan (chanFromAddr them) (bs,self)
    , setActive = \case
        False -> atomically $ writeTVar aclosed True
        True  -> return ()
    }

-- | Returns a pair of transports linked together to simulate two computers talking to each other.
testPairTransport :: IO (Transport err SockAddr ByteString, Transport err SockAddr ByteString)
testPairTransport = do
    achan <- atomically newTChan
    bchan <- atomically newTChan
    aclosed <- atomically $ newTVar False
    bclosed <- atomically $ newTVar False
    let a = SockAddrInet 1 1
        b = SockAddrInet 2 2
    return ( chanTransport (const bchan) a achan aclosed
           , chanTransport (const achan) b bchan bclosed )

newtype ByAddress err x addr = ByAddress (Transport err addr x)

newtype Tagged x addr = Tagged x

decorateAddr :: tag addr -> Arrival e addr x -> Arrival e (DSum tag Identity) x
decorateAddr tag Terminated       = Terminated
decorateAddr tag Discarded        = Discarded
decorateAddr tag (ParseError e)   = ParseError e
decorateAddr tag (Arrival addr x) = Arrival (tag ==> addr) x

mergeTransports :: GCompare tag => DMap tag (ByAddress err x) -> IO (Transport err (DSum tag Identity) x)
mergeTransports tmap = do
    -- vmap <- traverseWithKey (\k v -> Tagged <$> newEmptyMVar) tmap
    -- foldrWithKey (\k v n -> forkMergeBranch k v >> n) (return ()) vmap
    return Transport
        { awaitMessage =
            foldrWithKey (\k (ByAddress tr) n -> (first (decorateAddr k) <$> awaitMessage tr) `orElse` n)
                         retry
                         tmap
        , sendMessage = \(tag :=> Identity addr) x -> case DMap.lookup tag tmap of
            Just (ByAddress tr) -> sendMessage tr addr x
            Nothing             -> return ()
        , setActive = \toggle -> foldrWithKey (\_ (ByAddress tr) next -> setActive tr toggle >> next) (return ()) tmap
        }