From 9dded0e540876c9e928cfcb3c69666ce00b5852c Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Tue, 21 Aug 2018 02:13:10 -0400 Subject: Alternate session manager using IntervalSet for uniqs. --- src/Network/SessionTransports.hs | 97 +++++++++++++++++++++++++++ src/Network/Tox/Session.hs | 138 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+) create mode 100644 src/Network/SessionTransports.hs create mode 100644 src/Network/Tox/Session.hs (limited to 'src') diff --git a/src/Network/SessionTransports.hs b/src/Network/SessionTransports.hs new file mode 100644 index 00000000..17763e4e --- /dev/null +++ b/src/Network/SessionTransports.hs @@ -0,0 +1,97 @@ +{-# LANGUAGE NamedFieldPuns #-} +module Network.SessionTransports + ( Sessions + , initSessions + , newSession + , sessionHandler + ) where + +import Control.Concurrent +import Control.Concurrent.STM +import Control.Monad +import qualified Data.IntMap.Strict as IntMap + ;import Data.IntMap.Strict (IntMap) +import qualified Data.Map.Strict as Map + ;import Data.Map.Strict (Map) + +import Network.Address (SockAddr,either4or6) +import Network.QueryResponse +import qualified Data.IntervalSet as S + ;import Data.IntervalSet (IntSet) + +data Sessions x = Sessions + { sessionsByAddr :: TVar (Map SockAddr (IntMap (x -> IO Bool))) + , sessionsById :: TVar (IntMap SockAddr) + , sessionIds :: TVar IntSet + , sessionsSendRaw :: SockAddr -> x -> IO () + } + +initSessions :: (SockAddr -> x -> IO ()) -> IO (Sessions x) +initSessions send = atomically $ do + byaddr <- newTVar Map.empty + byid <- newTVar IntMap.empty + idset <- newTVar S.empty + return Sessions { sessionsByAddr = byaddr + , sessionsById = byid + , sessionIds = idset + , sessionsSendRaw = send + } + + + +rmSession :: Int -> (Maybe (IntMap x)) -> (Maybe (IntMap x)) +rmSession sid Nothing = Nothing +rmSession sid (Just m) = case IntMap.delete sid m of + m' | IntMap.null m' -> Nothing + | otherwise -> Just m' + +newSession :: Sessions raw + -> (addr -> y -> IO raw) + -> (SockAddr -> raw -> IO (Maybe (x, addr))) + -> SockAddr + -> IO (Maybe (TransportA err addr x y)) +newSession Sessions{sessionsByAddr,sessionsById,sessionIds,sessionsSendRaw} unwrap wrap addr0 = do + mvar <- newEmptyMVar + let saddr = -- Canonical in case of 6-mapped-4 addresses. + either id id $ either4or6 addr0 + handlePacket x = do + m <- wrap saddr x + case m of + Nothing -> return False + Just x' -> do putMVar mvar $! Just $! x' + return True + msid <- atomically $ do + msid <- S.nearestOutsider 0 <$> readTVar sessionIds + forM msid $ \sid -> do + modifyTVar' sessionIds $ S.insert sid + modifyTVar' sessionsById $ IntMap.insert sid saddr + modifyTVar' sessionsByAddr $ Map.insertWith IntMap.union saddr + $ IntMap.singleton sid handlePacket + return sid + forM msid $ \sid -> do + return Transport + { awaitMessage = \kont -> do + x <- takeMVar mvar + kont $! Right <$> x + , sendMessage = \addr x -> do + x' <- unwrap addr x + sessionsSendRaw saddr x' + , closeTransport = do + tryTakeMVar mvar + putMVar mvar Nothing + atomically $ do + modifyTVar' sessionIds $ S.delete sid + modifyTVar' sessionsById $ IntMap.delete sid + modifyTVar' sessionsByAddr $ Map.alter (rmSession sid) saddr + } + +sessionHandler :: Sessions x -> (SockAddr -> x -> IO (Maybe (x -> x))) +sessionHandler Sessions{sessionsByAddr} = \addr0 x -> do + let addr = -- Canonical in case of 6-mapped-4 addresses. + either id id $ either4or6 addr0 + dispatch [] = return () + dispatch (f:fs) = do b <- f x + when (not b) $ dispatch fs + fs <- atomically $ Map.lookup addr <$> readTVar sessionsByAddr + mapM_ (dispatch . IntMap.elems) fs + return Nothing -- consume all packets. diff --git a/src/Network/Tox/Session.hs b/src/Network/Tox/Session.hs new file mode 100644 index 00000000..a52e9478 --- /dev/null +++ b/src/Network/Tox/Session.hs @@ -0,0 +1,138 @@ +{-# LANGUAGE TupleSections #-} +module Network.Tox.Session where + +import Control.Concurrent.STM +import Control.Monad +import Data.Functor.Identity +import Data.Word +import Network.Socket + +import Crypto.Tox +import Data.PacketBuffer (PacketInboundEvent (..)) +import Data.Tox.Message +import Network.Lossless +import Network.QueryResponse +import Network.SessionTransports +import Network.Tox.Crypto.Transport +import Network.Tox.DHT.Transport (Cookie) +import Network.Tox.Handshake + +type SessionKey = SecretKey + +data SessionParams = SessionParams + { spCrypto :: TransportCrypto + , spSessions :: Sessions (CryptoPacket Encrypted) + , spGetSentHandshake :: SecretKey -> SockAddr + -> Cookie Identity + -> Cookie Encrypted + -> IO (Maybe (SessionKey, HandshakeData)) + , spOnNewSession :: Session -> IO () + } + +data Session = Session + { sOurKey :: SecretKey + , sTheirAddr :: SockAddr + , sSentHandshake :: HandshakeData + , sReceivedHandshake :: Handshake Identity + , sResendPackets :: [Word32] -> IO () + -- ^ If they request that we re-send certain packets, this method is how + -- that is accomplished. + , sMissingInbound :: IO ([Word32],Word32) + -- ^ This list of sequence numbers should be periodically polled and if + -- it is not empty, we should request they re-send these packets. For + -- convenience, a lower bound for the numbers in the list is also + -- returned. Suggested polling interval: a few seconds. + , sTransport :: Transport String () CryptoMessage + } + +handshakeH :: SessionParams + -> SockAddr + -> Handshake Encrypted + -> IO (Maybe a) +handshakeH sp saddr handshake = do + decryptHandshake (spCrypto sp) handshake + >>= either (\err -> return ()) + (uncurry $ plainHandshakeH sp saddr) + return Nothing + + +plainHandshakeH :: SessionParams + -> SockAddr + -> SecretKey + -> Handshake Identity + -> IO () +plainHandshakeH sp saddr skey handshake = do + let hd = runIdentity $ handshakeData handshake + sent <- spGetSentHandshake sp skey saddr (handshakeCookie handshake) (otherCookie hd) + forM_ sent $ \(hd_skey,hd_sent) -> do + sk <- SessionKeys (spCrypto sp) + hd_skey + (sessionKey hd) + <$> atomically (newTVar $ baseNonce hd) + <*> atomically (newTVar $ baseNonce hd_sent) + m <- newSession (spSessions sp) (\() p -> return p) (decryptPacket sk) saddr + forM_ m $ \t -> do + (t2,resend,getMissing) + <- lossless (\cp a -> return $ fmap (,a) $ checkLossless $ runIdentity $ pktData cp) + (\seqno p _ -> encryptPacket sk $ bookKeeping seqno p) + () + t + let _ = t :: TransportA String () (CryptoPacket Identity) (CryptoPacket Encrypted) + _ = t2 :: Transport String () CryptoMessage + spOnNewSession sp Session + { sOurKey = skey + , sTheirAddr = saddr + , sSentHandshake = hd_sent + , sReceivedHandshake = handshake + , sResendPackets = resend + , sMissingInbound = getMissing + , sTransport = t2 + } + return () + +decryptPacket :: SessionKeys -> SockAddr -> CryptoPacket Encrypted -> IO (Maybe (CryptoPacket Identity, ())) +decryptPacket sk saddr (CryptoPacket n16 ciphered) = do + (n,δ) <- atomically $ do + n <- readTVar (skNonceIncoming sk) + let δ = n16 - nonce24ToWord16 n + return ( n `addtoNonce24` fromIntegral δ, δ ) + secret <- lookupSharedSecret (skCrypto sk) (skMe sk) (skThem sk) n + case decodePlain =<< decrypt secret ciphered of + Left e -> return Nothing + Right x -> do + when ( δ > 43690 ) + $ atomically $ writeTVar (skNonceIncoming sk) (n `addtoNonce24` 21845) + return $ Just ( CryptoPacket n16 (pure x), () ) + +encryptPacket :: SessionKeys -> CryptoData -> IO (CryptoPacket Encrypted) +encryptPacket sk plain = do + n24 <- atomically $ do + n24 <- readTVar (skNonceOutgoing sk) + modifyTVar' (skNonceOutgoing sk) incrementNonce24 + return n24 + secret <- lookupSharedSecret (skCrypto sk) (skMe sk) (skThem sk) n24 + let ciphered = encrypt secret $ encodePlain $ plain + return $ CryptoPacket (nonce24ToWord16 n24) ciphered + +data SessionKeys = SessionKeys + { skCrypto :: TransportCrypto + , skMe :: SecretKey + , skThem :: PublicKey + , skNonceIncoming :: TVar Nonce24 -- +21845 when a threshold is reached. + , skNonceOutgoing :: TVar Nonce24 -- +1 on every packet + } + +bookKeeping :: SequenceInfo -> CryptoMessage -> CryptoData +bookKeeping (SequenceInfo seqno ack) m = CryptoData + { bufferStart = seqno :: Word32 + , bufferEnd = ack :: Word32 + , bufferData = m + } + +checkLossless :: CryptoData -> PacketInboundEvent CryptoMessage +checkLossless CryptoData{ bufferStart = ack + , bufferEnd = no + , bufferData = x } = tag no x ack + where + tag = case lossyness (msgID x) of Lossy -> PacketReceivedLossy + _ -> PacketReceived -- cgit v1.2.3