From 2367c9c0ab45a843cb4d1b8762af4ca06291348c Mon Sep 17 00:00:00 2001 From: Sam Truzjan Date: Thu, 5 Dec 2013 20:00:01 +0400 Subject: Make extended caps mutable --- src/Network/BitTorrent/Exchange/Wire.hs | 73 +++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 17 deletions(-) (limited to 'src/Network/BitTorrent/Exchange/Wire.hs') diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index 68c9b355..6a161762 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs @@ -4,20 +4,29 @@ -- Duplex channell -- This module control /integrity/ of data send and received. -- - +-- {-# LANGUAGE DeriveDataTypeable #-} module Network.BitTorrent.Exchange.Wire - ( -- * Exception - ProtocolError (..) + ( -- * Wire + Wire + + -- ** Exceptions + , ProtocolError (..) , WireFailure (..) , isWireFailure + , disconnectPeer - -- * Wire + -- ** Connection , Connection (..) - , Wire + + -- ** Setup , runWire , connectWire , acceptWire + + -- ** Query + , getConnection + , getExtCaps ) where import Control.Exception @@ -27,6 +36,7 @@ import Data.Conduit import Data.Conduit.Cereal as S import Data.Conduit.Network import Data.Default +import Data.IORef import Data.Maybe import Data.Monoid import Data.Serialize as S @@ -56,6 +66,7 @@ data ProtocolError = UnexpectedTopic InfoHash -- ^ peer replied with unexpected infohash. | UnexpectedPeerId PeerId -- ^ peer replied with unexpected peer id. | UnknownTopic InfoHash -- ^ peer requested unknown torrent. + | HandshakeRefused -- ^ peer do not send an extended handshake back. | InvalidMessage { violentSender :: ChannelSide -- ^ endpoint sent invalid message , extensionRequired :: Extension -- ^ @@ -87,11 +98,11 @@ isWireFailure _ = return () data Connection = Connection { connCaps :: !Caps - , connExtCaps :: !ExtendedCaps -- TODO caps can be enabled during communication + , connExtCaps :: !(IORef ExtendedCaps) , connTopic :: !InfoHash , connRemotePeerId :: !PeerId , connThisPeerId :: !PeerId - } deriving Show + } instance Pretty Connection where pretty Connection {..} = "Connection" @@ -146,6 +157,25 @@ connectToPeer p = do type Wire = ConduitM Message Message (ReaderT Connection IO) +protocolError :: ProtocolError -> Wire a +protocolError = monadThrow . ProtocolError + +disconnectPeer :: Wire a +disconnectPeer = monadThrow DisconnectPeer + +getExtCaps :: Wire ExtendedCaps +getExtCaps = do + capsRef <- lift $ asks connExtCaps + liftIO $ readIORef capsRef + +setExtCaps :: ExtendedCaps -> Wire () +setExtCaps caps = do + capsRef <- lift $ asks connExtCaps + liftIO $ writeIORef capsRef caps + +getConnection :: Wire Connection +getConnection = lift ask + validate :: ChannelSide -> Wire () validate side = await >>= maybe (return ()) yieldCheck where @@ -155,9 +185,10 @@ validate side = await >>= maybe (return ()) yieldCheck Nothing -> return () Just ext | allowed caps ext -> yield msg - | otherwise -> monadThrow $ ProtocolError $ InvalidMessage side ext + | otherwise -> protocolError $ InvalidMessage side ext -validate' action = do +validateBoth :: Wire () -> Wire () +validateBoth action = do validate RemotePeer action validate ThisPeer @@ -172,17 +203,23 @@ runWire action sock = runReaderT $ sendMessage :: PeerMessage msg => msg -> Wire () sendMessage msg = do - ecaps <- lift $ asks connExtCaps + ecaps <- getExtCaps yield $ envelop ecaps msg recvMessage :: Wire Message recvMessage = undefined -extendedHandshake :: Wire () -extendedHandshake = undefined +extendedHandshake :: ExtendedCaps -> Wire () +extendedHandshake caps = do + sendMessage $ nullExtendedHandshake caps + msg <- recvMessage + case msg of + Extended (EHandshake ExtendedHandshake {..}) -> + setExtCaps $ ehsCaps <> caps + _ -> protocolError HandshakeRefused connectWire :: Handshake -> PeerAddr -> ExtendedCaps -> Wire () -> IO () -connectWire hs addr caps wire = +connectWire hs addr extCaps wire = bracket (connectToPeer addr) close $ \ sock -> do hs' <- initiateHandshake sock hs @@ -193,12 +230,14 @@ connectWire hs addr caps wire = throwIO $ ProtocolError $ UnexpectedPeerId (hsPeerId hs') let caps = hsReserved hs <> hsReserved hs' - if allowed caps ExtExtended - then return () else return () + let wire' = if allowed caps ExtExtended + then extendedHandshake extCaps >> wire + else wire - runWire wire sock $ Connection + extCapsRef <- newIORef def + runWire wire' sock $ Connection { connCaps = caps - , connExtCaps = def + , connExtCaps = extCapsRef , connTopic = hsInfoHash hs , connRemotePeerId = hsPeerId hs' , connThisPeerId = hsPeerId hs -- cgit v1.2.3