summaryrefslogtreecommitdiff
path: root/dht/src/Network/QueryResponse.hs
blob: 9c33b91149452a4ab32e55db1e29a90ab34f406a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
-- | This module can implement any query\/response protocol.  It was written
-- with Kademlia implementations in mind.

{-# LANGUAGE CPP                   #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TupleSections         #-}
module Network.QueryResponse where

#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent
import GHC.Conc           (labelThread)
#endif
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import qualified Data.ByteString    as B
         ;import Data.ByteString    (ByteString)
import Data.Function
import Data.Functor.Contravariant
import qualified Data.IntMap.Strict as IntMap
         ;import Data.IntMap.Strict (IntMap)
import qualified Data.Map.Strict    as Map
         ;import Data.Map.Strict    (Map)
import Data.Time.Clock.POSIX
import qualified Data.Word64Map     as W64Map
         ;import Data.Word64Map     (Word64Map)
import Data.Word
import Data.Maybe
import GHC.Conc (closeFdWith)
import GHC.Event
import Network.Socket
import Network.Socket.ByteString    as B
import System.Endian
import System.IO
import System.IO.Error
import System.Timeout
import DPut
import DebugTag
import Data.TableMethods

-- | Three methods are required to implement a datagram based query\/response protocol.
data TransportA err addr x y = Transport
    { -- | Blocks until an inbound packet is available. Returns 'Nothing' when
      -- no more packets are expected due to a shutdown or close event.
      -- Otherwise, the packet will be parsed as type /x/ and an origin address
      -- /addr/.  Parse failure is indicated by the type 'err'.
      awaitMessage :: forall a. (Maybe (Either err (x, addr)) -> IO a) -> IO a
      -- | Send an /y/ packet to the given destination /addr/.
    , sendMessage :: addr -> y -> IO ()
      -- | Shutdown and clean up any state related to this 'Transport'.
    , closeTransport :: IO ()
    }

type Transport err addr x = TransportA err addr x x

-- | This function modifies a 'Transport' to use higher-level addresses and
-- packet representations.  It could be used to change UDP 'ByteString's into
-- bencoded syntax trees or to add an encryption layer in which addresses have
-- associated public keys.
layerTransportM ::
        (x -> addr -> IO (Either err (x', addr')))
        -- ^ Function that attempts to transform a low-level address/packet
        -- pair into a higher level representation.
        -> (y' -> addr' -> IO (y, addr))
        -- ^ Function to encode a high-level address/packet into a lower level
        -- representation.
        -> TransportA err addr x y
        -- ^ The low-level transport to be transformed.
        -> TransportA err addr' x' y'
layerTransportM parse encode tr =
    tr { awaitMessage = \kont ->
                awaitMessage tr $ \m -> mapM (mapM $ uncurry parse) m >>= kont . fmap join
       , sendMessage = \addr' msg' -> do
                (msg,addr) <- encode msg' addr'
                sendMessage tr addr msg
       }


-- | This function modifies a 'Transport' to use higher-level addresses and
-- packet representations.  It could be used to change UDP 'ByteString's into
-- bencoded syntax trees or to add an encryption layer in which addresses have
-- associated public keys.
layerTransport ::
        (x -> addr -> Either err (x', addr'))
        -- ^ Function that attempts to transform a low-level address/packet
        -- pair into a higher level representation.
        -> (y' -> addr' -> (y, addr))
        -- ^ Function to encode a high-level address/packet into a lower level
        -- representation.
        -> TransportA err addr x y
        -- ^ The low-level transport to be transformed.
        -> TransportA err addr' x' y'
layerTransport parse encode tr =
    layerTransportM (\x addr -> return $ parse x addr)
                    (\x' addr' -> return $ encode x' addr')
                    tr

-- | Paritions a 'Transport' into two higher-level transports.  Note: An 'MVar'
-- is used to share the same underlying socket, so be sure to fork a thread for
-- both returned 'Transport's to avoid hanging.
partitionTransport :: ((b,a) -> Either (x,xaddr) (b,a))
                      -> ((x,xaddr) -> Maybe (b,a))
                      -> Transport err a b
                      -> IO (Transport err xaddr x, Transport err a b)
partitionTransport parse encodex tr =
    partitionTransportM (return . parse) (return . encodex) tr

-- | Paritions a 'Transport' into two higher-level transports.  Note: An 'MVar'
-- is used to share the same underlying socket, so be sure to fork a thread for
-- both returned 'Transport's to avoid hanging.
partitionTransportM :: ((b,a) -> IO (Either (x,xaddr) (b,a)))
                      -> ((x,xaddr) -> IO (Maybe (b,a)))
                      -> Transport err a b
                      -> IO (Transport err xaddr x, Transport err a b)
partitionTransportM parse encodex tr = do
    mvar <- newEmptyMVar
    let xtr = tr { awaitMessage = \kont -> fix $ \again -> do
                    awaitMessage tr $ \m -> case m of
                        Just (Right msg) -> parse msg >>=
                                                either (kont . Just . Right)
                                                       (\y -> putMVar mvar (Just y) >> again)
                        Just (Left e)    -> kont $ Just (Left e)
                        Nothing          -> putMVar mvar Nothing >> kont Nothing
                 , sendMessage = \addr' msg' -> do
                    msg_addr <- encodex (msg',addr')
                    mapM_ (uncurry . flip $ sendMessage tr) msg_addr
                 }
        ytr = Transport
                { awaitMessage = \kont -> takeMVar mvar >>= kont . fmap Right
                , sendMessage = sendMessage tr
                , closeTransport = return ()
                }
    return (xtr, ytr)

partitionAndForkTransport ::
                      (dst -> msg -> IO ())
                      -> ((b,a) -> IO (Either (x,xaddr) (b,a)))
                      -> ((x,xaddr) -> IO (Maybe (Either (msg,dst) (b,a))))
                      -> Transport err a b
                      -> IO (Transport err xaddr x, Transport err a b)
partitionAndForkTransport forkedSend parse encodex tr = do
    mvar <- newEmptyMVar
    let xtr = tr { awaitMessage = \kont -> fix $ \again -> do
                    awaitMessage tr $ \m -> case m of
                        Just (Right msg) -> parse msg >>=
                                                either (kont . Just . Right)
                                                       (\y -> putMVar mvar (Just y) >> again)
                        Just (Left e)    -> kont $ Just (Left e)
                        Nothing          -> putMVar mvar Nothing >> kont Nothing
                 , sendMessage = \addr' msg' -> do
                    msg_addr <- encodex (msg',addr')
                    case msg_addr of
                        Just (Right (b,a))    -> sendMessage tr a b
                        Just (Left (msg,dst)) -> forkedSend dst msg
                        Nothing               -> return ()
                 }
        ytr = Transport
                { awaitMessage   = \kont -> takeMVar mvar >>= kont . fmap Right
                , sendMessage    = sendMessage tr
                , closeTransport = return ()
                }
    return (xtr, ytr)

-- |
--      * f add x --> Nothing, consume x
--                --> Just id, leave x to a different handler
--                --> Just g, apply g to x and leave that to a different handler
addHandler :: (err -> IO ()) -> (addr -> x -> IO (Maybe (x -> x))) -> Transport err addr x -> Transport err addr x
addHandler onParseError f tr = tr
    { awaitMessage = \kont -> fix $ \eat -> awaitMessage tr $ \m -> do
        case m of
            Just (Right (x, addr)) -> f addr x >>= maybe eat (kont . Just . Right . (, addr) . ($ x))
            Just (Left  e        ) -> onParseError e >> kont (Just $ Left e)
            Nothing                -> kont $ Nothing
    }

-- | Modify a 'Transport' to invoke an action upon every received packet.
onInbound :: (addr -> x -> IO ()) -> Transport err addr x -> Transport err addr x
onInbound f tr = addHandler (const $ return ()) (\addr x -> f addr x >> return (Just id)) tr

-- * Using a query\/response client.

-- | Fork a thread that handles inbound packets.  The returned action may be used
-- to terminate the thread and clean up any related state.
--
--  Example usage:
--
--  > -- Start client.
--  > quitServer <- forkListener "listener" (clientNet client)
--  > -- Send a query q, recieve a response r.
--  > r <- sendQuery client method q
--  > -- Quit client.
--  > quitServer
forkListener :: String -> Transport err addr x -> IO (IO ())
forkListener name client = do
    thread_id <- forkIO $ do
        myThreadId >>= flip labelThread ("listener."++name)
        fix $ \loop -> awaitMessage client $ maybe (return ()) (const loop)
        dput XMisc $ "Listener died: " ++ name
    return $ do
        closeTransport client
        -- killThread thread_id

asyncQuery_ :: Client err meth tid addr x
              -> MethodSerializer tid addr x meth a b
              -> a
              -> addr
              -> (Maybe b -> IO ())
              -> IO (tid,POSIXTime,Int)
asyncQuery_ (Client net d err pending whoami _) meth q addr0 withResponse = do
    now <- getPOSIXTime
    (tid,addr,expiry) <- atomically $ do
        tbl <- readTVar pending
        ((tid,addr,expiry), tbl') <- dispatchRegister (tableMethods d)
                (methodTimeout meth)
                now
                (withResponse . fmap (unwrapResponse meth))
                addr0
                tbl
        -- (addr,expiry) <- methodTimeout meth tid addr0
        writeTVar pending tbl'
        return (tid,addr,expiry)
    self <- whoami (Just addr)
    mres <- do sendMessage net addr (wrapQuery meth tid self addr q)
               return $ Just ()
             `catchIOError` (\e -> return Nothing)
    return (tid,now,expiry)

asyncQuery :: Show meth => Client err meth tid addr x
              -> MethodSerializer tid addr x meth a b
              -> a
              -> addr
              -> (Maybe b -> IO ())
              -> IO ()
asyncQuery client meth q addr withResponse0 = do
    tm <- getSystemTimerManager
    tidvar <- newEmptyMVar
    timedout <- registerTimeout tm 1000000 $ do
            dput XMisc $ "async TIMEDOUT " ++ show (method meth)
            withResponse0 Nothing
            tid <- takeMVar tidvar
            dput XMisc $ "async TIMEDOUT mvar " ++ show (method meth)
            case client of
                Client { clientDispatcher = d, clientPending = pending } -> do
                    atomically $ readTVar pending >>= dispatchCancel (tableMethods d) tid >>= writeTVar pending
    (tid,now,expiry) <- asyncQuery_ client meth q addr $ \x -> do
        unregisterTimeout tm timedout
        withResponse0 x
    putMVar tidvar tid
    updateTimeout tm timedout expiry
    dput XMisc $ "FIN asyncQuery "++show (method meth)++" TIMEOUT="++show expiry

-- | Send a query to a remote peer.  Note that this function will always time
-- out if 'forkListener' was never invoked to spawn a thread to receive and
-- dispatch the response.
sendQuery ::
    forall err a b tbl x meth tid addr.
        Client err meth tid addr x              -- ^ A query/response implementation.
        -> MethodSerializer tid addr x meth a b -- ^ Information for marshaling the query.
        -> a                                    -- ^ The outbound query.
        -> addr                                 -- ^ Destination address of query.
        -> IO (Maybe b)                         -- ^ The response, or 'Nothing' if it timed out.
sendQuery c@(Client net d err pending whoami _) meth q addr0 = do
    mvar <- newEmptyMVar
    (tid,now,expiry) <- asyncQuery_ c meth q addr0 $ mapM_ (putMVar mvar)
    mres <- timeout expiry $ takeMVar mvar
    case mres of
        Just b -> return $ Just b
        Nothing -> do
            atomically $ readTVar pending >>= dispatchCancel (tableMethods d) tid >>= writeTVar pending
            return Nothing

-- * Implementing a query\/response 'Client'.

-- | All inputs required to implement a query\/response client.
data Client err meth tid addr x = forall tbl. Client
    { -- | The 'Transport' used to dispatch and receive packets.
      clientNet :: Transport err addr x
      -- | Methods for handling inbound packets.
    , clientDispatcher :: DispatchMethods tbl err meth tid addr x
      -- | Methods for reporting various conditions.
    , clientErrorReporter :: ErrorReporter addr x meth tid err
      -- | State necessary for routing inbound responses and assigning unique
      -- /tid/ values for outgoing queries.
    , clientPending :: TVar tbl
      -- | An action yielding this client\'s own address.  It is invoked once
      -- on each outbound and inbound packet.  It is valid for this to always
      -- return the same value.
      --
      -- The argument, if supplied, is the remote address for the transaction.
      -- This can be used to maintain consistent aliases for specific peers.
    , clientAddress :: Maybe addr -> IO addr
      -- | Transform a query /tid/ value to an appropriate response /tid/
      -- value.  Normally, this would be the identity transformation, but if
      -- /tid/ includes a unique cryptographic nonce, then it should be
      -- generated here.
    , clientResponseId :: tid -> IO tid
    }

-- | An incoming message can be classified into three cases.
data MessageClass err meth tid addr x
    = IsQuery meth tid                                    -- ^ An unsolicited query is handled based on it's /meth/ value.  Any response
                                                          -- should include the provided /tid/ value.
    | IsResponse tid                                      -- ^ A response to a outgoing query we associated with a /tid/ value.
    | IsUnsolicited (addr -> addr -> IO (Maybe (x -> x))) -- ^ Transactionless informative packet.  The io action will be invoked
                                                          -- with the source and destination address of a message.  If it handles the
                                                          -- message, it should return Nothing. Otherwise, it should return a transform
                                                          -- (usually /id/) to apply before the next handler examines it.
    | IsUnknown err                                       -- ^ None of the above.

-- | Handler for an inbound query of type /x/ from an address of type _addr_.
data MethodHandler err tid addr x = forall a b. MethodHandler
    { -- | Parse the query into a more specific type for this method.
      methodParse :: x -> Either err a
      -- | Serialize the response for transmission, given a context /ctx/ and the origin
      -- and destination addresses.
    , methodSerialize :: tid -> addr -> addr -> b -> x
      -- | Fully typed action to perform upon the query.  The remote origin
      -- address of the query is provided to the handler.
    , methodAction :: addr -> a -> IO b
    }
      -- | See also 'IsUnsolicited' which likely makes this constructor unnecessary.
    | forall a. NoReply
    { -- | Parse the query into a more specific type for this method.
      methodParse :: x -> Either err a
      -- | Fully typed action to perform upon the query.  The remote origin
      -- address of the query is provided to the handler.
    , noreplyAction :: addr -> a -> IO ()
    }

contramapAddr :: (a -> b) -> MethodHandler err tid b x -> MethodHandler err tid a x
contramapAddr f (MethodHandler p s a)
    = MethodHandler
        p
        (\tid src dst result -> s tid (f src) (f dst) result)
        (\addr arg -> a (f addr) arg)
contramapAddr f (NoReply p a)
    = NoReply p (\addr arg -> a (f addr) arg)


-- | Attempt to invoke a 'MethodHandler' upon a given inbound query.  If the
-- parse is successful, the returned IO action will construct our reply if
-- there is one.  Otherwise, a parse err is returned.
dispatchQuery :: MethodHandler err tid addr x -- ^ Handler to invoke.
                 -> tid                       -- ^ The transaction id for this query\/response session.
                 -> addr                      -- ^ Our own address, to which the query was sent.
                 -> x                         -- ^ The query packet.
                 -> addr                      -- ^ The origin address of the query.
                 -> Either err (IO (Maybe x))
dispatchQuery (MethodHandler unwrapQ wrapR f) tid self x addr =
    fmap (\a -> Just . wrapR tid self addr <$> f addr a) $ unwrapQ x
dispatchQuery (NoReply unwrapQ f) tid self x addr =
    fmap (\a -> f addr a >> return Nothing) $ unwrapQ x

-- | These four parameters are required to implement an outgoing query.  A
-- peer-to-peer algorithm will define a 'MethodSerializer' for every 'MethodHandler' that
-- might be returned by 'lookupHandler'.
data MethodSerializer tid addr x meth a b = MethodSerializer
    { -- | Returns the microseconds to wait for a response to this query being
      -- sent to the given address.  The /addr/ may also be modified to add
      -- routing information.
      methodTimeout :: tid -> addr -> STM (addr,Int)
      -- | A method identifier used for error reporting.  This needn't be the
      -- same as the /meth/ argument to 'MethodHandler', but it is suggested.
    , method :: meth
      -- | Serialize the outgoing query /a/ into a transmittable packet /x/.
      -- The /addr/ arguments are, respectively, our own origin address and the
      -- destination of the request.  The /tid/ argument is useful for attaching
      -- auxiliary notations on all outgoing packets.
    , wrapQuery :: tid -> addr -> addr -> a -> x
      -- | Parse an inbound packet /x/ into a response /b/ for this query.
    , unwrapResponse :: x -> b
    }


-- | To dispatch responses to our outbound queries, we require three
-- primitives.  See the 'transactionMethods' function to create these
-- primitives out of a lookup table and a generator for transaction ids.
--
-- The type variable /d/ is used to represent the current state of the
-- transaction generator and the table of pending transactions.
data TransactionMethods d tid addr x = TransactionMethods
    {
      -- | Before a query is sent, this function stores an 'MVar' to which the
      -- response will be written too.  The returned /tid/ is a transaction id
      -- that can be used to forget the 'MVar' if the remote peer is not
      -- responding.
      dispatchRegister :: (tid -> addr -> STM (addr,Int)) -> POSIXTime -> (Maybe x -> IO ()) -> addr -> d -> STM ((tid,addr,Int), d)
      -- | This method is invoked when an incoming packet /x/ indicates it is
      -- a response to the transaction with id /tid/.  The returned IO action
      -- will write the packet to the correct 'MVar' thus completing the
      -- dispatch.
    , dispatchResponse :: tid -> x -> d -> STM (d, IO ())
      -- | When a timeout interval elapses, this method is called to remove the
      -- transaction from the table.
    , dispatchCancel :: tid -> d -> STM d
    }

-- | Construct 'TransactionMethods' methods out of 3 lookup table primitives and a
-- function for generating unique transaction ids.
transactionMethods ::
    TableMethods t tid  -- ^ Table methods to lookup values by /tid/.
    -> (g -> (tid,g))   -- ^ Generate a new unique /tid/ value and update the generator state /g/.
    -> TransactionMethods (g,t (Maybe x -> IO ())) tid addr x
transactionMethods methods generate = transactionMethods' id id methods generate

microsecondsDiff :: Int -> POSIXTime
microsecondsDiff us = fromIntegral us / 1000000

-- | Like 'transactionMethods' but allows extra information to be stored in the
-- table of pending transactions.  This also enables multiple 'Client's to
-- share a single transaction table.
transactionMethods' ::
    ((Maybe x -> IO ()) -> a)    -- ^ store MVar into table entry
    -> (a -> Maybe x -> IO void) -- ^ load MVar from table entry
    -> TableMethods t tid        -- ^ Table methods to lookup values by /tid/.
    -> (g -> (tid,g))            -- ^ Generate a new unique /tid/ value and update the generator state /g/.
    -> TransactionMethods (g,t a) tid addr x
transactionMethods' store load (TableMethods insert delete lookup) generate = TransactionMethods
    { dispatchCancel = \tid (g,t) -> return (g, delete tid t)
    , dispatchRegister = \getTimeout now v a0 (g,t) -> do
        let (tid,g') = generate g
        (a,expiry) <- getTimeout tid a0
        let t' = insert tid (store v) (now + microsecondsDiff expiry) t
        return ( (tid,a,expiry), (g',t') )
    , dispatchResponse = \tid x (g,t) ->
        case lookup tid t of
            Just v -> let t' = delete tid t
                      in return ((g,t'),void $ load v $ Just x)
            Nothing -> return ((g,t), return ())
    }

-- | A set of methods necessary for dispatching incoming packets.
data DispatchMethods tbl err meth tid addr x = DispatchMethods
    { -- | Classify an inbound packet as a query or response.
      classifyInbound :: x -> MessageClass err meth tid addr x
      -- | Lookup the handler for a inbound query.
    , lookupHandler :: meth -> Maybe (MethodHandler err tid addr x)
      -- | Methods for handling incoming responses.
    , tableMethods :: TransactionMethods tbl tid addr x
    }

-- | These methods indicate what should be done upon various conditions.  Write
-- to a log file, make debug prints, or simply ignore them.
--
--   [ /addr/ ]  Address of remote peer.
--
--   [ /x/ ]     Incoming or outgoing packet.
--
--   [ /meth/ ]  Method id of incoming or outgoing request.
--
--   [ /tid/ ]   Transaction id for outgoing packet.
--
--   [ /err/ ]   Error information, typically a 'String'.
data ErrorReporter addr x meth tid err = ErrorReporter
    { -- | Incoming: failed to parse packet.
      reportParseError :: err -> IO ()
      -- | Incoming: no handler for request.
    , reportMissingHandler :: meth -> addr -> x -> IO ()
      -- | Incoming: unable to identify request.
    , reportUnknown :: addr -> x -> err -> IO ()
    }

ignoreErrors :: ErrorReporter addr x meth tid err
ignoreErrors = ErrorReporter
    { reportParseError = \_ -> return ()
    , reportMissingHandler = \_ _ _ -> return ()
    , reportUnknown = \_ _ _ -> return ()
    }

logErrors :: ( Show addr
               , Show meth
               ) => ErrorReporter addr x meth tid String
logErrors = ErrorReporter
    { reportParseError = \err -> dput XMisc err
    , reportMissingHandler = \meth addr x -> dput XMisc $ show addr ++ " --> Missing handler ("++show meth++")"
    , reportUnknown = \addr x err -> dput XMisc $ show addr ++ " --> " ++ err
    }

printErrors :: ( Show addr
               , Show meth
               ) => Handle -> ErrorReporter addr x meth tid String
printErrors h = ErrorReporter
    { reportParseError = \err -> hPutStrLn h err
    , reportMissingHandler = \meth addr x -> hPutStrLn h $ show addr ++ " --> Missing handler ("++show meth++")"
    , reportUnknown = \addr x err -> hPutStrLn h $ show addr ++ " --> " ++ err
    }

-- Change the /err/ type for an 'ErrorReporter'.
instance Contravariant (ErrorReporter addr x meth tid) where
    -- contramap :: (t5 -> t4) -> ErrorReporter t3 t2 t1 t t4 -> ErrorReporter t3 t2 t1 t t5
    contramap f (ErrorReporter pe mh unk)
        = ErrorReporter (\e -> pe (f e))
                        mh
                        (\addr x e -> unk addr x (f e))

-- | Handle a single inbound packet and then invoke the given continuation.
-- The 'forkListener' function is implemented by passing this function to 'fix'
-- in a forked thread that loops until 'awaitMessage' returns 'Nothing' or
-- throws an exception.
handleMessage ::
    Client err meth tid addr x
    -> addr
    -> x
    -> IO (Maybe (x -> x))
handleMessage (Client net d err pending whoami responseID) addr plain = do
    -- Just (Left e)              -> do reportParseError err e
    --                                  return $! Just id
    -- Just (Right (plain, addr)) -> do
        case classifyInbound d plain of
            IsQuery meth tid -> case lookupHandler d meth of
                                Nothing -> do reportMissingHandler err meth addr plain
                                              return $! Just id
                                Just m  -> do
                                  self <- whoami (Just addr)
                                  tid' <- responseID tid
                                  either (\e -> do reportParseError err e
                                                   return $! Just id)
                                         (>>= \m -> do mapM_ (sendMessage net addr) m
                                                       return $! Nothing)
                                         (dispatchQuery m tid' self plain addr)
            IsUnsolicited action -> do
                self <- whoami (Just addr)
                action self addr
                return Nothing
            IsResponse tid -> do
                action <- atomically $ do
                    ts0 <- readTVar pending
                    (ts, action) <- dispatchResponse (tableMethods d) tid plain ts0
                    writeTVar pending ts
                    return action
                action
                return $! Nothing
            IsUnknown e -> do reportUnknown err addr plain e
                              return $! Just id
    -- Nothing -> return $! id

-- * UDP Datagrams.

-- | Access the address family of a given 'SockAddr'.  This convenient accessor
-- is missing from 'Network.Socket', so I implemented it here.
sockAddrFamily :: SockAddr -> Family
sockAddrFamily (SockAddrInet  _ _    ) = AF_INET
sockAddrFamily (SockAddrInet6 _ _ _ _) = AF_INET6
sockAddrFamily (SockAddrUnix  _      ) = AF_UNIX
#if !MIN_VERSION_network(3,0,0)
sockAddrFamily _                       = AF_CAN -- SockAddrCan constructor deprecated
#endif

-- | Packets with an empty payload may trigger EOF exception.
-- 'udpTransport' uses this function to avoid throwing in that
-- case.
ignoreEOF :: Socket -> MVar () -> a -> IOError -> IO (Maybe a)
ignoreEOF sock isClosed def e = do
    done <- tryReadMVar isClosed
    case done of
        Just () -> do close sock
                      dput XMisc "Closing UDP socket."
                      pure Nothing
        _ -> if isEOFError e then pure $ Just def
                             else throwIO e

-- | Hard-coded maximum packet size for incoming UDP Packets received via
-- 'udpTransport'.
udpBufferSize :: Int
udpBufferSize = 65536

-- | Wrapper around 'B.sendTo' that silently ignores DoesNotExistError.
saferSendTo :: Socket -> ByteString -> SockAddr -> IO ()
saferSendTo sock bs saddr = void (B.sendTo sock bs saddr)
    `catch` \e ->
        -- sendTo: does not exist (Network is unreachable)
        --       Occurs when IPv6 or IPv4 network is not available.
        --       Currently, we require -threaded to prevent a forever-hang in this case.
        if isDoesNotExistError e
            then return ()
            else throw e

-- | A 'udpTransport' uses a UDP socket to send and receive 'ByteString's.  The
-- argument is the listen-address for incoming packets.  This is a useful
-- low-level 'Transport' that can be transformed for higher-level protocols
-- using 'layerTransport'.
udpTransport :: Show err => SockAddr -> IO (Transport err SockAddr ByteString)
udpTransport bind_address = fst <$> udpTransport' bind_address

-- | Like 'udpTransport' except also returns the raw socket (for broadcast use).
udpTransport' :: Show err => SockAddr -> IO (Transport err SockAddr ByteString, Socket)
udpTransport' bind_address = do
  let family = sockAddrFamily bind_address
  sock <- socket family Datagram defaultProtocol
  when (family == AF_INET6) $ do
    setSocketOption sock IPv6Only 0
  setSocketOption sock Broadcast 1
  bind sock bind_address
  isClosed <- newEmptyMVar
  let tr = Transport {
      awaitMessage = \kont -> do
        r <- handle (ignoreEOF sock isClosed $ Right (B.empty, SockAddrInet 0 0)) $ do
                Just . Right <$!> B.recvFrom sock udpBufferSize
        kont $! r
    , sendMessage = case family of
        AF_INET6 -> \case
            (SockAddrInet port addr) -> \bs ->
                -- Change IPv4 to 4mapped6 address.
                saferSendTo sock bs $ SockAddrInet6 port 0 (0,0,0x0000ffff,fromBE32 addr) 0
            addr6 -> \bs -> saferSendTo sock bs addr6
        AF_INET  -> \case
            (SockAddrInet6 port 0 (0,0,0x0000ffff,raw4) 0) -> \bs -> do
                let host4 = toBE32 raw4
                -- Change 4mapped6 to ordinary IPv4.
                -- dput XMisc $ "4mapped6 -> "++show (SockAddrInet port host4)
                saferSendTo sock bs (SockAddrInet port host4)
            addr@(SockAddrInet6 {}) -> \bs -> dput XMisc ("Discarding packet to "++show addr)
            addr4 -> \bs -> saferSendTo sock bs addr4
        _ -> \addr bs -> saferSendTo sock bs addr
    , closeTransport = do
        dput XMisc $ "closeTransport for udpTransport' called. " ++ show bind_address
        tryPutMVar isClosed () -- signal awaitMessage that the transport is closed.
#if MIN_VERSION_network (3,1,0)
#elif MIN_VERSION_network(3,0,0)
        let withFdSocket sock f = fdSocket sock >>= f >>= seq sock . return
#else
        let withFdSocket sock f = f (fdSocket sock) >>= seq sock . return
#endif
        withFdSocket sock $ \fd -> do
            let sorryGHCButIAmNotFuckingClosingTheSocketYet fd = return ()
            -- This call is necessary to interrupt the blocking recvFrom call in awaitMessage.
            closeFdWith sorryGHCButIAmNotFuckingClosingTheSocketYet (fromIntegral fd)
    }
  return (tr, sock)

chanTransport :: (addr -> TChan (x, addr)) -> addr -> TChan (x, addr) -> TVar Bool -> Transport err addr x
chanTransport chanFromAddr self achan aclosed = Transport
    { awaitMessage = \kont -> do
        x <- atomically $ (Just <$> readTChan achan)
                            `orElse`
                          (readTVar aclosed >>= check >> return Nothing)
        kont $ Right <$> x
    , sendMessage = \them bs -> do
        atomically $ writeTChan (chanFromAddr them) (bs,self)
    , closeTransport = atomically $ writeTVar aclosed True
    }

-- | Returns a pair of transports linked together to simulate two computers talking to each other.
testPairTransport :: IO (Transport err SockAddr ByteString, Transport err SockAddr ByteString)
testPairTransport = do
    achan <- atomically newTChan
    bchan <- atomically newTChan
    aclosed <- atomically $ newTVar False
    bclosed <- atomically $ newTVar False
    let a = SockAddrInet 1 1
        b = SockAddrInet 2 2
    return ( chanTransport (const bchan) a achan aclosed
           , chanTransport (const achan) b bchan bclosed )