summaryrefslogtreecommitdiff
path: root/dht/src/Network/StreamServer.hs
blob: 1da612ce9a135a7f69fca2a5e77abbcc314238a7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
-- | This module implements a bare-bones TCP or Unix socket server.
{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
module Network.StreamServer
    ( streamServer
    , ServerHandle
    , getAcceptLoopThreadId
    , ServerConfig(..)
    , withSession
    , quitListening
  --, dummyServerHandle
    , listenSocket
    , Local(..)
    , Remote(..)
    ) where

import Data.Monoid
import Network.Socket as Socket
import System.Directory (removeFile)
import System.IO
    ( IOMode(..)
    , stderr
    , hFlush
    )
import Control.Monad
import Control.Monad.Fix (fix)
#ifdef THREAD_DEBUG
import Control.Concurrent.Lifted.Instrument
           ( forkIO, threadDelay, ThreadId, mkWeakThreadId, labelThread, myThreadId
           , killThread )
#else
import GHC.Conc (labelThread)
import Control.Concurrent
           ( forkIO, threadDelay, ThreadId, mkWeakThreadId, myThreadId
           , killThread )
#endif
import Control.Exception (handle,finally)
import System.IO.Error (tryIOError)
import System.Mem.Weak
import System.IO.Error

-- import Data.Conduit
import System.IO (Handle)
import Control.Concurrent.MVar (newMVar)

import Network.SocketLike
import DPut
import DebugTag

data ServerHandle = ServerHandle Socket (Weak ThreadId)

-- | Useful for testing.
getAcceptLoopThreadId :: ServerHandle -> IO (Weak ThreadId)
getAcceptLoopThreadId (ServerHandle _ t) = return t

listenSocket :: ServerHandle -> RestrictedSocket
listenSocket (ServerHandle sock _) = restrictSocket sock

{- // Removed, bit-rotted and there are no call sites
-- | Create a useless do-nothing 'ServerHandle'.
dummyServerHandle :: IO ServerHandle
dummyServerHandle = do
    mvar <- newMVar Closed
    let sock = MkSocket 0 AF_UNSPEC NoSocketType 0 mvar
    thread <- mkWeakThreadId <=< forkIO $ return ()
    return (ServerHandle sock thread)
-}

removeSocketFile :: SockAddr -> IO ()
removeSocketFile (SockAddrUnix fname) = removeFile fname
removeSocketFile _                    = return ()

-- | Terminate the server accept-loop.  Call this to shut down the server.
quitListening :: ServerHandle -> IO ()
quitListening (ServerHandle socket acceptThread) =
    finally (Socket.getSocketName socket >>= removeSocketFile)
            (do mapM_ killThread =<< deRefWeak acceptThread
                Socket.close socket)


-- | It's 'bshow' instead of 'show' to enable swapping in a 'ByteString'
-- variation.  (This is not exported.)
bshow :: Show a => a -> String
bshow e = show e

-- | Send a string to stderr.  Not exported.  Default 'serverWarn' when
-- 'withSession' is used to configure the server.
warnStderr :: String -> IO ()
warnStderr str = dput XMisc str >> hFlush stderr

newtype Local a = Local a deriving (Eq,Ord,Show)
newtype Remote a = Remote a deriving (Eq,Ord,Show)

data ServerConfig = ServerConfig
    { serverWarn    :: String -> IO ()
    -- ^ Action to report warnings and errors.
    , serverSession :: ( RestrictedSocket, (Local SockAddr, Remote SockAddr)) -> Int -> Handle -> IO ()
    -- ^ Action to handle interaction with a client
    }

-- | Initialize a 'ServerConfig' using the provided session handler.
withSession :: ((RestrictedSocket,(Local SockAddr,Remote SockAddr)) -> Int -> Handle -> IO ()) -> ServerConfig
withSession session = ServerConfig warnStderr session

-- | Launch a thread to listen at the given bind address and dispatch
-- to session handler threads on every incoming connection. Supports
-- IPv4 and IPv6, TCP and unix sockets.
--
-- The returned handle can be used with 'quitListening' to terminate the
-- thread and prevent any new sessions from starting.  Currently active
-- session threads will not be terminated or signaled in any way.
streamServer :: ServerConfig -> [SockAddr] -> IO ServerHandle
streamServer cfg addrs = do
    let warn   = serverWarn cfg
        family = case addrs of
                    SockAddrInet  {}:_ -> AF_INET
                    SockAddrInet6 {}:_ -> AF_INET6
                    SockAddrUnix  {}:_ -> AF_UNIX
                    []                 -> AF_INET6
    sock <- socket family Stream 0
    setSocketOption sock ReuseAddr 1
    let tryBind addr next _ = do
            tryIOError (removeSocketFile addr)
            bind sock addr
          `catchIOError` \e -> next (Just e)
    fix $ \loop -> let again mbe = do
                            forM_ mbe $ \e -> warn $ "bind-error: " <> bshow addrs <> " " <> bshow e
                            threadDelay 5000000
                            loop
                    in foldr tryBind again addrs Nothing
    listen sock maxListenQueue
    thread <- mkWeakThreadId <=< forkIO $ do
        bindaddr <- Socket.getSocketName sock
        myThreadId >>= flip labelThread ("StreamServer.acceptLoop." <> bshow bindaddr)
        acceptLoop cfg sock 0
    return (ServerHandle sock thread)

-- | Not exported.  This, combined with 'acceptException' form a mutually
-- recursive loop that handles incoming connections.  To quit the loop, the
-- socket must be closed by 'quitListening'.
acceptLoop :: ServerConfig -> Socket -> Int -> IO ()
acceptLoop cfg sock n = handle (acceptException cfg n sock) $ do
    (con,raddr) <- accept sock
    let conkey = n + 1
    laddr <- Socket.getSocketName con
    h <- socketToHandle con ReadWriteMode
    forkIO $ do
        myThreadId >>= flip labelThread "StreamServer.session"
        serverSession cfg (restrictHandleSocket h con, (Local laddr, Remote raddr)) conkey h
    acceptLoop cfg sock (n + 1)

acceptException :: ServerConfig -> Int -> Socket -> IOError -> IO ()
acceptException cfg n sock ioerror = do
    case show (ioeGetErrorType ioerror) of
      "resource exhausted" -> do -- try again (ioeGetErrorType ioerror == fullErrorType)
                                 serverWarn cfg $ ("acceptLoop: resource exhasted")
                                 threadDelay 500000
                                 acceptLoop cfg sock (n + 1)
      "invalid argument"   -> do -- quit on closed socket
                                 Socket.close sock
      message              -> do -- unexpected exception
                                 serverWarn cfg $ ("acceptLoop: "<>bshow message)
                                 Socket.close sock