From 967e902b869e7d9f3143df87d4d396e5e611cfd6 Mon Sep 17 00:00:00 2001 From: Sam Truzjan Date: Tue, 4 Feb 2014 01:12:06 +0400 Subject: Use a single socket in all UDP tracker queries --- src/Network/BitTorrent/Tracker/RPC/UDP.hs | 133 ++++++++++++++++++------------ 1 file changed, 82 insertions(+), 51 deletions(-) (limited to 'src/Network/BitTorrent/Tracker/RPC/UDP.hs') diff --git a/src/Network/BitTorrent/Tracker/RPC/UDP.hs b/src/Network/BitTorrent/Tracker/RPC/UDP.hs index a132524c..a3927c2c 100644 --- a/src/Network/BitTorrent/Tracker/RPC/UDP.hs +++ b/src/Network/BitTorrent/Tracker/RPC/UDP.hs @@ -14,22 +14,26 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} module Network.BitTorrent.Tracker.RPC.UDP - ( UDPTracker - , putTracker + ( -- * Manager + Options (..) + , Manager + , newManager + , closeManager + , withManager -- * RPC - , connect , announce , scrape - , retransmission ) where import Control.Applicative import Control.Exception import Control.Monad import Data.ByteString (ByteString) +import Data.Default import Data.IORef import Data.List as L +import Data.Map as M import Data.Maybe import Data.Monoid import Data.Serialize @@ -47,6 +51,49 @@ import Numeric import Network.BitTorrent.Tracker.Message +{----------------------------------------------------------------------- +-- Manager +-----------------------------------------------------------------------} + +sec :: Int +sec = 1000000 + +defMinTimeout :: Int +defMinTimeout = 15 * sec + +defMaxTimeout :: Int +defMaxTimeout = 15 * 2 ^ (8 :: Int) * sec + +data Options = Options + { optMinTimeout :: {-# UNPACK #-} !Int + , optMaxTimeout :: {-# UNPACK #-} !Int + } deriving (Show, Eq) + +instance Default Options where + def = Options + { optMinTimeout = defMinTimeout + , optMaxTimeout = defMaxTimeout + } + +data Manager = Manager + { options :: !Options + , sock :: !Socket +-- , dnsCache :: !(IORef (Map URI SockAddr)) + , connectionCache :: !(IORef (Map SockAddr Connection)) +-- , pendingResps :: !(IORef (Map Connection [MessageId])) + } + +newManager :: Options -> IO Manager +newManager opts = Manager opts + <$> socket AF_INET Datagram defaultProtocol + <*> newIORef M.empty + +closeManager :: Manager -> IO () +closeManager Manager {..} = close sock + +withManager :: Options -> (Manager -> IO a) -> IO a +withManager opts = bracket (newManager opts) closeManager + {----------------------------------------------------------------------- Tokens -----------------------------------------------------------------------} @@ -235,16 +282,13 @@ getTrackerAddr URI { uriAuthority = Just (URIAuth {..}) } = do _ -> fail "getTrackerAddr: unable to lookup host addr" getTrackerAddr _ = fail "getTrackerAddr: hostname unknown" -call :: SockAddr -> ByteString -> IO ByteString -call addr arg = bracket open close rpc - where - open = socket AF_INET Datagram defaultProtocol - rpc sock = do - BS.sendAllTo sock arg addr - (res, addr') <- BS.recvFrom sock maxPacketSize - unless (addr' == addr) $ do - throwIO $ userError "address mismatch" - return res +call :: Manager -> SockAddr -> ByteString -> IO ByteString +call Manager {..} addr arg = do + BS.sendAllTo sock arg addr + (res, addr') <- BS.recvFrom sock maxPacketSize + unless (addr' == addr) $ do + throwIO $ userError "address mismatch" + return res data UDPTracker = UDPTracker { trackerURI :: URI @@ -265,77 +309,64 @@ putTracker UDPTracker {..} = do print trackerURI print =<< readIORef trackerConnection -transaction :: UDPTracker -> Request -> IO Response -transaction tracker @ UDPTracker {..} request = do +transaction :: Manager -> UDPTracker -> Request -> IO Response +transaction m tracker @ UDPTracker {..} request = do cid <- getConnectionId tracker tid <- genTransactionId let trans = TransactionQ cid tid request addr <- getTrackerAddr trackerURI - res <- call addr (encode trans) + res <- call m addr (encode trans) case decode res of Right (TransactionR {..}) | tid == transIdR -> return response | otherwise -> throwIO $ userError "transaction id mismatch" Left msg -> throwIO $ userError msg -connectUDP :: UDPTracker -> IO ConnectionId -connectUDP tracker = do - resp <- transaction tracker Connect +connectUDP :: Manager -> UDPTracker -> IO ConnectionId +connectUDP m tracker = do + resp <- transaction m tracker Connect case resp of Connected cid -> return cid Failed msg -> throwIO $ userError $ T.unpack msg _ -> throwIO $ userError "connect: response type mismatch" -connect :: URI -> IO UDPTracker -connect uri = do +connect :: Manager -> URI -> IO UDPTracker +connect m uri = do tracker <- UDPTracker uri <$> (newIORef =<< initialConnection) - connId <- connectUDP tracker + connId <- connectUDP m tracker updateConnection connId tracker return tracker -freshConnection :: UDPTracker -> IO () -freshConnection tracker @ UDPTracker {..} = do +freshConnection :: Manager -> UDPTracker -> IO () +freshConnection m tracker @ UDPTracker {..} = do conn <- readIORef trackerConnection expired <- isExpired conn when expired $ do - connId <- connectUDP tracker + connId <- connectUDP m tracker updateConnection connId tracker -announce :: AnnounceQuery -> UDPTracker -> IO AnnounceInfo -announce ann tracker = do - freshConnection tracker - resp <- transaction tracker (Announce ann) +announce :: Manager -> AnnounceQuery -> UDPTracker -> IO AnnounceInfo +announce m ann tracker = do + freshConnection m tracker + resp <- transaction m tracker (Announce ann) case resp of Announced info -> return info _ -> fail "announce: response type mismatch" -scrape :: ScrapeQuery -> UDPTracker -> IO ScrapeInfo -scrape ihs tracker = do - freshConnection tracker - resp <- transaction tracker (Scrape ihs) +scrape :: Manager -> ScrapeQuery -> UDPTracker -> IO ScrapeInfo +scrape m ihs tracker = do + freshConnection m tracker + resp <- transaction m tracker (Scrape ihs) case resp of Scraped info -> return $ L.zip ihs info _ -> fail "scrape: response type mismatch" -{----------------------------------------------------------------------- - Retransmission ------------------------------------------------------------------------} - -sec :: Int -sec = 1000000 - -minTimeout :: Int -minTimeout = 15 * sec - -maxTimeout :: Int -maxTimeout = 15 * 2 ^ (8 :: Int) * sec - -retransmission :: IO a -> IO a -retransmission action = go minTimeout +retransmission :: Options -> IO a -> IO a +retransmission Options {..} action = go optMinTimeout where go curTimeout - | maxTimeout < curTimeout = throwIO $ userError "tracker down" - | otherwise = do + | curTimeout > optMaxTimeout = throwIO $ userError "tracker down" + | otherwise = do r <- timeout curTimeout action maybe (go (2 * curTimeout)) return r -- cgit v1.2.3