diff options
-rw-r--r-- | src/Network/BitTorrent/PeerWire/Bitfield.hs | 143 |
1 files changed, 116 insertions, 27 deletions
diff --git a/src/Network/BitTorrent/PeerWire/Bitfield.hs b/src/Network/BitTorrent/PeerWire/Bitfield.hs index 2d2bbd59..b375c1f5 100644 --- a/src/Network/BitTorrent/PeerWire/Bitfield.hs +++ b/src/Network/BitTorrent/PeerWire/Bitfield.hs | |||
@@ -10,7 +10,9 @@ | |||
10 | -- piece indexes any peer have. All associated operations should be | 10 | -- piece indexes any peer have. All associated operations should be |
11 | -- defined here as well. | 11 | -- defined here as well. |
12 | -- | 12 | -- |
13 | {-# LANGUAGE BangPatterns #-} | ||
13 | module Network.BitTorrent.PeerWire.Bitfield | 14 | module Network.BitTorrent.PeerWire.Bitfield |
15 | -- TODO: move to Data.Bitfield | ||
14 | ( Bitfield(..) | 16 | ( Bitfield(..) |
15 | 17 | ||
16 | -- * Construction | 18 | -- * Construction |
@@ -19,11 +21,12 @@ module Network.BitTorrent.PeerWire.Bitfield | |||
19 | 21 | ||
20 | -- * Query | 22 | -- * Query |
21 | , findMin, findMax | 23 | , findMin, findMax |
22 | , union, intersection, difference | 24 | , union, intersection, difference, combine |
23 | , frequencies | 25 | , frequencies |
24 | 26 | ||
25 | -- * Serialization | 27 | -- * Serialization |
26 | , getBitfield, putBitfield, bitfieldByteCount | 28 | , getBitfield, putBitfield, bitfieldByteCount |
29 | , aligned | ||
27 | ) where | 30 | ) where |
28 | 31 | ||
29 | import Control.Applicative hiding (empty) | 32 | import Control.Applicative hiding (empty) |
@@ -31,11 +34,14 @@ import Data.Array.Unboxed | |||
31 | import Data.Bits | 34 | import Data.Bits |
32 | import Data.ByteString (ByteString) | 35 | import Data.ByteString (ByteString) |
33 | import qualified Data.ByteString as B | 36 | import qualified Data.ByteString as B |
37 | import qualified Data.ByteString.Internal as B | ||
34 | import Data.List as L hiding (union) | 38 | import Data.List as L hiding (union) |
35 | import Data.Maybe | 39 | import Data.Maybe |
36 | import Data.Serialize | 40 | import Data.Serialize |
37 | import Data.Word | 41 | import Data.Word |
38 | 42 | ||
43 | import Foreign | ||
44 | |||
39 | import Network.BitTorrent.PeerWire.Block | 45 | import Network.BitTorrent.PeerWire.Block |
40 | import Data.Torrent | 46 | import Data.Torrent |
41 | 47 | ||
@@ -61,25 +67,115 @@ toByteString :: Bitfield -> ByteString | |||
61 | toByteString = bfBits | 67 | toByteString = bfBits |
62 | {-# INLINE toByteString #-} | 68 | {-# INLINE toByteString #-} |
63 | 69 | ||
64 | combine :: [ByteString] -> Maybe ByteString | 70 | getBitfield :: Int -> Get Bitfield |
65 | combine [] = Nothing | 71 | getBitfield n = MkBitfield <$> getBytes n |
66 | combine as@(a : _) = return $ foldr andBS empty as | 72 | {-# INLINE getBitfield #-} |
67 | where | ||
68 | andBS x acc = B.pack (B.zipWith (.&.) x acc) | ||
69 | empty = B.replicate (B.length a) 0 | ||
70 | 73 | ||
71 | frequencies :: [Bitfield] -> UArray PieceIx Int | 74 | putBitfield :: Bitfield -> Put |
72 | frequencies = undefined | 75 | putBitfield = putByteString . bfBits |
76 | {-# INLINE putBitfield #-} | ||
77 | |||
78 | bitfieldByteCount :: Bitfield -> Int | ||
79 | bitfieldByteCount = B.length . bfBits | ||
80 | {-# INLINE bitfieldByteCount #-} | ||
81 | |||
82 | |||
83 | |||
84 | type Mem a = (Ptr a, Int) | ||
85 | |||
86 | aligned :: Storable a => Mem Word8 -> (Mem Word8, Mem a, Mem Word8) | ||
87 | aligned (ptr, len) = | ||
88 | let lowPtr = ptr | ||
89 | lowLen = midPtr `minusPtr` ptr | ||
90 | midOff = lowLen | ||
91 | (midPtr, alg) = align (castPtr ptr) | ||
92 | midLen = alg * div (len - midOff) alg | ||
93 | midLenA = midLen `div` alg | ||
94 | hghOff = midOff + midLen | ||
95 | hghPtr = ptr `advancePtr` hghOff | ||
96 | hghLen = len - hghOff | ||
97 | in | ||
98 | ((lowPtr, lowLen), (midPtr, midLenA), (hghPtr, hghLen)) | ||
99 | where | ||
100 | align :: Storable a => Ptr a -> (Ptr a, Int) | ||
101 | align p = tie (alignPtr p) undefined | ||
102 | where | ||
103 | tie :: Storable a => (Int -> Ptr a) -> a -> (Ptr a, Int) | ||
104 | tie f a = (f (alignment a), (alignment a)) | ||
105 | {-# INLINE aligned #-} | ||
106 | |||
107 | zipWithBS :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> ByteString | ||
108 | zipWithBS f a b = | ||
109 | let (afptr, aoff, asize) = B.toForeignPtr a | ||
110 | (bfptr, boff, bsize) = B.toForeignPtr b | ||
111 | size = min asize bsize in | ||
112 | B.unsafeCreate size $ \ptr -> do | ||
113 | withForeignPtr afptr $ \aptr -> do | ||
114 | withForeignPtr bfptr $ \bptr -> | ||
115 | zipBytes (aptr `plusPtr` aoff) (bptr `plusPtr` boff) ptr size | ||
116 | where | ||
117 | zipBytes :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO () | ||
118 | zipBytes aptr bptr rptr n = go 0 | ||
119 | where | ||
120 | go :: Int -> IO () | ||
121 | go i | i < n = do -- TODO unfold | ||
122 | av <- peekByteOff aptr i | ||
123 | bv <- peekByteOff bptr i | ||
124 | pokeByteOff rptr i (f av bv) | ||
125 | go (succ i) | ||
126 | | otherwise = return () | ||
73 | 127 | ||
74 | zipWithBF :: (Word8 -> Word8 -> Word8) -> Bitfield -> Bitfield -> Bitfield | 128 | zipWithBF :: (Word8 -> Word8 -> Word8) -> Bitfield -> Bitfield -> Bitfield |
75 | zipWithBF f a b = MkBitfield $ B.pack $ B.zipWith f (bfBits a) (bfBits b) | 129 | zipWithBF f a b = MkBitfield $ zipWithBS f (bfBits a) (bfBits b) |
76 | {-# INLINE zipWithBF #-} | 130 | {-# INLINE zipWithBF #-} |
77 | 131 | ||
132 | findSet :: ByteString -> Maybe Int | ||
133 | findSet b = | ||
134 | let (fptr, off, len) = B.toForeignPtr b in | ||
135 | B.inlinePerformIO $ withForeignPtr fptr $ \_ptr -> do | ||
136 | let ptr = _ptr `advancePtr` off | ||
137 | |||
138 | let (low, mid, hgh) = aligned (ptr, len) | ||
139 | let lowOff = fst low `minusPtr` ptr | ||
140 | let midOff = fst mid `minusPtr` ptr | ||
141 | let hghOff = fst hgh `minusPtr` ptr | ||
142 | |||
143 | let resL = (lowOff +) <$> goFind low | ||
144 | let resM = (midOff +) <$> goFind (mid :: Mem Word) -- tune size here | ||
145 | -- TODO: with Word8 | ||
146 | -- bytestring findIndex works 2 | ||
147 | -- times faster. | ||
148 | let resH = (hghOff +) <$> goFind hgh | ||
149 | |||
150 | let res = resL <|> resM <|> resH | ||
151 | |||
152 | -- computation of res should not escape withForeignPtr | ||
153 | case res of | ||
154 | Nothing -> return () | ||
155 | Just _ -> return () | ||
156 | |||
157 | return res | ||
158 | |||
159 | where | ||
160 | goFind :: (Storable a, Eq a, Num a) => Mem a -> Maybe Int | ||
161 | goFind (ptr, n) = go 0 | ||
162 | where | ||
163 | go :: Int -> Maybe Int | ||
164 | go i | i < n = | ||
165 | let v = B.inlinePerformIO (peekElemOff ptr i) in | ||
166 | if v /= 0 | ||
167 | then Just i | ||
168 | else go (succ i) | ||
169 | | otherwise = Nothing | ||
170 | |||
171 | |||
78 | union :: Bitfield -> Bitfield -> Bitfield | 172 | union :: Bitfield -> Bitfield -> Bitfield |
79 | union = zipWithBF (.|.) | 173 | union = zipWithBF (.|.) |
174 | {-# INLINE union #-} | ||
80 | 175 | ||
81 | intersection :: Bitfield -> Bitfield -> Bitfield | 176 | intersection :: Bitfield -> Bitfield -> Bitfield |
82 | intersection = zipWithBF (.&.) | 177 | intersection = zipWithBF (.&.) |
178 | {-# INLINE intersection #-} | ||
83 | 179 | ||
84 | difference :: Bitfield -> Bitfield -> Bitfield | 180 | difference :: Bitfield -> Bitfield -> Bitfield |
85 | difference = zipWithBF diffWord8 | 181 | difference = zipWithBF diffWord8 |
@@ -89,45 +185,38 @@ difference = zipWithBF diffWord8 | |||
89 | {-# INLINE diffWord8 #-} | 185 | {-# INLINE diffWord8 #-} |
90 | {-# INLINE difference #-} | 186 | {-# INLINE difference #-} |
91 | 187 | ||
92 | 188 | combine :: [Bitfield] -> Maybe Bitfield | |
189 | combine [] = Nothing | ||
190 | combine as = return $ foldr1 intersection as | ||
93 | 191 | ||
94 | -- | Get min index of piece that the peer have. | 192 | -- | Get min index of piece that the peer have. |
95 | findMin :: Bitfield -> Maybe PieceIx | 193 | findMin :: Bitfield -> Maybe PieceIx |
96 | findMin (MkBitfield b) = do | 194 | findMin (MkBitfield b) = do |
97 | byteIx <- B.findIndex (0 /=) b | 195 | byteIx <- findSet b |
98 | bitIx <- findMinWord8 (B.index b byteIx) | 196 | bitIx <- findMinWord8 (B.index b byteIx) |
99 | return $ byteIx * bitSize (undefined :: Word8) + bitIx | 197 | return $ byteIx * bitSize (undefined :: Word8) + bitIx |
100 | where | 198 | where |
101 | -- TODO: bit tricks | 199 | -- TODO: bit tricks |
102 | findMinWord8 :: Word8 -> Maybe Int | 200 | findMinWord8 :: Word8 -> Maybe Int |
103 | findMinWord8 b = L.find (testBit b) [0..bitSize (undefined :: Word8) - 1] | 201 | findMinWord8 byte = L.find (testBit byte) [0..bitSize (undefined :: Word8) - 1] |
104 | {-# INLINE findMinWord8 #-} | 202 | {-# INLINE findMinWord8 #-} |
105 | {-# INLINE findMin #-} | 203 | {-# INLINE findMin #-} |
106 | 204 | ||
107 | 205 | ||
108 | findMax :: Bitfield -> Maybe PieceIx | 206 | findMax :: Bitfield -> Maybe PieceIx |
109 | findMax (MkBitfield b) = do | 207 | findMax (MkBitfield b) = do |
110 | byteIx <- (pred (B.length b) -) <$> B.findIndex (0 /=) (B.reverse b) | 208 | -- TODO avoid reverse |
209 | byteIx <- (pred (B.length b) -) <$> findSet (B.reverse b) | ||
111 | bitIx <- findMaxWord8 (B.index b byteIx) | 210 | bitIx <- findMaxWord8 (B.index b byteIx) |
112 | return $ byteIx * bitSize (undefined :: Word8) + bitIx | 211 | return $ byteIx * bitSize (undefined :: Word8) + bitIx |
113 | where | 212 | where |
114 | -- TODO: bit tricks | 213 | -- TODO: bit tricks |
115 | findMaxWord8 :: Word8 -> Maybe Int | 214 | findMaxWord8 :: Word8 -> Maybe Int |
116 | findMaxWord8 b = L.find (testBit b) | 215 | findMaxWord8 byte = L.find (testBit byte) |
117 | (reverse [0 :: Int .. | 216 | (reverse [0 :: Int .. |
118 | bitSize (undefined :: Word8) - 1]) | 217 | bitSize (undefined :: Word8) - 1]) |
119 | 218 | ||
120 | {-# INLINE findMax #-} | 219 | {-# INLINE findMax #-} |
121 | 220 | ||
122 | 221 | frequencies :: [Bitfield] -> UArray PieceIx Int | |
123 | getBitfield :: Int -> Get Bitfield | 222 | frequencies = undefined |
124 | getBitfield n = MkBitfield <$> getBytes n | ||
125 | {-# INLINE getBitfield #-} | ||
126 | |||
127 | putBitfield :: Bitfield -> Put | ||
128 | putBitfield = putByteString . bfBits | ||
129 | {-# INLINE putBitfield #-} | ||
130 | |||
131 | bitfieldByteCount :: Bitfield -> Int | ||
132 | bitfieldByteCount = B.length . bfBits | ||
133 | {-# INLINE bitfieldByteCount #-} | ||