summaryrefslogtreecommitdiff
path: root/Presence/DNSCache.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Presence/DNSCache.hs')
-rw-r--r--Presence/DNSCache.hs270
1 files changed, 270 insertions, 0 deletions
diff --git a/Presence/DNSCache.hs b/Presence/DNSCache.hs
new file mode 100644
index 00000000..aaf1a7be
--- /dev/null
+++ b/Presence/DNSCache.hs
@@ -0,0 +1,270 @@
1-- | Both 'getAddrInfo' and 'getHostByAddr' have hard-coded timeouts for
2-- waiting upon network queries that can be a little too long for some use
3-- cases. This module wraps both of them so that they block for at most one
4-- second. It caches late-arriving results so that they can be returned by
5-- repeated timed-out queries.
6--
7-- In order to achieve the shorter timeout, it is likely that the you will need
8-- to build with GHC's -threaded option. Otherwise, if the wrapped FFI calls
9-- to resolve the address will block Haskell threads. Note: I didn't verify
10-- this.
11{-# LANGUAGE TupleSections #-}
12{-# LANGUAGE RankNTypes #-}
13{-# LANGUAGE CPP #-}
14module DNSCache
15 ( DNSCache
16 , reverseResolve
17 , forwardResolve
18 , newDNSCache
19 , parseAddress
20 , strip_brackets
21 , withPort
22 ) where
23
24#ifdef THREAD_DEBUG
25import Control.Concurrent.Lifted.Instrument
26#else
27import Control.Concurrent.Lifted
28import GHC.Conc (labelThread)
29#endif
30import Control.Concurrent.STM
31import Data.Text ( Text )
32import Network.Socket ( SockAddr(..), AddrInfoFlag(..), defaultHints, getAddrInfo, AddrInfo(..) )
33import Data.Time.Clock ( UTCTime, getCurrentTime, diffUTCTime )
34import System.IO.Error ( isDoesNotExistError )
35import System.Endian ( fromBE32, toBE32 )
36import Control.Exception ( handle )
37import Data.Map ( Map )
38import qualified Data.Map as Map
39import qualified Network.BSD as BSD
40import qualified Data.Text as Text
41import Control.Monad
42import Data.Function
43import Data.List
44import Data.Ord
45import Data.Maybe
46import System.IO
47import System.IO.Error
48
49import SockAddr ()
50import ControlMaybe ( handleIO_ )
51import GetHostByAddr ( getHostByAddr )
52import InterruptibleDelay
53
54type TimeStamp = UTCTime
55
56data DNSCache =
57 DNSCache
58 { fcache :: TVar (Map Text [(TimeStamp, SockAddr)])
59 , rcache :: TVar (Map SockAddr [(TimeStamp, Text)])
60 }
61
62
63newDNSCache :: IO DNSCache
64newDNSCache = do
65 fcache <- newTVarIO Map.empty
66 rcache <- newTVarIO Map.empty
67 return DNSCache { fcache=fcache, rcache=rcache }
68
69updateCache :: Eq x =>
70 Bool -> TimeStamp -> [x] -> Maybe [(TimeStamp,x)] -> Maybe [(TimeStamp,x)]
71updateCache withScrub utc xs mys = do
72 let ys = maybe [] id mys
73 ys' = filter scrub ys
74 ys'' = map (utc,) xs ++ ys'
75 minute = 60
76 scrub (t,x) | withScrub && diffUTCTime utc t < minute = False
77 scrub (t,x) | x `elem` xs = False
78 scrub _ = True
79 guard $ not (null ys'')
80 return ys''
81
82dnsObserve :: DNSCache -> Bool -> TimeStamp -> [(Text,SockAddr)] -> STM ()
83dnsObserve dns withScrub utc obs = do
84 f <- readTVar $ fcache dns
85 r <- readTVar $ rcache dns
86 let obs' = map (\(n,a)->(n,a `withPort` 0)) obs
87 gs = do
88 g <- groupBy ((==) `on` fst) $ sortBy (comparing fst) obs'
89 (n,_) <- take 1 g
90 return (n,map snd g)
91 f' = foldl' updatef f gs
92 hs = do
93 h <- groupBy ((==) `on` snd) $ sortBy (comparing snd) obs'
94 (_,a) <- take 1 h
95 return (a,map fst h)
96 r' = foldl' updater r hs
97 writeTVar (fcache dns) f'
98 writeTVar (rcache dns) r'
99 where
100 updatef f (n,addrs) = Map.alter (updateCache withScrub utc addrs) n f
101 updater r (a,ns) = Map.alter (updateCache withScrub utc ns) a r
102
103make6mapped4 :: SockAddr -> SockAddr
104make6mapped4 addr@(SockAddrInet6 {}) = addr
105make6mapped4 addr@(SockAddrInet port a) = SockAddrInet6 port 0 (0,0,0xFFFF,fromBE32 a) 0
106
107tryForkOS :: IO () -> IO ThreadId
108tryForkOS action = catchIOError (forkOS action) $ \e -> do
109 hPutStrLn stderr $ "DNSCache: Link with -threaded to avoid excessively long time-out."
110 forkIO action
111
112
113-- Attempt to resolve the given domain name. Returns an empty list if the
114-- resolve operation takes longer than the timeout, but the 'DNSCache' will be
115-- updated when the resolve completes.
116--
117-- When the resolve operation does complete, any entries less than a minute old
118-- will be overwritten with the new results. Older entries are allowed to
119-- persist for reasons I don't understand as of this writing. (See 'updateCache')
120rawForwardResolve ::
121 DNSCache -> (Text -> IO ()) -> Int -> Text -> IO [SockAddr]
122rawForwardResolve dns onFail timeout addrtext = do
123 r <- atomically newEmptyTMVar
124 mvar <- interruptibleDelay
125 rt <- tryForkOS $ do
126 myThreadId >>= flip labelThread ("resolve."++show addrtext)
127 resolver r mvar
128 startDelay mvar timeout
129 did <- atomically $ tryPutTMVar r []
130 when did (onFail addrtext)
131 atomically $ readTMVar r
132 where
133 resolver r mvar = do
134 xs <- handle (\e -> let _ = isDoesNotExistError e in return [])
135 $ do fmap (nub . map (make6mapped4 . addrAddress)) $
136 getAddrInfo (Just $ defaultHints { addrFlags = [ AI_CANONNAME, AI_V4MAPPED ]})
137 (Just $ Text.unpack $ strip_brackets addrtext)
138 (Just "5269")
139 did <- atomically $ tryPutTMVar r xs
140 when did $ do
141 interruptDelay mvar
142 utc <- getCurrentTime
143 atomically $ dnsObserve dns True utc $ map (addrtext,) xs
144 return ()
145
146strip_brackets :: Text -> Text
147strip_brackets s =
148 case Text.uncons s of
149 Just ('[',t) -> Text.takeWhile (/=']') t
150 _ -> s
151
152
153reportTimeout :: forall a. Show a => a -> IO ()
154reportTimeout addrtext = do
155 hPutStrLn stderr $ "timeout resolving: "++show addrtext
156 -- killThread rt
157
158unmap6mapped4 :: SockAddr -> SockAddr
159unmap6mapped4 addr@(SockAddrInet6 port _ (0,0,0xFFFF,a) _) =
160 SockAddrInet port (toBE32 a)
161unmap6mapped4 addr = addr
162
163rawReverseResolve ::
164 DNSCache -> (SockAddr -> IO ()) -> Int -> SockAddr -> IO [Text]
165rawReverseResolve dns onFail timeout addr = do
166 r <- atomically newEmptyTMVar
167 mvar <- interruptibleDelay
168 rt <- forkOS $ resolver r mvar
169 startDelay mvar timeout
170 did <- atomically $ tryPutTMVar r []
171 when did (onFail addr)
172 atomically $ readTMVar r
173 where
174 resolver r mvar =
175 handleIO_ (return ()) $ do
176 ent <- getHostByAddr (unmap6mapped4 addr) -- AF_UNSPEC addr
177 let names = BSD.hostName ent : BSD.hostAliases ent
178 xs = map Text.pack $ nub names
179 forkIO $ do
180 utc <- getCurrentTime
181 atomically $ dnsObserve dns False utc $ map (,addr) xs
182 atomically $ putTMVar r xs
183
184-- Returns expired (older than a minute) cached reverse-dns results
185-- and removes them from the cache.
186expiredReverse :: DNSCache -> SockAddr -> IO [Text]
187expiredReverse dns addr = do
188 utc <- getCurrentTime
189 addr <- return $ addr `withPort` 0
190 es <- atomically $ do
191 r <- readTVar $ rcache dns
192 let ns = maybe [] id $ Map.lookup addr r
193 minute = 60 -- seconds
194 -- XXX: Is this right? flip diffUTCTime utc returns the age of the
195 -- cache entry?
196 (es0,ns') = partition ( (>=minute) . flip diffUTCTime utc . fst ) ns
197 es = map snd es0
198 modifyTVar' (rcache dns) $ Map.insert addr ns'
199 f <- readTVar $ fcache dns
200 let f' = foldl' (flip $ Map.alter (expire utc)) f es
201 expire utc Nothing = Nothing
202 expire utc (Just as) = if null as' then Nothing else Just as'
203 where as' = filter ( (<minute) . flip diffUTCTime utc . fst) as
204 writeTVar (fcache dns) f'
205 return es
206 return es
207
208cachedReverse :: DNSCache -> SockAddr -> IO [Text]
209cachedReverse dns addr = do
210 utc <- getCurrentTime
211 addr <- return $ addr `withPort` 0
212 atomically $ do
213 r <- readTVar (rcache dns)
214 let ns = maybe [] id $ Map.lookup addr r
215 {-
216 ns' = filter ( (<minute) . flip diffUTCTime utc . fst) ns
217 minute = 60 -- seconds
218 modifyTVar' (rcache dns) $ Map.insert addr ns'
219 return $ map snd ns'
220 -}
221 return $ map snd ns
222
223-- Returns any dns query results for the given name that were observed less
224-- than a minute ago and updates the forward-cache to remove any results older
225-- than that.
226cachedForward :: DNSCache -> Text -> IO [SockAddr]
227cachedForward dns n = do
228 utc <- getCurrentTime
229 atomically $ do
230 f <- readTVar (fcache dns)
231 let as = maybe [] id $ Map.lookup n f
232 as' = filter ( (<minute) . flip diffUTCTime utc . fst) as
233 minute = 60 -- seconds
234 modifyTVar' (fcache dns) $ Map.insert n as'
235 return $ map snd as'
236
237-- Reverse-resolves an address to a domain name. Returns both the result of a
238-- new query and any freshly cached results. Cache entries older than a minute
239-- will not be returned, but will be refreshed in spawned threads so that they
240-- may be available for the next call.
241reverseResolve :: DNSCache -> SockAddr -> IO [Text]
242reverseResolve dns addr = do
243 expired <- expiredReverse dns addr
244 forM_ expired $ \n -> forkIO $ do
245 rawForwardResolve dns (const $ return ()) 1000000 n
246 return ()
247 xs <- rawReverseResolve dns (const $ return ()) 1000000 addr
248 cs <- cachedReverse dns addr
249 return $ xs ++ filter (not . flip elem xs) cs
250
251-- Resolves a name, if there's no result within one second, then any cached
252-- results that are less than a minute old are returned.
253forwardResolve :: DNSCache -> Text -> IO [SockAddr]
254forwardResolve dns n = do
255 as <- rawForwardResolve dns (const $ return ()) 1000000 n
256 if null as
257 then cachedForward dns n
258 else return as
259
260parseAddress :: Text -> IO (Maybe SockAddr)
261parseAddress addr_str = do
262 info <- getAddrInfo (Just $ defaultHints { addrFlags = [ AI_NUMERICHOST ] })
263 (Just . Text.unpack $ addr_str)
264 (Just "0")
265 return . listToMaybe $ map addrAddress info
266
267
268withPort :: SockAddr -> Int -> SockAddr
269withPort (SockAddrInet _ a) port = SockAddrInet (toEnum port) a
270withPort (SockAddrInet6 _ a b c) port = SockAddrInet6 (toEnum port) a b c