[PBKDF2] make the code more friendly to a future mutable PRF.

Prevent doing B.take . B.concat by allocating only once the output buffer
This commit is contained in:
Vincent Hanquez 2015-05-03 08:17:03 +01:00
parent c1ed30b20e
commit c23ddb2eaa
3 changed files with 37 additions and 29 deletions

View File

@ -19,16 +19,15 @@ import Data.Word
import Data.Bits
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B (unsafeCreate, memset)
import Foreign.Storable
import Foreign.Ptr (Ptr, plusPtr)
import Control.Applicative
import Control.Monad (forM_, void)
import Foreign.Marshal.Alloc
import Foreign.Ptr (plusPtr)
import Crypto.Hash (HashAlgorithm)
import qualified Crypto.MAC.HMAC as HMAC
import qualified Crypto.Internal.ByteArray as B (convert, withByteArray)
import Crypto.Internal.ByteArray (ByteArray)
import qualified Crypto.Internal.ByteArray as B (allocAndFreeze, convert, withByteArray)
import Crypto.Internal.Bytes
-- | The PRF used for PBKDF2
type PRF = B.ByteString -- ^ the password parameters
@ -52,36 +51,44 @@ data Parameters = Parameters
}
-- | generate the pbkdf2 key derivation function from the output
generate :: PRF -> Parameters -> B.ByteString
generate :: ByteArray ba => PRF -> Parameters -> ba
generate prf params =
B.take (outputLength params) $ B.concat $ map f [1..l]
B.allocAndFreeze (outputLength params) $ \p -> do
bufSet p 0 (outputLength params)
loop 1 (outputLength params) p
where
!runPRF = prf (password params)
!hLen = B.length $ runPRF B.empty
-- run the following f function on each complete chunk.
-- when having an incomplete chunk, we call partial.
-- partial need to be the last call.
--
-- f(pass,salt,c,i) = U1 xor U2 xor .. xor Uc
-- U1 = PRF(pass,salt || BE32(i))
-- Uc = PRF(pass,Uc-1)
f iterNb = B.unsafeCreate hLen $ \dst -> do
loop iterNb len p
| len == 0 = return ()
| len < hLen = partial iterNb len p
| otherwise = do
let applyMany 0 _ = return ()
applyMany i uprev = do
let uData = runPRF uprev
B.withByteArray uData $ \u -> bufXor p p u hLen
applyMany (i-1) uData
applyMany (iterCounts params) (salt params `B.append` toBS iterNb)
loop (iterNb+1) (len - hLen) (p `plusPtr` hLen)
partial iterNb len p = allocaBytesAligned hLen 8 $ \tmp -> do
let applyMany 0 _ = return ()
applyMany i uprev =
let u = runPRF uprev
in bsXor dst u >> applyMany (i-1) u
void $ B.memset dst 0 (fromIntegral hLen)
applyMany i uprev = do
let uData = runPRF uprev
B.withByteArray uData $ \u -> bufXor tmp tmp u hLen
applyMany (i-1) uData
bufSet tmp 0 hLen
applyMany (iterCounts params) (salt params `B.append` toBS iterNb)
bufCopy p tmp len
-- a mutable version of xor, that allow to not reallocate
-- the accumulate buffer.
bsXor :: Ptr Word8 -> ByteString -> IO ()
bsXor d sBs = B.withByteArray sBs $ \s ->
forM_ [0..hLen-1] $ \i -> do
v <- xor <$> peek (s `plusPtr` i) <*> peek (d `plusPtr` i)
poke (d `plusPtr` i) (v :: Word8)
-- count the number of blocks necessary
l = let (q,rema) = (outputLength params) `divMod` hLen
in fromIntegral (q + if rema > 0 then 1 else 0)
-- big endian encoding of Word32
toBS :: Word32 -> ByteString
toBS w = B.pack [a,b,c,d]
@ -89,3 +96,4 @@ generate prf params =
b = fromIntegral ((w `shiftR` 16) .&. 0xff)
c = fromIntegral ((w `shiftR` 8) .&. 0xff)
d = fromIntegral (w .&. 0xff)
{-# NOINLINE generate #-}

View File

@ -48,11 +48,10 @@ generate params
| popCount (n params) /= 1 =
error "Scrypt: invalid parameters: n not a power of 2"
| otherwise = unsafeDoIO $ do
let b = PBKDF2.generate prf (PBKDF2.Parameters (password params) (salt params) 1 intLen)
newSalt <- B.alloc intLen $ \bPtr ->
let b = PBKDF2.generate prf (PBKDF2.Parameters (password params) (salt params) 1 intLen) :: B.Bytes
newSalt <- B.copy b $ \bPtr ->
allocaBytesAligned (128*(fromIntegral $ n params)*(r params)) 8 $ \v ->
allocaBytesAligned (256*r params) 8 $ \xy -> do
B.withByteArray b $ \bOrig -> bufCopy bPtr bOrig intLen
forM_ [0..(p params-1)] $ \i ->
ccryptonite_scrypt_smix (bPtr `plusPtr` (i * 128 * (r params)))
(fromIntegral $ r params) (n params) v xy

View File

@ -32,6 +32,7 @@ vectors_hmac_sha1 =
)
]
vectors_hmac_sha256 :: [ (VectParams, ByteString) ]
vectors_hmac_sha256 =
[ ( ("password", "salt", 2, 32)
, "\xae\x4d\x0c\x95\xaf\x6b\x46\xd3\x2d\x0a\xdf\xf9\x28\xf0\x6d\xd0\x2a\x30\x3f\x8e\xf3\xc2\x51\xdf\xd6\xe2\xd8\x5a\x95\x47\x4c\x43"