From c30f1424352769e19021e5540eb9b851f34ee967 Mon Sep 17 00:00:00 2001 From: Justin Smith Date: Tue, 7 May 2024 15:54:45 -0400 Subject: [PATCH] Stream cipher API --- aws-lc-rs/src/cipher.rs | 96 +++++- aws-lc-rs/src/cipher/streaming.rs | 527 ++++++++++++++++++++++++++++++ aws-lc-rs/src/ptr.rs | 6 +- 3 files changed, 626 insertions(+), 3 deletions(-) create mode 100644 aws-lc-rs/src/cipher/streaming.rs diff --git a/aws-lc-rs/src/cipher.rs b/aws-lc-rs/src/cipher.rs index 1eb4e6e2f3b..008edfd43e4 100644 --- a/aws-lc-rs/src/cipher.rs +++ b/aws-lc-rs/src/cipher.rs @@ -83,6 +83,53 @@ //! # } //! ``` //! +//! ### AES-128 CBC Streaming Cipher +//! +//! ```rust +//! # use std::error::Error; +//! # +//! # fn main() -> Result<(), Box> { +//! use aws_lc_rs::cipher::{StreamingDecryptingKey, StreamingEncryptingKey, UnboundCipherKey, AES_128}; +//! +//! let original_message = "This is a secret message!".as_bytes(); +//! +//! let key_bytes: &[u8] = &[ +//! 0xff, 0x0b, 0xe5, 0x84, 0x64, 0x0b, 0x00, 0xc8, 0x90, 0x7a, 0x4b, 0xbf, 0x82, 0x7c, 0xb6, +//! 0xd1, +//! ]; +//! +//! // Encrypt +//! let mut ciphertext_buffer = vec![0u8; original_message.len() + AES_128.block_len()]; +//! let ciphertext_slice = ciphertext_buffer.as_mut_slice(); +//! +//! let key = UnboundCipherKey::new(&AES_128, key_bytes)?; +//! let mut encrypting_key = StreamingEncryptingKey::cbc_pkcs7(key)?; +//! let written_slice = encrypting_key.update(original_message, ciphertext_slice)?; +//! let written_len = written_slice.len(); +//! let remaining_slice = &mut ciphertext_slice[written_len..]; +//! let (context, written_slice) = encrypting_key.finish(remaining_slice)?; +//! let ciphertext_len = written_len + written_slice.len(); +//! let ciphertext = &ciphertext_slice[0..ciphertext_len]; +//! +//! // Decrypt +//! let mut plaintext_buffer = vec![0u8; ciphertext_len + AES_128.block_len()]; +//! let plaintext_slice = plaintext_buffer.as_mut_slice(); +//! +//! let key = UnboundCipherKey::new(&AES_128, key_bytes)?; +//! let mut decrypting_key = StreamingDecryptingKey::cbc_pkcs7(key, context)?; +//! let written_slice = decrypting_key.update(ciphertext, plaintext_slice)?; +//! let written_len = written_slice.len(); +//! let remaining_slice = &mut plaintext_slice[written_len..]; +//! let written_slice = decrypting_key.finish(remaining_slice)?; +//! let plaintext_len = written_len + written_slice.len(); +//! let plaintext = &plaintext_slice[0..plaintext_len]; +//! +//! assert_eq!(original_message, plaintext); +//! # +//! # Ok(()) +//! # } +//! ``` +//! //! ## Constructing a `DecryptionContext` for decryption. //! //! ```rust @@ -142,8 +189,10 @@ pub(crate) mod block; pub(crate) mod chacha; pub(crate) mod key; mod padded; +mod streaming; pub use padded::{PaddedBlockDecryptingKey, PaddedBlockEncryptingKey}; +pub use streaming::{StreamingDecryptingKey, StreamingEncryptingKey}; use crate::buffer::Buffer; use crate::error::Unspecified; @@ -280,7 +329,9 @@ impl Algorithm { &self.id } - const fn block_len(&self) -> usize { + /// The block length of this cipher algorithm. + #[must_use] + pub const fn block_len(&self) -> usize { self.block_len } @@ -905,4 +956,47 @@ mod tests { "eca7285d19f3c20e295378460e8729", "b5098e5e788de6ac2f2098eb2fc6f8" ); + + #[test] + fn streaming_cipher() { + use crate::cipher::{ + StreamingDecryptingKey, StreamingEncryptingKey, UnboundCipherKey, AES_128, + }; + + let original_message = "This is a secret message!".as_bytes(); + + let key_bytes: &[u8] = &[ + 0xff, 0x0b, 0xe5, 0x84, 0x64, 0x0b, 0x00, 0xc8, 0x90, 0x7a, 0x4b, 0xbf, 0x82, 0x7c, + 0xb6, 0xd1, + ]; + + // Encrypt + let mut ciphertext_buffer = vec![0u8; original_message.len() + AES_128.block_len()]; + let ciphertext_slice = ciphertext_buffer.as_mut_slice(); + + let key = UnboundCipherKey::new(&AES_128, key_bytes).unwrap(); + let mut encrypting_key = StreamingEncryptingKey::cbc_pkcs7(key).unwrap(); + let written_slice = encrypting_key + .update(original_message, ciphertext_slice) + .unwrap(); + let written_len = written_slice.len(); + let remaining_slice = &mut ciphertext_slice[written_len..]; + let (context, written_slice) = encrypting_key.finish(remaining_slice).unwrap(); + let ciphertext_len = written_len + written_slice.len(); + let ciphertext = &ciphertext_slice[0..ciphertext_len]; + + // Decrypt + let mut plaintext_buffer = vec![0u8; ciphertext_len + AES_128.block_len()]; + let plaintext_slice = plaintext_buffer.as_mut_slice(); + let key = UnboundCipherKey::new(&AES_128, key_bytes).unwrap(); + let mut decrypting_key = StreamingDecryptingKey::cbc_pkcs7(key, context).unwrap(); + let written_slice = decrypting_key.update(ciphertext, plaintext_slice).unwrap(); + let written_len = written_slice.len(); + let remaining_slice = &mut plaintext_slice[written_len..]; + let written_slice = decrypting_key.finish(remaining_slice).unwrap(); + let plaintext_len = written_len + written_slice.len(); + let plaintext = &plaintext_slice[0..plaintext_len]; + + assert_eq!(original_message, plaintext); + } } diff --git a/aws-lc-rs/src/cipher/streaming.rs b/aws-lc-rs/src/cipher/streaming.rs new file mode 100644 index 00000000000..ef5b5d0f2fc --- /dev/null +++ b/aws-lc-rs/src/cipher/streaming.rs @@ -0,0 +1,527 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 OR ISC + +use crate::cipher::{ + Algorithm, DecryptionContext, EncryptionContext, OperatingMode, UnboundCipherKey, +}; +use crate::error::Unspecified; +use crate::ptr::{LcPtr, Pointer}; +use aws_lc::{ + EVP_CIPHER_CTX_new, EVP_CIPHER_iv_length, EVP_CIPHER_key_length, EVP_DecryptFinal_ex, + EVP_DecryptInit_ex, EVP_DecryptUpdate, EVP_EncryptFinal_ex, EVP_EncryptInit_ex, + EVP_EncryptUpdate, EVP_CIPHER_CTX, +}; +use std::ptr::null_mut; + +/// A cipher encryption key for streaming encryption operations. +pub struct StreamingEncryptingKey { + algorithm: &'static Algorithm, + mode: OperatingMode, + cipher_ctx: LcPtr, + context: EncryptionContext, +} + +impl StreamingEncryptingKey { + #[allow(clippy::needless_pass_by_value)] + fn new( + key: UnboundCipherKey, + mode: OperatingMode, + context: EncryptionContext, + ) -> Result { + let algorithm = key.algorithm(); + let cipher_ctx = LcPtr::new(unsafe { EVP_CIPHER_CTX_new() })?; + let cipher = mode.evp_cipher(key.algorithm); + let key_bytes = key.key_bytes.as_ref(); + debug_assert_eq!( + key_bytes.len(), + ::try_from(unsafe { EVP_CIPHER_key_length(*cipher) }).unwrap() + ); + let iv = <&[u8]>::try_from(&context)?; + debug_assert_eq!( + iv.len(), + ::try_from(unsafe { EVP_CIPHER_iv_length(*cipher) }).unwrap() + ); + + if 1 != unsafe { + EVP_EncryptInit_ex( + cipher_ctx.as_mut_ptr(), + *cipher, + null_mut(), + key_bytes.as_ptr(), + iv.as_ptr(), + ) + } { + return Err(Unspecified); + } + + Ok(Self { + algorithm, + mode, + cipher_ctx, + context, + }) + } + + /// Encrypt the input and return the output. + /// # Errors + /// Returns an error if the output buffer is too small. + pub fn update<'a>( + &mut self, + input: &[u8], + output: &'a mut [u8], + ) -> Result<&'a [u8], Unspecified> { + if output.len() < (input.len() + self.algorithm.block_len) { + return Err(Unspecified); + } + + let mut outlen: i32 = output.len().try_into()?; + let inlen: i32 = input.len().try_into()?; + if 1 != unsafe { + EVP_EncryptUpdate( + self.cipher_ctx.as_mut_ptr(), + output.as_mut_ptr(), + &mut outlen, + input.as_ptr(), + inlen, + ) + } { + return Err(Unspecified); + } + let outlen: usize = outlen.try_into()?; + Ok(&output[0..outlen]) + } + + /// Finish the encryption and return the output. + /// # Errors + /// Returns an error if the output buffer is too small. + pub fn finish(self, output: &mut [u8]) -> Result<(DecryptionContext, &[u8]), Unspecified> { + if output.len() < self.algorithm.block_len { + return Err(Unspecified); + } + let mut outlen: i32 = output.len().try_into()?; + if 1 != unsafe { + EVP_EncryptFinal_ex( + self.cipher_ctx.as_mut_ptr(), + output.as_mut_ptr(), + &mut outlen, + ) + } { + return Err(Unspecified); + } + let outlen: usize = outlen.try_into()?; + Ok((self.context.into(), &output[0..outlen])) + } + + /// Returns the cipher operating mode. + #[must_use] + pub fn mode(&self) -> OperatingMode { + self.mode + } + + /// Returns the cipher algorithm + #[must_use] + pub fn algorithm(&self) -> &'static Algorithm { + self.algorithm + } + + /// CTR cipher mode + /// # Errors + /// If the key is not valid for the cipher algorithm + pub fn ctr(key: UnboundCipherKey) -> Result { + let context = key.algorithm().new_encryption_context(OperatingMode::CTR)?; + Self::less_safe_ctr(key, context) + } + + /// CTR cipher mode + /// # Errors + /// If the key is not valid for the cipher algorithm + pub fn less_safe_ctr( + key: UnboundCipherKey, + context: EncryptionContext, + ) -> Result { + Self::new(key, OperatingMode::CTR, context) + } + + /// CBC cipher mode + /// # Errors + /// If the key is not valid for the cipher algorithm + pub fn cbc_pkcs7(key: UnboundCipherKey) -> Result { + let context = key.algorithm().new_encryption_context(OperatingMode::CBC)?; + Self::less_safe_cbc_pkcs7(key, context) + } + + /// CBC cipher mode + /// # Errors + /// If the key is not valid for the cipher algorithm + pub fn less_safe_cbc_pkcs7( + key: UnboundCipherKey, + context: EncryptionContext, + ) -> Result { + Self::new(key, OperatingMode::CBC, context) + } +} + +/// A cipher decryption key for streaming encryption operations. +pub struct StreamingDecryptingKey { + algorithm: &'static Algorithm, + mode: OperatingMode, + cipher_ctx: LcPtr, +} +impl StreamingDecryptingKey { + #[allow(clippy::needless_pass_by_value)] + fn new( + key: UnboundCipherKey, + mode: OperatingMode, + context: DecryptionContext, + ) -> Result { + let cipher_ctx = LcPtr::new(unsafe { EVP_CIPHER_CTX_new() })?; + let algorithm = key.algorithm(); + let cipher = mode.evp_cipher(key.algorithm); + let key_bytes = key.key_bytes.as_ref(); + debug_assert_eq!( + key_bytes.len(), + ::try_from(unsafe { EVP_CIPHER_key_length(*cipher) }).unwrap() + ); + let iv = <&[u8]>::try_from(&context)?; + debug_assert_eq!( + iv.len(), + ::try_from(unsafe { EVP_CIPHER_iv_length(*cipher) }).unwrap() + ); + + if 1 != unsafe { + EVP_DecryptInit_ex( + cipher_ctx.as_mut_ptr(), + *cipher, + null_mut(), + key_bytes.as_ptr(), + iv.as_ptr(), + ) + } { + return Err(Unspecified); + } + + Ok(Self { + algorithm, + mode, + cipher_ctx, + }) + } + + /// Decrypt the input and return the output. + /// # Errors + /// Returns an error if the output buffer is too small. + pub fn update<'a>( + &mut self, + input: &[u8], + output: &'a mut [u8], + ) -> Result<&'a [u8], Unspecified> { + if output.len() < (input.len() + self.algorithm.block_len) { + return Err(Unspecified); + } + + let mut outlen: i32 = output.len().try_into()?; + let inlen: i32 = input.len().try_into()?; + if 1 != unsafe { + EVP_DecryptUpdate( + self.cipher_ctx.as_mut_ptr(), + output.as_mut_ptr(), + &mut outlen, + input.as_ptr(), + inlen, + ) + } { + return Err(Unspecified); + } + let outlen: usize = outlen.try_into()?; + Ok(&output[0..outlen]) + } + + /// Finish the decryption and return the output. + /// # Errors + /// Returns an error if the output buffer is too small. + pub fn finish(self, output: &mut [u8]) -> Result<&[u8], Unspecified> { + let mut outlen: i32 = output.len().try_into()?; + if 1 != unsafe { EVP_DecryptFinal_ex(*self.cipher_ctx, output.as_mut_ptr(), &mut outlen) } { + return Err(Unspecified); + } + let outlen: usize = outlen.try_into()?; + Ok(&output[0..outlen]) + } + + /// Returns the cipher operating mode. + #[must_use] + pub fn mode(&self) -> OperatingMode { + self.mode + } + + /// Returns the cipher algorithm + #[must_use] + pub fn algorithm(&self) -> &'static Algorithm { + self.algorithm + } + + /// CTR cipher mode + /// # Errors + /// If the key is not valid for the cipher algorithm + pub fn ctr(key: UnboundCipherKey, context: DecryptionContext) -> Result { + Self::new(key, OperatingMode::CTR, context) + } + + /// CBC cipher mode + /// # Errors + /// If the key is not valid for the cipher algorithm + pub fn cbc_pkcs7( + key: UnboundCipherKey, + context: DecryptionContext, + ) -> Result { + Self::new(key, OperatingMode::CBC, context) + } +} + +#[cfg(test)] +mod tests { + use crate::cipher::streaming::{StreamingDecryptingKey, StreamingEncryptingKey}; + use crate::cipher::{ + DecryptionContext, OperatingMode, UnboundCipherKey, AES_256, AES_256_KEY_LEN, + }; + use crate::rand::{SecureRandom, SystemRandom}; + use paste::*; + + fn step_encrypt( + mut encrypting_key: StreamingEncryptingKey, + plaintext: &[u8], + step: usize, + ) -> (Box<[u8]>, DecryptionContext) { + let alg = encrypting_key.algorithm(); + let mode = encrypting_key.mode(); + let n = plaintext.len(); + let mut ciphertext = vec![0u8; n + alg.block_len()]; + + let mut in_idx: usize = 0; + let mut out_idx: usize = 0; + loop { + let mut in_end = in_idx + step; + if in_end > n { + in_end = n; + } + let out_end = out_idx + (in_end - in_idx) + alg.block_len(); + let output = encrypting_key + .update( + &plaintext[in_idx..in_end], + &mut ciphertext[out_idx..out_end], + ) + .unwrap(); + in_idx += step; + out_idx += output.len(); + if in_idx >= n { + break; + } + } + let out_end = out_idx + alg.block_len(); + let (decrypt_iv, output) = encrypting_key + .finish(&mut ciphertext[out_idx..out_end]) + .unwrap(); + let outlen = output.len(); + ciphertext.truncate(out_idx + outlen); + match mode { + OperatingMode::CBC => { + assert!(ciphertext.len() > plaintext.len()); + assert!(ciphertext.len() <= plaintext.len() + alg.block_len()); + } + OperatingMode::CTR => { + assert_eq!(ciphertext.len(), plaintext.len()); + } + } + + (ciphertext.into_boxed_slice(), decrypt_iv) + } + + fn step_decrypt( + mut decrypting_key: StreamingDecryptingKey, + ciphertext: &[u8], + step: usize, + ) -> Box<[u8]> { + let alg = decrypting_key.algorithm; + let mode = decrypting_key.mode; + let n = ciphertext.len(); + let mut plaintext = vec![0u8; n + alg.block_len()]; + + let mut in_idx: usize = 0; + let mut out_idx: usize = 0; + loop { + let mut in_end = in_idx + step; + if in_end > n { + in_end = n; + } + let out_end = out_idx + (in_end - in_idx) + alg.block_len(); + let output = decrypting_key + .update( + &ciphertext[in_idx..in_end], + &mut plaintext[out_idx..out_end], + ) + .unwrap(); + in_idx += step; + out_idx += output.len(); + if in_idx >= n { + break; + } + } + let out_end = out_idx + alg.block_len(); + let output = decrypting_key + .finish(&mut plaintext[out_idx..out_end]) + .unwrap(); + let outlen = output.len(); + plaintext.truncate(out_idx + outlen); + match mode { + OperatingMode::CBC => { + assert!(ciphertext.len() > plaintext.len()); + assert!(ciphertext.len() <= plaintext.len() + alg.block_len()); + } + OperatingMode::CTR => { + assert_eq!(ciphertext.len(), plaintext.len()); + } + } + plaintext.into_boxed_slice() + } + + macro_rules! helper_stream_step_encrypt_test { + ($mode:ident) => { + paste! { + fn []( + encrypting_key_creator: impl Fn() -> StreamingEncryptingKey, + decrypting_key_creator: impl Fn(DecryptionContext) -> StreamingDecryptingKey, + n: usize, + step: usize, + ) { + let mut input = vec![0u8; n]; + let random = SystemRandom::new(); + random.fill(&mut input).unwrap(); + + let encrypting_key = encrypting_key_creator(); + + let (ciphertext, decrypt_iv) = step_encrypt(encrypting_key, &input, step); + + let decrypting_key = decrypting_key_creator(decrypt_iv); + + let plaintext = step_decrypt(decrypting_key, &ciphertext, step); + + assert_eq!(input.as_slice(), &*plaintext); + } + } + }; + } + + helper_stream_step_encrypt_test!(cbc_pkcs7); + helper_stream_step_encrypt_test!(ctr); + + #[test] + fn test_step_cbc() { + let random = SystemRandom::new(); + let mut key = [0u8; AES_256_KEY_LEN]; + random.fill(&mut key).unwrap(); + let key = key; + + let encrypting_key_creator = || { + let key = UnboundCipherKey::new(&AES_256, &key.clone()).unwrap(); + StreamingEncryptingKey::cbc_pkcs7(key).unwrap() + }; + let decrypting_key_creator = |decryption_ctx: DecryptionContext| { + let key = UnboundCipherKey::new(&AES_256, &key.clone()).unwrap(); + StreamingDecryptingKey::cbc_pkcs7(key, decryption_ctx).unwrap() + }; + + for i in 13..=21 { + for j in 124..=131 { + helper_test_cbc_pkcs7_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + i, + ); + } + for j in 124..=131 { + helper_test_cbc_pkcs7_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + j - i, + ); + } + } + for j in 124..=131 { + helper_test_cbc_pkcs7_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + j, + ); + helper_test_cbc_pkcs7_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + 256, + ); + helper_test_cbc_pkcs7_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + 1, + ); + } + } + + #[test] + fn test_step_ctr() { + let random = SystemRandom::new(); + let mut key = [0u8; AES_256_KEY_LEN]; + random.fill(&mut key).unwrap(); + + let encrypting_key_creator = || { + let key = UnboundCipherKey::new(&AES_256, &key.clone()).unwrap(); + StreamingEncryptingKey::ctr(key).unwrap() + }; + let decrypting_key_creator = |decryption_ctx: DecryptionContext| { + let key = UnboundCipherKey::new(&AES_256, &key.clone()).unwrap(); + StreamingDecryptingKey::ctr(key, decryption_ctx).unwrap() + }; + + for i in 13..=21 { + for j in 124..=131 { + helper_test_ctr_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + i, + ); + } + for j in 124..=131 { + helper_test_ctr_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + j - i, + ); + } + } + for j in 124..=131 { + helper_test_ctr_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + j, + ); + helper_test_ctr_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + 256, + ); + helper_test_ctr_stream_encrypt_step_n_bytes( + encrypting_key_creator, + decrypting_key_creator, + j, + 1, + ); + } + } +} diff --git a/aws-lc-rs/src/ptr.rs b/aws-lc-rs/src/ptr.rs index 0177aae62a2..ca90b680f02 100644 --- a/aws-lc-rs/src/ptr.rs +++ b/aws-lc-rs/src/ptr.rs @@ -5,8 +5,9 @@ use core::ops::Deref; use aws_lc::{ BN_free, ECDSA_SIG_free, EC_GROUP_free, EC_KEY_free, EC_POINT_free, EVP_AEAD_CTX_free, - EVP_PKEY_CTX_free, EVP_PKEY_free, OPENSSL_free, RSA_free, BIGNUM, ECDSA_SIG, EC_GROUP, EC_KEY, - EC_POINT, EVP_AEAD_CTX, EVP_PKEY, EVP_PKEY_CTX, RSA, + EVP_CIPHER_CTX_free, EVP_PKEY_CTX_free, EVP_PKEY_free, OPENSSL_free, RSA_free, BIGNUM, + ECDSA_SIG, EC_GROUP, EC_KEY, EC_POINT, EVP_AEAD_CTX, EVP_CIPHER_CTX, EVP_PKEY, EVP_PKEY_CTX, + RSA, }; use mirai_annotations::verify_unreachable; @@ -207,6 +208,7 @@ create_pointer!(EVP_PKEY, EVP_PKEY_free); create_pointer!(EVP_PKEY_CTX, EVP_PKEY_CTX_free); create_pointer!(RSA, RSA_free); create_pointer!(EVP_AEAD_CTX, EVP_AEAD_CTX_free); +create_pointer!(EVP_CIPHER_CTX, EVP_CIPHER_CTX_free); #[cfg(test)] mod tests {