{-# LANGUAGE CPP #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Network.BitTorrent.DHT.Search where import Control.Concurrent import Control.Concurrent.Async.Pool import Control.Concurrent.STM import Control.Exception import Control.Monad import Data.Bool import Data.Function import Data.List import qualified Data.Map.Strict as Map ;import Data.Map.Strict (Map) import Data.Maybe import qualified Data.Set as Set ;import Data.Set (Set) import System.IO import qualified Data.MinMaxPSQ as MM ;import Data.MinMaxPSQ (MinMaxPSQ) import qualified Data.Wrapper.PSQ as PSQ ;import Data.Wrapper.PSQ (pattern (:->), Binding, PSQ) import Network.BitTorrent.Address hiding (NodeId) import Network.DatagramServer.Types #ifdef VERSION_bencoding import Network.DHT.Mainline () import Network.DatagramServer.Mainline (KMessageOf) type Ann = () #else import Network.DatagramServer.Tox as Tox type KMessageOf = Tox.Message type Ann = Bool #endif data IterativeSearch ip r = IterativeSearch { searchTarget :: NodeId KMessageOf , searchQuery :: NodeInfo KMessageOf ip Ann -> IO ([NodeInfo KMessageOf ip Ann], [r]) , searchPendingCount :: TVar Int , searchQueued :: TVar (MinMaxPSQ (NodeInfo KMessageOf ip Ann) (NodeDistance (NodeId KMessageOf))) , searchInformant :: TVar (MinMaxPSQ (NodeInfo KMessageOf ip Ann) (NodeDistance (NodeId KMessageOf))) , searchVisited :: TVar (Set (NodeAddr ip)) , searchResults :: TVar (Set r) } newSearch :: Eq ip => (NodeInfo KMessageOf ip Ann -> IO ([NodeInfo KMessageOf ip Ann], [r])) -> NodeId KMessageOf -> [NodeInfo KMessageOf ip Ann] -> IO (IterativeSearch ip r) newSearch qry target ns = atomically $ do c <- newTVar 0 q <- newTVar $ MM.fromList $ map (\n -> n :-> distance target (nodeId n)) ns i <- newTVar MM.empty v <- newTVar Set.empty r <- newTVar Set.empty return $ IterativeSearch target qry c q i v r searchAlpha :: Int searchAlpha = 3 searchK :: Int searchK = 8 sendQuery :: forall a ip. (Ord a, Ord ip) => IterativeSearch ip a -> Binding (NodeInfo KMessageOf ip Ann) (NodeDistance (NodeId KMessageOf)) -> IO () sendQuery IterativeSearch{..} (ni :-> d) = do (ns,rs) <- handle (\(SomeException e) -> return ([],[])) (searchQuery ni) atomically $ do modifyTVar searchPendingCount pred vs <- readTVar searchVisited -- We only queue a node if it is not yet visited let insertFoundNode :: NodeInfo KMessageOf ip u -> MinMaxPSQ (NodeInfo KMessageOf ip u) (NodeDistance (NodeId KMessageOf)) -> MinMaxPSQ (NodeInfo KMessageOf ip u) (NodeDistance (NodeId KMessageOf)) insertFoundNode n q | nodeAddr n `Set.member` vs = q | otherwise = MM.insertTake searchK n (distance searchTarget $ nodeId n) q modifyTVar searchQueued $ \q -> foldr insertFoundNode q ns modifyTVar searchInformant $ MM.insertTake searchK ni d modifyTVar searchResults $ \s -> foldr Set.insert s rs searchIsFinished :: Ord ip => IterativeSearch ip r -> STM Bool searchIsFinished IterativeSearch{..} = do q <- readTVar searchQueued cnt <- readTVar searchPendingCount informants <- readTVar searchInformant return $ cnt == 0 && ( MM.null q || ( MM.size informants >= searchK && ( PSQ.prio (fromJust $ MM.findMax informants) <= PSQ.prio (fromJust $ MM.findMin q)))) search :: (Ord r, Ord ip) => IterativeSearch ip r -> IO () search s@IterativeSearch{..} = withTaskGroup searchAlpha $ \g -> do fix $ \again -> do join $ atomically $ do cnt <- readTVar $ searchPendingCount informants <- readTVar searchInformant found <- MM.minView <$> readTVar searchQueued case found of Just (ni :-> d, q) | (MM.size informants < searchK) && (cnt > 0 || not (MM.null q)) || (PSQ.prio (fromJust $ MM.findMax informants) > d) -> do writeTVar searchQueued q modifyTVar searchVisited $ Set.insert (nodeAddr ni) modifyTVar searchPendingCount succ return $ withAsync g (sendQuery s (ni :-> d)) (const again) _ -> do check (cnt == 0) return $ return ()