Skip to content

Commit

Permalink
Speed up PostgresConnection fetch() and iterate()
Browse files Browse the repository at this point in the history
  • Loading branch information
vmarkovtsev committed Apr 30, 2020
1 parent 25e65ed commit 7cfe3e1
Showing 1 changed file with 61 additions and 19 deletions.
80 changes: 61 additions & 19 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,35 +74,40 @@ def connection(self) -> "PostgresConnection":


class Record(Mapping):
__slots__ = (
"_row",
"_result_columns",
"_dialect",
"_column_map",
"_column_map_int",
"_column_map_full",
)

def __init__(
self, row: asyncpg.Record, result_columns: tuple, dialect: Dialect
self,
row: asyncpg.Record,
result_columns: tuple,
dialect: Dialect,
column_maps: typing.Tuple[
typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]],
typing.Mapping[int, typing.Tuple[int, TypeEngine]],
typing.Mapping[str, typing.Tuple[int, TypeEngine]],
],
) -> None:
self._row = row
self._result_columns = result_columns
self._dialect = dialect
self._column_map = (
{}
) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]]
self._column_map_int = (
{}
) # type: typing.Mapping[int, typing.Tuple[int, TypeEngine]]
self._column_map_full = (
{}
) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]]
for idx, (column_name, _, column, datatype) in enumerate(self._result_columns):
self._column_map[column_name] = (idx, datatype)
self._column_map_int[idx] = (idx, datatype)
self._column_map_full[str(column[0])] = (idx, datatype)
self._column_map, self._column_map_int, self._column_map_full = column_maps

def values(self) -> typing.ValuesView:
return self._row.values()

def __getitem__(self, key: typing.Any) -> typing.Any:
if len(self._column_map) == 0: # raw query
return self._row[tuple(self._row.keys()).index(key)]
elif type(key) is Column:
elif isinstance(key, Column):
idx, datatype = self._column_map_full[str(key)]
elif type(key) is int:
elif isinstance(key, int):
idx, datatype = self._column_map_int[key]
else:
idx, datatype = self._column_map[key]
Expand Down Expand Up @@ -145,15 +150,22 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
rows = await self._connection.fetch(query, *args)
return [Record(row, result_columns, self._dialect) for row in rows]
dialect = self._dialect
column_maps = self._create_column_maps(result_columns)
return [Record(row, result_columns, dialect, column_maps) for row in rows]

async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
row = await self._connection.fetchrow(query, *args)
if row is None:
return None
return Record(row, result_columns, self._dialect)
return Record(
row,
result_columns,
self._dialect,
self._create_column_maps(result_columns),
)

async def fetch_val(
self, query: ClauseElement, column: typing.Any = 0
Expand Down Expand Up @@ -181,8 +193,9 @@ async def iterate(
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query, args, result_columns = self._compile(query)
column_maps = self._create_column_maps(result_columns)
async for row in self._connection.cursor(query, *args):
yield Record(row, result_columns, self._dialect)
yield Record(row, result_columns, self._dialect, column_maps)

def transaction(self) -> TransactionBackend:
return PostgresTransaction(connection=self)
Expand All @@ -208,6 +221,35 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
)
return compiled_query, args, compiled._result_columns

@staticmethod
def _create_column_maps(
result_columns: tuple,
) -> typing.Tuple[
typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]],
typing.Mapping[int, typing.Tuple[int, TypeEngine]],
typing.Mapping[str, typing.Tuple[int, TypeEngine]],
]:
"""
Generate column -> datatype mappings from the column definitions.
These mappings are used throughout PostgresConnection methods
to initialize Record-s. The underlying DB driver does not do type
conversion for us so we have wrap the returned asyncpg.Record-s.
:return: Three mappings from different ways to address a column to \
corresponding column indexes and datatypes: \
1. by column identifier; \
2. by column index; \
3. by column name in Column sqlalchemy objects.
"""
column_map, column_map_int, column_map_full = {}, {}, {}
breakpoint()
for idx, (column_name, _, column, datatype) in enumerate(result_columns):
column_map[column_name] = (idx, datatype)
column_map_int[idx] = (idx, datatype)
column_map_full[str(column[0])] = (idx, datatype)
return column_map, column_map_int, column_map_full

@property
def raw_connection(self) -> asyncpg.connection.Connection:
assert self._connection is not None, "Connection is not acquired"
Expand Down

0 comments on commit 7cfe3e1

Please sign in to comment.