-- | This module can implement any query\/response protocol. It was written -- with Kademlia implementations in mind. {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module Network.QueryResponse where #ifdef THREAD_DEBUG import Control.Concurrent.Lifted.Instrument #else import Control.Concurrent import GHC.Conc (labelThread) #endif import Control.Concurrent.STM import Control.Exception import Control.Monad import qualified Data.ByteString as B ;import Data.ByteString (ByteString) import Data.Function import Data.Functor.Contravariant import qualified Data.IntMap.Strict as IntMap ;import Data.IntMap.Strict (IntMap) import qualified Data.Map.Strict as Map ;import Data.Map.Strict (Map) import Data.Maybe import Data.Typeable import Network.Socket import Network.Socket.ByteString as B import System.Endian import System.IO import System.IO.Error import System.Timeout -- | Three methods are required to implement a datagram based query\/response protocol. data Transport err addr x = Transport { -- | Blocks until an inbound packet is available. Returns 'Nothing' when -- no more packets are expected due to a shutdown or close event. -- Otherwise, the packet will be parsed as type /x/ and an origin address -- /addr/. Parse failure is indicated by the type 'err'. awaitMessage :: forall a. (Maybe (Either err (x, addr)) -> IO a) -> IO a -- | Send an /x/ packet to the given destination /addr/. , sendMessage :: addr -> x -> IO () -- | Shutdown and clean up any state related to this 'Transport'. , closeTransport :: IO () } -- | This function modifies a 'Transport' to use higher-level addresses and -- packet representations. It could be used to change UDP 'ByteString's into -- bencoded syntax trees or to add an encryption layer in which addresses have -- associated public keys. layerTransport :: (x -> addr -> Either err (x', addr')) -- ^ Function that attempts to transform a low-level address/packet -- pair into a higher level representation. -> (x' -> addr' -> (x, addr)) -- ^ Function to encode a high-level address/packet into a lower level -- representation. -> Transport err addr x -- ^ The low-level transport to be transformed. -> Transport err addr' x' layerTransport parse encode tr = tr { awaitMessage = \kont -> awaitMessage tr $ \m -> kont $ fmap (>>= uncurry parse) m , sendMessage = \addr' msg' -> do let (msg,addr) = encode msg' addr' sendMessage tr addr msg } -- | Paritions a 'Transport' into two higher-level transports. Note: An 'MVar' -- is used to share the same underlying socket, so be sure to fork a thread for -- both returned 'Transport's to avoid hanging. partitionTransport :: ((b,a) -> Either (x,xaddr) (y,yaddr)) -> ((x,xaddr) -> (b,a)) -> ((y,yaddr) -> (b,a)) -> Transport err a b -> IO (Transport err xaddr x, Transport err yaddr y) partitionTransport parse encodex encodey tr = partitionTransportM (return . parse) (return . encodex) (return . encodey) tr -- | Paritions a 'Transport' into two higher-level transports. Note: An 'MVar' -- is used to share the same underlying socket, so be sure to fork a thread for -- both returned 'Transport's to avoid hanging. partitionTransportM :: ((b,a) -> IO (Either (x,xaddr) (y,yaddr))) -> ((x,xaddr) -> IO (b,a)) -> ((y,yaddr) -> IO (b,a)) -> Transport err a b -> IO (Transport err xaddr x, Transport err yaddr y) partitionTransportM parse encodex encodey tr = do mvar <- newEmptyMVar let xtr = tr { awaitMessage = \kont -> fix $ \again -> do awaitMessage tr $ \m -> case m of Just (Right msg) -> parse msg >>= either (kont . Just . Right) (\y -> putMVar mvar y >> again) Just (Left e) -> kont $ Just (Left e) Nothing -> kont Nothing , sendMessage = \addr' msg' -> do (msg,addr) <- encodex (msg',addr') sendMessage tr addr msg } ytr = Transport { awaitMessage = \kont -> takeMVar mvar >>= kont . Just . Right , sendMessage = \addr' msg' -> do (msg,addr) <- encodey (msg',addr') sendMessage tr addr msg , closeTransport = return () } return (xtr, ytr) addHandler :: (addr -> x -> IO (Maybe (x -> x))) -> Transport err addr x -> Transport err addr x addHandler f tr = tr { awaitMessage = \kont -> fix $ \eat -> awaitMessage tr $ \m -> do case m of Just (Right (x, addr)) -> f addr x >>= maybe eat (kont . Just . Right . (, addr) . ($ x)) Just (Left e ) -> kont $ Just (Left e) Nothing -> kont $ Nothing } -- | Modify a 'Transport' to invoke an action upon every received packet. onInbound :: (addr -> x -> IO ()) -> Transport err addr x -> Transport err addr x onInbound f tr = addHandler (\addr x -> f addr x >> return (Just id)) tr -- * Using a query\/response client. -- | Fork a thread that handles inbound packets. The returned action may be used -- to terminate the thread and clean up any related state. -- -- Example usage: -- -- > -- Start client. -- > quitServer <- forkListener "listener" (clientNet client) -- > -- Send a query q, recieve a response r. -- > r <- sendQuery client method q -- > -- Quit client. -- > quitServer forkListener :: String -> Transport err addr x -> IO (IO ()) forkListener name client = do thread_id <- forkIO $ do myThreadId >>= flip labelThread ("listener."++name) fix $ awaitMessage client . const return $ do closeTransport client killThread thread_id -- | Send a query to a remote peer. Note that this funciton will always time -- out if 'forkListener' was never invoked to spawn a thread to receive and -- dispatch the response. sendQuery :: forall err a b tbl x meth tid addr. Client err meth tid addr x -- ^ A query/response implementation. -> MethodSerializer tid addr x meth a b -- ^ Information for marshalling the query. -> a -- ^ The outbound query. -> addr -- ^ Destination address of query. -> IO (Maybe b) -- ^ The response, or 'Nothing' if it timed out. sendQuery (Client net d err pending whoami _) meth q addr = do mvar <- newEmptyMVar tid <- atomically $ do tbl <- readTVar pending (tid, tbl') <- dispatchRegister (tableMethods d) mvar tbl writeTVar pending tbl' return tid self <- whoami (Just addr) sendMessage net addr (wrapQuery meth tid self addr q) mres <- timeout (1000000 * methodTimeout meth) $ takeMVar mvar case mres of Just x -> return $ Just $ unwrapResponse meth x Nothing -> do atomically $ readTVar pending >>= dispatchCancel (tableMethods d) tid >>= writeTVar pending reportTimeout err (method meth) tid addr return Nothing -- * Implementing a query\/response 'Client'. -- | All inputs required to implement a query\/response client. data Client err meth tid addr x = forall tbl. Client { -- | The 'Transport' used to dispatch and receive packets. clientNet :: Transport err addr x -- | Methods for handling inbound packets. , clientDispatcher :: DispatchMethods tbl err meth tid addr x -- | Methods for reporting various conditions. , clientErrorReporter :: ErrorReporter addr x meth tid err -- | State necessary for routing inbound responses and assigning unique -- /tid/ values for outgoing queries. , clientPending :: TVar tbl -- | An action yielding this client\'s own address. It is invoked once -- on each outbound and inbound packet. It is valid for this to always -- return the same value. , clientAddress :: Maybe addr -> IO addr -- | Transform a query /tid/ value to an appropriate response /tid/ -- value. Normally, this would be the identity transformation, but if -- /tid/ includes a unique cryptographic nonce, then it should be -- generated here. , clientResponseId :: tid -> IO tid } -- | An incomming message can be classified into three cases. data MessageClass err meth tid = IsQuery meth tid -- ^ An unsolicited query is handled based on it's /meth/ value. Any response -- should include the provided /tid/ value. | IsResponse tid -- ^ A response to a outgoing query we associated with a /tid/ value. | IsUnknown err -- ^ None of the above. -- | Handler for an inbound query of type /x/ from an address of type _addr_. data MethodHandler err tid addr x = forall a b. MethodHandler { -- | Parse the query into a more specific type for this method. methodParse :: x -> Either err a -- | Serialize the response for transmission, given a context /ctx/ and the origin -- and destination addresses. , methodSerialize :: tid -> addr -> addr -> b -> x -- | Fully typed action to perform upon the query. The remote origin -- address of the query is provided to the handler. , methodAction :: addr -> a -> IO b } | forall a. NoReply { -- | Parse the query into a more specific type for this method. methodParse :: x -> Either err a -- | Fully typed action to perform upon the query. The remote origin -- address of the query is provided to the handler. , noreplyAction :: addr -> a -> IO () } contramapAddr :: (a -> b) -> MethodHandler err tid b x -> MethodHandler err tid a x contramapAddr f (MethodHandler p s a) = MethodHandler p (\tid src dst result -> s tid (f src) (f dst) result) (\addr arg -> a (f addr) arg) contramapAddr f (NoReply p a) = NoReply p (\addr arg -> a (f addr) arg) -- | Attempt to invoke a 'MethodHandler' upon a given inbound query. If the -- parse is successful, the returned IO action will construct our reply if -- there is one. Otherwise, a parse err is returned. dispatchQuery :: MethodHandler err tid addr x -- ^ Handler to invoke. -> tid -- ^ The transaction id for this query\/response session. -> addr -- ^ Our own address, to which the query was sent. -> x -- ^ The query packet. -> addr -- ^ The origin address of the query. -> Either err (IO (Maybe x)) dispatchQuery (MethodHandler unwrapQ wrapR f) tid self x addr = fmap (\a -> Just . wrapR tid self addr <$> f addr a) $ unwrapQ x dispatchQuery (NoReply unwrapQ f) tid self x addr = fmap (\a -> f addr a >> return Nothing) $ unwrapQ x -- | These four parameters are required to implement an ougoing query. A -- peer-to-peer algorithm will define a 'MethodSerializer' for every 'MethodHandler' that -- might be returned by 'lookupHandler'. data MethodSerializer tid addr x meth a b = MethodSerializer { -- | Seconds to wait for a response. methodTimeout :: Int -- | A method identifier used for error reporting. This needn't be the -- same as the /meth/ argument to 'MethodHandler', but it is suggested. , method :: meth -- | Serialize the outgoing query /a/ into a transmitable packet /x/. -- The /addr/ arguments are, respectively, our own origin address and the -- destination of the request. The /tid/ argument is useful for attaching -- auxillary notations on all outgoing packets. , wrapQuery :: tid -> addr -> addr -> a -> x -- | Parse an inbound packet /x/ into a response /b/ for this query. , unwrapResponse :: x -> b } -- | To dipatch responses to our outbound queries, we require three primitives. -- See the 'transactionMethods' function to create these primitives out of a -- lookup table and a generator for transaction ids. -- -- The type variable /d/ is used to represent the current state of the -- transaction generator and the table of pending transactions. data TransactionMethods d tid x = TransactionMethods { -- | Before a query is sent, this function stores an 'MVar' to which the -- response will be written too. The returned /tid/ is a transaction id -- that can be used to forget the 'MVar' if the remote peer is not -- responding. dispatchRegister :: MVar x -> d -> STM (tid, d) -- | This method is invoked when an incomming packet /x/ indicates it is -- a response to the transaction with id /tid/. The returned IO action -- is will write the packet to the correct 'MVar' thus completing the -- dispatch. , dispatchResponse :: tid -> x -> d -> STM (d, IO ()) -- | When a timeout interval elapses, this method is called to remove the -- transaction from the table. , dispatchCancel :: tid -> d -> STM d } -- | The standard lookup table methods for use as input to 'transactionMethods' -- in lieu of directly implementing 'TransactionMethods'. data TableMethods t tid = TableMethods { -- | Insert a new /tid/ entry into the transaction table. tblInsert :: forall a. tid -> a -> t a -> t a -- | Delete transaction /tid/ from the transaction table. , tblDelete :: forall a. tid -> t a -> t a -- | Lookup the value associated with transaction /tid/. , tblLookup :: forall a. tid -> t a -> Maybe a } -- | Methods for using 'Data.IntMap. intMapMethods :: TableMethods IntMap Int intMapMethods = TableMethods IntMap.insert IntMap.delete IntMap.lookup -- | Methods for using 'Data.Map' mapMethods :: Ord tid => TableMethods (Map tid) tid mapMethods = TableMethods Map.insert Map.delete Map.lookup -- | Change the key type for a lookup table implementation. -- -- This can be used with 'intMapMethods' or 'mapMethods' to restrict lookups to -- only a part of the generated /tid/ value. This is useful for /tid/ types -- that are especially large due their use for other purposes, such as secure -- nonces for encryption. instance Contravariant (TableMethods t) where -- contramap :: (tid -> t1) -> TableMethods t t1 -> TableMethods t tid contramap f (TableMethods ins del lookup) = TableMethods (\k v t -> ins (f k) v t) (\k t -> del (f k) t) (\k t -> lookup (f k) t) -- | Since 'Int' may be 32 or 64 bits, this function is provided as a -- convenience to test if an integral type, such as 'Data.Word.Word64', can be -- safely transformed into an 'Int' for use with 'IntMap'. -- -- Returns 'True' if the proxied type can be losslessly converted to 'Int' using -- 'fromIntegral'. fitsInInt :: forall word. (Bounded word, Integral word) => Proxy word -> Bool fitsInInt Proxy = (original == casted) where original = div maxBound 2 :: word casted = fromIntegral (fromIntegral original :: Int) :: word -- | Construct 'TransactionMethods' methods out of 3 lookup table primitives and a -- function for generating unique transaction ids. transactionMethods :: TableMethods t tid -- ^ Table methods to lookup values by /tid/. -> (g -> (tid,g)) -- ^ Generate a new unique /tid/ value and update the generator state /g/. -> TransactionMethods (g,t (MVar x)) tid x transactionMethods (TableMethods insert delete lookup) generate = TransactionMethods { dispatchCancel = \tid (g,t) -> return (g, delete tid t) , dispatchRegister = \v (g,t) -> let (tid,g') = generate g t' = insert tid v t in return ( tid, (g',t') ) , dispatchResponse = \tid x (g,t) -> case lookup tid t of Just v -> let t' = delete tid t in return ((g,t'),void $ tryPutMVar v x) Nothing -> return ((g,t), return ()) } -- | A set of methods neccessary for dispatching incomming packets. data DispatchMethods tbl err meth tid addr x = DispatchMethods { -- | Clasify an inbound packet as a query or response. classifyInbound :: x -> MessageClass err meth tid -- | Lookup the handler for a inbound query. , lookupHandler :: meth -> Maybe (MethodHandler err tid addr x) -- | Methods for handling incomming responses. , tableMethods :: TransactionMethods tbl tid x } -- | These methods indicate what should be done upon various conditions. Write -- to a log file, make debug prints, or simply ignore them. -- -- [ /addr/ ] Address of remote peer. -- -- [ /x/ ] Incomming or outgoing packet. -- -- [ /meth/ ] Method id of incomming or outgoing request. -- -- [ /tid/ ] Transaction id for outgoing packet. -- -- [ /err/ ] Error information, typically a 'String'. data ErrorReporter addr x meth tid err = ErrorReporter { -- | Incomming: failed to parse packet. reportParseError :: err -> IO () -- | Incomming: no handler for request. , reportMissingHandler :: meth -> addr -> x -> IO () -- | Incomming: unable to identify request. , reportUnknown :: addr -> x -> err -> IO () -- | Outgoing: remote peer is not responding. , reportTimeout :: meth -> tid -> addr -> IO () } ignoreErrors :: ErrorReporter addr x meth tid err ignoreErrors = ErrorReporter { reportParseError = \_ -> return () , reportMissingHandler = \_ _ _ -> return () , reportUnknown = \_ _ _ -> return () , reportTimeout = \_ _ _ -> return () } printErrors :: ( Show addr , Show meth ) => Handle -> ErrorReporter addr x meth tid String printErrors h = ErrorReporter { reportParseError = \err -> hPutStrLn h err , reportMissingHandler = \meth addr x -> hPutStrLn h $ show addr ++ " --> Missing handler ("++show meth++")" , reportUnknown = \addr x err -> hPutStrLn h $ show addr ++ " --> " ++ err , reportTimeout = \meth tid addr -> hPutStrLn h $ show addr ++ " --> Timeout ("++show meth++")" } -- Change the /err/ type for an 'ErrorReporter'. instance Contravariant (ErrorReporter addr x meth tid) where -- contramap :: (t5 -> t4) -> ErrorReporter t3 t2 t1 t t4 -> ErrorReporter t3 t2 t1 t t5 contramap f (ErrorReporter pe mh unk tim) = ErrorReporter (\e -> pe (f e)) mh (\addr x e -> unk addr x (f e)) tim -- | Handle a single inbound packet and then invoke the given continuation. -- The 'forkListener' function is implemeneted by passing this function to -- 'fix' in a forked thread that loops until 'awaitMessage' returns 'Nothing' -- or throws an exception. handleMessage :: Client err meth tid addr x -> addr -> x -> IO (Maybe (x -> x)) handleMessage (Client net d err pending whoami responseID) addr plain = do -- Just (Left e) -> do reportParseError err e -- return $! Just id -- Just (Right (plain, addr)) -> do case classifyInbound d plain of IsQuery meth tid -> case lookupHandler d meth of Nothing -> do reportMissingHandler err meth addr plain return $! Just id Just m -> do self <- whoami (Just addr) tid' <- responseID tid either (\e -> do reportParseError err e return $! Just id) (>>= \m -> do mapM_ (sendMessage net addr) m return $! Nothing) (dispatchQuery m tid' self plain addr) IsResponse tid -> do action <- atomically $ do ts0 <- readTVar pending (ts, action) <- dispatchResponse (tableMethods d) tid plain ts0 writeTVar pending ts return action action return $! Nothing IsUnknown e -> do reportUnknown err addr plain e return $! Just id -- Nothing -> return $! id -- * UDP Datagrams. -- | Access the address family of a given 'SockAddr'. This convenient accessor -- is missing from 'Network.Socket', so I implemented it here. sockAddrFamily :: SockAddr -> Family sockAddrFamily (SockAddrInet _ _ ) = AF_INET sockAddrFamily (SockAddrInet6 _ _ _ _) = AF_INET6 sockAddrFamily (SockAddrUnix _ ) = AF_UNIX sockAddrFamily (SockAddrCan _ ) = AF_CAN -- | Packets with an empty payload may trigger eof exception. -- 'udpTransport' uses this function to avoid throwing in that -- case. ignoreEOF :: a -> IOError -> IO a ignoreEOF def e | isEOFError e = pure def | otherwise = throwIO e -- | Hardcoded maximum packet size for incomming udp packets received via -- 'udpTransport'. udpBufferSize :: Int udpBufferSize = 65536 -- | A 'udpTransport' uses a UDP socket to send and receive 'ByteString's. The -- argument is the listen-address for incomming packets. This is a useful -- low-level 'Transport' that can be transformed for higher-level protocols -- using 'layerTransport'. udpTransport :: SockAddr -> IO (Transport err SockAddr ByteString) udpTransport bind_address = do let family = sockAddrFamily bind_address sock <- socket family Datagram defaultProtocol when (family == AF_INET6) $ do setSocketOption sock IPv6Only 0 bind sock bind_address return Transport { awaitMessage = \kont -> do r <- handle (ignoreEOF $ Just $ Right (B.empty, SockAddrInet 0 0)) $ do Just . Right <$!> B.recvFrom sock udpBufferSize kont $! r , sendMessage = case family of -- TODO: sendTo: does not exist (Network is unreachable) -- Occurs when IPv6 network is not available. -- Currently, we require -threaded to prevent a forever-hang in this case. AF_INET6 -> \case (SockAddrInet port addr) -> \bs -> -- Change IPv4 to 4mapped6 address. void $ B.sendTo sock bs $ SockAddrInet6 port 0 (0,0,0x0000ffff,fromBE32 addr) 0 addr6 -> \bs -> void $ B.sendTo sock bs addr6 AF_INET -> \case (SockAddrInet6 port 0 (0,0,0x0000ffff,raw4) 0) -> \bs -> do let host4 = toBE32 raw4 -- Change 4mapped6 to ordinary IPv4. -- hPutStrLn stderr $ "4mapped6 -> "++show (SockAddrInet port host4) void $ B.sendTo sock bs (SockAddrInet port host4) addr@(SockAddrInet6 {}) -> \bs -> hPutStrLn stderr ("Discarding packet to "++show addr) addr4 -> \bs -> void $ B.sendTo sock bs addr4 _ -> \addr bs -> void $ B.sendTo sock bs addr , closeTransport = close sock }