diff --git a/Crypto/Number/Generate.hs b/Crypto/Number/Generate.hs index 48365c3..912f569 100644 --- a/Crypto/Number/Generate.hs +++ b/Crypto/Number/Generate.hs @@ -16,9 +16,9 @@ import Crypto.Internal.Imports import Crypto.Number.Basic import Crypto.Number.Serialize import Crypto.Random.Types -import qualified Data.ByteString as B -import Crypto.Internal.ByteArray (Bytes) import Data.Bits ((.|.), (.&.), shiftR) +import Crypto.Internal.ByteArray (Bytes, ScrubbedBytes) +import qualified Crypto.Internal.ByteArray as B -- | generate a positive integer x, s.t. 0 <= x < m @@ -47,9 +47,12 @@ generateBetween low high = (low +) <$> generateMax (high - low + 1) -- the number of bits need to be multiple of 8. It will always returns -- an integer that is close to 2^(1+bits/8) by setting the 2 highest bits to 1. generateOfSize :: MonadRandom m => Int -> m Integer -generateOfSize bits = unmarshall <$> getRandomBytes (bits `div` 8) +generateOfSize bits = os2ip . setHighest <$> getRandomBytes (bits `div` 8) where - unmarshall bs = os2ip $ snd $ B.mapAccumL (\acc w -> (0, w .|. acc)) 0xc0 bs + setHighest :: ScrubbedBytes -> ScrubbedBytes + setHighest ran = case B.unpack ran of + [] -> B.empty + (w:ws) -> B.pack ((w .|. 0xc0) : ws) -- | Generate a number with the specified number of bits generateBits :: MonadRandom m => Int -> m Integer