Add and use Builder module
Avoids intermediate allocations and conversions when concatenating byte arrays of different types.
This commit is contained in:
parent
ef880291e3
commit
b01f610aa2
50
Crypto/Internal/Builder.hs
Normal file
50
Crypto/Internal/Builder.hs
Normal file
@ -0,0 +1,50 @@
|
||||
-- |
|
||||
-- Module : Crypto.Internal.Builder
|
||||
-- License : BSD-style
|
||||
-- Maintainer : Olivier Chéron <olivier.cheron@gmail.com>
|
||||
-- Stability : stable
|
||||
-- Portability : Good
|
||||
--
|
||||
-- Delaying and merging ByteArray allocations. This is similar to module
|
||||
-- "Data.ByteArray.Pack" except the total length is computed automatically based
|
||||
-- on what is appended.
|
||||
--
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
module Crypto.Internal.Builder
|
||||
( Builder
|
||||
, buildAndFreeze
|
||||
, builderLength
|
||||
, (<+>)
|
||||
, byte
|
||||
, bytes
|
||||
, zero
|
||||
) where
|
||||
|
||||
import Data.ByteArray (ByteArray, ByteArrayAccess)
|
||||
import qualified Data.ByteArray as B
|
||||
import Data.Memory.PtrMethods (memSet)
|
||||
import Data.Word (Word8)
|
||||
|
||||
import Foreign.Ptr (Ptr, plusPtr)
|
||||
import Foreign.Storable (poke)
|
||||
|
||||
data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer
|
||||
|
||||
(<+>) :: Builder -> Builder -> Builder
|
||||
(Builder s1 f1) <+> (Builder s2 f2) = Builder (s1 + s2) f
|
||||
where f p = f1 p >> f2 (p `plusPtr` s1)
|
||||
|
||||
builderLength :: Builder -> Int
|
||||
builderLength (Builder s _) = s
|
||||
|
||||
buildAndFreeze :: ByteArray ba => Builder -> ba
|
||||
buildAndFreeze (Builder s f) = B.allocAndFreeze s f
|
||||
|
||||
byte :: Word8 -> Builder
|
||||
byte !b = Builder 1 (`poke` b)
|
||||
|
||||
bytes :: ByteArrayAccess ba => ba -> Builder
|
||||
bytes bs = Builder (B.length bs) (B.copyByteArrayToPtr bs)
|
||||
|
||||
zero :: Int -> Builder
|
||||
zero s = Builder s (\p -> memSet p 0 s)
|
||||
@ -27,13 +27,11 @@ import qualified Crypto.Hash as H
|
||||
import Crypto.Hash.SHAKE (HashSHAKE(..))
|
||||
import Crypto.Hash.Types (HashAlgorithm(..), Digest(..))
|
||||
import qualified Crypto.Hash.Types as H
|
||||
import Foreign.Ptr (Ptr, plusPtr)
|
||||
import Foreign.Storable (poke)
|
||||
import Crypto.Internal.Builder
|
||||
import Foreign.Ptr (Ptr)
|
||||
import Data.Bits (shiftR)
|
||||
import Data.ByteArray (ByteArray, ByteArrayAccess)
|
||||
import Data.ByteArray (ByteArrayAccess)
|
||||
import qualified Data.ByteArray as B
|
||||
import Data.Word (Word8)
|
||||
import Data.Memory.PtrMethods (memSet)
|
||||
|
||||
|
||||
-- cSHAKE
|
||||
@ -48,7 +46,7 @@ cshakeInit n s p = H.Context $ B.allocAndFreeze c $ \(ptr :: Ptr (H.Context a))
|
||||
c = hashInternalContextSize (undefined :: a)
|
||||
w = hashBlockSize (undefined :: a)
|
||||
x = encodeString n <+> encodeString s
|
||||
b = builderAllocAndFreeze (bytepad x w) :: B.Bytes
|
||||
b = buildAndFreeze (bytepad x w) :: B.Bytes
|
||||
|
||||
cshakeUpdate :: (HashSHAKE a, ByteArrayAccess ba)
|
||||
=> H.Context a -> ba -> H.Context a
|
||||
@ -99,7 +97,7 @@ initialize str key = Context $ cshakeInit n str p
|
||||
where
|
||||
n = B.pack [75,77,65,67] :: B.Bytes -- "KMAC"
|
||||
w = hashBlockSize (undefined :: a)
|
||||
p = builderAllocAndFreeze (bytepad (encodeString key) w) :: B.ScrubbedBytes
|
||||
p = buildAndFreeze (bytepad (encodeString key) w) :: B.ScrubbedBytes
|
||||
|
||||
-- | Incrementally update a KMAC context.
|
||||
update :: (HashSHAKE a, ByteArrayAccess ba) => Context a -> ba -> Context a
|
||||
@ -114,7 +112,7 @@ finalize :: forall a . HashSHAKE a => Context a -> KMAC a
|
||||
finalize (Context ctx) = KMAC $ cshakeFinalize ctx suffix
|
||||
where
|
||||
l = cshakeOutputLength (undefined :: a)
|
||||
suffix = builderAllocAndFreeze (rightEncode l) :: B.Bytes
|
||||
suffix = buildAndFreeze (rightEncode l) :: B.Bytes
|
||||
|
||||
|
||||
-- Utilities
|
||||
@ -143,27 +141,3 @@ rightEncode x = digits <+> byte len
|
||||
i2osp :: Int -> Builder
|
||||
i2osp i | i >= 256 = i2osp (shiftR i 8) <+> byte (fromIntegral i)
|
||||
| otherwise = byte (fromIntegral i)
|
||||
|
||||
|
||||
-- Delaying and merging ByteArray allocations
|
||||
|
||||
data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer
|
||||
|
||||
(<+>) :: Builder -> Builder -> Builder
|
||||
(Builder s1 f1) <+> (Builder s2 f2) = Builder (s1 + s2) f
|
||||
where f p = f1 p >> f2 (p `plusPtr` s1)
|
||||
|
||||
builderLength :: Builder -> Int
|
||||
builderLength (Builder s _) = s
|
||||
|
||||
builderAllocAndFreeze :: ByteArray ba => Builder -> ba
|
||||
builderAllocAndFreeze (Builder s f) = B.allocAndFreeze s f
|
||||
|
||||
byte :: Word8 -> Builder
|
||||
byte !b = Builder 1 (`poke` b)
|
||||
|
||||
bytes :: ByteArrayAccess ba => ba -> Builder
|
||||
bytes bs = Builder (B.length bs) (B.copyByteArrayToPtr bs)
|
||||
|
||||
zero :: Int -> Builder
|
||||
zero s = Builder s (\p -> memSet p 0 s)
|
||||
|
||||
@ -48,7 +48,7 @@ module Crypto.PubKey.EdDSA
|
||||
) where
|
||||
|
||||
import Data.Bits
|
||||
import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes)
|
||||
import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes, View)
|
||||
import qualified Data.ByteArray as B
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Proxy
|
||||
@ -62,6 +62,7 @@ import Crypto.Random
|
||||
|
||||
import GHC.TypeLits (KnownNat, Nat)
|
||||
|
||||
import Crypto.Internal.Builder
|
||||
import Crypto.Internal.Compat
|
||||
import Crypto.Internal.Imports
|
||||
import Crypto.Internal.Nat (integralNatVal)
|
||||
@ -96,7 +97,7 @@ class ( EllipticCurveBasepointArith curve
|
||||
|
||||
-- hash with specified parameters
|
||||
hashWithDom :: (HashAlgorithm hash, ByteArrayAccess ctx, ByteArrayAccess msg)
|
||||
=> proxy curve -> hash -> Bool -> ctx -> [Bytes] -> msg -> Bytes
|
||||
=> proxy curve -> hash -> Bool -> ctx -> Builder -> msg -> Bytes
|
||||
|
||||
-- conversion between scalar, point and public key
|
||||
pointPublic :: proxy curve -> Point curve -> PublicKey curve hash
|
||||
@ -111,7 +112,7 @@ class ( EllipticCurveBasepointArith curve
|
||||
=> proxy curve
|
||||
-> hash
|
||||
-> SecretKey curve
|
||||
-> (Scalar curve, Bytes)
|
||||
-> (Scalar curve, View Bytes)
|
||||
|
||||
-- | Size of public keys for this curve (in bytes)
|
||||
publicKeySize :: EllipticCurveEdDSA curve => proxy curve -> Int
|
||||
@ -255,7 +256,7 @@ signPhCtx :: forall proxy curve hash ctx msg .
|
||||
signPhCtx prx ph ctx priv pub msg =
|
||||
let alg = undefined :: hash
|
||||
(s, prefix) = scheduleSecret prx alg priv
|
||||
digR = hashWithDom prx alg ph ctx [prefix] msg
|
||||
digR = hashWithDom prx alg ph ctx (bytes prefix) msg
|
||||
r = decodeScalarNoErr prx digR
|
||||
pR = pointBaseSmul prx r
|
||||
bsR = encodePoint prx pR
|
||||
@ -295,19 +296,18 @@ getK :: forall proxy curve hash ctx msg .
|
||||
=> proxy curve -> Bool -> ctx -> PublicKey curve hash -> Bytes -> msg -> Scalar curve
|
||||
getK prx ph ctx (PublicKey pub) bsR msg =
|
||||
let alg = undefined :: hash
|
||||
digK = hashWithDom prx alg ph ctx [bsR, pub] msg
|
||||
digK = hashWithDom prx alg ph ctx (bytes bsR <+> bytes pub) msg
|
||||
in decodeScalarNoErr prx digK
|
||||
|
||||
encodeSignature :: EllipticCurveEdDSA curve
|
||||
=> proxy curve
|
||||
-> (Bytes, Point curve, Scalar curve)
|
||||
-> Signature curve hash
|
||||
encodeSignature prx (bsR, _, sS) = Signature $
|
||||
if len0 > 0 then B.concat [ bsR, bsS, pad0 ] else B.append bsR bsS
|
||||
encodeSignature prx (bsR, _, sS) = Signature $ buildAndFreeze $
|
||||
bytes bsR <+> bytes bsS <+> zero len0
|
||||
where
|
||||
bsS = encodeScalarLE prx sS
|
||||
bsS = encodeScalarLE prx sS :: Bytes
|
||||
len0 = signatureSize prx - B.length bsR - B.length bsS
|
||||
pad0 = B.zero len0
|
||||
|
||||
decodeSignature :: ( EllipticCurveEdDSA curve
|
||||
, HashDigestSize hash ~ CurveDigestSize curve
|
||||
@ -339,12 +339,11 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where
|
||||
|
||||
hashWithDom _ alg ph ctx bss
|
||||
| not ph && B.null ctx = digestDomMsg alg bss
|
||||
| otherwise = digestDomMsg alg (bs:bss)
|
||||
where bs = B.concat [ "SigEd25519 no Ed25519 collisions" :: ByteString
|
||||
, B.singleton $ if ph then 1 else 0
|
||||
, B.singleton $ fromIntegral $ B.length ctx
|
||||
, B.convert ctx
|
||||
]
|
||||
| otherwise = digestDomMsg alg (dom <+> bss)
|
||||
where dom = bytes ("SigEd25519 no Ed25519 collisions" :: ByteString) <+>
|
||||
byte (if ph then 1 else 0) <+>
|
||||
byte (fromIntegral $ B.length ctx) <+>
|
||||
bytes ctx
|
||||
|
||||
pointPublic _ = PublicKey . Edwards25519.pointEncode
|
||||
publicPoint _ = Edwards25519.pointDecode
|
||||
@ -352,7 +351,7 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where
|
||||
decodeScalarLE _ = Edwards25519.scalarDecodeLong
|
||||
|
||||
scheduleSecret prx alg priv =
|
||||
(decodeScalarNoErr prx clamped, B.drop 32 hashed)
|
||||
(decodeScalarNoErr prx clamped, B.dropView hashed 32)
|
||||
where
|
||||
hashed = digest alg ($ priv)
|
||||
|
||||
@ -377,9 +376,9 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where
|
||||
-}
|
||||
|
||||
digestDomMsg :: (HashAlgorithm alg, ByteArrayAccess msg)
|
||||
=> alg -> [Bytes] -> msg -> Bytes
|
||||
=> alg -> Builder -> msg -> Bytes
|
||||
digestDomMsg alg bss bs = digest alg $ \update ->
|
||||
update (B.concat bss :: Bytes) >> update bs
|
||||
update (buildAndFreeze bss :: Bytes) >> update bs
|
||||
|
||||
digest :: HashAlgorithm alg
|
||||
=> alg
|
||||
|
||||
@ -230,6 +230,7 @@ Library
|
||||
Crypto.PubKey.ElGamal
|
||||
Crypto.ECC.Simple.Types
|
||||
Crypto.ECC.Simple.Prim
|
||||
Crypto.Internal.Builder
|
||||
Crypto.Internal.ByteArray
|
||||
Crypto.Internal.Compat
|
||||
Crypto.Internal.CompatPrim
|
||||
|
||||
Loading…
Reference in New Issue
Block a user