diff options
-rw-r--r-- | src/Network/BitTorrent/Exchange/Message.hs | 67 | ||||
-rw-r--r-- | tests/Network/BitTorrent/Exchange/MessageSpec.hs | 17 |
2 files changed, 58 insertions, 26 deletions
diff --git a/src/Network/BitTorrent/Exchange/Message.hs b/src/Network/BitTorrent/Exchange/Message.hs index 4d1694c6..e0a7dad7 100644 --- a/src/Network/BitTorrent/Exchange/Message.hs +++ b/src/Network/BitTorrent/Exchange/Message.hs | |||
@@ -40,9 +40,9 @@ module Network.BitTorrent.Exchange.Message | |||
40 | , Caps | 40 | , Caps |
41 | 41 | ||
42 | -- * Handshake | 42 | -- * Handshake |
43 | , ProtocolString | ||
43 | , Handshake(..) | 44 | , Handshake(..) |
44 | , defaultHandshake | 45 | , defaultHandshake |
45 | , defaultBTProtocol | ||
46 | , handshakeSize | 46 | , handshakeSize |
47 | , handshakeMaxSize | 47 | , handshakeMaxSize |
48 | 48 | ||
@@ -203,12 +203,48 @@ instance Capabilities Caps where | |||
203 | Handshake | 203 | Handshake |
204 | -----------------------------------------------------------------------} | 204 | -----------------------------------------------------------------------} |
205 | 205 | ||
206 | maxProtocolStringSize :: Word8 | ||
207 | maxProtocolStringSize = maxBound | ||
208 | |||
209 | -- | The protocol name is used to identify to the local peer which | ||
210 | -- version of BTP the remote peer uses. | ||
211 | newtype ProtocolString = ProtocolString BS.ByteString | ||
212 | deriving (Eq, Ord, Typeable) | ||
213 | |||
214 | -- | In BTP/1.0 the name is 'BitTorrent protocol'. If this string is | ||
215 | -- different from the local peers own protocol name, then the | ||
216 | -- connection is to be dropped. | ||
217 | instance Default ProtocolString where | ||
218 | def = ProtocolString "BitTorrent protocol" | ||
219 | |||
220 | instance Show ProtocolString where | ||
221 | show (ProtocolString bs) = show bs | ||
222 | |||
223 | instance Pretty ProtocolString where | ||
224 | pretty (ProtocolString bs) = PP.text $ BC.unpack bs | ||
225 | |||
226 | instance IsString ProtocolString where | ||
227 | fromString str | ||
228 | | L.length str <= fromIntegral maxProtocolStringSize | ||
229 | = ProtocolString (fromString str) | ||
230 | | otherwise = error $ "fromString: ProtocolString too long: " ++ str | ||
231 | |||
232 | instance Serialize ProtocolString where | ||
233 | put (ProtocolString bs) = do | ||
234 | putWord8 $ fromIntegral $ BS.length bs | ||
235 | putByteString bs | ||
236 | |||
237 | get = do | ||
238 | len <- getWord8 | ||
239 | bs <- getByteString $ fromIntegral len | ||
240 | return (ProtocolString bs) | ||
241 | |||
206 | -- | Handshake message is used to exchange all information necessary | 242 | -- | Handshake message is used to exchange all information necessary |
207 | -- to establish connection between peers. | 243 | -- to establish connection between peers. |
208 | -- | 244 | -- |
209 | data Handshake = Handshake { | 245 | data Handshake = Handshake { |
210 | -- | Identifier of the protocol. This is usually equal to defaultProtocol | 246 | -- | Identifier of the protocol. This is usually equal to defaultProtocol |
211 | hsProtocol :: BS.ByteString | 247 | hsProtocol :: ProtocolString |
212 | 248 | ||
213 | -- | Reserved bytes used to specify supported BEP's. | 249 | -- | Reserved bytes used to specify supported BEP's. |
214 | , hsReserved :: Caps | 250 | , hsReserved :: Caps |
@@ -229,23 +265,16 @@ data Handshake = Handshake { | |||
229 | 265 | ||
230 | instance Serialize Handshake where | 266 | instance Serialize Handshake where |
231 | put Handshake {..} = do | 267 | put Handshake {..} = do |
232 | S.putWord8 (fromIntegral (BS.length hsProtocol)) | 268 | put hsProtocol |
233 | S.putByteString hsProtocol | 269 | put hsReserved |
234 | S.put hsReserved | 270 | put hsInfoHash |
235 | S.put hsInfoHash | 271 | put hsPeerId |
236 | S.put hsPeerId | 272 | get = Handshake <$> get <*> get <*> get <*> get |
237 | |||
238 | get = do | ||
239 | len <- S.getWord8 | ||
240 | Handshake <$> S.getBytes (fromIntegral len) | ||
241 | <*> S.get | ||
242 | <*> S.get | ||
243 | <*> S.get | ||
244 | 273 | ||
245 | -- | Show handshake protocol string, caps and fingerprint. | 274 | -- | Show handshake protocol string, caps and fingerprint. |
246 | instance Pretty Handshake where | 275 | instance Pretty Handshake where |
247 | pretty Handshake {..} | 276 | pretty Handshake {..} |
248 | = text (BC.unpack hsProtocol) $$ | 277 | = pretty hsProtocol $$ |
249 | pretty hsReserved $$ | 278 | pretty hsReserved $$ |
250 | pretty (fingerprint hsPeerId) | 279 | pretty (fingerprint hsPeerId) |
251 | 280 | ||
@@ -256,15 +285,11 @@ handshakeSize n = 1 + fromIntegral n + 8 + 20 + 20 | |||
256 | 285 | ||
257 | -- | Maximum size of handshake message in bytes. | 286 | -- | Maximum size of handshake message in bytes. |
258 | handshakeMaxSize :: Int | 287 | handshakeMaxSize :: Int |
259 | handshakeMaxSize = handshakeSize maxBound | 288 | handshakeMaxSize = handshakeSize maxProtocolStringSize |
260 | |||
261 | -- | Default protocol string "BitTorrent protocol" as is. | ||
262 | defaultBTProtocol :: BS.ByteString | ||
263 | defaultBTProtocol = "BitTorrent protocol" | ||
264 | 289 | ||
265 | -- | Handshake with default protocol string and reserved bitmask. | 290 | -- | Handshake with default protocol string and reserved bitmask. |
266 | defaultHandshake :: InfoHash -> PeerId -> Handshake | 291 | defaultHandshake :: InfoHash -> PeerId -> Handshake |
267 | defaultHandshake = Handshake defaultBTProtocol def | 292 | defaultHandshake = Handshake def def |
268 | 293 | ||
269 | {----------------------------------------------------------------------- | 294 | {----------------------------------------------------------------------- |
270 | -- Regular messages | 295 | -- Regular messages |
diff --git a/tests/Network/BitTorrent/Exchange/MessageSpec.hs b/tests/Network/BitTorrent/Exchange/MessageSpec.hs index 8d1041dd..38a20112 100644 --- a/tests/Network/BitTorrent/Exchange/MessageSpec.hs +++ b/tests/Network/BitTorrent/Exchange/MessageSpec.hs | |||
@@ -1,10 +1,12 @@ | |||
1 | module Network.BitTorrent.Exchange.MessageSpec (spec) where | 1 | module Network.BitTorrent.Exchange.MessageSpec (spec) where |
2 | import Control.Applicative | 2 | import Control.Applicative |
3 | import Control.Exception | ||
3 | import Data.ByteString as BS | 4 | import Data.ByteString as BS |
4 | import Data.Default | 5 | import Data.Default |
5 | import Data.List as L | 6 | import Data.List as L |
6 | import Data.Set as S | 7 | import Data.Set as S |
7 | import Data.Serialize as S | 8 | import Data.Serialize as S |
9 | import Data.String | ||
8 | import Test.Hspec | 10 | import Test.Hspec |
9 | import Test.QuickCheck | 11 | import Test.QuickCheck |
10 | 12 | ||
@@ -19,6 +21,9 @@ instance Arbitrary Extension where | |||
19 | instance Arbitrary Caps where | 21 | instance Arbitrary Caps where |
20 | arbitrary = toCaps <$> arbitrary | 22 | arbitrary = toCaps <$> arbitrary |
21 | 23 | ||
24 | instance Arbitrary ProtocolString where | ||
25 | arbitrary = fromString <$> (arbitrary `suchThat` ((200 <) . L.length)) | ||
26 | |||
22 | instance Arbitrary Handshake where | 27 | instance Arbitrary Handshake where |
23 | arbitrary = Handshake <$> arbitrary <*> arbitrary | 28 | arbitrary = Handshake <$> arbitrary <*> arbitrary |
24 | <*> arbitrary <*> arbitrary | 29 | <*> arbitrary <*> arbitrary |
@@ -33,11 +38,13 @@ spec = do | |||
33 | S.fromList (fromCaps (toCaps (S.toList extSet) :: Caps)) | 38 | S.fromList (fromCaps (toCaps (S.toList extSet) :: Caps)) |
34 | `shouldBe` extSet | 39 | `shouldBe` extSet |
35 | 40 | ||
41 | describe "ProtocolString" $ do | ||
42 | it "fail to construct invalid string" $ do | ||
43 | let str = L.replicate 500 'x' | ||
44 | evaluate (fromString str :: ProtocolString) | ||
45 | `shouldThrow` | ||
46 | errorCall ("fromString: ProtocolString too long: " ++ str) | ||
47 | |||
36 | describe "Handshake" $ do | 48 | describe "Handshake" $ do |
37 | it "properly serialized" $ property $ \ hs -> | 49 | it "properly serialized" $ property $ \ hs -> |
38 | S.decode (S.encode hs ) `shouldBe` Right (hs :: Handshake) | 50 | S.decode (S.encode hs ) `shouldBe` Right (hs :: Handshake) |
39 | |||
40 | it "fail if protocol string is too long" $ do | ||
41 | pid <- genPeerId | ||
42 | let hs = (defaultHandshake def pid) {hsProtocol = BS.replicate 256 0} | ||
43 | S.decode (S.encode hs) `shouldBe` Right hs \ No newline at end of file | ||