Skip to content

Commit

Permalink
fix: make asyncio.Lock() run in background thread (#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon authored Jan 31, 2022
1 parent 7a8a130 commit f52ba7e
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 19 deletions.
16 changes: 9 additions & 7 deletions google/cloud/sql/connector/instance_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def _client_session(self) -> aiohttp.ClientSession:
_project: str
_region: str

_refresh_rate_limiter: AsyncRateLimiter
_refresh_in_progress: asyncio.locks.Event
_current: asyncio.Task # task wraps coroutine that returns InstanceMetadata
_next: asyncio.Task # task wraps coroutine that returns another task
Expand Down Expand Up @@ -274,17 +275,18 @@ def __init__(

self._auth_init(credentials)

self._refresh_rate_limiter = AsyncRateLimiter(
max_capacity=2, rate=1 / 30, loop=self._loop
)

async def _set_instance_data() -> None:
logger.debug("Updating instance data")
async def _async_init() -> None:
"""Initialize InstanceConnectionManager's variables that require the
event loop running in background thread.
"""
self._refresh_rate_limiter = AsyncRateLimiter(
max_capacity=2, rate=1 / 30, loop=self._loop
)
self._refresh_in_progress = asyncio.locks.Event()
self._current = self._loop.create_task(self._get_instance_data())
self._next = self._loop.create_task(self._schedule_refresh())

init_future = asyncio.run_coroutine_threadsafe(_set_instance_data(), self._loop)
init_future = asyncio.run_coroutine_threadsafe(_async_init(), self._loop)
init_future.result()

def __del__(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/sql/connector/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def __init__(
self.rate = rate
self.max_capacity = max_capacity
self._loop = loop or asyncio.get_event_loop()
self._lock = asyncio.Lock()
self._tokens: float = max_capacity
self._last_token_update = self._loop.time()
self._lock = asyncio.Lock()

def _update_token_count(self) -> None:
"""
Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ def async_loop() -> Generator:
Creates a loop in a background thread and returns it to use for testing.
"""
loop = asyncio.new_event_loop()
thr = threading.Thread(target=loop.run_forever)
thr = threading.Thread(target=loop.run_forever, daemon=True)
thr.start()
yield loop
loop.stop()
thr.join()


@pytest.fixture
Expand Down
24 changes: 24 additions & 0 deletions tests/system/test_connector_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import logging
import google.auth
from google.cloud.sql.connector import connector
import datetime
import concurrent.futures


def init_connection_engine(
Expand Down Expand Up @@ -80,3 +82,25 @@ def test_multiple_connectors() -> None:
)
except Exception as e:
logging.exception("Failed to connect with multiple Connector objects!", e)


def test_connector_in_ThreadPoolExecutor() -> None:
"""Test that Connector can connect from ThreadPoolExecutor thread.
This helps simulate how connector works in Cloud Run and Cloud Functions.
"""

def get_time() -> datetime.datetime:
"""Helper method for getting current time from database."""
default_connector = connector.Connector()
pool = init_connection_engine(default_connector)

# connect to database and get current time
with pool.connect() as conn:
current_time = conn.execute("SELECT NOW()").fetchone()
return current_time[0]

# try running connector in ThreadPoolExecutor as Cloud Run does
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(get_time)
return_value = future.result()
assert isinstance(return_value, datetime.datetime)
11 changes: 10 additions & 1 deletion tests/unit/test_instance_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,16 @@ def icm(

@pytest.fixture
def test_rate_limiter(async_loop: asyncio.AbstractEventLoop) -> AsyncRateLimiter:
return AsyncRateLimiter(max_capacity=1, rate=1 / 2, loop=async_loop)
async def rate_limiter_in_loop(
async_loop: asyncio.AbstractEventLoop,
) -> AsyncRateLimiter:
return AsyncRateLimiter(max_capacity=1, rate=1 / 2, loop=async_loop)

limiter_future = asyncio.run_coroutine_threadsafe(
rate_limiter_in_loop(async_loop), async_loop
)
limiter = limiter_future.result()
return limiter


class MockMetadata:
Expand Down
45 changes: 37 additions & 8 deletions tests/unit/test_rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,70 @@
)


async def rate_limiter_in_loop(
max_capacity: int, rate: float, loop: asyncio.AbstractEventLoop
) -> AsyncRateLimiter:
"""Helper function to create AsyncRateLimiter object inside given event loop."""
limiter = AsyncRateLimiter(max_capacity=max_capacity, rate=rate, loop=loop)
return limiter


@pytest.mark.asyncio
async def test_rate_limiter_throttles_requests() -> None:
async def test_rate_limiter_throttles_requests(
async_loop: asyncio.AbstractEventLoop,
) -> None:
"""Test to check whether rate limiter will throttle incoming requests."""
counter = 0
# allow 2 requests to go through every 5 seconds
limiter = AsyncRateLimiter(max_capacity=2, rate=1 / 5)
limiter_future = asyncio.run_coroutine_threadsafe(
rate_limiter_in_loop(max_capacity=2, rate=1 / 5, loop=async_loop), async_loop
)
limiter = limiter_future.result()

async def increment() -> None:
await limiter.acquire()
nonlocal counter
counter += 1

tasks = [increment() for _ in range(10)]
# create 10 tasks calling increment()
tasks = [async_loop.create_task(increment()) for _ in range(10)]

done, pending = await asyncio.wait(tasks, timeout=11)
# wait 10 seconds and check tasks
done, pending = asyncio.run_coroutine_threadsafe(
asyncio.wait(tasks, timeout=11), async_loop
).result()

# verify 4 tasks completed and 6 pending due to rate limiter
assert counter == 4
assert len(done) == 4
assert len(pending) == 6


@pytest.mark.asyncio
async def test_rate_limiter_completes_all_tasks() -> None:
async def test_rate_limiter_completes_all_tasks(
async_loop: asyncio.AbstractEventLoop,
) -> None:
"""Test to check all requests will go through rate limiter successfully."""
counter = 0
# allow 1 request to go through per second
limiter = AsyncRateLimiter(max_capacity=1, rate=1)
limiter_future = asyncio.run_coroutine_threadsafe(
rate_limiter_in_loop(max_capacity=1, rate=1, loop=async_loop), async_loop
)
limiter = limiter_future.result()

async def increment() -> None:
await limiter.acquire()
nonlocal counter
counter += 1

tasks = [increment() for _ in range(10)]
# create 10 tasks calling increment()
tasks = [async_loop.create_task(increment()) for _ in range(10)]

done, pending = await asyncio.wait(tasks, timeout=30)
done, pending = asyncio.run_coroutine_threadsafe(
asyncio.wait(tasks, timeout=30), async_loop
).result()

# verify all tasks done and none pending
assert counter == 10
assert len(done) == 10
assert len(pending) == 0

0 comments on commit f52ba7e

Please sign in to comment.