{-# LANGUAGE CPP #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleContexts #-} module Network.BitTorrent.DHT.Search where import Control.Concurrent import Control.Concurrent.Async.Pool import Control.Concurrent.STM import Control.Exception import Control.Monad import Data.Bool import Data.Function import Data.List import qualified Data.Map.Strict as Map ;import Data.Map.Strict (Map) import Data.Maybe import qualified Data.Set as Set ;import Data.Set (Set) import System.IO import qualified Data.MinMaxPSQ as MM ;import Data.MinMaxPSQ (MinMaxPSQ) import qualified Data.Wrapper.PSQ as PSQ ;import Data.Wrapper.PSQ (pattern (:->), Binding, PSQ, PSQKey) import Network.Address hiding (NodeId) import Network.DatagramServer.Types import Data.Bits data IterativeSearch dht u ip r = IterativeSearch { searchTarget :: NodeId dht , searchQuery :: NodeInfo dht ip u -> IO ([NodeInfo dht ip u], [r]) , searchPendingCount :: TVar Int , searchQueued :: TVar (MinMaxPSQ (NodeInfo dht ip u) (NodeDistance (NodeId dht))) , searchInformant :: TVar (MinMaxPSQ (NodeInfo dht ip u) (NodeDistance (NodeId dht))) , searchVisited :: TVar (Set (NodeAddr ip)) , searchResults :: TVar (Set r) } newSearch :: ( Eq ip , PSQKey (NodeId dht) , PSQKey (NodeInfo dht ip u) , FiniteBits (NodeId dht) ) => (NodeInfo dht ip u -> IO ([NodeInfo dht ip u], [r])) -> NodeId dht -> [NodeInfo dht ip u] -> IO (IterativeSearch dht u ip r) newSearch qry target ns = atomically $ do c <- newTVar 0 q <- newTVar $ MM.fromList $ map (\n -> n :-> distance target (nodeId n)) ns i <- newTVar MM.empty v <- newTVar Set.empty r <- newTVar Set.empty return $ IterativeSearch target qry c q i v r searchAlpha :: Int searchAlpha = 3 searchK :: Int searchK = 8 sendQuery :: forall a ip dht u. ( Ord a , Ord ip , PSQKey (NodeId dht) , PSQKey (NodeInfo dht ip u) , FiniteBits (NodeId dht) ) => IterativeSearch dht u ip a -> Binding (NodeInfo dht ip u) (NodeDistance (NodeId dht)) -> IO () sendQuery IterativeSearch{..} (ni :-> d) = do (ns,rs) <- handle (\(SomeException e) -> return ([],[])) (searchQuery ni) atomically $ do modifyTVar searchPendingCount pred vs <- readTVar searchVisited -- We only queue a node if it is not yet visited let insertFoundNode :: NodeInfo dht ip u -> MinMaxPSQ (NodeInfo dht ip u) (NodeDistance (NodeId dht)) -> MinMaxPSQ (NodeInfo dht ip u) (NodeDistance (NodeId dht)) insertFoundNode n q | nodeAddr n `Set.member` vs = q | otherwise = MM.insertTake searchK n (distance searchTarget $ nodeId n) q modifyTVar searchQueued $ \q -> foldr insertFoundNode q ns modifyTVar searchInformant $ MM.insertTake searchK ni d modifyTVar searchResults $ \s -> foldr Set.insert s rs searchIsFinished :: ( Ord ip , PSQKey (NodeId dht) , PSQKey (NodeInfo dht ip u) ) => IterativeSearch dht u ip r -> STM Bool searchIsFinished IterativeSearch{..} = do q <- readTVar searchQueued cnt <- readTVar searchPendingCount informants <- readTVar searchInformant return $ cnt == 0 && ( MM.null q || ( MM.size informants >= searchK && ( PSQ.prio (fromJust $ MM.findMax informants) <= PSQ.prio (fromJust $ MM.findMin q)))) search :: (Ord r, Ord ip, PSQKey (NodeId dht), PSQKey (NodeInfo dht ip u), FiniteBits (NodeId dht)) => IterativeSearch dht u ip r -> IO () search s@IterativeSearch{..} = withTaskGroup searchAlpha $ \g -> do fix $ \again -> do join $ atomically $ do cnt <- readTVar $ searchPendingCount informants <- readTVar searchInformant found <- MM.minView <$> readTVar searchQueued case found of Just (ni :-> d, q) | (MM.size informants < searchK) && (cnt > 0 || not (MM.null q)) || (PSQ.prio (fromJust $ MM.findMax informants) > d) -> do writeTVar searchQueued q modifyTVar searchVisited $ Set.insert (nodeAddr ni) modifyTVar searchPendingCount succ return $ withAsync g (sendQuery s (ni :-> d)) (const again) _ -> do check (cnt == 0) return $ return ()