diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 330b6b39..38547e32 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -74,25 +74,30 @@ def connection(self) -> "PostgresConnection": class Record(Mapping): + __slots__ = ( + "_row", + "_result_columns", + "_dialect", + "_column_map", + "_column_map_int", + "_column_map_full", + ) + def __init__( - self, row: asyncpg.Record, result_columns: tuple, dialect: Dialect + self, + row: asyncpg.Record, + result_columns: tuple, + dialect: Dialect, + column_maps: typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], + ], ) -> None: self._row = row self._result_columns = result_columns self._dialect = dialect - self._column_map = ( - {} - ) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]] - self._column_map_int = ( - {} - ) # type: typing.Mapping[int, typing.Tuple[int, TypeEngine]] - self._column_map_full = ( - {} - ) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]] - for idx, (column_name, _, column, datatype) in enumerate(self._result_columns): - self._column_map[column_name] = (idx, datatype) - self._column_map_int[idx] = (idx, datatype) - self._column_map_full[str(column[0])] = (idx, datatype) + self._column_map, self._column_map_int, self._column_map_full = column_maps def __getitem__(self, key: typing.Any) -> typing.Any: if len(self._column_map) == 0: # raw query @@ -142,7 +147,9 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]: assert self._connection is not None, "Connection is not acquired" query, args, result_columns = self._compile(query) rows = await self._connection.fetch(query, *args) - return [Record(row, result_columns, self._dialect) for row in rows] + dialect = self._dialect + column_maps = self._create_column_maps(result_columns) + return [Record(row, result_columns, dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]: assert self._connection is not None, "Connection is not acquired" @@ -150,7 +157,12 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin row = await self._connection.fetchrow(query, *args) if row is None: return None - return Record(row, result_columns, self._dialect) + return Record( + row, + result_columns, + self._dialect, + self._create_column_maps(result_columns), + ) async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" @@ -171,8 +183,9 @@ async def iterate( ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query, args, result_columns = self._compile(query) + column_maps = self._create_column_maps(result_columns) async for row in self._connection.cursor(query, *args): - yield Record(row, result_columns, self._dialect) + yield Record(row, result_columns, self._dialect, column_maps) def transaction(self) -> TransactionBackend: return PostgresTransaction(connection=self) @@ -198,6 +211,21 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: ) return compiled_query, args, compiled._result_columns + @staticmethod + def _create_column_maps( + result_columns: tuple, + ) -> typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], + ]: + column_map, column_map_int, column_map_full = {}, {}, {} + for idx, (column_name, _, column, datatype) in enumerate(result_columns): + column_map[column_name] = (idx, datatype) + column_map_int[idx] = (idx, datatype) + column_map_full[str(column[0])] = (idx, datatype) + return column_map, column_map_int, column_map_full + @property def raw_connection(self) -> asyncpg.connection.Connection: assert self._connection is not None, "Connection is not acquired"