Skip to content

Commit

Permalink
feat: use ContextVar[dict] to track connections and transactions per …
Browse files Browse the repository at this point in the history
…task
  • Loading branch information
zevisert committed Apr 11, 2023
1 parent f3078aa commit 75969d3
Showing 1 changed file with 46 additions and 36 deletions.
82 changes: 46 additions & 36 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing
from contextvars import ContextVar
from types import TracebackType
from typing import Optional
from typing import Dict, Optional
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit

from sqlalchemy import text
Expand Down Expand Up @@ -35,6 +35,21 @@

logger = logging.getLogger("databases")

# Connections are stored as task-local state, but care must be taken to ensure
# that two database instances in the same task overwrite each other's connections.
# For this reason, key comprises the database instance and the current task.
_connection_contextmap: ContextVar[
Dict[tuple["Database", asyncio.Task], "Connection"]
] = ContextVar("databases:Connection")


def _get_connection_contextmap() -> Dict[tuple["Database", asyncio.Task], "Connection"]:
connections = _connection_contextmap.get(None)
if connections is None:
connections = {}
_connection_contextmap.set(connections)
return connections


class Database:
SUPPORTED_BACKENDS = {
Expand Down Expand Up @@ -64,14 +79,6 @@ def __init__(
assert issubclass(backend_cls, DatabaseBackend)
self._backend = backend_cls(self.url, **self.options)

# Connections are stored as task-local state, and cannot be garbage collected,
# since the immutable global Context stores a strong reference to each ContextVar
# that is created. We need these local ContextVars since two Database objects
# could run in the same asyncio.Task with connections to different databases.
self._connection_contextvar: ContextVar[Optional["Connection"]] = ContextVar(
f"databases:Database:{id(self)}"
)

# When `force_rollback=True` is used, we use a single global
# connection, within a transaction that always rolls back.
self._global_connection: typing.Optional[Connection] = None
Expand Down Expand Up @@ -119,7 +126,10 @@ async def disconnect(self) -> None:
self._global_transaction = None
self._global_connection = None
else:
self._connection_contextvar.set(None)
task = asyncio.current_task()
connections = _get_connection_contextmap()
if (self, task) in connections:
del connections[self, task]

await self._backend.disconnect()
logger.info(
Expand Down Expand Up @@ -193,12 +203,12 @@ def connection(self) -> "Connection":
if self._global_connection is not None:
return self._global_connection

connection = self._connection_contextvar.get(None)
if connection is None:
connection = Connection(self._backend)
self._connection_contextvar.set(connection)
task = asyncio.current_task()
connections = _get_connection_contextmap()
if (self, task) not in connections:
connections[self, task] = Connection(self._backend)

return connection
return connections[self, task]

def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
Expand Down Expand Up @@ -339,6 +349,19 @@ def _build_query(

_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)

_transaction_contextmap: ContextVar[
Dict["Transaction", TransactionBackend]
] = ContextVar("databases:Transactions")


def _get_transaction_contextmap() -> Dict["Transaction", TransactionBackend]:
transactions = _transaction_contextmap.get(None)
if transactions is None:
transactions = {}
_transaction_contextmap.set(transactions)

return transactions


class Transaction:
def __init__(
Expand All @@ -351,15 +374,6 @@ def __init__(
self._force_rollback = force_rollback
self._extra_options = kwargs

# This ContextVar can never be garbage collected - similar to the ContextVar
# at Database._connection_contextvar - since the current Context has a strong
# reference to every ContextVar that is created. We need local ContextVars since
# there may be multiple (even nested) transactions in a single asyncio.Task,
# which each need their own unique TransactionBackend object.
self._transaction_contextvar: ContextVar[
Optional[TransactionBackend]
] = ContextVar(f"databases:Transaction:{id(self)}")

async def __aenter__(self) -> "Transaction":
"""
Called when entering `async with database.transaction()`
Expand Down Expand Up @@ -402,12 +416,8 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
async def start(self) -> "Transaction":
connection = self._connection_callable()
transaction = connection._connection.transaction()

# Cannot store returned reset token anywhere, for the same reason
# we need a ContextVar in the first place - `self` is not
# a safe object on which to store references for concurrent code.
self._transaction_contextvar.set(transaction)

transactions = _get_transaction_contextmap()
transactions[self] = transaction
async with connection._transaction_lock:
is_root = not connection._transaction_stack
await connection.__aenter__()
Expand All @@ -417,27 +427,27 @@ async def start(self) -> "Transaction":

async def commit(self) -> None:
connection = self._connection_callable()
transaction = self._transaction_contextvar.get(None)
transactions = _get_transaction_contextmap()
transaction = transactions.get(self, None)
assert transaction is not None, "Transaction not found in current task"
async with connection._transaction_lock:
assert connection._transaction_stack[-1] is self
connection._transaction_stack.pop()
await transaction.commit()
await connection.__aexit__()
# Have no reset token, set to None instead
self._transaction_contextvar.set(None)
del transactions[self]

async def rollback(self) -> None:
connection = self._connection_callable()
transaction = self._transaction_contextvar.get(None)
transactions = _get_transaction_contextmap()
transaction = transactions.get(self, None)
assert transaction is not None, "Transaction not found in current task"
async with connection._transaction_lock:
assert connection._transaction_stack[-1] is self
connection._transaction_stack.pop()
await transaction.rollback()
await connection.__aexit__()
# Have no reset token, set to None instead
self._transaction_contextvar.set(None)
del transactions[self]


class _EmptyNetloc(str):
Expand Down

0 comments on commit 75969d3

Please sign in to comment.