summaryrefslogtreecommitdiff
path: root/Data/OpenPGP/Util/DecryptSecretKey.hs
blob: 9f9e42a03f93414b1d525093a65de870753b8c0b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
module Data.OpenPGP.Util.DecryptSecretKey where

import qualified Data.OpenPGP as OpenPGP
import Data.OpenPGP.Internal (decode_s2k_count,checksumForKey)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LZ
import Data.Word
import Data.Int
import Data.Maybe
import Control.Monad (foldM)
import Data.Binary (get,Binary,Get,encode,put)
#if MIN_VERSION_binary(0,6,4)
import Data.Binary.Get (runGetOrFail)
#else
import Control.Exception as Exception (handle,ErrorCall(..))
import System.IO.Unsafe
import Data.Binary.Get (runGet)
#endif
import Control.Exception as Exception (IOException(..),catch)
import Data.Binary.Put (runPut)
import qualified Data.Serialize as Cereal
import Control.Applicative ( (<$>) )

import qualified Crypto.Cipher.AES as Vincent
import qualified Crypto.Cipher.Blowfish as Vincent

import qualified Crypto.Cipher.Types as Vincent
import Data.OpenPGP.Util.Decrypt

#if defined(VERSION_cryptonite)
import qualified Data.ByteArray as Bytes
import Crypto.Hash.Algorithms
import Crypto.Hash
import Crypto.Error
#else
import qualified Data.Byteable as Vincent
import Crypto.Hash.SHA1 as SHA1
#endif

import qualified Crypto.Random as Vincent

import Crypto.Cipher.Cast5 (CAST5_128)
import Crypto.Cipher.ThomasToVincent
import Data.OpenPGP.Util.Base
import Data.OpenPGP.Util.Gen (makeGen)

data Enciphered =
         EncipheredWithIV !LZ.ByteString -- initial vector is appended to front of ByteString
       | EncipheredZeroIV !LZ.ByteString -- initial vector is zero, ByteString contains only the block

withIV :: forall k. (Vincent.BlockCipher k) => (Vincent.IV k -> LZ.ByteString -> LZ.ByteString) -> Enciphered -> LZ.ByteString
withIV f (EncipheredWithIV s) = f iv bs
    where
    Just iv = Vincent.makeIV (LZ.toStrict ivbs)
    (ivbs,bs) = LZ.splitAt (fromIntegral ivlen) s
#if defined(VERSION_cryptonite)
    ivlen = Bytes.length (Vincent.nullIV :: Vincent.IV k)
#else
    ivlen = Vincent.byteableLength z
    _ = Vincent.constEqBytes z iv
    z = Vincent.nullIV
#endif
withIV f (EncipheredZeroIV s) = f Vincent.nullIV s


decryptSecretKey ::
    BS.ByteString           -- ^ Passphrase
    -> OpenPGP.Packet       -- ^ Encrypted SecretKeyPacket
    -> Maybe OpenPGP.Packet -- ^ Decrypted SecretKeyPacket
decryptSecretKey _ k@(OpenPGP.SecretKeyPacket { OpenPGP.symmetric_algorithm = OpenPGP.Unencrypted })
    = Just k
decryptSecretKey pass k@(OpenPGP.SecretKeyPacket {
        OpenPGP.version = 4, OpenPGP.key_algorithm = kalgo,
        OpenPGP.s2k = s2k, OpenPGP.symmetric_algorithm = salgo,
        OpenPGP.key = existing, OpenPGP.encrypted_data = encd
    }) | chkF material == LZ.toStrict chk =
        fmap (\m -> k {
            OpenPGP.s2k_useage = 0,
            OpenPGP.symmetric_algorithm = OpenPGP.Unencrypted,
            OpenPGP.encrypted_data = LZ.empty,
            OpenPGP.key = m
        }) parseMaterial
       | otherwise = Nothing
    where
    parseMaterial = maybeGet
        (foldM (\m f -> do {mpi <- get; return $ (f,mpi):m}) existing
        (OpenPGP.secret_key_fields kalgo)) material
    (material, chk) = LZ.splitAt (LZ.length decd - chkSize) decd
    (chkSize, chkF) = checksumForKey (OpenPGP.s2k_useage k)
    decd = withS2K simpleUnCFB salgo (Just s2k) (toLazyBS pass) (EncipheredWithIV encd)

decryptSecretKey _ _ = Nothing

checksum :: BS.ByteString -> Word16
checksum key = fromIntegral $
    BS.foldl' (\x y -> x + fromIntegral y) (0::Integer) key `mod` 65536


#if MIN_VERSION_binary(0,6,4)
maybeGet :: (Binary a) => Get a -> LZ.ByteString -> Maybe a
maybeGet g bs = (\(_,_,x) -> x) <$> hush (runGetOrFail g bs)
 where
    hush :: Either a b -> Maybe b
    hush (Left _) = Nothing
    hush (Right x) = Just x
#else
maybeGet :: (Binary a) => Get a -> LZ.ByteString -> Maybe a
maybeGet g bs = unsafePerformIO $
    handle (\(ErrorCall _)-> return Nothing) $ return . Just $ runGet g bs
#endif



withS2K :: (forall k. (Vincent.BlockCipher k) => k -> Vincent.IV k -> LZ.ByteString -> LZ.ByteString)
           -> OpenPGP.SymmetricAlgorithm
           -> Maybe OpenPGP.S2K
           -> LZ.ByteString -> Enciphered -> LZ.ByteString
withS2K codec OpenPGP.AES128 s2k s   = withIV $ codec (string2key s2k s :: Vincent.AES128)
withS2K codec OpenPGP.AES192 s2k s   = withIV $ codec (string2key s2k s :: Vincent.AES192)
withS2K codec OpenPGP.AES256 s2k s   = withIV $ codec (string2key s2k s :: Vincent.AES256)
withS2K codec OpenPGP.Blowfish s2k s = withIV $ codec (string2key s2k s :: Vincent.Blowfish128)
withS2K codec OpenPGP.CAST5 s2k s    = withIV $ codec (string2key s2k s :: ThomasToVincent CAST5_128)
withS2K codec algo _ _ = error $ "Unsupported symmetric algorithm : " ++ show algo ++ " in Data.OpenPGP.CryptoAPI.withS2K"

simpleCFB :: forall k g. (Vincent.BlockCipher k, RG g) => g -> k -> LZ.ByteString -> (LZ.ByteString, g)
simpleCFB g k bs = ( padThenUnpad k (LZ.fromChunks . (ivbs:) . (:[]) . Vincent.cfbEncrypt k iv . LZ.toStrict) bs
                   , g' )
 where
    Just iv = Vincent.makeIV ivbs
#if defined(VERSION_cryptonite)
    (ivbs,g') = Vincent.randomBytesGenerate ivlen g
    ivlen = Bytes.length (Vincent.nullIV :: Vincent.IV k)
#else
    z = Vincent.nullIV
    (ivbs,g') = Vincent.cprgGenerate ivlen g
    ivlen = Vincent.byteableLength z
    _ = Vincent.constEqBytes z iv
#endif

catchIO_ :: IO a -> IO a -> IO a
catchIO_ a h = Exception.catch a (\(_ :: IOException) -> h)

encryptSecretKey :: BS.ByteString ->  OpenPGP.S2K -> OpenPGP.SymmetricAlgorithm -> OpenPGP.Packet -> IO (Maybe OpenPGP.Packet)
encryptSecretKey passphrase s2k salgo plain = do
    flip catchIO_ (return Nothing) $ do
    g <- makeGen Nothing
    return $ Just
      plain { OpenPGP.key = [ x | x <- OpenPGP.key plain
                            , fst x `elem` OpenPGP.public_key_fields (OpenPGP.key_algorithm plain) ]
            , OpenPGP.symmetric_algorithm = salgo
            , OpenPGP.s2k = s2k
            , OpenPGP.s2k_useage = s2k_usage_octet
            , OpenPGP.encrypted_data = encd g
            }
 where
    material = runPut $ mapM_ put $ do
        f <- OpenPGP.secret_key_fields (OpenPGP.key_algorithm plain)
        maybeToList $ lookup f (OpenPGP.key plain)
    chk = LZ.fromChunks [ chkF material ]
    decd = LZ.append material chk
    encd g = fst $ withS2K' salgo (Just s2k) (toLazyBS passphrase) (simpleCFB g) decd

    -- If the string-to-key usage octet is zero or 255, then a two-octet
    -- checksum of the plaintext of the algorithm-specific portion (sum
    -- of all octets, mod 65536).  If the string-to-key usage octet was
    -- 254, then a 20-octet SHA-1 hash of the plaintext of the
    -- algorithm-specific portion.  This checksum or hash is encrypted
    -- together with the algorithm-specific fields (if string-to-key
    -- usage octet is not zero).  Note that for all other values, a
    -- two-octet checksum is required.
    s2k_usage_octet = 255
    -- chkSize = 2
    chkF = LZ.toStrict . encode . checksum . LZ.toStrict


    -- k = string2key s2k passphrase -- OpenPGP.string2key hashBySymbol s2k passphrase

randomS2K :: OpenPGP.HashAlgorithm -> IO OpenPGP.S2K
randomS2K hash = do
    g <- makeGen Nothing
#if defined(VERSION_cryptonite)
    let (saltbs,g') = Vincent.randomBytesGenerate 9 g
#else
    let (saltbs,g') = Vincent.cprgGenerate 9 g
#endif
    let Right salt = Cereal.decode (BS.drop 1 saltbs)
    return $ OpenPGP.IteratedSaltedS2K hash salt (decode_s2k_count $ BS.head saltbs)