From 073acb752c541c242a3398ca44080610ab23f1f2 Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Fri, 19 Aug 2022 14:38:16 +0200 Subject: [PATCH] Added black (code formatting) and isort (sorting imports) --- .flake8 | 11 + .pre-commit-config.yaml | 13 + README.md | 12 + setup.py | 7 +- tests/integration/conftest.py | 26 +- tests/integration/test_dbapi_integration.py | 280 +++++++++--------- .../test_sqlalchemy_integration.py | 206 +++++++------ tests/unit/conftest.py | 6 +- tests/unit/oauth_test_utils.py | 50 +++- tests/unit/sqlalchemy/conftest.py | 2 +- tests/unit/sqlalchemy/test_compiler.py | 51 ++-- tests/unit/sqlalchemy/test_datatype_parse.py | 24 +- tests/unit/sqlalchemy/test_dialect.py | 67 +++-- tests/unit/test_client.py | 228 ++++++++------ tests/unit/test_dbapi.py | 97 +++--- tests/unit/test_http.py | 2 +- tests/unit/test_transaction.py | 7 +- trino/__init__.py | 9 +- trino/auth.py | 89 +++--- trino/client.py | 91 +++--- trino/constants.py | 7 +- trino/dbapi.py | 121 ++++---- trino/exceptions.py | 2 + trino/sqlalchemy/compiler.py | 27 +- trino/sqlalchemy/datatype.py | 4 +- trino/sqlalchemy/dialect.py | 61 ++-- trino/transaction.py | 3 +- 27 files changed, 831 insertions(+), 672 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..21c33d71 --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +[flake8] +select = + E + W + F +ignore = + W503 # makes Flake8 work like black + W504 + E203 # makes Flake8 work like black + E741 + E501 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 32bf21cc..b815fc8a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,3 +13,16 @@ repos: additional_dependencies: - "types-pytz" - "types-requests" + + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + args: + - "--line-length=99" + + - repo: https://github.com/pycqa/isort + rev: 5.6.4 + hooks: + - id: isort + args: [ "--profile", "black", "--filter-files" ] diff --git a/README.md b/README.md index d45c6554..98b01322 100644 --- a/README.md +++ b/README.md @@ -424,6 +424,7 @@ We recommend that you use Python3's `venv` for development: $ python3 -m venv .venv $ . .venv/bin/activate $ pip install -e '.[tests]' +$ pre-commit install ``` With `-e` passed to `pip install` above pip can reference the code you are @@ -441,6 +442,17 @@ When the code is ready, submit a Pull Request. See also Trino's [guidelines](https://github.com/trinodb/trino/blob/master/.github/DEVELOPMENT.md). Most of them also apply to code in trino-python-client. +### `pre-commit` checks + +Code is automatically checked on commit by a [pre-commit](https://pre-commit.com/) git hook. + +Following checks are performed: + +- [`flake8`](https://flake8.pycqa.org/en/latest/) for code linting +- [`black`](https://github.com/psf/black) for code formatting +- [`isort`](https://pycqa.github.io/isort/) for sorting imports +- [`mypy`](https://mypy.readthedocs.io/en/stable/) for static type checking + ### Running tests `trino-python-client` uses [pytest](https://pytest.org/) for its tests. To run diff --git a/setup.py b/setup.py index 2695a530..2ee59cb2 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ import re import textwrap -from setuptools import setup, find_packages +from setuptools import find_packages, setup _version_re = re.compile(r"__version__\s+=\s+(.*)") @@ -39,6 +39,9 @@ "pytest", "pytest-runner", "click", + "pre-commit", + "black", + "isort", ] setup( @@ -75,7 +78,7 @@ "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Database :: Front-Ends", ], - python_requires='>=3.7', + python_requires=">=3.7", install_requires=["pytz", "requests"], extras_require={ "all": all_require, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 840ee8f3..0d3ea56e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -18,11 +18,11 @@ from uuid import uuid4 import click -import trino.logging import pytest -from trino.client import TrinoQuery, TrinoRequest, ClientSession -from trino.constants import DEFAULT_PORT +import trino.logging +from trino.client import ClientSession, TrinoQuery, TrinoRequest +from trino.constants import DEFAULT_PORT logger = trino.logging.get_logger(__name__) @@ -64,13 +64,7 @@ def start_trino(image_tag=None): def wait_for_trino_workers(host, port, timeout=180): - request = TrinoRequest( - host=host, - port=port, - client_session=ClientSession( - user="test_fixture" - ) - ) + request = TrinoRequest(host=host, port=port, client_session=ClientSession(user="test_fixture")) sql = "SELECT state FROM system.runtime.nodes" t0 = time.time() while True: @@ -116,9 +110,7 @@ def start_trino_and_wait(image_tag=None): if host: port = os.environ.get("TRINO_RUNNING_PORT", DEFAULT_PORT) else: - container_id, proc, host, port = start_local_trino_server( - image_tag - ) + container_id, proc, host, port = start_local_trino_server(image_tag) print("trino.server.hostname {}".format(host)) print("trino.server.port {}".format(port)) @@ -167,9 +159,7 @@ def cli(): pass -@click.option( - "--cache/--no-cache", default=True, help="enable/disable Docker build cache" -) +@click.option("--cache/--no-cache", default=True, help="enable/disable Docker build cache") @click.command() def trino_server(): container_id, _, _, _ = start_trino_and_wait() @@ -198,9 +188,7 @@ def trino_cli(container_id=None): @cli.command("list") def list_(): - subprocess.check_call( - ["docker", "ps", "--filter", "name=trino-python-client-tests-"] - ) + subprocess.check_call(["docker", "ps", "--filter", "name=trino-python-client-tests-"]) @cli.command() diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index c5d0dda0..99d21ff2 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from datetime import datetime, time, date, timezone, timedelta +from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal import pytest @@ -19,7 +19,7 @@ import trino from tests.integration.conftest import trino_version -from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError +from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError from trino.transaction import IsolationLevel @@ -27,9 +27,7 @@ def trino_connection(run_trino): _, host, port = run_trino - yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 - ) + yield trino.dbapi.Connection(host=host, port=port, user="test", source="test", max_attempts=1) @pytest.fixture @@ -104,7 +102,7 @@ def test_select_query_result_iteration_statement_params(trino_connection): ) x (id, name, letter) WHERE id >= ? """, - params=(3,) # expecting all the rows with id >= 3 + params=(3,), # expecting all the rows with id >= 3 ) @@ -153,7 +151,9 @@ def test_execute_many_without_params(trino_connection): cur = trino_connection.cursor() cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)") cur.fetchall() - cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", []) + cur.executemany( + "INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [] + ) with pytest.raises(TrinoUserError) as e: cur.fetchall() assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value) @@ -166,23 +166,22 @@ def test_execute_many_select(trino_connection): assert "Query must return update type" in str(e.value) -@pytest.mark.parametrize("connection_experimental_python_types,cursor_experimental_python_types,expected", - [ - (None, None, False), - (None, False, False), - (None, True, True), - (False, None, False), - (False, False, False), - (False, True, True), - (True, None, True), - (True, False, False), - (True, True, True), - ]) +@pytest.mark.parametrize( + "connection_experimental_python_types,cursor_experimental_python_types,expected", + [ + (None, None, False), + (None, False, False), + (None, True, True), + (False, None, False), + (False, False, False), + (False, True, True), + (True, None, True), + (True, False, False), + (True, True, True), + ], +) def test_experimental_python_types_with_connection_and_cursor( - connection_experimental_python_types, - cursor_experimental_python_types, - expected, - run_trino + connection_experimental_python_types, cursor_experimental_python_types, expected, run_trino ): _, host, port = run_trino @@ -195,7 +194,8 @@ def test_experimental_python_types_with_connection_and_cursor( cur = connection.cursor(experimental_python_types=cursor_experimental_python_types) - cur.execute(""" + cur.execute( + """ SELECT DECIMAL '0.142857', DATE '2018-01-01', @@ -203,35 +203,36 @@ def test_experimental_python_types_with_connection_and_cursor( TIMESTAMP '2019-01-01 00:00:00.000 UTC', TIMESTAMP '2019-01-01 00:00:00.000', TIME '00:00:00.000' - """) + """ + ) rows = cur.fetchall() if expected: - assert rows[0][0] == Decimal('0.142857') + assert rows[0][0] == Decimal("0.142857") assert rows[0][1] == date(2018, 1, 1) assert rows[0][2] == datetime(2019, 1, 1, tzinfo=timezone(timedelta(hours=1))) - assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone('UTC')) + assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone("UTC")) assert rows[0][4] == datetime(2019, 1, 1) assert rows[0][5] == time(0, 0, 0, 0) else: for value in rows[0]: assert isinstance(value, str) - assert rows[0][0] == '0.142857' - assert rows[0][1] == '2018-01-01' - assert rows[0][2] == '2019-01-01 00:00:00.000 +01:00' - assert rows[0][3] == '2019-01-01 00:00:00.000 UTC' - assert rows[0][4] == '2019-01-01 00:00:00.000' - assert rows[0][5] == '00:00:00.000' + assert rows[0][0] == "0.142857" + assert rows[0][1] == "2018-01-01" + assert rows[0][2] == "2019-01-01 00:00:00.000 +01:00" + assert rows[0][3] == "2019-01-01 00:00:00.000 UTC" + assert rows[0][4] == "2019-01-01 00:00:00.000" + assert rows[0][5] == "00:00:00.000" def test_decimal_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - cur.execute("SELECT ?", params=(Decimal('0.142857'),)) + cur.execute("SELECT ?", params=(Decimal("0.142857"),)) rows = cur.fetchall() - assert rows[0][0] == Decimal('0.142857') + assert rows[0][0] == Decimal("0.142857") def test_null_decimal(trino_connection): @@ -246,7 +247,7 @@ def test_null_decimal(trino_connection): def test_biggest_decimal(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = Decimal('99999999999999999999999999999999999999') + params = Decimal("99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -256,7 +257,7 @@ def test_biggest_decimal(trino_connection): def test_smallest_decimal(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = Decimal('-99999999999999999999999999999999999999') + params = Decimal("-99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -266,7 +267,7 @@ def test_smallest_decimal(trino_connection): def test_highest_precision_decimal(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = Decimal('0.99999999999999999999999999999999999999') + params = Decimal("0.99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -288,7 +289,7 @@ def test_datetime_query_param(trino_connection): def test_datetime_with_utc_time_zone_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('UTC')) + params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone("UTC")) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -314,7 +315,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection): def test_datetime_with_named_time_zone_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles')) + params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone("America/Los_Angeles")) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -347,14 +348,16 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection): cur.execute("SELECT TIMESTAMP '2001-08-22 03:04:05.321 -08:00'") rows = cur.fetchall() - assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z") + assert rows[0][0] == datetime.strptime( + "2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z" + ) def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) # This is a datetime that lies within a DST transition and not actually exists. - params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone('Europe/Brussels')) + params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone("Europe/Brussels")) with pytest.raises(trino.exceptions.TrinoUserError): cur.execute("SELECT ?", params=(params,)) cur.fetchall() @@ -365,21 +368,21 @@ def test_doubled_datetimes(trino_connection): # See also https://github.com/trinodb/trino/issues/5781 cur = trino_connection.cursor(experimental_python_types=True) - params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True) + params = pytz.timezone("US/Eastern").localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() - assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern')) + assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone("US/Eastern")) cur = trino_connection.cursor(experimental_python_types=True) - params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False) + params = pytz.timezone("US/Eastern").localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() - assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern')) + assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone("US/Eastern")) def test_date_query_param(trino_connection): @@ -407,11 +410,11 @@ def test_unsupported_python_dates(trino_connection): # dates below python min (1-1-1) or above max date (9999-12-31) are not supported for unsupported_date in [ - '-0001-01-01', - '0000-01-01', - '10000-01-01', - '-4999999-01-01', # Trino min date - '5000000-12-31', # Trino max date + "-0001-01-01", + "0000-01-01", + "10000-01-01", + "-4999999-01-01", # Trino min date + "5000000-12-31", # Trino max date ]: with pytest.raises(trino.exceptions.TrinoDataError): cur.execute(f"SELECT DATE '{unsupported_date}'") @@ -422,26 +425,26 @@ def test_supported_special_dates_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) for params in ( - # min python date - date(1, 1, 1), - # before julian->gregorian switch - date(1500, 1, 1), - # During julian->gregorian switch - date(1752, 9, 4), - # before epoch - date(1952, 4, 3), - date(1970, 1, 1), - date(1970, 2, 3), - # summer on northern hemisphere (possible DST) - date(2017, 7, 1), - # winter on northern hemisphere (possible DST on southern hemisphere) - date(2017, 1, 1), - # winter on southern hemisphere (possible DST on northern hemisphere) - date(2017, 12, 31), - date(1983, 4, 1), - date(1983, 10, 1), - # max python date - date(9999, 12, 31), + # min python date + date(1, 1, 1), + # before julian->gregorian switch + date(1500, 1, 1), + # During julian->gregorian switch + date(1752, 9, 4), + # before epoch + date(1952, 4, 3), + date(1970, 1, 1), + date(1970, 2, 3), + # summer on northern hemisphere (possible DST) + date(2017, 7, 1), + # winter on northern hemisphere (possible DST on southern hemisphere) + date(2017, 1, 1), + # winter on southern hemisphere (possible DST on northern hemisphere) + date(2017, 12, 31), + date(1983, 4, 1), + date(1983, 10, 1), + # max python date + date(9999, 12, 31), ): cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -464,7 +467,7 @@ def test_time_with_named_time_zone_query_param(trino_connection): with pytest.raises(trino.exceptions.NotSupportedError): cur = trino_connection.cursor() - params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('Asia/Shanghai')) + params = time(16, 43, 22, 320000, tzinfo=pytz.timezone("Asia/Shanghai")) cur.execute("SELECT ?", params=(params,)) @@ -599,7 +602,10 @@ def test_array_timestamp_query_param(trino_connection): def test_array_timestamp_with_timezone_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)] + params = [ + datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), + datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc), + ] cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -723,15 +729,18 @@ def test_int_query_param(trino_connection): assert cur.description[0][1] == "bigint" -@pytest.mark.parametrize('params', [ - 'NOT A LIST OR TUPPLE', - {'invalid', 'params'}, - object, -]) +@pytest.mark.parametrize( + "params", + [ + "NOT A LIST OR TUPPLE", + {"invalid", "params"}, + object, + ], +) def test_select_query_invalid_params(trino_connection, params): cur = trino_connection.cursor() with pytest.raises(AssertionError): - cur.execute('SELECT ?', params=params) + cur.execute("SELECT ?", params=params) def test_select_cursor_iteration(trino_connection): @@ -877,8 +886,9 @@ def test_transaction_multiple(trino_connection_with_transaction): assert len(rows2) == 1000 -@pytest.mark.skipif(trino_version() == '351', reason="Autocommit behaves " - "differently in older Trino versions") +@pytest.mark.skipif( + trino_version() == "351", reason="Autocommit behaves " "differently in older Trino versions" +) def test_transaction_autocommit(trino_connection_in_autocommit): with trino_connection_in_autocommit as connection: connection.start_transaction() @@ -887,17 +897,18 @@ def test_transaction_autocommit(trino_connection_in_autocommit): """ CREATE TABLE memory.default.nation AS SELECT * from tpch.tiny.nation - """) + """ + ) with pytest.raises(TrinoUserError) as transaction_error: cur.fetchall() - assert "Catalog only supports writes using autocommit: memory" \ - in str(transaction_error.value) + assert "Catalog only supports writes using autocommit: memory" in str( + transaction_error.value + ) def test_invalid_query_throws_correct_error(trino_connection): - """Tests that an invalid query raises the correct exception - """ + """Tests that an invalid query raises the correct exception""" cur = trino_connection.cursor() with pytest.raises(TrinoQueryError): cur.execute( @@ -910,14 +921,14 @@ def test_invalid_query_throws_correct_error(trino_connection): def test_eager_loading_cursor_description(trino_connection): description_expected = [ - ('node_id', 'varchar', None, None, None, None, None), - ('http_uri', 'varchar', None, None, None, None, None), - ('node_version', 'varchar', None, None, None, None, None), - ('coordinator', 'boolean', None, None, None, None, None), - ('state', 'varchar', None, None, None, None, None), + ("node_id", "varchar", None, None, None, None, None), + ("http_uri", "varchar", None, None, None, None, None), + ("node_version", "varchar", None, None, None, None, None), + ("coordinator", "boolean", None, None, None, None, None), + ("state", "varchar", None, None, None, None, None), ] cur = trino_connection.cursor() - cur.execute('SELECT * FROM system.runtime.nodes') + cur.execute("SELECT * FROM system.runtime.nodes") description_before = cur.description assert description_before is not None @@ -934,7 +945,7 @@ def test_eager_loading_cursor_description(trino_connection): def test_info_uri(trino_connection): cur = trino_connection.cursor() assert cur.info_uri is None - cur.execute('SELECT * FROM system.runtime.nodes') + cur.execute("SELECT * FROM system.runtime.nodes") assert cur.info_uri is not None assert cur._query.query_id in cur.info_uri cur.fetchall() @@ -971,44 +982,51 @@ def retrieve_client_tags_from_query(run_trino, client_tags): ) cur = trino_connection.cursor() - cur.execute('SELECT 1') + cur.execute("SELECT 1") cur.fetchall() api_url = "http://" + trino_connection.host + ":" + str(trino_connection.port) - query_info = requests.post(api_url + "/ui/login", data={ - "username": "admin", - "password": "", - "redirectPath": api_url + '/ui/api/query/' + cur._query.query_id - }).json() - - query_client_tags = query_info['session']['clientTags'] + query_info = requests.post( + api_url + "/ui/login", + data={ + "username": "admin", + "password": "", + "redirectPath": api_url + "/ui/api/query/" + cur._query.query_id, + }, + ).json() + + query_client_tags = query_info["session"]["clientTags"] return query_client_tags -@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions") +@pytest.mark.skipif( + trino_version() == "351", reason="current_catalog not supported in older Trino versions" +) def test_use_catalog_schema(trino_connection): cur = trino_connection.cursor() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() assert result[0][0] is None assert result[0][1] is None - cur.execute('USE tpch.tiny') + cur.execute("USE tpch.tiny") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'tiny' + assert result[0][0] == "tpch" + assert result[0][1] == "tiny" - cur.execute('USE tpcds.sf1') + cur.execute("USE tpcds.sf1") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpcds' - assert result[0][1] == 'sf1' + assert result[0][0] == "tpcds" + assert result[0][1] == "sf1" -@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions") +@pytest.mark.skipif( + trino_version() == "351", reason="current_catalog not supported in older Trino versions" +) def test_use_catalog(run_trino): _, host, port = run_trino @@ -1016,35 +1034,33 @@ def test_use_catalog(run_trino): host=host, port=port, user="test", source="test", catalog="tpch", max_attempts=1 ) cur = trino_connection.cursor() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' + assert result[0][0] == "tpch" assert result[0][1] is None - cur.execute('USE tiny') + cur.execute("USE tiny") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'tiny' + assert result[0][0] == "tpch" + assert result[0][1] == "tiny" - cur.execute('USE sf1') + cur.execute("USE sf1") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'sf1' + assert result[0][0] == "tpch" + assert result[0][1] == "sf1" -@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role") +@pytest.mark.skipif(trino_version() == "351", reason="Newer Trino versions return the system role") def test_set_role_trino_higher_351(run_trino): _, host, port = run_trino - trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch" - ) + trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch") cur = trino_connection.cursor() - cur.execute('SHOW TABLES FROM information_schema') + cur.execute("SHOW TABLES FROM information_schema") cur.fetchall() assert cur._request._client_session.role is None @@ -1053,15 +1069,15 @@ def test_set_role_trino_higher_351(run_trino): assert cur._request._client_session.role == "system=ALL" -@pytest.mark.skipif(trino_version() != '351', reason="Trino 351 returns the role for the current catalog") +@pytest.mark.skipif( + trino_version() != "351", reason="Trino 351 returns the role for the current catalog" +) def test_set_role_trino_351(run_trino): _, host, port = run_trino - trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch" - ) + trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch") cur = trino_connection.cursor() - cur.execute('SHOW TABLES FROM information_schema') + cur.execute("SHOW TABLES FROM information_schema") cur.fetchall() assert cur._request._client_session.role is None diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 1dc8f05a..cdac58ab 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -11,22 +11,24 @@ # limitations under the License import pytest import sqlalchemy as sqla -from sqlalchemy.sql import and_, or_, not_ +from sqlalchemy.sql import and_, not_, or_ @pytest.fixture def trino_connection(run_trino, request): _, host, port = run_trino - engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}", - connect_args={"source": "test", "max_attempts": 1}) + engine = sqla.create_engine( + f"trino://test@{host}:{port}/{request.param}", + connect_args={"source": "test", "max_attempts": 1}, + ) yield engine, engine.connect() -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_select_query(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) assert_column(nations, "nationkey", sqla.sql.sqltypes.BigInteger) assert_column(nations, "name", sqla.sql.sqltypes.String) assert_column(nations, "regionkey", sqla.sql.sqltypes.BigInteger) @@ -36,10 +38,10 @@ def test_select_query(trino_connection): rows = result.fetchall() assert len(rows) == 25 for row in rows: - assert isinstance(row['nationkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['regionkey'], int) - assert isinstance(row['comment'], str) + assert isinstance(row["nationkey"], int) + assert isinstance(row["name"], str) + assert isinstance(row["regionkey"], int) + assert isinstance(row["comment"], str) def assert_column(table, column_name, column_type): @@ -47,11 +49,11 @@ def assert_column(table, column_name, column_type): assert isinstance(getattr(table.c, column_name).type, column_type) -@pytest.mark.parametrize('trino_connection', ['system'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["system"], indirect=True) def test_select_specific_columns(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nodes = sqla.Table('nodes', metadata, schema='runtime', autoload_with=conn) + nodes = sqla.Table("nodes", metadata, schema="runtime", autoload_with=conn) assert_column(nodes, "node_id", sqla.sql.sqltypes.String) assert_column(nodes, "state", sqla.sql.sqltypes.String) query = sqla.select(nodes.c.node_id, nodes.c.state) @@ -59,26 +61,28 @@ def test_select_specific_columns(trino_connection): rows = result.fetchall() assert len(rows) > 0 for row in rows: - assert isinstance(row['node_id'], str) - assert isinstance(row['state'], str) + assert isinstance(row["node_id"], str) + assert isinstance(row["state"], str) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_define_and_create_table(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(engine, "test"): engine.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: - sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) - assert sqla.inspect(engine).has_table('users', schema="test") - users = sqla.Table('users', metadata, schema='test', autoload_with=conn) + assert sqla.inspect(engine).has_table("users", schema="test") + users = sqla.Table("users", metadata, schema="test", autoload_with=conn) assert_column(users, "id", sqla.sql.sqltypes.Integer) assert_column(users, "name", sqla.sql.sqltypes.String) assert_column(users, "fullname", sqla.sql.sqltypes.String) @@ -86,7 +90,7 @@ def test_define_and_create_table(trino_connection): metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_insert(trino_connection): engine, conn = trino_connection @@ -94,12 +98,14 @@ def test_insert(trino_connection): engine.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: - users = sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + users = sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) ins = users.insert() conn.execute(ins, {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}) @@ -112,72 +118,79 @@ def test_insert(trino_connection): metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_insert_multiple_statements(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(engine, "test"): engine.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() - users = sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + users = sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) ins = users.insert() - conn.execute(ins, [ - {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}, - {"id": 3, "name": "john", "fullname": "John Doe"}, - {"id": 4, "name": "mary", "fullname": "Mary Hopkins"}, - ]) + conn.execute( + ins, + [ + {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}, + {"id": 3, "name": "john", "fullname": "John Doe"}, + {"id": 4, "name": "mary", "fullname": "Mary Hopkins"}, + ], + ) query = sqla.select(users) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 3 - assert frozenset(rows) == frozenset([ - (2, "wendy", "Wendy Williams"), - (3, "john", "John Doe"), - (4, "mary", "Mary Hopkins"), - ]) + assert frozenset(rows) == frozenset( + [ + (2, "wendy", "Wendy Williams"), + (3, "john", "John Doe"), + (4, "mary", "Mary Hopkins"), + ] + ) metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_operators(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - customers = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + customers = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) query = sqla.select(customers).where(customers.c.nationkey == 2) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 1 for row in rows: - assert isinstance(row['nationkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['regionkey'], int) - assert isinstance(row['comment'], str) + assert isinstance(row["nationkey"], int) + assert isinstance(row["name"], str) + assert isinstance(row["regionkey"], int) + assert isinstance(row["comment"], str) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_conjunctions(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) - query = sqla.select(customers).where(and_( - customers.c.name.like('%12%'), - customers.c.nationkey == 15, - or_( - customers.c.mktsegment == 'AUTOMOBILE', - customers.c.mktsegment == 'HOUSEHOLD' - ), - not_(customers.c.acctbal < 0))) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) + query = sqla.select(customers).where( + and_( + customers.c.name.like("%12%"), + customers.c.nationkey == 15, + or_(customers.c.mktsegment == "AUTOMOBILE", customers.c.mktsegment == "HOUSEHOLD"), + not_(customers.c.acctbal < 0), + ) + ) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 1 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_textual_sql(trino_connection): _, conn = trino_connection s = sqla.text("SELECT * from tiny.customer where nationkey = :e1 AND acctbal < :e2") @@ -185,70 +198,79 @@ def test_textual_sql(trino_connection): rows = result.fetchall() assert len(rows) == 3 for row in rows: - assert isinstance(row['custkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['address'], str) - assert isinstance(row['nationkey'], int) - assert isinstance(row['phone'], str) - assert isinstance(row['acctbal'], float) - assert isinstance(row['mktsegment'], str) - assert isinstance(row['comment'], str) + assert isinstance(row["custkey"], int) + assert isinstance(row["name"], str) + assert isinstance(row["address"], str) + assert isinstance(row["nationkey"], int) + assert isinstance(row["phone"], str) + assert isinstance(row["acctbal"], float) + assert isinstance(row["mktsegment"], str) + assert isinstance(row["comment"], str) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_alias(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) nations1 = nations.alias("o1") nations2 = nations.alias("o2") - s = sqla.select(nations1) \ - .join(nations2, and_( - nations1.c.regionkey == nations2.c.regionkey, - nations1.c.nationkey != nations2.c.nationkey, - nations1.c.regionkey == 1 - )) \ + s = ( + sqla.select(nations1) + .join( + nations2, + and_( + nations1.c.regionkey == nations2.c.regionkey, + nations1.c.nationkey != nations2.c.nationkey, + nations1.c.regionkey == 1, + ), + ) .distinct() + ) result = conn.execute(s) rows = result.fetchall() assert len(rows) == 5 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_subquery(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900) automobile_customers_subquery = automobile_customers.subquery() - s = sqla.select(nations.c.name).where(nations.c.nationkey.in_(sqla.select(automobile_customers_subquery))) + s = sqla.select(nations.c.name).where( + nations.c.nationkey.in_(sqla.select(automobile_customers_subquery)) + ) result = conn.execute(s) rows = result.fetchall() assert len(rows) == 15 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_joins(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) - s = sqla.select(nations.c.name) \ - .select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey)) \ - .where(customers.c.acctbal < -900) \ + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) + s = ( + sqla.select(nations.c.name) + .select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey)) + .where(customers.c.acctbal < -900) .distinct() + ) result = conn.execute(s) rows = result.fetchall() assert len(rows) == 15 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_cte(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900) automobile_customers_cte = automobile_customers.cte() s = sqla.select(nations).where(nations.c.nationkey.in_(sqla.select(automobile_customers_cte))) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9d968e07..59e312d4 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -10,9 +10,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest.mock import MagicMock, patch +import pytest + @pytest.fixture(scope="session") def sample_post_response_data(): @@ -194,8 +195,7 @@ def sample_get_error_response_data(): "errorType": "USER_ERROR", "failureInfo": { "errorLocation": {"columnNumber": 15, "lineNumber": 1}, - "message": "line 1:15: Schema must be specified " - "when session schema is not set", + "message": "line 1:15: Schema must be specified " "when session schema is not set", "stack": [ "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)", "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)", diff --git a/tests/unit/oauth_test_utils.py b/tests/unit/oauth_test_utils.py index 77b891ce..1467a489 100644 --- a/tests/unit/oauth_test_utils.py +++ b/tests/unit/oauth_test_utils.py @@ -45,9 +45,15 @@ def __call__(self, request, uri, response_headers): authorization = request.headers.get("Authorization") if authorization and authorization.replace("Bearer ", "") in self.tokens: return [200, response_headers, json.dumps(self.sample_post_response_data)] - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", ' - f'x_token_server="{self.token_server}"', - 'Basic realm': '"Trino"'}, ""] + return [ + 401, + { + "Www-Authenticate": f'Bearer x_redirect_server="{self.redirect_server}", ' + f'x_token_server="{self.token_server}"', + "Basic realm": '"Trino"', + }, + "", + ] class GetTokenCallback: @@ -66,19 +72,25 @@ def __call__(self, request, uri, response_headers): def _get_token_requests(challenge_id): - return list(filter( - lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", - httpretty.latest_requests())) + return list( + filter( + lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", + httpretty.latest_requests(), + ) + ) def _post_statement_requests(): - return list(filter( - lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, - httpretty.latest_requests())) + return list( + filter( + lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, + httpretty.latest_requests(), + ) + ) class MultithreadedTokenServer: - Challenge = namedtuple('Challenge', ['token', 'attempts']) + Challenge = namedtuple("Challenge", ["token", "attempts"]) def __init__(self, sample_post_response_data, attempts=1): self.tokens = set() @@ -90,13 +102,15 @@ def __init__(self, sample_post_response_data, attempts=1): httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=self.post_statement_callback) + body=self.post_statement_callback, + ) # bind get token httpretty.register_uri( method=httpretty.GET, uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), - body=self.get_token_callback) + body=self.get_token_callback, + ) # noinspection PyUnusedLocal def post_statement_callback(self, request, uri, response_headers): @@ -111,9 +125,15 @@ def post_statement_callback(self, request, uri, response_headers): self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts) redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", ' - f'x_token_server="{token_server}"', - 'Basic realm': '"Trino"'}, ""] + return [ + 401, + { + "Www-Authenticate": f'Bearer x_redirect_server="{redirect_server}", ' + f'x_token_server="{token_server}"', + "Basic realm": '"Trino"', + }, + "", + ] # noinspection PyUnusedLocal def get_token_callback(self, request, uri, response_headers): diff --git a/tests/unit/sqlalchemy/conftest.py b/tests/unit/sqlalchemy/conftest.py index e80f19b8..71d6f74d 100644 --- a/tests/unit/sqlalchemy/conftest.py +++ b/tests/unit/sqlalchemy/conftest.py @@ -12,7 +12,7 @@ import pytest from sqlalchemy.sql.sqltypes import ARRAY -from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME +from trino.sqlalchemy.datatype import MAP, ROW, TIME, TIMESTAMP, SQLType @pytest.fixture(scope="session") diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 00bd686c..6139affc 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -10,32 +10,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from sqlalchemy import ( - Column, - insert, - Integer, - MetaData, - select, - String, - Table, -) +from sqlalchemy import Column, Integer, MetaData, String, Table, insert, select from sqlalchemy.schema import CreateTable from trino.sqlalchemy.dialect import TrinoDialect metadata = MetaData() table = Table( - 'table', + "table", metadata, - Column('id', Integer), - Column('name', String), + Column("id", Integer), + Column("name", String), ) table_with_catalog = Table( - 'table', - metadata, - Column('id', Integer), - schema='default', - trino_catalog='other' + "table", metadata, Column("id", Integer), schema="default", trino_catalog="other" ) @@ -47,7 +35,10 @@ def dialect(): def test_limit_offset(dialect): statement = select(table).limit(10).offset(0) query = statement.compile(dialect=dialect) - assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1\nLIMIT :param_2' + assert ( + str(query) + == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1\nLIMIT :param_2' + ) def test_limit(dialect): @@ -63,15 +54,16 @@ def test_offset(dialect): def test_cte_insert_order(dialect): - cte = select(table).cte('cte') + cte = select(table).cte("cte") statement = insert(table).from_select(table.columns, cte) query = statement.compile(dialect=dialect) - assert str(query) == \ - 'INSERT INTO "table" (id, name) WITH cte AS \n'\ - '(SELECT "table".id AS id, "table".name AS name \n'\ - 'FROM "table")\n'\ - ' SELECT cte.id, cte.name \n'\ - 'FROM cte' + assert ( + str(query) == 'INSERT INTO "table" (id, name) WITH cte AS \n' + '(SELECT "table".id AS id, "table".name AS name \n' + 'FROM "table")\n' + " SELECT cte.id, cte.name \n" + "FROM cte" + ) def test_catalogs_argument(dialect): @@ -83,9 +75,6 @@ def test_catalogs_argument(dialect): def test_catalogs_create_table(dialect): statement = CreateTable(table_with_catalog) query = statement.compile(dialect=dialect) - assert str(query) == \ - '\n'\ - 'CREATE TABLE "other".default."table" (\n'\ - '\tid INTEGER\n'\ - ')\n'\ - '\n' + assert ( + str(query) == "\n" 'CREATE TABLE "other".default."table" (\n' "\tid INTEGER\n" ")\n" "\n" + ) diff --git a/tests/unit/sqlalchemy/test_datatype_parse.py b/tests/unit/sqlalchemy/test_datatype_parse.py index 66a2f6b0..90e007b9 100644 --- a/tests/unit/sqlalchemy/test_datatype_parse.py +++ b/tests/unit/sqlalchemy/test_datatype_parse.py @@ -10,23 +10,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from sqlalchemy.sql.sqltypes import ( - CHAR, - VARCHAR, - ARRAY, - INTEGER, - DECIMAL, - DATE -) +from sqlalchemy.sql.sqltypes import ARRAY, CHAR, DATE, DECIMAL, INTEGER, VARCHAR from sqlalchemy.sql.type_api import TypeEngine from trino.sqlalchemy import datatype -from trino.sqlalchemy.datatype import ( - MAP, - ROW, - TIME, - TIMESTAMP -) +from trino.sqlalchemy.datatype import MAP, ROW, TIME, TIMESTAMP @pytest.mark.parametrize( @@ -68,7 +56,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype): "CHAR(10)": CHAR(10), "VARCHAR(10)": VARCHAR(10), "DECIMAL(20)": DECIMAL(20), - "DECIMAL(20, 3)": DECIMAL(20, 3) + "DECIMAL(20, 3)": DECIMAL(20, 3), } @@ -108,7 +96,9 @@ def test_parse_array(type_str: str, sql_type: ARRAY, assert_sqltype): "map(varchar(10), decimal(20,3))": MAP(VARCHAR(10), DECIMAL(20, 3)), "map(char, array(varchar(10)))": MAP(CHAR(), ARRAY(VARCHAR(10))), "map(varchar(10), array(varchar(10)))": MAP(VARCHAR(10), ARRAY(VARCHAR(10))), - "map(varchar(10), array(array(varchar(10))))": MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)), + "map(varchar(10), array(array(varchar(10))))": MAP( + VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2) + ), } @@ -187,7 +177,7 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype): "timestamp(3)": TIMESTAMP(3, timezone=False), "timestamp(6)": TIMESTAMP(6), "timestamp(12) with time zone": TIMESTAMP(12, timezone=True), - "timestamp with time zone": TIMESTAMP(timezone=True) + "timestamp with time zone": TIMESTAMP(timezone=True), } diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index b17f8cfe..4cff4c27 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -7,7 +7,11 @@ from trino.auth import BasicAuthentication from trino.dbapi import Connection -from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect +from trino.sqlalchemy.dialect import ( + CertificateAuthentication, + JWTAuthentication, + TrinoDialect, +) from trino.transaction import IsolationLevel @@ -26,7 +30,13 @@ def setup(self): ( make_url("trino://user@localhost:8080"), list(), - dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"), + dict( + host="localhost", + port=8080, + catalog="system", + user="user", + source="trino-sqlalchemy", + ), ), ( make_url("trino://user:pass@localhost:8080?source=trino-rulez"), @@ -38,17 +48,18 @@ def setup(self): user="user", auth=BasicAuthentication("user", "pass"), http_scheme="https", - source="trino-rulez" + source="trino-rulez", ), ), ( make_url( - 'trino://user@localhost:8080?' + "trino://user@localhost:8080?" 'session_properties={"query_max_run_time": "1d"}' '&http_headers={"trino": 1}' '&extra_credential=[("a", "b"), ("c", "d")]' '&client_tags=[1, "sql"]' - '&experimental_python_types=true'), + "&experimental_python_types=true" + ), list(), dict( host="localhost", @@ -65,7 +76,9 @@ def setup(self): ), ], ) - def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]): + def test_create_connect_args( + self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any] + ): actual_args, actual_kwargs = self.dialect.create_connect_args(url) assert actual_args == expected_args @@ -73,7 +86,9 @@ def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_ def test_create_connect_args_missing_user_when_specify_password(self): url = make_url("trino://:pass@localhost") - with pytest.raises(ValueError, match="Username is required when specify password in connection URL"): + with pytest.raises( + ValueError, match="Username is required when specify password in connection URL" + ): self.dialect.create_connect_args(url) def test_create_connect_args_wrong_db_format(self): @@ -97,36 +112,36 @@ def test_isolation_level(self): def test_trino_connection_basic_auth(): dialect = TrinoDialect() - username = 'trino-user' - password = 'trino-bunny' - url = make_url(f'trino://{username}:{password}@host') + username = "trino-user" + password = "trino-bunny" + url = make_url(f"trino://{username}:{password}@host") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], BasicAuthentication) - assert cparams['auth']._username == username - assert cparams['auth']._password == password + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], BasicAuthentication) + assert cparams["auth"]._username == username + assert cparams["auth"]._password == password def test_trino_connection_jwt_auth(): dialect = TrinoDialect() - access_token = 'sample-token' - url = make_url(f'trino://host/?access_token={access_token}') + access_token = "sample-token" + url = make_url(f"trino://host/?access_token={access_token}") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], JWTAuthentication) - assert cparams['auth'].token == access_token + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], JWTAuthentication) + assert cparams["auth"].token == access_token def test_trino_connection_certificate_auth(): dialect = TrinoDialect() - cert = '/path/to/cert.pem' - key = '/path/to/key.pem' - url = make_url(f'trino://host/?cert={cert}&key={key}') + cert = "/path/to/cert.pem" + key = "/path/to/key.pem" + url = make_url(f"trino://host/?cert={cert}&key={key}") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], CertificateAuthentication) - assert cparams['auth']._cert == cert - assert cparams['auth']._key == key + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], CertificateAuthentication) + assert cparams["auth"]._cert == cert + assert cparams["auth"]._key == key diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 12089305..ba30349b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -23,13 +23,28 @@ from requests_kerberos.exceptions import KerberosExchangeError import trino.exceptions -from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \ - MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \ - SERVER_ADDRESS +from tests.unit.oauth_test_utils import ( + REDIRECT_RESOURCE, + SERVER_ADDRESS, + TOKEN_RESOURCE, + GetTokenCallback, + MultithreadedTokenServer, + PostStatementCallback, + RedirectHandler, + _get_token_requests, + _post_statement_requests, +) from trino import constants from trino.auth import KerberosAuthentication, _OAuth2TokenBearer -from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession, _DelayExponential, _retry_with, \ - _RetryWithExponentialBackoff +from trino.client import ( + ClientSession, + TrinoQuery, + TrinoRequest, + TrinoResult, + _DelayExponential, + _retry_with, + _RetryWithExponentialBackoff, +) @mock.patch("trino.client.TrinoRequest.http") @@ -80,7 +95,7 @@ def test_request_headers(mock_get_and_post): headers={ accept_encoding_header: accept_encoding_value, client_info_header: client_info_value, - } + }, ), http_scheme="http", redirect_handler=None, @@ -113,13 +128,8 @@ def test_request_session_properties_headers(mock_get_and_post): host="coordinator", port=8080, client_session=ClientSession( - user="test_user", - properties={ - "a": "1", - "b": "2", - "c": "more=v1,v2" - } - ) + user="test_user", properties={"a": "1", "b": "2", "c": "more=v1,v2"} + ), ) def assert_headers(headers): @@ -154,10 +164,10 @@ def test_additional_request_post_headers(mock_get_and_post): http_scheme="http", ) - sql = 'select 1' + sql = "select 1" additional_headers = { - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', + "X-Trino-Fake-1": "one", + "X-Trino-Fake-2": "two", } combined_headers = req.http_headers @@ -167,7 +177,7 @@ def test_additional_request_post_headers(mock_get_and_post): # Validate that the post call was performed including the addtional headers _, post_kwargs = post.call_args - assert post_kwargs['headers'] == combined_headers + assert post_kwargs["headers"] == combined_headers def test_request_invalid_http_headers(): @@ -189,10 +199,7 @@ def test_request_client_tags_headers(mock_get_and_post): req = TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - client_tags=["tag1", "tag2"] - ), + client_session=ClientSession(user="test_user", client_tags=["tag1", "tag2"]), ) def assert_headers(headers): @@ -215,7 +222,7 @@ def test_request_client_tags_headers_no_client_tags(mock_get_and_post): port=8080, client_session=ClientSession( user="test_user", - ) + ), ) def assert_headers(headers): @@ -335,20 +342,20 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data): redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + post_statement_callback = PostStatementCallback( + redirect_server, token_server, [token], sample_post_response_data + ) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token, attempts) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() @@ -359,10 +366,11 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) response = request.post("select 1") - assert response.request.headers['Authorization'] == f"Bearer {token}" + assert response.request.headers["Authorization"] == f"Bearer {token}" assert redirect_handler.redirect_server == redirect_server assert get_token_callback.attempts == 0 assert len(_post_statement_requests()) == 2 @@ -378,20 +386,22 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + post_statement_callback = PostStatementCallback( + redirect_server, token_server, [token], sample_post_response_data + ) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token, attempts) httpretty.register_uri( - method=httpretty.GET, - uri=f"{TOKEN_RESOURCE}/{challenge_id}", - body=get_token_callback) + method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", body=get_token_callback + ) redirect_handler = RedirectHandler() @@ -402,7 +412,8 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") @@ -413,21 +424,34 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): assert len(_get_token_requests(challenge_id)) == _OAuth2TokenBearer.MAX_OAUTH_ATTEMPTS -@pytest.mark.parametrize("header,error", [ - ("", "Error: header WWW-Authenticate not available in the response."), - ('Bearer"', 'Error: header info didn\'t have x_redirect_server'), - ('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'), # noqa: E501 - ('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'), - ('Bearer x_token_server="token_server"', 'Error: header info didn\'t have x_redirect_server'), -]) +@pytest.mark.parametrize( + "header,error", + [ + ("", "Error: header WWW-Authenticate not available in the response."), + ('Bearer"', "Error: header info didn't have x_redirect_server"), + ( + 'x_redirect_server="redirect_server", x_token_server="token_server"', + 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"', + ), # noqa: E501 + ( + 'Bearer x_redirect_server="redirect_server"', + "Error: header info didn't have x_token_server", + ), + ( + 'Bearer x_token_server="token_server"', + "Error: header info didn't have x_redirect_server", + ), + ], +) @httprettified def test_oauth2_authentication_missing_headers(header, error): # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - adding_headers={'WWW-Authenticate': header}, - status=401) + adding_headers={"WWW-Authenticate": header}, + status=401, + ) request = TrinoRequest( host="coordinator", @@ -436,7 +460,8 @@ def test_oauth2_authentication_missing_headers(header, error): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=RedirectHandler())) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=RedirectHandler()), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") @@ -444,13 +469,16 @@ def test_oauth2_authentication_missing_headers(header, error): assert str(exp.value) == error -@pytest.mark.parametrize("header", [ - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge', - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge="value"', - 'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"', - 'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"', - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"', -]) +@pytest.mark.parametrize( + "header", + [ + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge', + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge="value"', + 'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"', + 'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"', + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"', + ], +) @httprettified def test_oauth2_header_parsing(header, sample_post_response_data): token = str(uuid.uuid4()) @@ -464,21 +492,27 @@ def post_statement(request, uri, response_headers): authorization = request.headers.get("Authorization") if authorization and authorization.replace("Bearer ", "") in token: return [200, response_headers, json.dumps(sample_post_response_data)] - return [401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server), - 'Basic realm': '"Trino"'}, ""] + return [ + 401, + { + "Www-Authenticate": header.format( + redirect_server=redirect_server, token_server=token_server + ), + "Basic realm": '"Trino"', + }, + "", + ] # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement) + body=post_statement, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() @@ -489,10 +523,10 @@ def post_statement(request, uri, response_headers): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), ).post("select 1") - assert response.request.headers['Authorization'] == f"Bearer {token}" + assert response.request.headers["Authorization"] == f"Bearer {token}" assert redirect_handler.redirect_server == redirect_server assert get_token_callback.attempts == 0 assert len(_post_statement_requests()) == 2 @@ -508,19 +542,23 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + post_statement_callback = PostStatementCallback( + redirect_server, token_server, [token], sample_post_response_data + ) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) httpretty.register_uri( method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", status=http_status, - body="error") + body="error", + ) redirect_handler = RedirectHandler() @@ -531,13 +569,17 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") assert redirect_handler.redirect_server == redirect_server - assert str(exp.value) == f"Error while getting the token response status code: {http_status}, body: error" + assert ( + str(exp.value) + == f"Error while getting the token response status code: {http_status}, body: error" + ) assert len(_post_statement_requests()) == 1 assert len(_get_token_requests(challenge_id)) == 1 @@ -564,7 +606,8 @@ def run(self) -> None: user="test", ), http_scheme=constants.HTTPS, - auth=auth) + auth=auth, + ) for i in range(10): # apparently HTTPretty in the current version is not thread-safe # https://github.com/gabrielfalcao/HTTPretty/issues/209 @@ -650,10 +693,7 @@ def test_trino_fetch_error(mock_requests, sample_get_error_response_data): assert "stack" in error.failure_info assert len(error.failure_info["stack"]) == 36 assert "suppressed" in error.failure_info - assert ( - error.message - == "line 1:15: Schema must be specified when session schema is not set" - ) + assert error.message == "line 1:15: Schema must be specified when session schema is not set" assert error.error_location == (1, 15) assert error.query_id == "20210817_140827_00000_arvdv" @@ -803,11 +843,14 @@ def test_authentication_fail_retry(monkeypatch): assert post_retry.retry_count == attempts -@pytest.mark.parametrize("status_code, attempts", [ - (502, 3), - (503, 3), - (504, 3), -]) +@pytest.mark.parametrize( + "status_code, attempts", + [ + (502, 3), + (503, 3), + (504, 3), + ], +) def test_5XX_error_retry(status_code, attempts, monkeypatch): http_resp = TrinoRequest.http.Response() http_resp.status_code = status_code @@ -824,7 +867,7 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch): client_session=ClientSession( user="test", ), - max_attempts=attempts + max_attempts=attempts, ) req.post("URL") @@ -834,9 +877,7 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch): assert post_retry.retry_count == attempts -@pytest.mark.parametrize("status_code", [ - 501 -]) +@pytest.mark.parametrize("status_code", [501]) def test_error_no_retry(status_code, monkeypatch): http_resp = TrinoRequest.http.Response() http_resp.status_code = status_code @@ -887,10 +928,12 @@ def test_trino_result_response_headers(): headers associated to the TrinoQuery instance provided to the `TrinoResult` class. """ - mock_trino_query = mock.Mock(respone_headers={ - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', - }) + mock_trino_query = mock.Mock( + respone_headers={ + "X-Trino-Fake-1": "one", + "X-Trino-Fake-2": "two", + } + ) result = TrinoResult( query=mock_trino_query, @@ -910,8 +953,8 @@ class MockResponse(mock.Mock): @property def headers(self): return { - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', + "X-Trino-Fake-1": "one", + "X-Trino-Fake-2": "two", } def json(self): @@ -930,18 +973,15 @@ def json(self): http_scheme="http", ) - sql = 'execute my_stament using 1, 2, 3' + sql = "execute my_stament using 1, 2, 3" additional_headers = { - constants.HEADER_PREPARED_STATEMENT: 'my_statement=added_prepare_statement_header' + constants.HEADER_PREPARED_STATEMENT: "my_statement=added_prepare_statement_header" } # Patch the post function to avoid making the requests, as well as to # validate that the function was called with the right arguments. - with mock.patch.object(req, 'post', return_value=MockResponse()) as mock_post: - query = TrinoQuery( - request=req, - sql=sql - ) + with mock.patch.object(req, "post", return_value=MockResponse()) as mock_post: + query = TrinoQuery(request=req, sql=sql) result = query.execute(additional_http_headers=additional_headers) # Validate the the post function was called with the right argguments diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 7b1c72c2..873c75ba 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -17,8 +17,16 @@ from httpretty import httprettified from requests import Session -from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \ - GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS +from tests.unit.oauth_test_utils import ( + REDIRECT_RESOURCE, + SERVER_ADDRESS, + TOKEN_RESOURCE, + GetTokenCallback, + PostStatementCallback, + RedirectHandler, + _get_token_requests, + _post_statement_requests, +) from trino import constants from trino.auth import OAuth2Authentication from trino.dbapi import connect @@ -58,28 +66,28 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data): redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + post_statement_callback = PostStatementCallback( + redirect_server, token_server, [token], sample_post_response_data + ) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() with connect( - "coordinator", - user="test", - auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), - http_scheme=constants.HTTPS + "coordinator", + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + http_scheme=constants.HTTPS, ) as conn: conn.cursor().execute("SELECT 1") conn.cursor().execute("SELECT 2") @@ -87,18 +95,15 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data): # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() with connect( - "coordinator", - user="test", - auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), - http_scheme=constants.HTTPS + "coordinator", + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + http_scheme=constants.HTTPS, ) as conn2: conn2.cursor().execute("SELECT 1") conn2.cursor().execute("SELECT 2") @@ -115,30 +120,27 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + post_statement_callback = PostStatementCallback( + redirect_server, token_server, [token], sample_post_response_data + ) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler) with connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS + "coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS ) as conn: conn.cursor().execute("SELECT 1") conn.cursor().execute("SELECT 2") @@ -146,16 +148,10 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) with connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS + "coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS ) as conn2: conn2.cursor().execute("SELECT 1") conn2.cursor().execute("SELECT 2") @@ -173,31 +169,26 @@ def test_token_retrieved_once_when_multithreaded(sample_post_response_data): redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + post_statement_callback = PostStatementCallback( + redirect_server, token_server, [token], sample_post_response_data + ) # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler) - conn = connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS - ) + conn = connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) class RunningThread(threading.Thread): lock = threading.Lock() @@ -209,11 +200,7 @@ def run(self) -> None: with RunningThread.lock: conn.cursor().execute("SELECT 1") - threads = [ - RunningThread(), - RunningThread(), - RunningThread() - ] + threads = [RunningThread(), RunningThread(), RunningThread()] # run and join all threads for thread in threads: diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index fd3c5e2d..9753bab5 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -10,8 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from trino.client import get_header_values, get_session_property_values from trino import constants +from trino.client import get_header_values, get_session_property_values def test_get_header_values(): diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index e97903cf..02adeb9d 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -10,9 +10,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from trino.transaction import IsolationLevel import pytest +from trino.transaction import IsolationLevel + def test_isolation_level_levels() -> None: levels = { @@ -27,9 +28,7 @@ def test_isolation_level_levels() -> None: def test_isolation_level_values() -> None: - values = { - 0, 1, 2, 3, 4 - } + values = {0, 1, 2, 3, 4} assert IsolationLevel.values() == values diff --git a/trino/__init__.py b/trino/__init__.py index 4ff3e55b..42db0646 100644 --- a/trino/__init__.py +++ b/trino/__init__.py @@ -10,13 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import auth -from . import dbapi -from . import client -from . import constants -from . import exceptions -from . import logging +from . import auth, client, constants, dbapi, exceptions, logging -__all__ = ['auth', 'dbapi', 'client', 'constants', 'exceptions', 'logging'] +__all__ = ["auth", "dbapi", "client", "constants", "exceptions", "logging"] __version__ = "0.315.0" diff --git a/trino/auth.py b/trino/auth.py index e6b4f04c..e2c5d3fa 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -11,18 +11,18 @@ # limitations under the License. import abc +import importlib import json import os import re import threading import webbrowser -from typing import Optional, List, Callable +from typing import Callable, List, Optional from urllib.parse import urlparse from requests import Request from requests.auth import AuthBase, extract_cookies_to_jar from requests.utils import parse_dict_header -import importlib import trino.logging from trino.client import exceptions @@ -95,15 +95,17 @@ def get_exceptions(self): def __eq__(self, other): if not isinstance(other, KerberosAuthentication): return False - return (self._config == other._config - and self._service_name == other._service_name - and self._mutual_authentication == other._mutual_authentication - and self._force_preemptive == other._force_preemptive - and self._hostname_override == other._hostname_override - and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response - and self._principal == other._principal - and self._delegate == other._delegate - and self._ca_bundle == other._ca_bundle) + return ( + self._config == other._config + and self._service_name == other._service_name + and self._mutual_authentication == other._mutual_authentication + and self._force_preemptive == other._force_preemptive + and self._hostname_override == other._hostname_override + and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response + and self._principal == other._principal + and self._delegate == other._delegate + and self._ca_bundle == other._ca_bundle + ) class BasicAuthentication(Authentication): @@ -143,7 +145,6 @@ def __call__(self, r): class JWTAuthentication(Authentication): - def __init__(self, token): self.token = token @@ -251,31 +252,38 @@ def get_token_from_cache(self, host: str) -> Optional[str]: try: return self._keyring.get_password(host, "token") except self._keyring.errors.NoKeyringError as e: - raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " - "detected, check https://pypi.org/project/keyring/ for more " - "information.") from e + raise trino.exceptions.NotSupportedError( + "Although keyring module is installed no backend has been " + "detected, check https://pypi.org/project/keyring/ for more " + "information." + ) from e def store_token_to_cache(self, host: str, token: str) -> None: try: # keyring is installed, so we can store the token for reuse within multiple threads self._keyring.set_password(host, "token", token) except self._keyring.errors.NoKeyringError as e: - raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " - "detected, check https://pypi.org/project/keyring/ for more " - "information.") from e + raise trino.exceptions.NotSupportedError( + "Although keyring module is installed no backend has been " + "detected, check https://pypi.org/project/keyring/ for more " + "information." + ) from e class _OAuth2TokenBearer(AuthBase): """ Custom implementation of Trino Oauth2 based authorization to get the token """ + MAX_OAUTH_ATTEMPTS = 5 _BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE) def __init__(self, redirect_auth_url_handler: Callable[[str], None]): self._redirect_auth_url = redirect_auth_url_handler keyring_cache = _OAuth2KeyRingTokenCache() - self._token_cache = keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache() + self._token_cache = ( + keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache() + ) self._token_lock = threading.Lock() self._inside_oauth_attempt_lock = threading.Lock() self._inside_oauth_attempt_blocker = threading.Event() @@ -285,9 +293,9 @@ def __call__(self, r): token = self._get_token_from_cache(host) if token is not None: - r.headers['Authorization'] = "Bearer " + token + r.headers["Authorization"] = "Bearer " + token - r.register_hook('response', self._authenticate) + r.register_hook("response", self._authenticate) return r @@ -312,20 +320,24 @@ def _authenticate(self, response, **kwargs): def _attempt_oauth(self, response, **kwargs): # we have to handle the authentication, may be token the token expired, or it wasn't there at all - auth_info = response.headers.get('WWW-Authenticate') + auth_info = response.headers.get("WWW-Authenticate") if not auth_info: - raise exceptions.TrinoAuthError("Error: header WWW-Authenticate not available in the response.") + raise exceptions.TrinoAuthError( + "Error: header WWW-Authenticate not available in the response." + ) if not _OAuth2TokenBearer._BEARER_PREFIX.search(auth_info): raise exceptions.TrinoAuthError(f"Error: header info didn't match {auth_info}") - auth_info_headers = parse_dict_header(_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)) + auth_info_headers = parse_dict_header( + _OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1) + ) - auth_server = auth_info_headers.get('x_redirect_server') + auth_server = auth_info_headers.get("x_redirect_server") if auth_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_redirect_server") - token_server = auth_info_headers.get('x_token_server') + token_server = auth_info_headers.get("x_token_server") if token_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server") @@ -349,7 +361,7 @@ def _retry_request(self, response, **kwargs): request.prepare_cookies(request._cookies) host = self._determine_host(response.request.url) - request.headers['Authorization'] = "Bearer " + self._get_token_from_cache(host) + request.headers["Authorization"] = "Bearer " + self._get_token_from_cache(host) retry_response = response.connection.send(request, **kwargs) retry_response.history.append(response) retry_response.request = request @@ -359,23 +371,26 @@ def _get_token(self, token_server, response, **kwargs): attempts = 0 while attempts < self.MAX_OAUTH_ATTEMPTS: attempts += 1 - with response.connection.send(Request(method='GET', url=token_server).prepare(), **kwargs) as response: + with response.connection.send( + Request(method="GET", url=token_server).prepare(), **kwargs + ) as response: if response.status_code == 200: token_response = json.loads(response.text) - token = token_response.get('token') + token = token_response.get("token") if token: return token - error = token_response.get('error') + error = token_response.get("error") if error: raise exceptions.TrinoAuthError(f"Error while getting the token: {error}") else: - token_server = token_response.get('nextUri') + token_server = token_response.get("nextUri") logger.debug(f"nextURi auth token server: {token_server}") else: raise exceptions.TrinoAuthError( f"Error while getting the token response " f"status code: {response.status_code}, " - f"body: {response.text}") + f"body: {response.text}" + ) raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token") @@ -393,10 +408,12 @@ def _determine_host(url) -> Optional[str]: class OAuth2Authentication(Authentication): - def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([ - WebBrowserRedirectHandler(), - ConsoleRedirectHandler() - ])): + def __init__( + self, + redirect_auth_url_handler=CompositeRedirectHandler( + [WebBrowserRedirectHandler(), ConsoleRedirectHandler()] + ), + ): self._redirect_auth_url = redirect_auth_url_handler self._bearer = _OAuth2TokenBearer(self._redirect_auth_url) diff --git a/trino/client.py b/trino/client.py index 211ce9c0..ba7aca34 100644 --- a/trino/client.py +++ b/trino/client.py @@ -62,7 +62,7 @@ else: PROXIES = {} -_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$') +_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$") INF = float("inf") NEGATIVE_INF = float("-inf") @@ -214,8 +214,7 @@ def get_header_values(headers, header): def get_session_property_values(headers, header): kvs = get_header_values(headers, header) return [ - (k.strip(), urllib.parse.unquote(v.strip())) - for k, v in (kv.split("=", 1) for kv in kvs) + (k.strip(), urllib.parse.unquote(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs) ] @@ -245,16 +244,14 @@ def __repr__(self): class _DelayExponential(object): - def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): + def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600): # 100ms # 2 hours self._base = base self._exponent = exponent self._jitter = jitter self._max_delay = max_delay def __call__(self, attempt): - delay = float(self._base) * (self._exponent ** attempt) + delay = float(self._base) * (self._exponent**attempt) if self._jitter: delay *= random.random() delay = min(float(self._max_delay), delay) @@ -262,9 +259,7 @@ def __call__(self, attempt): class _RetryWithExponentialBackoff(object): - def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): + def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600): # 100ms # 2 hours self._get_delay = _DelayExponential(base, exponent, jitter, max_delay) def retry(self, func, args, kwargs, err, attempt): @@ -383,7 +378,10 @@ def http_headers(self) -> Dict[str, str]: headers[constants.HEADER_SOURCE] = self._client_session.source headers[constants.HEADER_USER] = self._client_session.user headers[constants.HEADER_ROLE] = self._client_session.role - if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0: + if ( + self._client_session.client_tags is not None + and len(self._client_session.client_tags) > 0 + ): headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags) headers[constants.HEADER_SESSION] = ",".join( @@ -401,8 +399,10 @@ def http_headers(self) -> Dict[str, str]: transaction_id = self._client_session.transaction_id headers[constants.HEADER_TRANSACTION] = transaction_id - if self._client_session.extra_credential is not None and \ - len(self._client_session.extra_credential) > 0: + if ( + self._client_session.extra_credential is not None + and len(self._client_session.extra_credential) > 0 + ): for tup in self._client_session.extra_credential: self._verify_extra_credential(tup) @@ -410,9 +410,12 @@ def http_headers(self) -> Dict[str, str]: # HTTP 1.1 section 4.2 combine multiple extra credentials into a # comma-separated value # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format) - headers[constants.HEADER_EXTRA_CREDENTIAL] = \ - ", ".join( - [f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential]) + headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join( + [ + f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" + for tup in self._client_session.extra_credential + ] + ) return headers @@ -536,9 +539,7 @@ def process(self, http_response) -> TrinoStatus: raise self._process_error(response["error"], response.get("id")) if constants.HEADER_CLEAR_SESSION in http_response.headers: - for prop in get_header_values( - http_response.headers, constants.HEADER_CLEAR_SESSION - ): + for prop in get_header_values(http_response.headers, constants.HEADER_CLEAR_SESSION): self._client_session.properties.pop(prop, None) if constants.HEADER_SET_SESSION in http_response.headers: @@ -579,7 +580,7 @@ def _verify_extra_credential(self, header): raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'") try: - key.encode().decode('ascii') + key.encode().decode("ascii") except UnicodeDecodeError: raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'") @@ -643,9 +644,13 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any: raw_type = { "typeSignature": data_type["typeSignature"]["arguments"][0]["value"] } - return [cls._map_to_python_type((array_item, raw_type)) for array_item in value] + return [ + cls._map_to_python_type((array_item, raw_type)) for array_item in value + ] if raw_type == "row": - raw_types = map(lambda arg: arg["value"], data_type["typeSignature"]["arguments"]) + raw_types = map( + lambda arg: arg["value"], data_type["typeSignature"]["arguments"] + ) return tuple( cls._map_to_python_type((array_item, raw_type)) for (array_item, raw_type) in zip(value, raw_types) @@ -659,44 +664,53 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any: "typeSignature": data_type["typeSignature"]["arguments"][1]["value"] } return { - cls._map_to_python_type((key, raw_key_type)): - cls._map_to_python_type((value[key], raw_value_type)) + cls._map_to_python_type((key, raw_key_type)): cls._map_to_python_type( + (value[key], raw_value_type) + ) for key in value } elif "decimal" in raw_type: return Decimal(value) elif raw_type == "double": - if value == 'Infinity': + if value == "Infinity": return INF - elif value == '-Infinity': + elif value == "-Infinity": return NEGATIVE_INF - elif value == 'NaN': + elif value == "NaN": return NAN return value elif raw_type == "date": return datetime.strptime(value, "%Y-%m-%d").date() elif raw_type == "timestamp with time zone": - dt, tz = value.rsplit(' ', 1) - if tz.startswith('+') or tz.startswith('-'): + dt, tz = value.rsplit(" ", 1) + if tz.startswith("+") or tz.startswith("-"): return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f %z") - return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=pytz.timezone(tz)) + return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace( + tzinfo=pytz.timezone(tz) + ) elif "timestamp" in raw_type: return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f") elif "time with time zone" in raw_type: - matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value) + matches = re.match(r"^(.*)([\+\-])(\d{2}):(\d{2})$", value) assert matches is not None assert len(matches.groups()) == 4 - if matches.group(2) == '-': + if matches.group(2) == "-": tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) else: tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) - return datetime.strptime(matches.group(1), "%H:%M:%S.%f").time().replace(tzinfo=timezone(tz)) + return ( + datetime.strptime(matches.group(1), "%H:%M:%S.%f") + .time() + .replace(tzinfo=timezone(tz)) + ) elif "time" in raw_type: return datetime.strptime(value, "%H:%M:%S.%f").time() else: return value except ValueError as e: - error_str = f"Could not convert '{value}' into the associated python type for '{raw_type}'" + error_str = ( + f"Could not convert '{value}' into the associated python type for '{raw_type}'" + ) raise trino.exceptions.TrinoDataError(error_str) from e def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]: @@ -707,10 +721,10 @@ class TrinoQuery(object): """Represent the execution of a SQL statement by Trino.""" def __init__( - self, - request: TrinoRequest, - sql: str, - experimental_python_types: bool = False, + self, + request: TrinoRequest, + sql: str, + experimental_python_types: bool = False, ) -> None: self.query_id: Optional[str] = None @@ -814,6 +828,7 @@ def cancel(self) -> None: def is_finished(self) -> bool: import warnings + warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning) return self.finished diff --git a/trino/constants.py b/trino/constants.py index 30046908..41104fdb 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -12,7 +12,6 @@ from typing import Any, Optional - DEFAULT_PORT = 8080 DEFAULT_TLS_PORT = 443 DEFAULT_SOURCE = "trino-python-client" @@ -45,9 +44,9 @@ HEADER_STARTED_TRANSACTION = "X-Trino-Started-Transaction-Id" HEADER_TRANSACTION = "X-Trino-Transaction-Id" -HEADER_PREPARED_STATEMENT = 'X-Trino-Prepared-Statement' -HEADER_ADDED_PREPARE = 'X-Trino-Added-Prepare' -HEADER_DEALLOCATED_PREPARE = 'X-Trino-Deallocated-Prepare' +HEADER_PREPARED_STATEMENT = "X-Trino-Prepared-Statement" +HEADER_ADDED_PREPARE = "X-Trino-Added-Prepare" +HEADER_DEALLOCATED_PREPARE = "X-Trino-Deallocated-Prepare" HEADER_SET_SCHEMA = "X-Trino-Set-Schema" HEADER_SET_CATALOG = "X-Trino-Set-Catalog" diff --git a/trino/dbapi.py b/trino/dbapi.py index 44813168..6828455d 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -17,31 +17,30 @@ Fetch methods returns rows as a list of lists on purpose to let the caller decide to convert then to a list of tuples. """ -from decimal import Decimal -from typing import Any, List, Optional # NOQA for mypy types - import copy -import uuid import datetime import math +import uuid +from decimal import Decimal +from typing import Any, List, Optional # NOQA for mypy types -from trino import constants -import trino.exceptions import trino.client +import trino.exceptions import trino.logging -from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION +from trino import constants from trino.exceptions import ( - Warning, - Error, - InterfaceError, DatabaseError, DataError, - OperationalError, + Error, IntegrityError, + InterfaceError, InternalError, - ProgrammingError, NotSupportedError, + OperationalError, + ProgrammingError, + Warning, ) +from trino.transaction import NO_TRANSACTION, IsolationLevel, Transaction __all__ = [ # https://www.python.org/dev/peps/pep-0249/#globals @@ -128,7 +127,7 @@ def __init__( headers=http_headers, transaction_id=NO_TRANSACTION, extra_credential=extra_credential, - client_tags=client_tags + client_tags=client_tags, ) # mypy cannot follow module import if http_session is None: @@ -217,7 +216,9 @@ def cursor(self, experimental_python_types: bool = None): self, request, # if experimental_python_types is not explicitly set in Cursor, take from Connection - experimental_python_types if experimental_python_types is not None else self.experimental_python_types + experimental_python_types + if experimental_python_types is not None + else self.experimental_python_types, ) @@ -231,9 +232,7 @@ class Cursor(object): def __init__(self, connection, request, experimental_python_types: bool = False): if not isinstance(connection, Connection): - raise ValueError( - "connection must be a Connection object: {}".format(type(connection)) - ) + raise ValueError("connection must be a Connection object: {}".format(type(connection))) self._connection = connection self._request = request @@ -268,8 +267,7 @@ def description(self): # [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ] return [ - (col["name"], col["type"], None, None, None, None, None) - for col in self._query.columns + (col["name"], col["type"], None, None, None, None, None) for col in self._query.columns ] @property @@ -317,15 +315,17 @@ def _prepare_statement(self, operation, statement_name): :return: string representing the value of the 'X-Trino-Added-Prepare' header. """ - sql = 'PREPARE {statement_name} FROM {operation}'.format( - statement_name=statement_name, - operation=operation + sql = "PREPARE {statement_name} FROM {operation}".format( + statement_name=statement_name, operation=operation ) # Send prepare statement. Copy the _request object to avoid poluting the # one that is going to be used to execute the actual operation. - query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, - experimental_python_types=self._experimental_pyton_types) + query = trino.client.TrinoQuery( + copy.deepcopy(self._request), + sql=sql, + experimental_python_types=self._experimental_pyton_types, + ) result = query.execute() # Iterate until the 'X-Trino-Added-Prepare' header is found or @@ -338,16 +338,19 @@ def _prepare_statement(self, operation, statement_name): raise trino.exceptions.FailedToObtainAddedPrepareHeader - def _get_added_prepare_statement_trino_query( - self, - statement_name, - params - ): - sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) + def _get_added_prepare_statement_trino_query(self, statement_name, params): + sql = ( + "EXECUTE " + + statement_name + + " USING " + + ",".join(map(self._format_prepared_param, params)) + ) # No need to deepcopy _request here because this is the actual request # operation - return trino.client.TrinoQuery(self._request, sql=sql, experimental_python_types=self._experimental_pyton_types) + return trino.client.TrinoQuery( + self._request, sql=sql, experimental_python_types=self._experimental_pyton_types + ) def _format_prepared_param(self, param): """ @@ -374,7 +377,7 @@ def _format_prepared_param(self, param): return "DOUBLE '%s'" % param if isinstance(param, str): - return ("'%s'" % param.replace("'", "''")) + return "'%s'" % param.replace("'", "''") if isinstance(param, bytes): return "X'%s'" % param.hex() @@ -386,7 +389,7 @@ def _format_prepared_param(self, param): if isinstance(param, datetime.datetime) and param.tzinfo is not None: datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f") # named timezones - if hasattr(param.tzinfo, 'zone'): + if hasattr(param.tzinfo, "zone"): return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.zone) # offset-based timezones return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param)) @@ -401,17 +404,16 @@ def _format_prepared_param(self, param): return "DATE '%s'" % date_str if isinstance(param, list): - return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param)) + return "ARRAY[%s]" % ",".join(map(self._format_prepared_param, param)) if isinstance(param, tuple): - return "ROW(%s)" % ','.join(map(self._format_prepared_param, param)) + return "ROW(%s)" % ",".join(map(self._format_prepared_param, param)) if isinstance(param, dict): keys = list(param.keys()) values = [param[key] for key in keys] return "MAP({}, {})".format( - self._format_prepared_param(keys), - self._format_prepared_param(values) + self._format_prepared_param(keys), self._format_prepared_param(values) ) if isinstance(param, uuid.UUID): @@ -420,19 +422,22 @@ def _format_prepared_param(self, param): if isinstance(param, Decimal): return "DECIMAL '%s'" % param - raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param)) + raise trino.exceptions.NotSupportedError( + "Query parameter of type '%s' is not supported." % type(param) + ) def _deallocate_prepare_statement(self, added_prepare_header, statement_name): - sql = 'DEALLOCATE PREPARE ' + statement_name + sql = "DEALLOCATE PREPARE " + statement_name # Send deallocate statement. Copy the _request object to avoid poluting the # one that is going to be used to execute the actual operation. - query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, - experimental_python_types=self._experimental_pyton_types) + query = trino.client.TrinoQuery( + copy.deepcopy(self._request), + sql=sql, + experimental_python_types=self._experimental_pyton_types, + ) result = query.execute( - additional_http_headers={ - constants.HEADER_PREPARED_STATEMENT: added_prepare_header - } + additional_http_headers={constants.HEADER_PREPARED_STATEMENT: added_prepare_header} ) # Iterate until the 'X-Trino-Deallocated-Prepare' header is found or @@ -446,27 +451,22 @@ def _deallocate_prepare_statement(self, added_prepare_header, statement_name): raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader def _generate_unique_statement_name(self): - return 'st_' + uuid.uuid4().hex.replace('-', '') + return "st_" + uuid.uuid4().hex.replace("-", "") def execute(self, operation, params=None): if params: assert isinstance(params, (list, tuple)), ( - 'params must be a list or tuple containing the query ' - 'parameter values' + "params must be a list or tuple containing the query " "parameter values" ) statement_name = self._generate_unique_statement_name() # Send prepare statement - added_prepare_header = self._prepare_statement( - operation, statement_name - ) + added_prepare_header = self._prepare_statement(operation, statement_name) try: # Send execute statement and assign the return value to `results` # as it will be returned by the function - self._query = self._get_added_prepare_statement_trino_query( - statement_name, params - ) + self._query = self._get_added_prepare_statement_trino_query(statement_name, params) result = self._query.execute( additional_http_headers={ constants.HEADER_PREPARED_STATEMENT: added_prepare_header @@ -479,8 +479,11 @@ def execute(self, operation, params=None): self._deallocate_prepare_statement(added_prepare_header, statement_name) else: - self._query = trino.client.TrinoQuery(self._request, sql=operation, - experimental_python_types=self._experimental_pyton_types) + self._query = trino.client.TrinoQuery( + self._request, + sql=operation, + experimental_python_types=self._experimental_pyton_types, + ) result = self._query.execute() self._iterator = iter(result) return result @@ -570,9 +573,7 @@ def fetchall(self) -> List[List[Any]]: def cancel(self): if self._query is None: - raise trino.exceptions.OperationalError( - "Cancel query failed; no running query" - ) + raise trino.exceptions.OperationalError("Cancel query failed; no running query") self._query.cancel() def close(self): @@ -606,9 +607,7 @@ def __eq__(self, other): STRING = DBAPITypeObject("VARCHAR", "CHAR", "VARBINARY", "JSON", "IPADDRESS") -BINARY = DBAPITypeObject( - "ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest" -) +BINARY = DBAPITypeObject("ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest") NUMBER = DBAPITypeObject( "BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL" diff --git a/trino/exceptions.py b/trino/exceptions.py index 86708fd0..f8eddba3 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -139,6 +139,7 @@ class FailedToObtainAddedPrepareHeader(Error): Raise this exception when unable to find the 'X-Trino-Added-Prepare' header in the response of a PREPARE statement request. """ + pass @@ -147,6 +148,7 @@ class FailedToObtainDeallocatedPrepareHeader(Error): Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare' header in the response of a DEALLOCATED statement request. """ + pass diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index a085fbf3..30141edf 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -10,15 +10,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from sqlalchemy.sql import compiler + try: - from sqlalchemy.sql.expression import ( - Alias, - CTE, - Subquery, - ) + from sqlalchemy.sql.expression import CTE, Alias, Subquery except ImportError: # For SQLAlchemy versions < 1.4, the CTE and Subquery classes did not explicitly exist from sqlalchemy.sql.expression import Alias + CTE = type(None) Subquery = type(None) @@ -113,8 +111,16 @@ def limit_clause(self, select, **kw): text += "\nLIMIT " + self.process(select._limit_clause, **kw) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, use_schema=True, **kwargs): + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + **kwargs, + ): sql = super(TrinoSQLCompiler, self).visit_table( table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs ) @@ -128,13 +134,10 @@ def add_catalog(sql, table): if isinstance(table, (Alias, CTE, Subquery)): return sql - if ( - 'trino' not in table.dialect_options - or 'catalog' not in table.dialect_options['trino'] - ): + if "trino" not in table.dialect_options or "catalog" not in table.dialect_options["trino"]: return sql - catalog = table.dialect_options['trino']['catalog'] + catalog = table.dialect_options["trino"]["catalog"] sql = f'"{catalog}".{sql}' return sql diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index 8284ba9c..e54f98c3 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union from sqlalchemy import util from sqlalchemy.sql import sqltypes @@ -167,7 +167,7 @@ def aware_split( elif character == close_bracket: parens -= 1 elif character == quote: - if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote: + if quotes and string[j - len(escaped_quote) + 1 : j + 1] != escaped_quote: quotes = False elif not quotes: quotes = True diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 7c4409a0..d58fc387 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -19,7 +19,8 @@ from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.engine.url import URL -from trino import dbapi as trino_dbapi, logging +from trino import dbapi as trino_dbapi +from trino import logging from trino.auth import BasicAuthentication, CertificateAuthentication, JWTAuthentication from trino.dbapi import Cursor from trino.sqlalchemy import compiler, datatype, error @@ -102,7 +103,7 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if "cert" and "key" in url.query: kwargs["http_scheme"] = "https" - kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key']) + kwargs["auth"] = CertificateAuthentication(url.query["cert"], url.query["key"]) if "source" in url.query: kwargs["source"] = url.query["source"] @@ -122,16 +123,22 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any kwargs["client_tags"] = json.loads(url.query["client_tags"]) if "experimental_python_types" in url.query: - kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"]) + kwargs["experimental_python_types"] = json.loads( + url.query["experimental_python_types"] + ) return args, kwargs - def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + def get_columns( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[Dict[str, Any]]: if not self.has_table(connection, table_name, schema): raise exc.NoSuchTableError(f"schema={schema}, table={table_name}") return self._get_columns(connection, table_name, schema, **kw) - def _get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + def _get_columns( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[Dict[str, Any]]: schema = schema or self._get_default_schema_name(connection) query = dedent( """ @@ -158,11 +165,15 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No columns.append(column) return columns - def get_pk_constraint(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + def get_pk_constraint( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> Dict[str, Any]: """Trino has no support for primary keys. Returns a dummy""" return dict(name=None, constrained_columns=[]) - def get_primary_keys(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[str]: + def get_primary_keys( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[str]: pk = self.get_pk_constraint(connection, table_name, schema) return pk.get("constrained_columns") # type: ignore @@ -218,7 +229,9 @@ def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) """Trino has no support for temporary views. Returns an empty list.""" return [] - def get_view_definition(self, connection: Connection, view_name: str, schema: str = None, **kw) -> str: + def get_view_definition( + self, connection: Connection, view_name: str, schema: str = None, **kw + ) -> str: schema = schema or self._get_default_schema_name(connection) if schema is None: raise exc.NoSuchTableError("schema is required") @@ -233,17 +246,21 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st res = connection.execute(sql.text(query), schema=schema, view=view_name) return res.scalar() - def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + def get_indexes( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[Dict[str, Any]]: if not self.has_table(connection, table_name, schema): raise exc.NoSuchTableError(f"schema={schema}, table={table_name}") - partitioned_columns = self._get_columns(connection, f"{table_name}$partitions", schema, **kw) + partitioned_columns = self._get_columns( + connection, f"{table_name}$partitions", schema, **kw + ) if not partitioned_columns: return [] partition_index = dict( name="partition", column_names=[col["name"] for col in partitioned_columns], - unique=False + unique=False, ) return [partition_index] @@ -263,7 +280,9 @@ def get_check_constraints( """Trino has no support for check constraints. Returns an empty list.""" return [] - def get_table_comment(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + def get_table_comment( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> Dict[str, Any]: catalog_name = self._get_default_catalog_name(connection) if catalog_name is None: raise exc.NoSuchTableError("catalog is required in connection") @@ -282,13 +301,13 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str try: res = connection.execute( sql.text(query), - catalog_name=catalog_name, schema_name=schema_name, table_name=table_name + catalog_name=catalog_name, + schema_name=schema_name, + table_name=table_name, ) return dict(text=res.scalar()) except error.TrinoQueryError as e: - if e.error_name in ( - error.PERMISSION_DENIED, - ): + if e.error_name in (error.PERMISSION_DENIED,): return dict(text=None) raise @@ -318,7 +337,9 @@ def has_table(self, connection: Connection, table_name: str, schema: str = None, res = connection.execute(sql.text(query), schema=schema, table=table_name) return res.first() is not None - def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool: + def has_sequence( + self, connection: Connection, sequence_name: str, schema: str = None, **kw + ) -> bool: """Trino has no support for sequence. Returns False indicate that given sequence does not exists.""" return False @@ -341,7 +362,11 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]: return dbapi_connection.schema def do_execute( - self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None + self, + cursor: Cursor, + statement: str, + parameters: Tuple[Any, ...], + context: DefaultExecutionContext = None, ): cursor.execute(statement, parameters) if context and context.should_autocommit: diff --git a/trino/transaction.py b/trino/transaction.py index e6c85234..c6e6257d 100644 --- a/trino/transaction.py +++ b/trino/transaction.py @@ -12,11 +12,10 @@ from enum import Enum, unique from typing import Iterable -from trino import constants import trino.client import trino.exceptions import trino.logging - +from trino import constants logger = trino.logging.get_logger(__name__)