From f52ba7ec4aa916bc6bb0062eb1b29ac0611b45f5 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Mon, 31 Jan 2022 14:55:52 -0500 Subject: [PATCH] fix: make asyncio.Lock() run in background thread (#252) --- .../connector/instance_connection_manager.py | 16 ++++--- google/cloud/sql/connector/rate_limiter.py | 2 +- tests/conftest.py | 3 +- tests/system/test_connector_object.py | 24 ++++++++++ .../unit/test_instance_connection_manager.py | 11 ++++- tests/unit/test_rate_limiter.py | 45 +++++++++++++++---- 6 files changed, 82 insertions(+), 19 deletions(-) diff --git a/google/cloud/sql/connector/instance_connection_manager.py b/google/cloud/sql/connector/instance_connection_manager.py index eb7365bb..f32cd0c1 100644 --- a/google/cloud/sql/connector/instance_connection_manager.py +++ b/google/cloud/sql/connector/instance_connection_manager.py @@ -233,6 +233,7 @@ def _client_session(self) -> aiohttp.ClientSession: _project: str _region: str + _refresh_rate_limiter: AsyncRateLimiter _refresh_in_progress: asyncio.locks.Event _current: asyncio.Task # task wraps coroutine that returns InstanceMetadata _next: asyncio.Task # task wraps coroutine that returns another task @@ -274,17 +275,18 @@ def __init__( self._auth_init(credentials) - self._refresh_rate_limiter = AsyncRateLimiter( - max_capacity=2, rate=1 / 30, loop=self._loop - ) - - async def _set_instance_data() -> None: - logger.debug("Updating instance data") + async def _async_init() -> None: + """Initialize InstanceConnectionManager's variables that require the + event loop running in background thread. + """ + self._refresh_rate_limiter = AsyncRateLimiter( + max_capacity=2, rate=1 / 30, loop=self._loop + ) self._refresh_in_progress = asyncio.locks.Event() self._current = self._loop.create_task(self._get_instance_data()) self._next = self._loop.create_task(self._schedule_refresh()) - init_future = asyncio.run_coroutine_threadsafe(_set_instance_data(), self._loop) + init_future = asyncio.run_coroutine_threadsafe(_async_init(), self._loop) init_future.result() def __del__(self) -> None: diff --git a/google/cloud/sql/connector/rate_limiter.py b/google/cloud/sql/connector/rate_limiter.py index cc119e17..21a3ff74 100644 --- a/google/cloud/sql/connector/rate_limiter.py +++ b/google/cloud/sql/connector/rate_limiter.py @@ -46,9 +46,9 @@ def __init__( self.rate = rate self.max_capacity = max_capacity self._loop = loop or asyncio.get_event_loop() - self._lock = asyncio.Lock() self._tokens: float = max_capacity self._last_token_update = self._loop.time() + self._lock = asyncio.Lock() def _update_token_count(self) -> None: """ diff --git a/tests/conftest.py b/tests/conftest.py index 84959d09..b86f4965 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,11 +51,10 @@ def async_loop() -> Generator: Creates a loop in a background thread and returns it to use for testing. """ loop = asyncio.new_event_loop() - thr = threading.Thread(target=loop.run_forever) + thr = threading.Thread(target=loop.run_forever, daemon=True) thr.start() yield loop loop.stop() - thr.join() @pytest.fixture diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index 9d84028e..123a3552 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -19,6 +19,8 @@ import logging import google.auth from google.cloud.sql.connector import connector +import datetime +import concurrent.futures def init_connection_engine( @@ -80,3 +82,25 @@ def test_multiple_connectors() -> None: ) except Exception as e: logging.exception("Failed to connect with multiple Connector objects!", e) + + +def test_connector_in_ThreadPoolExecutor() -> None: + """Test that Connector can connect from ThreadPoolExecutor thread. + This helps simulate how connector works in Cloud Run and Cloud Functions. + """ + + def get_time() -> datetime.datetime: + """Helper method for getting current time from database.""" + default_connector = connector.Connector() + pool = init_connection_engine(default_connector) + + # connect to database and get current time + with pool.connect() as conn: + current_time = conn.execute("SELECT NOW()").fetchone() + return current_time[0] + + # try running connector in ThreadPoolExecutor as Cloud Run does + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(get_time) + return_value = future.result() + assert isinstance(return_value, datetime.datetime) diff --git a/tests/unit/test_instance_connection_manager.py b/tests/unit/test_instance_connection_manager.py index bcc4f02e..33cd9735 100644 --- a/tests/unit/test_instance_connection_manager.py +++ b/tests/unit/test_instance_connection_manager.py @@ -46,7 +46,16 @@ def icm( @pytest.fixture def test_rate_limiter(async_loop: asyncio.AbstractEventLoop) -> AsyncRateLimiter: - return AsyncRateLimiter(max_capacity=1, rate=1 / 2, loop=async_loop) + async def rate_limiter_in_loop( + async_loop: asyncio.AbstractEventLoop, + ) -> AsyncRateLimiter: + return AsyncRateLimiter(max_capacity=1, rate=1 / 2, loop=async_loop) + + limiter_future = asyncio.run_coroutine_threadsafe( + rate_limiter_in_loop(async_loop), async_loop + ) + limiter = limiter_future.result() + return limiter class MockMetadata: diff --git a/tests/unit/test_rate_limiter.py b/tests/unit/test_rate_limiter.py index 2a33cfe4..6e9eee2b 100644 --- a/tests/unit/test_rate_limiter.py +++ b/tests/unit/test_rate_limiter.py @@ -22,41 +22,70 @@ ) +async def rate_limiter_in_loop( + max_capacity: int, rate: float, loop: asyncio.AbstractEventLoop +) -> AsyncRateLimiter: + """Helper function to create AsyncRateLimiter object inside given event loop.""" + limiter = AsyncRateLimiter(max_capacity=max_capacity, rate=rate, loop=loop) + return limiter + + @pytest.mark.asyncio -async def test_rate_limiter_throttles_requests() -> None: +async def test_rate_limiter_throttles_requests( + async_loop: asyncio.AbstractEventLoop, +) -> None: + """Test to check whether rate limiter will throttle incoming requests.""" counter = 0 # allow 2 requests to go through every 5 seconds - limiter = AsyncRateLimiter(max_capacity=2, rate=1 / 5) + limiter_future = asyncio.run_coroutine_threadsafe( + rate_limiter_in_loop(max_capacity=2, rate=1 / 5, loop=async_loop), async_loop + ) + limiter = limiter_future.result() async def increment() -> None: await limiter.acquire() nonlocal counter counter += 1 - tasks = [increment() for _ in range(10)] + # create 10 tasks calling increment() + tasks = [async_loop.create_task(increment()) for _ in range(10)] - done, pending = await asyncio.wait(tasks, timeout=11) + # wait 10 seconds and check tasks + done, pending = asyncio.run_coroutine_threadsafe( + asyncio.wait(tasks, timeout=11), async_loop + ).result() + # verify 4 tasks completed and 6 pending due to rate limiter assert counter == 4 assert len(done) == 4 assert len(pending) == 6 @pytest.mark.asyncio -async def test_rate_limiter_completes_all_tasks() -> None: +async def test_rate_limiter_completes_all_tasks( + async_loop: asyncio.AbstractEventLoop, +) -> None: + """Test to check all requests will go through rate limiter successfully.""" counter = 0 # allow 1 request to go through per second - limiter = AsyncRateLimiter(max_capacity=1, rate=1) + limiter_future = asyncio.run_coroutine_threadsafe( + rate_limiter_in_loop(max_capacity=1, rate=1, loop=async_loop), async_loop + ) + limiter = limiter_future.result() async def increment() -> None: await limiter.acquire() nonlocal counter counter += 1 - tasks = [increment() for _ in range(10)] + # create 10 tasks calling increment() + tasks = [async_loop.create_task(increment()) for _ in range(10)] - done, pending = await asyncio.wait(tasks, timeout=30) + done, pending = asyncio.run_coroutine_threadsafe( + asyncio.wait(tasks, timeout=30), async_loop + ).result() + # verify all tasks done and none pending assert counter == 10 assert len(done) == 10 assert len(pending) == 0