Skip to content

Commit

Permalink
Fix rustls feature
Browse files Browse the repository at this point in the history
  • Loading branch information
blackbeam committed Mar 18, 2024
1 parent 78fdac5 commit 0c551f8
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ readme = "README.md"
repository = "https://github.com/blackbeam/mysql_async"
version = "0.34.0"
exclude = ["test/*"]
edition = "2018"
edition = "2021"
categories = ["asynchronous", "database"]

[dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

pub use url::ParseError;

mod tls;
pub mod tls;

use mysql_common::{
named_params::MixedParamsError, params::MissingNamedParameterError,
Expand Down
11 changes: 11 additions & 0 deletions src/error/tls/rustls_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

use std::fmt::Display;

use rustls::server::VerifierBuilderError;

#[derive(Debug)]
pub enum TlsError {
Tls(rustls::Error),
Pki(webpki::Error),
InvalidDnsName(webpki::InvalidDnsNameError),
VerifierBuilderError(VerifierBuilderError),
}

impl From<TlsError> for crate::Error {
Expand All @@ -15,6 +18,12 @@ impl From<TlsError> for crate::Error {
}
}

impl From<VerifierBuilderError> for TlsError {
fn from(e: VerifierBuilderError) -> Self {
TlsError::VerifierBuilderError(e)
}
}

impl From<rustls::Error> for TlsError {
fn from(e: rustls::Error) -> Self {
TlsError::Tls(e)
Expand Down Expand Up @@ -57,6 +66,7 @@ impl std::error::Error for TlsError {
TlsError::Tls(e) => Some(e),
TlsError::Pki(e) => Some(e),
TlsError::InvalidDnsName(e) => Some(e),
TlsError::VerifierBuilderError(e) => Some(e),
}
}
}
Expand All @@ -67,6 +77,7 @@ impl Display for TlsError {
TlsError::Tls(e) => e.fmt(f),
TlsError::Pki(e) => e.fmt(f),
TlsError::InvalidDnsName(e) => e.fmt(f),
TlsError::VerifierBuilderError(e) => e.fmt(f),
}
}
}
120 changes: 85 additions & 35 deletions src/io/tls/rustls_io.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
#![cfg(feature = "rustls-tls")]

use std::{convert::TryInto, sync::Arc};
use std::sync::Arc;

use rustls::{
client::{ServerCertVerifier, WebPkiVerifier},
Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore,
client::{
danger::{ServerCertVerified, ServerCertVerifier},
WebPkiServerVerifier,
},
pki_types::{CertificateDer, ServerName},
ClientConfig, RootCertStore,
};

use rustls_pemfile::certs;
use tokio_rustls::TlsConnector;

use crate::{io::Endpoint, Result, SslOpts};
use crate::{io::Endpoint, Result, SslOpts, TlsError};

impl SslOpts {
async fn load_root_certs(&self) -> crate::Result<Vec<Certificate>> {
async fn load_root_certs(&self) -> crate::Result<Vec<CertificateDer<'static>>> {
let mut output = Vec::new();

for root_cert in self.root_certs() {
let root_cert_data = root_cert.read().await?;
let mut seen = false;
for cert in certs(&mut &*root_cert_data)? {
for cert in certs(&mut &*root_cert_data) {
seen = true;
output.push(Certificate(cert));
output.push(cert?);
}

if !seen && !root_cert_data.is_empty() {
output.push(Certificate(root_cert_data.into_owned()));
output.push(CertificateDer::from(root_cert_data.into_owned()));
}
}

Expand All @@ -42,21 +46,13 @@ impl Endpoint {
}

let mut root_store = RootCertStore::empty();
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned()));

for cert in ssl_opts.load_root_certs().await? {
root_store.add(&cert)?;
root_store.add(cert)?;
}

let config_builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store.clone());
let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone());

let mut config = if let Some(identity) = ssl_opts.client_identity() {
let (cert_chain, priv_key) = identity.load().await?;
Expand All @@ -65,12 +61,13 @@ impl Endpoint {
config_builder.with_no_client_auth()
};

let server_name = domain
.as_str()
.try_into()
.map_err(|_| webpki::InvalidDnsNameError)?;
let server_name = ServerName::try_from(domain.as_str())
.map_err(|_| webpki::InvalidDnsNameError)?
.to_owned();
let mut dangerous = config.dangerous();
let web_pki_verifier = WebPkiVerifier::new(root_store, None);
let web_pki_verifier = WebPkiServerVerifier::builder(Arc::new(root_store))
.build()
.map_err(TlsError::from)?;
let dangerous_verifier = DangerousVerifier::new(
ssl_opts.accept_invalid_certs(),
ssl_opts.skip_domain_validation(),
Expand All @@ -97,17 +94,18 @@ impl Endpoint {
}
}

#[derive(Debug)]
struct DangerousVerifier {
accept_invalid_certs: bool,
skip_domain_validation: bool,
verifier: WebPkiVerifier,
verifier: Arc<WebPkiServerVerifier>,
}

impl DangerousVerifier {
fn new(
accept_invalid_certs: bool,
skip_domain_validation: bool,
verifier: WebPkiVerifier,
verifier: Arc<WebPkiServerVerifier>,
) -> Self {
Self {
accept_invalid_certs,
Expand All @@ -118,34 +116,86 @@ impl DangerousVerifier {
}

impl ServerCertVerifier for DangerousVerifier {
// fn verify_server_cert(
// &self,
// end_entity: &Certificate,
// intermediates: &[Certificate],
// server_name: &rustls::ServerName,
// scts: &mut dyn Iterator<Item = &[u8]>,
// ocsp_response: &[u8],
// now: std::time::SystemTime,
// ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
// if self.accept_invalid_certs {
// Ok(rustls::client::ServerCertVerified::assertion())
// } else {
// match self.verifier.verify_server_cert(
// end_entity,
// intermediates,
// server_name,
// scts,
// ocsp_response,
// now,
// ) {
// Ok(assertion) => Ok(assertion),
// Err(ref e)
// if e.to_string().contains("NotValidForName") && self.skip_domain_validation =>
// {
// Ok(rustls::client::ServerCertVerified::assertion())
// }
// Err(e) => Err(e),
// }
// }
// }
fn verify_server_cert(
&self,
end_entity: &Certificate,
intermediates: &[Certificate],
server_name: &rustls::ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &rustls::pki_types::ServerName<'_>,
ocsp_response: &[u8],
now: std::time::SystemTime,
) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
now: rustls::pki_types::UnixTime,
) -> std::prelude::v1::Result<ServerCertVerified, rustls::Error> {
if self.accept_invalid_certs {
Ok(rustls::client::ServerCertVerified::assertion())
Ok(ServerCertVerified::assertion())
} else {
match self.verifier.verify_server_cert(
end_entity,
intermediates,
server_name,
scts,
ocsp_response,
now,
) {
Ok(assertion) => Ok(assertion),
Err(ref e)
if e.to_string().contains("NotValidForName") && self.skip_domain_validation =>
{
Ok(rustls::client::ServerCertVerified::assertion())
Ok(ServerCertVerified::assertion())
}
Err(e) => Err(e),
}
}
}

fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::prelude::v1::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
self.verifier.verify_tls12_signature(message, cert, dss)
}

fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::prelude::v1::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
self.verifier.verify_tls13_signature(message, cert, dss)
}

fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.verifier.supported_verify_schemes()
}
}
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ pub use self::conn::pool::Pool;

#[doc(inline)]
pub use self::error::{
DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, UrlError,
tls::TlsError, DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError,
UrlError,
};

#[doc(inline)]
Expand Down
26 changes: 15 additions & 11 deletions src/opts/rustls_opts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![cfg(feature = "rustls-tls")]

use rustls::{Certificate, PrivateKey};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer};
use rustls_pemfile::{certs, rsa_private_keys};

use std::{borrow::Cow, path::Path};
Expand Down Expand Up @@ -50,27 +50,31 @@ impl ClientIdentity {
self.priv_key.borrow()
}

pub(crate) async fn load(&self) -> crate::Result<(Vec<Certificate>, PrivateKey)> {
pub(crate) async fn load(
&self,
) -> crate::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
let cert_data = self.cert_chain.read().await?;
let key_data = self.priv_key.read().await?;

let mut cert_chain = Vec::new();
if std::str::from_utf8(&cert_data).is_err() {
cert_chain.push(Certificate(cert_data.into_owned()));
cert_chain.push(CertificateDer::from(cert_data.into_owned()));
} else {
for cert in certs(&mut &*cert_data)? {
cert_chain.push(Certificate(cert));
for cert in certs(&mut &*cert_data) {
cert_chain.push(cert?);
}
}

let priv_key = if std::str::from_utf8(&key_data).is_err() {
Some(PrivateKey(key_data.into_owned()))
Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(
key_data.into_owned(),
)))
} else {
rsa_private_keys(&mut &*key_data)?
.into_iter()
.take(1)
.map(PrivateKey)
.next()
let mut priv_key = None;
for key in rsa_private_keys(&mut &*key_data).take(1) {
priv_key = Some(PrivateKeyDer::Pkcs1(key?.clone_key()));
}
priv_key
};

Ok((
Expand Down

0 comments on commit 0c551f8

Please sign in to comment.