[Salsa] opaquify the state just like for hash functions

add more tests
This commit is contained in:
Vincent Hanquez 2015-05-22 14:04:54 +01:00
parent 9a69c61e84
commit 1dacb7fa94
4 changed files with 196 additions and 125 deletions

View File

@ -13,30 +13,15 @@ module Crypto.Cipher.Salsa
, State
) where
import Data.ByteString (ByteString)
import Data.Memory.PtrMethods (memXor)
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString as BS
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.C.Types
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Foreign.Ptr
import Foreign.C.Types
-- | Salsa context
data State = State Int -- number of rounds
ScrubbedBytes -- Salsa's state
Int -- offset of data in previously generated chunk
ByteString -- previous generated chunk
round64 :: Int -> (Bool, Int)
round64 len
| len == 0 = (True, 0)
| m == 0 = (True, len)
| otherwise = (False, len + (64 - m))
where m = len `mod` 64
newtype State = State ScrubbedBytes
-- | Initialize a new Salsa context with the number of rounds,
-- the key and the nonce associated.
@ -50,11 +35,11 @@ initialize nbRounds key nonce
| not (nonceLen `elem` [8,12]) = error "Salsa: nonce length should be 64 or 96 bits"
| not (nbRounds `elem` [8,12,20]) = error "Salsa: rounds should be 8, 12 or 20"
| otherwise = unsafeDoIO $ do
stPtr <- B.alloc 64 $ \stPtr ->
stPtr <- B.alloc 132 $ \stPtr ->
B.withByteArray nonce $ \noncePtr ->
B.withByteArray key $ \keyPtr ->
ccryptonite_salsa_init stPtr kLen keyPtr nonceLen noncePtr
return $ State nbRounds stPtr 0 B.empty
ccryptonite_salsa_init stPtr (fromIntegral nbRounds) kLen keyPtr nonceLen noncePtr
return $ State stPtr
where kLen = B.length key
nonceLen = B.length nonce
@ -64,59 +49,33 @@ combine :: ByteArray ba
=> State -- ^ the current Salsa state
-> ba -- ^ the source to xor with the generator
-> (ba, State)
combine prev@(State nbRounds prevSt prevOffset prevOut) src
| outputLen == 0 = (B.empty, prev)
| outputLen <= prevBufLen = unsafeDoIO $ do
-- we have enough byte in the previous buffer to complete the query
-- without having to generate any extra bytes
output <- B.copy src $ \dst ->
B.withByteArray prevOut $ \prevPtr ->
memXor dst dst (prevPtr `plusPtr` prevOffset) outputLen
return (output, State nbRounds prevSt (prevOffset + outputLen) prevOut)
| otherwise = unsafeDoIO $ do
-- adjusted len is the number of bytes lefts to generate after
-- copying from the previous buffer.
let adjustedLen = outputLen - prevBufLen
(roundedAlready, newBytesToGenerate) = round64 adjustedLen
nextBufLen = newBytesToGenerate - adjustedLen
fptr <- BS.mallocByteString (newBytesToGenerate + prevBufLen)
newSt <- withForeignPtr fptr $ \dstPtr ->
B.withByteArray src $ \srcPtr -> do
-- copy the previous buffer by xor if any
B.withByteArray prevOut $ \prevPtr ->
memXor dstPtr srcPtr (prevPtr `plusPtr` prevOffset) prevBufLen
-- then create a new mutable copy of state
B.copy prevSt $ \stPtr ->
ccryptonite_salsa_combine nbRounds
(dstPtr `plusPtr` prevBufLen)
(castPtr stPtr)
(srcPtr `plusPtr` prevBufLen)
(fromIntegral adjustedLen)
-- return combined byte
return ( B.convert (BS.PS fptr 0 outputLen)
, State nbRounds newSt 0 (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen))
where
outputLen = B.length src
prevBufLen = B.length prevOut - prevOffset
combine prevSt@(State prevStMem) src
| B.null src = (B.empty, prevSt)
| otherwise = unsafeDoIO $ do
(out, st) <- B.copyRet prevStMem $ \ctx ->
B.alloc (B.length src) $ \dstPtr ->
B.withByteArray src $ \srcPtr -> do
ccryptonite_salsa_combine dstPtr ctx srcPtr (fromIntegral $ B.length src)
return (out, State st)
-- | Generate a number of bytes from the Salsa output directly
--
-- TODO: use salsa_generate directly instead of using combine xor'ing with 0.
generate :: ByteArray ba
=> State -- ^ the current Salsa state
-> Int -- ^ the length of data to generate
-> (ba, State)
generate st len = combine st (B.zero len)
generate prevSt@(State prevStMem) len
| len <= 0 = (B.empty, prevSt)
| otherwise = unsafeDoIO $ do
(out, st) <- B.copyRet prevStMem $ \ctx ->
B.alloc len $ \dstPtr ->
ccryptonite_salsa_generate dstPtr ctx (fromIntegral len)
return (out, State st)
foreign import ccall "cryptonite_salsa_init"
ccryptonite_salsa_init :: Ptr State -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
ccryptonite_salsa_init :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
foreign import ccall "cryptonite_salsa_combine"
ccryptonite_salsa_combine :: Int -> Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
ccryptonite_salsa_combine :: Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
{-
foreign import ccall "cryptonite_salsa_generate"
ccryptonite_salsa_generate :: Int -> Ptr Word8 -> Ptr State -> CUInt -> IO ()
-}
ccryptonite_salsa_generate :: Ptr Word8 -> Ptr State -> CUInt -> IO ()

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2014 Vincent Hanquez <vincent@snarc.org>
* Copyright (c) 2014-2015 Vincent Hanquez <vincent@snarc.org>
*
* All rights reserved.
*
@ -29,11 +29,10 @@
*/
#include <stdint.h>
#include <string.h>
#include <stdio.h>
#include "cryptonite_salsa.h"
#include "cryptonite_bitfn.h"
#include <stdio.h>
#define USE_8BITS 0
static const uint8_t sigma[16] = "expand 32-byte k";
static const uint8_t tau[16] = "expand 16-byte k";
@ -44,6 +43,9 @@ static const uint8_t tau[16] = "expand 16-byte k";
d ^= rol32(c+b, 13); \
a ^= rol32(d+c, 18);
#define ALIGNED64(PTR) \
(((uintptr_t)(const void *)(PTR)) % 8 == 0)
#define SALSA_CORE_LOOP \
for (i = rounds; i > 0; i -= 2) { \
QR (x0,x4,x8,x12); \
@ -117,9 +119,9 @@ void cryptonite_salsa_core_xor(int rounds, block *out, block *in)
}
/* only 2 valid values for keylen are 256 (32) and 128 (16) */
void cryptonite_salsa_init(cryptonite_salsa_state *st,
uint32_t keylen, const uint8_t *key,
uint32_t ivlen, const uint8_t *iv)
void cryptonite_salsa_init_core(cryptonite_salsa_state *st,
uint32_t keylen, const uint8_t *key,
uint32_t ivlen, const uint8_t *iv)
{
const uint8_t *constants = (keylen == 32) ? sigma : tau;
int i;
@ -157,67 +159,139 @@ void cryptonite_salsa_init(cryptonite_salsa_state *st,
}
}
void cryptonite_salsa_combine(uint32_t rounds, block *dst, cryptonite_salsa_state *st, const block *src, uint32_t bytes)
void cryptonite_salsa_init(cryptonite_salsa_context *ctx, uint8_t nb_rounds,
uint32_t keylen, const uint8_t *key,
uint32_t ivlen, const uint8_t *iv)
{
memset(ctx, 0, sizeof(*ctx));
ctx->nb_rounds = nb_rounds;
cryptonite_salsa_init_core(&ctx->st, keylen, key, ivlen, iv);
}
void cryptonite_salsa_combine(uint8_t *dst, cryptonite_salsa_context *ctx, const uint8_t *src, uint32_t bytes)
{
block out;
cryptonite_salsa_state *st;
int i;
if (!bytes)
return;
/* xor the previous buffer first (if any) */
if (ctx->prev_len > 0) {
int to_copy = (ctx->prev_len < bytes) ? ctx->prev_len : bytes;
for (i = 0; i < to_copy; i++)
dst[i] = src[i] ^ ctx->prev[ctx->prev_ofs+i];
memset(ctx->prev + ctx->prev_ofs, 0, to_copy);
ctx->prev_len -= to_copy;
ctx->prev_ofs += to_copy;
src += to_copy;
dst += to_copy;
bytes -= to_copy;
}
if (bytes == 0)
return;
st = &ctx->st;
/* xor new 64-bytes chunks and store the left over if any */
for (; bytes >= 64; bytes -= 64, src += 64, dst += 64) {
/* generate new chunk and update state */
salsa_core(ctx->nb_rounds, &out, st);
st->d[8] += 1;
if (st->d[8] == 0)
st->d[9] += 1;
for (i = 0; i < 64; ++i)
dst[i] = src[i] ^ out.b[i];
}
if (bytes > 0) {
/* generate new chunk and update state */
salsa_core(ctx->nb_rounds, &out, st);
st->d[8] += 1;
if (st->d[8] == 0)
st->d[9] += 1;
/* xor as much as needed */
for (i = 0; i < bytes; i++)
dst[i] = src[i] ^ out.b[i];
/* copy the left over in the buffer */
ctx->prev_len = 64 - bytes;
ctx->prev_ofs = i;
for (; i < 64; i++) {
ctx->prev[i] = out.b[i];
}
}
}
void cryptonite_salsa_generate(uint8_t *dst, cryptonite_salsa_context *ctx, uint32_t bytes)
{
cryptonite_salsa_state *st;
block out;
int i;
if (!bytes)
return;
for (;; bytes -= 64, src += 1, dst += 1) {
salsa_core(rounds, &out, st);
st->d[8] += 1;
if (st->d[8] == 0)
st->d[9] += 1;
if (bytes <= 64) {
for (i = 0; i < bytes; i++)
dst->b[i] = src->b[i] ^ out.b[i];
for (; i < 64; i++)
dst->b[i] = out.b[i];
return;
}
#if USE_8BITS
for (i = 0; i < 64; ++i)
dst->b[i] = src->b[i] ^ out.b[i];
#else
/* fast copy using 64 bits */
for (i = 0; i < 8; i++)
dst->q[i] = src->q[i] ^ out.q[i];
#endif
/* xor the previous buffer first (if any) */
if (ctx->prev_len > 0) {
int to_copy = (ctx->prev_len < bytes) ? ctx->prev_len : bytes;
for (i = 0; i < to_copy; i++)
dst[i] = ctx->prev[ctx->prev_ofs+i];
memset(ctx->prev + ctx->prev_ofs, 0, to_copy);
ctx->prev_len -= to_copy;
ctx->prev_ofs += to_copy;
dst += to_copy;
bytes -= to_copy;
}
}
void cryptonite_salsa_generate(uint32_t rounds, block *dst, cryptonite_salsa_state *st, uint32_t bytes)
{
block out;
int i;
if (!bytes)
if (bytes == 0)
return;
for (;; bytes -= 64, dst += 1) {
salsa_core(rounds, &out, st);
st = &ctx->st;
if (ALIGNED64(dst)) {
/* xor new 64-bytes chunks and store the left over if any */
for (; bytes >= 64; bytes -= 64, dst += 64) {
/* generate new chunk and update state */
salsa_core(ctx->nb_rounds, (block *) dst, st);
st->d[8] += 1;
if (st->d[8] == 0)
st->d[9] += 1;
}
} else {
/* xor new 64-bytes chunks and store the left over if any */
for (; bytes >= 64; bytes -= 64, dst += 64) {
/* generate new chunk and update state */
salsa_core(ctx->nb_rounds, &out, st);
st->d[8] += 1;
if (st->d[8] == 0)
st->d[9] += 1;
for (i = 0; i < 64; ++i)
dst[i] = out.b[i];
}
}
if (bytes > 0) {
/* generate new chunk and update state */
salsa_core(ctx->nb_rounds, &out, st);
st->d[8] += 1;
if (st->d[8] == 0)
st->d[9] += 1;
if (bytes <= 64) {
for (i = 0; i < bytes; ++i)
dst->b[i] = out.b[i];
return;
}
#if USE_8BITS
for (i = 0; i < 64; ++i)
dst->b[i] = out.b[i];
#else
for (i = 0; i < 8; i++)
dst->q[i] = out.q[i];
#endif
/* xor as much as needed */
for (i = 0; i < bytes; i++)
dst[i] = out.b[i];
/* copy the left over in the buffer */
ctx->prev_len = 64 - bytes;
ctx->prev_ofs = i;
for (; i < 64; i++)
ctx->prev[i] = out.b[i];
}
}

View File

@ -38,11 +38,20 @@ typedef union {
typedef block cryptonite_salsa_state;
typedef struct {
cryptonite_salsa_state st;
uint8_t prev[64];
uint8_t prev_ofs;
uint8_t prev_len;
uint8_t nb_rounds;
} cryptonite_salsa_context;
/* for scrypt */
void cryptonite_salsa_core_xor(int rounds, block *out, block *in);
void cryptonite_salsa_init(cryptonite_salsa_state *st, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
void cryptonite_salsa_combine(uint32_t rounds, block *dst, cryptonite_salsa_state *st, const block *src, uint32_t bytes);
void cryptonite_salsa_generate(uint32_t rounds, block *dst, cryptonite_salsa_state *st, uint32_t bytes);
void cryptonite_salsa_init_core(cryptonite_salsa_state *st, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
void cryptonite_salsa_init(cryptonite_salsa_context *ctx, uint8_t nb_rounds, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
void cryptonite_salsa_combine(uint8_t *dst, cryptonite_salsa_context *st, const uint8_t *src, uint32_t bytes);
void cryptonite_salsa_generate(uint8_t *dst, cryptonite_salsa_context *st, uint32_t bytes);
#endif

View File

@ -38,7 +38,9 @@ instance Arbitrary RandomVector where
tests = testGroup "Salsa"
[ testGroup "KAT" $
map (\(i,f) -> testCase (show (i :: Int)) f) $ zip [1..] $ map (\(r, k,i,e) -> salsaRunSimple e r k i) vectors
, testProperty "chunking" salsaChunks
, testProperty "generate-combine" salsaGenerateCombine
, testProperty "chunking-generate" salsaGenerateChunks
, testProperty "chunking-combine" salsaCombineChunks
]
where
salsaRunSimple expected rounds key nonce =
@ -55,8 +57,20 @@ tests = testGroup "Salsa"
in e : salsaLoop (current + B.length expectBs) salsaNext rs
| otherwise = error "internal error in salsaLoop"
salsaChunks :: ChunkingLen -> RandomVector -> Bool
salsaChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) =
salsaGenerateCombine :: ChunkingLen0_127 -> RandomVector -> Int0_2901 -> Bool
salsaGenerateCombine (ChunkingLen0_127 ckLen) (RandomVector (rounds, key, iv, _)) (Int0_2901 nbBytes) =
let initSalsa = Salsa.initialize rounds key iv
in loop nbBytes ckLen initSalsa
where loop n [] salsa = loop n ckLen salsa
loop 0 _ _ = True
loop n (x:xs) salsa =
let len = min x n
(c1, next) = Salsa.generate salsa len
(c2, _) = Salsa.combine salsa (B.replicate len 0)
in if c1 == c2 then loop (n - len) xs next else False
salsaGenerateChunks :: ChunkingLen -> RandomVector -> Bool
salsaGenerateChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) =
let initSalsa = Salsa.initialize rounds key iv
nbBytes = 1048
(expected,_) = Salsa.generate initSalsa nbBytes
@ -69,3 +83,18 @@ tests = testGroup "Salsa"
let len = min x n
(c, next) = Salsa.generate salsa len
in c : loop (n - len) xs next
salsaCombineChunks :: ChunkingLen -> RandomVector -> ArbitraryBS0_2901 -> Bool
salsaCombineChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) (ArbitraryBS0_2901 wholebs) =
let initSalsa = Salsa.initialize rounds key iv
(expected,_) = Salsa.combine initSalsa wholebs
chunks = loop wholebs ckLen initSalsa
in expected `propertyEq` B.concat chunks
where loop bs [] salsa = loop bs ckLen salsa
loop bs (x:xs) salsa
| B.null bs = []
| otherwise =
let (bs1, bs2) = B.splitAt (min x (B.length bs)) bs
(c, next) = Salsa.combine salsa bs1
in c : loop bs2 xs next