diff --git a/Crypto/PubKey/RSA/PSS.hs b/Crypto/PubKey/RSA/PSS.hs index c449a97..8abd228 100644 --- a/Crypto/PubKey/RSA/PSS.hs +++ b/Crypto/PubKey/RSA/PSS.hs @@ -11,9 +11,13 @@ module Crypto.PubKey.RSA.PSS , defaultPSSParamsSHA1 -- * Sign and verify functions , signWithSalt + , signDigestWithSalt , sign + , signDigest , signSafer + , signDigestSafer , verify + , verifyDigest ) where import Crypto.Random.Types @@ -27,6 +31,7 @@ import Data.Word import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray) import qualified Crypto.Internal.ByteArray as B (convert) + import Data.ByteString (ByteString) import qualified Data.ByteString as B @@ -53,6 +58,32 @@ defaultPSSParams hashAlg = defaultPSSParamsSHA1 :: PSSParams SHA1 ByteString ByteString defaultPSSParamsSHA1 = defaultPSSParams SHA1 +-- | Sign using the PSS parameters and the salt explicitely passed as parameters. +-- +-- the function ignore SaltLength from the PSS Parameters +signDigestWithSalt :: HashAlgorithm hash + => ByteString -- ^ Salt to use + -> Maybe Blinder -- ^ optional blinder to use + -> PSSParams hash ByteString ByteString -- ^ PSS Parameters to use + -> PrivateKey -- ^ RSA Private Key + -> Digest hash -- ^ Message digest + -> Either Error ByteString +signDigestWithSalt salt blinder params pk digest + | k < hashLen + saltLen + 2 = Left InvalidParameters + | otherwise = Right $ dp blinder pk em + where k = private_size pk + mHash = B.convert digest + dbLen = k - hashLen - 1 + saltLen = B.length salt + hashLen = hashDigestSize (pssHash params) + pubBits = private_size pk * 8 -- to change if public_size is converted in bytes + m' = B.concat [B.replicate 8 0,mHash,salt] + h = B.convert $ hashWith (pssHash params) m' + db = B.concat [B.replicate (dbLen - saltLen - 1) 0,B.singleton 1,salt] + dbmask = (pssMaskGenAlg params) h dbLen + maskedDB = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor db dbmask + em = B.concat [maskedDB, h, B.singleton (pssTrailerField params)] + -- | Sign using the PSS parameters and the salt explicitely passed as parameters. -- -- the function ignore SaltLength from the PSS Parameters @@ -63,22 +94,8 @@ signWithSalt :: HashAlgorithm hash -> PrivateKey -- ^ RSA Private Key -> ByteString -- ^ Message to sign -> Either Error ByteString -signWithSalt salt blinder params pk m - | k < hashLen + saltLen + 2 = Left InvalidParameters - | otherwise = Right $ dp blinder pk em - where mHash = B.convert $ hashWith (pssHash params) m - k = private_size pk - dbLen = k - hashLen - 1 - saltLen = B.length salt - hashLen = hashDigestSize (pssHash params) - pubBits = private_size pk * 8 -- to change if public_size is converted in bytes - - m' = B.concat [B.replicate 8 0,mHash,salt] - h = B.convert $ hashWith (pssHash params) m' - db = B.concat [B.replicate (dbLen - saltLen - 1) 0,B.singleton 1,salt] - dbmask = (pssMaskGenAlg params) h dbLen - maskedDB = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor db dbmask - em = B.concat [maskedDB, h, B.singleton (pssTrailerField params)] +signWithSalt salt blinder params pk m = signDigestWithSalt salt blinder params pk mHash + where mHash = hashWith (pssHash params) m -- | Sign using the PSS Parameters sign :: (HashAlgorithm hash, MonadRandom m) @@ -91,6 +108,17 @@ sign blinder params pk m = do salt <- getRandomBytes (pssSaltLength params) return (signWithSalt salt blinder params pk m) +-- | Sign using the PSS Parameters +signDigest :: (HashAlgorithm hash, MonadRandom m) + => Maybe Blinder -- ^ optional blinder to use + -> PSSParams hash ByteString ByteString -- ^ PSS Parameters to use + -> PrivateKey -- ^ RSA Private Key + -> Digest hash -- ^ Message digest + -> m (Either Error ByteString) +signDigest blinder params pk digest = do + salt <- getRandomBytes (pssSaltLength params) + return (signDigestWithSalt salt blinder params pk digest) + -- | Sign using the PSS Parameters and an automatically generated blinder. signSafer :: (HashAlgorithm hash, MonadRandom m) => PSSParams hash ByteString ByteString -- ^ PSS Parameters to use @@ -101,6 +129,16 @@ signSafer params pk m = do blinder <- generateBlinder (private_n pk) sign (Just blinder) params pk m +-- | Sign using the PSS Parameters and an automatically generated blinder. +signDigestSafer :: (HashAlgorithm hash, MonadRandom m) + => PSSParams hash ByteString ByteString -- ^ PSS Parameters to use + -> PrivateKey -- ^ private key + -> Digest hash -- ^ message digst + -> m (Either Error ByteString) +signDigestSafer params pk digest = do + blinder <- generateBlinder (private_n pk) + signDigest (Just blinder) params pk digest + -- | Verify a signature using the PSS Parameters verify :: HashAlgorithm hash => PSSParams hash ByteString ByteString @@ -110,7 +148,19 @@ verify :: HashAlgorithm hash -> ByteString -- ^ Message to verify -> ByteString -- ^ Signature -> Bool -verify params pk m s +verify params pk m s = verifyDigest params pk mHash s + where mHash = hashWith (pssHash params) m + +-- | Verify a signature using the PSS Parameters +verifyDigest :: HashAlgorithm hash + => PSSParams hash ByteString ByteString + -- ^ PSS Parameters to use to verify, + -- this need to be identical to the parameters when signing + -> PublicKey -- ^ RSA Public Key + -> Digest hash -- ^ Digest to verify + -> ByteString -- ^ Signature + -> Bool +verifyDigest params pk digest s | public_size pk /= B.length s = False | B.last em /= pssTrailerField params = False | not (B.all (== 0) ps0) = False @@ -118,6 +168,7 @@ verify params pk m s | otherwise = h == B.convert h' where -- parameters hashLen = hashDigestSize (pssHash params) + mHash = B.convert digest dbLen = public_size pk - hashLen - 1 pubBits = public_size pk * 8 -- to change if public_size is converted in bytes -- unmarshall fields @@ -128,7 +179,6 @@ verify params pk m s db = B.pack $ normalizeToKeySize pubBits $ B.zipWith xor maskedDB dbmask (ps0,z) = B.break (== 1) db (b1,salt) = B.splitAt 1 z - mHash = B.convert $ hashWith (pssHash params) m m' = B.concat [B.replicate 8 0,mHash,salt] h' = hashWith (pssHash params) m' @@ -137,3 +187,4 @@ normalizeToKeySize _ [] = [] -- very unlikely normalizeToKeySize bits (x:xs) = x .&. mask : xs where mask = if sh > 0 then 0xff `shiftR` (8-sh) else 0xff sh = ((bits-1) .&. 0x7) +