Add and use Builder module

Avoids intermediate allocations and conversions when concatenating
byte arrays of different types.
This commit is contained in:
Olivier Chéron 2020-02-09 13:41:37 +01:00
parent ef880291e3
commit b01f610aa2
4 changed files with 74 additions and 50 deletions

View File

@ -0,0 +1,50 @@
-- |
-- Module : Crypto.Internal.Builder
-- License : BSD-style
-- Maintainer : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability : stable
-- Portability : Good
--
-- Delaying and merging ByteArray allocations. This is similar to module
-- "Data.ByteArray.Pack" except the total length is computed automatically based
-- on what is appended.
--
{-# LANGUAGE BangPatterns #-}
module Crypto.Internal.Builder
( Builder
, buildAndFreeze
, builderLength
, (<+>)
, byte
, bytes
, zero
) where
import Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteArray as B
import Data.Memory.PtrMethods (memSet)
import Data.Word (Word8)
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (poke)
data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer
(<+>) :: Builder -> Builder -> Builder
(Builder s1 f1) <+> (Builder s2 f2) = Builder (s1 + s2) f
where f p = f1 p >> f2 (p `plusPtr` s1)
builderLength :: Builder -> Int
builderLength (Builder s _) = s
buildAndFreeze :: ByteArray ba => Builder -> ba
buildAndFreeze (Builder s f) = B.allocAndFreeze s f
byte :: Word8 -> Builder
byte !b = Builder 1 (`poke` b)
bytes :: ByteArrayAccess ba => ba -> Builder
bytes bs = Builder (B.length bs) (B.copyByteArrayToPtr bs)
zero :: Int -> Builder
zero s = Builder s (\p -> memSet p 0 s)

View File

@ -27,13 +27,11 @@ import qualified Crypto.Hash as H
import Crypto.Hash.SHAKE (HashSHAKE(..))
import Crypto.Hash.Types (HashAlgorithm(..), Digest(..))
import qualified Crypto.Hash.Types as H
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (poke)
import Crypto.Internal.Builder
import Foreign.Ptr (Ptr)
import Data.Bits (shiftR)
import Data.ByteArray (ByteArray, ByteArrayAccess)
import Data.ByteArray (ByteArrayAccess)
import qualified Data.ByteArray as B
import Data.Word (Word8)
import Data.Memory.PtrMethods (memSet)
-- cSHAKE
@ -48,7 +46,7 @@ cshakeInit n s p = H.Context $ B.allocAndFreeze c $ \(ptr :: Ptr (H.Context a))
c = hashInternalContextSize (undefined :: a)
w = hashBlockSize (undefined :: a)
x = encodeString n <+> encodeString s
b = builderAllocAndFreeze (bytepad x w) :: B.Bytes
b = buildAndFreeze (bytepad x w) :: B.Bytes
cshakeUpdate :: (HashSHAKE a, ByteArrayAccess ba)
=> H.Context a -> ba -> H.Context a
@ -99,7 +97,7 @@ initialize str key = Context $ cshakeInit n str p
where
n = B.pack [75,77,65,67] :: B.Bytes -- "KMAC"
w = hashBlockSize (undefined :: a)
p = builderAllocAndFreeze (bytepad (encodeString key) w) :: B.ScrubbedBytes
p = buildAndFreeze (bytepad (encodeString key) w) :: B.ScrubbedBytes
-- | Incrementally update a KMAC context.
update :: (HashSHAKE a, ByteArrayAccess ba) => Context a -> ba -> Context a
@ -114,7 +112,7 @@ finalize :: forall a . HashSHAKE a => Context a -> KMAC a
finalize (Context ctx) = KMAC $ cshakeFinalize ctx suffix
where
l = cshakeOutputLength (undefined :: a)
suffix = builderAllocAndFreeze (rightEncode l) :: B.Bytes
suffix = buildAndFreeze (rightEncode l) :: B.Bytes
-- Utilities
@ -143,27 +141,3 @@ rightEncode x = digits <+> byte len
i2osp :: Int -> Builder
i2osp i | i >= 256 = i2osp (shiftR i 8) <+> byte (fromIntegral i)
| otherwise = byte (fromIntegral i)
-- Delaying and merging ByteArray allocations
data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer
(<+>) :: Builder -> Builder -> Builder
(Builder s1 f1) <+> (Builder s2 f2) = Builder (s1 + s2) f
where f p = f1 p >> f2 (p `plusPtr` s1)
builderLength :: Builder -> Int
builderLength (Builder s _) = s
builderAllocAndFreeze :: ByteArray ba => Builder -> ba
builderAllocAndFreeze (Builder s f) = B.allocAndFreeze s f
byte :: Word8 -> Builder
byte !b = Builder 1 (`poke` b)
bytes :: ByteArrayAccess ba => ba -> Builder
bytes bs = Builder (B.length bs) (B.copyByteArrayToPtr bs)
zero :: Int -> Builder
zero s = Builder s (\p -> memSet p 0 s)

View File

@ -48,7 +48,7 @@ module Crypto.PubKey.EdDSA
) where
import Data.Bits
import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes)
import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes, View)
import qualified Data.ByteArray as B
import Data.ByteString (ByteString)
import Data.Proxy
@ -62,6 +62,7 @@ import Crypto.Random
import GHC.TypeLits (KnownNat, Nat)
import Crypto.Internal.Builder
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Crypto.Internal.Nat (integralNatVal)
@ -96,7 +97,7 @@ class ( EllipticCurveBasepointArith curve
-- hash with specified parameters
hashWithDom :: (HashAlgorithm hash, ByteArrayAccess ctx, ByteArrayAccess msg)
=> proxy curve -> hash -> Bool -> ctx -> [Bytes] -> msg -> Bytes
=> proxy curve -> hash -> Bool -> ctx -> Builder -> msg -> Bytes
-- conversion between scalar, point and public key
pointPublic :: proxy curve -> Point curve -> PublicKey curve hash
@ -111,7 +112,7 @@ class ( EllipticCurveBasepointArith curve
=> proxy curve
-> hash
-> SecretKey curve
-> (Scalar curve, Bytes)
-> (Scalar curve, View Bytes)
-- | Size of public keys for this curve (in bytes)
publicKeySize :: EllipticCurveEdDSA curve => proxy curve -> Int
@ -255,7 +256,7 @@ signPhCtx :: forall proxy curve hash ctx msg .
signPhCtx prx ph ctx priv pub msg =
let alg = undefined :: hash
(s, prefix) = scheduleSecret prx alg priv
digR = hashWithDom prx alg ph ctx [prefix] msg
digR = hashWithDom prx alg ph ctx (bytes prefix) msg
r = decodeScalarNoErr prx digR
pR = pointBaseSmul prx r
bsR = encodePoint prx pR
@ -295,19 +296,18 @@ getK :: forall proxy curve hash ctx msg .
=> proxy curve -> Bool -> ctx -> PublicKey curve hash -> Bytes -> msg -> Scalar curve
getK prx ph ctx (PublicKey pub) bsR msg =
let alg = undefined :: hash
digK = hashWithDom prx alg ph ctx [bsR, pub] msg
digK = hashWithDom prx alg ph ctx (bytes bsR <+> bytes pub) msg
in decodeScalarNoErr prx digK
encodeSignature :: EllipticCurveEdDSA curve
=> proxy curve
-> (Bytes, Point curve, Scalar curve)
-> Signature curve hash
encodeSignature prx (bsR, _, sS) = Signature $
if len0 > 0 then B.concat [ bsR, bsS, pad0 ] else B.append bsR bsS
encodeSignature prx (bsR, _, sS) = Signature $ buildAndFreeze $
bytes bsR <+> bytes bsS <+> zero len0
where
bsS = encodeScalarLE prx sS
bsS = encodeScalarLE prx sS :: Bytes
len0 = signatureSize prx - B.length bsR - B.length bsS
pad0 = B.zero len0
decodeSignature :: ( EllipticCurveEdDSA curve
, HashDigestSize hash ~ CurveDigestSize curve
@ -339,12 +339,11 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where
hashWithDom _ alg ph ctx bss
| not ph && B.null ctx = digestDomMsg alg bss
| otherwise = digestDomMsg alg (bs:bss)
where bs = B.concat [ "SigEd25519 no Ed25519 collisions" :: ByteString
, B.singleton $ if ph then 1 else 0
, B.singleton $ fromIntegral $ B.length ctx
, B.convert ctx
]
| otherwise = digestDomMsg alg (dom <+> bss)
where dom = bytes ("SigEd25519 no Ed25519 collisions" :: ByteString) <+>
byte (if ph then 1 else 0) <+>
byte (fromIntegral $ B.length ctx) <+>
bytes ctx
pointPublic _ = PublicKey . Edwards25519.pointEncode
publicPoint _ = Edwards25519.pointDecode
@ -352,7 +351,7 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where
decodeScalarLE _ = Edwards25519.scalarDecodeLong
scheduleSecret prx alg priv =
(decodeScalarNoErr prx clamped, B.drop 32 hashed)
(decodeScalarNoErr prx clamped, B.dropView hashed 32)
where
hashed = digest alg ($ priv)
@ -377,9 +376,9 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where
-}
digestDomMsg :: (HashAlgorithm alg, ByteArrayAccess msg)
=> alg -> [Bytes] -> msg -> Bytes
=> alg -> Builder -> msg -> Bytes
digestDomMsg alg bss bs = digest alg $ \update ->
update (B.concat bss :: Bytes) >> update bs
update (buildAndFreeze bss :: Bytes) >> update bs
digest :: HashAlgorithm alg
=> alg

View File

@ -230,6 +230,7 @@ Library
Crypto.PubKey.ElGamal
Crypto.ECC.Simple.Types
Crypto.ECC.Simple.Prim
Crypto.Internal.Builder
Crypto.Internal.ByteArray
Crypto.Internal.Compat
Crypto.Internal.CompatPrim