From f63a3c6025b421d43017581875f3f5a3fb277cf6 Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Mon, 1 Jun 2015 07:48:31 +0100 Subject: [PATCH] [p256] fix all the bugs found by the now useful P256 test suite --- Crypto/PubKey/ECC/P256.hs | 132 ++++++++++++++++++++++++++++++++------ tests/KAT_PubKey/P256.hs | 126 +++++++++++++++++++++++++++++++++--- 2 files changed, 230 insertions(+), 28 deletions(-) diff --git a/Crypto/PubKey/ECC/P256.hs b/Crypto/PubKey/ECC/P256.hs index 1bf5cb7..8e54df8 100644 --- a/Crypto/PubKey/ECC/P256.hs +++ b/Crypto/PubKey/ECC/P256.hs @@ -11,8 +11,6 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE EmptyDataDecls #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} -{-# OPTIONS_GHC -fno-warn-unused-matches #-} -{-# OPTIONS_GHC -fno-warn-unused-imports #-} module Crypto.PubKey.ECC.P256 ( Scalar , Point @@ -22,37 +20,51 @@ module Crypto.PubKey.ECC.P256 , pointsMulVarTime , pointIsValid , toPoint + , pointToIntegers + , pointFromIntegers + , pointToBinary + , pointFromBinary -- * scalar arithmetic , scalarZero + , scalarIsZero , scalarAdd , scalarSub , scalarInv , scalarCmp , scalarFromBinary , scalarToBinary + , scalarFromInteger + , scalarToInteger ) where import Data.Word import Foreign.Ptr import Foreign.C.Types +import Control.Monad import Crypto.Internal.Compat import Crypto.Internal.Imports import Crypto.Internal.ByteArray import qualified Crypto.Internal.ByteArray as B +import Data.Memory.PtrMethods (memSet) import Crypto.Error +import Crypto.Number.Serialize.Internal (os2ip, i2ospOf) +import qualified Crypto.Number.Serialize as S (os2ip, i2ospOf) -- | A P256 scalar -newtype Scalar = Scalar ScrubbedBytes +newtype Scalar = Scalar Bytes deriving (Eq,ByteArrayAccess) -- | A P256 point -data Point = Point !Bytes !Bytes +newtype Point = Point Bytes deriving (Show,Eq) scalarSize :: Int scalarSize = 32 +pointSize :: Int +pointSize = 64 + type P256Digit = Word32 data P256Scalar @@ -71,8 +83,11 @@ data P256X -- > scalar * G -- toPoint :: Scalar -> Point -toPoint s = withNewPoint $ \px py -> withScalar s $ \p -> - ccryptonite_p256_basepoint_mul p px py +toPoint s + | scalarIsZero s = error "cannot create point from zero" + | otherwise = + withNewPoint $ \px py -> withScalar s $ \p -> + ccryptonite_p256_basepoint_mul p px py -- | Add a point to another point pointAdd :: Point -> Point -> Point @@ -104,6 +119,46 @@ pointIsValid p = unsafeDoIO $ withPoint p $ \px py -> do r <- ccryptonite_p256_is_valid_point px py return (r /= 0) +pointToIntegers :: Point -> (Integer, Integer) +pointToIntegers p = unsafeDoIO $ withPoint p $ \px py -> + allocTemp 32 (serialize (castPtr px) (castPtr py)) + where + serialize px py temp = do + ccryptonite_p256_to_bin px temp + x <- os2ip temp scalarSize + ccryptonite_p256_to_bin py temp + y <- os2ip temp scalarSize + return (x,y) + +pointFromIntegers :: (Integer, Integer) -> Point +pointFromIntegers (x,y) = withNewPoint $ \dx dy -> + allocTemp scalarSize (\temp -> fill temp (castPtr dx) x >> fill temp (castPtr dy) y) + where + -- put @n to @temp in big endian format, then from @temp to @dest in p256 scalar format + fill :: Ptr Word8 -> Ptr P256Scalar -> Integer -> IO () + fill temp dest n = do + -- write the integer in big endian format to temp + memSet temp 0 scalarSize + e <- i2ospOf n temp scalarSize + if e == 0 + then error "pointFromIntegers: filling failed" + else return () + -- then fill dest with the P256 scalar from temp + ccryptonite_p256_from_bin temp dest + +pointToBinary :: ByteArray ba => Point -> ba +pointToBinary p = B.unsafeCreate pointSize $ \dst -> withPoint p $ \px py -> do + ccryptonite_p256_to_bin (castPtr px) dst + ccryptonite_p256_to_bin (castPtr py) (dst `plusPtr` 32) + +pointFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Point +pointFromBinary ba + | B.length ba /= pointSize = CryptoFailed $ CryptoError_PublicKeySizeInvalid + | otherwise = + CryptoPassed $ withNewPoint $ \px py -> B.withByteArray ba $ \src -> do + ccryptonite_p256_from_bin src (castPtr px) + ccryptonite_p256_from_bin (src `plusPtr` scalarSize) (castPtr py) + ------------------------------------------------------------------------ -- Scalar methods ------------------------------------------------------------------------ @@ -112,14 +167,27 @@ pointIsValid p = unsafeDoIO $ withPoint p $ \px py -> do scalarZero :: Scalar scalarZero = withNewScalarFreeze $ \d -> ccryptonite_p256_init d +scalarIsZero :: Scalar -> Bool +scalarIsZero s = unsafeDoIO $ withScalar s $ \d -> do + result <- ccryptonite_p256_is_zero d + return $ result /= 0 + +scalarNeedReducing :: Ptr P256Scalar -> IO Bool +scalarNeedReducing d = do + c <- ccryptonite_p256_cmp d ccryptonite_SECP256r1_n + return (c >= 0) + -- | Perform addition between two scalars -- -- > a + b scalarAdd :: Scalar -> Scalar -> Scalar scalarAdd a b = withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb -> do - void $ ccryptonite_p256_add pa pb d - ccryptonite_p256_mod ccryptonite_SECP256r1_n d d + carry <- ccryptonite_p256_add pa pb d + when (carry /= 0) $ void $ ccryptonite_p256_sub d ccryptonite_SECP256r1_n d + needReducing <- scalarNeedReducing d + when needReducing $ do + ccryptonite_p256_mod ccryptonite_SECP256r1_n d d -- | Perform subtraction between two scalars -- @@ -127,8 +195,11 @@ scalarAdd a b = scalarSub :: Scalar -> Scalar -> Scalar scalarSub a b = withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb -> do - void $ ccryptonite_p256_sub pa pb d - ccryptonite_p256_mod ccryptonite_SECP256r1_n d d + borrow <- ccryptonite_p256_sub pa pb d + when (borrow /= 0) $ void $ ccryptonite_p256_add d ccryptonite_SECP256r1_n d + --needReducing <- scalarNeedReducing d + --when needReducing $ do + -- ccryptonite_p256_mod ccryptonite_SECP256r1_n d d -- | Give the inverse of the scalar -- @@ -154,33 +225,40 @@ scalarFromBinary ba | otherwise = CryptoPassed $ withNewScalarFreeze $ \p -> B.withByteArray ba $ \b -> ccryptonite_p256_from_bin b p +{-# NOINLINE scalarFromBinary #-} -- | convert a scalar to binary scalarToBinary :: ByteArray ba => Scalar -> ba -scalarToBinary s = B.allocAndFreeze scalarSize $ \b -> withScalar s $ \p -> +scalarToBinary s = B.unsafeCreate scalarSize $ \b -> withScalar s $ \p -> ccryptonite_p256_to_bin p b +{-# NOINLINE scalarToBinary #-} + +scalarFromInteger :: Integer -> CryptoFailable Scalar +scalarFromInteger i = + maybe (CryptoFailed CryptoError_SecretKeySizeInvalid) scalarFromBinary (S.i2ospOf 32 i :: Maybe Bytes) + +scalarToInteger :: Scalar -> Integer +scalarToInteger s = S.os2ip (scalarToBinary s :: Bytes) ------------------------------------------------------------------------ -- Memory Helpers ------------------------------------------------------------------------ withNewPoint :: (Ptr P256X -> Ptr P256Y -> IO ()) -> Point -withNewPoint f = unsafeDoIO $ do - (x,y) <- B.allocRet pointCoordSize $ \py -> B.alloc pointCoordSize $ \px -> f px py - return $! Point x y - where pointCoordSize = 32 +withNewPoint f = Point $ B.unsafeCreate pointSize $ \px -> f px (pxToPy px) {-# NOINLINE withNewPoint #-} withPoint :: Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a -withPoint (Point x y) f = B.withByteArray x $ \px -> B.withByteArray y $ \py -> f px py +withPoint (Point d) f = B.withByteArray d $ \px -> f px (pxToPy px) + +pxToPy :: Ptr P256X -> Ptr P256Y +pxToPy px = castPtr (px `plusPtr` scalarSize) withNewScalarFreeze :: (Ptr P256Scalar -> IO ()) -> Scalar withNewScalarFreeze f = Scalar $ B.allocAndFreeze scalarSize f {-# NOINLINE withNewScalarFreeze #-} withTempScalar :: (Ptr P256Scalar -> IO a) -> IO a -withTempScalar f = ignoreSnd <$> B.allocRet scalarSize f - where ignoreSnd :: (a, ScrubbedBytes) -> a - ignoreSnd = fst +withTempScalar f = allocTempScrubbed scalarSize (f . castPtr) withScalar :: Scalar -> (Ptr P256Scalar -> IO a) -> IO a withScalar (Scalar d) f = B.withByteArray d f @@ -191,6 +269,18 @@ withScalarZero f = ccryptonite_p256_init d f d +allocTemp :: Int -> (Ptr Word8 -> IO a) -> IO a +allocTemp n f = ignoreSnd <$> B.allocRet n f + where + ignoreSnd :: (a, Bytes) -> a + ignoreSnd = fst + +allocTempScrubbed :: Int -> (Ptr Word8 -> IO a) -> IO a +allocTempScrubbed n f = ignoreSnd <$> B.allocRet n f + where + ignoreSnd :: (a, ScrubbedBytes) -> a + ignoreSnd = fst + ------------------------------------------------------------------------ -- Foreign bindings ------------------------------------------------------------------------ @@ -203,10 +293,14 @@ foreign import ccall "&cryptonite_SECP256r1_b" foreign import ccall "cryptonite_p256_init" ccryptonite_p256_init :: Ptr P256Scalar -> IO () +foreign import ccall "cryptonite_p256_is_zero" + ccryptonite_p256_is_zero :: Ptr P256Scalar -> IO CInt foreign import ccall "cryptonite_p256_clear" ccryptonite_p256_clear :: Ptr P256Scalar -> IO () foreign import ccall "cryptonite_p256_add" ccryptonite_p256_add :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO CInt +foreign import ccall "cryptonite_p256_add_d" + ccryptonite_p256_add_d :: Ptr P256Scalar -> P256Digit -> Ptr P256Scalar -> IO CInt foreign import ccall "cryptonite_p256_sub" ccryptonite_p256_sub :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO CInt foreign import ccall "cryptonite_p256_cmp" diff --git a/tests/KAT_PubKey/P256.hs b/tests/KAT_PubKey/P256.hs index cb0b978..e9ed142 100644 --- a/tests/KAT_PubKey/P256.hs +++ b/tests/KAT_PubKey/P256.hs @@ -1,33 +1,141 @@ -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} module KAT_PubKey.P256 (tests) where -import Control.Arrow (second) - import qualified Crypto.PubKey.ECC.Types as ECC import qualified Crypto.PubKey.ECC.Prim as ECC import qualified Crypto.PubKey.ECC.P256 as P256 -import Test.Tasty.KAT -import Test.Tasty.KAT.FileLoader import Data.ByteArray (Bytes) -import Crypto.Number.Serialize (i2ospOf) +import Crypto.Number.Serialize (i2ospOf, os2ip) +import Crypto.Number.ModArithmetic (inverseCoprimes) import Crypto.Error import Imports +newtype P256Scalar = P256Scalar Integer + deriving (Show,Eq,Ord) + +instance Arbitrary P256Scalar where + arbitrary = P256Scalar . getQAInteger <$> arbitrary + +curve = ECC.getCurveByName ECC.SEC_p256r1 +curveN = ECC.ecc_n . ECC.common_curve $ curve +curveGen = ECC.ecc_g . ECC.common_curve $ curve + +pointP256ToECC :: P256.Point -> ECC.Point +pointP256ToECC = uncurry ECC.Point . P256.pointToIntegers + +unP256Scalar :: P256Scalar -> P256.Scalar +unP256Scalar (P256Scalar r') = + let r = if r' == 0 then 0x2901 else (r' `mod` curveN) + rBytes = i2ospScalar r + in case P256.scalarFromBinary rBytes of + CryptoFailed err -> error ("cannot convert scalar: " ++ show err) + CryptoPassed scalar -> scalar + where + i2ospScalar :: Integer -> Bytes + i2ospScalar i = + case i2ospOf 32 i of + Nothing -> error "invalid size of P256 scalar" + Just b -> b + +unP256 :: P256Scalar -> Integer +unP256 (P256Scalar r') = if r' == 0 then 0x2901 else (r' `mod` curveN) + +p256ScalarToInteger :: P256.Scalar -> Integer +p256ScalarToInteger s = os2ip (P256.scalarToBinary s :: Bytes) + +xS = 0xde2444bebc8d36e682edd27e0f271508617519b3221a8fa0b77cab3989da97c9 +yS = 0xc093ae7ff36e5380fc01a5aad1e66659702de80f53cec576b6350b243042a256 +xT = 0x55a8b00f8da1d44e62f6b3b25316212e39540dc861c89575bb8cf92e35e0986b +yT = 0x5421c3209c2d6c704835d82ac4c3dd90f61a8a52598b9e7ab656e9d8c8b24316 +xR = 0x72b13dd4354b6b81745195e98cc5ba6970349191ac476bd4553cf35a545a067e +yR = 0x8d585cbb2e1327d75241a8a122d7620dc33b13315aa5c9d46d013011744ac264 + tests = testGroup "P256" [ testGroup "scalar" - [ testProperty "marshalling" $ \(Positive r') -> + [ testProperty "marshalling" $ \(QAInteger r') -> let r = r' `mod` curveN rBytes = i2ospScalar r in case P256.scalarFromBinary rBytes of CryptoFailed err -> error (show err) CryptoPassed scalar -> rBytes `propertyEq` P256.scalarToBinary scalar + , testProperty "add" $ \r1 r2 -> + let r = (unP256 r1 + unP256 r2) `mod` curveN + r' = P256.scalarAdd (unP256Scalar r1) (unP256Scalar r2) + in r `propertyEq` p256ScalarToInteger r' + , testProperty "add0" $ \r -> + let v = unP256 r + v' = P256.scalarAdd (unP256Scalar r) P256.scalarZero + in v `propertyEq` p256ScalarToInteger v' + , testProperty "add-n-1" $ \r -> + let nm1 = throwCryptoError $ P256.scalarFromInteger (curveN - 1) + v = unP256 r + v' = P256.scalarAdd (unP256Scalar r) nm1 + in (((curveN - 1) + v) `mod` curveN) `propertyEq` p256ScalarToInteger v' + , testProperty "sub" $ \r1 r2 -> + let r = (unP256 r1 - unP256 r2) `mod` curveN + r' = P256.scalarSub (unP256Scalar r1) (unP256Scalar r2) + v = (unP256 r2 - unP256 r1) `mod` curveN + v' = P256.scalarSub (unP256Scalar r2) (unP256Scalar r1) + in propertyHold + [ eqTest "r1-r2" r (p256ScalarToInteger r') + , eqTest "r2-r1" v (p256ScalarToInteger v') + ] + , testProperty "sub-n-1" $ \r -> + let nm1 = throwCryptoError $ P256.scalarFromInteger (curveN - 1) + v = unP256 r + v' = P256.scalarSub (unP256Scalar r) nm1 + in ((v - (curveN - 1)) `mod` curveN) `propertyEq` p256ScalarToInteger v' + , testProperty "inv" $ \r' -> + let inv = inverseCoprimes (unP256 r') curveN + inv' = P256.scalarInv (unP256Scalar r') + in if unP256 r' == 0 then True else inv `propertyEq` p256ScalarToInteger inv' + ] + , testGroup "point" + [ testProperty "marshalling" $ \rx ry -> + let p = P256.pointFromIntegers (unP256 rx, unP256 ry) + b = P256.pointToBinary p :: Bytes + p' = P256.pointFromBinary b + in propertyHold [ eqTest "point" (CryptoPassed p) p' ] + , testProperty "marshalling-integer" $ \rx ry -> + let p = P256.pointFromIntegers (unP256 rx, unP256 ry) + (x,y) = P256.pointToIntegers p + in propertyHold [ eqTest "x" (unP256 rx) x, eqTest "y" (unP256 ry) y ] + , testCase "valid-point-1" $ casePointIsValid (xS,yS) + , testCase "valid-point-2" $ casePointIsValid (xR,yR) + , testCase "valid-point-3" $ casePointIsValid (xT,yT) + , testCase "point-add-1" $ + let s = P256.pointFromIntegers (xS, yS) + t = P256.pointFromIntegers (xT, yT) + r = P256.pointFromIntegers (xR, yR) + in r @=? P256.pointAdd s t + , testProperty "lift-to-curve" $ propertyLiftToCurve + , testProperty "point-add" $ propertyPointAdd ] ] where - curve = ECC.getCurveByName ECC.SEC_p256r1 - curveN = ECC.ecc_n . ECC.common_curve $ curve + casePointIsValid pointTuple = + let s = P256.pointFromIntegers pointTuple in True @=? P256.pointIsValid s + + propertyLiftToCurve r = + let p = P256.toPoint (unP256Scalar r) + (x,y) = P256.pointToIntegers p + pEcc = ECC.pointMul curve (unP256 r) curveGen + in pEcc `propertyEq` ECC.Point x y + + propertyPointAdd r1 r2 = + let p1 = P256.toPoint (unP256Scalar r1) + p2 = P256.toPoint (unP256Scalar r2) + pe1 = ECC.pointMul curve (unP256 r1) curveGen + pe2 = ECC.pointMul curve (unP256 r2) curveGen + pR = P256.toPoint (P256.scalarAdd (unP256Scalar r1) (unP256Scalar r2)) + peR = ECC.pointAdd curve pe1 pe2 + (x,y) = P256.pointToIntegers (P256.pointAdd p1 p2) -- P256.pointToIntegers pR + in propertyHold [ eqTest "p256" pR (P256.pointAdd p1 p2) + , eqTest "ecc" peR (pointP256ToECC pR) + ] i2ospScalar :: Integer -> Bytes i2ospScalar i =