From 9567fa2526a49b0cdbf1dd8e30e14168c499d0e2 Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Mon, 11 May 2015 09:28:48 +0100 Subject: [PATCH] [number] almost rewrite all serialization to be faster and less depends on random versions --- Crypto/Number/Serialize.hs | 178 +++++++++++-------------------------- 1 file changed, 54 insertions(+), 124 deletions(-) diff --git a/Crypto/Number/Serialize.hs b/Crypto/Number/Serialize.hs index 97a613a..ea1798e 100644 --- a/Crypto/Number/Serialize.hs +++ b/Crypto/Number/Serialize.hs @@ -1,10 +1,3 @@ -{-# LANGUAGE CPP #-} -#ifndef MIN_VERSION_integer_gmp -#define MIN_VERSION_integer_gmp(a,b,c) 0 -#endif -#if MIN_VERSION_integer_gmp(0,5,1) -{-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns #-} -#endif -- | -- Module : Crypto.Number.Serialize -- License : BSD-style @@ -13,6 +6,7 @@ -- Portability : Good -- -- fast serialization primitives for integer +{-# LANGUAGE BangPatterns #-} module Crypto.Number.Serialize ( i2osp , os2ip @@ -21,147 +15,83 @@ module Crypto.Number.Serialize , lengthBytes ) where -import Data.ByteString (ByteString) -import qualified Data.ByteString.Internal as B -import qualified Data.ByteString as B hiding (length) -import Foreign.Ptr - -#if MIN_VERSION_integer_gmp(0,5,1) -#if __GLASGOW_HASKELL__ >= 710 -import Control.Monad (void) -#endif -import GHC.Integer.GMP.Internals -import GHC.Base -import GHC.Ptr -import System.IO.Unsafe -import Foreign.ForeignPtr -#else -import Foreign.Storable -import Data.Bits -#endif - +import Data.Bits +import Data.Word +import Foreign.Storable +import Foreign.Ptr +import Crypto.Number.Compat +import Crypto.Internal.Compat (unsafeDoIO) import qualified Crypto.Internal.ByteArray as B +import Data.Memory.PtrMethods -#if !MIN_VERSION_integer_gmp(0,5,1) -{-# INLINE divMod256 #-} -divMod256 :: Integer -> (Integer, Integer) -divMod256 n = (n `shiftR` 8, n .&. 0xff) -#endif +divMod256 :: Integer -> (Integer, Word8) +divMod256 n = (n `shiftR` 8, fromIntegral n) -- | os2ip converts a byte string into a positive integer os2ip :: B.ByteArrayAccess ba => ba -> Integer -#if MIN_VERSION_integer_gmp(0,5,1) -os2ip bs = unsafePerformIO $ B.withByteArray fptr $ \ptr -> - let !(Ptr ad) = (ptr `plusPtr` ofs) -#if __GLASGOW_HASKELL__ >= 710 - in importIntegerFromAddr ad (int2Word# n) 1# -#else - in IO $ \s -> importIntegerFromAddr ad (int2Word# n) 1# s -#endif -{-# NOINLINE os2ip #-} -#else -os2ip = B.foldl' (\a b -> (256 * a) .|. (fromIntegral b)) 0 . B.convert -{-# INLINE os2ip #-} -#endif +os2ip bs = unsafeDoIO $ B.withByteArray bs (loop 0 0) + where + len = B.length bs + + loop :: Integer -> Int -> Ptr Word8 -> IO Integer + loop !acc i p + | i == len = return acc + | otherwise = do + w <- peekByteOff p i :: IO Word8 + loop ((acc `shiftL` 8) .|. fromIntegral w) (i+1) p -- | i2osp converts a positive integer into a byte string +-- +-- first byte is MSB (most significant byte), last byte is the LSB (least significant byte) i2osp :: B.ByteArray ba => Integer -> ba -#if MIN_VERSION_integer_gmp(0,5,1) -i2osp 0 = B.allocAndFreeze 1 $ \p -> poke p (0 :: Word8) -i2osp m = B.allocAndFreeze (I# (word2Int# sz)) fillPtr - where !sz = sizeInBaseInteger m 256# -#if __GLASGOW_HASKELL__ >= 710 - fillPtr (Ptr srcAddr) = void $ exportIntegerToAddr m srcAddr 1# -#else - fillPtr (Ptr srcAddr) = IO $ \s -> case exportIntegerToAddr m srcAddr 1# s of - (# s2, _ #) -> (# s2, () #) -#endif -{-# NOINLINE i2osp #-} -#else -i2osp m - | m < 0 = error "i2osp: cannot convert a negative integer to a bytestring" - | otherwise = B.convert $ B.reverse $ B.unfoldr fdivMod256 m - where fdivMod256 0 = Nothing - fdivMod256 n = Just (fromIntegral a,b) where (b,a) = divMod256 n -#endif +i2osp 0 = B.allocAndFreeze 1 $ \p -> pokeByteOff p 0 (0 :: Word8) +i2osp m = B.allocAndFreeze sz (\p -> fillPtr p >> return ()) + where + !sz = lengthBytes m + fillPtr p = gmpExportInteger m p `onGmpUnsupported` export p (sz-1) m + export p ofs i + | ofs == 0 = pokeByteOff p ofs (fromIntegral i :: Word8) + | otherwise = do + let (i', b) = divMod256 i + pokeByteOff p ofs b + export p (ofs-1) i' -- | just like i2osp, but take an extra parameter for size. -- if the number is too big to fit in @len bytes, nothing is returned -- otherwise the number is padded with 0 to fit the @len required. --- --- FIXME: use unsafeCreate to fill the bytestring i2ospOf :: B.ByteArray ba => Int -> Integer -> Maybe ba -#if MIN_VERSION_integer_gmp(0,5,1) +i2ospOf 0 _ = error "cannot create integer serialization in 0 bytes" +i2ospOf len 0 = Just $ B.allocAndFreeze len $ \p -> memSet p 0 len i2ospOf len m - | sz <= len = Just $ i2ospOf_ len m - | otherwise = Nothing - where !sz = I# (word2Int# (sizeInBaseInteger m 256#)) -#else -i2ospOf len m - | lenbytes < len = Just $ B.convert $ B.replicate (len - lenbytes) 0 `B.append` bytes - | lenbytes == len = Just $ B.convert bytes - | otherwise = Nothing - where lenbytes = B.length bytes - bytes = i2osp m -#endif + | sz > len = Nothing + | otherwise = Just $ B.allocAndFreeze len $ \p -> memSet p 0 len >> fillPtr (p `plusPtr` (len - sz)) + where + !sz = lengthBytes m + fillPtr p = gmpExportInteger m p `onGmpUnsupported` export p (sz-1) m + export p ofs i + | ofs == 0 = pokeByteOff p ofs (fromIntegral i :: Word8) + | otherwise = do + let (i', b) = divMod256 i + pokeByteOff p ofs b + export p (ofs-1) i' + +-- -- | just like i2ospOf except that it doesn't expect a failure: i.e. -- an integer larger than the number of output bytes requested -- -- for example if you just took a modulo of the number that represent -- the size (example the RSA modulo n). i2ospOf_ :: B.ByteArray ba => Int -> Integer -> ba -#if MIN_VERSION_integer_gmp(0,5,1) -i2ospOf_ len m = B.allocAndFreeze len fillPtr - where !sz = (sizeInBaseInteger m 256#) - isz = I# (word2Int# sz) - fillPtr ptr - | len < isz = error "cannot compute i2ospOf_ with integer larger than output bytes" - | len == isz = - let !(Ptr srcAddr) = ptr in -#if __GLASGOW_HASKELL__ >= 710 - void (exportIntegerToAddr m srcAddr 1#) -#else - IO $ \s -> case exportIntegerToAddr m srcAddr 1# s of - (# s2, _ #) -> (# s2, () #) -#endif - | otherwise = do - let z = len-isz - _ <- B.memset ptr 0 (fromIntegral len) - let !(Ptr addr) = ptr `plusPtr` z -#if __GLASGOW_HASKELL__ >= 710 - void (exportIntegerToAddr m addr 1#) -#else - IO $ \s -> case exportIntegerToAddr m addr 1# s of - (# s2, _ #) -> (# s2, () #) -#endif -{-# NOINLINE i2ospOf_ #-} -#else -i2ospOf_ len m = B.convert $ B.unsafeCreate len fillPtr - where fillPtr srcPtr = loop m (srcPtr `plusPtr` (len-1)) - where loop n ptr = do - let (nn,a) = divMod256 n - poke ptr (fromIntegral a) - if ptr == srcPtr - then return () - else (if nn == 0 then fillerLoop else loop nn) (ptr `plusPtr` (-1)) - fillerLoop ptr = do - poke ptr 0 - if ptr == srcPtr - then return () - else fillerLoop (ptr `plusPtr` (-1)) -{-# INLINE i2ospOf_ #-} -#endif +i2ospOf_ len = maybe (error "i2ospOf_: integer is larger than expected") id . i2ospOf len -- | returns the number of bytes to store an integer with i2osp -- -- with integer-simple, this function is really slow. lengthBytes :: Integer -> Int -#if MIN_VERSION_integer_gmp(0,5,1) -lengthBytes n = I# (word2Int# (sizeInBaseInteger n 256#)) -#else -lengthBytes n - | n < 256 = 1 - | otherwise = 1 + lengthBytes (n `shiftR` 8) -#endif +lengthBytes n = gmpSizeInBytes n `onGmpUnsupported` nbBytes n + where + nbBytes !v + | v < 256 = 1 + | otherwise = 1 + nbBytes (v `shiftR` 8)