diff --git a/Crypto/Cipher/Blowfish/Primitive.hs b/Crypto/Cipher/Blowfish/Primitive.hs index f2b8cd8..6e1af82 100644 --- a/Crypto/Cipher/Blowfish/Primitive.hs +++ b/Crypto/Cipher/Blowfish/Primitive.hs @@ -18,10 +18,9 @@ module Crypto.Cipher.Blowfish.Primitive , decrypt ) where -import Control.Monad (forM_) +import Control.Monad (forM_, when) import Data.Bits import Data.Word -import qualified Data.ByteString as B import Crypto.Error import Crypto.Internal.Compat @@ -53,23 +52,9 @@ cipher ctx b initBlowfish :: ByteArray key => key -> CryptoFailable Context initBlowfish key | len > (448 `div` 8) = CryptoFailed $ CryptoError_KeySizeInvalid - | len == 0 = keyFromByteString (B.replicate (18*4) 0) - | otherwise = keyFromByteString . B.pack . take (18*4) . cycle . B.unpack . byteArrayToBS $ key + | otherwise = CryptoPassed $ makeKeySchedule key where len = byteArrayLength key -keyFromByteString :: B.ByteString -> CryptoFailable Context -keyFromByteString k - | B.length k /= (18 * 4) = CryptoFailed CryptoError_KeySizeInvalid - | otherwise = CryptoPassed . makeKeySchedule . w8tow32 . B.unpack $ k - where - w8tow32 :: [Word8] -> [Word32] - w8tow32 [] = [] - w8tow32 (a:b:c:d:xs) = ( (fromIntegral a `shiftL` 24) .|. - (fromIntegral b `shiftL` 16) .|. - (fromIntegral c `shiftL` 8) .|. - (fromIntegral d) ) : w8tow32 xs - w8tow32 _ = error $ "internal error: Crypto.Cipher.Blowfish:keyFromByteString" - coreCrypto :: Context -> Word64 -> Word64 coreCrypto (BF p s0 s1 s2 s3) input = doRound input 0 where @@ -91,11 +76,21 @@ coreCrypto (BF p s0 s1 s2 s3) input = doRound input 0 d = s3 (fromIntegral $ t .&. 0xff) in fromIntegral (((a + b) `xor` c) + d) `shiftL` 32 -makeKeySchedule :: [Word32] -> Context +makeKeySchedule :: ByteArray key => key -> Context makeKeySchedule key = let v = unsafeDoIO $ do + let len = byteArrayLength key mv <- createKeySchedule - forM_ (zip key [0..17]) $ \(k, i) -> mutableArrayWriteXor32 mv i k + when (len > 0) $ forM_ [0..17] $ \i -> do + let a = byteArrayIndex key ((i * 4 + 0) `mod` len) + b = byteArrayIndex key ((i * 4 + 1) `mod` len) + c = byteArrayIndex key ((i * 4 + 2) `mod` len) + d = byteArrayIndex key ((i * 4 + 3) `mod` len) + k = (fromIntegral a `shiftL` 24) .|. + (fromIntegral b `shiftL` 16) .|. + (fromIntegral c `shiftL` 8) .|. + (fromIntegral d) + mutableArrayWriteXor32 mv i k prepare mv mutableArray32Freeze mv in BF (\i -> arrayRead32 v i)