diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d7da4983..4100c39e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -881,21 +881,6 @@ def __call__(self, *args, **kwargs): return http_response -def test_trino_result_response_headers(): - """ - Validates that the `TrinoResult.response_headers` property returns the - headers associated to the TrinoQuery instance provided to the `TrinoResult` - class. - """ - mock_trino_query = mock.Mock(respone_headers={ - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', - }) - - result = TrinoResult(query=mock_trino_query, rows=[]) - assert result.response_headers == mock_trino_query.response_headers - - def test_trino_query_response_headers(sample_get_response_data): """ Validates that the `TrinoQuery.execute` function can take addtional headers diff --git a/trino/client.py b/trino/client.py index b973ff6b..82d973c4 100644 --- a/trino/client.py +++ b/trino/client.py @@ -125,6 +125,7 @@ def __init__( self._extra_credential = extra_credential self._client_tags = client_tags self._role = role + self._prepared_statements: Dict[str, str] = {} self._object_lock = threading.Lock() @property @@ -206,6 +207,15 @@ def role(self, role): with self._object_lock: self._role = role + @property + def prepared_statements(self): + return self._prepared_statements + + @prepared_statements.setter + def prepared_statements(self, prepared_statements): + with self._object_lock: + self._prepared_statements = prepared_statements + def get_header_values(headers, header): return [val.strip() for val in headers[header].split(",")] @@ -218,6 +228,12 @@ def get_session_property_values(headers, header): for k, v in (kv.split("=", 1) for kv in kvs) ] +def get_prepared_statement_values(headers, header): + kvs = get_header_values(headers, header) + return [ + (k.strip(), urllib.parse.unquote_plus(v.strip())) + for k, v in (kv.split("=", 1) for kv in kvs) + ] class TrinoStatus(object): def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None): @@ -392,6 +408,13 @@ def http_headers(self) -> Dict[str, str]: for name, value in self._client_session.properties.items() ) + if len(self._client_session.prepared_statements) != 0: + # ``name`` must not contain ``=`` + headers[constants.HEADER_PREPARED_STATEMENT] = ",".join( + "{}={}".format(name, urllib.parse.quote_plus(statement)) + for name, statement in self._client_session.prepared_statements.items() + ) + # merge custom http headers for key in self._client_session.headers: if key in headers.keys(): @@ -556,6 +579,18 @@ def process(self, http_response) -> TrinoStatus: if constants.HEADER_SET_ROLE in http_response.headers: self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE] + if constants.HEADER_ADDED_PREPARE in http_response.headers: + for name, statement in get_prepared_statement_values( + http_response.headers, constants.HEADER_ADDED_PREPARE + ): + self._client_session.prepared_statements[name] = statement + + if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers: + for name in get_header_values( + http_response.headers, constants.HEADER_DEALLOCATED_PREPARE + ): + self._client_session.prepared_statements.pop(name) + self._next_uri = response.get("nextUri") return TrinoStatus( @@ -622,10 +657,6 @@ def __iter__(self): self._rows = next_rows - @property - def response_headers(self): - return self._query.response_headers - class TrinoQuery(object): """Represent the execution of a SQL statement by Trino.""" @@ -648,7 +679,6 @@ def __init__( self._update_type = None self._sql = sql self._result: Optional[TrinoResult] = None - self._response_headers = None self._experimental_python_types = experimental_python_types self._row_mapper: Optional[RowMapper] = None @@ -705,7 +735,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult: rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows self._result = TrinoResult(self, rows) - # Execute should block until at least one row is received + # Execute should block until at least one row is received or query is finished or cancelled while not self.finished and not self.cancelled and len(self._result.rows) == 0: self._result.rows += self.fetch() return self._result @@ -725,7 +755,6 @@ def fetch(self) -> List[List[Any]]: status = self._request.process(response) self._update_state(status) logger.debug(status) - self._response_headers = response.headers if status.next_uri is None: self._finished = True @@ -763,10 +792,6 @@ def finished(self) -> bool: def cancelled(self) -> bool: return self._cancelled - @property - def response_headers(self): - return self._response_headers - def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts): def wrapper(func): diff --git a/trino/dbapi.py b/trino/dbapi.py index 70fb43bb..4d9f038a 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -295,58 +295,42 @@ def warnings(self): return self._query.warnings return None + def _new_request_with_session_from(self, request): + """ + Returns a new request with the `ClientSession` set to the one from the + given request. + """ + request = self.connection._create_request() + request._client_session = request._client_session + return request + def setinputsizes(self, sizes): raise trino.exceptions.NotSupportedError def setoutputsize(self, size, column): raise trino.exceptions.NotSupportedError - def _prepare_statement(self, operation, statement_name): + def _prepare_statement(self, statement, name): """ - Prepends the given `operation` with "PREPARE FROM" and - executes as a prepare statement. - - :param operation: sql to be executed. - :param statement_name: name that will be assigned to the prepare - statement. - - :raises trino.exceptions.FailedToObtainAddedPrepareHeader: Error raised - when unable to find the 'X-Trino-Added-Prepare' for the PREPARE - statement request. + Registers a prepared statement for the provided `operation` with the + `name` assigned to it. - :return: string representing the value of the 'X-Trino-Added-Prepare' - header. + :param statement: sql to be executed. + :param name: name that will be assigned to the prepared statement. """ - sql = 'PREPARE {statement_name} FROM {operation}'.format( - statement_name=statement_name, - operation=operation - ) - - # Send prepare statement. Copy the _request object to avoid polluting the - # one that is going to be used to execute the actual operation. - query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, + sql = f"PREPARE {name} FROM {statement}" + # TODO: Evaluate whether we can avoid the piggybacking on current request + query = trino.client.TrinoQuery(self._new_request_with_session_from(self._request), sql=sql, experimental_python_types=self._experimental_pyton_types) - result = query.execute() - - # Iterate until the 'X-Trino-Added-Prepare' header is found or - # until there are no more results - for _ in result: - response_headers = result.response_headers - - if constants.HEADER_ADDED_PREPARE in response_headers: - return response_headers[constants.HEADER_ADDED_PREPARE] + query.execute() - raise trino.exceptions.FailedToObtainAddedPrepareHeader - def _get_added_prepare_statement_trino_query( + def _execute_prepared_statement( self, statement_name, params ): sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) - - # No need to deepcopy _request here because this is the actual request - # operation return trino.client.TrinoQuery(self._request, sql=sql, experimental_python_types=self._experimental_pyton_types) def _format_prepared_param(self, param): @@ -422,28 +406,12 @@ def _format_prepared_param(self, param): raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param)) - def _deallocate_prepare_statement(self, added_prepare_header, statement_name): + def _deallocate_prepared_statement(self, statement_name): sql = 'DEALLOCATE PREPARE ' + statement_name - - # Send deallocate statement. Copy the _request object to avoid poluting the - # one that is going to be used to execute the actual operation. - query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, + # TODO: Evaluate whether we can avoid the piggybacking on current request + query = trino.client.TrinoQuery(self._new_request_with_session_from(self._request), sql=sql, experimental_python_types=self._experimental_pyton_types) - result = query.execute( - additional_http_headers={ - constants.HEADER_PREPARED_STATEMENT: added_prepare_header - } - ) - - # Iterate until the 'X-Trino-Deallocated-Prepare' header is found or - # until there are no more results - for _ in result: - response_headers = result.response_headers - - if constants.HEADER_DEALLOCATED_PREPARE in response_headers: - return response_headers[constants.HEADER_DEALLOCATED_PREPARE] - - raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader + query.execute() def _generate_unique_statement_name(self): return 'st_' + uuid.uuid4().hex.replace('-', '') @@ -456,27 +424,21 @@ def execute(self, operation, params=None): ) statement_name = self._generate_unique_statement_name() - # Send prepare statement - added_prepare_header = self._prepare_statement( - operation, statement_name - ) + self._prepare_statement(operation, statement_name) try: # Send execute statement and assign the return value to `results` # as it will be returned by the function - self._query = self._get_added_prepare_statement_trino_query( + self._query = self._execute_prepared_statement( statement_name, params ) - result = self._query.execute( - additional_http_headers={ - constants.HEADER_PREPARED_STATEMENT: added_prepare_header - } - ) + result = self._query.execute() finally: # Send deallocate statement # At this point the query can be deallocated since it has already # been executed - self._deallocate_prepare_statement(added_prepare_header, statement_name) + # TODO: Consider caching prepared statements if requested by caller + self._deallocate_prepared_statement(statement_name) else: self._query = trino.client.TrinoQuery(self._request, sql=operation, diff --git a/trino/exceptions.py b/trino/exceptions.py index 86708fd0..bfd4fef4 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -134,22 +134,6 @@ class TrinoUserError(TrinoQueryError, ProgrammingError): pass -class FailedToObtainAddedPrepareHeader(Error): - """ - Raise this exception when unable to find the 'X-Trino-Added-Prepare' - header in the response of a PREPARE statement request. - """ - pass - - -class FailedToObtainDeallocatedPrepareHeader(Error): - """ - Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare' - header in the response of a DEALLOCATED statement request. - """ - pass - - # client module errors class HttpError(Exception): pass