summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Gröber <dxld@darkboxed.org>2014-01-04 04:16:42 +0000
committerSam Truzjan <pxqr.sta@gmail.com>2014-01-04 21:43:14 +0400
commitac421efc3db225d3d965580286552541f51dbb68 (patch)
tree82ebde98894d5373d3dbb60d9dcd44a2e5bcfb9a
parent9000a995bddfd85a2e2a25e23eb23ebc53489a1d (diff)
Move mutable state in Connection to single field and make a MonadState instance for (Connected IO a)
..also add lenses
-rw-r--r--src/Network/BitTorrent/Exchange/Wire.hs105
1 files changed, 45 insertions, 60 deletions
diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs
index 0897f482..a3b60b99 100644
--- a/src/Network/BitTorrent/Exchange/Wire.hs
+++ b/src/Network/BitTorrent/Exchange/Wire.hs
@@ -10,6 +10,9 @@
10-- This module control /integrity/ of data send and received. 10-- This module control /integrity/ of data send and received.
11-- 11--
12{-# LANGUAGE DeriveDataTypeable #-} 12{-# LANGUAGE DeriveDataTypeable #-}
13{-# LANGUAGE TemplateHaskell #-}
14{-# LANGUAGE MultiParamTypeClasses #-}
15{-# LANGUAGE GeneralizedNewtypeDeriving #-}
13module Network.BitTorrent.Exchange.Wire 16module Network.BitTorrent.Exchange.Wire
14 ( -- * Wire 17 ( -- * Wire
15 Wire 18 Wire
@@ -61,7 +64,10 @@ module Network.BitTorrent.Exchange.Wire
61import Control.Applicative 64import Control.Applicative
62import Control.Exception 65import Control.Exception
63import Control.Monad.Reader 66import Control.Monad.Reader
67import Control.Monad.State
68import Control.Lens
64import Data.ByteString as BS 69import Data.ByteString as BS
70import Data.ByteString.Lazy as BSL
65import Data.Conduit 71import Data.Conduit
66import Data.Conduit.Cereal 72import Data.Conduit.Cereal
67import Data.Conduit.List 73import Data.Conduit.List
@@ -73,7 +79,7 @@ import Data.Monoid
73import Data.Serialize as S 79import Data.Serialize as S
74import Data.Typeable 80import Data.Typeable
75import Network 81import Network
76import Network.Socket 82import Network.Socket hiding (Connected)
77import Network.Socket.ByteString as BS 83import Network.Socket.ByteString as BS
78import Text.PrettyPrint as PP hiding (($$), (<>)) 84import Text.PrettyPrint as PP hiding (($$), (<>))
79import Text.PrettyPrint.Class 85import Text.PrettyPrint.Class
@@ -218,10 +224,6 @@ isWireFailure _ = return ()
218protocolError :: MonadThrow m => ProtocolError -> m a 224protocolError :: MonadThrow m => ProtocolError -> m a
219protocolError = monadThrow . ProtocolError 225protocolError = monadThrow . ProtocolError
220 226
221-- | Forcefully terminate wire session and close socket.
222disconnectPeer :: Wire a
223disconnectPeer = monadThrow DisconnectPeer
224
225{----------------------------------------------------------------------- 227{-----------------------------------------------------------------------
226-- Stats 228-- Stats
227-----------------------------------------------------------------------} 229-----------------------------------------------------------------------}
@@ -251,13 +253,6 @@ instance Monoid FlowStats where
251 , messageCount = messageCount a + messageCount b 253 , messageCount = messageCount a + messageCount b
252 } 254 }
253 255
254-- | Aggregate one more message stats in this direction.
255addFlowStats :: ByteStats -> FlowStats -> FlowStats
256addFlowStats x FlowStats {..} = FlowStats
257 { messageBytes = messageBytes <> x
258 , messageCount = succ messageCount
259 }
260
261-- | Find average length of byte sequences per message. 256-- | Find average length of byte sequences per message.
262avgByteStats :: FlowStats -> ByteStats 257avgByteStats :: FlowStats -> ByteStats
263avgByteStats (FlowStats n ByteStats {..}) = ByteStats 258avgByteStats (FlowStats n ByteStats {..}) = ByteStats
@@ -302,8 +297,8 @@ instance Monoid ConnectionStats where
302 297
303-- | Aggregate one more message stats in the /specified/ direction. 298-- | Aggregate one more message stats in the /specified/ direction.
304addStats :: ChannelSide -> ByteStats -> ConnectionStats -> ConnectionStats 299addStats :: ChannelSide -> ByteStats -> ConnectionStats -> ConnectionStats
305addStats ThisPeer x s = s { outcomingFlow = addFlowStats x (outcomingFlow s) } 300addStats ThisPeer x s = s { outcomingFlow = (FlowStats 1 x) <> (outcomingFlow s) }
306addStats RemotePeer x s = s { incomingFlow = addFlowStats x (incomingFlow s) } 301addStats RemotePeer x s = s { incomingFlow = (FlowStats 1 x) <> (incomingFlow s) }
307 302
308-- | Sum of overhead and control bytes in both directions. 303-- | Sum of overhead and control bytes in both directions.
309wastedBytes :: ConnectionStats -> Int 304wastedBytes :: ConnectionStats -> Int
@@ -448,6 +443,8 @@ data ConnectionState = ConnectionState {
448 , _connMetadata :: Maybe (Cached InfoDict) 443 , _connMetadata :: Maybe (Cached InfoDict)
449 } 444 }
450 445
446makeLenses ''ConnectionState
447
451-- | Connection keep various info about both peers. 448-- | Connection keep various info about both peers.
452data Connection = Connection 449data Connection = Connection
453 { -- | /Both/ peers handshaked with this protocol string. The only 450 { -- | /Both/ peers handshaked with this protocol string. The only
@@ -474,17 +471,8 @@ data Connection = Connection
474 -- | 471 -- |
475 , connOptions :: !Options 472 , connOptions :: !Options
476 473
477 -- | If @not (allowed ExtExtended connCaps)@ then this set is always 474 -- | Mutable connection state, see 'ConnectionState'
478 -- empty. Otherwise it has the BEP10 extension protocol mandated mapping of 475 , connState :: !(IORef ConnectionState)
479 -- 'MessageId' to the message type for the remote peer.
480 , connExtCaps :: !(IORef ExtendedCaps)
481
482 -- | Current extended handshake information from the remote peer
483 , connRemoteEhs :: !(IORef ExtendedHandshake)
484
485 -- | Various stats about messages sent and received. Stats can be
486 -- used to protect /this/ peer against flood attacks.
487 , connStats :: !(IORef ConnectionStats)
488 476
489-- -- | Max request queue length. 477-- -- | Max request queue length.
490-- , connMaxQueueLen :: !Int 478-- , connMaxQueueLen :: !Int
@@ -506,7 +494,6 @@ isAllowed Connection {..} msg
506sendHandshake :: Socket -> Handshake -> IO () 494sendHandshake :: Socket -> Handshake -> IO ()
507sendHandshake sock hs = sendAll sock (S.encode hs) 495sendHandshake sock hs = sendAll sock (S.encode hs)
508 496
509-- TODO drop connection if protocol string do not match
510recvHandshake :: Socket -> IO Handshake 497recvHandshake :: Socket -> IO Handshake
511recvHandshake sock = do 498recvHandshake sock = do
512 header <- BS.recv sock 1 499 header <- BS.recv sock 1
@@ -543,7 +530,15 @@ connectToPeer p = do
543-----------------------------------------------------------------------} 530-----------------------------------------------------------------------}
544 531
545-- | do not expose this so we can change it without breaking api 532-- | do not expose this so we can change it without breaking api
546type Connected = ReaderT Connection 533newtype Connected m a = Connected { runConnected :: (ReaderT Connection m a) }
534 deriving (Functor, Applicative, Monad, MonadIO, MonadReader Connection, MonadThrow)
535
536instance (MonadIO m) => MonadState ConnectionState (Connected m) where
537 get = Connected (asks connState) >>= liftIO . readIORef
538 put x = Connected (asks connState) >>= liftIO . flip writeIORef x
539
540instance MonadTrans Connected where
541 lift = Connected . lift
547 542
548-- | A duplex channel connected to a remote peer which keep tracks 543-- | A duplex channel connected to a remote peer which keep tracks
549-- connection parameters. 544-- connection parameters.
@@ -553,40 +548,25 @@ type Wire a = ConduitM Message Message (Connected IO) a
553-- Query 548-- Query
554-----------------------------------------------------------------------} 549-----------------------------------------------------------------------}
555 550
556readRef :: (Connection -> IORef a) -> Connected IO a
557readRef f = do
558 ref <- asks f
559 liftIO (readIORef ref)
560
561writeRef :: (Connection -> IORef a) -> a -> Connected IO ()
562writeRef f v = do
563 ref <- asks f
564 liftIO (writeIORef ref v)
565
566modifyRef :: (Connection -> IORef a) -> (a -> a) -> Connected IO ()
567modifyRef f m = do
568 ref <- asks f
569 liftIO (atomicModifyIORef' ref (\x -> (m x, ())))
570
571setExtCaps :: ExtendedCaps -> Wire () 551setExtCaps :: ExtendedCaps -> Wire ()
572setExtCaps = lift . writeRef connExtCaps 552setExtCaps x = lift $ connExtCaps .= x
573 553
574-- | Get current extended capabilities. Note that this value can 554-- | Get current extended capabilities. Note that this value can
575-- change in current session if either this or remote peer will 555-- change in current session if either this or remote peer will
576-- initiate rehandshaking. 556-- initiate rehandshaking.
577getExtCaps :: Wire ExtendedCaps 557getExtCaps :: Wire ExtendedCaps
578getExtCaps = lift $ readRef connExtCaps 558getExtCaps = lift $ use connExtCaps
579 559
580setRemoteEhs :: ExtendedHandshake -> Wire () 560setRemoteEhs :: ExtendedHandshake -> Wire ()
581setRemoteEhs = lift . writeRef connRemoteEhs 561setRemoteEhs x = lift $ connRemoteEhs .= x
582 562
583getRemoteEhs :: Wire ExtendedHandshake 563getRemoteEhs :: Wire ExtendedHandshake
584getRemoteEhs = lift $ readRef connRemoteEhs 564getRemoteEhs = lift $ use connRemoteEhs
585 565
586-- | Get current stats. Note that this value will change with the next 566-- | Get current stats. Note that this value will change with the next
587-- sent or received message. 567-- sent or received message.
588getStats :: Wire ConnectionStats 568getStats :: Wire ConnectionStats
589getStats = lift $ readRef connStats 569getStats = lift $ use connStats
590 570
591-- | See the 'Connection' section for more info. 571-- | See the 'Connection' section for more info.
592getConnection :: Wire Connection 572getConnection :: Wire Connection
@@ -597,7 +577,7 @@ getConnection = lift ask
597-----------------------------------------------------------------------} 577-----------------------------------------------------------------------}
598 578
599putStats :: ChannelSide -> Message -> Connected IO () 579putStats :: ChannelSide -> Message -> Connected IO ()
600putStats side msg = modifyRef connStats (addStats side (stats msg)) 580putStats side msg = connStats %= addStats side (stats msg)
601 581
602validate :: ChannelSide -> Message -> Connected IO () 582validate :: ChannelSide -> Message -> Connected IO ()
603validate side msg = do 583validate side msg = do
@@ -619,13 +599,13 @@ trackFlow side = iterM $ do
619 599
620-- | Normally you should use 'connectWire' or 'acceptWire'. 600-- | Normally you should use 'connectWire' or 'acceptWire'.
621runWire :: Wire () -> Socket -> Connection -> IO () 601runWire :: Wire () -> Socket -> Connection -> IO ()
622runWire action sock = runReaderT $ 602runWire action sock conn = flip runReaderT conn $ runConnected $
623 sourceSocket sock $= 603 sourceSocket sock $=
624 conduitGet get $= 604 conduitGet S.get $=
625 trackFlow RemotePeer $= 605 trackFlow RemotePeer $=
626 action $= 606 action $=
627 trackFlow ThisPeer $= 607 trackFlow ThisPeer $=
628 conduitPut put $$ 608 conduitPut S.put $$
629 sinkSocket sock 609 sinkSocket sock
630 610
631-- | This function will block until a peer send new message. You can 611-- | This function will block until a peer send new message. You can
@@ -636,9 +616,13 @@ recvMessage = await >>= maybe (monadThrow PeerDisconnected) return
636-- | You can also use 'yield'. 616-- | You can also use 'yield'.
637sendMessage :: PeerMessage msg => msg -> Wire () 617sendMessage :: PeerMessage msg => msg -> Wire ()
638sendMessage msg = do 618sendMessage msg = do
639 ecaps <- getExtCaps 619 ecaps <- use connExtCaps
640 yield $ envelop ecaps msg 620 yield $ envelop ecaps msg
641 621
622-- | Forcefully terminate wire session and close socket.
623disconnectPeer :: Wire a
624disconnectPeer = monadThrow DisconnectPeer
625
642extendedHandshake :: ExtendedCaps -> Wire () 626extendedHandshake :: ExtendedCaps -> Wire ()
643extendedHandshake caps = do 627extendedHandshake caps = do
644 -- TODO add other params to the handshake 628 -- TODO add other params to the handshake
@@ -683,11 +667,14 @@ connectWire hs addr extCaps wire =
683 then extendedHandshake extCaps >> wire 667 then extendedHandshake extCaps >> wire
684 else wire 668 else wire
685 669
686 extCapsRef <- newIORef def 670 cstate <- newIORef $ ConnectionState {
687 remoteEhs <- newIORef def 671 _connExtCaps = def
688 statsRef <- newIORef ConnectionStats 672 , _connRemoteEhs = def
689 { outcomingFlow = FlowStats 1 $ handshakeStats hs 673 , _connStats = ConnectionStats {
690 , incomingFlow = FlowStats 1 $ handshakeStats hs' 674 outcomingFlow = FlowStats 1 $ handshakeStats hs
675 , incomingFlow = FlowStats 1 $ handshakeStats hs'
676 }
677 , _connMetadata = Nothing
691 } 678 }
692 679
693 runWire wire' sock $ Connection 680 runWire wire' sock $ Connection
@@ -697,9 +684,7 @@ connectWire hs addr extCaps wire =
697 , connRemotePeerId = hsPeerId hs' 684 , connRemotePeerId = hsPeerId hs'
698 , connThisPeerId = hsPeerId hs 685 , connThisPeerId = hsPeerId hs
699 , connOptions = def 686 , connOptions = def
700 , connExtCaps = extCapsRef 687 , connState = cstate
701 , connRemoteEhs = remoteEhs
702 , connStats = statsRef
703 } 688 }
704 689
705-- | Accept 'Wire' connection using already 'Network.Socket.accept'ed 690-- | Accept 'Wire' connection using already 'Network.Socket.accept'ed