[number] generate number with bounds more effectively

This commit is contained in:
Vincent Hanquez 2015-05-23 11:59:10 +01:00
parent a4baf9383b
commit 2153e5690f

View File

@ -21,7 +21,7 @@ import Crypto.Random.Types
import Control.Monad (when) import Control.Monad (when)
import Foreign.Ptr import Foreign.Ptr
import Foreign.Storable import Foreign.Storable
import Data.Bits ((.|.), (.&.), shiftL, shiftR, complement) import Data.Bits ((.|.), (.&.), shiftL, complement, testBit)
import Crypto.Internal.ByteArray (Bytes, ScrubbedBytes) import Crypto.Internal.ByteArray (Bytes, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B import qualified Crypto.Internal.ByteArray as B
@ -79,23 +79,48 @@ generateParams bits genTopPolicy generateOdd
bit = (bits - 1) `mod` 8; bit = (bits - 1) `mod` 8;
mask = 0xff `shiftL` (bit + 1); mask = 0xff `shiftL` (bit + 1);
-- | generate a positive integer x, s.t. 0 <= x < m -- | Generate a positive integer x, s.t. 0 <= x < range
generateMax :: MonadRandom m => Integer -> m Integer generateMax :: MonadRandom m
generateMax 1 = return 0 => Integer -- ^ range
generateMax m -> m Integer
| m <= 0 = error "negative value for generateMax" generateMax range
| otherwise = do | range <= 1 = return 0
result <- randomInt bytesLength | range < 127 = generateSimple
let result' = result `shiftR` bitsPoppedOff | canOverGenerate = loopGenerateOver tries
if result' >= m | otherwise = loopGenerate tries
then generateMax m
else return result'
where where
bytesLength = lengthBytes m -- this "generator" is mostly for quickcheck benefits. it'll be biased if
bitsLength = log2 (m-1) + 1 -- range is not a multiple of 2, but overall, no security should be
bitsPoppedOff = 8 - (bitsLength `mod` 8) -- assumed for a number between 0 and 127.
generateSimple = flip mod range `fmap` generateParams bits Nothing False
randomInt nbBytes = os2ipBytes <$> getRandomBytes nbBytes loopGenerate count
| count == 0 = error "internal: generateMax (normal) doesn't seems to work properly"
| otherwise = do
r <- generateParams bits Nothing False
if isValid r then return r else loopGenerate (count-1)
loopGenerateOver count
| count == 0 = error "internal: generateMax (over) doesn't seems to work properly"
| otherwise = do
r <- generateParams (bits+1) Nothing False
let r2 = r - range
r3 = r2 - range
if isValid r
then return r
else if isValid r2
then return r2
else if isValid r3
then return r3
else loopGenerateOver (count-1)
bits = numBits range
canOverGenerate = bits > 3 && not (range `testBit` (bits-2)) && not (range `testBit` (bits-3))
isValid n = n < range
tries :: Int
tries = 100
-- | generate a number between the inclusive bound [low,high]. -- | generate a number between the inclusive bound [low,high].
generateBetween :: MonadRandom m => Integer -> Integer -> m Integer generateBetween :: MonadRandom m => Integer -> Integer -> m Integer