diff --git a/Data/Memory/ByteArray.hs b/Data/Memory/ByteArray.hs new file mode 100644 index 0000000..19ae285 --- /dev/null +++ b/Data/Memory/ByteArray.hs @@ -0,0 +1,22 @@ +-- | +-- Module : Data.Memory.ByteArray +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : stable +-- Portability : Good +-- +-- Simple and efficient byte array types +-- +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE NoImplicitPrelude #-} +module Data.Memory.ByteArray + ( module X + ) where + +import Data.Memory.ByteArray.Types as X +import Data.Memory.ByteArray.Methods as X +import Data.Memory.ByteArray.ScrubbedBytes as X (ScrubbedBytes) +import Data.Memory.ByteArray.Bytes as X (Bytes) +import Data.Memory.ByteArray.MemView as X (MemView(..)) diff --git a/Data/Memory/ByteArray/Bytes.hs b/Data/Memory/ByteArray/Bytes.hs new file mode 100644 index 0000000..eaa5a35 --- /dev/null +++ b/Data/Memory/ByteArray/Bytes.hs @@ -0,0 +1,123 @@ +-- | +-- Module : Data.Memory.ByteArray.Bytes +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : stable +-- Portability : Good +-- +-- Simple and efficient byte array types +-- +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +module Data.Memory.ByteArray.Bytes + ( Bytes + ) where + +import GHC.Types +import GHC.Prim +import GHC.Ptr +import Data.Memory.Internal.CompatPrim +import Data.Memory.Internal.Compat (unsafeDoIO) +import Data.Memory.ByteArray.Types +import Data.Memory.Encoding.Base16 (showHexadecimal) + +data Bytes = Bytes (MutableByteArray# RealWorld) + +instance Show Bytes where + show = bytesShowHex +instance Eq Bytes where + (==) = bytesEq + +instance ByteArrayAccess Bytes where + length = bytesLength + withByteArray = withBytes +instance ByteArray Bytes where + allocRet = bytesAllocRet + +------------------------------------------------------------------------ +newBytes :: Int -> IO Bytes +newBytes (I# sz) = IO $ \s -> + case newAlignedPinnedByteArray# sz 8# s of + (# s', mbarr #) -> (# s', Bytes mbarr #) + +touchBytes :: Bytes -> IO () +touchBytes (Bytes mba) = IO $ \s -> case touch# mba s of s' -> (# s', () #) + +sizeofBytes :: Bytes -> Int +sizeofBytes (Bytes mba) = I# (sizeofMutableByteArray# mba) + +withPtr :: Bytes -> (Ptr p -> IO a) -> IO a +withPtr b@(Bytes mba) f = do + a <- f (Ptr (byteArrayContents# (unsafeCoerce# mba))) + touchBytes b + return a +------------------------------------------------------------------------ + +{- +bytesCopyAndModify :: Bytes -> (Ptr a -> IO ()) -> IO Bytes +bytesCopyAndModify src f = do + dst <- newBytes sz + withPtr dst $ \d -> do + withPtr src $ \s -> copyBytes (castPtr d) s sz + f d + return dst + where sz = sizeofBytes src + +bytesTemporary :: Int -> (Ptr p -> IO a) -> IO a +bytesTemporary sz f = newBytes sz >>= \ba -> withPtr ba f + +bytesCopyTemporary :: Bytes -> (Ptr p -> IO a) -> IO a +bytesCopyTemporary src f = do + dst <- newBytes (sizeofBytes src) + withPtr dst $ \d -> do + withPtr src $ \s -> copyBytes (castPtr d) s (sizeofBytes src) + f d +bytesAlloc :: Int -> (Ptr p -> IO ()) -> IO Bytes +bytesAlloc sz f = do + ba <- newBytes sz + withPtr ba f + return ba +-} + +bytesAllocRet :: Int -> (Ptr p -> IO a) -> IO (a, Bytes) +bytesAllocRet sz f = do + ba <- newBytes sz + r <- withPtr ba f + return (r, ba) + +bytesLength :: Bytes -> Int +bytesLength = sizeofBytes + +withBytes :: Bytes -> (Ptr p -> IO a) -> IO a +withBytes = withPtr + +bytesEq :: Bytes -> Bytes -> Bool +bytesEq b1@(Bytes m1) b2@(Bytes m2) + | l1 /= l2 = False + | otherwise = unsafeDoIO $ IO $ \s -> loop 0# s + where + !l1@(I# len) = bytesLength b1 + !l2 = bytesLength b2 + + loop i s + | booleanPrim (i ==# len) = (# s, True #) + | otherwise = + case readWord8Array# m1 i s of + (# s', e1 #) -> case readWord8Array# m2 i s' of + (# s'', e2 #) -> + if booleanPrim (eqWord# e1 e2) + then loop (i +# 1#) s'' + else (# s', False #) + +{- +bytesIndex :: Bytes -> Int -> Word8 +bytesIndex (Bytes m) (I# i) = unsafeDoIO $ IO $ \s -> + case readWord8Array# m i s of + (# s', e #) -> (# s', W8# e #) +{-# NOINLINE bytesIndex #-} +-} + +bytesShowHex :: Bytes -> String +bytesShowHex b = showHexadecimal (withPtr b) (bytesLength b) +{-# NOINLINE bytesShowHex #-} diff --git a/Data/Memory/ByteArray/MemView.hs b/Data/Memory/ByteArray/MemView.hs new file mode 100644 index 0000000..5663d60 --- /dev/null +++ b/Data/Memory/ByteArray/MemView.hs @@ -0,0 +1,21 @@ +-- | +-- Module : Data.Memory.ByteArray.MemView +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : stable +-- Portability : Good +-- +module Data.Memory.ByteArray.MemView + ( MemView(..) + ) where + +import Foreign.Ptr +import Data.Memory.ByteArray.Types +import Data.Memory.Internal.Imports + +data MemView = MemView !(Ptr Word8) !Int + +instance ByteArrayAccess MemView where + length (MemView _ l) = l + withByteArray (MemView p _) f = f (castPtr p) + diff --git a/Data/Memory/ByteArray/Methods.hs b/Data/Memory/ByteArray/Methods.hs new file mode 100644 index 0000000..9c5ea3e --- /dev/null +++ b/Data/Memory/ByteArray/Methods.hs @@ -0,0 +1,191 @@ +-- | +-- Module : Data.Memory.ByteArray.Methods +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : stable +-- Portability : Good +-- +{-# LANGUAGE BangPatterns #-} +module Data.Memory.ByteArray.Methods + ( alloc + , allocAndFreeze + , empty + , zero + , copy + , take + , convert + , convertHex + , copyRet + , copyAndFreeze + , split + , xor + , eq + , index + , constEq + , concat + , toW64BE + , toW64LE + , mapAsWord64 + , mapAsWord128 + ) where + +import Data.Memory.Internal.Compat +import Data.Memory.Internal.Imports hiding (empty) +import Data.Memory.ByteArray.Types +import Data.Memory.Endian +import Data.Memory.PtrMethods +import Data.Memory.ExtendedWords +import Data.Memory.Encoding.Base16 +import Foreign.Storable +import Foreign.Ptr + +import Prelude hiding (length, take, concat) + +alloc :: ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba +alloc n f = snd `fmap` allocRet n f + +allocAndFreeze :: ByteArray a => Int -> (Ptr p -> IO ()) -> a +allocAndFreeze sz f = unsafeDoIO (alloc sz f) + +empty :: ByteArray a => a +empty = unsafeDoIO (alloc 0 $ \_ -> return ()) + +-- | Create a xor of bytes between a and b. +-- +-- the returns byte array is the size of the smallest input. +xor :: (ByteArrayAccess a, ByteArrayAccess b, ByteArray c) => a -> b -> c +xor a b = + allocAndFreeze n $ \pc -> + withByteArray a $ \pa -> + withByteArray b $ \pb -> + memXor pc pa pb n + where + n = min la lb + la = length a + lb = length b + +index :: ByteArrayAccess a => a -> Int -> Word8 +index b i = unsafeDoIO $ withByteArray b $ \p -> peek (p `plusPtr` i) + +split :: ByteArray bs => Int -> bs -> (bs, bs) +split n bs + | n <= 0 = (empty, bs) + | n >= len = (bs, empty) + | otherwise = unsafeDoIO $ do + withByteArray bs $ \p -> do + b1 <- alloc n $ \r -> memCopy r p n + b2 <- alloc (len - n) $ \r -> memCopy r (p `plusPtr` n) (len - n) + return (b1, b2) + where len = length bs + +take :: ByteArray bs => Int -> bs -> bs +take n bs = + allocAndFreeze m $ \d -> withByteArray bs $ \s -> memCopy d s m + where + m = min len n + len = length bs + +concat :: ByteArray bs => [bs] -> bs +concat [] = empty +concat allBs = allocAndFreeze total (loop allBs) + where + total = sum $ map length allBs + + loop [] _ = return () + loop (b:bs) dst = do + let sz = length b + withByteArray b $ \p -> memCopy dst p sz + loop bs (dst `plusPtr` sz) + +copy :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO ()) -> IO bs2 +copy bs f = + alloc (length bs) $ \d -> do + withByteArray bs $ \s -> memCopy d s (length bs) + f (castPtr d) + +copyRet :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO a) -> IO (a, bs2) +copyRet bs f = + allocRet (length bs) $ \d -> do + withByteArray bs $ \s -> memCopy d s (length bs) + f (castPtr d) + +copyAndFreeze :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO ()) -> bs2 +copyAndFreeze bs f = + allocAndFreeze (length bs) $ \d -> do + withByteArray bs $ \s -> memCopy d s (length bs) + f (castPtr d) + +zero :: ByteArray ba => Int -> ba +zero n = allocAndFreeze n $ \ptr -> memSet ptr 0 n + +eq :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> Bool +eq b1 b2 + | l1 /= l2 = False + | otherwise = unsafeDoIO $ withByteArray b1 $ \p1 -> withByteArray b2 $ \p2 -> memEqual p1 p2 l1 + where + l1 = length b1 + l2 = length b2 + +-- | A constant time equality test for 2 ByteArrayAccess values. +-- +-- If values are of 2 different sizes, the function will abort early +-- without comparing any bytes. +-- +-- compared to == , this function will go over all the bytes +-- present before yielding a result even when knowing the +-- overall result early in the processing. +constEq :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> Bool +constEq b1 b2 + | l1 /= l2 = False + | otherwise = unsafeDoIO $ withByteArray b1 $ \p1 -> withByteArray b2 $ \p2 -> memConstEqual p1 p2 l1 + where + l1 = length b1 + l2 = length b2 + +toW64BE :: ByteArrayAccess bs => bs -> Int -> BE Word64 +toW64BE bs ofs = unsafeDoIO $ withByteArray bs $ \p -> peek (p `plusPtr` ofs) + +toW64LE :: ByteArrayAccess bs => bs -> Int -> LE Word64 +toW64LE bs ofs = unsafeDoIO $ withByteArray bs $ \p -> peek (p `plusPtr` ofs) + +mapAsWord128 :: ByteArray bs => (Word128 -> Word128) -> bs -> bs +mapAsWord128 f bs = + allocAndFreeze len $ \dst -> + withByteArray bs $ \src -> + loop (len `div` 16) dst src + where + len = length bs + loop :: Int -> Ptr (BE Word64) -> Ptr (BE Word64) -> IO () + loop 0 _ _ = return () + loop i d s = do + w1 <- peek s + w2 <- peek (s `plusPtr` 8) + let (Word128 r1 r2) = f (Word128 (fromBE w1) (fromBE w2)) + poke d (toBE r1) + poke (d `plusPtr` 8) (toBE r2) + loop (i-1) (d `plusPtr` 16) (s `plusPtr` 16) + +mapAsWord64 :: ByteArray bs => (Word64 -> Word64) -> bs -> bs +mapAsWord64 f bs = + allocAndFreeze len $ \dst -> + withByteArray bs $ \src -> + loop (len `div` 8) dst src + where + len = length bs + + loop :: Int -> Ptr (BE Word64) -> Ptr (BE Word64) -> IO () + loop 0 _ _ = return () + loop i d s = do + w <- peek s + let r = f (fromBE w) + poke d (toBE r) + loop (i-1) (d `plusPtr` 8) (s `plusPtr` 8) + +convert :: (ByteArrayAccess bin, ByteArray bout) => bin -> bout +convert = flip copyAndFreeze (\_ -> return ()) + +convertHex :: (ByteArrayAccess bin, ByteArray bout) => bin -> bout +convertHex b = + allocAndFreeze (length b * 2) $ \bout -> + withByteArray b $ \bin -> + toHexadecimal bout bin (length b) diff --git a/Data/Memory/ByteArray/ScrubbedBytes.hs b/Data/Memory/ByteArray/ScrubbedBytes.hs new file mode 100644 index 0000000..7840122 --- /dev/null +++ b/Data/Memory/ByteArray/ScrubbedBytes.hs @@ -0,0 +1,103 @@ +-- | +-- Module : Data.Memory.ByteArray.ScrubbedBytes +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : Stable +-- Portability : GHC +-- +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE CPP #-} +module Data.Memory.ByteArray.ScrubbedBytes + ( ScrubbedBytes + ) where + +import GHC.Types +import GHC.Prim +import GHC.Ptr +import Data.Memory.Internal.CompatPrim +import Data.Memory.Internal.Compat (unsafeDoIO) +import Data.Memory.PtrMethods (memConstEqual) +import Data.Memory.ByteArray.Types + +-- | ScrubbedBytes is a memory chunk which have the properties of: +-- +-- * Being scrubbed after its goes out of scope. +-- +-- * A Show instance that doesn't actually show any content +-- +-- * A Eq instance that is constant time +-- +data ScrubbedBytes = ScrubbedBytes (MutableByteArray# RealWorld) + +instance Show ScrubbedBytes where + show _ = "" + +instance Eq ScrubbedBytes where + (==) = scrubbedBytesEq + +instance ByteArrayAccess ScrubbedBytes where + length = sizeofScrubbedBytes + withByteArray = withPtr + +instance ByteArray ScrubbedBytes where + allocRet = scrubbedBytesAllocRet + +newScrubbedBytes :: Int -> IO ScrubbedBytes +newScrubbedBytes (I# sz) + | booleanPrim (sz <=# 0#) = error "negative or null size for scrubbed array" -- TODO raise a proper exception + | otherwise = IO $ \s -> + case newAlignedPinnedByteArray# sz 8# s of + (# s1, mbarr #) -> + let !scrubber = getScrubber + !mba = ScrubbedBytes mbarr + in case mkWeak# mbarr () (scrubber (byteArrayContents# (unsafeCoerce# mbarr)) >> touchScrubbedBytes mba) s1 of + (# s2, _ #) -> (# s2, mba #) + where + getScrubber :: Addr# -> IO () + getScrubber = eitherDivideBy8# sz scrubber64 scrubber8 + + scrubber64 :: Int# -> Addr# -> IO () + scrubber64 sz64 addr = IO $ \s -> (# loop sz64 addr s, () #) + where loop :: Int# -> Addr# -> State# RealWorld -> State# RealWorld + loop n a s + | booleanPrim (n ==# 0#) = s + | otherwise = + case writeWord64OffAddr# a 0# 0## s of + s' -> loop (n -# 1#) (plusAddr# a 8#) s' + + scrubber8 :: Int# -> Addr# -> IO () + scrubber8 sz8 addr = IO $ \s -> (# loop sz8 addr s, () #) + where loop :: Int# -> Addr# -> State# RealWorld -> State# RealWorld + loop n a s + | booleanPrim (n ==# 0#) = s + | otherwise = + case writeWord8OffAddr# a 0# 0## s of + s' -> loop (n -# 1#) (plusAddr# a 1#) s' + +scrubbedBytesAllocRet :: Int -> (Ptr p -> IO a) -> IO (a, ScrubbedBytes) +scrubbedBytesAllocRet sz f = do + ba <- newScrubbedBytes sz + r <- withPtr ba f + return (r, ba) + +sizeofScrubbedBytes :: ScrubbedBytes -> Int +sizeofScrubbedBytes (ScrubbedBytes mba) = I# (sizeofMutableByteArray# mba) + +withPtr :: ScrubbedBytes -> (Ptr p -> IO a) -> IO a +withPtr b@(ScrubbedBytes mba) f = do + a <- f (Ptr (byteArrayContents# (unsafeCoerce# mba))) + touchScrubbedBytes b + return a + +touchScrubbedBytes :: ScrubbedBytes -> IO () +touchScrubbedBytes (ScrubbedBytes mba) = IO $ \s -> case touch# mba s of s' -> (# s', () #) + +scrubbedBytesEq :: ScrubbedBytes -> ScrubbedBytes -> Bool +scrubbedBytesEq a b + | l1 /= l2 = False + | otherwise = unsafeDoIO $ withPtr a $ \p1 -> withPtr b $ \p2 -> memConstEqual p1 p2 l1 + where + l1 = sizeofScrubbedBytes a + l2 = sizeofScrubbedBytes b diff --git a/Data/Memory/ByteArray/Types.hs b/Data/Memory/ByteArray/Types.hs new file mode 100644 index 0000000..143e76f --- /dev/null +++ b/Data/Memory/ByteArray/Types.hs @@ -0,0 +1,41 @@ +-- | +-- Module : Data.Memory.ByteArray.Types +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : stable +-- Portability : Good +-- +{-# LANGUAGE CPP #-} +module Data.Memory.ByteArray.Types + ( ByteArrayAccess(..) + , ByteArray(..) + ) where + +import Foreign.Ptr + +#ifdef WITH_BYTESTRING_SUPPORT +import qualified Data.ByteString as B (length) +import qualified Data.ByteString.Internal as B +import Foreign.ForeignPtr (withForeignPtr) +#endif + +class ByteArrayAccess ba where + length :: ba -> Int + withByteArray :: ba -> (Ptr p -> IO a) -> IO a + +class ByteArrayAccess ba => ByteArray ba where + allocRet :: Int -> (Ptr p -> IO a) -> IO (a, ba) + +#ifdef WITH_BYTESTRING_SUPPORT +instance ByteArrayAccess B.ByteString where + length = B.length + withByteArray b f = withForeignPtr fptr $ \ptr -> f (ptr `plusPtr` off) + where (fptr, off, _) = B.toForeignPtr b + +instance ByteArray B.ByteString where + allocRet sz f = do + fptr <- B.mallocByteString sz + r <- withForeignPtr fptr (f . castPtr) + return (r, B.PS fptr 0 sz) +#endif + diff --git a/Data/Memory/Encoding/Base16.hs b/Data/Memory/Encoding/Base16.hs new file mode 100644 index 0000000..d7a3e9b --- /dev/null +++ b/Data/Memory/Encoding/Base16.hs @@ -0,0 +1,99 @@ +-- | +-- Module : Data.Memory.Encoding.Base16 +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : experimental +-- Portability : unknown +-- +-- Hexadecimal escaper +-- +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE Rank2Types #-} +module Data.Memory.Encoding.Base16 + ( showHexadecimal + , toHexadecimal + ) where + +import Data.Memory.Internal.Compat +import Data.Word +import GHC.Prim +import GHC.Types +import GHC.Word +import Control.Monad +import Foreign.Storable +import Foreign.Ptr (Ptr) + +showHexadecimal :: (forall a . (Ptr Word8 -> IO a) -> IO a) + -> Int + -> String +showHexadecimal withPtr = doChunks 0 + where + doChunks ofs len + | len < 4 = doUnique ofs len + | otherwise = do + let !(W8# a, W8# b, W8# c, W8# d) = unsafeDoIO $ withPtr (read4 ofs) + !(# w1, w2 #) = convertByte a + !(# w3, w4 #) = convertByte b + !(# w5, w6 #) = convertByte c + !(# w7, w8 #) = convertByte d + in wToChar w1 : wToChar w2 : wToChar w3 : wToChar w4 + : wToChar w5 : wToChar w6 : wToChar w7 : wToChar w8 + : doChunks (ofs + 4) (len - 4) + + doUnique ofs len + | len == 0 = [] + | otherwise = + let !(W8# b) = unsafeDoIO $ withPtr (byteIndex ofs) + !(# w1, w2 #) = convertByte b + in wToChar w1 : wToChar w2 : doUnique (ofs + 1) (len - 1) + + read4 :: Int -> Ptr Word8 -> IO (Word8, Word8, Word8, Word8) + read4 ofs p = + liftM4 (,,,) (byteIndex ofs p) (byteIndex (ofs+1) p) + (byteIndex (ofs+2) p) (byteIndex (ofs+3) p) + + wToChar :: Word# -> Char + wToChar w = toEnum (I# (word2Int# w)) + + byteIndex :: Int -> Ptr Word8 -> IO Word8 + byteIndex i p = peekByteOff p i + +toHexadecimal :: Ptr Word8 -> Ptr Word8 -> Int -> IO () +toHexadecimal bout bin n = loop 0 + where loop i + | i == n = return () + | otherwise = do + (W8# w) <- peekByteOff bin i + let !(# w1, w2 #) = convertByte w + pokeByteOff bout (i * 2) (W8# w1) + pokeByteOff bout (i * 2 + 1) (W8# w2) + loop (i+1) + +convertByte :: Word# -> (# Word#, Word# #) +convertByte b = (# r tableHi b, r tableLo b #) + where + r :: Addr# -> Word# -> Word# + r table index = indexWord8OffAddr# table (word2Int# index) + + !tableLo = + "0123456789abcdef0123456789abcdef\ + \0123456789abcdef0123456789abcdef\ + \0123456789abcdef0123456789abcdef\ + \0123456789abcdef0123456789abcdef\ + \0123456789abcdef0123456789abcdef\ + \0123456789abcdef0123456789abcdef\ + \0123456789abcdef0123456789abcdef\ + \0123456789abcdef0123456789abcdef"# + !tableHi = + "00000000000000001111111111111111\ + \22222222222222223333333333333333\ + \44444444444444445555555555555555\ + \66666666666666667777777777777777\ + \88888888888888889999999999999999\ + \aaaaaaaaaaaaaaaabbbbbbbbbbbbbbbb\ + \ccccccccccccccccdddddddddddddddd\ + \eeeeeeeeeeeeeeeeffffffffffffffff"# +{-# INLINE convertByte #-} diff --git a/Data/Memory/Endian.hs b/Data/Memory/Endian.hs new file mode 100644 index 0000000..0735194 --- /dev/null +++ b/Data/Memory/Endian.hs @@ -0,0 +1,114 @@ +-- | +-- Module : Data.Memory.Endian +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : stable +-- Portability : good +-- +{-# LANGUAGE CPP #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +module Data.Memory.Endian + ( Endianness(..) + , getSystemEndianness + , BE(..), LE(..) + , fromBE, toBE + , fromLE, toLE + ) where + +import Data.Word (Word16, Word32, Word64) +import Foreign.Storable +#if !defined(ARCH_IS_LITTLE_ENDIAN) && !defined(ARCH_IS_BIG_ENDIAN) +import Data.Memory.Internal.Compat (unsafeDoIO) +#endif + +import Data.Memory.Internal.Compat (byteSwap64, byteSwap32, byteSwap16) + +-- | represent the CPU endianness +-- +-- Big endian system stores bytes with the MSB as the first byte. +-- Little endian system stores bytes with the LSB as the first byte. +-- +-- middle endian is purposely avoided. +data Endianness = LittleEndian + | BigEndian + deriving (Show,Eq) + +-- | Return the system endianness +getSystemEndianness :: Endianness +#ifdef ARCH_IS_LITTLE_ENDIAN +getSystemEndianness = LittleEndian +#elif ARCH_IS_BIG_ENDIAN +getSystemEndianness = BigEndian +#else +getSystemEndianness + | isLittleEndian = LittleEndian + | isBigEndian = BigEndian + | otherwise = error "cannot determine endianness" + where + isLittleEndian = endianCheck == 2 + isBigEndian = endianCheck == 1 + endianCheck = unsafeDoIO $ alloca $ \p -> do + poke p (0x01000002 :: Word32) + peek (castPtr p :: Ptr Word8) +#endif + +-- | Little Endian value +newtype LE a = LE { unLE :: a } + deriving (Show,Eq,Storable) + +-- | Big Endian value +newtype BE a = BE { unBE :: a } + deriving (Show,Eq,Storable) + +-- | Convert a value in cpu endianess to big endian +toBE :: ByteSwap a => a -> BE a +#ifdef ARCH_IS_LITTLE_ENDIAN +toBE = BE . byteSwap +#elif ARCH_IS_BIG_ENDIAN +toBE = BE +#else +toBE = BE . (if getSystemEndianness == LittleEndian then byteSwap else id) +#endif +{-# INLINE toBE #-} + +-- | Convert from a big endian value to the cpu endianness +fromBE :: ByteSwap a => BE a -> a +#ifdef ARCH_IS_LITTLE_ENDIAN +fromBE (BE a) = byteSwap a +#elif ARCH_IS_BIG_ENDIAN +fromBE (BE a) = a +#else +fromBE (BE a) = if getSystemEndianness == LittleEndian then byteSwap a else a +#endif +{-# INLINE fromBE #-} + +-- | Convert a value in cpu endianess to little endian +toLE :: ByteSwap a => a -> LE a +#ifdef ARCH_IS_LITTLE_ENDIAN +toLE = LE +#elif ARCH_IS_BIG_ENDIAN +toLE = LE . byteSwap +#else +toLE = LE . (if getSystemEndianness == LittleEndian then id else byteSwap) +#endif +{-# INLINE toLE #-} + +-- | Convert from a little endian value to the cpu endianness +fromLE :: ByteSwap a => LE a -> a +#ifdef ARCH_IS_LITTLE_ENDIAN +fromLE (LE a) = a +#elif ARCH_IS_BIG_ENDIAN +fromLE (LE a) = byteSwap a +#else +fromLE (LE a) = if getSystemEndianness == LittleEndian then a else byteSwap a +#endif +{-# INLINE fromLE #-} + +class Storable a => ByteSwap a where + byteSwap :: a -> a +instance ByteSwap Word16 where + byteSwap = byteSwap16 +instance ByteSwap Word32 where + byteSwap = byteSwap32 +instance ByteSwap Word64 where + byteSwap = byteSwap64 diff --git a/Data/Memory/ExtendedWords.hs b/Data/Memory/ExtendedWords.hs new file mode 100644 index 0000000..6e2052b --- /dev/null +++ b/Data/Memory/ExtendedWords.hs @@ -0,0 +1,16 @@ +-- | +-- Module : Data.Memory.ExtendedWords +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : experimental +-- Portability : unknown +-- +-- Extra Word size +-- +module Data.Memory.ExtendedWords + ( Word128(..) + ) where + +import Data.Word (Word64) + +data Word128 = Word128 !Word64 !Word64 deriving (Show, Eq) diff --git a/Data/Memory/Internal/Compat.hs b/Data/Memory/Internal/Compat.hs new file mode 100644 index 0000000..2b94112 --- /dev/null +++ b/Data/Memory/Internal/Compat.hs @@ -0,0 +1,65 @@ +-- | +-- Module : Data.Memory.Internal.Compat +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : stable +-- Portability : Good +-- +-- This module try to keep all the difference between versions of base +-- or other needed packages, so that modules don't need to use CPP +-- +{-# LANGUAGE CPP #-} +module Data.Memory.Internal.Compat + ( unsafeDoIO + , popCount + , byteSwap64 + , byteSwap32 + , byteSwap16 + ) where + +import System.IO.Unsafe +import Data.Word +import Data.Bits + +-- | perform io for hashes that do allocation and ffi. +-- unsafeDupablePerformIO is used when possible as the +-- computation is pure and the output is directly linked +-- to the input. we also do not modify anything after it has +-- been returned to the user. +unsafeDoIO :: IO a -> a +#if __GLASGOW_HASKELL__ > 704 +unsafeDoIO = unsafeDupablePerformIO +#else +unsafeDoIO = unsafePerformIO +#endif + +#if !(MIN_VERSION_base(4,5,0)) +popCount :: Word64 -> Int +popCount n = loop 0 n + where loop c 0 = c + loop c i = loop (c + if testBit c 0 then 1 else 0) (i `shiftR` 1) +#endif + +#if !(MIN_VERSION_base(4,7,0)) +byteSwap64 :: Word64 -> Word64 +byteSwap64 w = + (w `shiftR` 56) .|. (w `shiftL` 56) + .|. ((w `shiftR` 40) .&. 0xff00) .|. ((w .&. 0xff00) `shiftL` 40) + .|. ((w `shiftR` 24) .&. 0xff0000) .|. ((w .&. 0xff0000) `shiftL` 24) + .|. ((w `shiftR` 8) .&. 0xff000000) .|. ((w .&. 0xff000000) `shiftL` 8) +#endif + +#if !(MIN_VERSION_base(4,7,0)) +byteSwap32 :: Word32 -> Word32 +byteSwap32 w = + (w `shiftR` 24) + .|. (w `shiftL` 24) + .|. ((w `shiftR` 8) .&. 0xff00) + .|. ((w .&. 0xff00) `shiftL` 8) +#endif + +#if !(MIN_VERSION_base(4,7,0)) +byteSwap16 :: Word16 -> Word16 +byteSwap16 w = + (w `shiftR` 8) .|. (w `shiftL` 8) +#endif diff --git a/Data/Memory/Internal/CompatPrim.hs b/Data/Memory/Internal/CompatPrim.hs new file mode 100644 index 0000000..1b7bb15 --- /dev/null +++ b/Data/Memory/Internal/CompatPrim.hs @@ -0,0 +1,84 @@ +-- | +-- Module : Data.Memory.Internal.CompatPrim +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : stable +-- Portability : Compat +-- +-- This module try to keep all the difference between versions of ghc primitive +-- or other needed packages, so that modules don't need to use CPP. +-- +-- Note that MagicHash and CPP conflicts in places, making it "more interesting" +-- to write compat code for primitives +-- +{-# LANGUAGE CPP #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +module Data.Memory.Internal.CompatPrim + ( be32Prim + , le32Prim + , byteswap32Prim + , booleanPrim + , eitherDivideBy8# + ) where + +import GHC.Prim + +-- | byteswap Word# to or from Big Endian +-- +-- on a big endian machine, this function is a nop. +be32Prim :: Word# -> Word# +#ifdef ARCH_IS_LITTLE_ENDIAN +be32Prim = byteswap32Prim +#else +be32Prim w = w +#endif + +-- | byteswap Word# to or from Little Endian +-- +-- on a little endian machine, this function is a nop. +le32Prim :: Word# -> Word# +#ifdef ARCH_IS_LITTLE_ENDIAN +le32Prim w = w +#else +le32Prim = byteswap32Prim +#endif + +byteswap32Prim :: Word# -> Word# +#if __GLASGOW_HASKELL__ >= 708 +byteswap32Prim w = byteSwap32# w +#else +byteswap32Prim w = + let !a = uncheckedShiftL# w 24# + !b = and# (uncheckedShiftL# w 8#) 0x00ff0000## + !c = and# (uncheckedShiftRL# w 8#) 0x0000ff00## + !d = and# (uncheckedShiftRL# w 24#) 0x000000ff## + in or# a (or# b (or# c d)) +#endif + +#if __GLASGOW_HASKELL__ >= 708 +booleanPrim :: Int# -> Bool +booleanPrim v = tagToEnum# v +#else +booleanPrim :: Bool -> Bool +booleanPrim b = b +#endif + +-- | Apply or or another function if 8 divides the number of bytes +eitherDivideBy8# :: Int# -- ^ number of bytes + -> (Int# -> a) -- ^ if it divided by 8, the argument is the number of 8 bytes words + -> (Int# -> a) -- ^ if it doesn't, just the number of bytes + -> a +#if __GLASGOW_HASKELL__ >= 740 +eitherDivideBy8# v f8 f1 = + let !(# q, r #) = quotRemInt v 8# + in if booleanPrim (r ==# 0) + then f8 q + else f1 v +#else +eitherDivideBy8# v f8 f1 = + if booleanPrim ((remInt# v 8#) ==# 0#) + then f8 (quotInt# v 8#) + else f1 v +#endif diff --git a/Data/Memory/Internal/Imports.hs b/Data/Memory/Internal/Imports.hs new file mode 100644 index 0000000..6a4d830 --- /dev/null +++ b/Data/Memory/Internal/Imports.hs @@ -0,0 +1,15 @@ +-- | +-- Module : Data.Memory.Internal.Imports +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : experimental +-- Portability : unknown +-- +module Data.Memory.Internal.Imports + ( module X + ) where + +import Data.Word as X +import Control.Applicative as X +import Control.Monad as X (forM, forM_, void) +import Control.Arrow as X (first, second) diff --git a/Data/Memory/PtrMethods.hs b/Data/Memory/PtrMethods.hs new file mode 100644 index 0000000..eb940c3 --- /dev/null +++ b/Data/Memory/PtrMethods.hs @@ -0,0 +1,110 @@ +-- | +-- Module : Data.Memory.PtrMethods +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : experimental +-- Portability : unknown +-- +-- methods to manipulate raw memory representation +-- +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE ForeignFunctionInterface #-} +module Data.Memory.PtrMethods + ( memCreateTemporary + , memXor + , memXorWith + , memCopy + , memSet + , memEqual + , memConstEqual + , memCompare + ) where + +import Data.Memory.Internal.Imports +import Foreign.Ptr (Ptr, plusPtr) +import Foreign.Storable (peek, poke, pokeByteOff, peekByteOff) +import Foreign.C.Types +import Foreign.Marshal.Alloc (allocaBytesAligned) +import Data.Bits (xor) + +-- | Create a new temporary buffer +memCreateTemporary :: Int -> (Ptr Word8 -> IO a) -> IO a +memCreateTemporary size f = allocaBytesAligned size 8 f + +-- | xor bytes from source1 and source2 to destination +-- +-- d = s1 xor s2 +-- +-- s1, nor s2 are modified unless d point to s1 or s2 +memXor :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO () +memXor _ _ _ 0 = return () +memXor d s1 s2 n = do + (xor <$> peek s1 <*> peek s2) >>= poke d + memXor (d `plusPtr` 1) (s1 `plusPtr` 1) (s2 `plusPtr` 1) (n-1) + +-- | xor bytes from source with a specific value to destination +-- +-- d = replicate (sizeof s) v `xor` s +memXorWith :: Ptr Word8 -> Word8 -> Ptr Word8 -> Int -> IO () +memXorWith d v s n = loop 0 + where + loop i + | i == n = return () + | otherwise = do + (xor v <$> peekByteOff s i) >>= pokeByteOff d i + loop (i+1) + +-- | Copy a set number of bytes from @src to @dst +memCopy :: Ptr Word8 -> Ptr Word8 -> Int -> IO () +memCopy dst src n = c_memcpy dst src (fromIntegral n) + +-- | Set @n number of bytes to the same value @v +memSet :: Ptr Word8 -> Word8 -> Int -> IO () +memSet start v n = c_memset start (fromIntegral v) (fromIntegral n) >>= \_ -> return () + +memEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool +memEqual p1 p2 n = loop 0 + where + loop i + | i == n = return True + | otherwise = do + e <- (==) <$> peekByteOff p1 i <*> (peekByteOff p2 i :: IO Word8) + if e then loop (i+1) else return False + +memCompare :: Ptr Word8 -> Ptr Word8 -> Int -> IO Ordering +memCompare p1 p2 n = loop 0 + where + loop i + | i == n = return EQ + | otherwise = do + e <- compare <$> peekByteOff p1 i <*> (peekByteOff p2 i :: IO Word8) + if e == EQ then loop (i+1) else return e + +-- | A constant time equality test for 2 Memory buffers +-- +-- compared to normal equality function, this function will go +-- over all the bytes present before yielding a result even when +-- knowing the overall result early in the processing. +memConstEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool +memConstEqual p1 p2 n = loop 0 True + where + loop i !ret + | i == n = return ret + | otherwise = do + e <- (==) <$> peek p1 <*> peek p2 + loop (i+1) (ret &&! e) + + -- Bool == Bool + (&&!) :: Bool -> Bool -> Bool + True &&! True = True + True &&! False = False + False &&! True = False + False &&! False = False + +foreign import ccall unsafe "memset" + c_memset :: Ptr Word8 -> Word8 -> CSize -> IO () + +foreign import ccall unsafe "memcpy" + c_memcpy :: Ptr Word8 -> Ptr Word8 -> CSize -> IO () diff --git a/cryptonite.cabal b/cryptonite.cabal index 002445d..3b5ac23 100644 --- a/cryptonite.cabal +++ b/cryptonite.cabal @@ -35,6 +35,11 @@ Flag support_pclmuldq Default: False Manual: True +Flag builtin_memory + Description: Build with a local snapshot of the memory package + Default: True + Manual: True + Library Exposed-modules: Crypto.Cipher.AES Crypto.Cipher.Blowfish @@ -164,6 +169,24 @@ Library , cbits/cryptonite_scrypt.c include-dirs: cbits cbits/ed25519 + if flag(builtin_memory) + Exposed-modules: Data.Memory.ByteArray + Data.Memory.Endian + Data.Memory.PtrMethods + Data.Memory.ExtendedWords + Data.Memory.Encoding.Base16 + Other-modules: Data.Memory.Internal.Compat + Data.Memory.Internal.CompatPrim + Data.Memory.Internal.Imports + Data.Memory.ByteArray.Types + Data.Memory.ByteArray.Bytes + Data.Memory.ByteArray.ScrubbedBytes + Data.Memory.ByteArray.Methods + Data.Memory.ByteArray.MemView + CPP-options: -DWITH_BYTESTRING_SUPPORT + else + build-depends: memory + -- FIXME armel or mispel is also little endian. -- might be a good idea to also add a runtime autodetect mode. -- ARCH_ENDIAN_UNKNOWN