Skip to content

Commit

Permalink
Speed up PostgresConnection fetch() and iterate()
Browse files Browse the repository at this point in the history
  • Loading branch information
vmarkovtsev committed Apr 26, 2020
1 parent 15ecd04 commit 24e3242
Showing 1 changed file with 45 additions and 17 deletions.
62 changes: 45 additions & 17 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,30 @@ def connection(self) -> "PostgresConnection":


class Record(Mapping):
__slots__ = (
"_row",
"_result_columns",
"_dialect",
"_column_map",
"_column_map_int",
"_column_map_full",
)

def __init__(
self, row: asyncpg.Record, result_columns: tuple, dialect: Dialect
self,
row: asyncpg.Record,
result_columns: tuple,
dialect: Dialect,
column_maps: typing.Tuple[
typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]],
typing.Mapping[int, typing.Tuple[int, TypeEngine]],
typing.Mapping[str, typing.Tuple[int, TypeEngine]],
],
) -> None:
self._row = row
self._result_columns = result_columns
self._dialect = dialect
self._column_map = (
{}
) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]]
self._column_map_int = (
{}
) # type: typing.Mapping[int, typing.Tuple[int, TypeEngine]]
self._column_map_full = (
{}
) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]]
for idx, (column_name, _, column, datatype) in enumerate(self._result_columns):
self._column_map[column_name] = (idx, datatype)
self._column_map_int[idx] = (idx, datatype)
self._column_map_full[str(column[0])] = (idx, datatype)
self._column_map, self._column_map_int, self._column_map_full = column_maps

def __getitem__(self, key: typing.Any) -> typing.Any:
if len(self._column_map) == 0: # raw query
Expand Down Expand Up @@ -142,15 +147,22 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
rows = await self._connection.fetch(query, *args)
return [Record(row, result_columns, self._dialect) for row in rows]
dialect = self._dialect
column_maps = self._create_column_maps(result_columns)
return [Record(row, result_columns, dialect, column_maps) for row in rows]

async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
row = await self._connection.fetchrow(query, *args)
if row is None:
return None
return Record(row, result_columns, self._dialect)
return Record(
row,
result_columns,
self._dialect,
self._create_column_maps(result_columns),
)

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
Expand All @@ -171,8 +183,9 @@ async def iterate(
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
column_maps = self._create_column_maps(result_columns)
async for row in self._connection.cursor(query, *args):
yield Record(row, result_columns, self._dialect)
yield Record(row, result_columns, self._dialect, column_maps)

def transaction(self) -> TransactionBackend:
return PostgresTransaction(connection=self)
Expand All @@ -198,6 +211,21 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
)
return compiled_query, args, compiled._result_columns

@staticmethod
def _create_column_maps(
result_columns: tuple,
) -> typing.Tuple[
typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]],
typing.Mapping[int, typing.Tuple[int, TypeEngine]],
typing.Mapping[str, typing.Tuple[int, TypeEngine]],
]:
column_map, column_map_int, column_map_full = {}, {}, {}
for idx, (column_name, _, column, datatype) in enumerate(result_columns):
column_map[column_name] = (idx, datatype)
column_map_int[idx] = (idx, datatype)
column_map_full[str(column[0])] = (idx, datatype)
return column_map, column_map_int, column_map_full

@property
def raw_connection(self) -> asyncpg.connection.Connection:
assert self._connection is not None, "Connection is not acquired"
Expand Down

0 comments on commit 24e3242

Please sign in to comment.