diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 06168b29..1fc092dd 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -21,9 +21,9 @@ IPTypes, ) from google.cloud.sql.connector.utils import generate_keys - +from google.auth.credentials import Credentials from threading import Thread -from typing import Any, Dict +from typing import Any, Dict, Optional logger = logging.getLogger(name=__name__) @@ -43,9 +43,13 @@ class Connector: Enables IAM based authentication (Postgres only). :type timeout: int - :param timeout: + :param timeout The time limit for a connection before raising a TimeoutError. + :type credentials: google.auth.credentials.Credentials + :param credentials + Credentials object used to authenticate connections to Cloud SQL server. + If not specified, Application Default Credentials are used. """ def __init__( @@ -53,6 +57,7 @@ def __init__( ip_types: IPTypes = IPTypes.PUBLIC, enable_iam_auth: bool = False, timeout: int = 30, + credentials: Optional[Credentials] = None, ) -> None: self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True) @@ -66,6 +71,7 @@ def __init__( self._timeout = timeout self._enable_iam_auth = enable_iam_auth self._ip_types = ip_types + self._credentials = credentials def connect( self, instance_connection_string: str, driver: str, **kwargs: Any @@ -112,6 +118,7 @@ def connect( driver, self._keys, self._loop, + self._credentials, enable_iam_auth, ) self._instances[instance_connection_string] = icm diff --git a/google/cloud/sql/connector/instance_connection_manager.py b/google/cloud/sql/connector/instance_connection_manager.py index e4ddb19d..eb7365bb 100644 --- a/google/cloud/sql/connector/instance_connection_manager.py +++ b/google/cloud/sql/connector/instance_connection_manager.py @@ -27,7 +27,7 @@ import datetime from enum import Enum import google.auth -from google.auth.credentials import Credentials +from google.auth.credentials import Credentials, with_scopes_if_required import google.auth.transport.requests import OpenSSL import platform @@ -117,6 +117,15 @@ def __init__(self, *args: Any) -> None: super(PlatformNotSupportedError, self).__init__(self, *args) +class CredentialsTypeError(Exception): + """ + Raised when credentials parameter is not proper type. + """ + + def __init__(self, *args: Any) -> None: + super(CredentialsTypeError, self).__init__(self, *args) + + class InstanceMetadata: ip_addrs: Dict[str, Any] context: ssl.SSLContext @@ -177,6 +186,11 @@ class InstanceConnectionManager: The user agent string to append to SQLAdmin API requests :type user_agent_string: str + :type credentials: google.auth.credentials.Credentials + :param credentials + Credentials object used to authenticate connections to Cloud SQL server. + If not specified, Application Default Credentials are used. + :param enable_iam_auth Enables IAM based authentication for Postgres instances. :type enable_iam_auth: bool @@ -229,6 +243,7 @@ def __init__( driver_name: str, keys: concurrent.futures.Future, loop: asyncio.AbstractEventLoop, + credentials: Optional[Credentials] = None, enable_iam_auth: bool = False, ) -> None: # Validate connection string @@ -250,7 +265,14 @@ def __init__( self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}" self._loop = loop self._keys = asyncio.wrap_future(keys, loop=self._loop) - self._auth_init() + # validate credentials type + if not isinstance(credentials, Credentials) and credentials is not None: + raise CredentialsTypeError( + "Arg credentials must be type 'google.auth.credentials.Credentials' " + "or None (to use Application Default Credentials)" + ) + + self._auth_init(credentials) self._refresh_rate_limiter = AsyncRateLimiter( max_capacity=2, rate=1 / 30, loop=self._loop @@ -343,17 +365,25 @@ async def _get_instance_data(self) -> InstanceMetadata: self._enable_iam_auth, ) - def _auth_init(self) -> None: + def _auth_init(self, credentials: Optional[Credentials]) -> None: """Creates and assigns a Google Python API service object for Google Cloud SQL Admin API. - """ - credentials, project = google.auth.default( - scopes=[ - "https://www.googleapis.com/auth/sqlservice.admin", - "https://www.googleapis.com/auth/cloud-platform", - ] - ) + :type credentials: google.auth.credentials.Credentials + :param credentials + Credentials object used to authenticate connections to Cloud SQL server. + If not specified, Application Default Credentials are used. + """ + scopes = [ + "https://www.googleapis.com/auth/sqlservice.admin", + "https://www.googleapis.com/auth/cloud-platform", + ] + # if Credentials object is passed in, use for authentication + if isinstance(credentials, Credentials): + credentials = with_scopes_if_required(credentials, scopes=scopes) + # otherwise use application default credentials + else: + credentials, project = google.auth.default(scopes=scopes) self._credentials = credentials diff --git a/tests/unit/test_instance_connection_manager.py b/tests/unit/test_instance_connection_manager.py index 2a906ecc..830c0aef 100644 --- a/tests/unit/test_instance_connection_manager.py +++ b/tests/unit/test_instance_connection_manager.py @@ -15,16 +15,24 @@ """ import asyncio +from unittest.mock import Mock, patch import datetime from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from typing import Any import pytest # noqa F401 Needed to run the tests +from google.auth.credentials import Credentials from google.cloud.sql.connector.instance_connection_manager import ( InstanceConnectionManager, + CredentialsTypeError, ) from google.cloud.sql.connector.utils import generate_keys +@pytest.fixture +def mock_credentials() -> Credentials: + return Mock(spec=Credentials) + + @pytest.fixture def icm( async_loop: asyncio.AbstractEventLoop, connect_string: str @@ -73,6 +81,21 @@ def test_InstanceConnectionManager_init(async_loop: asyncio.AbstractEventLoop) - ) +def test_InstanceConnectionManager_init_bad_credentials( + async_loop: asyncio.AbstractEventLoop, +) -> None: + """ + Test to check whether the __init__ method of InstanceConnectionManager + throws proper error for bad credentials arg type. + """ + connect_string = "test-project:test-region:test-instance" + keys = asyncio.run_coroutine_threadsafe(generate_keys(), async_loop) + with pytest.raises(CredentialsTypeError): + assert InstanceConnectionManager( + connect_string, "pymysql", keys, async_loop, credentials=1 + ) + + @pytest.mark.asyncio async def test_perform_refresh_replaces_result( icm: InstanceConnectionManager, test_rate_limiter: AsyncRateLimiter @@ -171,3 +194,35 @@ async def test_force_refresh_cancels_pending_refresh( assert pending_refresh.cancelled() is True assert isinstance(icm._current.result(), MockMetadata) + + +def test_auth_init_with_credentials_object( + icm: InstanceConnectionManager, mock_credentials: Credentials +) -> None: + """ + Test that InstanceConnectionManager's _auth_init initializes _credentials + when passed a google.auth.credentials.Credentials object. + """ + setattr(icm, "_credentials", None) + with patch( + "google.cloud.sql.connector.instance_connection_manager.with_scopes_if_required" + ) as mock_auth: + mock_auth.return_value = mock_credentials + icm._auth_init(credentials=mock_credentials) + assert isinstance(icm._credentials, Credentials) + mock_auth.assert_called_once() + + +def test_auth_init_with_default_credentials( + icm: InstanceConnectionManager, mock_credentials: Credentials +) -> None: + """ + Test that InstanceConnectionManager's _auth_init initializes _credentials + with application default credentials when credentials are not specified. + """ + setattr(icm, "_credentials", None) + with patch("google.auth.default") as mock_auth: + mock_auth.return_value = mock_credentials, None + icm._auth_init(credentials=None) + assert isinstance(icm._credentials, Credentials) + mock_auth.assert_called_once()