diff --git a/Crypto/Number/Generate.hs b/Crypto/Number/Generate.hs index 1f356e5..b27fae7 100644 --- a/Crypto/Number/Generate.hs +++ b/Crypto/Number/Generate.hs @@ -21,7 +21,7 @@ import Crypto.Random.Types import Control.Monad (when) import Foreign.Ptr import Foreign.Storable -import Data.Bits ((.|.), (.&.), shiftL, shiftR, complement) +import Data.Bits ((.|.), (.&.), shiftL, complement, testBit) import Crypto.Internal.ByteArray (Bytes, ScrubbedBytes) import qualified Crypto.Internal.ByteArray as B @@ -79,23 +79,48 @@ generateParams bits genTopPolicy generateOdd bit = (bits - 1) `mod` 8; mask = 0xff `shiftL` (bit + 1); --- | generate a positive integer x, s.t. 0 <= x < m -generateMax :: MonadRandom m => Integer -> m Integer -generateMax 1 = return 0 -generateMax m - | m <= 0 = error "negative value for generateMax" - | otherwise = do - result <- randomInt bytesLength - let result' = result `shiftR` bitsPoppedOff - if result' >= m - then generateMax m - else return result' +-- | Generate a positive integer x, s.t. 0 <= x < range +generateMax :: MonadRandom m + => Integer -- ^ range + -> m Integer +generateMax range + | range <= 1 = return 0 + | range < 127 = generateSimple + | canOverGenerate = loopGenerateOver tries + | otherwise = loopGenerate tries where - bytesLength = lengthBytes m - bitsLength = log2 (m-1) + 1 - bitsPoppedOff = 8 - (bitsLength `mod` 8) + -- this "generator" is mostly for quickcheck benefits. it'll be biased if + -- range is not a multiple of 2, but overall, no security should be + -- assumed for a number between 0 and 127. + generateSimple = flip mod range `fmap` generateParams bits Nothing False - randomInt nbBytes = os2ipBytes <$> getRandomBytes nbBytes + loopGenerate count + | count == 0 = error "internal: generateMax (normal) doesn't seems to work properly" + | otherwise = do + r <- generateParams bits Nothing False + if isValid r then return r else loopGenerate (count-1) + + loopGenerateOver count + | count == 0 = error "internal: generateMax (over) doesn't seems to work properly" + | otherwise = do + r <- generateParams (bits+1) Nothing False + let r2 = r - range + r3 = r2 - range + if isValid r + then return r + else if isValid r2 + then return r2 + else if isValid r3 + then return r3 + else loopGenerateOver (count-1) + + bits = numBits range + canOverGenerate = bits > 3 && not (range `testBit` (bits-2)) && not (range `testBit` (bits-3)) + + isValid n = n < range + + tries :: Int + tries = 100 -- | generate a number between the inclusive bound [low,high]. generateBetween :: MonadRandom m => Integer -> Integer -> m Integer