diff options
Diffstat (limited to 'src/Network/BitTorrent/Tracker/RPC/UDP.hs')
-rw-r--r-- | src/Network/BitTorrent/Tracker/RPC/UDP.hs | 133 |
1 files changed, 82 insertions, 51 deletions
diff --git a/src/Network/BitTorrent/Tracker/RPC/UDP.hs b/src/Network/BitTorrent/Tracker/RPC/UDP.hs index a132524c..a3927c2c 100644 --- a/src/Network/BitTorrent/Tracker/RPC/UDP.hs +++ b/src/Network/BitTorrent/Tracker/RPC/UDP.hs | |||
@@ -14,22 +14,26 @@ | |||
14 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | 14 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} |
15 | {-# LANGUAGE TypeFamilies #-} | 15 | {-# LANGUAGE TypeFamilies #-} |
16 | module Network.BitTorrent.Tracker.RPC.UDP | 16 | module Network.BitTorrent.Tracker.RPC.UDP |
17 | ( UDPTracker | 17 | ( -- * Manager |
18 | , putTracker | 18 | Options (..) |
19 | , Manager | ||
20 | , newManager | ||
21 | , closeManager | ||
22 | , withManager | ||
19 | 23 | ||
20 | -- * RPC | 24 | -- * RPC |
21 | , connect | ||
22 | , announce | 25 | , announce |
23 | , scrape | 26 | , scrape |
24 | , retransmission | ||
25 | ) where | 27 | ) where |
26 | 28 | ||
27 | import Control.Applicative | 29 | import Control.Applicative |
28 | import Control.Exception | 30 | import Control.Exception |
29 | import Control.Monad | 31 | import Control.Monad |
30 | import Data.ByteString (ByteString) | 32 | import Data.ByteString (ByteString) |
33 | import Data.Default | ||
31 | import Data.IORef | 34 | import Data.IORef |
32 | import Data.List as L | 35 | import Data.List as L |
36 | import Data.Map as M | ||
33 | import Data.Maybe | 37 | import Data.Maybe |
34 | import Data.Monoid | 38 | import Data.Monoid |
35 | import Data.Serialize | 39 | import Data.Serialize |
@@ -48,6 +52,49 @@ import Numeric | |||
48 | import Network.BitTorrent.Tracker.Message | 52 | import Network.BitTorrent.Tracker.Message |
49 | 53 | ||
50 | {----------------------------------------------------------------------- | 54 | {----------------------------------------------------------------------- |
55 | -- Manager | ||
56 | -----------------------------------------------------------------------} | ||
57 | |||
58 | sec :: Int | ||
59 | sec = 1000000 | ||
60 | |||
61 | defMinTimeout :: Int | ||
62 | defMinTimeout = 15 * sec | ||
63 | |||
64 | defMaxTimeout :: Int | ||
65 | defMaxTimeout = 15 * 2 ^ (8 :: Int) * sec | ||
66 | |||
67 | data Options = Options | ||
68 | { optMinTimeout :: {-# UNPACK #-} !Int | ||
69 | , optMaxTimeout :: {-# UNPACK #-} !Int | ||
70 | } deriving (Show, Eq) | ||
71 | |||
72 | instance Default Options where | ||
73 | def = Options | ||
74 | { optMinTimeout = defMinTimeout | ||
75 | , optMaxTimeout = defMaxTimeout | ||
76 | } | ||
77 | |||
78 | data Manager = Manager | ||
79 | { options :: !Options | ||
80 | , sock :: !Socket | ||
81 | -- , dnsCache :: !(IORef (Map URI SockAddr)) | ||
82 | , connectionCache :: !(IORef (Map SockAddr Connection)) | ||
83 | -- , pendingResps :: !(IORef (Map Connection [MessageId])) | ||
84 | } | ||
85 | |||
86 | newManager :: Options -> IO Manager | ||
87 | newManager opts = Manager opts | ||
88 | <$> socket AF_INET Datagram defaultProtocol | ||
89 | <*> newIORef M.empty | ||
90 | |||
91 | closeManager :: Manager -> IO () | ||
92 | closeManager Manager {..} = close sock | ||
93 | |||
94 | withManager :: Options -> (Manager -> IO a) -> IO a | ||
95 | withManager opts = bracket (newManager opts) closeManager | ||
96 | |||
97 | {----------------------------------------------------------------------- | ||
51 | Tokens | 98 | Tokens |
52 | -----------------------------------------------------------------------} | 99 | -----------------------------------------------------------------------} |
53 | 100 | ||
@@ -235,16 +282,13 @@ getTrackerAddr URI { uriAuthority = Just (URIAuth {..}) } = do | |||
235 | _ -> fail "getTrackerAddr: unable to lookup host addr" | 282 | _ -> fail "getTrackerAddr: unable to lookup host addr" |
236 | getTrackerAddr _ = fail "getTrackerAddr: hostname unknown" | 283 | getTrackerAddr _ = fail "getTrackerAddr: hostname unknown" |
237 | 284 | ||
238 | call :: SockAddr -> ByteString -> IO ByteString | 285 | call :: Manager -> SockAddr -> ByteString -> IO ByteString |
239 | call addr arg = bracket open close rpc | 286 | call Manager {..} addr arg = do |
240 | where | 287 | BS.sendAllTo sock arg addr |
241 | open = socket AF_INET Datagram defaultProtocol | 288 | (res, addr') <- BS.recvFrom sock maxPacketSize |
242 | rpc sock = do | 289 | unless (addr' == addr) $ do |
243 | BS.sendAllTo sock arg addr | 290 | throwIO $ userError "address mismatch" |
244 | (res, addr') <- BS.recvFrom sock maxPacketSize | 291 | return res |
245 | unless (addr' == addr) $ do | ||
246 | throwIO $ userError "address mismatch" | ||
247 | return res | ||
248 | 292 | ||
249 | data UDPTracker = UDPTracker | 293 | data UDPTracker = UDPTracker |
250 | { trackerURI :: URI | 294 | { trackerURI :: URI |
@@ -265,77 +309,64 @@ putTracker UDPTracker {..} = do | |||
265 | print trackerURI | 309 | print trackerURI |
266 | print =<< readIORef trackerConnection | 310 | print =<< readIORef trackerConnection |
267 | 311 | ||
268 | transaction :: UDPTracker -> Request -> IO Response | 312 | transaction :: Manager -> UDPTracker -> Request -> IO Response |
269 | transaction tracker @ UDPTracker {..} request = do | 313 | transaction m tracker @ UDPTracker {..} request = do |
270 | cid <- getConnectionId tracker | 314 | cid <- getConnectionId tracker |
271 | tid <- genTransactionId | 315 | tid <- genTransactionId |
272 | let trans = TransactionQ cid tid request | 316 | let trans = TransactionQ cid tid request |
273 | 317 | ||
274 | addr <- getTrackerAddr trackerURI | 318 | addr <- getTrackerAddr trackerURI |
275 | res <- call addr (encode trans) | 319 | res <- call m addr (encode trans) |
276 | case decode res of | 320 | case decode res of |
277 | Right (TransactionR {..}) | 321 | Right (TransactionR {..}) |
278 | | tid == transIdR -> return response | 322 | | tid == transIdR -> return response |
279 | | otherwise -> throwIO $ userError "transaction id mismatch" | 323 | | otherwise -> throwIO $ userError "transaction id mismatch" |
280 | Left msg -> throwIO $ userError msg | 324 | Left msg -> throwIO $ userError msg |
281 | 325 | ||
282 | connectUDP :: UDPTracker -> IO ConnectionId | 326 | connectUDP :: Manager -> UDPTracker -> IO ConnectionId |
283 | connectUDP tracker = do | 327 | connectUDP m tracker = do |
284 | resp <- transaction tracker Connect | 328 | resp <- transaction m tracker Connect |
285 | case resp of | 329 | case resp of |
286 | Connected cid -> return cid | 330 | Connected cid -> return cid |
287 | Failed msg -> throwIO $ userError $ T.unpack msg | 331 | Failed msg -> throwIO $ userError $ T.unpack msg |
288 | _ -> throwIO $ userError "connect: response type mismatch" | 332 | _ -> throwIO $ userError "connect: response type mismatch" |
289 | 333 | ||
290 | connect :: URI -> IO UDPTracker | 334 | connect :: Manager -> URI -> IO UDPTracker |
291 | connect uri = do | 335 | connect m uri = do |
292 | tracker <- UDPTracker uri <$> (newIORef =<< initialConnection) | 336 | tracker <- UDPTracker uri <$> (newIORef =<< initialConnection) |
293 | connId <- connectUDP tracker | 337 | connId <- connectUDP m tracker |
294 | updateConnection connId tracker | 338 | updateConnection connId tracker |
295 | return tracker | 339 | return tracker |
296 | 340 | ||
297 | freshConnection :: UDPTracker -> IO () | 341 | freshConnection :: Manager -> UDPTracker -> IO () |
298 | freshConnection tracker @ UDPTracker {..} = do | 342 | freshConnection m tracker @ UDPTracker {..} = do |
299 | conn <- readIORef trackerConnection | 343 | conn <- readIORef trackerConnection |
300 | expired <- isExpired conn | 344 | expired <- isExpired conn |
301 | when expired $ do | 345 | when expired $ do |
302 | connId <- connectUDP tracker | 346 | connId <- connectUDP m tracker |
303 | updateConnection connId tracker | 347 | updateConnection connId tracker |
304 | 348 | ||
305 | announce :: AnnounceQuery -> UDPTracker -> IO AnnounceInfo | 349 | announce :: Manager -> AnnounceQuery -> UDPTracker -> IO AnnounceInfo |
306 | announce ann tracker = do | 350 | announce m ann tracker = do |
307 | freshConnection tracker | 351 | freshConnection m tracker |
308 | resp <- transaction tracker (Announce ann) | 352 | resp <- transaction m tracker (Announce ann) |
309 | case resp of | 353 | case resp of |
310 | Announced info -> return info | 354 | Announced info -> return info |
311 | _ -> fail "announce: response type mismatch" | 355 | _ -> fail "announce: response type mismatch" |
312 | 356 | ||
313 | scrape :: ScrapeQuery -> UDPTracker -> IO ScrapeInfo | 357 | scrape :: Manager -> ScrapeQuery -> UDPTracker -> IO ScrapeInfo |
314 | scrape ihs tracker = do | 358 | scrape m ihs tracker = do |
315 | freshConnection tracker | 359 | freshConnection m tracker |
316 | resp <- transaction tracker (Scrape ihs) | 360 | resp <- transaction m tracker (Scrape ihs) |
317 | case resp of | 361 | case resp of |
318 | Scraped info -> return $ L.zip ihs info | 362 | Scraped info -> return $ L.zip ihs info |
319 | _ -> fail "scrape: response type mismatch" | 363 | _ -> fail "scrape: response type mismatch" |
320 | 364 | ||
321 | {----------------------------------------------------------------------- | 365 | retransmission :: Options -> IO a -> IO a |
322 | Retransmission | 366 | retransmission Options {..} action = go optMinTimeout |
323 | -----------------------------------------------------------------------} | ||
324 | |||
325 | sec :: Int | ||
326 | sec = 1000000 | ||
327 | |||
328 | minTimeout :: Int | ||
329 | minTimeout = 15 * sec | ||
330 | |||
331 | maxTimeout :: Int | ||
332 | maxTimeout = 15 * 2 ^ (8 :: Int) * sec | ||
333 | |||
334 | retransmission :: IO a -> IO a | ||
335 | retransmission action = go minTimeout | ||
336 | where | 367 | where |
337 | go curTimeout | 368 | go curTimeout |
338 | | maxTimeout < curTimeout = throwIO $ userError "tracker down" | 369 | | curTimeout > optMaxTimeout = throwIO $ userError "tracker down" |
339 | | otherwise = do | 370 | | otherwise = do |
340 | r <- timeout curTimeout action | 371 | r <- timeout curTimeout action |
341 | maybe (go (2 * curTimeout)) return r | 372 | maybe (go (2 * curTimeout)) return r |