diff --git a/Crypto/PubKey/ECC/ECDSA.hs b/Crypto/PubKey/ECC/ECDSA.hs index 23c9180..bb12ce7 100644 --- a/Crypto/PubKey/ECC/ECDSA.hs +++ b/Crypto/PubKey/ECC/ECDSA.hs @@ -102,10 +102,7 @@ verify hashAlg pk@(PublicKey curve q) (Signature r s) msg let z = tHash hashAlg msg n u1 = z * w `mod` n u2 = r * w `mod` n - -- TODO: Use Shamir's trick - g' = pointMul curve u1 g - q' = pointMul curve u2 q - x = pointAdd curve g' q' + x = pointAddTwoMuls curve u1 g u2 q case x of PointO -> Nothing Point x1 _ -> return $ x1 `mod` n diff --git a/Crypto/PubKey/ECC/Prim.hs b/Crypto/PubKey/ECC/Prim.hs index a3a8324..2428fc8 100644 --- a/Crypto/PubKey/ECC/Prim.hs +++ b/Crypto/PubKey/ECC/Prim.hs @@ -7,6 +7,7 @@ module Crypto.PubKey.ECC.Prim , pointDouble , pointBaseMul , pointMul + , pointAddTwoMuls , isPointAtInfinity , isPointValid ) where @@ -108,6 +109,33 @@ pointMul c n p | odd n = pointAdd c p (pointMul c (n - 1) p) | otherwise = pointMul c (n `div` 2) (pointDouble c p) +-- | Elliptic curve double-scalar multiplication (uses Shamir's trick). +-- +-- > pointAddTwoMuls c n1 p1 n2 p2 == pointAdd c (pointMul c n1 p1) +-- > (pointMul c n2 p2) +-- +-- /WARNING:/ Vulnerable to timing attacks. +pointAddTwoMuls :: Curve -> Integer -> Point -> Integer -> Point -> Point +pointAddTwoMuls _ _ PointO _ PointO = PointO +pointAddTwoMuls c _ PointO n2 p2 = pointMul c n2 p2 +pointAddTwoMuls c n1 p1 _ PointO = pointMul c n1 p1 +pointAddTwoMuls c n1 p1 n2 p2 + | n1 < 0 = pointAddTwoMuls c (-n1) (pointNegate c p1) n2 p2 + | n2 < 0 = pointAddTwoMuls c n1 p1 (-n2) (pointNegate c p2) + | otherwise = go (n1, n2) + + where + p0 = pointAdd c p1 p2 + + go (0, 0 ) = PointO + go (k1, k2) = + let q = pointDouble c $ go (k1 `div` 2, k2 `div` 2) + in case (odd k1, odd k2) of + (True , True ) -> pointAdd c p0 q + (True , False ) -> pointAdd c p1 q + (False , True ) -> pointAdd c p2 q + (False , False ) -> q + -- | Check if a point is the point at infinity. isPointAtInfinity :: Point -> Bool isPointAtInfinity PointO = True diff --git a/benchs/Bench.hs b/benchs/Bench.hs index 13047cc..2be6c95 100644 --- a/benchs/Bench.hs +++ b/benchs/Bench.hs @@ -7,16 +7,17 @@ import Criterion.Main import "cryptonite" Crypto.Hash import "cryptonite" Crypto.Error import "cryptonite" Crypto.Cipher.DES -import "cryptonite" Crypto.Cipher.Camellia import "cryptonite" Crypto.Cipher.AES import "cryptonite" Crypto.Cipher.Blowfish import "cryptonite" Crypto.Cipher.Types import qualified "cryptonite" Crypto.Cipher.ChaChaPoly1305 as CP -import "cryptonite" Crypto.Hash (SHA512(..)) import qualified "cryptonite" Crypto.KDF.PBKDF2 as PBKDF2 -import Data.ByteArray (ByteArray, Bytes) +import qualified "cryptonite" Crypto.PubKey.ECC.Types as ECC +import qualified "cryptonite" Crypto.PubKey.ECC.Prim as ECC + +import Data.ByteArray (ByteArray) import qualified Data.ByteString as B @@ -103,9 +104,27 @@ benchAE = key32 = B.replicate 32 0 +benchECC = + [ bench "pointAddTwoMuls-baseline" $ nf run_b (n1, p1, n2, p2) + , bench "pointAddTwoMuls-optimized" $ nf run_o (n1, p1, n2, p2) + ] + where run_b (n, p, k, q) = ECC.pointAdd c (ECC.pointMul c n p) + (ECC.pointMul c k q) + + run_o (n, p, k, q) = ECC.pointAddTwoMuls c n p k q + + c = ECC.getCurveByName ECC.SEC_p256r1 + r1 = 7 + r2 = 11 + p1 = ECC.pointBaseMul c r1 + p2 = ECC.pointBaseMul c r2 + n1 = 0x2ba9daf2363b2819e69b34a39cf496c2458a9b2a21505ea9e7b7cbca42dc7435 + n2 = 0xf054a7f60d10b8c2cf847ee90e9e029f8b0e971b09ca5f55c4d49921a11fadc1 + main = defaultMain [ bgroup "hash" benchHash , bgroup "block-cipher" benchBlockCipher , bgroup "AE" benchAE , bgroup "pbkdf2" benchPBKDF2 + , bgroup "ECC" benchECC ] diff --git a/tests/KAT_PubKey/ECC.hs b/tests/KAT_PubKey/ECC.hs index d17055c..9c6a923 100644 --- a/tests/KAT_PubKey/ECC.hs +++ b/tests/KAT_PubKey/ECC.hs @@ -138,10 +138,16 @@ vectorsPoint = doPointValidTest (i, vector) = testCase (show i) (valid vector @=? ECC.isPointValid (curve vector) (ECC.Point (x vector) (y vector))) +arbitraryPoint :: ECC.Curve -> Gen ECC.Point +arbitraryPoint aCurve = + frequency [(5, return ECC.PointO), (95, pointGen)] + where + n = ECC.ecc_n (ECC.common_curve aCurve) + pointGen = ECC.pointBaseMul aCurve <$> choose (1, n - 1) eccTests = testGroup "ECC" [ testGroup "valid-point" $ map doPointValidTest (zip [katZero..] vectorsPoint) - , testGroup "property" $ + , testGroup "property" [ testProperty "point-add" $ \aCurve (QAInteger r1) (QAInteger r2) -> let curveN = ECC.ecc_n . ECC.common_curve $ aCurve curveGen = ECC.ecc_g . ECC.common_curve $ aCurve @@ -149,6 +155,19 @@ eccTests = testGroup "ECC" p2 = ECC.pointMul aCurve r2 curveGen pR = ECC.pointMul aCurve ((r1 + r2) `mod` curveN) curveGen in pR `propertyEq` ECC.pointAdd aCurve p1 p2 + , localOption (QuickCheckTests 20) $ + testProperty "point-mul-mul" $ \aCurve (QAInteger n1) (QAInteger n2) -> do + p <- arbitraryPoint aCurve + let pRes = ECC.pointMul aCurve (n1 * n2) p + let pDef = ECC.pointMul aCurve n1 (ECC.pointMul aCurve n2 p) + return $ pRes `propertyEq` pDef + , localOption (QuickCheckTests 20) $ + testProperty "double-scalar-mult" $ \aCurve (QAInteger n1) (QAInteger n2) -> do + p1 <- arbitraryPoint aCurve + p2 <- arbitraryPoint aCurve + let pRes = ECC.pointAddTwoMuls aCurve n1 p1 n2 p2 + let pDef = ECC.pointAdd aCurve (ECC.pointMul aCurve n1 p1) (ECC.pointMul aCurve n2 p2) + return $ pRes `propertyEq` pDef ] ]