diff options
Diffstat (limited to 'src/Network/BitTorrent')
-rw-r--r-- | src/Network/BitTorrent/Tracker/RPC/UDP.hs | 158 |
1 files changed, 80 insertions, 78 deletions
diff --git a/src/Network/BitTorrent/Tracker/RPC/UDP.hs b/src/Network/BitTorrent/Tracker/RPC/UDP.hs index d7b359ed..a835dc23 100644 --- a/src/Network/BitTorrent/Tracker/RPC/UDP.hs +++ b/src/Network/BitTorrent/Tracker/RPC/UDP.hs | |||
@@ -40,6 +40,7 @@ import Data.Serialize | |||
40 | import Data.Text as T | 40 | import Data.Text as T |
41 | import Data.Text.Encoding | 41 | import Data.Text.Encoding |
42 | import Data.Time | 42 | import Data.Time |
43 | import Data.Time.Clock.POSIX | ||
43 | import Data.Word | 44 | import Data.Word |
44 | import Text.Read (readMaybe) | 45 | import Text.Read (readMaybe) |
45 | import Network.Socket hiding (Connected, connect) | 46 | import Network.Socket hiding (Connected, connect) |
@@ -64,15 +65,21 @@ defMinTimeout = 15 * sec | |||
64 | defMaxTimeout :: Int | 65 | defMaxTimeout :: Int |
65 | defMaxTimeout = 15 * 2 ^ (8 :: Int) * sec | 66 | defMaxTimeout = 15 * 2 ^ (8 :: Int) * sec |
66 | 67 | ||
68 | -- announce request packet | ||
69 | defMaxPacketSize :: Int | ||
70 | defMaxPacketSize = 98 | ||
71 | |||
67 | data Options = Options | 72 | data Options = Options |
68 | { optMinTimeout :: {-# UNPACK #-} !Int | 73 | { optMaxPacketSize :: {-# UNPACK #-} !Int |
69 | , optMaxTimeout :: {-# UNPACK #-} !Int | 74 | , optMinTimeout :: {-# UNPACK #-} !Int |
75 | , optMaxTimeout :: {-# UNPACK #-} !Int | ||
70 | } deriving (Show, Eq) | 76 | } deriving (Show, Eq) |
71 | 77 | ||
72 | instance Default Options where | 78 | instance Default Options where |
73 | def = Options | 79 | def = Options |
74 | { optMinTimeout = defMinTimeout | 80 | { optMaxPacketSize = defMaxPacketSize |
75 | , optMaxTimeout = defMaxTimeout | 81 | , optMinTimeout = defMinTimeout |
82 | , optMaxTimeout = defMaxTimeout | ||
76 | } | 83 | } |
77 | 84 | ||
78 | data Manager = Manager | 85 | data Manager = Manager |
@@ -134,9 +141,6 @@ newtype ConnectionId = ConnectionId Word64 | |||
134 | instance Show ConnectionId where | 141 | instance Show ConnectionId where |
135 | showsPrec _ (ConnectionId cid) = showString "0x" <> showHex cid | 142 | showsPrec _ (ConnectionId cid) = showString "0x" <> showHex cid |
136 | 143 | ||
137 | genConnectionId :: IO ConnectionId | ||
138 | genConnectionId = ConnectionId <$> genToken | ||
139 | |||
140 | initialConnectionId :: ConnectionId | 144 | initialConnectionId :: ConnectionId |
141 | initialConnectionId = ConnectionId 0x41727101980 | 145 | initialConnectionId = ConnectionId 0x41727101980 |
142 | 146 | ||
@@ -266,16 +270,17 @@ instance Serialize (Transaction Response) where | |||
266 | connectionLifetime :: NominalDiffTime | 270 | connectionLifetime :: NominalDiffTime |
267 | connectionLifetime = 60 | 271 | connectionLifetime = 60 |
268 | 272 | ||
269 | connectionLifetimeServer :: NominalDiffTime | ||
270 | connectionLifetimeServer = 120 | ||
271 | |||
272 | data Connection = Connection | 273 | data Connection = Connection |
273 | { connectionId :: ConnectionId | 274 | { connectionId :: ConnectionId |
274 | , connectionTimestamp :: UTCTime | 275 | , connectionTimestamp :: UTCTime |
275 | } deriving Show | 276 | } deriving Show |
277 | |||
278 | -- placeholder for the first 'connect' | ||
279 | initialConnection :: Connection | ||
280 | initialConnection = Connection initialConnectionId (posixSecondsToUTCTime 0) | ||
276 | 281 | ||
277 | initialConnection :: IO Connection | 282 | establishedConnection :: ConnectionId -> IO Connection |
278 | initialConnection = Connection initialConnectionId <$> getCurrentTime | 283 | establishedConnection cid = Connection cid <$> getCurrentTime |
279 | 284 | ||
280 | isExpired :: Connection -> IO Bool | 285 | isExpired :: Connection -> IO Bool |
281 | isExpired Connection {..} = do | 286 | isExpired Connection {..} = do |
@@ -284,46 +289,21 @@ isExpired Connection {..} = do | |||
284 | return $ timeDiff > connectionLifetime | 289 | return $ timeDiff > connectionLifetime |
285 | 290 | ||
286 | {----------------------------------------------------------------------- | 291 | {----------------------------------------------------------------------- |
287 | RPC | 292 | -- Basic transaction |
288 | -----------------------------------------------------------------------} | 293 | -----------------------------------------------------------------------} |
289 | 294 | ||
290 | maxPacketSize :: Int | ||
291 | maxPacketSize = 98 -- announce request packet | ||
292 | |||
293 | call :: Manager -> SockAddr -> ByteString -> IO ByteString | 295 | call :: Manager -> SockAddr -> ByteString -> IO ByteString |
294 | call Manager {..} addr arg = do | 296 | call Manager {..} addr arg = do |
295 | BS.sendAllTo sock arg addr | 297 | BS.sendAllTo sock arg addr |
296 | (res, addr') <- BS.recvFrom sock maxPacketSize | 298 | (res, addr') <- BS.recvFrom sock (optMaxPacketSize options) |
297 | unless (addr' == addr) $ do | 299 | unless (addr' == addr) $ do |
298 | throwIO $ userError "address mismatch" | 300 | throwIO $ userError "address mismatch" |
299 | return res | 301 | return res |
300 | 302 | ||
301 | data UDPTracker = UDPTracker | 303 | transaction :: Manager -> SockAddr -> Connection -> Request -> IO Response |
302 | { trackerURI :: URI | 304 | transaction m addr conn request = do |
303 | , trackerConnection :: IORef Connection | ||
304 | } | ||
305 | |||
306 | updateConnection :: ConnectionId -> UDPTracker -> IO () | ||
307 | updateConnection cid UDPTracker {..} = do | ||
308 | newConnection <- Connection cid <$> getCurrentTime | ||
309 | writeIORef trackerConnection newConnection | ||
310 | |||
311 | getConnectionId :: UDPTracker -> IO ConnectionId | ||
312 | getConnectionId UDPTracker {..} | ||
313 | = connectionId <$> readIORef trackerConnection | ||
314 | |||
315 | putTracker :: UDPTracker -> IO () | ||
316 | putTracker UDPTracker {..} = do | ||
317 | print trackerURI | ||
318 | print =<< readIORef trackerConnection | ||
319 | |||
320 | transaction :: Manager -> UDPTracker -> Request -> IO Response | ||
321 | transaction m tracker @ UDPTracker {..} request = do | ||
322 | cid <- getConnectionId tracker | ||
323 | tid <- genTransactionId | 305 | tid <- genTransactionId |
324 | let trans = TransactionQ cid tid request | 306 | let trans = TransactionQ (connectionId conn) tid request |
325 | |||
326 | addr <- getTrackerAddr m trackerURI | ||
327 | res <- call m addr (encode trans) | 307 | res <- call m addr (encode trans) |
328 | case decode res of | 308 | case decode res of |
329 | Right (TransactionR {..}) | 309 | Right (TransactionR {..}) |
@@ -331,47 +311,48 @@ transaction m tracker @ UDPTracker {..} request = do | |||
331 | | otherwise -> throwIO $ userError "transaction id mismatch" | 311 | | otherwise -> throwIO $ userError "transaction id mismatch" |
332 | Left msg -> throwIO $ userError msg | 312 | Left msg -> throwIO $ userError msg |
333 | 313 | ||
334 | connectUDP :: Manager -> UDPTracker -> IO ConnectionId | 314 | {----------------------------------------------------------------------- |
335 | connectUDP m tracker = do | 315 | -- Connection cache |
336 | resp <- transaction m tracker Connect | 316 | -----------------------------------------------------------------------} |
317 | |||
318 | connect :: Manager -> SockAddr -> Connection -> IO ConnectionId | ||
319 | connect m addr conn = do | ||
320 | resp <- transaction m addr conn Connect | ||
337 | case resp of | 321 | case resp of |
338 | Connected cid -> return cid | 322 | Connected cid -> return cid |
339 | Failed msg -> throwIO $ userError $ T.unpack msg | 323 | Failed msg -> throwIO $ userError $ T.unpack msg |
340 | _ -> throwIO $ userError "connect: response type mismatch" | 324 | _ -> throwIO $ userError "connect: response type mismatch" |
341 | 325 | ||
342 | connect :: Manager -> URI -> IO UDPTracker | 326 | newConnection :: Manager -> SockAddr -> IO Connection |
343 | connect m uri = do | 327 | newConnection m addr = do |
344 | tracker <- UDPTracker uri <$> (newIORef =<< initialConnection) | 328 | connId <- connect m addr initialConnection |
345 | connId <- connectUDP m tracker | 329 | establishedConnection connId |
346 | updateConnection connId tracker | ||
347 | return tracker | ||
348 | 330 | ||
349 | freshConnection :: Manager -> UDPTracker -> IO () | 331 | refreshConnection :: Manager -> SockAddr -> Connection -> IO Connection |
350 | freshConnection m tracker @ UDPTracker {..} = do | 332 | refreshConnection mgr addr conn = do |
351 | conn <- readIORef trackerConnection | ||
352 | expired <- isExpired conn | 333 | expired <- isExpired conn |
353 | when expired $ do | 334 | if expired |
354 | connId <- connectUDP m tracker | 335 | then do |
355 | updateConnection connId tracker | 336 | connId <- connect mgr addr conn |
356 | 337 | establishedConnection connId | |
357 | getConnection :: Manager -> URI -> IO Connection | 338 | else do |
358 | getConnection _ = undefined | 339 | return conn |
359 | 340 | ||
360 | announce :: Manager -> AnnounceQuery -> UDPTracker -> IO AnnounceInfo | 341 | withCache :: Manager -> SockAddr |
361 | announce m ann tracker = do | 342 | -> (Maybe Connection -> IO Connection) -> IO Connection |
362 | freshConnection m tracker | 343 | withCache mgr addr action = do |
363 | resp <- transaction m tracker (Announce ann) | 344 | cache <- readIORef (connectionCache mgr) |
364 | case resp of | 345 | conn <- action (M.lookup addr cache) |
365 | Announced info -> return info | 346 | writeIORef (connectionCache mgr) (M.insert addr conn cache) |
366 | _ -> fail "announce: response type mismatch" | 347 | return conn |
348 | |||
349 | getConnection :: Manager -> SockAddr -> IO Connection | ||
350 | getConnection mgr addr = withCache mgr addr $ | ||
351 | maybe (newConnection mgr addr) (refreshConnection mgr addr) | ||
367 | 352 | ||
368 | scrape :: Manager -> ScrapeQuery -> UDPTracker -> IO ScrapeInfo | 353 | {----------------------------------------------------------------------- |
369 | scrape m ihs tracker = do | 354 | -- RPC |
370 | freshConnection m tracker | 355 | -----------------------------------------------------------------------} |
371 | resp <- transaction m tracker (Scrape ihs) | ||
372 | case resp of | ||
373 | Scraped info -> return $ L.zip ihs info | ||
374 | _ -> fail "scrape: response type mismatch" | ||
375 | 356 | ||
376 | retransmission :: Options -> IO a -> IO a | 357 | retransmission :: Options -> IO a -> IO a |
377 | retransmission Options {..} action = go optMinTimeout | 358 | retransmission Options {..} action = go optMinTimeout |
@@ -381,3 +362,24 @@ retransmission Options {..} action = go optMinTimeout | |||
381 | | otherwise = do | 362 | | otherwise = do |
382 | r <- timeout curTimeout action | 363 | r <- timeout curTimeout action |
383 | maybe (go (2 * curTimeout)) return r | 364 | maybe (go (2 * curTimeout)) return r |
365 | |||
366 | queryTracker :: Manager -> URI -> Request -> IO Response | ||
367 | queryTracker mgr uri req = do | ||
368 | addr <- getTrackerAddr mgr uri | ||
369 | retransmission (options mgr) $ do | ||
370 | conn <- getConnection mgr addr | ||
371 | transaction mgr addr conn req | ||
372 | |||
373 | announce :: Manager -> URI -> AnnounceQuery -> IO AnnounceInfo | ||
374 | announce mgr uri q = do | ||
375 | resp <- queryTracker mgr uri (Announce q) | ||
376 | case resp of | ||
377 | Announced info -> return info | ||
378 | _ -> fail "announce: response type mismatch" | ||
379 | |||
380 | scrape :: Manager -> URI -> ScrapeQuery -> IO ScrapeInfo | ||
381 | scrape mgr uri ihs = do | ||
382 | resp <- queryTracker mgr uri (Scrape ihs) | ||
383 | case resp of | ||
384 | Scraped info -> return $ L.zip ihs info | ||
385 | _ -> fail "scrape: response type mismatch" | ||