diff options
author | Daniel Gröber <dxld@darkboxed.org> | 2014-01-04 04:16:42 +0000 |
---|---|---|
committer | Sam Truzjan <pxqr.sta@gmail.com> | 2014-01-04 21:43:14 +0400 |
commit | ac421efc3db225d3d965580286552541f51dbb68 (patch) | |
tree | 82ebde98894d5373d3dbb60d9dcd44a2e5bcfb9a | |
parent | 9000a995bddfd85a2e2a25e23eb23ebc53489a1d (diff) |
Move mutable state in Connection to single field and make a MonadState instance for (Connected IO a)
..also add lenses
-rw-r--r-- | src/Network/BitTorrent/Exchange/Wire.hs | 105 |
1 files changed, 45 insertions, 60 deletions
diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index 0897f482..a3b60b99 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs | |||
@@ -10,6 +10,9 @@ | |||
10 | -- This module control /integrity/ of data send and received. | 10 | -- This module control /integrity/ of data send and received. |
11 | -- | 11 | -- |
12 | {-# LANGUAGE DeriveDataTypeable #-} | 12 | {-# LANGUAGE DeriveDataTypeable #-} |
13 | {-# LANGUAGE TemplateHaskell #-} | ||
14 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
15 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
13 | module Network.BitTorrent.Exchange.Wire | 16 | module Network.BitTorrent.Exchange.Wire |
14 | ( -- * Wire | 17 | ( -- * Wire |
15 | Wire | 18 | Wire |
@@ -61,7 +64,10 @@ module Network.BitTorrent.Exchange.Wire | |||
61 | import Control.Applicative | 64 | import Control.Applicative |
62 | import Control.Exception | 65 | import Control.Exception |
63 | import Control.Monad.Reader | 66 | import Control.Monad.Reader |
67 | import Control.Monad.State | ||
68 | import Control.Lens | ||
64 | import Data.ByteString as BS | 69 | import Data.ByteString as BS |
70 | import Data.ByteString.Lazy as BSL | ||
65 | import Data.Conduit | 71 | import Data.Conduit |
66 | import Data.Conduit.Cereal | 72 | import Data.Conduit.Cereal |
67 | import Data.Conduit.List | 73 | import Data.Conduit.List |
@@ -73,7 +79,7 @@ import Data.Monoid | |||
73 | import Data.Serialize as S | 79 | import Data.Serialize as S |
74 | import Data.Typeable | 80 | import Data.Typeable |
75 | import Network | 81 | import Network |
76 | import Network.Socket | 82 | import Network.Socket hiding (Connected) |
77 | import Network.Socket.ByteString as BS | 83 | import Network.Socket.ByteString as BS |
78 | import Text.PrettyPrint as PP hiding (($$), (<>)) | 84 | import Text.PrettyPrint as PP hiding (($$), (<>)) |
79 | import Text.PrettyPrint.Class | 85 | import Text.PrettyPrint.Class |
@@ -218,10 +224,6 @@ isWireFailure _ = return () | |||
218 | protocolError :: MonadThrow m => ProtocolError -> m a | 224 | protocolError :: MonadThrow m => ProtocolError -> m a |
219 | protocolError = monadThrow . ProtocolError | 225 | protocolError = monadThrow . ProtocolError |
220 | 226 | ||
221 | -- | Forcefully terminate wire session and close socket. | ||
222 | disconnectPeer :: Wire a | ||
223 | disconnectPeer = monadThrow DisconnectPeer | ||
224 | |||
225 | {----------------------------------------------------------------------- | 227 | {----------------------------------------------------------------------- |
226 | -- Stats | 228 | -- Stats |
227 | -----------------------------------------------------------------------} | 229 | -----------------------------------------------------------------------} |
@@ -251,13 +253,6 @@ instance Monoid FlowStats where | |||
251 | , messageCount = messageCount a + messageCount b | 253 | , messageCount = messageCount a + messageCount b |
252 | } | 254 | } |
253 | 255 | ||
254 | -- | Aggregate one more message stats in this direction. | ||
255 | addFlowStats :: ByteStats -> FlowStats -> FlowStats | ||
256 | addFlowStats x FlowStats {..} = FlowStats | ||
257 | { messageBytes = messageBytes <> x | ||
258 | , messageCount = succ messageCount | ||
259 | } | ||
260 | |||
261 | -- | Find average length of byte sequences per message. | 256 | -- | Find average length of byte sequences per message. |
262 | avgByteStats :: FlowStats -> ByteStats | 257 | avgByteStats :: FlowStats -> ByteStats |
263 | avgByteStats (FlowStats n ByteStats {..}) = ByteStats | 258 | avgByteStats (FlowStats n ByteStats {..}) = ByteStats |
@@ -302,8 +297,8 @@ instance Monoid ConnectionStats where | |||
302 | 297 | ||
303 | -- | Aggregate one more message stats in the /specified/ direction. | 298 | -- | Aggregate one more message stats in the /specified/ direction. |
304 | addStats :: ChannelSide -> ByteStats -> ConnectionStats -> ConnectionStats | 299 | addStats :: ChannelSide -> ByteStats -> ConnectionStats -> ConnectionStats |
305 | addStats ThisPeer x s = s { outcomingFlow = addFlowStats x (outcomingFlow s) } | 300 | addStats ThisPeer x s = s { outcomingFlow = (FlowStats 1 x) <> (outcomingFlow s) } |
306 | addStats RemotePeer x s = s { incomingFlow = addFlowStats x (incomingFlow s) } | 301 | addStats RemotePeer x s = s { incomingFlow = (FlowStats 1 x) <> (incomingFlow s) } |
307 | 302 | ||
308 | -- | Sum of overhead and control bytes in both directions. | 303 | -- | Sum of overhead and control bytes in both directions. |
309 | wastedBytes :: ConnectionStats -> Int | 304 | wastedBytes :: ConnectionStats -> Int |
@@ -448,6 +443,8 @@ data ConnectionState = ConnectionState { | |||
448 | , _connMetadata :: Maybe (Cached InfoDict) | 443 | , _connMetadata :: Maybe (Cached InfoDict) |
449 | } | 444 | } |
450 | 445 | ||
446 | makeLenses ''ConnectionState | ||
447 | |||
451 | -- | Connection keep various info about both peers. | 448 | -- | Connection keep various info about both peers. |
452 | data Connection = Connection | 449 | data Connection = Connection |
453 | { -- | /Both/ peers handshaked with this protocol string. The only | 450 | { -- | /Both/ peers handshaked with this protocol string. The only |
@@ -474,17 +471,8 @@ data Connection = Connection | |||
474 | -- | | 471 | -- | |
475 | , connOptions :: !Options | 472 | , connOptions :: !Options |
476 | 473 | ||
477 | -- | If @not (allowed ExtExtended connCaps)@ then this set is always | 474 | -- | Mutable connection state, see 'ConnectionState' |
478 | -- empty. Otherwise it has the BEP10 extension protocol mandated mapping of | 475 | , connState :: !(IORef ConnectionState) |
479 | -- 'MessageId' to the message type for the remote peer. | ||
480 | , connExtCaps :: !(IORef ExtendedCaps) | ||
481 | |||
482 | -- | Current extended handshake information from the remote peer | ||
483 | , connRemoteEhs :: !(IORef ExtendedHandshake) | ||
484 | |||
485 | -- | Various stats about messages sent and received. Stats can be | ||
486 | -- used to protect /this/ peer against flood attacks. | ||
487 | , connStats :: !(IORef ConnectionStats) | ||
488 | 476 | ||
489 | -- -- | Max request queue length. | 477 | -- -- | Max request queue length. |
490 | -- , connMaxQueueLen :: !Int | 478 | -- , connMaxQueueLen :: !Int |
@@ -506,7 +494,6 @@ isAllowed Connection {..} msg | |||
506 | sendHandshake :: Socket -> Handshake -> IO () | 494 | sendHandshake :: Socket -> Handshake -> IO () |
507 | sendHandshake sock hs = sendAll sock (S.encode hs) | 495 | sendHandshake sock hs = sendAll sock (S.encode hs) |
508 | 496 | ||
509 | -- TODO drop connection if protocol string do not match | ||
510 | recvHandshake :: Socket -> IO Handshake | 497 | recvHandshake :: Socket -> IO Handshake |
511 | recvHandshake sock = do | 498 | recvHandshake sock = do |
512 | header <- BS.recv sock 1 | 499 | header <- BS.recv sock 1 |
@@ -543,7 +530,15 @@ connectToPeer p = do | |||
543 | -----------------------------------------------------------------------} | 530 | -----------------------------------------------------------------------} |
544 | 531 | ||
545 | -- | do not expose this so we can change it without breaking api | 532 | -- | do not expose this so we can change it without breaking api |
546 | type Connected = ReaderT Connection | 533 | newtype Connected m a = Connected { runConnected :: (ReaderT Connection m a) } |
534 | deriving (Functor, Applicative, Monad, MonadIO, MonadReader Connection, MonadThrow) | ||
535 | |||
536 | instance (MonadIO m) => MonadState ConnectionState (Connected m) where | ||
537 | get = Connected (asks connState) >>= liftIO . readIORef | ||
538 | put x = Connected (asks connState) >>= liftIO . flip writeIORef x | ||
539 | |||
540 | instance MonadTrans Connected where | ||
541 | lift = Connected . lift | ||
547 | 542 | ||
548 | -- | A duplex channel connected to a remote peer which keep tracks | 543 | -- | A duplex channel connected to a remote peer which keep tracks |
549 | -- connection parameters. | 544 | -- connection parameters. |
@@ -553,40 +548,25 @@ type Wire a = ConduitM Message Message (Connected IO) a | |||
553 | -- Query | 548 | -- Query |
554 | -----------------------------------------------------------------------} | 549 | -----------------------------------------------------------------------} |
555 | 550 | ||
556 | readRef :: (Connection -> IORef a) -> Connected IO a | ||
557 | readRef f = do | ||
558 | ref <- asks f | ||
559 | liftIO (readIORef ref) | ||
560 | |||
561 | writeRef :: (Connection -> IORef a) -> a -> Connected IO () | ||
562 | writeRef f v = do | ||
563 | ref <- asks f | ||
564 | liftIO (writeIORef ref v) | ||
565 | |||
566 | modifyRef :: (Connection -> IORef a) -> (a -> a) -> Connected IO () | ||
567 | modifyRef f m = do | ||
568 | ref <- asks f | ||
569 | liftIO (atomicModifyIORef' ref (\x -> (m x, ()))) | ||
570 | |||
571 | setExtCaps :: ExtendedCaps -> Wire () | 551 | setExtCaps :: ExtendedCaps -> Wire () |
572 | setExtCaps = lift . writeRef connExtCaps | 552 | setExtCaps x = lift $ connExtCaps .= x |
573 | 553 | ||
574 | -- | Get current extended capabilities. Note that this value can | 554 | -- | Get current extended capabilities. Note that this value can |
575 | -- change in current session if either this or remote peer will | 555 | -- change in current session if either this or remote peer will |
576 | -- initiate rehandshaking. | 556 | -- initiate rehandshaking. |
577 | getExtCaps :: Wire ExtendedCaps | 557 | getExtCaps :: Wire ExtendedCaps |
578 | getExtCaps = lift $ readRef connExtCaps | 558 | getExtCaps = lift $ use connExtCaps |
579 | 559 | ||
580 | setRemoteEhs :: ExtendedHandshake -> Wire () | 560 | setRemoteEhs :: ExtendedHandshake -> Wire () |
581 | setRemoteEhs = lift . writeRef connRemoteEhs | 561 | setRemoteEhs x = lift $ connRemoteEhs .= x |
582 | 562 | ||
583 | getRemoteEhs :: Wire ExtendedHandshake | 563 | getRemoteEhs :: Wire ExtendedHandshake |
584 | getRemoteEhs = lift $ readRef connRemoteEhs | 564 | getRemoteEhs = lift $ use connRemoteEhs |
585 | 565 | ||
586 | -- | Get current stats. Note that this value will change with the next | 566 | -- | Get current stats. Note that this value will change with the next |
587 | -- sent or received message. | 567 | -- sent or received message. |
588 | getStats :: Wire ConnectionStats | 568 | getStats :: Wire ConnectionStats |
589 | getStats = lift $ readRef connStats | 569 | getStats = lift $ use connStats |
590 | 570 | ||
591 | -- | See the 'Connection' section for more info. | 571 | -- | See the 'Connection' section for more info. |
592 | getConnection :: Wire Connection | 572 | getConnection :: Wire Connection |
@@ -597,7 +577,7 @@ getConnection = lift ask | |||
597 | -----------------------------------------------------------------------} | 577 | -----------------------------------------------------------------------} |
598 | 578 | ||
599 | putStats :: ChannelSide -> Message -> Connected IO () | 579 | putStats :: ChannelSide -> Message -> Connected IO () |
600 | putStats side msg = modifyRef connStats (addStats side (stats msg)) | 580 | putStats side msg = connStats %= addStats side (stats msg) |
601 | 581 | ||
602 | validate :: ChannelSide -> Message -> Connected IO () | 582 | validate :: ChannelSide -> Message -> Connected IO () |
603 | validate side msg = do | 583 | validate side msg = do |
@@ -619,13 +599,13 @@ trackFlow side = iterM $ do | |||
619 | 599 | ||
620 | -- | Normally you should use 'connectWire' or 'acceptWire'. | 600 | -- | Normally you should use 'connectWire' or 'acceptWire'. |
621 | runWire :: Wire () -> Socket -> Connection -> IO () | 601 | runWire :: Wire () -> Socket -> Connection -> IO () |
622 | runWire action sock = runReaderT $ | 602 | runWire action sock conn = flip runReaderT conn $ runConnected $ |
623 | sourceSocket sock $= | 603 | sourceSocket sock $= |
624 | conduitGet get $= | 604 | conduitGet S.get $= |
625 | trackFlow RemotePeer $= | 605 | trackFlow RemotePeer $= |
626 | action $= | 606 | action $= |
627 | trackFlow ThisPeer $= | 607 | trackFlow ThisPeer $= |
628 | conduitPut put $$ | 608 | conduitPut S.put $$ |
629 | sinkSocket sock | 609 | sinkSocket sock |
630 | 610 | ||
631 | -- | This function will block until a peer send new message. You can | 611 | -- | This function will block until a peer send new message. You can |
@@ -636,9 +616,13 @@ recvMessage = await >>= maybe (monadThrow PeerDisconnected) return | |||
636 | -- | You can also use 'yield'. | 616 | -- | You can also use 'yield'. |
637 | sendMessage :: PeerMessage msg => msg -> Wire () | 617 | sendMessage :: PeerMessage msg => msg -> Wire () |
638 | sendMessage msg = do | 618 | sendMessage msg = do |
639 | ecaps <- getExtCaps | 619 | ecaps <- use connExtCaps |
640 | yield $ envelop ecaps msg | 620 | yield $ envelop ecaps msg |
641 | 621 | ||
622 | -- | Forcefully terminate wire session and close socket. | ||
623 | disconnectPeer :: Wire a | ||
624 | disconnectPeer = monadThrow DisconnectPeer | ||
625 | |||
642 | extendedHandshake :: ExtendedCaps -> Wire () | 626 | extendedHandshake :: ExtendedCaps -> Wire () |
643 | extendedHandshake caps = do | 627 | extendedHandshake caps = do |
644 | -- TODO add other params to the handshake | 628 | -- TODO add other params to the handshake |
@@ -683,11 +667,14 @@ connectWire hs addr extCaps wire = | |||
683 | then extendedHandshake extCaps >> wire | 667 | then extendedHandshake extCaps >> wire |
684 | else wire | 668 | else wire |
685 | 669 | ||
686 | extCapsRef <- newIORef def | 670 | cstate <- newIORef $ ConnectionState { |
687 | remoteEhs <- newIORef def | 671 | _connExtCaps = def |
688 | statsRef <- newIORef ConnectionStats | 672 | , _connRemoteEhs = def |
689 | { outcomingFlow = FlowStats 1 $ handshakeStats hs | 673 | , _connStats = ConnectionStats { |
690 | , incomingFlow = FlowStats 1 $ handshakeStats hs' | 674 | outcomingFlow = FlowStats 1 $ handshakeStats hs |
675 | , incomingFlow = FlowStats 1 $ handshakeStats hs' | ||
676 | } | ||
677 | , _connMetadata = Nothing | ||
691 | } | 678 | } |
692 | 679 | ||
693 | runWire wire' sock $ Connection | 680 | runWire wire' sock $ Connection |
@@ -697,9 +684,7 @@ connectWire hs addr extCaps wire = | |||
697 | , connRemotePeerId = hsPeerId hs' | 684 | , connRemotePeerId = hsPeerId hs' |
698 | , connThisPeerId = hsPeerId hs | 685 | , connThisPeerId = hsPeerId hs |
699 | , connOptions = def | 686 | , connOptions = def |
700 | , connExtCaps = extCapsRef | 687 | , connState = cstate |
701 | , connRemoteEhs = remoteEhs | ||
702 | , connStats = statsRef | ||
703 | } | 688 | } |
704 | 689 | ||
705 | -- | Accept 'Wire' connection using already 'Network.Socket.accept'ed | 690 | -- | Accept 'Wire' connection using already 'Network.Socket.accept'ed |