Unified DSA and ECDSA truncate&hash function.

This commit is contained in:
Crockett 2019-02-03 13:30:56 -08:00
parent d5003a46a6
commit c71a6733dd
3 changed files with 15 additions and 40 deletions

View File

@ -39,6 +39,7 @@ import Crypto.Number.Generate
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes, convert, index, dropView, takeView, pack, unpack) import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes, convert, index, dropView, takeView, pack, unpack)
import Crypto.Internal.Imports import Crypto.Internal.Imports
import Crypto.Hash import Crypto.Hash
import Crypto.PubKey.Internal (dsaTruncHash)
import Prelude import Prelude
-- | DSA Public Number, usually embedded in DSA Public Key -- | DSA Public Number, usually embedded in DSA Public Key
@ -126,7 +127,7 @@ signWith k pk hashAlg msg
x = private_x pk x = private_x pk
-- compute r,s -- compute r,s
kInv = fromJust $ inverse k q kInv = fromJust $ inverse k q
hm = dsaHash q hashAlg msg hm = dsaTruncHash hashAlg msg q
r = expSafe g k p `mod` q r = expSafe g k p `mod` q
s = (kInv * (hm + x * r)) `mod` q s = (kInv * (hm + x * r)) `mod` q
@ -148,36 +149,8 @@ verify hashAlg pk (Signature r s) m
| otherwise = v == r | otherwise = v == r
where (Params p g q) = public_params pk where (Params p g q) = public_params pk
y = public_y pk y = public_y pk
hm = dsaHash q hashAlg m hm = dsaTruncHash hashAlg m q
w = fromJust $ inverse s q w = fromJust $ inverse s q
u1 = (hm*w) `mod` q u1 = (hm*w) `mod` q
u2 = (r*w) `mod` q u2 = (r*w) `mod` q
v = ((expFast g u1 p) * (expFast y u2 p)) `mod` p `mod` q v = ((expFast g u1 p) * (expFast y u2 p)) `mod` p `mod` q
dsaHash :: (ByteArrayAccess msg, HashAlgorithm hash) => Integer -> hash -> msg -> Integer
dsaHash q hashAlg msg =
-- if the hash is larger than the size of q, truncate it; FIXME: deal with the case of a q not evenly divisible by 8
let numDropBits = (hashDigestSize hashAlg)*8 - numBits q
rawHash = hashWith hashAlg msg
in case compare numDropBits 0 of
GT -> -- hash output is larger than modulus
let (nq,nr) = numDropBits `divMod` 8
in if nr == 0 -- difference is 0 mod 8 => numBits is 0 `mod` 8
then os2ip $ takeView rawHash $ (numBits q) `div` 8
else os2ip $ shiftR rawHash numDropBits
_ -> os2ip rawHash
-- shift right by a given number of bits, dropping full bytes of leading zeros
-- based on code from the `bits-bytestring` package
shiftR :: (ByteArrayAccess m) => m -> Int -> ScrubbedBytes
shiftR bs i =
let ws = unpack bs
in pack $ go 0 $ take (length ws - q) ws
where
(q,r) = i `divMod` 8
go _ [] = []
go w1 (w2:wst) = (maskR w1 w2) : go w2 wst
-- given [w1,w2], constructs w2', which is left by j bits to get the
-- bottom j bits of w1 || top (8-j) bits of w2
maskR w1 w2 = (Bits.shiftL w1 (8-r)) Bits..|. (Bits.shiftR w2 r)

View File

@ -26,6 +26,7 @@ import Crypto.Number.Serialize
import Crypto.Number.Generate import Crypto.Number.Generate
import Crypto.PubKey.ECC.Types import Crypto.PubKey.ECC.Types
import Crypto.PubKey.ECC.Prim import Crypto.PubKey.ECC.Prim
import Crypto.PubKey.Internal (dsaTruncHash)
import Crypto.Hash import Crypto.Hash
import Crypto.Hash.Types (hashDigestSize) import Crypto.Hash.Types (hashDigestSize)
@ -69,7 +70,7 @@ signWith :: (ByteArrayAccess msg, HashAlgorithm hash)
-> msg -- ^ message to sign -> msg -- ^ message to sign
-> Maybe Signature -> Maybe Signature
signWith k (PrivateKey curve d) hashAlg msg = do signWith k (PrivateKey curve d) hashAlg msg = do
let z = tHash hashAlg msg n let z = dsaTruncHash hashAlg msg n
CurveCommon _ _ g n _ = common_curve curve CurveCommon _ _ g n _ = common_curve curve
let point = pointMul curve k g let point = pointMul curve k g
r <- case point of r <- case point of
@ -99,7 +100,7 @@ verify hashAlg pk@(PublicKey curve q) (Signature r s) msg
| r < 1 || r >= n || s < 1 || s >= n = False | r < 1 || r >= n || s < 1 || s >= n = False
| otherwise = maybe False (r ==) $ do | otherwise = maybe False (r ==) $ do
w <- inverse s n w <- inverse s n
let z = tHash hashAlg msg n let z = dsaTruncHash hashAlg msg n
u1 = z * w `mod` n u1 = z * w `mod` n
u2 = r * w `mod` n u2 = r * w `mod` n
x = pointAddTwoMuls curve u1 g u2 q x = pointAddTwoMuls curve u1 g u2 q
@ -109,11 +110,3 @@ verify hashAlg pk@(PublicKey curve q) (Signature r s) msg
where n = ecc_n cc where n = ecc_n cc
g = ecc_g cc g = ecc_g cc
cc = common_curve $ public_curve pk cc = common_curve $ public_curve pk
-- | Truncate and hash.
tHash :: (ByteArrayAccess msg, HashAlgorithm hash) => hash -> msg -> Integer -> Integer
tHash hashAlg m n
| d > 0 = shiftR e d
| otherwise = e
where e = os2ip $ hashWith hashAlg m
d = hashDigestSize hashAlg * 8 - numBits n

View File

@ -8,6 +8,7 @@
module Crypto.PubKey.Internal module Crypto.PubKey.Internal
( and' ( and'
, (&&!) , (&&!)
, dsaTruncHash
) where ) where
import Data.List (foldl') import Data.List (foldl')
@ -22,3 +23,11 @@ True &&! True = True
True &&! False = False True &&! False = False
False &&! True = False False &&! True = False
False &&! False = False False &&! False = False
-- | Truncate and hash for DSA and ECDSA.
dsaTruncHash :: (ByteArrayAccess msg, HashAlgorithm hash) => hash -> msg -> Integer -> Integer
dsaTruncHash hashAlg m n
| d > 0 = shiftR e d
| otherwise = e
where e = os2ip $ hashWith hashAlg m
d = hashDigestSize hashAlg * 8 - numBits n