diff options
author | Sam Truzjan <pxqr.sta@gmail.com> | 2013-12-05 20:00:01 +0400 |
---|---|---|
committer | Sam Truzjan <pxqr.sta@gmail.com> | 2013-12-05 20:00:01 +0400 |
commit | 2367c9c0ab45a843cb4d1b8762af4ca06291348c (patch) | |
tree | 5124617196b8558cf876d62e2dd7f7052c49fccb | |
parent | 6f092fb275367b6afe4f0745f975e8ee53012d56 (diff) |
Make extended caps mutable
-rw-r--r-- | src/Network/BitTorrent/Exchange/Message.hs | 7 | ||||
-rw-r--r-- | src/Network/BitTorrent/Exchange/Wire.hs | 73 |
2 files changed, 62 insertions, 18 deletions
diff --git a/src/Network/BitTorrent/Exchange/Message.hs b/src/Network/BitTorrent/Exchange/Message.hs index b879e212..17ec7da6 100644 --- a/src/Network/BitTorrent/Exchange/Message.hs +++ b/src/Network/BitTorrent/Exchange/Message.hs | |||
@@ -62,6 +62,7 @@ module Network.BitTorrent.Exchange.Message | |||
62 | , ExtendedExtension | 62 | , ExtendedExtension |
63 | , ExtendedCaps (..) | 63 | , ExtendedCaps (..) |
64 | , ExtendedHandshake (..) | 64 | , ExtendedHandshake (..) |
65 | , nullExtendedHandshake | ||
65 | , ExtendedMetadata (..) | 66 | , ExtendedMetadata (..) |
66 | ) where | 67 | ) where |
67 | 68 | ||
@@ -411,7 +412,7 @@ data ExtendedHandshake = ExtendedHandshake | |||
411 | } deriving (Show, Eq, Typeable) | 412 | } deriving (Show, Eq, Typeable) |
412 | 413 | ||
413 | instance Default ExtendedHandshake where | 414 | instance Default ExtendedHandshake where |
414 | def = ExtendedHandshake Nothing Nothing def Nothing Nothing Nothing | 415 | def = nullExtendedHandshake def |
415 | 416 | ||
416 | instance BEncode ExtendedHandshake where | 417 | instance BEncode ExtendedHandshake where |
417 | toBEncode ExtendedHandshake {..} = toDict $ | 418 | toBEncode ExtendedHandshake {..} = toDict $ |
@@ -439,6 +440,10 @@ instance Pretty ExtendedHandshake where | |||
439 | instance PeerMessage ExtendedHandshake where | 440 | instance PeerMessage ExtendedHandshake where |
440 | envelop c = envelop c . EHandshake | 441 | envelop c = envelop c . EHandshake |
441 | 442 | ||
443 | nullExtendedHandshake :: ExtendedCaps -> ExtendedHandshake | ||
444 | nullExtendedHandshake caps | ||
445 | = ExtendedHandshake Nothing Nothing caps Nothing Nothing Nothing | ||
446 | |||
442 | {----------------------------------------------------------------------- | 447 | {----------------------------------------------------------------------- |
443 | -- Metadata exchange | 448 | -- Metadata exchange |
444 | -----------------------------------------------------------------------} | 449 | -----------------------------------------------------------------------} |
diff --git a/src/Network/BitTorrent/Exchange/Wire.hs b/src/Network/BitTorrent/Exchange/Wire.hs index 68c9b355..6a161762 100644 --- a/src/Network/BitTorrent/Exchange/Wire.hs +++ b/src/Network/BitTorrent/Exchange/Wire.hs | |||
@@ -4,20 +4,29 @@ | |||
4 | -- Duplex channell | 4 | -- Duplex channell |
5 | -- This module control /integrity/ of data send and received. | 5 | -- This module control /integrity/ of data send and received. |
6 | -- | 6 | -- |
7 | 7 | -- | |
8 | {-# LANGUAGE DeriveDataTypeable #-} | 8 | {-# LANGUAGE DeriveDataTypeable #-} |
9 | module Network.BitTorrent.Exchange.Wire | 9 | module Network.BitTorrent.Exchange.Wire |
10 | ( -- * Exception | 10 | ( -- * Wire |
11 | ProtocolError (..) | 11 | Wire |
12 | |||
13 | -- ** Exceptions | ||
14 | , ProtocolError (..) | ||
12 | , WireFailure (..) | 15 | , WireFailure (..) |
13 | , isWireFailure | 16 | , isWireFailure |
17 | , disconnectPeer | ||
14 | 18 | ||
15 | -- * Wire | 19 | -- ** Connection |
16 | , Connection (..) | 20 | , Connection (..) |
17 | , Wire | 21 | |
22 | -- ** Setup | ||
18 | , runWire | 23 | , runWire |
19 | , connectWire | 24 | , connectWire |
20 | , acceptWire | 25 | , acceptWire |
26 | |||
27 | -- ** Query | ||
28 | , getConnection | ||
29 | , getExtCaps | ||
21 | ) where | 30 | ) where |
22 | 31 | ||
23 | import Control.Exception | 32 | import Control.Exception |
@@ -27,6 +36,7 @@ import Data.Conduit | |||
27 | import Data.Conduit.Cereal as S | 36 | import Data.Conduit.Cereal as S |
28 | import Data.Conduit.Network | 37 | import Data.Conduit.Network |
29 | import Data.Default | 38 | import Data.Default |
39 | import Data.IORef | ||
30 | import Data.Maybe | 40 | import Data.Maybe |
31 | import Data.Monoid | 41 | import Data.Monoid |
32 | import Data.Serialize as S | 42 | import Data.Serialize as S |
@@ -56,6 +66,7 @@ data ProtocolError | |||
56 | = UnexpectedTopic InfoHash -- ^ peer replied with unexpected infohash. | 66 | = UnexpectedTopic InfoHash -- ^ peer replied with unexpected infohash. |
57 | | UnexpectedPeerId PeerId -- ^ peer replied with unexpected peer id. | 67 | | UnexpectedPeerId PeerId -- ^ peer replied with unexpected peer id. |
58 | | UnknownTopic InfoHash -- ^ peer requested unknown torrent. | 68 | | UnknownTopic InfoHash -- ^ peer requested unknown torrent. |
69 | | HandshakeRefused -- ^ peer do not send an extended handshake back. | ||
59 | | InvalidMessage | 70 | | InvalidMessage |
60 | { violentSender :: ChannelSide -- ^ endpoint sent invalid message | 71 | { violentSender :: ChannelSide -- ^ endpoint sent invalid message |
61 | , extensionRequired :: Extension -- ^ | 72 | , extensionRequired :: Extension -- ^ |
@@ -87,11 +98,11 @@ isWireFailure _ = return () | |||
87 | 98 | ||
88 | data Connection = Connection | 99 | data Connection = Connection |
89 | { connCaps :: !Caps | 100 | { connCaps :: !Caps |
90 | , connExtCaps :: !ExtendedCaps -- TODO caps can be enabled during communication | 101 | , connExtCaps :: !(IORef ExtendedCaps) |
91 | , connTopic :: !InfoHash | 102 | , connTopic :: !InfoHash |
92 | , connRemotePeerId :: !PeerId | 103 | , connRemotePeerId :: !PeerId |
93 | , connThisPeerId :: !PeerId | 104 | , connThisPeerId :: !PeerId |
94 | } deriving Show | 105 | } |
95 | 106 | ||
96 | instance Pretty Connection where | 107 | instance Pretty Connection where |
97 | pretty Connection {..} = "Connection" | 108 | pretty Connection {..} = "Connection" |
@@ -146,6 +157,25 @@ connectToPeer p = do | |||
146 | 157 | ||
147 | type Wire = ConduitM Message Message (ReaderT Connection IO) | 158 | type Wire = ConduitM Message Message (ReaderT Connection IO) |
148 | 159 | ||
160 | protocolError :: ProtocolError -> Wire a | ||
161 | protocolError = monadThrow . ProtocolError | ||
162 | |||
163 | disconnectPeer :: Wire a | ||
164 | disconnectPeer = monadThrow DisconnectPeer | ||
165 | |||
166 | getExtCaps :: Wire ExtendedCaps | ||
167 | getExtCaps = do | ||
168 | capsRef <- lift $ asks connExtCaps | ||
169 | liftIO $ readIORef capsRef | ||
170 | |||
171 | setExtCaps :: ExtendedCaps -> Wire () | ||
172 | setExtCaps caps = do | ||
173 | capsRef <- lift $ asks connExtCaps | ||
174 | liftIO $ writeIORef capsRef caps | ||
175 | |||
176 | getConnection :: Wire Connection | ||
177 | getConnection = lift ask | ||
178 | |||
149 | validate :: ChannelSide -> Wire () | 179 | validate :: ChannelSide -> Wire () |
150 | validate side = await >>= maybe (return ()) yieldCheck | 180 | validate side = await >>= maybe (return ()) yieldCheck |
151 | where | 181 | where |
@@ -155,9 +185,10 @@ validate side = await >>= maybe (return ()) yieldCheck | |||
155 | Nothing -> return () | 185 | Nothing -> return () |
156 | Just ext | 186 | Just ext |
157 | | allowed caps ext -> yield msg | 187 | | allowed caps ext -> yield msg |
158 | | otherwise -> monadThrow $ ProtocolError $ InvalidMessage side ext | 188 | | otherwise -> protocolError $ InvalidMessage side ext |
159 | 189 | ||
160 | validate' action = do | 190 | validateBoth :: Wire () -> Wire () |
191 | validateBoth action = do | ||
161 | validate RemotePeer | 192 | validate RemotePeer |
162 | action | 193 | action |
163 | validate ThisPeer | 194 | validate ThisPeer |
@@ -172,17 +203,23 @@ runWire action sock = runReaderT $ | |||
172 | 203 | ||
173 | sendMessage :: PeerMessage msg => msg -> Wire () | 204 | sendMessage :: PeerMessage msg => msg -> Wire () |
174 | sendMessage msg = do | 205 | sendMessage msg = do |
175 | ecaps <- lift $ asks connExtCaps | 206 | ecaps <- getExtCaps |
176 | yield $ envelop ecaps msg | 207 | yield $ envelop ecaps msg |
177 | 208 | ||
178 | recvMessage :: Wire Message | 209 | recvMessage :: Wire Message |
179 | recvMessage = undefined | 210 | recvMessage = undefined |
180 | 211 | ||
181 | extendedHandshake :: Wire () | 212 | extendedHandshake :: ExtendedCaps -> Wire () |
182 | extendedHandshake = undefined | 213 | extendedHandshake caps = do |
214 | sendMessage $ nullExtendedHandshake caps | ||
215 | msg <- recvMessage | ||
216 | case msg of | ||
217 | Extended (EHandshake ExtendedHandshake {..}) -> | ||
218 | setExtCaps $ ehsCaps <> caps | ||
219 | _ -> protocolError HandshakeRefused | ||
183 | 220 | ||
184 | connectWire :: Handshake -> PeerAddr -> ExtendedCaps -> Wire () -> IO () | 221 | connectWire :: Handshake -> PeerAddr -> ExtendedCaps -> Wire () -> IO () |
185 | connectWire hs addr caps wire = | 222 | connectWire hs addr extCaps wire = |
186 | bracket (connectToPeer addr) close $ \ sock -> do | 223 | bracket (connectToPeer addr) close $ \ sock -> do |
187 | hs' <- initiateHandshake sock hs | 224 | hs' <- initiateHandshake sock hs |
188 | 225 | ||
@@ -193,12 +230,14 @@ connectWire hs addr caps wire = | |||
193 | throwIO $ ProtocolError $ UnexpectedPeerId (hsPeerId hs') | 230 | throwIO $ ProtocolError $ UnexpectedPeerId (hsPeerId hs') |
194 | 231 | ||
195 | let caps = hsReserved hs <> hsReserved hs' | 232 | let caps = hsReserved hs <> hsReserved hs' |
196 | if allowed caps ExtExtended | 233 | let wire' = if allowed caps ExtExtended |
197 | then return () else return () | 234 | then extendedHandshake extCaps >> wire |
235 | else wire | ||
198 | 236 | ||
199 | runWire wire sock $ Connection | 237 | extCapsRef <- newIORef def |
238 | runWire wire' sock $ Connection | ||
200 | { connCaps = caps | 239 | { connCaps = caps |
201 | , connExtCaps = def | 240 | , connExtCaps = extCapsRef |
202 | , connTopic = hsInfoHash hs | 241 | , connTopic = hsInfoHash hs |
203 | , connRemotePeerId = hsPeerId hs' | 242 | , connRemotePeerId = hsPeerId hs' |
204 | , connThisPeerId = hsPeerId hs | 243 | , connThisPeerId = hsPeerId hs |