diff --git a/Crypto/Internal/Builder.hs b/Crypto/Internal/Builder.hs new file mode 100644 index 0000000..d33ebfd --- /dev/null +++ b/Crypto/Internal/Builder.hs @@ -0,0 +1,50 @@ +-- | +-- Module : Crypto.Internal.Builder +-- License : BSD-style +-- Maintainer : Olivier Chéron +-- Stability : stable +-- Portability : Good +-- +-- Delaying and merging ByteArray allocations. This is similar to module +-- "Data.ByteArray.Pack" except the total length is computed automatically based +-- on what is appended. +-- +{-# LANGUAGE BangPatterns #-} +module Crypto.Internal.Builder + ( Builder + , buildAndFreeze + , builderLength + , (<+>) + , byte + , bytes + , zero + ) where + +import Data.ByteArray (ByteArray, ByteArrayAccess) +import qualified Data.ByteArray as B +import Data.Memory.PtrMethods (memSet) +import Data.Word (Word8) + +import Foreign.Ptr (Ptr, plusPtr) +import Foreign.Storable (poke) + +data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer + +(<+>) :: Builder -> Builder -> Builder +(Builder s1 f1) <+> (Builder s2 f2) = Builder (s1 + s2) f + where f p = f1 p >> f2 (p `plusPtr` s1) + +builderLength :: Builder -> Int +builderLength (Builder s _) = s + +buildAndFreeze :: ByteArray ba => Builder -> ba +buildAndFreeze (Builder s f) = B.allocAndFreeze s f + +byte :: Word8 -> Builder +byte !b = Builder 1 (`poke` b) + +bytes :: ByteArrayAccess ba => ba -> Builder +bytes bs = Builder (B.length bs) (B.copyByteArrayToPtr bs) + +zero :: Int -> Builder +zero s = Builder s (\p -> memSet p 0 s) diff --git a/Crypto/MAC/KMAC.hs b/Crypto/MAC/KMAC.hs index f07e9e9..def8b98 100644 --- a/Crypto/MAC/KMAC.hs +++ b/Crypto/MAC/KMAC.hs @@ -27,13 +27,11 @@ import qualified Crypto.Hash as H import Crypto.Hash.SHAKE (HashSHAKE(..)) import Crypto.Hash.Types (HashAlgorithm(..), Digest(..)) import qualified Crypto.Hash.Types as H -import Foreign.Ptr (Ptr, plusPtr) -import Foreign.Storable (poke) +import Crypto.Internal.Builder +import Foreign.Ptr (Ptr) import Data.Bits (shiftR) -import Data.ByteArray (ByteArray, ByteArrayAccess) +import Data.ByteArray (ByteArrayAccess) import qualified Data.ByteArray as B -import Data.Word (Word8) -import Data.Memory.PtrMethods (memSet) -- cSHAKE @@ -48,7 +46,7 @@ cshakeInit n s p = H.Context $ B.allocAndFreeze c $ \(ptr :: Ptr (H.Context a)) c = hashInternalContextSize (undefined :: a) w = hashBlockSize (undefined :: a) x = encodeString n <+> encodeString s - b = builderAllocAndFreeze (bytepad x w) :: B.Bytes + b = buildAndFreeze (bytepad x w) :: B.Bytes cshakeUpdate :: (HashSHAKE a, ByteArrayAccess ba) => H.Context a -> ba -> H.Context a @@ -99,7 +97,7 @@ initialize str key = Context $ cshakeInit n str p where n = B.pack [75,77,65,67] :: B.Bytes -- "KMAC" w = hashBlockSize (undefined :: a) - p = builderAllocAndFreeze (bytepad (encodeString key) w) :: B.ScrubbedBytes + p = buildAndFreeze (bytepad (encodeString key) w) :: B.ScrubbedBytes -- | Incrementally update a KMAC context. update :: (HashSHAKE a, ByteArrayAccess ba) => Context a -> ba -> Context a @@ -114,7 +112,7 @@ finalize :: forall a . HashSHAKE a => Context a -> KMAC a finalize (Context ctx) = KMAC $ cshakeFinalize ctx suffix where l = cshakeOutputLength (undefined :: a) - suffix = builderAllocAndFreeze (rightEncode l) :: B.Bytes + suffix = buildAndFreeze (rightEncode l) :: B.Bytes -- Utilities @@ -143,27 +141,3 @@ rightEncode x = digits <+> byte len i2osp :: Int -> Builder i2osp i | i >= 256 = i2osp (shiftR i 8) <+> byte (fromIntegral i) | otherwise = byte (fromIntegral i) - - --- Delaying and merging ByteArray allocations - -data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer - -(<+>) :: Builder -> Builder -> Builder -(Builder s1 f1) <+> (Builder s2 f2) = Builder (s1 + s2) f - where f p = f1 p >> f2 (p `plusPtr` s1) - -builderLength :: Builder -> Int -builderLength (Builder s _) = s - -builderAllocAndFreeze :: ByteArray ba => Builder -> ba -builderAllocAndFreeze (Builder s f) = B.allocAndFreeze s f - -byte :: Word8 -> Builder -byte !b = Builder 1 (`poke` b) - -bytes :: ByteArrayAccess ba => ba -> Builder -bytes bs = Builder (B.length bs) (B.copyByteArrayToPtr bs) - -zero :: Int -> Builder -zero s = Builder s (\p -> memSet p 0 s) diff --git a/Crypto/PubKey/EdDSA.hs b/Crypto/PubKey/EdDSA.hs index f0fd8ec..67b733c 100644 --- a/Crypto/PubKey/EdDSA.hs +++ b/Crypto/PubKey/EdDSA.hs @@ -48,7 +48,7 @@ module Crypto.PubKey.EdDSA ) where import Data.Bits -import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes) +import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes, View) import qualified Data.ByteArray as B import Data.ByteString (ByteString) import Data.Proxy @@ -62,6 +62,7 @@ import Crypto.Random import GHC.TypeLits (KnownNat, Nat) +import Crypto.Internal.Builder import Crypto.Internal.Compat import Crypto.Internal.Imports import Crypto.Internal.Nat (integralNatVal) @@ -96,7 +97,7 @@ class ( EllipticCurveBasepointArith curve -- hash with specified parameters hashWithDom :: (HashAlgorithm hash, ByteArrayAccess ctx, ByteArrayAccess msg) - => proxy curve -> hash -> Bool -> ctx -> [Bytes] -> msg -> Bytes + => proxy curve -> hash -> Bool -> ctx -> Builder -> msg -> Bytes -- conversion between scalar, point and public key pointPublic :: proxy curve -> Point curve -> PublicKey curve hash @@ -111,7 +112,7 @@ class ( EllipticCurveBasepointArith curve => proxy curve -> hash -> SecretKey curve - -> (Scalar curve, Bytes) + -> (Scalar curve, View Bytes) -- | Size of public keys for this curve (in bytes) publicKeySize :: EllipticCurveEdDSA curve => proxy curve -> Int @@ -255,7 +256,7 @@ signPhCtx :: forall proxy curve hash ctx msg . signPhCtx prx ph ctx priv pub msg = let alg = undefined :: hash (s, prefix) = scheduleSecret prx alg priv - digR = hashWithDom prx alg ph ctx [prefix] msg + digR = hashWithDom prx alg ph ctx (bytes prefix) msg r = decodeScalarNoErr prx digR pR = pointBaseSmul prx r bsR = encodePoint prx pR @@ -295,19 +296,18 @@ getK :: forall proxy curve hash ctx msg . => proxy curve -> Bool -> ctx -> PublicKey curve hash -> Bytes -> msg -> Scalar curve getK prx ph ctx (PublicKey pub) bsR msg = let alg = undefined :: hash - digK = hashWithDom prx alg ph ctx [bsR, pub] msg + digK = hashWithDom prx alg ph ctx (bytes bsR <+> bytes pub) msg in decodeScalarNoErr prx digK encodeSignature :: EllipticCurveEdDSA curve => proxy curve -> (Bytes, Point curve, Scalar curve) -> Signature curve hash -encodeSignature prx (bsR, _, sS) = Signature $ - if len0 > 0 then B.concat [ bsR, bsS, pad0 ] else B.append bsR bsS +encodeSignature prx (bsR, _, sS) = Signature $ buildAndFreeze $ + bytes bsR <+> bytes bsS <+> zero len0 where - bsS = encodeScalarLE prx sS + bsS = encodeScalarLE prx sS :: Bytes len0 = signatureSize prx - B.length bsR - B.length bsS - pad0 = B.zero len0 decodeSignature :: ( EllipticCurveEdDSA curve , HashDigestSize hash ~ CurveDigestSize curve @@ -339,12 +339,11 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where hashWithDom _ alg ph ctx bss | not ph && B.null ctx = digestDomMsg alg bss - | otherwise = digestDomMsg alg (bs:bss) - where bs = B.concat [ "SigEd25519 no Ed25519 collisions" :: ByteString - , B.singleton $ if ph then 1 else 0 - , B.singleton $ fromIntegral $ B.length ctx - , B.convert ctx - ] + | otherwise = digestDomMsg alg (dom <+> bss) + where dom = bytes ("SigEd25519 no Ed25519 collisions" :: ByteString) <+> + byte (if ph then 1 else 0) <+> + byte (fromIntegral $ B.length ctx) <+> + bytes ctx pointPublic _ = PublicKey . Edwards25519.pointEncode publicPoint _ = Edwards25519.pointDecode @@ -352,7 +351,7 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where decodeScalarLE _ = Edwards25519.scalarDecodeLong scheduleSecret prx alg priv = - (decodeScalarNoErr prx clamped, B.drop 32 hashed) + (decodeScalarNoErr prx clamped, B.dropView hashed 32) where hashed = digest alg ($ priv) @@ -377,9 +376,9 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where -} digestDomMsg :: (HashAlgorithm alg, ByteArrayAccess msg) - => alg -> [Bytes] -> msg -> Bytes + => alg -> Builder -> msg -> Bytes digestDomMsg alg bss bs = digest alg $ \update -> - update (B.concat bss :: Bytes) >> update bs + update (buildAndFreeze bss :: Bytes) >> update bs digest :: HashAlgorithm alg => alg diff --git a/cryptonite.cabal b/cryptonite.cabal index 245d8c2..619c0f3 100644 --- a/cryptonite.cabal +++ b/cryptonite.cabal @@ -230,6 +230,7 @@ Library Crypto.PubKey.ElGamal Crypto.ECC.Simple.Types Crypto.ECC.Simple.Prim + Crypto.Internal.Builder Crypto.Internal.ByteArray Crypto.Internal.Compat Crypto.Internal.CompatPrim