diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Network/BitTorrent/Exchange/Message.hs | 15 | ||||
-rw-r--r-- | src/Network/BitTorrent/Exchange/Wire.hs | 140 |
2 files changed, 97 insertions, 58 deletions
diff --git a/src/Network/BitTorrent/Exchange/Message.hs b/src/Network/BitTorrent/Exchange/Message.hs index 070a0e42..e1e17e6e 100644 --- a/src/Network/BitTorrent/Exchange/Message.hs +++ b/src/Network/BitTorrent/Exchange/Message.hs | |||
@@ -45,6 +45,7 @@ module Network.BitTorrent.Exchange.Message | |||
45 | , defaultHandshake | 45 | , defaultHandshake |
46 | , handshakeSize | 46 | , handshakeSize |
47 | , handshakeMaxSize | 47 | , handshakeMaxSize |
48 | , handshakeStats | ||
48 | 49 | ||
49 | -- * Stats | 50 | -- * Stats |
50 | , ByteCount | 51 | , ByteCount |
@@ -296,6 +297,10 @@ handshakeMaxSize = handshakeSize maxProtocolNameSize | |||
296 | defaultHandshake :: InfoHash -> PeerId -> Handshake | 297 | defaultHandshake :: InfoHash -> PeerId -> Handshake |
297 | defaultHandshake = Handshake def def | 298 | defaultHandshake = Handshake def def |
298 | 299 | ||
300 | handshakeStats :: Handshake -> ByteStats | ||
301 | handshakeStats (Handshake (ProtocolName bs) _ _ _) | ||
302 | = ByteStats 1 (BS.length bs + 8 + 20 + 20) 0 | ||
303 | |||
299 | {----------------------------------------------------------------------- | 304 | {----------------------------------------------------------------------- |
300 | -- Stats | 305 | -- Stats |
301 | -----------------------------------------------------------------------} | 306 | -----------------------------------------------------------------------} |
@@ -320,6 +325,16 @@ data ByteStats = ByteStats | |||
320 | , payload :: {-# UNPACK #-} !ByteCount | 325 | , payload :: {-# UNPACK #-} !ByteCount |
321 | } deriving Show | 326 | } deriving Show |
322 | 327 | ||
328 | instance Pretty ByteStats where | ||
329 | pretty s @ ByteStats {..} = fsep | ||
330 | [ PP.int overhead, "overhead" | ||
331 | , PP.int control, "control" | ||
332 | , PP.int payload, "payload" | ||
333 | , "bytes" | ||
334 | ] $+$ fsep | ||
335 | [ PP.int (byteLength s), "total bytes" | ||
336 | ] | ||
337 | |||
323 | -- | Empty byte sequences. | 338 | -- | Empty byte sequences. |
324 | instance Default ByteStats where | 339 | instance Default ByteStats where |
325 | def = ByteStats 0 0 0 | 340 | def = ByteStats 0 0 0 |
diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index a0f683c8..239358d9 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs | |||
@@ -47,11 +47,6 @@ module Network.BitTorrent.Exchange.Wire | |||
47 | , getConnection | 47 | , getConnection |
48 | , getExtCaps | 48 | , getExtCaps |
49 | , getStats | 49 | , getStats |
50 | |||
51 | -- ** Conduits | ||
52 | , validate | ||
53 | , validateBoth | ||
54 | , trackStats | ||
55 | ) where | 50 | ) where |
56 | 51 | ||
57 | import Control.Applicative | 52 | import Control.Applicative |
@@ -59,7 +54,8 @@ import Control.Exception | |||
59 | import Control.Monad.Reader | 54 | import Control.Monad.Reader |
60 | import Data.ByteString as BS | 55 | import Data.ByteString as BS |
61 | import Data.Conduit | 56 | import Data.Conduit |
62 | import Data.Conduit.Cereal as S | 57 | import Data.Conduit.Cereal |
58 | import Data.Conduit.List | ||
63 | import Data.Conduit.Network | 59 | import Data.Conduit.Network |
64 | import Data.Default | 60 | import Data.Default |
65 | import Data.IORef | 61 | import Data.IORef |
@@ -185,21 +181,33 @@ instance Pretty WireFailure where | |||
185 | isWireFailure :: Monad m => WireFailure -> m () | 181 | isWireFailure :: Monad m => WireFailure -> m () |
186 | isWireFailure _ = return () | 182 | isWireFailure _ = return () |
187 | 183 | ||
184 | protocolError :: MonadThrow m => ProtocolError -> m a | ||
185 | protocolError = monadThrow . ProtocolError | ||
186 | |||
187 | -- | Forcefully terminate wire session and close socket. | ||
188 | disconnectPeer :: Wire a | ||
189 | disconnectPeer = monadThrow DisconnectPeer | ||
190 | |||
188 | {----------------------------------------------------------------------- | 191 | {----------------------------------------------------------------------- |
189 | -- Stats | 192 | -- Stats |
190 | -----------------------------------------------------------------------} | 193 | -----------------------------------------------------------------------} |
191 | 194 | ||
192 | -- | Message stats in one direction. | 195 | -- | Message stats in one direction. |
193 | data FlowStats = FlowStats | 196 | data FlowStats = FlowStats |
194 | { -- | Sum of byte sequences of all messages. | 197 | { -- | Number of the messages sent or received. |
195 | messageBytes :: {-# UNPACK #-} !ByteStats | 198 | messageCount :: {-# UNPACK #-} !Int |
196 | -- | Number of the messages sent or received. | 199 | -- | Sum of byte sequences of all messages. |
197 | , messageCount :: {-# UNPACK #-} !Int | 200 | , messageBytes :: {-# UNPACK #-} !ByteStats |
198 | } deriving Show | 201 | } deriving Show |
199 | 202 | ||
203 | instance Pretty FlowStats where | ||
204 | pretty FlowStats {..} = | ||
205 | PP.int messageCount <+> "messages" $+$ | ||
206 | pretty messageBytes | ||
207 | |||
200 | -- | Zeroed stats. | 208 | -- | Zeroed stats. |
201 | instance Default FlowStats where | 209 | instance Default FlowStats where |
202 | def = FlowStats def 0 | 210 | def = FlowStats 0 def |
203 | 211 | ||
204 | -- | Monoid under addition. | 212 | -- | Monoid under addition. |
205 | instance Monoid FlowStats where | 213 | instance Monoid FlowStats where |
@@ -216,6 +224,14 @@ addFlowStats x FlowStats {..} = FlowStats | |||
216 | , messageCount = succ messageCount | 224 | , messageCount = succ messageCount |
217 | } | 225 | } |
218 | 226 | ||
227 | -- | Find average length of byte sequences per message. | ||
228 | avgByteStats :: FlowStats -> ByteStats | ||
229 | avgByteStats (FlowStats n ByteStats {..}) = ByteStats | ||
230 | { overhead = overhead `quot` n | ||
231 | , control = control `quot` n | ||
232 | , payload = payload `quot` n | ||
233 | } | ||
234 | |||
219 | -- | Message stats in both directions. This data can be retrieved | 235 | -- | Message stats in both directions. This data can be retrieved |
220 | -- using 'getStats' function. | 236 | -- using 'getStats' function. |
221 | -- | 237 | -- |
@@ -231,6 +247,13 @@ data ConnectionStats = ConnectionStats | |||
231 | , outcomingFlow :: !FlowStats | 247 | , outcomingFlow :: !FlowStats |
232 | } deriving Show | 248 | } deriving Show |
233 | 249 | ||
250 | instance Pretty ConnectionStats where | ||
251 | pretty ConnectionStats {..} = vcat | ||
252 | [ "Recv:" <+> pretty incomingFlow | ||
253 | , "Sent:" <+> pretty outcomingFlow | ||
254 | , "Both:" <+> pretty (incomingFlow <> outcomingFlow) | ||
255 | ] | ||
256 | |||
234 | -- | Zeroed stats. | 257 | -- | Zeroed stats. |
235 | instance Default ConnectionStats where | 258 | instance Default ConnectionStats where |
236 | def = ConnectionStats def def | 259 | def = ConnectionStats def def |
@@ -337,86 +360,83 @@ connectToPeer p = do | |||
337 | -----------------------------------------------------------------------} | 360 | -----------------------------------------------------------------------} |
338 | 361 | ||
339 | -- | do not expose this so we can change it without breaking api | 362 | -- | do not expose this so we can change it without breaking api |
340 | type Connectivity = ReaderT Connection | 363 | type Connected = ReaderT Connection |
341 | 364 | ||
342 | -- | A duplex channel connected to a remote peer which keep tracks | 365 | -- | A duplex channel connected to a remote peer which keep tracks |
343 | -- connection parameters. | 366 | -- connection parameters. |
344 | type Wire a = ConduitM Message Message (Connectivity IO) a | 367 | type Wire a = ConduitM Message Message (Connected IO) a |
345 | |||
346 | protocolError :: ProtocolError -> Wire a | ||
347 | protocolError = monadThrow . ProtocolError | ||
348 | 368 | ||
349 | -- | Forcefully terminate wire session and close socket. | 369 | {----------------------------------------------------------------------- |
350 | disconnectPeer :: Wire a | 370 | -- Query |
351 | disconnectPeer = monadThrow DisconnectPeer | 371 | -----------------------------------------------------------------------} |
352 | 372 | ||
353 | readRef :: (Connection -> IORef a) -> Wire a | 373 | readRef :: (Connection -> IORef a) -> Connected IO a |
354 | readRef f = do | 374 | readRef f = do |
355 | ref <- lift (asks f) | 375 | ref <- asks f |
356 | liftIO (readIORef ref) | 376 | liftIO (readIORef ref) |
357 | 377 | ||
358 | writeRef :: (Connection -> IORef a) -> a -> Wire () | 378 | writeRef :: (Connection -> IORef a) -> a -> Connected IO () |
359 | writeRef f v = do | 379 | writeRef f v = do |
360 | ref <- lift (asks f) | 380 | ref <- asks f |
361 | liftIO (writeIORef ref v) | 381 | liftIO (writeIORef ref v) |
362 | 382 | ||
363 | modifyRef :: (Connection -> IORef a) -> (a -> a) -> Wire () | 383 | modifyRef :: (Connection -> IORef a) -> (a -> a) -> Connected IO () |
364 | modifyRef f m = do | 384 | modifyRef f m = do |
365 | ref <- lift (asks f) | 385 | ref <- asks f |
366 | liftIO (atomicModifyIORef' ref (\x -> (m x, ()))) | 386 | liftIO (atomicModifyIORef' ref (\x -> (m x, ()))) |
367 | 387 | ||
368 | setExtCaps :: ExtendedCaps -> Wire () | 388 | setExtCaps :: ExtendedCaps -> Wire () |
369 | setExtCaps = writeRef connExtCaps | 389 | setExtCaps = lift . writeRef connExtCaps |
370 | 390 | ||
371 | -- | Get current extended capabilities. Note that this value can | 391 | -- | Get current extended capabilities. Note that this value can |
372 | -- change in current session if either this or remote peer will | 392 | -- change in current session if either this or remote peer will |
373 | -- initiate rehandshaking. | 393 | -- initiate rehandshaking. |
374 | getExtCaps :: Wire ExtendedCaps | 394 | getExtCaps :: Wire ExtendedCaps |
375 | getExtCaps = readRef connExtCaps | 395 | getExtCaps = lift $ readRef connExtCaps |
376 | 396 | ||
377 | -- | Get current stats. Note that this value will change with the next | 397 | -- | Get current stats. Note that this value will change with the next |
378 | -- sent or received message. | 398 | -- sent or received message. |
379 | getStats :: Wire ConnectionStats | 399 | getStats :: Wire ConnectionStats |
380 | getStats = readRef connStats | 400 | getStats = lift $ readRef connStats |
381 | |||
382 | putStats :: ChannelSide -> Message -> Wire () | ||
383 | putStats side msg = modifyRef connStats (addStats side (stats msg)) | ||
384 | 401 | ||
385 | -- | See the 'Connection' section for more info. | 402 | -- | See the 'Connection' section for more info. |
386 | getConnection :: Wire Connection | 403 | getConnection :: Wire Connection |
387 | getConnection = lift ask | 404 | getConnection = lift ask |
388 | 405 | ||
389 | validate :: ChannelSide -> Wire () | 406 | {----------------------------------------------------------------------- |
390 | validate side = await >>= maybe (return ()) yieldCheck | 407 | -- Wrapper |
391 | where | 408 | -----------------------------------------------------------------------} |
392 | yieldCheck msg = do | 409 | |
393 | caps <- lift $ asks connCaps | 410 | putStats :: ChannelSide -> Message -> Connected IO () |
394 | case requires msg of | 411 | putStats side msg = modifyRef connStats (addStats side (stats msg)) |
395 | Nothing -> return () | 412 | |
396 | Just ext | 413 | validate :: ChannelSide -> Message -> Connected IO () |
397 | | ext `allowed` caps -> yield msg | 414 | validate side msg = do |
398 | | otherwise -> protocolError $ DisallowedMessage side ext | 415 | caps <- asks connCaps |
399 | 416 | case requires msg of | |
400 | validateBoth :: Wire () -> Wire () | ||
401 | validateBoth action = do | ||
402 | validate RemotePeer | ||
403 | action | ||
404 | validate ThisPeer | ||
405 | |||
406 | trackStats :: Wire () | ||
407 | trackStats = do | ||
408 | mmsg <- await | ||
409 | case mmsg of | ||
410 | Nothing -> return () | 417 | Nothing -> return () |
411 | Just msg -> putStats ThisPeer msg -- FIXME not really ThisPeer | 418 | Just ext |
419 | | ext `allowed` caps -> return () | ||
420 | | otherwise -> protocolError $ DisallowedMessage side ext | ||
421 | |||
422 | trackFlow :: ChannelSide -> Wire () | ||
423 | trackFlow side = iterM $ do | ||
424 | validate side | ||
425 | putStats side | ||
426 | |||
427 | {----------------------------------------------------------------------- | ||
428 | -- Setup | ||
429 | -----------------------------------------------------------------------} | ||
412 | 430 | ||
413 | -- | Normally you should use 'connectWire' or 'acceptWire'. | 431 | -- | Normally you should use 'connectWire' or 'acceptWire'. |
414 | runWire :: Wire () -> Socket -> Connection -> IO () | 432 | runWire :: Wire () -> Socket -> Connection -> IO () |
415 | runWire action sock = runReaderT $ | 433 | runWire action sock = runReaderT $ |
416 | sourceSocket sock $= | 434 | sourceSocket sock $= |
417 | S.conduitGet S.get $= | 435 | conduitGet get $= |
418 | action $= | 436 | trackFlow RemotePeer $= |
419 | S.conduitPut S.put $$ | 437 | action $= |
438 | trackFlow ThisPeer $= | ||
439 | conduitPut put $$ | ||
420 | sinkSocket sock | 440 | sinkSocket sock |
421 | 441 | ||
422 | -- | This function will block until a peer send new message. You can | 442 | -- | This function will block until a peer send new message. You can |
@@ -475,7 +495,11 @@ connectWire hs addr extCaps wire = | |||
475 | else wire | 495 | else wire |
476 | 496 | ||
477 | extCapsRef <- newIORef def | 497 | extCapsRef <- newIORef def |
478 | statsRef <- newIORef def | 498 | statsRef <- newIORef ConnectionStats |
499 | { outcomingFlow = FlowStats 1 $ handshakeStats hs | ||
500 | , incomingFlow = FlowStats 1 $ handshakeStats hs' | ||
501 | } | ||
502 | |||
479 | runWire wire' sock $ Connection | 503 | runWire wire' sock $ Connection |
480 | { connProtocol = hsProtocol hs | 504 | { connProtocol = hsProtocol hs |
481 | , connCaps = caps | 505 | , connCaps = caps |