From 48770bf79f446ae088102edd6e0266dac2beb469 Mon Sep 17 00:00:00 2001 From: Baojun Wang Date: Thu, 1 Jun 2017 15:16:24 -0700 Subject: [PATCH] fix aes ccm decryption cbcmac mis-match --- Crypto/Cipher/AES/Primitive.hs | 35 ++++++++--------- cbits/cryptonite_aes.c | 71 ++++++++++++++++++++-------------- cbits/cryptonite_aes.h | 17 ++++---- 3 files changed, 67 insertions(+), 56 deletions(-) diff --git a/Crypto/Cipher/AES/Primitive.hs b/Crypto/Cipher/AES/Primitive.hs index 6c684b1..2e4f2ac 100644 --- a/Crypto/Cipher/AES/Primitive.hs +++ b/Crypto/Cipher/AES/Primitive.hs @@ -104,7 +104,7 @@ ocbMode aes = AEADModeImpl -- | Create an AES AEAD implementation for GCM ccmMode :: AES -> AEADModeImpl AESCCM ccmMode aes = AEADModeImpl - { aeadImplAppendHeader = ccmAppendAAD + { aeadImplAppendHeader = ccmAppendAAD aes , aeadImplEncrypt = ccmEncrypt aes , aeadImplDecrypt = ccmDecrypt aes , aeadImplFinalize = ccmFinish aes @@ -133,7 +133,7 @@ sizeOCB :: Int sizeOCB = 160 sizeCCM :: Int -sizeCCM = 544 +sizeCCM = 80 keyToPtr :: AES -> (Ptr AES -> IO a) -> IO a keyToPtr (AES b) f = withByteArray b (f . castPtr) @@ -179,9 +179,6 @@ withCCMKeyAndCopySt aes (AESCCM ccmSt) f = a <- withByteArray newSt $ \ccmStPtr -> f (castPtr ccmStPtr) aesPtr return (a, AESCCM newSt) -withNewCCMSt :: AESCCM -> (Ptr AESCCM -> IO ()) -> IO AESCCM -withNewCCMSt (AESCCM ccmSt) f = B.copy ccmSt (f . castPtr) >>= \sm2 -> return (AESCCM sm2) - -- | Initialize a new context with a key -- -- Key needs to be of length 16, 24 or 32 bytes. Any other values will return failure @@ -506,15 +503,10 @@ ccmInit ctx iv n m l = unsafeDoIO $ do -- -- needs to happen after initialization and before appending encryption/decryption data. {-# NOINLINE ccmAppendAAD #-} -ccmAppendAAD :: ByteArrayAccess aad => AESCCM -> aad -> AESCCM -ccmAppendAAD ccmSt input = unsafeDoIO doAppend - where doAppend = - withNewCCMSt ccmSt $ \ccmStPtr -> - withByteArray input $ \i -> - c_aes_ccm_aad ccmStPtr i (fromIntegral $ B.length input) - -doCTR :: (ByteArray ba, BlockCipher cipher) => cipher -> ba -> ba -> ba -doCTR ctx iv0 input = ctrCombine ctx (ivAdd (IV (B.convert iv0 :: B.Bytes)) 1) input +ccmAppendAAD :: ByteArrayAccess aad => AES -> AESCCM -> aad -> AESCCM +ccmAppendAAD ctx ccm input = unsafeDoIO $ snd <$> withCCMKeyAndCopySt ctx ccm doAppend + where doAppend ccmStPtr aesPtr = + withByteArray input $ \i -> c_aes_ccm_aad ccmStPtr aesPtr i (fromIntegral $ B.length input) -- | append data to encrypt and append to the CCM context -- @@ -522,10 +514,10 @@ doCTR ctx iv0 input = ctrCombine ctx (ivAdd (IV (B.convert iv0 :: B.Bytes)) 1) i -- needs to happen after AAD appending, or after initialization if no AAD data. {-# NOINLINE ccmEncrypt #-} ccmEncrypt :: ByteArray ba => AES -> AESCCM -> ba -> (ba, AESCCM) -ccmEncrypt ctx ccm input = unsafeDoIO $ (withCCMKeyAndCopySt ctx ccm cbcmacAndIv >>= \(iv0, cc) -> return (doCTR ctx iv0 input, cc)) +ccmEncrypt ctx ccm input = unsafeDoIO $ withCCMKeyAndCopySt ctx ccm cbcmacAndIv where len = B.length input cbcmacAndIv ccmStPtr aesPtr = - B.alloc 16 $ \o -> + B.alloc len $ \o -> withByteArray input $ \i -> c_aes_ccm_encrypt (castPtr o) ccmStPtr aesPtr i (fromIntegral len) @@ -535,7 +527,12 @@ ccmEncrypt ctx ccm input = unsafeDoIO $ (withCCMKeyAndCopySt ctx ccm cbcmacAndIv -- needs to happen after AAD appending, or after initialization if no AAD data. {-# NOINLINE ccmDecrypt #-} ccmDecrypt :: ByteArray ba => AES -> AESCCM -> ba -> (ba, AESCCM) -ccmDecrypt = ccmEncrypt +ccmDecrypt ctx ccm input = unsafeDoIO $ withCCMKeyAndCopySt ctx ccm cbcmacAndIv + where len = B.length input + cbcmacAndIv ccmStPtr aesPtr = + B.alloc len $ \o -> + withByteArray input $ \i -> + c_aes_ccm_decrypt (castPtr o) ccmStPtr aesPtr i (fromIntegral len) -- | Generate the Tag from CCM context {-# NOINLINE ccmFinish #-} @@ -606,10 +603,10 @@ foreign import ccall "cryptonite_aes.h cryptonite_aes_ocb_finish" c_aes_ocb_finish :: CString -> Ptr AESOCB -> Ptr AES -> IO () foreign import ccall "cryptonite_aes.h cryptonite_aes_ccm_init" - c_aes_ccm_init :: Ptr AESCCM -> Ptr AES -> Ptr Word8 -> CUInt -> CULong -> CInt -> CInt -> IO () + c_aes_ccm_init :: Ptr AESCCM -> Ptr AES -> Ptr Word8 -> CUInt -> CUInt -> CInt -> CInt -> IO () foreign import ccall "cryptonite_aes.h cryptonite_aes_ccm_aad" - c_aes_ccm_aad :: Ptr AESCCM -> CString -> CUInt -> IO () + c_aes_ccm_aad :: Ptr AESCCM -> Ptr AES -> CString -> CUInt -> IO () foreign import ccall "cryptonite_aes.h cryptonite_aes_ccm_encrypt" c_aes_ccm_encrypt :: CString -> Ptr AESCCM -> Ptr AES -> CString -> CUInt -> IO () diff --git a/cbits/cryptonite_aes.c b/cbits/cryptonite_aes.c index df6186d..1888566 100644 --- a/cbits/cryptonite_aes.c +++ b/cbits/cryptonite_aes.c @@ -30,6 +30,7 @@ #include #include +#include #include #include @@ -448,7 +449,7 @@ static void ccm_encode_b0(block128* output, aes_ccm* ccm, int has_adata) int last = 15; int m = ccm->length_M; int l = ccm->length_L; - uint64_t msg_len = ccm->length_input; + unsigned msg_len = ccm->length_input; block128_zero(output); block128_copy(output, &ccm->nonce); @@ -460,7 +461,7 @@ static void ccm_encode_b0(block128* output, aes_ccm* ccm, int has_adata) } /* encode adata length */ -static int ccm_encode_la(block128* output, uint64_t la) +static int ccm_encode_la(block128* output, unsigned la) { if (la < ( (1 << 16) - (1 << 8)) ) { output->b[0] = (la >> 8) & 0xff; @@ -474,18 +475,6 @@ static int ccm_encode_la(block128* output, uint64_t la) output->b[4] = (la >> 8) & 0xff; output->b[5] = la & 0xff; return 6; - } else { - output->b[0] = 0xff; - output->b[1] = 0xff; - output->b[2] = (la >> 56) & 0xff; - output->b[3] = (la >> 48) & 0xff; - output->b[4] = (la >> 40) & 0xff; - output->b[5] = (la >> 32) & 0xff; - output->b[6] = (la >> 24) & 0xff; - output->b[7] = (la >> 16) & 0xff; - output->b[8] = (la >> 8) & 0xff; - output->b[9] = la & 0xff; - return 10; } } @@ -508,7 +497,7 @@ static void ccm_cbcmac_add(aes_ccm* ccm, aes_key* key, block128* bi) } /* even though it is possible to support message size as large as 2^64, we support up to 2^32 only */ -void cryptonite_aes_ccm_init(aes_ccm *ccm, aes_key *key, uint8_t *nonce, uint32_t nonce_len, uint64_t input_size, int m, int l) +void cryptonite_aes_ccm_init(aes_ccm *ccm, aes_key *key, uint8_t *nonce, uint32_t nonce_len, uint32_t input_size, int m, int l) { memset(ccm, 0, sizeof(aes_ccm)); @@ -529,22 +518,19 @@ void cryptonite_aes_ccm_init(aes_ccm *ccm, aes_key *key, uint8_t *nonce, uint32_ ccm->length_input = input_size; memcpy(&ccm->nonce.b[1], nonce, 15 - l); - memcpy(&ccm->aad_key, key, sizeof(aes_key)); ccm_encode_b0(&ccm->b0, ccm, 1); /* assume aad is present */ - ccm_encode_ctr(&ccm->iv, ccm, 0); - cryptonite_aes_encrypt_block(&ccm->xi, key, &ccm->b0); } /* even though l(a) can be as large as 2^64, we only handle aad up to 2 ^ 32 for practical reasons. Also we don't support incremental aad add, because the 1st encoded adata has length information */ -void cryptonite_aes_ccm_aad(aes_ccm *ccm, uint8_t *input, uint32_t length) +void cryptonite_aes_ccm_aad(aes_ccm *ccm, aes_key *key, uint8_t *input, uint32_t length) { block128 tmp; - aes_key* key = &ccm->aad_key; + assert (ccm->length_aad == 0); ccm->length_aad = length; int len_len; @@ -572,19 +558,17 @@ void cryptonite_aes_ccm_aad(aes_ccm *ccm, uint8_t *input, uint32_t length) block128_copy_bytes(&tmp, input, length); ccm_cbcmac_add(ccm, key, &tmp); } - - memset(&ccm->aad_key, 0, sizeof(aes_key)); + block128_copy(&ccm->header_cbcmac, &ccm->xi); } void cryptonite_aes_ccm_finish(uint8_t *tag, aes_ccm *ccm, aes_key *key) { block128 iv, s0; - block128 u; + block128_zero(&iv); ccm_encode_ctr(&iv, ccm, 0); cryptonite_aes_encrypt_block(&s0, key, &iv); - block128_vxor(&u, &ccm->xi, &s0); - memcpy(tag, u.b, ccm->length_M); + block128_vxor((block128*)tag, &ccm->xi, &s0); } static inline void ocb_block_double(block128 *d, block128 *s) @@ -922,18 +906,23 @@ static void ocb_generic_crypt(uint8_t *output, aes_ocb *ocb, aes_key *key, void cryptonite_aes_generic_ccm_encrypt(uint8_t *output, aes_ccm *ccm, aes_key *key, uint8_t *input, uint32_t length) { - block128 tmp; + block128 tmp, ctr; /* when aad is absent, reset b0 block */ if (ccm->length_aad == 0) { ccm_encode_b0(&ccm->b0, ccm, 0); /* assume aad is present */ cryptonite_aes_encrypt_block(&ccm->xi, key, &ccm->b0); + block128_copy(&ccm->header_cbcmac, &ccm->xi); } + assert (length == ccm->length_input); if (length != ccm->length_input) { return; } + ccm_encode_ctr(&ctr, ccm, 1); + cryptonite_aes_encrypt_ctr(output, key, &ctr, input, length); + for (;length >= 16; input += 16, length -= 16) { block128_copy(&tmp, (block128*)input); ccm_cbcmac_add(ccm, key, &tmp); @@ -943,12 +932,38 @@ void cryptonite_aes_generic_ccm_encrypt(uint8_t *output, aes_ccm *ccm, aes_key * block128_copy_bytes(&tmp, input, length); ccm_cbcmac_add(ccm, key, &tmp); } - block128_copy((block128*)output, &ccm->iv); } void cryptonite_aes_generic_ccm_decrypt(uint8_t *output, aes_ccm *ccm, aes_key *key, uint8_t *input, uint32_t length) { - cryptonite_aes_generic_ccm_encrypt(output, ccm, key, input, length); + block128 tmp, ctr; + + /* when aad is absent, reset b0 block */ + if (ccm->length_aad == 0) { + ccm_encode_b0(&ccm->b0, ccm, 0); /* assume aad is present */ + cryptonite_aes_encrypt_block(&ccm->xi, key, &ccm->b0); + block128_copy(&ccm->header_cbcmac, &ccm->xi); + } + + assert (length == ccm->length_input); + if (length != ccm->length_input) { + return; + } + + ccm_encode_ctr(&ctr, ccm, 1); + cryptonite_aes_encrypt_ctr(output, key, &ctr, input, length); + block128_copy(&ccm->xi, &ccm->header_cbcmac); + input = output; + + for (;length >= 16; input += 16, length -= 16) { + block128_copy(&tmp, (block128*)input); + ccm_cbcmac_add(ccm, key, &tmp); + } + if (length > 0) { + block128_zero(&tmp); + block128_copy_bytes(&tmp, input, length); + ccm_cbcmac_add(ccm, key, &tmp); + } } void cryptonite_aes_generic_ocb_encrypt(uint8_t *output, aes_ocb *ocb, aes_key *key, uint8_t *input, uint32_t length) diff --git a/cbits/cryptonite_aes.h b/cbits/cryptonite_aes.h index 70fbd47..0838a03 100644 --- a/cbits/cryptonite_aes.h +++ b/cbits/cryptonite_aes.h @@ -55,15 +55,14 @@ typedef struct { uint64_t length_input; } aes_gcm; -/* size = 544 */ +/* size = 80 */ typedef struct { - aes_block iv; /* iv with counter = 0 block */ - aes_block xi; /* X_i: cbc mac */ - aes_block b0; /* block b0 */ + aes_block xi; + aes_block header_cbcmac; + aes_block b0; aes_block nonce; - aes_key aad_key; - uint64_t length_aad; - uint64_t length_input; + unsigned length_aad; + unsigned length_input; int length_M; int length_L; } aes_ccm; @@ -110,8 +109,8 @@ void cryptonite_aes_ocb_encrypt(uint8_t *output, aes_ocb *ocb, aes_key *key, uin void cryptonite_aes_ocb_decrypt(uint8_t *output, aes_ocb *ocb, aes_key *key, uint8_t *input, uint32_t length); void cryptonite_aes_ocb_finish(uint8_t *tag, aes_ocb *ocb, aes_key *key); -void cryptonite_aes_ccm_init(aes_ccm *ccm, aes_key *key, uint8_t *nonce, uint32_t len, uint64_t msg_size, int m, int l); -void cryptonite_aes_ccm_aad(aes_ccm *ccm, uint8_t *input, uint32_t length); +void cryptonite_aes_ccm_init(aes_ccm *ccm, aes_key *key, uint8_t *nonce, uint32_t len, uint32_t msg_size, int m, int l); +void cryptonite_aes_ccm_aad(aes_ccm *ccm, aes_key *key, uint8_t *input, uint32_t length); void cryptonite_aes_ccm_encrypt(uint8_t *output, aes_ccm *ccm, aes_key *key, uint8_t *input, uint32_t length); void cryptonite_aes_ccm_decrypt(uint8_t *output, aes_ccm *ccm, aes_key *key, uint8_t *input, uint32_t length); void cryptonite_aes_ccm_finish(uint8_t *tag, aes_ccm *ccm, aes_key *key);