diff --git a/Crypto/Internal/ByteArray.hs b/Crypto/Internal/ByteArray.hs index d7b1fbc..612833c 100644 --- a/Crypto/Internal/ByteArray.hs +++ b/Crypto/Internal/ByteArray.hs @@ -13,8 +13,11 @@ module Crypto.Internal.ByteArray ( ByteArray(..) , ByteArrayAccess(..) + , byteArrayAlloc , byteArrayAllocAndFreeze , empty + , byteArrayCopy + , byteArrayCopyRet , byteArrayCopyAndFreeze , byteArraySplit , byteArrayXor @@ -50,32 +53,35 @@ class ByteArrayAccess ba where withByteArray :: ba -> (Ptr p -> IO a) -> IO a class ByteArrayAccess ba => ByteArray ba where - byteArrayAlloc :: Int -> (Ptr p -> IO ()) -> IO ba + byteArrayAllocRet :: Int -> (Ptr p -> IO a) -> IO (a, ba) + +byteArrayAlloc :: ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba +byteArrayAlloc n f = snd `fmap` byteArrayAllocRet n f instance ByteArrayAccess Bytes where byteArrayLength = bytesLength withByteArray = withBytes instance ByteArray Bytes where - byteArrayAlloc = bytesAlloc + byteArrayAllocRet = bytesAllocRet instance ByteArrayAccess ByteString where byteArrayLength = B.length withByteArray b f = withForeignPtr fptr $ \ptr -> f (ptr `plusPtr` off) where (fptr, off, _) = B.toForeignPtr b instance ByteArray ByteString where - byteArrayAlloc sz f = do + byteArrayAllocRet sz f = do fptr <- B.mallocByteString sz - withForeignPtr fptr (f . castPtr) - return $! B.PS fptr 0 sz + r <- withForeignPtr fptr (f . castPtr) + return (r, B.PS fptr 0 sz) instance ByteArrayAccess SecureMem where byteArrayLength = secureMemGetSize withByteArray b f = withSecureMemPtr b (f . castPtr) instance ByteArray SecureMem where - byteArrayAlloc sz f = do + byteArrayAllocRet sz f = do out <- allocateSecureMem sz - withSecureMemPtr out (f . castPtr) - return out + r <- withSecureMemPtr out (f . castPtr) + return (r, out) byteArrayAllocAndFreeze :: ByteArray a => Int -> (Ptr p -> IO ()) -> a byteArrayAllocAndFreeze sz f = unsafeDoIO (byteArrayAlloc sz f) @@ -123,6 +129,18 @@ byteArrayConcat allBs = byteArrayAllocAndFreeze total (loop allBs) withByteArray b $ \p -> bufCopy dst p sz loop bs (dst `plusPtr` sz) +byteArrayCopy :: (ByteArray bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO ()) -> IO bs2 +byteArrayCopy bs f = + byteArrayAlloc (byteArrayLength bs) $ \d -> do + withByteArray bs $ \s -> bufCopy d s (byteArrayLength bs) + f (castPtr d) + +byteArrayCopyRet :: (ByteArray bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO a) -> IO (a, bs2) +byteArrayCopyRet bs f = + byteArrayAllocRet (byteArrayLength bs) $ \d -> do + withByteArray bs $ \s -> bufCopy d s (byteArrayLength bs) + f (castPtr d) + byteArrayCopyAndFreeze :: (ByteArray bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO ()) -> bs2 byteArrayCopyAndFreeze bs f = byteArrayAllocAndFreeze (byteArrayLength bs) $ \d -> do diff --git a/Crypto/Internal/Memory.hs b/Crypto/Internal/Memory.hs index f603f94..d72928d 100644 --- a/Crypto/Internal/Memory.hs +++ b/Crypto/Internal/Memory.hs @@ -16,6 +16,7 @@ module Crypto.Internal.Memory , bytesTemporary , bytesCopyTemporary , bytesAlloc + , bytesAllocRet , bytesLength , withBytes , SecureBytes @@ -75,6 +76,12 @@ bytesAlloc sz f = do withPtr ba f return ba +bytesAllocRet :: Int -> (Ptr p -> IO a) -> IO (a, Bytes) +bytesAllocRet sz f = do + ba <- newBytes sz + r <- withPtr ba f + return (r, ba) + bytesLength :: Bytes -> Int bytesLength = sizeofBytes