Unified DSA and ECDSA truncate&hash function.
This commit is contained in:
parent
d5003a46a6
commit
c71a6733dd
@ -39,6 +39,7 @@ import Crypto.Number.Generate
|
|||||||
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes, convert, index, dropView, takeView, pack, unpack)
|
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes, convert, index, dropView, takeView, pack, unpack)
|
||||||
import Crypto.Internal.Imports
|
import Crypto.Internal.Imports
|
||||||
import Crypto.Hash
|
import Crypto.Hash
|
||||||
|
import Crypto.PubKey.Internal (dsaTruncHash)
|
||||||
import Prelude
|
import Prelude
|
||||||
|
|
||||||
-- | DSA Public Number, usually embedded in DSA Public Key
|
-- | DSA Public Number, usually embedded in DSA Public Key
|
||||||
@ -126,7 +127,7 @@ signWith k pk hashAlg msg
|
|||||||
x = private_x pk
|
x = private_x pk
|
||||||
-- compute r,s
|
-- compute r,s
|
||||||
kInv = fromJust $ inverse k q
|
kInv = fromJust $ inverse k q
|
||||||
hm = dsaHash q hashAlg msg
|
hm = dsaTruncHash hashAlg msg q
|
||||||
r = expSafe g k p `mod` q
|
r = expSafe g k p `mod` q
|
||||||
s = (kInv * (hm + x * r)) `mod` q
|
s = (kInv * (hm + x * r)) `mod` q
|
||||||
|
|
||||||
@ -148,36 +149,8 @@ verify hashAlg pk (Signature r s) m
|
|||||||
| otherwise = v == r
|
| otherwise = v == r
|
||||||
where (Params p g q) = public_params pk
|
where (Params p g q) = public_params pk
|
||||||
y = public_y pk
|
y = public_y pk
|
||||||
hm = dsaHash q hashAlg m
|
hm = dsaTruncHash hashAlg m q
|
||||||
|
|
||||||
w = fromJust $ inverse s q
|
w = fromJust $ inverse s q
|
||||||
u1 = (hm*w) `mod` q
|
u1 = (hm*w) `mod` q
|
||||||
u2 = (r*w) `mod` q
|
u2 = (r*w) `mod` q
|
||||||
v = ((expFast g u1 p) * (expFast y u2 p)) `mod` p `mod` q
|
v = ((expFast g u1 p) * (expFast y u2 p)) `mod` p `mod` q
|
||||||
|
|
||||||
dsaHash :: (ByteArrayAccess msg, HashAlgorithm hash) => Integer -> hash -> msg -> Integer
|
|
||||||
dsaHash q hashAlg msg =
|
|
||||||
-- if the hash is larger than the size of q, truncate it; FIXME: deal with the case of a q not evenly divisible by 8
|
|
||||||
let numDropBits = (hashDigestSize hashAlg)*8 - numBits q
|
|
||||||
rawHash = hashWith hashAlg msg
|
|
||||||
in case compare numDropBits 0 of
|
|
||||||
GT -> -- hash output is larger than modulus
|
|
||||||
let (nq,nr) = numDropBits `divMod` 8
|
|
||||||
in if nr == 0 -- difference is 0 mod 8 => numBits is 0 `mod` 8
|
|
||||||
then os2ip $ takeView rawHash $ (numBits q) `div` 8
|
|
||||||
else os2ip $ shiftR rawHash numDropBits
|
|
||||||
_ -> os2ip rawHash
|
|
||||||
|
|
||||||
-- shift right by a given number of bits, dropping full bytes of leading zeros
|
|
||||||
-- based on code from the `bits-bytestring` package
|
|
||||||
shiftR :: (ByteArrayAccess m) => m -> Int -> ScrubbedBytes
|
|
||||||
shiftR bs i =
|
|
||||||
let ws = unpack bs
|
|
||||||
in pack $ go 0 $ take (length ws - q) ws
|
|
||||||
where
|
|
||||||
(q,r) = i `divMod` 8
|
|
||||||
go _ [] = []
|
|
||||||
go w1 (w2:wst) = (maskR w1 w2) : go w2 wst
|
|
||||||
-- given [w1,w2], constructs w2', which is left by j bits to get the
|
|
||||||
-- bottom j bits of w1 || top (8-j) bits of w2
|
|
||||||
maskR w1 w2 = (Bits.shiftL w1 (8-r)) Bits..|. (Bits.shiftR w2 r)
|
|
||||||
@ -26,6 +26,7 @@ import Crypto.Number.Serialize
|
|||||||
import Crypto.Number.Generate
|
import Crypto.Number.Generate
|
||||||
import Crypto.PubKey.ECC.Types
|
import Crypto.PubKey.ECC.Types
|
||||||
import Crypto.PubKey.ECC.Prim
|
import Crypto.PubKey.ECC.Prim
|
||||||
|
import Crypto.PubKey.Internal (dsaTruncHash)
|
||||||
import Crypto.Hash
|
import Crypto.Hash
|
||||||
import Crypto.Hash.Types (hashDigestSize)
|
import Crypto.Hash.Types (hashDigestSize)
|
||||||
|
|
||||||
@ -69,7 +70,7 @@ signWith :: (ByteArrayAccess msg, HashAlgorithm hash)
|
|||||||
-> msg -- ^ message to sign
|
-> msg -- ^ message to sign
|
||||||
-> Maybe Signature
|
-> Maybe Signature
|
||||||
signWith k (PrivateKey curve d) hashAlg msg = do
|
signWith k (PrivateKey curve d) hashAlg msg = do
|
||||||
let z = tHash hashAlg msg n
|
let z = dsaTruncHash hashAlg msg n
|
||||||
CurveCommon _ _ g n _ = common_curve curve
|
CurveCommon _ _ g n _ = common_curve curve
|
||||||
let point = pointMul curve k g
|
let point = pointMul curve k g
|
||||||
r <- case point of
|
r <- case point of
|
||||||
@ -99,7 +100,7 @@ verify hashAlg pk@(PublicKey curve q) (Signature r s) msg
|
|||||||
| r < 1 || r >= n || s < 1 || s >= n = False
|
| r < 1 || r >= n || s < 1 || s >= n = False
|
||||||
| otherwise = maybe False (r ==) $ do
|
| otherwise = maybe False (r ==) $ do
|
||||||
w <- inverse s n
|
w <- inverse s n
|
||||||
let z = tHash hashAlg msg n
|
let z = dsaTruncHash hashAlg msg n
|
||||||
u1 = z * w `mod` n
|
u1 = z * w `mod` n
|
||||||
u2 = r * w `mod` n
|
u2 = r * w `mod` n
|
||||||
x = pointAddTwoMuls curve u1 g u2 q
|
x = pointAddTwoMuls curve u1 g u2 q
|
||||||
@ -109,11 +110,3 @@ verify hashAlg pk@(PublicKey curve q) (Signature r s) msg
|
|||||||
where n = ecc_n cc
|
where n = ecc_n cc
|
||||||
g = ecc_g cc
|
g = ecc_g cc
|
||||||
cc = common_curve $ public_curve pk
|
cc = common_curve $ public_curve pk
|
||||||
|
|
||||||
-- | Truncate and hash.
|
|
||||||
tHash :: (ByteArrayAccess msg, HashAlgorithm hash) => hash -> msg -> Integer -> Integer
|
|
||||||
tHash hashAlg m n
|
|
||||||
| d > 0 = shiftR e d
|
|
||||||
| otherwise = e
|
|
||||||
where e = os2ip $ hashWith hashAlg m
|
|
||||||
d = hashDigestSize hashAlg * 8 - numBits n
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
module Crypto.PubKey.Internal
|
module Crypto.PubKey.Internal
|
||||||
( and'
|
( and'
|
||||||
, (&&!)
|
, (&&!)
|
||||||
|
, dsaTruncHash
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.List (foldl')
|
import Data.List (foldl')
|
||||||
@ -22,3 +23,11 @@ True &&! True = True
|
|||||||
True &&! False = False
|
True &&! False = False
|
||||||
False &&! True = False
|
False &&! True = False
|
||||||
False &&! False = False
|
False &&! False = False
|
||||||
|
|
||||||
|
-- | Truncate and hash for DSA and ECDSA.
|
||||||
|
dsaTruncHash :: (ByteArrayAccess msg, HashAlgorithm hash) => hash -> msg -> Integer -> Integer
|
||||||
|
dsaTruncHash hashAlg m n
|
||||||
|
| d > 0 = shiftR e d
|
||||||
|
| otherwise = e
|
||||||
|
where e = os2ip $ hashWith hashAlg m
|
||||||
|
d = hashDigestSize hashAlg * 8 - numBits n
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user