diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 018a1c98a..de45e3aa8 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -10,7 +10,7 @@ use std::{fmt, sync::Arc}; use tokio::net::TcpStream; #[cfg(feature = "rustls")] use tokio_rustls::{ - rustls::{internal::pemfile, ClientConfig, NoClientAuth, ServerConfig, Session}, + rustls::{internal::pemfile, ClientConfig, NoClientAuth, PrivateKey, ServerConfig, Session}, webpki::DNSNameRef, TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector, }; @@ -34,6 +34,10 @@ pub(crate) struct Cert { enum TlsError { #[allow(dead_code)] H2NotNegotiated, + #[cfg(feature = "rustls")] + CertificateParseError, + #[cfg(feature = "rustls")] + PrivateKeyParseError, } #[derive(Clone)] @@ -181,16 +185,47 @@ impl TlsAcceptor { }) } + #[cfg(feature = "rustls")] + fn load_rustls_private_key( + mut cursor: std::io::Cursor<&[u8]>, + ) -> Result { + // First attempt to load the private key assuming it is PKCS8-encoded + if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) { + if keys.len() > 0 { + return Ok(keys.remove(0)); + } + } + + // If it not, try loading the private key as an RSA key + cursor.set_position(0); + if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) { + if keys.len() > 0 { + return Ok(keys.remove(0)); + } + } + + // Otherwise we have a Private Key parsing problem + Err(Box::new(TlsError::PrivateKeyParseError)) + } + #[cfg(feature = "rustls")] pub(crate) fn new_with_rustls(identity: Identity) -> Result { let cert = { let mut cert = std::io::Cursor::new(&identity.cert.pem[..]); - pemfile::certs(&mut cert).unwrap() + match pemfile::certs(&mut cert) { + Ok(certs) => certs, + Err(_) => return Err(Box::new(TlsError::CertificateParseError)), + } }; let key = { - let mut key = std::io::Cursor::new(&identity.key[..]); - pemfile::pkcs8_private_keys(&mut key).unwrap().remove(0) + let key = std::io::Cursor::new(&identity.key[..]); + match Self::load_rustls_private_key(key) { + Ok(key) => key, + Err(e) => { + return Err(e); + } + } }; let mut config = ServerConfig::new(NoClientAuth::new()); @@ -248,6 +283,13 @@ impl fmt::Display for TlsError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."), + #[cfg(feature = "rustls")] + TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."), + #[cfg(feature = "rustls")] + TlsError::PrivateKeyParseError => write!( + f, + "Error parsing TLS private key - no RSA or PKCS8-encoded keys found." + ), } } }