Skip to content

Commit

Permalink
Fix prepared statement handling
Browse files Browse the repository at this point in the history
The prepared statement handling code assumed that for each query we'll
always receive some non-empty response even after the initial response
which is not a valid assumption.

This assumption worked because earlier Trino used to send empty fake
results even for queries which don't return results (like PREPARE and
DEALLOCATE) but is now invalid with
trinodb/trino@bc794cd.

The other problem with the code was that it leaked HTTP protocol details
into dbapi.py and worked around it by keeping a deep copy of the request
object from the PREPARE execution and re-using it for the actual query
execution.

The new code fixes both issues by processing the prepared statement
headers as they are received and storing the resulting set of active
prepared statements on the ClientSession object. The ClientSession's set
of prepared statements is then rendered into the prepared statement
request header in TrinoRequest. Since the ClientSession is created and
reused for the entire Connection this also means that we can now
actually implement re-use of prepared statements within a single
Connection.
  • Loading branch information
hashhar committed Sep 30, 2022
1 parent efb6680 commit b2ff5ff
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
3 changes: 3 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,9 @@ def __call__(self, *args, **kwargs):
return http_response


# TODO: What was this test added to verify?
# test_trino_query_response_headers already verifies that custom headers can be passed.
# Possibly we can remove this.
def test_trino_result_response_headers():
"""
Validates that the `TrinoResult.response_headers` property returns the
Expand Down
45 changes: 35 additions & 10 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
self._extra_credential = extra_credential
self._client_tags = client_tags
self._role = role
self._prepared_statements = {}
self._object_lock = threading.Lock()

@property
Expand Down Expand Up @@ -206,6 +207,15 @@ def role(self, role):
with self._object_lock:
self._role = role

@property
def prepared_statements(self):
return self._prepared_statements

@prepared_statements.setter
def prepared_statements(self, prepared_statements):
with self._object_lock:
self._prepared_statements = prepared_statements


def get_header_values(headers, header):
return [val.strip() for val in headers[header].split(",")]
Expand All @@ -218,6 +228,12 @@ def get_session_property_values(headers, header):
for k, v in (kv.split("=", 1) for kv in kvs)
]

def get_prepared_statement_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs)
]

class TrinoStatus(object):
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
Expand Down Expand Up @@ -392,6 +408,13 @@ def http_headers(self) -> Dict[str, str]:
for name, value in self._client_session.properties.items()
)

if len(self._client_session.prepared_statements) != 0:
# ``name`` must not contain ``==``
headers[constants.HEADER_PREPARED_STATEMENT] = ",".join(
"{}={}".format(name, urllib.parse.quote(statement))
for name, statement in self._client_session.prepared_statements
)

# merge custom http headers
for key in self._client_session.headers:
if key in headers.keys():
Expand Down Expand Up @@ -556,6 +579,18 @@ def process(self, http_response) -> TrinoStatus:
if constants.HEADER_SET_ROLE in http_response.headers:
self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE]

if constants.HEADER_ADDED_PREPARE in http_response.headers:
for name, statement in get_prepared_statement_values(
http_response.headers, constants.HEADER_ADDED_PREPARE
):
self._client_session.prepared_statements[name] = statement

if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers:
for name in get_header_values(
http_response.headers, constants.HEADER_DEALLOCATED_PREPARE
):
self._client_session.prepared_statements.pop(name)

self._next_uri = response.get("nextUri")

return TrinoStatus(
Expand Down Expand Up @@ -622,10 +657,6 @@ def __iter__(self):

self._rows = next_rows

@property
def response_headers(self):
return self._query.response_headers


class TrinoQuery(object):
"""Represent the execution of a SQL statement by Trino."""
Expand All @@ -648,7 +679,6 @@ def __init__(
self._update_type = None
self._sql = sql
self._result: Optional[TrinoResult] = None
self._response_headers = None
self._experimental_python_types = experimental_python_types
self._row_mapper: Optional[RowMapper] = None

Expand Down Expand Up @@ -725,7 +755,6 @@ def fetch(self) -> List[List[Any]]:
status = self._request.process(response)
self._update_state(status)
logger.debug(status)
self._response_headers = response.headers
if status.next_uri is None:
self._finished = True

Expand Down Expand Up @@ -763,10 +792,6 @@ def finished(self) -> bool:
def cancelled(self) -> bool:
return self._cancelled

@property
def response_headers(self):
return self._response_headers


def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
def wrapper(func):
Expand Down

0 comments on commit b2ff5ff

Please sign in to comment.