[Salsa] opaquify the state just like for hash functions
add more tests
This commit is contained in:
parent
9a69c61e84
commit
1dacb7fa94
@ -13,30 +13,15 @@ module Crypto.Cipher.Salsa
|
||||
, State
|
||||
) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Memory.PtrMethods (memXor)
|
||||
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes)
|
||||
import qualified Crypto.Internal.ByteArray as B
|
||||
import qualified Data.ByteString.Internal as BS
|
||||
import qualified Data.ByteString as BS
|
||||
import Crypto.Internal.Compat
|
||||
import Crypto.Internal.Imports
|
||||
import Foreign.Ptr
|
||||
import Foreign.ForeignPtr
|
||||
import Foreign.C.Types
|
||||
import Crypto.Internal.Compat
|
||||
import Crypto.Internal.Imports
|
||||
import Foreign.Ptr
|
||||
import Foreign.C.Types
|
||||
|
||||
-- | Salsa context
|
||||
data State = State Int -- number of rounds
|
||||
ScrubbedBytes -- Salsa's state
|
||||
Int -- offset of data in previously generated chunk
|
||||
ByteString -- previous generated chunk
|
||||
|
||||
round64 :: Int -> (Bool, Int)
|
||||
round64 len
|
||||
| len == 0 = (True, 0)
|
||||
| m == 0 = (True, len)
|
||||
| otherwise = (False, len + (64 - m))
|
||||
where m = len `mod` 64
|
||||
newtype State = State ScrubbedBytes
|
||||
|
||||
-- | Initialize a new Salsa context with the number of rounds,
|
||||
-- the key and the nonce associated.
|
||||
@ -50,11 +35,11 @@ initialize nbRounds key nonce
|
||||
| not (nonceLen `elem` [8,12]) = error "Salsa: nonce length should be 64 or 96 bits"
|
||||
| not (nbRounds `elem` [8,12,20]) = error "Salsa: rounds should be 8, 12 or 20"
|
||||
| otherwise = unsafeDoIO $ do
|
||||
stPtr <- B.alloc 64 $ \stPtr ->
|
||||
stPtr <- B.alloc 132 $ \stPtr ->
|
||||
B.withByteArray nonce $ \noncePtr ->
|
||||
B.withByteArray key $ \keyPtr ->
|
||||
ccryptonite_salsa_init stPtr kLen keyPtr nonceLen noncePtr
|
||||
return $ State nbRounds stPtr 0 B.empty
|
||||
ccryptonite_salsa_init stPtr (fromIntegral nbRounds) kLen keyPtr nonceLen noncePtr
|
||||
return $ State stPtr
|
||||
where kLen = B.length key
|
||||
nonceLen = B.length nonce
|
||||
|
||||
@ -64,59 +49,33 @@ combine :: ByteArray ba
|
||||
=> State -- ^ the current Salsa state
|
||||
-> ba -- ^ the source to xor with the generator
|
||||
-> (ba, State)
|
||||
combine prev@(State nbRounds prevSt prevOffset prevOut) src
|
||||
| outputLen == 0 = (B.empty, prev)
|
||||
| outputLen <= prevBufLen = unsafeDoIO $ do
|
||||
-- we have enough byte in the previous buffer to complete the query
|
||||
-- without having to generate any extra bytes
|
||||
output <- B.copy src $ \dst ->
|
||||
B.withByteArray prevOut $ \prevPtr ->
|
||||
memXor dst dst (prevPtr `plusPtr` prevOffset) outputLen
|
||||
return (output, State nbRounds prevSt (prevOffset + outputLen) prevOut)
|
||||
| otherwise = unsafeDoIO $ do
|
||||
-- adjusted len is the number of bytes lefts to generate after
|
||||
-- copying from the previous buffer.
|
||||
let adjustedLen = outputLen - prevBufLen
|
||||
(roundedAlready, newBytesToGenerate) = round64 adjustedLen
|
||||
nextBufLen = newBytesToGenerate - adjustedLen
|
||||
|
||||
fptr <- BS.mallocByteString (newBytesToGenerate + prevBufLen)
|
||||
newSt <- withForeignPtr fptr $ \dstPtr ->
|
||||
B.withByteArray src $ \srcPtr -> do
|
||||
-- copy the previous buffer by xor if any
|
||||
B.withByteArray prevOut $ \prevPtr ->
|
||||
memXor dstPtr srcPtr (prevPtr `plusPtr` prevOffset) prevBufLen
|
||||
|
||||
-- then create a new mutable copy of state
|
||||
B.copy prevSt $ \stPtr ->
|
||||
ccryptonite_salsa_combine nbRounds
|
||||
(dstPtr `plusPtr` prevBufLen)
|
||||
(castPtr stPtr)
|
||||
(srcPtr `plusPtr` prevBufLen)
|
||||
(fromIntegral adjustedLen)
|
||||
-- return combined byte
|
||||
return ( B.convert (BS.PS fptr 0 outputLen)
|
||||
, State nbRounds newSt 0 (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen))
|
||||
where
|
||||
outputLen = B.length src
|
||||
prevBufLen = B.length prevOut - prevOffset
|
||||
combine prevSt@(State prevStMem) src
|
||||
| B.null src = (B.empty, prevSt)
|
||||
| otherwise = unsafeDoIO $ do
|
||||
(out, st) <- B.copyRet prevStMem $ \ctx ->
|
||||
B.alloc (B.length src) $ \dstPtr ->
|
||||
B.withByteArray src $ \srcPtr -> do
|
||||
ccryptonite_salsa_combine dstPtr ctx srcPtr (fromIntegral $ B.length src)
|
||||
return (out, State st)
|
||||
|
||||
-- | Generate a number of bytes from the Salsa output directly
|
||||
--
|
||||
-- TODO: use salsa_generate directly instead of using combine xor'ing with 0.
|
||||
generate :: ByteArray ba
|
||||
=> State -- ^ the current Salsa state
|
||||
-> Int -- ^ the length of data to generate
|
||||
-> (ba, State)
|
||||
generate st len = combine st (B.zero len)
|
||||
generate prevSt@(State prevStMem) len
|
||||
| len <= 0 = (B.empty, prevSt)
|
||||
| otherwise = unsafeDoIO $ do
|
||||
(out, st) <- B.copyRet prevStMem $ \ctx ->
|
||||
B.alloc len $ \dstPtr ->
|
||||
ccryptonite_salsa_generate dstPtr ctx (fromIntegral len)
|
||||
return (out, State st)
|
||||
|
||||
foreign import ccall "cryptonite_salsa_init"
|
||||
ccryptonite_salsa_init :: Ptr State -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
|
||||
ccryptonite_salsa_init :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
|
||||
|
||||
foreign import ccall "cryptonite_salsa_combine"
|
||||
ccryptonite_salsa_combine :: Int -> Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
|
||||
ccryptonite_salsa_combine :: Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO ()
|
||||
|
||||
{-
|
||||
foreign import ccall "cryptonite_salsa_generate"
|
||||
ccryptonite_salsa_generate :: Int -> Ptr Word8 -> Ptr State -> CUInt -> IO ()
|
||||
-}
|
||||
ccryptonite_salsa_generate :: Ptr Word8 -> Ptr State -> CUInt -> IO ()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2014 Vincent Hanquez <vincent@snarc.org>
|
||||
* Copyright (c) 2014-2015 Vincent Hanquez <vincent@snarc.org>
|
||||
*
|
||||
* All rights reserved.
|
||||
*
|
||||
@ -29,11 +29,10 @@
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#include "cryptonite_salsa.h"
|
||||
#include "cryptonite_bitfn.h"
|
||||
#include <stdio.h>
|
||||
|
||||
#define USE_8BITS 0
|
||||
|
||||
static const uint8_t sigma[16] = "expand 32-byte k";
|
||||
static const uint8_t tau[16] = "expand 16-byte k";
|
||||
@ -44,6 +43,9 @@ static const uint8_t tau[16] = "expand 16-byte k";
|
||||
d ^= rol32(c+b, 13); \
|
||||
a ^= rol32(d+c, 18);
|
||||
|
||||
#define ALIGNED64(PTR) \
|
||||
(((uintptr_t)(const void *)(PTR)) % 8 == 0)
|
||||
|
||||
#define SALSA_CORE_LOOP \
|
||||
for (i = rounds; i > 0; i -= 2) { \
|
||||
QR (x0,x4,x8,x12); \
|
||||
@ -117,9 +119,9 @@ void cryptonite_salsa_core_xor(int rounds, block *out, block *in)
|
||||
}
|
||||
|
||||
/* only 2 valid values for keylen are 256 (32) and 128 (16) */
|
||||
void cryptonite_salsa_init(cryptonite_salsa_state *st,
|
||||
uint32_t keylen, const uint8_t *key,
|
||||
uint32_t ivlen, const uint8_t *iv)
|
||||
void cryptonite_salsa_init_core(cryptonite_salsa_state *st,
|
||||
uint32_t keylen, const uint8_t *key,
|
||||
uint32_t ivlen, const uint8_t *iv)
|
||||
{
|
||||
const uint8_t *constants = (keylen == 32) ? sigma : tau;
|
||||
int i;
|
||||
@ -157,67 +159,139 @@ void cryptonite_salsa_init(cryptonite_salsa_state *st,
|
||||
}
|
||||
}
|
||||
|
||||
void cryptonite_salsa_combine(uint32_t rounds, block *dst, cryptonite_salsa_state *st, const block *src, uint32_t bytes)
|
||||
void cryptonite_salsa_init(cryptonite_salsa_context *ctx, uint8_t nb_rounds,
|
||||
uint32_t keylen, const uint8_t *key,
|
||||
uint32_t ivlen, const uint8_t *iv)
|
||||
{
|
||||
memset(ctx, 0, sizeof(*ctx));
|
||||
ctx->nb_rounds = nb_rounds;
|
||||
cryptonite_salsa_init_core(&ctx->st, keylen, key, ivlen, iv);
|
||||
}
|
||||
|
||||
void cryptonite_salsa_combine(uint8_t *dst, cryptonite_salsa_context *ctx, const uint8_t *src, uint32_t bytes)
|
||||
{
|
||||
block out;
|
||||
cryptonite_salsa_state *st;
|
||||
int i;
|
||||
|
||||
if (!bytes)
|
||||
return;
|
||||
|
||||
/* xor the previous buffer first (if any) */
|
||||
if (ctx->prev_len > 0) {
|
||||
int to_copy = (ctx->prev_len < bytes) ? ctx->prev_len : bytes;
|
||||
for (i = 0; i < to_copy; i++)
|
||||
dst[i] = src[i] ^ ctx->prev[ctx->prev_ofs+i];
|
||||
memset(ctx->prev + ctx->prev_ofs, 0, to_copy);
|
||||
ctx->prev_len -= to_copy;
|
||||
ctx->prev_ofs += to_copy;
|
||||
src += to_copy;
|
||||
dst += to_copy;
|
||||
bytes -= to_copy;
|
||||
}
|
||||
|
||||
if (bytes == 0)
|
||||
return;
|
||||
|
||||
st = &ctx->st;
|
||||
|
||||
/* xor new 64-bytes chunks and store the left over if any */
|
||||
for (; bytes >= 64; bytes -= 64, src += 64, dst += 64) {
|
||||
/* generate new chunk and update state */
|
||||
salsa_core(ctx->nb_rounds, &out, st);
|
||||
st->d[8] += 1;
|
||||
if (st->d[8] == 0)
|
||||
st->d[9] += 1;
|
||||
|
||||
for (i = 0; i < 64; ++i)
|
||||
dst[i] = src[i] ^ out.b[i];
|
||||
}
|
||||
|
||||
if (bytes > 0) {
|
||||
/* generate new chunk and update state */
|
||||
salsa_core(ctx->nb_rounds, &out, st);
|
||||
st->d[8] += 1;
|
||||
if (st->d[8] == 0)
|
||||
st->d[9] += 1;
|
||||
|
||||
/* xor as much as needed */
|
||||
for (i = 0; i < bytes; i++)
|
||||
dst[i] = src[i] ^ out.b[i];
|
||||
|
||||
/* copy the left over in the buffer */
|
||||
ctx->prev_len = 64 - bytes;
|
||||
ctx->prev_ofs = i;
|
||||
for (; i < 64; i++) {
|
||||
ctx->prev[i] = out.b[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cryptonite_salsa_generate(uint8_t *dst, cryptonite_salsa_context *ctx, uint32_t bytes)
|
||||
{
|
||||
cryptonite_salsa_state *st;
|
||||
block out;
|
||||
int i;
|
||||
|
||||
if (!bytes)
|
||||
return;
|
||||
|
||||
for (;; bytes -= 64, src += 1, dst += 1) {
|
||||
salsa_core(rounds, &out, st);
|
||||
|
||||
st->d[8] += 1;
|
||||
if (st->d[8] == 0)
|
||||
st->d[9] += 1;
|
||||
|
||||
if (bytes <= 64) {
|
||||
for (i = 0; i < bytes; i++)
|
||||
dst->b[i] = src->b[i] ^ out.b[i];
|
||||
for (; i < 64; i++)
|
||||
dst->b[i] = out.b[i];
|
||||
return;
|
||||
}
|
||||
|
||||
#if USE_8BITS
|
||||
for (i = 0; i < 64; ++i)
|
||||
dst->b[i] = src->b[i] ^ out.b[i];
|
||||
#else
|
||||
/* fast copy using 64 bits */
|
||||
for (i = 0; i < 8; i++)
|
||||
dst->q[i] = src->q[i] ^ out.q[i];
|
||||
#endif
|
||||
/* xor the previous buffer first (if any) */
|
||||
if (ctx->prev_len > 0) {
|
||||
int to_copy = (ctx->prev_len < bytes) ? ctx->prev_len : bytes;
|
||||
for (i = 0; i < to_copy; i++)
|
||||
dst[i] = ctx->prev[ctx->prev_ofs+i];
|
||||
memset(ctx->prev + ctx->prev_ofs, 0, to_copy);
|
||||
ctx->prev_len -= to_copy;
|
||||
ctx->prev_ofs += to_copy;
|
||||
dst += to_copy;
|
||||
bytes -= to_copy;
|
||||
}
|
||||
}
|
||||
|
||||
void cryptonite_salsa_generate(uint32_t rounds, block *dst, cryptonite_salsa_state *st, uint32_t bytes)
|
||||
{
|
||||
block out;
|
||||
int i;
|
||||
|
||||
if (!bytes)
|
||||
if (bytes == 0)
|
||||
return;
|
||||
|
||||
for (;; bytes -= 64, dst += 1) {
|
||||
salsa_core(rounds, &out, st);
|
||||
st = &ctx->st;
|
||||
|
||||
if (ALIGNED64(dst)) {
|
||||
/* xor new 64-bytes chunks and store the left over if any */
|
||||
for (; bytes >= 64; bytes -= 64, dst += 64) {
|
||||
/* generate new chunk and update state */
|
||||
salsa_core(ctx->nb_rounds, (block *) dst, st);
|
||||
st->d[8] += 1;
|
||||
if (st->d[8] == 0)
|
||||
st->d[9] += 1;
|
||||
}
|
||||
} else {
|
||||
/* xor new 64-bytes chunks and store the left over if any */
|
||||
for (; bytes >= 64; bytes -= 64, dst += 64) {
|
||||
/* generate new chunk and update state */
|
||||
salsa_core(ctx->nb_rounds, &out, st);
|
||||
st->d[8] += 1;
|
||||
if (st->d[8] == 0)
|
||||
st->d[9] += 1;
|
||||
|
||||
for (i = 0; i < 64; ++i)
|
||||
dst[i] = out.b[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (bytes > 0) {
|
||||
/* generate new chunk and update state */
|
||||
salsa_core(ctx->nb_rounds, &out, st);
|
||||
st->d[8] += 1;
|
||||
if (st->d[8] == 0)
|
||||
st->d[9] += 1;
|
||||
|
||||
if (bytes <= 64) {
|
||||
for (i = 0; i < bytes; ++i)
|
||||
dst->b[i] = out.b[i];
|
||||
return;
|
||||
}
|
||||
#if USE_8BITS
|
||||
for (i = 0; i < 64; ++i)
|
||||
dst->b[i] = out.b[i];
|
||||
#else
|
||||
for (i = 0; i < 8; i++)
|
||||
dst->q[i] = out.q[i];
|
||||
#endif
|
||||
/* xor as much as needed */
|
||||
for (i = 0; i < bytes; i++)
|
||||
dst[i] = out.b[i];
|
||||
|
||||
/* copy the left over in the buffer */
|
||||
ctx->prev_len = 64 - bytes;
|
||||
ctx->prev_ofs = i;
|
||||
for (; i < 64; i++)
|
||||
ctx->prev[i] = out.b[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -38,11 +38,20 @@ typedef union {
|
||||
|
||||
typedef block cryptonite_salsa_state;
|
||||
|
||||
typedef struct {
|
||||
cryptonite_salsa_state st;
|
||||
uint8_t prev[64];
|
||||
uint8_t prev_ofs;
|
||||
uint8_t prev_len;
|
||||
uint8_t nb_rounds;
|
||||
} cryptonite_salsa_context;
|
||||
|
||||
/* for scrypt */
|
||||
void cryptonite_salsa_core_xor(int rounds, block *out, block *in);
|
||||
|
||||
void cryptonite_salsa_init(cryptonite_salsa_state *st, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
|
||||
void cryptonite_salsa_combine(uint32_t rounds, block *dst, cryptonite_salsa_state *st, const block *src, uint32_t bytes);
|
||||
void cryptonite_salsa_generate(uint32_t rounds, block *dst, cryptonite_salsa_state *st, uint32_t bytes);
|
||||
void cryptonite_salsa_init_core(cryptonite_salsa_state *st, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
|
||||
void cryptonite_salsa_init(cryptonite_salsa_context *ctx, uint8_t nb_rounds, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv);
|
||||
void cryptonite_salsa_combine(uint8_t *dst, cryptonite_salsa_context *st, const uint8_t *src, uint32_t bytes);
|
||||
void cryptonite_salsa_generate(uint8_t *dst, cryptonite_salsa_context *st, uint32_t bytes);
|
||||
|
||||
#endif
|
||||
|
||||
@ -38,7 +38,9 @@ instance Arbitrary RandomVector where
|
||||
tests = testGroup "Salsa"
|
||||
[ testGroup "KAT" $
|
||||
map (\(i,f) -> testCase (show (i :: Int)) f) $ zip [1..] $ map (\(r, k,i,e) -> salsaRunSimple e r k i) vectors
|
||||
, testProperty "chunking" salsaChunks
|
||||
, testProperty "generate-combine" salsaGenerateCombine
|
||||
, testProperty "chunking-generate" salsaGenerateChunks
|
||||
, testProperty "chunking-combine" salsaCombineChunks
|
||||
]
|
||||
where
|
||||
salsaRunSimple expected rounds key nonce =
|
||||
@ -55,8 +57,20 @@ tests = testGroup "Salsa"
|
||||
in e : salsaLoop (current + B.length expectBs) salsaNext rs
|
||||
| otherwise = error "internal error in salsaLoop"
|
||||
|
||||
salsaChunks :: ChunkingLen -> RandomVector -> Bool
|
||||
salsaChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) =
|
||||
salsaGenerateCombine :: ChunkingLen0_127 -> RandomVector -> Int0_2901 -> Bool
|
||||
salsaGenerateCombine (ChunkingLen0_127 ckLen) (RandomVector (rounds, key, iv, _)) (Int0_2901 nbBytes) =
|
||||
let initSalsa = Salsa.initialize rounds key iv
|
||||
in loop nbBytes ckLen initSalsa
|
||||
where loop n [] salsa = loop n ckLen salsa
|
||||
loop 0 _ _ = True
|
||||
loop n (x:xs) salsa =
|
||||
let len = min x n
|
||||
(c1, next) = Salsa.generate salsa len
|
||||
(c2, _) = Salsa.combine salsa (B.replicate len 0)
|
||||
in if c1 == c2 then loop (n - len) xs next else False
|
||||
|
||||
salsaGenerateChunks :: ChunkingLen -> RandomVector -> Bool
|
||||
salsaGenerateChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) =
|
||||
let initSalsa = Salsa.initialize rounds key iv
|
||||
nbBytes = 1048
|
||||
(expected,_) = Salsa.generate initSalsa nbBytes
|
||||
@ -69,3 +83,18 @@ tests = testGroup "Salsa"
|
||||
let len = min x n
|
||||
(c, next) = Salsa.generate salsa len
|
||||
in c : loop (n - len) xs next
|
||||
|
||||
salsaCombineChunks :: ChunkingLen -> RandomVector -> ArbitraryBS0_2901 -> Bool
|
||||
salsaCombineChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) (ArbitraryBS0_2901 wholebs) =
|
||||
let initSalsa = Salsa.initialize rounds key iv
|
||||
(expected,_) = Salsa.combine initSalsa wholebs
|
||||
chunks = loop wholebs ckLen initSalsa
|
||||
in expected `propertyEq` B.concat chunks
|
||||
|
||||
where loop bs [] salsa = loop bs ckLen salsa
|
||||
loop bs (x:xs) salsa
|
||||
| B.null bs = []
|
||||
| otherwise =
|
||||
let (bs1, bs2) = B.splitAt (min x (B.length bs)) bs
|
||||
(c, next) = Salsa.combine salsa bs1
|
||||
in c : loop bs2 xs next
|
||||
|
||||
Loading…
Reference in New Issue
Block a user