summaryrefslogtreecommitdiff
path: root/lib/Compat.hs
blob: 3b778519194e17f4abf4989f327a67b5f780505f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
{-# LANGUAGE CPP #-}
module Compat where

import Data.Bits
import Data.Word
import Data.ASN1.Types
import Data.ASN1.Encoding
import Data.ASN1.BinaryEncoding
import Crypto.PubKey.RSA as RSA

#if defined(VERSION_cryptonite)

instance ASN1Object PublicKey where
    toASN1 pubKey = \xs -> Start Sequence
                         : IntVal (public_n pubKey)
                         : IntVal (public_e pubKey)
                         : End Sequence
                         : xs
    fromASN1 (Start Sequence:IntVal smodulus:IntVal pubexp:End Sequence:xs) =
        Right (PublicKey { public_size = calculate_modulus modulus 1
                         , public_n    = modulus
                         , public_e    = pubexp
                         }
              , xs)
        where calculate_modulus n i = if (2 ^ (i * 8)) > n then i else calculate_modulus n (i+1)
              -- some bad implementation will not serialize ASN.1 integer properly, leading
              -- to negative modulus. if that's the case, we correct it.
              modulus = toPositive smodulus
    fromASN1 ( Start Sequence
             : IntVal 0
             : Start Sequence
             : OID [1, 2, 840, 113549, 1, 1, 1]
             : Null
             : End Sequence
             : OctetString bs
             : xs
             ) = let inner = either strError fromASN1 $ decodeASN1' BER bs
                     strError = Left .
                                ("fromASN1: RSA.PublicKey: " ++) . show
                 in either Left (\(k, _) -> Right (k, xs)) inner
    fromASN1 _ =
        Left "fromASN1: RSA.PublicKey: unexpected format"

#endif

toPositive :: Integer -> Integer
toPositive int
    | int < 0   = uintOfBytes $ bytesOfInt int
    | otherwise = int
  where uintOfBytes = foldl (\acc n -> (acc `shiftL` 8) + fromIntegral n) 0
        bytesOfInt :: Integer -> [Word8]
        bytesOfInt n = if testBit (head nints) 7 then nints else 0xff : nints
          where nints = reverse $ plusOne $ reverse $ map complement $ bytesOfUInt (abs n)
                plusOne []     = [1]
                plusOne (x:xs) = if x == 0xff then 0 : plusOne xs else (x+1) : xs
                bytesOfUInt x = reverse (list x)
                  where list i = if i <= 0xff then [fromIntegral i] else (fromIntegral i .&. 0xff) : list (i `shiftR` 8)