diff options
Diffstat (limited to 'Presence/DNSCache.hs')
-rw-r--r-- | Presence/DNSCache.hs | 270 |
1 files changed, 270 insertions, 0 deletions
diff --git a/Presence/DNSCache.hs b/Presence/DNSCache.hs new file mode 100644 index 00000000..aaf1a7be --- /dev/null +++ b/Presence/DNSCache.hs | |||
@@ -0,0 +1,270 @@ | |||
1 | -- | Both 'getAddrInfo' and 'getHostByAddr' have hard-coded timeouts for | ||
2 | -- waiting upon network queries that can be a little too long for some use | ||
3 | -- cases. This module wraps both of them so that they block for at most one | ||
4 | -- second. It caches late-arriving results so that they can be returned by | ||
5 | -- repeated timed-out queries. | ||
6 | -- | ||
7 | -- In order to achieve the shorter timeout, it is likely that the you will need | ||
8 | -- to build with GHC's -threaded option. Otherwise, if the wrapped FFI calls | ||
9 | -- to resolve the address will block Haskell threads. Note: I didn't verify | ||
10 | -- this. | ||
11 | {-# LANGUAGE TupleSections #-} | ||
12 | {-# LANGUAGE RankNTypes #-} | ||
13 | {-# LANGUAGE CPP #-} | ||
14 | module DNSCache | ||
15 | ( DNSCache | ||
16 | , reverseResolve | ||
17 | , forwardResolve | ||
18 | , newDNSCache | ||
19 | , parseAddress | ||
20 | , strip_brackets | ||
21 | , withPort | ||
22 | ) where | ||
23 | |||
24 | #ifdef THREAD_DEBUG | ||
25 | import Control.Concurrent.Lifted.Instrument | ||
26 | #else | ||
27 | import Control.Concurrent.Lifted | ||
28 | import GHC.Conc (labelThread) | ||
29 | #endif | ||
30 | import Control.Concurrent.STM | ||
31 | import Data.Text ( Text ) | ||
32 | import Network.Socket ( SockAddr(..), AddrInfoFlag(..), defaultHints, getAddrInfo, AddrInfo(..) ) | ||
33 | import Data.Time.Clock ( UTCTime, getCurrentTime, diffUTCTime ) | ||
34 | import System.IO.Error ( isDoesNotExistError ) | ||
35 | import System.Endian ( fromBE32, toBE32 ) | ||
36 | import Control.Exception ( handle ) | ||
37 | import Data.Map ( Map ) | ||
38 | import qualified Data.Map as Map | ||
39 | import qualified Network.BSD as BSD | ||
40 | import qualified Data.Text as Text | ||
41 | import Control.Monad | ||
42 | import Data.Function | ||
43 | import Data.List | ||
44 | import Data.Ord | ||
45 | import Data.Maybe | ||
46 | import System.IO | ||
47 | import System.IO.Error | ||
48 | |||
49 | import SockAddr () | ||
50 | import ControlMaybe ( handleIO_ ) | ||
51 | import GetHostByAddr ( getHostByAddr ) | ||
52 | import InterruptibleDelay | ||
53 | |||
54 | type TimeStamp = UTCTime | ||
55 | |||
56 | data DNSCache = | ||
57 | DNSCache | ||
58 | { fcache :: TVar (Map Text [(TimeStamp, SockAddr)]) | ||
59 | , rcache :: TVar (Map SockAddr [(TimeStamp, Text)]) | ||
60 | } | ||
61 | |||
62 | |||
63 | newDNSCache :: IO DNSCache | ||
64 | newDNSCache = do | ||
65 | fcache <- newTVarIO Map.empty | ||
66 | rcache <- newTVarIO Map.empty | ||
67 | return DNSCache { fcache=fcache, rcache=rcache } | ||
68 | |||
69 | updateCache :: Eq x => | ||
70 | Bool -> TimeStamp -> [x] -> Maybe [(TimeStamp,x)] -> Maybe [(TimeStamp,x)] | ||
71 | updateCache withScrub utc xs mys = do | ||
72 | let ys = maybe [] id mys | ||
73 | ys' = filter scrub ys | ||
74 | ys'' = map (utc,) xs ++ ys' | ||
75 | minute = 60 | ||
76 | scrub (t,x) | withScrub && diffUTCTime utc t < minute = False | ||
77 | scrub (t,x) | x `elem` xs = False | ||
78 | scrub _ = True | ||
79 | guard $ not (null ys'') | ||
80 | return ys'' | ||
81 | |||
82 | dnsObserve :: DNSCache -> Bool -> TimeStamp -> [(Text,SockAddr)] -> STM () | ||
83 | dnsObserve dns withScrub utc obs = do | ||
84 | f <- readTVar $ fcache dns | ||
85 | r <- readTVar $ rcache dns | ||
86 | let obs' = map (\(n,a)->(n,a `withPort` 0)) obs | ||
87 | gs = do | ||
88 | g <- groupBy ((==) `on` fst) $ sortBy (comparing fst) obs' | ||
89 | (n,_) <- take 1 g | ||
90 | return (n,map snd g) | ||
91 | f' = foldl' updatef f gs | ||
92 | hs = do | ||
93 | h <- groupBy ((==) `on` snd) $ sortBy (comparing snd) obs' | ||
94 | (_,a) <- take 1 h | ||
95 | return (a,map fst h) | ||
96 | r' = foldl' updater r hs | ||
97 | writeTVar (fcache dns) f' | ||
98 | writeTVar (rcache dns) r' | ||
99 | where | ||
100 | updatef f (n,addrs) = Map.alter (updateCache withScrub utc addrs) n f | ||
101 | updater r (a,ns) = Map.alter (updateCache withScrub utc ns) a r | ||
102 | |||
103 | make6mapped4 :: SockAddr -> SockAddr | ||
104 | make6mapped4 addr@(SockAddrInet6 {}) = addr | ||
105 | make6mapped4 addr@(SockAddrInet port a) = SockAddrInet6 port 0 (0,0,0xFFFF,fromBE32 a) 0 | ||
106 | |||
107 | tryForkOS :: IO () -> IO ThreadId | ||
108 | tryForkOS action = catchIOError (forkOS action) $ \e -> do | ||
109 | hPutStrLn stderr $ "DNSCache: Link with -threaded to avoid excessively long time-out." | ||
110 | forkIO action | ||
111 | |||
112 | |||
113 | -- Attempt to resolve the given domain name. Returns an empty list if the | ||
114 | -- resolve operation takes longer than the timeout, but the 'DNSCache' will be | ||
115 | -- updated when the resolve completes. | ||
116 | -- | ||
117 | -- When the resolve operation does complete, any entries less than a minute old | ||
118 | -- will be overwritten with the new results. Older entries are allowed to | ||
119 | -- persist for reasons I don't understand as of this writing. (See 'updateCache') | ||
120 | rawForwardResolve :: | ||
121 | DNSCache -> (Text -> IO ()) -> Int -> Text -> IO [SockAddr] | ||
122 | rawForwardResolve dns onFail timeout addrtext = do | ||
123 | r <- atomically newEmptyTMVar | ||
124 | mvar <- interruptibleDelay | ||
125 | rt <- tryForkOS $ do | ||
126 | myThreadId >>= flip labelThread ("resolve."++show addrtext) | ||
127 | resolver r mvar | ||
128 | startDelay mvar timeout | ||
129 | did <- atomically $ tryPutTMVar r [] | ||
130 | when did (onFail addrtext) | ||
131 | atomically $ readTMVar r | ||
132 | where | ||
133 | resolver r mvar = do | ||
134 | xs <- handle (\e -> let _ = isDoesNotExistError e in return []) | ||
135 | $ do fmap (nub . map (make6mapped4 . addrAddress)) $ | ||
136 | getAddrInfo (Just $ defaultHints { addrFlags = [ AI_CANONNAME, AI_V4MAPPED ]}) | ||
137 | (Just $ Text.unpack $ strip_brackets addrtext) | ||
138 | (Just "5269") | ||
139 | did <- atomically $ tryPutTMVar r xs | ||
140 | when did $ do | ||
141 | interruptDelay mvar | ||
142 | utc <- getCurrentTime | ||
143 | atomically $ dnsObserve dns True utc $ map (addrtext,) xs | ||
144 | return () | ||
145 | |||
146 | strip_brackets :: Text -> Text | ||
147 | strip_brackets s = | ||
148 | case Text.uncons s of | ||
149 | Just ('[',t) -> Text.takeWhile (/=']') t | ||
150 | _ -> s | ||
151 | |||
152 | |||
153 | reportTimeout :: forall a. Show a => a -> IO () | ||
154 | reportTimeout addrtext = do | ||
155 | hPutStrLn stderr $ "timeout resolving: "++show addrtext | ||
156 | -- killThread rt | ||
157 | |||
158 | unmap6mapped4 :: SockAddr -> SockAddr | ||
159 | unmap6mapped4 addr@(SockAddrInet6 port _ (0,0,0xFFFF,a) _) = | ||
160 | SockAddrInet port (toBE32 a) | ||
161 | unmap6mapped4 addr = addr | ||
162 | |||
163 | rawReverseResolve :: | ||
164 | DNSCache -> (SockAddr -> IO ()) -> Int -> SockAddr -> IO [Text] | ||
165 | rawReverseResolve dns onFail timeout addr = do | ||
166 | r <- atomically newEmptyTMVar | ||
167 | mvar <- interruptibleDelay | ||
168 | rt <- forkOS $ resolver r mvar | ||
169 | startDelay mvar timeout | ||
170 | did <- atomically $ tryPutTMVar r [] | ||
171 | when did (onFail addr) | ||
172 | atomically $ readTMVar r | ||
173 | where | ||
174 | resolver r mvar = | ||
175 | handleIO_ (return ()) $ do | ||
176 | ent <- getHostByAddr (unmap6mapped4 addr) -- AF_UNSPEC addr | ||
177 | let names = BSD.hostName ent : BSD.hostAliases ent | ||
178 | xs = map Text.pack $ nub names | ||
179 | forkIO $ do | ||
180 | utc <- getCurrentTime | ||
181 | atomically $ dnsObserve dns False utc $ map (,addr) xs | ||
182 | atomically $ putTMVar r xs | ||
183 | |||
184 | -- Returns expired (older than a minute) cached reverse-dns results | ||
185 | -- and removes them from the cache. | ||
186 | expiredReverse :: DNSCache -> SockAddr -> IO [Text] | ||
187 | expiredReverse dns addr = do | ||
188 | utc <- getCurrentTime | ||
189 | addr <- return $ addr `withPort` 0 | ||
190 | es <- atomically $ do | ||
191 | r <- readTVar $ rcache dns | ||
192 | let ns = maybe [] id $ Map.lookup addr r | ||
193 | minute = 60 -- seconds | ||
194 | -- XXX: Is this right? flip diffUTCTime utc returns the age of the | ||
195 | -- cache entry? | ||
196 | (es0,ns') = partition ( (>=minute) . flip diffUTCTime utc . fst ) ns | ||
197 | es = map snd es0 | ||
198 | modifyTVar' (rcache dns) $ Map.insert addr ns' | ||
199 | f <- readTVar $ fcache dns | ||
200 | let f' = foldl' (flip $ Map.alter (expire utc)) f es | ||
201 | expire utc Nothing = Nothing | ||
202 | expire utc (Just as) = if null as' then Nothing else Just as' | ||
203 | where as' = filter ( (<minute) . flip diffUTCTime utc . fst) as | ||
204 | writeTVar (fcache dns) f' | ||
205 | return es | ||
206 | return es | ||
207 | |||
208 | cachedReverse :: DNSCache -> SockAddr -> IO [Text] | ||
209 | cachedReverse dns addr = do | ||
210 | utc <- getCurrentTime | ||
211 | addr <- return $ addr `withPort` 0 | ||
212 | atomically $ do | ||
213 | r <- readTVar (rcache dns) | ||
214 | let ns = maybe [] id $ Map.lookup addr r | ||
215 | {- | ||
216 | ns' = filter ( (<minute) . flip diffUTCTime utc . fst) ns | ||
217 | minute = 60 -- seconds | ||
218 | modifyTVar' (rcache dns) $ Map.insert addr ns' | ||
219 | return $ map snd ns' | ||
220 | -} | ||
221 | return $ map snd ns | ||
222 | |||
223 | -- Returns any dns query results for the given name that were observed less | ||
224 | -- than a minute ago and updates the forward-cache to remove any results older | ||
225 | -- than that. | ||
226 | cachedForward :: DNSCache -> Text -> IO [SockAddr] | ||
227 | cachedForward dns n = do | ||
228 | utc <- getCurrentTime | ||
229 | atomically $ do | ||
230 | f <- readTVar (fcache dns) | ||
231 | let as = maybe [] id $ Map.lookup n f | ||
232 | as' = filter ( (<minute) . flip diffUTCTime utc . fst) as | ||
233 | minute = 60 -- seconds | ||
234 | modifyTVar' (fcache dns) $ Map.insert n as' | ||
235 | return $ map snd as' | ||
236 | |||
237 | -- Reverse-resolves an address to a domain name. Returns both the result of a | ||
238 | -- new query and any freshly cached results. Cache entries older than a minute | ||
239 | -- will not be returned, but will be refreshed in spawned threads so that they | ||
240 | -- may be available for the next call. | ||
241 | reverseResolve :: DNSCache -> SockAddr -> IO [Text] | ||
242 | reverseResolve dns addr = do | ||
243 | expired <- expiredReverse dns addr | ||
244 | forM_ expired $ \n -> forkIO $ do | ||
245 | rawForwardResolve dns (const $ return ()) 1000000 n | ||
246 | return () | ||
247 | xs <- rawReverseResolve dns (const $ return ()) 1000000 addr | ||
248 | cs <- cachedReverse dns addr | ||
249 | return $ xs ++ filter (not . flip elem xs) cs | ||
250 | |||
251 | -- Resolves a name, if there's no result within one second, then any cached | ||
252 | -- results that are less than a minute old are returned. | ||
253 | forwardResolve :: DNSCache -> Text -> IO [SockAddr] | ||
254 | forwardResolve dns n = do | ||
255 | as <- rawForwardResolve dns (const $ return ()) 1000000 n | ||
256 | if null as | ||
257 | then cachedForward dns n | ||
258 | else return as | ||
259 | |||
260 | parseAddress :: Text -> IO (Maybe SockAddr) | ||
261 | parseAddress addr_str = do | ||
262 | info <- getAddrInfo (Just $ defaultHints { addrFlags = [ AI_NUMERICHOST ] }) | ||
263 | (Just . Text.unpack $ addr_str) | ||
264 | (Just "0") | ||
265 | return . listToMaybe $ map addrAddress info | ||
266 | |||
267 | |||
268 | withPort :: SockAddr -> Int -> SockAddr | ||
269 | withPort (SockAddrInet _ a) port = SockAddrInet (toEnum port) a | ||
270 | withPort (SockAddrInet6 _ a b c) port = SockAddrInet6 (toEnum port) a b c | ||