From 343b7593b595003fb547c77c9419e342e2110d3f Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Tue, 14 Feb 2017 15:26:06 +0000 Subject: [PATCH] add Constraint for divisibility --- Crypto/Hash/SHAKE.hs | 26 +++++++----- Crypto/Internal/Nat.hs | 95 ++++++++++++++++++++++++++++++++++++++++++ cryptonite.cabal | 1 + 3 files changed, 111 insertions(+), 11 deletions(-) create mode 100644 Crypto/Internal/Nat.hs diff --git a/Crypto/Hash/SHAKE.hs b/Crypto/Hash/SHAKE.hs index 2262f74..2e4f3e4 100644 --- a/Crypto/Hash/SHAKE.hs +++ b/Crypto/Hash/SHAKE.hs @@ -12,7 +12,11 @@ {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} module Crypto.Hash.SHAKE ( SHAKE128 (..), SHAKE256 (..) ) where @@ -24,6 +28,7 @@ import Data.Word (Word8, Word32) import Data.Proxy (Proxy(..)) import GHC.TypeLits (Nat, KnownNat, natVal) +import Crypto.Internal.Nat -- | SHAKE128 (128 bits) extendable output function. Supports an arbitrary -- digest size (multiple of 8 bits), to be specified as a type parameter @@ -35,13 +40,13 @@ import GHC.TypeLits (Nat, KnownNat, natVal) data SHAKE128 (bitlen :: Nat) = SHAKE128 deriving (Show, Typeable) -instance KnownNat bitlen => HashAlgorithm (SHAKE128 bitlen) where +instance (IsDivisibleBy8 bitLen, KnownNat bitLen) => HashAlgorithm (SHAKE128 bitLen) where hashBlockSize _ = 168 - hashDigestSize _ = byteLen (Proxy :: Proxy bitlen) + hashDigestSize _ = byteLen (Proxy :: Proxy bitLen) hashInternalContextSize _ = 376 hashInternalInit p = c_sha3_init p 128 hashInternalUpdate = c_sha3_update - hashInternalFinalize = shakeFinalizeOutput (Proxy :: Proxy bitlen) + hashInternalFinalize = shakeFinalizeOutput (Proxy :: Proxy bitLen) -- | SHAKE256 (256 bits) extendable output function. Supports an arbitrary -- digest size (multiple of 8 bits), to be specified as a type parameter @@ -53,27 +58,26 @@ instance KnownNat bitlen => HashAlgorithm (SHAKE128 bitlen) where data SHAKE256 (bitlen :: Nat) = SHAKE256 deriving (Show, Typeable) -instance KnownNat bitlen => HashAlgorithm (SHAKE256 bitlen) where +instance (IsDivisibleBy8 bitLen, KnownNat bitLen) => HashAlgorithm (SHAKE256 bitLen) where hashBlockSize _ = 136 - hashDigestSize _ = byteLen (Proxy :: Proxy bitlen) + hashDigestSize _ = byteLen (Proxy :: Proxy bitLen) hashInternalContextSize _ = 344 hashInternalInit p = c_sha3_init p 256 hashInternalUpdate = c_sha3_update - hashInternalFinalize = shakeFinalizeOutput (Proxy :: Proxy bitlen) + hashInternalFinalize = shakeFinalizeOutput (Proxy :: Proxy bitLen) -shakeFinalizeOutput :: KnownNat bitlen - => proxy bitlen +shakeFinalizeOutput :: (IsDivisibleBy8 bitLen, KnownNat bitLen) + => proxy bitLen -> Ptr (Context a) -> Ptr (Digest a) -> IO () shakeFinalizeOutput d ctx dig = do c_sha3_finalize_shake ctx - c_sha3_output ctx dig (byteLen d) + c_sha3_output ctx dig (fromInteger (natVal d `div` 8)) -byteLen :: (KnownNat bitlen, Num a) => proxy bitlen -> a +byteLen :: (KnownNat bitlen, IsDivisibleBy8 bitlen, Num a) => proxy bitlen -> a byteLen d = fromInteger (natVal d `div` 8) - foreign import ccall unsafe "cryptonite_sha3_init" c_sha3_init :: Ptr (Context a) -> Word32 -> IO () diff --git a/Crypto/Internal/Nat.hs b/Crypto/Internal/Nat.hs new file mode 100644 index 0000000..a01b1d2 --- /dev/null +++ b/Crypto/Internal/Nat.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module Crypto.Internal.Nat + ( type IsDivisibleBy8 + ) where + +import GHC.TypeLits (Nat, KnownNat, natVal, type (+), type (-), TypeError, ErrorMessage(..)) + +type family IsDiv8 (bitLen :: Nat) (n :: Nat) where + IsDiv8 bitLen 0 = 'True + IsDiv8 bitLen 1 = TypeError ('Text "bitLen " ':<>: 'ShowType bitLen ':<>: 'Text " is not divisible by 8") + IsDiv8 bitLen 2 = TypeError ('Text "bitLen " ':<>: 'ShowType bitLen ':<>: 'Text " is not divisible by 8") + IsDiv8 bitLen 3 = TypeError ('Text "bitLen " ':<>: 'ShowType bitLen ':<>: 'Text " is not divisible by 8") + IsDiv8 bitLen 4 = TypeError ('Text "bitLen " ':<>: 'ShowType bitLen ':<>: 'Text " is not divisible by 8") + IsDiv8 bitLen 5 = TypeError ('Text "bitLen " ':<>: 'ShowType bitLen ':<>: 'Text " is not divisible by 8") + IsDiv8 bitLen 6 = TypeError ('Text "bitLen " ':<>: 'ShowType bitLen ':<>: 'Text " is not divisible by 8") + IsDiv8 bitLen 7 = TypeError ('Text "bitLen " ':<>: 'ShowType bitLen ':<>: 'Text " is not divisible by 8") + IsDiv8 bitLen n = IsDiv8 n (Mod8 n) + +type family Mod8 (n :: Nat) where + Mod8 0 = 0 + Mod8 1 = 1 + Mod8 2 = 2 + Mod8 3 = 3 + Mod8 4 = 4 + Mod8 5 = 5 + Mod8 6 = 6 + Mod8 7 = 7 + Mod8 8 = 0 + Mod8 9 = 1 + Mod8 10 = 2 + Mod8 11 = 3 + Mod8 12 = 4 + Mod8 13 = 5 + Mod8 14 = 6 + Mod8 15 = 7 + Mod8 16 = 0 + Mod8 17 = 1 + Mod8 18 = 2 + Mod8 19 = 3 + Mod8 20 = 4 + Mod8 21 = 5 + Mod8 22 = 6 + Mod8 23 = 7 + Mod8 24 = 0 + Mod8 25 = 1 + Mod8 26 = 2 + Mod8 27 = 3 + Mod8 28 = 4 + Mod8 29 = 5 + Mod8 30 = 6 + Mod8 31 = 7 + Mod8 32 = 0 + Mod8 33 = 1 + Mod8 34 = 2 + Mod8 35 = 3 + Mod8 36 = 4 + Mod8 37 = 5 + Mod8 38 = 6 + Mod8 39 = 7 + Mod8 40 = 0 + Mod8 41 = 1 + Mod8 42 = 2 + Mod8 43 = 3 + Mod8 44 = 4 + Mod8 45 = 5 + Mod8 46 = 6 + Mod8 47 = 7 + Mod8 48 = 0 + Mod8 49 = 1 + Mod8 50 = 2 + Mod8 51 = 3 + Mod8 52 = 4 + Mod8 53 = 5 + Mod8 54 = 6 + Mod8 55 = 7 + Mod8 56 = 0 + Mod8 57 = 1 + Mod8 58 = 2 + Mod8 59 = 3 + Mod8 60 = 4 + Mod8 61 = 5 + Mod8 62 = 6 + Mod8 63 = 7 + Mod8 n = Mod8 (n - 64) + +type IsDivisibleBy8 bitLen = IsDiv8 bitLen bitLen ~ 'True + diff --git a/cryptonite.cabal b/cryptonite.cabal index 5dc616c..de95e37 100644 --- a/cryptonite.cabal +++ b/cryptonite.cabal @@ -205,6 +205,7 @@ Library Crypto.Internal.WordArray if impl(ghc >= 7.8) Other-modules: Crypto.Hash.SHAKE + Crypto.Internal.Nat Build-depends: base >= 4.3 && < 5 , bytestring , memory >= 0.12