diff --git a/README.md b/README.md index d7688ef4..88d1a420 100644 --- a/README.md +++ b/README.md @@ -185,14 +185,15 @@ defaults for each connection to make, you can initialize a `Connector` object as follows: ```python -from google.cloud.sql.connector import Connector, IPTypes +from google.cloud.sql.connector import Connector # Note: all parameters below are optional connector = Connector( ip_type="public", # can also be "private" or "psc" enable_iam_auth=False, timeout=30, - credentials=custom_creds # google.auth.credentials.Credentials + credentials=custom_creds, # google.auth.credentials.Credentials + refresh_strategy="lazy", # can be "lazy" or "background" ) ``` @@ -254,6 +255,21 @@ with Connector() as connector: print(row) ``` +### Configuring a Lazy Refresh (Cloud Run, Cloud Functions etc.) + +The Connector's `refresh_strategy` argument can be set to `"lazy"` to configure +the Python Connector to retrieve connection info lazily and as-needed. +Otherwise, a background refresh cycle runs to retrive the connection info +periodically. This setting is useful in environments where the CPU may be +throttled outside of a request context, e.g., Cloud Run, Cloud Functions, etc. + +To set the refresh strategy, set the `refresh_strategy` keyword argument when +initializing a `Connector`: + +```python +connector = Connector(refresh_strategy="lazy") +``` + ### Specifying IP Address Type The Cloud SQL Python Connector can be used to connect to Cloud SQL instances @@ -277,7 +293,7 @@ conn = connector.connect( ``` > [!IMPORTANT] -> +> > If specifying Private IP or Private Service Connect (PSC), your application must be > attached to the proper VPC network to connect to your Cloud SQL instance. For most > applications this will require the use of a [VPC Connector][vpc-connector]. @@ -355,6 +371,14 @@ The Python Connector can be used alongside popular Python web frameworks such as Flask, FastAPI, etc, to integrate Cloud SQL databases within your web applications. +> [!NOTE] +> +> For serverless environments such as Cloud Functions, Cloud Run, etc, it may be +> beneficial to initialize the `Connector` with the lazy refresh strategy. +> i.e. `Connector(refresh_strategy="lazy")` +> +> See [Configuring a Lazy Refresh](#configuring-a-lazy-refresh-cloud-run-cloud-functions-etc) + #### Flask-SQLAlchemy [Flask-SQLAlchemy](https://flask-sqlalchemy.palletsprojects.com/en/2.x/) diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 22ee9f32..b58c7760 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -17,6 +17,13 @@ from google.cloud.sql.connector.connector import Connector from google.cloud.sql.connector.connector import create_async_connector from google.cloud.sql.connector.instance import IPTypes +from google.cloud.sql.connector.instance import RefreshStrategy from google.cloud.sql.connector.version import __version__ -__all__ = ["__version__", "create_async_connector", "Connector", "IPTypes"] +__all__ = [ + "__version__", + "create_async_connector", + "Connector", + "IPTypes", + "RefreshStrategy", +] diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index a38fcafd..ad396760 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -22,7 +22,7 @@ import socket from threading import Thread from types import TracebackType -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Type, Union import google.auth from google.auth.credentials import Credentials @@ -34,6 +34,8 @@ from google.cloud.sql.connector.exceptions import DnsNameResolutionError from google.cloud.sql.connector.instance import IPTypes from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.instance import RefreshStrategy +from google.cloud.sql.connector.lazy import LazyRefreshCache import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds @@ -62,6 +64,7 @@ def __init__( sqladmin_api_endpoint: Optional[str] = None, user_agent: Optional[str] = None, universe_domain: Optional[str] = None, + refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, ) -> None: """Initializes a Connector instance. @@ -98,6 +101,11 @@ def __init__( universe_domain (str): The universe domain for Cloud SQL API calls. Default: "googleapis.com". + refresh_strategy (str | RefreshStrategy): The default refresh strategy + used to refresh SSL/TLS cert and instance metadata. Can be one + of the following: RefreshStrategy.LAZY ("LAZY") or + RefreshStrategy.BACKGROUND ("BACKGROUND"). + Default: RefreshStrategy.BACKGROUND """ # if event loop is given, use for background tasks if loop: @@ -113,7 +121,7 @@ def __init__( asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), loop=self._loop, ) - self._cache: Dict[str, RefreshAheadCache] = {} + self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} self._client: Optional[CloudSQLClient] = None # initialize credentials @@ -139,6 +147,10 @@ def __init__( if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) self._ip_type = ip_type + # if refresh_strategy is str, convert to RefreshStrategy enum + if isinstance(refresh_strategy, str): + refresh_strategy = RefreshStrategy._from_str(refresh_strategy) + self._refresh_strategy = refresh_strategy self._universe_domain = universe_domain # construct service endpoint for Cloud SQL Admin API calls if not sqladmin_api_endpoint: @@ -265,12 +277,20 @@ async def connect_async( "connector.Connector object." ) else: - cache = RefreshAheadCache( - instance_connection_string, - self._client, - self._keys, - enable_iam_auth, - ) + if self._refresh_strategy == RefreshStrategy.LAZY: + cache = LazyRefreshCache( + instance_connection_string, + self._client, + self._keys, + enable_iam_auth, + ) + else: + cache = RefreshAheadCache( + instance_connection_string, + self._client, + self._keys, + enable_iam_auth, + ) self._cache[instance_connection_string] = cache connect_func = { diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index a8e9f1bc..6251ae7c 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -52,6 +52,23 @@ def _parse_instance_connection_name(connection_name: str) -> Tuple[str, str, str return connection_name_split[1], connection_name_split[3], connection_name_split[4] +class RefreshStrategy(Enum): + LAZY: str = "LAZY" + BACKGROUND: str = "BACKGROUND" + + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError( + f"Incorrect value for refresh_strategy, got '{value}'. Want one of: " + f"{', '.join([repr(m.value) for m in cls])}." + ) + + @classmethod + def _from_str(cls, refresh_strategy: str) -> RefreshStrategy: + """Convert refresh strategy from a str into RefreshStrategy.""" + return cls(refresh_strategy.upper()) + + class IPTypes(Enum): PUBLIC: str = "PRIMARY" PRIVATE: str = "PRIVATE" diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py new file mode 100644 index 00000000..9b8cfa24 --- /dev/null +++ b/google/cloud/sql/connector/lazy.py @@ -0,0 +1,132 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from datetime import datetime +from datetime import timedelta +from datetime import timezone +import logging +from typing import Optional + +from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.instance import _parse_instance_connection_name +from google.cloud.sql.connector.refresh_utils import _refresh_buffer + +logger = logging.getLogger(name=__name__) + + +class LazyRefreshCache: + """Cache that refreshes connection info when a caller requests a connection. + + Only refreshes the cache when a new connection is requested and the current + certificate is close to or already expired. + + This is the recommended option for serverless environments. + """ + + def __init__( + self, + instance_connection_string: str, + client: CloudSQLClient, + keys: asyncio.Future, + enable_iam_auth: bool = False, + ) -> None: + """Initializes a LazyRefreshCache instance. + + Args: + instance_connection_string (str): The Cloud SQL Instance's + connection string (also known as an instance connection name). + client (CloudSQLClient): The Cloud SQL Client instance. + keys (asyncio.Future): A future to the client's public-private key + pair. + enable_iam_auth (bool): Enables automatic IAM database authentication + (Postgres and MySQL) as the default authentication method for all + connections. + """ + # validate and parse instance connection name + self._project, self._region, self._instance = _parse_instance_connection_name( + instance_connection_string + ) + self._instance_connection_string = instance_connection_string + + self._enable_iam_auth = enable_iam_auth + self._keys = keys + self._client = client + self._lock = asyncio.Lock() + self._cached: Optional[ConnectionInfo] = None + self._needs_refresh = False + + async def force_refresh(self) -> None: + """ + Invalidates the cache and configures the next call to + connect_info() to retrieve a fresh ConnectionInfo instance. + """ + async with self._lock: + self._needs_refresh = True + + async def connect_info(self) -> ConnectionInfo: + """Retrieves ConnectionInfo instance for establishing a secure + connection to the Cloud SQL instance. + """ + async with self._lock: + # If connection info is cached, check expiration. + # Pad expiration with a buffer to give the client plenty of time to + # establish a connection to the server with the certificate. + if ( + self._cached + and not self._needs_refresh + and datetime.now(timezone.utc) + < (self._cached.expiration - timedelta(seconds=_refresh_buffer)) + ): + logger.debug( + f"['{self._instance_connection_string}']: Connection info " + "is still valid, using cached info" + ) + return self._cached + logger.debug( + f"['{self._instance_connection_string}']: Connection info " + "refresh operation started" + ) + try: + conn_info = await self._client.get_connection_info( + self._project, + self._region, + self._instance, + self._keys, + self._enable_iam_auth, + ) + except Exception as e: + logger.debug( + f"['{self._instance_connection_string}']: Connection info " + f"refresh operation failed: {str(e)}" + ) + raise + logger.debug( + f"['{self._instance_connection_string}']: Connection info " + "refresh operation completed successfully" + ) + logger.debug( + f"['{self._instance_connection_string}']: Current certificate " + f"expiration = {str(conn_info.expiration)}" + ) + self._cached = conn_info + self._needs_refresh = False + return conn_info + + async def close(self) -> None: + """Close is a no-op and provided purely for a consistent interface with + other cache types. + """ + pass diff --git a/tests/system/test_pg8000_iam_auth.py b/tests/system/test_pg8000_iam_auth.py index f3c1c3cd..cfe385b5 100644 --- a/tests/system/test_pg8000_iam_auth.py +++ b/tests/system/test_pg8000_iam_auth.py @@ -13,37 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os from typing import Generator -import uuid -# [START cloud_sql_connector_postgres_pg8000_iam_auth] import pg8000 import pytest import sqlalchemy from google.cloud.sql.connector import Connector -# [END cloud_sql_connector_postgres_pg8000_iam_auth] - -table_name = f"books_{uuid.uuid4().hex}" - -# [START cloud_sql_connector_postgres_pg8000_iam_auth] # The Cloud SQL Python Connector can be used along with SQLAlchemy using the # 'creator' argument to 'create_engine' -def init_connection_engine() -> sqlalchemy.engine.Engine: +def init_connection_engine(connector: Connector) -> sqlalchemy.engine.Engine: # initialize Connector object for connections to Cloud SQL def getconn() -> pg8000.dbapi.Connection: - with Connector() as connector: - conn: pg8000.dbapi.Connection = connector.connect( - os.environ["POSTGRES_IAM_CONNECTION_NAME"], - "pg8000", - user=os.environ["POSTGRES_IAM_USER"], - db=os.environ["POSTGRES_DB"], - enable_iam_auth=True, - ) - return conn + conn: pg8000.dbapi.Connection = connector.connect( + os.environ["POSTGRES_IAM_CONNECTION_NAME"], + "pg8000", + user=os.environ["POSTGRES_IAM_USER"], + db=os.environ["POSTGRES_DB"], + enable_iam_auth=True, + ) + return conn # create SQLAlchemy connection pool pool = sqlalchemy.create_engine( @@ -55,38 +48,39 @@ def getconn() -> pg8000.dbapi.Connection: return pool -# [END cloud_sql_connector_postgres_pg8000_iam_auth] +@pytest.fixture +def pool() -> Generator: + connector = Connector() + pool = init_connection_engine(connector) + yield pool -@pytest.fixture(name="pool") -def setup() -> Generator: - pool = init_connection_engine() + connector.close() - with pool.connect() as conn: - conn.execute( - sqlalchemy.text( - f"CREATE TABLE IF NOT EXISTS {table_name}" - " ( id CHAR(20) NOT NULL, title TEXT NOT NULL );" - ) - ) - yield pool +@pytest.fixture +def lazy_pool() -> Generator: + connector = Connector(refresh_strategy="lazy") + pool = init_connection_engine(connector) - with pool.connect() as conn: - conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS {table_name}")) + yield pool + connector.close() -def test_pooled_connection_with_pg8000_iam_auth(pool: sqlalchemy.engine.Engine) -> None: - insert_stmt = sqlalchemy.text( - f"INSERT INTO {table_name} (id, title) VALUES (:id, :title)", - ) - with pool.connect() as conn: - conn.execute(insert_stmt, parameters={"id": "book1", "title": "Book One"}) - conn.execute(insert_stmt, parameters={"id": "book2", "title": "Book Two"}) - select_stmt = sqlalchemy.text(f"SELECT title FROM {table_name} ORDER BY ID;") +def test_pooled_connection_with_pg8000_iam_auth( + pool: sqlalchemy.engine.Engine, +) -> None: with pool.connect() as conn: - rows = conn.execute(select_stmt).fetchall() - titles = [row[0] for row in rows] - - assert titles == ["Book One", "Book Two"] + result = conn.execute(sqlalchemy.text("SELECT 1;")).fetchone() + assert isinstance(result[0], int) + assert result[0] == 1 + + +def test_lazy_connection_with_pg8000_iam_auth( + lazy_pool: sqlalchemy.engine.Engine, +) -> None: + with lazy_pool.connect() as conn: + result = conn.execute(sqlalchemy.text("SELECT 1;")).fetchone() + assert isinstance(result[0], int) + assert result[0] == 1 diff --git a/tests/unit/test_lazy.py b/tests/unit/test_lazy.py new file mode 100644 index 00000000..27cd80b4 --- /dev/null +++ b/tests/unit/test_lazy.py @@ -0,0 +1,64 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.utils import generate_keys + + +async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> None: + """ + Test that LazyRefreshCache.connect_info works as expected. + """ + keys = asyncio.create_task(generate_keys()) + cache = LazyRefreshCache( + "test-project:test-region:test-instance", + client=fake_client, + keys=keys, + enable_iam_auth=False, + ) + # check that cached connection info is empty + assert cache._cached is None + conn_info = await cache.connect_info() + # check that cached connection info is now set + assert isinstance(cache._cached, ConnectionInfo) + # check that calling connect_info uses cached info + conn_info2 = await cache.connect_info() + assert conn_info2 == conn_info + + +async def test_LazyRefreshCache_force_refresh(fake_client: CloudSQLClient) -> None: + """ + Test that LazyRefreshCache.force_refresh works as expected. + """ + keys = asyncio.create_task(generate_keys()) + cache = LazyRefreshCache( + "test-project:test-region:test-instance", + client=fake_client, + keys=keys, + enable_iam_auth=False, + ) + conn_info = await cache.connect_info() + # check that cached connection info is now set + assert isinstance(cache._cached, ConnectionInfo) + await cache.force_refresh() + # check that calling connect_info after force_refresh gets new ConnectionInfo + conn_info2 = await cache.connect_info() + # check that new connection info was retrieved + assert conn_info2 != conn_info + assert cache._cached == conn_info2 + await cache.close()