diff --git a/Cargo.lock b/Cargo.lock index 427636f53..992bb96e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -822,6 +822,7 @@ dependencies = [ "proptest", "reqwest", "rstest", + "rustls-cng", "serde", "serde_derive", "serde_json", @@ -844,7 +845,6 @@ dependencies = [ "url", "utoipa", "uuid", - "x509-cert", "zeroize", ] @@ -3103,6 +3103,18 @@ dependencies = [ "sct", ] +[[package]] +name = "rustls-cng" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3612780e9bada1c1f6cf3ad98604c86f403947a78e91943bbec58f690f77ce09" +dependencies = [ + "rustls 0.21.7", + "sha2 0.10.8", + "thiserror", + "windows-sys 0.48.0", +] + [[package]] name = "rustls-native-certs" version = "0.6.3" diff --git a/README.md b/README.md index a5969571f..f03b4fffb 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,12 @@ Currently, stable options are: - `TlsPrivateKeyFile`: path to the private key to use for TLS, +- `UseWindowsCertificateStore`: enable usage of the Windows Certificate Store, + +- `WindowsCertificateStoreType`: type of the Windows Certificate Store to use, + +- `WindowsCertificateStoreName`: name of the Windows Certificate Store to use, + - `Listeners`: array of listener URLs. Each element has the following schema: diff --git a/devolutions-gateway/Cargo.toml b/devolutions-gateway/Cargo.toml index 0d606edca..1d8bcc04d 100644 --- a/devolutions-gateway/Cargo.toml +++ b/devolutions-gateway/Cargo.toml @@ -51,7 +51,6 @@ backoff = "0.4" # Security, crypto… picky = { version = "7.0.0-rc.8", default-features = false, features = ["jose", "x509"] } zeroize = { version = "1.6", features = ["derive"] } -x509-cert = { version = "0.2", features = ["std"] } multibase = "0.9" # Logging @@ -94,6 +93,9 @@ packet = { git = "https://github.com/fdubois1/rust-packet.git" } # For KDC proxy portpicker = "0.1" +[target.'cfg(windows)'.dependencies] +rustls-cng = "0.3" + [target.'cfg(windows)'.build-dependencies] embed-resource = "2.4" diff --git a/devolutions-gateway/src/config.rs b/devolutions-gateway/src/config.rs index 9d73ea228..91802d73f 100644 --- a/devolutions-gateway/src/config.rs +++ b/devolutions-gateway/src/config.rs @@ -43,45 +43,21 @@ pub struct TlsPublicKey(pub Vec); #[derive(Clone)] pub struct Tls { pub acceptor: tokio_rustls::TlsAcceptor, - pub leaf_certificate: rustls::Certificate, - pub leaf_public_key: TlsPublicKey, } impl fmt::Debug for Tls { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsConfig") - .field("certificate", &self.leaf_certificate) - .field("public_key", &self.leaf_public_key) - .finish_non_exhaustive() + f.debug_struct("TlsConfig").finish_non_exhaustive() } } impl Tls { - fn init(certificates: Vec, private_key: rustls::PrivateKey) -> anyhow::Result { - use x509_cert::der::Decode as _; + fn init(cert_source: crate::tls::CertificateSource) -> anyhow::Result { + let tls_server_config = crate::tls::build_server_config(cert_source).context("failed build TLS config")?; - let leaf_certificate = certificates.last().context("TLS leaf certificate is missing")?.clone(); + let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_server_config)); - let leaf_public_key = x509_cert::Certificate::from_der(&leaf_certificate.0) - .context("failed to parse leaf TLS certificate")? - .tbs_certificate - .subject_public_key_info - .subject_public_key - .as_bytes() - .context("subject public key BIT STRING is not aligned")? - .to_owned() - .pipe(TlsPublicKey); - - let rustls_config = - crate::tls::build_server_config(certificates, private_key).context("failed build TLS config")?; - - let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(rustls_config)); - - Ok(Self { - acceptor, - leaf_certificate, - leaf_public_key, - }) + Ok(Self { acceptor }) } } @@ -144,11 +120,27 @@ impl Conf { .tls_certificate_file .as_ref() .zip(conf_file.tls_private_key_file.as_ref()) - .map(|(cert_file, key_file)| { - let tls_certificate = read_rustls_certificate_file(cert_file).context("TLS certificate")?; - let tls_private_key = read_rustls_priv_key_file(key_file).context("TLS private key")?; - Tls::init(tls_certificate, tls_private_key).context("failed to init TLS config") + .map(|(cert_file, key_file)| -> anyhow::Result<_> { + let certificates = read_rustls_certificate_file(cert_file).context("TLS certificate")?; + let private_key = read_rustls_priv_key_file(key_file).context("TLS private key")?; + Ok(crate::tls::CertificateSource::External { + certificates, + private_key, + }) }) + .transpose()? + .or_else(|| { + conf_file.use_windows_certificate_store.unwrap_or(false).then(|| { + crate::tls::CertificateSource::WindowsCertificateStore { + store_type: conf_file.windows_certificate_store_type.unwrap_or_default(), + store_name: conf_file + .windows_certificate_store_name + .clone() + .unwrap_or_else(|| String::from("my")), + } + }) + }) + .map(|cert_source| Tls::init(cert_source).context("failed to init TLS config")) .transpose()?; let requires_tls = listeners @@ -217,7 +209,7 @@ impl Conf { sogar: conf_file.sogar.clone().unwrap_or_default(), jrl_file, ngrok: conf_file.ngrok.clone(), - verbosity_profile: conf_file.verbosity_profile, + verbosity_profile: conf_file.verbosity_profile.unwrap_or_default(), debug: conf_file.debug.clone().unwrap_or_default(), }) } @@ -561,8 +553,6 @@ fn to_listener_urls(conf: &dto::ListenerConf, hostname: &str, auto_ipv6: bool) - pub mod dto { use std::collections::HashMap; - use serde::{de, ser}; - use super::*; /// Source of truth for Gateway configuration @@ -603,6 +593,15 @@ pub mod dto { /// Private key to use for TLS #[serde(alias = "PrivateKeyFile")] pub tls_private_key_file: Option, + /// Enable usage of the Windows Certificate Store + #[serde(skip_serializing_if = "Option::is_none")] + pub use_windows_certificate_store: Option, + /// Type of the Windows Certificate Store to use + #[serde(skip_serializing_if = "Option::is_none")] + pub windows_certificate_store_type: Option, + /// Name of the Windows Certificate Store to use + #[serde(skip_serializing_if = "Option::is_none")] + pub windows_certificate_store_name: Option, /// Listeners to launch at startup #[serde(default, skip_serializing_if = "Vec::is_empty")] @@ -621,8 +620,8 @@ pub mod dto { pub ngrok: Option, /// Verbosity profile - #[serde(default, skip_serializing_if = "VerbosityProfile::is_default")] - pub verbosity_profile: VerbosityProfile, + #[serde(skip_serializing_if = "Option::is_none")] + pub verbosity_profile: Option, /// (Unstable) Folder and prefix for log files #[serde(skip_serializing_if = "Option::is_none")] @@ -663,6 +662,9 @@ pub mod dto { delegation_private_key_data: None, tls_certificate_file: None, tls_private_key_file: None, + use_windows_certificate_store: None, + windows_certificate_store_type: None, + windows_certificate_store_name: None, listeners: vec![ ListenerConf { internal_url: "tcp://*:8181".try_into().unwrap(), @@ -675,7 +677,7 @@ pub mod dto { ], subscriber: None, ngrok: None, - verbosity_profile: VerbosityProfile::default(), + verbosity_profile: None, log_file: None, jrl_file: None, plugins: None, @@ -703,12 +705,6 @@ pub mod dto { Quiet, } - impl VerbosityProfile { - pub fn is_default(&self) -> bool { - Self::default().eq(self) - } - } - /// Unsafe debug options that should only ever be used at development stage /// /// These options might change or get removed without further notice. @@ -796,7 +792,7 @@ pub mod dto { } } - #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + #[derive(PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)] pub enum SogarPermission { Push, Pull, @@ -810,7 +806,7 @@ pub mod dto { } #[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))] - #[derive(PartialEq, Eq, Debug, Clone, Default, Serialize, Deserialize)] + #[derive(PartialEq, Eq, Debug, Clone, Copy, Default, Serialize, Deserialize)] pub enum DataEncoding { #[default] Multibase, @@ -820,13 +816,13 @@ pub mod dto { Base64UrlPad, } - #[derive(PartialEq, Eq, Debug, Clone, Default, Serialize, Deserialize)] + #[derive(PartialEq, Eq, Debug, Clone, Copy, Default, Serialize, Deserialize)] pub enum CertFormat { #[default] X509, } - #[derive(PartialEq, Eq, Debug, Clone, Default, Serialize, Deserialize)] + #[derive(PartialEq, Eq, Debug, Clone, Copy, Default, Serialize, Deserialize)] pub enum PrivKeyFormat { #[default] Pkcs8, @@ -835,7 +831,7 @@ pub mod dto { } #[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))] - #[derive(PartialEq, Eq, Debug, Clone, Default, Serialize, Deserialize)] + #[derive(PartialEq, Eq, Debug, Clone, Copy, Default, Serialize, Deserialize)] pub enum PubKeyFormat { #[default] Spki, @@ -946,4 +942,12 @@ pub mod dto { #[serde(default, skip_serializing_if = "Vec::is_empty")] pub deny_cidrs: Vec, } + + #[derive(PartialEq, Eq, Debug, Clone, Copy, Default, Serialize, Deserialize)] + pub enum WindowsCertStoreType { + #[default] + LocalMachine, + CurrentUser, + CurrentService, + } } diff --git a/devolutions-gateway/src/tls.rs b/devolutions-gateway/src/tls.rs index 6936f7d96..639aaa127 100644 --- a/devolutions-gateway/src/tls.rs +++ b/devolutions-gateway/src/tls.rs @@ -52,18 +52,105 @@ pub async fn connect(dns_name: &str, stream: TcpStream) -> io::Result, - private_key: rustls::PrivateKey, -) -> anyhow::Result { - rustls::ServerConfig::builder() +pub enum CertificateSource { + External { + certificates: Vec, + private_key: rustls::PrivateKey, + }, + #[cfg(windows)] + WindowsCertificateStore { + store_type: crate::config::dto::WindowsCertStoreType, + store_name: String, + }, +} + +pub fn build_server_config(cert_source: CertificateSource) -> anyhow::Result { + let builder = rustls::ServerConfig::builder() .with_cipher_suites(rustls::DEFAULT_CIPHER_SUITES) // = with_safe_default_cipher_suites, but explicit, just to show we are using rustls's default cipher suites .with_safe_default_kx_groups() .with_protocol_versions(rustls::DEFAULT_VERSIONS) // = with_safe_default_protocol_versions, but explicit as well .context("couldn't set supported TLS protocol versions")? - .with_no_client_auth() - .with_single_cert(certificates, private_key) - .context("couldn't set server config cert") + .with_no_client_auth(); + + match cert_source { + CertificateSource::External { + certificates, + private_key, + } => builder + .with_single_cert(certificates, private_key) + .context("couldn't set server config cert"), + #[cfg(windows)] + CertificateSource::WindowsCertificateStore { store_type, store_name } => { + let resolver = windows::ServerCertResolver::open_store(store_type, &store_name) + .context("create ServerCertResolver")?; + Ok(builder.with_cert_resolver(Arc::new(resolver))) + } + } +} + +#[cfg(windows)] +pub mod windows { + use std::sync::Arc; + + use anyhow::Context as _; + use rustls_cng::{ + signer::CngSigningKey, + store::{CertStore, CertStoreType}, + }; + use tokio_rustls::rustls::{ + server::{ClientHello, ResolvesServerCert}, + sign::CertifiedKey, + Certificate, + }; + + use crate::config::dto; + + pub struct ServerCertResolver(CertStore); + + impl ServerCertResolver { + pub fn open_store(store_type: dto::WindowsCertStoreType, store_name: &str) -> anyhow::Result { + let store_type = match store_type { + dto::WindowsCertStoreType::LocalMachine => CertStoreType::LocalMachine, + dto::WindowsCertStoreType::CurrentUser => CertStoreType::CurrentUser, + dto::WindowsCertStoreType::CurrentService => CertStoreType::CurrentService, + }; + + let store = CertStore::open(store_type, store_name).context("open Windows certificate store")?; + + Ok(Self(store)) + } + } + + impl ResolvesServerCert for ServerCertResolver { + fn resolve(&self, client_hello: ClientHello) -> Option> { + trace!(server_name = ?client_hello.server_name()); + let name = client_hello.server_name()?; + + // look up certificate by subject + let contexts = self.0.find_by_subject_str(name).ok()?; + + // attempt to acquire a private key and construct CngSigningKey + let (context, key) = contexts.into_iter().find_map(|ctx| { + let key = ctx.acquire_key().ok()?; + CngSigningKey::new(key).ok().map(|key| (ctx, key)) + })?; + + trace!(key_algorithm_group = ?key.key().algorithm_group()); + trace!(key_algorithm = ?key.key().algorithm()); + + // attempt to acquire a full certificate chain + let chain = context.as_chain_der().ok()?; + let certs = chain.into_iter().map(Certificate).collect(); + + // return CertifiedKey instance + Some(Arc::new(CertifiedKey { + cert: certs, + key: Arc::new(key), + ocsp: None, + sct_list: None, + })) + } + } } pub mod sanity {