From 7d5cf5919c2bd11f835baf243bca341521c03879 Mon Sep 17 00:00:00 2001 From: Sam Truzjan Date: Fri, 6 Dec 2013 04:22:36 +0400 Subject: Unify capabilities operations --- src/Network/BitTorrent/Exchange/Message.hs | 82 +++++++++++++++++------------- src/Network/BitTorrent/Exchange/Wire.hs | 8 +-- 2 files changed, 52 insertions(+), 38 deletions(-) diff --git a/src/Network/BitTorrent/Exchange/Message.hs b/src/Network/BitTorrent/Exchange/Message.hs index 33937a93..c614b1ae 100644 --- a/src/Network/BitTorrent/Exchange/Message.hs +++ b/src/Network/BitTorrent/Exchange/Message.hs @@ -27,17 +27,17 @@ -- {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE TemplateHaskell #-} {-# OPTIONS -fno-warn-orphans #-} module Network.BitTorrent.Exchange.Message ( -- * Capabilities - Extension (..) + Capabilities (..) + , Extension (..) , Caps - , toCaps - , fromCaps - , allowed -- * Handshake , Handshake(..) @@ -65,9 +65,6 @@ module Network.BitTorrent.Exchange.Message -- *** Capabilities , ExtendedExtension (..) , ExtendedCaps (..) - , toExtCaps - , fromExtCaps - , extendedAllowed -- *** Handshake , ExtendedHandshake (..) @@ -108,6 +105,26 @@ import Data.Torrent.InfoHash import Network.BitTorrent.Core import Network.BitTorrent.Exchange.Block +{----------------------------------------------------------------------- +-- Capabilities +-----------------------------------------------------------------------} + +-- | +class Capabilities caps where + type Ext caps :: * + + -- | Pack extensions to caps. + toCaps :: [Ext caps] -> caps + + -- | Unpack extensions from caps. + fromCaps :: caps -> [Ext caps] + + -- | Check if an extension is a member of the specified set. + allowed :: Ext caps -> caps -> Bool + +ppCaps :: Capabilities caps => Pretty (Ext caps) => caps -> Doc +ppCaps = hcat . punctuate ", " . L.map pretty . fromCaps + {----------------------------------------------------------------------- -- Extensions -----------------------------------------------------------------------} @@ -129,10 +146,10 @@ instance Pretty Extension where pretty ExtExtended = "Extension Protocol" -- | Extension bitmask as specified by BEP 4. -capMask :: Extension -> Caps -capMask ExtDHT = Caps 0x01 -capMask ExtFast = Caps 0x04 -capMask ExtExtended = Caps 0x100000 +extMask :: Extension -> Word64 +extMask ExtDHT = 0x01 +extMask ExtFast = 0x04 +extMask ExtExtended = 0x100000 {----------------------------------------------------------------------- -- Capabilities @@ -140,12 +157,13 @@ capMask ExtExtended = Caps 0x100000 -- | Capabilities is a set of 'Extension's usually sent in 'Handshake' -- messages. -newtype Caps = Caps { unCaps :: Word64 } +newtype Caps = Caps Word64 deriving (Show, Eq) -- | Render set of extensions as comma separated list. instance Pretty Caps where - pretty = hcat . punctuate ", " . L.map pretty . fromCaps + pretty = ppCaps + {-# INLINE pretty #-} -- | The empty set. instance Default Caps where @@ -168,19 +186,14 @@ instance Serialize Caps where get = Caps <$> S.getWord64be {-# INLINE get #-} --- | Check if an extension is a member of the specified set. -allowed :: Caps -> Extension -> Bool -allowed (Caps caps) = testMask . capMask - where - testMask (Caps bits) = (bits .&. caps) == bits +instance Capabilities Caps where + type Ext Caps = Extension --- | Pack extensions to caps. -toCaps :: [Extension] -> Caps -toCaps = Caps . L.foldr (.|.) 0 . L.map (unCaps . capMask) + allowed e (Caps caps) = (extMask e .&. caps) /= 0 + {-# INLINE allowed #-} --- | Unpack extensions from caps. -fromCaps :: Caps -> [Extension] -fromCaps caps = L.filter (allowed caps) [minBound..maxBound] + toCaps = Caps . L.foldr (.|.) 0 . L.map extMask + fromCaps caps = L.filter (`allowed` caps) [minBound..maxBound] {----------------------------------------------------------------------- Handshake @@ -449,7 +462,8 @@ newtype ExtendedCaps = ExtendedCaps { extendedCaps :: ExtendedMap } deriving (Show, Eq) instance Pretty ExtendedCaps where - pretty = hcat . punctuate ", " . L.map pretty . fromExtCaps + pretty = ppCaps + {-# INLINE pretty #-} -- | The empty set. instance Default ExtendedCaps where @@ -463,7 +477,7 @@ instance Default ExtendedCaps where -- id from the first caps for the extensions existing in both caps. -- instance Monoid ExtendedCaps where - mempty = toExtCaps [minBound..maxBound] + mempty = toCaps [minBound..maxBound] mappend (ExtendedCaps a) (ExtendedCaps b) = ExtendedCaps (M.intersection a b) @@ -482,16 +496,16 @@ instance BEncode ExtendedCaps where fromBEncode (BDict bd) = pure $ ExtendedCaps $ appendBDict bd M.empty fromBEncode _ = decodingError "ExtendedCaps" -toExtCaps :: [ExtendedExtension] -> ExtendedCaps -toExtCaps = ExtendedCaps . M.fromList . L.map (id &&& extId) +instance Capabilities ExtendedCaps where + type Ext ExtendedCaps = ExtendedExtension + + toCaps = ExtendedCaps . M.fromList . L.map (id &&& extId) -fromExtCaps :: ExtendedCaps -> [ExtendedExtension] -fromExtCaps = M.keys . extendedCaps -{-# INLINE fromExtCaps #-} + fromCaps = M.keys . extendedCaps + {-# INLINE fromCaps #-} -extendedAllowed :: ExtendedExtension -> ExtendedCaps -> Bool -extendedAllowed e (ExtendedCaps caps) = M.member e caps -{-# INLINE extendedAllowed #-} + allowed e (ExtendedCaps caps) = M.member e caps + {-# INLINE allowed #-} {----------------------------------------------------------------------- -- Extended handshake diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index 8f6e1d58..6f80a567 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs @@ -114,7 +114,7 @@ instance Pretty Connection where isAllowed :: Connection -> Message -> Bool isAllowed Connection {..} msg - | Just ext <- requires msg = allowed connCaps ext + | Just ext <- requires msg = ext `allowed` connCaps | otherwise = True {----------------------------------------------------------------------- @@ -189,8 +189,8 @@ validate side = await >>= maybe (return ()) yieldCheck case requires msg of Nothing -> return () Just ext - | allowed caps ext -> yield msg - | otherwise -> protocolError $ InvalidMessage side ext + | ext `allowed` caps -> yield msg + | otherwise -> protocolError $ InvalidMessage side ext validateBoth :: Wire () -> Wire () validateBoth action = do @@ -235,7 +235,7 @@ connectWire hs addr extCaps wire = throwIO $ ProtocolError $ UnexpectedPeerId (hsPeerId hs') let caps = hsReserved hs <> hsReserved hs' - let wire' = if allowed caps ExtExtended + let wire' = if ExtExtended `allowed` caps then extendedHandshake extCaps >> wire else wire -- cgit v1.2.3