From b1e3351d7f3147afab9d29596d324d93e2198ba3 Mon Sep 17 00:00:00 2001 From: doujiang24 Date: Fri, 26 Jul 2024 01:03:34 +0800 Subject: [PATCH] new extension for TLS cert selection (#32465) fix https://github.com/envoyproxy/envoy/issues/30600 Commit Message: Add an extension point to allow overriding TLS certificate selection behavior. An extension can select certificate base on the incoming SNI, in both sync and async mode. Signed-off-by: doujiang24 --- .../transport_sockets/tls/v3/tls.proto | 9 +- .../network/test/postgres_integration_test.cc | 2 +- envoy/ssl/context_config.h | 5 + envoy/ssl/context_manager.h | 3 - envoy/ssl/handshaker.h | 112 +++++ envoy/ssl/ssl_socket_extended_info.h | 29 ++ mobile/test/common/integration/test_server.cc | 4 +- mobile/test/performance/files_em_does_not_use | 1 + .../quic_server_transport_socket_factory.cc | 2 +- source/common/tls/BUILD | 3 + source/common/tls/context_config_impl.cc | 1 + source/common/tls/context_impl.cc | 2 +- source/common/tls/context_impl.h | 2 +- .../tls/default_tls_certificate_selector.cc | 269 ++++++++++++ .../tls/default_tls_certificate_selector.h | 85 ++++ .../common/tls/server_context_config_impl.cc | 34 +- .../common/tls/server_context_config_impl.h | 8 +- source/common/tls/server_context_impl.cc | 313 +++----------- source/common/tls/server_context_impl.h | 39 +- source/common/tls/ssl_handshaker.cc | 57 +++ source/common/tls/ssl_handshaker.h | 33 ++ source/common/tls/ssl_socket.cc | 9 + source/common/tls/ssl_socket.h | 1 + .../tls/downstream_config.cc | 2 +- .../grpc_client_integration_test_harness.h | 2 +- test/common/quic/BUILD | 1 + .../quic/envoy_quic_proof_source_test.cc | 11 + test/common/tls/BUILD | 59 +++ test/common/tls/cert_selector/BUILD | 24 ++ .../tls/cert_selector/async_cert_selector.cc | 60 +++ .../tls/cert_selector/async_cert_selector.h | 90 ++++ test/common/tls/cert_selector/stats.cc | 21 + test/common/tls/cert_selector/stats.h | 32 ++ test/common/tls/cert_validator/test_common.h | 15 + test/common/tls/context_impl_test.cc | 135 +++--- test/common/tls/handshaker_factory_test.cc | 2 +- test/common/tls/handshaker_test.cc | 1 + test/common/tls/integration/BUILD | 1 + .../tls/integration/ssl_integration_test.cc | 113 +++++ .../integration/ssl_integration_test_base.cc | 1 + .../integration/ssl_integration_test_base.h | 1 + test/common/tls/ssl_socket_test.cc | 42 +- .../tls/tls_certificate_selector_test.cc | 405 ++++++++++++++++++ test/config/utility.cc | 10 + test/config/utility.h | 6 + .../grpc/xds_failover_integration_test.cc | 2 +- .../http/router/auto_sni_integration_test.cc | 2 +- .../upstream_starttls_integration_test.cc | 2 +- .../alpn_selection_integration_test.cc | 2 +- test/integration/base_integration_test.cc | 4 +- .../sds_dynamic_integration_test.cc | 2 +- test/integration/ssl_utility.cc | 8 +- test/integration/xfcc_integration_test.cc | 4 +- test/mocks/ssl/mocks.h | 2 + test/per_file_coverage.sh | 2 +- 55 files changed, 1708 insertions(+), 379 deletions(-) create mode 100644 source/common/tls/default_tls_certificate_selector.cc create mode 100644 source/common/tls/default_tls_certificate_selector.h create mode 100644 test/common/tls/cert_selector/BUILD create mode 100644 test/common/tls/cert_selector/async_cert_selector.cc create mode 100644 test/common/tls/cert_selector/async_cert_selector.h create mode 100644 test/common/tls/cert_selector/stats.cc create mode 100644 test/common/tls/cert_selector/stats.h create mode 100644 test/common/tls/tls_certificate_selector_test.cc diff --git a/api/envoy/extensions/transport_sockets/tls/v3/tls.proto b/api/envoy/extensions/transport_sockets/tls/v3/tls.proto index c22d580e6f5d..c305ff74f42a 100644 --- a/api/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/api/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -163,7 +163,7 @@ message TlsKeyLog { } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 16] +// [#next-free-field: 17] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CommonTlsContext"; @@ -274,6 +274,13 @@ message CommonTlsContext { // [#not-implemented-hide:] CertificateProviderPluginInstance tls_certificate_provider_instance = 14; + // Custom TLS certificate selector. + // + // Select TLS certificate based on TLS client hello. + // If empty, defaults to native TLS certificate selection behavior: + // DNS SANs or Subject Common Name in TLS certificates is extracted as server name pattern to match SNI. + config.core.v3.TypedExtensionConfig custom_tls_certificate_selector = 16; + // Certificate provider for fetching TLS certificates. // [#not-implemented-hide:] CertificateProvider tls_certificate_certificate_provider = 9 diff --git a/contrib/postgres_proxy/filters/network/test/postgres_integration_test.cc b/contrib/postgres_proxy/filters/network/test/postgres_integration_test.cc index 3f3b2dd5a1c5..e94f720b33bc 100644 --- a/contrib/postgres_proxy/filters/network/test/postgres_integration_test.cc +++ b/contrib/postgres_proxy/filters/network/test/postgres_integration_test.cc @@ -318,7 +318,7 @@ class UpstreamSSLBaseIntegrationTest : public PostgresBaseIntegrationTest { NiceMock mock_factory_ctx; ON_CALL(mock_factory_ctx.server_context_, api()).WillByDefault(testing::ReturnRef(*api_)); auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - downstream_tls_context, mock_factory_ctx); + downstream_tls_context, mock_factory_ctx, false); static auto* client_stats_store = new Stats::TestIsolatedStoreImpl(); Network::DownstreamTransportSocketFactoryPtr tls_context = Network::DownstreamTransportSocketFactoryPtr{ diff --git a/envoy/ssl/context_config.h b/envoy/ssl/context_config.h index a247c71533ad..53619a58805b 100644 --- a/envoy/ssl/context_config.h +++ b/envoy/ssl/context_config.h @@ -204,6 +204,11 @@ class ServerContextConfig : public virtual ContextConfig { * @return true if the client cipher preference is enabled, false otherwise. */ virtual bool preferClientCiphers() const PURE; + + /** + * @return a factory which can be used to create TLS context provider instances. + */ + virtual TlsCertificateSelectorFactory tlsCertificateSelectorFactory() const PURE; }; using ServerContextConfigPtr = std::unique_ptr; diff --git a/envoy/ssl/context_manager.h b/envoy/ssl/context_manager.h index af3837e9451c..8a7d2ff6736c 100644 --- a/envoy/ssl/context_manager.h +++ b/envoy/ssl/context_manager.h @@ -19,9 +19,6 @@ class CommonFactoryContext; namespace Ssl { -// Opaque type defined and used by the ``ServerContext``. -struct TlsContext; - using ContextAdditionalInitFunc = std::function; diff --git a/envoy/ssl/handshaker.h b/envoy/ssl/handshaker.h index 3cd4972c8508..30b6b07d1403 100644 --- a/envoy/ssl/handshaker.h +++ b/envoy/ssl/handshaker.h @@ -12,8 +12,20 @@ #include "openssl/ssl.h" namespace Envoy { + +namespace Server { +namespace Configuration { +class CommonFactoryContext; +} // namespace Configuration +} // namespace Server + namespace Ssl { +// Opaque type defined and used by the ``ServerContext``. +struct TlsContext; + +class ServerContextConfig; + class HandshakeCallbacks { public: virtual ~HandshakeCallbacks() = default; @@ -45,6 +57,12 @@ class HandshakeCallbacks { * asynchronous. */ virtual void onAsynchronousCertValidationComplete() PURE; + + /** + * A callback to be called upon certificate selection completion if the selection is + * asynchronous. + */ + virtual void onAsynchronousCertificateSelectionComplete() PURE; }; /** @@ -156,5 +174,99 @@ class HandshakerFactory : public Config::TypedFactory { virtual SslCtxCb sslctxCb(HandshakerFactoryContext& handshaker_factory_context) const PURE; }; +struct SelectionResult { + enum class SelectionStatus { + // A certificate was successfully selected. + Success, + // Certificate selection will complete asynchronously later. + Pending, + // Certificate selection failed. + Failed, + }; + SelectionStatus status; // Status of the certificate selection. + // Selected TLS context which it only be non-null when status is Success. + const Ssl::TlsContext* selected_ctx; + // True if OCSP stapling should be enabled. + bool staple; +}; + +/** + * Used to return the result from an asynchronous cert selection. + */ +class CertificateSelectionCallback { +public: + virtual ~CertificateSelectionCallback() = default; + + virtual Event::Dispatcher& dispatcher() PURE; + + /** + * Called when the asynchronous cert selection completes. + * @param selected_ctx selected Ssl::TlsContext, it's empty when selection failed. + * @param staple true when need to set OCSP response. + */ + virtual void onCertificateSelectionResult(OptRef selected_ctx, + bool staple) PURE; +}; + +using CertificateSelectionCallbackPtr = std::unique_ptr; + +enum class OcspStapleAction { Staple, NoStaple, Fail, ClientNotCapable }; + +class TlsCertificateSelector { +public: + virtual ~TlsCertificateSelector() = default; + + /** + * Select TLS context based on the client hello in non-QUIC TLS handshake. + * + * @return selected_ctx should only not be null when status is SelectionStatus::Success, and it + * will have the same lifetime as ``ServerContextImpl``. + */ + virtual SelectionResult selectTlsContext(const SSL_CLIENT_HELLO& ssl_client_hello, + CertificateSelectionCallbackPtr cb) PURE; + + /** + * Finds the best matching context in QUIC TLS handshake, which doesn't support async mode yet. + * + * @return context will have the same lifetime as ``ServerContextImpl``. + */ + virtual std::pair + findTlsContext(absl::string_view sni, bool client_ecdsa_capable, bool client_ocsp_capable, + bool* cert_matched_sni) PURE; +}; + +using TlsCertificateSelectorPtr = std::unique_ptr; + +class TlsCertificateSelectorContext { +public: + virtual ~TlsCertificateSelectorContext() = default; + + /** + * @return reference to the initialized Tls Contexts. + */ + virtual const std::vector& getTlsContexts() const PURE; +}; + +using TlsCertificateSelectorFactory = std::function; + +class TlsCertificateSelectorConfigFactory : public Config::TypedFactory { +public: + /** + * @param for_quic true when in quic context, which does not support selecting certificate + * asynchronously. + * @returns a factory to create a TlsCertificateSelector. Accepts the |config| and + * |validation_visitor| for early validation. This virtual base doesn't + * perform MessageUtil::downcastAndValidate, but an implementation should. + */ + virtual TlsCertificateSelectorFactory + createTlsCertificateSelectorFactory(const Protobuf::Message& config, + Server::Configuration::CommonFactoryContext& factory_context, + ProtobufMessage::ValidationVisitor& validation_visitor, + absl::Status& creation_status, bool for_quic) PURE; + + std::string category() const override { return "envoy.tls.certificate_selectors"; } +}; + } // namespace Ssl } // namespace Envoy diff --git a/envoy/ssl/ssl_socket_extended_info.h b/envoy/ssl/ssl_socket_extended_info.h index b26bc96ce851..192f395c204e 100644 --- a/envoy/ssl/ssl_socket_extended_info.h +++ b/envoy/ssl/ssl_socket_extended_info.h @@ -11,6 +11,9 @@ namespace Envoy { namespace Ssl { +// Opaque type defined and used by the ``ServerContext``. +struct TlsContext; + enum class ClientValidationStatus { NotValidated, NoClientCertificate, Validated, Failed }; enum class ValidateStatus { @@ -20,6 +23,13 @@ enum class ValidateStatus { Failed, }; +enum class CertificateSelectionStatus { + NotStarted, + Pending, + Successful, + Failed, +}; + /** * Used to return the result from an asynchronous cert validation. */ @@ -82,6 +92,25 @@ class SslExtendedSocketInfo { * case of failure. */ virtual uint8_t certificateValidationAlert() const PURE; + + /** + * @return CertificateSelectionCallbackPtr a callback used to return the cert selection result. + */ + virtual CertificateSelectionCallbackPtr createCertificateSelectionCallback() PURE; + + /** + * Called after the cert selection completes either synchronously or asynchronously. + * @param selected_ctx selected Ssl::TlsContext, it's empty when selection failed. + * @param async true if the validation is completed asynchronously. + * @param staple true when need to set OCSP response. + */ + virtual void onCertificateSelectionCompleted(OptRef selected_ctx, + bool staple, bool async) PURE; + + /** + * @return CertificateSelectionStatus the cert selection status. + */ + virtual CertificateSelectionStatus certificateSelectionResult() const PURE; }; } // namespace Ssl diff --git a/mobile/test/common/integration/test_server.cc b/mobile/test/common/integration/test_server.cc index 593859ee8354..c8727516a5eb 100644 --- a/mobile/test/common/integration/test_server.cc +++ b/mobile/test/common/integration/test_server.cc @@ -316,8 +316,8 @@ Network::DownstreamTransportSocketFactoryPtr TestServer::createUpstreamTlsContex ctx->mutable_trusted_ca()->set_filename( TestEnvironment::runfilesPath("test/config/integration/certs/upstreamcacert.pem")); tls_context.mutable_common_tls_context()->add_alpn_protocols("h2"); - auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create(tls_context, - factory_context); + auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( + tls_context, factory_context, false); static auto* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); return *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( std::move(cfg), context_manager_, *upstream_stats_store->rootScope(), diff --git a/mobile/test/performance/files_em_does_not_use b/mobile/test/performance/files_em_does_not_use index 2d52ec1d9c29..9d2c0e093dbd 100644 --- a/mobile/test/performance/files_em_does_not_use +++ b/mobile/test/performance/files_em_does_not_use @@ -17,3 +17,4 @@ source/common/router/vhds.h source/common/tls/server_context_config_impl.h source/common/tls/server_context_impl.h source/common/tls/ocsp/ocsp.h +source/common/tls/default_tls_certificate_selector.h diff --git a/source/common/quic/quic_server_transport_socket_factory.cc b/source/common/quic/quic_server_transport_socket_factory.cc index 8c636746c7cc..f7ee3199c84c 100644 --- a/source/common/quic/quic_server_transport_socket_factory.cc +++ b/source/common/quic/quic_server_transport_socket_factory.cc @@ -21,7 +21,7 @@ QuicServerTransportSocketConfigFactory::createTransportSocketFactory( config, context.messageValidationVisitor()); absl::StatusOr> server_config_or_error = Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - quic_transport.downstream_tls_context(), context); + quic_transport.downstream_tls_context(), context, true); RETURN_IF_NOT_OK(server_config_or_error.status()); auto server_config = std::move(server_config_or_error.value()); // TODO(RyanTheOptimist): support TLS client authentication. diff --git a/source/common/tls/BUILD b/source/common/tls/BUILD index bcbf3e40d803..ca18f49206ae 100644 --- a/source/common/tls/BUILD +++ b/source/common/tls/BUILD @@ -168,6 +168,7 @@ envoy_cc_library( hdrs = ["server_context_config_impl.h"], deps = [ ":context_config_lib", + ":server_context_lib", "@envoy_api//envoy/extensions/transport_sockets/tls/v3:pkg_cc_proto", ], ) @@ -222,9 +223,11 @@ envoy_cc_library( envoy_cc_library( name = "server_context_lib", srcs = [ + "default_tls_certificate_selector.cc", "server_context_impl.cc", ], hdrs = [ + "default_tls_certificate_selector.h", "server_context_impl.h", ], deps = [ diff --git a/source/common/tls/context_config_impl.cc b/source/common/tls/context_config_impl.cc index 1a6675eb0662..3b5247eb848b 100644 --- a/source/common/tls/context_config_impl.cc +++ b/source/common/tls/context_config_impl.cc @@ -9,6 +9,7 @@ #include "source/common/common/empty_string.h" #include "source/common/config/datasource.h" #include "source/common/network/cidr_range.h" +#include "source/common/protobuf/message_validator_impl.h" #include "source/common/protobuf/utility.h" #include "source/common/secret/sds_api.h" #include "source/common/ssl/certificate_validation_context_config_impl.h" diff --git a/source/common/tls/context_impl.cc b/source/common/tls/context_impl.cc index 8e997df07399..fed9720a46a3 100644 --- a/source/common/tls/context_impl.cc +++ b/source/common/tls/context_impl.cc @@ -692,7 +692,7 @@ ValidationResults ContextImpl::customVerifyCertChainForQuic( namespace Ssl { -bool TlsContext::isCipherEnabled(uint16_t cipher_id, uint16_t client_version) { +bool TlsContext::isCipherEnabled(uint16_t cipher_id, uint16_t client_version) const { const SSL_CIPHER* c = SSL_get_cipher_by_value(cipher_id); if (c == nullptr) { return false; diff --git a/source/common/tls/context_impl.h b/source/common/tls/context_impl.h index dfeef9288243..080f1ecb11a8 100644 --- a/source/common/tls/context_impl.h +++ b/source/common/tls/context_impl.h @@ -57,7 +57,7 @@ struct TlsContext { #endif std::string getCertChainFileName() const { return cert_chain_file_path_; }; - bool isCipherEnabled(uint16_t cipher_id, uint16_t client_version); + bool isCipherEnabled(uint16_t cipher_id, uint16_t client_version) const; Envoy::Ssl::PrivateKeyMethodProviderSharedPtr getPrivateKeyMethodProvider() { return private_key_method_provider_; } diff --git a/source/common/tls/default_tls_certificate_selector.cc b/source/common/tls/default_tls_certificate_selector.cc new file mode 100644 index 000000000000..61dac47be2e5 --- /dev/null +++ b/source/common/tls/default_tls_certificate_selector.cc @@ -0,0 +1,269 @@ +#include "source/common/tls/default_tls_certificate_selector.h" + +#include "source/common/tls/utility.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +DefaultTlsCertificateSelector::DefaultTlsCertificateSelector( + const Ssl::ServerContextConfig& config, Ssl::TlsCertificateSelectorContext& selector_ctx) + : server_ctx_(dynamic_cast(selector_ctx)), + tls_contexts_(selector_ctx.getTlsContexts()), ocsp_staple_policy_(config.ocspStaplePolicy()), + full_scan_certs_on_sni_mismatch_(config.fullScanCertsOnSNIMismatch()) { + for (auto& ctx : tls_contexts_) { + if (ctx.cert_chain_ == nullptr) { + continue; + } + bssl::UniquePtr public_key(X509_get_pubkey(ctx.cert_chain_.get())); + const int pkey_id = EVP_PKEY_id(public_key.get()); + // Load DNS SAN entries and Subject Common Name as server name patterns after certificate + // chain loaded, and populate ServerNamesMap which will be used to match SNI. + has_rsa_ |= (pkey_id == EVP_PKEY_RSA); + populateServerNamesMap(ctx, pkey_id); + } +}; + +void DefaultTlsCertificateSelector::populateServerNamesMap(const Ssl::TlsContext& ctx, + int pkey_id) { + if (ctx.cert_chain_ == nullptr) { + return; + } + + auto populate = [&](const std::string& sn) { + std::string sn_pattern = sn; + if (absl::StartsWith(sn, "*.")) { + sn_pattern = sn.substr(1); + } + PkeyTypesMap pkey_types_map; + // Multiple certs with different key type are allowed for one server name pattern. + auto sn_match = server_names_map_.try_emplace(sn_pattern, pkey_types_map).first; + auto pt_match = sn_match->second.find(pkey_id); + if (pt_match != sn_match->second.end()) { + // When there are duplicate names, prefer the earlier one. + // + // If all of the SANs in a certificate are unused due to duplicates, it could be useful + // to issue a warning, but that would require additional tracking that hasn't been + // implemented. + return; + } + sn_match->second.emplace( + std::pair>(pkey_id, ctx)); + }; + + bssl::UniquePtr san_names(static_cast( + X509_get_ext_d2i(ctx.cert_chain_.get(), NID_subject_alt_name, nullptr, nullptr))); + if (san_names != nullptr) { + auto dns_sans = Utility::getSubjectAltNames(*ctx.cert_chain_, GEN_DNS); + // https://www.rfc-editor.org/rfc/rfc6066#section-3 + // Currently, the only server names supported are DNS hostnames, so we + // only save dns san entries to match SNI. + for (const auto& san : dns_sans) { + populate(san); + } + } else { + // https://www.rfc-editor.org/rfc/rfc6125#section-6.4.4 + // As noted, a client MUST NOT seek a match for a reference identifier + // of CN-ID if the presented identifiers include a DNS-ID, SRV-ID, + // URI-ID, or any application-specific identifier types supported by the + // client. + X509_NAME* cert_subject = X509_get_subject_name(ctx.cert_chain_.get()); + const int cn_index = X509_NAME_get_index_by_NID(cert_subject, NID_commonName, -1); + if (cn_index >= 0) { + X509_NAME_ENTRY* cn_entry = X509_NAME_get_entry(cert_subject, cn_index); + if (cn_entry) { + ASN1_STRING* cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry); + if (ASN1_STRING_length(cn_asn1) > 0) { + std::string subject_cn(reinterpret_cast(ASN1_STRING_data(cn_asn1)), + ASN1_STRING_length(cn_asn1)); + populate(subject_cn); + } + } + } + } +} + +Ssl::SelectionResult +DefaultTlsCertificateSelector::selectTlsContext(const SSL_CLIENT_HELLO& ssl_client_hello, + Ssl::CertificateSelectionCallbackPtr) { + absl::string_view sni = + absl::NullSafeStringView(SSL_get_servername(ssl_client_hello.ssl, TLSEXT_NAMETYPE_host_name)); + const bool client_ecdsa_capable = server_ctx_.isClientEcdsaCapable(ssl_client_hello); + const bool client_ocsp_capable = server_ctx_.isClientOcspCapable(ssl_client_hello); + + auto [selected_ctx, ocsp_staple_action] = + findTlsContext(sni, client_ecdsa_capable, client_ocsp_capable, nullptr); + + auto stats = server_ctx_.stats(); + if (client_ocsp_capable) { + stats.ocsp_staple_requests_.inc(); + } + + switch (ocsp_staple_action) { + case Ssl::OcspStapleAction::Staple: + stats.ocsp_staple_responses_.inc(); + break; + case Ssl::OcspStapleAction::NoStaple: + stats.ocsp_staple_omitted_.inc(); + break; + case Ssl::OcspStapleAction::Fail: + stats.ocsp_staple_failed_.inc(); + return {Ssl::SelectionResult::SelectionStatus::Failed, nullptr, false}; + case Ssl::OcspStapleAction::ClientNotCapable: + // This happens when client does not support OCSP, do nothing. + break; + } + + return {Ssl::SelectionResult::SelectionStatus::Success, &selected_ctx, + ocsp_staple_action == Ssl::OcspStapleAction::Staple}; +} + +Ssl::OcspStapleAction DefaultTlsCertificateSelector::ocspStapleAction(const Ssl::TlsContext& ctx, + bool client_ocsp_capable) { + if (!client_ocsp_capable) { + return Ssl::OcspStapleAction::ClientNotCapable; + } + + auto& response = ctx.ocsp_response_; + + auto policy = ocsp_staple_policy_; + if (ctx.is_must_staple_) { + // The certificate has the must-staple extension, so upgrade the policy to match. + policy = Ssl::ServerContextConfig::OcspStaplePolicy::MustStaple; + } + + const bool valid_response = response && !response->isExpired(); + + switch (policy) { + case Ssl::ServerContextConfig::OcspStaplePolicy::LenientStapling: + if (!valid_response) { + return Ssl::OcspStapleAction::NoStaple; + } + return Ssl::OcspStapleAction::Staple; + + case Ssl::ServerContextConfig::OcspStaplePolicy::StrictStapling: + if (valid_response) { + return Ssl::OcspStapleAction::Staple; + } + if (response) { + // Expired response. + return Ssl::OcspStapleAction::Fail; + } + return Ssl::OcspStapleAction::NoStaple; + + case Ssl::ServerContextConfig::OcspStaplePolicy::MustStaple: + if (!valid_response) { + return Ssl::OcspStapleAction::Fail; + } + return Ssl::OcspStapleAction::Staple; + } + PANIC_DUE_TO_CORRUPT_ENUM; +} + +std::pair +DefaultTlsCertificateSelector::findTlsContext(absl::string_view sni, bool client_ecdsa_capable, + bool client_ocsp_capable, bool* cert_matched_sni) { + bool unused = false; + if (cert_matched_sni == nullptr) { + // Avoid need for nullptr checks when this is set. + cert_matched_sni = &unused; + } + + // selected_ctx represents the final selected certificate, it should meet all requirements or pick + // a candidate. + const Ssl::TlsContext* selected_ctx = nullptr; + const Ssl::TlsContext* candidate_ctx = nullptr; + Ssl::OcspStapleAction ocsp_staple_action; + + auto selected = [&](const Ssl::TlsContext& ctx) -> bool { + auto action = ocspStapleAction(ctx, client_ocsp_capable); + if (action == Ssl::OcspStapleAction::Fail) { + // The selected ctx must adhere to OCSP policy + return false; + } + + if (client_ecdsa_capable == ctx.is_ecdsa_) { + selected_ctx = &ctx; + ocsp_staple_action = action; + return true; + } + + if (client_ecdsa_capable && !ctx.is_ecdsa_ && candidate_ctx == nullptr) { + // ECDSA cert is preferred if client is ECDSA capable, so RSA cert is marked as a candidate, + // searching will continue until exhausting all certs or find a exact match. + candidate_ctx = &ctx; + ocsp_staple_action = action; + return false; + } + + return false; + }; + + auto select_from_map = [this, &selected](absl::string_view server_name) -> void { + auto it = server_names_map_.find(server_name); + if (it == server_names_map_.end()) { + return; + } + const auto& pkey_types_map = it->second; + for (const auto& entry : pkey_types_map) { + if (selected(entry.second.get())) { + break; + } + } + }; + + auto tail_select = [&](bool go_to_next_phase) { + if (selected_ctx == nullptr) { + selected_ctx = candidate_ctx; + } + + if (selected_ctx == nullptr && !go_to_next_phase) { + selected_ctx = &tls_contexts_[0]; + ocsp_staple_action = ocspStapleAction(*selected_ctx, client_ocsp_capable); + } + }; + + // Select cert based on SNI if SNI is provided by client. + if (!sni.empty()) { + // Match on exact server name, i.e. "www.example.com" for "www.example.com". + select_from_map(sni); + tail_select(true); + + if (selected_ctx == nullptr) { + // Match on wildcard domain, i.e. ".example.com" for "www.example.com". + // https://datatracker.ietf.org/doc/html/rfc6125#section-6.4 + size_t pos = sni.find('.', 1); + if (pos < sni.size() - 1 && pos != std::string::npos) { + absl::string_view wildcard = sni.substr(pos); + select_from_map(wildcard); + } + } + *cert_matched_sni = (selected_ctx != nullptr || candidate_ctx != nullptr); + // tail_select(full_scan_certs_on_sni_mismatch_); + tail_select(full_scan_certs_on_sni_mismatch_); + } + // Full scan certs if SNI is not provided by client; + // Full scan certs if client provides SNI but no cert matches to it, + // it requires full_scan_certs_on_sni_mismatch is enabled. + if (selected_ctx == nullptr) { + candidate_ctx = nullptr; + // Skip loop when there is no cert compatible to key type + if (client_ecdsa_capable || (!client_ecdsa_capable && has_rsa_)) { + for (const auto& ctx : tls_contexts_) { + if (selected(ctx)) { + break; + } + } + } + tail_select(false); + } + + ASSERT(selected_ctx != nullptr); + return {*selected_ctx, ocsp_staple_action}; +} + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/common/tls/default_tls_certificate_selector.h b/source/common/tls/default_tls_certificate_selector.h new file mode 100644 index 000000000000..6a462c55301a --- /dev/null +++ b/source/common/tls/default_tls_certificate_selector.h @@ -0,0 +1,85 @@ +#pragma once + +#include + +#include "envoy/ssl/handshaker.h" + +#include "source/common/tls/context_impl.h" +#include "source/common/tls/server_context_impl.h" +#include "source/common/tls/stats.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +// Defined in server_context_impl.h +class ServerContextImpl; + +/** + * The default TLS context provider, selecting certificate based on SNI. + */ +class DefaultTlsCertificateSelector : public Ssl::TlsCertificateSelector, + protected Logger::Loggable { +public: + DefaultTlsCertificateSelector(const Ssl::ServerContextConfig& config, + Ssl::TlsCertificateSelectorContext& selector_ctx); + + Ssl::SelectionResult selectTlsContext(const SSL_CLIENT_HELLO& ssl_client_hello, + Ssl::CertificateSelectionCallbackPtr cb) override; + + // Finds the best matching context. The returned context will have the same lifetime as + // ``ServerContextImpl``. + std::pair + findTlsContext(absl::string_view sni, bool client_ecdsa_capable, bool client_ocsp_capable, + bool* cert_matched_sni) override; + +private: + // Currently, at most one certificate of a given key type may be specified for each exact + // server name or wildcard domain name. + using PkeyTypesMap = absl::flat_hash_map>; + // Both exact server names and wildcard domains are part of the same map, in which wildcard + // domains are prefixed with "." (i.e. ".example.com" for "*.example.com") to differentiate + // between exact and wildcard entries. + using ServerNamesMap = absl::flat_hash_map; + + void populateServerNamesMap(const Ssl::TlsContext& ctx, const int pkey_id); + + Ssl::OcspStapleAction ocspStapleAction(const Ssl::TlsContext& ctx, bool client_ocsp_capable); + + // ServerContext own this selector, it's safe to use itself here. + ServerContextImpl& server_ctx_; + const std::vector& tls_contexts_; + + ServerNamesMap server_names_map_; + bool has_rsa_{false}; + + const Ssl::ServerContextConfig::OcspStaplePolicy ocsp_staple_policy_; + bool full_scan_certs_on_sni_mismatch_; +}; + +class TlsCertificateSelectorConfigFactoryImpl : public Ssl::TlsCertificateSelectorConfigFactory { +public: + std::string name() const override { return "envoy.tls.certificate_selectors.default"; } + Ssl::TlsCertificateSelectorFactory createTlsCertificateSelectorFactory( + const Protobuf::Message&, Server::Configuration::CommonFactoryContext&, + ProtobufMessage::ValidationVisitor&, absl::Status&, bool) override { + return [](const Ssl::ServerContextConfig& config, + Ssl::TlsCertificateSelectorContext& selector_ctx) { + return std::make_unique(config, selector_ctx); + }; + } + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + static Ssl::TlsCertificateSelectorConfigFactory* getDefaultTlsCertificateSelectorConfigFactory() { + static TlsCertificateSelectorConfigFactoryImpl default_tls_certificate_selector_config_factory; + return &default_tls_certificate_selector_config_factory; + } +}; + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/common/tls/server_context_config_impl.cc b/source/common/tls/server_context_config_impl.cc index e40ebd9d15ce..1a61f090332f 100644 --- a/source/common/tls/server_context_config_impl.cc +++ b/source/common/tls/server_context_config_impl.cc @@ -9,9 +9,11 @@ #include "source/common/common/empty_string.h" #include "source/common/config/datasource.h" #include "source/common/network/cidr_range.h" +#include "source/common/protobuf/message_validator_impl.h" #include "source/common/protobuf/utility.h" #include "source/common/secret/sds_api.h" #include "source/common/ssl/certificate_validation_context_config_impl.h" +#include "source/common/tls/default_tls_certificate_selector.h" #include "source/common/tls/ssl_handshaker.h" #include "openssl/ssl.h" @@ -101,10 +103,10 @@ const std::string ServerContextConfigImpl::DEFAULT_CURVES = absl::StatusOr> ServerContextConfigImpl::create( const envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext& config, - Server::Configuration::TransportSocketFactoryContext& secret_provider_context) { + Server::Configuration::TransportSocketFactoryContext& secret_provider_context, bool for_quic) { absl::Status creation_status = absl::OkStatus(); std::unique_ptr ret = absl::WrapUnique( - new ServerContextConfigImpl(config, secret_provider_context, creation_status)); + new ServerContextConfigImpl(config, secret_provider_context, creation_status, for_quic)); RETURN_IF_NOT_OK(creation_status); return ret; } @@ -112,7 +114,7 @@ absl::StatusOr> ServerContextConfigImpl ServerContextConfigImpl::ServerContextConfigImpl( const envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext& config, Server::Configuration::TransportSocketFactoryContext& factory_context, - absl::Status& creation_status) + absl::Status& creation_status, bool for_quic) : ContextConfigImpl(config.common_tls_context(), DEFAULT_MIN_VERSION, DEFAULT_MAX_VERSION, DEFAULT_CIPHER_SUITES, DEFAULT_CURVES, factory_context, creation_status), require_client_certificate_( @@ -156,6 +158,25 @@ ServerContextConfigImpl::ServerContextConfigImpl( session_timeout_ = std::chrono::seconds(DurationUtil::durationToSeconds(config.session_timeout())); } + + if (config.common_tls_context().has_custom_tls_certificate_selector()) { + // If a custom tls context provider is configured, derive the factory from the config. + const auto& provider_config = config.common_tls_context().custom_tls_certificate_selector(); + Ssl::TlsCertificateSelectorConfigFactory* provider_factory = + &Config::Utility::getAndCheckFactory( + provider_config); + tls_certificate_selector_factory_ = provider_factory->createTlsCertificateSelectorFactory( + provider_config.typed_config(), factory_context.serverFactoryContext(), + factory_context.messageValidationVisitor(), creation_status, for_quic); + return; + } + + auto factory = + TlsCertificateSelectorConfigFactoryImpl::getDefaultTlsCertificateSelectorConfigFactory(); + const ProtobufWkt::Any any; + tls_certificate_selector_factory_ = factory->createTlsCertificateSelectorFactory( + any, factory_context.serverFactoryContext(), ProtobufMessage::getNullValidationVisitor(), + creation_status, for_quic); } void ServerContextConfigImpl::setSecretUpdateCallback(std::function callback) { @@ -230,6 +251,13 @@ Ssl::ServerContextConfig::OcspStaplePolicy ServerContextConfigImpl::ocspStaplePo PANIC_DUE_TO_CORRUPT_ENUM; } +Ssl::TlsCertificateSelectorFactory ServerContextConfigImpl::tlsCertificateSelectorFactory() const { + if (!tls_certificate_selector_factory_) { + IS_ENVOY_BUG("No envoy.tls.certificate_selectors registered"); + } + return tls_certificate_selector_factory_; +} + } // namespace Tls } // namespace TransportSockets } // namespace Extensions diff --git a/source/common/tls/server_context_config_impl.h b/source/common/tls/server_context_config_impl.h index b11ddba70802..5bb69aab69b4 100644 --- a/source/common/tls/server_context_config_impl.h +++ b/source/common/tls/server_context_config_impl.h @@ -14,7 +14,8 @@ class ServerContextConfigImpl : public ContextConfigImpl, public Envoy::Ssl::Ser public: static absl::StatusOr> create(const envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext& config, - Server::Configuration::TransportSocketFactoryContext& secret_provider_context); + Server::Configuration::TransportSocketFactoryContext& secret_provider_context, + bool for_quic); // Ssl::ServerContextConfig bool requireClientCertificate() const override { return require_client_certificate_; } @@ -42,11 +43,13 @@ class ServerContextConfigImpl : public ContextConfigImpl, public Envoy::Ssl::Ser bool fullScanCertsOnSNIMismatch() const override { return full_scan_certs_on_sni_mismatch_; } bool preferClientCiphers() const override { return prefer_client_ciphers_; } + Ssl::TlsCertificateSelectorFactory tlsCertificateSelectorFactory() const override; + private: ServerContextConfigImpl( const envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext& config, Server::Configuration::TransportSocketFactoryContext& secret_provider_context, - absl::Status& creation_status); + absl::Status& creation_status, bool for_quic); static const unsigned DEFAULT_MIN_VERSION; static const unsigned DEFAULT_MAX_VERSION; @@ -68,6 +71,7 @@ class ServerContextConfigImpl : public ContextConfigImpl, public Envoy::Ssl::Ser const envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext::OcspStaplePolicy& policy); + Ssl::TlsCertificateSelectorFactory tls_certificate_selector_factory_; absl::optional session_timeout_; const bool disable_stateless_session_resumption_; const bool disable_stateful_session_resumption_; diff --git a/source/common/tls/server_context_impl.cc b/source/common/tls/server_context_impl.cc index b7ac27c718a6..7b76f958f138 100644 --- a/source/common/tls/server_context_impl.cc +++ b/source/common/tls/server_context_impl.cc @@ -96,29 +96,19 @@ ServerContextImpl::ServerContextImpl(Stats::Scope& scope, absl::Status& creation_status) : ContextImpl(scope, config, factory_context, additional_init, creation_status), session_ticket_keys_(config.sessionTicketKeys()), - ocsp_staple_policy_(config.ocspStaplePolicy()), - full_scan_certs_on_sni_mismatch_(config.fullScanCertsOnSNIMismatch()) { + ocsp_staple_policy_(config.ocspStaplePolicy()) { if (!creation_status.ok()) { return; } + // If creation failed, do not create the selector. + tls_certificate_selector_ = config.tlsCertificateSelectorFactory()(config, *this); + if (config.tlsCertificates().empty() && !config.capabilities().provides_certificates) { creation_status = absl::InvalidArgumentError("Server TlsCertificates must have a certificate specified"); return; } - for (auto& ctx : tls_contexts_) { - if (ctx.cert_chain_ == nullptr) { - continue; - } - bssl::UniquePtr public_key(X509_get_pubkey(ctx.cert_chain_.get())); - const int pkey_id = EVP_PKEY_id(public_key.get()); - // Load DNS SAN entries and Subject Common Name as server name patterns after certificate - // chain loaded, and populate ServerNamesMap which will be used to match SNI. - has_rsa_ |= (pkey_id == EVP_PKEY_RSA); - populateServerNamesMap(ctx, pkey_id); - } - // Compute the session context ID hash. We use all the certificate identities, // since we should have a common ID for session resumption no matter what cert // is used. We do this early because it can fail. @@ -221,63 +211,6 @@ ServerContextImpl::ServerContextImpl(Stats::Scope& scope, } } -void ServerContextImpl::populateServerNamesMap(Ssl::TlsContext& ctx, int pkey_id) { - if (ctx.cert_chain_ == nullptr) { - return; - } - - auto populate = [&](const std::string& sn) { - std::string sn_pattern = sn; - if (absl::StartsWith(sn, "*.")) { - sn_pattern = sn.substr(1); - } - PkeyTypesMap pkey_types_map; - // Multiple certs with different key type are allowed for one server name pattern. - auto sn_match = server_names_map_.try_emplace(sn_pattern, pkey_types_map).first; - auto pt_match = sn_match->second.find(pkey_id); - if (pt_match != sn_match->second.end()) { - // When there are duplicate names, prefer the earlier one. - // - // If all of the SANs in a certificate are unused due to duplicates, it could be useful - // to issue a warning, but that would require additional tracking that hasn't been - // implemented. - return; - } - sn_match->second.emplace(std::pair>(pkey_id, ctx)); - }; - - bssl::UniquePtr san_names(static_cast( - X509_get_ext_d2i(ctx.cert_chain_.get(), NID_subject_alt_name, nullptr, nullptr))); - if (san_names != nullptr) { - auto dns_sans = Utility::getSubjectAltNames(*ctx.cert_chain_, GEN_DNS); - // https://www.rfc-editor.org/rfc/rfc6066#section-3 - // Currently, the only server names supported are DNS hostnames, so we - // only save dns san entries to match SNI. - for (const auto& san : dns_sans) { - populate(san); - } - } else { - // https://www.rfc-editor.org/rfc/rfc6125#section-6.4.4 - // As noted, a client MUST NOT seek a match for a reference identifier - // of CN-ID if the presented identifiers include a DNS-ID, SRV-ID, - // URI-ID, or any application-specific identifier types supported by the - // client. - X509_NAME* cert_subject = X509_get_subject_name(ctx.cert_chain_.get()); - const int cn_index = X509_NAME_get_index_by_NID(cert_subject, NID_commonName, -1); - if (cn_index >= 0) { - X509_NAME_ENTRY* cn_entry = X509_NAME_get_entry(cert_subject, cn_index); - if (cn_entry) { - ASN1_STRING* cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry); - if (ASN1_STRING_length(cn_asn1) > 0) { - std::string subject_cn(reinterpret_cast(ASN1_STRING_data(cn_asn1)), - ASN1_STRING_length(cn_asn1)); - populate(subject_cn); - } - } - } - } -} - absl::StatusOr ServerContextImpl::generateHashForSessionContextId(const std::vector& server_names) { uint8_t hash_buffer[EVP_MAX_MD_SIZE]; @@ -443,23 +376,23 @@ int ServerContextImpl::sessionTicketProcess(SSL*, uint8_t* key_name, uint8_t* iv } } -bool ServerContextImpl::isClientEcdsaCapable(const SSL_CLIENT_HELLO* ssl_client_hello) { +bool ServerContextImpl::isClientEcdsaCapable(const SSL_CLIENT_HELLO& ssl_client_hello) const { CBS client_hello; - CBS_init(&client_hello, ssl_client_hello->client_hello, ssl_client_hello->client_hello_len); + CBS_init(&client_hello, ssl_client_hello.client_hello, ssl_client_hello.client_hello_len); // This is the TLSv1.3 case (TLSv1.2 on the wire and the supported_versions extensions present). // We just need to look at signature algorithms. - const uint16_t client_version = ssl_client_hello->version; + const uint16_t client_version = ssl_client_hello.version; if (client_version == TLS1_2_VERSION && tls_max_version_ == TLS1_3_VERSION) { // If the supported_versions extension is found then we assume that the client is competent // enough that just checking the signature_algorithms is sufficient. const uint8_t* supported_versions_data; size_t supported_versions_len; - if (SSL_early_callback_ctx_extension_get(ssl_client_hello, TLSEXT_TYPE_supported_versions, + if (SSL_early_callback_ctx_extension_get(&ssl_client_hello, TLSEXT_TYPE_supported_versions, &supported_versions_data, &supported_versions_len)) { const uint8_t* signature_algorithms_data; size_t signature_algorithms_len; - if (SSL_early_callback_ctx_extension_get(ssl_client_hello, TLSEXT_TYPE_signature_algorithms, + if (SSL_early_callback_ctx_extension_get(&ssl_client_hello, TLSEXT_TYPE_signature_algorithms, &signature_algorithms_data, &signature_algorithms_len)) { CBS signature_algorithms_ext, signature_algorithms; @@ -481,7 +414,7 @@ bool ServerContextImpl::isClientEcdsaCapable(const SSL_CLIENT_HELLO* ssl_client_ // ECDSA and also for a compatible cipher suite. https://tools.ietf.org/html/rfc4492#section-5.1.1 const uint8_t* curvelist_data; size_t curvelist_len; - if (!SSL_early_callback_ctx_extension_get(ssl_client_hello, TLSEXT_TYPE_supported_groups, + if (!SSL_early_callback_ctx_extension_get(&ssl_client_hello, TLSEXT_TYPE_supported_groups, &curvelist_data, &curvelist_len)) { return false; } @@ -496,7 +429,7 @@ bool ServerContextImpl::isClientEcdsaCapable(const SSL_CLIENT_HELLO* ssl_client_ // The client must have offered an ECDSA ciphersuite that we like. CBS cipher_suites; - CBS_init(&cipher_suites, ssl_client_hello->cipher_suites, ssl_client_hello->cipher_suites_len); + CBS_init(&cipher_suites, ssl_client_hello.cipher_suites, ssl_client_hello.cipher_suites_len); while (CBS_len(&cipher_suites) > 0) { uint16_t cipher_id; @@ -513,10 +446,10 @@ bool ServerContextImpl::isClientEcdsaCapable(const SSL_CLIENT_HELLO* ssl_client_ return false; } -bool ServerContextImpl::isClientOcspCapable(const SSL_CLIENT_HELLO* ssl_client_hello) { +bool ServerContextImpl::isClientOcspCapable(const SSL_CLIENT_HELLO& ssl_client_hello) const { const uint8_t* status_request_data; size_t status_request_len; - if (SSL_early_callback_ctx_extension_get(ssl_client_hello, TLSEXT_TYPE_status_request, + if (SSL_early_callback_ctx_extension_get(&ssl_client_hello, TLSEXT_TYPE_status_request, &status_request_data, &status_request_len)) { return true; } @@ -524,190 +457,66 @@ bool ServerContextImpl::isClientOcspCapable(const SSL_CLIENT_HELLO* ssl_client_h return false; } -OcspStapleAction ServerContextImpl::ocspStapleAction(const Ssl::TlsContext& ctx, - bool client_ocsp_capable) { - if (!client_ocsp_capable) { - return OcspStapleAction::ClientNotCapable; - } - - auto& response = ctx.ocsp_response_; - - auto policy = ocsp_staple_policy_; - if (ctx.is_must_staple_) { - // The certificate has the must-staple extension, so upgrade the policy to match. - policy = Ssl::ServerContextConfig::OcspStaplePolicy::MustStaple; - } - - const bool valid_response = response && !response->isExpired(); - - switch (policy) { - case Ssl::ServerContextConfig::OcspStaplePolicy::LenientStapling: - if (!valid_response) { - return OcspStapleAction::NoStaple; - } - return OcspStapleAction::Staple; - - case Ssl::ServerContextConfig::OcspStaplePolicy::StrictStapling: - if (valid_response) { - return OcspStapleAction::Staple; - } - if (response) { - // Expired response. - return OcspStapleAction::Fail; - } - return OcspStapleAction::NoStaple; - - case Ssl::ServerContextConfig::OcspStaplePolicy::MustStaple: - if (!valid_response) { - return OcspStapleAction::Fail; - } - return OcspStapleAction::Staple; - } - PANIC_DUE_TO_CORRUPT_ENUM; -} - -std::pair +std::pair ServerContextImpl::findTlsContext(absl::string_view sni, bool client_ecdsa_capable, bool client_ocsp_capable, bool* cert_matched_sni) { - bool unused = false; - if (cert_matched_sni == nullptr) { - // Avoid need for nullptr checks when this is set. - cert_matched_sni = &unused; - } - - // selected_ctx represents the final selected certificate, it should meet all requirements or pick - // a candidate. - const Ssl::TlsContext* selected_ctx = nullptr; - const Ssl::TlsContext* candidate_ctx = nullptr; - OcspStapleAction ocsp_staple_action; - - auto selected = [&](const Ssl::TlsContext& ctx) -> bool { - auto action = ocspStapleAction(ctx, client_ocsp_capable); - if (action == OcspStapleAction::Fail) { - // The selected ctx must adhere to OCSP policy - return false; - } - - if (client_ecdsa_capable == ctx.is_ecdsa_) { - selected_ctx = &ctx; - ocsp_staple_action = action; - return true; - } - - if (client_ecdsa_capable && !ctx.is_ecdsa_ && candidate_ctx == nullptr) { - // ECDSA cert is preferred if client is ECDSA capable, so RSA cert is marked as a candidate, - // searching will continue until exhausting all certs or find a exact match. - candidate_ctx = &ctx; - ocsp_staple_action = action; - return false; - } + return tls_certificate_selector_->findTlsContext(sni, client_ecdsa_capable, client_ocsp_capable, + cert_matched_sni); +} - return false; - }; +enum ssl_select_cert_result_t +ServerContextImpl::selectTlsContext(const SSL_CLIENT_HELLO* ssl_client_hello) { + ASSERT(tls_certificate_selector_ != nullptr); - auto select_from_map = [this, &selected](absl::string_view server_name) -> void { - auto it = server_names_map_.find(server_name); - if (it == server_names_map_.end()) { - return; - } - const auto& pkey_types_map = it->second; - for (const auto& entry : pkey_types_map) { - if (selected(entry.second.get())) { - break; - } - } - }; + auto* extended_socket_info = reinterpret_cast( + SSL_get_ex_data(ssl_client_hello->ssl, ContextImpl::sslExtendedSocketInfoIndex())); - auto tail_select = [&](bool go_to_next_phase) { - if (selected_ctx == nullptr) { - selected_ctx = candidate_ctx; - } + auto selection_result = extended_socket_info->certificateSelectionResult(); + switch (selection_result) { + case Ssl::CertificateSelectionStatus::NotStarted: + // continue + break; - if (selected_ctx == nullptr && !go_to_next_phase) { - selected_ctx = &tls_contexts_[0]; - ocsp_staple_action = ocspStapleAction(*selected_ctx, client_ocsp_capable); - } - }; - - // Select cert based on SNI if SNI is provided by client. - if (!sni.empty()) { - // Match on exact server name, i.e. "www.example.com" for "www.example.com". - select_from_map(sni); - tail_select(true); - - if (selected_ctx == nullptr) { - // Match on wildcard domain, i.e. ".example.com" for "www.example.com". - // https://datatracker.ietf.org/doc/html/rfc6125#section-6.4 - size_t pos = sni.find('.', 1); - if (pos < sni.size() - 1 && pos != std::string::npos) { - absl::string_view wildcard = sni.substr(pos); - select_from_map(wildcard); - } - } - *cert_matched_sni = (selected_ctx != nullptr || candidate_ctx != nullptr); - tail_select(full_scan_certs_on_sni_mismatch_); - } - // Full scan certs if SNI is not provided by client; - // Full scan certs if client provides SNI but no cert matches to it, - // it requires full_scan_certs_on_sni_mismatch is enabled. - if (selected_ctx == nullptr) { - candidate_ctx = nullptr; - // Skip loop when there is no cert compatible to key type - if (client_ecdsa_capable || (!client_ecdsa_capable && has_rsa_)) { - for (const auto& ctx : tls_contexts_) { - if (selected(ctx)) { - break; - } - } - } - tail_select(false); - } + case Ssl::CertificateSelectionStatus::Pending: + ENVOY_LOG(trace, "already waiting certificate"); + return ssl_select_cert_retry; - ASSERT(selected_ctx != nullptr); - return {*selected_ctx, ocsp_staple_action}; -} + case Ssl::CertificateSelectionStatus::Successful: + ENVOY_LOG(trace, "wait certificate success"); + return ssl_select_cert_success; -enum ssl_select_cert_result_t -ServerContextImpl::selectTlsContext(const SSL_CLIENT_HELLO* ssl_client_hello) { - absl::string_view sni = absl::NullSafeStringView( - SSL_get_servername(ssl_client_hello->ssl, TLSEXT_NAMETYPE_host_name)); - const bool client_ecdsa_capable = isClientEcdsaCapable(ssl_client_hello); - const bool client_ocsp_capable = isClientOcspCapable(ssl_client_hello); - - auto [selected_ctx, ocsp_staple_action] = - findTlsContext(sni, client_ecdsa_capable, client_ocsp_capable, nullptr); - - // Apply the selected context. This must be done before OCSP stapling below - // since applying the context can remove the previously-set OCSP response. - // This will only return NULL if memory allocation fails. - RELEASE_ASSERT(SSL_set_SSL_CTX(ssl_client_hello->ssl, selected_ctx.ssl_ctx_.get()) != nullptr, - ""); - - if (client_ocsp_capable) { - stats_.ocsp_staple_requests_.inc(); + default: + ENVOY_LOG(trace, "wait certificate failed"); + return ssl_select_cert_error; } - switch (ocsp_staple_action) { - case OcspStapleAction::Staple: { - // We avoid setting the OCSP response if the client didn't request it, but doing so is safe. - RELEASE_ASSERT(selected_ctx.ocsp_response_, - "OCSP response must be present under OcspStapleAction::Staple"); - auto& resp_bytes = selected_ctx.ocsp_response_->rawBytes(); - int rc = SSL_set_ocsp_response(ssl_client_hello->ssl, resp_bytes.data(), resp_bytes.size()); - RELEASE_ASSERT(rc != 0, ""); - stats_.ocsp_staple_responses_.inc(); - } break; - case OcspStapleAction::NoStaple: - stats_.ocsp_staple_omitted_.inc(); - break; - case OcspStapleAction::Fail: - stats_.ocsp_staple_failed_.inc(); + ENVOY_LOG(trace, "TLS context selection result: {}, before selectTlsContext", + static_cast(selection_result)); + + const auto result = tls_certificate_selector_->selectTlsContext( + *ssl_client_hello, extended_socket_info->createCertificateSelectionCallback()); + + ENVOY_LOG(trace, + "TLS context selection result: {}, after selectTlsContext, selection result status: {}", + static_cast(extended_socket_info->certificateSelectionResult()), + static_cast(result.status)); + ASSERT(extended_socket_info->certificateSelectionResult() == + Ssl::CertificateSelectionStatus::Pending, + "invalid selection result"); + + switch (result.status) { + case Ssl::SelectionResult::SelectionStatus::Success: + extended_socket_info->onCertificateSelectionCompleted(*result.selected_ctx, result.staple, + false); + return ssl_select_cert_success; + case Ssl::SelectionResult::SelectionStatus::Pending: + return ssl_select_cert_retry; + case Ssl::SelectionResult::SelectionStatus::Failed: + extended_socket_info->onCertificateSelectionCompleted(OptRef(), false, + false); return ssl_select_cert_error; - case OcspStapleAction::ClientNotCapable: - break; } - - return ssl_select_cert_success; + PANIC_DUE_TO_CORRUPT_ENUM; } absl::StatusOr ServerContextFactoryImpl::createServerContext( diff --git a/source/common/tls/server_context_impl.h b/source/common/tls/server_context_impl.h index 8b296a97e475..ed3c359a17f3 100644 --- a/source/common/tls/server_context_impl.h +++ b/source/common/tls/server_context_impl.h @@ -22,6 +22,7 @@ #include "source/common/tls/cert_validator/cert_validator.h" #include "source/common/tls/context_impl.h" #include "source/common/tls/context_manager_impl.h" +#include "source/common/tls/default_tls_certificate_selector.h" #include "source/common/tls/ocsp/ocsp.h" #include "source/common/tls/stats.h" @@ -38,9 +39,9 @@ namespace Extensions { namespace TransportSockets { namespace Tls { -enum class OcspStapleAction { Staple, NoStaple, Fail, ClientNotCapable }; - -class ServerContextImpl : public ContextImpl, public Envoy::Ssl::ServerContext { +class ServerContextImpl : public ContextImpl, + public Envoy::Ssl::ServerContext, + public Envoy::Ssl::TlsCertificateSelectorContext { public: static absl::StatusOr> create(Stats::Scope& scope, const Envoy::Ssl::ServerContextConfig& config, @@ -48,6 +49,10 @@ class ServerContextImpl : public ContextImpl, public Envoy::Ssl::ServerContext { Server::Configuration::CommonFactoryContext& factory_context, Ssl::ContextAdditionalInitFunc additional_init); + // Ssl::TlsCertificateSelectorContext + // The returned vector has the same life-time as the Ssl::TlsCertificateSelectorContext. + const std::vector& getTlsContexts() const override { return tls_contexts_; }; + // Select the TLS certificate context in SSL_CTX_set_select_certificate_cb() callback with // ClientHello details. This is made public for use by custom TLS extensions who want to // manually create and use this as a client hello callback. @@ -55,45 +60,31 @@ class ServerContextImpl : public ContextImpl, public Envoy::Ssl::ServerContext { // Finds the best matching context. The returned context will have the same lifetime as // this ``ServerContextImpl``. - std::pair findTlsContext(absl::string_view sni, - bool client_ecdsa_capable, - bool client_ocsp_capable, - bool* cert_matched_sni); + std::pair findTlsContext(absl::string_view sni, + bool client_ecdsa_capable, + bool client_ocsp_capable, + bool* cert_matched_sni); + bool isClientEcdsaCapable(const SSL_CLIENT_HELLO& ssl_client_hello) const; + bool isClientOcspCapable(const SSL_CLIENT_HELLO& ssl_client_hello) const; private: ServerContextImpl(Stats::Scope& scope, const Envoy::Ssl::ServerContextConfig& config, const std::vector& server_names, Server::Configuration::CommonFactoryContext& factory_context, Ssl::ContextAdditionalInitFunc additional_init, absl::Status& creation_status); - - // Currently, at most one certificate of a given key type may be specified for each exact - // server name or wildcard domain name. - using PkeyTypesMap = absl::flat_hash_map>; - // Both exact server names and wildcard domains are part of the same map, in which wildcard - // domains are prefixed with "." (i.e. ".example.com" for "*.example.com") to differentiate - // between exact and wildcard entries. - using ServerNamesMap = absl::flat_hash_map; - - void populateServerNamesMap(Ssl::TlsContext& ctx, const int pkey_id); - using SessionContextID = std::array; int alpnSelectCallback(const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen); int sessionTicketProcess(SSL* ssl, uint8_t* key_name, uint8_t* iv, EVP_CIPHER_CTX* ctx, HMAC_CTX* hmac_ctx, int encrypt); - bool isClientEcdsaCapable(const SSL_CLIENT_HELLO* ssl_client_hello); - bool isClientOcspCapable(const SSL_CLIENT_HELLO* ssl_client_hello); - OcspStapleAction ocspStapleAction(const Ssl::TlsContext& ctx, bool client_ocsp_capable); absl::StatusOr generateHashForSessionContextId(const std::vector& server_names); + Ssl::TlsCertificateSelectorPtr tls_certificate_selector_; const std::vector session_ticket_keys_; const Ssl::ServerContextConfig::OcspStaplePolicy ocsp_staple_policy_; - ServerNamesMap server_names_map_; - bool has_rsa_{false}; - bool full_scan_certs_on_sni_mismatch_; }; class ServerContextFactoryImpl : public ServerContextFactory { diff --git a/source/common/tls/ssl_handshaker.cc b/source/common/tls/ssl_handshaker.cc index ab2c6f1fc0f8..df9d7cbf8382 100644 --- a/source/common/tls/ssl_handshaker.cc +++ b/source/common/tls/ssl_handshaker.cc @@ -6,6 +6,7 @@ #include "source/common/common/empty_string.h" #include "source/common/http/headers.h" #include "source/common/runtime/runtime_features.h" +#include "source/common/tls/context_impl.h" #include "source/common/tls/utility.h" using Envoy::Network::PostIoAction; @@ -29,10 +30,23 @@ void ValidateResultCallbackImpl::onCertValidationResult(bool succeeded, extended_socket_info_->onCertificateValidationCompleted(succeeded, true); } +void CertificateSelectionCallbackImpl::onSslHandshakeCancelled() { extended_socket_info_.reset(); } + +void CertificateSelectionCallbackImpl::onCertificateSelectionResult( + OptRef selected_ctx, bool staple) { + if (!extended_socket_info_.has_value()) { + return; + } + extended_socket_info_->onCertificateSelectionCompleted(selected_ctx, staple, true); +} + SslExtendedSocketInfoImpl::~SslExtendedSocketInfoImpl() { if (cert_validate_result_callback_.has_value()) { cert_validate_result_callback_->onSslHandshakeCancelled(); } + if (cert_selection_callback_.has_value()) { + cert_selection_callback_->onSslHandshakeCancelled(); + } } void SslExtendedSocketInfoImpl::setCertificateValidationStatus( @@ -64,6 +78,48 @@ Ssl::ValidateResultCallbackPtr SslExtendedSocketInfoImpl::createValidateResultCa return callback; } +void SslExtendedSocketInfoImpl::onCertificateSelectionCompleted( + OptRef selected_ctx, bool staple, bool async) { + RELEASE_ASSERT(cert_selection_result_ == Ssl::CertificateSelectionStatus::Pending, + "onCertificateSelectionCompleted twice"); + if (!selected_ctx.has_value()) { + cert_selection_result_ = Ssl::CertificateSelectionStatus::Failed; + } else { + cert_selection_result_ = Ssl::CertificateSelectionStatus::Successful; + // Apply the selected context. This must be done before OCSP stapling below + // since applying the context can remove the previously-set OCSP response. + // This will only return NULL if memory allocation fails. + RELEASE_ASSERT(SSL_set_SSL_CTX(ssl_handshaker_.ssl(), selected_ctx->ssl_ctx_.get()) != nullptr, + ""); + + if (staple) { + // We avoid setting the OCSP response if the client didn't request it, but doing so is safe. + RELEASE_ASSERT(selected_ctx->ocsp_response_, + "OCSP response must be present under OcspStapleAction::Staple"); + const std::vector& resp_bytes = selected_ctx->ocsp_response_->rawBytes(); + const int rc = + SSL_set_ocsp_response(ssl_handshaker_.ssl(), resp_bytes.data(), resp_bytes.size()); + RELEASE_ASSERT(rc != 0, ""); + } + } + if (cert_selection_callback_.has_value()) { + cert_selection_callback_.reset(); + // Resume handshake. + if (async) { + ssl_handshaker_.handshakeCallbacks()->onAsynchronousCertificateSelectionComplete(); + } + } +} + +Ssl::CertificateSelectionCallbackPtr +SslExtendedSocketInfoImpl::createCertificateSelectionCallback() { + auto callback = std::make_unique( + ssl_handshaker_.handshakeCallbacks()->connection().dispatcher(), *this); + cert_selection_callback_ = *callback; + cert_selection_result_ = Ssl::CertificateSelectionStatus::Pending; + return callback; +} + SslHandshakerImpl::SslHandshakerImpl(bssl::UniquePtr ssl, int ssl_extended_socket_info_index, Ssl::HandshakeCallbacks* handshake_callbacks) : ssl_(std::move(ssl)), handshake_callbacks_(handshake_callbacks), @@ -95,6 +151,7 @@ Network::PostIoAction SslHandshakerImpl::doHandshake() { case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: return PostIoAction::KeepOpen; + case SSL_ERROR_PENDING_CERTIFICATE: case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION: case SSL_ERROR_WANT_CERTIFICATE_VERIFY: state_ = Ssl::SocketState::HandshakeInProgress; diff --git a/source/common/tls/ssl_handshaker.h b/source/common/tls/ssl_handshaker.h index 1a61777e3b24..3a5e162d99ad 100644 --- a/source/common/tls/ssl_handshaker.h +++ b/source/common/tls/ssl_handshaker.h @@ -50,6 +50,25 @@ class ValidateResultCallbackImpl : public Ssl::ValidateResultCallback { OptRef extended_socket_info_; }; +class CertificateSelectionCallbackImpl : public Ssl::CertificateSelectionCallback, + protected Logger::Loggable { +public: + CertificateSelectionCallbackImpl(Event::Dispatcher& dispatcher, + SslExtendedSocketInfoImpl& extended_socket_info) + : dispatcher_(dispatcher), extended_socket_info_(extended_socket_info) {} + + Event::Dispatcher& dispatcher() override { return dispatcher_; } + + void onCertificateSelectionResult(OptRef selected_ctx, + bool staple) override; + + void onSslHandshakeCancelled(); + +private: + Event::Dispatcher& dispatcher_; + OptRef extended_socket_info_; +}; + class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { public: explicit SslExtendedSocketInfoImpl(SslHandshakerImpl& handshaker) : ssl_handshaker_(handshaker) {} @@ -67,6 +86,13 @@ class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { void setCertificateValidationAlert(uint8_t alert) { cert_validation_alert_ = alert; } + Ssl::CertificateSelectionCallbackPtr createCertificateSelectionCallback() override; + void onCertificateSelectionCompleted(OptRef selected_ctx, bool staple, + bool async) override; + Ssl::CertificateSelectionStatus certificateSelectionResult() const override { + return cert_selection_result_; + } + private: Envoy::Ssl::ClientValidationStatus certificate_validation_status_{ Envoy::Ssl::ClientValidationStatus::NotValidated}; @@ -79,6 +105,13 @@ class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { // Stores the validation result if there is any. // nullopt if no validation has ever been kicked off. Ssl::ValidateStatus cert_validation_result_{Ssl::ValidateStatus::NotStarted}; + // Latch the in-flight cert selection callback. + // nullopt if there is none. + OptRef cert_selection_callback_{absl::nullopt}; + // Stores the cert selection result if there is any. + // NotStarted if no cert selection has ever been kicked off. + Ssl::CertificateSelectionStatus cert_selection_result_{ + Ssl::CertificateSelectionStatus::NotStarted}; }; class SslHandshakerImpl : public ConnectionInfoImplBase, diff --git a/source/common/tls/ssl_socket.cc b/source/common/tls/ssl_socket.cc index dfbf77899ec3..0e279330fd1f 100644 --- a/source/common/tls/ssl_socket.cc +++ b/source/common/tls/ssl_socket.cc @@ -376,6 +376,15 @@ void SslSocket::onAsynchronousCertValidationComplete() { } } +void SslSocket::onAsynchronousCertificateSelectionComplete() { + ENVOY_CONN_LOG(debug, "Async cert selection completed", callbacks_->connection()); + if (info_->state() != Ssl::SocketState::HandshakeInProgress) { + IS_ENVOY_BUG(fmt::format("unexpected handshake state: {}", static_cast(info_->state()))); + return; + } + resumeHandshake(); +} + } // namespace Tls } // namespace TransportSockets } // namespace Extensions diff --git a/source/common/tls/ssl_socket.h b/source/common/tls/ssl_socket.h index d5254a556a47..4cdb34393f3f 100644 --- a/source/common/tls/ssl_socket.h +++ b/source/common/tls/ssl_socket.h @@ -73,6 +73,7 @@ class SslSocket : public Network::TransportSocket, void onFailure() override; Network::TransportSocketCallbacks* transportSocketCallbacks() override { return callbacks_; } void onAsynchronousCertValidationComplete() override; + void onAsynchronousCertificateSelectionComplete() override; SSL* rawSslForTest() const { return rawSsl(); } diff --git a/source/extensions/transport_sockets/tls/downstream_config.cc b/source/extensions/transport_sockets/tls/downstream_config.cc index 79a8d07118d3..0821dc9dbcd7 100644 --- a/source/extensions/transport_sockets/tls/downstream_config.cc +++ b/source/extensions/transport_sockets/tls/downstream_config.cc @@ -21,7 +21,7 @@ DownstreamSslSocketFactory::createTransportSocketFactory( MessageUtil::downcastAndValidate< const envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext&>( message, context.messageValidationVisitor()), - context); + context, false); RETURN_IF_NOT_OK(server_config_or_error.status()); return ServerSslSocketFactory::create(std::move(server_config_or_error.value()), context.sslContextManager(), context.statsScope(), diff --git a/test/common/grpc/grpc_client_integration_test_harness.h b/test/common/grpc/grpc_client_integration_test_harness.h index d60e93ab2764..6cf0eb2034bb 100644 --- a/test/common/grpc/grpc_client_integration_test_harness.h +++ b/test/common/grpc/grpc_client_integration_test_harness.h @@ -672,7 +672,7 @@ class GrpcSslClientIntegrationTest : public GrpcClientIntegrationTest { } auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - tls_context, factory_context_); + tls_context, factory_context_, false); static auto* upstream_stats_store = new Stats::IsolatedStoreImpl(); return *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( diff --git a/test/common/quic/BUILD b/test/common/quic/BUILD index 748958802567..dc2df66c3174 100644 --- a/test/common/quic/BUILD +++ b/test/common/quic/BUILD @@ -63,6 +63,7 @@ envoy_cc_test( "//source/common/quic:envoy_quic_proof_source_lib", "//source/common/quic:envoy_quic_proof_verifier_lib", "//source/common/tls:context_config_lib", + "//source/common/tls:server_context_lib", "//test/mocks/network:network_mocks", "//test/mocks/server:server_factory_context_mocks", "//test/mocks/ssl:ssl_mocks", diff --git a/test/common/quic/envoy_quic_proof_source_test.cc b/test/common/quic/envoy_quic_proof_source_test.cc index af94c48380a5..c831f9b71334 100644 --- a/test/common/quic/envoy_quic_proof_source_test.cc +++ b/test/common/quic/envoy_quic_proof_source_test.cc @@ -7,6 +7,7 @@ #include "source/common/quic/envoy_quic_utils.h" #include "source/common/tls/client_context_impl.h" #include "source/common/tls/context_config_impl.h" +#include "source/common/tls/default_tls_certificate_selector.h" #include "test/common/quic/test_utils.h" #include "test/mocks/network/mocks.h" @@ -191,6 +192,16 @@ class EnvoyQuicProofSourceTest : public ::testing::Test { EXPECT_CALL(filter_chain_, transportSocketFactory()) .WillRepeatedly(ReturnRef(*transport_socket_factory_)); + auto factory = Extensions::TransportSockets::Tls::TlsCertificateSelectorConfigFactoryImpl:: + getDefaultTlsCertificateSelectorConfigFactory(); + ASSERT_TRUE(factory); + const ProtobufWkt::Any any; + absl::Status creation_status = absl::OkStatus(); + auto tls_certificate_selector_factory_cb = factory->createTlsCertificateSelectorFactory( + any, factory_context_, ProtobufMessage::getNullValidationVisitor(), creation_status, true); + EXPECT_CALL(*mock_context_config_, tlsCertificateSelectorFactory()) + .WillRepeatedly(Return(tls_certificate_selector_factory_cb)); + EXPECT_CALL(*mock_context_config_, isReady()).WillRepeatedly(Return(true)); std::vector> tls_cert_configs{ std::reference_wrapper(tls_cert_config_)}; diff --git a/test/common/tls/BUILD b/test/common/tls/BUILD index af168e029f10..055c56ab68b1 100644 --- a/test/common/tls/BUILD +++ b/test/common/tls/BUILD @@ -16,6 +16,65 @@ envoy_cc_test_library( hdrs = ["ssl_certs_test.h"], ) +envoy_cc_test( + name = "tls_certificate_selector_test", + size = "large", + srcs = [ + "tls_certificate_selector_test.cc", + ], + data = [ + # TODO(mattklein123): We should consolidate all of our test certs in a single place as + # right now we have a bunch of duplication which is confusing. + "//test/config/integration/certs", + "//test/common/tls/ocsp/test_data:certs", + "//test/common/tls/test_data:certs", + ], + external_deps = ["ssl"], + deps = [ + ":ssl_certs_test_lib", + ":test_private_key_method_provider_test_lib", + "//envoy/network:transport_socket_interface", + "//source/common/buffer:buffer_lib", + "//source/common/common:empty_string", + "//source/common/event:dispatcher_includes", + "//source/common/event:dispatcher_lib", + "//source/common/json:json_loader_lib", + "//source/common/network:listen_socket_lib", + "//source/common/network:transport_socket_options_lib", + "//source/common/network:utility_lib", + "//source/common/stats:isolated_store_lib", + "//source/common/stats:stats_lib", + "//source/common/stream_info:stream_info_lib", + "//source/common/tls:context_config_lib", + "//source/common/tls:context_lib", + "//source/common/tls:server_context_config_lib", + "//source/common/tls:server_context_lib", + "//source/common/tls:ssl_socket_lib", + "//source/common/tls:utility_lib", + "//source/common/tls/private_key:private_key_manager_lib", + "//test/common/tls/cert_validator:timed_cert_validator", + "//test/common/tls/test_data:cert_infos", + "//test/mocks/buffer:buffer_mocks", + "//test/mocks/init:init_mocks", + "//test/mocks/local_info:local_info_mocks", + "//test/mocks/network:io_handle_mocks", + "//test/mocks/network:network_mocks", + "//test/mocks/runtime:runtime_mocks", + "//test/mocks/server:server_mocks", + "//test/mocks/ssl:ssl_mocks", + "//test/mocks/stats:stats_mocks", + "//test/test_common:environment_lib", + "//test/test_common:logging_lib", + "//test/test_common:network_utility_lib", + "//test/test_common:registry_lib", + "//test/test_common:simulated_time_system_lib", + "//test/test_common:test_runtime_lib", + "//test/test_common:utility_lib", + "@envoy_api//envoy/config/listener/v3:pkg_cc_proto", + "@envoy_api//envoy/extensions/transport_sockets/tls/v3:pkg_cc_proto", + ], +) + envoy_cc_test( name = "ssl_socket_test", size = "large", diff --git a/test/common/tls/cert_selector/BUILD b/test/common/tls/cert_selector/BUILD new file mode 100644 index 000000000000..c0b97bab585a --- /dev/null +++ b/test/common/tls/cert_selector/BUILD @@ -0,0 +1,24 @@ +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_test_library", + "envoy_package", +) + +licenses(["notice"]) # Apache 2 + +envoy_package() + +envoy_cc_test_library( + name = "async_cert_selector", + srcs = [ + "async_cert_selector.cc", + "stats.cc", + ], + hdrs = [ + "async_cert_selector.h", + "stats.h", + ], + deps = [ + "//source/common/tls:server_context_lib", + ], +) diff --git a/test/common/tls/cert_selector/async_cert_selector.cc b/test/common/tls/cert_selector/async_cert_selector.cc new file mode 100644 index 000000000000..70d70536c257 --- /dev/null +++ b/test/common/tls/cert_selector/async_cert_selector.cc @@ -0,0 +1,60 @@ +#include "test/common/tls/cert_selector/async_cert_selector.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +Ssl::SelectionResult +AsyncTlsCertificateSelector::selectTlsContext(const SSL_CLIENT_HELLO&, + Ssl::CertificateSelectionCallbackPtr cb) { + ENVOY_LOG_MISC(info, "debug: select context"); + + if (mode_ == "sync") { + stats_.cert_selection_sync_.inc(); + auto& tls_context = selector_ctx_.getTlsContexts()[0]; + return {Ssl::SelectionResult::SelectionStatus::Success, &tls_context, false}; + } + + if (mode_ == "async") { + ENVOY_LOG_MISC(info, "debug: select cert async"); + stats_.cert_selection_async_.inc(); + cb_ = std::move(cb); + cb_->dispatcher().post([this] { + selectTlsContextAsync(); + stats_.cert_selection_async_finished_.inc(); + }); + return {Ssl::SelectionResult::SelectionStatus::Pending, nullptr, false}; + } + + if (mode_ == "sleep") { + ENVOY_LOG_MISC(info, "debug: select cert sleep"); + // select cert async after 20ms + stats_.cert_selection_sleep_.inc(); + cb_ = std::move(cb); + selection_timer_ = cb_->dispatcher().createTimer([this] { + selectTlsContextAsync(); + stats_.cert_selection_sleep_finished_.inc(); + }); + selection_timer_->enableTimer(std::chrono::milliseconds(20)); + return {Ssl::SelectionResult::SelectionStatus::Pending, nullptr, false}; + } + + stats_.cert_selection_failed_.inc(); + return {Ssl::SelectionResult::SelectionStatus::Failed, nullptr, false}; +}; + +void AsyncTlsCertificateSelector::selectTlsContextAsync() { + ENVOY_LOG_MISC(info, "debug: select cert async done"); + // choose the first one. + auto& tls_context = selector_ctx_.getTlsContexts()[0]; + cb_->onCertificateSelectionResult(tls_context, false); + selection_timer_.reset(); +} + +REGISTER_FACTORY(AsyncTlsCertificateSelectorFactory, Ssl::TlsCertificateSelectorConfigFactory); + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/common/tls/cert_selector/async_cert_selector.h b/test/common/tls/cert_selector/async_cert_selector.h new file mode 100644 index 000000000000..b525c32179d3 --- /dev/null +++ b/test/common/tls/cert_selector/async_cert_selector.h @@ -0,0 +1,90 @@ +#pragma once + +#include + +#include "envoy/ssl/handshaker.h" + +#include "source/common/tls/context_impl.h" +#include "source/common/tls/server_context_impl.h" + +#include "stats.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +class AsyncTlsCertificateSelector : public Ssl::TlsCertificateSelector, + protected Logger::Loggable { +public: + AsyncTlsCertificateSelector(Stats::Scope& store, Ssl::TlsCertificateSelectorContext& selector_ctx, + std::string mode) + : stats_(generateCertSelectionStats(store)), selector_ctx_(selector_ctx), mode_(mode) {} + + ~AsyncTlsCertificateSelector() override { + ENVOY_LOG(info, "debug: ~AsyncTlsCertificateSelector"); + } + + Ssl::SelectionResult selectTlsContext(const SSL_CLIENT_HELLO&, + Ssl::CertificateSelectionCallbackPtr cb) override; + + // It's only for quic. + std::pair findTlsContext(absl::string_view, bool, + bool, bool*) override { + PANIC("unreachable"); + }; + + void selectTlsContextAsync(); + +private: + CertSelectionStats stats_; + Ssl::TlsCertificateSelectorContext& selector_ctx_; + Ssl::CertificateSelectionCallbackPtr cb_; + std::string mode_; + Event::TimerPtr selection_timer_; +}; + +class AsyncTlsCertificateSelectorFactory : public Ssl::TlsCertificateSelectorConfigFactory { +public: + Ssl::TlsCertificateSelectorFactory createTlsCertificateSelectorFactory( + const Protobuf::Message& config, Server::Configuration::CommonFactoryContext& factory_context, + ProtobufMessage::ValidationVisitor&, absl::Status& creation_status, bool for_quic) override { + if (for_quic) { + creation_status = absl::InvalidArgumentError("does not support for quic"); + return Ssl::TlsCertificateSelectorFactory(); + } + + std::string mode; + const ProtobufWkt::Any* any_config = dynamic_cast(&config); + if (any_config) { + ProtobufWkt::StringValue string_value; + if (any_config->UnpackTo(&string_value)) { + mode = string_value.value(); + } + } + if (mode.empty()) { + creation_status = absl::InvalidArgumentError("invalid cert selection mode"); + return Ssl::TlsCertificateSelectorFactory(); + } + + auto& scope = factory_context.scope(); + + return [mode, &scope](const Ssl::ServerContextConfig&, + Ssl::TlsCertificateSelectorContext& selector_ctx) { + return std::make_unique(scope, selector_ctx, mode); + }; + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return ProtobufTypes::MessagePtr{new ProtobufWkt::StringValue()}; + } + + std::string name() const override { return "test-tls-context-provider"; }; +}; + +DECLARE_FACTORY(AsyncTlsCertificateSelectorFactory); + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/common/tls/cert_selector/stats.cc b/test/common/tls/cert_selector/stats.cc new file mode 100644 index 000000000000..8ae98e6cf98d --- /dev/null +++ b/test/common/tls/cert_selector/stats.cc @@ -0,0 +1,21 @@ +#include "test/common/tls/cert_selector/stats.h" + +#include "envoy/stats/scope.h" +#include "envoy/stats/stats_macros.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +CertSelectionStats generateCertSelectionStats(Stats::Scope& store) { + std::string prefix("aysnc_cert_selection."); + return {ALL_CERT_SELECTION_STATS(POOL_COUNTER_PREFIX(store, prefix), + POOL_GAUGE_PREFIX(store, prefix), + POOL_HISTOGRAM_PREFIX(store, prefix))}; +} + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/common/tls/cert_selector/stats.h b/test/common/tls/cert_selector/stats.h new file mode 100644 index 000000000000..894bcdf01dfd --- /dev/null +++ b/test/common/tls/cert_selector/stats.h @@ -0,0 +1,32 @@ +#pragma once + +#include "envoy/stats/scope.h" +#include "envoy/stats/stats_macros.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +#define ALL_CERT_SELECTION_STATS(COUNTER, GAUGE, HISTOGRAM) \ + COUNTER(cert_selection_sync) \ + COUNTER(cert_selection_async) \ + COUNTER(cert_selection_async_finished) \ + COUNTER(cert_selection_sleep) \ + COUNTER(cert_selection_sleep_finished) \ + COUNTER(cert_selection_failed) + +/** + * Wrapper struct for SSL stats. @see stats_macros.h + */ +struct CertSelectionStats { + ALL_CERT_SELECTION_STATS(GENERATE_COUNTER_STRUCT, GENERATE_GAUGE_STRUCT, + GENERATE_HISTOGRAM_STRUCT) +}; + +CertSelectionStats generateCertSelectionStats(Stats::Scope& store); + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/common/tls/cert_validator/test_common.h b/test/common/tls/cert_validator/test_common.h index 2437644aa518..b948977a3e39 100644 --- a/test/common/tls/cert_validator/test_common.h +++ b/test/common/tls/cert_validator/test_common.h @@ -2,6 +2,7 @@ #include +#include "envoy/ssl/context.h" #include "envoy/ssl/context_config.h" #include "envoy/ssl/ssl_socket_extended_info.h" @@ -34,9 +35,23 @@ class TestSslExtendedSocketInfo : public Envoy::Ssl::SslExtendedSocketInfo { Ssl::ValidateStatus certificateValidationResult() const override { return validate_result_; } uint8_t certificateValidationAlert() const override { return SSL_AD_CERTIFICATE_UNKNOWN; } + Ssl::CertificateSelectionCallbackPtr createCertificateSelectionCallback() override { + return nullptr; + } + void onCertificateSelectionCompleted(OptRef selected_ctx, bool, + bool) override { + cert_selection_result_ = selected_ctx.has_value() ? Ssl::CertificateSelectionStatus::Successful + : Ssl::CertificateSelectionStatus::Failed; + } + Ssl::CertificateSelectionStatus certificateSelectionResult() const override { + return cert_selection_result_; + } + private: Envoy::Ssl::ClientValidationStatus status_; Ssl::ValidateStatus validate_result_{Ssl::ValidateStatus::NotStarted}; + Ssl::CertificateSelectionStatus cert_selection_result_{ + Ssl::CertificateSelectionStatus::NotStarted}; }; class TestCertificateValidationContextConfig diff --git a/test/common/tls/context_impl_test.cc b/test/common/tls/context_impl_test.cc index 0bbd94dbd4ae..206514f17323 100644 --- a/test/common/tls/context_impl_test.cc +++ b/test/common/tls/context_impl_test.cc @@ -155,7 +155,7 @@ TEST_F(SslContextImplTest, TestServerCipherPreference) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(yaml), tls_context); - auto cfg = ServerContextConfigImpl::create(tls_context, factory_context_).value(); + auto cfg = ServerContextConfigImpl::create(tls_context, factory_context_, false).value(); ASSERT_FALSE(cfg.get()->preferClientCiphers()); auto socket_factory = *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( @@ -180,7 +180,7 @@ TEST_F(SslContextImplTest, TestPreferClientCiphers) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(yaml), tls_context); - auto cfg = ServerContextConfigImpl::create(tls_context, factory_context_).value(); + auto cfg = ServerContextConfigImpl::create(tls_context, factory_context_, false).value(); ASSERT_TRUE(cfg.get()->preferClientCiphers()); auto socket_factory = *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( @@ -532,7 +532,8 @@ TEST_F(SslContextImplTest, DuplicateRsaCertSameExactDNSSan) { filename: "{{ test_rundir }}/test/common/tls/test_data/selfsigned_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_NO_THROW(loadConfig(*server_context_config)); } @@ -552,7 +553,8 @@ TEST_F(SslContextImplTest, DuplicateRsaCertSameWildcardDNSSan) { filename: "{{ test_rundir }}/test/common/tls/test_data/san_multiple_dns_1_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_NO_THROW(loadConfig(*server_context_config)); } @@ -572,7 +574,8 @@ TEST_F(SslContextImplTest, AcceptableMultipleRsaCerts) { filename: "{{ test_rundir }}/test/common/tls/test_data/san_dns_rsa_2_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_NO_THROW(loadConfig(*server_context_config)); } @@ -592,7 +595,8 @@ TEST_F(SslContextImplTest, DuplicateEcdsaCert) { filename: "{{ test_rundir }}/test/common/tls/test_data/selfsigned_ecdsa_p256_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_NO_THROW(loadConfig(*server_context_config)); } @@ -612,7 +616,8 @@ TEST_F(SslContextImplTest, AcceptableMultipleEcdsaCerts) { filename: "{{ test_rundir }}/test/common/tls/test_data/san_dns_ecdsa_2_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_NO_THROW(loadConfig(*server_context_config)); } @@ -631,7 +636,8 @@ TEST_F(SslContextImplTest, CertDuplicatedSansAndCN) { filename: "{{ test_rundir }}/test/common/tls/test_data/san_multiple_dns_1_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_NO_THROW(loadConfig(*server_context_config)); } @@ -655,7 +661,8 @@ TEST_F(SslContextImplTest, MultipleCertsSansAndCN) { filename: "{{ test_rundir }}/test/common/tls/test_data/san_wildcard_dns_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_NO_THROW(loadConfig(*server_context_config)); } @@ -671,7 +678,8 @@ TEST_F(SslContextImplTest, MustHaveSubjectOrSAN) { filename: "{{ test_rundir }}/test/common/tls/test_data/no_subject_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_EQ( manager_.createSslServerContext(*store_.rootScope(), *server_context_config, {}, nullptr) .status() @@ -690,8 +698,9 @@ class SslServerContextImplOcspTest : public SslContextImplTest { Envoy::Ssl::ServerContextSharedPtr loadConfigYaml(const std::string& yaml) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(yaml), tls_context); - auto cfg = THROW_OR_RETURN_VALUE(ServerContextConfigImpl::create(tls_context, factory_context_), - std::unique_ptr); + auto cfg = + THROW_OR_RETURN_VALUE(ServerContextConfigImpl::create(tls_context, factory_context_, false), + std::unique_ptr); return loadConfig(*cfg); } }; @@ -901,7 +910,7 @@ class SslServerContextImplTicketTest : public SslContextImplTest { "{{ test_rundir }}/test/common/tls/test_data/unittest_key.pem")); auto server_context_config = - THROW_OR_RETURN_VALUE(ServerContextConfigImpl::create(cfg, factory_context_), + THROW_OR_RETURN_VALUE(ServerContextConfigImpl::create(cfg, factory_context_, false), std::unique_ptr); loadConfig(*server_context_config); } @@ -909,8 +918,9 @@ class SslServerContextImplTicketTest : public SslContextImplTest { void loadConfigYaml(const std::string& yaml) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(yaml), tls_context); - auto cfg = THROW_OR_RETURN_VALUE(ServerContextConfigImpl::create(tls_context, factory_context_), - std::unique_ptr); + auto cfg = + THROW_OR_RETURN_VALUE(ServerContextConfigImpl::create(tls_context, factory_context_, false), + std::unique_ptr); loadConfig(*cfg); } }; @@ -1028,7 +1038,8 @@ TEST_F(SslServerContextImplTicketTest, TicketKeySdsNotReady) { auto* sds_secret_configs = tls_context.mutable_session_ticket_keys_sds_secret_config(); sds_secret_configs->set_name("abc.com"); sds_secret_configs->mutable_sds_config(); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); // When sds secret is not downloaded, config is not ready. EXPECT_FALSE(server_context_config->isReady()); // Set various callbacks to config. @@ -1064,7 +1075,8 @@ name: "abc.com" tls_context.mutable_session_ticket_keys_sds_secret_config()->set_name("abc.com"); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_TRUE(server_context_config->isReady()); ASSERT_EQ(server_context_config->sessionTicketKeys().size(), 2); @@ -1152,7 +1164,8 @@ TEST_F(SslServerContextImplTicketTest, StatelessSessionResumptionEnabledByDefaul )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_FALSE(server_context_config->disableStatelessSessionResumption()); } @@ -1169,7 +1182,8 @@ TEST_F(SslServerContextImplTicketTest, StatelessSessionResumptionExplicitlyEnabl )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_FALSE(server_context_config->disableStatelessSessionResumption()); } @@ -1186,7 +1200,8 @@ TEST_F(SslServerContextImplTicketTest, StatelessSessionResumptionDisabled) { )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_TRUE(server_context_config->disableStatelessSessionResumption()); } @@ -1205,7 +1220,8 @@ TEST_F(SslServerContextImplTicketTest, StatelessSessionResumptionEnabledWhenKeyI )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_FALSE(server_context_config->disableStatelessSessionResumption()); } @@ -1840,8 +1856,9 @@ class ServerContextConfigImplTest : public SslCertsTest { // Multiple TLS certificates are supported. TEST_F(ServerContextConfigImplTest, MultipleTlsCertificates) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "No TLS certificates found for server context"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "No TLS certificates found for server context"); const std::string rsa_tls_certificate_yaml = R"EOF( certificate_chain: filename: "{{ test_rundir }}/test/common/tls/test_data/selfsigned_cert.pem" @@ -1858,7 +1875,8 @@ TEST_F(ServerContextConfigImplTest, MultipleTlsCertificates) { *tls_context.mutable_common_tls_context()->add_tls_certificates()); TestUtility::loadFromYaml(TestEnvironment::substitute(ecdsa_tls_certificate_yaml), *tls_context.mutable_common_tls_context()->add_tls_certificates()); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); auto tls_certs = server_context_config->tlsCertificates(); ASSERT_EQ(2, tls_certs.size()); EXPECT_THAT(tls_certs[0].get().privateKeyPath(), EndsWith("selfsigned_key.pem")); @@ -1867,8 +1885,9 @@ TEST_F(ServerContextConfigImplTest, MultipleTlsCertificates) { TEST_F(ServerContextConfigImplTest, TlsCertificatesAndSdsConfig) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "No TLS certificates found for server context"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "No TLS certificates found for server context"); const std::string tls_certificate_yaml = R"EOF( certificate_chain: filename: "{{ test_rundir }}/test/common/tls/test_data/selfsigned_cert.pem" @@ -1878,8 +1897,9 @@ TEST_F(ServerContextConfigImplTest, TlsCertificatesAndSdsConfig) { TestUtility::loadFromYaml(TestEnvironment::substitute(tls_certificate_yaml), *tls_context.mutable_common_tls_context()->add_tls_certificates()); tls_context.mutable_common_tls_context()->add_tls_certificate_sds_secret_configs(); - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "SDS and non-SDS TLS certificates may not be mixed in server contexts"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "SDS and non-SDS TLS certificates may not be mixed in server contexts"); } TEST_F(ServerContextConfigImplTest, SdsConfigNoName) { @@ -1917,7 +1937,8 @@ TEST_F(ServerContextConfigImplTest, SecretNotReady) { tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); sds_secret_configs->set_name("abc.com"); sds_secret_configs->mutable_sds_config(); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); // When sds secret is not downloaded, config is not ready. EXPECT_FALSE(server_context_config->isReady()); // Set various callbacks to config. @@ -1948,7 +1969,8 @@ TEST_F(ServerContextConfigImplTest, ValidationContextNotReady) { tls_context.mutable_common_tls_context()->mutable_validation_context_sds_secret_config(); sds_secret_configs->set_name("abc.com"); sds_secret_configs->mutable_sds_config(); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); // When sds secret is not downloaded, config is not ready. EXPECT_FALSE(server_context_config->isReady()); // Set various callbacks to config. @@ -1962,7 +1984,8 @@ TEST_F(ServerContextConfigImplTest, ValidationContextNotReady) { TEST_F(ServerContextConfigImplTest, TlsCertificateNonEmpty) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; tls_context.mutable_common_tls_context()->add_tls_certificates(); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); ContextManagerImpl manager(server_factory_context_); Stats::IsolatedStoreImpl store; EXPECT_EQ(manager @@ -1983,8 +2006,9 @@ TEST_F(ServerContextConfigImplTest, InvalidIgnoreCertsNoCA) { server_validation_ctx->set_allow_expired_certificate(true); - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "Certificate validity period is always ignored without trusted CA"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "Certificate validity period is always ignored without trusted CA"); envoy::extensions::transport_sockets::tls::v3::TlsCertificate* server_cert = tls_context.mutable_common_tls_context()->add_tls_certificates(); @@ -1996,19 +2020,20 @@ TEST_F(ServerContextConfigImplTest, InvalidIgnoreCertsNoCA) { server_validation_ctx->set_allow_expired_certificate(false); EXPECT_NO_THROW(auto server_context_config = - *ServerContextConfigImpl::create(tls_context, factory_context_)); + *ServerContextConfigImpl::create(tls_context, factory_context_, false)); server_validation_ctx->set_allow_expired_certificate(true); - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "Certificate validity period is always ignored without trusted CA"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "Certificate validity period is always ignored without trusted CA"); // But once you add a trusted CA, you should be able to create the context. server_validation_ctx->mutable_trusted_ca()->set_filename( TestEnvironment::substitute("{{ test_rundir }}/test/common/tls/test_data/ca_cert.pem")); EXPECT_NO_THROW(auto server_context_config = - *ServerContextConfigImpl::create(tls_context, factory_context_)); + *ServerContextConfigImpl::create(tls_context, factory_context_, false)); } TEST_F(ServerContextConfigImplTest, PrivateKeyMethodLoadFailureNoProvider) { @@ -2031,8 +2056,9 @@ TEST_F(ServerContextConfigImplTest, PrivateKeyMethodLoadFailureNoProvider) { test_value: 100 )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "Failed to load private key provider: mock_provider"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "Failed to load private key provider: mock_provider"); } TEST_F(ServerContextConfigImplTest, PrivateKeyMethodLoadFailureNoProviderFallback) { @@ -2056,8 +2082,9 @@ TEST_F(ServerContextConfigImplTest, PrivateKeyMethodLoadFailureNoProviderFallbac fallback: true )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "Failed to load private key provider: mock_provider"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "Failed to load private key provider: mock_provider"); } TEST_F(ServerContextConfigImplTest, PrivateKeyMethodLoadFailureNoMethod) { @@ -2088,7 +2115,8 @@ TEST_F(ServerContextConfigImplTest, PrivateKeyMethodLoadFailureNoMethod) { test_value: 100 )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_EQ(manager .createSslServerContext(*store.rootScope(), *server_context_config, std::vector{}, nullptr) @@ -2122,7 +2150,8 @@ TEST_F(ServerContextConfigImplTest, PrivateKeyMethodLoadSuccess) { test_value: 100 )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); } TEST_F(ServerContextConfigImplTest, PrivateKeyMethodFallback) { @@ -2153,7 +2182,8 @@ TEST_F(ServerContextConfigImplTest, PrivateKeyMethodFallback) { fallback: true )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - auto server_context_config = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_context_config = + *ServerContextConfigImpl::create(tls_context, factory_context_, false); } // Test that if both typed and untyped matchers for sans are specified, we @@ -2189,7 +2219,7 @@ TEST_F(ServerContextConfigImplTest, DeprecatedSanMatcher) { "Ignoring match_subject_alt_names as match_typed_subject_alt_names is also " "specified, and the former is deprecated.", server_context_config = - *ServerContextConfigImpl::create(tls_context, factory_context_)); + *ServerContextConfigImpl::create(tls_context, factory_context_, false)); EXPECT_EQ(server_context_config->certificateValidationContext()->subjectAltNameMatchers().size(), 1); EXPECT_EQ( @@ -2221,8 +2251,9 @@ TEST_F(ServerContextConfigImplTest, Pkcs12LoadFailureBothPkcs12AndMethod) { test_value: 100 )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "Certificate configuration can't have both pkcs12 and private_key_provider"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "Certificate configuration can't have both pkcs12 and private_key_provider"); } TEST_F(ServerContextConfigImplTest, Pkcs12LoadFailureBothPkcs12AndKey) { @@ -2236,8 +2267,9 @@ TEST_F(ServerContextConfigImplTest, Pkcs12LoadFailureBothPkcs12AndKey) { filename: "{{ test_rundir }}/test/common/tls/test_data/selfsigned_key.pem" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "Certificate configuration can't have both pkcs12 and private_key"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "Certificate configuration can't have both pkcs12 and private_key"); } TEST_F(ServerContextConfigImplTest, Pkcs12LoadFailureBothPkcs12AndCertChain) { @@ -2251,8 +2283,9 @@ TEST_F(ServerContextConfigImplTest, Pkcs12LoadFailureBothPkcs12AndCertChain) { filename: "{{ test_rundir }}/test/common/tls/test_data/san_dns3_certkeychain.p12" )EOF"; TestUtility::loadFromYaml(TestEnvironment::substitute(tls_context_yaml), tls_context); - EXPECT_EQ(ServerContextConfigImpl::create(tls_context, factory_context_).status().message(), - "Certificate configuration can't have both pkcs12 and certificate_chain"); + EXPECT_EQ( + ServerContextConfigImpl::create(tls_context, factory_context_, false).status().message(), + "Certificate configuration can't have both pkcs12 and certificate_chain"); } // TODO: test throw from additional_init diff --git a/test/common/tls/handshaker_factory_test.cc b/test/common/tls/handshaker_factory_test.cc index a22f85832321..72b49a505b0e 100644 --- a/test/common/tls/handshaker_factory_test.cc +++ b/test/common/tls/handshaker_factory_test.cc @@ -289,7 +289,7 @@ TEST_F(HandshakerFactoryDownstreamTest, ServerHandshakerProvidesCertificates) { .WillRepeatedly(Return(std::reference_wrapper(*process_context_impl))); auto server_context_config = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - tls_context_, mock_factory_ctx); + tls_context_, mock_factory_ctx, false); EXPECT_TRUE(server_context_config->isReady()); EXPECT_NO_THROW(*context_manager_->createSslServerContext( *stats_store_.rootScope(), *server_context_config, std::vector{}, nullptr)); diff --git a/test/common/tls/handshaker_test.cc b/test/common/tls/handshaker_test.cc index 0ac2f7d66b2f..db2c4bb3302c 100644 --- a/test/common/tls/handshaker_test.cc +++ b/test/common/tls/handshaker_test.cc @@ -46,6 +46,7 @@ class MockHandshakeCallbacks : public Ssl::HandshakeCallbacks { MOCK_METHOD(void, onFailure, (), (override)); MOCK_METHOD(Network::TransportSocketCallbacks*, transportSocketCallbacks, (), (override)); MOCK_METHOD(void, onAsynchronousCertValidationComplete, (), (override)); + MOCK_METHOD(void, onAsynchronousCertificateSelectionComplete, (), (override)); }; class HandshakerTest : public SslCertsTest { diff --git a/test/common/tls/integration/BUILD b/test/common/tls/integration/BUILD index 9ecd12407c87..66ff5b459239 100644 --- a/test/common/tls/integration/BUILD +++ b/test/common/tls/integration/BUILD @@ -44,6 +44,7 @@ envoy_cc_test( "//source/common/tls:ssl_handshaker_lib", "//source/extensions/transport_sockets/tls:config", "//test/common/config:dummy_config_proto_cc_proto", + "//test/common/tls/cert_selector:async_cert_selector", "//test/common/tls/cert_validator:timed_cert_validator", "//test/integration/filters:stream_info_to_headers_filter_lib", "//test/mocks/secret:secret_mocks", diff --git a/test/common/tls/integration/ssl_integration_test.cc b/test/common/tls/integration/ssl_integration_test.cc index 53773b7081e1..d2718ea31c06 100644 --- a/test/common/tls/integration/ssl_integration_test.cc +++ b/test/common/tls/integration/ssl_integration_test.cc @@ -1176,5 +1176,118 @@ TEST_P(SslKeyLogTest, SetMultipleIps) { logCheck(); } +TEST_P(SslIntegrationTest, SyncCertSelectorSucceeds) { + tls_cert_selector_yaml_ = R"EOF( +name: test-tls-context-provider +typed_config: + "@type": type.googleapis.com/google.protobuf.StringValue + value: sync + )EOF"; + ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { + return makeSslClientConnection({}); + }; + testRouterRequestAndResponseWithBody(16 * 1024 * 1024, 16 * 1024 * 1024, false, false, &creator); + checkStats(); + EXPECT_EQ(test_server_->counter("aysnc_cert_selection.cert_selection_sync")->value(), 1); +} + +TEST_P(SslIntegrationTest, AsyncCertSelectorSucceeds) { + tls_cert_selector_yaml_ = R"EOF( +name: test-tls-context-provider +typed_config: + "@type": type.googleapis.com/google.protobuf.StringValue + value: async + )EOF"; + ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { + return makeSslClientConnection({}); + }; + testRouterRequestAndResponseWithBody(16 * 1024 * 1024, 16 * 1024 * 1024, false, false, &creator); + checkStats(); + + EXPECT_EQ(test_server_->counter("aysnc_cert_selection.cert_selection_async")->value(), 1); + EXPECT_EQ(test_server_->counter("aysnc_cert_selection.cert_selection_async_finished")->value(), + 1); +} + +TEST_P(SslIntegrationTest, AsyncSleepCertSelectorSucceeds) { + tls_cert_selector_yaml_ = R"EOF( +name: test-tls-context-provider +typed_config: + "@type": type.googleapis.com/google.protobuf.StringValue + value: sleep + )EOF"; + ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { + return makeSslClientConnection({}); + }; + testRouterRequestAndResponseWithBody(16 * 1024 * 1024, 16 * 1024 * 1024, false, false, &creator); + checkStats(); + + EXPECT_EQ(test_server_->counter("aysnc_cert_selection.cert_selection_sleep")->value(), 1); + EXPECT_EQ(test_server_->counter("aysnc_cert_selection.cert_selection_sleep_finished")->value(), + 1); +} + +TEST_P(SslIntegrationTest, AsyncSleepCertSelectionAfterTearDown) { + tls_cert_selector_yaml_ = R"EOF( +name: test-tls-context-provider +typed_config: + "@type": type.googleapis.com/google.protobuf.StringValue + value: sleep + )EOF"; + initialize(); + + Network::ClientConnectionPtr connection = makeSslClientConnection({}); + ConnectionStatusCallbacks callbacks; + connection->addConnectionCallbacks(callbacks); + connection->connect(); + const auto* socket = dynamic_cast( + connection->ssl().get()); + ASSERT(socket); + + // wait for the server tls handshake into sleep state. + test_server_->waitForCounterEq("aysnc_cert_selection.cert_selection_sleep", 1, + TestUtility::DefaultTimeout, dispatcher_.get()); + + ASSERT_EQ(connection->state(), Network::Connection::State::Open); + ENVOY_LOG_MISC(debug, "debug: closing connection"); + connection->close(Network::ConnectionCloseType::NoFlush); + connection.reset(); + + // wait the sleep timer in cert selector is triggered. + test_server_->waitForCounterEq("aysnc_cert_selection.cert_selection_sleep_finished", 1, + TestUtility::DefaultTimeout, dispatcher_.get()); +} + +TEST_P(SslIntegrationTest, AsyncCertSelectionAfterSslShutdown) { + tls_cert_selector_yaml_ = R"EOF( +name: test-tls-context-provider +typed_config: + "@type": type.googleapis.com/google.protobuf.StringValue + value: sleep + )EOF"; + initialize(); + + Network::ClientConnectionPtr connection = makeSslClientConnection({}); + ConnectionStatusCallbacks callbacks; + connection->addConnectionCallbacks(callbacks); + connection->connect(); + const auto* socket = dynamic_cast( + connection->ssl().get()); + ASSERT(socket); + + // wait for the server tls handshake into sleep state. + test_server_->waitForCounterEq("aysnc_cert_selection.cert_selection_sleep", 1, + TestUtility::DefaultTimeout, dispatcher_.get()); + + ASSERT_EQ(connection->state(), Network::Connection::State::Open); + connection->close(Network::ConnectionCloseType::NoFlush); + + // wait the sleep timer in cert selector is triggered. + test_server_->waitForCounterEq("aysnc_cert_selection.cert_selection_sleep_finished", 1, + TestUtility::DefaultTimeout, dispatcher_.get()); + + connection.reset(); +} + } // namespace Ssl } // namespace Envoy diff --git a/test/common/tls/integration/ssl_integration_test_base.cc b/test/common/tls/integration/ssl_integration_test_base.cc index 0539b833a9c2..53ae8d35027b 100644 --- a/test/common/tls/integration/ssl_integration_test_base.cc +++ b/test/common/tls/integration/ssl_integration_test_base.cc @@ -15,6 +15,7 @@ void SslIntegrationTestBase::initialize() { .setCurves(server_curves_) .setCiphers(server_ciphers_) .setExpectClientEcdsaCert(client_ecdsa_cert_) + .setTlsCertSelector(tls_cert_selector_yaml_) .setTlsKeyLogFilter(keylog_local_, keylog_remote_, keylog_local_negative_, keylog_remote_negative_, keylog_path_, diff --git a/test/common/tls/integration/ssl_integration_test_base.h b/test/common/tls/integration/ssl_integration_test_base.h index 31bbda465023..4eb3fb99fb5f 100644 --- a/test/common/tls/integration/ssl_integration_test_base.h +++ b/test/common/tls/integration/ssl_integration_test_base.h @@ -39,6 +39,7 @@ class SslIntegrationTestBase : public HttpIntegrationTest { bool ocsp_staple_required_{false}; bool prefer_client_ciphers_{false}; bool client_ecdsa_cert_{false}; + std::string tls_cert_selector_yaml_{""}; // Set this true to debug SSL handshake issues with openssl s_client. The // verbose trace will be in the logs, openssl must be installed separately. bool debug_with_s_client_{false}; diff --git a/test/common/tls/ssl_socket_test.cc b/test/common/tls/ssl_socket_test.cc index 1919fdf24252..4ee59996b4d9 100644 --- a/test/common/tls/ssl_socket_test.cc +++ b/test/common/tls/ssl_socket_test.cc @@ -372,7 +372,7 @@ void testUtil(const TestUtilOptions& options) { TestUtility::loadFromYaml(TestEnvironment::substitute(options.serverCtxYaml()), server_tls_context); auto server_cfg = THROW_OR_RETURN_VALUE( - ServerContextConfigImpl::create(server_tls_context, transport_socket_factory_context), + ServerContextConfigImpl::create(server_tls_context, transport_socket_factory_context, false), std::unique_ptr); NiceMock server_factory_context; ContextManagerImpl manager(server_factory_context); @@ -736,7 +736,8 @@ void testUtilV2(const TestUtilOptionsV2& options) { ASSERT(transport_socket.has_typed_config()); transport_socket.typed_config().UnpackTo(&tls_context); - auto server_cfg = *ServerContextConfigImpl::create(tls_context, transport_socket_factory_context); + auto server_cfg = + *ServerContextConfigImpl::create(tls_context, transport_socket_factory_context, false); auto factory_or_error = ServerSslSocketFactory::create( std::move(server_cfg), manager, *server_stats_store.rootScope(), server_names); @@ -1030,7 +1031,7 @@ TEST_P(SslSocketTest, ServerTransportSocketOptions) { ; envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), tls_context); - auto server_cfg = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(tls_context, factory_context_, false); NiceMock server_factory_context; ContextManagerImpl manager(server_factory_context); auto server_ssl_socket_factory = @@ -3076,7 +3077,7 @@ TEST_P(SslSocketTest, FlushCloseDuringHandshake) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), tls_context); - auto server_cfg = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(tls_context, factory_context_, false); NiceMock server_factory_context; ContextManagerImpl manager(server_factory_context); Stats::TestUtil::TestStore server_stats_store; @@ -3136,7 +3137,7 @@ TEST_P(SslSocketTest, HalfClose) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); - auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_, false); NiceMock server_factory_context; ContextManagerImpl manager(server_factory_context); Stats::TestUtil::TestStore server_stats_store; @@ -3222,7 +3223,7 @@ TEST_P(SslSocketTest, ShutdownWithCloseNotify) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); - auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_, false); NiceMock server_factory_context; ContextManagerImpl manager(server_factory_context); Stats::TestUtil::TestStore server_stats_store; @@ -3314,7 +3315,7 @@ TEST_P(SslSocketTest, ShutdownWithoutCloseNotify) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); - auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_, false); NiceMock server_factory_context; ContextManagerImpl manager(server_factory_context); Stats::TestUtil::TestStore server_stats_store; @@ -3422,7 +3423,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); - auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_, false); NiceMock server_factory_context; ContextManagerImpl manager(server_factory_context); Stats::TestUtil::TestStore server_stats_store; @@ -3519,13 +3520,13 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context1; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml1), server_tls_context1); - auto server_cfg1 = - *ServerContextConfigImpl::create(server_tls_context1, transport_socket_factory_context); + auto server_cfg1 = *ServerContextConfigImpl::create(server_tls_context1, + transport_socket_factory_context, false); envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context2; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml2), server_tls_context2); - auto server_cfg2 = - *ServerContextConfigImpl::create(server_tls_context2, transport_socket_factory_context); + auto server_cfg2 = *ServerContextConfigImpl::create(server_tls_context2, + transport_socket_factory_context, false); auto server_ssl_socket_factory1 = *ServerSslSocketFactory::create( std::move(server_cfg1), manager, *server_stats_store.rootScope(), server_names1); auto server_ssl_socket_factory2 = *ServerSslSocketFactory::create( @@ -3681,7 +3682,7 @@ void testSupportForSessionResumption(const std::string& server_ctx_yaml, envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); auto server_cfg = - *ServerContextConfigImpl::create(server_tls_context, transport_socket_factory_context); + *ServerContextConfigImpl::create(server_tls_context, transport_socket_factory_context, false); auto server_ssl_socket_factory = *ServerSslSocketFactory::create( std::move(server_cfg), manager, *server_stats_store.rootScope(), {}); @@ -4323,10 +4324,10 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context1; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), tls_context1); - auto server_cfg = *ServerContextConfigImpl::create(tls_context1, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(tls_context1, factory_context_, false); envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context2; TestUtility::loadFromYaml(TestEnvironment::substitute(server2_ctx_yaml), tls_context2); - auto server2_cfg = *ServerContextConfigImpl::create(tls_context2, factory_context_); + auto server2_cfg = *ServerContextConfigImpl::create(tls_context2, factory_context_, false); NiceMock server_factory_context; ContextManagerImpl manager(server_factory_context); Stats::TestUtil::TestStore server_stats_store; @@ -4458,7 +4459,7 @@ void SslSocketTest::testClientSessionResumption(const std::string& server_ctx_ya envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_ctx_proto; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_ctx_proto); auto server_cfg = - *ServerContextConfigImpl::create(server_ctx_proto, transport_socket_factory_context); + *ServerContextConfigImpl::create(server_ctx_proto, transport_socket_factory_context, false); auto server_ssl_socket_factory = *ServerSslSocketFactory::create( std::move(server_cfg), manager, *server_stats_store.rootScope(), std::vector{}); @@ -4721,7 +4722,7 @@ TEST_P(SslSocketTest, SslError) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), tls_context); - auto server_cfg = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(tls_context, factory_context_, false); ContextManagerImpl manager(factory_context_.serverFactoryContext()); Stats::TestUtil::TestStore server_stats_store; auto server_ssl_socket_factory = *ServerSslSocketFactory::create( @@ -5253,7 +5254,7 @@ TEST_P(SslSocketTest, SetSignatureAlgorithms) { envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); - auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, factory_context_, false); ContextManagerImpl manager(factory_context_.serverFactoryContext()); Stats::TestUtil::TestStore server_stats_store; auto server_ssl_socket_factory = *ServerSslSocketFactory::create( @@ -5865,7 +5866,7 @@ TEST_P(SslSocketTest, DownstreamNotReadySslSocket) { tls_context.mutable_common_tls_context()->mutable_tls_certificate_sds_secret_configs()->Add(); sds_secret_configs->set_name("abc.com"); sds_secret_configs->mutable_sds_config(); - auto server_cfg = *ServerContextConfigImpl::create(tls_context, factory_context_); + auto server_cfg = *ServerContextConfigImpl::create(tls_context, factory_context_, false); EXPECT_TRUE(server_cfg->tlsCertificates().empty()); EXPECT_FALSE(server_cfg->isReady()); @@ -5957,7 +5958,8 @@ class SslReadBufferLimitTest : public SslSocketTest { void initialize() { TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml_), downstream_tls_context_); - auto server_cfg = *ServerContextConfigImpl::create(downstream_tls_context_, factory_context_); + auto server_cfg = + *ServerContextConfigImpl::create(downstream_tls_context_, factory_context_, false); manager_ = std::make_unique(factory_context_.serverFactoryContext()); server_ssl_socket_factory_ = *ServerSslSocketFactory::create(std::move(server_cfg), *manager_, *server_stats_store_.rootScope(), diff --git a/test/common/tls/tls_certificate_selector_test.cc b/test/common/tls/tls_certificate_selector_test.cc new file mode 100644 index 000000000000..bc78f264f7e4 --- /dev/null +++ b/test/common/tls/tls_certificate_selector_test.cc @@ -0,0 +1,405 @@ +#include +#include +#include + +#include "envoy/config/listener/v3/listener.pb.h" +#include "envoy/config/listener/v3/listener_components.pb.h" +#include "envoy/extensions/transport_sockets/tls/v3/cert.pb.h" +#include "envoy/network/transport_socket.h" + +#include "source/common/buffer/buffer_impl.h" +#include "source/common/common/empty_string.h" +#include "source/common/event/dispatcher_impl.h" +#include "source/common/json/json_loader.h" +#include "source/common/network/address_impl.h" +#include "source/common/network/listen_socket_impl.h" +#include "source/common/network/tcp_listener_impl.h" +#include "source/common/network/transport_socket_options_impl.h" +#include "source/common/network/utility.h" +#include "source/common/stream_info/stream_info_impl.h" +#include "source/common/tls/client_ssl_socket.h" +#include "source/common/tls/context_config_impl.h" +#include "source/common/tls/context_impl.h" +#include "source/common/tls/private_key/private_key_manager_impl.h" +#include "source/common/tls/server_context_config_impl.h" +#include "source/common/tls/server_ssl_socket.h" + +#include "test/common/tls/cert_validator/timed_cert_validator.h" +#include "test/common/tls/ssl_certs_test.h" +#include "test/common/tls/test_data/ca_cert_info.h" +#include "test/common/tls/test_data/extensions_cert_info.h" +#include "test/common/tls/test_data/no_san_cert_info.h" +#include "test/common/tls/test_data/password_protected_cert_info.h" +#include "test/common/tls/test_data/san_dns2_cert_info.h" +#include "test/common/tls/test_data/san_dns3_cert_info.h" +#include "test/common/tls/test_data/san_dns4_cert_info.h" +#include "test/common/tls/test_data/san_dns_cert_info.h" +#include "test/common/tls/test_data/san_dns_ecdsa_1_cert_info.h" +#include "test/common/tls/test_data/san_dns_rsa_1_cert_info.h" +#include "test/common/tls/test_data/san_dns_rsa_2_cert_info.h" +#include "test/common/tls/test_data/san_multiple_dns_1_cert_info.h" +#include "test/common/tls/test_data/san_multiple_dns_cert_info.h" +#include "test/common/tls/test_data/san_uri_cert_info.h" +#include "test/common/tls/test_data/selfsigned_ecdsa_p256_cert_info.h" +#include "test/common/tls/test_private_key_method_provider.h" +#include "test/mocks/buffer/mocks.h" +#include "test/mocks/init/mocks.h" +#include "test/mocks/local_info/mocks.h" +#include "test/mocks/network/io_handle.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/runtime/mocks.h" +#include "test/mocks/secret/mocks.h" +#include "test/mocks/server/server_factory_context.h" +#include "test/mocks/server/transport_socket_factory_context.h" +#include "test/mocks/ssl/mocks.h" +#include "test/mocks/stats/mocks.h" +#include "test/test_common/environment.h" +#include "test/test_common/network_utility.h" +#include "test/test_common/registry.h" +#include "test/test_common/test_runtime.h" +#include "test/test_common/utility.h" + +#include "absl/strings/str_replace.h" +#include "absl/types/optional.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "openssl/ssl.h" +#include "xds/type/v3/typed_struct.pb.h" + +using testing::_; +using testing::Invoke; +using testing::MockFunction; +using testing::NiceMock; +using testing::Ref; +using testing::ReturnRef; +using testing::WithArg; + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +class TestTlsCertificateSelector : public virtual Ssl::TlsCertificateSelector { +public: + TestTlsCertificateSelector(Ssl::TlsCertificateSelectorContext& selector_ctx, + const Protobuf::Message&) + : selector_ctx_(selector_ctx) {} + ~TestTlsCertificateSelector() override { + ENVOY_LOG_MISC(info, "debug: ~TestTlsCertificateSelector"); + } + Ssl::SelectionResult selectTlsContext(const SSL_CLIENT_HELLO&, + Ssl::CertificateSelectionCallbackPtr cb) override { + ENVOY_LOG_MISC(info, "debug: select context"); + + switch (mod_) { + case Ssl::SelectionResult::SelectionStatus::Success: + ENVOY_LOG_MISC(info, "debug: select cert done"); + return {mod_, &getTlsContext(), false}; + break; + case Ssl::SelectionResult::SelectionStatus::Pending: + ENVOY_LOG_MISC(info, "debug: select cert async"); + cb_ = std::move(cb); + cb_->dispatcher().post([this] { selectTlsContextAsync(); }); + break; + default: + break; + } + return {mod_, nullptr, false}; + }; + + std::pair findTlsContext(absl::string_view, bool, + bool, bool*) override { + PANIC("unreachable"); + }; + + void selectTlsContextAsync() { + ENVOY_LOG_MISC(info, "debug: select cert async done"); + cb_->onCertificateSelectionResult(getTlsContext(), false); + } + + const Ssl::TlsContext& getTlsContext() { return selector_ctx_.getTlsContexts()[0]; } + + Ssl::SelectionResult::SelectionStatus mod_; + +private: + Ssl::TlsCertificateSelectorContext& selector_ctx_; + Ssl::CertificateSelectionCallbackPtr cb_; +}; + +class TestTlsCertificateSelectorFactory : public Ssl::TlsCertificateSelectorConfigFactory { +public: + using CreateProviderHook = + std::function; + + Ssl::TlsCertificateSelectorFactory + createTlsCertificateSelectorFactory(const Protobuf::Message& config, + Server::Configuration::CommonFactoryContext& factory_context, + ProtobufMessage::ValidationVisitor& validation_visitor, + absl::Status& creation_status, bool for_quic) override { + if (selector_cb_) { + selector_cb_(config, factory_context, validation_visitor); + } + if (for_quic) { + creation_status = absl::InvalidArgumentError("does not supported for quic"); + return Ssl::TlsCertificateSelectorFactory(); + } + return [&config, this](const Ssl::ServerContextConfig&, + Ssl::TlsCertificateSelectorContext& selector_ctx) { + ENVOY_LOG_MISC(info, "debug: init provider"); + auto provider = std::make_unique(selector_ctx, config); + provider->mod_ = mod_; + return provider; + }; + } + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + std::string name() const override { return "test-tls-context-provider"; }; + + CreateProviderHook selector_cb_; + Ssl::SelectionResult::SelectionStatus mod_; +}; + +Network::ListenerPtr createListener(Network::SocketSharedPtr&& socket, + Network::TcpListenerCallbacks& cb, Runtime::Loader& runtime, + const Network::ListenerConfig& listener_config, + Server::ThreadLocalOverloadStateOptRef overload_state, + Random::RandomGenerator& rng, Event::Dispatcher& dispatcher) { + return std::make_unique( + dispatcher, rng, runtime, std::move(socket), cb, listener_config.bindToPort(), + listener_config.ignoreGlobalConnLimit(), listener_config.shouldBypassOverloadManager(), + listener_config.maxConnectionsToAcceptPerSocketEvent(), overload_state); +} + +class TlsCertificateSelectorFactoryTest + : public testing::Test, + public testing::WithParamInterface { +protected: + TlsCertificateSelectorFactoryTest() + : registered_factory_(provider_factory_), version_(GetParam()) { + scoped_runtime_.mergeValues( + {{"envoy.reloadable_features.no_extension_lookup_by_name", "false"}}); + } + + void testUtil(Ssl::SelectionResult::SelectionStatus mod) { + const std::string server_ctx_yaml = R"EOF( + common_tls_context: + tls_certificates: + certificate_chain: + filename: "{{ test_rundir }}/test/common/tls/test_data/no_san_cert.pem" + private_key: + filename: "{{ test_rundir }}/test/common/tls/test_data/no_san_key.pem" + validation_context: + trusted_ca: + filename: "{{ test_rundir }}/test/common/tls/test_data/ca_cert.pem" + custom_tls_certificate_selector: + name: test-tls-context-provider + typed_config: + "@type": type.googleapis.com/xds.type.v3.TypedStruct + value: + foo: bar +)EOF"; + const std::string client_ctx_yaml = R"EOF( + common_tls_context: + tls_certificates: + certificate_chain: + filename: "{{ test_rundir }}/test/common/tls/test_data/no_san_cert.pem" + private_key: + filename: "{{ test_rundir }}/test/common/tls/test_data/no_san_key.pem" +)EOF"; + + Event::SimulatedTimeSystem time_system; + + Stats::TestUtil::TestStore server_stats_store; + Api::ApiPtr server_api = Api::createApiForTest(server_stats_store, time_system); + NiceMock runtime; + testing::NiceMock + transport_socket_factory_context; + ON_CALL(transport_socket_factory_context.server_context_, api()) + .WillByDefault(ReturnRef(*server_api)); + + MockFunction mock_factory_cb; + provider_factory_.selector_cb_ = mock_factory_cb.AsStdFunction(); + + EXPECT_CALL(mock_factory_cb, Call) + .WillOnce(WithArg<1>([&](Server::Configuration::CommonFactoryContext& context) { + // Check that the objects available via the context are the same ones + // provided to the parent context. + EXPECT_THAT(context.api(), Ref(*server_api)); + })); + + envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; + TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); + // provider factory callback will be Called here. + auto server_cfg = *ServerContextConfigImpl::create(server_tls_context, + transport_socket_factory_context, false); + + Event::DispatcherPtr dispatcher = server_api->allocateDispatcher("test_thread"); + provider_factory_.mod_ = mod; + + NiceMock server_factory_context; + Tls::ContextManagerImpl manager(server_factory_context); + auto server_ssl_socket_factory = *ServerSslSocketFactory::create( + std::move(server_cfg), manager, *server_stats_store.rootScope(), + std::vector{}); + + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(version_)); + Network::MockTcpListenerCallbacks callbacks; + NiceMock listener_config; + Server::ThreadLocalOverloadStateOptRef overload_state; + Network::ListenerPtr listener = + createListener(socket, callbacks, runtime, listener_config, overload_state, + server_api->randomGenerator(), *dispatcher); + + envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext client_tls_context; + TestUtility::loadFromYaml(TestEnvironment::substitute(client_ctx_yaml), client_tls_context); + + Stats::TestUtil::TestStore client_stats_store; + Api::ApiPtr client_api = Api::createApiForTest(client_stats_store, time_system); + testing::NiceMock + client_factory_context; + ON_CALL(client_factory_context.server_context_, api()).WillByDefault(ReturnRef(*client_api)); + + auto client_cfg = *ClientContextConfigImpl::create(client_tls_context, client_factory_context); + auto client_ssl_socket_factory = *ClientSslSocketFactory::create( + std::move(client_cfg), manager, *client_stats_store.rootScope()); + Network::ClientConnectionPtr client_connection = dispatcher->createClientConnection( + socket->connectionInfoProvider().localAddress(), Network::Address::InstanceConstSharedPtr(), + client_ssl_socket_factory->createTransportSocket(nullptr, nullptr), nullptr, nullptr); + Network::ConnectionPtr server_connection; + Network::MockConnectionCallbacks server_connection_callbacks; + NiceMock stream_info; + EXPECT_CALL(callbacks, onAccept_(_)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { + auto ssl_socket = server_ssl_socket_factory->createDownstreamTransportSocket(); + // configureInitialCongestionWindow is an unimplemented empty function, this is just to + // increase code coverage. + ssl_socket->configureInitialCongestionWindow(100, std::chrono::microseconds(123)); + server_connection = dispatcher->createServerConnection( + std::move(socket), std::move(ssl_socket), stream_info); + server_connection->addConnectionCallbacks(server_connection_callbacks); + })); + EXPECT_CALL(callbacks, recordConnectionsAcceptedOnSocketEvent(_)); + + Network::MockConnectionCallbacks client_connection_callbacks; + client_connection->addConnectionCallbacks(client_connection_callbacks); + client_connection->connect(); + + size_t connect_count = 0; + auto connect_second_time = [&]() { + ENVOY_LOG_MISC(debug, "connect count: {}", connect_count); + if (++connect_count == 2) { + // By default, the session is not created with session resumption. The + // client should see a session ID but the server should not. + EXPECT_EQ(EMPTY_STRING, server_connection->ssl()->sessionId()); + EXPECT_NE(EMPTY_STRING, client_connection->ssl()->sessionId()); + + server_connection->close(Network::ConnectionCloseType::NoFlush); + client_connection->close(Network::ConnectionCloseType::NoFlush); + dispatcher->exit(); + } + }; + + size_t close_count = 0; + auto close_second_time = [&close_count, &dispatcher]() { + if (++close_count == 2) { + dispatcher->exit(); + } + }; + + if (false) { + EXPECT_CALL(client_connection_callbacks, onEvent) + .WillRepeatedly(Invoke([&](Network::ConnectionEvent e) -> void { + ENVOY_LOG_MISC(info, "client onEvent {}", static_cast(e)); + connect_second_time(); + })); + + EXPECT_CALL(server_connection_callbacks, onEvent) + .WillRepeatedly(Invoke([&](Network::ConnectionEvent e) -> void { + ENVOY_LOG_MISC(info, "server onEvent {}", static_cast(e)); + connect_second_time(); + })); + } else { + if (mod == Ssl::SelectionResult::SelectionStatus::Failed) { + EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::RemoteClose)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { close_second_time(); })); + EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::RemoteClose)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { close_second_time(); })); + } else { + EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { connect_second_time(); })); + EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { connect_second_time(); })); + EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)); + EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)); + } + } + + dispatcher->run(Event::Dispatcher::RunType::Block); + } + + TestTlsCertificateSelectorFactory provider_factory_; + Registry::InjectFactory registered_factory_; + TestScopedRuntime scoped_runtime_; + + Network::Address::IpVersion version_; +}; + +TEST_P(TlsCertificateSelectorFactoryTest, Success) { + testUtil(Ssl::SelectionResult::SelectionStatus::Success); +} + +TEST_P(TlsCertificateSelectorFactoryTest, Failed) { + testUtil(Ssl::SelectionResult::SelectionStatus::Failed); +} + +TEST_P(TlsCertificateSelectorFactoryTest, Pending) { + testUtil(Ssl::SelectionResult::SelectionStatus::Pending); +} + +TEST_P(TlsCertificateSelectorFactoryTest, QUICFactory) { + const std::string server_ctx_yaml = R"EOF( + common_tls_context: + tls_certificates: + certificate_chain: + filename: "{{ test_rundir }}/test/common/tls/test_data/no_san_cert.pem" + private_key: + filename: "{{ test_rundir }}/test/common/tls/test_data/no_san_key.pem" + validation_context: + trusted_ca: + filename: "{{ test_rundir }}/test/common/tls/test_data/ca_cert.pem" + custom_tls_certificate_selector: + name: test-tls-context-provider + typed_config: + "@type": type.googleapis.com/xds.type.v3.TypedStruct + value: + foo: bar +)EOF"; + + Event::SimulatedTimeSystem time_system; + Stats::TestUtil::TestStore server_stats_store; + Api::ApiPtr server_api = Api::createApiForTest(server_stats_store, time_system); + testing::NiceMock + transport_socket_factory_context; + ON_CALL(transport_socket_factory_context.server_context_, api()) + .WillByDefault(ReturnRef(*server_api)); + + envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; + TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); + // provider factory callback will be Called here. + auto server_cfg = + ServerContextConfigImpl::create(server_tls_context, transport_socket_factory_context, true); + + EXPECT_FALSE(server_cfg.ok()); +} + +INSTANTIATE_TEST_SUITE_P(IpVersions, TlsCertificateSelectorFactoryTest, + testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), + TestUtility::ipTestParamsToString); + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/config/utility.cc b/test/config/utility.cc index 0f908a434394..69891e86f60b 100644 --- a/test/config/utility.cc +++ b/test/config/utility.cc @@ -1537,6 +1537,16 @@ void ConfigHelper::initializeTls( *validation_context->mutable_match_typed_subject_alt_names() = {options.san_matchers_.begin(), options.san_matchers_.end()}; } + if (!options.tls_cert_selector_yaml_.empty()) { + auto* cert_selector = common_tls_context.mutable_custom_tls_certificate_selector(); +#ifdef ENVOY_ENABLE_YAML + TestUtility::loadFromYaml(TestEnvironment::substitute(options.tls_cert_selector_yaml_), + *cert_selector); +#else + UNREFERENCED_PARAMETER(cert_selector); + PANIC("YAML support compiled out"); +#endif + } initializeTlsKeyLog(common_tls_context, options); } diff --git a/test/config/utility.h b/test/config/utility.h index c9016cb15691..eb34d534bbec 100644 --- a/test/config/utility.h +++ b/test/config/utility.h @@ -131,6 +131,11 @@ class ConfigHelper { return *this; } + ServerSslOptions& setTlsCertSelector(std::string yaml) { + tls_cert_selector_yaml_ = yaml; + return *this; + } + bool allow_expired_certificate_{}; envoy::config::core::v3::TypedExtensionConfig* custom_validator_config_{nullptr}; bool rsa_cert_{true}; @@ -152,6 +157,7 @@ class ConfigHelper { Network::Address::IpVersion ip_version_{Network::Address::IpVersion::v4}; std::vector san_matchers_{}; + std::string tls_cert_selector_yaml_{""}; bool client_with_intermediate_cert_{false}; bool trust_root_only_{false}; absl::optional max_verify_depth_{absl::nullopt}; diff --git a/test/extensions/config_subscription/grpc/xds_failover_integration_test.cc b/test/extensions/config_subscription/grpc/xds_failover_integration_test.cc index 61ed6134fbc2..d48a51b16980 100644 --- a/test/extensions/config_subscription/grpc/xds_failover_integration_test.cc +++ b/test/extensions/config_subscription/grpc/xds_failover_integration_test.cc @@ -147,7 +147,7 @@ class XdsFailoverAdsIntegrationTest : public AdsDeltaSotwIntegrationSubStatePara tls_cert->mutable_private_key()->set_filename( TestEnvironment::runfilesPath("test/config/integration/certs/upstreamkey.pem")); auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - tls_context, factory_context_); + tls_context, factory_context_, false); // upstream_stats_store_ should have been initialized be prior call to // BaseIntegrationTest::createXdsUpstream(). ASSERT(upstream_stats_store_ != nullptr); diff --git a/test/extensions/filters/http/router/auto_sni_integration_test.cc b/test/extensions/filters/http/router/auto_sni_integration_test.cc index 4ae07e01a362..325c64af38af 100644 --- a/test/extensions/filters/http/router/auto_sni_integration_test.cc +++ b/test/extensions/filters/http/router/auto_sni_integration_test.cc @@ -60,7 +60,7 @@ class AutoSniIntegrationTest : public testing::TestWithParam mock_factory_ctx; ON_CALL(mock_factory_ctx.server_context_, api()).WillByDefault(testing::ReturnRef(*api_)); auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - downstream_tls_context, mock_factory_ctx); + downstream_tls_context, mock_factory_ctx, false); static auto* client_stats_store = new Stats::TestIsolatedStoreImpl(); tls_context_ = Network::DownstreamTransportSocketFactoryPtr{ *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( diff --git a/test/integration/alpn_selection_integration_test.cc b/test/integration/alpn_selection_integration_test.cc index 2b5421fe880d..5812e2fafff4 100644 --- a/test/integration/alpn_selection_integration_test.cc +++ b/test/integration/alpn_selection_integration_test.cc @@ -69,7 +69,7 @@ require_client_certificate: true TestEnvironment::runfilesPath("test/config/integration/certs/cacert.pem")); TestUtility::loadFromYaml(yaml, tls_context); auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - tls_context, factory_context_); + tls_context, factory_context_, false); static auto* upstream_stats_store = new Stats::IsolatedStoreImpl(); return *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( std::move(cfg), context_manager_, *upstream_stats_store->rootScope(), diff --git a/test/integration/base_integration_test.cc b/test/integration/base_integration_test.cc index ce040b423743..457f03b0a011 100644 --- a/test/integration/base_integration_test.cc +++ b/test/integration/base_integration_test.cc @@ -153,7 +153,7 @@ BaseIntegrationTest::createUpstreamTlsContext(const FakeUpstreamConfig& upstream } if (upstream_config.upstream_protocol_ != Http::CodecType::HTTP3) { auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - tls_context, factory_context_); + tls_context, factory_context_, false); static auto* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); return *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( std::move(cfg), context_manager_, *upstream_stats_store->rootScope(), @@ -581,7 +581,7 @@ void BaseIntegrationTest::createXdsUpstream() { tls_cert->mutable_private_key()->set_filename( TestEnvironment::runfilesPath("test/config/integration/certs/upstreamkey.pem")); auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - tls_context, factory_context_); + tls_context, factory_context_, false); upstream_stats_store_ = std::make_unique(); auto context = *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( diff --git a/test/integration/sds_dynamic_integration_test.cc b/test/integration/sds_dynamic_integration_test.cc index e346767efe9a..9a189c65053b 100644 --- a/test/integration/sds_dynamic_integration_test.cc +++ b/test/integration/sds_dynamic_integration_test.cc @@ -712,7 +712,7 @@ class SdsDynamicDownstreamCertValidationContextTest : public SdsDynamicDownstrea TestEnvironment::runfilesPath("test/config/integration/certs/clientkey.pem")); auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( - tls_context, factory_context_); + tls_context, factory_context_, false); static auto* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); return Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( std::move(cfg), context_manager_, *upstream_stats_store->rootScope(), diff --git a/test/integration/ssl_utility.cc b/test/integration/ssl_utility.cc index ff7d48b98201..6c48ffdd5cff 100644 --- a/test/integration/ssl_utility.cc +++ b/test/integration/ssl_utility.cc @@ -118,8 +118,8 @@ createUpstreamSslContext(ContextManager& context_manager, Api::Api& api, bool us NiceMock mock_factory_ctx; ON_CALL(mock_factory_ctx.server_context_, api()).WillByDefault(ReturnRef(api)); - auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create(tls_context, - mock_factory_ctx); + auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( + tls_context, mock_factory_ctx, false); static auto* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); if (!use_http3) { @@ -151,8 +151,8 @@ Network::DownstreamTransportSocketFactoryPtr createFakeUpstreamSslContext( tls_cert->mutable_private_key()->set_filename(TestEnvironment::runfilesPath( fmt::format("test/config/integration/certs/{}key.pem", upstream_cert_name))); - auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create(tls_context, - factory_context); + auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( + tls_context, factory_context, false); static auto* upstream_stats_store = new Stats::IsolatedStoreImpl(); return *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( diff --git a/test/integration/xfcc_integration_test.cc b/test/integration/xfcc_integration_test.cc index e614a50ea7cd..b70ab7aac9ed 100644 --- a/test/integration/xfcc_integration_test.cc +++ b/test/integration/xfcc_integration_test.cc @@ -112,8 +112,8 @@ Network::DownstreamTransportSocketFactoryPtr XfccIntegrationTest::createUpstream tls_cert->mutable_private_key()->set_filename( TestEnvironment::runfilesPath("test/config/integration/certs/upstreamkey.pem")); - auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create(tls_context, - factory_context_); + auto cfg = *Extensions::TransportSockets::Tls::ServerContextConfigImpl::create( + tls_context, factory_context_, false); static auto* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); return *Extensions::TransportSockets::Tls::ServerSslSocketFactory::create( std::move(cfg), *context_manager_, *(upstream_stats_store->rootScope()), diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 7a3e35a95707..211c7afae603 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -137,6 +137,8 @@ class MockServerContextConfig : public ServerContextConfig { MOCK_METHOD(void, setSecretUpdateCallback, (std::function callback)); MOCK_METHOD(Ssl::HandshakerFactoryCb, createHandshaker, (), (const, override)); + MOCK_METHOD(Ssl::TlsCertificateSelectorFactory, tlsCertificateSelectorFactory, (), + (const, override)); MOCK_METHOD(Ssl::HandshakerCapabilities, capabilities, (), (const, override)); MOCK_METHOD(Ssl::SslCtxCb, sslctxCb, (), (const, override)); diff --git a/test/per_file_coverage.sh b/test/per_file_coverage.sh index a387bfb6e6af..5258c8403c84 100755 --- a/test/per_file_coverage.sh +++ b/test/per_file_coverage.sh @@ -50,7 +50,7 @@ declare -a KNOWN_LOW_COVERAGE=( "source/extensions/tracers/opencensus:94.0" "source/extensions/tracers/zipkin:95.8" "source/extensions/transport_sockets:97.4" -"source/common/tls:94.9" +"source/common/tls:94.7" "source/common/tls/cert_validator:94.2" "source/common/tls/private_key:88.9" "source/extensions/wasm_runtime/wamr:0.0" # Not enabled in coverage build