diff --git a/mbedtls/src/pk/mod.rs b/mbedtls/src/pk/mod.rs index dfab5cdcc..86b91e7fd 100644 --- a/mbedtls/src/pk/mod.rs +++ b/mbedtls/src/pk/mod.rs @@ -832,6 +832,11 @@ impl Pk { sig: &mut [u8], rng: &mut F, ) -> Result { + // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to auto-detect size and cause an invalid write. + if hash.len() == 0 || sig.len() == 0 { + return Err(Error::PkBadInputData) + } + match self.pk_type() { Type::Rsa | Type::RsaAlt | Type::RsassaPss => { if sig.len() < (self.len() / 8) { @@ -868,6 +873,11 @@ impl Pk { sig: &mut [u8], rng: &mut F, ) -> Result { + // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to auto-detect size and cause an invalid write. + if hash.len() == 0 || sig.len() == 0 { + return Err(Error::PkBadInputData) + } + use crate::rng::RngCallbackMut; if self.pk_type() == Type::Ecdsa || self.pk_type() == Type::Eckey { @@ -913,6 +923,11 @@ impl Pk { } pub fn verify(&mut self, md: MdType, hash: &[u8], sig: &[u8]) -> Result<()> { + // If hash or sig are allowed with size 0 (&[]) then mbedtls will attempt to auto-detect size and cause an invalid write. + if hash.len() == 0 || sig.len() == 0 { + return Err(Error::PkBadInputData) + } + unsafe { pk_verify( &mut self.inner, @@ -1274,6 +1289,18 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi ) .unwrap(); pk.verify(digest, data, &signature[0..len]).unwrap(); + + assert_eq!(pk.verify(digest, data, &[]).unwrap_err(), Error::PkBadInputData); + assert_eq!(pk.verify(digest, &[], &signature[0..len]).unwrap_err(), Error::PkBadInputData); + + + let mut dummy_sig = []; + assert_eq!(pk.sign(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData); + assert_eq!(pk.sign(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData); + + assert_eq!(pk.sign_deterministic(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData); + assert_eq!(pk.sign_deterministic(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData); + } } diff --git a/mbedtls/src/rng/ctr_drbg.rs b/mbedtls/src/rng/ctr_drbg.rs index bf55a615d..96e7e465e 100644 --- a/mbedtls/src/rng/ctr_drbg.rs +++ b/mbedtls/src/rng/ctr_drbg.rs @@ -17,7 +17,12 @@ use mbedtls_sys::types::size_t; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; use crate::error::{IntoResult, Result}; -use crate::rng::{EntropyCallback, RngCallback, RngCallbackMut}; +use crate::rng::{EntropyCallback, EntropyCallbackMut, RngCallback, RngCallbackMut}; + +enum EntropyHolder { + Shared(Arc), + Unique(Box), +} define!( // `ctr_drbg_context` inlines an `aes_context`, which is immovable. See @@ -30,7 +35,7 @@ define!( #[c_box_ty(ctr_drbg_context)] #[repr(C)] struct CtrDrbg { - entropy: Arc, + entropy: EntropyHolder, }; const drop: fn(&mut Self) = ctr_drbg_free; impl<'a> Into {} @@ -63,8 +68,28 @@ impl CtrDrbg { ).into_result()?; } - Ok(CtrDrbg { inner, entropy }) + Ok(CtrDrbg { inner, entropy: EntropyHolder::Shared(entropy) }) + } + + pub fn with_mut_entropy(entropy: T, additional_entropy: Option<&[u8]>) -> Result { + let mut inner = Box::new(ctr_drbg_context::default()); + + // We take sole ownership of entropy, all access is guarded via mutexes. + let mut entropy = Box::new(entropy); + unsafe { + ctr_drbg_init(&mut *inner); + ctr_drbg_seed( + &mut *inner, + Some(T::call_mut), + entropy.data_ptr_mut(), + additional_entropy.map(<[_]>::as_ptr).unwrap_or(::core::ptr::null()), + additional_entropy.map(<[_]>::len).unwrap_or(0) + ).into_result()?; + } + + Ok(CtrDrbg { inner, entropy: EntropyHolder::Unique(entropy) }) } + pub fn prediction_resistance(&self) -> bool { if self.inner.prediction_resistance == CTR_DRBG_PR_OFF { diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index 3cb8bfa9d..fadbdbdc5 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -100,6 +100,46 @@ callback!(DbgCallback: Fn(i32, Cow<'_, str>, i32, Cow<'_, str>) -> ()); callback!(SniCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>); callback!(CaCallback: Fn(&MbedtlsList) -> Result>); + +#[repr(transparent)] +pub struct NullTerminatedStrList { + c: Vec<*mut c_char>, +} + +unsafe impl Send for NullTerminatedStrList {} +unsafe impl Sync for NullTerminatedStrList {} + +impl NullTerminatedStrList { + #[cfg(feature = "std")] + pub fn new(list: &[&str]) -> Result { + let mut ret = NullTerminatedStrList { c: Vec::with_capacity(list.len() + 1) }; + + for item in list { + ret.c.push(::std::ffi::CString::new(*item).map_err(|_| Error::SslBadInputData)?.into_raw()); + } + + ret.c.push(core::ptr::null_mut()); + Ok(ret) + } + + pub fn as_ptr(&self) -> *const *const c_char { + self.c.as_ptr() as *const _ + } +} + +#[cfg(feature = "std")] +impl Drop for NullTerminatedStrList { + fn drop(&mut self) { + for i in self.c.iter() { + unsafe { + if !(*i).is_null() { + let _ = ::std::ffi::CString::from_raw(*i); + } + } + } + } +} + define!( #[c_ty(ssl_config)] #[repr(C)] @@ -116,9 +156,7 @@ define!( ciphersuites: Vec>>, curves: Option>>, - - #[allow(dead_code)] - dhm: Option>, + protocols: Option>, verify_callback: Option>, #[cfg(feature = "std")] @@ -154,7 +192,7 @@ impl Config { rng: None, ciphersuites: vec![], curves: None, - dhm: None, + protocols: None, verify_callback: None, #[cfg(feature = "std")] dbg_callback: None, @@ -184,6 +222,18 @@ impl Config { self.ciphersuites.push(list); } + /// Set the supported Application Layer Protocols. + pub fn set_alpn_protocols(&mut self, protocols: Arc) -> Result<()> { + unsafe { + ssl_conf_alpn_protocols(&mut self.inner, protocols.as_ptr() as *mut _) + .into_result() + .map(|_| ())?; + } + + self.protocols = Some(protocols); + Ok(()) + } + pub fn set_ciphersuites_for_version(&mut self, list: Arc>, major: c_int, minor: c_int) { Self::check_c_list(&list); unsafe { ssl_conf_ciphersuites_for_version(self.into(), list.as_ptr(), major, minor) } @@ -232,13 +282,13 @@ impl Config { /// Takes both DER and PEM forms of FFDH parameters in `DHParams` format. /// /// When calling on PEM-encoded data, `params` must be NULL-terminated - pub fn set_dh_params(&mut self, dhm: Arc) -> Result<()> { + pub fn set_dh_params(&mut self, dhm: &Dhm) -> Result<()> { unsafe { + // This copies the dhm parameters and does not store any pointer to it ssl_conf_dh_param_ctx(self.into(), dhm.inner_ffi_mut()) .into_result() .map(|_| ())?; } - self.dhm = Some(dhm); Ok(()) } @@ -316,12 +366,10 @@ impl Config { // - We can pointer cast to it to allow storing additional objects. // let cb = &mut *(closure as *mut F); - let context = UnsafeFrom::from(ctx).unwrap(); - - let mut ctx = HandshakeContext::init(context); + let ctx = UnsafeFrom::from(ctx).unwrap(); let name = from_raw_parts(name, name_len); - match cb(&mut ctx, name) { + match cb(ctx, name) { Ok(()) => 0, Err(_) => -1, } diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index f40a31f15..04a48a973 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -6,31 +6,32 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ - -use core::any::Any; use core::result::Result as StdResult; + +#[cfg(feature = "std")] +use { + std::io::{Read, Write, Result as IoResult}, + std::sync::Arc, +}; + #[cfg(not(feature = "std"))] use core_io::{Read, Write, Result as IoResult}; -#[cfg(feature = "std")] -use std::io::{Read, Write, Result as IoResult}; -#[cfg(feature = "std")] -use std::sync::Arc; + use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void}; use mbedtls_sys::types::size_t; use mbedtls_sys::*; -use crate::alloc::{List as MbedtlsList}; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; +use crate::alloc::{List as MbedtlsList}; use crate::error::{Error, Result, IntoResult}; use crate::pk::Pk; use crate::private::UnsafeFrom; use crate::ssl::config::{Config, Version, AuthMode}; use crate::x509::{Certificate, Crl, VerifyError}; - -pub trait IoCallback : Any { +pub trait IoCallback { unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int where Self: Sized; unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int where Self: Sized; fn data_ptr(&mut self) -> *mut c_void; @@ -70,13 +71,7 @@ impl IoCallback for IO { define!( #[c_ty(ssl_context)] #[repr(C)] - struct Context { - // config is used read-only for mutliple contexts and is immutable once configured. - config: Arc, - - // Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated. - io: Option>, - + struct HandshakeContext { handshake_ca_cert: Option>>, handshake_crl: Option>, @@ -89,7 +84,31 @@ define!( impl<'a> UnsafeFrom {} ); -impl Context { +#[repr(C)] +pub struct Context { + // Base structure used in SNI callback where we cannot determine the io type. + inner: HandshakeContext, + + // config is used read-only for mutliple contexts and is immutable once configured. + config: Arc, + + // Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated. + io: Option>, +} + +impl<'a, T> Into<*const ssl_context> for &'a Context { + fn into(self) -> *const ssl_context { + self.handle() + } +} + +impl<'a, T> Into<*mut ssl_context> for &'a mut Context { + fn into(self) -> *mut ssl_context { + self.handle_mut() + } +} + +impl Context { pub fn new(config: Arc) -> Self { let mut inner = ssl_context::default(); @@ -99,19 +118,30 @@ impl Context { }; Context { - inner, + inner: HandshakeContext { + inner, + handshake_ca_cert: None, + handshake_crl: None, + + handshake_cert: vec![], + handshake_pk: vec![], + }, config: config.clone(), io: None, - - handshake_ca_cert: None, - handshake_crl: None, - - handshake_cert: vec![], - handshake_pk: vec![], } } - pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> { + pub(crate) fn handle(&self) -> &::mbedtls_sys::ssl_context { + self.inner.handle() + } + + pub(crate) fn handle_mut(&mut self) -> &mut ::mbedtls_sys::ssl_context { + self.inner.handle_mut() + } +} + +impl Context { + pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> { unsafe { let mut io = Box::new(io); ssl_session_reset(self.into()).into_result()?; @@ -127,23 +157,27 @@ impl Context { ); self.io = Some(io); + self.inner.reset_handshake(); + } - self.handshake_cert.clear(); - self.handshake_pk.clear(); - self.handshake_ca_cert = None; - self.handshake_crl = None; + match self.handshake() { + Ok(()) => Ok(()), + Err(Error::SslWantRead) => Err(Error::SslWantRead), + Err(Error::SslWantWrite) => Err(Error::SslWantWrite), + Err(e) => { + self.close(); + Err(e) + }, + } - match ssl_handshake(self.into()).into_result() { - Err(e) => { - // safely end borrow of io - ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None); - self.io = None; - Err(e) - }, - Ok(_) => { - Ok(()) - } - } + } +} + +impl Context { + fn handshake(&mut self) -> Result<()> { + unsafe { + ssl_flush_output(self.into()).into_result()?; + ssl_handshake(self.into()).into_result_discard() } } @@ -187,22 +221,23 @@ impl Context { self.io = None; } } - - pub fn io(&self) -> Option<&dyn Any> { + + pub fn io(&self) -> Option<&T> { self.io.as_ref().map(|v| &**v) } - pub fn io_mut(&mut self) -> Option<&mut dyn Any> { + + pub fn io_mut(&mut self) -> Option<&mut T> { self.io.as_mut().map(|v| &mut **v) } /// Return the minor number of the negotiated TLS version pub fn minor_version(&self) -> i32 { - self.inner.minor_ver + self.handle().minor_ver } /// Return the major number of the negotiated TLS version pub fn major_version(&self) -> i32 { - self.inner.major_ver + self.handle().major_ver } /// Return the number of bytes currently available to read that @@ -231,27 +266,41 @@ impl Context { /// All assigned ciphersuites are listed by the IANA in /// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt pub fn ciphersuite(&self) -> Result { - if self.inner.session.is_null() { + if self.handle().session.is_null() { return Err(Error::SslBadInputData); } - Ok(unsafe { self.inner.session.as_ref().unwrap().ciphersuite as u16 }) + Ok(unsafe { self.handle().session.as_ref().unwrap().ciphersuite as u16 }) } pub fn peer_cert(&self) -> Result>> { - if self.inner.session.is_null() { + if self.handle().session.is_null() { return Err(Error::SslBadInputData); } unsafe { // We cannot call the peer cert function as we need a pointer to a pointer to create the MbedtlsList, we need something in the heap / cannot use any local variable for that. - let peer_cert : &MbedtlsList = UnsafeFrom::from(&((*self.inner.session).peer_cert) as *const *mut x509_crt as *const *const x509_crt).ok_or(Error::SslBadInputData)?; + let peer_cert : &MbedtlsList = UnsafeFrom::from(&((*self.handle().session).peer_cert) as *const *mut x509_crt as *const *const x509_crt).ok_or(Error::SslBadInputData)?; Ok(Some(peer_cert)) } } + + + #[cfg(feature = "std")] + pub fn get_alpn_protocol(&self) -> Result> { + unsafe { + let ptr = ssl_get_alpn_protocol(self.handle()); + if ptr.is_null() { + Ok(None) + } else { + let s = std::ffi::CStr::from_ptr(ptr).to_str()?; + Ok(Some(s)) + } + } + } } -impl Drop for Context { +impl Drop for Context { fn drop(&mut self) { unsafe { self.close(); @@ -260,7 +309,7 @@ impl Drop for Context { } } -impl Read for Context { +impl Read for Context { fn read(&mut self, buf: &mut [u8]) -> IoResult { match unsafe { ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result() } { Err(Error::SslPeerCloseNotify) => Ok(0), @@ -270,7 +319,7 @@ impl Read for Context { } } -impl Write for Context { +impl Write for Context { fn write(&mut self, buf: &[u8]) -> IoResult { match unsafe { ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result() } { Err(Error::SslPeerCloseNotify) => Ok(0), @@ -283,12 +332,6 @@ impl Write for Context { Ok(()) } } - - -pub struct HandshakeContext<'ctx> { - pub context: &'ctx mut Context, -} - // // Class exists only during SNI callback that is configured from Config. // SNI Callback must provide input whos lifetime exceed the SNI closure to avoid memory corruptions. @@ -301,42 +344,44 @@ pub struct HandshakeContext<'ctx> { // - mbedtls not providing any callbacks on handshake finish. // - no reasonable way to obtain a storage within the sni callback tied to the handshake or to the rust Context. (without resorting to a unscalable map or pointer magic that mbedtls may invalidate) // -impl<'ctx> HandshakeContext<'ctx> { - - pub(crate) fn init(context: &'ctx mut Context) -> Self { - HandshakeContext { context } +impl HandshakeContext { + fn reset_handshake(&mut self) { + self.handshake_cert.clear(); + self.handshake_pk.clear(); + self.handshake_ca_cert = None; + self.handshake_crl = None; } pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> { - if self.context.inner.handshake as *const _ == ::core::ptr::null() { + if self.inner.handshake as *const _ == ::core::ptr::null() { return Err(Error::SslBadInputData); } - unsafe { ssl_set_hs_authmode(self.context.into(), am as i32) } + unsafe { ssl_set_hs_authmode(self.into(), am as i32) } Ok(()) } pub fn set_ca_list( &mut self, - chain: Arc>, + chain: Option>>, crl: Option>, ) -> Result<()> { // mbedtls_ssl_set_hs_ca_chain does not check for NULL handshake. - if self.context.inner.handshake as *const _ == ::core::ptr::null() { + if self.inner.handshake as *const _ == ::core::ptr::null() { return Err(Error::SslBadInputData); } // This will override current handshake CA chain. unsafe { ssl_set_hs_ca_chain( - self.context.into(), - chain.inner_ffi_mut(), + self.into(), + chain.as_ref().map(|chain| chain.inner_ffi_mut()).unwrap_or(::core::ptr::null_mut()), crl.as_ref().map(|crl| crl.inner_ffi_mut()).unwrap_or(::core::ptr::null_mut()), ); } - self.context.handshake_ca_cert = Some(chain); - self.context.handshake_crl = crl; + self.handshake_ca_cert = chain; + self.handshake_crl = crl; Ok(()) } @@ -350,21 +395,31 @@ impl<'ctx> HandshakeContext<'ctx> { key: Arc, ) -> Result<()> { // mbedtls_ssl_set_hs_own_cert does not check for NULL handshake. - if self.context.inner.handshake as *const _ == ::core::ptr::null() { + if self.inner.handshake as *const _ == ::core::ptr::null() { return Err(Error::SslBadInputData); } // This will append provided certificate pointers in internal structures. unsafe { - ssl_set_hs_own_cert(self.context.into(), chain.inner_ffi_mut(), key.inner_ffi_mut()).into_result()?; + ssl_set_hs_own_cert(self.into(), chain.inner_ffi_mut(), key.inner_ffi_mut()).into_result()?; } - self.context.handshake_cert.push(chain); - self.context.handshake_pk.push(key); + self.handshake_cert.push(chain); + self.handshake_pk.push(key); Ok(()) } } +#[cfg(test)] +mod tests { + use crate::ssl::context::HandshakeContext; + use crate::tests::TestTrait; + + #[test] + fn handshakecontext_sync() { + assert!(!TestTrait::::new().impls_trait(), "HandshakeContext must be !Sync"); + } +} // ssl_get_alpn_protocol // ssl_get_max_frag_len diff --git a/mbedtls/tests/hyper.rs b/mbedtls/tests/hyper.rs index 07a890ba7..475599066 100644 --- a/mbedtls/tests/hyper.rs +++ b/mbedtls/tests/hyper.rs @@ -12,12 +12,12 @@ use mbedtls::ssl::{Config, Context}; // Native TLS compatibility - to move to native tls client in the future #[derive(Clone)] pub struct TlsStream { - context: Arc>, + context: Arc>>, phantom: PhantomData, } impl TlsStream { - pub fn new(context: Arc>) -> Self { + pub fn new(context: Arc>>) -> Self { TlsStream { context: context, phantom: PhantomData, @@ -28,14 +28,15 @@ impl TlsStream { unsafe impl Send for TlsStream {} unsafe impl Sync for TlsStream {} -impl io::Read for TlsStream + +impl io::Read for TlsStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.context.lock().unwrap().read(buf) } } -impl io::Write for TlsStream +impl io::Write for TlsStream { fn write(&mut self, buf: &[u8]) -> io::Result { self.context.lock().unwrap().write(buf) @@ -52,19 +53,19 @@ impl NetworkStream for TlsStream fn peer_addr(&mut self) -> io::Result { self.context.lock().unwrap().io_mut() .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))? - .downcast_mut::().unwrap().peer_addr() + .peer_addr() } fn set_read_timeout(&self, dur: Option) -> io::Result<()> { self.context.lock().unwrap().io_mut() .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))? - .downcast_mut::().unwrap().set_read_timeout(dur) + .set_read_timeout(dur) } fn set_write_timeout(&self, dur: Option) -> io::Result<()> { self.context.lock().unwrap().io_mut() .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))? - .downcast_mut::().unwrap().set_write_timeout(dur) + .set_write_timeout(dur) } }