diff --git a/Crypto/KDF/PBKDF2.hs b/Crypto/KDF/PBKDF2.hs index abeab04..6e9f09a 100644 --- a/Crypto/KDF/PBKDF2.hs +++ b/Crypto/KDF/PBKDF2.hs @@ -19,16 +19,15 @@ import Data.Word import Data.Bits import Data.ByteString (ByteString) import qualified Data.ByteString as B -import qualified Data.ByteString.Internal as B (unsafeCreate, memset) -import Foreign.Storable -import Foreign.Ptr (Ptr, plusPtr) -import Control.Applicative -import Control.Monad (forM_, void) +import Foreign.Marshal.Alloc +import Foreign.Ptr (plusPtr) import Crypto.Hash (HashAlgorithm) import qualified Crypto.MAC.HMAC as HMAC -import qualified Crypto.Internal.ByteArray as B (convert, withByteArray) +import Crypto.Internal.ByteArray (ByteArray) +import qualified Crypto.Internal.ByteArray as B (allocAndFreeze, convert, withByteArray) +import Crypto.Internal.Bytes -- | The PRF used for PBKDF2 type PRF = B.ByteString -- ^ the password parameters @@ -52,36 +51,44 @@ data Parameters = Parameters } -- | generate the pbkdf2 key derivation function from the output -generate :: PRF -> Parameters -> B.ByteString +generate :: ByteArray ba => PRF -> Parameters -> ba generate prf params = - B.take (outputLength params) $ B.concat $ map f [1..l] + B.allocAndFreeze (outputLength params) $ \p -> do + bufSet p 0 (outputLength params) + loop 1 (outputLength params) p where !runPRF = prf (password params) !hLen = B.length $ runPRF B.empty - + + -- run the following f function on each complete chunk. + -- when having an incomplete chunk, we call partial. + -- partial need to be the last call. + -- -- f(pass,salt,c,i) = U1 xor U2 xor .. xor Uc -- U1 = PRF(pass,salt || BE32(i)) -- Uc = PRF(pass,Uc-1) - f iterNb = B.unsafeCreate hLen $ \dst -> do + loop iterNb len p + | len == 0 = return () + | len < hLen = partial iterNb len p + | otherwise = do + let applyMany 0 _ = return () + applyMany i uprev = do + let uData = runPRF uprev + B.withByteArray uData $ \u -> bufXor p p u hLen + applyMany (i-1) uData + applyMany (iterCounts params) (salt params `B.append` toBS iterNb) + loop (iterNb+1) (len - hLen) (p `plusPtr` hLen) + + partial iterNb len p = allocaBytesAligned hLen 8 $ \tmp -> do let applyMany 0 _ = return () - applyMany i uprev = - let u = runPRF uprev - in bsXor dst u >> applyMany (i-1) u - void $ B.memset dst 0 (fromIntegral hLen) + applyMany i uprev = do + let uData = runPRF uprev + B.withByteArray uData $ \u -> bufXor tmp tmp u hLen + applyMany (i-1) uData + bufSet tmp 0 hLen applyMany (iterCounts params) (salt params `B.append` toBS iterNb) + bufCopy p tmp len - -- a mutable version of xor, that allow to not reallocate - -- the accumulate buffer. - bsXor :: Ptr Word8 -> ByteString -> IO () - bsXor d sBs = B.withByteArray sBs $ \s -> - forM_ [0..hLen-1] $ \i -> do - v <- xor <$> peek (s `plusPtr` i) <*> peek (d `plusPtr` i) - poke (d `plusPtr` i) (v :: Word8) - - -- count the number of blocks necessary - l = let (q,rema) = (outputLength params) `divMod` hLen - in fromIntegral (q + if rema > 0 then 1 else 0) - -- big endian encoding of Word32 toBS :: Word32 -> ByteString toBS w = B.pack [a,b,c,d] @@ -89,3 +96,4 @@ generate prf params = b = fromIntegral ((w `shiftR` 16) .&. 0xff) c = fromIntegral ((w `shiftR` 8) .&. 0xff) d = fromIntegral (w .&. 0xff) +{-# NOINLINE generate #-} diff --git a/Crypto/KDF/Scrypt.hs b/Crypto/KDF/Scrypt.hs index 21d430d..557172b 100644 --- a/Crypto/KDF/Scrypt.hs +++ b/Crypto/KDF/Scrypt.hs @@ -48,11 +48,10 @@ generate params | popCount (n params) /= 1 = error "Scrypt: invalid parameters: n not a power of 2" | otherwise = unsafeDoIO $ do - let b = PBKDF2.generate prf (PBKDF2.Parameters (password params) (salt params) 1 intLen) - newSalt <- B.alloc intLen $ \bPtr -> + let b = PBKDF2.generate prf (PBKDF2.Parameters (password params) (salt params) 1 intLen) :: B.Bytes + newSalt <- B.copy b $ \bPtr -> allocaBytesAligned (128*(fromIntegral $ n params)*(r params)) 8 $ \v -> allocaBytesAligned (256*r params) 8 $ \xy -> do - B.withByteArray b $ \bOrig -> bufCopy bPtr bOrig intLen forM_ [0..(p params-1)] $ \i -> ccryptonite_scrypt_smix (bPtr `plusPtr` (i * 128 * (r params))) (fromIntegral $ r params) (n params) v xy diff --git a/tests/KAT_PBKDF2.hs b/tests/KAT_PBKDF2.hs index 978f9e2..6f2cd79 100644 --- a/tests/KAT_PBKDF2.hs +++ b/tests/KAT_PBKDF2.hs @@ -32,6 +32,7 @@ vectors_hmac_sha1 = ) ] +vectors_hmac_sha256 :: [ (VectParams, ByteString) ] vectors_hmac_sha256 = [ ( ("password", "salt", 2, 32) , "\xae\x4d\x0c\x95\xaf\x6b\x46\xd3\x2d\x0a\xdf\xf9\x28\xf0\x6d\xd0\x2a\x30\x3f\x8e\xf3\xc2\x51\xdf\xd6\xe2\xd8\x5a\x95\x47\x4c\x43"