diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 190663df..1d19a736 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -74,25 +74,30 @@ def connection(self) -> "PostgresConnection": class Record(Mapping): + __slots__ = ( + "_row", + "_result_columns", + "_dialect", + "_column_map", + "_column_map_int", + "_column_map_full", + ) + def __init__( - self, row: asyncpg.Record, result_columns: tuple, dialect: Dialect + self, + row: asyncpg.Record, + result_columns: tuple, + dialect: Dialect, + column_maps: typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], + ], ) -> None: self._row = row self._result_columns = result_columns self._dialect = dialect - self._column_map = ( - {} - ) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]] - self._column_map_int = ( - {} - ) # type: typing.Mapping[int, typing.Tuple[int, TypeEngine]] - self._column_map_full = ( - {} - ) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]] - for idx, (column_name, _, column, datatype) in enumerate(self._result_columns): - self._column_map[column_name] = (idx, datatype) - self._column_map_int[idx] = (idx, datatype) - self._column_map_full[str(column[0])] = (idx, datatype) + self._column_map, self._column_map_int, self._column_map_full = column_maps def values(self) -> typing.ValuesView: return self._row.values() @@ -100,9 +105,9 @@ def values(self) -> typing.ValuesView: def __getitem__(self, key: typing.Any) -> typing.Any: if len(self._column_map) == 0: # raw query return self._row[tuple(self._row.keys()).index(key)] - elif type(key) is Column: + elif isinstance(key, Column): idx, datatype = self._column_map_full[str(key)] - elif type(key) is int: + elif isinstance(key, int): idx, datatype = self._column_map_int[key] else: idx, datatype = self._column_map[key] @@ -145,7 +150,9 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]: assert self._connection is not None, "Connection is not acquired" query, args, result_columns = self._compile(query) rows = await self._connection.fetch(query, *args) - return [Record(row, result_columns, self._dialect) for row in rows] + dialect = self._dialect + column_maps = self._create_column_maps(result_columns) + return [Record(row, result_columns, dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]: assert self._connection is not None, "Connection is not acquired" @@ -153,7 +160,12 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin row = await self._connection.fetchrow(query, *args) if row is None: return None - return Record(row, result_columns, self._dialect) + return Record( + row, + result_columns, + self._dialect, + self._create_column_maps(result_columns), + ) async def fetch_val( self, query: ClauseElement, column: typing.Any = 0 @@ -181,8 +193,9 @@ async def iterate( ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query, args, result_columns = self._compile(query) + column_maps = self._create_column_maps(result_columns) async for row in self._connection.cursor(query, *args): - yield Record(row, result_columns, self._dialect) + yield Record(row, result_columns, self._dialect, column_maps) def transaction(self) -> TransactionBackend: return PostgresTransaction(connection=self) @@ -208,6 +221,35 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: ) return compiled_query, args, compiled._result_columns + @staticmethod + def _create_column_maps( + result_columns: tuple, + ) -> typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], + ]: + """ + Generate column -> datatype mappings from the column definitions. + + These mappings are used throughout PostgresConnection methods + to initialize Record-s. The underlying DB driver does not do type + conversion for us so we have wrap the returned asyncpg.Record-s. + + :return: Three mappings from different ways to address a column to \ + corresponding column indexes and datatypes: \ + 1. by column identifier; \ + 2. by column index; \ + 3. by column name in Column sqlalchemy objects. + """ + column_map, column_map_int, column_map_full = {}, {}, {} + breakpoint() + for idx, (column_name, _, column, datatype) in enumerate(result_columns): + column_map[column_name] = (idx, datatype) + column_map_int[idx] = (idx, datatype) + column_map_full[str(column[0])] = (idx, datatype) + return column_map, column_map_int, column_map_full + @property def raw_connection(self) -> asyncpg.connection.Connection: assert self._connection is not None, "Connection is not acquired"