{-# LANGUAGE OverloadedStrings, ViewPatterns #-} -- validatecert.hs -- -- translation of cert_valid.pl into haskell -- from squid/helpers/ssl/cert_valid.pl import Data.Char import Data.Monoid import Data.List import Data.Maybe import qualified Data.Map as Map import qualified Data.ByteString.Char8 as S import qualified Data.ByteString.Lazy.Char8 as L import qualified Data.X509 as X509 import Control.Monad import Control.Monad.Fix import System.IO.Error import System.IO import Data.Map ( Map ) import Data.Time.LocalTime ( getZonedTime ) import Data.Time.Format ( formatTime ) import Data.X509 as X509 ( SignedCertificate, decodeSignedObject ) import System.Exit import System.Posix.Process ( getProcessID ) import System.Locale ( defaultTimeLocale ) import System.Environment ( getProgName, getArgs ) import ScanningParser import PEM continue e body = either (const $ return ()) body e digits s = S.all isDigit s bshow :: Show x => x -> S.ByteString bshow = S.pack . show toS = foldl1' (<>) . L.toChunks parseHeader :: S.ByteString -> Either S.ByteString (S.ByteString, S.ByteString, Int, S.ByteString) parseHeader first_line = parseHeaderWords $ S.words first_line where parseHeaderWords (channelId:code:bodylen:body:ignored) | not (digits channelId) = Left $ channelId <> " BH message=\"This helper is concurrent and requires\ \ the concurrency option to be specified.\"\1" parseHeaderWords (channelId:code:bodylen:body:ignored) | not (digits bodylen) = Left $ channelId <> " BH message=\"cert validator request syntax error.\" \1"; parseHeaderWords (channelId:code:bodylen:body:ignored) = Right ( channelId , code , read $ S.unpack bodylen , body <> "\n" ) parseHeaderWords (channelId:_) = Left $ channelId <> " BH message=\"Insufficient words in message.\"\1" parseHeaderWords [] = Left "" data ValidationError = ValidationError { veName :: S.ByteString , veCert :: S.ByteString , veReason :: S.ByteString } type Cert = SignedCertificate -- PEMBlob pemToCert :: PEMBlob -> Maybe Cert pemToCert pem = either (const Nothing) Just $ decodeSignedObject obj where obj = foldl1' (<>) $ L.toChunks $ pemBlob pem certSubject :: Cert -> S.ByteString certSubject cert = maybe "" X509.getCharacterStringRawData $ foldr1 mplus [cn,ou,o] where dn = X509.certSubjectDN $ X509.getCertificate cert cn = X509.getDnElement X509.DnCommonName dn ou = X509.getDnElement X509.DnOrganizationUnit dn o = X509.getDnElement X509.DnOrganization dn data ValidationRequest = ValidationRequest { vrHostname :: S.ByteString , vrErrors :: Map S.ByteString ValidationError , vrCerts :: Map S.ByteString Cert , vrSyntaxErrors :: [L.ByteString] , vrPeerCertId :: Maybe S.ByteString } main = do debug <- do args <- getArgs when (not $ null $ ["-h","--help"] `intersect` args) $ do me <- getProgName hPutStr stderr $ usage me [(["-h","--help"], "brief help message") ,(["-d","--debug"], "enable debug messages to stderr")] exitSuccess return $ not $ null $ ["-d","--debug"] `intersect` args fix $ \next -> do e <- tryIOError S.getLine continue e $ \first_line -> do when (S.all isSpace first_line) next let wlog' s | S.null s = return () | debug = wlog s | otherwise = return () flip (either wlog') (parseHeader first_line) $ \(channelId,code,bodylen,body0) -> do when debug $ wlog $ "GOT " <> "Code=" <> code <> " " <> bshow bodylen <> "\n" body1 <- L.hGet stdin (bodylen - S.length body0) `catchIOError` (const $ return "") let body = L.fromChunks $ body0 : L.toChunks body1 req = parseRequest body when debug $ forM_ (vrSyntaxErrors req) $ \request -> do wlog $ "ParseError on \"" <> toS request <> "\"\n" when debug $ do wlog $ "Parse result:\n" wlog $ "\tHOST: " <> vrHostname req <> "\n" maybe (return ()) (\certid -> wlog $ "\tCERT: " <> certid <> "\n") $ vrPeerCertId req let estr = S.intercalate "," $ map showe $ Map.elems $ vrErrors req showe e = veName e <> "/" <> veCert e wlog $ "\tERRORS: " <> estr <> "\n" forM_ (Map.toList $ vrCerts req) $ \(key,cert) -> do wlog $ "\tCHAIN " <> key <> ": " <> certSubject cert <> "\n" let responseErrors = fmap (\ve -> ve { veReason = "Checked by validatecert.hs" }) $ vrErrors req response0 = createResponse req responseErrors len = bshow $ S.length response0 response = if Map.null responseErrors then channelId <> " OK " <> len <> " " <> response0 <> "\1" else channelId <> " ERR " <> len <> " " <> response0 <> "\1" S.putStr response hFlush stdout when debug $ forM_ (S.lines $ S.init response) $ \msg -> do wlog $ ">> " <> msg <> "\n" next createResponse :: ValidationRequest -> Map S.ByteString ValidationError -> S.ByteString createResponse vr responseErrors = S.concat $ zipWith mkresp [0..] $ Map.elems responseErrors where mkresp i err = "error_name_" <> bshow i <> "=" <> veName err <> "\n" <>"error_reason_" <> bshow i <> "=" <> veReason err <> "\n" <>"error_cert_" <> bshow i <> "=" <> veCert err <> "\n" -- vrCertFromErr err = vrCerts vr Map.! veCert err parseRequest :: L.ByteString -> ValidationRequest parseRequest body = parseRequest0 vr0 body where vr0 = ValidationRequest { vrHostname = "" , vrErrors = Map.empty , vrCerts = Map.empty , vrSyntaxErrors = [] , vrPeerCertId = Nothing } ve0 = ValidationError { veName = "" , veCert = "" , veReason = "" } parseRequest0 :: ValidationRequest -> L.ByteString -> ValidationRequest parseRequest0 vr request | L.all isSpace request = vr parseRequest0 vr (splitEq -> Just ("host",L.break (=='\n')->(hostname,rs))) = parseRequest0 vr' rs where vr' = vr { vrHostname = toS hostname } parseRequest0 vr (splitEq -> Just (var,cert)) | "cert_" `L.isPrefixOf` var = parseRequest0 vr' $ L.unlines rs where vr' = maybe vr upd $ mb >>= pemToCert upd cert = vr { vrCerts = Map.insert (toS var) cert $ vrCerts vr , vrPeerCertId = Just $ fromMaybe (toS var) $ vrPeerCertId vr } p = pemParser (Just "CERTIFICATE") (mb,rs) = scanAndParse1 p $ L.lines cert parseRequest0 vr (digitsId . splitEq -> Just (("error_name",d),L.break (=='\n')->(errorName,rs))) = parseRequest0 vr' rs where vr' = vr { vrErrors = Map.alter (setErrorName errorName) (toS d) $ vrErrors vr } parseRequest0 vr (digitsId . splitEq -> Just (("error_cert",d),L.break (=='\n')->(certId,rs))) = parseRequest0 vr' rs where vr' = vr { vrErrors = Map.alter (setErrorCert certId) (toS d) $ vrErrors vr } parseRequest0 vr req = vr' where vr' = vr { vrSyntaxErrors = syntaxError $ vrSyntaxErrors vr } syntaxError es = es ++ [ req ] setErrorName :: L.ByteString -> Maybe ValidationError -> Maybe ValidationError setErrorName x mb = maybe (Just $ ve0 { veName = toS x }) (\ve -> Just $ ve { veName = toS x }) mb setErrorCert :: L.ByteString -> Maybe ValidationError -> Maybe ValidationError setErrorCert x mb = maybe (Just $ ve0 { veCert = toS x }) (\ve -> Just $ ve { veCert = toS x }) mb digitsId mb = do (n,v) <- mb let (n',tl) = L.span isDigit $ L.reverse n if "_" `L.isPrefixOf` tl then Just ( (L.reverse $ L.drop 1 tl, L.reverse n'), v ) else Nothing splitEq request = if L.null tl then Nothing else Just (hd,L.drop 1 tl) where (hd,tl) = L.break (=='=') $ L.dropWhile isSpace request wlog msg = do now <- getZonedTime pid <- getProcessID self <- getProgName hPutStr stderr $ formatTime defaultTimeLocale "%Y/%m/%d %H:%M:%S.0" now <> " " <> self <> " " <> show pid <> " | " <> S.unpack msg hFlush stderr usage :: String -> [([String],String)] -> String usage cmdname argspec = unlines $ intercalate [""] $ [ "Usage:" , tab <> cmdname <> " " <> breif argspec ] : map helptext argspec where tab = " " tabbb = tab <> tab <> tab alts as = intercalate " | " as bracket s = "[" <> s <> "]" breif spec = intercalate " " $ map (bracket . alts . fst) spec helptext (as,help) = [ tab <> alts as , tabbb <> help ]