From b94f0971bb2445ff9bd3427a5ab5954eb9fa066d Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Fri, 26 May 2023 15:41:00 -0700 Subject: [PATCH] fix: remove connection inheritance, add more tests, update docs Connections are once again stored as state on the Database instance, keyed by the current asyncio.Task. Each task acquires it's own connection, and a WeakKeyDictionary allows the connection to be discarded if the owning task is garbage collected. TransactionBackends are still stored as contextvars, and a connection must be explicitly provided to descendant tasks if active transaction state is to be inherited. --- databases/core.py | 44 +++--- docs/connections_and_transactions.md | 23 ++- tests/test_databases.py | 226 ++++++++++++++++++++------- 3 files changed, 203 insertions(+), 90 deletions(-) diff --git a/databases/core.py b/databases/core.py index cf5a7aa0..795609ea 100644 --- a/databases/core.py +++ b/databases/core.py @@ -36,12 +36,9 @@ logger = logging.getLogger("databases") -_ACTIVE_CONNECTIONS: ContextVar[ - typing.Optional["weakref.WeakKeyDictionary['Database', 'Connection']"] -] = ContextVar("databases:open_connections", default=None) _ACTIVE_TRANSACTIONS: ContextVar[ typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] -] = ContextVar("databases:open_transactions", default=None) +] = ContextVar("databases:active_transactions", default=None) class Database: @@ -54,6 +51,8 @@ class Database: "sqlite": "databases.backends.sqlite:SQLiteBackend", } + _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" + def __init__( self, url: typing.Union[str, "DatabaseURL"], @@ -64,6 +63,7 @@ def __init__( self.url = DatabaseURL(url) self.options = options self.is_connected = False + self._connection_map = weakref.WeakKeyDictionary() self._force_rollback = force_rollback @@ -78,28 +78,28 @@ def __init__( self._global_transaction: typing.Optional[Transaction] = None @property - def _connection(self) -> typing.Optional["Connection"]: - connections = _ACTIVE_CONNECTIONS.get() - if connections is None: - return None + def _current_task(self) -> asyncio.Task: + task = asyncio.current_task() + if not task: + raise RuntimeError("No currently active asyncio.Task found") + return task - return connections.get(self, None) + @property + def _connection(self) -> typing.Optional["Connection"]: + return self._connection_map.get(self._current_task) @_connection.setter def _connection( self, connection: typing.Optional["Connection"] ) -> typing.Optional["Connection"]: - connections = _ACTIVE_CONNECTIONS.get() - if connections is None: - connections = weakref.WeakKeyDictionary() - _ACTIVE_CONNECTIONS.set(connections) + task = self._current_task if connection is None: - connections.pop(self, None) + self._connection_map.pop(task, None) else: - connections[self] = connection + self._connection_map[task] = connection - return connections.get(self, None) + return self._connection async def connect(self) -> None: """ @@ -119,7 +119,7 @@ async def connect(self) -> None: assert self._global_connection is None assert self._global_transaction is None - self._global_connection = Connection(self._backend) + self._global_connection = Connection(self, self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True ) @@ -218,7 +218,7 @@ def connection(self) -> "Connection": return self._global_connection if not self._connection: - self._connection = Connection(self._backend) + self._connection = Connection(self, self._backend) return self._connection @@ -243,7 +243,8 @@ def _get_backend(self) -> str: class Connection: - def __init__(self, backend: DatabaseBackend) -> None: + def __init__(self, database: Database, backend: DatabaseBackend) -> None: + self._database = database self._backend = backend self._connection_lock = asyncio.Lock() @@ -277,6 +278,7 @@ async def __aexit__( self._connection_counter -= 1 if self._connection_counter == 0: await self._connection.release() + self._database._connection = None async def fetch_all( self, @@ -393,13 +395,15 @@ def _transaction( transactions = _ACTIVE_TRANSACTIONS.get() if transactions is None: transactions = weakref.WeakKeyDictionary() - _ACTIVE_TRANSACTIONS.set(transactions) + else: + transactions = transactions.copy() if transaction is None: transactions.pop(self, None) else: transactions[self] = transaction + _ACTIVE_TRANSACTIONS.set(transactions) return transactions.get(self, None) async def __aenter__(self) -> "Transaction": diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index e52243e3..11044655 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints. ## Connecting and disconnecting -You can control the database connect/disconnect, by using it as a async context manager. +You can control the database connection pool with an async context manager: ```python async with Database(DATABASE_URL) as database: ... ``` -Or by using explicit connection and disconnection: +Or by using the explicit `.connect()` and `.disconnect()` methods: ```python database = Database(DATABASE_URL) @@ -23,6 +23,8 @@ await database.connect() await database.disconnect() ``` +Connections within this connection pool are acquired for each new `asyncio.Task`. + If you're integrating against a web framework, then you'll probably want to hook into framework startup or shutdown events. For example, with [Starlette][starlette] you would use the following: @@ -96,12 +98,13 @@ async def create_users(request): ... ``` -Transaction state is stored in the context of the currently executing asynchronous task. -This state is _inherited_ by tasks that are started from within an active transaction: +Transaction state is tied to the connection used in the currently executing asynchronous task. +If you would like to influence an active transaction from another task, the connection must be +shared. This state is _inherited_ by tasks that are share the same connection: ```python -async def add_excitement(database: Database, id: int): - await database.execute( +async def add_excitement(connnection: databases.core.Connection, id: int): + await connection.execute( "UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id", {"id": id} ) @@ -113,17 +116,13 @@ async with Database(database_url) as database: await database.execute( "INSERT INTO notes(id, text) values (1, 'databases is cool')" ) - # ...but child tasks inherit transaction state! - await asyncio.create_task(add_excitement(database, id=1)) + # ...but child tasks can use this connection now! + await asyncio.create_task(add_excitement(database.connection(), id=1)) await database.fetch_val("SELECT text FROM notes WHERE id=1") # ^ returns: "databases is cool!!!" ``` -!!! note - In python 3.11, you can opt-out of context propagation by providing a new context to - [`asyncio.create_task`](https://docs.python.org/3.11/library/asyncio-task.html#creating-tasks). - Nested transactions are fully supported, and are implemented using database savepoints: ```python diff --git a/tests/test_databases.py b/tests/test_databases.py index c78ce4f3..4d737261 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -482,11 +482,29 @@ async def test_transaction_commit(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_transaction_context_child_task_interaction(database_url): +async def test_transaction_context_child_task_inheritance(database_url): + """ + Ensure that transactions are inherited by child tasks. + """ + async with Database(database_url) as database: + + async def check_transaction(transaction, active_transaction): + # Should have inherited the same transaction backend from the parent task + assert transaction._transaction is active_transaction + + async with database.transaction() as transaction: + await asyncio.create_task( + check_transaction(transaction, transaction._transaction) + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance_example(database_url): """ Ensure that child tasks may influence inherited transactions. """ - # This is an practical example of the next test. + # This is an practical example of the above test. async with Database(database_url) as database: async with database.transaction(): # Create a note @@ -503,37 +521,19 @@ async def test_transaction_context_child_task_interaction(database_url): result = await database.fetch_one(notes.select().where(notes.c.id == 1)) assert result.text == "prior" - async def run_update_from_child_task(): - # Chage the note from a child task - await database.execute( + async def run_update_from_child_task(connection): + # Change the note from a child task + await connection.execute( notes.update().where(notes.c.id == 1).values(text="test") ) - await asyncio.create_task(run_update_from_child_task()) + await asyncio.create_task(run_update_from_child_task(database.connection())) # Confirm the child's change result = await database.fetch_one(notes.select().where(notes.c.id == 1)) assert result.text == "test" -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter -async def test_transaction_context_child_task_inheritance(database_url): - """ - Ensure that transactions are inherited by child tasks. - """ - async with Database(database_url) as database: - - async def check_transaction(transaction, active_transaction): - # Should have inherited the same transaction backend from the parent task - assert transaction._transaction is active_transaction - - async with database.transaction() as transaction: - await asyncio.create_task( - check_transaction(transaction, transaction._transaction) - ) - - @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_context_sibling_task_isolation(database_url): @@ -568,56 +568,99 @@ async def check_transaction(transaction): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_cleanup_contextmanager(database_url): +async def test_transaction_context_sibling_task_isolation_example(database_url): + """ + Ensure that transactions are running in sibling tasks are isolated from eachother. + """ + # This is an practical example of the above test. + setup = asyncio.Event() + done = asyncio.Event() + + async def tx1(connection): + async with connection.transaction(): + await db.execute( + notes.insert(), values={"id": 1, "text": "tx1", "completed": False} + ) + setup.set() + await done.wait() + + async def tx2(connection): + async with connection.transaction(): + await setup.wait() + result = await db.fetch_all(notes.select()) + assert result == [], result + done.set() + + async with Database(database_url) as db: + await asyncio.gather(tx1(db), tx2(db)) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_contextmanager(database_url): """ - Ensure that contextvar connections are not persisted unecessarily. + Ensure that task connections are not persisted unecessarily. """ - from databases.core import _ACTIVE_CONNECTIONS - assert _ACTIVE_CONNECTIONS.get() is None + ready = asyncio.Event() + done = asyncio.Event() + + async def check_child_connection(database: Database): + async with database.connection(): + ready.set() + await done.wait() async with Database(database_url) as database: + # Should have a connection in this task # .connect is lazy, it doesn't create a Connection, but .connection does connection = database.connection() + assert isinstance(database._connection_map, MutableMapping) + assert database._connection_map.get(asyncio.current_task()) is connection - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database) is connection + # Create a child task and see if it registers a connection + task = asyncio.create_task(check_child_connection(database)) + await ready.wait() + assert database._connection_map.get(task) is not None + assert database._connection_map.get(task) is not connection - # Context manager closes, open_connections is cleaned up - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database, None) is None + # Let the child task finish, and see if it cleaned up + done.set() + await task + # This is normal exit logic cleanup, the WeakKeyDictionary + # shouldn't have cleaned up yet since the task is still referenced + assert task not in database._connection_map + + # Context manager closes, all open connections are removed + assert isinstance(database._connection_map, MutableMapping) + assert len(database._connection_map) == 0 @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_cleanup_garbagecollector(database_url): +async def test_connection_cleanup_garbagecollector(database_url): """ - Ensure that contextvar connections are not persisted unecessarily, even + Ensure that connections for tasks are not persisted unecessarily, even if exit handlers are not called. """ - from databases.core import _ACTIVE_CONNECTIONS - - assert _ACTIVE_CONNECTIONS.get() is None - database = Database(database_url) await database.connect() - connection = database.connection() - # Should be tracking the connection - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database) is connection + created = asyncio.Event() + + async def check_child_connection(database: Database): + # neither .disconnect nor .__aexit__ are called before deleting this task + database.connection() + created.set() - # neither .disconnect nor .__aexit__ are called before deleting the reference - del database + task = asyncio.create_task(check_child_connection(database)) + await created.wait() + assert task in database._connection_map + await task + del task gc.collect() - # Should have dropped reference to connection, even without proper cleanup - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert len(open_connections) == 0 + # Should not have a connection for the task anymore + assert len(database._connection_map) == 0 @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -632,7 +675,6 @@ async def test_transaction_context_cleanup_contextmanager(database_url): async with Database(database_url) as database: async with database.transaction() as transaction: - open_transactions = _ACTIVE_TRANSACTIONS.get() assert isinstance(open_transactions, MutableMapping) assert open_transactions.get(transaction) is transaction._transaction @@ -818,17 +860,44 @@ async def insert_data(raise_exception): with pytest.raises(RuntimeError): await insert_data(raise_exception=True) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 0 await insert_data(raise_exception=False) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_decorator_concurrent(database_url): + """ + Ensure that @database.transaction() can be called concurrently. + """ + + database = Database(database_url) + + @database.transaction() + async def insert_data(): + await database.execute( + query=notes.insert().values(text="example", completed=True) + ) + + async with database: + await asyncio.gather( + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + ) + + results = await database.fetch_all(query=notes.select()) + assert len(results) == 6 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_datetime_field(database_url): @@ -1007,7 +1076,7 @@ async def test_connection_context_same_task(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_multiple_tasks(database_url): +async def test_connection_context_multiple_sibling_tasks(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None @@ -1037,6 +1106,47 @@ async def get_connection_2(): await task_2 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_tasks(database_url): + async with Database(database_url) as database: + parent_connection = database.connection() + connection_1 = None + connection_2 = None + task_1_ready = asyncio.Event() + task_2_ready = asyncio.Event() + test_complete = asyncio.Event() + + async def get_connection_1(): + nonlocal connection_1 + + async with database.connection() as connection: + connection_1 = connection + task_1_ready.set() + await test_complete.wait() + + async def get_connection_2(): + nonlocal connection_2 + + async with database.connection() as connection: + connection_2 = connection + task_2_ready.set() + await test_complete.wait() + + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) + await task_1_ready.wait() + await task_2_ready.wait() + + assert connection_1 is not parent_connection + assert connection_2 is not parent_connection + assert connection_1 is not connection_2 + + test_complete.set() + await task_1 + await task_2 + + @pytest.mark.parametrize( "database_url1,database_url2", (