From 9953d0a9ba7e992062ae60ae8e24054b0883b50e Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Fri, 24 Jan 2020 23:08:14 -0500 Subject: QueryResponse rework: awaitMessage :: STM (Arrival err addr x, IO ()) --- server/src/Network/QueryResponse.hs | 127 +++++++++++++++----------------- server/src/Network/QueryResponse/TCP.hs | 5 +- 2 files changed, 63 insertions(+), 69 deletions(-) diff --git a/server/src/Network/QueryResponse.hs b/server/src/Network/QueryResponse.hs index 69cc6f50..0c200916 100644 --- a/server/src/Network/QueryResponse.hs +++ b/server/src/Network/QueryResponse.hs @@ -19,6 +19,7 @@ import Control.Concurrent.Lifted.Instrument import Control.Concurrent import GHC.Conc (labelThread) #endif +import Control.Arrow import Control.Concurrent.STM import Control.Exception import Control.Monad @@ -73,7 +74,7 @@ data Arrival err addr x data TransportA err addr x y = Transport { -- | Blocks until an inbound packet is available. Then calls the provided -- continuation with the packet and origin address or an error condition. - awaitMessage :: forall a. (Arrival err addr x -> IO a) -> STM (IO a) + awaitMessage :: STM (Arrival err addr x, IO ()) -- | Send an /y/ packet to the given destination /addr/. , sendMessage :: addr -> y -> IO () -- | Shutdown and clean up any state related to this 'Transport'. @@ -84,7 +85,7 @@ type Transport err addr x = TransportA err addr x x nullTransport :: TransportA err addr x y nullTransport = Transport - { awaitMessage = \_ -> retry + { awaitMessage = retry , sendMessage = \_ _ -> return () , setActive = \_ -> return () } @@ -97,7 +98,7 @@ closeTransport tr = setActive tr False -- bencoded syntax trees or to add an encryption layer in which addresses have -- associated public keys. layerTransportM :: - (x -> addr -> IO (Either err (x', addr'))) + (x -> addr -> STM (Either err (x', addr'))) -- ^ Function that attempts to transform a low-level address/packet -- pair into a higher level representation. -> (y' -> addr' -> IO (y, addr)) @@ -107,14 +108,15 @@ layerTransportM :: -- ^ The low-level transport to be transformed. -> TransportA err addr' x' y' layerTransportM parse encode tr = - tr { awaitMessage = \kont -> - awaitMessage tr $ \case - Terminated -> kont $ Terminated - Discarded -> kont $ Discarded - ParseError e -> kont $ ParseError e - Arrival addr x -> parse x addr >>= \case - Left e -> kont $ ParseError e - Right (x',addr') -> kont $ Arrival addr' x' + tr { awaitMessage = do + (m,io) <- awaitMessage tr + case m of + Terminated -> return $ (,) Terminated io + Discarded -> return $ (,) Discarded io + ParseError e -> return $ (ParseError e,io) + Arrival addr x -> parse x addr >>= \case + Left e -> return (ParseError e, io) + Right (x',addr') -> return (Arrival addr' x', io) , sendMessage = \addr' msg' -> do (msg,addr) <- encode msg' addr' sendMessage tr addr msg @@ -143,26 +145,26 @@ layerTransport parse encode tr = -- | Paritions a 'Transport' into two higher-level transports. Note: A 'TChan' -- is used to share the same underlying socket, so be sure to fork a thread for -- both returned 'Transport's to avoid hanging. -partitionTransportM :: ((b,a) -> IO (Either (x,xaddr) (b,a))) +partitionTransportM :: ((b,a) -> STM (Either (x,xaddr) (b,a))) -> ((y,xaddr) -> IO (Maybe (c,a))) -> TransportA err a b c -> IO (TransportA err xaddr x y, TransportA err a b c) partitionTransportM parse encodex tr = do tchan <- atomically newTChan - let ytr = tr { awaitMessage = \kont -> - awaitMessage tr $ \m -> case m of + let ytr = tr { awaitMessage = do + (m,io) <- awaitMessage tr + case m of Arrival adr msg -> parse (msg,adr) >>= \case - Left x -> atomically (writeTChan tchan (Just x)) >> kont Discarded - Right (y,yaddr) -> kont $ Arrival yaddr y - ParseError e -> kont $ ParseError e - Discarded -> kont $ Discarded - Terminated -> atomically (writeTChan tchan Nothing) >> kont Terminated + Left x -> return (Discarded, io >> atomically (writeTChan tchan (Just x))) + Right (y,yaddr) -> return (Arrival yaddr y, io) + Terminated -> return (Terminated, io >> atomically (writeTChan tchan Nothing)) + _ -> return (m,io) , sendMessage = sendMessage tr } xtr = Transport - { awaitMessage = \kont -> readTChan tchan >>= pure . kont . \case - Nothing -> Terminated - Just (x,xaddr) -> Arrival xaddr x + { awaitMessage = readTChan tchan >>= \case + Nothing -> return (Terminated, return ()) + Just (x,xaddr) -> return (Arrival xaddr x, return ()) , sendMessage = \addr' msg' -> do msg_addr <- encodex (msg',addr') mapM_ (uncurry . flip $ sendMessage tr) msg_addr @@ -180,24 +182,21 @@ partitionTransport :: ((b,a) -> Either (x,xaddr) (b,a)) partitionTransport parse encodex tr = partitionTransportM (return . parse) (return . encodex) tr --- | --- * f add x --> Nothing, consume x --- --> Just id, leave x to a different handler --- --> Just g, apply g to x and leave that to a different handler --- --- Note: If you add a handler to one of the branches before applying a --- 'mergeTransports' combinator, then this handler may not block or return --- Nothing. -addHandler :: (addr -> x -> IO (Maybe (x -> x))) -> TransportA err addr x y -> TransportA err addr x y +addHandler :: (Arrival err addr x -> STM (Arrival err addr x, IO ())) -> TransportA err addr x y -> TransportA err addr x y addHandler f tr = tr - { awaitMessage = \kont -> fix $ \eat -> awaitMessage tr $ \case - Arrival addr x -> f addr x >>= maybe (join $ atomically eat) (kont . Arrival addr . ($ x)) - m -> kont m + { awaitMessage = do + (m,io1) <- awaitMessage tr + (m', io2) <- f m + return (m', io1 >> io2) } +forArrival :: Applicative m => (addr -> x -> IO ()) -> Arrival err addr x -> m (Arrival err addr x, IO ()) +forArrival f (Arrival addr x) = pure (Arrival addr x, f addr x) +forArrival _ m = pure (m, return ()) + -- | Modify a 'Transport' to invoke an action upon every received packet. onInbound :: (addr -> x -> IO ()) -> Transport err addr x -> Transport err addr x -onInbound f tr = addHandler (\addr x -> f addr x >> return (Just id)) tr +onInbound f tr = addHandler (forArrival f) tr -- * Using a query\/response client. @@ -217,10 +216,13 @@ forkListener name onParseError client = do setActive client True thread_id <- forkIO $ do myThreadId >>= flip labelThread ("listener."++name) - fix $ \loop -> join $ atomically $ awaitMessage client $ \case - Terminated -> return () - ParseError e -> onParseError e >> loop - _ -> loop + fix $ \loop -> do + (m,io) <- atomically $ awaitMessage client + io + case m of + Terminated -> return () + ParseError e -> onParseError e >> loop + _ -> loop dput XMisc $ "Listener died: " ++ name return $ do setActive client False @@ -539,40 +541,35 @@ transactionMethods methods generate = transactionMethods' id id methods generate -- throws an exception. handleMessage :: ClientA err meth tid addr x y - -> addr - -> x - -> IO (Maybe (x -> x)) -handleMessage (Client net d err pending whoami responseID) addr plain = do + -> Arrival err addr x + -> STM (Arrival err addr x, IO ()) +handleMessage (Client net d err pending whoami responseID) msg@(Arrival addr plain) = do -- Just (Left e) -> do reportParseError err e -- return $! Just id -- Just (Right (plain, addr)) -> do case classifyInbound d plain of IsQuery meth tid -> case lookupHandler d meth of - Nothing -> do reportMissingHandler err meth addr plain - return $! Just id - Just m -> do + Nothing -> return (msg, reportMissingHandler err meth addr plain) + Just m -> return $ (,) Discarded $ do self <- whoami (Just addr) tid' <- responseID tid - either (\e -> do reportParseError err e - return $! Just id) - (>>= \m -> do mapM_ (sendMessage net addr) m - return $! Nothing) + either (\e -> reportParseError err e) + (\iom -> iom >>= mapM_ (sendMessage net addr)) (dispatchQuery m tid' self plain addr) - IsUnsolicited action -> do + IsUnsolicited action -> return $ (,) Discarded $ do self <- whoami (Just addr) - action self addr - return Nothing - IsResponse tid -> do + _ <- action self addr + return () + IsResponse tid -> return $ (,) Discarded $ do action <- atomically $ do ts0 <- readTVar pending (ts, action) <- dispatchResponse (tableMethods d) tid (Success plain) ts0 writeTVar pending ts return action action - return $! Nothing - IsUnknown e -> do reportUnknown err addr plain e - return $! Just id + IsUnknown e -> return (msg, reportUnknown err addr plain e) -- Nothing -> return $! id +handleMessage _ msg = return (msg, return ()) -- * UDP Datagrams. @@ -629,9 +626,7 @@ udpTransport' bind_address = do isClosed <- newEmptyMVar udpTChan <- atomically newTChan let tr = Transport { - awaitMessage = \kont -> do - r <- readTChan udpTChan - return $ kont $! r + awaitMessage = fmap (,return()) $ readTChan udpTChan , sendMessage = case family of AF_INET6 -> \case (SockAddrInet port addr) -> \bs -> @@ -681,11 +676,9 @@ udpTransport bind_address = fst <$> udpTransport' bind_address chanTransport :: (addr -> TChan (x, addr)) -> addr -> TChan (x, addr) -> TVar Bool -> Transport err addr x chanTransport chanFromAddr self achan aclosed = Transport - { awaitMessage = \kont -> do - x <- (uncurry (flip Arrival) <$> readTChan achan) - `orElse` - (readTVar aclosed >>= check >> return Terminated) - return $ kont x + { awaitMessage = fmap (, return ()) $ + orElse (uncurry (flip Arrival) <$> readTChan achan) + (readTVar aclosed >>= check >> return Terminated) , sendMessage = \them bs -> do atomically $ writeTChan (chanFromAddr them) (bs,self) , setActive = \case @@ -720,8 +713,8 @@ mergeTransports tmap = do -- vmap <- traverseWithKey (\k v -> Tagged <$> newEmptyMVar) tmap -- foldrWithKey (\k v n -> forkMergeBranch k v >> n) (return ()) vmap return Transport - { awaitMessage = \kont -> - foldrWithKey (\k (ByAddress tr) n -> awaitMessage tr (kont . decorateAddr k) `orElse` n) + { awaitMessage = + foldrWithKey (\k (ByAddress tr) n -> (first (decorateAddr k) <$> awaitMessage tr) `orElse` n) retry tmap , sendMessage = \(tag :=> Identity addr) x -> case DMap.lookup tag tmap of diff --git a/server/src/Network/QueryResponse/TCP.hs b/server/src/Network/QueryResponse/TCP.hs index 8b1b432b..24aacd98 100644 --- a/server/src/Network/QueryResponse/TCP.hs +++ b/server/src/Network/QueryResponse/TCP.hs @@ -213,8 +213,9 @@ tcpTransport maxcon stream = do msgvar <- atomically newEmptyTMVar tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty) return $ (,) tcpcache Transport - { awaitMessage = \f -> takeTMVar msgvar >>= \x -> return $ do - f x `catchIOError` (\e -> dput XTCP ("TCP transport stopped. " ++ show e) >> f Terminated) + { awaitMessage = do + x <- takeTMVar msgvar + return (x, return ()) , sendMessage = \addr (bDoCon,y) -> do void . forkLabeled "tcp-send" $ do msock <- acquireConnection msgvar tcpcache stream addr bDoCon -- cgit v1.2.3