Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Refactor Async(Postgresql/MySQL)Connection #213

Merged
merged 2 commits into from
Apr 8, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 42 additions & 63 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Copyright (c) 2014, Alexey Kinëv <[email protected]>

"""
import abc
import asyncio
import contextlib
import functools
Expand Down Expand Up @@ -851,19 +852,14 @@ async def aio_execute(self, query):
return (await coroutine(query))


##############
# PostgreSQL #
##############


class AsyncPostgresqlConnection:
class AsyncConnectionPool(metaclass=abc.ABCMeta):
kalombos marked this conversation as resolved.
Show resolved Hide resolved
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout or aiopg.DEFAULT_TIMEOUT
self.timeout = timeout
self.connect_params = kwargs

async def acquire(self):
Expand All @@ -876,14 +872,11 @@ def release(self, conn):
"""
self.pool.release(conn)

@abc.abstractmethod
async def connect(self):
kalombos marked this conversation as resolved.
Show resolved Hide resolved
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_params)
raise NotImplementedError

async def close(self):
kalombos marked this conversation as resolved.
Show resolved Hide resolved
"""Terminate all pool connections.
Expand All @@ -892,8 +885,7 @@ async def close(self):
await self.pool.wait_closed()

async def cursor(self, conn=None, *args, **kwargs):
"""Get a cursor for the specified transaction connection
or acquire from the pool.
"""Get cursor for connection from pool.
"""
in_transaction = conn is not None
if not conn:
Expand All @@ -909,6 +901,41 @@ async def cursor(self, conn=None, *args, **kwargs):
in_transaction=in_transaction)
return cursor

async def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
conn = cursor.connection
await cursor.close()
if not in_transaction:
self.release(conn)


##############
# PostgreSQL #
##############


class AsyncPostgresqlConnection(AsyncConnectionPool):
kalombos marked this conversation as resolved.
Show resolved Hide resolved
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
super().__init__(
database=database,
loop=loop,
timeout=timeout or aiopg.DEFAULT_TIMEOUT,
**kwargs,
)

async def connect(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_params)

async def release_cursor(self, cursor, in_transaction=False):
kalombos marked this conversation as resolved.
Show resolved Hide resolved
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
Expand Down Expand Up @@ -1027,25 +1054,9 @@ def use_speedups(self, value):
#########


class AsyncMySQLConnection:
class AsyncMySQLConnection(AsyncConnectionPool):
kalombos marked this conversation as resolved.
Show resolved Hide resolved
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout
self.connect_params = kwargs

async def acquire(self):
"""Acquire connection from pool.
"""
return (await self.pool.acquire())

def release(self, conn):
"""Release connection to pool.
"""
self.pool.release(conn)

async def connect(self):
"""Create connection pool asynchronously.
Expand All @@ -1056,38 +1067,6 @@ async def connect(self):
connect_timeout=self.timeout,
**self.connect_params)

async def close(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
await self.pool.wait_closed()

async def cursor(self, conn=None, *args, **kwargs):
"""Get cursor for connection from pool.
"""
in_transaction = conn is not None
if not conn:
conn = await self.acquire()
try:
cursor = await conn.cursor(*args, **kwargs)
except:
if not in_transaction:
self.release(conn)
raise
cursor.release = functools.partial(
self.release_cursor, cursor,
in_transaction=in_transaction)
return cursor

async def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
conn = cursor.connection
await cursor.close()
if not in_transaction:
self.release(conn)


class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
"""MySQL database driver providing **single drop-in sync** connection
Expand Down
Loading