From fef0cd7b5fd56016ad220f3c7b8f3abd720ab81f Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Thu, 22 Sep 2022 12:28:20 -0400 Subject: [PATCH] fix: throw error if Auto IAM AuthN is unsupported (#476) --- google/cloud/sql/connector/instance.py | 29 ++++++++++++++---- google/cloud/sql/connector/refresh_utils.py | 1 + tests/conftest.py | 9 ++++-- tests/system/test_connector_object.py | 33 +++++++++++++++++++++ tests/unit/mocks.py | 10 +++++-- tests/unit/test_instance.py | 19 ++++++++++++ tests/unit/test_refresh_utils.py | 6 ++-- 7 files changed, 95 insertions(+), 12 deletions(-) diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 31f5eedc..5bb420eb 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -100,6 +100,15 @@ def __init__(self, *args: Any) -> None: super(ExpiredInstanceMetadata, self).__init__(self, *args) +class AutoIAMAuthNotSupported(Exception): + """ + Exception to be raised when Automatic IAM Authentication is not + supported with database engine version. + """ + + pass + + class InstanceMetadata: ip_addrs: Dict[str, Any] context: ssl.SSLContext @@ -115,7 +124,6 @@ def __init__( enable_iam_auth: bool, ) -> None: self.ip_addrs = ip_addrs - self.context = ssl.SSLContext(ssl.PROTOCOL_TLS) # verify OpenSSL version supports TLSv1.3 @@ -366,10 +374,21 @@ async def _perform_refresh(self) -> InstanceMetadata: self._enable_iam_auth, ) ) - - metadata, ephemeral_cert = await asyncio.gather( - metadata_task, ephemeral_task - ) + try: + metadata = await metadata_task + # check if automatic IAM database authn is supported for database engine + if self._enable_iam_auth and not metadata[ + "database_version" + ].startswith("POSTGRES"): + raise AutoIAMAuthNotSupported( + f"'{metadata['database_version']}' does not support automatic IAM authentication. It is only supported with Cloud SQL Postgres instances." + ) + except Exception: + # cancel ephemeral cert task if exception occurs before it is awaited + ephemeral_task.cancel() + raise + + ephemeral_cert = await ephemeral_task x509 = load_pem_x509_certificate( ephemeral_cert.encode("UTF-8"), default_backend() diff --git a/google/cloud/sql/connector/refresh_utils.py b/google/cloud/sql/connector/refresh_utils.py index d009621b..abb70fc4 100644 --- a/google/cloud/sql/connector/refresh_utils.py +++ b/google/cloud/sql/connector/refresh_utils.py @@ -105,6 +105,7 @@ async def _get_metadata( metadata = { "ip_addresses": {ip["type"]: ip["ipAddress"] for ip in ret_dict["ipAddresses"]}, "server_ca_cert": ret_dict["serverCaCert"]["cert"], + "database_version": ret_dict["databaseVersion"], } return metadata diff --git a/tests/conftest.py b/tests/conftest.py index 11813b88..b2ca0483 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -162,20 +162,23 @@ async def instance( # mock Cloud SQL Admin API calls with aioresponses() as mocked: mocked.get( - "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings", + f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}/connectSettings", status=200, body=mock_instance.connect_settings(), repeat=True, ) mocked.post( - "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert", + f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}:generateEphemeralCert", status=200, body=mock_instance.generate_ephemeral(client_key), repeat=True, ) instance = Instance( - "my-project:my-region:my-instance", "pg8000", keys, event_loop + f"{mock_instance.project}:{mock_instance.region}:{mock_instance.name}", + "pg8000", + keys, + event_loop, ) yield instance diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index 909aa9c4..91ab10d1 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -15,11 +15,13 @@ """ import asyncio import os +import pytest import pymysql import sqlalchemy import logging import google.auth from google.cloud.sql.connector import Connector +from google.cloud.sql.connector.instance import AutoIAMAuthNotSupported import datetime import concurrent.futures from threading import Thread @@ -141,3 +143,34 @@ def test_connector_with_custom_loop() -> None: assert result[0] == 1 # assert that Connector does not start its own thread assert connector._thread is None + + +def test_connector_mysql_iam_auth_error() -> None: + """ + Test that connecting with enable_iam_auth set to True + for MySQL raises exception. + """ + with pytest.raises(AutoIAMAuthNotSupported): + with Connector(enable_iam_auth=True) as connector: + connector.connect( + os.environ["MYSQL_CONNECTION_NAME"], + "pymysql", + user="my-user", + db="my-db", + ) + + +def test_connector_sqlserver_iam_auth_error() -> None: + """ + Test that connecting with enable_iam_auth set to True + for SQL Server raises exception. + """ + with pytest.raises(AutoIAMAuthNotSupported): + with Connector(enable_iam_auth=True) as connector: + connector.connect( + os.environ["SQLSERVER_CONNECTION_NAME"], + "pytds", + user="my-user", + password="my-pass", + db="my-db", + ) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 9e53e506..84b66faa 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -178,11 +178,17 @@ async def create_ssl_context() -> ssl.SSLContext: class FakeCSQLInstance: - def __init__(self, project: str, region: str, name: str) -> None: + def __init__( + self, + project: str = "my-project", + region: str = "my-region", + name: str = "my-instance", + db_version: str = "POSTGRES_14", + ) -> None: self.project = project self.region = region self.name = name - self.db_version = "POSTGRES_14" # arbitrary value + self.db_version = db_version self.ip_addrs = {"PRIMARY": "0.0.0.0", "PRIVATE": "1.1.1.1"} self.backend_type = "SECOND_GEN" diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 63cb6b09..fb31c184 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -20,6 +20,7 @@ from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from google.auth.credentials import Credentials from google.cloud.sql.connector.instance import ( + AutoIAMAuthNotSupported, IPTypes, Instance, CredentialsTypeError, @@ -402,3 +403,21 @@ async def test_ClientResponseError( ) finally: await instance.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_instance", + [ + mocks.FakeCSQLInstance(db_version="SQLSERVER_2019_STANDARD"), + mocks.FakeCSQLInstance(db_version="MYSQL_8_0"), + ], +) +async def test_AutoIAMAuthNotSupportedError(instance: Instance) -> None: + """ + Test that AutoIAMAuthNotSupported exception is raised + for SQL Server and MySQL instances. + """ + instance._enable_iam_auth = True + with pytest.raises(AutoIAMAuthNotSupported): + await instance._current diff --git a/tests/unit/test_refresh_utils.py b/tests/unit/test_refresh_utils.py index 010c059b..dd848307 100644 --- a/tests/unit/test_refresh_utils.py +++ b/tests/unit/test_refresh_utils.py @@ -164,8 +164,10 @@ async def test_get_metadata( instance, ) - assert result["ip_addresses"] is not None and isinstance( - result["server_ca_cert"], str + assert ( + result["ip_addresses"] is not None + and result["database_version"] == "POSTGRES_14" + and isinstance(result["server_ca_cert"], str) )