summaryrefslogtreecommitdiff
path: root/server/src/Network/StreamServer.hs
blob: 8ebdf67810a0dffff1ef63dc5f2e86641fb72cb7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
-- | This module implements a bare-bones TCP or Unix socket server.
{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
module Network.StreamServer
    ( forkStreamServer
    , ServerHandle
    , getAcceptLoopThreadId
    , ServerConfig(..)
    , withSession
    , quitListening
  --, dummyServerHandle
    , listenSocket
    , Local(..)
    , Remote(..)
    ) where

import Data.Monoid
import Network.Socket as Socket
import System.Directory (removeFile)
import System.IO
    ( IOMode(..)
    , stderr
    , hFlush
    )
import Control.Monad
import Control.Monad.Fix (fix)
#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
           ( forkIO, threadDelay, ThreadId, mkWeakThreadId, labelThread, myThreadId
           , killThread )
#else
import GHC.Conc (labelThread)
import Control.Concurrent
           ( forkIO, threadDelay, ThreadId, mkWeakThreadId, myThreadId
           , killThread )
#endif
import Control.Exception (handle,finally)
import System.IO.Error (tryIOError)
import System.Mem.Weak
import System.IO.Error

-- import Data.Conduit
import System.IO (Handle)
import Control.Concurrent.MVar (newMVar)

import Network.Address (make6mapped4, canonize)
import Network.SocketLike
import DPut
import DebugTag

data ServerHandle = ServerHandle Socket (Weak ThreadId)

-- | Useful for testing.
getAcceptLoopThreadId :: ServerHandle -> IO (Weak ThreadId)
getAcceptLoopThreadId (ServerHandle _ t) = return t

listenSocket :: ServerHandle -> RestrictedSocket
listenSocket (ServerHandle sock _) = restrictSocket sock

{- // Removed, bit-rotted and there are no call sites
-- | Create a useless do-nothing 'ServerHandle'.
dummyServerHandle :: IO ServerHandle
dummyServerHandle = do
    mvar <- newMVar Closed
    let sock = MkSocket 0 AF_UNSPEC NoSocketType 0 mvar
    thread <- mkWeakThreadId <=< forkIO $ return ()
    return (ServerHandle sock thread)
-}

removeSocketFile :: SockAddr -> IO ()
removeSocketFile (SockAddrUnix ('\0':_)) = return ()
removeSocketFile (SockAddrUnix fname)    = removeFile fname
removeSocketFile _                       = return ()

-- | Terminate the server accept-loop.  Call this to shut down the server.
quitListening :: ServerHandle -> IO ()
quitListening (ServerHandle socket acceptThread) =
    finally (Socket.getSocketName socket >>= removeSocketFile)
            (do mapM_ killThread =<< deRefWeak acceptThread
                Socket.close socket)


-- | It's 'bshow' instead of 'show' to enable swapping in a 'ByteString'
-- variation.  (This is not exported.)
bshow :: Show a => a -> String
bshow e = show e

-- | Send a string to stderr.  Not exported.  Default 'serverWarn' when
-- 'withSession' is used to configure the server.
warnStderr :: String -> IO ()
warnStderr str = dput XMisc str >> hFlush stderr

newtype Local a = Local a deriving (Eq,Ord,Show)
newtype Remote a = Remote a deriving (Eq,Ord,Show)

data ServerConfig = ServerConfig
    { serverWarn    :: String -> IO ()
    -- ^ Action to report warnings and errors.
    , serverSession :: ( RestrictedSocket, (Local SockAddr, Remote SockAddr)) -> Int -> Handle -> IO ()
    -- ^ Action to handle interaction with a client
    }

-- | Initialize a 'ServerConfig' using the provided session handler.
withSession :: ((RestrictedSocket,(Local SockAddr,Remote SockAddr)) -> Int -> Handle -> IO ()) -> ServerConfig
withSession session = ServerConfig warnStderr session

-- | Launch a thread to listen at the given bind address and dispatch to
-- session handler threads on every incoming connection. Supports IPv4 and
-- IPv6, TCP and unix sockets.
--
-- Arguments:
--
--   [cfg]   Functions for handling incomming sessions and logging prints.
--
--   [addrs] A list of bind addresses that will be tried one after another
--   until a successful listening socket is created.
--
-- The returned handle can be used with 'quitListening' to terminate the thread
-- and prevent any new sessions from starting.  Currently active session
-- threads will not be terminated or signaled in any way.
forkStreamServer :: ServerConfig -> [SockAddr] -> IO ServerHandle
forkStreamServer cfg addrs0 = do
    let warn   = serverWarn cfg
        family = case addrs0 of
                    SockAddrInet  {}:_ -> AF_INET
                    SockAddrInet6 {}:_ -> AF_INET6
                    SockAddrUnix  {}:_ -> AF_UNIX
                    []                 -> AF_INET6
        addrs = map (if family == AF_INET6 then make6mapped4 else canonize) addrs0
    sock <- socket family Stream 0
    setSocketOption sock ReuseAddr 1
    let tryBind addr next = do
            warn $ "Trying to bind to TCP " ++ show addr
            tryIOError (removeSocketFile addr)
            bind sock addr
            return $ Just addr
          `catchIOError` \e -> do
            warn $ "bind-error: " <> bshow addr <> " " <> bshow e
            next
    bound <- fix $ \loop -> do
            m <- foldr tryBind (return Nothing) addrs
            case m of
                Just a  -> return a
                Nothing -> threadDelay 5000000 >> loop
    listen sock maxListenQueue
    thread <- mkWeakThreadId <=< forkIO $ do
        bindaddr <- Socket.getSocketName sock
        myThreadId >>= flip labelThread ("stream.acceptLoop." <> bshow bindaddr)
        acceptLoop cfg sock 0
    return (ServerHandle sock thread)

-- | Not exported.  This, combined with 'acceptException' form a mutually
-- recursive loop that handles incoming connections.  To quit the loop, the
-- socket must be closed by 'quitListening'.
acceptLoop :: ServerConfig -> Socket -> Int -> IO ()
acceptLoop cfg sock n = handle (acceptException cfg n sock) $ do
    (con,raddr) <- accept sock
    let conkey = n + 1
    laddr <- Socket.getSocketName con
    h <- socketToHandle con ReadWriteMode
    forkIO $ do
        myThreadId >>= flip labelThread ("stream.session." ++ show (canonize raddr))
        serverSession cfg (restrictHandleSocket h con, (Local laddr, Remote raddr)) conkey h
    acceptLoop cfg sock (n + 1)

acceptException :: ServerConfig -> Int -> Socket -> IOError -> IO ()
acceptException cfg n sock ioerror = do
    case show (ioeGetErrorType ioerror) of
      "resource exhausted" -> do -- try again (ioeGetErrorType ioerror == fullErrorType)
                                 serverWarn cfg $ ("acceptLoop: resource exhasted")
                                 threadDelay 500000
                                 acceptLoop cfg sock (n + 1)
      "invalid argument"   -> do -- quit on closed socket
                                 Socket.close sock
      message              -> do -- unexpected exception
                                 serverWarn cfg $ ("acceptLoop: "<>bshow message)
                                 Socket.close sock