[ChaCha] opaquify the state to be handled directly in the C level just like the hash function

increase the number of tests
This commit is contained in:
Vincent Hanquez 2015-05-22 14:04:27 +01:00
parent 5d85834264
commit 9a69c61e84
5 changed files with 217 additions and 124 deletions

View File

@ -17,33 +17,19 @@ module Crypto.Cipher.ChaCha
, StateSimple
) where
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes, withByteArray)
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Data.Memory.PtrMethods (memXor)
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.C.Types
-- | ChaCha context
data State = State Int -- number of rounds
ScrubbedBytes -- ChaCha's state
ByteString -- previous generated chunk
newtype State = State ScrubbedBytes
-- | ChaCha context for DRG purpose (see Crypto.Random.ChaChaDRG)
newtype StateSimple = StateSimple ScrubbedBytes -- just ChaCha's state
round64 :: Int -> (Bool, Int)
round64 len
| len == 0 = (True, 0)
| m == 0 = (True, len)
| otherwise = (False, len + (64 - m))
where m = len `mod` 64
-- | Initialize a new ChaCha context with the number of rounds,
-- the key and the nonce associated.
initialize :: (ByteArrayAccess key, ByteArrayAccess nonce)
@ -55,12 +41,12 @@ initialize nbRounds key nonce
| not (kLen `elem` [16,32]) = error "ChaCha: key length should be 128 or 256 bits"
| not (nonceLen `elem` [8,12]) = error "ChaCha: nonce length should be 64 or 96 bits"
| not (nbRounds `elem` [8,12,20]) = error "ChaCha: rounds should be 8, 12 or 20"
| otherwise = unsafeDoIO $ do
stPtr <- B.alloc 64 $ \stPtr ->
withByteArray nonce $ \noncePtr ->
withByteArray key $ \keyPtr ->
ccryptonite_chacha_init (castPtr stPtr) kLen keyPtr nonceLen noncePtr
return $ State nbRounds stPtr B.empty
| otherwise = unsafeDoIO $ do
stPtr <- B.alloc 132 $ \stPtr ->
B.withByteArray nonce $ \noncePtr ->
B.withByteArray key $ \keyPtr ->
ccryptonite_chacha_init stPtr (fromIntegral nbRounds) kLen keyPtr nonceLen noncePtr
return $ State stPtr
where kLen = B.length key
nonceLen = B.length nonce
@ -72,60 +58,39 @@ initializeSimple seed
| sLen /= 40 = error "ChaCha Random: seed length should be 40 bytes"
| otherwise = unsafeDoIO $ do
stPtr <- B.alloc 64 $ \stPtr ->
withByteArray seed $ \seedPtr ->
ccryptonite_chacha_init (castPtr stPtr) 32 seedPtr 8 (seedPtr `plusPtr` 32)
B.withByteArray seed $ \seedPtr ->
ccryptonite_chacha_init_core stPtr 32 seedPtr 8 (seedPtr `plusPtr` 32)
return $ StateSimple stPtr
where
sLen = B.length seed
-- | Combine the chacha output and an arbitrary message with a xor,
-- and return the combined output and the new state.
combine :: State -- ^ the current ChaCha state
-> ByteString -- ^ the source to xor with the generator
-> (ByteString, State)
combine prev@(State nbRounds prevSt prevOut) src
| outputLen == 0 = (B.empty, prev)
| outputLen <= prevBufLen =
-- we have enough byte in the previous buffer to complete the query
-- without having to generate any extra bytes
let (b1,b2) = BS.splitAt outputLen prevOut
in (B.xor b1 src, State nbRounds prevSt b2)
| otherwise = unsafeDoIO $ do
-- adjusted len is the number of bytes lefts to generate after
-- copying from the previous buffer.
let adjustedLen = outputLen - prevBufLen
(roundedAlready, newBytesToGenerate) = round64 adjustedLen
nextBufLen = newBytesToGenerate - adjustedLen
fptr <- BS.mallocByteString (newBytesToGenerate + prevBufLen)
newSt <- withForeignPtr fptr $ \dstPtr ->
withByteArray src $ \srcPtr -> do
-- copy the previous buffer by xor if any
withByteArray prevOut $ \prevPtr ->
memXor dstPtr srcPtr prevPtr prevBufLen
-- then create a new mutable copy of state
B.copy prevSt $ \stPtr ->
ccryptonite_chacha_combine nbRounds
(dstPtr `plusPtr` prevBufLen)
(castPtr stPtr)
(srcPtr `plusPtr` prevBufLen)
(fromIntegral adjustedLen)
-- return combined byte
return ( BS.PS fptr 0 outputLen
, State nbRounds newSt (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen))
where
outputLen = B.length src
prevBufLen = B.length prevOut
combine :: ByteArray ba
=> State -- ^ the current ChaCha state
-> ba -- ^ the source to xor with the generator
-> (ba, State)
combine prevSt@(State prevStMem) src
| B.null src = (B.empty, prevSt)
| otherwise = unsafeDoIO $ do
(out, st) <- B.copyRet prevStMem $ \ctx ->
B.alloc (B.length src) $ \dstPtr ->
B.withByteArray src $ \srcPtr ->
ccryptonite_chacha_combine dstPtr ctx srcPtr (fromIntegral $ B.length src)
return (out, State st)
-- | Generate a number of bytes from the ChaCha output directly
--
-- TODO: use chacha_generate directly instead of using combine xor'ing with 0.
generate :: State -- ^ the current ChaCha state
generate :: ByteArray ba
=> State -- ^ the current ChaCha state
-> Int -- ^ the length of data to generate
-> (ByteString, State)
generate st len = combine st (BS.replicate len 0)
-> (ba, State)
generate prevSt@(State prevStMem) len
| len <= 0 = (B.empty, prevSt)
| otherwise = unsafeDoIO $ do
(out, st) <- B.copyRet prevStMem $ \ctx ->
B.alloc len $ \dstPtr ->
ccryptonite_chacha_generate dstPtr ctx (fromIntegral len)
return (out, State st)
-- | similar to 'generate' but assume certains values
generateSimple :: ByteArray ba
@ -135,20 +100,22 @@ generateSimple :: ByteArray ba
generateSimple (StateSimple prevSt) nbBytes = unsafeDoIO $ do
newSt <- B.copy prevSt (\_ -> return ())
output <- B.alloc nbBytes $ \dstPtr ->
withByteArray newSt $ \stPtr ->
ccryptonite_chacha_random 8 dstPtr (castPtr stPtr) (fromIntegral nbBytes)
B.withByteArray newSt $ \stPtr ->
ccryptonite_chacha_random 8 dstPtr stPtr (fromIntegral nbBytes)
return (output, StateSimple newSt)
foreign import ccall "cryptonite_chacha_init_core"
ccryptonite_chacha_init_core :: Ptr StateSimple -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
foreign import ccall "cryptonite_chacha_init"
ccryptonite_chacha_init :: Ptr State -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
ccryptonite_chacha_init :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
foreign import ccall "cryptonite_chacha_combine"
ccryptonite_chacha_combine :: Int -> Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
ccryptonite_chacha_combine :: Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
foreign import ccall "cryptonite_chacha_generate"
ccryptonite_chacha_generate :: Ptr Word8 -> Ptr State -> CUInt -> IO ()
foreign import ccall "cryptonite_chacha_random"
ccryptonite_chacha_random :: Int -> Ptr Word8 -> Ptr StateSimple -> CUInt -> IO ()
{-
foreign import ccall "cryptonite_chacha_generate"
ccryptonite_chacha_generate :: Int -> Ptr Word8 -> Ptr State -> CUInt -> IO ()
-}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2014 Vincent Hanquez <vincent@snarc.org>
* Copyright (c) 2014-2015 Vincent Hanquez <vincent@snarc.org>
*
* All rights reserved.
*
@ -34,14 +34,15 @@
#include "cryptonite_bitfn.h"
#include <stdio.h>
#define USE_8BITS 0
#define QR(a,b,c,d) \
a += b; d = rol32(d ^ a,16); \
c += d; b = rol32(b ^ c,12); \
a += b; d = rol32(d ^ a, 8); \
c += d; b = rol32(b ^ c, 7);
#define ALIGNED64(PTR) \
(((uintptr_t)(const void *)(PTR)) % 8 == 0)
static const uint8_t sigma[16] = "expand 32-byte k";
static const uint8_t tau[16] = "expand 16-byte k";
@ -96,9 +97,9 @@ static void chacha_core(int rounds, block *out, const cryptonite_chacha_state *i
}
/* only 2 valids values are 256 (32) and 128 (16) */
void cryptonite_chacha_init(cryptonite_chacha_state *st,
uint32_t keylen, const uint8_t *key,
uint32_t ivlen, const uint8_t *iv)
void cryptonite_chacha_init_core(cryptonite_chacha_state *st,
uint32_t keylen, const uint8_t *key,
uint32_t ivlen, const uint8_t *iv)
{
const uint8_t *constants = (keylen == 32) ? sigma : tau;
int i;
@ -135,66 +136,139 @@ void cryptonite_chacha_init(cryptonite_chacha_state *st,
}
}
void cryptonite_chacha_combine(uint32_t rounds, block *dst, cryptonite_chacha_state *st, const block *src, uint32_t bytes)
void cryptonite_chacha_init(cryptonite_chacha_context *ctx, uint8_t nb_rounds,
uint32_t keylen, const uint8_t *key,
uint32_t ivlen, const uint8_t *iv)
{
memset(ctx, 0, sizeof(*ctx));
ctx->nb_rounds = nb_rounds;
cryptonite_chacha_init_core(&ctx->st, keylen, key, ivlen, iv);
}
void cryptonite_chacha_combine(uint8_t *dst, cryptonite_chacha_context *ctx, const uint8_t *src, uint32_t bytes)
{
block out;
cryptonite_chacha_state *st;
int i;
if (!bytes)
return;
for (;; bytes -= 64, src += 1, dst += 1) {
chacha_core(rounds, &out, st);
/* xor the previous buffer first (if any) */
if (ctx->prev_len > 0) {
int to_copy = (ctx->prev_len < bytes) ? ctx->prev_len : bytes;
for (i = 0; i < to_copy; i++)
dst[i] = src[i] ^ ctx->prev[ctx->prev_ofs+i];
memset(ctx->prev + ctx->prev_ofs, 0, to_copy);
ctx->prev_len -= to_copy;
ctx->prev_ofs += to_copy;
src += to_copy;
dst += to_copy;
bytes -= to_copy;
}
if (bytes == 0)
return;
st = &ctx->st;
/* xor new 64-bytes chunks and store the left over if any */
for (; bytes >= 64; bytes -= 64, src += 64, dst += 64) {
/* generate new chunk and update state */
chacha_core(ctx->nb_rounds, &out, st);
st->d[12] += 1;
if (st->d[12] == 0)
st->d[13] += 1;
if (bytes <= 64) {
for (i = 0; i < bytes; i++)
dst->b[i] = src->b[i] ^ out.b[i];
for (; i < 64; i++)
dst->b[i] = out.b[i];
return;
}
#if USE_8BITS
for (i = 0; i < 64; ++i)
dst->b[i] = src->b[i] ^ out.b[i];
#else
/* fast copy using 64 bits */
for (i = 0; i < 8; i++)
dst->q[i] = src->q[i] ^ out.q[i];
#endif
dst[i] = src[i] ^ out.b[i];
}
if (bytes > 0) {
/* generate new chunk and update state */
chacha_core(ctx->nb_rounds, &out, st);
st->d[12] += 1;
if (st->d[12] == 0)
st->d[13] += 1;
/* xor as much as needed */
for (i = 0; i < bytes; i++)
dst[i] = src[i] ^ out.b[i];
/* copy the left over in the buffer */
ctx->prev_len = 64 - bytes;
ctx->prev_ofs = i;
for (; i < 64; i++) {
ctx->prev[i] = out.b[i];
}
}
}
void cryptonite_chacha_generate(uint32_t rounds, block *dst, cryptonite_chacha_state *st, uint32_t bytes)
void cryptonite_chacha_generate(uint8_t *dst, cryptonite_chacha_context *ctx, uint32_t bytes)
{
cryptonite_chacha_state *st;
block out;
int i;
if (!bytes)
return;
for (;; bytes -= 64, dst += 1) {
chacha_core(rounds, &out, st);
/* xor the previous buffer first (if any) */
if (ctx->prev_len > 0) {
int to_copy = (ctx->prev_len < bytes) ? ctx->prev_len : bytes;
for (i = 0; i < to_copy; i++)
dst[i] = ctx->prev[ctx->prev_ofs+i];
memset(ctx->prev + ctx->prev_ofs, 0, to_copy);
ctx->prev_len -= to_copy;
ctx->prev_ofs += to_copy;
dst += to_copy;
bytes -= to_copy;
}
if (bytes == 0)
return;
st = &ctx->st;
if (ALIGNED64(dst)) {
/* xor new 64-bytes chunks and store the left over if any */
for (; bytes >= 64; bytes -= 64, dst += 64) {
/* generate new chunk and update state */
chacha_core(ctx->nb_rounds, (block *) dst, st);
st->d[12] += 1;
if (st->d[12] == 0)
st->d[13] += 1;
}
} else {
/* xor new 64-bytes chunks and store the left over if any */
for (; bytes >= 64; bytes -= 64, dst += 64) {
/* generate new chunk and update state */
chacha_core(ctx->nb_rounds, &out, st);
st->d[12] += 1;
if (st->d[12] == 0)
st->d[13] += 1;
for (i = 0; i < 64; ++i)
dst[i] = out.b[i];
}
}
if (bytes > 0) {
/* generate new chunk and update state */
chacha_core(ctx->nb_rounds, &out, st);
st->d[12] += 1;
if (st->d[12] == 0)
st->d[13] += 1;
if (bytes <= 64) {
for (i = 0; i < bytes; ++i)
dst->b[i] = out.b[i];
return;
}
#if USE_8BITS
for (i = 0; i < 64; ++i)
dst->b[i] = out.b[i];
#else
for (i = 0; i < 8; i++)
dst->q[i] = out.q[i];
#endif
/* xor as much as needed */
for (i = 0; i < bytes; i++)
dst[i] = out.b[i];
/* copy the left over in the buffer */
ctx->prev_len = 64 - bytes;
ctx->prev_ofs = i;
for (; i < 64; i++)
ctx->prev[i] = out.b[i];
}
}
@ -207,11 +281,11 @@ void cryptonite_chacha_random(uint32_t rounds, uint8_t *dst, cryptonite_chacha_s
for (; bytes >= 16; bytes -= 16, dst += 16) {
chacha_core(rounds, &out, st);
memcpy(dst, out.b + 40, 16);
cryptonite_chacha_init(st, 32, out.b, 8, out.b + 32);
cryptonite_chacha_init_core(st, 32, out.b, 8, out.b + 32);
}
if (bytes) {
chacha_core(rounds, &out, st);
memcpy(dst, out.b + 40, bytes);
cryptonite_chacha_init(st, 32, out.b, 8, out.b + 32);
cryptonite_chacha_init_core(st, 32, out.b, 8, out.b + 32);
}
}

View File

@ -38,8 +38,17 @@ typedef union {
typedef block cryptonite_chacha_state;
void cryptonite_chacha_init(cryptonite_chacha_state *st, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
void cryptonite_chacha_combine(uint32_t rounds, block *dst, cryptonite_chacha_state *st, const block *src, uint32_t bytes);
void cryptonite_chacha_generate(uint32_t rounds, block *dst, cryptonite_chacha_state *st, uint32_t bytes);
typedef struct {
cryptonite_chacha_state st;
uint8_t prev[64];
uint8_t prev_ofs;
uint8_t prev_len;
uint8_t nb_rounds;
} cryptonite_chacha_context;
void cryptonite_chacha_init_core(cryptonite_chacha_state *st, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
void cryptonite_chacha_init(cryptonite_chacha_context *ctx, uint8_t nb_rounds, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
void cryptonite_chacha_combine(uint8_t *dst, cryptonite_chacha_context *st, const uint8_t *src, uint32_t bytes);
void cryptonite_chacha_generate(uint8_t *dst, cryptonite_chacha_context *st, uint32_t bytes);
#endif

View File

@ -38,19 +38,21 @@ tests = testGroup "ChaCha"
, testCase "8-256-K0-I0" (chachaRunSimple b8_256_k0_i0 8 32 8)
, testCase "12-256-K0-I0" (chachaRunSimple b12_256_k0_i0 12 32 8)
, testCase "20-256-K0-I0" (chachaRunSimple b20_256_k0_i0 20 32 8)
, testProperty "chunking" chachaChunks
, testProperty "generate-combine" chachaGenerateCombine
, testProperty "chunking-generate" chachaGenerateChunks
, testProperty "chunking-combine" chachaCombineChunks
]
where chachaRunSimple expected rounds klen nonceLen =
let chacha = ChaCha.initialize rounds (B.replicate klen 0) (B.replicate nonceLen 0)
in expected @=? fst (ChaCha.generate chacha (B.length expected))
chachaChunks :: ChunkingLen -> Vector -> Bool
chachaChunks (ChunkingLen ckLen) (Vector rounds key iv) =
chachaGenerateChunks :: ChunkingLen -> Vector -> Bool
chachaGenerateChunks (ChunkingLen ckLen) (Vector rounds key iv) =
let initChaCha = ChaCha.initialize rounds key iv
nbBytes = 1048
(expected,_) = ChaCha.generate initChaCha nbBytes
chunks = loop nbBytes ckLen (ChaCha.initialize rounds key iv)
in expected == B.concat chunks
chunks = loop nbBytes ckLen initChaCha
in expected `propertyEq` B.concat chunks
where loop n [] chacha = loop n ckLen chacha
loop 0 _ _ = []
@ -58,3 +60,31 @@ tests = testGroup "ChaCha"
let len = min x n
(c, next) = ChaCha.generate chacha len
in c : loop (n - len) xs next
chachaGenerateCombine :: ChunkingLen0_127 -> Vector -> Int0_2901 -> Bool
chachaGenerateCombine (ChunkingLen0_127 ckLen) (Vector rounds key iv) (Int0_2901 nbBytes) =
let initChaCha = ChaCha.initialize rounds key iv
in loop nbBytes ckLen initChaCha
where loop n [] chacha = loop n ckLen chacha
loop 0 _ _ = True
loop n (x:xs) chacha =
let len = min x n
(c1, next) = ChaCha.generate chacha len
(c2, _) = ChaCha.combine chacha (B.replicate len 0)
in if c1 == c2 then loop (n - len) xs next else False
chachaCombineChunks :: ChunkingLen0_127 -> Vector -> ArbitraryBS0_2901 -> Bool
chachaCombineChunks (ChunkingLen0_127 ckLen) (Vector rounds key iv) (ArbitraryBS0_2901 wholebs) =
let initChaCha = ChaCha.initialize rounds key iv
(expected,_) = ChaCha.combine initChaCha wholebs
chunks = loop wholebs ckLen initChaCha
in expected `propertyEq` B.concat chunks
where loop bs [] chacha = loop bs ckLen chacha
loop bs (x:xs) chacha
| B.null bs = []
| otherwise =
let (bs1, bs2) = B.splitAt (min x (B.length bs)) bs
(c, next) = ChaCha.combine chacha bs1
in c : loop bs2 xs next

View File

@ -14,12 +14,25 @@ newtype ChunkingLen = ChunkingLen [Int]
instance Arbitrary ChunkingLen where
arbitrary = ChunkingLen `fmap` replicateM 16 (choose (0,14))
newtype ChunkingLen0_127 = ChunkingLen0_127 [Int]
deriving (Show,Eq)
instance Arbitrary ChunkingLen0_127 where
arbitrary = ChunkingLen0_127 `fmap` replicateM 16 (choose (0,127))
newtype ArbitraryBS0_2901 = ArbitraryBS0_2901 ByteString
deriving (Show,Eq,Ord)
instance Arbitrary ArbitraryBS0_2901 where
arbitrary = ArbitraryBS0_2901 `fmap` arbitraryBSof 0 2901
newtype Int0_2901 = Int0_2901 Int
deriving (Show,Eq,Ord)
instance Arbitrary Int0_2901 where
arbitrary = Int0_2901 `fmap` choose (0,2901)
arbitraryBS :: Int -> Gen ByteString
arbitraryBS n = B.pack `fmap` replicateM n arbitrary