diff --git a/crypto/s2n_rsa_pss.c b/crypto/s2n_rsa_pss.c index da034d6ad30..a865f0a66a2 100644 --- a/crypto/s2n_rsa_pss.c +++ b/crypto/s2n_rsa_pss.c @@ -17,10 +17,6 @@ #include #include -#include "error/s2n_errno.h" - -#include "stuffer/s2n_stuffer.h" - #include "crypto/s2n_evp_signing.h" #include "crypto/s2n_hash.h" #include "crypto/s2n_openssl.h" @@ -28,7 +24,8 @@ #include "crypto/s2n_rsa_pss.h" #include "crypto/s2n_rsa_signing.h" #include "crypto/s2n_pkey.h" - +#include "error/s2n_errno.h" +#include "stuffer/s2n_stuffer.h" #include "utils/s2n_blob.h" #include "utils/s2n_random.h" #include "utils/s2n_safety.h" @@ -144,7 +141,7 @@ static int s2n_rsa_validate_params_match(const struct s2n_pkey *pub, const struc POSIX_ENSURE_REF(priv); /* OpenSSL Documentation Links: - * - https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_get0_RSA.html + * - https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_get1_RSA.html * - https://www.openssl.org/docs/manmaster/man3/RSA_get0_key.html */ RSA *pub_rsa_key = pub->key.rsa_key.rsa; @@ -176,14 +173,21 @@ static int s2n_rsa_pss_keys_match(const struct s2n_pkey *pub, const struct s2n_p static int s2n_rsa_pss_key_free(struct s2n_pkey *pkey) { - /* This object does not own the reference to the key -- - * s2n_pkey handles it. */ + POSIX_ENSURE_REF(pkey); + struct s2n_rsa_key *rsa_key = &pkey->key.rsa_key; + if (rsa_key->rsa == NULL) { + return 0; + } + + RSA_free(rsa_key->rsa); + rsa_key->rsa = NULL; return 0; } int s2n_evp_pkey_to_rsa_pss_public_key(struct s2n_rsa_key *rsa_key, EVP_PKEY *pkey) { - RSA *pub_rsa_key = EVP_PKEY_get0_RSA(pkey); + RSA *pub_rsa_key = EVP_PKEY_get1_RSA(pkey); + POSIX_ENSURE_REF(pub_rsa_key); S2N_ERROR_IF(s2n_rsa_is_private_key(pub_rsa_key), S2N_ERR_KEY_MISMATCH); @@ -193,7 +197,7 @@ int s2n_evp_pkey_to_rsa_pss_public_key(struct s2n_rsa_key *rsa_key, EVP_PKEY *pk int s2n_evp_pkey_to_rsa_pss_private_key(struct s2n_rsa_key *rsa_key, EVP_PKEY *pkey) { - RSA *priv_rsa_key = EVP_PKEY_get0_RSA(pkey); + RSA *priv_rsa_key = EVP_PKEY_get1_RSA(pkey); POSIX_ENSURE_REF(priv_rsa_key); /* Documentation: https://www.openssl.org/docs/man1.1.1/man3/RSA_check_key.html */ diff --git a/tests/unit/s2n_rsa_pss_rsae_test.c b/tests/unit/s2n_rsa_pss_rsae_test.c index 49ac7af8998..20983950bf5 100644 --- a/tests/unit/s2n_rsa_pss_rsae_test.c +++ b/tests/unit/s2n_rsa_pss_rsae_test.c @@ -22,9 +22,11 @@ #include "crypto/s2n_hash.h" #include "crypto/s2n_rsa.h" #include "crypto/s2n_rsa_pss.h" +#include "error/s2n_errno.h" #include "stuffer/s2n_stuffer.h" #include "tls/s2n_connection.h" #include "tls/s2n_config.h" +#include "utils/s2n_safety.h" #include "utils/s2n_random.h" #define HASH_ALG S2N_HASH_SHA256 @@ -195,7 +197,7 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_pkey_free(&rsa_public_key)); } - #if RSA_PSS_CERTS_SUPPORTED +#if RSA_PSS_CERTS_SUPPORTED struct s2n_cert_chain_and_key *rsa_pss_cert_chain; EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&rsa_pss_cert_chain, @@ -245,11 +247,10 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_asn1der_to_public_key_and_type(&rsa_pss_public_key, &rsa_pss_pkey_type, &rsa_pss_cert_chain->cert_chain->head->raw)); EXPECT_EQUAL(rsa_pss_pkey_type, S2N_PKEY_TYPE_RSA_PSS); - /* Set the keys equal */ - const BIGNUM *n, *e, *d; - RSA_get0_key(EVP_PKEY_get0_RSA(rsa_public_key.pkey), &n, &e, &d); - EXPECT_SUCCESS(RSA_set0_key(EVP_PKEY_get0_RSA(rsa_pss_public_key.pkey), - BN_dup(n), BN_dup(e), BN_dup(d))); + /* Set the keys equal. */ + RSA *rsa_key_copy = EVP_PKEY_get1_RSA(rsa_public_key.pkey); + POSIX_GUARD_OSSL(EVP_PKEY_set1_RSA(rsa_pss_public_key.pkey, rsa_key_copy), S2N_ERR_KEY_INIT); + RSA_free(rsa_key_copy); /* RSA signed with PSS, RSA_PSS verified with PSS */ { @@ -285,12 +286,15 @@ int main(int argc, char **argv) &rsa_cert_chain->cert_chain->head->raw)); EXPECT_EQUAL(rsa_pkey_type, S2N_PKEY_TYPE_RSA); - RSA *rsa_key = EVP_PKEY_get0_RSA(rsa_public_key.pkey); + /* Modify the rsa_public_key for each test_case. */ + RSA *rsa_key_copy = EVP_PKEY_get1_RSA(rsa_public_key.pkey); BIGNUM *n = BN_new(), *e = BN_new(), *d = BN_new(); EXPECT_SUCCESS(BN_hex2bn(&n, test_case.key_param_n)); EXPECT_SUCCESS(BN_hex2bn(&e, test_case.key_param_e)); EXPECT_SUCCESS(BN_hex2bn(&d, test_case.key_param_d)); - EXPECT_SUCCESS(RSA_set0_key(rsa_key, n, e, d)); + EXPECT_SUCCESS(RSA_set0_key(rsa_key_copy, n, e, d)); + POSIX_GUARD_OSSL(EVP_PKEY_set1_RSA(rsa_public_key.pkey, rsa_key_copy), S2N_ERR_KEY_INIT); + RSA_free(rsa_key_copy); struct s2n_stuffer message_stuffer = { 0 }, signature_stuffer = { 0 }; s2n_stuffer_alloc_ro_from_hex_string(&message_stuffer, test_case.message); @@ -311,7 +315,8 @@ int main(int argc, char **argv) } EXPECT_SUCCESS(s2n_cert_chain_and_key_free(rsa_pss_cert_chain)); - #endif +#endif + EXPECT_SUCCESS(s2n_cert_chain_and_key_free(rsa_cert_chain)); END_TEST(); }