diff options
author | Sam Truzjan <pxqr.sta@gmail.com> | 2013-12-06 04:22:36 +0400 |
---|---|---|
committer | Sam Truzjan <pxqr.sta@gmail.com> | 2013-12-06 04:22:36 +0400 |
commit | 7d5cf5919c2bd11f835baf243bca341521c03879 (patch) | |
tree | 46e9294e478de6b84be870899a84004b0c82488c /src | |
parent | be414f0ef8d2bd5078177b7334045b3b7eedc482 (diff) |
Unify capabilities operations
Diffstat (limited to 'src')
-rw-r--r-- | src/Network/BitTorrent/Exchange/Message.hs | 82 | ||||
-rw-r--r-- | src/Network/BitTorrent/Exchange/Wire.hs | 8 |
2 files changed, 52 insertions, 38 deletions
diff --git a/src/Network/BitTorrent/Exchange/Message.hs b/src/Network/BitTorrent/Exchange/Message.hs index 33937a93..c614b1ae 100644 --- a/src/Network/BitTorrent/Exchange/Message.hs +++ b/src/Network/BitTorrent/Exchange/Message.hs | |||
@@ -27,17 +27,17 @@ | |||
27 | -- | 27 | -- |
28 | {-# LANGUAGE ViewPatterns #-} | 28 | {-# LANGUAGE ViewPatterns #-} |
29 | {-# LANGUAGE FlexibleInstances #-} | 29 | {-# LANGUAGE FlexibleInstances #-} |
30 | {-# LANGUAGE FlexibleContexts #-} | ||
31 | {-# LANGUAGE TypeFamilies #-} | ||
30 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | 32 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} |
31 | {-# LANGUAGE DeriveDataTypeable #-} | 33 | {-# LANGUAGE DeriveDataTypeable #-} |
32 | {-# LANGUAGE TemplateHaskell #-} | 34 | {-# LANGUAGE TemplateHaskell #-} |
33 | {-# OPTIONS -fno-warn-orphans #-} | 35 | {-# OPTIONS -fno-warn-orphans #-} |
34 | module Network.BitTorrent.Exchange.Message | 36 | module Network.BitTorrent.Exchange.Message |
35 | ( -- * Capabilities | 37 | ( -- * Capabilities |
36 | Extension (..) | 38 | Capabilities (..) |
39 | , Extension (..) | ||
37 | , Caps | 40 | , Caps |
38 | , toCaps | ||
39 | , fromCaps | ||
40 | , allowed | ||
41 | 41 | ||
42 | -- * Handshake | 42 | -- * Handshake |
43 | , Handshake(..) | 43 | , Handshake(..) |
@@ -65,9 +65,6 @@ module Network.BitTorrent.Exchange.Message | |||
65 | -- *** Capabilities | 65 | -- *** Capabilities |
66 | , ExtendedExtension (..) | 66 | , ExtendedExtension (..) |
67 | , ExtendedCaps (..) | 67 | , ExtendedCaps (..) |
68 | , toExtCaps | ||
69 | , fromExtCaps | ||
70 | , extendedAllowed | ||
71 | 68 | ||
72 | -- *** Handshake | 69 | -- *** Handshake |
73 | , ExtendedHandshake (..) | 70 | , ExtendedHandshake (..) |
@@ -109,6 +106,26 @@ import Network.BitTorrent.Core | |||
109 | import Network.BitTorrent.Exchange.Block | 106 | import Network.BitTorrent.Exchange.Block |
110 | 107 | ||
111 | {----------------------------------------------------------------------- | 108 | {----------------------------------------------------------------------- |
109 | -- Capabilities | ||
110 | -----------------------------------------------------------------------} | ||
111 | |||
112 | -- | | ||
113 | class Capabilities caps where | ||
114 | type Ext caps :: * | ||
115 | |||
116 | -- | Pack extensions to caps. | ||
117 | toCaps :: [Ext caps] -> caps | ||
118 | |||
119 | -- | Unpack extensions from caps. | ||
120 | fromCaps :: caps -> [Ext caps] | ||
121 | |||
122 | -- | Check if an extension is a member of the specified set. | ||
123 | allowed :: Ext caps -> caps -> Bool | ||
124 | |||
125 | ppCaps :: Capabilities caps => Pretty (Ext caps) => caps -> Doc | ||
126 | ppCaps = hcat . punctuate ", " . L.map pretty . fromCaps | ||
127 | |||
128 | {----------------------------------------------------------------------- | ||
112 | -- Extensions | 129 | -- Extensions |
113 | -----------------------------------------------------------------------} | 130 | -----------------------------------------------------------------------} |
114 | 131 | ||
@@ -129,10 +146,10 @@ instance Pretty Extension where | |||
129 | pretty ExtExtended = "Extension Protocol" | 146 | pretty ExtExtended = "Extension Protocol" |
130 | 147 | ||
131 | -- | Extension bitmask as specified by BEP 4. | 148 | -- | Extension bitmask as specified by BEP 4. |
132 | capMask :: Extension -> Caps | 149 | extMask :: Extension -> Word64 |
133 | capMask ExtDHT = Caps 0x01 | 150 | extMask ExtDHT = 0x01 |
134 | capMask ExtFast = Caps 0x04 | 151 | extMask ExtFast = 0x04 |
135 | capMask ExtExtended = Caps 0x100000 | 152 | extMask ExtExtended = 0x100000 |
136 | 153 | ||
137 | {----------------------------------------------------------------------- | 154 | {----------------------------------------------------------------------- |
138 | -- Capabilities | 155 | -- Capabilities |
@@ -140,12 +157,13 @@ capMask ExtExtended = Caps 0x100000 | |||
140 | 157 | ||
141 | -- | Capabilities is a set of 'Extension's usually sent in 'Handshake' | 158 | -- | Capabilities is a set of 'Extension's usually sent in 'Handshake' |
142 | -- messages. | 159 | -- messages. |
143 | newtype Caps = Caps { unCaps :: Word64 } | 160 | newtype Caps = Caps Word64 |
144 | deriving (Show, Eq) | 161 | deriving (Show, Eq) |
145 | 162 | ||
146 | -- | Render set of extensions as comma separated list. | 163 | -- | Render set of extensions as comma separated list. |
147 | instance Pretty Caps where | 164 | instance Pretty Caps where |
148 | pretty = hcat . punctuate ", " . L.map pretty . fromCaps | 165 | pretty = ppCaps |
166 | {-# INLINE pretty #-} | ||
149 | 167 | ||
150 | -- | The empty set. | 168 | -- | The empty set. |
151 | instance Default Caps where | 169 | instance Default Caps where |
@@ -168,19 +186,14 @@ instance Serialize Caps where | |||
168 | get = Caps <$> S.getWord64be | 186 | get = Caps <$> S.getWord64be |
169 | {-# INLINE get #-} | 187 | {-# INLINE get #-} |
170 | 188 | ||
171 | -- | Check if an extension is a member of the specified set. | 189 | instance Capabilities Caps where |
172 | allowed :: Caps -> Extension -> Bool | 190 | type Ext Caps = Extension |
173 | allowed (Caps caps) = testMask . capMask | ||
174 | where | ||
175 | testMask (Caps bits) = (bits .&. caps) == bits | ||
176 | 191 | ||
177 | -- | Pack extensions to caps. | 192 | allowed e (Caps caps) = (extMask e .&. caps) /= 0 |
178 | toCaps :: [Extension] -> Caps | 193 | {-# INLINE allowed #-} |
179 | toCaps = Caps . L.foldr (.|.) 0 . L.map (unCaps . capMask) | ||
180 | 194 | ||
181 | -- | Unpack extensions from caps. | 195 | toCaps = Caps . L.foldr (.|.) 0 . L.map extMask |
182 | fromCaps :: Caps -> [Extension] | 196 | fromCaps caps = L.filter (`allowed` caps) [minBound..maxBound] |
183 | fromCaps caps = L.filter (allowed caps) [minBound..maxBound] | ||
184 | 197 | ||
185 | {----------------------------------------------------------------------- | 198 | {----------------------------------------------------------------------- |
186 | Handshake | 199 | Handshake |
@@ -449,7 +462,8 @@ newtype ExtendedCaps = ExtendedCaps { extendedCaps :: ExtendedMap } | |||
449 | deriving (Show, Eq) | 462 | deriving (Show, Eq) |
450 | 463 | ||
451 | instance Pretty ExtendedCaps where | 464 | instance Pretty ExtendedCaps where |
452 | pretty = hcat . punctuate ", " . L.map pretty . fromExtCaps | 465 | pretty = ppCaps |
466 | {-# INLINE pretty #-} | ||
453 | 467 | ||
454 | -- | The empty set. | 468 | -- | The empty set. |
455 | instance Default ExtendedCaps where | 469 | instance Default ExtendedCaps where |
@@ -463,7 +477,7 @@ instance Default ExtendedCaps where | |||
463 | -- id from the first caps for the extensions existing in both caps. | 477 | -- id from the first caps for the extensions existing in both caps. |
464 | -- | 478 | -- |
465 | instance Monoid ExtendedCaps where | 479 | instance Monoid ExtendedCaps where |
466 | mempty = toExtCaps [minBound..maxBound] | 480 | mempty = toCaps [minBound..maxBound] |
467 | mappend (ExtendedCaps a) (ExtendedCaps b) = | 481 | mappend (ExtendedCaps a) (ExtendedCaps b) = |
468 | ExtendedCaps (M.intersection a b) | 482 | ExtendedCaps (M.intersection a b) |
469 | 483 | ||
@@ -482,16 +496,16 @@ instance BEncode ExtendedCaps where | |||
482 | fromBEncode (BDict bd) = pure $ ExtendedCaps $ appendBDict bd M.empty | 496 | fromBEncode (BDict bd) = pure $ ExtendedCaps $ appendBDict bd M.empty |
483 | fromBEncode _ = decodingError "ExtendedCaps" | 497 | fromBEncode _ = decodingError "ExtendedCaps" |
484 | 498 | ||
485 | toExtCaps :: [ExtendedExtension] -> ExtendedCaps | 499 | instance Capabilities ExtendedCaps where |
486 | toExtCaps = ExtendedCaps . M.fromList . L.map (id &&& extId) | 500 | type Ext ExtendedCaps = ExtendedExtension |
501 | |||
502 | toCaps = ExtendedCaps . M.fromList . L.map (id &&& extId) | ||
487 | 503 | ||
488 | fromExtCaps :: ExtendedCaps -> [ExtendedExtension] | 504 | fromCaps = M.keys . extendedCaps |
489 | fromExtCaps = M.keys . extendedCaps | 505 | {-# INLINE fromCaps #-} |
490 | {-# INLINE fromExtCaps #-} | ||
491 | 506 | ||
492 | extendedAllowed :: ExtendedExtension -> ExtendedCaps -> Bool | 507 | allowed e (ExtendedCaps caps) = M.member e caps |
493 | extendedAllowed e (ExtendedCaps caps) = M.member e caps | 508 | {-# INLINE allowed #-} |
494 | {-# INLINE extendedAllowed #-} | ||
495 | 509 | ||
496 | {----------------------------------------------------------------------- | 510 | {----------------------------------------------------------------------- |
497 | -- Extended handshake | 511 | -- Extended handshake |
diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index 8f6e1d58..6f80a567 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs | |||
@@ -114,7 +114,7 @@ instance Pretty Connection where | |||
114 | 114 | ||
115 | isAllowed :: Connection -> Message -> Bool | 115 | isAllowed :: Connection -> Message -> Bool |
116 | isAllowed Connection {..} msg | 116 | isAllowed Connection {..} msg |
117 | | Just ext <- requires msg = allowed connCaps ext | 117 | | Just ext <- requires msg = ext `allowed` connCaps |
118 | | otherwise = True | 118 | | otherwise = True |
119 | 119 | ||
120 | {----------------------------------------------------------------------- | 120 | {----------------------------------------------------------------------- |
@@ -189,8 +189,8 @@ validate side = await >>= maybe (return ()) yieldCheck | |||
189 | case requires msg of | 189 | case requires msg of |
190 | Nothing -> return () | 190 | Nothing -> return () |
191 | Just ext | 191 | Just ext |
192 | | allowed caps ext -> yield msg | 192 | | ext `allowed` caps -> yield msg |
193 | | otherwise -> protocolError $ InvalidMessage side ext | 193 | | otherwise -> protocolError $ InvalidMessage side ext |
194 | 194 | ||
195 | validateBoth :: Wire () -> Wire () | 195 | validateBoth :: Wire () -> Wire () |
196 | validateBoth action = do | 196 | validateBoth action = do |
@@ -235,7 +235,7 @@ connectWire hs addr extCaps wire = | |||
235 | throwIO $ ProtocolError $ UnexpectedPeerId (hsPeerId hs') | 235 | throwIO $ ProtocolError $ UnexpectedPeerId (hsPeerId hs') |
236 | 236 | ||
237 | let caps = hsReserved hs <> hsReserved hs' | 237 | let caps = hsReserved hs <> hsReserved hs' |
238 | let wire' = if allowed caps ExtExtended | 238 | let wire' = if ExtExtended `allowed` caps |
239 | then extendedHandshake extCaps >> wire | 239 | then extendedHandshake extCaps >> wire |
240 | else wire | 240 | else wire |
241 | 241 | ||