Performance improvements

This commit is contained in:
Sam Protas 2017-04-02 19:36:58 -04:00
parent 7eedbaa112
commit b1a9c7c047

View File

@ -29,8 +29,8 @@ mdsPolynomial, rsPolynomial :: Word32
mdsPolynomial = 0x169 -- x^8 + x^6 + x^5 + x^3 + 1, see [TWOFISH] 4.2 mdsPolynomial = 0x169 -- x^8 + x^6 + x^5 + x^3 + 1, see [TWOFISH] 4.2
rsPolynomial = 0x14d -- x^8 + x^6 + x^3 + x^2 + 1, see [TWOFISH] 4.3 rsPolynomial = 0x14d -- x^8 + x^6 + x^3 + x^2 + 1, see [TWOFISH] 4.3
data Twofish = Twofish { s :: ([Word32], [Word32], [Word32], [Word32]) data Twofish = Twofish { s :: (Array32, Array32, Array32, Array32)
, k :: [Word32] } , k :: Array32 }
-- CONFIRMED -- CONFIRMED
-- | Initialize a 128-bit key -- | Initialize a 128-bit key
@ -43,7 +43,7 @@ initTwofish :: ByteArray key
initTwofish key initTwofish key
| B.length key /= blockSize = CryptoFailed CryptoError_KeySizeInvalid | B.length key /= blockSize = CryptoFailed CryptoError_KeySizeInvalid
| otherwise = CryptoPassed Twofish { k = generatedK, s = generatedS } | otherwise = CryptoPassed Twofish { k = generatedK, s = generatedS }
where generatedK = genK key where generatedK = array32 40 $ genK key
generatedS = genSboxes $ sWords key generatedS = genSboxes $ sWords key
@ -64,25 +64,28 @@ encrypt cipher = mapBlocks (encryptBlock cipher)
encryptBlock :: ByteArray ba => Twofish -> ba -> ba encryptBlock :: ByteArray ba => Twofish -> ba -> ba
encryptBlock Twofish { s = (s1, s2, s3, s4), k = ks } message = store32ls ts encryptBlock Twofish { s = (s1, s2, s3, s4), k = ks } message = store32ls ts
where (a, b, c, d) = load32ls message where (a, b, c, d) = load32ls message
[a', b', c', d'] = zipWith xor [a, b, c, d] ks a' = a `xor` arrayRead32 ks 0
b' = b `xor` arrayRead32 ks 1
c' = c `xor` arrayRead32 ks 2
d' = d `xor` arrayRead32 ks 3
(a'', b'', c'', d'') = foldl' shuffle (a', b', c', d') [0..7] (a'', b'', c'', d'') = foldl' shuffle (a', b', c', d') [0..7]
ts = (c'' `xor` ks !! 4, d'' `xor` ks !! 5, a'' `xor` ks !! 6, b'' `xor` ks !! 7) ts = (c'' `xor` arrayRead32 ks 4, d'' `xor` arrayRead32 ks 5, a'' `xor` arrayRead32 ks 6, b'' `xor` arrayRead32 ks 7)
shuffle :: (Word32, Word32, Word32, Word32) -> Int -> (Word32, Word32, Word32, Word32) shuffle :: (Word32, Word32, Word32, Word32) -> Int -> (Word32, Word32, Word32, Word32)
shuffle (retA, retB, retC, retD) ind = (retA', retB', retC', retD') shuffle (retA, retB, retC, retD) ind = (retA', retB', retC', retD')
where ks' = take 4 $ drop (8 + 4 * ind) ks where [k0, k1, k2, k3] = fmap (\offset -> arrayRead32 ks $ (8 + 4 * ind) + offset) [0..3]
t2 = byteIndex s2 retB `xor` byteIndex s3 (shiftR retB 8) `xor` byteIndex s4 (shiftR retB 16) `xor` byteIndex s1 (shiftR retB 24) t2 = byteIndex s2 retB `xor` byteIndex s3 (shiftR retB 8) `xor` byteIndex s4 (shiftR retB 16) `xor` byteIndex s1 (shiftR retB 24)
t1 = (byteIndex s1 retA `xor` byteIndex s2 (shiftR retA 8) `xor` byteIndex s3 (shiftR retA 16) `xor` byteIndex s4 (shiftR retA 24)) + t2 t1 = (byteIndex s1 retA `xor` byteIndex s2 (shiftR retA 8) `xor` byteIndex s3 (shiftR retA 16) `xor` byteIndex s4 (shiftR retA 24)) + t2
retC' = rotateR (retC `xor` (t1 + head ks')) 1 retC' = rotateR (retC `xor` (t1 + k0)) 1
retD' = rotateL retD 1 `xor` (t1 + t2 + (ks' !! 1)) retD' = rotateL retD 1 `xor` (t1 + t2 + k1)
t2' = byteIndex s2 retD' `xor` byteIndex s3 (shiftR retD' 8) `xor` byteIndex s4 (shiftR retD' 16) `xor` byteIndex s1 (shiftR retD' 24) t2' = byteIndex s2 retD' `xor` byteIndex s3 (shiftR retD' 8) `xor` byteIndex s4 (shiftR retD' 16) `xor` byteIndex s1 (shiftR retD' 24)
t1' = (byteIndex s1 retC' `xor` byteIndex s2 (shiftR retC' 8) `xor` byteIndex s3 (shiftR retC' 16) `xor` byteIndex s4 (shiftR retC' 24)) + t2' t1' = (byteIndex s1 retC' `xor` byteIndex s2 (shiftR retC' 8) `xor` byteIndex s3 (shiftR retC' 16) `xor` byteIndex s4 (shiftR retC' 24)) + t2'
retA' = rotateR (retA `xor` (t1' + (ks' !! 2))) 1 retA' = rotateR (retA `xor` (t1' + k2)) 1
retB' = rotateL retB 1 `xor` (t1' + t2' + (ks' !! 3)) retB' = rotateL retB 1 `xor` (t1' + t2' + k3)
-- Unsafe, no bounds checking -- Unsafe, no bounds checking
byteIndex :: Integral a => [b] -> a -> b byteIndex :: Integral a => Array32 -> a -> Word32
byteIndex xs ind = xs !! fromIntegral byte byteIndex xs ind = arrayRead32 xs $ fromIntegral byte
where byte = fromIntegral ind :: Word8 where byte = fromIntegral ind :: Word8
-- | Decrypts the given ByteString using the given Key -- | Decrypts the given ByteString using the given Key
@ -96,21 +99,21 @@ decrypt cipher = mapBlocks (decryptBlock cipher)
decryptBlock :: ByteArray ba => Twofish -> ba -> ba decryptBlock :: ByteArray ba => Twofish -> ba -> ba
decryptBlock Twofish { s = (s1, s2, s3, s4), k = ks } message = store32ls ixs decryptBlock Twofish { s = (s1, s2, s3, s4), k = ks } message = store32ls ixs
where (a, b, c, d) = load32ls message where (a, b, c, d) = load32ls message
(a', b', c', d') = (c `xor` ks !! 6, d `xor` ks !! 7, a `xor` ks !! 4, b `xor` ks !! 5) (a', b', c', d') = (c `xor` arrayRead32 ks 6, d `xor` arrayRead32 ks 7, a `xor` arrayRead32 ks 4, b `xor` arrayRead32 ks 5)
(a'', b'', c'', d'') = foldl' unshuffle (a', b', c', d') [8, 7..1] (a'', b'', c'', d'') = foldl' unshuffle (a', b', c', d') [8, 7..1]
ixs = (a'' `xor` head ks, b'' `xor` ks !! 1, c'' `xor` ks !! 2, d'' `xor` ks !! 3) ixs = (a'' `xor` arrayRead32 ks 0, b'' `xor` arrayRead32 ks 1, c'' `xor` arrayRead32 ks 2, d'' `xor` arrayRead32 ks 3)
unshuffle :: (Word32, Word32, Word32, Word32) -> Int -> (Word32, Word32, Word32, Word32) unshuffle :: (Word32, Word32, Word32, Word32) -> Int -> (Word32, Word32, Word32, Word32)
unshuffle (retA, retB, retC, retD) ind = (retA', retB', retC', retD') unshuffle (retA, retB, retC, retD) ind = (retA', retB', retC', retD')
where ks' = take 4 $ drop (4 + 4 * ind) ks where [k0, k1, k2, k3] = fmap (\offset -> arrayRead32 ks $ (4 + 4 * ind) + offset) [0..3]
t2 = byteIndex s2 retD `xor` byteIndex s3 (shiftR retD 8) `xor` byteIndex s4 (shiftR retD 16) `xor` byteIndex s1 (shiftR retD 24) t2 = byteIndex s2 retD `xor` byteIndex s3 (shiftR retD 8) `xor` byteIndex s4 (shiftR retD 16) `xor` byteIndex s1 (shiftR retD 24)
t1 = (byteIndex s1 retC `xor` byteIndex s2 (shiftR retC 8) `xor` byteIndex s3 (shiftR retC 16) `xor` byteIndex s4 (shiftR retC 24)) + t2 t1 = (byteIndex s1 retC `xor` byteIndex s2 (shiftR retC 8) `xor` byteIndex s3 (shiftR retC 16) `xor` byteIndex s4 (shiftR retC 24)) + t2
retA' = rotateL retA 1 `xor` (t1 + (ks' !! 2)) retA' = rotateL retA 1 `xor` (t1 + k2)
retB' = rotateR (retB `xor` (t2 + t1 + (ks' !! 3))) 1 retB' = rotateR (retB `xor` (t2 + t1 + k3)) 1
t2' = byteIndex s2 retB' `xor` byteIndex s3 (shiftR retB' 8) `xor` byteIndex s4 (shiftR retB' 16) `xor` byteIndex s1 (shiftR retB' 24) t2' = byteIndex s2 retB' `xor` byteIndex s3 (shiftR retB' 8) `xor` byteIndex s4 (shiftR retB' 16) `xor` byteIndex s1 (shiftR retB' 24)
t1' = (byteIndex s1 retA' `xor` byteIndex s2 (shiftR retA' 8) `xor` byteIndex s3 (shiftR retA' 16) `xor` byteIndex s4 (shiftR retA' 24)) + t2' t1' = (byteIndex s1 retA' `xor` byteIndex s2 (shiftR retA' 8) `xor` byteIndex s3 (shiftR retA' 16) `xor` byteIndex s4 (shiftR retA' 24)) + t2'
retC' = rotateL retC 1 `xor` (t1' + head ks') retC' = rotateL retC 1 `xor` (t1' + k0)
retD' = rotateR (retD `xor` (t2' + t1' + (ks' !! 1))) 1 retD' = rotateR (retD `xor` (t2' + t1' + k1)) 1
sbox0 :: Int -> Word8 sbox0 :: Int -> Word8
sbox0 = arrayRead8 t sbox0 = arrayRead8 t
@ -190,9 +193,10 @@ sWords key = sWord
data Column = Zero | One | Two | Three deriving (Show, Eq, Enum, Bounded) data Column = Zero | One | Two | Three deriving (Show, Eq, Enum, Bounded)
-- Only implemented for 128-bit key (so far) -- Only implemented for 128-bit key (so far)
genSboxes :: [Word8] -> ([Word32], [Word32], [Word32], [Word32]) genSboxes :: [Word8] -> (Array32, Array32, Array32, Array32)
genSboxes ws = (b0, b1, b2, b3) genSboxes ws = (mkArray b0, mkArray b1, mkArray b2, mkArray b3)
where range = [0..255] where range = [0..255]
mkArray = array32 256
b0 = fmap mapper range b0 = fmap mapper range
where mapper :: Int -> Word32 where mapper :: Int -> Word32
mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox0 . fromIntegral $ sbox0 byte `xor` head ws) `xor` ws !! 4)) Zero mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox0 . fromIntegral $ sbox0 byte `xor` head ws) `xor` ws !! 4)) Zero