From 39ee0a4aa2799196da13fb39c0a3603ccaa09203 Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Wed, 8 Apr 2015 14:58:49 +0100 Subject: [PATCH] refactor some stuff --- Crypto/Cipher/Types/Block.hs | 5 +- Crypto/Internal/ByteArray.hs | 27 +++++++---- tests/BlockCipher.hs | 93 ++++++++++++++++++++++++++++++++---- 3 files changed, 104 insertions(+), 21 deletions(-) diff --git a/Crypto/Cipher/Types/Block.hs b/Crypto/Cipher/Types/Block.hs index 5906ee2..ee438cf 100644 --- a/Crypto/Cipher/Types/Block.hs +++ b/Crypto/Cipher/Types/Block.hs @@ -50,8 +50,9 @@ import Foreign.Storable -- | an IV parametrized by the cipher data IV c = forall byteArray . ByteArray byteArray => IV byteArray -instance BlockCipher c => ByteArray (IV c) where - +instance BlockCipher c => ByteArrayAccess (IV c) where + withByteArray (IV z) f = withByteArray z f + byteArrayLength (IV z) = byteArrayLength z type XTS cipher = (cipher, cipher) -> IV cipher -- ^ Usually represent the Data Unit (e.g. disk sector) diff --git a/Crypto/Internal/ByteArray.hs b/Crypto/Internal/ByteArray.hs index 8cd5fce..c2ddb3e 100644 --- a/Crypto/Internal/ByteArray.hs +++ b/Crypto/Internal/ByteArray.hs @@ -12,6 +12,7 @@ {-# LANGUAGE UnboxedTuples #-} module Crypto.Internal.ByteArray ( ByteArray(..) + , ByteArrayAccess(..) , byteArrayAllocAndFreeze , empty , byteArrayCopyAndFreeze @@ -36,32 +37,37 @@ import Data.ByteString (ByteString) import qualified Data.ByteString as B (length) import qualified Data.ByteString.Internal as B -class ByteArray ba where - byteArrayAlloc :: Int -> (Ptr p -> IO ()) -> IO ba +class ByteArrayAccess ba where byteArrayLength :: ba -> Int withByteArray :: ba -> (Ptr p -> IO a) -> IO a -instance ByteArray Bytes where - byteArrayAlloc = bytesAlloc +class ByteArrayAccess ba => ByteArray ba where + byteArrayAlloc :: Int -> (Ptr p -> IO ()) -> IO ba + +instance ByteArrayAccess Bytes where byteArrayLength = bytesLength withByteArray = withBytes +instance ByteArray Bytes where + byteArrayAlloc = bytesAlloc +instance ByteArrayAccess ByteString where + byteArrayLength = B.length + withByteArray b f = withForeignPtr fptr $ \ptr -> f (ptr `plusPtr` off) + where (fptr, off, _) = B.toForeignPtr b instance ByteArray ByteString where byteArrayAlloc sz f = do fptr <- B.mallocByteString sz withForeignPtr fptr (f . castPtr) return $! B.PS fptr 0 sz - byteArrayLength = B.length - withByteArray b f = withForeignPtr fptr $ \ptr -> f (ptr `plusPtr` off) - where (fptr, off, _) = B.toForeignPtr b +instance ByteArrayAccess SecureMem where + byteArrayLength = secureMemGetSize + withByteArray b f = withSecureMemPtr b (f . castPtr) instance ByteArray SecureMem where byteArrayAlloc sz f = do out <- allocateSecureMem sz withSecureMemPtr out (f . castPtr) return out - byteArrayLength = secureMemGetSize - withByteArray b f = withSecureMemPtr b (f . castPtr) byteArrayAllocAndFreeze :: ByteArray a => Int -> (Ptr p -> IO ()) -> a byteArrayAllocAndFreeze sz f = unsafeDoIO (byteArrayAlloc sz f) @@ -72,7 +78,7 @@ empty = unsafeDoIO (byteArrayAlloc 0 $ \_ -> return ()) -- | Create a xor of bytes between a and b. -- -- the returns byte array is the size of the smallest input. -byteArrayXor :: (ByteArray a, ByteArray b, ByteArray c) => a -> b -> c +byteArrayXor :: (ByteArrayAccess a, ByteArrayAccess b, ByteArray c) => a -> b -> c byteArrayXor a b = byteArrayAllocAndFreeze n $ \pc -> withByteArray a $ \pa -> @@ -122,4 +128,5 @@ byteArrayToW64BE :: ByteArray bs => bs -> Int -> Word64 byteArrayToW64BE bs ofs = unsafeDoIO $ withByteArray bs $ \p -> fromBE64 <$> peek (p `plusPtr` ofs) -- move me elsewhere. not working properly for big endian machine, as it should be id +fromBE64 :: Word64 -> Word64 fromBE64 = byteSwap64 diff --git a/tests/BlockCipher.hs b/tests/BlockCipher.hs index b96370e..c1f28c8 100644 --- a/tests/BlockCipher.hs +++ b/tests/BlockCipher.hs @@ -92,6 +92,7 @@ data KATs = KATs defaultKATs = KATs [] [] [] [] [] [] +{- testECB (_, _, cipherInit) ecbEncrypt ecbDecrypt kats = testGroup "ECB" (concatMap katTest (zip is kats) {- ++ propTests-}) where katTest (i,d) = @@ -145,6 +146,67 @@ testKatAEAD cipherInit aeadInit aeadAppendHeader aeadEncrypt aeadDecrypt aeadFin (dbs,aeadDFinal) = aeadDecrypt aeadHeaded (aeadCiphertext d) etag = aeadFinalize aeadEFinal (aeadTaglen d) dtag = aeadFinalize aeadDFinal (aeadTaglen d) +-} + +testKATs :: BlockCipher cipher + => KATs + -> cipher + -> TestTree +testKATs kats cipher = testGroup "KAT" + ( maybeGroup makeECBTest "ECB" (kat_ECB kats) + ++ maybeGroup makeCBCTest "CBC" (kat_CBC kats) + ++ maybeGroup makeCFBTest "CFB" (kat_CFB kats) + ++ maybeGroup makeCTRTest "CTR" (kat_CTR kats) + -- ++ maybeGroup makeXTSTest "XTS" (kat_XTS kats) + -- ++ maybeGroup makeAEADTest "AEAD" (kat_AEAD kats) + ) + where makeECBTest i d = + [ testCase ("E" ++ i) (ecbEncrypt ctx (ecbPlaintext d) @?= ecbCiphertext d) + , testCase ("D" ++ i) (ecbDecrypt ctx (ecbCiphertext d) @?= ecbPlaintext d) + ] + where ctx = cipherInit (cipherMakeKey cipher $ ecbKey d) + makeCBCTest i d = + [ testCase ("E" ++ i) (cbcEncrypt ctx iv (cbcPlaintext d) @?= cbcCiphertext d) + , testCase ("D" ++ i) (cbcDecrypt ctx iv (cbcCiphertext d) @?= cbcPlaintext d) + ] + where ctx = cipherInit (cipherMakeKey cipher $ cbcKey d) + iv = cipherMakeIV cipher $ cbcIV d + makeCFBTest i d = + [ testCase ("E" ++ i) (cfbEncrypt ctx iv (cfbPlaintext d) @?= cfbCiphertext d) + , testCase ("D" ++ i) (cfbDecrypt ctx iv (cfbCiphertext d) @?= cfbPlaintext d) + ] + where ctx = cipherInit (cipherMakeKey cipher $ cfbKey d) + iv = cipherMakeIV cipher $ cfbIV d + makeCTRTest i d = + [ testCase ("E" ++ i) (ctrCombine ctx iv (ctrPlaintext d) @?= ctrCiphertext d) + , testCase ("D" ++ i) (ctrCombine ctx iv (ctrCiphertext d) @?= ctrPlaintext d) + ] + where ctx = cipherInit (cipherMakeKey cipher $ ctrKey d) + iv = cipherMakeIV cipher $ ctrIV d +{- + makeXTSTest i d = + [ testCase ("E" ++ i) (xtsEncrypt ctx iv 0 (xtsPlaintext d) @?= xtsCiphertext d) + , testCase ("D" ++ i) (xtsDecrypt ctx iv 0 (xtsCiphertext d) @?= xtsPlaintext d) + ] + where ctx1 = cipherInit (cipherMakeKey cipher $ xtsKey1 d) + ctx2 = cipherInit (cipherMakeKey cipher $ xtsKey2 d) + ctx = (ctx1, ctx2) + iv = cipherMakeIV cipher $ xtsIV d + makeAEADTest i d = + [ testCase ("AE" ++ i) (etag @?= aeadTag d) + , testCase ("AD" ++ i) (dtag @?= aeadTag d) + , testCase ("E" ++ i) (ebs @?= aeadCiphertext d) + , testCase ("D" ++ i) (dbs @?= aeadPlaintext d) + ] + where ctx = cipherInit (cipherMakeKey cipher $ aeadKey d) + aead = maybe (error $ "cipher doesn't support aead mode: " ++ show (aeadMode d)) id + $ aeadInit (aeadMode d) ctx (aeadIV d) + aeadHeaded = aeadAppendHeader aead (aeadHeader d) + (ebs,aeadEFinal) = aeadEncrypt aeadHeaded (aeadPlaintext d) + (dbs,aeadDFinal) = aeadDecrypt aeadHeaded (aeadCiphertext d) + etag = aeadFinalize aeadEFinal (aeadTaglen d) + dtag = aeadFinalize aeadDFinal (aeadTaglen d) +-} ------------------------------------------------------------------------ -- Properties @@ -203,11 +265,11 @@ instance Show (CFB8Unit a) where instance Show (CTRUnit a) where show (CTRUnit key iv b) = "CTR(key=" ++ show key ++ ",iv=" ++ show (unPlaintext iv) ++ ",input=" ++ show b ++ ")" instance Show (XTSUnit a) where - show (XTSUnit key1 key2 iv b) = "XTS(key1=" ++ show (toBytes key1) ++ ",key2=" ++ show (toBytes key2) ++ ",iv=" ++ show (toBytes iv) ++ ",input=" ++ show b ++ ")" + show (XTSUnit key1 key2 iv b) = "XTS(key1=" ++ show key1 ++ ",key2=" ++ show key2 ++ ",iv=" ++ show (unPlaintext iv) ++ ",input=" ++ show b ++ ")" instance Show (AEADUnit a) where - show (AEADUnit key iv aad b) = "AEAD(key=" ++ show (toBytes key) ++ ",iv=" ++ show iv ++ ",aad=" ++ show (toBytes aad) ++ ",input=" ++ show b ++ ")" + show (AEADUnit key iv aad b) = "AEAD(key=" ++ show key ++ ",iv=" ++ show iv ++ ",aad=" ++ show (unPlaintext aad) ++ ",input=" ++ show b ++ ")" instance Show (StreamUnit a) where - show (StreamUnit key b) = "Stream(key=" ++ show (toBytes key) ++ ",input=" ++ show b ++ ")" + show (StreamUnit key b) = "Stream(key=" ++ show key ++ ",input=" ++ show b ++ ")" -- | Generate an arbitrary valid key for a specific block cipher generateKey :: Cipher a => Gen (Key a) @@ -281,7 +343,7 @@ testBlockCipherBasic cipher = [ testProperty "ECB" ecbProp ] where ecbProp = toTests cipher toTests :: BlockCipher a => a -> (ECBUnit a -> Bool) toTests _ = testProperty_ECB - testProperty_ECB (ECBUnit (cipherInit -> ctx) (toBytes -> plaintext)) = + testProperty_ECB (ECBUnit (cipherInit -> ctx) (unPlaintext -> plaintext)) = plaintext `assertEq` ecbDecrypt ctx (ecbEncrypt ctx plaintext) testBlockCipherModes :: BlockCipher a => a -> [TestTree] @@ -300,18 +362,18 @@ testBlockCipherModes cipher = --,testProperty_CFB8 ,testProperty_CTR ) - testProperty_CBC (CBCUnit (cipherInit -> ctx) testIV (toBytes -> plaintext)) = + testProperty_CBC (CBCUnit (cipherInit -> ctx) testIV (unPlaintext -> plaintext)) = plaintext `assertEq` cbcDecrypt ctx testIV (cbcEncrypt ctx testIV plaintext) - testProperty_CFB (CFBUnit (cipherInit -> ctx) testIV (toBytes -> plaintext)) = + testProperty_CFB (CFBUnit (cipherInit -> ctx) testIV (unPlaintext -> plaintext)) = plaintext `assertEq` cfbDecrypt ctx testIV (cfbEncrypt ctx testIV plaintext) {- - testProperty_CFB8 (CFB8Unit (cipherInit -> ctx) testIV (toBytes -> plaintext)) = + testProperty_CFB8 (CFB8Unit (cipherInit -> ctx) testIV (unPlaintext -> plaintext)) = plaintext `assertEq` cfb8Decrypt ctx testIV (cfb8Encrypt ctx testIV plaintext) -} - testProperty_CTR (CTRUnit (cipherInit -> ctx) testIV (toBytes -> plaintext)) = + testProperty_CTR (CTRUnit (cipherInit -> ctx) testIV (unPlaintext -> plaintext)) = plaintext `assertEq` ctrCombine ctx testIV (ctrCombine ctx testIV plaintext) testBlockCipherAEAD :: BlockCipher a => a -> [TestTree] @@ -325,7 +387,7 @@ testBlockCipherAEAD cipher = where aeadProp = toTests cipher toTests :: BlockCipher a => a -> (AEADMode -> AEADUnit a -> Bool) toTests _ = testProperty_AEAD - testProperty_AEAD mode (AEADUnit (cipherInit -> ctx) testIV (toBytes -> aad) (toBytes -> plaintext)) = + testProperty_AEAD mode (AEADUnit (cipherInit -> ctx) testIV (unPlaintext -> aad) (unPlaintext -> plaintext)) = case aeadInit mode ctx testIV of Just iniAead -> let aead = aeadAppendHeader iniAead aad @@ -368,5 +430,18 @@ assertEq :: ByteString -> ByteString -> Bool assertEq b1 b2 | b1 /= b2 = error ("b1: " ++ show b1 ++ " b2: " ++ show b2) | otherwise = True +cipherMakeKey :: Cipher cipher => cipher -> ByteString -> Key cipher +cipherMakeKey c bs = bs + +cipherMakeIV :: BlockCipher cipher => cipher -> ByteString -> IV cipher +cipherMakeIV _ bs = fromJust $ makeIV bs + +maybeGroup :: (String -> t -> [TestTree]) -> TestName -> [t] -> [TestTree] +maybeGroup mkTest groupName l + | null l = [] + | otherwise = [testGroup groupName (concatMap (\(i, d) -> mkTest (show i) d) $ zip nbs l)] + where nbs :: [Int] + nbs = [0..] + is :: [Int] is = [1..]