{-# LANGUAGE TupleSections #-} module DNSCache ( DNSCache , reverseResolve , forwardResolve , newDNSCache , parseAddress , strip_brackets , withPort ) where import Control.Concurrent import Control.Concurrent.STM import Data.Text ( Text ) import Network.Socket ( SockAddr(..), AddrInfoFlag(..), defaultHints, getAddrInfo, AddrInfo(..) ) import Data.Time.Clock ( UTCTime, getCurrentTime, diffUTCTime ) import System.IO.Error ( isDoesNotExistError ) import System.Endian ( fromBE32, toBE32 ) import Control.Exception ( handle, ErrorCall(..) ) import Data.Map ( Map ) import qualified Data.Map as Map import qualified Network.BSD as BSD import qualified Data.Text as Text import Control.Monad import Data.List import Data.Ord import Data.Maybe import SockAddr () import ControlMaybe ( handleIO_ ) import GetHostByAddr ( getHostByAddr ) type TimeStamp = UTCTime data DNSCache = DNSCache { fcache :: TVar (Map Text [(TimeStamp, SockAddr)]) , rcache :: TVar (Map SockAddr [(TimeStamp, Text)]) } newDNSCache :: IO DNSCache newDNSCache = do fcache <- newTVarIO Map.empty rcache <- newTVarIO Map.empty return DNSCache { fcache=fcache, rcache=rcache } equivBy f a b = f a == f b updateCache :: Eq x => Bool -> TimeStamp -> [x] -> Maybe [(TimeStamp,x)] -> Maybe [(TimeStamp,x)] updateCache withScrub utc xs mys = do let ys = maybe [] id mys ys' = filter scrub ys ys'' = map (utc,) xs ++ ys' minute = 60 scrub (t,x) | withScrub && diffUTCTime utc t < minute = False scrub (t,x) | x `elem` xs = False scrub _ = True guard $ not (null ys'') return ys'' dnsObserve :: DNSCache -> Bool -> TimeStamp -> [(Text,SockAddr)] -> STM () dnsObserve dns withScrub utc obs = do f <- readTVar $ fcache dns r <- readTVar $ rcache dns let obs' = map (\(n,a)->(n,a `withPort` 0)) obs gs = do g <- groupBy (equivBy fst) $ sortBy (comparing fst) obs' (n,_) <- take 1 g return (n,map snd g) f' = foldl' updatef f gs hs = do h <- groupBy (equivBy snd) $ sortBy (comparing snd) obs' (_,a) <- take 1 h return (a,map fst h) r' = foldl' updater r hs writeTVar (fcache dns) f' writeTVar (rcache dns) r' where updatef f (n,addrs) = Map.alter (updateCache withScrub utc addrs) n f updater r (a,ns) = Map.alter (updateCache withScrub utc ns) a r make6mapped4 addr@(SockAddrInet6 {}) = addr make6mapped4 addr@(SockAddrInet port a) = SockAddrInet6 port 0 (0,0,0xFFFF,fromBE32 a) 0 rawForwardResolve dns fail timeout addrtext = do r <- atomically newEmptyTMVar mvar <- atomically newEmptyTMVar rt <- forkOS $ resolver r mvar tt <- forkIO $ timer (fail addrtext) timeout r rt atomically $ putTMVar mvar tt atomically $ readTMVar r where resolver r mvar = do xs <- handle (\e -> let _ = isDoesNotExistError e in return []) $ do fmap (nub . map (make6mapped4 . addrAddress)) $ getAddrInfo (Just $ defaultHints { addrFlags = [ AI_CANONNAME, AI_V4MAPPED ]}) (Just $ Text.unpack $ strip_brackets addrtext) (Just "5269") did <- atomically $ tryPutTMVar r xs when did $ do tt <- atomically $ readTMVar mvar throwTo tt (ErrorCall "Interrupted delay") utc <- getCurrentTime atomically $ dnsObserve dns True utc $ map (addrtext,) xs return () strip_brackets s = case Text.uncons s of Just ('[',t) -> Text.takeWhile (/=']') t _ -> s reportTimeout addrtext = do putStrLn $ "timeout resolving: "++show addrtext -- killThread rt timer fail timeout r rt = do handle (\(ErrorCall _)-> return ()) $ do threadDelay timeout did <- atomically $ tryPutTMVar r [] when did fail unmap6mapped4 addr@(SockAddrInet6 port _ (0,0,0xFFFF,a) _) = SockAddrInet port (toBE32 a) unmap6mapped4 addr = addr rawReverseResolve dns fail timeout addr = do r <- atomically newEmptyTMVar mvar <- atomically newEmptyTMVar rt <- forkOS $ resolver r mvar tt <- forkIO $ timer (fail addr) timeout r rt atomically $ putTMVar mvar tt atomically $ readTMVar r where resolver r mvar = handleIO_ (return ()) $ do ent <- getHostByAddr (unmap6mapped4 addr) -- AF_UNSPEC addr let names = BSD.hostName ent : BSD.hostAliases ent xs = map Text.pack $ nub names forkIO $ do utc <- getCurrentTime atomically $ dnsObserve dns False utc $ map (,addr) xs atomically $ putTMVar r xs expiredReverse dns addr = do utc <- getCurrentTime addr <- return $ addr `withPort` 0 es <- atomically $ do r <- readTVar $ rcache dns let ns = maybe [] id $ Map.lookup addr r minute = 60 -- seconds (es0,ns') = partition ( (>=minute) . flip diffUTCTime utc . fst ) ns es = map snd es0 modifyTVar' (rcache dns) $ Map.insert addr ns' f <- readTVar $ fcache dns let f' = foldl' (flip $ Map.alter (expire utc)) f es expire utc Nothing = Nothing expire utc (Just as) = if null as' then Nothing else Just as' where as' = filter ( ( forkIO $ do rawForwardResolve dns (const $ return ()) 1000000 n return () xs <- rawReverseResolve dns (const $ return ()) 1000000 addr cs <- cachedReverse dns addr return $ xs ++ filter (not . flip elem xs) cs forwardResolve dns n = do as <- rawForwardResolve dns (const $ return ()) 1000000 n if null as then cachedForward dns n else return as parseAddress :: Text -> IO (Maybe SockAddr) parseAddress addr_str = do info <- getAddrInfo (Just $ defaultHints { addrFlags = [ AI_NUMERICHOST ] }) (Just . Text.unpack $ addr_str) (Just "0") return . listToMaybe $ map addrAddress info withPort :: SockAddr -> Int -> SockAddr withPort (SockAddrInet _ a) port = SockAddrInet (toEnum port) a withPort (SockAddrInet6 _ a b c) port = SockAddrInet6 (toEnum port) a b c