summaryrefslogtreecommitdiff
path: root/dht/TCPProber.hs
blob: ff68ba07308b59345420df3d65cfe5b030113136 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
{-# LANGUAGE CPP        #-}
{-# LANGUAGE LambdaCase #-}
module TCPProber where

#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
#else
import Control.Concurrent
import GHC.Conc
#endif

import Control.Arrow
import Control.Concurrent.STM
import Control.Monad
import Data.Function
import Data.IP
import Data.Maybe
import Data.Time.Clock.POSIX
import Network.Socket
import System.Timeout

import DPut
import DebugTag
import Crypto.Tox
import Data.Wrapper.PSQ as PSQ
import Network.Kademlia.Search
import Network.Tox.NodeId
import qualified Network.Tox.TCP as TCP

resolvePort :: TCP.RelayClient -> NodeInfo -> IO (Maybe PortNumber)
resolvePort tcp ni = do
    got <- newTVarIO Nothing
    cnt <- newTVarIO 0
    let n port = TCP.NodeInfo ni port
        forkPort port = do
            atomically $ modifyTVar' cnt succ
            t <- forkIO $ do
                dput XTCP $ "TCP-probe pinging " ++ show (n port)
                m <- TCP.tcpPing tcp $ n port
                atomically $ do
                    mg <- readTVar got
                    when (isNothing mg)
                        $ forM_ m $ \_ -> writeTVar got $ Just port
                    modifyTVar' cnt pred
            labelThread t $ "probe." ++ show port ++ "." ++ show (nodeId ni)
            return t
        readResult = atomically $ do
            m <- readTVar got
            case m of
                Just v  -> return $ Just v
                Nothing -> readTVar cnt >>= check . (== 0) >> return Nothing
    t443 <- forkPort 443
    t80 <- forkPort 80
    p <- timeout 1000000 readResult >>= \case
            Just (Just p) -> do
                killThread t443
                killThread t80
                return $ Just p
            _ -> do
                let uport = nodePort ni
                tudp <- forM (guard $ uport `notElem` [443,80,3389,33445])
                            $ \() -> forkPort uport
                t3389 <- forkPort 3389
                t33445 <- forkPort 33445
                p <- readResult
                mapM_ killThread [t443,t80,t3389,t33445]
                mapM_ killThread (tudp :: Maybe ThreadId)
                return p
    return p

data TCPProber = TCPProber
    { probeQueue :: TVar (PSQ' NodeId POSIXTime SockAddr{-UDP-})
    , probeSpill :: TVar (PSQ' NodeId POSIXTime SockAddr{-UDP-})
    , probeSpillCount :: TVar Int -- Data.HashPSQ has O(n) size, so we keep a count.
    , probeCache :: TVar (PSQ' NodeId POSIXTime (SockAddr{-UDP-},PortNumber{-TCP-}))
    , probeCacheCount :: TVar Int
    }

newProber :: STM TCPProber
newProber = do
    q        <- newTVar PSQ.empty
    spill    <- newTVar PSQ.empty
    spillcnt <- newTVar 0
    cache    <- newTVar PSQ.empty
    cachecnt <- newTVar 0
    return TCPProber
        { probeQueue      = q
        , probeSpill      = spill
        , probeSpillCount = spillcnt
        , probeCache      = cache
        , probeCacheCount = cachecnt
        }


enqueueProbe :: TCPProber -> NodeInfo -> IO ()
enqueueProbe prober ni = do
    tm <- getPOSIXTime
    atomically $ do
        spill <- readTVar (probeSpill prober)
        cache <- readTVar (probeCache prober)
        q     <- readTVar (probeQueue prober)
        let bump var x = modifyTVar' var $ insert' (nodeId ni) x tm
        case PSQ.lookup (nodeId ni) cache of
            Just (tm, x)                       -> bump (probeCache prober) x
            Nothing | member (nodeId ni) spill -> bump (probeSpill prober) (nodeAddr ni)
                    | member (nodeId ni) q     -> return ()
                    | otherwise                -> bump (probeQueue prober) (nodeAddr ni)

maxSpill :: Int
maxSpill = 100

maxCache :: Int
maxCache = 50

runProbeQueue :: TCPProber -> TCP.RelayClient -> Int -> IO ()
runProbeQueue prober client maxjobs = do
    jcnt <- newTVarIO 0
    fix $ \loop -> do
        (tm, mni) <- atomically $ do
            j <- readTVar jcnt
            check (j < maxjobs)
            q <- readTVar $ probeQueue prober
            case minView q of
                Nothing                       -> retry
                Just (Binding nid saddr tm,q') -> do
                    writeTVar (probeQueue prober) q'
                    return (tm, nodeInfo nid saddr)
        forM_ mni $ \ni -> do
            atomically $ modifyTVar' jcnt succ
            t <- forkIO $ do
                m <- resolvePort client ni
                atomically $ case m of
                    Nothing -> do
                        pcnt <- readTVar (probeSpillCount prober)
                        modifyTVar' (probeSpill prober) $ insert' (nodeId ni) (nodeAddr ni) tm
                        if (pcnt == maxSpill)
                            then modifyTVar' (probeSpill prober) deleteMin
                            else modifyTVar' (probeSpillCount prober) succ
                    Just p  -> do
                        ccnt <- readTVar (probeCacheCount prober)
                        modifyTVar' (probeCache prober) $ insert' (nodeId ni) (nodeAddr ni,p) tm
                        if (ccnt == maxCache)
                            then modifyTVar' (probeCache prober) deleteMin
                            else modifyTVar' (probeCacheCount prober) succ
                atomically $ modifyTVar' jcnt pred
            labelThread t ("probe."++show ni)
        loop


getNodes :: TCPProber -> TCP.TCPClient err Nonce8 -> NodeId -> TCP.NodeInfo -> IO (Maybe ([TCP.NodeInfo],[TCP.NodeInfo],Maybe ()))
getNodes prober tcp seeking dst = do
    r <- TCP.getUDPNodes' tcp seeking (TCP.udpNodeInfo dst)
    dput XTCP $ "Got via TCP nodes: " ++ show r
    let tcps (ns,_,mb) = (ns',ns',mb)
         where ns' = do
                    n <- ns
                    [ TCP.NodeInfo n 0 ]
    fmap join $ forM r $ \(ns,gw) -> do
        let ts = tcps ns
        if TCP.nodeId gw == TCP.nodeId dst
            then return $ Just ts
            else do
                enqueueProbe prober (TCP.udpNodeInfo dst)
                return $ Just ts
        return $ Just ts

nodeSearch :: TCPProber -> TCP.TCPClient err Nonce8 -> Search NodeId (IP, PortNumber) () TCP.NodeInfo TCP.NodeInfo
nodeSearch prober tcp = Search
    { searchSpace       = TCP.tcpSpace
    , searchNodeAddress = TCP.nodeIP &&& TCP.tcpPort
    , searchQuery       = Left $ getNodes prober tcp
    , searchAlpha       = 8
    , searchK           = 16
    }