From 8d9f493fe2dc0c6db81a3dfddac449ae4b80f8a1 Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Wed, 8 Apr 2015 20:42:15 +0100 Subject: [PATCH] add fast and time constant Eq function for bytearray --- Crypto/Internal/ByteArray.hs | 51 +++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/Crypto/Internal/ByteArray.hs b/Crypto/Internal/ByteArray.hs index ecdb60c..4a2fda9 100644 --- a/Crypto/Internal/ByteArray.hs +++ b/Crypto/Internal/ByteArray.hs @@ -18,6 +18,8 @@ module Crypto.Internal.ByteArray , byteArrayCopyAndFreeze , byteArraySplit , byteArrayXor + , byteArrayEq + , byteArrayConstEq , byteArrayConcat , byteArrayToBS , byteArrayFromBS @@ -25,7 +27,7 @@ module Crypto.Internal.ByteArray , byteArrayToW64LE ) where -import Control.Applicative ((<$>)) +import Control.Applicative ((<$>), (<*>)) import Data.Word import Data.SecureMem import Crypto.Internal.Memory @@ -120,6 +122,53 @@ byteArrayCopyAndFreeze bs f = withByteArray bs $ \s -> bufCopy d s (byteArrayLength bs) f (castPtr d) +byteArrayEq :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> Bool +byteArrayEq b1 b2 + | l1 /= l2 = False + | otherwise = unsafeDoIO $ + withByteArray b1 $ \p1 -> + withByteArray b2 $ \p2 -> + loop l1 p1 p2 + where + l1 = byteArrayLength b1 + l2 = byteArrayLength b2 + loop :: Int -> Ptr Word8 -> Ptr Word8 -> IO Bool + loop 0 _ _ = return True + loop i p1 p2 = do + e <- (==) <$> peek p1 <*> peek p2 + if e then loop (i-1) (p1 `plusPtr` 1) (p2 `plusPtr` 1) else return False + +-- | A constant time equality test for 2 ByteArrayAccess values. +-- +-- If values are of 2 different sizes, the function will abort early +-- without comparing any bytes. +-- +-- compared to == , this function will go over all the bytes +-- present before yielding a result even when knowing the +-- overall result early in the processing. +byteArrayConstEq :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> Bool +byteArrayConstEq b1 b2 + | l1 /= l2 = False + | otherwise = unsafeDoIO $ + withByteArray b1 $ \p1 -> + withByteArray b2 $ \p2 -> + loop l1 True p1 p2 + where + l1 = byteArrayLength b1 + l2 = byteArrayLength b2 + loop :: Int -> Bool -> Ptr Word8 -> Ptr Word8 -> IO Bool + loop 0 !ret _ _ = return ret + loop i !ret p1 p2 = do + e <- (==) <$> peek p1 <*> peek p2 + loop (i-1) (ret &&! e) (p1 `plusPtr` 1) (p2 `plusPtr` 1) + + -- Bool == Bool + (&&!) :: Bool -> Bool -> Bool + True &&! True = True + True &&! False = False + False &&! True = False + False &&! False = False + byteArrayToBS :: ByteArray bs => bs -> ByteString byteArrayToBS bs = byteArrayCopyAndFreeze bs (\_ -> return ())