From d8a7ad88bfdb76b7c481c0ce89de63528a06e734 Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Fri, 27 Dec 2019 01:20:59 -0500 Subject: Made the BucketRefresher state accessible from CommonAPI. --- kad/src/Network/Kademlia/Bootstrap.hs | 44 +++++++++++++++++++++++++++++---- kad/src/Network/Kademlia/CommonAPI.hs | 8 ++++-- kad/src/Network/Kademlia/Persistence.hs | 2 +- kad/src/Network/Kademlia/Search.hs | 9 ++++--- 4 files changed, 52 insertions(+), 11 deletions(-) (limited to 'kad') diff --git a/kad/src/Network/Kademlia/Bootstrap.hs b/kad/src/Network/Kademlia/Bootstrap.hs index 08ba3318..c07b3c6c 100644 --- a/kad/src/Network/Kademlia/Bootstrap.hs +++ b/kad/src/Network/Kademlia/Bootstrap.hs @@ -16,6 +16,8 @@ module Network.Kademlia.Bootstrap where import Data.Function +import qualified Data.IntMap.Strict as IntMap + ;import Data.IntMap.Strict (IntMap) import Data.Maybe import qualified Data.Set as Set import Data.Time.Clock.POSIX (getPOSIXTime) @@ -71,10 +73,12 @@ data BucketRefresher nid ni = forall tok addr. Ord addr => BucketRefresher , -- | Timestamp of last bucket event. refreshLastTouch :: TVar POSIXTime , -- | This variable indicates whether or not we are in bootstrapping mode. - bootstrapMode :: TVar Bool + bootstrapMode :: TVar Bool , -- | When this countdown reaches 0, we exit bootstrap mode. It is decremented on -- every finished refresh. bootstrapCountdown :: TVar (Maybe Int) + -- | Internal state of background searches. Exposed for debugging purposes. + , refreshState :: TVar (IntMap [BucketSearch nid ni]) } newBucketRefresher :: ( Ord addr, Hashable addr @@ -91,6 +95,7 @@ newBucketRefresher bkts sch ping = do lasttouch <- newTVar 0 -- Would use getPOSIXTime here, or minBound, but alas... bootstrapVar <- newTVar True -- Start in bootstrapping mode. bootstrapCnt <- newTVar Nothing + st <- newTVar IntMap.empty return BucketRefresher { refreshInterval = 15 * 60 , refreshQueue = sched @@ -100,6 +105,7 @@ newBucketRefresher bkts sch ping = do , refreshLastTouch = lasttouch , bootstrapMode = bootstrapVar , bootstrapCountdown = bootstrapCnt + , refreshState = st } -- | This was added to avoid the compile error "Record update for @@ -118,6 +124,7 @@ updateRefresherIO sch ping BucketRefresher{..} = BucketRefresher , refreshLastTouch = refreshLastTouch , bootstrapMode = bootstrapMode , bootstrapCountdown = bootstrapCountdown + , refreshState = refreshState } -- | Fork a refresh loop. Kill the returned thread to terminate it. @@ -228,10 +235,29 @@ onFinishedRefresh BucketRefresher { bootstrapCountdown return $ do action ; dput XRefresh $ "BOOTSTRAP complete (" ++ show (R.shape tbl) ++ ")." else return $ do action ; dput XRefresh $ "BOOTSTRAP progress " ++ show (num,R.shape tbl,cnt) +data BucketSearch nid ni = forall addr tok. BucketSearch + { bucketSample :: nid + , bucketResults :: TVar (Set.Set ni) + , bucketFinFlag :: TVar Bool + , bucketState :: SearchState nid addr tok ni ni + , bucketThread :: ThreadId + } + +removeBucketState :: BucketSearch nid ni -> Maybe [BucketSearch nid ni] -> Maybe [BucketSearch nid ni] +removeBucketState bst Nothing = Nothing +removeBucketState bst (Just xs) = case filter (\b -> bucketThread b /= bucketThread bst) xs of + [] -> Nothing + ys -> Just ys + +insertBucketState :: BucketSearch nid ni -> Maybe [BucketSearch nid ni] -> Maybe [BucketSearch nid ni] +insertBucketState bst Nothing = Just [bst] +insertBucketState bst (Just xs) = Just (bst : xs) + refreshBucket :: (Show nid, Ord ni, Ord nid, Hashable nid, Hashable ni) => BucketRefresher nid ni -> Int -> IO Int refreshBucket r@BucketRefresher{ refreshSearch = sch - , refreshBuckets = var } + , refreshBuckets = var + , refreshState = rstate } n = do tbl <- atomically (readTVar var) let count = bktCount tbl @@ -248,13 +274,21 @@ refreshBucket r@BucketRefresher{ refreshSearch = sch dput XRefresh $ "Start refresh " ++ show (n,sample) -- Set 15 minute timeout in order to avoid overlapping refreshes. - s <- search sch tbl sample $ if n+1 == R.defaultBucketCount - then const $ return True -- Never short-circuit the last bucket. - else checkBucketFull (searchSpace sch) var resultCounter fin n + (s,thread) <- search sch tbl sample $ if n+1 == R.defaultBucketCount + then const $ return True -- Never short-circuit the last bucket. + else checkBucketFull (searchSpace sch) var resultCounter fin n + let bstate = BucketSearch sample resultCounter fin s thread + atomically $ modifyTVar' rstate $ IntMap.alter (insertBucketState bstate) n _ <- timeout (15*60*1000000) $ do atomically $ searchIsFinished s >>= check atomically $ searchCancel s dput XDHT $ "Finish refresh " ++ show (n,sample) + bg <- forkIO $ do + atomically $ do + searchIsFinished s >>= check + modifyTVar' rstate $ IntMap.alter (removeBucketState bstate) n + + labelThread bg ("backgrounded." ++ show n ++ "." ++ show sample) now <- getPOSIXTime join $ atomically $ onFinishedRefresh r n now rcount <- atomically $ do diff --git a/kad/src/Network/Kademlia/CommonAPI.hs b/kad/src/Network/Kademlia/CommonAPI.hs index 4714cecc..6d3fd16c 100644 --- a/kad/src/Network/Kademlia/CommonAPI.hs +++ b/kad/src/Network/Kademlia/CommonAPI.hs @@ -1,5 +1,8 @@ {-# LANGUAGE ExistentialQuantification #-} -module Network.Kademlia.CommonAPI where +module Network.Kademlia.CommonAPI + ( module Network.Kademlia.CommonAPI + , refreshBuckets + ) where import Control.Concurrent @@ -12,6 +15,7 @@ import qualified Data.Set as Set import Data.Time.Clock.POSIX import Data.Typeable +import Network.Kademlia.Bootstrap import Network.Kademlia.Search import Network.Kademlia.Routing as R import Crypto.Tox (SecretKey,PublicKey) @@ -29,7 +33,7 @@ data DHT = forall nid ni. ( Show ni , S.Serialize nid ) => DHT - { dhtBuckets :: TVar (BucketList ni) + { dhtBuckets :: BucketRefresher nid ni , dhtSecretKey :: STM (Maybe SecretKey) , dhtPing :: Map.Map String (DHTPing ni) , dhtQuery :: Map.Map String (DHTQuery nid ni) diff --git a/kad/src/Network/Kademlia/Persistence.hs b/kad/src/Network/Kademlia/Persistence.hs index 32ec169d..f89287fe 100644 --- a/kad/src/Network/Kademlia/Persistence.hs +++ b/kad/src/Network/Kademlia/Persistence.hs @@ -16,7 +16,7 @@ import System.IO.Error saveNodes :: String -> DHT -> IO () saveNodes netname DHT{dhtBuckets} = do - bkts <- atomically $ readTVar dhtBuckets + bkts <- atomically $ readTVar (refreshBuckets dhtBuckets) let ns = map fst $ concat $ R.toList bkts bs = J.encode ns fname = nodesFileName netname diff --git a/kad/src/Network/Kademlia/Search.hs b/kad/src/Network/Kademlia/Search.hs index 5b60c303..856a7cfc 100644 --- a/kad/src/Network/Kademlia/Search.hs +++ b/kad/src/Network/Kademlia/Search.hs @@ -194,12 +194,15 @@ search :: , PSQKey nid , PSQKey ni , Show nid - ) => Search nid addr tok ni r -> R.BucketList ni -> nid -> (r -> STM Bool) -> IO (SearchState nid addr tok ni r) + ) => Search nid addr tok ni r -> R.BucketList ni -> nid -> (r -> STM Bool) -> IO (SearchState nid addr tok ni r, ThreadId) search sch buckets target result = do let ns = R.kclosest (searchSpace sch) (searchK sch) target buckets st <- atomically $ newSearch sch target ns - t <- forkIO $ searchLoop sch target result st - return st + v <- newTVarIO False + t <- forkIO $ atomically (check =<< readTVar v) >> searchLoop sch target result st + labelThread t ("search.pending." ++ show target) + atomically $ writeTVar v True + return (st,t) searchLoop :: ( Ord addr, Ord nid, Ord ni, Show nid, Hashable nid, Hashable ni ) => Search nid addr tok ni r -- ^ Query and distance methods. -- cgit v1.2.3