From ac421efc3db225d3d965580286552541f51dbb68 Mon Sep 17 00:00:00 2001 From: Daniel Gröber Date: Sat, 4 Jan 2014 04:16:42 +0000 Subject: Move mutable state in Connection to single field and make a MonadState instance for (Connected IO a) ..also add lenses --- src/Network/BitTorrent/Exchange/Wire.hs | 105 ++++++++++++++------------------ 1 file changed, 45 insertions(+), 60 deletions(-) (limited to 'src/Network') diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index 0897f482..a3b60b99 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs @@ -10,6 +10,9 @@ -- This module control /integrity/ of data send and received. -- {-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} module Network.BitTorrent.Exchange.Wire ( -- * Wire Wire @@ -61,7 +64,10 @@ module Network.BitTorrent.Exchange.Wire import Control.Applicative import Control.Exception import Control.Monad.Reader +import Control.Monad.State +import Control.Lens import Data.ByteString as BS +import Data.ByteString.Lazy as BSL import Data.Conduit import Data.Conduit.Cereal import Data.Conduit.List @@ -73,7 +79,7 @@ import Data.Monoid import Data.Serialize as S import Data.Typeable import Network -import Network.Socket +import Network.Socket hiding (Connected) import Network.Socket.ByteString as BS import Text.PrettyPrint as PP hiding (($$), (<>)) import Text.PrettyPrint.Class @@ -218,10 +224,6 @@ isWireFailure _ = return () protocolError :: MonadThrow m => ProtocolError -> m a protocolError = monadThrow . ProtocolError --- | Forcefully terminate wire session and close socket. -disconnectPeer :: Wire a -disconnectPeer = monadThrow DisconnectPeer - {----------------------------------------------------------------------- -- Stats -----------------------------------------------------------------------} @@ -251,13 +253,6 @@ instance Monoid FlowStats where , messageCount = messageCount a + messageCount b } --- | Aggregate one more message stats in this direction. -addFlowStats :: ByteStats -> FlowStats -> FlowStats -addFlowStats x FlowStats {..} = FlowStats - { messageBytes = messageBytes <> x - , messageCount = succ messageCount - } - -- | Find average length of byte sequences per message. avgByteStats :: FlowStats -> ByteStats avgByteStats (FlowStats n ByteStats {..}) = ByteStats @@ -302,8 +297,8 @@ instance Monoid ConnectionStats where -- | Aggregate one more message stats in the /specified/ direction. addStats :: ChannelSide -> ByteStats -> ConnectionStats -> ConnectionStats -addStats ThisPeer x s = s { outcomingFlow = addFlowStats x (outcomingFlow s) } -addStats RemotePeer x s = s { incomingFlow = addFlowStats x (incomingFlow s) } +addStats ThisPeer x s = s { outcomingFlow = (FlowStats 1 x) <> (outcomingFlow s) } +addStats RemotePeer x s = s { incomingFlow = (FlowStats 1 x) <> (incomingFlow s) } -- | Sum of overhead and control bytes in both directions. wastedBytes :: ConnectionStats -> Int @@ -448,6 +443,8 @@ data ConnectionState = ConnectionState { , _connMetadata :: Maybe (Cached InfoDict) } +makeLenses ''ConnectionState + -- | Connection keep various info about both peers. data Connection = Connection { -- | /Both/ peers handshaked with this protocol string. The only @@ -474,17 +471,8 @@ data Connection = Connection -- | , connOptions :: !Options - -- | If @not (allowed ExtExtended connCaps)@ then this set is always - -- empty. Otherwise it has the BEP10 extension protocol mandated mapping of - -- 'MessageId' to the message type for the remote peer. - , connExtCaps :: !(IORef ExtendedCaps) - - -- | Current extended handshake information from the remote peer - , connRemoteEhs :: !(IORef ExtendedHandshake) - - -- | Various stats about messages sent and received. Stats can be - -- used to protect /this/ peer against flood attacks. - , connStats :: !(IORef ConnectionStats) + -- | Mutable connection state, see 'ConnectionState' + , connState :: !(IORef ConnectionState) -- -- | Max request queue length. -- , connMaxQueueLen :: !Int @@ -506,7 +494,6 @@ isAllowed Connection {..} msg sendHandshake :: Socket -> Handshake -> IO () sendHandshake sock hs = sendAll sock (S.encode hs) --- TODO drop connection if protocol string do not match recvHandshake :: Socket -> IO Handshake recvHandshake sock = do header <- BS.recv sock 1 @@ -543,7 +530,15 @@ connectToPeer p = do -----------------------------------------------------------------------} -- | do not expose this so we can change it without breaking api -type Connected = ReaderT Connection +newtype Connected m a = Connected { runConnected :: (ReaderT Connection m a) } + deriving (Functor, Applicative, Monad, MonadIO, MonadReader Connection, MonadThrow) + +instance (MonadIO m) => MonadState ConnectionState (Connected m) where + get = Connected (asks connState) >>= liftIO . readIORef + put x = Connected (asks connState) >>= liftIO . flip writeIORef x + +instance MonadTrans Connected where + lift = Connected . lift -- | A duplex channel connected to a remote peer which keep tracks -- connection parameters. @@ -553,40 +548,25 @@ type Wire a = ConduitM Message Message (Connected IO) a -- Query -----------------------------------------------------------------------} -readRef :: (Connection -> IORef a) -> Connected IO a -readRef f = do - ref <- asks f - liftIO (readIORef ref) - -writeRef :: (Connection -> IORef a) -> a -> Connected IO () -writeRef f v = do - ref <- asks f - liftIO (writeIORef ref v) - -modifyRef :: (Connection -> IORef a) -> (a -> a) -> Connected IO () -modifyRef f m = do - ref <- asks f - liftIO (atomicModifyIORef' ref (\x -> (m x, ()))) - setExtCaps :: ExtendedCaps -> Wire () -setExtCaps = lift . writeRef connExtCaps +setExtCaps x = lift $ connExtCaps .= x -- | Get current extended capabilities. Note that this value can -- change in current session if either this or remote peer will -- initiate rehandshaking. getExtCaps :: Wire ExtendedCaps -getExtCaps = lift $ readRef connExtCaps +getExtCaps = lift $ use connExtCaps setRemoteEhs :: ExtendedHandshake -> Wire () -setRemoteEhs = lift . writeRef connRemoteEhs +setRemoteEhs x = lift $ connRemoteEhs .= x getRemoteEhs :: Wire ExtendedHandshake -getRemoteEhs = lift $ readRef connRemoteEhs +getRemoteEhs = lift $ use connRemoteEhs -- | Get current stats. Note that this value will change with the next -- sent or received message. getStats :: Wire ConnectionStats -getStats = lift $ readRef connStats +getStats = lift $ use connStats -- | See the 'Connection' section for more info. getConnection :: Wire Connection @@ -597,7 +577,7 @@ getConnection = lift ask -----------------------------------------------------------------------} putStats :: ChannelSide -> Message -> Connected IO () -putStats side msg = modifyRef connStats (addStats side (stats msg)) +putStats side msg = connStats %= addStats side (stats msg) validate :: ChannelSide -> Message -> Connected IO () validate side msg = do @@ -619,13 +599,13 @@ trackFlow side = iterM $ do -- | Normally you should use 'connectWire' or 'acceptWire'. runWire :: Wire () -> Socket -> Connection -> IO () -runWire action sock = runReaderT $ +runWire action sock conn = flip runReaderT conn $ runConnected $ sourceSocket sock $= - conduitGet get $= + conduitGet S.get $= trackFlow RemotePeer $= action $= trackFlow ThisPeer $= - conduitPut put $$ + conduitPut S.put $$ sinkSocket sock -- | This function will block until a peer send new message. You can @@ -636,9 +616,13 @@ recvMessage = await >>= maybe (monadThrow PeerDisconnected) return -- | You can also use 'yield'. sendMessage :: PeerMessage msg => msg -> Wire () sendMessage msg = do - ecaps <- getExtCaps + ecaps <- use connExtCaps yield $ envelop ecaps msg +-- | Forcefully terminate wire session and close socket. +disconnectPeer :: Wire a +disconnectPeer = monadThrow DisconnectPeer + extendedHandshake :: ExtendedCaps -> Wire () extendedHandshake caps = do -- TODO add other params to the handshake @@ -683,11 +667,14 @@ connectWire hs addr extCaps wire = then extendedHandshake extCaps >> wire else wire - extCapsRef <- newIORef def - remoteEhs <- newIORef def - statsRef <- newIORef ConnectionStats - { outcomingFlow = FlowStats 1 $ handshakeStats hs - , incomingFlow = FlowStats 1 $ handshakeStats hs' + cstate <- newIORef $ ConnectionState { + _connExtCaps = def + , _connRemoteEhs = def + , _connStats = ConnectionStats { + outcomingFlow = FlowStats 1 $ handshakeStats hs + , incomingFlow = FlowStats 1 $ handshakeStats hs' + } + , _connMetadata = Nothing } runWire wire' sock $ Connection @@ -697,9 +684,7 @@ connectWire hs addr extCaps wire = , connRemotePeerId = hsPeerId hs' , connThisPeerId = hsPeerId hs , connOptions = def - , connExtCaps = extCapsRef - , connRemoteEhs = remoteEhs - , connStats = statsRef + , connState = cstate } -- | Accept 'Wire' connection using already 'Network.Socket.accept'ed -- cgit v1.2.3