cryptonite/Data/ByteArray/Methods.hs
2015-05-14 10:04:16 +01:00

233 lines
6.9 KiB
Haskell

-- |
-- Module : Data.ByteArray.Methods
-- License : BSD-style
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
-- Stability : stable
-- Portability : Good
--
{-# LANGUAGE BangPatterns #-}
module Data.ByteArray.Methods
( alloc
, allocAndFreeze
, create
, unsafeCreate
, pack
, unpack
, empty
, replicate
, zero
, copy
, take
, convert
, convertHex
, copyRet
, copyAndFreeze
, split
, xor
, index
, eq
, constEq
, append
, concat
, toW64BE
, toW64LE
, mapAsWord64
, mapAsWord128
) where
import Data.Memory.Internal.Compat
import Data.Memory.Internal.Imports hiding (empty)
import Data.ByteArray.Types
import Data.Memory.Endian
import Data.Memory.PtrMethods
import Data.Memory.ExtendedWords
import Data.Memory.Encoding.Base16
import Foreign.Storable
import Foreign.Ptr
import Prelude hiding (length, take, concat, replicate)
import qualified Prelude
alloc :: ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
alloc n f = snd `fmap` allocRet n f
create :: ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
create n f = alloc n f
allocAndFreeze :: ByteArray a => Int -> (Ptr p -> IO ()) -> a
allocAndFreeze sz f = unsafeDoIO (alloc sz f)
{-# NOINLINE allocAndFreeze #-}
unsafeCreate :: ByteArray a => Int -> (Ptr p -> IO ()) -> a
unsafeCreate sz f = unsafeDoIO (alloc sz f)
{-# NOINLINE unsafeCreate #-}
empty :: ByteArray a => a
empty = unsafeDoIO (alloc 0 $ \_ -> return ())
-- | Pack a list of bytes into a bytearray
pack :: ByteArray a => [Word8] -> a
pack l = unsafeCreate (Prelude.length l) (fill 0 l)
where fill _ [] _ = return ()
fill i (x:xs) p = pokeByteOff p i x >> fill (i+1) xs p
-- | Un-pack a bytearray into a list of bytes
unpack :: ByteArrayAccess a => a -> [Word8]
unpack bs = loop 0
where !len = length bs
loop i
| i == len = []
| otherwise =
let !v = unsafeDoIO $ withByteArray bs (\p -> peekByteOff p i)
in v : loop (i+1)
-- | Create a xor of bytes between a and b.
--
-- the returns byte array is the size of the smallest input.
xor :: (ByteArrayAccess a, ByteArrayAccess b, ByteArray c) => a -> b -> c
xor a b =
unsafeCreate n $ \pc ->
withByteArray a $ \pa ->
withByteArray b $ \pb ->
memXor pc pa pb n
where
n = min la lb
la = length a
lb = length b
index :: ByteArrayAccess a => a -> Int -> Word8
index b i = unsafeDoIO $ withByteArray b $ \p -> peek (p `plusPtr` i)
split :: ByteArray bs => Int -> bs -> (bs, bs)
split n bs
| n <= 0 = (empty, bs)
| n >= len = (bs, empty)
| otherwise = unsafeDoIO $ do
withByteArray bs $ \p -> do
b1 <- alloc n $ \r -> memCopy r p n
b2 <- alloc (len - n) $ \r -> memCopy r (p `plusPtr` n) (len - n)
return (b1, b2)
where len = length bs
take :: ByteArray bs => Int -> bs -> bs
take n bs =
unsafeCreate m $ \d -> withByteArray bs $ \s -> memCopy d s m
where
m = min len n
len = length bs
concat :: ByteArray bs => [bs] -> bs
concat [] = empty
concat allBs = unsafeCreate total (loop allBs)
where
total = sum $ map length allBs
loop [] _ = return ()
loop (b:bs) dst = do
let sz = length b
withByteArray b $ \p -> memCopy dst p sz
loop bs (dst `plusPtr` sz)
append :: ByteArray bs => bs -> bs -> bs
append b1 b2 = concat [b1,b2]
copy :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO ()) -> IO bs2
copy bs f =
alloc (length bs) $ \d -> do
withByteArray bs $ \s -> memCopy d s (length bs)
f (castPtr d)
copyRet :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
copyRet bs f =
allocRet (length bs) $ \d -> do
withByteArray bs $ \s -> memCopy d s (length bs)
f (castPtr d)
copyAndFreeze :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO ()) -> bs2
copyAndFreeze bs f =
unsafeCreate (length bs) $ \d -> do
withByteArray bs $ \s -> memCopy d s (length bs)
f (castPtr d)
replicate :: ByteArray ba => Int -> Word8 -> ba
replicate 0 _ = empty
replicate n b = unsafeCreate n $ \ptr -> memSet ptr b n
{-# NOINLINE replicate #-}
zero :: ByteArray ba => Int -> ba
zero 0 = empty
zero n = unsafeCreate n $ \ptr -> memSet ptr 0 n
{-# NOINLINE zero #-}
eq :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> Bool
eq b1 b2
| l1 /= l2 = False
| otherwise = unsafeDoIO $ withByteArray b1 $ \p1 -> withByteArray b2 $ \p2 -> memEqual p1 p2 l1
where
l1 = length b1
l2 = length b2
-- | A constant time equality test for 2 ByteArrayAccess values.
--
-- If values are of 2 different sizes, the function will abort early
-- without comparing any bytes.
--
-- compared to == , this function will go over all the bytes
-- present before yielding a result even when knowing the
-- overall result early in the processing.
constEq :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> Bool
constEq b1 b2
| l1 /= l2 = False
| otherwise = unsafeDoIO $ withByteArray b1 $ \p1 -> withByteArray b2 $ \p2 -> memConstEqual p1 p2 l1
where
l1 = length b1
l2 = length b2
toW64BE :: ByteArrayAccess bs => bs -> Int -> BE Word64
toW64BE bs ofs = unsafeDoIO $ withByteArray bs $ \p -> peek (p `plusPtr` ofs)
toW64LE :: ByteArrayAccess bs => bs -> Int -> LE Word64
toW64LE bs ofs = unsafeDoIO $ withByteArray bs $ \p -> peek (p `plusPtr` ofs)
mapAsWord128 :: ByteArray bs => (Word128 -> Word128) -> bs -> bs
mapAsWord128 f bs =
unsafeCreate len $ \dst ->
withByteArray bs $ \src ->
loop (len `div` 16) dst src
where
len = length bs
loop :: Int -> Ptr (BE Word64) -> Ptr (BE Word64) -> IO ()
loop 0 _ _ = return ()
loop i d s = do
w1 <- peek s
w2 <- peek (s `plusPtr` 8)
let (Word128 r1 r2) = f (Word128 (fromBE w1) (fromBE w2))
poke d (toBE r1)
poke (d `plusPtr` 8) (toBE r2)
loop (i-1) (d `plusPtr` 16) (s `plusPtr` 16)
mapAsWord64 :: ByteArray bs => (Word64 -> Word64) -> bs -> bs
mapAsWord64 f bs =
unsafeCreate len $ \dst ->
withByteArray bs $ \src ->
loop (len `div` 8) dst src
where
len = length bs
loop :: Int -> Ptr (BE Word64) -> Ptr (BE Word64) -> IO ()
loop 0 _ _ = return ()
loop i d s = do
w <- peek s
let r = f (fromBE w)
poke d (toBE r)
loop (i-1) (d `plusPtr` 8) (s `plusPtr` 8)
convert :: (ByteArrayAccess bin, ByteArray bout) => bin -> bout
convert = flip copyAndFreeze (\_ -> return ())
convertHex :: (ByteArrayAccess bin, ByteArray bout) => bin -> bout
convertHex b =
unsafeCreate (length b * 2) $ \bout ->
withByteArray b $ \bin ->
toHexadecimal bout bin (length b)