-- | Both 'getAddrInfo' and 'getHostByAddr' have hard-coded timeouts for -- waiting upon network queries that can be a little too long for some use -- cases. This module wraps both of them so that they block for at most one -- second. It caches late-arriving results so that they can be returned by -- repeated timed-out queries. -- -- In order to achieve the shorter timeout, it is likely that the you will need -- to build with GHC's -threaded option. Otherwise, if the wrapped FFI calls -- to resolve the address will block Haskell threads. Note: I didn't verify -- this. {-# LANGUAGE TupleSections #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE CPP #-} module DNSCache ( DNSCache , reverseResolve , forwardResolve , newDNSCache , parseAddress , unsafeParseAddress , strip_brackets , withPort ) where import Control.Concurrent.ThreadUtil import Control.Arrow import Control.Concurrent.STM import Data.Text ( Text ) import Network.Socket ( SockAddr(..), AddrInfoFlag(..), defaultHints, getAddrInfo, AddrInfo(..) ) import Data.Time.Clock ( UTCTime, getCurrentTime, diffUTCTime ) import System.IO.Error ( isDoesNotExistError ) import System.Endian ( fromBE32, toBE32 ) import Control.Exception ( handle ) import Data.Map ( Map ) import qualified Data.Map as Map import qualified Network.BSD as BSD import qualified Data.Text as Text import Control.Monad import Data.Function import Data.List import Data.Ord import Data.Maybe import System.IO.Error import System.IO.Unsafe import SockAddr () import ControlMaybe ( handleIO_ ) import GetHostByAddr ( getHostByAddr ) import InterruptibleDelay import DPut import DebugTag type TimeStamp = UTCTime data DNSCache = DNSCache { fcache :: TVar (Map Text [(TimeStamp, SockAddr)]) , rcache :: TVar (Map SockAddr [(TimeStamp, Text)]) } newDNSCache :: IO DNSCache newDNSCache = do fcache <- newTVarIO Map.empty rcache <- newTVarIO Map.empty return DNSCache { fcache=fcache, rcache=rcache } updateCache :: Eq x => Bool -> TimeStamp -> [x] -> Maybe [(TimeStamp,x)] -> Maybe [(TimeStamp,x)] updateCache withScrub utc xs mys = do let ys = maybe [] id mys ys' = filter scrub ys ys'' = map (utc,) xs ++ ys' minute = 60 scrub (t,x) | withScrub && diffUTCTime utc t < minute = False scrub (t,x) | x `elem` xs = False scrub _ = True guard $ not (null ys'') return ys'' dnsObserve :: DNSCache -> Bool -> TimeStamp -> [(Text,SockAddr)] -> STM () dnsObserve dns withScrub utc obs = do f <- readTVar $ fcache dns r <- readTVar $ rcache dns let obs' = map (\(n,a)->(n,a `withPort` 0)) obs gs = do g <- groupBy ((==) `on` fst) $ sortBy (comparing fst) obs' (n,_) <- take 1 g return (n,map snd g) f' = foldl' updatef f gs hs = do h <- groupBy ((==) `on` snd) $ sortBy (comparing snd) obs' (_,a) <- take 1 h return (a,map fst h) r' = foldl' updater r hs writeTVar (fcache dns) f' writeTVar (rcache dns) r' where updatef f (n,addrs) = Map.alter (updateCache withScrub utc addrs) n f updater r (a,ns) = Map.alter (updateCache withScrub utc ns) a r make6mapped4 :: SockAddr -> SockAddr make6mapped4 addr@(SockAddrInet6 {}) = addr make6mapped4 addr@(SockAddrInet port a) = SockAddrInet6 port 0 (0,0,0xFFFF,fromBE32 a) 0 tryForkOS :: String -> IO () -> IO ThreadId tryForkOS lbl action = catchIOError (forkOSLabeled lbl action) $ \e -> do dput XMisc $ "DNSCache: Link with -threaded to avoid excessively long time-out." forkLabeled lbl action -- Attempt to resolve the given domain name. Returns an empty list if the -- resolve operation takes longer than the timeout, but the 'DNSCache' will be -- updated when the resolve completes. -- -- When the resolve operation does complete, any entries less than a minute old -- will be overwritten with the new results. Older entries are allowed to -- persist for reasons I don't understand as of this writing. (See 'updateCache') rawForwardResolve :: DNSCache -> (Text -> IO ()) -> Int -> Text -> IO [SockAddr] rawForwardResolve dns onFail timeout addrtext = do r <- atomically newEmptyTMVar mvar <- interruptibleDelay rt <- tryForkOS ("resolve."++show addrtext) $ do resolver r mvar startDelay mvar timeout did <- atomically $ tryPutTMVar r [] when did (onFail addrtext) atomically $ readTMVar r where resolver r mvar = do xs <- handle (\e -> let _ = isDoesNotExistError e in return []) $ do fmap (nub . map (make6mapped4 . addrAddress)) $ getAddrInfo (Just $ defaultHints { addrFlags = [ AI_CANONNAME, AI_V4MAPPED ]}) (Just $ Text.unpack $ strip_brackets addrtext) (Just "5269") did <- atomically $ tryPutTMVar r xs when did $ do interruptDelay mvar utc <- getCurrentTime atomically $ dnsObserve dns True utc $ map (addrtext,) xs return () strip_brackets :: Text -> Text strip_brackets s = case Text.uncons s of Just ('[',t) -> Text.takeWhile (/=']') t _ -> s reportTimeout :: forall a. Show a => a -> IO () reportTimeout addrtext = do dput XMisc $ "timeout resolving: "++show addrtext -- killThread rt unmap6mapped4 :: SockAddr -> SockAddr unmap6mapped4 addr@(SockAddrInet6 port _ (0,0,0xFFFF,a) _) = SockAddrInet port (toBE32 a) unmap6mapped4 addr = addr rawReverseResolve :: DNSCache -> (SockAddr -> IO ()) -> Int -> SockAddr -> IO [Text] rawReverseResolve dns onFail timeout addr = do r <- atomically newEmptyTMVar mvar <- interruptibleDelay rt <- forkOS $ resolver r mvar startDelay mvar timeout did <- atomically $ tryPutTMVar r [] when did (onFail addr) atomically $ readTMVar r where resolver r mvar = handleIO_ (return ()) $ do ent <- getHostByAddr (unmap6mapped4 addr) -- AF_UNSPEC addr let names = BSD.hostName ent : BSD.hostAliases ent xs = map Text.pack $ nub names forkIO $ do utc <- getCurrentTime atomically $ dnsObserve dns False utc $ map (,addr) xs atomically $ putTMVar r xs -- Returns expired (older than a minute) cached reverse-dns results -- and removes them from the cache. expiredReverse :: DNSCache -> SockAddr -> IO [Text] expiredReverse dns addr = do utc <- getCurrentTime addr <- return $ addr `withPort` 0 es <- atomically $ do r <- readTVar $ rcache dns let ns = maybe [] id $ Map.lookup addr r minute = 60 -- seconds -- XXX: Is this right? flip diffUTCTime utc returns the age of the -- cache entry? (es0,ns') = partition ( (>=minute) . flip diffUTCTime utc . fst ) ns es = map snd es0 modifyTVar' (rcache dns) $ Map.insert addr ns' f <- readTVar $ fcache dns let f' = foldl' (flip $ Map.alter (expire utc)) f es expire utc Nothing = Nothing expire utc (Just as) = if null as' then Nothing else Just as' where as' = filter ( ( SockAddr -> IO [Text] cachedReverse dns addr = do utc <- getCurrentTime addr <- return $ addr `withPort` 0 atomically $ do r <- readTVar (rcache dns) let ns = maybe [] id $ Map.lookup addr r {- ns' = filter ( ( Text -> IO [SockAddr] cachedForward dns n = do utc <- getCurrentTime atomically $ do f <- readTVar (fcache dns) let as = maybe [] id $ Map.lookup n f as' = filter ( ( SockAddr -> IO [Text] reverseResolve dns addr = do expired <- expiredReverse dns addr forM_ expired $ \n -> forkIO $ do rawForwardResolve dns (const $ return ()) 1000000 n return () xs <- rawReverseResolve dns (const $ return ()) 1000000 addr cs <- cachedReverse dns addr return $ xs ++ filter (not . flip elem xs) cs -- Resolves a name, if there's no result within one second, then any cached -- results that are less than a minute old are returned. forwardResolve :: DNSCache -> Text -> IO [SockAddr] forwardResolve dns n = do as <- rawForwardResolve dns (const $ return ()) 1000000 n if null as then cachedForward dns n else return as parseAddress :: Text -> IO (Maybe SockAddr) parseAddress addr_str = do info <- getAddrInfo (Just $ defaultHints { addrFlags = [ AI_NUMERICHOST ] }) (Just . Text.unpack $ addr_str) (Just "0") return . listToMaybe $ map addrAddress info splitAtPort :: String -> (String,String) splitAtPort s = second sanitizePort $ case s of ('[':t) -> break (==']') t _ -> break (==':') s where sanitizePort (']':':':p) = p sanitizePort (':':p) = p sanitizePort _ = "0" unsafeParseAddress :: String -> Maybe SockAddr unsafeParseAddress addr_str = unsafePerformIO $ do let (ipstr,portstr) = splitAtPort addr_str info <- getAddrInfo (Just $ defaultHints { addrFlags = [ AI_NUMERICHOST ] }) (Just ipstr) (Just portstr) return . listToMaybe $ map addrAddress info withPort :: SockAddr -> Int -> SockAddr withPort (SockAddrInet _ a) port = SockAddrInet (toEnum port) a withPort (SockAddrInet6 _ a b c) port = SockAddrInet6 (toEnum port) a b c