Merge pull request #274 from ocheron/p256-add-sub
Improve P256.scalarAdd and P256.scalarSub
This commit is contained in:
commit
c9f8dac6b0
@ -45,7 +45,6 @@ module Crypto.PubKey.ECC.P256
|
||||
import Data.Word
|
||||
import Foreign.Ptr
|
||||
import Foreign.C.Types
|
||||
import Control.Monad
|
||||
|
||||
import Crypto.Internal.Compat
|
||||
import Crypto.Internal.Imports
|
||||
@ -222,34 +221,21 @@ scalarIsZero s = unsafeDoIO $ withScalar s $ \d -> do
|
||||
result <- ccryptonite_p256_is_zero d
|
||||
return $ result /= 0
|
||||
|
||||
scalarNeedReducing :: Ptr P256Scalar -> IO Bool
|
||||
scalarNeedReducing d = do
|
||||
c <- ccryptonite_p256_cmp d ccryptonite_SECP256r1_n
|
||||
return (c >= 0)
|
||||
|
||||
-- | Perform addition between two scalars
|
||||
--
|
||||
-- > a + b
|
||||
scalarAdd :: Scalar -> Scalar -> Scalar
|
||||
scalarAdd a b =
|
||||
withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb -> do
|
||||
carry <- ccryptonite_p256_add pa pb d
|
||||
when (carry /= 0) $ void $ ccryptonite_p256_sub d ccryptonite_SECP256r1_n d
|
||||
needReducing <- scalarNeedReducing d
|
||||
when needReducing $ do
|
||||
ccryptonite_p256_mod ccryptonite_SECP256r1_n d d
|
||||
withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb ->
|
||||
ccryptonite_p256e_modadd ccryptonite_SECP256r1_n pa pb d
|
||||
|
||||
-- | Perform subtraction between two scalars
|
||||
--
|
||||
-- > a - b
|
||||
scalarSub :: Scalar -> Scalar -> Scalar
|
||||
scalarSub a b =
|
||||
withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb -> do
|
||||
borrow <- ccryptonite_p256_sub pa pb d
|
||||
when (borrow /= 0) $ void $ ccryptonite_p256_add d ccryptonite_SECP256r1_n d
|
||||
--needReducing <- scalarNeedReducing d
|
||||
--when needReducing $ do
|
||||
-- ccryptonite_p256_mod ccryptonite_SECP256r1_n d d
|
||||
withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb ->
|
||||
ccryptonite_p256e_modsub ccryptonite_SECP256r1_n pa pb d
|
||||
|
||||
-- | Give the inverse of the scalar
|
||||
--
|
||||
@ -352,12 +338,12 @@ foreign import ccall "cryptonite_p256_is_zero"
|
||||
ccryptonite_p256_is_zero :: Ptr P256Scalar -> IO CInt
|
||||
foreign import ccall "cryptonite_p256_clear"
|
||||
ccryptonite_p256_clear :: Ptr P256Scalar -> IO ()
|
||||
foreign import ccall "cryptonite_p256_add"
|
||||
ccryptonite_p256_add :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
|
||||
foreign import ccall "cryptonite_p256e_modadd"
|
||||
ccryptonite_p256e_modadd :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
|
||||
foreign import ccall "cryptonite_p256_add_d"
|
||||
ccryptonite_p256_add_d :: Ptr P256Scalar -> P256Digit -> Ptr P256Scalar -> IO CInt
|
||||
foreign import ccall "cryptonite_p256_sub"
|
||||
ccryptonite_p256_sub :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
|
||||
foreign import ccall "cryptonite_p256e_modsub"
|
||||
ccryptonite_p256e_modsub :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
|
||||
foreign import ccall "cryptonite_p256_cmp"
|
||||
ccryptonite_p256_cmp :: Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
|
||||
foreign import ccall "cryptonite_p256_mod"
|
||||
|
||||
@ -386,3 +386,25 @@ void cryptonite_p256_to_bin(const cryptonite_p256_int* src, uint8_t dst[P256_NBY
|
||||
p += 4;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
"p256e" functions are not part of the original source
|
||||
*/
|
||||
|
||||
#define MSB_COMPLEMENT(x) (((x) >> (P256_BITSPERDIGIT - 1)) - 1)
|
||||
|
||||
// c = a + b mod MOD
|
||||
void cryptonite_p256e_modadd(const cryptonite_p256_int* MOD, const cryptonite_p256_int* a, const cryptonite_p256_int* b, cryptonite_p256_int* c) {
|
||||
cryptonite_p256_digit top = cryptonite_p256_add(a, b, c);
|
||||
top = subM(MOD, top, P256_DIGITS(c), -1);
|
||||
top = subM(MOD, top, P256_DIGITS(c), MSB_COMPLEMENT(top));
|
||||
addM(MOD, 0, P256_DIGITS(c), top);
|
||||
}
|
||||
|
||||
// c = a - b mod MOD
|
||||
void cryptonite_p256e_modsub(const cryptonite_p256_int* MOD, const cryptonite_p256_int* a, const cryptonite_p256_int* b, cryptonite_p256_int* c) {
|
||||
cryptonite_p256_digit top = cryptonite_p256_sub(a, b, c);
|
||||
top = addM(MOD, top, P256_DIGITS(c), ~MSB_COMPLEMENT(top));
|
||||
top = subM(MOD, top, P256_DIGITS(c), MSB_COMPLEMENT(top));
|
||||
addM(MOD, 0, P256_DIGITS(c), top);
|
||||
}
|
||||
|
||||
@ -17,7 +17,19 @@ newtype P256Scalar = P256Scalar Integer
|
||||
deriving (Show,Eq,Ord)
|
||||
|
||||
instance Arbitrary P256Scalar where
|
||||
arbitrary = P256Scalar . getQAInteger <$> arbitrary
|
||||
-- Cover the full range up to 2^256-1 except 0 and curveN. To test edge
|
||||
-- cases with arithmetic functions, some values close to 0, curveN and
|
||||
-- 2^256 are given higher frequency.
|
||||
arbitrary = P256Scalar <$> oneof
|
||||
[ choose (1, w)
|
||||
, choose (w + 1, curveN - w - 1)
|
||||
, choose (curveN - w, curveN - 1)
|
||||
, choose (curveN + 1, curveN + w)
|
||||
, choose (curveN + w + 1, high - w - 1)
|
||||
, choose (high - w, high - 1)
|
||||
]
|
||||
where high = 2^(256 :: Int)
|
||||
w = 100
|
||||
|
||||
curve = ECC.getCurveByName ECC.SEC_p256r1
|
||||
curveN = ECC.ecc_n . ECC.common_curve $ curve
|
||||
@ -26,22 +38,21 @@ curveGen = ECC.ecc_g . ECC.common_curve $ curve
|
||||
pointP256ToECC :: P256.Point -> ECC.Point
|
||||
pointP256ToECC = uncurry ECC.Point . P256.pointToIntegers
|
||||
|
||||
i2ospScalar :: Integer -> Bytes
|
||||
i2ospScalar i =
|
||||
case i2ospOf 32 i of
|
||||
Nothing -> error "invalid size of P256 scalar"
|
||||
Just b -> b
|
||||
|
||||
unP256Scalar :: P256Scalar -> P256.Scalar
|
||||
unP256Scalar (P256Scalar r') =
|
||||
let r = if r' == 0 then 0x2901 else (r' `mod` curveN)
|
||||
rBytes = i2ospScalar r
|
||||
unP256Scalar (P256Scalar r) =
|
||||
let rBytes = i2ospScalar r
|
||||
in case P256.scalarFromBinary rBytes of
|
||||
CryptoFailed err -> error ("cannot convert scalar: " ++ show err)
|
||||
CryptoPassed scalar -> scalar
|
||||
where
|
||||
i2ospScalar :: Integer -> Bytes
|
||||
i2ospScalar i =
|
||||
case i2ospOf 32 i of
|
||||
Nothing -> error "invalid size of P256 scalar"
|
||||
Just b -> b
|
||||
|
||||
unP256 :: P256Scalar -> Integer
|
||||
unP256 (P256Scalar r') = if r' == 0 then 0x2901 else (r' `mod` curveN)
|
||||
unP256 (P256Scalar r) = r
|
||||
|
||||
p256ScalarToInteger :: P256.Scalar -> Integer
|
||||
p256ScalarToInteger s = os2ip (P256.scalarToBinary s :: Bytes)
|
||||
@ -55,9 +66,8 @@ yR = 0x8d585cbb2e1327d75241a8a122d7620dc33b13315aa5c9d46d013011744ac264
|
||||
|
||||
tests = testGroup "P256"
|
||||
[ testGroup "scalar"
|
||||
[ testProperty "marshalling" $ \(QAInteger r') ->
|
||||
let r = r' `mod` curveN
|
||||
rBytes = i2ospScalar r
|
||||
[ testProperty "marshalling" $ \(QAInteger r) ->
|
||||
let rBytes = i2ospScalar r
|
||||
in case P256.scalarFromBinary rBytes of
|
||||
CryptoFailed err -> error (show err)
|
||||
CryptoPassed scalar -> rBytes `propertyEq` P256.scalarToBinary scalar
|
||||
@ -66,14 +76,9 @@ tests = testGroup "P256"
|
||||
r' = P256.scalarAdd (unP256Scalar r1) (unP256Scalar r2)
|
||||
in r `propertyEq` p256ScalarToInteger r'
|
||||
, testProperty "add0" $ \r ->
|
||||
let v = unP256 r
|
||||
let v = unP256 r `mod` curveN
|
||||
v' = P256.scalarAdd (unP256Scalar r) P256.scalarZero
|
||||
in v `propertyEq` p256ScalarToInteger v'
|
||||
, testProperty "add-n-1" $ \r ->
|
||||
let nm1 = throwCryptoError $ P256.scalarFromInteger (curveN - 1)
|
||||
v = unP256 r
|
||||
v' = P256.scalarAdd (unP256Scalar r) nm1
|
||||
in (((curveN - 1) + v) `mod` curveN) `propertyEq` p256ScalarToInteger v'
|
||||
, testProperty "sub" $ \r1 r2 ->
|
||||
let r = (unP256 r1 - unP256 r2) `mod` curveN
|
||||
r' = P256.scalarSub (unP256Scalar r1) (unP256Scalar r2)
|
||||
@ -83,11 +88,10 @@ tests = testGroup "P256"
|
||||
[ eqTest "r1-r2" r (p256ScalarToInteger r')
|
||||
, eqTest "r2-r1" v (p256ScalarToInteger v')
|
||||
]
|
||||
, testProperty "sub-n-1" $ \r ->
|
||||
let nm1 = throwCryptoError $ P256.scalarFromInteger (curveN - 1)
|
||||
v = unP256 r
|
||||
v' = P256.scalarSub (unP256Scalar r) nm1
|
||||
in ((v - (curveN - 1)) `mod` curveN) `propertyEq` p256ScalarToInteger v'
|
||||
, testProperty "sub0" $ \r ->
|
||||
let v = unP256 r `mod` curveN
|
||||
v' = P256.scalarSub (unP256Scalar r) P256.scalarZero
|
||||
in v `propertyEq` p256ScalarToInteger v'
|
||||
, testProperty "inv" $ \r' ->
|
||||
let inv = inverseCoprimes (unP256 r') curveN
|
||||
inv' = P256.scalarInv (unP256Scalar r')
|
||||
@ -133,7 +137,8 @@ tests = testGroup "P256"
|
||||
pe2 = ECC.pointMul curve (unP256 r2) curveGen
|
||||
pR = P256.toPoint (P256.scalarAdd (unP256Scalar r1) (unP256Scalar r2))
|
||||
peR = ECC.pointAdd curve pe1 pe2
|
||||
in propertyHold [ eqTest "p256" pR (P256.pointAdd p1 p2)
|
||||
in (unP256 r1 + unP256 r2) `mod` curveN /= 0 ==>
|
||||
propertyHold [ eqTest "p256" pR (P256.pointAdd p1 p2)
|
||||
, eqTest "ecc" peR (pointP256ToECC pR)
|
||||
]
|
||||
|
||||
@ -142,9 +147,3 @@ tests = testGroup "P256"
|
||||
pe = ECC.pointMul curve (unP256 r) curveGen
|
||||
pR = P256.pointNegate p
|
||||
in ECC.pointNegate curve pe `propertyEq` (pointP256ToECC pR)
|
||||
|
||||
i2ospScalar :: Integer -> Bytes
|
||||
i2ospScalar i =
|
||||
case i2ospOf 32 i of
|
||||
Nothing -> error "invalid size of P256 scalar"
|
||||
Just b -> b
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
module Utils where
|
||||
|
||||
import Control.Applicative
|
||||
import Control.Monad (replicateM)
|
||||
import Data.Char
|
||||
import Data.Word
|
||||
import Data.List
|
||||
@ -28,13 +27,13 @@ newtype ChunkingLen = ChunkingLen [Int]
|
||||
deriving (Show,Eq)
|
||||
|
||||
instance Arbitrary ChunkingLen where
|
||||
arbitrary = ChunkingLen `fmap` replicateM 16 (choose (0,14))
|
||||
arbitrary = ChunkingLen `fmap` vectorOf 16 (choose (0,14))
|
||||
|
||||
newtype ChunkingLen0_127 = ChunkingLen0_127 [Int]
|
||||
deriving (Show,Eq)
|
||||
|
||||
instance Arbitrary ChunkingLen0_127 where
|
||||
arbitrary = ChunkingLen0_127 `fmap` replicateM 16 (choose (0,127))
|
||||
arbitrary = ChunkingLen0_127 `fmap` vectorOf 16 (choose (0,127))
|
||||
|
||||
|
||||
newtype ArbitraryBS0_2901 = ArbitraryBS0_2901 ByteString
|
||||
@ -63,7 +62,7 @@ instance Arbitrary QAInteger where
|
||||
arbitrary = oneof
|
||||
[ QAInteger . fromIntegral <$> (choose (0, 65536) :: Gen Int) -- small integer
|
||||
, larger <$> choose (0,4096) <*> choose (0, 65536) -- medium integer
|
||||
, QAInteger . os2ip . B.pack <$> (choose (0,32) >>= \n -> replicateM n arbitrary) -- [ 0 .. 2^32 ] sized integer
|
||||
, QAInteger . os2ip <$> arbitraryBSof 0 32 -- [ 0 .. 2^32 ] sized integer
|
||||
]
|
||||
where
|
||||
larger :: Int -> Int -> QAInteger
|
||||
@ -73,10 +72,10 @@ instance Arbitrary QAInteger where
|
||||
somePrime = 18446744073709551557
|
||||
|
||||
arbitraryBS :: Int -> Gen ByteString
|
||||
arbitraryBS n = B.pack `fmap` replicateM n arbitrary
|
||||
arbitraryBS = fmap B.pack . vector
|
||||
|
||||
arbitraryBSof :: Int -> Int -> Gen ByteString
|
||||
arbitraryBSof minSize maxSize = choose (minSize, maxSize) >>= \n -> (B.pack `fmap` replicateM n arbitrary)
|
||||
arbitraryBSof minSize maxSize = choose (minSize, maxSize) >>= arbitraryBS
|
||||
|
||||
chunkS :: ChunkingLen -> ByteString -> [ByteString]
|
||||
chunkS (ChunkingLen originalChunks) = loop originalChunks
|
||||
|
||||
Loading…
Reference in New Issue
Block a user