From b3a7be20b973974317b7974ee9799a403e3cf8b4 Mon Sep 17 00:00:00 2001 From: Sam Truzjan Date: Tue, 11 Feb 2014 15:28:53 +0400 Subject: Add reference to session from connection --- examples/MkTorrent.hs | 2 +- src/Network/BitTorrent/Exchange/Session.hs | 13 +++--- src/Network/BitTorrent/Exchange/Wire.hs | 75 ++++++++++++++++-------------- 3 files changed, 48 insertions(+), 42 deletions(-) diff --git a/examples/MkTorrent.hs b/examples/MkTorrent.hs index 93ac639b..e9eb7f1a 100644 --- a/examples/MkTorrent.hs +++ b/examples/MkTorrent.hs @@ -361,7 +361,7 @@ exchangeTorrent ih addr = do pid <- genPeerId var <- newEmptyMVar let hs = Handshake def (toCaps [ExtExtended]) ih pid - connectWire hs addr (toCaps [ExtMetadata]) $ do + connectWire () hs addr (toCaps [ExtMetadata]) $ do infodict <- getMetadata liftIO $ putMVar var infodict takeMVar var diff --git a/src/Network/BitTorrent/Exchange/Session.hs b/src/Network/BitTorrent/Exchange/Session.hs index 5bfc2a71..885dcb13 100644 --- a/src/Network/BitTorrent/Exchange/Session.hs +++ b/src/Network/BitTorrent/Exchange/Session.hs @@ -8,7 +8,7 @@ module Network.BitTorrent.Exchange.Session , Network.BitTorrent.Exchange.Session.insert ) where -import Control.Concurrent.STM +import Control.Concurrent import Control.Exception import Control.Lens import Control.Monad.Reader @@ -37,24 +37,23 @@ data ExchangeError | CorruptedPiece PieceIx data Session = Session - { peerId :: PeerId + { tpeerId :: PeerId , bitfield :: Bitfield , assembler :: Assembler , storage :: Storage , unchoked :: [PeerAddr IP] - , handler :: Exchange () - , connections :: Map (PeerAddr IP) Connection + , connections :: MVar (Map (PeerAddr IP) (Connection Session)) } + newSession :: PeerAddr IP -> Storage -> Bitfield -> IO Session newSession addr st bf = do return Session - { peerId = undefined + { tpeerId = undefined , bitfield = undefined , assembler = undefined , storage = undefined , unchoked = undefined - , handler = undefined , connections = undefined } @@ -76,7 +75,7 @@ deleteAll = undefined -- Event loop -----------------------------------------------------------------------} -type Exchange = StateT Session (ReaderT Connection IO) +type Exchange = StateT Session (ReaderT (Connection Session) IO) --runExchange :: Exchange () -> [PeerAddr] -> IO () --runExchange exchange peers = do diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index 31da3f0c..4bd342ca 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs @@ -67,7 +67,7 @@ import Control.Monad.Reader import Control.Monad.State import Control.Lens import Data.ByteString as BS -import Data.ByteString.Lazy as BSL +import Data.ByteString.Lazy as BSL import Data.Conduit import Data.Conduit.Cereal import Data.Conduit.List @@ -85,12 +85,13 @@ import Text.PrettyPrint as PP hiding (($$), (<>)) import Text.PrettyPrint.Class import Text.Show.Functions +import Data.BEncode as BE +import Data.Torrent +import Data.Torrent.Bitfield import Data.Torrent.InfoHash +import Data.Torrent.Piece import Network.BitTorrent.Core import Network.BitTorrent.Exchange.Message -import Data.Torrent -import Data.Torrent.Piece -import Data.BEncode as BE -- TODO handle port message? -- TODO handle limits? @@ -446,7 +447,7 @@ data ConnectionState = ConnectionState { makeLenses ''ConnectionState -- | Connection keep various info about both peers. -data Connection = Connection +data Connection s = Connection { -- | /Both/ peers handshaked with this protocol string. The only -- value is \"Bittorrent Protocol\" but this can be changed in -- future. @@ -476,13 +477,16 @@ data Connection = Connection -- -- | Max request queue length. -- , connMaxQueueLen :: !Int + + -- | Environment data. + , connSession :: !s } -instance Pretty Connection where +instance Pretty (Connection s) where pretty Connection {..} = "Connection" -- TODO check extended messages too -isAllowed :: Connection -> Message -> Bool +isAllowed :: Connection s -> Message -> Bool isAllowed Connection {..} msg | Just ext <- requires msg = ext `allowed` connCaps | otherwise = True @@ -523,56 +527,58 @@ initiateHandshake sock hs = do -----------------------------------------------------------------------} -- | do not expose this so we can change it without breaking api -newtype Connected m a = Connected { runConnected :: (ReaderT Connection m a) } - deriving (Functor, Applicative, Monad, MonadIO, MonadReader Connection, MonadThrow) +newtype Connected s m a = Connected { runConnected :: (ReaderT (Connection s) m a) } + deriving (Functor, Applicative, Monad + , MonadIO, MonadReader (Connection s), MonadThrow + ) -instance (MonadIO m) => MonadState ConnectionState (Connected m) where - get = Connected (asks connState) >>= liftIO . readIORef +instance MonadIO m => MonadState ConnectionState (Connected s m) where + get = Connected (asks connState) >>= liftIO . readIORef put x = Connected (asks connState) >>= liftIO . flip writeIORef x -instance MonadTrans Connected where +instance MonadTrans (Connected s) where lift = Connected . lift -- | A duplex channel connected to a remote peer which keep tracks -- connection parameters. -type Wire a = ConduitM Message Message (Connected IO) a +type Wire s a = ConduitM Message Message (Connected s IO) a {----------------------------------------------------------------------- -- Query -----------------------------------------------------------------------} -setExtCaps :: ExtendedCaps -> Wire () +setExtCaps :: ExtendedCaps -> Wire s () setExtCaps x = lift $ connExtCaps .= x -- | Get current extended capabilities. Note that this value can -- change in current session if either this or remote peer will -- initiate rehandshaking. -getExtCaps :: Wire ExtendedCaps +getExtCaps :: Wire s ExtendedCaps getExtCaps = lift $ use connExtCaps -setRemoteEhs :: ExtendedHandshake -> Wire () +setRemoteEhs :: ExtendedHandshake -> Wire s () setRemoteEhs x = lift $ connRemoteEhs .= x -getRemoteEhs :: Wire ExtendedHandshake +getRemoteEhs :: Wire s ExtendedHandshake getRemoteEhs = lift $ use connRemoteEhs -- | Get current stats. Note that this value will change with the next -- sent or received message. -getStats :: Wire ConnectionStats +getStats :: Wire s ConnectionStats getStats = lift $ use connStats -- | See the 'Connection' section for more info. -getConnection :: Wire Connection +getConnection :: Wire s (Connection s) getConnection = lift ask {----------------------------------------------------------------------- -- Wrapper -----------------------------------------------------------------------} -putStats :: ChannelSide -> Message -> Connected IO () +putStats :: ChannelSide -> Message -> Connected s IO () putStats side msg = connStats %= addStats side (stats msg) -validate :: ChannelSide -> Message -> Connected IO () +validate :: ChannelSide -> Message -> Connected s IO () validate side msg = do caps <- asks connCaps case requires msg of @@ -581,7 +587,7 @@ validate side msg = do | ext `allowed` caps -> return () | otherwise -> protocolError $ DisallowedMessage side ext -trackFlow :: ChannelSide -> Wire () +trackFlow :: ChannelSide -> Wire s () trackFlow side = iterM $ do validate side putStats side @@ -591,7 +597,7 @@ trackFlow side = iterM $ do -----------------------------------------------------------------------} -- | Normally you should use 'connectWire' or 'acceptWire'. -runWire :: Wire () -> Socket -> Connection -> IO () +runWire :: Wire s () -> Socket -> Connection s -> IO () runWire action sock conn = flip runReaderT conn $ runConnected $ sourceSocket sock $= conduitGet S.get $= @@ -603,20 +609,20 @@ runWire action sock conn = flip runReaderT conn $ runConnected $ -- | This function will block until a peer send new message. You can -- also use 'await'. -recvMessage :: Wire Message +recvMessage :: Wire s Message recvMessage = await >>= maybe (monadThrow PeerDisconnected) return -- | You can also use 'yield'. -sendMessage :: PeerMessage msg => msg -> Wire () +sendMessage :: PeerMessage msg => msg -> Wire s () sendMessage msg = do ecaps <- use connExtCaps yield $ envelop ecaps msg -- | Forcefully terminate wire session and close socket. -disconnectPeer :: Wire a +disconnectPeer :: Wire s a disconnectPeer = monadThrow DisconnectPeer -extendedHandshake :: ExtendedCaps -> Wire () +extendedHandshake :: ExtendedCaps -> Wire s () extendedHandshake caps = do -- TODO add other params to the handshake sendMessage $ nullExtendedHandshake caps @@ -627,10 +633,10 @@ extendedHandshake caps = do setRemoteEhs remoteEhs _ -> protocolError HandshakeRefused -rehandshake :: ExtendedCaps -> Wire () +rehandshake :: ExtendedCaps -> Wire s () rehandshake caps = undefined -reconnect :: Wire () +reconnect :: Wire s () reconnect = undefined -- | Initiate 'Wire' connection and handshake with a peer. This function will @@ -639,8 +645,8 @@ reconnect = undefined -- -- This function can throw 'WireFailure' exception. -- -connectWire :: Handshake -> PeerAddr IP -> ExtendedCaps -> Wire () -> IO () -connectWire hs addr extCaps wire = +connectWire :: s -> Handshake -> PeerAddr IP -> ExtendedCaps -> Wire s () -> IO () +connectWire session hs addr extCaps wire = bracket (peerSocket Stream addr) close $ \ sock -> do hs' <- initiateHandshake sock hs @@ -678,6 +684,7 @@ connectWire hs addr extCaps wire = , connThisPeerId = hsPeerId hs , connOptions = def , connState = cstate + , connSession = session } -- | Accept 'Wire' connection using already 'Network.Socket.accept'ed @@ -686,7 +693,7 @@ connectWire hs addr extCaps wire = -- -- This function can throw 'WireFailure' exception. -- -acceptWire :: Socket -> PeerAddr IP -> Wire () -> IO () +acceptWire :: Socket -> PeerAddr IP -> Wire s () -> IO () acceptWire sock peerAddr wire = do bracket (return sock) close $ \ _ -> do error "acceptWire: not implemented" @@ -696,7 +703,7 @@ acceptWire sock peerAddr wire = do -----------------------------------------------------------------------} -- TODO introduce new metadata exchange specific exceptions -fetchMetadata :: Wire [BS.ByteString] +fetchMetadata :: Wire s [BS.ByteString] fetchMetadata = loop 0 where recvData = recvMessage >>= inspect @@ -721,7 +728,7 @@ fetchMetadata = loop 0 then pure [pieceData piece] else (pieceData piece :) <$> loop (succ i) -getMetadata :: Wire InfoDict +getMetadata :: Wire s InfoDict getMetadata = do chunks <- fetchMetadata Connection {..} <- getConnection -- cgit v1.2.3