RSASSA-PSS with key of arbitrary length

Instead of public_size / private_size which are in bytes only, this
uses function numBits to recover the effective length of the modulus
in bits.  The patch also handles removal of unneeded initial byte when
the length is 1 modulo 8.
This commit is contained in:
Olivier Chéron 2019-01-12 09:22:15 +01:00
parent 274911c608
commit f9ae52327c

View File

@ -26,6 +26,7 @@ import Crypto.PubKey.RSA.Prim
import Crypto.PubKey.RSA (generateBlinder) import Crypto.PubKey.RSA (generateBlinder)
import Crypto.PubKey.MaskGenFunction import Crypto.PubKey.MaskGenFunction
import Crypto.Hash import Crypto.Hash
import Crypto.Number.Basic (numBits)
import Data.Bits (xor, shiftR, (.&.)) import Data.Bits (xor, shiftR, (.&.))
import Data.Word import Data.Word
@ -69,14 +70,15 @@ signDigestWithSalt :: HashAlgorithm hash
-> Digest hash -- ^ Message digest -> Digest hash -- ^ Message digest
-> Either Error ByteString -> Either Error ByteString
signDigestWithSalt salt blinder params pk digest signDigestWithSalt salt blinder params pk digest
| k < hashLen + saltLen + 2 = Left InvalidParameters | emLen < hashLen + saltLen + 2 = Left InvalidParameters
| otherwise = Right $ dp blinder pk em | otherwise = Right $ dp blinder pk em
where k = private_size pk where k = private_size pk
emLen = if emTruncate pubBits then k - 1 else k
mHash = B.convert digest mHash = B.convert digest
dbLen = k - hashLen - 1 dbLen = emLen - hashLen - 1
saltLen = B.length salt saltLen = B.length salt
hashLen = hashDigestSize (pssHash params) hashLen = hashDigestSize (pssHash params)
pubBits = private_size pk * 8 -- to change if public_size is converted in bytes pubBits = numBits (private_n pk)
m' = B.concat [B.replicate 8 0,mHash,salt] m' = B.concat [B.replicate 8 0,mHash,salt]
h = B.convert $ hashWith (pssHash params) m' h = B.convert $ hashWith (pssHash params) m'
db = B.concat [B.replicate (dbLen - saltLen - 1) 0,B.singleton 1,salt] db = B.concat [B.replicate (dbLen - saltLen - 1) 0,B.singleton 1,salt]
@ -161,7 +163,8 @@ verifyDigest :: HashAlgorithm hash
-> ByteString -- ^ Signature -> ByteString -- ^ Signature
-> Bool -> Bool
verifyDigest params pk digest s verifyDigest params pk digest s
| public_size pk /= B.length s = False | B.length s /= k = False
| B.any (/= 0) pre = False
| B.last em /= pssTrailerField params = False | B.last em /= pssTrailerField params = False
| not (B.all (== 0) ps0) = False | not (B.all (== 0) ps0) = False
| b1 /= B.singleton 1 = False | b1 /= B.singleton 1 = False
@ -169,11 +172,13 @@ verifyDigest params pk digest s
where -- parameters where -- parameters
hashLen = hashDigestSize (pssHash params) hashLen = hashDigestSize (pssHash params)
mHash = B.convert digest mHash = B.convert digest
dbLen = public_size pk - hashLen - 1 k = public_size pk
pubBits = public_size pk * 8 -- to change if public_size is converted in bytes emLen = if emTruncate pubBits then k - 1 else k
dbLen = emLen - hashLen - 1
pubBits = numBits (public_n pk)
-- unmarshall fields -- unmarshall fields
em = ep pk s (pre, em) = B.splitAt (k - emLen) (ep pk s) -- drop 0..1 byte
maskedDB = B.take (B.length em - hashLen - 1) em maskedDB = B.take dbLen em
h = B.take hashLen $ B.drop (B.length maskedDB) em h = B.take hashLen $ B.drop (B.length maskedDB) em
dbmask = pssMaskGenAlg params h dbLen dbmask = pssMaskGenAlg params h dbLen
db = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor maskedDB dbmask db = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor maskedDB dbmask
@ -182,6 +187,10 @@ verifyDigest params pk digest s
m' = B.concat [B.replicate 8 0,mHash,salt] m' = B.concat [B.replicate 8 0,mHash,salt]
h' = hashWith (pssHash params) m' h' = hashWith (pssHash params) m'
-- When the modulus has bit length 1 modulo 8 we drop the first byte.
emTruncate :: Int -> Bool
emTruncate bits = ((bits-1) .&. 0x7) == 0
normalizeToKeySize :: Int -> [Word8] -> [Word8] normalizeToKeySize :: Int -> [Word8] -> [Word8]
normalizeToKeySize _ [] = [] -- very unlikely normalizeToKeySize _ [] = [] -- very unlikely
normalizeToKeySize bits (x:xs) = x .&. mask : xs normalizeToKeySize bits (x:xs) = x .&. mask : xs