diff --git a/peewee_async.py b/peewee_async.py index aaf475f..5b3ee12 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -13,6 +13,7 @@ Copyright (c) 2014, Alexey Kinëv """ +import abc import asyncio import contextlib import functools @@ -712,7 +713,7 @@ async def connect_async(self, loop=None, timeout=None): timeout=timeout, **self.connect_params_async ) - await conn.connect() + await conn.create() self._async_conn = conn async def cursor_async(self): @@ -734,7 +735,7 @@ async def close_async(self): if self._async_conn: conn = self._async_conn self._async_conn = None - await conn.close() + await conn.terminate() async def push_transaction_async(self): """Increment async transaction depth. @@ -851,19 +852,14 @@ async def aio_execute(self, query): return (await coroutine(query)) -############## -# PostgreSQL # -############## - - -class AsyncPostgresqlConnection: +class AioPool(metaclass=abc.ABCMeta): """Asynchronous database connection pool. """ def __init__(self, *, database=None, loop=None, timeout=None, **kwargs): self.pool = None self.loop = loop self.database = database - self.timeout = timeout or aiopg.DEFAULT_TIMEOUT + self.timeout = timeout self.connect_params = kwargs async def acquire(self): @@ -876,24 +872,20 @@ def release(self, conn): """ self.pool.release(conn) - async def connect(self): + @abc.abstractmethod + async def create(self): """Create connection pool asynchronously. """ - self.pool = await aiopg.create_pool( - loop=self.loop, - timeout=self.timeout, - database=self.database, - **self.connect_params) + raise NotImplementedError - async def close(self): + async def terminate(self): """Terminate all pool connections. """ self.pool.terminate() await self.pool.wait_closed() async def cursor(self, conn=None, *args, **kwargs): - """Get a cursor for the specified transaction connection - or acquire from the pool. + """Get cursor for connection from pool. """ in_transaction = conn is not None if not conn: @@ -914,10 +906,44 @@ async def release_cursor(self, cursor, in_transaction=False): the connection is also released back to the pool. """ conn = cursor.connection - cursor.close() + await self.close_cursor(cursor) if not in_transaction: self.release(conn) + @abc.abstractmethod + async def close_cursor(self, cursor): + raise NotImplementedError + + + +############## +# PostgreSQL # +############## + + +class AioPostgresqlPool(AioPool): + """Asynchronous database connection pool. + """ + def __init__(self, *, database=None, loop=None, timeout=None, **kwargs): + super().__init__( + database=database, + loop=loop, + timeout=timeout or aiopg.DEFAULT_TIMEOUT, + **kwargs, + ) + + async def create(self): + """Create connection pool asynchronously. + """ + self.pool = await aiopg.create_pool( + loop=self.loop, + timeout=self.timeout, + database=self.database, + **self.connect_params) + + async def close_cursor(self, cursor): + cursor.close() + class AsyncPostgresqlMixin(AsyncDatabase): """Mixin for `peewee.PostgresqlDatabase` providing extra methods @@ -926,7 +952,7 @@ class AsyncPostgresqlMixin(AsyncDatabase): if psycopg2: Error = psycopg2.Error - def init_async(self, conn_cls=AsyncPostgresqlConnection, + def init_async(self, conn_cls=AioPostgresqlPool, enable_json=False, enable_hstore=False): if not aiopg: raise Exception("Error, aiopg is not installed!") @@ -1027,27 +1053,11 @@ def use_speedups(self, value): ######### -class AsyncMySQLConnection: +class AioMysqlPool(AioPool): """Asynchronous database connection pool. """ - def __init__(self, *, database=None, loop=None, timeout=None, **kwargs): - self.pool = None - self.loop = loop - self.database = database - self.timeout = timeout - self.connect_params = kwargs - - async def acquire(self): - """Acquire connection from pool. - """ - return (await self.pool.acquire()) - - def release(self, conn): - """Release connection to pool. - """ - self.pool.release(conn) - async def connect(self): + async def create(self): """Create connection pool asynchronously. """ self.pool = await aiomysql.create_pool( @@ -1056,37 +1066,8 @@ async def connect(self): connect_timeout=self.timeout, **self.connect_params) - async def close(self): - """Terminate all pool connections. - """ - self.pool.terminate() - await self.pool.wait_closed() - - async def cursor(self, conn=None, *args, **kwargs): - """Get cursor for connection from pool. - """ - in_transaction = conn is not None - if not conn: - conn = await self.acquire() - try: - cursor = await conn.cursor(*args, **kwargs) - except: - if not in_transaction: - self.release(conn) - raise - cursor.release = functools.partial( - self.release_cursor, cursor, - in_transaction=in_transaction) - return cursor - - async def release_cursor(self, cursor, in_transaction=False): - """Release cursor coroutine. Unless in transaction, - the connection is also released back to the pool. - """ - conn = cursor.connection + async def close_cursor(self, cursor): await cursor.close() - if not in_transaction: - self.release(conn) class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase): @@ -1108,7 +1089,7 @@ def init(self, database, **kwargs): raise Exception("Error, aiomysql is not installed!") self.min_connections = 1 self.max_connections = 1 - self._async_conn_cls = kwargs.pop('async_conn', AsyncMySQLConnection) + self._async_conn_cls = kwargs.pop('async_conn', AioMysqlPool) super().init(database, **kwargs) @property