refactor some stuff

This commit is contained in:
Vincent Hanquez 2015-04-08 14:58:49 +01:00
parent ca125f3e66
commit 39ee0a4aa2
3 changed files with 104 additions and 21 deletions

View File

@ -50,8 +50,9 @@ import Foreign.Storable
-- | an IV parametrized by the cipher
data IV c = forall byteArray . ByteArray byteArray => IV byteArray
instance BlockCipher c => ByteArray (IV c) where
instance BlockCipher c => ByteArrayAccess (IV c) where
withByteArray (IV z) f = withByteArray z f
byteArrayLength (IV z) = byteArrayLength z
type XTS cipher = (cipher, cipher)
-> IV cipher -- ^ Usually represent the Data Unit (e.g. disk sector)

View File

@ -12,6 +12,7 @@
{-# LANGUAGE UnboxedTuples #-}
module Crypto.Internal.ByteArray
( ByteArray(..)
, ByteArrayAccess(..)
, byteArrayAllocAndFreeze
, empty
, byteArrayCopyAndFreeze
@ -36,32 +37,37 @@ import Data.ByteString (ByteString)
import qualified Data.ByteString as B (length)
import qualified Data.ByteString.Internal as B
class ByteArray ba where
byteArrayAlloc :: Int -> (Ptr p -> IO ()) -> IO ba
class ByteArrayAccess ba where
byteArrayLength :: ba -> Int
withByteArray :: ba -> (Ptr p -> IO a) -> IO a
instance ByteArray Bytes where
byteArrayAlloc = bytesAlloc
class ByteArrayAccess ba => ByteArray ba where
byteArrayAlloc :: Int -> (Ptr p -> IO ()) -> IO ba
instance ByteArrayAccess Bytes where
byteArrayLength = bytesLength
withByteArray = withBytes
instance ByteArray Bytes where
byteArrayAlloc = bytesAlloc
instance ByteArrayAccess ByteString where
byteArrayLength = B.length
withByteArray b f = withForeignPtr fptr $ \ptr -> f (ptr `plusPtr` off)
where (fptr, off, _) = B.toForeignPtr b
instance ByteArray ByteString where
byteArrayAlloc sz f = do
fptr <- B.mallocByteString sz
withForeignPtr fptr (f . castPtr)
return $! B.PS fptr 0 sz
byteArrayLength = B.length
withByteArray b f = withForeignPtr fptr $ \ptr -> f (ptr `plusPtr` off)
where (fptr, off, _) = B.toForeignPtr b
instance ByteArrayAccess SecureMem where
byteArrayLength = secureMemGetSize
withByteArray b f = withSecureMemPtr b (f . castPtr)
instance ByteArray SecureMem where
byteArrayAlloc sz f = do
out <- allocateSecureMem sz
withSecureMemPtr out (f . castPtr)
return out
byteArrayLength = secureMemGetSize
withByteArray b f = withSecureMemPtr b (f . castPtr)
byteArrayAllocAndFreeze :: ByteArray a => Int -> (Ptr p -> IO ()) -> a
byteArrayAllocAndFreeze sz f = unsafeDoIO (byteArrayAlloc sz f)
@ -72,7 +78,7 @@ empty = unsafeDoIO (byteArrayAlloc 0 $ \_ -> return ())
-- | Create a xor of bytes between a and b.
--
-- the returns byte array is the size of the smallest input.
byteArrayXor :: (ByteArray a, ByteArray b, ByteArray c) => a -> b -> c
byteArrayXor :: (ByteArrayAccess a, ByteArrayAccess b, ByteArray c) => a -> b -> c
byteArrayXor a b =
byteArrayAllocAndFreeze n $ \pc ->
withByteArray a $ \pa ->
@ -122,4 +128,5 @@ byteArrayToW64BE :: ByteArray bs => bs -> Int -> Word64
byteArrayToW64BE bs ofs = unsafeDoIO $ withByteArray bs $ \p -> fromBE64 <$> peek (p `plusPtr` ofs)
-- move me elsewhere. not working properly for big endian machine, as it should be id
fromBE64 :: Word64 -> Word64
fromBE64 = byteSwap64

View File

@ -92,6 +92,7 @@ data KATs = KATs
defaultKATs = KATs [] [] [] [] [] []
{-
testECB (_, _, cipherInit) ecbEncrypt ecbDecrypt kats =
testGroup "ECB" (concatMap katTest (zip is kats) {- ++ propTests-})
where katTest (i,d) =
@ -145,6 +146,67 @@ testKatAEAD cipherInit aeadInit aeadAppendHeader aeadEncrypt aeadDecrypt aeadFin
(dbs,aeadDFinal) = aeadDecrypt aeadHeaded (aeadCiphertext d)
etag = aeadFinalize aeadEFinal (aeadTaglen d)
dtag = aeadFinalize aeadDFinal (aeadTaglen d)
-}
testKATs :: BlockCipher cipher
=> KATs
-> cipher
-> TestTree
testKATs kats cipher = testGroup "KAT"
( maybeGroup makeECBTest "ECB" (kat_ECB kats)
++ maybeGroup makeCBCTest "CBC" (kat_CBC kats)
++ maybeGroup makeCFBTest "CFB" (kat_CFB kats)
++ maybeGroup makeCTRTest "CTR" (kat_CTR kats)
-- ++ maybeGroup makeXTSTest "XTS" (kat_XTS kats)
-- ++ maybeGroup makeAEADTest "AEAD" (kat_AEAD kats)
)
where makeECBTest i d =
[ testCase ("E" ++ i) (ecbEncrypt ctx (ecbPlaintext d) @?= ecbCiphertext d)
, testCase ("D" ++ i) (ecbDecrypt ctx (ecbCiphertext d) @?= ecbPlaintext d)
]
where ctx = cipherInit (cipherMakeKey cipher $ ecbKey d)
makeCBCTest i d =
[ testCase ("E" ++ i) (cbcEncrypt ctx iv (cbcPlaintext d) @?= cbcCiphertext d)
, testCase ("D" ++ i) (cbcDecrypt ctx iv (cbcCiphertext d) @?= cbcPlaintext d)
]
where ctx = cipherInit (cipherMakeKey cipher $ cbcKey d)
iv = cipherMakeIV cipher $ cbcIV d
makeCFBTest i d =
[ testCase ("E" ++ i) (cfbEncrypt ctx iv (cfbPlaintext d) @?= cfbCiphertext d)
, testCase ("D" ++ i) (cfbDecrypt ctx iv (cfbCiphertext d) @?= cfbPlaintext d)
]
where ctx = cipherInit (cipherMakeKey cipher $ cfbKey d)
iv = cipherMakeIV cipher $ cfbIV d
makeCTRTest i d =
[ testCase ("E" ++ i) (ctrCombine ctx iv (ctrPlaintext d) @?= ctrCiphertext d)
, testCase ("D" ++ i) (ctrCombine ctx iv (ctrCiphertext d) @?= ctrPlaintext d)
]
where ctx = cipherInit (cipherMakeKey cipher $ ctrKey d)
iv = cipherMakeIV cipher $ ctrIV d
{-
makeXTSTest i d =
[ testCase ("E" ++ i) (xtsEncrypt ctx iv 0 (xtsPlaintext d) @?= xtsCiphertext d)
, testCase ("D" ++ i) (xtsDecrypt ctx iv 0 (xtsCiphertext d) @?= xtsPlaintext d)
]
where ctx1 = cipherInit (cipherMakeKey cipher $ xtsKey1 d)
ctx2 = cipherInit (cipherMakeKey cipher $ xtsKey2 d)
ctx = (ctx1, ctx2)
iv = cipherMakeIV cipher $ xtsIV d
makeAEADTest i d =
[ testCase ("AE" ++ i) (etag @?= aeadTag d)
, testCase ("AD" ++ i) (dtag @?= aeadTag d)
, testCase ("E" ++ i) (ebs @?= aeadCiphertext d)
, testCase ("D" ++ i) (dbs @?= aeadPlaintext d)
]
where ctx = cipherInit (cipherMakeKey cipher $ aeadKey d)
aead = maybe (error $ "cipher doesn't support aead mode: " ++ show (aeadMode d)) id
$ aeadInit (aeadMode d) ctx (aeadIV d)
aeadHeaded = aeadAppendHeader aead (aeadHeader d)
(ebs,aeadEFinal) = aeadEncrypt aeadHeaded (aeadPlaintext d)
(dbs,aeadDFinal) = aeadDecrypt aeadHeaded (aeadCiphertext d)
etag = aeadFinalize aeadEFinal (aeadTaglen d)
dtag = aeadFinalize aeadDFinal (aeadTaglen d)
-}
------------------------------------------------------------------------
-- Properties
@ -203,11 +265,11 @@ instance Show (CFB8Unit a) where
instance Show (CTRUnit a) where
show (CTRUnit key iv b) = "CTR(key=" ++ show key ++ ",iv=" ++ show (unPlaintext iv) ++ ",input=" ++ show b ++ ")"
instance Show (XTSUnit a) where
show (XTSUnit key1 key2 iv b) = "XTS(key1=" ++ show (toBytes key1) ++ ",key2=" ++ show (toBytes key2) ++ ",iv=" ++ show (toBytes iv) ++ ",input=" ++ show b ++ ")"
show (XTSUnit key1 key2 iv b) = "XTS(key1=" ++ show key1 ++ ",key2=" ++ show key2 ++ ",iv=" ++ show (unPlaintext iv) ++ ",input=" ++ show b ++ ")"
instance Show (AEADUnit a) where
show (AEADUnit key iv aad b) = "AEAD(key=" ++ show (toBytes key) ++ ",iv=" ++ show iv ++ ",aad=" ++ show (toBytes aad) ++ ",input=" ++ show b ++ ")"
show (AEADUnit key iv aad b) = "AEAD(key=" ++ show key ++ ",iv=" ++ show iv ++ ",aad=" ++ show (unPlaintext aad) ++ ",input=" ++ show b ++ ")"
instance Show (StreamUnit a) where
show (StreamUnit key b) = "Stream(key=" ++ show (toBytes key) ++ ",input=" ++ show b ++ ")"
show (StreamUnit key b) = "Stream(key=" ++ show key ++ ",input=" ++ show b ++ ")"
-- | Generate an arbitrary valid key for a specific block cipher
generateKey :: Cipher a => Gen (Key a)
@ -281,7 +343,7 @@ testBlockCipherBasic cipher = [ testProperty "ECB" ecbProp ]
where ecbProp = toTests cipher
toTests :: BlockCipher a => a -> (ECBUnit a -> Bool)
toTests _ = testProperty_ECB
testProperty_ECB (ECBUnit (cipherInit -> ctx) (toBytes -> plaintext)) =
testProperty_ECB (ECBUnit (cipherInit -> ctx) (unPlaintext -> plaintext)) =
plaintext `assertEq` ecbDecrypt ctx (ecbEncrypt ctx plaintext)
testBlockCipherModes :: BlockCipher a => a -> [TestTree]
@ -300,18 +362,18 @@ testBlockCipherModes cipher =
--,testProperty_CFB8
,testProperty_CTR
)
testProperty_CBC (CBCUnit (cipherInit -> ctx) testIV (toBytes -> plaintext)) =
testProperty_CBC (CBCUnit (cipherInit -> ctx) testIV (unPlaintext -> plaintext)) =
plaintext `assertEq` cbcDecrypt ctx testIV (cbcEncrypt ctx testIV plaintext)
testProperty_CFB (CFBUnit (cipherInit -> ctx) testIV (toBytes -> plaintext)) =
testProperty_CFB (CFBUnit (cipherInit -> ctx) testIV (unPlaintext -> plaintext)) =
plaintext `assertEq` cfbDecrypt ctx testIV (cfbEncrypt ctx testIV plaintext)
{-
testProperty_CFB8 (CFB8Unit (cipherInit -> ctx) testIV (toBytes -> plaintext)) =
testProperty_CFB8 (CFB8Unit (cipherInit -> ctx) testIV (unPlaintext -> plaintext)) =
plaintext `assertEq` cfb8Decrypt ctx testIV (cfb8Encrypt ctx testIV plaintext)
-}
testProperty_CTR (CTRUnit (cipherInit -> ctx) testIV (toBytes -> plaintext)) =
testProperty_CTR (CTRUnit (cipherInit -> ctx) testIV (unPlaintext -> plaintext)) =
plaintext `assertEq` ctrCombine ctx testIV (ctrCombine ctx testIV plaintext)
testBlockCipherAEAD :: BlockCipher a => a -> [TestTree]
@ -325,7 +387,7 @@ testBlockCipherAEAD cipher =
where aeadProp = toTests cipher
toTests :: BlockCipher a => a -> (AEADMode -> AEADUnit a -> Bool)
toTests _ = testProperty_AEAD
testProperty_AEAD mode (AEADUnit (cipherInit -> ctx) testIV (toBytes -> aad) (toBytes -> plaintext)) =
testProperty_AEAD mode (AEADUnit (cipherInit -> ctx) testIV (unPlaintext -> aad) (unPlaintext -> plaintext)) =
case aeadInit mode ctx testIV of
Just iniAead ->
let aead = aeadAppendHeader iniAead aad
@ -368,5 +430,18 @@ assertEq :: ByteString -> ByteString -> Bool
assertEq b1 b2 | b1 /= b2 = error ("b1: " ++ show b1 ++ " b2: " ++ show b2)
| otherwise = True
cipherMakeKey :: Cipher cipher => cipher -> ByteString -> Key cipher
cipherMakeKey c bs = bs
cipherMakeIV :: BlockCipher cipher => cipher -> ByteString -> IV cipher
cipherMakeIV _ bs = fromJust $ makeIV bs
maybeGroup :: (String -> t -> [TestTree]) -> TestName -> [t] -> [TestTree]
maybeGroup mkTest groupName l
| null l = []
| otherwise = [testGroup groupName (concatMap (\(i, d) -> mkTest (show i) d) $ zip nbs l)]
where nbs :: [Int]
nbs = [0..]
is :: [Int]
is = [1..]