diff --git a/Crypto/ECC/Ed25519.hs b/Crypto/ECC/Ed25519.hs index aea6f69..cf9fcd6 100644 --- a/Crypto/ECC/Ed25519.hs +++ b/Crypto/ECC/Ed25519.hs @@ -20,6 +20,8 @@ module Crypto.ECC.Ed25519 , pointEncode -- * Arithmetic functions , toPoint + , scalarAdd + , scalarMul , pointNegate , pointAdd , pointDouble @@ -119,6 +121,22 @@ scalarDecodeLong bs return $ CryptoPassed (Scalar s) {-# NOINLINE scalarDecodeLong #-} +-- | Add two scalars. +scalarAdd :: Scalar -> Scalar -> Scalar +scalarAdd (Scalar a) (Scalar b) = + Scalar $ B.allocAndFreeze scalarArraySize $ \out -> + withByteArray a $ \pa -> + withByteArray b $ \pb -> + ed25519_scalar_add out pa pb + +-- | Multiply two scalars. +scalarMul :: Scalar -> Scalar -> Scalar +scalarMul (Scalar a) (Scalar b) = + Scalar $ B.allocAndFreeze scalarArraySize $ \out -> + withByteArray a $ \pa -> + withByteArray b $ \pb -> + ed25519_scalar_mul out pa pb + -- | Multiplies a scalar with the curve base point. toPoint :: Scalar -> Point toPoint (Scalar scalar) = @@ -202,6 +220,18 @@ foreign import ccall "cryptonite_ed25519_scalar_decode_long" -> CSize -> IO () +foreign import ccall "cryptonite_ed25519_scalar_add" + ed25519_scalar_add :: Ptr Scalar -- sum + -> Ptr Scalar -- a + -> Ptr Scalar -- b + -> IO () + +foreign import ccall "cryptonite_ed25519_scalar_mul" + ed25519_scalar_mul :: Ptr Scalar -- out + -> Ptr Scalar -- a + -> Ptr Scalar -- b + -> IO () + foreign import ccall "cryptonite_ed25519_point_encode" ed25519_point_encode :: Ptr Word8 -> Ptr Point diff --git a/cbits/ed25519/ed25519-cryptonite-exts.h b/cbits/ed25519/ed25519-cryptonite-exts.h index 3c7fbd8..4bca444 100644 --- a/cbits/ed25519/ed25519-cryptonite-exts.h +++ b/cbits/ed25519/ed25519-cryptonite-exts.h @@ -30,6 +30,16 @@ ED25519_FN(ed25519_scalar_eq) (const bignum256modm a, const bignum256modm b) { return (int) (1 & ((e - 1) >> bignum256modm_bits_per_limb)); } +void +ED25519_FN(ed25519_scalar_add) (bignum256modm r, const bignum256modm x, const bignum256modm y) { + add256_modm(r, x, y); +} + +void +ED25519_FN(ed25519_scalar_mul) (bignum256modm r, const bignum256modm x, const bignum256modm y) { + mul256_modm(r, x, y); +} + /* Point functions