diff --git a/Crypto/Cipher/ChaCha.hs b/Crypto/Cipher/ChaCha.hs index af53d21..aeac3fc 100644 --- a/Crypto/Cipher/ChaCha.hs +++ b/Crypto/Cipher/ChaCha.hs @@ -110,7 +110,7 @@ combine prev@(State nbRounds prevSt prevOut) src (dstPtr `plusPtr` prevBufLen) (castPtr stPtr) (srcPtr `plusPtr` prevBufLen) - (fromIntegral newBytesToGenerate) + (fromIntegral adjustedLen) -- return combined byte return ( BS.PS fptr 0 outputLen diff --git a/Crypto/Cipher/Salsa.hs b/Crypto/Cipher/Salsa.hs index 75be3e5..3252977 100644 --- a/Crypto/Cipher/Salsa.hs +++ b/Crypto/Cipher/Salsa.hs @@ -21,7 +21,6 @@ import qualified Data.ByteString.Internal as BS import qualified Data.ByteString as BS import Crypto.Internal.Compat import Crypto.Internal.Imports -import Data.Bits (xor) import Foreign.Ptr import Foreign.ForeignPtr import Foreign.C.Types @@ -29,6 +28,7 @@ import Foreign.C.Types -- | Salsa context data State = State Int -- number of rounds ScrubbedBytes -- Salsa's state + Int -- offset of data in previously generated chunk ByteString -- previous generated chunk round64 :: Int -> (Bool, Int) @@ -54,7 +54,7 @@ initialize nbRounds key nonce B.withByteArray nonce $ \noncePtr -> B.withByteArray key $ \keyPtr -> ccryptonite_salsa_init stPtr kLen keyPtr nonceLen noncePtr - return $ State nbRounds stPtr B.empty + return $ State nbRounds stPtr 0 B.empty where kLen = B.length key nonceLen = B.length nonce @@ -64,13 +64,15 @@ combine :: ByteArray ba => State -- ^ the current Salsa state -> ba -- ^ the source to xor with the generator -> (ba, State) -combine prev@(State nbRounds prevSt prevOut) src - | outputLen == 0 = (B.empty, prev) - | outputLen <= prevBufLen = +combine prev@(State nbRounds prevSt prevOffset prevOut) src + | outputLen == 0 = (B.empty, prev) + | outputLen <= prevBufLen = unsafeDoIO $ do -- we have enough byte in the previous buffer to complete the query -- without having to generate any extra bytes - let (b1,b2) = BS.splitAt outputLen prevOut - in (B.convert $ BS.pack $ BS.zipWith xor b1 (B.convert src), State nbRounds prevSt b2) + output <- B.copy src $ \dst -> + B.withByteArray prevOut $ \prevPtr -> + memXor dst dst (prevPtr `plusPtr` prevOffset) outputLen + return (output, State nbRounds prevSt (prevOffset + outputLen) prevOut) | otherwise = unsafeDoIO $ do -- adjusted len is the number of bytes lefts to generate after -- copying from the previous buffer. @@ -83,7 +85,7 @@ combine prev@(State nbRounds prevSt prevOut) src B.withByteArray src $ \srcPtr -> do -- copy the previous buffer by xor if any B.withByteArray prevOut $ \prevPtr -> - memXor dstPtr srcPtr prevPtr prevBufLen + memXor dstPtr srcPtr (prevPtr `plusPtr` prevOffset) prevBufLen -- then create a new mutable copy of state B.copy prevSt $ \stPtr -> @@ -91,13 +93,13 @@ combine prev@(State nbRounds prevSt prevOut) src (dstPtr `plusPtr` prevBufLen) (castPtr stPtr) (srcPtr `plusPtr` prevBufLen) - (fromIntegral newBytesToGenerate) + (fromIntegral adjustedLen) -- return combined byte return ( B.convert (BS.PS fptr 0 outputLen) - , State nbRounds newSt (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen)) + , State nbRounds newSt 0 (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen)) where outputLen = B.length src - prevBufLen = B.length prevOut + prevBufLen = B.length prevOut - prevOffset -- | Generate a number of bytes from the Salsa output directly -- diff --git a/tests/ChaCha.hs b/tests/ChaCha.hs index 50816ea..9e10896 100644 --- a/tests/ChaCha.hs +++ b/tests/ChaCha.hs @@ -23,6 +23,14 @@ b12_256_k0_i0 = b20_256_k0_i0 = "\x76\xb8\xe0\xad\xa0\xf1\x3d\x90\x40\x5d\x6a\xe5\x53\x86\xbd\x28\xbd\xd2\x19\xb8\xa0\x8d\xed\x1a\xa8\x36\xef\xcc\x8b\x77\x0d\xc7\xda\x41\x59\x7c\x51\x57\x48\x8d\x77\x24\xe0\x3f\xb8\xd8\x4a\x37\x6a\x43\xb8\xf4\x15\x18\xa1\x1c\xc3\x87\xb6\x69\xb2\xee\x65\x86\x9f\x07\xe7\xbe\x55\x51\x38\x7a\x98\xba\x97\x7c\x73\x2d\x08\x0d\xcb\x0f\x29\xa0\x48\xe3\x65\x69\x12\xc6\x53\x3e\x32\xee\x7a\xed\x29\xb7\x21\x76\x9c\xe6\x4e\x43\xd5\x71\x33\xb0\x74\xd8\x39\xd5\x31\xed\x1f\x28\x51\x0a\xfb\x45\xac\xe1\x0a\x1f\x4b\x79\x4d\x6f" +data Vector = Vector Int -- rounds + ByteString -- key + ByteString -- nonce + deriving (Show,Eq) + +instance Arbitrary Vector where + arbitrary = Vector 20 <$> arbitraryBS 16 <*> arbitraryBS 12 + tests = testGroup "ChaCha" [ testCase "8-128-K0-I0" (chachaRunSimple b8_128_k0_i0 8 16 8) , testCase "12-128-K0-I0" (chachaRunSimple b12_128_k0_i0 12 16 8) @@ -30,7 +38,23 @@ tests = testGroup "ChaCha" , testCase "8-256-K0-I0" (chachaRunSimple b8_256_k0_i0 8 32 8) , testCase "12-256-K0-I0" (chachaRunSimple b12_256_k0_i0 12 32 8) , testCase "20-256-K0-I0" (chachaRunSimple b20_256_k0_i0 20 32 8) + , testProperty "chunking" chachaChunks ] where chachaRunSimple expected rounds klen nonceLen = let chacha = ChaCha.initialize rounds (B.replicate klen 0) (B.replicate nonceLen 0) in expected @=? fst (ChaCha.generate chacha (B.length expected)) + + chachaChunks :: ChunkingLen -> Vector -> Bool + chachaChunks (ChunkingLen ckLen) (Vector rounds key iv) = + let initChaCha = ChaCha.initialize rounds key iv + nbBytes = 1048 + (expected,_) = ChaCha.generate initChaCha nbBytes + chunks = loop nbBytes ckLen (ChaCha.initialize rounds key iv) + in expected == B.concat chunks + + where loop n [] chacha = loop n ckLen chacha + loop 0 _ _ = [] + loop n (x:xs) chacha = + let len = min x n + (c, next) = ChaCha.generate chacha len + in c : loop (n - len) xs next