Unified DSA and ECDSA truncate&hash function.

This commit is contained in:
Crockett 2019-02-03 13:30:56 -08:00
parent d5003a46a6
commit c71a6733dd
3 changed files with 15 additions and 40 deletions

View File

@ -39,6 +39,7 @@ import Crypto.Number.Generate
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes, convert, index, dropView, takeView, pack, unpack)
import Crypto.Internal.Imports
import Crypto.Hash
import Crypto.PubKey.Internal (dsaTruncHash)
import Prelude
-- | DSA Public Number, usually embedded in DSA Public Key
@ -126,7 +127,7 @@ signWith k pk hashAlg msg
x = private_x pk
-- compute r,s
kInv = fromJust $ inverse k q
hm = dsaHash q hashAlg msg
hm = dsaTruncHash hashAlg msg q
r = expSafe g k p `mod` q
s = (kInv * (hm + x * r)) `mod` q
@ -148,36 +149,8 @@ verify hashAlg pk (Signature r s) m
| otherwise = v == r
where (Params p g q) = public_params pk
y = public_y pk
hm = dsaHash q hashAlg m
hm = dsaTruncHash hashAlg m q
w = fromJust $ inverse s q
u1 = (hm*w) `mod` q
u2 = (r*w) `mod` q
v = ((expFast g u1 p) * (expFast y u2 p)) `mod` p `mod` q
dsaHash :: (ByteArrayAccess msg, HashAlgorithm hash) => Integer -> hash -> msg -> Integer
dsaHash q hashAlg msg =
-- if the hash is larger than the size of q, truncate it; FIXME: deal with the case of a q not evenly divisible by 8
let numDropBits = (hashDigestSize hashAlg)*8 - numBits q
rawHash = hashWith hashAlg msg
in case compare numDropBits 0 of
GT -> -- hash output is larger than modulus
let (nq,nr) = numDropBits `divMod` 8
in if nr == 0 -- difference is 0 mod 8 => numBits is 0 `mod` 8
then os2ip $ takeView rawHash $ (numBits q) `div` 8
else os2ip $ shiftR rawHash numDropBits
_ -> os2ip rawHash
-- shift right by a given number of bits, dropping full bytes of leading zeros
-- based on code from the `bits-bytestring` package
shiftR :: (ByteArrayAccess m) => m -> Int -> ScrubbedBytes
shiftR bs i =
let ws = unpack bs
in pack $ go 0 $ take (length ws - q) ws
where
(q,r) = i `divMod` 8
go _ [] = []
go w1 (w2:wst) = (maskR w1 w2) : go w2 wst
-- given [w1,w2], constructs w2', which is left by j bits to get the
-- bottom j bits of w1 || top (8-j) bits of w2
maskR w1 w2 = (Bits.shiftL w1 (8-r)) Bits..|. (Bits.shiftR w2 r)

View File

@ -26,6 +26,7 @@ import Crypto.Number.Serialize
import Crypto.Number.Generate
import Crypto.PubKey.ECC.Types
import Crypto.PubKey.ECC.Prim
import Crypto.PubKey.Internal (dsaTruncHash)
import Crypto.Hash
import Crypto.Hash.Types (hashDigestSize)
@ -69,7 +70,7 @@ signWith :: (ByteArrayAccess msg, HashAlgorithm hash)
-> msg -- ^ message to sign
-> Maybe Signature
signWith k (PrivateKey curve d) hashAlg msg = do
let z = tHash hashAlg msg n
let z = dsaTruncHash hashAlg msg n
CurveCommon _ _ g n _ = common_curve curve
let point = pointMul curve k g
r <- case point of
@ -99,7 +100,7 @@ verify hashAlg pk@(PublicKey curve q) (Signature r s) msg
| r < 1 || r >= n || s < 1 || s >= n = False
| otherwise = maybe False (r ==) $ do
w <- inverse s n
let z = tHash hashAlg msg n
let z = dsaTruncHash hashAlg msg n
u1 = z * w `mod` n
u2 = r * w `mod` n
x = pointAddTwoMuls curve u1 g u2 q
@ -109,11 +110,3 @@ verify hashAlg pk@(PublicKey curve q) (Signature r s) msg
where n = ecc_n cc
g = ecc_g cc
cc = common_curve $ public_curve pk
-- | Truncate and hash.
tHash :: (ByteArrayAccess msg, HashAlgorithm hash) => hash -> msg -> Integer -> Integer
tHash hashAlg m n
| d > 0 = shiftR e d
| otherwise = e
where e = os2ip $ hashWith hashAlg m
d = hashDigestSize hashAlg * 8 - numBits n

View File

@ -8,6 +8,7 @@
module Crypto.PubKey.Internal
( and'
, (&&!)
, dsaTruncHash
) where
import Data.List (foldl')
@ -22,3 +23,11 @@ True &&! True = True
True &&! False = False
False &&! True = False
False &&! False = False
-- | Truncate and hash for DSA and ECDSA.
dsaTruncHash :: (ByteArrayAccess msg, HashAlgorithm hash) => hash -> msg -> Integer -> Integer
dsaTruncHash hashAlg m n
| d > 0 = shiftR e d
| otherwise = e
where e = os2ip $ hashWith hashAlg m
d = hashDigestSize hashAlg * 8 - numBits n