From 9d898a8a8d723a1a4ac59e90b6963b328fbfa02f Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Sat, 13 Aug 2022 01:38:04 +0200 Subject: [PATCH] Acknowledge reception of data in `TrinoResult` --- trino/client.py | 22 ++++++++++------------ trino/dbapi.py | 2 +- trino/sqlalchemy/dialect.py | 6 +++--- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/trino/client.py b/trino/client.py index fd0eeaeb..399ec9bc 100644 --- a/trino/client.py +++ b/trino/client.py @@ -594,7 +594,8 @@ class TrinoResult(object): def __init__(self, query, rows=None): self._query = query - self._rows = rows or [] + # Initial rows from the first POST request + self._rows = rows self._rownumber = 0 @property @@ -602,20 +603,17 @@ def rownumber(self) -> int: return self._rownumber def __iter__(self): - # Initial fetch from the first POST request - for row in self._rows: - self._rownumber += 1 - yield row - self._rows = None - - # Subsequent fetches from GET requests until next_uri is empty. - while not self._query.finished: - rows = self._query.fetch() - for row in rows: + # A query only transitions to a FINISHED state when the results are fully consumed: + # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. + while not self._query.finished or self._rows is not None: + next_rows = self._query.fetch() if not self._query.finished else None + for row in self._rows: self._rownumber += 1 logger.debug("row %s", row) yield row + self._rows = next_rows + @property def response_headers(self): return self._query.response_headers @@ -641,7 +639,7 @@ def __init__( self._request = request self._update_type = None self._sql = sql - self._result = TrinoResult(self) + self._result: Optional[TrinoResult] = None self._response_headers = None self._experimental_python_types = experimental_python_types self._row_mapper: Optional[RowMapper] = None diff --git a/trino/dbapi.py b/trino/dbapi.py index 44813168..70fb43bb 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -322,7 +322,7 @@ def _prepare_statement(self, operation, statement_name): operation=operation ) - # Send prepare statement. Copy the _request object to avoid poluting the + # Send prepare statement. Copy the _request object to avoid polluting the # one that is going to be used to execute the actual operation. query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, experimental_python_types=self._experimental_pyton_types) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 7c4409a0..e967cb6b 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -231,7 +231,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st """ ).strip() res = connection.execute(sql.text(query), schema=schema, view=view_name) - return res.scalar() + return res.scalar_one_or_none() def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: if not self.has_table(connection, table_name, schema): @@ -284,7 +284,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str sql.text(query), catalog_name=catalog_name, schema_name=schema_name, table_name=table_name ) - return dict(text=res.scalar()) + return dict(text=res.scalar_one_or_none()) except error.TrinoQueryError as e: if e.error_name in ( error.PERMISSION_DENIED, @@ -326,7 +326,7 @@ def _get_server_version_info(self, connection: Connection) -> Any: query = "SELECT version()" try: res = connection.execute(sql.text(query)) - version = res.scalar() + version = res.scalar_one() return tuple([version]) except exc.ProgrammingError as e: logger.debug(f"Failed to get server version: {e.orig.message}")