diff options
Diffstat (limited to 'src/Network')
-rw-r--r-- | src/Network/BitTorrent/Exchange/Session.hs | 36 | ||||
-rw-r--r-- | src/Network/BitTorrent/Exchange/Wire.hs | 122 |
2 files changed, 60 insertions, 98 deletions
diff --git a/src/Network/BitTorrent/Exchange/Session.hs b/src/Network/BitTorrent/Exchange/Session.hs index 0d4f3d02..416e00fd 100644 --- a/src/Network/BitTorrent/Exchange/Session.hs +++ b/src/Network/BitTorrent/Exchange/Session.hs | |||
@@ -110,7 +110,7 @@ newSession logFun addr rootPath dict = do | |||
110 | closeSession :: Session -> IO () | 110 | closeSession :: Session -> IO () |
111 | closeSession = undefined | 111 | closeSession = undefined |
112 | 112 | ||
113 | instance MonadIO m => MonadLogger (Connected Session m) where | 113 | instance MonadLogger (Connected Session) where |
114 | monadLoggerLog loc src lvl msg = do | 114 | monadLoggerLog loc src lvl msg = do |
115 | conn <- ask | 115 | conn <- ask |
116 | ses <- asks connSession | 116 | ses <- asks connSession |
@@ -143,9 +143,9 @@ insert addr ses @ Session {..} = do | |||
143 | let hs = Handshake def caps infohash tpeerId | 143 | let hs = Handshake def caps infohash tpeerId |
144 | chan <- dupChan broadcast | 144 | chan <- dupChan broadcast |
145 | connectWire ses hs addr ecaps chan $ do | 145 | connectWire ses hs addr ecaps chan $ do |
146 | conn <- getConnection | 146 | conn <- ask |
147 | -- liftIO $ modifyMVar_ connections $ pure . M.insert addr conn | 147 | -- liftIO $ modifyMVar_ connections $ pure . M.insert addr conn |
148 | resizeBitfield (totalPieces storage) | 148 | lift $ resizeBitfield (totalPieces storage) |
149 | logEvent "Connection established" | 149 | logEvent "Connection established" |
150 | exchange | 150 | exchange |
151 | -- liftIO $ modifyMVar_ connections $ pure . M.delete addr | 151 | -- liftIO $ modifyMVar_ connections $ pure . M.delete addr |
@@ -162,12 +162,12 @@ deleteAll = undefined | |||
162 | 162 | ||
163 | withStatusUpdates :: StatusUpdates a -> Wire Session a | 163 | withStatusUpdates :: StatusUpdates a -> Wire Session a |
164 | withStatusUpdates m = do | 164 | withStatusUpdates m = do |
165 | Session {..} <- getSession | 165 | Session {..} <- asks connSession |
166 | liftIO $ runStatusUpdates status m | 166 | liftIO $ runStatusUpdates status m |
167 | 167 | ||
168 | getThisBitfield :: Wire Session Bitfield | 168 | getThisBitfield :: Wire Session Bitfield |
169 | getThisBitfield = do | 169 | getThisBitfield = do |
170 | ses <- getSession | 170 | ses <- asks connSession |
171 | liftIO $ SS.getBitfield (status ses) | 171 | liftIO $ SS.getBitfield (status ses) |
172 | 172 | ||
173 | readBlock :: BlockIx -> Storage -> IO (Block BL.ByteString) | 173 | readBlock :: BlockIx -> Storage -> IO (Block BL.ByteString) |
@@ -181,8 +181,8 @@ readBlock bix @ BlockIx {..} s = do | |||
181 | 181 | ||
182 | sendBroadcast :: PeerMessage msg => msg -> Wire Session () | 182 | sendBroadcast :: PeerMessage msg => msg -> Wire Session () |
183 | sendBroadcast msg = do | 183 | sendBroadcast msg = do |
184 | Session {..} <- getSession | 184 | Session {..} <- asks connSession |
185 | ecaps <- getExtCaps | 185 | ecaps <- use connExtCaps |
186 | liftIO $ writeChan broadcast (envelop ecaps msg) | 186 | liftIO $ writeChan broadcast (envelop ecaps msg) |
187 | 187 | ||
188 | {----------------------------------------------------------------------- | 188 | {----------------------------------------------------------------------- |
@@ -191,9 +191,9 @@ sendBroadcast msg = do | |||
191 | 191 | ||
192 | fillRequestQueue :: Wire Session () | 192 | fillRequestQueue :: Wire Session () |
193 | fillRequestQueue = do | 193 | fillRequestQueue = do |
194 | maxN <- getAdvertisedQueueLength | 194 | maxN <- lift $ getAdvertisedQueueLength |
195 | rbf <- getRemoteBitfield | 195 | rbf <- use connBitfield |
196 | addr <- connRemoteAddr <$> getConnection | 196 | addr <- asks connRemoteAddr |
197 | blks <- withStatusUpdates $ do | 197 | blks <- withStatusUpdates $ do |
198 | n <- getRequestQueueLength addr | 198 | n <- getRequestQueueLength addr |
199 | scheduleBlocks addr rbf (maxN - n) | 199 | scheduleBlocks addr rbf (maxN - n) |
@@ -201,13 +201,13 @@ fillRequestQueue = do | |||
201 | 201 | ||
202 | tryFillRequestQueue :: Wire Session () | 202 | tryFillRequestQueue :: Wire Session () |
203 | tryFillRequestQueue = do | 203 | tryFillRequestQueue = do |
204 | allowed <- canDownload <$> getStatus | 204 | allowed <- canDownload <$> use connStatus |
205 | when allowed $ do | 205 | when allowed $ do |
206 | fillRequestQueue | 206 | fillRequestQueue |
207 | 207 | ||
208 | interesting :: Wire Session () | 208 | interesting :: Wire Session () |
209 | interesting = do | 209 | interesting = do |
210 | addr <- connRemoteAddr <$> getConnection | 210 | addr <- asks connRemoteAddr |
211 | logMessage (Status (Interested True)) | 211 | logMessage (Status (Interested True)) |
212 | sendMessage (Interested True) | 212 | sendMessage (Interested True) |
213 | logMessage (Status (Choking False)) | 213 | logMessage (Status (Choking False)) |
@@ -220,17 +220,17 @@ interesting = do | |||
220 | 220 | ||
221 | handleStatus :: StatusUpdate -> Wire Session () | 221 | handleStatus :: StatusUpdate -> Wire Session () |
222 | handleStatus s = do | 222 | handleStatus s = do |
223 | updateConnStatus RemotePeer s | 223 | connStatus %= over remoteStatus (updateStatus s) |
224 | case s of | 224 | case s of |
225 | Interested _ -> return () | 225 | Interested _ -> return () |
226 | Choking True -> do | 226 | Choking True -> do |
227 | addr <- connRemoteAddr <$> getConnection | 227 | addr <- asks connRemoteAddr |
228 | withStatusUpdates (resetPending addr) | 228 | withStatusUpdates (resetPending addr) |
229 | Choking False -> tryFillRequestQueue | 229 | Choking False -> tryFillRequestQueue |
230 | 230 | ||
231 | handleAvailable :: Available -> Wire Session () | 231 | handleAvailable :: Available -> Wire Session () |
232 | handleAvailable msg = do | 232 | handleAvailable msg = do |
233 | updateRemoteBitfield $ case msg of | 233 | connBitfield %= case msg of |
234 | Have ix -> BF.insert ix | 234 | Have ix -> BF.insert ix |
235 | Bitfield bf -> const bf | 235 | Bitfield bf -> const bf |
236 | 236 | ||
@@ -245,15 +245,15 @@ handleAvailable msg = do | |||
245 | 245 | ||
246 | handleTransfer :: Transfer -> Wire Session () | 246 | handleTransfer :: Transfer -> Wire Session () |
247 | handleTransfer (Request bix) = do | 247 | handleTransfer (Request bix) = do |
248 | Session {..} <- getSession | 248 | Session {..} <- asks connSession |
249 | bitfield <- getThisBitfield | 249 | bitfield <- getThisBitfield |
250 | upload <- canUpload <$> getStatus | 250 | upload <- canUpload <$> use connStatus |
251 | when (upload && ixPiece bix `BF.member` bitfield) $ do | 251 | when (upload && ixPiece bix `BF.member` bitfield) $ do |
252 | blk <- liftIO $ readBlock bix storage | 252 | blk <- liftIO $ readBlock bix storage |
253 | sendMessage (Piece blk) | 253 | sendMessage (Piece blk) |
254 | 254 | ||
255 | handleTransfer (Piece blk) = do | 255 | handleTransfer (Piece blk) = do |
256 | Session {..} <- getSession | 256 | Session {..} <- asks connSession |
257 | isSuccess <- withStatusUpdates (pushBlock blk storage) | 257 | isSuccess <- withStatusUpdates (pushBlock blk storage) |
258 | case isSuccess of | 258 | case isSuccess of |
259 | Nothing -> liftIO $ throwIO $ userError "block is not requested" | 259 | Nothing -> liftIO $ throwIO $ userError "block is not requested" |
diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index 64fa3295..e88b3ae5 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs | |||
@@ -18,41 +18,46 @@ module Network.BitTorrent.Exchange.Wire | |||
18 | Connected | 18 | Connected |
19 | , Wire | 19 | , Wire |
20 | 20 | ||
21 | -- ** Connection | 21 | -- * Connection |
22 | , Connection | 22 | , Connection |
23 | |||
24 | -- ** Identity | ||
23 | , connRemoteAddr | 25 | , connRemoteAddr |
24 | , connProtocol | ||
25 | , connCaps | ||
26 | , connTopic | 26 | , connTopic |
27 | , connRemotePeerId | 27 | , connRemotePeerId |
28 | , connThisPeerId | 28 | , connThisPeerId |
29 | |||
30 | -- ** Capabilities | ||
31 | , connProtocol | ||
32 | , connCaps | ||
33 | , connExtCaps | ||
34 | , connRemoteEhs | ||
35 | |||
36 | -- ** State | ||
37 | , connStatus | ||
38 | , connBitfield | ||
39 | |||
40 | -- ** Env | ||
29 | , connOptions | 41 | , connOptions |
30 | , connSession | 42 | , connSession |
43 | , connStats | ||
31 | 44 | ||
32 | -- ** Setup | 45 | -- * Setup |
33 | , runWire | 46 | , runWire |
34 | , connectWire | 47 | , connectWire |
35 | , acceptWire | 48 | , acceptWire |
36 | , resizeBitfield | 49 | , resizeBitfield |
37 | 50 | ||
38 | -- ** Messaging | 51 | -- * Messaging |
39 | , recvMessage | 52 | , recvMessage |
40 | , sendMessage | 53 | , sendMessage |
41 | , filterQueue | 54 | , filterQueue |
42 | , getAdvertisedQueueLength | 55 | , getAdvertisedQueueLength |
43 | 56 | ||
44 | -- ** Query | 57 | -- * Query |
45 | , getConnection | ||
46 | , getSession | ||
47 | , getStatus | ||
48 | , getRemoteBitfield | ||
49 | , updateConnStatus | ||
50 | , updateRemoteBitfield | ||
51 | , getExtCaps | ||
52 | , getStats | ||
53 | , getMetadata | 58 | , getMetadata |
54 | 59 | ||
55 | -- ** Exceptions | 60 | -- * Exceptions |
56 | , ChannelSide (..) | 61 | , ChannelSide (..) |
57 | , ProtocolError (..) | 62 | , ProtocolError (..) |
58 | , WireFailure (..) | 63 | , WireFailure (..) |
@@ -60,15 +65,15 @@ module Network.BitTorrent.Exchange.Wire | |||
60 | , isWireFailure | 65 | , isWireFailure |
61 | , disconnectPeer | 66 | , disconnectPeer |
62 | 67 | ||
63 | -- ** Stats | 68 | -- * Stats |
64 | , ByteStats (..) | 69 | , ByteStats (..) |
65 | , FlowStats (..) | 70 | , FlowStats (..) |
66 | , ConnectionStats (..) | 71 | , ConnectionStats (..) |
67 | 72 | ||
68 | -- ** Flood detection | 73 | -- * Flood detection |
69 | , FloodDetector (..) | 74 | , FloodDetector (..) |
70 | 75 | ||
71 | -- ** Options | 76 | -- * Options |
72 | , Options (..) | 77 | , Options (..) |
73 | ) where | 78 | ) where |
74 | 79 | ||
@@ -449,6 +454,10 @@ data ConnectionState = ConnectionState { | |||
449 | -- | If @not (allowed ExtExtended connCaps)@ then this set is always | 454 | -- | If @not (allowed ExtExtended connCaps)@ then this set is always |
450 | -- empty. Otherwise it has the BEP10 extension protocol mandated mapping of | 455 | -- empty. Otherwise it has the BEP10 extension protocol mandated mapping of |
451 | -- 'MessageId' to the message type for the remote peer. | 456 | -- 'MessageId' to the message type for the remote peer. |
457 | -- | ||
458 | -- Note that this value can change in current session if either | ||
459 | -- this or remote peer will initiate rehandshaking. | ||
460 | -- | ||
452 | _connExtCaps :: !ExtendedCaps | 461 | _connExtCaps :: !ExtendedCaps |
453 | 462 | ||
454 | -- | Current extended handshake information from the remote peer | 463 | -- | Current extended handshake information from the remote peer |
@@ -456,6 +465,9 @@ data ConnectionState = ConnectionState { | |||
456 | 465 | ||
457 | -- | Various stats about messages sent and received. Stats can be | 466 | -- | Various stats about messages sent and received. Stats can be |
458 | -- used to protect /this/ peer against flood attacks. | 467 | -- used to protect /this/ peer against flood attacks. |
468 | -- | ||
469 | -- Note that this value will change with the next sent or received | ||
470 | -- message. | ||
459 | , _connStats :: !ConnectionStats | 471 | , _connStats :: !ConnectionStats |
460 | 472 | ||
461 | , _connStatus :: !ConnectionStatus | 473 | , _connStatus :: !ConnectionStatus |
@@ -565,70 +577,40 @@ initiateHandshake sock hs = do | |||
565 | -----------------------------------------------------------------------} | 577 | -----------------------------------------------------------------------} |
566 | 578 | ||
567 | -- | do not expose this so we can change it without breaking api | 579 | -- | do not expose this so we can change it without breaking api |
568 | newtype Connected s m a = Connected { runConnected :: (ReaderT (Connection s) m a) } | 580 | newtype Connected s a = Connected { runConnected :: (ReaderT (Connection s) IO a) } |
569 | deriving (Functor, Applicative, Monad | 581 | deriving (Functor, Applicative, Monad |
570 | , MonadIO, MonadReader (Connection s), MonadThrow | 582 | , MonadIO, MonadReader (Connection s), MonadThrow |
571 | ) | 583 | ) |
572 | 584 | ||
573 | instance MonadIO m => MonadState ConnectionState (Connected s m) where | 585 | instance MonadState ConnectionState (Connected s) where |
574 | get = Connected (asks connState) >>= liftIO . readIORef | 586 | get = Connected (asks connState) >>= liftIO . readIORef |
575 | put x = Connected (asks connState) >>= liftIO . flip writeIORef x | 587 | put x = Connected (asks connState) >>= liftIO . flip writeIORef x |
576 | 588 | ||
577 | instance MonadTrans (Connected s) where | ||
578 | lift = Connected . lift | ||
579 | |||
580 | -- | A duplex channel connected to a remote peer which keep tracks | 589 | -- | A duplex channel connected to a remote peer which keep tracks |
581 | -- connection parameters. | 590 | -- connection parameters. |
582 | type Wire s a = ConduitM Message Message (Connected s IO) a | 591 | type Wire s a = ConduitM Message Message (Connected s) a |
583 | 592 | ||
584 | {----------------------------------------------------------------------- | 593 | {----------------------------------------------------------------------- |
585 | -- Query | 594 | -- Query |
586 | -----------------------------------------------------------------------} | 595 | -----------------------------------------------------------------------} |
587 | 596 | ||
588 | setExtCaps :: ExtendedCaps -> Wire s () | ||
589 | setExtCaps x = lift $ connExtCaps .= x | ||
590 | |||
591 | -- | Get current extended capabilities. Note that this value can | ||
592 | -- change in current session if either this or remote peer will | ||
593 | -- initiate rehandshaking. | ||
594 | getExtCaps :: Wire s ExtendedCaps | ||
595 | getExtCaps = lift $ use connExtCaps | ||
596 | |||
597 | setRemoteEhs :: ExtendedHandshake -> Wire s () | ||
598 | setRemoteEhs x = lift $ connRemoteEhs .= x | ||
599 | |||
600 | getRemoteEhs :: Wire s ExtendedHandshake | ||
601 | getRemoteEhs = lift $ use connRemoteEhs | ||
602 | |||
603 | -- | Get current stats. Note that this value will change with the next | ||
604 | -- sent or received message. | ||
605 | getStats :: Wire s ConnectionStats | ||
606 | getStats = lift $ use connStats | ||
607 | |||
608 | -- | See the 'Connection' section for more info. | ||
609 | getConnection :: Wire s (Connection s) | ||
610 | getConnection = lift ask | ||
611 | |||
612 | getSession :: Wire s s | ||
613 | getSession = lift (asks connSession) | ||
614 | |||
615 | -- TODO configurable | 597 | -- TODO configurable |
616 | defQueueLength :: Int | 598 | defQueueLength :: Int |
617 | defQueueLength = 1 | 599 | defQueueLength = 1 |
618 | 600 | ||
619 | getAdvertisedQueueLength :: Wire s Int | 601 | getAdvertisedQueueLength :: Connected s Int |
620 | getAdvertisedQueueLength = do | 602 | getAdvertisedQueueLength = do |
621 | ExtendedHandshake {..} <- getRemoteEhs | 603 | ExtendedHandshake {..} <- use connRemoteEhs |
622 | return $ fromMaybe defQueueLength ehsQueueLength | 604 | return $ fromMaybe defQueueLength ehsQueueLength |
623 | 605 | ||
624 | {----------------------------------------------------------------------- | 606 | {----------------------------------------------------------------------- |
625 | -- Wrapper | 607 | -- Wrapper |
626 | -----------------------------------------------------------------------} | 608 | -----------------------------------------------------------------------} |
627 | 609 | ||
628 | putStats :: ChannelSide -> Message -> Connected s IO () | 610 | putStats :: ChannelSide -> Message -> Connected s () |
629 | putStats side msg = connStats %= addStats side (stats msg) | 611 | putStats side msg = connStats %= addStats side (stats msg) |
630 | 612 | ||
631 | validate :: ChannelSide -> Message -> Connected s IO () | 613 | validate :: ChannelSide -> Message -> Connected s () |
632 | validate side msg = do | 614 | validate side msg = do |
633 | caps <- asks connCaps | 615 | caps <- asks connCaps |
634 | case requires msg of | 616 | case requires msg of |
@@ -696,8 +678,8 @@ extendedHandshake caps = do | |||
696 | msg <- recvMessage | 678 | msg <- recvMessage |
697 | case msg of | 679 | case msg of |
698 | Extended (EHandshake remoteEhs@(ExtendedHandshake {..})) -> do | 680 | Extended (EHandshake remoteEhs@(ExtendedHandshake {..})) -> do |
699 | setExtCaps $ ehsCaps <> caps | 681 | connExtCaps .= (ehsCaps <> caps) |
700 | setRemoteEhs remoteEhs | 682 | connRemoteEhs .= remoteEhs |
701 | _ -> protocolError HandshakeRefused | 683 | _ -> protocolError HandshakeRefused |
702 | 684 | ||
703 | rehandshake :: ExtendedCaps -> Wire s () | 685 | rehandshake :: ExtendedCaps -> Wire s () |
@@ -776,29 +758,9 @@ acceptWire sock peerAddr wire = do | |||
776 | bracket (return sock) close $ \ _ -> do | 758 | bracket (return sock) close $ \ _ -> do |
777 | error "acceptWire: not implemented" | 759 | error "acceptWire: not implemented" |
778 | 760 | ||
779 | {----------------------------------------------------------------------- | ||
780 | -- Connection Status | ||
781 | -----------------------------------------------------------------------} | ||
782 | |||
783 | getStatus :: Wire s ConnectionStatus | ||
784 | getStatus = lift $ use connStatus | ||
785 | |||
786 | updateConnStatus :: ChannelSide -> StatusUpdate -> Wire s () | ||
787 | updateConnStatus side u = lift $ do | ||
788 | connStatus %= (over (statusSide side) (updateStatus u)) | ||
789 | where | ||
790 | statusSide ThisPeer = clientStatus | ||
791 | statusSide RemotePeer = remoteStatus | ||
792 | |||
793 | getRemoteBitfield :: Wire s Bitfield | ||
794 | getRemoteBitfield = lift $ use connBitfield | ||
795 | |||
796 | updateRemoteBitfield :: (Bitfield -> Bitfield) -> Wire s () | ||
797 | updateRemoteBitfield f = lift $ connBitfield %= f | ||
798 | |||
799 | -- | Used when size of bitfield becomes known. | 761 | -- | Used when size of bitfield becomes known. |
800 | resizeBitfield :: Int -> Wire s () | 762 | resizeBitfield :: Int -> Connected s () |
801 | resizeBitfield n = updateRemoteBitfield (adjustSize n) | 763 | resizeBitfield n = connBitfield %= adjustSize n |
802 | 764 | ||
803 | {----------------------------------------------------------------------- | 765 | {----------------------------------------------------------------------- |
804 | -- Metadata exchange | 766 | -- Metadata exchange |
@@ -833,7 +795,7 @@ fetchMetadata = loop 0 | |||
833 | getMetadata :: Wire s InfoDict | 795 | getMetadata :: Wire s InfoDict |
834 | getMetadata = do | 796 | getMetadata = do |
835 | chunks <- fetchMetadata | 797 | chunks <- fetchMetadata |
836 | Connection {..} <- getConnection | 798 | Connection {..} <- ask |
837 | case BE.decode (BS.concat chunks) of | 799 | case BE.decode (BS.concat chunks) of |
838 | Right (infodict @ InfoDict {..}) | 800 | Right (infodict @ InfoDict {..}) |
839 | | connTopic == idInfoHash -> return infodict | 801 | | connTopic == idInfoHash -> return infodict |