From 95130019de7865443971f4dff890de4907896319 Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Mon, 22 Aug 2022 09:46:01 +0200 Subject: [PATCH] Reformatted files using black and isort --- setup.py | 4 +- tests/integration/conftest.py | 37 +-- tests/integration/test_dbapi_integration.py | 269 +++++++++--------- .../test_sqlalchemy_integration.py | 202 +++++++------ tests/unit/conftest.py | 6 +- tests/unit/oauth_test_utils.py | 49 +++- tests/unit/sqlalchemy/conftest.py | 2 +- tests/unit/sqlalchemy/test_compiler.py | 46 +-- tests/unit/sqlalchemy/test_datatype_parse.py | 20 +- tests/unit/sqlalchemy/test_dialect.py | 59 ++-- tests/unit/test_client.py | 211 ++++++++------ tests/unit/test_dbapi.py | 89 +++--- tests/unit/test_http.py | 2 +- tests/unit/test_transaction.py | 7 +- trino/__init__.py | 9 +- trino/auth.py | 73 ++--- trino/client.py | 74 ++--- trino/constants.py | 7 +- trino/dbapi.py | 119 +++----- trino/exceptions.py | 2 + trino/sqlalchemy/compiler.py | 31 +- trino/sqlalchemy/datatype.py | 2 +- trino/sqlalchemy/dialect.py | 21 +- trino/transaction.py | 15 +- 24 files changed, 656 insertions(+), 700 deletions(-) diff --git a/setup.py b/setup.py index 2695a530..0bc76ec7 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ import re import textwrap -from setuptools import setup, find_packages +from setuptools import find_packages, setup _version_re = re.compile(r"__version__\s+=\s+(.*)") @@ -75,7 +75,7 @@ "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Database :: Front-Ends", ], - python_requires='>=3.7', + python_requires=">=3.7", install_requires=["pytz", "requests"], extras_require={ "all": all_require, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 840ee8f3..91b6fd95 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -18,11 +18,11 @@ from uuid import uuid4 import click -import trino.logging import pytest -from trino.client import TrinoQuery, TrinoRequest, ClientSession -from trino.constants import DEFAULT_PORT +import trino.logging +from trino.client import ClientSession, TrinoQuery, TrinoRequest +from trino.constants import DEFAULT_PORT logger = trino.logging.get_logger(__name__) @@ -64,13 +64,7 @@ def start_trino(image_tag=None): def wait_for_trino_workers(host, port, timeout=180): - request = TrinoRequest( - host=host, - port=port, - client_session=ClientSession( - user="test_fixture" - ) - ) + request = TrinoRequest(host=host, port=port, client_session=ClientSession(user="test_fixture")) sql = "SELECT state FROM system.runtime.nodes" t0 = time.time() while True: @@ -116,9 +110,7 @@ def start_trino_and_wait(image_tag=None): if host: port = os.environ.get("TRINO_RUNNING_PORT", DEFAULT_PORT) else: - container_id, proc, host, port = start_local_trino_server( - image_tag - ) + container_id, proc, host, port = start_local_trino_server(image_tag) print("trino.server.hostname {}".format(host)) print("trino.server.port {}".format(port)) @@ -135,9 +127,7 @@ def stop_trino(container_id, proc): def find_images(name): assert name - output = subprocess.check_output( - ["docker", "images", "--format", "{{.Repository}}:{{.Tag}}", name] - ) + output = subprocess.check_output(["docker", "images", "--format", "{{.Repository}}:{{.Tag}}", name]) return [line.decode() for line in output.splitlines()] @@ -167,9 +157,7 @@ def cli(): pass -@click.option( - "--cache/--no-cache", default=True, help="enable/disable Docker build cache" -) +@click.option("--cache/--no-cache", default=True, help="enable/disable Docker build cache") @click.command() def trino_server(): container_id, _, _, _ = start_trino_and_wait() @@ -198,19 +186,14 @@ def trino_cli(container_id=None): @cli.command("list") def list_(): - subprocess.check_call( - ["docker", "ps", "--filter", "name=trino-python-client-tests-"] - ) + subprocess.check_call(["docker", "ps", "--filter", "name=trino-python-client-tests-"]) @cli.command() def clean(): cmd = ( - "docker ps " - "--filter name=trino-python-client-tests- " - "--format={{.Names}} | " - "xargs -n 1 docker kill" # NOQA deliberate additional indent - ) + "docker ps --filter name=trino-python-client-tests- --format={{.Names}} | xargs -n 1 docker kill" + ) # NOQA deliberate additional indent subprocess.check_output(cmd, shell=True) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index c5d0dda0..bf7cddf8 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from datetime import datetime, time, date, timezone, timedelta +from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal import pytest @@ -19,7 +19,7 @@ import trino from tests.integration.conftest import trino_version -from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError +from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError from trino.transaction import IsolationLevel @@ -27,9 +27,7 @@ def trino_connection(run_trino): _, host, port = run_trino - yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 - ) + yield trino.dbapi.Connection(host=host, port=port, user="test", source="test", max_attempts=1) @pytest.fixture @@ -93,8 +91,7 @@ def test_select_query_result_iteration(trino_connection): def test_select_query_result_iteration_statement_params(trino_connection): cur = trino_connection.cursor() cur.execute( - """ - SELECT * FROM ( + """SELECT * FROM ( values (1, 'one', 'a'), (2, 'two', 'b'), @@ -104,7 +101,7 @@ def test_select_query_result_iteration_statement_params(trino_connection): ) x (id, name, letter) WHERE id >= ? """, - params=(3,) # expecting all the rows with id >= 3 + params=(3,), # expecting all the rows with id >= 3 ) @@ -166,23 +163,22 @@ def test_execute_many_select(trino_connection): assert "Query must return update type" in str(e.value) -@pytest.mark.parametrize("connection_experimental_python_types,cursor_experimental_python_types,expected", - [ - (None, None, False), - (None, False, False), - (None, True, True), - (False, None, False), - (False, False, False), - (False, True, True), - (True, None, True), - (True, False, False), - (True, True, True), - ]) +@pytest.mark.parametrize( + "connection_experimental_python_types,cursor_experimental_python_types,expected", + [ + (None, None, False), + (None, False, False), + (None, True, True), + (False, None, False), + (False, False, False), + (False, True, True), + (True, None, True), + (True, False, False), + (True, True, True), + ], +) def test_experimental_python_types_with_connection_and_cursor( - connection_experimental_python_types, - cursor_experimental_python_types, - expected, - run_trino + connection_experimental_python_types, cursor_experimental_python_types, expected, run_trino ): _, host, port = run_trino @@ -195,43 +191,44 @@ def test_experimental_python_types_with_connection_and_cursor( cur = connection.cursor(experimental_python_types=cursor_experimental_python_types) - cur.execute(""" - SELECT + cur.execute( + """SELECT DECIMAL '0.142857', DATE '2018-01-01', TIMESTAMP '2019-01-01 00:00:00.000+01:00', TIMESTAMP '2019-01-01 00:00:00.000 UTC', TIMESTAMP '2019-01-01 00:00:00.000', TIME '00:00:00.000' - """) + """ + ) rows = cur.fetchall() if expected: - assert rows[0][0] == Decimal('0.142857') + assert rows[0][0] == Decimal("0.142857") assert rows[0][1] == date(2018, 1, 1) assert rows[0][2] == datetime(2019, 1, 1, tzinfo=timezone(timedelta(hours=1))) - assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone('UTC')) + assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone("UTC")) assert rows[0][4] == datetime(2019, 1, 1) assert rows[0][5] == time(0, 0, 0, 0) else: for value in rows[0]: assert isinstance(value, str) - assert rows[0][0] == '0.142857' - assert rows[0][1] == '2018-01-01' - assert rows[0][2] == '2019-01-01 00:00:00.000 +01:00' - assert rows[0][3] == '2019-01-01 00:00:00.000 UTC' - assert rows[0][4] == '2019-01-01 00:00:00.000' - assert rows[0][5] == '00:00:00.000' + assert rows[0][0] == "0.142857" + assert rows[0][1] == "2018-01-01" + assert rows[0][2] == "2019-01-01 00:00:00.000 +01:00" + assert rows[0][3] == "2019-01-01 00:00:00.000 UTC" + assert rows[0][4] == "2019-01-01 00:00:00.000" + assert rows[0][5] == "00:00:00.000" def test_decimal_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - cur.execute("SELECT ?", params=(Decimal('0.142857'),)) + cur.execute("SELECT ?", params=(Decimal("0.142857"),)) rows = cur.fetchall() - assert rows[0][0] == Decimal('0.142857') + assert rows[0][0] == Decimal("0.142857") def test_null_decimal(trino_connection): @@ -246,7 +243,7 @@ def test_null_decimal(trino_connection): def test_biggest_decimal(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = Decimal('99999999999999999999999999999999999999') + params = Decimal("99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -256,7 +253,7 @@ def test_biggest_decimal(trino_connection): def test_smallest_decimal(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = Decimal('-99999999999999999999999999999999999999') + params = Decimal("-99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -266,7 +263,7 @@ def test_smallest_decimal(trino_connection): def test_highest_precision_decimal(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = Decimal('0.99999999999999999999999999999999999999') + params = Decimal("0.99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -288,7 +285,7 @@ def test_datetime_query_param(trino_connection): def test_datetime_with_utc_time_zone_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('UTC')) + params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone("UTC")) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -314,7 +311,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection): def test_datetime_with_named_time_zone_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles')) + params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone("America/Los_Angeles")) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -354,7 +351,7 @@ def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) # This is a datetime that lies within a DST transition and not actually exists. - params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone('Europe/Brussels')) + params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone("Europe/Brussels")) with pytest.raises(trino.exceptions.TrinoUserError): cur.execute("SELECT ?", params=(params,)) cur.fetchall() @@ -365,21 +362,21 @@ def test_doubled_datetimes(trino_connection): # See also https://github.com/trinodb/trino/issues/5781 cur = trino_connection.cursor(experimental_python_types=True) - params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True) + params = pytz.timezone("US/Eastern").localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() - assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern')) + assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone("US/Eastern")) cur = trino_connection.cursor(experimental_python_types=True) - params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False) + params = pytz.timezone("US/Eastern").localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() - assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern')) + assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone("US/Eastern")) def test_date_query_param(trino_connection): @@ -407,11 +404,11 @@ def test_unsupported_python_dates(trino_connection): # dates below python min (1-1-1) or above max date (9999-12-31) are not supported for unsupported_date in [ - '-0001-01-01', - '0000-01-01', - '10000-01-01', - '-4999999-01-01', # Trino min date - '5000000-12-31', # Trino max date + "-0001-01-01", + "0000-01-01", + "10000-01-01", + "-4999999-01-01", # Trino min date + "5000000-12-31", # Trino max date ]: with pytest.raises(trino.exceptions.TrinoDataError): cur.execute(f"SELECT DATE '{unsupported_date}'") @@ -422,26 +419,26 @@ def test_supported_special_dates_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) for params in ( - # min python date - date(1, 1, 1), - # before julian->gregorian switch - date(1500, 1, 1), - # During julian->gregorian switch - date(1752, 9, 4), - # before epoch - date(1952, 4, 3), - date(1970, 1, 1), - date(1970, 2, 3), - # summer on northern hemisphere (possible DST) - date(2017, 7, 1), - # winter on northern hemisphere (possible DST on southern hemisphere) - date(2017, 1, 1), - # winter on southern hemisphere (possible DST on northern hemisphere) - date(2017, 12, 31), - date(1983, 4, 1), - date(1983, 10, 1), - # max python date - date(9999, 12, 31), + # min python date + date(1, 1, 1), + # before julian->gregorian switch + date(1500, 1, 1), + # During julian->gregorian switch + date(1752, 9, 4), + # before epoch + date(1952, 4, 3), + date(1970, 1, 1), + date(1970, 2, 3), + # summer on northern hemisphere (possible DST) + date(2017, 7, 1), + # winter on northern hemisphere (possible DST on southern hemisphere) + date(2017, 1, 1), + # winter on southern hemisphere (possible DST on northern hemisphere) + date(2017, 12, 31), + date(1983, 4, 1), + date(1983, 10, 1), + # max python date + date(9999, 12, 31), ): cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -464,7 +461,7 @@ def test_time_with_named_time_zone_query_param(trino_connection): with pytest.raises(trino.exceptions.NotSupportedError): cur = trino_connection.cursor() - params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('Asia/Shanghai')) + params = time(16, 43, 22, 320000, tzinfo=pytz.timezone("Asia/Shanghai")) cur.execute("SELECT ?", params=(params,)) @@ -599,7 +596,10 @@ def test_array_timestamp_query_param(trino_connection): def test_array_timestamp_with_timezone_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)] + params = [ + datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), + datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc), + ] cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -723,15 +723,18 @@ def test_int_query_param(trino_connection): assert cur.description[0][1] == "bigint" -@pytest.mark.parametrize('params', [ - 'NOT A LIST OR TUPPLE', - {'invalid', 'params'}, - object, -]) +@pytest.mark.parametrize( + "params", + [ + "NOT A LIST OR TUPPLE", + {"invalid", "params"}, + object, + ], +) def test_select_query_invalid_params(trino_connection, params): cur = trino_connection.cursor() with pytest.raises(AssertionError): - cur.execute('SELECT ?', params=params) + cur.execute("SELECT ?", params=params) def test_select_cursor_iteration(trino_connection): @@ -877,27 +880,20 @@ def test_transaction_multiple(trino_connection_with_transaction): assert len(rows2) == 1000 -@pytest.mark.skipif(trino_version() == '351', reason="Autocommit behaves " - "differently in older Trino versions") +@pytest.mark.skipif(trino_version() == "351", reason="Autocommit behaves differently in older Trino versions") def test_transaction_autocommit(trino_connection_in_autocommit): with trino_connection_in_autocommit as connection: connection.start_transaction() cur = connection.cursor() - cur.execute( - """ - CREATE TABLE memory.default.nation - AS SELECT * from tpch.tiny.nation - """) + cur.execute("CREATE TABLE memory.default.nation AS SELECT * from tpch.tiny.nation") with pytest.raises(TrinoUserError) as transaction_error: cur.fetchall() - assert "Catalog only supports writes using autocommit: memory" \ - in str(transaction_error.value) + assert "Catalog only supports writes using autocommit: memory" in str(transaction_error.value) def test_invalid_query_throws_correct_error(trino_connection): - """Tests that an invalid query raises the correct exception - """ + """Tests that an invalid query raises the correct exception""" cur = trino_connection.cursor() with pytest.raises(TrinoQueryError): cur.execute( @@ -910,14 +906,14 @@ def test_invalid_query_throws_correct_error(trino_connection): def test_eager_loading_cursor_description(trino_connection): description_expected = [ - ('node_id', 'varchar', None, None, None, None, None), - ('http_uri', 'varchar', None, None, None, None, None), - ('node_version', 'varchar', None, None, None, None, None), - ('coordinator', 'boolean', None, None, None, None, None), - ('state', 'varchar', None, None, None, None, None), + ("node_id", "varchar", None, None, None, None, None), + ("http_uri", "varchar", None, None, None, None, None), + ("node_version", "varchar", None, None, None, None, None), + ("coordinator", "boolean", None, None, None, None, None), + ("state", "varchar", None, None, None, None, None), ] cur = trino_connection.cursor() - cur.execute('SELECT * FROM system.runtime.nodes') + cur.execute("SELECT * FROM system.runtime.nodes") description_before = cur.description assert description_before is not None @@ -934,7 +930,7 @@ def test_eager_loading_cursor_description(trino_connection): def test_info_uri(trino_connection): cur = trino_connection.cursor() assert cur.info_uri is None - cur.execute('SELECT * FROM system.runtime.nodes') + cur.execute("SELECT * FROM system.runtime.nodes") assert cur.info_uri is not None assert cur._query.query_id in cur.info_uri cur.fetchall() @@ -971,44 +967,47 @@ def retrieve_client_tags_from_query(run_trino, client_tags): ) cur = trino_connection.cursor() - cur.execute('SELECT 1') + cur.execute("SELECT 1") cur.fetchall() api_url = "http://" + trino_connection.host + ":" + str(trino_connection.port) - query_info = requests.post(api_url + "/ui/login", data={ - "username": "admin", - "password": "", - "redirectPath": api_url + '/ui/api/query/' + cur._query.query_id - }).json() - - query_client_tags = query_info['session']['clientTags'] + query_info = requests.post( + api_url + "/ui/login", + data={ + "username": "admin", + "password": "", + "redirectPath": api_url + "/ui/api/query/" + cur._query.query_id, + }, + ).json() + + query_client_tags = query_info["session"]["clientTags"] return query_client_tags -@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions") +@pytest.mark.skipif(trino_version() == "351", reason="current_catalog not supported in older Trino versions") def test_use_catalog_schema(trino_connection): cur = trino_connection.cursor() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() assert result[0][0] is None assert result[0][1] is None - cur.execute('USE tpch.tiny') + cur.execute("USE tpch.tiny") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'tiny' + assert result[0][0] == "tpch" + assert result[0][1] == "tiny" - cur.execute('USE tpcds.sf1') + cur.execute("USE tpcds.sf1") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpcds' - assert result[0][1] == 'sf1' + assert result[0][0] == "tpcds" + assert result[0][1] == "sf1" -@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions") +@pytest.mark.skipif(trino_version() == "351", reason="current_catalog not supported in older Trino versions") def test_use_catalog(run_trino): _, host, port = run_trino @@ -1016,35 +1015,33 @@ def test_use_catalog(run_trino): host=host, port=port, user="test", source="test", catalog="tpch", max_attempts=1 ) cur = trino_connection.cursor() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' + assert result[0][0] == "tpch" assert result[0][1] is None - cur.execute('USE tiny') + cur.execute("USE tiny") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'tiny' + assert result[0][0] == "tpch" + assert result[0][1] == "tiny" - cur.execute('USE sf1') + cur.execute("USE sf1") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'sf1' + assert result[0][0] == "tpch" + assert result[0][1] == "sf1" -@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role") +@pytest.mark.skipif(trino_version() == "351", reason="Newer Trino versions return the system role") def test_set_role_trino_higher_351(run_trino): _, host, port = run_trino - trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch" - ) + trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch") cur = trino_connection.cursor() - cur.execute('SHOW TABLES FROM information_schema') + cur.execute("SHOW TABLES FROM information_schema") cur.fetchall() assert cur._request._client_session.role is None @@ -1053,15 +1050,13 @@ def test_set_role_trino_higher_351(run_trino): assert cur._request._client_session.role == "system=ALL" -@pytest.mark.skipif(trino_version() != '351', reason="Trino 351 returns the role for the current catalog") +@pytest.mark.skipif(trino_version() != "351", reason="Trino 351 returns the role for the current catalog") def test_set_role_trino_351(run_trino): _, host, port = run_trino - trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch" - ) + trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch") cur = trino_connection.cursor() - cur.execute('SHOW TABLES FROM information_schema') + cur.execute("SHOW TABLES FROM information_schema") cur.fetchall() assert cur._request._client_session.role is None diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 1dc8f05a..bf79ef52 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -11,22 +11,24 @@ # limitations under the License import pytest import sqlalchemy as sqla -from sqlalchemy.sql import and_, or_, not_ +from sqlalchemy.sql import and_, not_, or_ @pytest.fixture def trino_connection(run_trino, request): _, host, port = run_trino - engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}", - connect_args={"source": "test", "max_attempts": 1}) + engine = sqla.create_engine( + f"trino://test@{host}:{port}/{request.param}", + connect_args={"source": "test", "max_attempts": 1}, + ) yield engine, engine.connect() -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_select_query(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) assert_column(nations, "nationkey", sqla.sql.sqltypes.BigInteger) assert_column(nations, "name", sqla.sql.sqltypes.String) assert_column(nations, "regionkey", sqla.sql.sqltypes.BigInteger) @@ -36,10 +38,10 @@ def test_select_query(trino_connection): rows = result.fetchall() assert len(rows) == 25 for row in rows: - assert isinstance(row['nationkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['regionkey'], int) - assert isinstance(row['comment'], str) + assert isinstance(row["nationkey"], int) + assert isinstance(row["name"], str) + assert isinstance(row["regionkey"], int) + assert isinstance(row["comment"], str) def assert_column(table, column_name, column_type): @@ -47,11 +49,11 @@ def assert_column(table, column_name, column_type): assert isinstance(getattr(table.c, column_name).type, column_type) -@pytest.mark.parametrize('trino_connection', ['system'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["system"], indirect=True) def test_select_specific_columns(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nodes = sqla.Table('nodes', metadata, schema='runtime', autoload_with=conn) + nodes = sqla.Table("nodes", metadata, schema="runtime", autoload_with=conn) assert_column(nodes, "node_id", sqla.sql.sqltypes.String) assert_column(nodes, "state", sqla.sql.sqltypes.String) query = sqla.select(nodes.c.node_id, nodes.c.state) @@ -59,26 +61,28 @@ def test_select_specific_columns(trino_connection): rows = result.fetchall() assert len(rows) > 0 for row in rows: - assert isinstance(row['node_id'], str) - assert isinstance(row['state'], str) + assert isinstance(row["node_id"], str) + assert isinstance(row["state"], str) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_define_and_create_table(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(engine, "test"): engine.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: - sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) - assert sqla.inspect(engine).has_table('users', schema="test") - users = sqla.Table('users', metadata, schema='test', autoload_with=conn) + assert sqla.inspect(engine).has_table("users", schema="test") + users = sqla.Table("users", metadata, schema="test", autoload_with=conn) assert_column(users, "id", sqla.sql.sqltypes.Integer) assert_column(users, "name", sqla.sql.sqltypes.String) assert_column(users, "fullname", sqla.sql.sqltypes.String) @@ -86,7 +90,7 @@ def test_define_and_create_table(trino_connection): metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_insert(trino_connection): engine, conn = trino_connection @@ -94,12 +98,14 @@ def test_insert(trino_connection): engine.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: - users = sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + users = sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) ins = users.insert() conn.execute(ins, {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}) @@ -112,72 +118,79 @@ def test_insert(trino_connection): metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_insert_multiple_statements(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(engine, "test"): engine.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() - users = sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + users = sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) ins = users.insert() - conn.execute(ins, [ - {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}, - {"id": 3, "name": "john", "fullname": "John Doe"}, - {"id": 4, "name": "mary", "fullname": "Mary Hopkins"}, - ]) + conn.execute( + ins, + [ + {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}, + {"id": 3, "name": "john", "fullname": "John Doe"}, + {"id": 4, "name": "mary", "fullname": "Mary Hopkins"}, + ], + ) query = sqla.select(users) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 3 - assert frozenset(rows) == frozenset([ - (2, "wendy", "Wendy Williams"), - (3, "john", "John Doe"), - (4, "mary", "Mary Hopkins"), - ]) + assert frozenset(rows) == frozenset( + [ + (2, "wendy", "Wendy Williams"), + (3, "john", "John Doe"), + (4, "mary", "Mary Hopkins"), + ] + ) metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_operators(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - customers = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + customers = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) query = sqla.select(customers).where(customers.c.nationkey == 2) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 1 for row in rows: - assert isinstance(row['nationkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['regionkey'], int) - assert isinstance(row['comment'], str) + assert isinstance(row["nationkey"], int) + assert isinstance(row["name"], str) + assert isinstance(row["regionkey"], int) + assert isinstance(row["comment"], str) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_conjunctions(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) - query = sqla.select(customers).where(and_( - customers.c.name.like('%12%'), - customers.c.nationkey == 15, - or_( - customers.c.mktsegment == 'AUTOMOBILE', - customers.c.mktsegment == 'HOUSEHOLD' - ), - not_(customers.c.acctbal < 0))) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) + query = sqla.select(customers).where( + and_( + customers.c.name.like("%12%"), + customers.c.nationkey == 15, + or_(customers.c.mktsegment == "AUTOMOBILE", customers.c.mktsegment == "HOUSEHOLD"), + not_(customers.c.acctbal < 0), + ) + ) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 1 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_textual_sql(trino_connection): _, conn = trino_connection s = sqla.text("SELECT * from tiny.customer where nationkey = :e1 AND acctbal < :e2") @@ -185,41 +198,46 @@ def test_textual_sql(trino_connection): rows = result.fetchall() assert len(rows) == 3 for row in rows: - assert isinstance(row['custkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['address'], str) - assert isinstance(row['nationkey'], int) - assert isinstance(row['phone'], str) - assert isinstance(row['acctbal'], float) - assert isinstance(row['mktsegment'], str) - assert isinstance(row['comment'], str) + assert isinstance(row["custkey"], int) + assert isinstance(row["name"], str) + assert isinstance(row["address"], str) + assert isinstance(row["nationkey"], int) + assert isinstance(row["phone"], str) + assert isinstance(row["acctbal"], float) + assert isinstance(row["mktsegment"], str) + assert isinstance(row["comment"], str) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_alias(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) nations1 = nations.alias("o1") nations2 = nations.alias("o2") - s = sqla.select(nations1) \ - .join(nations2, and_( - nations1.c.regionkey == nations2.c.regionkey, - nations1.c.nationkey != nations2.c.nationkey, - nations1.c.regionkey == 1 - )) \ + s = ( + sqla.select(nations1) + .join( + nations2, + and_( + nations1.c.regionkey == nations2.c.regionkey, + nations1.c.nationkey != nations2.c.nationkey, + nations1.c.regionkey == 1, + ), + ) .distinct() + ) result = conn.execute(s) rows = result.fetchall() assert len(rows) == 5 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_subquery(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900) automobile_customers_subquery = automobile_customers.subquery() s = sqla.select(nations.c.name).where(nations.c.nationkey.in_(sqla.select(automobile_customers_subquery))) @@ -228,27 +246,29 @@ def test_subquery(trino_connection): assert len(rows) == 15 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_joins(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) - s = sqla.select(nations.c.name) \ - .select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey)) \ - .where(customers.c.acctbal < -900) \ + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) + s = ( + sqla.select(nations.c.name) + .select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey)) + .where(customers.c.acctbal < -900) .distinct() + ) result = conn.execute(s) rows = result.fetchall() assert len(rows) == 15 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_cte(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900) automobile_customers_cte = automobile_customers.cte() s = sqla.select(nations).where(nations.c.nationkey.in_(sqla.select(automobile_customers_cte))) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9d968e07..18cb2e48 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -10,9 +10,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest.mock import MagicMock, patch +import pytest + @pytest.fixture(scope="session") def sample_post_response_data(): @@ -194,8 +195,7 @@ def sample_get_error_response_data(): "errorType": "USER_ERROR", "failureInfo": { "errorLocation": {"columnNumber": 15, "lineNumber": 1}, - "message": "line 1:15: Schema must be specified " - "when session schema is not set", + "message": "line 1:15: Schema must be specified when session schema is not set", "stack": [ "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)", "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)", diff --git a/tests/unit/oauth_test_utils.py b/tests/unit/oauth_test_utils.py index 77b891ce..2a0bf45e 100644 --- a/tests/unit/oauth_test_utils.py +++ b/tests/unit/oauth_test_utils.py @@ -45,9 +45,14 @@ def __call__(self, request, uri, response_headers): authorization = request.headers.get("Authorization") if authorization and authorization.replace("Bearer ", "") in self.tokens: return [200, response_headers, json.dumps(self.sample_post_response_data)] - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", ' - f'x_token_server="{self.token_server}"', - 'Basic realm': '"Trino"'}, ""] + return [ + 401, + { + "Www-Authenticate": f'Bearer x_redirect_server="{self.redirect_server}", x_token_server="{self.token_server}"', + "Basic realm": '"Trino"', + }, + "", + ] class GetTokenCallback: @@ -66,19 +71,25 @@ def __call__(self, request, uri, response_headers): def _get_token_requests(challenge_id): - return list(filter( - lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", - httpretty.latest_requests())) + return list( + filter( + lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", + httpretty.latest_requests(), + ) + ) def _post_statement_requests(): - return list(filter( - lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, - httpretty.latest_requests())) + return list( + filter( + lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, + httpretty.latest_requests(), + ) + ) class MultithreadedTokenServer: - Challenge = namedtuple('Challenge', ['token', 'attempts']) + Challenge = namedtuple("Challenge", ["token", "attempts"]) def __init__(self, sample_post_response_data, attempts=1): self.tokens = set() @@ -90,13 +101,15 @@ def __init__(self, sample_post_response_data, attempts=1): httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=self.post_statement_callback) + body=self.post_statement_callback, + ) # bind get token httpretty.register_uri( method=httpretty.GET, uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), - body=self.get_token_callback) + body=self.get_token_callback, + ) # noinspection PyUnusedLocal def post_statement_callback(self, request, uri, response_headers): @@ -111,9 +124,15 @@ def post_statement_callback(self, request, uri, response_headers): self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts) redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", ' - f'x_token_server="{token_server}"', - 'Basic realm': '"Trino"'}, ""] + return [ + 401, + { + "Www-Authenticate": f'Bearer x_redirect_server="{redirect_server}", ' + f'x_token_server="{token_server}"', + "Basic realm": '"Trino"', + }, + "", + ] # noinspection PyUnusedLocal def get_token_callback(self, request, uri, response_headers): diff --git a/tests/unit/sqlalchemy/conftest.py b/tests/unit/sqlalchemy/conftest.py index e80f19b8..71d6f74d 100644 --- a/tests/unit/sqlalchemy/conftest.py +++ b/tests/unit/sqlalchemy/conftest.py @@ -12,7 +12,7 @@ import pytest from sqlalchemy.sql.sqltypes import ARRAY -from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME +from trino.sqlalchemy.datatype import MAP, ROW, TIME, TIMESTAMP, SQLType @pytest.fixture(scope="session") diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 00bd686c..d7f5acf4 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -10,33 +10,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from sqlalchemy import ( - Column, - insert, - Integer, - MetaData, - select, - String, - Table, -) +from sqlalchemy import Column, Integer, MetaData, String, Table, insert, select from sqlalchemy.schema import CreateTable from trino.sqlalchemy.dialect import TrinoDialect metadata = MetaData() table = Table( - 'table', - metadata, - Column('id', Integer), - Column('name', String), -) -table_with_catalog = Table( - 'table', + "table", metadata, - Column('id', Integer), - schema='default', - trino_catalog='other' + Column("id", Integer), + Column("name", String), ) +table_with_catalog = Table("table", metadata, Column("id", Integer), schema="default", trino_catalog="other") @pytest.fixture @@ -63,15 +49,16 @@ def test_offset(dialect): def test_cte_insert_order(dialect): - cte = select(table).cte('cte') + cte = select(table).cte("cte") statement = insert(table).from_select(table.columns, cte) query = statement.compile(dialect=dialect) - assert str(query) == \ - 'INSERT INTO "table" (id, name) WITH cte AS \n'\ - '(SELECT "table".id AS id, "table".name AS name \n'\ - 'FROM "table")\n'\ - ' SELECT cte.id, cte.name \n'\ - 'FROM cte' + assert ( + str(query) == 'INSERT INTO "table" (id, name) WITH cte AS \n' + '(SELECT "table".id AS id, "table".name AS name \n' + 'FROM "table")\n' + " SELECT cte.id, cte.name \n" + "FROM cte" + ) def test_catalogs_argument(dialect): @@ -83,9 +70,4 @@ def test_catalogs_argument(dialect): def test_catalogs_create_table(dialect): statement = CreateTable(table_with_catalog) query = statement.compile(dialect=dialect) - assert str(query) == \ - '\n'\ - 'CREATE TABLE "other".default."table" (\n'\ - '\tid INTEGER\n'\ - ')\n'\ - '\n' + assert str(query) == """\nCREATE TABLE "other".default."table" (\n\tid INTEGER\n)\n\n""" diff --git a/tests/unit/sqlalchemy/test_datatype_parse.py b/tests/unit/sqlalchemy/test_datatype_parse.py index 66a2f6b0..66c98823 100644 --- a/tests/unit/sqlalchemy/test_datatype_parse.py +++ b/tests/unit/sqlalchemy/test_datatype_parse.py @@ -10,23 +10,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from sqlalchemy.sql.sqltypes import ( - CHAR, - VARCHAR, - ARRAY, - INTEGER, - DECIMAL, - DATE -) +from sqlalchemy.sql.sqltypes import ARRAY, CHAR, DATE, DECIMAL, INTEGER, VARCHAR from sqlalchemy.sql.type_api import TypeEngine from trino.sqlalchemy import datatype -from trino.sqlalchemy.datatype import ( - MAP, - ROW, - TIME, - TIMESTAMP -) +from trino.sqlalchemy.datatype import MAP, ROW, TIME, TIMESTAMP @pytest.mark.parametrize( @@ -68,7 +56,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype): "CHAR(10)": CHAR(10), "VARCHAR(10)": VARCHAR(10), "DECIMAL(20)": DECIMAL(20), - "DECIMAL(20, 3)": DECIMAL(20, 3) + "DECIMAL(20, 3)": DECIMAL(20, 3), } @@ -187,7 +175,7 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype): "timestamp(3)": TIMESTAMP(3, timezone=False), "timestamp(6)": TIMESTAMP(6), "timestamp(12) with time zone": TIMESTAMP(12, timezone=True), - "timestamp with time zone": TIMESTAMP(timezone=True) + "timestamp with time zone": TIMESTAMP(timezone=True), } diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index b17f8cfe..6ac58e19 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -7,7 +7,11 @@ from trino.auth import BasicAuthentication from trino.dbapi import Connection -from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect +from trino.sqlalchemy.dialect import ( + CertificateAuthentication, + JWTAuthentication, + TrinoDialect, +) from trino.transaction import IsolationLevel @@ -26,7 +30,13 @@ def setup(self): ( make_url("trino://user@localhost:8080"), list(), - dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"), + dict( + host="localhost", + port=8080, + catalog="system", + user="user", + source="trino-sqlalchemy", + ), ), ( make_url("trino://user:pass@localhost:8080?source=trino-rulez"), @@ -38,17 +48,18 @@ def setup(self): user="user", auth=BasicAuthentication("user", "pass"), http_scheme="https", - source="trino-rulez" + source="trino-rulez", ), ), ( make_url( - 'trino://user@localhost:8080?' + "trino://user@localhost:8080?" 'session_properties={"query_max_run_time": "1d"}' '&http_headers={"trino": 1}' '&extra_credential=[("a", "b"), ("c", "d")]' '&client_tags=[1, "sql"]' - '&experimental_python_types=true'), + "&experimental_python_types=true" + ), list(), dict( host="localhost", @@ -97,36 +108,36 @@ def test_isolation_level(self): def test_trino_connection_basic_auth(): dialect = TrinoDialect() - username = 'trino-user' - password = 'trino-bunny' - url = make_url(f'trino://{username}:{password}@host') + username = "trino-user" + password = "trino-bunny" + url = make_url(f"trino://{username}:{password}@host") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], BasicAuthentication) - assert cparams['auth']._username == username - assert cparams['auth']._password == password + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], BasicAuthentication) + assert cparams["auth"]._username == username + assert cparams["auth"]._password == password def test_trino_connection_jwt_auth(): dialect = TrinoDialect() - access_token = 'sample-token' - url = make_url(f'trino://host/?access_token={access_token}') + access_token = "sample-token" + url = make_url(f"trino://host/?access_token={access_token}") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], JWTAuthentication) - assert cparams['auth'].token == access_token + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], JWTAuthentication) + assert cparams["auth"].token == access_token def test_trino_connection_certificate_auth(): dialect = TrinoDialect() - cert = '/path/to/cert.pem' - key = '/path/to/key.pem' - url = make_url(f'trino://host/?cert={cert}&key={key}') + cert = "/path/to/cert.pem" + key = "/path/to/key.pem" + url = make_url(f"trino://host/?cert={cert}&key={key}") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], CertificateAuthentication) - assert cparams['auth']._cert == cert - assert cparams['auth']._key == key + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], CertificateAuthentication) + assert cparams["auth"]._cert == cert + assert cparams["auth"]._key == key diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 12089305..0c274ab3 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -23,13 +23,28 @@ from requests_kerberos.exceptions import KerberosExchangeError import trino.exceptions -from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \ - MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \ - SERVER_ADDRESS +from tests.unit.oauth_test_utils import ( + REDIRECT_RESOURCE, + SERVER_ADDRESS, + TOKEN_RESOURCE, + GetTokenCallback, + MultithreadedTokenServer, + PostStatementCallback, + RedirectHandler, + _get_token_requests, + _post_statement_requests, +) from trino import constants from trino.auth import KerberosAuthentication, _OAuth2TokenBearer -from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession, _DelayExponential, _retry_with, \ - _RetryWithExponentialBackoff +from trino.client import ( + ClientSession, + TrinoQuery, + TrinoRequest, + TrinoResult, + _DelayExponential, + _retry_with, + _RetryWithExponentialBackoff, +) @mock.patch("trino.client.TrinoRequest.http") @@ -80,7 +95,7 @@ def test_request_headers(mock_get_and_post): headers={ accept_encoding_header: accept_encoding_value, client_info_header: client_info_value, - } + }, ), http_scheme="http", redirect_handler=None, @@ -112,14 +127,7 @@ def test_request_session_properties_headers(mock_get_and_post): req = TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - properties={ - "a": "1", - "b": "2", - "c": "more=v1,v2" - } - ) + client_session=ClientSession(user="test_user", properties={"a": "1", "b": "2", "c": "more=v1,v2"}), ) def assert_headers(headers): @@ -154,10 +162,10 @@ def test_additional_request_post_headers(mock_get_and_post): http_scheme="http", ) - sql = 'select 1' + sql = "select 1" additional_headers = { - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', + "X-Trino-Fake-1": "one", + "X-Trino-Fake-2": "two", } combined_headers = req.http_headers @@ -167,7 +175,7 @@ def test_additional_request_post_headers(mock_get_and_post): # Validate that the post call was performed including the addtional headers _, post_kwargs = post.call_args - assert post_kwargs['headers'] == combined_headers + assert post_kwargs["headers"] == combined_headers def test_request_invalid_http_headers(): @@ -189,10 +197,7 @@ def test_request_client_tags_headers(mock_get_and_post): req = TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - client_tags=["tag1", "tag2"] - ), + client_session=ClientSession(user="test_user", client_tags=["tag1", "tag2"]), ) def assert_headers(headers): @@ -215,7 +220,7 @@ def test_request_client_tags_headers_no_client_tags(mock_get_and_post): port=8080, client_session=ClientSession( user="test_user", - ) + ), ) def assert_headers(headers): @@ -341,14 +346,12 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data): httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token, attempts) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() @@ -359,10 +362,11 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) response = request.post("select 1") - assert response.request.headers['Authorization'] == f"Bearer {token}" + assert response.request.headers["Authorization"] == f"Bearer {token}" assert redirect_handler.redirect_server == redirect_server assert get_token_callback.attempts == 0 assert len(_post_statement_requests()) == 2 @@ -384,14 +388,12 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token, attempts) - httpretty.register_uri( - method=httpretty.GET, - uri=f"{TOKEN_RESOURCE}/{challenge_id}", - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", body=get_token_callback) redirect_handler = RedirectHandler() @@ -402,7 +404,8 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") @@ -413,21 +416,34 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): assert len(_get_token_requests(challenge_id)) == _OAuth2TokenBearer.MAX_OAUTH_ATTEMPTS -@pytest.mark.parametrize("header,error", [ - ("", "Error: header WWW-Authenticate not available in the response."), - ('Bearer"', 'Error: header info didn\'t have x_redirect_server'), - ('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'), # noqa: E501 - ('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'), - ('Bearer x_token_server="token_server"', 'Error: header info didn\'t have x_redirect_server'), -]) +@pytest.mark.parametrize( + "header,error", + [ + ("", "Error: header WWW-Authenticate not available in the response."), + ('Bearer"', "Error: header info didn't have x_redirect_server"), + ( + 'x_redirect_server="redirect_server", x_token_server="token_server"', + 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"', + ), # noqa: E501 + ( + 'Bearer x_redirect_server="redirect_server"', + "Error: header info didn't have x_token_server", + ), + ( + 'Bearer x_token_server="token_server"', + "Error: header info didn't have x_redirect_server", + ), + ], +) @httprettified def test_oauth2_authentication_missing_headers(header, error): # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - adding_headers={'WWW-Authenticate': header}, - status=401) + adding_headers={"WWW-Authenticate": header}, + status=401, + ) request = TrinoRequest( host="coordinator", @@ -436,7 +452,8 @@ def test_oauth2_authentication_missing_headers(header, error): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=RedirectHandler())) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=RedirectHandler()), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") @@ -444,13 +461,16 @@ def test_oauth2_authentication_missing_headers(header, error): assert str(exp.value) == error -@pytest.mark.parametrize("header", [ - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge', - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge="value"', - 'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"', - 'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"', - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"', -]) +@pytest.mark.parametrize( + "header", + [ + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge', + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge="value"', + 'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"', + 'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"', + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"', + ], +) @httprettified def test_oauth2_header_parsing(header, sample_post_response_data): token = str(uuid.uuid4()) @@ -464,21 +484,25 @@ def post_statement(request, uri, response_headers): authorization = request.headers.get("Authorization") if authorization and authorization.replace("Bearer ", "") in token: return [200, response_headers, json.dumps(sample_post_response_data)] - return [401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server), - 'Basic realm': '"Trino"'}, ""] + return [ + 401, + { + "Www-Authenticate": header.format(redirect_server=redirect_server, token_server=token_server), + "Basic realm": '"Trino"', + }, + "", + ] # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement) + body=post_statement, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() @@ -489,10 +513,10 @@ def post_statement(request, uri, response_headers): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), ).post("select 1") - assert response.request.headers['Authorization'] == f"Bearer {token}" + assert response.request.headers["Authorization"] == f"Bearer {token}" assert redirect_handler.redirect_server == redirect_server assert get_token_callback.attempts == 0 assert len(_post_statement_requests()) == 2 @@ -514,13 +538,15 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) httpretty.register_uri( method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", status=http_status, - body="error") + body="error", + ) redirect_handler = RedirectHandler() @@ -531,7 +557,8 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") @@ -564,7 +591,8 @@ def run(self) -> None: user="test", ), http_scheme=constants.HTTPS, - auth=auth) + auth=auth, + ) for i in range(10): # apparently HTTPretty in the current version is not thread-safe # https://github.com/gabrielfalcao/HTTPretty/issues/209 @@ -650,10 +678,7 @@ def test_trino_fetch_error(mock_requests, sample_get_error_response_data): assert "stack" in error.failure_info assert len(error.failure_info["stack"]) == 36 assert "suppressed" in error.failure_info - assert ( - error.message - == "line 1:15: Schema must be specified when session schema is not set" - ) + assert error.message == "line 1:15: Schema must be specified when session schema is not set" assert error.error_location == (1, 15) assert error.query_id == "20210817_140827_00000_arvdv" @@ -803,11 +828,14 @@ def test_authentication_fail_retry(monkeypatch): assert post_retry.retry_count == attempts -@pytest.mark.parametrize("status_code, attempts", [ - (502, 3), - (503, 3), - (504, 3), -]) +@pytest.mark.parametrize( + "status_code, attempts", + [ + (502, 3), + (503, 3), + (504, 3), + ], +) def test_5XX_error_retry(status_code, attempts, monkeypatch): http_resp = TrinoRequest.http.Response() http_resp.status_code = status_code @@ -824,7 +852,7 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch): client_session=ClientSession( user="test", ), - max_attempts=attempts + max_attempts=attempts, ) req.post("URL") @@ -834,9 +862,7 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch): assert post_retry.retry_count == attempts -@pytest.mark.parametrize("status_code", [ - 501 -]) +@pytest.mark.parametrize("status_code", [501]) def test_error_no_retry(status_code, monkeypatch): http_resp = TrinoRequest.http.Response() http_resp.status_code = status_code @@ -887,10 +913,12 @@ def test_trino_result_response_headers(): headers associated to the TrinoQuery instance provided to the `TrinoResult` class. """ - mock_trino_query = mock.Mock(respone_headers={ - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', - }) + mock_trino_query = mock.Mock( + respone_headers={ + "X-Trino-Fake-1": "one", + "X-Trino-Fake-2": "two", + } + ) result = TrinoResult( query=mock_trino_query, @@ -910,8 +938,8 @@ class MockResponse(mock.Mock): @property def headers(self): return { - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', + "X-Trino-Fake-1": "one", + "X-Trino-Fake-2": "two", } def json(self): @@ -930,18 +958,13 @@ def json(self): http_scheme="http", ) - sql = 'execute my_stament using 1, 2, 3' - additional_headers = { - constants.HEADER_PREPARED_STATEMENT: 'my_statement=added_prepare_statement_header' - } + sql = "execute my_stament using 1, 2, 3" + additional_headers = {constants.HEADER_PREPARED_STATEMENT: "my_statement=added_prepare_statement_header"} # Patch the post function to avoid making the requests, as well as to # validate that the function was called with the right arguments. - with mock.patch.object(req, 'post', return_value=MockResponse()) as mock_post: - query = TrinoQuery( - request=req, - sql=sql - ) + with mock.patch.object(req, "post", return_value=MockResponse()) as mock_post: + query = TrinoQuery(request=req, sql=sql) result = query.execute(additional_http_headers=additional_headers) # Validate the the post function was called with the right argguments diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 7b1c72c2..7a9a46c0 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -17,8 +17,16 @@ from httpretty import httprettified from requests import Session -from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \ - GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS +from tests.unit.oauth_test_utils import ( + REDIRECT_RESOURCE, + SERVER_ADDRESS, + TOKEN_RESOURCE, + GetTokenCallback, + PostStatementCallback, + RedirectHandler, + _get_token_requests, + _post_statement_requests, +) from trino import constants from trino.auth import OAuth2Authentication from trino.dbapi import connect @@ -64,22 +72,20 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data): httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() with connect( - "coordinator", - user="test", - auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), - http_scheme=constants.HTTPS + "coordinator", + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + http_scheme=constants.HTTPS, ) as conn: conn.cursor().execute("SELECT 1") conn.cursor().execute("SELECT 2") @@ -87,18 +93,15 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data): # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() with connect( - "coordinator", - user="test", - auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), - http_scheme=constants.HTTPS + "coordinator", + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + http_scheme=constants.HTTPS, ) as conn2: conn2.cursor().execute("SELECT 1") conn2.cursor().execute("SELECT 2") @@ -121,42 +124,27 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler) - with connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS - ) as conn: + with connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) as conn: conn.cursor().execute("SELECT 1") conn.cursor().execute("SELECT 2") conn.cursor().execute("SELECT 3") # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) - with connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS - ) as conn2: + with connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) as conn2: conn2.cursor().execute("SELECT 1") conn2.cursor().execute("SELECT 2") conn2.cursor().execute("SELECT 3") @@ -179,25 +167,18 @@ def test_token_retrieved_once_when_multithreaded(sample_post_response_data): httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler) - conn = connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS - ) + conn = connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) class RunningThread(threading.Thread): lock = threading.Lock() @@ -209,11 +190,7 @@ def run(self) -> None: with RunningThread.lock: conn.cursor().execute("SELECT 1") - threads = [ - RunningThread(), - RunningThread(), - RunningThread() - ] + threads = [RunningThread(), RunningThread(), RunningThread()] # run and join all threads for thread in threads: diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index fd3c5e2d..9753bab5 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -10,8 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from trino.client import get_header_values, get_session_property_values from trino import constants +from trino.client import get_header_values, get_session_property_values def test_get_header_values(): diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index e97903cf..02adeb9d 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -10,9 +10,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from trino.transaction import IsolationLevel import pytest +from trino.transaction import IsolationLevel + def test_isolation_level_levels() -> None: levels = { @@ -27,9 +28,7 @@ def test_isolation_level_levels() -> None: def test_isolation_level_values() -> None: - values = { - 0, 1, 2, 3, 4 - } + values = {0, 1, 2, 3, 4} assert IsolationLevel.values() == values diff --git a/trino/__init__.py b/trino/__init__.py index 4ff3e55b..42db0646 100644 --- a/trino/__init__.py +++ b/trino/__init__.py @@ -10,13 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import auth -from . import dbapi -from . import client -from . import constants -from . import exceptions -from . import logging +from . import auth, client, constants, dbapi, exceptions, logging -__all__ = ['auth', 'dbapi', 'client', 'constants', 'exceptions', 'logging'] +__all__ = ["auth", "dbapi", "client", "constants", "exceptions", "logging"] __version__ = "0.315.0" diff --git a/trino/auth.py b/trino/auth.py index e6b4f04c..0ec15366 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -11,18 +11,18 @@ # limitations under the License. import abc +import importlib import json import os import re import threading import webbrowser -from typing import Optional, List, Callable +from typing import Callable, List, Optional from urllib.parse import urlparse from requests import Request from requests.auth import AuthBase, extract_cookies_to_jar from requests.utils import parse_dict_header -import importlib import trino.logging from trino.client import exceptions @@ -95,15 +95,17 @@ def get_exceptions(self): def __eq__(self, other): if not isinstance(other, KerberosAuthentication): return False - return (self._config == other._config - and self._service_name == other._service_name - and self._mutual_authentication == other._mutual_authentication - and self._force_preemptive == other._force_preemptive - and self._hostname_override == other._hostname_override - and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response - and self._principal == other._principal - and self._delegate == other._delegate - and self._ca_bundle == other._ca_bundle) + return ( + self._config == other._config + and self._service_name == other._service_name + and self._mutual_authentication == other._mutual_authentication + and self._force_preemptive == other._force_preemptive + and self._hostname_override == other._hostname_override + and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response + and self._principal == other._principal + and self._delegate == other._delegate + and self._ca_bundle == other._ca_bundle + ) class BasicAuthentication(Authentication): @@ -143,7 +145,6 @@ def __call__(self, r): class JWTAuthentication(Authentication): - def __init__(self, token): self.token = token @@ -251,24 +252,29 @@ def get_token_from_cache(self, host: str) -> Optional[str]: try: return self._keyring.get_password(host, "token") except self._keyring.errors.NoKeyringError as e: - raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " - "detected, check https://pypi.org/project/keyring/ for more " - "information.") from e + raise trino.exceptions.NotSupportedError( + "Although keyring module is installed no backend has been " + "detected, check https://pypi.org/project/keyring/ for more " + "information." + ) from e def store_token_to_cache(self, host: str, token: str) -> None: try: # keyring is installed, so we can store the token for reuse within multiple threads self._keyring.set_password(host, "token", token) except self._keyring.errors.NoKeyringError as e: - raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " - "detected, check https://pypi.org/project/keyring/ for more " - "information.") from e + raise trino.exceptions.NotSupportedError( + "Although keyring module is installed no backend has been " + "detected, check https://pypi.org/project/keyring/ for more " + "information." + ) from e class _OAuth2TokenBearer(AuthBase): """ Custom implementation of Trino Oauth2 based authorization to get the token """ + MAX_OAUTH_ATTEMPTS = 5 _BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE) @@ -285,9 +291,9 @@ def __call__(self, r): token = self._get_token_from_cache(host) if token is not None: - r.headers['Authorization'] = "Bearer " + token + r.headers["Authorization"] = "Bearer " + token - r.register_hook('response', self._authenticate) + r.register_hook("response", self._authenticate) return r @@ -312,7 +318,7 @@ def _authenticate(self, response, **kwargs): def _attempt_oauth(self, response, **kwargs): # we have to handle the authentication, may be token the token expired, or it wasn't there at all - auth_info = response.headers.get('WWW-Authenticate') + auth_info = response.headers.get("WWW-Authenticate") if not auth_info: raise exceptions.TrinoAuthError("Error: header WWW-Authenticate not available in the response.") @@ -321,11 +327,11 @@ def _attempt_oauth(self, response, **kwargs): auth_info_headers = parse_dict_header(_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)) - auth_server = auth_info_headers.get('x_redirect_server') + auth_server = auth_info_headers.get("x_redirect_server") if auth_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_redirect_server") - token_server = auth_info_headers.get('x_token_server') + token_server = auth_info_headers.get("x_token_server") if token_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server") @@ -349,7 +355,7 @@ def _retry_request(self, response, **kwargs): request.prepare_cookies(request._cookies) host = self._determine_host(response.request.url) - request.headers['Authorization'] = "Bearer " + self._get_token_from_cache(host) + request.headers["Authorization"] = "Bearer " + self._get_token_from_cache(host) retry_response = response.connection.send(request, **kwargs) retry_response.history.append(response) retry_response.request = request @@ -359,23 +365,24 @@ def _get_token(self, token_server, response, **kwargs): attempts = 0 while attempts < self.MAX_OAUTH_ATTEMPTS: attempts += 1 - with response.connection.send(Request(method='GET', url=token_server).prepare(), **kwargs) as response: + with response.connection.send(Request(method="GET", url=token_server).prepare(), **kwargs) as response: if response.status_code == 200: token_response = json.loads(response.text) - token = token_response.get('token') + token = token_response.get("token") if token: return token - error = token_response.get('error') + error = token_response.get("error") if error: raise exceptions.TrinoAuthError(f"Error while getting the token: {error}") else: - token_server = token_response.get('nextUri') + token_server = token_response.get("nextUri") logger.debug(f"nextURi auth token server: {token_server}") else: raise exceptions.TrinoAuthError( f"Error while getting the token response " f"status code: {response.status_code}, " - f"body: {response.text}") + f"body: {response.text}" + ) raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token") @@ -393,10 +400,10 @@ def _determine_host(url) -> Optional[str]: class OAuth2Authentication(Authentication): - def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([ - WebBrowserRedirectHandler(), - ConsoleRedirectHandler() - ])): + def __init__( + self, + redirect_auth_url_handler=CompositeRedirectHandler([WebBrowserRedirectHandler(), ConsoleRedirectHandler()]), + ): self._redirect_auth_url = redirect_auth_url_handler self._bearer = _OAuth2TokenBearer(self._redirect_auth_url) diff --git a/trino/client.py b/trino/client.py index 211ce9c0..49364ad5 100644 --- a/trino/client.py +++ b/trino/client.py @@ -62,7 +62,7 @@ else: PROXIES = {} -_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$') +_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$") INF = float("inf") NEGATIVE_INF = float("-inf") @@ -213,10 +213,7 @@ def get_header_values(headers, header): def get_session_property_values(headers, header): kvs = get_header_values(headers, header) - return [ - (k.strip(), urllib.parse.unquote(v.strip())) - for k, v in (kv.split("=", 1) for kv in kvs) - ] + return [(k.strip(), urllib.parse.unquote(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs)] class TrinoStatus(object): @@ -245,16 +242,14 @@ def __repr__(self): class _DelayExponential(object): - def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): + def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600): # 100ms # 2 hours self._base = base self._exponent = exponent self._jitter = jitter self._max_delay = max_delay def __call__(self, attempt): - delay = float(self._base) * (self._exponent ** attempt) + delay = float(self._base) * (self._exponent**attempt) if self._jitter: delay *= random.random() delay = min(float(self._max_delay), delay) @@ -262,9 +257,7 @@ def __call__(self, attempt): class _RetryWithExponentialBackoff(object): - def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): + def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600): # 100ms # 2 hours self._get_delay = _DelayExponential(base, exponent, jitter, max_delay) def retry(self, func, args, kwargs, err, attempt): @@ -401,8 +394,7 @@ def http_headers(self) -> Dict[str, str]: transaction_id = self._client_session.transaction_id headers[constants.HEADER_TRANSACTION] = transaction_id - if self._client_session.extra_credential is not None and \ - len(self._client_session.extra_credential) > 0: + if self._client_session.extra_credential is not None and len(self._client_session.extra_credential) > 0: for tup in self._client_session.extra_credential: self._verify_extra_credential(tup) @@ -410,9 +402,9 @@ def http_headers(self) -> Dict[str, str]: # HTTP 1.1 section 4.2 combine multiple extra credentials into a # comma-separated value # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format) - headers[constants.HEADER_EXTRA_CREDENTIAL] = \ - ", ".join( - [f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential]) + headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join( + [f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential] + ) return headers @@ -536,15 +528,11 @@ def process(self, http_response) -> TrinoStatus: raise self._process_error(response["error"], response.get("id")) if constants.HEADER_CLEAR_SESSION in http_response.headers: - for prop in get_header_values( - http_response.headers, constants.HEADER_CLEAR_SESSION - ): + for prop in get_header_values(http_response.headers, constants.HEADER_CLEAR_SESSION): self._client_session.properties.pop(prop, None) if constants.HEADER_SET_SESSION in http_response.headers: - for key, value in get_session_property_values( - http_response.headers, constants.HEADER_SET_SESSION - ): + for key, value in get_session_property_values(http_response.headers, constants.HEADER_SET_SESSION): self._client_session.properties[key] = value if constants.HEADER_SET_CATALOG in http_response.headers: @@ -579,7 +567,7 @@ def _verify_extra_credential(self, header): raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'") try: - key.encode().decode('ascii') + key.encode().decode("ascii") except UnicodeDecodeError: raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'") @@ -640,9 +628,7 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any: try: if isinstance(value, list): if raw_type == "array": - raw_type = { - "typeSignature": data_type["typeSignature"]["arguments"][0]["value"] - } + raw_type = {"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]} return [cls._map_to_python_type((array_item, raw_type)) for array_item in value] if raw_type == "row": raw_types = map(lambda arg: arg["value"], data_type["typeSignature"]["arguments"]) @@ -652,41 +638,36 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any: ) return value if isinstance(value, dict): - raw_key_type = { - "typeSignature": data_type["typeSignature"]["arguments"][0]["value"] - } - raw_value_type = { - "typeSignature": data_type["typeSignature"]["arguments"][1]["value"] - } + raw_key_type = {"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]} + raw_value_type = {"typeSignature": data_type["typeSignature"]["arguments"][1]["value"]} return { - cls._map_to_python_type((key, raw_key_type)): - cls._map_to_python_type((value[key], raw_value_type)) + cls._map_to_python_type((key, raw_key_type)): cls._map_to_python_type((value[key], raw_value_type)) for key in value } elif "decimal" in raw_type: return Decimal(value) elif raw_type == "double": - if value == 'Infinity': + if value == "Infinity": return INF - elif value == '-Infinity': + elif value == "-Infinity": return NEGATIVE_INF - elif value == 'NaN': + elif value == "NaN": return NAN return value elif raw_type == "date": return datetime.strptime(value, "%Y-%m-%d").date() elif raw_type == "timestamp with time zone": - dt, tz = value.rsplit(' ', 1) - if tz.startswith('+') or tz.startswith('-'): + dt, tz = value.rsplit(" ", 1) + if tz.startswith("+") or tz.startswith("-"): return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f %z") return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=pytz.timezone(tz)) elif "timestamp" in raw_type: return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f") elif "time with time zone" in raw_type: - matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value) + matches = re.match(r"^(.*)([\+\-])(\d{2}):(\d{2})$", value) assert matches is not None assert len(matches.groups()) == 4 - if matches.group(2) == '-': + if matches.group(2) == "-": tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) else: tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) @@ -707,10 +688,10 @@ class TrinoQuery(object): """Represent the execution of a SQL statement by Trino.""" def __init__( - self, - request: TrinoRequest, - sql: str, - experimental_python_types: bool = False, + self, + request: TrinoRequest, + sql: str, + experimental_python_types: bool = False, ) -> None: self.query_id: Optional[str] = None @@ -814,6 +795,7 @@ def cancel(self) -> None: def is_finished(self) -> bool: import warnings + warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning) return self.finished diff --git a/trino/constants.py b/trino/constants.py index 30046908..41104fdb 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -12,7 +12,6 @@ from typing import Any, Optional - DEFAULT_PORT = 8080 DEFAULT_TLS_PORT = 443 DEFAULT_SOURCE = "trino-python-client" @@ -45,9 +44,9 @@ HEADER_STARTED_TRANSACTION = "X-Trino-Started-Transaction-Id" HEADER_TRANSACTION = "X-Trino-Transaction-Id" -HEADER_PREPARED_STATEMENT = 'X-Trino-Prepared-Statement' -HEADER_ADDED_PREPARE = 'X-Trino-Added-Prepare' -HEADER_DEALLOCATED_PREPARE = 'X-Trino-Deallocated-Prepare' +HEADER_PREPARED_STATEMENT = "X-Trino-Prepared-Statement" +HEADER_ADDED_PREPARE = "X-Trino-Added-Prepare" +HEADER_DEALLOCATED_PREPARE = "X-Trino-Deallocated-Prepare" HEADER_SET_SCHEMA = "X-Trino-Set-Schema" HEADER_SET_CATALOG = "X-Trino-Set-Catalog" diff --git a/trino/dbapi.py b/trino/dbapi.py index 44813168..3553d03a 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -17,31 +17,30 @@ Fetch methods returns rows as a list of lists on purpose to let the caller decide to convert then to a list of tuples. """ -from decimal import Decimal -from typing import Any, List, Optional # NOQA for mypy types - import copy -import uuid import datetime import math +import uuid +from decimal import Decimal +from typing import Any, List, Optional # NOQA for mypy types -from trino import constants -import trino.exceptions import trino.client +import trino.exceptions import trino.logging -from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION +from trino import constants from trino.exceptions import ( - Warning, - Error, - InterfaceError, DatabaseError, DataError, - OperationalError, + Error, IntegrityError, + InterfaceError, InternalError, - ProgrammingError, NotSupportedError, + OperationalError, + ProgrammingError, + Warning, ) +from trino.transaction import NO_TRANSACTION, IsolationLevel, Transaction __all__ = [ # https://www.python.org/dev/peps/pep-0249/#globals @@ -87,7 +86,6 @@ class Connection(object): a sequence of SQL statements. A single query i.e. the execution of a SQL statement, can also be cancelled. Transactions are not supported by this client implementation yet. - """ def __init__( @@ -128,7 +126,7 @@ def __init__( headers=http_headers, transaction_id=NO_TRANSACTION, extra_credential=extra_credential, - client_tags=client_tags + client_tags=client_tags, ) # mypy cannot follow module import if http_session is None: @@ -217,7 +215,7 @@ def cursor(self, experimental_python_types: bool = None): self, request, # if experimental_python_types is not explicitly set in Cursor, take from Connection - experimental_python_types if experimental_python_types is not None else self.experimental_python_types + experimental_python_types if experimental_python_types is not None else self.experimental_python_types, ) @@ -231,9 +229,7 @@ class Cursor(object): def __init__(self, connection, request, experimental_python_types: bool = False): if not isinstance(connection, Connection): - raise ValueError( - "connection must be a Connection object: {}".format(type(connection)) - ) + raise ValueError("connection must be a Connection object: {}".format(type(connection))) self._connection = connection self._request = request @@ -267,10 +263,7 @@ def description(self): return None # [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ] - return [ - (col["name"], col["type"], None, None, None, None, None) - for col in self._query.columns - ] + return [(col["name"], col["type"], None, None, None, None, None) for col in self._query.columns] @property def rowcount(self): @@ -317,15 +310,15 @@ def _prepare_statement(self, operation, statement_name): :return: string representing the value of the 'X-Trino-Added-Prepare' header. """ - sql = 'PREPARE {statement_name} FROM {operation}'.format( - statement_name=statement_name, - operation=operation - ) + sql = "PREPARE {statement_name} FROM {operation}".format(statement_name=statement_name, operation=operation) # Send prepare statement. Copy the _request object to avoid poluting the # one that is going to be used to execute the actual operation. - query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, - experimental_python_types=self._experimental_pyton_types) + query = trino.client.TrinoQuery( + copy.deepcopy(self._request), + sql=sql, + experimental_python_types=self._experimental_pyton_types, + ) result = query.execute() # Iterate until the 'X-Trino-Added-Prepare' header is found or @@ -338,12 +331,8 @@ def _prepare_statement(self, operation, statement_name): raise trino.exceptions.FailedToObtainAddedPrepareHeader - def _get_added_prepare_statement_trino_query( - self, - statement_name, - params - ): - sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) + def _get_added_prepare_statement_trino_query(self, statement_name, params): + sql = "EXECUTE " + statement_name + " USING " + ",".join(map(self._format_prepared_param, params)) # No need to deepcopy _request here because this is the actual request # operation @@ -374,7 +363,7 @@ def _format_prepared_param(self, param): return "DOUBLE '%s'" % param if isinstance(param, str): - return ("'%s'" % param.replace("'", "''")) + return "'%s'" % param.replace("'", "''") if isinstance(param, bytes): return "X'%s'" % param.hex() @@ -386,7 +375,7 @@ def _format_prepared_param(self, param): if isinstance(param, datetime.datetime) and param.tzinfo is not None: datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f") # named timezones - if hasattr(param.tzinfo, 'zone'): + if hasattr(param.tzinfo, "zone"): return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.zone) # offset-based timezones return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param)) @@ -401,18 +390,15 @@ def _format_prepared_param(self, param): return "DATE '%s'" % date_str if isinstance(param, list): - return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param)) + return "ARRAY[%s]" % ",".join(map(self._format_prepared_param, param)) if isinstance(param, tuple): - return "ROW(%s)" % ','.join(map(self._format_prepared_param, param)) + return "ROW(%s)" % ",".join(map(self._format_prepared_param, param)) if isinstance(param, dict): keys = list(param.keys()) values = [param[key] for key in keys] - return "MAP({}, {})".format( - self._format_prepared_param(keys), - self._format_prepared_param(values) - ) + return "MAP({}, {})".format(self._format_prepared_param(keys), self._format_prepared_param(values)) if isinstance(param, uuid.UUID): return "UUID '%s'" % param @@ -423,17 +409,16 @@ def _format_prepared_param(self, param): raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param)) def _deallocate_prepare_statement(self, added_prepare_header, statement_name): - sql = 'DEALLOCATE PREPARE ' + statement_name + sql = "DEALLOCATE PREPARE " + statement_name # Send deallocate statement. Copy the _request object to avoid poluting the # one that is going to be used to execute the actual operation. - query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, - experimental_python_types=self._experimental_pyton_types) - result = query.execute( - additional_http_headers={ - constants.HEADER_PREPARED_STATEMENT: added_prepare_header - } + query = trino.client.TrinoQuery( + copy.deepcopy(self._request), + sql=sql, + experimental_python_types=self._experimental_pyton_types, ) + result = query.execute(additional_http_headers={constants.HEADER_PREPARED_STATEMENT: added_prepare_header}) # Iterate until the 'X-Trino-Deallocated-Prepare' header is found or # until there are no more results @@ -446,31 +431,24 @@ def _deallocate_prepare_statement(self, added_prepare_header, statement_name): raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader def _generate_unique_statement_name(self): - return 'st_' + uuid.uuid4().hex.replace('-', '') + return "st_" + uuid.uuid4().hex.replace("-", "") def execute(self, operation, params=None): if params: assert isinstance(params, (list, tuple)), ( - 'params must be a list or tuple containing the query ' - 'parameter values' + "params must be a list or tuple containing the query parameter values" ) statement_name = self._generate_unique_statement_name() # Send prepare statement - added_prepare_header = self._prepare_statement( - operation, statement_name - ) + added_prepare_header = self._prepare_statement(operation, statement_name) try: # Send execute statement and assign the return value to `results` # as it will be returned by the function - self._query = self._get_added_prepare_statement_trino_query( - statement_name, params - ) + self._query = self._get_added_prepare_statement_trino_query(statement_name, params) result = self._query.execute( - additional_http_headers={ - constants.HEADER_PREPARED_STATEMENT: added_prepare_header - } + additional_http_headers={constants.HEADER_PREPARED_STATEMENT: added_prepare_header} ) finally: # Send deallocate statement @@ -479,8 +457,11 @@ def execute(self, operation, params=None): self._deallocate_prepare_statement(added_prepare_header, statement_name) else: - self._query = trino.client.TrinoQuery(self._request, sql=operation, - experimental_python_types=self._experimental_pyton_types) + self._query = trino.client.TrinoQuery( + self._request, + sql=operation, + experimental_python_types=self._experimental_pyton_types, + ) result = self._query.execute() self._iterator = iter(result) return result @@ -570,9 +551,7 @@ def fetchall(self) -> List[List[Any]]: def cancel(self): if self._query is None: - raise trino.exceptions.OperationalError( - "Cancel query failed; no running query" - ) + raise trino.exceptions.OperationalError("Cancel query failed; no running query") self._query.cancel() def close(self): @@ -606,13 +585,9 @@ def __eq__(self, other): STRING = DBAPITypeObject("VARCHAR", "CHAR", "VARBINARY", "JSON", "IPADDRESS") -BINARY = DBAPITypeObject( - "ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest" -) +BINARY = DBAPITypeObject("ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest") -NUMBER = DBAPITypeObject( - "BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL" -) +NUMBER = DBAPITypeObject("BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL") DATETIME = DBAPITypeObject( "DATE", diff --git a/trino/exceptions.py b/trino/exceptions.py index 86708fd0..f8eddba3 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -139,6 +139,7 @@ class FailedToObtainAddedPrepareHeader(Error): Raise this exception when unable to find the 'X-Trino-Added-Prepare' header in the response of a PREPARE statement request. """ + pass @@ -147,6 +148,7 @@ class FailedToObtainDeallocatedPrepareHeader(Error): Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare' header in the response of a DEALLOCATED statement request. """ + pass diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index a085fbf3..fbd151fc 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -10,15 +10,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from sqlalchemy.sql import compiler + try: - from sqlalchemy.sql.expression import ( - Alias, - CTE, - Subquery, - ) + from sqlalchemy.sql.expression import CTE, Alias, Subquery except ImportError: # For SQLAlchemy versions < 1.4, the CTE and Subquery classes did not explicitly exist from sqlalchemy.sql.expression import Alias + CTE = type(None) Subquery = type(None) @@ -113,11 +111,17 @@ def limit_clause(self, select, **kw): text += "\nLIMIT " + self.process(select._limit_clause, **kw) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, use_schema=True, **kwargs): - sql = super(TrinoSQLCompiler, self).visit_table( - table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs - ) + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + **kwargs, + ): + sql = super(TrinoSQLCompiler, self).visit_table(table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs) return self.add_catalog(sql, table) @staticmethod @@ -128,13 +132,10 @@ def add_catalog(sql, table): if isinstance(table, (Alias, CTE, Subquery)): return sql - if ( - 'trino' not in table.dialect_options - or 'catalog' not in table.dialect_options['trino'] - ): + if "trino" not in table.dialect_options or "catalog" not in table.dialect_options["trino"]: return sql - catalog = table.dialect_options['trino']['catalog'] + catalog = table.dialect_options["trino"]["catalog"] sql = f'"{catalog}".{sql}' return sql diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index 8284ba9c..d0880283 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union from sqlalchemy import util from sqlalchemy.sql import sqltypes diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 7c4409a0..3bf60c1f 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -19,7 +19,8 @@ from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.engine.url import URL -from trino import dbapi as trino_dbapi, logging +from trino import dbapi as trino_dbapi +from trino import logging from trino.auth import BasicAuthentication, CertificateAuthentication, JWTAuthentication from trino.dbapi import Cursor from trino.sqlalchemy import compiler, datatype, error @@ -102,7 +103,7 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if "cert" and "key" in url.query: kwargs["http_scheme"] = "https" - kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key']) + kwargs["auth"] = CertificateAuthentication(url.query["cert"], url.query["key"]) if "source" in url.query: kwargs["source"] = url.query["source"] @@ -243,7 +244,7 @@ def get_indexes(self, connection: Connection, table_name: str, schema: str = Non partition_index = dict( name="partition", column_names=[col["name"] for col in partitioned_columns], - unique=False + unique=False, ) return [partition_index] @@ -282,13 +283,13 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str try: res = connection.execute( sql.text(query), - catalog_name=catalog_name, schema_name=schema_name, table_name=table_name + catalog_name=catalog_name, + schema_name=schema_name, + table_name=table_name, ) return dict(text=res.scalar()) except error.TrinoQueryError as e: - if e.error_name in ( - error.PERMISSION_DENIED, - ): + if e.error_name in (error.PERMISSION_DENIED,): return dict(text=None) raise @@ -341,7 +342,11 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]: return dbapi_connection.schema def do_execute( - self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None + self, + cursor: Cursor, + statement: str, + parameters: Tuple[Any, ...], + context: DefaultExecutionContext = None, ): cursor.execute(statement, parameters) if context and context.should_autocommit: diff --git a/trino/transaction.py b/trino/transaction.py index e6c85234..716c5953 100644 --- a/trino/transaction.py +++ b/trino/transaction.py @@ -12,11 +12,10 @@ from enum import Enum, unique from typing import Iterable -from trino import constants import trino.client import trino.exceptions import trino.logging - +from trino import constants logger = trino.logging.get_logger(__name__) @@ -66,9 +65,7 @@ def request(self): def begin(self): response = self._request.post(START_TRANSACTION) if not response.ok: - raise trino.exceptions.DatabaseError( - "failed to start transaction: {}".format(response.status_code) - ) + raise trino.exceptions.DatabaseError("failed to start transaction: {}".format(response.status_code)) transaction_id = response.headers.get(constants.HEADER_STARTED_TRANSACTION) if transaction_id and transaction_id != NO_TRANSACTION: self._id = response.headers[constants.HEADER_STARTED_TRANSACTION] @@ -87,9 +84,7 @@ def commit(self): try: list(query.execute()) except Exception as err: - raise trino.exceptions.DatabaseError( - "failed to commit transaction {}: {}".format(self._id, err) - ) + raise trino.exceptions.DatabaseError("failed to commit transaction {}: {}".format(self._id, err)) self._id = NO_TRANSACTION self._request.transaction_id = self._id @@ -98,8 +93,6 @@ def rollback(self): try: list(query.execute()) except Exception as err: - raise trino.exceptions.DatabaseError( - "failed to rollback transaction {}: {}".format(self._id, err) - ) + raise trino.exceptions.DatabaseError("failed to rollback transaction {}: {}".format(self._id, err)) self._id = NO_TRANSACTION self._request.transaction_id = self._id