Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add arg for specifying credentials #226

Merged
merged 10 commits into from
Jan 4, 2022
15 changes: 12 additions & 3 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
IPTypes,
)
from google.cloud.sql.connector.utils import generate_keys

from google.auth.credentials import Credentials
from threading import Thread
from typing import Any, Dict
from typing import Any, Dict, Union

logger = logging.getLogger(name=__name__)

Expand All @@ -43,16 +43,23 @@ class Connector:
Enables IAM based authentication (Postgres only).

:type timeout: int
:param timeout:
:param timeout
The time limit for a connection before raising a TimeoutError.

:type service_account_creds:
Optional [str | google.auth.credentials.Credentials]
:param service_account_creds
Path to JSON service account key file to be used for authentication
or google.auth.credentials.Credentials object.
If not specified, Application Default Credentials are used.
"""

def __init__(
self,
ip_types: IPTypes = IPTypes.PUBLIC,
enable_iam_auth: bool = False,
timeout: int = 30,
service_account_creds: Union[str, Credentials, None] = None,
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True)
Expand All @@ -66,6 +73,7 @@ def __init__(
self._timeout = timeout
self._enable_iam_auth = enable_iam_auth
self._ip_types = ip_types
self._service_account_creds = service_account_creds

def connect(
self, instance_connection_string: str, driver: str, **kwargs: Any
Expand Down Expand Up @@ -112,6 +120,7 @@ def connect(
driver,
self._keys,
self._loop,
self._service_account_creds,
enable_iam_auth,
)
self._instances[instance_connection_string] = icm
Expand Down
64 changes: 54 additions & 10 deletions google/cloud/sql/connector/instance_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import datetime
from enum import Enum
import google.auth
from google.auth.credentials import Credentials
from google.auth.credentials import Credentials, with_scopes_if_required
import google.auth.transport.requests
import OpenSSL
import platform
Expand All @@ -39,6 +39,7 @@
Awaitable,
Dict,
Optional,
Union,
TYPE_CHECKING,
)

Expand Down Expand Up @@ -117,6 +118,15 @@ def __init__(self, *args: Any) -> None:
super(PlatformNotSupportedError, self).__init__(self, *args)


class ServiceAccountCredentialsTypeError(Exception):
"""
Raised when service account credentials type is not proper type.
"""

def __init__(self, *args: Any) -> None:
super(ServiceAccountCredentialsTypeError, self).__init__(self, *args)


class InstanceMetadata:
ip_addrs: Dict[str, Any]
context: ssl.SSLContext
Expand Down Expand Up @@ -177,6 +187,13 @@ class InstanceConnectionManager:
The user agent string to append to SQLAdmin API requests
:type user_agent_string: str

:type service_account_creds:
Optional [str | google.auth.credentials.Credentials]
:param service_account_creds
Path to JSON service account key file to be used for authentication
or google.auth.credentials.Credentials object.
If not specified, Application Default Credentials are used.

:param enable_iam_auth
Enables IAM based authentication for Postgres instances.
:type enable_iam_auth: bool
Expand Down Expand Up @@ -229,6 +246,7 @@ def __init__(
driver_name: str,
keys: concurrent.futures.Future,
loop: asyncio.AbstractEventLoop,
service_account_creds: Union[str, Credentials, None] = None,
enable_iam_auth: bool = False,
) -> None:
# Validate connection string
Expand All @@ -250,7 +268,19 @@ def __init__(
self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}"
self._loop = loop
self._keys = asyncio.wrap_future(keys, loop=self._loop)
self._auth_init()
# validate credentials type
if (
not isinstance(service_account_creds, str)
and not isinstance(service_account_creds, Credentials)
and service_account_creds is not None
):
raise ServiceAccountCredentialsTypeError(
"Arg service_account_creds must be type 'str' (path to valid credentials "
"key file), or type 'google.auth.credentials.Credentials' or None "
"(Application Default Credentials)"
)

self._auth_init(service_account_creds)

self._refresh_rate_limiter = AsyncRateLimiter(
max_capacity=2, rate=1 / 30, loop=self._loop
Expand Down Expand Up @@ -343,17 +373,31 @@ async def _get_instance_data(self) -> InstanceMetadata:
self._enable_iam_auth,
)

def _auth_init(self) -> None:
def _auth_init(self, service_account_creds: Union[str, Credentials, None]) -> None:
"""Creates and assigns a Google Python API service object for
Google Cloud SQL Admin API.
"""

credentials, project = google.auth.default(
scopes=[
"https://www.googleapis.com/auth/sqlservice.admin",
"https://www.googleapis.com/auth/cloud-platform",
]
)
:type service_account_creds: str | google.auth.credentials.Credentials | None
:param service_account_creds
Path to JSON service account key file to be used for authentication
or google.auth.credentials.Credentials object.
If not specified, Application Default Credentials are used.
"""
scopes = [
"https://www.googleapis.com/auth/sqlservice.admin",
"https://www.googleapis.com/auth/cloud-platform",
]
# if Credentials object is passed in, use for authentication
if isinstance(service_account_creds, Credentials):
credentials = with_scopes_if_required(service_account_creds, scopes=scopes)
# if string is passed in, load credentials from file
elif isinstance(service_account_creds, str):
credentials, project = google.auth.load_credentials_from_file(
filename=service_account_creds, scopes=scopes
)
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
# otherwise use application default credentials
else:
credentials, project = google.auth.default(scopes=scopes)

self._credentials = credentials

Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_instance_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@
"""

import asyncio
from unittest.mock import Mock, patch
import datetime
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
from typing import Any
import pytest # noqa F401 Needed to run the tests
from google.auth.credentials import Credentials
from google.cloud.sql.connector.instance_connection_manager import (
InstanceConnectionManager,
ServiceAccountCredentialsTypeError,
)
from google.cloud.sql.connector.utils import generate_keys


@pytest.fixture
def mock_credentials() -> Credentials:
return Mock(spec=Credentials)


@pytest.fixture
def icm(
async_loop: asyncio.AbstractEventLoop, connect_string: str
Expand Down Expand Up @@ -73,6 +81,21 @@ def test_InstanceConnectionManager_init(async_loop: asyncio.AbstractEventLoop) -
)


def test_InstanceConnectionManager_init_bad_service_account_creds(
async_loop: asyncio.AbstractEventLoop,
) -> None:
"""
Test to check whether the __init__ method of InstanceConnectionManager
throws proper error for bad service_account_creds arg type.
"""
connect_string = "test-project:test-region:test-instance"
keys = asyncio.run_coroutine_threadsafe(generate_keys(), async_loop)
with pytest.raises(ServiceAccountCredentialsTypeError):
assert InstanceConnectionManager(
connect_string, "pymysql", keys, async_loop, service_account_creds=1
)


@pytest.mark.asyncio
async def test_perform_refresh_replaces_result(
icm: InstanceConnectionManager, test_rate_limiter: AsyncRateLimiter
Expand Down Expand Up @@ -171,3 +194,50 @@ async def test_force_refresh_cancels_pending_refresh(

assert pending_refresh.cancelled() is True
assert isinstance(icm._current.result(), MockMetadata)


def test_auth_init_with_credentials_object(
icm: InstanceConnectionManager, mock_credentials: Credentials
) -> None:
"""
Test that InstanceConnectionManager's _auth_init initializes _credentials
when passed a google.auth.credentials.Credentials object.
"""
setattr(icm, "_credentials", None)
with patch(
"google.cloud.sql.connector.instance_connection_manager.with_scopes_if_required"
) as mock_auth:
mock_auth.return_value = mock_credentials
icm._auth_init(service_account_creds=mock_credentials)
assert isinstance(icm._credentials, Credentials)
mock_auth.assert_called_once()


def test_auth_init_with_credentials_file(
icm: InstanceConnectionManager, mock_credentials: Credentials
) -> None:
"""
Test that InstanceConnectionManager's _auth_init initializes _credentials
when passed a service account key file.
"""
setattr(icm, "_credentials", None)
with patch("google.auth.load_credentials_from_file") as mock_auth:
mock_auth.return_value = mock_credentials, None
icm._auth_init(service_account_creds="credentials.json")
assert isinstance(icm._credentials, Credentials)
mock_auth.assert_called_once()


def test_auth_init_with_default_credentials(
icm: InstanceConnectionManager, mock_credentials: Credentials
) -> None:
"""
Test that InstanceConnectionManager's _auth_init initializes _credentials
with application default credentials when credentials are not specified.
"""
setattr(icm, "_credentials", None)
with patch("google.auth.default") as mock_auth:
mock_auth.return_value = mock_credentials, None
icm._auth_init(service_account_creds=None)
assert isinstance(icm._credentials, Credentials)
mock_auth.assert_called_once()