summaryrefslogtreecommitdiff
path: root/src/Network/QueryResponse.hs
blob: f6f2807d030d5e3430e437412e9160ca2104582a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
-- | This module can implement any query\/response protocol.  It was written
-- with Kademlia implementations in mind.

{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE GADTs #-}
module Network.QueryResponse where

#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import GHC.Conc (labelThread)
import Control.Concurrent
#endif
import Control.Concurrent.STM
import System.Timeout
import Data.Function
import Control.Exception
import Control.Monad
import qualified Data.ByteString as B
         ;import Data.ByteString (ByteString)
import Network.Socket
import Network.Socket.ByteString as B
import System.IO.Error
import Data.Maybe

-- * Using a query\/response 'Client'.

-- | Fork a thread that handles inbound packets.  The returned action may be used
-- to terminate the thread and clean up any related state.
--
--  Example usage:
--
--  > -- Start client.
--  > quitServer <- forkListener client
--  > -- Send a query q, recieve a response r.
--  > r <- sendQuery client method q
--  > -- Quit client.
--  > quitServer
forkListener :: Client err meth tid addr x ctx -> IO (IO ())
forkListener client = do
    thread_id <- forkIO $ do
        myThreadId >>= flip labelThread "listener"
        fix $ handleMessage client
    return $ do
        closeTransport (clientNet client)
        killThread thread_id

-- | Send a query to a remote peer.  Note that this funciton will always time
-- out if 'forkListener' was never invoked to spawn a thread receive and
-- dispatch the response.
sendQuery ::
    forall err a b tbl x ctx meth tid addr.
        Client err meth tid addr x ctx          -- ^ A query/response implementation.
        -> MethodSerializer addr x ctx meth a b -- ^ Information for marshalling the query.
        -> a                                    -- ^ The outbound query.
        -> addr                                 -- ^ Destination address of query.
        -> IO (Maybe b)                         -- ^ The response, or 'Nothing' if it timed out.
sendQuery (Client net d err pending whoami) meth q addr = do
    mvar <- newEmptyMVar
    tid <- atomically $ do
        tbl <- readTVar pending
        let (tid, tbl') = dispatchRegister (tableMethods d) mvar tbl
        writeTVar pending tbl'
        return tid
    (self,ctx) <- whoami
    sendMessage net addr (wrapQuery meth ctx self addr q)
    mres <- timeout (methodTimeout meth) $ takeMVar mvar
    case mres of
        Just x -> return $ Just $ unwrapResponse meth x
        Nothing -> do
            atomically $ modifyTVar' pending (dispatchCancel (tableMethods d) tid)
            reportTimeout err (method meth) tid addr
            return Nothing

-- * Implementing a query\/response 'Client'.

-- | All inputs required to implement a query\/response client.
data Client err meth tid addr x ctx = forall tbl. Client
    { -- | The 'Transport' used to dispatch and receive packets.
      clientNet :: Transport err addr x
      -- | Methods for handling inbound packets.
    , clientDispatcher :: DispatchMethods tbl err meth tid addr x ctx
      -- | Methods for reporting various conditions.
    , clientErrorReporter :: ErrorReporter addr x meth tid err
      -- | State necessary for routing inbound responses and assigning unique
      -- /tid/ values for outgoing queries.
    , clientPending :: TVar tbl
      -- | An action yielding this client\'s own address along with some
      -- context neccessary for serializing outgoing packets.  It is invoked
      -- once on each outbound and inbound packet.  It is valid for this to
      -- always return the same value.
    , clientContext :: IO (addr,ctx)
    }

-- | An incomming message can be classified into three cases.
data MessageClass err meth tid
    = IsQuery meth tid -- ^ An unsolicited query is handled based on it's /meth/ value.  Any response
                       -- should include the provided /tid/ value.
    | IsResponse tid   -- ^ A response to a outgoing query we associated with a /tid/ value.
    | IsUnknown err    -- ^ None of the above.

-- | Handler for an inbound query of type _x_ from an address of type _addr_.
data MethodHandler err tid addr x ctx = forall a b. MethodHandler
    { -- | Parse the query into a more specific type for this method.
      methodParse :: x -> Either err a
      -- | Serialize the response for transmission, given a context /ctx/ and the origin
      -- and destination addresses.
    , methodSerialize :: ctx -> tid -> addr -> addr -> b -> x
      -- | Fully typed action to perform upon the query.  The remote origin
      -- address of the query is provided to the handler.
    , methodAction :: addr -> a -> IO b
    }

-- | Attempt to invoke a 'MethodHandler' upon a given inbound query.  If the
-- parse is successful, the returned IO action will construct our reply.
-- Otherwise, a parse err is returned.
dispatchQuery :: MethodHandler err tid addr x ctx -- ^ Handler to invoke.
                 -> ctx                           -- ^ Arbitrary context used during serialization.
                 -> tid                           -- ^ The transaction id for this query\/response session.
                 -> addr                          -- ^ Our own address, to which the query was sent.
                 -> x                             -- ^ The query packet.
                 -> addr                          -- ^ The origin address of the query.
                 -> Either err (IO x)
dispatchQuery (MethodHandler unwrapQ wrapR f) ctx tid self x addr =
    fmap (\a -> wrapR ctx tid self addr <$> f addr a) $ unwrapQ x

-- | These four parameters are required to implement an ougoing query.  A
-- peer-to-peer algorithm will define a 'MethodSerializer' for every 'MethodHandler' that
-- might be returned by 'lookupHandler'.
data MethodSerializer addr x ctx meth a b = MethodSerializer
    { -- | Seconds to wait for a response.
      methodTimeout :: Int
      -- | A method identifier used for error reporting.  This needn't be the
      -- same as the /meth/ argument to 'MethodHandler', but it is suggested.
    , method :: meth
      -- | Serialize the outgoing query /a/ into a transmitable packet /x/.
      -- The /addr/ arguments are, respectively, our own origin address and the
      -- destination of the request.  The /ctx/ argument is useful for attaching
      -- auxillary notations on all outgoing packets.
    , wrapQuery :: ctx -> addr -> addr -> a -> x
      -- | Parse an inbound packet /x/ into a response /b/ for this query.
    , unwrapResponse :: x -> b
    }


-- | Three methods are required to implement a datagram based query\/response protocol.
data Transport err addr x = Transport
    { -- | Blocks until an inbound packet is available. Returns 'Nothing' when
      -- no more packets are expected due to a shutdown or close event.
      -- Otherwise, the packet will be parsed as type /x/ and an origin address
      -- /addr/.  Parse failure is indicated by the type 'err'.
      awaitMessage :: IO (Maybe (Either err (x, addr)))
      -- | Send an /x/ packet to the given destination /addr/.
    , sendMessage :: addr -> x -> IO ()
      -- | Shutdown and clean up any state related to this 'Transport'.
    , closeTransport :: IO ()
    }

-- | This function modifies a 'Transport' to use higher-level addresses and
-- packet representations.  It could be used to change UDP 'ByteString's into
-- bencoded syntax trees or to add an encryption layer in which addresses have
-- associated public keys.
layerTransport ::
        (x -> addr -> Either err (x', addr'))
        -- ^ Function that attempts to transform a low-level address/packet
        -- pair into a higher level representation.
        -> (x' -> addr' -> (x, addr))
        -- ^ Function to encode a high-level address/packet into a lower level
        -- representation.
        -> Transport err addr x
        -- ^ The low-level transport to be transformed.
        -> Transport err addr' x'
layerTransport parse encode tr =
    tr { awaitMessage = do
                m <- awaitMessage tr
                return $ fmap (>>= uncurry parse) m
       , sendMessage = \addr' msg' -> do
                let (msg,addr) = encode msg' addr'
                sendMessage tr addr msg
       }


-- | To dipatch responses to our outbound queries, we require three primitives.
-- See the 'transactionTableMethods' function to create these primitives out of a
-- lookup table and a generator for transaction ids.
--
-- The type variable /d/ is used to represent the current state of the
-- transaction generator and the table of pending transactions.
data TableMethods d tid x = TableMethods
    {
      -- | Before a query is sent, this function stores an 'MVar' to which the
      -- response will be written too.  The returned _tid_ is a transaction id
      -- that can be used to forget the 'MVar' if the remote peer is not
      -- responding.
      dispatchRegister :: MVar x -> d -> (tid, d)
      -- | This method is invoked when an incomming packet _x_ indicates it is
      -- a response to the transaction with id _tid_.  The returned IO action
      -- is will write the packet to the correct 'MVar' thus completing the
      -- dispatch.
    , dispatchResponse :: tid -> x -> d -> (d, IO ())
      -- | When a timeout interval elapses, this method is called to remove the
      -- transaction from the table.
    , dispatchCancel :: tid -> d -> d
    }

-- | Construct 'TableMethods' methods out of 3 lookup table primitives and a
-- function for generating unique transaction ids.
transactionTableMethods ::
    (forall a. tid -> a -> t a -> t a)
    -- ^ Insert a new _tid_ entry into the transaction table.
    -> (forall a. tid -> t a -> t a)
    -- ^ Delete transaction _tid_ from the transaction table.
    -> (forall a. tid -> t a -> Maybe a)
    -- ^ Lookup the value associated with transaction _tid_.
    -> (g -> (tid,g))
    -- ^ Generate a new unique _tid_ value and update the generator state _g_.
    -> TableMethods (g,t (MVar x)) tid x
transactionTableMethods insert delete lookup generate = TableMethods
    { dispatchCancel = \tid (g,t) -> (g, delete tid t)
    , dispatchRegister = \v (g,t) ->
        let (tid,g') = generate g
            t' = insert tid v t
        in ( tid, (g',t') )
    , dispatchResponse = \tid x (g,t) ->
        case lookup tid t of
            Just v -> let t' = delete tid t
                      in ((g,t'),void $ tryPutMVar v x)
            Nothing -> ((g,t), return ())
    }

-- | A set of methods neccessary for dispatching incomming packets.
data DispatchMethods tbl err meth tid addr x ctx = DispatchMethods
    { -- | Clasify an inbound packet as a query or response.
      classifyInbound :: x -> MessageClass err meth tid
      -- | Lookup the handler for a inbound query.
    , lookupHandler :: meth -> Maybe (MethodHandler err tid addr x ctx)
      -- | Methods for handling incomming responses.
    , tableMethods :: TableMethods tbl tid x
    }

-- | These methods indicate what should be done upon various conditions.  Write
-- to a log file, make debug prints, or simply ignore them.
--
--   [ /addr/ ]  Address of remote peer.
--
--   [ /x/ ]     Incomming or outgoing packet.
--
--   [ /meth/ ]  Method id of incomming or outgoing request.
--
--   [ /tid/ ]   Transaction id for outgoing packet.
--
--   [ /err/ ]   Error information, typically a 'String'.
data ErrorReporter addr x meth tid err = ErrorReporter
    { -- | Incomming: failed to parse packet.
      reportParseError :: err -> IO ()
      -- | Incomming: no handler for request.
    , reportMissingHandler :: meth -> addr -> x -> IO ()
      -- | Incomming: unable to identify request.
    , reportUnknown :: addr -> x -> err -> IO ()
      -- | Outgoing: remote peer is not responding.
    , reportTimeout :: meth -> tid -> addr -> IO ()
    }

-- | Handle a single inbound packet and then invoke the given continuation.
-- The 'forkListener' function is implemeneted by passing this function to
-- 'fix' in a forked thread that loops until 'awaitMessage' returns 'Nothing'
-- or throws an exception.
handleMessage ::
    Client err meth tid addr x ctx
    -> IO ()
    -> IO ()
handleMessage (Client net d err pending whoami) again = do
  awaitMessage net >>= \case
    Just (Left e)              -> do reportParseError err e
                                     again
    Just (Right (plain, addr)) -> do
        case classifyInbound d plain of
            IsQuery meth tid -> case lookupHandler d meth of
                                Nothing -> reportMissingHandler err meth addr plain
                                Just m  -> do
                                  (self,ctx) <- whoami
                                  either (reportParseError err)
                                         (>>= sendMessage net addr)
                                         (dispatchQuery m ctx tid self plain addr)
            IsResponse tid -> do
                action <- atomically $ do
                    ts0 <- readTVar pending
                    let (ts, action) = dispatchResponse (tableMethods d) tid plain ts0
                    writeTVar pending ts
                    return action
                action
            IsUnknown e -> reportUnknown err addr plain e
        again
    Nothing -> return ()

-- * UDP Datagrams.

-- | Access the address family of a given 'SockAddr'.  This convenient accessor
-- is missing from 'Network.Socket', so I implemented it here.
sockAddrFamily :: SockAddr -> Family
sockAddrFamily (SockAddrInet  _ _    ) = AF_INET
sockAddrFamily (SockAddrInet6 _ _ _ _) = AF_INET6
sockAddrFamily (SockAddrUnix  _      ) = AF_UNIX
sockAddrFamily (SockAddrCan _        ) = AF_CAN

-- | Packets with an empty payload may trigger eof exception.
-- 'udpTransport' uses this function to avoid throwing in that
-- case.
ignoreEOF def e | isEOFError e = pure def
                | otherwise    = throwIO e

-- | Hardcoded maximum packet size for incomming udp packets received via
-- 'udpTransport'.
udpBufferSize :: Int
udpBufferSize = 65536

-- | A 'udpTransport' uses a UDP socket to send and receive 'ByteString's.  The
-- argument is the listen-address for incomming packets.  This is a useful
-- low-level 'Transport' that can be transformed for higher-level protocols
-- using 'layerTransport'.
udpTransport :: SockAddr -> IO (Transport err SockAddr ByteString)
udpTransport bind_address = do
  let family = sockAddrFamily bind_address
  sock <- socket family Datagram defaultProtocol
  when (family == AF_INET6) $ do
    setSocketOption sock IPv6Only 0
  bind sock bind_address
  return Transport
    { awaitMessage = handle (ignoreEOF $ Just $ Right (B.empty, SockAddrInet 0 0)) $ do
        r <- B.recvFrom sock udpBufferSize
        return $ Just $ Right r
    , sendMessage = \addr bs -> void $ B.sendTo sock bs addr
    , closeTransport = close sock
    }