summaryrefslogtreecommitdiff
path: root/src/Network/BitTorrent/DHT/Search.hs
blob: 356f6fd92d6f271b06cd255bb6006335cc1a670d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
module Network.BitTorrent.DHT.Search where

import Control.Concurrent
import Control.Concurrent.Async.Pool
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import Data.Bool
import Data.Function
import Data.List
import qualified Data.Map.Strict     as Map
         ;import Data.Map.Strict     (Map)
import Data.Maybe
import qualified Data.Set            as Set
         ;import Data.Set            (Set)
import System.IO

import qualified Data.MinMaxPSQ   as MM
         ;import Data.MinMaxPSQ   (MinMaxPSQ)
import qualified Data.Wrapper.PSQ as PSQ
         ;import Data.Wrapper.PSQ (pattern (:->), Binding, PSQ)
import Network.Address hiding (NodeId)
import Network.DatagramServer.Types
import Data.Bits

data IterativeSearch dht u ip r = IterativeSearch
    { searchTarget       :: NodeId dht
    , searchQuery        :: NodeInfo dht ip u -> IO ([NodeInfo dht ip u], [r])
    , searchPendingCount :: TVar Int
    , searchQueued       :: TVar (MinMaxPSQ (NodeInfo dht ip u) (NodeDistance (NodeId dht)))
    , searchInformant    :: TVar (MinMaxPSQ (NodeInfo dht ip u) (NodeDistance (NodeId dht)))
    , searchVisited      :: TVar (Set (NodeAddr ip))
    , searchResults      :: TVar (Set r)
    }

newSearch :: ( Eq ip
             , Ord (NodeId dht)
             , FiniteBits (NodeId dht)
             ) => (NodeInfo dht ip u -> IO ([NodeInfo dht ip u], [r]))
                      -> NodeId dht -> [NodeInfo dht ip u] -> IO (IterativeSearch dht u ip r)
newSearch qry target ns = atomically $ do
    c <- newTVar 0
    q <- newTVar $ MM.fromList $ map (\n -> n :-> distance target (nodeId n)) ns
    i <- newTVar MM.empty
    v <- newTVar Set.empty
    r <- newTVar Set.empty
    return $ IterativeSearch target qry c q i v r

searchAlpha :: Int
searchAlpha = 3

searchK :: Int
searchK = 8

sendQuery :: forall a ip dht u.
            ( Ord a
            , Ord ip
            , Ord (NodeId dht)
            , FiniteBits (NodeId dht)
            ) =>
            IterativeSearch dht u ip a
            -> Binding (NodeInfo dht ip u) (NodeDistance (NodeId dht))
            -> IO ()
sendQuery IterativeSearch{..} (ni :-> d) = do
    (ns,rs) <- handle (\(SomeException e) -> return ([],[]))
                      (searchQuery ni)
    atomically $ do
        modifyTVar searchPendingCount pred
        vs <- readTVar searchVisited
        -- We only queue a node if it is not yet visited
        let insertFoundNode :: NodeInfo dht ip u
                           -> MinMaxPSQ (NodeInfo dht ip u) (NodeDistance (NodeId dht))
                           -> MinMaxPSQ (NodeInfo dht ip u) (NodeDistance (NodeId dht))
            insertFoundNode n q
             | nodeAddr n `Set.member` vs = q
             | otherwise                  = MM.insertTake searchK n (distance searchTarget $ nodeId n) q
        modifyTVar searchQueued    $ \q -> foldr insertFoundNode q ns
        modifyTVar searchInformant $ MM.insertTake searchK ni d
        modifyTVar searchResults   $ \s -> foldr Set.insert s rs


searchIsFinished :: ( Ord ip
                    , Ord (NodeId dht)
                    ) => IterativeSearch dht u ip r -> STM Bool
searchIsFinished IterativeSearch{..} = do
    q <- readTVar searchQueued
    cnt <- readTVar searchPendingCount
    informants <- readTVar searchInformant
    return $ cnt == 0
             && ( MM.null q
                  || ( MM.size informants >= searchK
                       && ( PSQ.prio (fromJust $ MM.findMax informants)
                            <= PSQ.prio (fromJust $ MM.findMin q))))

search ::
    (Ord r, Ord ip, Ord (NodeId dht), FiniteBits (NodeId dht)) =>
    IterativeSearch dht u ip r -> IO ()
search s@IterativeSearch{..} = withTaskGroup searchAlpha $ \g -> do
    fix $ \again -> do
        join $ atomically $ do
            cnt <- readTVar $ searchPendingCount
            informants <- readTVar searchInformant
            found <- MM.minView <$> readTVar searchQueued
            case found of
                Just (ni :-> d, q)
                  | (MM.size informants < searchK) && (cnt > 0 || not (MM.null q))
                    || (PSQ.prio (fromJust $ MM.findMax informants) > d)
                  -> do writeTVar searchQueued q
                        modifyTVar searchVisited $ Set.insert (nodeAddr ni)
                        modifyTVar searchPendingCount succ
                        return $ withAsync g (sendQuery s (ni :-> d)) (const again)
                _ -> do check (cnt == 0)
                        return $ return ()