use byte array in more places

This commit is contained in:
Vincent Hanquez 2015-04-24 17:22:13 +01:00
parent 6722a02a74
commit 9dd17fc0c4
8 changed files with 108 additions and 108 deletions

View File

@ -90,7 +90,7 @@ combine prev@(State nbRounds prevSt prevOut) src
-- we have enough byte in the previous buffer to complete the query -- we have enough byte in the previous buffer to complete the query
-- without having to generate any extra bytes -- without having to generate any extra bytes
let (b1,b2) = BS.splitAt outputLen prevOut let (b1,b2) = BS.splitAt outputLen prevOut
in (BS.pack $ BS.zipWith xor b1 src, State nbRounds prevSt b2) in (B.xor b1 src, State nbRounds prevSt b2)
| otherwise = unsafeDoIO $ do | otherwise = unsafeDoIO $ do
-- adjusted len is the number of bytes lefts to generate after -- adjusted len is the number of bytes lefts to generate after
-- copying from the previous buffer. -- copying from the previous buffer.
@ -106,14 +106,13 @@ combine prev@(State nbRounds prevSt prevOut) src
loopXor dstPtr srcPtr prevPtr prevBufLen loopXor dstPtr srcPtr prevPtr prevBufLen
-- then create a new mutable copy of state -- then create a new mutable copy of state
st <- B.copy prevSt (\_ -> return ()) B.copy prevSt $ \stPtr ->
withByteArray st $ \stPtr ->
ccryptonite_chacha_combine nbRounds ccryptonite_chacha_combine nbRounds
(dstPtr `plusPtr` prevBufLen) (dstPtr `plusPtr` prevBufLen)
(castPtr stPtr) (castPtr stPtr)
(srcPtr `plusPtr` prevBufLen) (srcPtr `plusPtr` prevBufLen)
(fromIntegral newBytesToGenerate) (fromIntegral newBytesToGenerate)
return st
-- return combined byte -- return combined byte
return ( BS.PS fptr 0 outputLen return ( BS.PS fptr 0 outputLen
, State nbRounds newSt (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen)) , State nbRounds newSt (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen))

View File

@ -1,4 +1,3 @@
{-# LANGUAGE ForeignFunctionInterface #-}
-- | -- |
-- Module : Crypto.Cipher.RC4 -- Module : Crypto.Cipher.RC4
-- License : BSD-style -- License : BSD-style
@ -13,6 +12,8 @@
-- --
-- Reorganized and simplified to have an opaque context. -- Reorganized and simplified to have an opaque context.
-- --
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.Cipher.RC4 module Crypto.Cipher.RC4
( initialize ( initialize
, combine , combine
@ -20,19 +21,16 @@ module Crypto.Cipher.RC4
, State , State
) where ) where
import Data.Word import Data.Word
import Data.Byteable import Foreign.Ptr
import Data.SecureMem import Crypto.Internal.ByteArray (SecureBytes, ByteArray, ByteArrayAccess)
import Foreign.Ptr import qualified Crypto.Internal.ByteArray as B
import Foreign.ForeignPtr
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
import Crypto.Internal.Compat import Crypto.Internal.Compat
-- | The encryption state for RC4 -- | The encryption state for RC4
newtype State = State SecureMem newtype State = State SecureBytes
deriving (ByteArrayAccess)
-- | C Call for initializing the encryptor -- | C Call for initializing the encryptor
foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_init" foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_init"
@ -52,29 +50,29 @@ foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_combine"
-- --
-- seed the context with an initial key. the key size need to be -- seed the context with an initial key. the key size need to be
-- adequate otherwise security takes a hit. -- adequate otherwise security takes a hit.
initialize :: Byteable key initialize :: ByteArrayAccess key
=> key -- ^ The key => key -- ^ The key
-> State -- ^ The RC4 context with the key mixed in -> State -- ^ The RC4 context with the key mixed in
initialize key = unsafeDoIO $ do initialize key = unsafeDoIO $ do
st <- createSecureMem 264 $ \stPtr -> st <- B.alloc 264 $ \stPtr ->
withBytePtr key $ \keyPtr -> c_rc4_init keyPtr (fromIntegral $ byteableLength key) (castPtr stPtr) B.withByteArray key $ \keyPtr -> c_rc4_init keyPtr (fromIntegral $ B.length key) (castPtr stPtr)
return $ State st return $ State st
-- | generate the next len bytes of the rc4 stream without combining -- | generate the next len bytes of the rc4 stream without combining
-- it to anything. -- it to anything.
generate :: State -> Int -> (State, ByteString) generate :: ByteArray ba => State -> Int -> (State, ba)
generate ctx len = combine ctx (B.replicate len 0) generate ctx len = combine ctx (B.zero len)
-- | RC4 xor combination of the rc4 stream with an input -- | RC4 xor combination of the rc4 stream with an input
combine :: State -- ^ rc4 context combine :: ByteArray ba
-> ByteString -- ^ input => State -- ^ rc4 context
-> (State, ByteString) -- ^ new rc4 context, and the output -> ba -- ^ input
combine (State prevSt) clearText = unsafeDoIO $ do -> (State, ba) -- ^ new rc4 context, and the output
outfptr <- B.mallocByteString len combine (State prevSt) clearText = unsafeDoIO $
st <- secureMemCopy prevSt B.allocRet len $ \outptr ->
withSecureMemPtr st $ \stPtr -> B.withByteArray clearText $ \clearPtr -> do
withForeignPtr outfptr $ \outptr -> st <- B.copy prevSt $ \stPtr ->
withBytePtr clearText $ \clearPtr -> c_rc4_combine (castPtr stPtr) clearPtr (fromIntegral len) outptr
c_rc4_combine (castPtr stPtr) clearPtr (fromIntegral len) outptr return $! State st
return $! (State st, B.PS outfptr 0 len) --return $! (State st, B.PS outfptr 0 len)
where len = B.length clearText where len = B.length clearText

View File

@ -20,7 +20,7 @@ module Crypto.Cipher.Types.Base
import Data.Word import Data.Word
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, SecureBytes, withByteArray) import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray)
import qualified Crypto.Internal.ByteArray as B import qualified Crypto.Internal.ByteArray as B
import Crypto.Error import Crypto.Error

View File

@ -36,7 +36,6 @@ module Crypto.Cipher.Types.Block
--, cfb8Decrypt --, cfb8Decrypt
) where ) where
import Data.Byteable
import Data.Word import Data.Word
import Crypto.Error import Crypto.Error
import Crypto.Cipher.Types.Base import Crypto.Cipher.Types.Base
@ -146,12 +145,12 @@ class BlockCipher cipher => BlockCipher128 cipher where
xtsDecrypt = xtsDecryptGeneric xtsDecrypt = xtsDecryptGeneric
-- | Create an IV for a specified block cipher -- | Create an IV for a specified block cipher
makeIV :: (Byteable b, BlockCipher c) => b -> Maybe (IV c) makeIV :: (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV b = toIV undefined makeIV b = toIV undefined
where toIV :: BlockCipher c => c -> Maybe (IV c) where toIV :: BlockCipher c => c -> Maybe (IV c)
toIV cipher toIV cipher
| byteableLength b == sz = Just (IV $ toBytes b) | B.length b == sz = Just $ IV (B.convert b :: Bytes)
| otherwise = Nothing | otherwise = Nothing
where sz = blockSize cipher where sz = blockSize cipher
-- | Create an IV that is effectively representing the number 0 -- | Create an IV that is effectively representing the number 0

View File

@ -21,6 +21,8 @@ module Crypto.Data.AFIS
import Crypto.Hash import Crypto.Hash
import Crypto.Random.Types import Crypto.Random.Types
import Crypto.Internal.Memory (Bytes) import Crypto.Internal.Memory (Bytes)
import Crypto.Internal.Bytes (bufSet, bufCopy)
import Crypto.Internal.Compat
import Crypto.Internal.ByteArray (withByteArray) import Crypto.Internal.ByteArray (withByteArray)
import Control.Monad (forM_, foldM) import Control.Monad (forM_, foldM)
import Data.Byteable import Data.Byteable
@ -29,11 +31,10 @@ import Data.Word
import Data.Bits import Data.Bits
import Foreign.Storable import Foreign.Storable
import Foreign.Ptr import Foreign.Ptr
import Foreign.ForeignPtr (withForeignPtr, newForeignPtr_) import Foreign.ForeignPtr (newForeignPtr_)
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B import qualified Data.ByteString.Internal as B
import System.IO.Unsafe (unsafePerformIO) import qualified Crypto.Internal.ByteArray as B
-- | Split data to diffused data, using a random generator and -- | Split data to diffused data, using a random generator and
-- an hash algorithm. -- an hash algorithm.
@ -62,15 +63,14 @@ split :: (HashAlgorithm a, DRG rng)
{-# NOINLINE split #-} {-# NOINLINE split #-}
split hashF rng expandTimes src split hashF rng expandTimes src
| expandTimes <= 1 = error "invalid expandTimes value" | expandTimes <= 1 = error "invalid expandTimes value"
| otherwise = unsafePerformIO $ do | otherwise = unsafeDoIO $ do
fptr <- B.mallocByteString diffusedLen (rng', bs) <- B.allocRet diffusedLen runOp
rng' <- withForeignPtr fptr runOp return (bs, rng')
return (B.fromForeignPtr fptr 0 diffusedLen, rng')
where diffusedLen = blockSize * expandTimes where diffusedLen = blockSize * expandTimes
blockSize = B.length src blockSize = B.length src
runOp dstPtr = do runOp dstPtr = do
let lastBlock = dstPtr `plusPtr` (blockSize * (expandTimes-1)) let lastBlock = dstPtr `plusPtr` (blockSize * (expandTimes-1))
_ <- B.memset lastBlock 0 (fromIntegral blockSize) bufSet lastBlock 0 blockSize
let randomBlockPtrs = map (plusPtr dstPtr . (*) blockSize) [0..(expandTimes-2)] let randomBlockPtrs = map (plusPtr dstPtr . (*) blockSize) [0..(expandTimes-2)]
rng' <- foldM fillRandomBlock rng randomBlockPtrs rng' <- foldM fillRandomBlock rng randomBlockPtrs
mapM_ (addRandomBlock lastBlock) randomBlockPtrs mapM_ (addRandomBlock lastBlock) randomBlockPtrs
@ -81,7 +81,7 @@ split hashF rng expandTimes src
diffuse hashF lastBlock blockSize diffuse hashF lastBlock blockSize
fillRandomBlock g blockPtr = do fillRandomBlock g blockPtr = do
let (rand :: Bytes, g') = randomBytesGenerate blockSize g let (rand :: Bytes, g') = randomBytesGenerate blockSize g
withByteArray rand $ \randPtr -> B.memcpy blockPtr randPtr (fromIntegral blockSize) withByteArray rand $ \randPtr -> bufCopy blockPtr randPtr (fromIntegral blockSize)
return g' return g'
-- | Merge previously diffused data back to the original data. -- | Merge previously diffused data back to the original data.
@ -94,9 +94,9 @@ merge :: HashAlgorithm a
merge hashF expandTimes bs merge hashF expandTimes bs
| r /= 0 = error "diffused data not a multiple of expandTimes" | r /= 0 = error "diffused data not a multiple of expandTimes"
| originalSize <= 0 = error "diffused data null" | originalSize <= 0 = error "diffused data null"
| otherwise = unsafePerformIO $ B.create originalSize $ \dstPtr -> | otherwise = B.allocAndFreeze originalSize $ \dstPtr ->
withBytePtr bs $ \srcPtr -> do B.withByteArray bs $ \srcPtr -> do
_ <- B.memset dstPtr 0 (fromIntegral originalSize) bufSet dstPtr 0 originalSize
forM_ [0..(expandTimes-2)] $ \i -> do forM_ [0..(expandTimes-2)] $ \i -> do
xorMem (srcPtr `plusPtr` (i * originalSize)) dstPtr originalSize xorMem (srcPtr `plusPtr` (i * originalSize)) dstPtr originalSize
diffuse hashF dstPtr originalSize diffuse hashF dstPtr originalSize
@ -126,10 +126,10 @@ diffuse :: HashAlgorithm a
diffuse hashF src sz = loop src 0 diffuse hashF src sz = loop src 0
where (full,pad) = sz `quotRem` digestSize where (full,pad) = sz `quotRem` digestSize
loop s i | i < full = do h <- hashBlock i `fmap` byteStringOfPtr s digestSize loop s i | i < full = do h <- hashBlock i `fmap` byteStringOfPtr s digestSize
withBytePtr h $ \hPtr -> B.memcpy s hPtr (fromIntegral digestSize) B.withByteArray h $ \hPtr -> bufCopy s hPtr digestSize
loop (s `plusPtr` digestSize) (i+1) loop (s `plusPtr` digestSize) (i+1)
| pad /= 0 = do h <- hashBlock i `fmap` byteStringOfPtr s pad | pad /= 0 = do h <- hashBlock i `fmap` byteStringOfPtr s pad
withBytePtr h $ \hPtr -> B.memcpy s hPtr (fromIntegral pad) B.withByteArray h $ \hPtr -> bufCopy s hPtr pad
return () return ()
| otherwise = return () | otherwise = return ()
@ -139,12 +139,12 @@ diffuse hashF src sz = loop src 0
byteStringOfPtr ptr digestSz = newForeignPtr_ ptr >>= \fptr -> return $ B.fromForeignPtr fptr 0 digestSz byteStringOfPtr ptr digestSz = newForeignPtr_ ptr >>= \fptr -> return $ B.fromForeignPtr fptr 0 digestSz
hashBlock n b = hashBlock n b =
toBytes $ hashF $ B.unsafeCreate (B.length b+4) $ \ptr -> do toBytes $ hashF $ B.allocAndFreeze (B.length b+4) $ \ptr -> do
poke ptr (f8 (n `shiftR` 24)) poke ptr (f8 (n `shiftR` 24))
poke (ptr `plusPtr` 1) (f8 (n `shiftR` 16)) poke (ptr `plusPtr` 1) (f8 (n `shiftR` 16))
poke (ptr `plusPtr` 2) (f8 (n `shiftR` 8)) poke (ptr `plusPtr` 2) (f8 (n `shiftR` 8))
poke (ptr `plusPtr` 3) (f8 n) poke (ptr `plusPtr` 3) (f8 n)
--putWord32BE (fromIntegral n) >> putBytes src) --putWord32BE (fromIntegral n) >> putBytes src)
withBytePtr b $ \srcPtr -> B.memcpy (ptr `plusPtr` 4) srcPtr (fromIntegral $ B.length b) withByteArray b $ \srcPtr -> bufCopy (ptr `plusPtr` 4) srcPtr (B.length b)
where f8 :: Int -> Word8 where f8 :: Int -> Word8
f8 = fromIntegral f8 = fromIntegral

View File

@ -18,12 +18,14 @@ module Crypto.Internal.ByteArray
-- * Inhabitants -- * Inhabitants
, Bytes , Bytes
, SecureBytes , SecureBytes
, MemView(..)
-- * methods -- * methods
, alloc , alloc
, allocAndFreeze , allocAndFreeze
, empty , empty
, zero , zero
, copy , copy
, take
, convert , convert
, copyRet , copyRet
, copyAndFreeze , copyAndFreeze
@ -57,6 +59,12 @@ import qualified Data.ByteString.Internal as B
import Prelude (flip, return, div, (-), ($), (==), (/=), (<=), (>=), Int, Bool(..), IO, otherwise, sum, map, fmap, snd, (.), min) import Prelude (flip, return, div, (-), ($), (==), (/=), (<=), (>=), Int, Bool(..), IO, otherwise, sum, map, fmap, snd, (.), min)
data MemView = MemView !(Ptr Word8) !Int
instance ByteArrayAccess MemView where
length (MemView _ l) = l
withByteArray (MemView p _) f = f (castPtr p)
class ByteArrayAccess ba where class ByteArrayAccess ba where
length :: ba -> Int length :: ba -> Int
withByteArray :: ba -> (Ptr p -> IO a) -> IO a withByteArray :: ba -> (Ptr p -> IO a) -> IO a
@ -126,6 +134,13 @@ split n bs
return (b1, b2) return (b1, b2)
where len = length bs where len = length bs
take :: ByteArray bs => Int -> bs -> bs
take n bs =
allocAndFreeze m $ \d -> withByteArray bs $ \s -> bufCopy d s m
where
m = min len n
len = length bs
concat :: ByteArray bs => [bs] -> bs concat :: ByteArray bs => [bs] -> bs
concat [] = empty concat [] = empty
concat allBs = allocAndFreeze total (loop allBs) concat allBs = allocAndFreeze total (loop allBs)

View File

@ -1,4 +1,3 @@
{-# LANGUAGE ForeignFunctionInterface #-}
-- | -- |
-- Module : Crypto.MAC.Poly1305 -- Module : Crypto.MAC.Poly1305
@ -9,6 +8,8 @@
-- --
-- Poly1305 implementation -- Poly1305 implementation
-- --
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.MAC.Poly1305 module Crypto.MAC.Poly1305
( Ctx ( Ctx
, Auth(..) , Auth(..)
@ -22,27 +23,22 @@ module Crypto.MAC.Poly1305
, auth , auth
) where ) where
import Control.Monad (void) import Foreign.Ptr
import Foreign.Ptr import Foreign.C.Types
import Foreign.C.Types import Data.Word
import qualified Data.ByteString as B import Crypto.Internal.ByteArray (ByteArrayAccess, SecureBytes, Bytes)
import qualified Data.ByteString.Internal as B import qualified Crypto.Internal.ByteArray as B
import Data.ByteString (ByteString)
import Data.Word
import Data.Byteable
import System.IO.Unsafe
import Data.SecureMem
-- | Poly1305 Context -- | Poly1305 Context
newtype Ctx = Ctx SecureMem newtype Ctx = Ctx SecureBytes
deriving (ByteArrayAccess)
-- | Poly1305 Auth -- | Poly1305 Auth
newtype Auth = Auth ByteString newtype Auth = Auth Bytes
deriving (ByteArrayAccess)
instance Eq Auth where instance Eq Auth where
(Auth a1) == (Auth a2) = constEqBytes a1 a2 (Auth a1) == (Auth a2) = B.constEq a1 a2
instance Byteable Auth where
toBytes (Auth b) = b
foreign import ccall unsafe "cryptonite_poly1305.h cryptonite_poly1305_init" foreign import ccall unsafe "cryptonite_poly1305.h cryptonite_poly1305_init"
c_poly1305_init :: Ptr Ctx -> Ptr Word8 -> IO () c_poly1305_init :: Ptr Ctx -> Ptr Word8 -> IO ()
@ -54,55 +50,50 @@ foreign import ccall unsafe "cryptonite_poly1305.h cryptonite_poly1305_finalize"
c_poly1305_finalize :: Ptr Word8 -> Ptr Ctx -> IO () c_poly1305_finalize :: Ptr Word8 -> Ptr Ctx -> IO ()
-- | initialize a Poly1305 context -- | initialize a Poly1305 context
initialize :: Byteable key initialize :: ByteArrayAccess key
=> key => key
-> Ctx -> Ctx
initialize key initialize key
| byteableLength key /= 32 = error "Poly1305: key length expected 32 bytes" | B.length key /= 32 = error "Poly1305: key length expected 32 bytes"
| otherwise = Ctx $ unsafePerformIO $ do | otherwise = Ctx $ B.allocAndFreeze 84 $ \ctxPtr ->
withBytePtr key $ \keyPtr -> B.withByteArray key $ \keyPtr ->
createSecureMem 84 $ \ctxPtr -> c_poly1305_init (castPtr ctxPtr) keyPtr
c_poly1305_init (castPtr ctxPtr) keyPtr
{-# NOINLINE initialize #-} {-# NOINLINE initialize #-}
-- | update a context with a bytestring -- | update a context with a bytestring
update :: Ctx -> ByteString -> Ctx update :: ByteArrayAccess ba => Ctx -> ba -> Ctx
update (Ctx prevCtx) d = unsafePerformIO $ do update (Ctx prevCtx) d = Ctx $ B.copyAndFreeze prevCtx $ \ctxPtr ->
ctx <- secureMemCopy prevCtx B.withByteArray d $ \dataPtr ->
withSecureMemPtr ctx $ \ctxPtr -> c_poly1305_update (castPtr ctxPtr) dataPtr (fromIntegral $ B.length d)
withBytePtr d $ \dataPtr ->
c_poly1305_update (castPtr ctxPtr) dataPtr (fromIntegral $ B.length d)
return $ Ctx ctx
{-# NOINLINE update #-} {-# NOINLINE update #-}
-- | updates a context with multiples bytestring -- | updates a context with multiples bytestring
updates :: Ctx -> [ByteString] -> Ctx updates :: ByteArrayAccess ba => Ctx -> [ba] -> Ctx
updates (Ctx prevCtx) d = unsafePerformIO $ do updates (Ctx prevCtx) d = Ctx $ B.copyAndFreeze prevCtx (loop d)
ctx <- secureMemCopy prevCtx
withSecureMemPtr ctx (loop d . castPtr)
return $ Ctx ctx
where loop [] _ = return () where loop [] _ = return ()
loop (x:xs) ctxPtr = do loop (x:xs) ctxPtr = do
withBytePtr x $ \dataPtr -> c_poly1305_update ctxPtr dataPtr (fromIntegral $ B.length x) B.withByteArray x $ \dataPtr -> c_poly1305_update ctxPtr dataPtr (fromIntegral $ B.length x)
loop xs ctxPtr loop xs ctxPtr
{-# NOINLINE updates #-} {-# NOINLINE updates #-}
-- | finalize the context into a digest bytestring -- | finalize the context into a digest bytestring
finalize :: Ctx -> Auth finalize :: Ctx -> Auth
finalize (Ctx prevCtx) = Auth $ B.unsafeCreate 16 $ \dst -> do finalize (Ctx prevCtx) = Auth $ B.allocAndFreeze 16 $ \dst -> do
ctx <- secureMemCopy prevCtx _ <- B.copy prevCtx (\ctxPtr -> c_poly1305_finalize dst (castPtr ctxPtr)) :: IO SecureBytes
withSecureMemPtr ctx $ \ctxPtr -> c_poly1305_finalize dst (castPtr ctxPtr) return ()
{-# NOINLINE finalize #-} {-# NOINLINE finalize #-}
-- | One-pass authorization creation -- | One-pass authorization creation
auth :: Byteable key => key -> ByteString -> Auth auth :: (ByteArrayAccess key, ByteArrayAccess ba) => key -> ba -> Auth
auth key d auth key d
| byteableLength key /= 32 = error "Poly1305: key length expected 32 bytes" | B.length key /= 32 = error "Poly1305: key length expected 32 bytes"
| otherwise = Auth $ B.unsafeCreate 16 $ \dst -> do | otherwise = Auth $ B.allocAndFreeze 16 $ \dst -> do
-- initialize the context _ <- B.alloc 84 (onCtx dst) :: IO SecureBytes
void $ createSecureMem 84 $ \ctxPtr -> withBytePtr key $ \keyPtr -> do return ()
c_poly1305_init (castPtr ctxPtr) keyPtr where
withBytePtr d $ \dataPtr -> onCtx dst ctxPtr =
c_poly1305_update (castPtr ctxPtr) dataPtr (fromIntegral $ B.length d) B.withByteArray key $ \keyPtr -> do
-- finalize c_poly1305_init (castPtr ctxPtr) keyPtr
c_poly1305_finalize dst (castPtr ctxPtr) B.withByteArray d $ \dataPtr ->
c_poly1305_update (castPtr ctxPtr) dataPtr (fromIntegral $ B.length d)
c_poly1305_finalize dst (castPtr ctxPtr)

View File

@ -14,9 +14,8 @@ module Crypto.Random.EntropyPool
import Control.Concurrent.MVar import Control.Concurrent.MVar
import Crypto.Random.Entropy.Unsafe import Crypto.Random.Entropy.Unsafe
import Crypto.Internal.ByteArray (ByteArray) import Crypto.Internal.ByteArray (ByteArray, SecureBytes)
import qualified Crypto.Internal.ByteArray as B import qualified Crypto.Internal.ByteArray as B
import Data.SecureMem
import Data.Word (Word8) import Data.Word (Word8)
import Data.Maybe (catMaybes) import Data.Maybe (catMaybes)
import Foreign.Marshal.Utils (copyBytes) import Foreign.Marshal.Utils (copyBytes)
@ -24,7 +23,7 @@ import Foreign.Ptr (plusPtr, Ptr)
-- | Pool of Entropy. contains a self mutating pool of entropy, -- | Pool of Entropy. contains a self mutating pool of entropy,
-- that is always guarantee to contains data. -- that is always guarantee to contains data.
data EntropyPool = EntropyPool [EntropyBackend] (MVar Int) SecureMem data EntropyPool = EntropyPool [EntropyBackend] (MVar Int) SecureBytes
-- size of entropy pool by default -- size of entropy pool by default
defaultPoolSize :: Int defaultPoolSize :: Int
@ -35,9 +34,8 @@ defaultPoolSize = 4096
-- While you can create as many entropy pool as you want, the pool can be shared between multiples RNGs. -- While you can create as many entropy pool as you want, the pool can be shared between multiples RNGs.
createEntropyPoolWith :: Int -> [EntropyBackend] -> IO EntropyPool createEntropyPoolWith :: Int -> [EntropyBackend] -> IO EntropyPool
createEntropyPoolWith poolSize backends = do createEntropyPoolWith poolSize backends = do
sm <- allocateSecureMem poolSize
m <- newMVar 0 m <- newMVar 0
withSecureMemPtr sm $ replenish poolSize backends sm <- B.alloc poolSize (replenish poolSize backends)
return $ EntropyPool backends m sm return $ EntropyPool backends m sm
-- | Create a new entropy pool with a default size. -- | Create a new entropy pool with a default size.
@ -51,10 +49,10 @@ createEntropyPool = do
-- | Put a chunk of the entropy pool into a buffer -- | Put a chunk of the entropy pool into a buffer
getEntropyPtr :: EntropyPool -> Int -> Ptr Word8 -> IO () getEntropyPtr :: EntropyPool -> Int -> Ptr Word8 -> IO ()
getEntropyPtr (EntropyPool backends posM sm) n outPtr = getEntropyPtr (EntropyPool backends posM sm) n outPtr =
withSecureMemPtr sm $ \entropyPoolPtr -> B.withByteArray sm $ \entropyPoolPtr ->
modifyMVar_ posM $ \pos -> modifyMVar_ posM $ \pos ->
copyLoop outPtr entropyPoolPtr pos n copyLoop outPtr entropyPoolPtr pos n
where poolSize = secureMemGetSize sm where poolSize = B.length sm
copyLoop d s pos left copyLoop d s pos left
| left == 0 = return pos | left == 0 = return pos
| otherwise = do | otherwise = do