re-enable number generation
This commit is contained in:
parent
be3eacc068
commit
d9b16a529e
@ -6,51 +6,55 @@
|
|||||||
-- Portability : Good
|
-- Portability : Good
|
||||||
|
|
||||||
module Crypto.Number.Generate
|
module Crypto.Number.Generate
|
||||||
( {-generateMax
|
( generateMax
|
||||||
, generateBetween
|
, generateBetween
|
||||||
, generateOfSize
|
, generateOfSize
|
||||||
, generateBits-}
|
, generateBits
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Control.Applicative
|
||||||
import Crypto.Number.Basic
|
import Crypto.Number.Basic
|
||||||
import Crypto.Number.Serialize
|
import Crypto.Number.Serialize
|
||||||
|
import Crypto.Random.Types
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import Data.Bits ((.|.), (.&.), shiftR)
|
import Data.Bits ((.|.), (.&.), shiftR)
|
||||||
|
|
||||||
{-
|
|
||||||
-- | generate a positive integer x, s.t. 0 <= x < m
|
-- | generate a positive integer x, s.t. 0 <= x < m
|
||||||
generateMax :: CPRG g => g -> Integer -> (Integer, g)
|
generateMax :: MonadRandom m => Integer -> m Integer
|
||||||
generateMax rng 1 = (0, rng)
|
generateMax 1 = return 0
|
||||||
generateMax rng m
|
generateMax m
|
||||||
| (result' >= m) = generateMax rng' m
|
| m <= 0 = error "negative value for generateMax"
|
||||||
| otherwise = (result', rng')
|
| otherwise = do
|
||||||
|
result <- randomInt bytesLength
|
||||||
|
let result' = result `shiftR` bitsPoppedOff
|
||||||
|
if result' >= m
|
||||||
|
then generateMax m
|
||||||
|
else return result'
|
||||||
where
|
where
|
||||||
bytesLength = lengthBytes m
|
bytesLength = lengthBytes m
|
||||||
bitsLength = (log2 (m-1) + 1)
|
bitsLength = log2 (m-1) + 1
|
||||||
bitsPoppedOff = 8 - (bitsLength `mod` 8)
|
bitsPoppedOff = 8 - (bitsLength `mod` 8)
|
||||||
randomInt bytes = withRandomBytes rng bytes $ \bs -> os2ip bs
|
|
||||||
|
|
||||||
(result, rng') = randomInt bytesLength
|
randomInt nbBytes = os2ip <$> getRandomBytes nbBytes
|
||||||
result' = result `shiftR` bitsPoppedOff
|
|
||||||
|
|
||||||
-- | generate a number between the inclusive bound [low,high].
|
-- | generate a number between the inclusive bound [low,high].
|
||||||
generateBetween :: CPRG g => g -> Integer -> Integer -> (Integer, g)
|
generateBetween :: MonadRandom m => Integer -> Integer -> m Integer
|
||||||
generateBetween rng low high = (low + v, rng')
|
generateBetween low high = (low +) <$> generateMax (high - low + 1)
|
||||||
where (v, rng') = generateMax rng (high - low + 1)
|
|
||||||
|
|
||||||
-- | generate a positive integer of a specific size in bits.
|
-- | generate a positive integer of a specific size in bits.
|
||||||
-- the number of bits need to be multiple of 8. It will always returns
|
-- the number of bits need to be multiple of 8. It will always returns
|
||||||
-- an integer that is close to 2^(1+bits/8) by setting the 2 highest bits to 1.
|
-- an integer that is close to 2^(1+bits/8) by setting the 2 highest bits to 1.
|
||||||
generateOfSize :: CPRG g => g -> Int -> (Integer, g)
|
generateOfSize :: MonadRandom m => Int -> m Integer
|
||||||
generateOfSize rng bits = withRandomBytes rng (bits `div` 8) $ \bs ->
|
generateOfSize bits = toInteger <$> getRandomBytes (bits `div` 8)
|
||||||
os2ip $ snd $ B.mapAccumL (\acc w -> (0, w .|. acc)) 0xc0 bs
|
where
|
||||||
|
toInteger bs = os2ip $ snd $ B.mapAccumL (\acc w -> (0, w .|. acc)) 0xc0 bs
|
||||||
|
|
||||||
-- | Generate a number with the specified number of bits
|
-- | Generate a number with the specified number of bits
|
||||||
generateBits :: CPRG g => g -> Int -> (Integer, g)
|
generateBits :: MonadRandom m => Int -> m Integer
|
||||||
generateBits rng nbBits = withRandomBytes rng nbBytes' $ \bs -> modF (os2ip bs)
|
generateBits nbBits = modF . os2ip <$> getRandomBytes nbBytes'
|
||||||
where (nbBytes, strayBits) = nbBits `divMod` 8
|
where (nbBytes, strayBits) = nbBits `divMod` 8
|
||||||
nbBytes' | strayBits == 0 = nbBytes
|
nbBytes' | strayBits == 0 = nbBytes
|
||||||
| otherwise = nbBytes + 1
|
| otherwise = nbBytes + 1
|
||||||
modF | strayBits == 0 = id
|
modF | strayBits == 0 = id
|
||||||
| otherwise = (.&.) (2^nbBits - 1)
|
| otherwise = (.&.) (2^nbBits - 1)
|
||||||
-}
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user