summaryrefslogtreecommitdiff
path: root/src/Network/QueryResponse.hs
blob: 41e254860514037a8c387064ddc68cbf0bf26f06 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
-- | This module can implement any query\/response protocol.  It was written
-- with Kademlia implementations in mind.

{-# LANGUAGE CPP                   #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TupleSections         #-}
module Network.QueryResponse where

#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent
import GHC.Conc           (labelThread)
#endif
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import qualified Data.ByteString    as B
         ;import Data.ByteString    (ByteString)
import Data.Function
import qualified Data.IntMap.Strict as IntMap
         ;import Data.IntMap.Strict (IntMap)
import qualified Data.Map.Strict    as Map
         ;import Data.Map.Strict    (Map)
import Data.Maybe
import Data.Typeable
import Network.Socket
import Network.Socket.ByteString    as B
import System.Endian
import System.IO
import System.IO.Error
import System.Timeout

-- | Three methods are required to implement a datagram based query\/response protocol.
data Transport err addr x = Transport
    { -- | Blocks until an inbound packet is available. Returns 'Nothing' when
      -- no more packets are expected due to a shutdown or close event.
      -- Otherwise, the packet will be parsed as type /x/ and an origin address
      -- /addr/.  Parse failure is indicated by the type 'err'.
      awaitMessage :: forall a. (Maybe (Either err (x, addr)) -> IO a) -> IO a
      -- | Send an /x/ packet to the given destination /addr/.
    , sendMessage :: addr -> x -> IO ()
      -- | Shutdown and clean up any state related to this 'Transport'.
    , closeTransport :: IO ()
    }

-- | This function modifies a 'Transport' to use higher-level addresses and
-- packet representations.  It could be used to change UDP 'ByteString's into
-- bencoded syntax trees or to add an encryption layer in which addresses have
-- associated public keys.
layerTransport ::
        (x -> addr -> Either err (x', addr'))
        -- ^ Function that attempts to transform a low-level address/packet
        -- pair into a higher level representation.
        -> (x' -> addr' -> (x, addr))
        -- ^ Function to encode a high-level address/packet into a lower level
        -- representation.
        -> Transport err addr x
        -- ^ The low-level transport to be transformed.
        -> Transport err addr' x'
layerTransport parse encode tr =
    tr { awaitMessage = \kont ->
                awaitMessage tr $ \m -> kont $ fmap (>>= uncurry parse) m
       , sendMessage = \addr' msg' -> do
                let (msg,addr) = encode msg' addr'
                sendMessage tr addr msg
       }

-- | Paritions a 'Transport' into two higher-level transports.  Note: An 'MVar'
-- is used to share the same underlying socket, so be sure to fork a thread for
-- both returned 'Transport's to avoid hanging.
partitionTransport :: ((b,a) -> Either (x,xaddr) (y,yaddr))
                      -> ((x,xaddr) -> (b,a))
                      -> ((y,yaddr) -> (b,a))
                      -> Transport err a b
                      -> IO (Transport err xaddr x, Transport err yaddr y)
partitionTransport parse encodex encodey tr = do
    mvar <- newEmptyMVar
    let xtr = tr { awaitMessage = \kont -> fix $ \again -> do
                    awaitMessage tr $ \m -> case m of
                        Just (Right msg) -> either (kont . Just . Right)
                                                   (\y -> putMVar mvar y >> again)
                                                $ parse msg
                        Just (Left e)    -> kont $ Just (Left e)
                        Nothing          -> kont Nothing
                 , sendMessage = \addr' msg' -> do
                    let (msg,addr) = encodex (msg',addr')
                    sendMessage tr addr msg
                 }
        ytr = Transport
                { awaitMessage = \kont -> takeMVar mvar >>= kont . Just . Right
                , sendMessage = \addr' msg' -> do
                    let (msg,addr) = encodey (msg',addr')
                    sendMessage tr addr msg
                , closeTransport = return ()
                }
    return (xtr, ytr)

addHandler :: (addr -> x -> IO (Maybe (x -> x))) -> Transport err addr x -> Transport err addr x
addHandler f tr = tr
    { awaitMessage = \kont -> fix $ \eat -> awaitMessage tr $ \m -> do
        case m of
            Just (Right (x, addr)) -> f addr x >>= maybe eat (kont . Just . Right . (, addr) . ($ x))
            Just (Left  e        ) -> kont $ Just (Left e)
            Nothing                -> kont $ Nothing
    }

-- | Modify a 'Transport' to invoke an action upon every received packet.
onInbound :: (addr -> x -> IO ()) -> Transport err addr x -> Transport err addr x
onInbound f tr = addHandler (\addr x -> f addr x >> return (Just id)) tr

-- * Using a query\/response client.

-- | Fork a thread that handles inbound packets.  The returned action may be used
-- to terminate the thread and clean up any related state.
--
--  Example usage:
--
--  > -- Start client.
--  > quitServer <- forkListener "listener" (clientNet client)
--  > -- Send a query q, recieve a response r.
--  > r <- sendQuery client method q
--  > -- Quit client.
--  > quitServer
forkListener :: String -> Transport err addr x -> IO (IO ())
forkListener name client = do
    thread_id <- forkIO $ do
        myThreadId >>= flip labelThread ("listener."++name)
        fix $ awaitMessage client . const
    return $ do
        closeTransport client
        killThread thread_id

-- | Send a query to a remote peer.  Note that this funciton will always time
-- out if 'forkListener' was never invoked to spawn a thread receive and
-- dispatch the response.
sendQuery ::
    forall err a b tbl x meth tid addr.
        Client err meth tid addr x              -- ^ A query/response implementation.
        -> MethodSerializer tid addr x meth a b -- ^ Information for marshalling the query.
        -> a                                    -- ^ The outbound query.
        -> addr                                 -- ^ Destination address of query.
        -> IO (Maybe b)                         -- ^ The response, or 'Nothing' if it timed out.
sendQuery (Client net d err pending whoami _) meth q addr = do
    mvar <- newEmptyMVar
    tid <- atomically $ do
        tbl <- readTVar pending
        let (tid, tbl') = dispatchRegister (tableMethods d) mvar tbl
        writeTVar pending tbl'
        return tid
    self <- whoami (Just addr)
    sendMessage net addr (wrapQuery meth tid self addr q)
    mres <- timeout (1000000 * methodTimeout meth) $ takeMVar mvar
    case mres of
        Just x -> return $ Just $ unwrapResponse meth x
        Nothing -> do
            atomically $ modifyTVar' pending (dispatchCancel (tableMethods d) tid)
            reportTimeout err (method meth) tid addr
            return Nothing

-- * Implementing a query\/response 'Client'.

-- | All inputs required to implement a query\/response client.
data Client err meth tid addr x = forall tbl. Client
    { -- | The 'Transport' used to dispatch and receive packets.
      clientNet :: Transport err addr x
      -- | Methods for handling inbound packets.
    , clientDispatcher :: DispatchMethods tbl err meth tid addr x
      -- | Methods for reporting various conditions.
    , clientErrorReporter :: ErrorReporter addr x meth tid err
      -- | State necessary for routing inbound responses and assigning unique
      -- /tid/ values for outgoing queries.
    , clientPending :: TVar tbl
      -- | An action yielding this client\'s own address.  It is invoked once
      -- on each outbound and inbound packet.  It is valid for this to always
      -- return the same value.
    , clientAddress :: Maybe addr -> IO addr
      -- | Transform a query /tid/ value to an appropriate response /tid/
      -- value.  Normally, this would be the identity transformation, but if
      -- /tid/ includes a unique cryptographic nonce, then it should be
      -- generated here.
    , clientResponseId :: tid -> IO tid
    }

-- | An incomming message can be classified into three cases.
data MessageClass err meth tid
    = IsQuery meth tid -- ^ An unsolicited query is handled based on it's /meth/ value.  Any response
                       -- should include the provided /tid/ value.
    | IsResponse tid   -- ^ A response to a outgoing query we associated with a /tid/ value.
    | IsUnknown err    -- ^ None of the above.

-- | Handler for an inbound query of type /x/ from an address of type _addr_.
data MethodHandler err tid addr x = forall a b. MethodHandler
    { -- | Parse the query into a more specific type for this method.
      methodParse :: x -> Either err a
      -- | Serialize the response for transmission, given a context /ctx/ and the origin
      -- and destination addresses.
    , methodSerialize :: tid -> addr -> addr -> b -> x
      -- | Fully typed action to perform upon the query.  The remote origin
      -- address of the query is provided to the handler.
    , methodAction :: addr -> a -> IO b
    }
    | forall a. NoReply
    { -- | Parse the query into a more specific type for this method.
      methodParse :: x -> Either err a
      -- | Fully typed action to perform upon the query.  The remote origin
      -- address of the query is provided to the handler.
    , noreplyAction :: addr -> a -> IO ()
    }

contramapAddr :: (a -> b) -> MethodHandler err tid b x -> MethodHandler err tid a x
contramapAddr f (MethodHandler p s a)
    = MethodHandler
        p
        (\tid src dst result -> s tid (f src) (f dst) result)
        (\addr arg -> a (f addr) arg)
contramapAddr f (NoReply p a)
    = NoReply p (\addr arg -> a (f addr) arg)


-- | Attempt to invoke a 'MethodHandler' upon a given inbound query.  If the
-- parse is successful, the returned IO action will construct our reply if
-- there is one.  Otherwise, a parse err is returned.
dispatchQuery :: MethodHandler err tid addr x -- ^ Handler to invoke.
                 -> tid                       -- ^ The transaction id for this query\/response session.
                 -> addr                      -- ^ Our own address, to which the query was sent.
                 -> x                         -- ^ The query packet.
                 -> addr                      -- ^ The origin address of the query.
                 -> Either err (IO (Maybe x))
dispatchQuery (MethodHandler unwrapQ wrapR f) tid self x addr =
    fmap (\a -> Just . wrapR tid self addr <$> f addr a) $ unwrapQ x
dispatchQuery (NoReply unwrapQ f) tid self x addr =
    fmap (\a -> f addr a >> return Nothing) $ unwrapQ x

-- | These four parameters are required to implement an ougoing query.  A
-- peer-to-peer algorithm will define a 'MethodSerializer' for every 'MethodHandler' that
-- might be returned by 'lookupHandler'.
data MethodSerializer tid addr x meth a b = MethodSerializer
    { -- | Seconds to wait for a response.
      methodTimeout :: Int
      -- | A method identifier used for error reporting.  This needn't be the
      -- same as the /meth/ argument to 'MethodHandler', but it is suggested.
    , method :: meth
      -- | Serialize the outgoing query /a/ into a transmitable packet /x/.
      -- The /addr/ arguments are, respectively, our own origin address and the
      -- destination of the request.  The /ctx/ argument is useful for attaching
      -- auxillary notations on all outgoing packets.
    , wrapQuery :: tid -> addr -> addr -> a -> x
      -- | Parse an inbound packet /x/ into a response /b/ for this query.
    , unwrapResponse :: x -> b
    }


-- | To dipatch responses to our outbound queries, we require three primitives.
-- See the 'transactionMethods' function to create these primitives out of a
-- lookup table and a generator for transaction ids.
--
-- The type variable /d/ is used to represent the current state of the
-- transaction generator and the table of pending transactions.
data TransactionMethods d tid x = TransactionMethods
    {
      -- | Before a query is sent, this function stores an 'MVar' to which the
      -- response will be written too.  The returned /tid/ is a transaction id
      -- that can be used to forget the 'MVar' if the remote peer is not
      -- responding.
      dispatchRegister :: MVar x -> d -> (tid, d)
      -- | This method is invoked when an incomming packet /x/ indicates it is
      -- a response to the transaction with id /tid/.  The returned IO action
      -- is will write the packet to the correct 'MVar' thus completing the
      -- dispatch.
    , dispatchResponse :: tid -> x -> d -> (d, IO ())
      -- | When a timeout interval elapses, this method is called to remove the
      -- transaction from the table.
    , dispatchCancel :: tid -> d -> d
    }

-- | The standard lookup table methods for use as input to 'transactionMethods'
-- in lieu of directly implementing 'TransactionMethods'.
data TableMethods t tid = TableMethods
    { -- | Insert a new /tid/ entry into the transaction table.
      tblInsert :: forall a. tid -> a -> t a -> t a
      -- | Delete transaction /tid/ from the transaction table.
    , tblDelete :: forall a. tid -> t a -> t a
       -- | Lookup the value associated with transaction /tid/.
    , tblLookup :: forall a. tid -> t a -> Maybe a
    }

-- | Methods for using 'Data.IntMap.
intMapMethods :: TableMethods IntMap Int
intMapMethods = TableMethods IntMap.insert IntMap.delete IntMap.lookup

-- | Methods for using 'Data.Map'
mapMethods :: Ord tid => TableMethods (Map tid) tid
mapMethods = TableMethods Map.insert Map.delete Map.lookup

-- | Change the key type for a lookup table implementation.
--
-- This can be used with 'intMapMethods' or 'mapMethods' to restrict lookups to
-- only a part of the generated /tid/ value.  This is useful for /tid/ types
-- that are especially large due their use for other purposes, such as secure
-- nonces for encryption.
contramapT f (TableMethods ins del lookup) =
    TableMethods (\k v t -> ins (f k) v t)
                 (\k t   -> del (f k) t)
                 (\k t   -> lookup (f k) t)

-- | Since 'Int' may be 32 or 64 bits, this function is provided as a
-- convenience to test if an integral type, such as 'Data.Word.Word64', can be
-- safely transformed into an 'Int' for use with 'IntMap'.
--
-- Returns 'True' if the proxied type can be losslessly convered to 'Int' using
-- 'fromIntegral'.
fitsInInt :: forall word. (Bounded word, Integral word) => Proxy word -> Bool
fitsInInt Proxy = (original == casted)
 where
    original = div maxBound 2 :: word
    casted   = fromIntegral (fromIntegral original :: Int) :: word

-- | Construct 'TransactionMethods' methods out of 3 lookup table primitives and a
-- function for generating unique transaction ids.
transactionMethods ::
    TableMethods t tid  -- ^ Table methods to lookup values by /tid/.
    -> (g -> (tid,g))   -- ^ Generate a new unique /tid/ value and update the generator state /g/.
    -> TransactionMethods (g,t (MVar x)) tid x
transactionMethods (TableMethods insert delete lookup) generate = TransactionMethods
    { dispatchCancel = \tid (g,t) -> (g, delete tid t)
    , dispatchRegister = \v (g,t) ->
        let (tid,g') = generate g
            t' = insert tid v t
        in ( tid, (g',t') )
    , dispatchResponse = \tid x (g,t) ->
        case lookup tid t of
            Just v -> let t' = delete tid t
                      in ((g,t'),void $ tryPutMVar v x)
            Nothing -> ((g,t), return ())
    }

-- | A set of methods neccessary for dispatching incomming packets.
data DispatchMethods tbl err meth tid addr x = DispatchMethods
    { -- | Clasify an inbound packet as a query or response.
      classifyInbound :: x -> MessageClass err meth tid
      -- | Lookup the handler for a inbound query.
    , lookupHandler :: meth -> Maybe (MethodHandler err tid addr x)
      -- | Methods for handling incomming responses.
    , tableMethods :: TransactionMethods tbl tid x
    }

-- | These methods indicate what should be done upon various conditions.  Write
-- to a log file, make debug prints, or simply ignore them.
--
--   [ /addr/ ]  Address of remote peer.
--
--   [ /x/ ]     Incomming or outgoing packet.
--
--   [ /meth/ ]  Method id of incomming or outgoing request.
--
--   [ /tid/ ]   Transaction id for outgoing packet.
--
--   [ /err/ ]   Error information, typically a 'String'.
data ErrorReporter addr x meth tid err = ErrorReporter
    { -- | Incomming: failed to parse packet.
      reportParseError :: err -> IO ()
      -- | Incomming: no handler for request.
    , reportMissingHandler :: meth -> addr -> x -> IO ()
      -- | Incomming: unable to identify request.
    , reportUnknown :: addr -> x -> err -> IO ()
      -- | Outgoing: remote peer is not responding.
    , reportTimeout :: meth -> tid -> addr -> IO ()
    }

ignoreErrors :: ErrorReporter addr x meth tid err
ignoreErrors = ErrorReporter
    { reportParseError = \_ -> return ()
    , reportMissingHandler = \_ _ _ -> return ()
    , reportUnknown = \_ _ _ -> return ()
    , reportTimeout = \_ _ _ -> return ()
    }

printErrors :: ( Show addr
               , Show meth
               ) => Handle -> ErrorReporter addr x meth tid String
printErrors h = ErrorReporter
    { reportParseError = \err -> hPutStrLn h err
    , reportMissingHandler = \meth addr x -> hPutStrLn h $ show addr ++ " --> Missing handler ("++show meth++")"
    , reportUnknown = \addr x err -> hPutStrLn h $ show addr ++ " --> " ++ err
    , reportTimeout = \meth tid addr -> hPutStrLn h $ show addr ++ " --> Timeout ("++show meth++")"
    }

-- Change the /err/ type for an 'ErrorReporter'.
contramapE f (ErrorReporter pe mh unk tim)
    = ErrorReporter (\e -> pe (f e))
                    mh
                    (\addr x e -> unk addr x (f e))
                    tim

-- | Handle a single inbound packet and then invoke the given continuation.
-- The 'forkListener' function is implemeneted by passing this function to
-- 'fix' in a forked thread that loops until 'awaitMessage' returns 'Nothing'
-- or throws an exception.
handleMessage ::
    Client err meth tid addr x
    -> addr
    -> x
    -> IO (Maybe (x -> x))
handleMessage (Client net d err pending whoami responseID) addr plain = do
    -- Just (Left e)              -> do reportParseError err e
    --                                  return $! Just id
    -- Just (Right (plain, addr)) -> do
        case classifyInbound d plain of
            IsQuery meth tid -> case lookupHandler d meth of
                                Nothing -> do reportMissingHandler err meth addr plain
                                              return $! Just id
                                Just m  -> do
                                  self <- whoami (Just addr)
                                  tid' <- responseID tid
                                  either (\e -> do reportParseError err e
                                                   return $! Just id)
                                         (>>= \m -> do mapM_ (sendMessage net addr) m
                                                       return $! Nothing)
                                         (dispatchQuery m tid' self plain addr)
            IsResponse tid -> do
                action <- atomically $ do
                    ts0 <- readTVar pending
                    let (ts, action) = dispatchResponse (tableMethods d) tid plain ts0
                    writeTVar pending ts
                    return action
                action
                return $! Nothing
            IsUnknown e -> do reportUnknown err addr plain e
                              return $! Just id
    -- Nothing -> return $! id

-- * UDP Datagrams.

-- | Access the address family of a given 'SockAddr'.  This convenient accessor
-- is missing from 'Network.Socket', so I implemented it here.
sockAddrFamily :: SockAddr -> Family
sockAddrFamily (SockAddrInet  _ _    ) = AF_INET
sockAddrFamily (SockAddrInet6 _ _ _ _) = AF_INET6
sockAddrFamily (SockAddrUnix  _      ) = AF_UNIX
sockAddrFamily (SockAddrCan _        ) = AF_CAN

-- | Packets with an empty payload may trigger eof exception.
-- 'udpTransport' uses this function to avoid throwing in that
-- case.
ignoreEOF def e | isEOFError e = pure def
                | otherwise    = throwIO e

-- | Hardcoded maximum packet size for incomming udp packets received via
-- 'udpTransport'.
udpBufferSize :: Int
udpBufferSize = 65536

-- | A 'udpTransport' uses a UDP socket to send and receive 'ByteString's.  The
-- argument is the listen-address for incomming packets.  This is a useful
-- low-level 'Transport' that can be transformed for higher-level protocols
-- using 'layerTransport'.
udpTransport :: SockAddr -> IO (Transport err SockAddr ByteString)
udpTransport bind_address = do
  let family = sockAddrFamily bind_address
  sock <- socket family Datagram defaultProtocol
  when (family == AF_INET6) $ do
    setSocketOption sock IPv6Only 0
  bind sock bind_address
  return Transport
    { awaitMessage = \kont -> do
        r <- handle (ignoreEOF $ Just $ Right (B.empty, SockAddrInet 0 0)) $ do
                Just . Right <$!> B.recvFrom sock udpBufferSize
        kont $! r
    , sendMessage = case family of
        -- TODO: sendTo: does not exist (Network is unreachable)
        --       Occurs when IPv6 network is not available.
        --       Currently, we require -threaded to prevent a forever-hang in this case.
        AF_INET6 -> \case
            (SockAddrInet port addr) -> \bs ->
                -- Change IPv4 to 4mapped6 address.
                void $ B.sendTo sock bs $ SockAddrInet6 port 0 (0,0,0x0000ffff,fromBE32 addr) 0
            addr6 -> \bs -> void $ B.sendTo sock bs addr6
        AF_INET  -> \case
            (SockAddrInet6 port 0 (0,0,0x0000ffff,raw4) 0) -> \bs -> do
                let host4 = toBE32 raw4
                -- Change 4mapped6 to ordinary IPv4.
                -- hPutStrLn stderr $ "4mapped6 -> "++show (SockAddrInet port host4)
                void $ B.sendTo sock bs (SockAddrInet port host4)
            addr@(SockAddrInet6 {}) -> \bs -> hPutStrLn stderr ("Discarding packet to "++show addr)
            addr4 -> \bs -> void $ B.sendTo sock bs addr4
        _ -> \addr bs -> void $ B.sendTo sock bs addr
    , closeTransport = close sock
    }