1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.BitTorrent.DHT.Search where
import Control.Concurrent
import Control.Concurrent.Async.Pool
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import Data.Bool
import Data.Function
import Data.List
import qualified Data.Map.Strict as Map
;import Data.Map.Strict (Map)
import Data.Maybe
import qualified Data.Set as Set
;import Data.Set (Set)
import System.IO
import qualified Data.MinMaxPSQ as MM
;import Data.MinMaxPSQ (MinMaxPSQ)
import qualified Data.Wrapper.PSQ as PSQ
;import Data.Wrapper.PSQ (pattern (:->), Binding, PSQ)
import Network.BitTorrent.Address hiding (NodeId)
import Network.RPC
#ifdef VERSION_bencoding
import Network.DHT.Mainline ()
import Network.KRPC.Message (KMessageOf)
type Ann = ()
#else
import Data.Tox as Tox
type KMessageOf = Tox.Message
type Ann = Bool
#endif
data IterativeSearch ip r = IterativeSearch
{ searchTarget :: NodeId KMessageOf
, searchQuery :: NodeInfo KMessageOf ip Ann -> IO ([NodeInfo KMessageOf ip Ann], [r])
, searchPendingCount :: TVar Int
, searchQueued :: TVar (MinMaxPSQ (NodeInfo KMessageOf ip Ann) (NodeDistance (NodeId KMessageOf)))
, searchInformant :: TVar (MinMaxPSQ (NodeInfo KMessageOf ip Ann) (NodeDistance (NodeId KMessageOf)))
, searchVisited :: TVar (Set (NodeAddr ip))
, searchResults :: TVar (Set r)
}
newSearch :: Eq ip => (NodeInfo KMessageOf ip Ann -> IO ([NodeInfo KMessageOf ip Ann], [r]))
-> NodeId KMessageOf -> [NodeInfo KMessageOf ip Ann] -> IO (IterativeSearch ip r)
newSearch qry target ns = atomically $ do
c <- newTVar 0
q <- newTVar $ MM.fromList $ map (\n -> n :-> distance target (nodeId n)) ns
i <- newTVar MM.empty
v <- newTVar Set.empty
r <- newTVar Set.empty
return $ IterativeSearch target qry c q i v r
searchAlpha :: Int
searchAlpha = 3
searchK :: Int
searchK = 8
sendQuery :: forall a ip. (Ord a, Ord ip) =>
IterativeSearch ip a
-> Binding (NodeInfo KMessageOf ip Ann) (NodeDistance (NodeId KMessageOf))
-> IO ()
sendQuery IterativeSearch{..} (ni :-> d) = do
(ns,rs) <- handle (\(SomeException e) -> return ([],[]))
(searchQuery ni)
atomically $ do
modifyTVar searchPendingCount pred
vs <- readTVar searchVisited
-- We only queue a node if it is not yet visited
let insertFoundNode :: NodeInfo KMessageOf ip u
-> MinMaxPSQ (NodeInfo KMessageOf ip u) (NodeDistance (NodeId KMessageOf))
-> MinMaxPSQ (NodeInfo KMessageOf ip u) (NodeDistance (NodeId KMessageOf))
insertFoundNode n q
| nodeAddr n `Set.member` vs = q
| otherwise = MM.insertTake searchK n (distance searchTarget $ nodeId n) q
modifyTVar searchQueued $ \q -> foldr insertFoundNode q ns
modifyTVar searchInformant $ MM.insertTake searchK ni d
modifyTVar searchResults $ \s -> foldr Set.insert s rs
searchIsFinished :: Ord ip => IterativeSearch ip r -> STM Bool
searchIsFinished IterativeSearch{..} = do
q <- readTVar searchQueued
cnt <- readTVar searchPendingCount
informants <- readTVar searchInformant
return $ cnt == 0
&& ( MM.null q
|| ( MM.size informants >= searchK
&& ( PSQ.prio (fromJust $ MM.findMax informants)
<= PSQ.prio (fromJust $ MM.findMin q))))
search ::
(Ord r, Ord ip) =>
IterativeSearch ip r -> IO ()
search s@IterativeSearch{..} = withTaskGroup searchAlpha $ \g -> do
fix $ \again -> do
join $ atomically $ do
cnt <- readTVar $ searchPendingCount
informants <- readTVar searchInformant
found <- MM.minView <$> readTVar searchQueued
case found of
Just (ni :-> d, q)
| (MM.size informants < searchK) && (cnt > 0 || not (MM.null q))
|| (PSQ.prio (fromJust $ MM.findMax informants) > d)
-> do writeTVar searchQueued q
modifyTVar searchVisited $ Set.insert (nodeAddr ni)
modifyTVar searchPendingCount succ
return $ withAsync g (sendQuery s (ni :-> d)) (const again)
_ -> do check (cnt == 0)
return $ return ()
|