Skip to content

Commit

Permalink
Fix prepared statement handling
Browse files Browse the repository at this point in the history
The prepared statement handling code assumed that for each query we'll
always receive some non-empty response even after the initial response
which is not a valid assumption.

This assumption worked because earlier Trino used to send empty fake
results even for queries which don't return results (like PREPARE and
DEALLOCATE) but is now invalid with
trinodb/trino@bc794cd.

The other problem with the code was that it leaked HTTP protocol details
into dbapi.py and worked around it by keeping a deep copy of the request
object from the PREPARE execution and re-using it for the actual query
execution.

The new code fixes both issues by processing the prepared statement
headers as they are received and storing the resulting set of active
prepared statements on the ClientSession object. The ClientSession's set
of prepared statements is then rendered into the prepared statement
request header in TrinoRequest. Since the ClientSession is created and
reused for the entire Connection this also means that we can now
actually implement re-use of prepared statements within a single
Connection.
  • Loading branch information
hashhar committed Sep 30, 2022
1 parent efb6680 commit 49e051e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 108 deletions.
15 changes: 0 additions & 15 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,21 +881,6 @@ def __call__(self, *args, **kwargs):
return http_response


def test_trino_result_response_headers():
"""
Validates that the `TrinoResult.response_headers` property returns the
headers associated to the TrinoQuery instance provided to the `TrinoResult`
class.
"""
mock_trino_query = mock.Mock(respone_headers={
'X-Trino-Fake-1': 'one',
'X-Trino-Fake-2': 'two',
})

result = TrinoResult(query=mock_trino_query, rows=[])
assert result.response_headers == mock_trino_query.response_headers


def test_trino_query_response_headers(sample_get_response_data):
"""
Validates that the `TrinoQuery.execute` function can take addtional headers
Expand Down
47 changes: 36 additions & 11 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
self._extra_credential = extra_credential
self._client_tags = client_tags
self._role = role
self._prepared_statements: Dict[str, str] = {}
self._object_lock = threading.Lock()

@property
Expand Down Expand Up @@ -206,6 +207,15 @@ def role(self, role):
with self._object_lock:
self._role = role

@property
def prepared_statements(self):
return self._prepared_statements

@prepared_statements.setter
def prepared_statements(self, prepared_statements):
with self._object_lock:
self._prepared_statements = prepared_statements


def get_header_values(headers, header):
return [val.strip() for val in headers[header].split(",")]
Expand All @@ -218,6 +228,12 @@ def get_session_property_values(headers, header):
for k, v in (kv.split("=", 1) for kv in kvs)
]

def get_prepared_statement_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs)
]

class TrinoStatus(object):
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
Expand Down Expand Up @@ -392,6 +408,13 @@ def http_headers(self) -> Dict[str, str]:
for name, value in self._client_session.properties.items()
)

if len(self._client_session.prepared_statements) != 0:
# ``name`` must not contain ``=``
headers[constants.HEADER_PREPARED_STATEMENT] = ",".join(
"{}={}".format(name, urllib.parse.quote_plus(statement))
for name, statement in self._client_session.prepared_statements.items()
)

# merge custom http headers
for key in self._client_session.headers:
if key in headers.keys():
Expand Down Expand Up @@ -556,6 +579,18 @@ def process(self, http_response) -> TrinoStatus:
if constants.HEADER_SET_ROLE in http_response.headers:
self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE]

if constants.HEADER_ADDED_PREPARE in http_response.headers:
for name, statement in get_prepared_statement_values(
http_response.headers, constants.HEADER_ADDED_PREPARE
):
self._client_session.prepared_statements[name] = statement

if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers:
for name in get_header_values(
http_response.headers, constants.HEADER_DEALLOCATED_PREPARE
):
self._client_session.prepared_statements.pop(name)

self._next_uri = response.get("nextUri")

return TrinoStatus(
Expand Down Expand Up @@ -622,10 +657,6 @@ def __iter__(self):

self._rows = next_rows

@property
def response_headers(self):
return self._query.response_headers


class TrinoQuery(object):
"""Represent the execution of a SQL statement by Trino."""
Expand All @@ -648,7 +679,6 @@ def __init__(
self._update_type = None
self._sql = sql
self._result: Optional[TrinoResult] = None
self._response_headers = None
self._experimental_python_types = experimental_python_types
self._row_mapper: Optional[RowMapper] = None

Expand Down Expand Up @@ -705,7 +735,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
self._result = TrinoResult(self, rows)

# Execute should block until at least one row is received
# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
return self._result
Expand All @@ -725,7 +755,6 @@ def fetch(self) -> List[List[Any]]:
status = self._request.process(response)
self._update_state(status)
logger.debug(status)
self._response_headers = response.headers
if status.next_uri is None:
self._finished = True

Expand Down Expand Up @@ -763,10 +792,6 @@ def finished(self) -> bool:
def cancelled(self) -> bool:
return self._cancelled

@property
def response_headers(self):
return self._response_headers


def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
def wrapper(func):
Expand Down
94 changes: 28 additions & 66 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,58 +295,42 @@ def warnings(self):
return self._query.warnings
return None

def _new_request_with_session_from(self, request):
"""
Returns a new request with the `ClientSession` set to the one from the
given request.
"""
request = self.connection._create_request()
request._client_session = request._client_session
return request

def setinputsizes(self, sizes):
raise trino.exceptions.NotSupportedError

def setoutputsize(self, size, column):
raise trino.exceptions.NotSupportedError

def _prepare_statement(self, operation, statement_name):
def _prepare_statement(self, statement, name):
"""
Prepends the given `operation` with "PREPARE <statement_name> FROM" and
executes as a prepare statement.
:param operation: sql to be executed.
:param statement_name: name that will be assigned to the prepare
statement.
:raises trino.exceptions.FailedToObtainAddedPrepareHeader: Error raised
when unable to find the 'X-Trino-Added-Prepare' for the PREPARE
statement request.
Registers a prepared statement for the provided `operation` with the
`name` assigned to it.
:return: string representing the value of the 'X-Trino-Added-Prepare'
header.
:param statement: sql to be executed.
:param name: name that will be assigned to the prepared statement.
"""
sql = 'PREPARE {statement_name} FROM {operation}'.format(
statement_name=statement_name,
operation=operation
)

# Send prepare statement. Copy the _request object to avoid polluting the
# one that is going to be used to execute the actual operation.
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
sql = f"PREPARE {name} FROM {statement}"
# TODO: Evaluate whether we can avoid the piggybacking on current request
query = trino.client.TrinoQuery(self._new_request_with_session_from(self._request), sql=sql,
experimental_python_types=self._experimental_pyton_types)
result = query.execute()

# Iterate until the 'X-Trino-Added-Prepare' header is found or
# until there are no more results
for _ in result:
response_headers = result.response_headers

if constants.HEADER_ADDED_PREPARE in response_headers:
return response_headers[constants.HEADER_ADDED_PREPARE]
query.execute()

raise trino.exceptions.FailedToObtainAddedPrepareHeader

def _get_added_prepare_statement_trino_query(
def _execute_prepared_statement(
self,
statement_name,
params
):
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))

# No need to deepcopy _request here because this is the actual request
# operation
return trino.client.TrinoQuery(self._request, sql=sql, experimental_python_types=self._experimental_pyton_types)

def _format_prepared_param(self, param):
Expand Down Expand Up @@ -422,28 +406,12 @@ def _format_prepared_param(self, param):

raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))

def _deallocate_prepare_statement(self, added_prepare_header, statement_name):
def _deallocate_prepared_statement(self, statement_name):
sql = 'DEALLOCATE PREPARE ' + statement_name

# Send deallocate statement. Copy the _request object to avoid poluting the
# one that is going to be used to execute the actual operation.
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
# TODO: Evaluate whether we can avoid the piggybacking on current request
query = trino.client.TrinoQuery(self._new_request_with_session_from(self._request), sql=sql,
experimental_python_types=self._experimental_pyton_types)
result = query.execute(
additional_http_headers={
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
}
)

# Iterate until the 'X-Trino-Deallocated-Prepare' header is found or
# until there are no more results
for _ in result:
response_headers = result.response_headers

if constants.HEADER_DEALLOCATED_PREPARE in response_headers:
return response_headers[constants.HEADER_DEALLOCATED_PREPARE]

raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader
query.execute()

def _generate_unique_statement_name(self):
return 'st_' + uuid.uuid4().hex.replace('-', '')
Expand All @@ -456,27 +424,21 @@ def execute(self, operation, params=None):
)

statement_name = self._generate_unique_statement_name()
# Send prepare statement
added_prepare_header = self._prepare_statement(
operation, statement_name
)
self._prepare_statement(operation, statement_name)

try:
# Send execute statement and assign the return value to `results`
# as it will be returned by the function
self._query = self._get_added_prepare_statement_trino_query(
self._query = self._execute_prepared_statement(
statement_name, params
)
result = self._query.execute(
additional_http_headers={
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
}
)
result = self._query.execute()
finally:
# Send deallocate statement
# At this point the query can be deallocated since it has already
# been executed
self._deallocate_prepare_statement(added_prepare_header, statement_name)
# TODO: Consider caching prepared statements if requested by caller
self._deallocate_prepared_statement(statement_name)

else:
self._query = trino.client.TrinoQuery(self._request, sql=operation,
Expand Down
16 changes: 0 additions & 16 deletions trino/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,22 +134,6 @@ class TrinoUserError(TrinoQueryError, ProgrammingError):
pass


class FailedToObtainAddedPrepareHeader(Error):
"""
Raise this exception when unable to find the 'X-Trino-Added-Prepare'
header in the response of a PREPARE statement request.
"""
pass


class FailedToObtainDeallocatedPrepareHeader(Error):
"""
Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare'
header in the response of a DEALLOCATED statement request.
"""
pass


# client module errors
class HttpError(Exception):
pass
Expand Down

0 comments on commit 49e051e

Please sign in to comment.