From 71980c797f0fa242f544f6bf706999983b0bcf68 Mon Sep 17 00:00:00 2001 From: Sam T Date: Wed, 8 May 2013 10:52:23 +0400 Subject: ~ Fix handshake. --- src/Network/BitTorrent/PeerWire/Handshake.hs | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) (limited to 'src/Network/BitTorrent') diff --git a/src/Network/BitTorrent/PeerWire/Handshake.hs b/src/Network/BitTorrent/PeerWire/Handshake.hs index 62d7d7f4..770ca3ce 100644 --- a/src/Network/BitTorrent/PeerWire/Handshake.hs +++ b/src/Network/BitTorrent/PeerWire/Handshake.hs @@ -20,6 +20,8 @@ module Network.BitTorrent.PeerWire.Handshake ) where import Control.Applicative +import Control.Monad +import Control.Exception import Data.Word import Data.ByteString (ByteString) import qualified Data.ByteString as B @@ -74,7 +76,6 @@ instance Serialize Handshake where handshakeCaps :: Handshake -> Capabilities handshakeCaps = hsReserved --- TODO add reserved bits info -- | Format handshake in human readable form. ppHandshake :: Handshake -> String ppHandshake hs = BC.unpack (hsProtocol hs) ++ " " @@ -100,23 +101,23 @@ defaultReserved = 0 defaultHandshake :: InfoHash -> PeerID -> Handshake defaultHandshake = Handshake defaultBTProtocol defaultReserved --- TODO exceptions instead of Either -- | Handshaking with a peer specified by the second argument. --- -handshake :: Socket -> Handshake -> IO (Either String Handshake) +handshake :: Socket -> Handshake -> IO Handshake handshake sock hs = do sendAll sock (S.encode hs) header <- recv sock 1 - if B.length header == 0 then - return $ Left "" - else do - let protocolLen = B.head header - let restLen = handshakeSize protocolLen - 1 - body <- recv sock restLen - let resp = B.cons protocolLen body - - return (checkIH (S.decode resp)) + when (B.length header == 0) $ + throw $ userError "Unable to receive handshake." + + let protocolLen = B.head header + let restLen = handshakeSize protocolLen - 1 + body <- recv sock restLen + let resp = B.cons protocolLen body + + case checkIH (S.decode resp) of + Right hs' -> return hs' + Left msg -> throw $ userError msg where checkIH (Right hs') | hsInfoHash hs /= hsInfoHash hs' = Left "Handshake info hash do not match." -- cgit v1.2.3