summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorJoe Crayne <joe@jerkface.net>2020-01-03 15:35:23 -0500
committerJoe Crayne <joe@jerkface.net>2020-01-03 17:26:06 -0500
commit31b799222cb76cd0002d9a3cc5b340a7b6fed139 (patch)
tree8b834e455529fb270375e4967d1acad56553544f /server
parent1e03ed3670a8386ede93a09fa0c67785e7da6478 (diff)
server library.
Diffstat (limited to 'server')
-rw-r--r--server/server.cabal25
-rw-r--r--server/src/Connection.hs135
-rw-r--r--server/src/Connection/Tcp.hs825
-rw-r--r--server/src/Control/Concurrent/Delay.hs50
-rw-r--r--server/src/Control/Concurrent/PingMachine.hs163
-rw-r--r--server/src/ControlMaybe.hs64
-rw-r--r--server/src/DNSCache.hs286
-rw-r--r--server/src/Data/TableMethods.hs105
-rw-r--r--server/src/DebugTag.hs24
-rw-r--r--server/src/ForkLabeled.hs16
-rw-r--r--server/src/GetHostByAddr.hs78
-rw-r--r--server/src/Network/QueryResponse.hs716
-rw-r--r--server/src/Network/QueryResponse/TCP.hs225
-rw-r--r--server/src/Network/SocketLike.hs98
-rw-r--r--server/src/Network/StreamServer.hs167
-rw-r--r--server/src/SockAddr.hs14
16 files changed, 2991 insertions, 0 deletions
diff --git a/server/server.cabal b/server/server.cabal
new file mode 100644
index 00000000..95d7aacf
--- /dev/null
+++ b/server/server.cabal
@@ -0,0 +1,25 @@
1cabal-version: 2.2
2-- Initial package description 'server.cabal' generated by 'cabal init'.
3-- For further documentation, see http://haskell.org/cabal/users-guide/
4
5name: server
6version: 0.1.0.0
7synopsis: TCP/UDP server library.
8-- description:
9-- bug-reports:
10license: NONE
11-- license-file: LICENSE
12author: Joe Crayne
13maintainer: joe@jerkface.net
14-- copyright:
15category: Network
16extra-source-files: CHANGELOG.md
17
18library
19 exposed-modules: Network.QueryResponse, Network.StreamServer, Network.SocketLike, Network.QueryResponse.TCP, Data.TableMethods, Connection.Tcp, Control.Concurrent.Delay, DNSCache, GetHostByAddr, ControlMaybe, SockAddr, Control.Concurrent.PingMachine, Connection
20 other-modules: ForkLabeled, DebugTag
21 other-extensions: CPP, GADTs, LambdaCase, PartialTypeSignatures, RankNTypes, ScopedTypeVariables, TupleSections, TypeFamilies, TypeOperators, OverloadedStrings, GeneralizedNewtypeDeriving, DoAndIfThenElse, FlexibleInstances, StandaloneDeriving
22 build-depends: base, stm, bytestring, dependent-map, dependent-sum, contravariant, containers, time, network, cpu, dput-hslogger, directory, lifted-base, hashable, conduit, text, psq-wrap, minmax-psq, lifted-concurrent, word64-map, network-addr
23 hs-source-dirs: src
24 default-language: Haskell2010
25 cpp-options: -DTHREAD_DEBUG
diff --git a/server/src/Connection.hs b/server/src/Connection.hs
new file mode 100644
index 00000000..ea86f4bb
--- /dev/null
+++ b/server/src/Connection.hs
@@ -0,0 +1,135 @@
1{-# LANGUAGE DeriveFunctor #-}
2{-# LANGUAGE LambdaCase #-}
3module Connection where
4
5import Control.Applicative
6import Control.Arrow
7import Control.Concurrent.STM
8import Data.Bits
9import Data.Word
10import qualified Data.Map as Map
11 ;import Data.Map (Map)
12import Network.Socket (SockAddr(..))
13
14import Control.Concurrent.PingMachine
15
16-- | This type indicates the current status of a connection. The type
17-- parameter indicates protocol-specific status information. To present
18-- information as a user-comprehensible string, use 'showStatus'.
19data Status status
20 = Dormant
21 | InProgress status
22 | Established
23 deriving (Show,Eq,Ord,Functor)
24
25-- | A policy indicates a desired connection status.
26data Policy
27 = RefusingToConnect -- ^ We desire no connection.
28 | OpenToConnect -- ^ We will cooperate if a remote side initiates.
29 | TryingToConnect -- ^ We desire to be connected.
30 deriving (Eq,Ord,Show)
31
32-- | Information obtained via the 'connectionStatus' interface to
33-- 'Manager'.
34data Connection status = Connection
35 { connStatus :: Status status
36 , connPolicy :: Policy
37 }
38 deriving Functor
39
40-- | A 'PeerAddress' identifies an active session. For inactive sessions, multiple
41-- values may be feasible.
42
43-- We use a 'SockAddr' as it is convenient for TCP and UDP connections. But if
44-- that is not your use case, see 'uniqueAsKey'.
45newtype PeerAddress = PeerAddress { peerAddress :: SockAddr }
46 deriving (Eq,Ord,Show)
47
48-- | A 24-byte word.
49data Uniq24 = Uniq24 !Word64 !Word64 !Word64
50 deriving (Eq,Ord,Show)
51
52-- | Coerce a 'Uniq24' to a useable 'PeerAddress'. Note that this stores the
53-- special value 0 into the port number of the underlying 'SockAddr' and thus
54-- should be compatible for mixing together with TCP/UDP peers.
55uniqueAsKey :: Uniq24 -> PeerAddress
56uniqueAsKey (Uniq24 x y z) = PeerAddress $ SockAddrInet6 (fromIntegral 0) a bcde f
57 where
58 a = fromIntegral (x `shiftR` 32)
59 b = fromIntegral x
60 c = fromIntegral (y `shiftR` 32)
61 d = fromIntegral y
62 e = fromIntegral (z `shiftR` 32)
63 f = fromIntegral z
64 bcde = (b,c,d,e)
65
66-- | Inverse of 'uniqueAsKey'
67keyAsUnique :: PeerAddress -> Maybe Uniq24
68keyAsUnique (PeerAddress (SockAddrInet6 0 a bcde f)) = Just $ Uniq24 x y z
69 where
70 (b,c,d,e) = bcde
71 x = (fromIntegral a `shiftL` 32) .|. fromIntegral b
72 y = (fromIntegral c `shiftL` 32) .|. fromIntegral d
73 z = (fromIntegral e `shiftL` 32) .|. fromIntegral f
74keyAsUniq _ = Nothing
75
76
77-- | This is an interface to make or query status information about connections
78-- of a specific kind.
79--
80-- Type parameters:
81--
82-- /k/ names a connection. It should implement Ord, and can be parsed and
83-- displayed using 'stringToKey' and 'showKey'.
84--
85-- /status/ indicates the progress of a connection. It is intended as a
86-- parameter to the 'InProgress' constructor of 'Status'.
87--
88data Manager status k = Manager
89 { -- | Connect or disconnect a connection.
90 setPolicy :: k -> Policy -> IO ()
91 -- | Lookup a connection status.
92 , status :: k -> STM (Connection status)
93 -- | Obtain a list of all known connections.
94 , connections :: STM [k]
95 -- | Parse a connection key out of a string. Inverse of 'showKey'.
96 , stringToKey :: String -> Maybe k
97 -- | Convert a progress value to a string.
98 , showProgress :: status -> String
99 -- | Show a connection key as a string.
100 , showKey :: k -> String
101 -- | Obtain an address from a human-friendly name. For TCP/UDP
102 -- connections, this might be a forward-resolving DNS query.
103 , resolvePeer :: k -> IO [PeerAddress]
104 -- | This is the reverse of 'resolvePeer'. For TCP/UDP connections, this
105 -- might be a reverse-resolve DNS query.
106 , reverseAddress :: PeerAddress -> IO [k]
107 }
108
109-- | Present status information (visible in a UI) for a connection.
110showStatus :: Manager status k -> Status status -> String
111showStatus mgr Dormant = "dormant"
112showStatus mgr Established = "established"
113showStatus mgr (InProgress s) = "in progress ("++showProgress mgr s++")"
114
115
116-- | Combine two different species of 'Manager' into a single interface using
117-- 'Either' to combine key and status types.
118addManagers :: (Ord kA, Ord kB) =>
119 Manager statusA kA
120 -> Manager statusB kB
121 -> Manager (Either statusA statusB) (Either kA kB)
122addManagers mgrA mgrB = Manager
123 { setPolicy = either (setPolicy mgrA) (setPolicy mgrB)
124 , status = \case
125 Left k -> fmap Left <$> status mgrA k
126 Right k -> fmap Right <$> status mgrB k
127 , connections = do
128 as <- connections mgrA
129 bs <- connections mgrB
130 return $ map Left as ++ map Right bs
131 , stringToKey = \str -> Left <$> stringToKey mgrA str
132 <|> Right <$> stringToKey mgrB str
133 , showProgress = either (showProgress mgrA) (showProgress mgrB)
134 , showKey = either (showKey mgrA) (showKey mgrB)
135 }
diff --git a/server/src/Connection/Tcp.hs b/server/src/Connection/Tcp.hs
new file mode 100644
index 00000000..7d93e7de
--- /dev/null
+++ b/server/src/Connection/Tcp.hs
@@ -0,0 +1,825 @@
1{-# OPTIONS_HADDOCK prune #-}
2{-# LANGUAGE CPP #-}
3{-# LANGUAGE DoAndIfThenElse #-}
4{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE LambdaCase #-}
6{-# LANGUAGE NondecreasingIndentation #-}
7{-# LANGUAGE OverloadedStrings #-}
8{-# LANGUAGE RankNTypes #-}
9{-# LANGUAGE StandaloneDeriving #-}
10{-# LANGUAGE TupleSections #-}
11-----------------------------------------------------------------------------
12-- |
13-- Module : Connection.Tcp
14--
15-- Maintainer : joe@jerkface.net
16-- Stability : experimental
17--
18-- A TCP client/server library.
19--
20-- TODO:
21--
22-- * interface tweaks
23--
24module Connection.Tcp
25 ( module Connection.Tcp
26 , module Control.Concurrent.PingMachine ) where
27
28import Data.ByteString (ByteString,hGetNonBlocking)
29import qualified Data.ByteString.Char8 as S -- ( hPutStrLn, hPutStr, pack)
30import Data.Conduit ( ConduitT, Void, Flush )
31#if MIN_VERSION_containers(0,5,0)
32import qualified Data.Map.Strict as Map
33import Data.Map.Strict (Map)
34#else
35import qualified Data.Map as Map
36import Data.Map (Map)
37#endif
38import Data.Monoid ( (<>) )
39import Control.Concurrent.ThreadUtil
40
41import Control.Arrow
42import Control.Concurrent.STM
43-- import Control.Concurrent.STM.TMVar
44-- import Control.Concurrent.STM.TChan
45-- import Control.Concurrent.STM.Delay
46import Control.Exception ({-evaluate,-}handle,SomeException(..),ErrorCall(..),onException)
47import Control.Monad
48import Control.Monad.Fix
49-- import Control.Monad.STM
50-- import Control.Monad.Trans.Resource
51import Control.Monad.IO.Class (MonadIO (liftIO))
52import Data.Maybe
53import System.IO.Error (isDoesNotExistError)
54import System.IO
55 ( IOMode(..)
56 , hSetBuffering
57 , BufferMode(..)
58 , hWaitForInput
59 , hClose
60 , hIsEOF
61 , Handle
62 )
63import Network.Socket as Socket
64import Network.BSD
65 ( getProtocolNumber
66 )
67import Debug.Trace
68import Data.Time.Clock (getCurrentTime,diffUTCTime)
69-- import SockAddr ()
70-- import System.Locale (defaultTimeLocale)
71
72import qualified Data.Text as Text
73 ;import Data.Text (Text)
74import DNSCache
75import Control.Concurrent.Delay
76import Control.Concurrent.PingMachine
77import Network.StreamServer
78import Network.SocketLike hiding (sClose)
79import qualified Connection as G
80 ;import Connection (Manager (..), PeerAddress (..), Policy (..))
81import Network.Address (localhost4)
82import DPut
83import DebugTag
84
85
86type Microseconds = Int
87
88-- | This object is passed with the 'Listen' and 'Connect'
89-- instructions in order to control the behavior of the
90-- connections that are established. It is parameterized
91-- by a user-suplied type @conkey@ that is used as a lookup
92-- key for connections.
93data ConnectionParameters conkey u =
94 ConnectionParameters
95 { pingInterval :: PingInterval
96 -- ^ The miliseconds of idle to allow before a 'RequiresPing'
97 -- event is signaled.
98 , timeout :: TimeOut
99 -- ^ The miliseconds of idle after 'RequiresPing' is signaled
100 -- that are necessary for the connection to be considered
101 -- lost and signalling 'EOF'.
102 , makeConnKey :: (RestrictedSocket,(Local SockAddr, Remote SockAddr)) -> IO (conkey,u)
103 -- ^ This action creates a lookup key for a new connection. If 'duplex'
104 -- is 'True' and the result is already assocatied with an established
105 -- connection, then an 'EOF' will be forced before the the new
106 -- connection becomes active.
107 --
108 , duplex :: Bool
109 -- ^ If True, then the connection will be treated as a normal
110 -- two-way socket. Otherwise, a readable socket is established
111 -- with 'Listen' and a writable socket is established with
112 -- 'Connect' and they are associated when 'makeConnKey' yields
113 -- same value for each.
114 }
115
116-- | Use this function to select appropriate default values for
117-- 'ConnectionParameters' other than 'makeConnKey'.
118--
119-- Current defaults:
120--
121-- * 'pingInterval' = 28000
122--
123-- * 'timeout' = 2000
124--
125-- * 'duplex' = True
126--
127connectionDefaults
128 :: ((RestrictedSocket,(Local SockAddr,Remote SockAddr)) -> IO (conkey,u)) -> ConnectionParameters conkey u
129connectionDefaults f = ConnectionParameters
130 { pingInterval = 28000
131 , timeout = 2000
132 , makeConnKey = f
133 , duplex = True
134 }
135
136-- | Instructions for a 'Server' object
137--
138-- To issue a command, put it into the 'serverCommand' TMVar.
139data ServerInstruction conkey u
140 = Quit
141 -- ^ kill the server. This command is automatically issued when
142 -- the server is released.
143 | Listen SockAddr (ConnectionParameters conkey u)
144 -- ^ listen for incoming connections on the given bind address.
145 | Connect SockAddr (ConnectionParameters conkey u)
146 -- ^ connect to addresses
147 | ConnectWithEndlessRetry SockAddr
148 (ConnectionParameters conkey u)
149 Miliseconds
150 -- ^ keep retrying the connection
151 | Ignore SockAddr
152 -- ^ stop listening on specified bind address
153 | Send conkey ByteString
154 -- ^ send bytes to an established connection
155
156#ifdef TEST
157deriving instance Show conkey => Show (ServerInstruction conkey u)
158instance Show (a -> b) where show _ = "<function>"
159deriving instance Show conkey => Show (ConnectionParameters conkey u)
160#endif
161
162-- | This type specifies which which half of a half-duplex
163-- connection is of interest.
164data InOrOut = In | Out
165 deriving (Enum,Eq,Ord,Show,Read)
166
167-- | These events may be read from 'serverEvent' TChannel.
168--
169data ConnectionEvent b
170 = Connection (STM Bool) (ConduitT () b IO ()) (ConduitT (Flush b) Void IO ())
171 -- ^ A new connection was established
172 | ConnectFailure SockAddr
173 -- ^ A 'Connect' command failed.
174 | HalfConnection InOrOut
175 -- ^ Half of a half-duplex connection is avaliable.
176 | EOF
177 -- ^ A connection was terminated
178 | RequiresPing
179 -- ^ 'pingInterval' miliseconds of idle was experienced
180
181#ifdef TEST
182instance Show (IO a) where show _ = "<IO action>"
183instance Show (STM a) where show _ = "<STM action>"
184instance Eq (ByteString -> IO Bool) where (==) _ _ = True
185instance Eq (IO (Maybe ByteString)) where (==) _ _ = True
186instance Eq (STM Bool) where (==) _ _ = True
187deriving instance Show b => Show (ConnectionEvent b)
188deriving instance Eq b => Eq (ConnectionEvent b)
189#endif
190
191-- | This is the per-connection state.
192data ConnectionRecord u
193 = ConnectionRecord { ckont :: TMVar (STM (IO ())) -- ^ used to pass a continuation to update the eof-handler
194 , cstate :: ConnectionState -- ^ used to send/receive data to the connection
195 , cdata :: u -- ^ user data, stored in the connection map for convenience
196 }
197
198-- | This object accepts commands and signals events and maintains
199-- the list of currently listening ports and established connections.
200data Server a u releaseKey b
201 = Server { serverCommand :: TMVar (ServerInstruction a u)
202 , serverEvent :: TChan ((a,u), ConnectionEvent b)
203 , serverReleaseKey :: releaseKey
204 , conmap :: TVar (Map a (ConnectionRecord u))
205 , listenmap :: TVar (Map SockAddr ServerHandle)
206 , retrymap :: TVar (Map SockAddr (TVar Bool,InterruptibleDelay))
207 }
208
209control :: Server a u releaseKey b -> ServerInstruction a u -> IO ()
210control sv = atomically . putTMVar (serverCommand sv)
211
212type Allocate releaseKey m = forall b. IO b -> (b -> IO ()) -> m (releaseKey, b)
213
214noCleanUp :: MonadIO m => Allocate () m
215noCleanUp io _ = ( (,) () ) `liftM` liftIO io
216
217-- | Construct a 'Server' object. Use 'Control.Monad.Trans.Resource.ResourceT'
218-- to ensure proper cleanup. For example,
219--
220-- > import Connection.Tcp
221-- > import Control.Monad.Trans.Resource (runResourceT)
222-- > import Control.Monad.IO.Class (liftIO)
223-- > import Control.Monad.STM (atomically)
224-- > import Control.Concurrent.STM.TMVar (putTMVar)
225-- > import Control.Concurrent.STM.TChan (readTChan)
226-- >
227-- > main = runResourceT $ do
228-- > sv <- server allocate
229-- > let params = connectionDefaults (return . snd)
230-- > liftIO . atomically $ putTMVar (serverCommand sv) (Listen 2942 params)
231-- > let loop = do
232-- > (_,event) <- atomically $ readTChan (serverEvent sv)
233-- > case event of
234-- > Connection getPingFlag readData writeData -> do
235-- > forkIO $ do
236-- > fix $ \readLoop -> do
237-- > readData >>= mapM $ \bytes ->
238-- > putStrLn $ "got: " ++ show bytes
239-- > readLoop
240-- > case event of EOF -> return ()
241-- > _ -> loop
242-- > liftIO loop
243--
244-- Using 'Control.Monad.Trans.Resource.ResourceT' is optional. Pass 'noCleanUp'
245-- to do without automatic cleanup and be sure to remember to write 'Quit' to
246-- the 'serverCommand' variable.
247server ::
248 -- forall (m :: * -> *) a u conkey releaseKey.
249 (Show conkey, MonadIO m, Ord conkey) =>
250 Allocate releaseKey m
251 -> ( IO (Maybe ByteString) -> (ByteString -> IO Bool) -> ( ConduitT () x IO (), ConduitT (Flush x) Void IO () ) )
252 -> m (Server conkey u releaseKey x)
253server allocate sessionConduits = do
254 (key,cmds) <- allocate (atomically newEmptyTMVar)
255 (atomically . flip putTMVar Quit)
256 server <- liftIO . atomically $ do
257 tchan <- newTChan
258 conmap <- newTVar Map.empty
259 listenmap<- newTVar Map.empty
260 retrymap <- newTVar Map.empty
261 return Server { serverCommand = cmds
262 , serverEvent = tchan
263 , serverReleaseKey = key
264 , conmap = conmap
265 , listenmap = listenmap
266 , retrymap = retrymap
267 }
268 liftIO $ do
269 forkLabeled "server" $ fix $ \loop -> do
270 instr <- atomically $ takeTMVar cmds
271 -- warn $ "instr = " <> bshow instr
272 let again = do doit server instr
273 -- warn $ "finished " <> bshow instr
274 loop
275 case instr of Quit -> closeAll server
276 _ -> again
277 return server
278 where
279 closeAll server = do
280 listening <- atomically . readTVar $ listenmap server
281 mapM_ quitListening (Map.elems listening)
282 let stopRetry (v,d) = do atomically $ writeTVar v False
283 interruptDelay d
284 retriers <- atomically $ do
285 rmap <- readTVar $ retrymap server
286 writeTVar (retrymap server) Map.empty
287 return rmap
288 mapM_ stopRetry (Map.elems retriers)
289 cons <- atomically . readTVar $ conmap server
290 atomically $ mapM_ (connClose . cstate) (Map.elems cons)
291 atomically $ mapM_ (connWait . cstate) (Map.elems cons)
292 atomically $ writeTVar (conmap server) Map.empty
293
294
295 doit server (Listen port params) = do
296
297 listening <- Map.member port
298 `fmap` atomically (readTVar $ listenmap server)
299 when (not listening) $ do
300
301 dput XMisc $ "Started listening on "++show port
302
303 sserv <- flip streamServer [port] ServerConfig
304 { serverWarn = dput XMisc
305 , serverSession = \sock _ h -> do
306 (conkey,u) <- makeConnKey params sock
307 _ <- newConnection server sessionConduits params conkey u h In
308 return ()
309 }
310
311 atomically $ listenmap server `modifyTVar'` Map.insert port sserv
312
313 doit server (Ignore port) = do
314 dput XMisc $ "Stopping listen on "++show port
315 mb <- atomically $ do
316 map <- readTVar $ listenmap server
317 modifyTVar' (listenmap server) $ Map.delete port
318 return $ Map.lookup port map
319 maybe (return ()) quitListening $ mb
320
321 doit server (Send con bs) = do -- . void . forkIO $ do
322 map <- atomically $ readTVar (conmap server)
323 let post False = (trace ("cant send: "++show bs) $ return ())
324 post True = return ()
325 maybe (post False)
326 (post <=< flip connWrite bs . cstate)
327 $ Map.lookup con map
328
329 doit server (Connect addr params) = join $ atomically $ do
330 Map.lookup addr <$> readTVar (retrymap server)
331 >>= return . \case
332 Nothing -> forkit
333 Just (v,d) -> do b <- atomically $ readTVar v
334 interruptDelay d
335 when (not b) forkit
336 where
337 forkit = void . forkLabeled ( "Connect." ++ show addr ) $ do
338 proto <- getProtocolNumber "tcp"
339 sock <- socket (socketFamily addr) Stream proto
340 handle (\e -> do -- let t = ioeGetErrorType e
341 when (isDoesNotExistError e) $ return () -- warn "GOTCHA"
342 -- warn $ "connect-error: " <> bshow e
343 (conkey,u) <- makeConnKey params (restrictSocket sock,(Local localhost4, Remote addr)) -- XXX: ?
344 Socket.close sock
345 atomically
346 $ writeTChan (serverEvent server)
347 $ ((conkey,u),ConnectFailure addr))
348 $ do
349 connect sock addr
350 laddr <- Socket.getSocketName sock
351 (conkey,u) <- makeConnKey params (restrictSocket sock, (Local laddr, Remote addr))
352 h <- socketToHandle sock ReadWriteMode
353 newConnection server sessionConduits params conkey u h Out
354 return ()
355
356 doit server (ConnectWithEndlessRetry addr params interval) = do
357 proto <- getProtocolNumber "tcp"
358 void . forkLabeled ("ConnectWithEndlessRetry." ++ show addr) $ do
359 timer <- interruptibleDelay
360 (retryVar,action) <- atomically $ do
361 map <- readTVar (retrymap server)
362 action <- case Map.lookup addr map of
363 Nothing -> return $ return ()
364 Just (v,d) -> do writeTVar v False
365 return $ interruptDelay d
366 v <- newTVar True
367 writeTVar (retrymap server) $! Map.insert addr (v,timer) map
368 return (v,action :: IO ())
369 action
370 fix $ \retryLoop -> do
371 utc <- getCurrentTime
372 shouldRetry <- do
373 handle (\(SomeException e) -> do
374 -- Exceptions thrown by 'socket' need to be handled specially
375 -- since we don't have enough information to broadcast a ConnectFailure
376 -- on serverEvent.
377 warn $ "Failed to create socket: " <> bshow e
378 atomically $ readTVar retryVar) $ do
379 sock <- socket (socketFamily addr) Stream proto
380 handle (\(SomeException e) -> do
381 -- Any thing else goes wrong and we broadcast ConnectFailure.
382 do (conkey,u) <- makeConnKey params (restrictSocket sock,(Local localhost4, Remote addr))
383 Socket.close sock
384 atomically $ writeTChan (serverEvent server) ((conkey,u),ConnectFailure addr)
385 `onException` return ()
386 atomically $ readTVar retryVar) $ do
387 connect sock addr
388 laddr <- Socket.getSocketName sock
389 (conkey,u) <- makeConnKey params (restrictSocket sock, (Local laddr, Remote addr))
390 h <- socketToHandle sock ReadWriteMode
391 threads <- newConnection server sessionConduits params conkey u h Out
392 atomically $ do threadsWait threads
393 readTVar retryVar
394 fin_utc <- getCurrentTime
395 when shouldRetry $ do
396 let elapsed = 1000.0 * (fin_utc `diffUTCTime` utc)
397 expected = fromIntegral interval
398 when (shouldRetry && elapsed < expected) $ do
399 debugNoise $ "Waiting to retry " <> bshow addr
400 void $ startDelay timer (round $ 1000 * (expected-elapsed))
401 debugNoise $ "retry " <> bshow (shouldRetry,addr)
402 when shouldRetry $ retryLoop
403
404
405-- INTERNAL ----------------------------------------------------------
406
407{-
408hWriteUntilNothing h outs =
409 fix $ \loop -> do
410 mb <- atomically $ takeTMVar outs
411 case mb of Just bs -> do S.hPutStrLn h bs
412 warn $ "wrote " <> bs
413 loop
414 Nothing -> do warn $ "wrote Nothing"
415 hClose h
416
417-}
418connRead :: ConnectionState -> IO (Maybe ByteString)
419connRead (WriteOnlyConnection w) = do
420 -- atomically $ discardContents (threadsChannel w)
421 return Nothing
422connRead conn = do
423 c <- atomically $ getThreads
424 threadsRead c
425 where
426 getThreads =
427 case conn of SaneConnection c -> return c
428 ReadOnlyConnection c -> return c
429 ConnectionPair c w -> do
430 -- discardContents (threadsChannel w)
431 return c
432
433socketFamily :: SockAddr -> Family
434socketFamily (SockAddrInet _ _) = AF_INET
435socketFamily (SockAddrInet6 _ _ _ _) = AF_INET6
436socketFamily (SockAddrUnix _) = AF_UNIX
437
438
439conevent :: ( IO (Maybe ByteString) -> (ByteString -> IO Bool) -> ( ConduitT () x IO (), ConduitT (Flush x) Void IO () ) )
440 -> ConnectionState
441 -> ConnectionEvent x
442conevent sessionConduits con = Connection pingflag read write
443 where
444 pingflag = swapTVar (pingFlag (connPingTimer con)) False
445 (read,write) = sessionConduits (connRead con) (connWrite con)
446
447newConnection :: Ord a
448 => Server a u1 releaseKey x
449 -> ( IO (Maybe ByteString) -> (ByteString -> IO Bool) -> ( ConduitT () x IO (), ConduitT (Flush x) Void IO () ) )
450 -> ConnectionParameters conkey u
451 -> a
452 -> u1
453 -> Handle
454 -> InOrOut
455 -> IO ConnectionThreads
456newConnection server sessionConduits params conkey u h inout = do
457 hSetBuffering h NoBuffering
458 let (idle_ms,timeout_ms) =
459 case (inout,duplex params) of
460 (Out,False) -> ( 0, 0 )
461 _ -> ( pingInterval params
462 , timeout params )
463
464 new <- do pinglogic <- forkPingMachine "newConnection" idle_ms timeout_ms
465 connectionThreads h pinglogic
466 started <- atomically $ newEmptyTMVar
467 kontvar <- atomically newEmptyTMVar
468 -- XXX: Why does kontvar store STM (IO ()) instead of just IO () ?
469 let _ = kontvar :: TMVar (STM (IO ()))
470 forkLabeled ("connecting...") $ do
471 getkont <- atomically $ takeTMVar kontvar
472 kont <- atomically getkont
473 kont
474
475 atomically $ do
476 current <- fmap (Map.lookup conkey) $ readTVar (conmap server)
477 case current of
478 Nothing -> do
479 (newCon,e) <- return $
480 if duplex params
481 then let newcon = SaneConnection new
482 in ( newcon, ((conkey,u), conevent sessionConduits newcon) )
483 else ( case inout of
484 In -> ReadOnlyConnection new
485 Out -> WriteOnlyConnection new
486 , ((conkey,u), HalfConnection inout) )
487 modifyTVar' (conmap server) $ Map.insert conkey
488 ConnectionRecord { ckont = kontvar
489 , cstate = newCon
490 , cdata = u }
491 announce e
492 putTMVar kontvar $ return $ do
493 myThreadId >>= flip labelThread ("connection."++show inout) -- XXX: more info would be nice.
494 atomically $ putTMVar started ()
495 -- Wait for something interesting.
496 handleEOF conkey u kontvar newCon
497 Just what@ConnectionRecord { ckont =mvar }-> do
498 putTMVar kontvar $ return $ return () -- Kill redundant "connecting..." thread.
499 putTMVar mvar $ do
500 -- The action returned by updateConMap, eventually invokes handleEOF,
501 -- so the sequencer thread will not be terminated.
502 kont <- updateConMap conkey u new what
503 putTMVar started ()
504 return kont
505 return new
506 where
507
508 announce e = writeTChan (serverEvent server) e
509
510 -- This function loops and will not quit unless an action is posted to the
511 -- mvar that does not in turn invoke this function, or if an EOF occurs.
512 handleEOF conkey u mvar newCon = do
513 action <- atomically . foldr1 orElse $
514 [ takeTMVar mvar >>= id -- passed continuation
515 , connWait newCon >> return eof
516 , connWaitPing newCon >>= return . sendPing
517 -- , pingWait pingTimer >>= return . sendPing
518 ]
519 action :: IO ()
520 where
521 eof = do
522 -- warn $ "EOF " <>bshow conkey
523 connCancelPing newCon
524 atomically $ do connFlush newCon
525 announce ((conkey,u),EOF)
526 modifyTVar' (conmap server)
527 $ Map.delete conkey
528 -- warn $ "fin-EOF "<>bshow conkey
529
530 sendPing PingTimeOut = do
531 {-
532 utc <- getCurrentTime
533 let utc' = formatTime defaultTimeLocale "%s" utc
534 warn $ "ping:TIMEOUT " <> bshow utc'
535 -}
536 atomically (connClose newCon)
537 eof
538
539 sendPing PingIdle = do
540 {-
541 utc <- getCurrentTime
542 let utc' = formatTime defaultTimeLocale "%s" utc
543 -- warn $ "ping:IDLE " <> bshow utc'
544 -}
545 atomically $ announce ((conkey,u),RequiresPing)
546 handleEOF conkey u mvar newCon
547
548
549 updateConMap conkey u new (ConnectionRecord { ckont=mvar, cstate=replaced, cdata=u0 }) = do
550 new' <-
551 if duplex params then do
552 announce ((conkey,u),EOF)
553 connClose replaced
554 let newcon = SaneConnection new
555 announce $ ((conkey,u),conevent sessionConduits newcon)
556 return $ newcon
557 else
558 case replaced of
559 WriteOnlyConnection w | inout==In ->
560 do let newcon = ConnectionPair new w
561 announce ((conkey,u),conevent sessionConduits newcon)
562 return newcon
563 ReadOnlyConnection r | inout==Out ->
564 do let newcon = ConnectionPair r new
565 announce ((conkey,u),conevent sessionConduits newcon)
566 return newcon
567 _ -> do -- connFlush todo
568 announce ((conkey,u0), EOF)
569 connClose replaced
570 announce ((conkey,u), HalfConnection inout)
571 return $ case inout of
572 In -> ReadOnlyConnection new
573 Out -> WriteOnlyConnection new
574 modifyTVar' (conmap server) $ Map.insert conkey
575 ConnectionRecord { ckont = mvar
576 , cstate = new'
577 , cdata = u }
578 return $ handleEOF conkey u mvar new'
579
580
581getPacket :: Handle -> IO ByteString
582getPacket h = do hWaitForInput h (-1)
583 hGetNonBlocking h 1024
584
585
586
587-- | 'ConnectionThreads' is an interface to a pair of threads
588-- that are reading and writing a 'Handle'.
589data ConnectionThreads = ConnectionThreads
590 { threadsWriter :: TMVar (Maybe ByteString)
591 , threadsChannel :: TChan ByteString
592 , threadsWait :: STM () -- ^ waits for a 'ConnectionThreads' object to close
593 , threadsPing :: PingMachine
594 }
595
596-- | This spawns the reader and writer threads and returns a newly
597-- constructed 'ConnectionThreads' object.
598connectionThreads :: Handle -> PingMachine -> IO ConnectionThreads
599connectionThreads h pinglogic = do
600
601 (donew,outs) <- atomically $ liftM2 (,) newEmptyTMVar newEmptyTMVar
602
603 (doner,incomming) <- atomically $ liftM2 (,) newEmptyTMVar newTChan
604 readerThread <- forkLabeled "readerThread" $ do
605 let finished e = do
606 hClose h
607 -- warn $ "finished read: " <> bshow (fmap ioeGetErrorType e)
608 -- let _ = fmap ioeGetErrorType e -- type hint
609 let _ = fmap what e where what (SomeException _) = undefined
610 atomically $ do tryTakeTMVar outs
611 putTMVar outs Nothing -- quit writer
612 putTMVar doner ()
613 handle (finished . Just) $ do
614 pingBump pinglogic -- start the ping timer
615 fix $ \loop -> do
616 packet <- getPacket h
617 -- warn $ "read: " <> S.take 60 packet
618 atomically $ writeTChan incomming packet
619 pingBump pinglogic
620 -- warn $ "bumped: " <> S.take 60 packet
621 isEof <- hIsEOF h
622 if isEof then finished Nothing else loop
623
624 writerThread <- forkLabeled "writerThread" . fix $ \loop -> do
625 let finished = do -- warn $ "finished write"
626 -- hClose h -- quit reader
627 throwTo readerThread (ErrorCall "EOF")
628 atomically $ putTMVar donew ()
629 mb <- atomically $ readTMVar outs
630 case mb of Just bs -> handle (\(SomeException e)->finished)
631 (do -- warn $ "writing: " <> S.take 60 bs
632 S.hPutStr h bs
633 -- warn $ "wrote: " <> S.take 60 bs
634 atomically $ takeTMVar outs
635 loop)
636 Nothing -> finished
637
638 let wait = do readTMVar donew
639 readTMVar doner
640 return ()
641 return ConnectionThreads { threadsWriter = outs
642 , threadsChannel = incomming
643 , threadsWait = wait
644 , threadsPing = pinglogic }
645
646
647-- | 'threadsWrite' writes the given 'ByteString' to the
648-- 'ConnectionThreads' object. It blocks until the ByteString
649-- is written and 'True' is returned, or the connection is
650-- interrupted and 'False' is returned.
651threadsWrite :: ConnectionThreads -> ByteString -> IO Bool
652threadsWrite c bs = atomically $
653 orElse (const False `fmap` threadsWait c)
654 (const True `fmap` putTMVar (threadsWriter c) (Just bs))
655
656-- | 'threadsClose' signals for the 'ConnectionThreads' object
657-- to quit and close the associated 'Handle'. This operation
658-- is non-blocking, follow it with 'threadsWait' if you want
659-- to wait for the operation to complete.
660threadsClose :: ConnectionThreads -> STM ()
661threadsClose c = do
662 let mvar = threadsWriter c
663 v <- tryReadTMVar mvar
664 case v of
665 Just Nothing -> return () -- already closed
666 _ -> putTMVar mvar Nothing
667
668-- | 'threadsRead' blocks until a 'ByteString' is available which
669-- is returned to the caller, or the connection is interrupted and
670-- 'Nothing' is returned.
671threadsRead :: ConnectionThreads -> IO (Maybe ByteString)
672threadsRead c = atomically $
673 orElse (const Nothing `fmap` threadsWait c)
674 (Just `fmap` readTChan (threadsChannel c))
675
676-- | A 'ConnectionState' is an interface to a single 'ConnectionThreads'
677-- or to a pair of 'ConnectionThreads' objects that are considered as one
678-- connection.
679data ConnectionState =
680 SaneConnection ConnectionThreads
681 -- ^ ordinary read/write connection
682 | WriteOnlyConnection ConnectionThreads
683 | ReadOnlyConnection ConnectionThreads
684 | ConnectionPair ConnectionThreads ConnectionThreads
685 -- ^ Two 'ConnectionThreads' objects, read operations use the
686 -- first, write operations use the second.
687
688
689
690connWrite :: ConnectionState -> ByteString -> IO Bool
691connWrite (ReadOnlyConnection _) bs = return False
692connWrite conn bs = threadsWrite c bs
693 where
694 c = case conn of SaneConnection c -> c
695 WriteOnlyConnection c -> c
696 ConnectionPair _ c -> c
697
698
699mapConn :: Bool ->
700 (ConnectionThreads -> STM ()) -> ConnectionState -> STM ()
701mapConn both action c =
702 case c of
703 SaneConnection rw -> action rw
704 ReadOnlyConnection r -> action r
705 WriteOnlyConnection w -> action w
706 ConnectionPair r w -> do
707 rem <- orElse (const w `fmap` action r)
708 (const r `fmap` action w)
709 when both $ action rem
710
711connClose :: ConnectionState -> STM ()
712connClose c = mapConn True threadsClose c
713
714connWait :: ConnectionState -> STM ()
715connWait c = doit -- mapConn False threadsWait c
716 where
717 action = threadsWait
718 doit =
719 case c of
720 SaneConnection rw -> action rw
721 ReadOnlyConnection r -> action r
722 WriteOnlyConnection w -> action w
723 ConnectionPair r w -> do
724 rem <- orElse (const w `fmap` action r)
725 (const r `fmap` action w)
726 threadsClose rem
727
728connPingTimer :: ConnectionState -> PingMachine
729connPingTimer c =
730 case c of
731 SaneConnection rw -> threadsPing rw
732 ReadOnlyConnection r -> threadsPing r
733 WriteOnlyConnection w -> threadsPing w -- should be disabled.
734 ConnectionPair r w -> threadsPing r
735
736connCancelPing :: ConnectionState -> IO ()
737connCancelPing c = pingCancel (connPingTimer c)
738
739connWaitPing :: ConnectionState -> STM PingEvent
740connWaitPing c = pingWait (connPingTimer c)
741
742connFlush :: ConnectionState -> STM ()
743connFlush c =
744 case c of
745 SaneConnection rw -> waitChan rw
746 ReadOnlyConnection r -> waitChan r
747 WriteOnlyConnection w -> return ()
748 ConnectionPair r w -> waitChan r
749 where
750 waitChan t = do
751 b <- isEmptyTChan (threadsChannel t)
752 when (not b) retry
753
754bshow :: Show a => a -> ByteString
755bshow e = S.pack . show $ e
756
757warn :: ByteString -> IO ()
758warn str =dputB XMisc str
759
760debugNoise :: Monad m => t -> m ()
761debugNoise str = return ()
762
763data TCPStatus = Resolving | AwaitingRead | AwaitingWrite
764
765-- SockAddr -> (SockAddr, ConnectionParameters SockAddr ConnectionData, Miliseconds)
766
767
768tcpManager :: (PeerAddress -> (SockAddr, ConnectionParameters PeerAddress u, Miliseconds))
769 -- -> (String -> Maybe Text)
770 -- -> (Text -> IO (Maybe PeerAddress))
771 -> Server PeerAddress u releaseKey x
772 -> IO (Manager TCPStatus Text)
773tcpManager grokKey sv = do
774 rmap <- atomically $ newTVar Map.empty -- Map k (Maybe conkey)
775 nullping <- forkPingMachine "tcpManager" 0 0
776 (rslv,rev) <- do
777 dns <- newDNSCache
778 let rslv k = map PeerAddress <$> forwardResolve dns k
779 rev (PeerAddress addr) = reverseResolve dns addr
780 return (rslv,rev)
781 return Manager {
782 setPolicy = \k -> \case
783 TryingToConnect -> join $ atomically $ do
784 r <- readTVar rmap
785 case Map.lookup k r of
786 Just {} -> return $ return () -- Connection already in progress.
787 Nothing -> do
788 modifyTVar' rmap $ Map.insert k Nothing
789 return $ void $ forkLabeled ("resolve."++show k) $ do
790 mconkey <- listToMaybe <$> rslv k
791 case mconkey of
792 Nothing -> atomically $ modifyTVar' rmap $ Map.delete k
793 Just conkey -> do
794 control sv $ case grokKey conkey of
795 (saddr,params,ms) -> ConnectWithEndlessRetry saddr params ms
796 OpenToConnect -> dput XMisc "TODO: TCP OpenToConnect"
797 RefusingToConnect -> dput XMisc "TODO: TCP RefusingToConnect"
798 , status = \k -> do
799 c <- readTVar (conmap sv)
800 ck <- Map.lookup k <$> readTVar rmap
801 return $ exportConnection c (join ck)
802 , connections = Map.keys <$> readTVar rmap
803 , stringToKey = Just . Text.pack
804 , showProgress = \case
805 Resolving -> "resolving"
806 AwaitingRead -> "awaiting inbound"
807 AwaitingWrite -> "awaiting outbound"
808 , showKey = show
809 , resolvePeer = rslv
810 , reverseAddress = rev
811 }
812
813exportConnection :: Ord conkey => Map conkey (ConnectionRecord u) -> Maybe conkey -> G.Connection TCPStatus
814exportConnection conmap mkey = G.Connection
815 { G.connStatus = case mkey of
816 Nothing -> G.Dormant
817 Just conkey -> case Map.lookup conkey conmap of
818 Nothing -> G.InProgress Resolving
819 Just (ConnectionRecord ckont cstate cdata) -> case cstate of
820 SaneConnection {} -> G.Established
821 ConnectionPair {} -> G.Established
822 ReadOnlyConnection {} -> G.InProgress AwaitingWrite
823 WriteOnlyConnection {} -> G.InProgress AwaitingRead
824 , G.connPolicy = TryingToConnect
825 }
diff --git a/server/src/Control/Concurrent/Delay.hs b/server/src/Control/Concurrent/Delay.hs
new file mode 100644
index 00000000..5cc1f99a
--- /dev/null
+++ b/server/src/Control/Concurrent/Delay.hs
@@ -0,0 +1,50 @@
1{-# LANGUAGE NondecreasingIndentation #-}
2module Control.Concurrent.Delay where
3
4import Control.Concurrent
5import Control.Monad
6import Control.Exception ({-evaluate,-}handle,finally,throwIO)
7import Data.Time.Clock (NominalDiffTime)
8import System.IO.Error
9
10type Microseconds = Int
11
12microseconds :: NominalDiffTime -> Microseconds
13microseconds d = round $ 1000000 * d
14
15data InterruptibleDelay = InterruptibleDelay
16 { delayThread :: MVar ThreadId
17 }
18
19interruptibleDelay :: IO InterruptibleDelay
20interruptibleDelay = do
21 fmap InterruptibleDelay newEmptyMVar
22
23-- | Delay for the given number of microseconds and return 'True' if the delay
24-- is not interrupted.
25--
26-- Note: If a thread is already waiting on the given 'InterruptibleDelay'
27-- object, then this will block until it becomes available and only then start
28-- the delay timer.
29startDelay :: InterruptibleDelay -> Microseconds -> IO Bool
30startDelay d interval = do
31 thread <- myThreadId
32 handle (\e -> do when (not $ isUserError e) (throwIO e)
33 return False) $ do
34 putMVar (delayThread d) thread
35 threadDelay interval
36 void $ takeMVar (delayThread d)
37 return True
38 -- The following cleanup shouldn't be necessary, but I'm paranoid.
39 `finally` tryTakeMVar (delayThread d)
40
41 where debugNoise str = return ()
42
43-- | Signal the thread waiting on the given 'InterruptibleDelay' object to
44-- continue even though the timeout has not elapsed. If no thread is waiting,
45-- then this is a no-op.
46interruptDelay :: InterruptibleDelay -> IO ()
47interruptDelay d = do
48 mthread <- tryTakeMVar (delayThread d)
49 forM_ mthread $ \thread -> do
50 throwTo thread (userError "Interrupted delay")
diff --git a/server/src/Control/Concurrent/PingMachine.hs b/server/src/Control/Concurrent/PingMachine.hs
new file mode 100644
index 00000000..5de0e2e5
--- /dev/null
+++ b/server/src/Control/Concurrent/PingMachine.hs
@@ -0,0 +1,163 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE NondecreasingIndentation #-}
3{-# LANGUAGE TupleSections #-}
4module Control.Concurrent.PingMachine where
5
6import Control.Monad
7import Data.Function
8#ifdef THREAD_DEBUG
9import Control.Concurrent.Lifted.Instrument
10#else
11import Control.Concurrent (forkIO)
12import Control.Concurrent.Lifted
13import GHC.Conc (labelThread)
14#endif
15import Control.Concurrent.STM
16
17import Control.Concurrent.Delay
18
19type Miliseconds = Int
20type TimeOut = Miliseconds
21type PingInterval = Miliseconds
22
23-- | Events that occur as a result of the 'PingMachine' watchdog.
24--
25-- Use 'pingWait' to wait for one of these to occur.
26data PingEvent
27 = PingIdle -- ^ You should send a ping if you observe this event.
28 | PingTimeOut -- ^ You should give up on the connection in case of this event.
29
30data PingMachine = PingMachine
31 { pingFlag :: TVar Bool
32 , pingInterruptible :: InterruptibleDelay
33 , pingEvent :: TMVar PingEvent
34 , pingStarted :: TMVar Bool
35 }
36
37-- | Fork a thread to monitor a connection for a ping timeout.
38--
39-- If 'pingBump' is not invoked after a idle is signaled, a timeout event will
40-- occur. When that happens, even if the caller chooses to ignore this event,
41-- the watchdog thread will be terminated and no more ping events will be
42-- signaled.
43--
44-- An idle connection will be signaled by:
45--
46-- (1) 'pingFlag' is set 'True'
47--
48-- (2) 'pingWait' returns 'PingIdle'
49--
50-- Either may be tested to determine whether a ping should be sent but
51-- 'pingFlag' is difficult to use properly because it is up to the caller to
52-- remember that the ping is already in progress.
53forkPingMachine
54 :: String
55 -> PingInterval -- ^ Milliseconds of idle before a ping is considered necessary.
56 -> TimeOut -- ^ Milliseconds after 'PingIdle' before we signal 'PingTimeOut'.
57 -> IO PingMachine
58forkPingMachine label idle timeout = do
59 d <- interruptibleDelay
60 flag <- atomically $ newTVar False
61 canceled <- atomically $ newTVar False
62 event <- atomically newEmptyTMVar
63 started <- atomically $ newEmptyTMVar
64 when (idle/=0) $ void . forkIO $ do
65 myThreadId >>= flip labelThread ("Ping." ++ label) -- ("ping.watchdog")
66 (>>=) (atomically (readTMVar started)) $ flip when $ do
67 fix $ \loop -> do
68 atomically $ writeTVar flag False
69 fin <- startDelay d (1000*idle)
70 (>>=) (atomically (readTMVar started)) $ flip when $ do
71 if (not fin) then loop
72 else do
73 -- Idle event
74 atomically $ do
75 tryTakeTMVar event
76 putTMVar event PingIdle
77 writeTVar flag True
78 fin <- startDelay d (1000*timeout)
79 (>>=) (atomically (readTMVar started)) $ flip when $ do
80 me <- myThreadId
81 if (not fin) then loop
82 else do
83 -- Timeout event
84 atomically $ do
85 tryTakeTMVar event
86 writeTVar flag False
87 putTMVar event PingTimeOut
88 return PingMachine
89 { pingFlag = flag
90 , pingInterruptible = d
91 , pingEvent = event
92 , pingStarted = started
93 }
94
95-- | like 'forkPingMachine' but the timeout and idle parameters can be changed dynamically
96-- Unlike 'forkPingMachine', 'forkPingMachineDynamic' always launches a thread
97-- regardless of idle value.
98forkPingMachineDynamic
99 :: String
100 -> TVar PingInterval -- ^ Milliseconds of idle before a ping is considered necessary.
101 -> TVar TimeOut -- ^ Milliseconds after 'PingIdle' before we signal 'PingTimeOut'.
102 -> IO PingMachine
103forkPingMachineDynamic label idleV timeoutV = do
104 d <- interruptibleDelay
105 flag <- atomically $ newTVar False
106 canceled <- atomically $ newTVar False
107 event <- atomically newEmptyTMVar
108 started <- atomically $ newEmptyTMVar
109 void . forkIO $ do
110 myThreadId >>= flip labelThread ("Ping." ++ label) -- ("ping.watchdog")
111 (>>=) (atomically (readTMVar started)) $ flip when $ do
112 fix $ \loop -> do
113 atomically $ writeTVar flag False
114 (idle,timeout) <- atomically $ (,) <$> readTVar idleV <*> readTVar timeoutV
115 fin <- startDelay d (1000*idle)
116 (>>=) (atomically (readTMVar started)) $ flip when $ do
117 if (not fin) then loop
118 else do
119 -- Idle event
120 atomically $ do
121 tryTakeTMVar event
122 putTMVar event PingIdle
123 writeTVar flag True
124 fin <- startDelay d (1000*timeout)
125 (>>=) (atomically (readTMVar started)) $ flip when $ do
126 me <- myThreadId
127 if (not fin) then loop
128 else do
129 -- Timeout event
130 atomically $ do
131 tryTakeTMVar event
132 writeTVar flag False
133 putTMVar event PingTimeOut
134 return PingMachine
135 { pingFlag = flag
136 , pingInterruptible = d
137 , pingEvent = event
138 , pingStarted = started
139 }
140
141-- | Terminate the watchdog thread. Call this upon connection close.
142--
143-- You should ensure no threads are waiting on 'pingWait' because there is no
144-- 'PingEvent' signaling termination.
145pingCancel :: PingMachine -> IO ()
146pingCancel me = do
147 atomically $ do tryTakeTMVar (pingStarted me)
148 putTMVar (pingStarted me) False
149 interruptDelay (pingInterruptible me)
150
151-- | Reset the ping timer. Call this regularly to prevent 'PingTimeOut'.
152pingBump :: PingMachine -> IO ()
153pingBump me = do
154 atomically $ do
155 b <- tryReadTMVar (pingStarted me)
156 when (b/=Just False) $ do
157 tryTakeTMVar (pingStarted me)
158 putTMVar (pingStarted me) True
159 interruptDelay (pingInterruptible me)
160
161-- | Retries until a 'PingEvent' occurs.
162pingWait :: PingMachine -> STM PingEvent
163pingWait me = takeTMVar (pingEvent me)
diff --git a/server/src/ControlMaybe.hs b/server/src/ControlMaybe.hs
new file mode 100644
index 00000000..a101d667
--- /dev/null
+++ b/server/src/ControlMaybe.hs
@@ -0,0 +1,64 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE ScopedTypeVariables #-}
3module ControlMaybe
4 ( module ControlMaybe
5 , module Data.Functor
6 ) where
7
8-- import GHC.IO.Exception (IOException(..))
9import Control.Monad
10import Data.Functor
11import System.IO.Error
12
13
14-- forM_ with less polymorphism.
15withJust :: Monad m => Maybe x -> (x -> m ()) -> m ()
16withJust m f = forM_ m f
17{-# INLINE withJust #-}
18
19whenJust :: Monad m => m (Maybe x) -> (x -> m ()) -> m ()
20whenJust acn f = acn >>= mapM_ f
21{-# INLINE whenJust #-}
22
23
24catchIO_ :: IO a -> IO a -> IO a
25catchIO_ body catcher = catchIOError body (\_ -> catcher)
26{-# INLINE catchIO_ #-}
27
28handleIO_ :: IO a -> IO a -> IO a
29handleIO_ catcher body = catchIOError body (\_ -> catcher)
30{-# INLINE handleIO_ #-}
31
32
33handleIO :: (IOError -> IO a) -> IO a -> IO a
34handleIO catcher body = catchIOError body catcher
35{-# INLINE handleIO #-}
36
37#if !MIN_VERSION_base(4,11,0)
38-- | Flipped version of '<$>'.
39--
40-- @
41-- ('<&>') = 'flip' 'fmap'
42-- @
43--
44-- @since 4.11.0.0
45--
46-- ==== __Examples__
47-- Apply @(+1)@ to a list, a 'Data.Maybe.Just' and a 'Data.Either.Right':
48--
49-- >>> Just 2 <&> (+1)
50-- Just 3
51--
52-- >>> [1,2,3] <&> (+1)
53-- [2,3,4]
54--
55-- >>> Right 3 <&> (+1)
56-- Right 4
57--
58(<&>) :: Functor f => f a -> (a -> b) -> f b
59as <&> f = f <$> as
60
61infixl 1 <&>
62#endif
63
64
diff --git a/server/src/DNSCache.hs b/server/src/DNSCache.hs
new file mode 100644
index 00000000..f539c71f
--- /dev/null
+++ b/server/src/DNSCache.hs
@@ -0,0 +1,286 @@
1-- | Both 'getAddrInfo' and 'getHostByAddr' have hard-coded timeouts for
2-- waiting upon network queries that can be a little too long for some use
3-- cases. This module wraps both of them so that they block for at most one
4-- second. It caches late-arriving results so that they can be returned by
5-- repeated timed-out queries.
6--
7-- In order to achieve the shorter timeout, it is likely that the you will need
8-- to build with GHC's -threaded option. Otherwise, if the wrapped FFI calls
9-- to resolve the address will block Haskell threads. Note: I didn't verify
10-- this.
11{-# LANGUAGE CPP #-}
12{-# LANGUAGE NondecreasingIndentation #-}
13{-# LANGUAGE RankNTypes #-}
14{-# LANGUAGE TupleSections #-}
15module DNSCache
16 ( DNSCache
17 , reverseResolve
18 , forwardResolve
19 , newDNSCache
20 , parseAddress
21 , unsafeParseAddress
22 , strip_brackets
23 , withPort
24 ) where
25
26import Control.Concurrent.ThreadUtil
27import Control.Arrow
28import Control.Concurrent.STM
29import Data.Text ( Text )
30import Network.Socket ( SockAddr(..), AddrInfoFlag(..), defaultHints, getAddrInfo, AddrInfo(..) )
31import Data.Time.Clock ( UTCTime, getCurrentTime, diffUTCTime )
32import System.IO.Error ( isDoesNotExistError )
33import System.Endian ( fromBE32, toBE32 )
34import Control.Exception ( handle )
35import Data.Map ( Map )
36import qualified Data.Map as Map
37import qualified Network.BSD as BSD
38import qualified Data.Text as Text
39import Control.Monad
40import Data.Function
41import Data.List
42import Data.Ord
43import Data.Maybe
44import System.IO.Error
45import System.IO.Unsafe
46
47import SockAddr ()
48import ControlMaybe ( handleIO_ )
49import GetHostByAddr ( getHostByAddr )
50import Control.Concurrent.Delay
51import DPut
52import DebugTag
53
54type TimeStamp = UTCTime
55
56data DNSCache =
57 DNSCache
58 { fcache :: TVar (Map Text [(TimeStamp, SockAddr)])
59 , rcache :: TVar (Map SockAddr [(TimeStamp, Text)])
60 }
61
62
63newDNSCache :: IO DNSCache
64newDNSCache = do
65 fcache <- newTVarIO Map.empty
66 rcache <- newTVarIO Map.empty
67 return DNSCache { fcache=fcache, rcache=rcache }
68
69updateCache :: Eq x =>
70 Bool -> TimeStamp -> [x] -> Maybe [(TimeStamp,x)] -> Maybe [(TimeStamp,x)]
71updateCache withScrub utc xs mys = do
72 let ys = maybe [] id mys
73 ys' = filter scrub ys
74 ys'' = map (utc,) xs ++ ys'
75 minute = 60
76 scrub (t,x) | withScrub && diffUTCTime utc t < minute = False
77 scrub (t,x) | x `elem` xs = False
78 scrub _ = True
79 guard $ not (null ys'')
80 return ys''
81
82dnsObserve :: DNSCache -> Bool -> TimeStamp -> [(Text,SockAddr)] -> STM ()
83dnsObserve dns withScrub utc obs = do
84 f <- readTVar $ fcache dns
85 r <- readTVar $ rcache dns
86 let obs' = map (\(n,a)->(n,a `withPort` 0)) obs
87 gs = do
88 g <- groupBy ((==) `on` fst) $ sortBy (comparing fst) obs'
89 (n,_) <- take 1 g
90 return (n,map snd g)
91 f' = foldl' updatef f gs
92 hs = do
93 h <- groupBy ((==) `on` snd) $ sortBy (comparing snd) obs'
94 (_,a) <- take 1 h
95 return (a,map fst h)
96 r' = foldl' updater r hs
97 writeTVar (fcache dns) f'
98 writeTVar (rcache dns) r'
99 where
100 updatef f (n,addrs) = Map.alter (updateCache withScrub utc addrs) n f
101 updater r (a,ns) = Map.alter (updateCache withScrub utc ns) a r
102
103make6mapped4 :: SockAddr -> SockAddr
104make6mapped4 addr@(SockAddrInet6 {}) = addr
105make6mapped4 addr@(SockAddrInet port a) = SockAddrInet6 port 0 (0,0,0xFFFF,fromBE32 a) 0
106
107tryForkOS :: String -> IO () -> IO ThreadId
108tryForkOS lbl action = catchIOError (forkOSLabeled lbl action) $ \e -> do
109 dput XMisc $ "DNSCache: Link with -threaded to avoid excessively long time-out."
110 forkLabeled lbl action
111
112
113-- Attempt to resolve the given domain name. Returns an empty list if the
114-- resolve operation takes longer than the timeout, but the 'DNSCache' will be
115-- updated when the resolve completes.
116--
117-- When the resolve operation does complete, any entries less than a minute old
118-- will be overwritten with the new results. Older entries are allowed to
119-- persist for reasons I don't understand as of this writing. (See 'updateCache')
120rawForwardResolve ::
121 DNSCache -> (Text -> IO ()) -> Int -> Text -> IO [SockAddr]
122rawForwardResolve dns onFail timeout addrtext = do
123 r <- atomically newEmptyTMVar
124 mvar <- interruptibleDelay
125 rt <- tryForkOS ("resolve."++show addrtext) $ do
126 resolver r mvar
127 startDelay mvar timeout
128 did <- atomically $ tryPutTMVar r []
129 when did (onFail addrtext)
130 atomically $ readTMVar r
131 where
132 resolver r mvar = do
133 xs <- handle (\e -> let _ = isDoesNotExistError e in return [])
134 $ do fmap (nub . map (make6mapped4 . addrAddress)) $
135 getAddrInfo (Just $ defaultHints { addrFlags = [ AI_CANONNAME, AI_V4MAPPED ]})
136 (Just $ Text.unpack $ strip_brackets addrtext)
137 (Just "5269")
138 did <- atomically $ tryPutTMVar r xs
139 when did $ do
140 interruptDelay mvar
141 utc <- getCurrentTime
142 atomically $ dnsObserve dns True utc $ map (addrtext,) xs
143 return ()
144
145strip_brackets :: Text -> Text
146strip_brackets s =
147 case Text.uncons s of
148 Just ('[',t) -> Text.takeWhile (/=']') t
149 _ -> s
150
151
152reportTimeout :: forall a. Show a => a -> IO ()
153reportTimeout addrtext = do
154 dput XMisc $ "timeout resolving: "++show addrtext
155 -- killThread rt
156
157unmap6mapped4 :: SockAddr -> SockAddr
158unmap6mapped4 addr@(SockAddrInet6 port _ (0,0,0xFFFF,a) _) =
159 SockAddrInet port (toBE32 a)
160unmap6mapped4 addr = addr
161
162rawReverseResolve ::
163 DNSCache -> (SockAddr -> IO ()) -> Int -> SockAddr -> IO [Text]
164rawReverseResolve dns onFail timeout addr = do
165 r <- atomically newEmptyTMVar
166 mvar <- interruptibleDelay
167 rt <- forkOS $ resolver r mvar
168 startDelay mvar timeout
169 did <- atomically $ tryPutTMVar r []
170 when did (onFail addr)
171 atomically $ readTMVar r
172 where
173 resolver r mvar =
174 handleIO_ (return ()) $ do
175 ent <- getHostByAddr (unmap6mapped4 addr) -- AF_UNSPEC addr
176 let names = BSD.hostName ent : BSD.hostAliases ent
177 xs = map Text.pack $ nub names
178 forkIO $ do
179 utc <- getCurrentTime
180 atomically $ dnsObserve dns False utc $ map (,addr) xs
181 atomically $ putTMVar r xs
182
183-- Returns expired (older than a minute) cached reverse-dns results
184-- and removes them from the cache.
185expiredReverse :: DNSCache -> SockAddr -> IO [Text]
186expiredReverse dns addr = do
187 utc <- getCurrentTime
188 addr <- return $ addr `withPort` 0
189 es <- atomically $ do
190 r <- readTVar $ rcache dns
191 let ns = maybe [] id $ Map.lookup addr r
192 minute = 60 -- seconds
193 -- XXX: Is this right? flip diffUTCTime utc returns the age of the
194 -- cache entry?
195 (es0,ns') = partition ( (>=minute) . flip diffUTCTime utc . fst ) ns
196 es = map snd es0
197 modifyTVar' (rcache dns) $ Map.insert addr ns'
198 f <- readTVar $ fcache dns
199 let f' = foldl' (flip $ Map.alter (expire utc)) f es
200 expire utc Nothing = Nothing
201 expire utc (Just as) = if null as' then Nothing else Just as'
202 where as' = filter ( (<minute) . flip diffUTCTime utc . fst) as
203 writeTVar (fcache dns) f'
204 return es
205 return es
206
207cachedReverse :: DNSCache -> SockAddr -> IO [Text]
208cachedReverse dns addr = do
209 utc <- getCurrentTime
210 addr <- return $ addr `withPort` 0
211 atomically $ do
212 r <- readTVar (rcache dns)
213 let ns = maybe [] id $ Map.lookup addr r
214 {-
215 ns' = filter ( (<minute) . flip diffUTCTime utc . fst) ns
216 minute = 60 -- seconds
217 modifyTVar' (rcache dns) $ Map.insert addr ns'
218 return $ map snd ns'
219 -}
220 return $ map snd ns
221
222-- Returns any dns query results for the given name that were observed less
223-- than a minute ago and updates the forward-cache to remove any results older
224-- than that.
225cachedForward :: DNSCache -> Text -> IO [SockAddr]
226cachedForward dns n = do
227 utc <- getCurrentTime
228 atomically $ do
229 f <- readTVar (fcache dns)
230 let as = maybe [] id $ Map.lookup n f
231 as' = filter ( (<minute) . flip diffUTCTime utc . fst) as
232 minute = 60 -- seconds
233 modifyTVar' (fcache dns) $ Map.insert n as'
234 return $ map snd as'
235
236-- Reverse-resolves an address to a domain name. Returns both the result of a
237-- new query and any freshly cached results. Cache entries older than a minute
238-- will not be returned, but will be refreshed in spawned threads so that they
239-- may be available for the next call.
240reverseResolve :: DNSCache -> SockAddr -> IO [Text]
241reverseResolve dns addr = do
242 expired <- expiredReverse dns addr
243 forM_ expired $ \n -> forkIO $ do
244 rawForwardResolve dns (const $ return ()) 1000000 n
245 return ()
246 xs <- rawReverseResolve dns (const $ return ()) 1000000 addr
247 cs <- cachedReverse dns addr
248 return $ xs ++ filter (not . flip elem xs) cs
249
250-- Resolves a name, if there's no result within one second, then any cached
251-- results that are less than a minute old are returned.
252forwardResolve :: DNSCache -> Text -> IO [SockAddr]
253forwardResolve dns n = do
254 as <- rawForwardResolve dns (const $ return ()) 1000000 n
255 if null as
256 then cachedForward dns n
257 else return as
258
259parseAddress :: Text -> IO (Maybe SockAddr)
260parseAddress addr_str = do
261 info <- getAddrInfo (Just $ defaultHints { addrFlags = [ AI_NUMERICHOST ] })
262 (Just . Text.unpack $ addr_str)
263 (Just "0")
264 return . listToMaybe $ map addrAddress info
265
266
267splitAtPort :: String -> (String,String)
268splitAtPort s = second sanitizePort $ case s of
269 ('[':t) -> break (==']') t
270 _ -> break (==':') s
271 where
272 sanitizePort (']':':':p) = p
273 sanitizePort (':':p) = p
274 sanitizePort _ = "0"
275
276unsafeParseAddress :: String -> Maybe SockAddr
277unsafeParseAddress addr_str = unsafePerformIO $ do
278 let (ipstr,portstr) = splitAtPort addr_str
279 info <- getAddrInfo (Just $ defaultHints { addrFlags = [ AI_NUMERICHOST ] })
280 (Just ipstr)
281 (Just portstr)
282 return . listToMaybe $ map addrAddress info
283
284withPort :: SockAddr -> Int -> SockAddr
285withPort (SockAddrInet _ a) port = SockAddrInet (toEnum port) a
286withPort (SockAddrInet6 _ a b c) port = SockAddrInet6 (toEnum port) a b c
diff --git a/server/src/Data/TableMethods.hs b/server/src/Data/TableMethods.hs
new file mode 100644
index 00000000..e4208a69
--- /dev/null
+++ b/server/src/Data/TableMethods.hs
@@ -0,0 +1,105 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE GADTs #-}
3{-# LANGUAGE LambdaCase #-}
4{-# LANGUAGE PartialTypeSignatures #-}
5{-# LANGUAGE RankNTypes #-}
6{-# LANGUAGE ScopedTypeVariables #-}
7{-# LANGUAGE TupleSections #-}
8module Data.TableMethods where
9
10import Data.Functor.Contravariant
11import Data.Time.Clock.POSIX
12import Data.Word
13import qualified Data.IntMap.Strict as IntMap
14 ;import Data.IntMap.Strict (IntMap)
15import qualified Data.Map.Strict as Map
16 ;import Data.Map.Strict (Map)
17import qualified Data.Word64Map as W64Map
18 ;import Data.Word64Map (Word64Map)
19
20import Data.Wrapper.PSQ as PSQ
21
22type Priority = POSIXTime
23
24data OptionalPriority t tid x
25 = NoPriority
26 | HasPriority (Priority -> t x -> ([(tid, Priority, x)], t x))
27
28-- | The standard lookup table methods.
29data TableMethods t tid = TableMethods
30 { -- | Insert a new /tid/ entry into the transaction table.
31 tblInsert :: forall a. tid -> a -> Priority -> t a -> t a
32 -- | Delete transaction /tid/ from the transaction table.
33 , tblDelete :: forall a. tid -> t a -> t a
34 -- | Lookup the value associated with transaction /tid/.
35 , tblLookup :: forall a. tid -> t a -> Maybe a
36 }
37
38data QMethods t tid x = QMethods
39 { qTbl :: TableMethods t tid
40 , qAtMostView :: OptionalPriority t tid x
41 }
42
43vanillaTable :: TableMethods t tid -> QMethods t tid x
44vanillaTable tbl = QMethods tbl NoPriority
45
46priorityTable :: TableMethods t tid
47 -> (Priority -> t x -> ([(k, Priority, x)], t x))
48 -> (k -> x -> tid)
49 -> QMethods t tid x
50priorityTable tbl atmost f = QMethods
51 { qTbl = tbl
52 , qAtMostView = HasPriority $ \p t -> case atmost p t of
53 (es,t') -> (map (\(k,p,a) -> (f k a, p, a)) es, t')
54 }
55
56-- | Methods for using 'Data.IntMap'.
57intMapMethods :: TableMethods IntMap Int
58intMapMethods = TableMethods
59 { tblInsert = \tid a p -> IntMap.insert tid a
60 , tblDelete = IntMap.delete
61 , tblLookup = IntMap.lookup
62 }
63
64-- | Methods for using 'Data.Word64Map'.
65w64MapMethods :: TableMethods Word64Map Word64
66w64MapMethods = TableMethods
67 { tblInsert = \tid a p -> W64Map.insert tid a
68 , tblDelete = W64Map.delete
69 , tblLookup = W64Map.lookup
70 }
71
72-- | Methods for using 'Data.Map'
73mapMethods :: Ord tid => TableMethods (Map tid) tid
74mapMethods = TableMethods
75 { tblInsert = \tid a p -> Map.insert tid a
76 , tblDelete = Map.delete
77 , tblLookup = Map.lookup
78 }
79
80-- psqMethods :: PSQKey tid => QMethods (HashPSQ tid Priority) tid x
81psqMethods :: PSQKey k => (tid -> k) -> (k -> x -> tid) -> QMethods (PSQ' k Priority) tid x
82psqMethods g f = priorityTable (contramap g tbl) PSQ.atMostView f
83 where
84 tbl :: PSQKey tid => TableMethods (PSQ' tid Priority) tid
85 tbl = TableMethods
86 { tblInsert = PSQ.insert'
87 , tblDelete = PSQ.delete
88 , tblLookup = \tid t -> case PSQ.lookup tid t of
89 Just (p,a) -> Just a
90 Nothing -> Nothing
91 }
92
93
94-- | Change the key type for a lookup table implementation.
95--
96-- This can be used with 'intMapMethods' or 'mapMethods' to restrict lookups to
97-- only a part of the generated /tid/ value. This is useful for /tid/ types
98-- that are especially large due their use for other purposes, such as secure
99-- nonces for encryption.
100instance Contravariant (TableMethods t) where
101 -- contramap :: (tid -> t1) -> TableMethods t t1 -> TableMethods t tid
102 contramap f (TableMethods ins del lookup) =
103 TableMethods (\k p v t -> ins (f k) p v t)
104 (\k t -> del (f k) t)
105 (\k t -> lookup (f k) t)
diff --git a/server/src/DebugTag.hs b/server/src/DebugTag.hs
new file mode 100644
index 00000000..9ac04bb0
--- /dev/null
+++ b/server/src/DebugTag.hs
@@ -0,0 +1,24 @@
1module DebugTag where
2
3import Data.Typeable
4
5-- | Debug Tags, add more as needed, but ensure XAnnounce is always first, XMisc last
6data DebugTag
7 = XAnnounce
8 | XBitTorrent
9 | XDHT
10 | XLan
11 | XMan
12 | XNetCrypto
13 | XNetCryptoOut
14 | XOnion
15 | XRoutes
16 | XPing
17 | XRefresh
18 | XJabber
19 | XTCP
20 | XMisc
21 | XNodeinfoSearch
22 | XUnexpected -- Used only for special anomalous errors that we didn't expect to happen.
23 | XUnused -- Never commit code that uses XUnused.
24 deriving (Eq, Ord, Show, Read, Enum, Bounded,Typeable)
diff --git a/server/src/ForkLabeled.hs b/server/src/ForkLabeled.hs
new file mode 100644
index 00000000..50b5d76c
--- /dev/null
+++ b/server/src/ForkLabeled.hs
@@ -0,0 +1,16 @@
1{-# LANGUAGE CPP #-}
2module ForkLabeled where
3
4#ifdef THREAD_DEBUG
5import Control.Concurrent.Lifted.Instrument
6#else
7import Control.Concurrent.Lifted
8import GHC.Conc (labelThread,forkIO)
9#endif
10
11forkLabeled :: String -> IO () -> IO ThreadId
12forkLabeled s io = do
13 t <- forkIO io
14 labelThread t s
15 return t
16
diff --git a/server/src/GetHostByAddr.hs b/server/src/GetHostByAddr.hs
new file mode 100644
index 00000000..068fc93d
--- /dev/null
+++ b/server/src/GetHostByAddr.hs
@@ -0,0 +1,78 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE NondecreasingIndentation #-}
3module GetHostByAddr where
4
5import Network.BSD
6import Foreign.Ptr
7import Foreign.C.Types
8import Foreign.Storable (Storable(..))
9import Foreign.Marshal.Utils (with)
10import Foreign.Marshal.Alloc
11import Control.Concurrent
12import System.IO.Unsafe
13import System.IO.Error (ioeSetErrorString, mkIOError)
14import Network.Socket
15import GHC.IO.Exception
16
17
18throwNoSuchThingIfNull :: String -> String -> IO (Ptr a) -> IO (Ptr a)
19throwNoSuchThingIfNull loc desc act = do
20 ptr <- act
21 if (ptr == nullPtr)
22 then ioError (ioeSetErrorString (mkIOError NoSuchThing loc Nothing Nothing) desc)
23 else return ptr
24
25{-# NOINLINE lock #-}
26lock :: MVar ()
27lock = unsafePerformIO $ newMVar ()
28
29withLock :: IO a -> IO a
30withLock act = withMVar lock (\_ -> act)
31
32trySysCall :: IO a -> IO a
33trySysCall act = act
34
35{-
36-- The locking of gethostbyaddr is similar to gethostbyname.
37-- | Get a 'HostEntry' corresponding to the given address and family.
38-- Note that only IPv4 is currently supported.
39getHostByAddr :: Family -> SockAddr -> IO HostEntry
40getHostByAddr family addr = do
41 withSockAddr addr $ \ ptr_addr len -> withLock $ do
42 throwNoSuchThingIfNull "getHostByAddr" "no such host entry"
43 $ trySysCall $ c_gethostbyaddr ptr_addr (fromIntegral len) (packFamily family)
44 >>= peek
45-}
46
47
48-- The locking of gethostbyaddr is similar to gethostbyname.
49-- | Get a 'HostEntry' corresponding to the given address and family.
50-- Note that only IPv4 is currently supported.
51-- getHostByAddr :: Family -> HostAddress -> IO HostEntry
52-- getHostByAddr family addr = do
53getHostByAddr :: SockAddr -> IO HostEntry
54getHostByAddr (SockAddrInet port addr ) = do
55 let family = AF_INET
56 with addr $ \ ptr_addr -> withLock $ do
57 throwNoSuchThingIfNull "getHostByAddr" "no such host entry"
58 $ trySysCall $ c_gethostbyaddr ptr_addr (fromIntegral (sizeOf addr)) (packFamily family)
59 >>= peek
60getHostByAddr (SockAddrInet6 port flow (a,b,c,d) scope) = do
61 let family = AF_INET6
62 allocaBytes 16 $ \ ptr_addr -> do
63 pokeElemOff ptr_addr 0 a
64 pokeElemOff ptr_addr 1 b
65 pokeElemOff ptr_addr 2 c
66 pokeElemOff ptr_addr 3 d
67 withLock $ do
68 throwNoSuchThingIfNull "getHostByAddr" "no such host entry"
69 $ trySysCall $ c_gethostbyaddr ptr_addr 16 (packFamily family)
70 >>= peek
71
72
73foreign import ccall safe "gethostbyaddr"
74 c_gethostbyaddr :: Ptr a -> CInt -> CInt -> IO (Ptr HostEntry)
75
76
77
78-- vim:ft=haskell:
diff --git a/server/src/Network/QueryResponse.hs b/server/src/Network/QueryResponse.hs
new file mode 100644
index 00000000..20e7ecf0
--- /dev/null
+++ b/server/src/Network/QueryResponse.hs
@@ -0,0 +1,716 @@
1-- | This module can implement any query\/response protocol. It was written
2-- with Kademlia implementations in mind.
3
4{-# LANGUAGE CPP #-}
5{-# LANGUAGE GADTs #-}
6{-# LANGUAGE LambdaCase #-}
7{-# LANGUAGE PartialTypeSignatures #-}
8{-# LANGUAGE RankNTypes #-}
9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE TupleSections #-}
11module Network.QueryResponse where
12
13#ifdef THREAD_DEBUG
14import Control.Concurrent.Lifted.Instrument
15#else
16import Control.Concurrent
17import GHC.Conc (labelThread)
18#endif
19import Control.Concurrent.STM
20import Control.Exception
21import Control.Monad
22import qualified Data.ByteString as B
23 ;import Data.ByteString (ByteString)
24import Data.Dependent.Map as DMap
25import Data.Dependent.Sum
26import Data.Function
27import Data.Functor.Contravariant
28import Data.Functor.Identity
29import Data.GADT.Show
30import qualified Data.IntMap.Strict as IntMap
31 ;import Data.IntMap.Strict (IntMap)
32import qualified Data.Map.Strict as Map
33 ;import Data.Map.Strict (Map)
34import Data.Time.Clock.POSIX
35import qualified Data.Word64Map as W64Map
36 ;import Data.Word64Map (Word64Map)
37import Data.Word
38import Data.Maybe
39import GHC.Conc (closeFdWith)
40import GHC.Event
41import Network.Socket
42import Network.Socket.ByteString as B
43import System.Endian
44import System.IO
45import System.IO.Error
46import System.Timeout
47
48import DPut
49import DebugTag
50import Data.TableMethods
51
52-- | An inbound packet or condition raised while monitoring a connection.
53data Arrival err addr x
54 = Terminated -- ^ Virtual message that signals EOF.
55 | ParseError !err -- ^ A badly-formed message was received.
56 | Arrival { arrivedFrom :: !addr , arrivedMsg :: !x } -- ^ Inbound message.
57
58-- | Three methods are required to implement a datagram based query\/response protocol.
59data TransportA err addr x y = Transport
60 { -- | Blocks until an inbound packet is available. Then calls the provided
61 -- continuation with the packet and origin addres or an error condition.
62 awaitMessage :: forall a. (Arrival err addr x -> IO a) -> STM (IO a)
63 -- | Send an /y/ packet to the given destination /addr/.
64 , sendMessage :: addr -> y -> IO ()
65 -- | Shutdown and clean up any state related to this 'Transport'.
66 , setActive :: Bool -> IO ()
67 }
68
69type Transport err addr x = TransportA err addr x x
70
71closeTransport :: TransportA err addr x y -> IO ()
72closeTransport tr = setActive tr False
73
74-- | This function modifies a 'Transport' to use higher-level addresses and
75-- packet representations. It could be used to change UDP 'ByteString's into
76-- bencoded syntax trees or to add an encryption layer in which addresses have
77-- associated public keys.
78layerTransportM ::
79 (x -> addr -> IO (Either err (x', addr')))
80 -- ^ Function that attempts to transform a low-level address/packet
81 -- pair into a higher level representation.
82 -> (y' -> addr' -> IO (y, addr))
83 -- ^ Function to encode a high-level address/packet into a lower level
84 -- representation.
85 -> TransportA err addr x y
86 -- ^ The low-level transport to be transformed.
87 -> TransportA err addr' x' y'
88layerTransportM parse encode tr =
89 tr { awaitMessage = \kont ->
90 awaitMessage tr $ \case
91 Terminated -> kont $ Terminated
92 ParseError e -> kont $ ParseError e
93 Arrival addr x -> parse x addr >>= \case
94 Left e -> kont $ ParseError e
95 Right (x',addr') -> kont $ Arrival addr' x'
96 , sendMessage = \addr' msg' -> do
97 (msg,addr) <- encode msg' addr'
98 sendMessage tr addr msg
99 }
100
101
102-- | This function modifies a 'Transport' to use higher-level addresses and
103-- packet representations. It could be used to change UDP 'ByteString's into
104-- bencoded syntax trees or to add an encryption layer in which addresses have
105-- associated public keys.
106layerTransport ::
107 (x -> addr -> Either err (x', addr'))
108 -- ^ Function that attempts to transform a low-level address/packet
109 -- pair into a higher level representation.
110 -> (y' -> addr' -> (y, addr))
111 -- ^ Function to encode a high-level address/packet into a lower level
112 -- representation.
113 -> TransportA err addr x y
114 -- ^ The low-level transport to be transformed.
115 -> TransportA err addr' x' y'
116layerTransport parse encode tr =
117 layerTransportM (\x addr -> return $ parse x addr)
118 (\x' addr' -> return $ encode x' addr')
119 tr
120
121-- | Paritions a 'Transport' into two higher-level transports. Note: A 'TChan'
122-- is used to share the same underlying socket, so be sure to fork a thread for
123-- both returned 'Transport's to avoid hanging.
124partitionTransportM :: ((b,a) -> IO (Either (x,xaddr) (b,a)))
125 -> ((y,xaddr) -> IO (Maybe (c,a)))
126 -> TransportA err a b c
127 -> IO (TransportA err xaddr x y, TransportA err a b c)
128partitionTransportM parse encodex tr = do
129 tchan <- atomically newTChan
130 let ytr = tr { awaitMessage = \kont -> fix $ \again -> do
131 awaitMessage tr $ \m -> case m of
132 Arrival adr msg -> parse (msg,adr) >>= \case
133 Left x -> atomically (writeTChan tchan (Just x)) >> join (atomically again)
134 Right (y,yaddr) -> kont $ Arrival yaddr y
135 ParseError e -> kont $ ParseError e
136 Terminated -> atomically (writeTChan tchan Nothing) >> kont Terminated
137 , sendMessage = sendMessage tr
138 }
139 xtr = Transport
140 { awaitMessage = \kont -> readTChan tchan >>= pure . kont . \case
141 Nothing -> Terminated
142 Just (x,xaddr) -> Arrival xaddr x
143 , sendMessage = \addr' msg' -> do
144 msg_addr <- encodex (msg',addr')
145 mapM_ (uncurry . flip $ sendMessage tr) msg_addr
146 , setActive = const $ return ()
147 }
148 return (xtr, ytr)
149
150-- | Paritions a 'Transport' into two higher-level transports. Note: An 'TChan'
151-- is used to share the same underlying socket, so be sure to fork a thread for
152-- both returned 'Transport's to avoid hanging.
153partitionTransport :: ((b,a) -> Either (x,xaddr) (b,a))
154 -> ((y,xaddr) -> Maybe (c,a))
155 -> TransportA err a b c
156 -> IO (TransportA err xaddr x y, TransportA err a b c)
157partitionTransport parse encodex tr =
158 partitionTransportM (return . parse) (return . encodex) tr
159
160-- |
161-- * f add x --> Nothing, consume x
162-- --> Just id, leave x to a different handler
163-- --> Just g, apply g to x and leave that to a different handler
164--
165-- Note: If you add a handler to one of the branches before applying a
166-- 'mergeTransports' combinator, then this handler may not block or return
167-- Nothing.
168addHandler :: (err -> IO ()) -> (addr -> x -> IO (Maybe (x -> x))) -> TransportA err addr x y -> TransportA err addr x y
169addHandler onParseError f tr = tr
170 { awaitMessage = \kont -> fix $ \eat -> awaitMessage tr $ \case
171 Arrival addr x -> f addr x >>= maybe (join $ atomically eat) (kont . Arrival addr . ($ x))
172 ParseError e -> onParseError e >> kont (ParseError e)
173 Terminated -> kont Terminated
174 }
175
176-- | Modify a 'Transport' to invoke an action upon every received packet.
177onInbound :: (addr -> x -> IO ()) -> Transport err addr x -> Transport err addr x
178onInbound f tr = addHandler (const $ return ()) (\addr x -> f addr x >> return (Just id)) tr
179
180-- * Using a query\/response client.
181
182-- | Fork a thread that handles inbound packets. The returned action may be used
183-- to terminate the thread and clean up any related state.
184--
185-- Example usage:
186--
187-- > -- Start client.
188-- > quitServer <- forkListener "listener" (clientNet client)
189-- > -- Send a query q, recieve a response r.
190-- > r <- sendQuery client method q
191-- > -- Quit client.
192-- > quitServer
193forkListener :: String -> Transport err addr x -> IO (IO ())
194forkListener name client = do
195 setActive client True
196 thread_id <- forkIO $ do
197 myThreadId >>= flip labelThread ("listener."++name)
198 fix $ \loop -> join $ atomically $ awaitMessage client $ \case
199 Terminated -> return ()
200 _ -> loop
201 dput XMisc $ "Listener died: " ++ name
202 return $ do
203 setActive client False
204 -- killThread thread_id
205
206-- * Implementing a query\/response 'Client'.
207
208-- | These methods indicate what should be done upon various conditions. Write
209-- to a log file, make debug prints, or simply ignore them.
210--
211-- [ /addr/ ] Address of remote peer.
212--
213-- [ /x/ ] Incoming or outgoing packet.
214--
215-- [ /meth/ ] Method id of incoming or outgoing request.
216--
217-- [ /tid/ ] Transaction id for outgoing packet.
218--
219-- [ /err/ ] Error information, typically a 'String'.
220data ErrorReporter addr x meth tid err = ErrorReporter
221 { -- | Incoming: failed to parse packet.
222 reportParseError :: err -> IO ()
223 -- | Incoming: no handler for request.
224 , reportMissingHandler :: meth -> addr -> x -> IO ()
225 -- | Incoming: unable to identify request.
226 , reportUnknown :: addr -> x -> err -> IO ()
227 }
228
229ignoreErrors :: ErrorReporter addr x meth tid err
230ignoreErrors = ErrorReporter
231 { reportParseError = \_ -> return ()
232 , reportMissingHandler = \_ _ _ -> return ()
233 , reportUnknown = \_ _ _ -> return ()
234 }
235
236logErrors :: ( Show addr
237 , Show meth
238 ) => ErrorReporter addr x meth tid String
239logErrors = ErrorReporter
240 { reportParseError = \err -> dput XMisc err
241 , reportMissingHandler = \meth addr x -> dput XMisc $ show addr ++ " --> Missing handler ("++show meth++")"
242 , reportUnknown = \addr x err -> dput XMisc $ show addr ++ " --> " ++ err
243 }
244
245printErrors :: ( Show addr
246 , Show meth
247 ) => Handle -> ErrorReporter addr x meth tid String
248printErrors h = ErrorReporter
249 { reportParseError = \err -> hPutStrLn h err
250 , reportMissingHandler = \meth addr x -> hPutStrLn h $ show addr ++ " --> Missing handler ("++show meth++")"
251 , reportUnknown = \addr x err -> hPutStrLn h $ show addr ++ " --> " ++ err
252 }
253
254-- Change the /err/ type for an 'ErrorReporter'.
255instance Contravariant (ErrorReporter addr x meth tid) where
256 -- contramap :: (t5 -> t4) -> ErrorReporter t3 t2 t1 t t4 -> ErrorReporter t3 t2 t1 t t5
257 contramap f (ErrorReporter pe mh unk)
258 = ErrorReporter (\e -> pe (f e))
259 mh
260 (\addr x e -> unk addr x (f e))
261
262-- | An incoming message can be classified into three cases.
263data MessageClass err meth tid addr x
264 = IsQuery meth tid -- ^ An unsolicited query is handled based on it's /meth/ value. Any response
265 -- should include the provided /tid/ value.
266 | IsResponse tid -- ^ A response to a outgoing query we associated with a /tid/ value.
267 | IsUnsolicited (addr -> addr -> IO (Maybe (x -> x))) -- ^ Transactionless informative packet. The io action will be invoked
268 -- with the source and destination address of a message. If it handles the
269 -- message, it should return Nothing. Otherwise, it should return a transform
270 -- (usually /id/) to apply before the next handler examines it.
271 | IsUnknown err -- ^ None of the above.
272
273-- | Handler for an inbound query of type /x/ from an address of type _addr_.
274type MethodHandler err tid addr x = MethodHandlerA err tid addr x x
275
276-- | Handler for an inbound query of type /x/ with outbound response of type
277-- /y/ to an address of type /addr/.
278data MethodHandlerA err tid addr x y = forall a b. MethodHandler
279 { -- | Parse the query into a more specific type for this method.
280 methodParse :: x -> Either err a
281 -- | Serialize the response for transmission, given a context /ctx/ and the origin
282 -- and destination addresses.
283 , methodSerialize :: tid -> addr -> addr -> b -> y
284 -- | Fully typed action to perform upon the query. The remote origin
285 -- address of the query is provided to the handler.
286 --
287 -- TODO: Allow queries to be ignored?
288 , methodAction :: addr -> a -> IO b
289 }
290 -- | See also 'IsUnsolicited' which likely makes this constructor unnecessary.
291 | forall a. NoReply
292 { -- | Parse the query into a more specific type for this method.
293 methodParse :: x -> Either err a
294 -- | Fully typed action to perform upon the query. The remote origin
295 -- address of the query is provided to the handler.
296 , noreplyAction :: addr -> a -> IO ()
297 }
298
299
300-- | To dispatch responses to our outbound queries, we require three
301-- primitives. See the 'transactionMethods' function to create these
302-- primitives out of a lookup table and a generator for transaction ids.
303--
304-- The type variable /d/ is used to represent the current state of the
305-- transaction generator and the table of pending transactions.
306data TransactionMethods d qid addr x = TransactionMethods
307 {
308 -- | Before a query is sent, this function stores an 'MVar' to which the
309 -- response will be written too. The returned /qid/ is a transaction id
310 -- that can be used to forget the 'MVar' if the remote peer is not
311 -- responding.
312 dispatchRegister :: POSIXTime -- time of expiry
313 -> (Maybe x -> IO ()) -- callback upon response (or timeout)
314 -> addr
315 -> d
316 -> STM (qid, d)
317 -- | This method is invoked when an incoming packet /x/ indicates it is
318 -- a response to the transaction with id /qid/. The returned IO action
319 -- will write the packet to the correct 'MVar' thus completing the
320 -- dispatch.
321 , dispatchResponse :: qid -> x -> d -> STM (d, IO ())
322 -- | When a timeout interval elapses, this method is called to remove the
323 -- transaction from the table.
324 , dispatchCancel :: qid -> d -> STM d
325 }
326
327-- | A set of methods necessary for dispatching incoming packets.
328type DispatchMethods tbl err meth tid addr x = DispatchMethodsA tbl err meth tid addr x x
329
330-- | A set of methods necessary for dispatching incoming packets.
331data DispatchMethodsA tbl err meth tid addr x y = DispatchMethods
332 { -- | Classify an inbound packet as a query or response.
333 classifyInbound :: x -> MessageClass err meth tid addr x
334 -- | Lookup the handler for a inbound query.
335 , lookupHandler :: meth -> Maybe (MethodHandlerA err tid addr x y)
336 -- | Methods for handling incoming responses.
337 , tableMethods :: TransactionMethods tbl tid addr x
338 }
339
340-- | All inputs required to implement a query\/response client.
341type Client err meth tid addr x = ClientA err meth tid addr x x
342
343-- | All inputs required to implement a query\/response client.
344data ClientA err meth tid addr x y = forall tbl. Client
345 { -- | The 'Transport' used to dispatch and receive packets.
346 clientNet :: TransportA err addr x y
347 -- | Methods for handling inbound packets.
348 , clientDispatcher :: DispatchMethodsA tbl err meth tid addr x y
349 -- | Methods for reporting various conditions.
350 , clientErrorReporter :: ErrorReporter addr x meth tid err
351 -- | State necessary for routing inbound responses and assigning unique
352 -- /tid/ values for outgoing queries.
353 , clientPending :: TVar tbl
354 -- | An action yielding this client\'s own address. It is invoked once
355 -- on each outbound and inbound packet. It is valid for this to always
356 -- return the same value.
357 --
358 -- The argument, if supplied, is the remote address for the transaction.
359 -- This can be used to maintain consistent aliases for specific peers.
360 , clientAddress :: Maybe addr -> IO addr
361 -- | Transform a query /tid/ value to an appropriate response /tid/
362 -- value. Normally, this would be the identity transformation, but if
363 -- /tid/ includes a unique cryptographic nonce, then it should be
364 -- generated here.
365 , clientResponseId :: tid -> IO tid
366 }
367
368-- | These four parameters are required to implement an outgoing query. A
369-- peer-to-peer algorithm will define a 'MethodSerializer' for every 'MethodHandler' that
370-- might be returned by 'lookupHandler'.
371data MethodSerializerA tid addr x y meth a b = MethodSerializer
372 { -- | Returns the microseconds to wait for a response to this query being
373 -- sent to the given address. The /addr/ may also be modified to add
374 -- routing information.
375 methodTimeout :: addr -> STM (addr,Int)
376 -- | A method identifier used for error reporting. This needn't be the
377 -- same as the /meth/ argument to 'MethodHandler', but it is suggested.
378 , method :: meth
379 -- | Serialize the outgoing query /a/ into a transmittable packet /x/.
380 -- The /addr/ arguments are, respectively, our own origin address and the
381 -- destination of the request. The /tid/ argument is useful for attaching
382 -- auxiliary notations on all outgoing packets.
383 , wrapQuery :: tid -> addr -> addr -> a -> x
384 -- | Parse an inbound packet /x/ into a response /b/ for this query.
385 , unwrapResponse :: y -> b
386 }
387
388type MethodSerializer tid addr x meth a b = MethodSerializerA tid addr x x meth a b
389
390microsecondsDiff :: Int -> POSIXTime
391microsecondsDiff us = fromIntegral us / 1000000
392
393asyncQuery_ :: Client err meth tid addr x
394 -> MethodSerializer tid addr x meth a b
395 -> a
396 -> addr
397 -> (Maybe b -> IO ())
398 -> IO (tid,POSIXTime,Int)
399asyncQuery_ (Client net d err pending whoami _) meth q addr0 withResponse = do
400 now <- getPOSIXTime
401 (tid,addr,expiry) <- atomically $ do
402 tbl <- readTVar pending
403 (addr,expiry) <- methodTimeout meth addr0
404 (tid, tbl') <- dispatchRegister (tableMethods d)
405 (now + microsecondsDiff expiry)
406 (withResponse . fmap (unwrapResponse meth))
407 addr -- XXX: Should be addr0 or addr?
408 tbl
409 -- (addr,expiry) <- methodTimeout meth tid addr0
410 writeTVar pending tbl'
411 return (tid,addr,expiry)
412 self <- whoami (Just addr)
413 mres <- do sendMessage net addr (wrapQuery meth tid self addr q)
414 return $ Just ()
415 `catchIOError` (\e -> return Nothing)
416 return (tid,now,expiry)
417
418asyncQuery :: Show meth => Client err meth tid addr x
419 -> MethodSerializer tid addr x meth a b
420 -> a
421 -> addr
422 -> (Maybe b -> IO ())
423 -> IO ()
424asyncQuery client meth q addr withResponse0 = do
425 tm <- getSystemTimerManager
426 tidvar <- newEmptyMVar
427 timedout <- registerTimeout tm 1000000 $ do
428 dput XMisc $ "async TIMEDOUT " ++ show (method meth)
429 withResponse0 Nothing
430 tid <- takeMVar tidvar
431 dput XMisc $ "async TIMEDOUT mvar " ++ show (method meth)
432 case client of
433 Client { clientDispatcher = d, clientPending = pending } -> do
434 atomically $ readTVar pending >>= dispatchCancel (tableMethods d) tid >>= writeTVar pending
435 (tid,now,expiry) <- asyncQuery_ client meth q addr $ \x -> do
436 unregisterTimeout tm timedout
437 withResponse0 x
438 putMVar tidvar tid
439 updateTimeout tm timedout expiry
440 dput XMisc $ "FIN asyncQuery "++show (method meth)++" TIMEOUT="++show expiry
441
442-- | Send a query to a remote peer. Note that this function will always time
443-- out if 'forkListener' was never invoked to spawn a thread to receive and
444-- dispatch the response.
445sendQuery ::
446 forall err a b tbl x meth tid addr.
447 Client err meth tid addr x -- ^ A query/response implementation.
448 -> MethodSerializer tid addr x meth a b -- ^ Information for marshaling the query.
449 -> a -- ^ The outbound query.
450 -> addr -- ^ Destination address of query.
451 -> IO (Maybe b) -- ^ The response, or 'Nothing' if it timed out.
452sendQuery c@(Client net d err pending whoami _) meth q addr0 = do
453 mvar <- newEmptyMVar
454 (tid,now,expiry) <- asyncQuery_ c meth q addr0 $ mapM_ (putMVar mvar)
455 mres <- timeout expiry $ takeMVar mvar
456 case mres of
457 Just b -> return $ Just b
458 Nothing -> do
459 atomically $ readTVar pending >>= dispatchCancel (tableMethods d) tid >>= writeTVar pending
460 return Nothing
461
462contramapAddr :: (a -> b) -> MethodHandler err tid b x -> MethodHandler err tid a x
463contramapAddr f (MethodHandler p s a)
464 = MethodHandler
465 p
466 (\tid src dst result -> s tid (f src) (f dst) result)
467 (\addr arg -> a (f addr) arg)
468contramapAddr f (NoReply p a)
469 = NoReply p (\addr arg -> a (f addr) arg)
470
471-- | Query handlers can throw this to ignore a query instead of responding to
472-- it.
473data DropQuery = DropQuery
474 deriving Show
475
476instance Exception DropQuery
477
478-- | Attempt to invoke a 'MethodHandler' upon a given inbound query. If the
479-- parse is successful, the returned IO action will construct our reply if
480-- there is one. Otherwise, a parse err is returned.
481dispatchQuery :: MethodHandlerA err tid addr x y -- ^ Handler to invoke.
482 -> tid -- ^ The transaction id for this query\/response session.
483 -> addr -- ^ Our own address, to which the query was sent.
484 -> x -- ^ The query packet.
485 -> addr -- ^ The origin address of the query.
486 -> Either err (IO (Maybe y))
487dispatchQuery (MethodHandler unwrapQ wrapR f) tid self x addr =
488 fmap (\a -> catch (Just . wrapR tid self addr <$> f addr a)
489 (\DropQuery -> return Nothing))
490 $ unwrapQ x
491dispatchQuery (NoReply unwrapQ f) tid self x addr =
492 fmap (\a -> f addr a >> return Nothing) $ unwrapQ x
493
494-- | Like 'transactionMethods' but allows extra information to be stored in the
495-- table of pending transactions. This also enables multiple 'Client's to
496-- share a single transaction table.
497transactionMethods' ::
498 ((Maybe x -> IO ()) -> a) -- ^ store MVar into table entry
499 -> (a -> Maybe x -> IO void) -- ^ load MVar from table entry
500 -> TableMethods t tid -- ^ Table methods to lookup values by /tid/.
501 -> (g -> (tid,g)) -- ^ Generate a new unique /tid/ value and update the generator state /g/.
502 -> TransactionMethods (g,t a) tid addr x
503transactionMethods' store load (TableMethods insert delete lookup) generate = TransactionMethods
504 { dispatchCancel = \tid (g,t) -> return (g, delete tid t)
505 , dispatchRegister = \nowPlusExpiry v a (g,t) -> do
506 let (tid,g') = generate g
507 let t' = insert tid (store v) nowPlusExpiry t -- (now + microsecondsDiff expiry) t
508 return ( tid, (g',t') )
509 , dispatchResponse = \tid x (g,t) ->
510 case lookup tid t of
511 Just v -> let t' = delete tid t
512 in return ((g,t'),void $ load v $ Just x)
513 Nothing -> return ((g,t), return ())
514 }
515
516-- | Construct 'TransactionMethods' methods out of 3 lookup table primitives and a
517-- function for generating unique transaction ids.
518transactionMethods ::
519 TableMethods t tid -- ^ Table methods to lookup values by /tid/.
520 -> (g -> (tid,g)) -- ^ Generate a new unique /tid/ value and update the generator state /g/.
521 -> TransactionMethods (g,t (Maybe x -> IO ())) tid addr x
522transactionMethods methods generate = transactionMethods' id id methods generate
523
524-- | Handle a single inbound packet and then invoke the given continuation.
525-- The 'forkListener' function is implemented by passing this function to 'fix'
526-- in a forked thread that loops until 'awaitMessage' returns 'Nothing' or
527-- throws an exception.
528handleMessage ::
529 ClientA err meth tid addr x y
530 -> addr
531 -> x
532 -> IO (Maybe (x -> x))
533handleMessage (Client net d err pending whoami responseID) addr plain = do
534 -- Just (Left e) -> do reportParseError err e
535 -- return $! Just id
536 -- Just (Right (plain, addr)) -> do
537 case classifyInbound d plain of
538 IsQuery meth tid -> case lookupHandler d meth of
539 Nothing -> do reportMissingHandler err meth addr plain
540 return $! Just id
541 Just m -> do
542 self <- whoami (Just addr)
543 tid' <- responseID tid
544 either (\e -> do reportParseError err e
545 return $! Just id)
546 (>>= \m -> do mapM_ (sendMessage net addr) m
547 return $! Nothing)
548 (dispatchQuery m tid' self plain addr)
549 IsUnsolicited action -> do
550 self <- whoami (Just addr)
551 action self addr
552 return Nothing
553 IsResponse tid -> do
554 action <- atomically $ do
555 ts0 <- readTVar pending
556 (ts, action) <- dispatchResponse (tableMethods d) tid plain ts0
557 writeTVar pending ts
558 return action
559 action
560 return $! Nothing
561 IsUnknown e -> do reportUnknown err addr plain e
562 return $! Just id
563 -- Nothing -> return $! id
564
565-- * UDP Datagrams.
566
567-- | Access the address family of a given 'SockAddr'. This convenient accessor
568-- is missing from 'Network.Socket', so I implemented it here.
569sockAddrFamily :: SockAddr -> Family
570sockAddrFamily (SockAddrInet _ _ ) = AF_INET
571sockAddrFamily (SockAddrInet6 _ _ _ _) = AF_INET6
572sockAddrFamily (SockAddrUnix _ ) = AF_UNIX
573#if !MIN_VERSION_network(3,0,0)
574sockAddrFamily _ = AF_CAN -- SockAddrCan constructor deprecated
575#endif
576
577-- | Packets with an empty payload may trigger EOF exception.
578-- 'udpTransport' uses this function to avoid throwing in that
579-- case.
580ignoreEOF :: Socket -> MVar () -> Arrival e a x -> IOError -> IO (Arrival e a x)
581ignoreEOF sock isClosed def e = do
582 done <- tryReadMVar isClosed
583 case done of
584 Just () -> do close sock
585 dput XMisc "Closing UDP socket."
586 pure Terminated
587 _ -> if isEOFError e then pure def
588 else throwIO e
589
590-- | Hard-coded maximum packet size for incoming UDP Packets received via
591-- 'udpTransport'.
592udpBufferSize :: Int
593udpBufferSize = 65536
594
595-- | Wrapper around 'B.sendTo' that silently ignores DoesNotExistError.
596saferSendTo :: Socket -> ByteString -> SockAddr -> IO ()
597saferSendTo sock bs saddr = void (B.sendTo sock bs saddr)
598 `catch` \e ->
599 -- sendTo: does not exist (Network is unreachable)
600 -- Occurs when IPv6 or IPv4 network is not available.
601 -- Currently, we require -threaded to prevent a forever-hang in this case.
602 if isDoesNotExistError e
603 then return ()
604 else throw e
605
606-- | Like 'udpTransport' except also returns the raw socket (for broadcast use).
607udpTransport' :: Show err => SockAddr -> IO (Transport err SockAddr ByteString, Socket)
608udpTransport' bind_address = do
609 let family = sockAddrFamily bind_address
610 sock <- socket family Datagram defaultProtocol
611 when (family == AF_INET6) $ do
612 setSocketOption sock IPv6Only 0
613 setSocketOption sock Broadcast 1
614 bind sock bind_address
615 isClosed <- newEmptyMVar
616 udpTChan <- atomically newTChan
617 let tr = Transport {
618 awaitMessage = \kont -> do
619 r <- readTChan udpTChan
620 return $ kont $! r
621 , sendMessage = case family of
622 AF_INET6 -> \case
623 (SockAddrInet port addr) -> \bs ->
624 -- Change IPv4 to 4mapped6 address.
625 saferSendTo sock bs $ SockAddrInet6 port 0 (0,0,0x0000ffff,fromBE32 addr) 0
626 addr6 -> \bs -> saferSendTo sock bs addr6
627 AF_INET -> \case
628 (SockAddrInet6 port 0 (0,0,0x0000ffff,raw4) 0) -> \bs -> do
629 let host4 = toBE32 raw4
630 -- Change 4mapped6 to ordinary IPv4.
631 -- dput XMisc $ "4mapped6 -> "++show (SockAddrInet port host4)
632 saferSendTo sock bs (SockAddrInet port host4)
633 addr@(SockAddrInet6 {}) -> \bs -> dput XMisc ("Discarding packet to "++show addr)
634 addr4 -> \bs -> saferSendTo sock bs addr4
635 _ -> \addr bs -> saferSendTo sock bs addr
636 , setActive = \case
637 False -> do
638 dput XMisc $ "closeTransport for udpTransport' called. " ++ show bind_address
639 tryPutMVar isClosed () -- signal awaitMessage that the transport is closed.
640#if MIN_VERSION_network (3,1,0)
641#elif MIN_VERSION_network(3,0,0)
642 let withFdSocket sock f = fdSocket sock >>= f >>= seq sock . return
643#else
644 let withFdSocket sock f = f (fdSocket sock) >>= seq sock . return
645#endif
646 withFdSocket sock $ \fd -> do
647 let sorryGHCButIAmNotFuckingClosingTheSocketYet fd = return ()
648 -- This call is necessary to interrupt the blocking recvFrom call in awaitMessage.
649 closeFdWith sorryGHCButIAmNotFuckingClosingTheSocketYet (fromIntegral fd)
650 True -> do
651 udpThread <- forkIO $ fix $ \again -> do
652 r <- handle (ignoreEOF sock isClosed $ Arrival (SockAddrInet 0 0) B.empty) $ do
653 uncurry (flip Arrival) <$!> B.recvFrom sock udpBufferSize
654 atomically $ writeTChan udpTChan r
655 case r of Terminated -> return ()
656 _ -> again
657 labelThread udpThread ("udp.io."++show bind_address)
658 }
659 return (tr, sock)
660
661-- | A 'udpTransport' uses a UDP socket to send and receive 'ByteString's. The
662-- argument is the listen-address for incoming packets. This is a useful
663-- low-level 'Transport' that can be transformed for higher-level protocols
664-- using 'layerTransport'.
665udpTransport :: Show err => SockAddr -> IO (Transport err SockAddr ByteString)
666udpTransport bind_address = fst <$> udpTransport' bind_address
667
668chanTransport :: (addr -> TChan (x, addr)) -> addr -> TChan (x, addr) -> TVar Bool -> Transport err addr x
669chanTransport chanFromAddr self achan aclosed = Transport
670 { awaitMessage = \kont -> do
671 x <- (uncurry (flip Arrival) <$> readTChan achan)
672 `orElse`
673 (readTVar aclosed >>= check >> return Terminated)
674 return $ kont x
675 , sendMessage = \them bs -> do
676 atomically $ writeTChan (chanFromAddr them) (bs,self)
677 , setActive = \case
678 False -> atomically $ writeTVar aclosed True
679 True -> return ()
680 }
681
682-- | Returns a pair of transports linked together to simulate two computers talking to each other.
683testPairTransport :: IO (Transport err SockAddr ByteString, Transport err SockAddr ByteString)
684testPairTransport = do
685 achan <- atomically newTChan
686 bchan <- atomically newTChan
687 aclosed <- atomically $ newTVar False
688 bclosed <- atomically $ newTVar False
689 let a = SockAddrInet 1 1
690 b = SockAddrInet 2 2
691 return ( chanTransport (const bchan) a achan aclosed
692 , chanTransport (const achan) b bchan bclosed )
693
694newtype ByAddress err x addr = ByAddress (Transport err addr x)
695
696newtype Tagged x addr = Tagged x
697
698decorateAddr :: tag addr -> Arrival e addr x -> Arrival e (DSum tag Identity) x
699decorateAddr tag Terminated = Terminated
700decorateAddr tag (ParseError e) = ParseError e
701decorateAddr tag (Arrival addr x) = Arrival (tag ==> addr) x
702
703mergeTransports :: GCompare tag => DMap tag (ByAddress err x) -> IO (Transport err (DSum tag Identity) x)
704mergeTransports tmap = do
705 -- vmap <- traverseWithKey (\k v -> Tagged <$> newEmptyMVar) tmap
706 -- foldrWithKey (\k v n -> forkMergeBranch k v >> n) (return ()) vmap
707 return Transport
708 { awaitMessage = \kont ->
709 foldrWithKey (\k (ByAddress tr) n -> awaitMessage tr (kont . decorateAddr k) `orElse` n)
710 retry
711 tmap
712 , sendMessage = \(tag :=> Identity addr) x -> case DMap.lookup tag tmap of
713 Just (ByAddress tr) -> sendMessage tr addr x
714 Nothing -> return ()
715 , setActive = \toggle -> foldrWithKey (\_ (ByAddress tr) next -> setActive tr toggle >> next) (return ()) tmap
716 }
diff --git a/server/src/Network/QueryResponse/TCP.hs b/server/src/Network/QueryResponse/TCP.hs
new file mode 100644
index 00000000..8b1b432b
--- /dev/null
+++ b/server/src/Network/QueryResponse/TCP.hs
@@ -0,0 +1,225 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE GeneralizedNewtypeDeriving #-}
3{-# LANGUAGE LambdaCase #-}
4{-# LANGUAGE OverloadedStrings #-}
5module Network.QueryResponse.TCP where
6
7#ifdef THREAD_DEBUG
8import Control.Concurrent.Lifted.Instrument
9#else
10import Control.Concurrent.Lifted
11import GHC.Conc (labelThread,forkIO)
12import ForkLabeled
13#endif
14
15import Control.Arrow
16import Control.Concurrent.STM
17import Control.Concurrent.STM.TMVar
18import Control.Monad
19import Data.ByteString (ByteString,hPut)
20import Data.Function
21import Data.Hashable
22import Data.Maybe
23import Data.Ord
24import Data.Time.Clock.POSIX
25import Data.Word
26import Data.String (IsString(..))
27import Network.BSD
28import Network.Socket as Socket
29import System.Timeout
30import System.IO
31import System.IO.Error
32
33import DebugTag
34import DebugUtil
35import DPut
36import Connection.Tcp (socketFamily)
37import qualified Data.MinMaxPSQ as MM
38import Network.QueryResponse
39
40data TCPSession st
41 = PendingTCPSession
42 | TCPSession
43 { tcpHandle :: Handle
44 , tcpState :: st
45 , tcpThread :: ThreadId
46 }
47
48newtype TCPAddress = TCPAddress SockAddr
49 deriving (Eq,Ord,Show)
50
51instance Hashable TCPAddress where
52 hashWithSalt salt (TCPAddress x) = case x of
53 SockAddrInet port addr -> hashWithSalt salt (fromIntegral port :: Word16,addr)
54 SockAddrInet6 port b c d -> hashWithSalt salt (fromIntegral port :: Word16,b,c,d)
55 _ -> 0
56
57data TCPCache st = TCPCache
58 { lru :: TVar (MM.MinMaxPSQ' TCPAddress (Down POSIXTime) (TCPSession st))
59 , tcpMax :: Int
60 }
61
62-- This is a suitable /st/ parameter to 'TCPCache'
63data SessionProtocol x y = SessionProtocol
64 { streamGoodbye :: IO () -- ^ "Goodbye" protocol upon termination.
65 , streamDecode :: IO (Maybe x) -- ^ Parse inbound messages.
66 , streamEncode :: y -> IO () -- ^ Serialize outbound messages.
67 }
68
69data StreamHandshake addr x y = StreamHandshake
70 { streamHello :: addr -> Handle -> IO (SessionProtocol x y) -- ^ "Hello" protocol upon fresh connection.
71 , streamAddr :: addr -> SockAddr
72 }
73
74killSession :: TCPSession st -> IO ()
75killSession PendingTCPSession = return ()
76killSession TCPSession{tcpThread=t} = killThread t
77
78showStat :: IsString p => TCPSession st -> p
79showStat r = case r of PendingTCPSession -> "pending."
80 TCPSession {} -> "established."
81
82tcp_timeout :: Int
83tcp_timeout = 10000000
84
85acquireConnection :: TMVar (Arrival a addr x)
86 -> TCPCache (SessionProtocol x y)
87 -> StreamHandshake addr x y
88 -> addr
89 -> Bool
90 -> IO (Maybe (y -> IO ()))
91acquireConnection mvar tcpcache stream addr bDoCon = do
92 now <- getPOSIXTime
93 -- dput XTCP $ "acquireConnection 0 " ++ show (streamAddr stream addr)
94 entry <- atomically $ do
95 c <- readTVar (lru tcpcache)
96 let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c
97 case v of
98 Nothing | bDoCon -> writeTVar (lru tcpcache)
99 $ MM.insert' (TCPAddress $ streamAddr stream addr) PendingTCPSession (Down now) c
100 | otherwise -> return ()
101 Just (tm, v) -> writeTVar (lru tcpcache)
102 $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down now) c
103 return v
104 -- dput XTCP $ "acquireConnection 1 " ++ show (streamAddr stream addr, fmap (second showStat) entry)
105 case entry of
106 Nothing -> fmap join $ forM (guard bDoCon) $ \() -> do
107 proto <- getProtocolNumber "tcp"
108 sock <- socket (socketFamily $ streamAddr stream addr) Stream proto
109 mh <- catchIOError (do h <- timeout tcp_timeout $ do
110 connect sock (streamAddr stream addr) `catchIOError` (\e -> close sock)
111 h <- socketToHandle sock ReadWriteMode
112 hSetBuffering h NoBuffering
113 return h
114 return h)
115 $ \e -> return Nothing
116 when (isNothing mh) $ do
117 atomically $ modifyTVar' (lru tcpcache)
118 $ MM.delete (TCPAddress $ streamAddr stream addr)
119 Socket.close sock
120 ret <- fmap join $ forM mh $ \h -> do
121 mst <- catchIOError (Just <$> streamHello stream addr h)
122 (\e -> return Nothing)
123 case mst of
124 Nothing -> do
125 atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr)
126 return Nothing
127 Just st -> do
128 dput XTCP $ "TCP Connected! " ++ show (streamAddr stream addr)
129 signal <- newTVarIO False
130 let showAddr a = show (streamAddr stream a)
131 rthread <- forkLabeled ("tcp:"++showAddr addr) $ do
132 atomically (readTVar signal >>= check)
133 fix $ \loop -> do
134 x <- streamDecode st
135 dput XTCP $ "TCP streamDecode " ++ show (streamAddr stream addr) ++ " --> " ++ maybe "Nothing" (const "got") x
136 case x of
137 Just u -> do
138 m <- timeout tcp_timeout $ atomically (putTMVar mvar $ Arrival addr u)
139 when (isNothing m) $ do
140 dput XTCP $ "TCP "++show (streamAddr stream addr) ++ " dropped packet."
141 atomically $ tryTakeTMVar mvar
142 return ()
143 loop
144 Nothing -> do
145 dput XTCP $ "TCP disconnected: " ++ show (streamAddr stream addr)
146 do atomically $ modifyTVar' (lru tcpcache)
147 $ MM.delete (TCPAddress $ streamAddr stream addr)
148 c <- atomically $ readTVar (lru tcpcache)
149 now <- getPOSIXTime
150 forM_ (zip [1..] $ MM.toList c) $ \(i,MM.Binding (TCPAddress addr) r (Down tm)) -> do
151 dput XTCP $ unwords [show i ++ ".", "Still connected:", show addr, show (now - tm), showStat r]
152 mreport <- timeout tcp_timeout $ threadReport False -- XXX: Paranoid timeout
153 case mreport of
154 Just treport -> dput XTCP treport
155 Nothing -> dput XTCP "TCP ERROR: threadReport timed out."
156 hClose h `catchIOError` \e -> return ()
157 let v = TCPSession
158 { tcpHandle = h
159 , tcpState = st
160 , tcpThread = rthread
161 }
162 t <- getPOSIXTime
163 retires <- atomically $ do
164 c <- readTVar (lru tcpcache)
165 let (rs,c') = MM.takeView (tcpMax tcpcache)
166 $ MM.insert' (TCPAddress $ streamAddr stream addr) v (Down t) c
167 writeTVar (lru tcpcache) c'
168 writeTVar signal True
169 return rs
170 forM_ retires $ \(MM.Binding (TCPAddress k) r _) -> void $ forkLabeled ("tcp-close:"++show k) $ do
171 dput XTCP $ "TCP dropped: " ++ show k
172 killSession r
173 case r of TCPSession {tcpState=st,tcpHandle=h} -> do
174 streamGoodbye st
175 hClose h
176 `catchIOError` \e -> return ()
177 _ -> return ()
178
179 return $ Just $ streamEncode st
180 when (isNothing ret) $ do
181 atomically $ modifyTVar' (lru tcpcache) $ MM.delete (TCPAddress $ streamAddr stream addr)
182 return ret
183 Just (tm, PendingTCPSession)
184 | not bDoCon -> return Nothing
185 | otherwise -> fmap join $ timeout tcp_timeout $ atomically $ do
186 c <- readTVar (lru tcpcache)
187 let v = MM.lookup' (TCPAddress $ streamAddr stream addr) c
188 case v of
189 Just (_,TCPSession{tcpState=st}) -> return $ Just $ streamEncode st
190 Nothing -> return Nothing
191 _ -> retry
192 Just (tm, v@TCPSession {tcpState=st}) -> return $ Just $ streamEncode st
193
194closeAll :: TCPCache (SessionProtocol x y) -> StreamHandshake addr x y -> IO ()
195closeAll tcpcache stream = do
196 dput XTCP "TCP.closeAll called."
197 cache <- atomically $ swapTVar (lru tcpcache) MM.empty
198 forM_ (MM.toList cache) $ \(MM.Binding (TCPAddress addr) r tm) -> do
199 killSession r
200 case r of TCPSession{tcpState=st,tcpHandle=h} -> catchIOError (streamGoodbye st >> hClose h)
201 (\e -> return ())
202 _ -> return ()
203
204-- Use a cache of TCP client connections for sending (and receiving) packets.
205-- The boolean value prepended to the message allows the sender to specify
206-- whether or not a new connection will be initiated if neccessary. If 'False'
207-- is passed, then the packet will be sent only if there already exists a
208-- connection.
209tcpTransport :: Int -- ^ maximum number of TCP links to maintain.
210 -> StreamHandshake addr x y
211 -> IO (TCPCache (SessionProtocol x y), TransportA err addr x (Bool,y))
212tcpTransport maxcon stream = do
213 msgvar <- atomically newEmptyTMVar
214 tcpcache <- atomically $ (`TCPCache` maxcon) <$> newTVar (MM.empty)
215 return $ (,) tcpcache Transport
216 { awaitMessage = \f -> takeTMVar msgvar >>= \x -> return $ do
217 f x `catchIOError` (\e -> dput XTCP ("TCP transport stopped. " ++ show e) >> f Terminated)
218 , sendMessage = \addr (bDoCon,y) -> do
219 void . forkLabeled "tcp-send" $ do
220 msock <- acquireConnection msgvar tcpcache stream addr bDoCon
221 mapM_ ($ y) msock
222 `catchIOError` \e -> dput XTCP $ "TCP-send: " ++ show e
223 , setActive = \case False -> closeAll tcpcache stream >> atomically (putTMVar msgvar Terminated)
224 True -> return ()
225 }
diff --git a/server/src/Network/SocketLike.hs b/server/src/Network/SocketLike.hs
new file mode 100644
index 00000000..37891cfd
--- /dev/null
+++ b/server/src/Network/SocketLike.hs
@@ -0,0 +1,98 @@
1{-# LANGUAGE GeneralizedNewtypeDeriving #-}
2{-# LANGUAGE TupleSections #-}
3{-# LANGUAGE CPP #-}
4-- |
5--
6-- A socket could be used indirectly via a 'System.IO.Handle' or a conduit from
7-- Michael Snoyman's conduit package. But doing so presents an encapsulation
8-- problem. Do we allow access to the underlying socket and trust that it wont
9-- be used in an unsafe way? Or do we protect it at the higher level and deny
10-- access to various state information?
11--
12-- The 'SocketLike' class enables the approach that provides a safe wrapper to
13-- the underlying socket and gives access to various state information without
14-- enabling direct reads or writes.
15module Network.SocketLike
16 ( SocketLike(..)
17 , RestrictedSocket
18 , restrictSocket
19 , restrictHandleSocket
20 -- * Re-exports
21 --
22 -- | To make the 'SocketLike' methods less awkward to use, the types
23 -- 'CUInt', 'SockAddr', and 'PortNumber' are re-exported.
24 , CUInt
25 , PortNumber
26 , SockAddr(..)
27 ) where
28
29import Network.Socket
30 ( PortNumber
31 , SockAddr
32 )
33import Foreign.C.Types ( CUInt )
34
35import qualified Network.Socket as NS
36import System.IO (Handle,hClose,hIsOpen)
37import Control.Arrow
38
39-- | A safe (mostly read-only) interface to a 'NS.Socket'. Note that despite
40-- how this class is named, it provides no access to typical 'NS.Socket' uses
41-- like sending or receiving network packets.
42class SocketLike sock where
43 -- | See 'NS.getSocketName'
44 getSocketName :: sock -> IO SockAddr
45 -- | See 'NS.getPeerName'
46 getPeerName :: sock -> IO SockAddr
47 -- | See 'NS.getPeerCred'
48-- getPeerCred :: sock -> IO (CUInt, CUInt, CUInt)
49
50 -- | Is the socket still valid? Connected
51 --
52 -- In order to give the instance writer
53 -- the option to do book-keeping in a pure
54 -- type, a conceptually modified version of
55 -- the 'SocketLike' is returned.
56 --
57 isValidSocket :: sock -> IO (sock,Bool)
58
59
60instance SocketLike NS.Socket where
61 getSocketName = NS.getSocketName
62 getPeerName = NS.getPeerName
63-- getPeerCred = NS.getPeerCred
64#if MIN_VERSION_network(3,1,0)
65 isValidSocket s = (s,) <$> NS.withFdSocket s (return . (/= (-1)))
66#else
67#if MIN_VERSION_network(3,0,0)
68 isValidSocket s = (s,) . (/= (-1)) <$> NS.fdSocket s
69#else
70#if MIN_VERSION_network(2,4,0)
71 isValidSocket s = (s,) <$> NS.isConnected s -- warning: this is always False if the socket
72 -- was converted to a Handle
73#else
74 isValidSocket s = (s,) <$> NS.sIsConnected s -- warning: this is always False if the socket
75 -- was converted to a Handle
76#endif
77#endif
78#endif
79
80-- | An encapsulated socket. Data reads and writes are not possible.
81data RestrictedSocket = Restricted (Maybe Handle) NS.Socket deriving Show
82
83instance SocketLike RestrictedSocket where
84 getSocketName (Restricted mb sock) = NS.getSocketName sock
85 getPeerName (Restricted mb sock) = NS.getPeerName sock
86-- getPeerCred (Restricted mb sock) = NS.getPeerCred sock
87 isValidSocket rs@(Restricted mb sock) = maybe (first (Restricted mb) <$> isValidSocket sock) (((rs,) <$>) . hIsOpen) mb
88
89-- | Create a 'RestrictedSocket' that explicitly disallows sending or
90-- receiving data.
91restrictSocket :: NS.Socket -> RestrictedSocket
92restrictSocket socket = Restricted Nothing socket
93
94-- | Build a 'RestrictedSocket' for which 'sClose' will close the given
95-- 'Handle'. It is intended that this 'Handle' was obtained via
96-- 'NS.socketToHandle'.
97restrictHandleSocket :: Handle -> NS.Socket -> RestrictedSocket
98restrictHandleSocket h socket = Restricted (Just h) socket
diff --git a/server/src/Network/StreamServer.hs b/server/src/Network/StreamServer.hs
new file mode 100644
index 00000000..1da612ce
--- /dev/null
+++ b/server/src/Network/StreamServer.hs
@@ -0,0 +1,167 @@
1-- | This module implements a bare-bones TCP or Unix socket server.
2{-# LANGUAGE CPP #-}
3{-# LANGUAGE TypeFamilies #-}
4{-# LANGUAGE TypeOperators #-}
5{-# LANGUAGE OverloadedStrings #-}
6{-# LANGUAGE RankNTypes #-}
7module Network.StreamServer
8 ( streamServer
9 , ServerHandle
10 , getAcceptLoopThreadId
11 , ServerConfig(..)
12 , withSession
13 , quitListening
14 --, dummyServerHandle
15 , listenSocket
16 , Local(..)
17 , Remote(..)
18 ) where
19
20import Data.Monoid
21import Network.Socket as Socket
22import System.Directory (removeFile)
23import System.IO
24 ( IOMode(..)
25 , stderr
26 , hFlush
27 )
28import Control.Monad
29import Control.Monad.Fix (fix)
30#ifdef THREAD_DEBUG
31import Control.Concurrent.Lifted.Instrument
32 ( forkIO, threadDelay, ThreadId, mkWeakThreadId, labelThread, myThreadId
33 , killThread )
34#else
35import GHC.Conc (labelThread)
36import Control.Concurrent
37 ( forkIO, threadDelay, ThreadId, mkWeakThreadId, myThreadId
38 , killThread )
39#endif
40import Control.Exception (handle,finally)
41import System.IO.Error (tryIOError)
42import System.Mem.Weak
43import System.IO.Error
44
45-- import Data.Conduit
46import System.IO (Handle)
47import Control.Concurrent.MVar (newMVar)
48
49import Network.SocketLike
50import DPut
51import DebugTag
52
53data ServerHandle = ServerHandle Socket (Weak ThreadId)
54
55-- | Useful for testing.
56getAcceptLoopThreadId :: ServerHandle -> IO (Weak ThreadId)
57getAcceptLoopThreadId (ServerHandle _ t) = return t
58
59listenSocket :: ServerHandle -> RestrictedSocket
60listenSocket (ServerHandle sock _) = restrictSocket sock
61
62{- // Removed, bit-rotted and there are no call sites
63-- | Create a useless do-nothing 'ServerHandle'.
64dummyServerHandle :: IO ServerHandle
65dummyServerHandle = do
66 mvar <- newMVar Closed
67 let sock = MkSocket 0 AF_UNSPEC NoSocketType 0 mvar
68 thread <- mkWeakThreadId <=< forkIO $ return ()
69 return (ServerHandle sock thread)
70-}
71
72removeSocketFile :: SockAddr -> IO ()
73removeSocketFile (SockAddrUnix fname) = removeFile fname
74removeSocketFile _ = return ()
75
76-- | Terminate the server accept-loop. Call this to shut down the server.
77quitListening :: ServerHandle -> IO ()
78quitListening (ServerHandle socket acceptThread) =
79 finally (Socket.getSocketName socket >>= removeSocketFile)
80 (do mapM_ killThread =<< deRefWeak acceptThread
81 Socket.close socket)
82
83
84-- | It's 'bshow' instead of 'show' to enable swapping in a 'ByteString'
85-- variation. (This is not exported.)
86bshow :: Show a => a -> String
87bshow e = show e
88
89-- | Send a string to stderr. Not exported. Default 'serverWarn' when
90-- 'withSession' is used to configure the server.
91warnStderr :: String -> IO ()
92warnStderr str = dput XMisc str >> hFlush stderr
93
94newtype Local a = Local a deriving (Eq,Ord,Show)
95newtype Remote a = Remote a deriving (Eq,Ord,Show)
96
97data ServerConfig = ServerConfig
98 { serverWarn :: String -> IO ()
99 -- ^ Action to report warnings and errors.
100 , serverSession :: ( RestrictedSocket, (Local SockAddr, Remote SockAddr)) -> Int -> Handle -> IO ()
101 -- ^ Action to handle interaction with a client
102 }
103
104-- | Initialize a 'ServerConfig' using the provided session handler.
105withSession :: ((RestrictedSocket,(Local SockAddr,Remote SockAddr)) -> Int -> Handle -> IO ()) -> ServerConfig
106withSession session = ServerConfig warnStderr session
107
108-- | Launch a thread to listen at the given bind address and dispatch
109-- to session handler threads on every incoming connection. Supports
110-- IPv4 and IPv6, TCP and unix sockets.
111--
112-- The returned handle can be used with 'quitListening' to terminate the
113-- thread and prevent any new sessions from starting. Currently active
114-- session threads will not be terminated or signaled in any way.
115streamServer :: ServerConfig -> [SockAddr] -> IO ServerHandle
116streamServer cfg addrs = do
117 let warn = serverWarn cfg
118 family = case addrs of
119 SockAddrInet {}:_ -> AF_INET
120 SockAddrInet6 {}:_ -> AF_INET6
121 SockAddrUnix {}:_ -> AF_UNIX
122 [] -> AF_INET6
123 sock <- socket family Stream 0
124 setSocketOption sock ReuseAddr 1
125 let tryBind addr next _ = do
126 tryIOError (removeSocketFile addr)
127 bind sock addr
128 `catchIOError` \e -> next (Just e)
129 fix $ \loop -> let again mbe = do
130 forM_ mbe $ \e -> warn $ "bind-error: " <> bshow addrs <> " " <> bshow e
131 threadDelay 5000000
132 loop
133 in foldr tryBind again addrs Nothing
134 listen sock maxListenQueue
135 thread <- mkWeakThreadId <=< forkIO $ do
136 bindaddr <- Socket.getSocketName sock
137 myThreadId >>= flip labelThread ("StreamServer.acceptLoop." <> bshow bindaddr)
138 acceptLoop cfg sock 0
139 return (ServerHandle sock thread)
140
141-- | Not exported. This, combined with 'acceptException' form a mutually
142-- recursive loop that handles incoming connections. To quit the loop, the
143-- socket must be closed by 'quitListening'.
144acceptLoop :: ServerConfig -> Socket -> Int -> IO ()
145acceptLoop cfg sock n = handle (acceptException cfg n sock) $ do
146 (con,raddr) <- accept sock
147 let conkey = n + 1
148 laddr <- Socket.getSocketName con
149 h <- socketToHandle con ReadWriteMode
150 forkIO $ do
151 myThreadId >>= flip labelThread "StreamServer.session"
152 serverSession cfg (restrictHandleSocket h con, (Local laddr, Remote raddr)) conkey h
153 acceptLoop cfg sock (n + 1)
154
155acceptException :: ServerConfig -> Int -> Socket -> IOError -> IO ()
156acceptException cfg n sock ioerror = do
157 case show (ioeGetErrorType ioerror) of
158 "resource exhausted" -> do -- try again (ioeGetErrorType ioerror == fullErrorType)
159 serverWarn cfg $ ("acceptLoop: resource exhasted")
160 threadDelay 500000
161 acceptLoop cfg sock (n + 1)
162 "invalid argument" -> do -- quit on closed socket
163 Socket.close sock
164 message -> do -- unexpected exception
165 serverWarn cfg $ ("acceptLoop: "<>bshow message)
166 Socket.close sock
167
diff --git a/server/src/SockAddr.hs b/server/src/SockAddr.hs
new file mode 100644
index 00000000..b5fbf16e
--- /dev/null
+++ b/server/src/SockAddr.hs
@@ -0,0 +1,14 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE StandaloneDeriving #-}
3module SockAddr () where
4
5#if MIN_VERSION_network(2,4,0)
6import Network.Socket ()
7#else
8import Network.Socket ( SockAddr(..) )
9
10deriving instance Ord SockAddr
11#endif
12
13
14