From 7eedbaa112a91ec82d5657594e5f4bce786f3444 Mon Sep 17 00:00:00 2001 From: Sam Protas Date: Sun, 2 Apr 2017 18:34:10 -0400 Subject: [PATCH] Initial implementaiton with passing tests --- Crypto/Cipher/Twofish.hs | 18 ++ Crypto/Cipher/Twofish/Primitive.hs | 260 +++++++++++++++++++++++++++++ cryptonite.cabal | 2 + tests/KAT_Twofish.hs | 19 +++ tests/Tests.hs | 2 + 5 files changed, 301 insertions(+) create mode 100644 Crypto/Cipher/Twofish.hs create mode 100644 Crypto/Cipher/Twofish/Primitive.hs create mode 100644 tests/KAT_Twofish.hs diff --git a/Crypto/Cipher/Twofish.hs b/Crypto/Cipher/Twofish.hs new file mode 100644 index 0000000..df61a5a --- /dev/null +++ b/Crypto/Cipher/Twofish.hs @@ -0,0 +1,18 @@ +module Crypto.Cipher.Twofish + ( Twofish128 (..) + ) where + +import Crypto.Cipher.Twofish.Primitive +import Crypto.Cipher.Types + +newtype Twofish128 = Twofish128 Twofish + +instance Cipher Twofish128 where + cipherName _ = "Twofish128" + cipherKeySize _ = KeySizeFixed 16 + cipherInit k = Twofish128 `fmap` initTwofish k + +instance BlockCipher Twofish128 where + blockSize _ = 16 + ecbEncrypt (Twofish128 key) = encrypt key + ecbDecrypt (Twofish128 key) = decrypt key \ No newline at end of file diff --git a/Crypto/Cipher/Twofish/Primitive.hs b/Crypto/Cipher/Twofish/Primitive.hs new file mode 100644 index 0000000..949b629 --- /dev/null +++ b/Crypto/Cipher/Twofish/Primitive.hs @@ -0,0 +1,260 @@ +{-# LANGUAGE MagicHash #-} +module Crypto.Cipher.Twofish.Primitive + ( Twofish (..) + , initTwofish + , encrypt + , decrypt + ) where + +import Crypto.Error +import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes) +import qualified Crypto.Internal.ByteArray as B +import Crypto.Internal.WordArray +import Crypto.Internal.Words +import Data.Word +import Data.Int +import Data.Bits +import Data.List +import Control.Monad + +-- Based on the Golang referance implementation +-- https://github.com/golang/crypto/blob/master/twofish/twofish.go + + +-- BlockSize is the constant block size of Twofish. +blockSize :: Int +blockSize = 16 + +mdsPolynomial, rsPolynomial :: Word32 +mdsPolynomial = 0x169 -- x^8 + x^6 + x^5 + x^3 + 1, see [TWOFISH] 4.2 +rsPolynomial = 0x14d -- x^8 + x^6 + x^3 + x^2 + 1, see [TWOFISH] 4.3 + +data Twofish = Twofish { s :: ([Word32], [Word32], [Word32], [Word32]) + , k :: [Word32] } + +-- CONFIRMED +-- | Initialize a 128-bit key +-- +-- Return the initialized key or a error message if the given +-- keyseed was not 16-bytes in length. +initTwofish :: ByteArray key + => key -- ^ The key to create the camellia context + -> CryptoFailable Twofish +initTwofish key + | B.length key /= blockSize = CryptoFailed CryptoError_KeySizeInvalid + | otherwise = CryptoPassed Twofish { k = generatedK, s = generatedS } + where generatedK = genK key + generatedS = genSboxes $ sWords key + + +mapBlocks :: ByteArray ba => (ba -> ba) -> ba -> ba +mapBlocks operation input + | B.null rest = blockOutput + | otherwise = blockOutput `B.append` mapBlocks operation rest + where (block, rest) = B.splitAt blockSize input + blockOutput = operation block + +-- | Encrypts the given ByteString using the given Key +encrypt :: ByteArray ba + => Twofish -- ^ The key to use + -> ba -- ^ The data to encrypt + -> ba +encrypt cipher = mapBlocks (encryptBlock cipher) + +encryptBlock :: ByteArray ba => Twofish -> ba -> ba +encryptBlock Twofish { s = (s1, s2, s3, s4), k = ks } message = store32ls ts + where (a, b, c, d) = load32ls message + [a', b', c', d'] = zipWith xor [a, b, c, d] ks + (a'', b'', c'', d'') = foldl' shuffle (a', b', c', d') [0..7] + ts = (c'' `xor` ks !! 4, d'' `xor` ks !! 5, a'' `xor` ks !! 6, b'' `xor` ks !! 7) + + shuffle :: (Word32, Word32, Word32, Word32) -> Int -> (Word32, Word32, Word32, Word32) + shuffle (retA, retB, retC, retD) ind = (retA', retB', retC', retD') + where ks' = take 4 $ drop (8 + 4 * ind) ks + t2 = byteIndex s2 retB `xor` byteIndex s3 (shiftR retB 8) `xor` byteIndex s4 (shiftR retB 16) `xor` byteIndex s1 (shiftR retB 24) + t1 = (byteIndex s1 retA `xor` byteIndex s2 (shiftR retA 8) `xor` byteIndex s3 (shiftR retA 16) `xor` byteIndex s4 (shiftR retA 24)) + t2 + retC' = rotateR (retC `xor` (t1 + head ks')) 1 + retD' = rotateL retD 1 `xor` (t1 + t2 + (ks' !! 1)) + t2' = byteIndex s2 retD' `xor` byteIndex s3 (shiftR retD' 8) `xor` byteIndex s4 (shiftR retD' 16) `xor` byteIndex s1 (shiftR retD' 24) + t1' = (byteIndex s1 retC' `xor` byteIndex s2 (shiftR retC' 8) `xor` byteIndex s3 (shiftR retC' 16) `xor` byteIndex s4 (shiftR retC' 24)) + t2' + retA' = rotateR (retA `xor` (t1' + (ks' !! 2))) 1 + retB' = rotateL retB 1 `xor` (t1' + t2' + (ks' !! 3)) + +-- Unsafe, no bounds checking +byteIndex :: Integral a => [b] -> a -> b +byteIndex xs ind = xs !! fromIntegral byte + where byte = fromIntegral ind :: Word8 + +-- | Decrypts the given ByteString using the given Key +decrypt :: ByteArray ba + => Twofish -- ^ The key to use + -> ba -- ^ The data to decrypt + -> ba +decrypt cipher = mapBlocks (decryptBlock cipher) + +{- decryption for 128 bits blocks -} +decryptBlock :: ByteArray ba => Twofish -> ba -> ba +decryptBlock Twofish { s = (s1, s2, s3, s4), k = ks } message = store32ls ixs + where (a, b, c, d) = load32ls message + (a', b', c', d') = (c `xor` ks !! 6, d `xor` ks !! 7, a `xor` ks !! 4, b `xor` ks !! 5) + (a'', b'', c'', d'') = foldl' unshuffle (a', b', c', d') [8, 7..1] + ixs = (a'' `xor` head ks, b'' `xor` ks !! 1, c'' `xor` ks !! 2, d'' `xor` ks !! 3) + + unshuffle :: (Word32, Word32, Word32, Word32) -> Int -> (Word32, Word32, Word32, Word32) + unshuffle (retA, retB, retC, retD) ind = (retA', retB', retC', retD') + where ks' = take 4 $ drop (4 + 4 * ind) ks + t2 = byteIndex s2 retD `xor` byteIndex s3 (shiftR retD 8) `xor` byteIndex s4 (shiftR retD 16) `xor` byteIndex s1 (shiftR retD 24) + t1 = (byteIndex s1 retC `xor` byteIndex s2 (shiftR retC 8) `xor` byteIndex s3 (shiftR retC 16) `xor` byteIndex s4 (shiftR retC 24)) + t2 + retA' = rotateL retA 1 `xor` (t1 + (ks' !! 2)) + retB' = rotateR (retB `xor` (t2 + t1 + (ks' !! 3))) 1 + t2' = byteIndex s2 retB' `xor` byteIndex s3 (shiftR retB' 8) `xor` byteIndex s4 (shiftR retB' 16) `xor` byteIndex s1 (shiftR retB' 24) + t1' = (byteIndex s1 retA' `xor` byteIndex s2 (shiftR retA' 8) `xor` byteIndex s3 (shiftR retA' 16) `xor` byteIndex s4 (shiftR retA' 24)) + t2' + retC' = rotateL retC 1 `xor` (t1' + head ks') + retD' = rotateR (retD `xor` (t2' + t1' + (ks' !! 1))) 1 + +sbox0 :: Int -> Word8 +sbox0 = arrayRead8 t + where t = array8 + "\xa9\x67\xb3\xe8\x04\xfd\xa3\x76\x9a\x92\x80\x78\xe4\xdd\xd1\x38\ + \\x0d\xc6\x35\x98\x18\xf7\xec\x6c\x43\x75\x37\x26\xfa\x13\x94\x48\ + \\xf2\xd0\x8b\x30\x84\x54\xdf\x23\x19\x5b\x3d\x59\xf3\xae\xa2\x82\ + \\x63\x01\x83\x2e\xd9\x51\x9b\x7c\xa6\xeb\xa5\xbe\x16\x0c\xe3\x61\ + \\xc0\x8c\x3a\xf5\x73\x2c\x25\x0b\xbb\x4e\x89\x6b\x53\x6a\xb4\xf1\ + \\xe1\xe6\xbd\x45\xe2\xf4\xb6\x66\xcc\x95\x03\x56\xd4\x1c\x1e\xd7\ + \\xfb\xc3\x8e\xb5\xe9\xcf\xbf\xba\xea\x77\x39\xaf\x33\xc9\x62\x71\ + \\x81\x79\x09\xad\x24\xcd\xf9\xd8\xe5\xc5\xb9\x4d\x44\x08\x86\xe7\ + \\xa1\x1d\xaa\xed\x06\x70\xb2\xd2\x41\x7b\xa0\x11\x31\xc2\x27\x90\ + \\x20\xf6\x60\xff\x96\x5c\xb1\xab\x9e\x9c\x52\x1b\x5f\x93\x0a\xef\ + \\x91\x85\x49\xee\x2d\x4f\x8f\x3b\x47\x87\x6d\x46\xd6\x3e\x69\x64\ + \\x2a\xce\xcb\x2f\xfc\x97\x05\x7a\xac\x7f\xd5\x1a\x4b\x0e\xa7\x5a\ + \\x28\x14\x3f\x29\x88\x3c\x4c\x02\xb8\xda\xb0\x17\x55\x1f\x8a\x7d\ + \\x57\xc7\x8d\x74\xb7\xc4\x9f\x72\x7e\x15\x22\x12\x58\x07\x99\x34\ + \\x6e\x50\xde\x68\x65\xbc\xdb\xf8\xc8\xa8\x2b\x40\xdc\xfe\x32\xa4\ + \\xca\x10\x21\xf0\xd3\x5d\x0f\x00\x6f\x9d\x36\x42\x4a\x5e\xc1\xe0"# + +sbox1 :: Int -> Word8 +sbox1 = arrayRead8 t + where t = array8 + "\x75\xf3\xc6\xf4\xdb\x7b\xfb\xc8\x4a\xd3\xe6\x6b\x45\x7d\xe8\x4b\ + \\xd6\x32\xd8\xfd\x37\x71\xf1\xe1\x30\x0f\xf8\x1b\x87\xfa\x06\x3f\ + \\x5e\xba\xae\x5b\x8a\x00\xbc\x9d\x6d\xc1\xb1\x0e\x80\x5d\xd2\xd5\ + \\xa0\x84\x07\x14\xb5\x90\x2c\xa3\xb2\x73\x4c\x54\x92\x74\x36\x51\ + \\x38\xb0\xbd\x5a\xfc\x60\x62\x96\x6c\x42\xf7\x10\x7c\x28\x27\x8c\ + \\x13\x95\x9c\xc7\x24\x46\x3b\x70\xca\xe3\x85\xcb\x11\xd0\x93\xb8\ + \\xa6\x83\x20\xff\x9f\x77\xc3\xcc\x03\x6f\x08\xbf\x40\xe7\x2b\xe2\ + \\x79\x0c\xaa\x82\x41\x3a\xea\xb9\xe4\x9a\xa4\x97\x7e\xda\x7a\x17\ + \\x66\x94\xa1\x1d\x3d\xf0\xde\xb3\x0b\x72\xa7\x1c\xef\xd1\x53\x3e\ + \\x8f\x33\x26\x5f\xec\x76\x2a\x49\x81\x88\xee\x21\xc4\x1a\xeb\xd9\ + \\xc5\x39\x99\xcd\xad\x31\x8b\x01\x18\x23\xdd\x1f\x4e\x2d\xf9\x48\ + \\x4f\xf2\x65\x8e\x78\x5c\x58\x19\x8d\xe5\x98\x57\x67\x7f\x05\x64\ + \\xaf\x63\xb6\xfe\xf5\xb7\x3c\xa5\xce\xe9\x68\x44\xe0\x4d\x43\x69\ + \\x29\x2e\xac\x15\x59\xa8\x0a\x9e\x6e\x47\xdf\x34\x35\x6a\xcf\xdc\ + \\x22\xc9\xc0\x9b\x89\xd4\xed\xab\x12\xa2\x0d\x52\xbb\x02\x2f\xa9\ + \\xd7\x61\x1e\xb4\x50\x04\xf6\xc2\x16\x25\x86\x56\x55\x09\xbe\x91"# + +rs :: [[Word8]] +rs = [ [0x01, 0xA4, 0x55, 0x87, 0x5A, 0x58, 0xDB, 0x9E] + , [0xA4, 0x56, 0x82, 0xF3, 0x1E, 0xC6, 0x68, 0xE5] + , [0x02, 0xA1, 0xFC, 0xC1, 0x47, 0xAE, 0x3D, 0x19] + , [0xA4, 0x55, 0x87, 0x5A, 0x58, 0xDB, 0x9E, 0x03] ] + + + +load32ls :: ByteArray ba => ba -> (Word32, Word32, Word32, Word32) +load32ls message = (intify q1, intify q2, intify q3, intify q4) + where (half1, half2) = B.splitAt 8 message + (q1, q2) = B.splitAt 4 half1 + (q3, q4) = B.splitAt 4 half2 + + intify :: ByteArray ba => ba -> Word32 + intify bytes = foldl' (\int (word, ind) -> int .|. shiftL (fromIntegral word) (ind * 8) ) 0 (zip (B.unpack bytes) [0..]) + +store32ls :: ByteArray ba => (Word32, Word32, Word32, Word32) -> ba +store32ls (a, b, c, d) = B.pack $ concatMap splitWordl [a, b, c, d] + where splitWordl :: Word32 -> [Word8] + splitWordl w = fmap (\ind -> fromIntegral $ shiftR w (8 * ind)) [0..3] + + +-- Create S words +sWords :: ByteArray ba => ba -> [Word8] +sWords key = sWord + where word64Count = B.length key `div` 2 + sWord = concatMap (\wordIndex -> + map (\rsRow -> + foldl' (\acc (rsVal, colIndex) -> + acc `xor` gfMult rsPolynomial (B.index key $ 8 * wordIndex + colIndex) rsVal + ) 0 (zip rsRow [0..]) + ) rs + ) [0..word64Count - 1] + +data Column = Zero | One | Two | Three deriving (Show, Eq, Enum, Bounded) + +-- Only implemented for 128-bit key (so far) +genSboxes :: [Word8] -> ([Word32], [Word32], [Word32], [Word32]) +genSboxes ws = (b0, b1, b2, b3) + where range = [0..255] + b0 = fmap mapper range + where mapper :: Int -> Word32 + mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox0 . fromIntegral $ sbox0 byte `xor` head ws) `xor` ws !! 4)) Zero + b1 = fmap mapper range + where mapper byte = mdsColumnMult ((sbox0 . fromIntegral) ((sbox0 . fromIntegral $ sbox1 byte `xor` ws !! 1) `xor` ws !! 5)) One + b2 = fmap mapper range + where mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox1 . fromIntegral $ sbox0 byte `xor` ws !! 2) `xor` ws !! 6)) Two + b3 = fmap mapper range + where mapper byte = mdsColumnMult ((sbox0 . fromIntegral) ((sbox1 . fromIntegral $ sbox1 byte `xor` ws !! 3) `xor` ws !! 7)) Three + +genK :: (ByteArray ba) => ba -> [Word32] +genK key = concatMap (tupToList . makeTuple) [0..19] + where makeTuple :: Word8 -> (Word32, Word32) + makeTuple idx = (a + b', rotateL (2 * b' + a) 9) + where tmp1 = replicate 4 $ 2 * idx + tmp2 = fmap (+1) tmp1 + a = h (B.pack tmp1 :: Bytes) key 0 + b = h (B.pack tmp2 :: Bytes) key 1 + b' = rotateL b 8 + + tupToList :: (a, a) -> [a] + tupToList (a, b) = [a, b] + + +-- ONLY implemented for 128-bit key (so far) +h :: (Show ba1, ByteArray ba1, ByteArray ba2) => ba1 -> ba2 -> Int -> Word32 +h input key offset = foldl' xorMdsColMult 0 $ zip [y0', y1', y2', y3'] $ enumFrom Zero + where [y0, y1, y2, y3] = B.unpack $ B.take 4 input + y0' = sbox1 . fromIntegral $ (sbox0 . fromIntegral $ (sbox0 (fromIntegral y0) `xor` B.index key (4 * (2 + offset) + 0))) `xor` B.index key (4 * (0 + offset) + 0) :: Word8 + y1' = sbox0 . fromIntegral $ (sbox0 . fromIntegral $ (sbox1 (fromIntegral y1) `xor` B.index key (4 * (2 + offset) + 1))) `xor` B.index key (4 * (0 + offset) + 1) + y2' = sbox1 . fromIntegral $ (sbox1 . fromIntegral $ (sbox0 (fromIntegral y2) `xor` B.index key (4 * (2 + offset) + 2))) `xor` B.index key (4 * (0 + offset) + 2) + y3' = sbox0 . fromIntegral $ (sbox1 . fromIntegral $ (sbox1 (fromIntegral y3) `xor` B.index key (4 * (2 + offset) + 3))) `xor` B.index key (4 * (0 + offset) + 3) + + xorMdsColMult :: Word32 -> (Word8, Column) -> Word32 + xorMdsColMult acc wordAndIndex = acc `xor` uncurry mdsColumnMult wordAndIndex + +mdsColumnMult :: Word8 -> Column -> Word32 +mdsColumnMult byte col = + case col of Zero -> input .|. rotateL mul5B 8 .|. rotateL mulEF 16 .|. rotateL mulEF 24 + One -> mulEF .|. rotateL mulEF 8 .|. rotateL mul5B 16 .|. rotateL input 24 + Two -> mul5B .|. rotateL mulEF 8 .|. rotateL input 16 .|. rotateL mulEF 24 + Three -> mul5B .|. rotateL input 8 .|. rotateL mulEF 16 .|. rotateL mul5B 24 + where input = fromIntegral byte + mul5B = fromIntegral $ gfMult mdsPolynomial byte 0x5B + mulEF = fromIntegral $ gfMult mdsPolynomial byte 0xEF + +tupInd :: (Bits b) => b -> (a, a) -> a +tupInd b + | testBit b 0 = snd + | otherwise = fst + +gfMult :: Word32 -> Word8 -> Word8 -> Word8 +gfMult p a b = fromIntegral $ run a b' p' result 0 + where b' = (0, fromIntegral b) + p' = (0, p) + result = 0 + + run :: Word8 -> (Word32, Word32) -> (Word32, Word32) -> Word32 -> Int -> Word32 + run a' b'' p'' result' count = + if count == 7 + then result'' + else run a'' b''' p'' result'' (count + 1) + where result'' = result' `xor` tupInd (a' .&. 1) b'' + a'' = shiftR a' 1 + b''' = (fst b'', tupInd (shiftR (snd b'') 7) p'' `xor` shiftL (snd b'') 1) diff --git a/cryptonite.cabal b/cryptonite.cabal index 9569a18..d9ed097 100644 --- a/cryptonite.cabal +++ b/cryptonite.cabal @@ -108,6 +108,7 @@ Library Crypto.Cipher.RC4 Crypto.Cipher.Salsa Crypto.Cipher.TripleDES + Crypto.Cipher.Twofish Crypto.Cipher.Types Crypto.Cipher.XSalsa Crypto.ConstructHash.MiyaguchiPreneel @@ -165,6 +166,7 @@ Library Crypto.Cipher.Blowfish.Primitive Crypto.Cipher.Camellia.Primitive Crypto.Cipher.DES.Primitive + Crypto.Cipher.Twofish.Primitive Crypto.Cipher.Types.AEAD Crypto.Cipher.Types.Base Crypto.Cipher.Types.Block diff --git a/tests/KAT_Twofish.hs b/tests/KAT_Twofish.hs new file mode 100644 index 0000000..a37b688 --- /dev/null +++ b/tests/KAT_Twofish.hs @@ -0,0 +1,19 @@ +module KAT_Twofish (tests) where + +import Imports +import BlockCipher + +import qualified Data.ByteString as B +import Crypto.Cipher.Twofish + + +vectors_twofish128 = + [ KAT_ECB (B.replicate 16 0x00) (B.replicate 16 0x00) (B.pack [0x9F,0x58,0x9F,0x5C,0xF6,0x12,0x2C,0x32,0xB6,0xBF,0xEC,0x2F,0x2A,0xE8,0xC3,0x5A]) + , KAT_ECB (B.pack [0x9F,0x58,0x9F,0x5C,0xF6,0x12,0x2C,0x32,0xB6,0xBF,0xEC,0x2F,0x2A,0xE8,0xC3,0x5A]) + (B.pack [0xD4, 0x91, 0xDB, 0x16, 0xE7, 0xB1, 0xC3, 0x9E, 0x86, 0xCB, 0x08, 0x6B, 0x78, 0x9F, 0x54, 0x19]) + (B.pack [0x01, 0x9F, 0x98, 0x09, 0xDE, 0x17, 0x11, 0x85, 0x8F, 0xAA, 0xC3, 0xA3, 0xBA, 0x20, 0xFB, 0xC3]) + ] + +kats128 = defaultKATs { kat_ECB = vectors_twofish128 } + +tests = testBlockCipher kats128 (undefined :: Twofish128) diff --git a/tests/Tests.hs b/tests/Tests.hs index 68d79f4..b6ecf3a 100644 --- a/tests/Tests.hs +++ b/tests/Tests.hs @@ -31,6 +31,7 @@ import qualified KAT_Camellia import qualified KAT_DES import qualified KAT_RC4 import qualified KAT_TripleDES +import qualified KAT_Twofish -- misc -------------------------------- import qualified KAT_AFIS import qualified Padding @@ -66,6 +67,7 @@ tests = testGroup "cryptonite" , KAT_Camellia.tests , KAT_DES.tests , KAT_TripleDES.tests + , KAT_Twofish.tests ] , testGroup "stream-cipher" [ KAT_RC4.tests