diff --git a/Crypto/Cipher/Salsa.hs b/Crypto/Cipher/Salsa.hs index 3252977..9aa2926 100644 --- a/Crypto/Cipher/Salsa.hs +++ b/Crypto/Cipher/Salsa.hs @@ -13,30 +13,15 @@ module Crypto.Cipher.Salsa , State ) where -import Data.ByteString (ByteString) -import Data.Memory.PtrMethods (memXor) import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes) import qualified Crypto.Internal.ByteArray as B -import qualified Data.ByteString.Internal as BS -import qualified Data.ByteString as BS -import Crypto.Internal.Compat -import Crypto.Internal.Imports -import Foreign.Ptr -import Foreign.ForeignPtr -import Foreign.C.Types +import Crypto.Internal.Compat +import Crypto.Internal.Imports +import Foreign.Ptr +import Foreign.C.Types -- | Salsa context -data State = State Int -- number of rounds - ScrubbedBytes -- Salsa's state - Int -- offset of data in previously generated chunk - ByteString -- previous generated chunk - -round64 :: Int -> (Bool, Int) -round64 len - | len == 0 = (True, 0) - | m == 0 = (True, len) - | otherwise = (False, len + (64 - m)) - where m = len `mod` 64 +newtype State = State ScrubbedBytes -- | Initialize a new Salsa context with the number of rounds, -- the key and the nonce associated. @@ -50,11 +35,11 @@ initialize nbRounds key nonce | not (nonceLen `elem` [8,12]) = error "Salsa: nonce length should be 64 or 96 bits" | not (nbRounds `elem` [8,12,20]) = error "Salsa: rounds should be 8, 12 or 20" | otherwise = unsafeDoIO $ do - stPtr <- B.alloc 64 $ \stPtr -> + stPtr <- B.alloc 132 $ \stPtr -> B.withByteArray nonce $ \noncePtr -> B.withByteArray key $ \keyPtr -> - ccryptonite_salsa_init stPtr kLen keyPtr nonceLen noncePtr - return $ State nbRounds stPtr 0 B.empty + ccryptonite_salsa_init stPtr (fromIntegral nbRounds) kLen keyPtr nonceLen noncePtr + return $ State stPtr where kLen = B.length key nonceLen = B.length nonce @@ -64,59 +49,33 @@ combine :: ByteArray ba => State -- ^ the current Salsa state -> ba -- ^ the source to xor with the generator -> (ba, State) -combine prev@(State nbRounds prevSt prevOffset prevOut) src - | outputLen == 0 = (B.empty, prev) - | outputLen <= prevBufLen = unsafeDoIO $ do - -- we have enough byte in the previous buffer to complete the query - -- without having to generate any extra bytes - output <- B.copy src $ \dst -> - B.withByteArray prevOut $ \prevPtr -> - memXor dst dst (prevPtr `plusPtr` prevOffset) outputLen - return (output, State nbRounds prevSt (prevOffset + outputLen) prevOut) - | otherwise = unsafeDoIO $ do - -- adjusted len is the number of bytes lefts to generate after - -- copying from the previous buffer. - let adjustedLen = outputLen - prevBufLen - (roundedAlready, newBytesToGenerate) = round64 adjustedLen - nextBufLen = newBytesToGenerate - adjustedLen - - fptr <- BS.mallocByteString (newBytesToGenerate + prevBufLen) - newSt <- withForeignPtr fptr $ \dstPtr -> - B.withByteArray src $ \srcPtr -> do - -- copy the previous buffer by xor if any - B.withByteArray prevOut $ \prevPtr -> - memXor dstPtr srcPtr (prevPtr `plusPtr` prevOffset) prevBufLen - - -- then create a new mutable copy of state - B.copy prevSt $ \stPtr -> - ccryptonite_salsa_combine nbRounds - (dstPtr `plusPtr` prevBufLen) - (castPtr stPtr) - (srcPtr `plusPtr` prevBufLen) - (fromIntegral adjustedLen) - -- return combined byte - return ( B.convert (BS.PS fptr 0 outputLen) - , State nbRounds newSt 0 (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen)) - where - outputLen = B.length src - prevBufLen = B.length prevOut - prevOffset +combine prevSt@(State prevStMem) src + | B.null src = (B.empty, prevSt) + | otherwise = unsafeDoIO $ do + (out, st) <- B.copyRet prevStMem $ \ctx -> + B.alloc (B.length src) $ \dstPtr -> + B.withByteArray src $ \srcPtr -> do + ccryptonite_salsa_combine dstPtr ctx srcPtr (fromIntegral $ B.length src) + return (out, State st) -- | Generate a number of bytes from the Salsa output directly --- --- TODO: use salsa_generate directly instead of using combine xor'ing with 0. generate :: ByteArray ba => State -- ^ the current Salsa state -> Int -- ^ the length of data to generate -> (ba, State) -generate st len = combine st (B.zero len) +generate prevSt@(State prevStMem) len + | len <= 0 = (B.empty, prevSt) + | otherwise = unsafeDoIO $ do + (out, st) <- B.copyRet prevStMem $ \ctx -> + B.alloc len $ \dstPtr -> + ccryptonite_salsa_generate dstPtr ctx (fromIntegral len) + return (out, State st) foreign import ccall "cryptonite_salsa_init" - ccryptonite_salsa_init :: Ptr State -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO () + ccryptonite_salsa_init :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO () foreign import ccall "cryptonite_salsa_combine" - ccryptonite_salsa_combine :: Int -> Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO () + ccryptonite_salsa_combine :: Ptr Word8 -> Ptr State -> Ptr Word8 -> CUInt -> IO () -{- foreign import ccall "cryptonite_salsa_generate" - ccryptonite_salsa_generate :: Int -> Ptr Word8 -> Ptr State -> CUInt -> IO () --} + ccryptonite_salsa_generate :: Ptr Word8 -> Ptr State -> CUInt -> IO () diff --git a/cbits/cryptonite_salsa.c b/cbits/cryptonite_salsa.c index a82e792..0bd9660 100644 --- a/cbits/cryptonite_salsa.c +++ b/cbits/cryptonite_salsa.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014 Vincent Hanquez + * Copyright (c) 2014-2015 Vincent Hanquez * * All rights reserved. * @@ -29,11 +29,10 @@ */ #include +#include +#include #include "cryptonite_salsa.h" #include "cryptonite_bitfn.h" -#include - -#define USE_8BITS 0 static const uint8_t sigma[16] = "expand 32-byte k"; static const uint8_t tau[16] = "expand 16-byte k"; @@ -44,6 +43,9 @@ static const uint8_t tau[16] = "expand 16-byte k"; d ^= rol32(c+b, 13); \ a ^= rol32(d+c, 18); +#define ALIGNED64(PTR) \ + (((uintptr_t)(const void *)(PTR)) % 8 == 0) + #define SALSA_CORE_LOOP \ for (i = rounds; i > 0; i -= 2) { \ QR (x0,x4,x8,x12); \ @@ -117,9 +119,9 @@ void cryptonite_salsa_core_xor(int rounds, block *out, block *in) } /* only 2 valid values for keylen are 256 (32) and 128 (16) */ -void cryptonite_salsa_init(cryptonite_salsa_state *st, - uint32_t keylen, const uint8_t *key, - uint32_t ivlen, const uint8_t *iv) +void cryptonite_salsa_init_core(cryptonite_salsa_state *st, + uint32_t keylen, const uint8_t *key, + uint32_t ivlen, const uint8_t *iv) { const uint8_t *constants = (keylen == 32) ? sigma : tau; int i; @@ -157,67 +159,139 @@ void cryptonite_salsa_init(cryptonite_salsa_state *st, } } -void cryptonite_salsa_combine(uint32_t rounds, block *dst, cryptonite_salsa_state *st, const block *src, uint32_t bytes) +void cryptonite_salsa_init(cryptonite_salsa_context *ctx, uint8_t nb_rounds, + uint32_t keylen, const uint8_t *key, + uint32_t ivlen, const uint8_t *iv) { + memset(ctx, 0, sizeof(*ctx)); + ctx->nb_rounds = nb_rounds; + cryptonite_salsa_init_core(&ctx->st, keylen, key, ivlen, iv); +} + +void cryptonite_salsa_combine(uint8_t *dst, cryptonite_salsa_context *ctx, const uint8_t *src, uint32_t bytes) +{ + block out; + cryptonite_salsa_state *st; + int i; + + if (!bytes) + return; + + /* xor the previous buffer first (if any) */ + if (ctx->prev_len > 0) { + int to_copy = (ctx->prev_len < bytes) ? ctx->prev_len : bytes; + for (i = 0; i < to_copy; i++) + dst[i] = src[i] ^ ctx->prev[ctx->prev_ofs+i]; + memset(ctx->prev + ctx->prev_ofs, 0, to_copy); + ctx->prev_len -= to_copy; + ctx->prev_ofs += to_copy; + src += to_copy; + dst += to_copy; + bytes -= to_copy; + } + + if (bytes == 0) + return; + + st = &ctx->st; + + /* xor new 64-bytes chunks and store the left over if any */ + for (; bytes >= 64; bytes -= 64, src += 64, dst += 64) { + /* generate new chunk and update state */ + salsa_core(ctx->nb_rounds, &out, st); + st->d[8] += 1; + if (st->d[8] == 0) + st->d[9] += 1; + + for (i = 0; i < 64; ++i) + dst[i] = src[i] ^ out.b[i]; + } + + if (bytes > 0) { + /* generate new chunk and update state */ + salsa_core(ctx->nb_rounds, &out, st); + st->d[8] += 1; + if (st->d[8] == 0) + st->d[9] += 1; + + /* xor as much as needed */ + for (i = 0; i < bytes; i++) + dst[i] = src[i] ^ out.b[i]; + + /* copy the left over in the buffer */ + ctx->prev_len = 64 - bytes; + ctx->prev_ofs = i; + for (; i < 64; i++) { + ctx->prev[i] = out.b[i]; + } + } +} + +void cryptonite_salsa_generate(uint8_t *dst, cryptonite_salsa_context *ctx, uint32_t bytes) +{ + cryptonite_salsa_state *st; block out; int i; if (!bytes) return; - for (;; bytes -= 64, src += 1, dst += 1) { - salsa_core(rounds, &out, st); - - st->d[8] += 1; - if (st->d[8] == 0) - st->d[9] += 1; - - if (bytes <= 64) { - for (i = 0; i < bytes; i++) - dst->b[i] = src->b[i] ^ out.b[i]; - for (; i < 64; i++) - dst->b[i] = out.b[i]; - return; - } - -#if USE_8BITS - for (i = 0; i < 64; ++i) - dst->b[i] = src->b[i] ^ out.b[i]; -#else - /* fast copy using 64 bits */ - for (i = 0; i < 8; i++) - dst->q[i] = src->q[i] ^ out.q[i]; -#endif + /* xor the previous buffer first (if any) */ + if (ctx->prev_len > 0) { + int to_copy = (ctx->prev_len < bytes) ? ctx->prev_len : bytes; + for (i = 0; i < to_copy; i++) + dst[i] = ctx->prev[ctx->prev_ofs+i]; + memset(ctx->prev + ctx->prev_ofs, 0, to_copy); + ctx->prev_len -= to_copy; + ctx->prev_ofs += to_copy; + dst += to_copy; + bytes -= to_copy; } -} -void cryptonite_salsa_generate(uint32_t rounds, block *dst, cryptonite_salsa_state *st, uint32_t bytes) -{ - block out; - int i; - - if (!bytes) + if (bytes == 0) return; - for (;; bytes -= 64, dst += 1) { - salsa_core(rounds, &out, st); + st = &ctx->st; + if (ALIGNED64(dst)) { + /* xor new 64-bytes chunks and store the left over if any */ + for (; bytes >= 64; bytes -= 64, dst += 64) { + /* generate new chunk and update state */ + salsa_core(ctx->nb_rounds, (block *) dst, st); + st->d[8] += 1; + if (st->d[8] == 0) + st->d[9] += 1; + } + } else { + /* xor new 64-bytes chunks and store the left over if any */ + for (; bytes >= 64; bytes -= 64, dst += 64) { + /* generate new chunk and update state */ + salsa_core(ctx->nb_rounds, &out, st); + st->d[8] += 1; + if (st->d[8] == 0) + st->d[9] += 1; + + for (i = 0; i < 64; ++i) + dst[i] = out.b[i]; + } + } + + if (bytes > 0) { + /* generate new chunk and update state */ + salsa_core(ctx->nb_rounds, &out, st); st->d[8] += 1; if (st->d[8] == 0) st->d[9] += 1; - if (bytes <= 64) { - for (i = 0; i < bytes; ++i) - dst->b[i] = out.b[i]; - return; - } -#if USE_8BITS - for (i = 0; i < 64; ++i) - dst->b[i] = out.b[i]; -#else - for (i = 0; i < 8; i++) - dst->q[i] = out.q[i]; -#endif + /* xor as much as needed */ + for (i = 0; i < bytes; i++) + dst[i] = out.b[i]; + + /* copy the left over in the buffer */ + ctx->prev_len = 64 - bytes; + ctx->prev_ofs = i; + for (; i < 64; i++) + ctx->prev[i] = out.b[i]; } } diff --git a/cbits/cryptonite_salsa.h b/cbits/cryptonite_salsa.h index 981ac98..33e9cda 100644 --- a/cbits/cryptonite_salsa.h +++ b/cbits/cryptonite_salsa.h @@ -38,11 +38,20 @@ typedef union { typedef block cryptonite_salsa_state; +typedef struct { + cryptonite_salsa_state st; + uint8_t prev[64]; + uint8_t prev_ofs; + uint8_t prev_len; + uint8_t nb_rounds; +} cryptonite_salsa_context; + /* for scrypt */ void cryptonite_salsa_core_xor(int rounds, block *out, block *in); -void cryptonite_salsa_init(cryptonite_salsa_state *st, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv); -void cryptonite_salsa_combine(uint32_t rounds, block *dst, cryptonite_salsa_state *st, const block *src, uint32_t bytes); -void cryptonite_salsa_generate(uint32_t rounds, block *dst, cryptonite_salsa_state *st, uint32_t bytes); +void cryptonite_salsa_init_core(cryptonite_salsa_state *st, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv); +void cryptonite_salsa_init(cryptonite_salsa_context *ctx, uint8_t nb_rounds, uint32_t keylen, const uint8_t *key, uint32_t ivlen, const uint8_t *iv); +void cryptonite_salsa_combine(uint8_t *dst, cryptonite_salsa_context *st, const uint8_t *src, uint32_t bytes); +void cryptonite_salsa_generate(uint8_t *dst, cryptonite_salsa_context *st, uint32_t bytes); #endif diff --git a/tests/Salsa.hs b/tests/Salsa.hs index b111bd2..1ecda19 100644 --- a/tests/Salsa.hs +++ b/tests/Salsa.hs @@ -38,7 +38,9 @@ instance Arbitrary RandomVector where tests = testGroup "Salsa" [ testGroup "KAT" $ map (\(i,f) -> testCase (show (i :: Int)) f) $ zip [1..] $ map (\(r, k,i,e) -> salsaRunSimple e r k i) vectors - , testProperty "chunking" salsaChunks + , testProperty "generate-combine" salsaGenerateCombine + , testProperty "chunking-generate" salsaGenerateChunks + , testProperty "chunking-combine" salsaCombineChunks ] where salsaRunSimple expected rounds key nonce = @@ -55,8 +57,20 @@ tests = testGroup "Salsa" in e : salsaLoop (current + B.length expectBs) salsaNext rs | otherwise = error "internal error in salsaLoop" - salsaChunks :: ChunkingLen -> RandomVector -> Bool - salsaChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) = + salsaGenerateCombine :: ChunkingLen0_127 -> RandomVector -> Int0_2901 -> Bool + salsaGenerateCombine (ChunkingLen0_127 ckLen) (RandomVector (rounds, key, iv, _)) (Int0_2901 nbBytes) = + let initSalsa = Salsa.initialize rounds key iv + in loop nbBytes ckLen initSalsa + where loop n [] salsa = loop n ckLen salsa + loop 0 _ _ = True + loop n (x:xs) salsa = + let len = min x n + (c1, next) = Salsa.generate salsa len + (c2, _) = Salsa.combine salsa (B.replicate len 0) + in if c1 == c2 then loop (n - len) xs next else False + + salsaGenerateChunks :: ChunkingLen -> RandomVector -> Bool + salsaGenerateChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) = let initSalsa = Salsa.initialize rounds key iv nbBytes = 1048 (expected,_) = Salsa.generate initSalsa nbBytes @@ -69,3 +83,18 @@ tests = testGroup "Salsa" let len = min x n (c, next) = Salsa.generate salsa len in c : loop (n - len) xs next + + salsaCombineChunks :: ChunkingLen -> RandomVector -> ArbitraryBS0_2901 -> Bool + salsaCombineChunks (ChunkingLen ckLen) (RandomVector (rounds, key, iv, _)) (ArbitraryBS0_2901 wholebs) = + let initSalsa = Salsa.initialize rounds key iv + (expected,_) = Salsa.combine initSalsa wholebs + chunks = loop wholebs ckLen initSalsa + in expected `propertyEq` B.concat chunks + + where loop bs [] salsa = loop bs ckLen salsa + loop bs (x:xs) salsa + | B.null bs = [] + | otherwise = + let (bs1, bs2) = B.splitAt (min x (B.length bs)) bs + (c, next) = Salsa.combine salsa bs1 + in c : loop bs2 xs next