diff --git a/Crypto/PubKey/RSA/PSS.hs b/Crypto/PubKey/RSA/PSS.hs index b221262..2f6ddf3 100644 --- a/Crypto/PubKey/RSA/PSS.hs +++ b/Crypto/PubKey/RSA/PSS.hs @@ -26,6 +26,7 @@ import Crypto.PubKey.RSA.Prim import Crypto.PubKey.RSA (generateBlinder) import Crypto.PubKey.MaskGenFunction import Crypto.Hash +import Crypto.Number.Basic (numBits) import Data.Bits (xor, shiftR, (.&.)) import Data.Word @@ -69,14 +70,15 @@ signDigestWithSalt :: HashAlgorithm hash -> Digest hash -- ^ Message digest -> Either Error ByteString signDigestWithSalt salt blinder params pk digest - | k < hashLen + saltLen + 2 = Left InvalidParameters - | otherwise = Right $ dp blinder pk em + | emLen < hashLen + saltLen + 2 = Left InvalidParameters + | otherwise = Right $ dp blinder pk em where k = private_size pk + emLen = if emTruncate pubBits then k - 1 else k mHash = B.convert digest - dbLen = k - hashLen - 1 + dbLen = emLen - hashLen - 1 saltLen = B.length salt hashLen = hashDigestSize (pssHash params) - pubBits = private_size pk * 8 -- to change if public_size is converted in bytes + pubBits = numBits (private_n pk) m' = B.concat [B.replicate 8 0,mHash,salt] h = B.convert $ hashWith (pssHash params) m' db = B.concat [B.replicate (dbLen - saltLen - 1) 0,B.singleton 1,salt] @@ -161,7 +163,8 @@ verifyDigest :: HashAlgorithm hash -> ByteString -- ^ Signature -> Bool verifyDigest params pk digest s - | public_size pk /= B.length s = False + | B.length s /= k = False + | B.any (/= 0) pre = False | B.last em /= pssTrailerField params = False | not (B.all (== 0) ps0) = False | b1 /= B.singleton 1 = False @@ -169,11 +172,13 @@ verifyDigest params pk digest s where -- parameters hashLen = hashDigestSize (pssHash params) mHash = B.convert digest - dbLen = public_size pk - hashLen - 1 - pubBits = public_size pk * 8 -- to change if public_size is converted in bytes + k = public_size pk + emLen = if emTruncate pubBits then k - 1 else k + dbLen = emLen - hashLen - 1 + pubBits = numBits (public_n pk) -- unmarshall fields - em = ep pk s - maskedDB = B.take (B.length em - hashLen - 1) em + (pre, em) = B.splitAt (k - emLen) (ep pk s) -- drop 0..1 byte + maskedDB = B.take dbLen em h = B.take hashLen $ B.drop (B.length maskedDB) em dbmask = pssMaskGenAlg params h dbLen db = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor maskedDB dbmask @@ -182,6 +187,10 @@ verifyDigest params pk digest s m' = B.concat [B.replicate 8 0,mHash,salt] h' = hashWith (pssHash params) m' +-- When the modulus has bit length 1 modulo 8 we drop the first byte. +emTruncate :: Int -> Bool +emTruncate bits = ((bits-1) .&. 0x7) == 0 + normalizeToKeySize :: Int -> [Word8] -> [Word8] normalizeToKeySize _ [] = [] -- very unlikely normalizeToKeySize bits (x:xs) = x .&. mask : xs