diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000..21c33d71
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,11 @@
+[flake8]
+select =
+    E
+    W
+    F
+ignore =
+    W503 # makes Flake8 work like black
+    W504
+    E203 # makes Flake8 work like black
+    E741
+    E501
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 32bf21cc..b815fc8a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -13,3 +13,16 @@ repos:
         additional_dependencies:
           - "types-pytz"
           - "types-requests"
+
+  - repo: https://github.com/psf/black
+    rev: 22.3.0
+    hooks:
+      - id: black
+        args:
+          - "--line-length=99"
+
+  - repo: https://github.com/pycqa/isort
+    rev: 5.6.4
+    hooks:
+      - id: isort
+        args: [ "--profile", "black", "--filter-files" ]
diff --git a/README.md b/README.md
index d45c6554..98b01322 100644
--- a/README.md
+++ b/README.md
@@ -424,6 +424,7 @@ We recommend that you use Python3's `venv` for development:
 $ python3 -m venv .venv
 $ . .venv/bin/activate
 $ pip install -e '.[tests]'
+$ pre-commit install
 ```
 
 With `-e` passed to `pip install` above pip can reference the code you are
@@ -441,6 +442,17 @@ When the code is ready, submit a Pull Request.
 See also Trino's [guidelines](https://github.com/trinodb/trino/blob/master/.github/DEVELOPMENT.md).
 Most of them also apply to code in trino-python-client.
 
+### `pre-commit` checks
+
+Code is automatically checked on commit by a [pre-commit](https://pre-commit.com/) git hook.
+
+Following checks are performed:
+
+- [`flake8`](https://flake8.pycqa.org/en/latest/) for code linting
+- [`black`](https://github.com/psf/black) for code formatting
+- [`isort`](https://pycqa.github.io/isort/) for sorting imports
+- [`mypy`](https://mypy.readthedocs.io/en/stable/) for static type checking
+
 ### Running tests
 
 `trino-python-client` uses [pytest](https://pytest.org/) for its tests. To run
diff --git a/setup.py b/setup.py
index 2695a530..2ee59cb2 100755
--- a/setup.py
+++ b/setup.py
@@ -16,7 +16,7 @@
 import re
 import textwrap
 
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
 
 _version_re = re.compile(r"__version__\s+=\s+(.*)")
 
@@ -39,6 +39,9 @@
     "pytest",
     "pytest-runner",
     "click",
+    "pre-commit",
+    "black",
+    "isort",
 ]
 
 setup(
@@ -75,7 +78,7 @@
         "Programming Language :: Python :: Implementation :: PyPy",
         "Topic :: Database :: Front-Ends",
     ],
-    python_requires='>=3.7',
+    python_requires=">=3.7",
     install_requires=["pytz", "requests"],
     extras_require={
         "all": all_require,
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index 840ee8f3..0d3ea56e 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -18,11 +18,11 @@
 from uuid import uuid4
 
 import click
-import trino.logging
 import pytest
-from trino.client import TrinoQuery, TrinoRequest, ClientSession
-from trino.constants import DEFAULT_PORT
 
+import trino.logging
+from trino.client import ClientSession, TrinoQuery, TrinoRequest
+from trino.constants import DEFAULT_PORT
 
 logger = trino.logging.get_logger(__name__)
 
@@ -64,13 +64,7 @@ def start_trino(image_tag=None):
 
 
 def wait_for_trino_workers(host, port, timeout=180):
-    request = TrinoRequest(
-        host=host,
-        port=port,
-        client_session=ClientSession(
-            user="test_fixture"
-        )
-    )
+    request = TrinoRequest(host=host, port=port, client_session=ClientSession(user="test_fixture"))
     sql = "SELECT state FROM system.runtime.nodes"
     t0 = time.time()
     while True:
@@ -116,9 +110,7 @@ def start_trino_and_wait(image_tag=None):
     if host:
         port = os.environ.get("TRINO_RUNNING_PORT", DEFAULT_PORT)
     else:
-        container_id, proc, host, port = start_local_trino_server(
-            image_tag
-        )
+        container_id, proc, host, port = start_local_trino_server(image_tag)
 
     print("trino.server.hostname {}".format(host))
     print("trino.server.port {}".format(port))
@@ -167,9 +159,7 @@ def cli():
     pass
 
 
-@click.option(
-    "--cache/--no-cache", default=True, help="enable/disable Docker build cache"
-)
+@click.option("--cache/--no-cache", default=True, help="enable/disable Docker build cache")
 @click.command()
 def trino_server():
     container_id, _, _, _ = start_trino_and_wait()
@@ -198,9 +188,7 @@ def trino_cli(container_id=None):
 
 @cli.command("list")
 def list_():
-    subprocess.check_call(
-        ["docker", "ps", "--filter", "name=trino-python-client-tests-"]
-    )
+    subprocess.check_call(["docker", "ps", "--filter", "name=trino-python-client-tests-"])
 
 
 @cli.command()
diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py
index c5d0dda0..99d21ff2 100644
--- a/tests/integration/test_dbapi_integration.py
+++ b/tests/integration/test_dbapi_integration.py
@@ -10,7 +10,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import math
-from datetime import datetime, time, date, timezone, timedelta
+from datetime import date, datetime, time, timedelta, timezone
 from decimal import Decimal
 
 import pytest
@@ -19,7 +19,7 @@
 
 import trino
 from tests.integration.conftest import trino_version
-from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError
+from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError
 from trino.transaction import IsolationLevel
 
 
@@ -27,9 +27,7 @@
 def trino_connection(run_trino):
     _, host, port = run_trino
 
-    yield trino.dbapi.Connection(
-        host=host, port=port, user="test", source="test", max_attempts=1
-    )
+    yield trino.dbapi.Connection(host=host, port=port, user="test", source="test", max_attempts=1)
 
 
 @pytest.fixture
@@ -104,7 +102,7 @@ def test_select_query_result_iteration_statement_params(trino_connection):
         ) x (id, name, letter)
         WHERE id >= ?
         """,
-        params=(3,)  # expecting all the rows with id >= 3
+        params=(3,),  # expecting all the rows with id >= 3
     )
 
 
@@ -153,7 +151,9 @@ def test_execute_many_without_params(trino_connection):
     cur = trino_connection.cursor()
     cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)")
     cur.fetchall()
-    cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
+    cur.executemany(
+        "INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", []
+    )
     with pytest.raises(TrinoUserError) as e:
         cur.fetchall()
     assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value)
@@ -166,23 +166,22 @@ def test_execute_many_select(trino_connection):
     assert "Query must return update type" in str(e.value)
 
 
-@pytest.mark.parametrize("connection_experimental_python_types,cursor_experimental_python_types,expected",
-                         [
-                             (None, None, False),
-                             (None, False, False),
-                             (None, True, True),
-                             (False, None, False),
-                             (False, False, False),
-                             (False, True, True),
-                             (True, None, True),
-                             (True, False, False),
-                             (True, True, True),
-                         ])
+@pytest.mark.parametrize(
+    "connection_experimental_python_types,cursor_experimental_python_types,expected",
+    [
+        (None, None, False),
+        (None, False, False),
+        (None, True, True),
+        (False, None, False),
+        (False, False, False),
+        (False, True, True),
+        (True, None, True),
+        (True, False, False),
+        (True, True, True),
+    ],
+)
 def test_experimental_python_types_with_connection_and_cursor(
-        connection_experimental_python_types,
-        cursor_experimental_python_types,
-        expected,
-        run_trino
+    connection_experimental_python_types, cursor_experimental_python_types, expected, run_trino
 ):
     _, host, port = run_trino
 
@@ -195,7 +194,8 @@ def test_experimental_python_types_with_connection_and_cursor(
 
     cur = connection.cursor(experimental_python_types=cursor_experimental_python_types)
 
-    cur.execute("""
+    cur.execute(
+        """
     SELECT
         DECIMAL '0.142857',
         DATE '2018-01-01',
@@ -203,35 +203,36 @@ def test_experimental_python_types_with_connection_and_cursor(
         TIMESTAMP '2019-01-01 00:00:00.000 UTC',
         TIMESTAMP '2019-01-01 00:00:00.000',
         TIME '00:00:00.000'
-    """)
+    """
+    )
     rows = cur.fetchall()
 
     if expected:
-        assert rows[0][0] == Decimal('0.142857')
+        assert rows[0][0] == Decimal("0.142857")
         assert rows[0][1] == date(2018, 1, 1)
         assert rows[0][2] == datetime(2019, 1, 1, tzinfo=timezone(timedelta(hours=1)))
-        assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone('UTC'))
+        assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone("UTC"))
         assert rows[0][4] == datetime(2019, 1, 1)
         assert rows[0][5] == time(0, 0, 0, 0)
     else:
         for value in rows[0]:
             assert isinstance(value, str)
 
-        assert rows[0][0] == '0.142857'
-        assert rows[0][1] == '2018-01-01'
-        assert rows[0][2] == '2019-01-01 00:00:00.000 +01:00'
-        assert rows[0][3] == '2019-01-01 00:00:00.000 UTC'
-        assert rows[0][4] == '2019-01-01 00:00:00.000'
-        assert rows[0][5] == '00:00:00.000'
+        assert rows[0][0] == "0.142857"
+        assert rows[0][1] == "2018-01-01"
+        assert rows[0][2] == "2019-01-01 00:00:00.000 +01:00"
+        assert rows[0][3] == "2019-01-01 00:00:00.000 UTC"
+        assert rows[0][4] == "2019-01-01 00:00:00.000"
+        assert rows[0][5] == "00:00:00.000"
 
 
 def test_decimal_query_param(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    cur.execute("SELECT ?", params=(Decimal('0.142857'),))
+    cur.execute("SELECT ?", params=(Decimal("0.142857"),))
     rows = cur.fetchall()
 
-    assert rows[0][0] == Decimal('0.142857')
+    assert rows[0][0] == Decimal("0.142857")
 
 
 def test_null_decimal(trino_connection):
@@ -246,7 +247,7 @@ def test_null_decimal(trino_connection):
 def test_biggest_decimal(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    params = Decimal('99999999999999999999999999999999999999')
+    params = Decimal("99999999999999999999999999999999999999")
     cur.execute("SELECT ?", params=(params,))
     rows = cur.fetchall()
 
@@ -256,7 +257,7 @@ def test_biggest_decimal(trino_connection):
 def test_smallest_decimal(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    params = Decimal('-99999999999999999999999999999999999999')
+    params = Decimal("-99999999999999999999999999999999999999")
     cur.execute("SELECT ?", params=(params,))
     rows = cur.fetchall()
 
@@ -266,7 +267,7 @@ def test_smallest_decimal(trino_connection):
 def test_highest_precision_decimal(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    params = Decimal('0.99999999999999999999999999999999999999')
+    params = Decimal("0.99999999999999999999999999999999999999")
     cur.execute("SELECT ?", params=(params,))
     rows = cur.fetchall()
 
@@ -288,7 +289,7 @@ def test_datetime_query_param(trino_connection):
 def test_datetime_with_utc_time_zone_query_param(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('UTC'))
+    params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone("UTC"))
 
     cur.execute("SELECT ?", params=(params,))
     rows = cur.fetchall()
@@ -314,7 +315,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
 def test_datetime_with_named_time_zone_query_param(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))
+    params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone("America/Los_Angeles"))
 
     cur.execute("SELECT ?", params=(params,))
     rows = cur.fetchall()
@@ -347,14 +348,16 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection):
     cur.execute("SELECT TIMESTAMP '2001-08-22 03:04:05.321 -08:00'")
     rows = cur.fetchall()
 
-    assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z")
+    assert rows[0][0] == datetime.strptime(
+        "2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z"
+    )
 
 
 def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
     # This is a datetime that lies within a DST transition and not actually exists.
-    params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone('Europe/Brussels'))
+    params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone("Europe/Brussels"))
     with pytest.raises(trino.exceptions.TrinoUserError):
         cur.execute("SELECT ?", params=(params,))
         cur.fetchall()
@@ -365,21 +368,21 @@ def test_doubled_datetimes(trino_connection):
     # See also https://github.com/trinodb/trino/issues/5781
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True)
+    params = pytz.timezone("US/Eastern").localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True)
 
     cur.execute("SELECT ?", params=(params,))
     rows = cur.fetchall()
 
-    assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))
+    assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone("US/Eastern"))
 
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False)
+    params = pytz.timezone("US/Eastern").localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False)
 
     cur.execute("SELECT ?", params=(params,))
     rows = cur.fetchall()
 
-    assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))
+    assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone("US/Eastern"))
 
 
 def test_date_query_param(trino_connection):
@@ -407,11 +410,11 @@ def test_unsupported_python_dates(trino_connection):
 
     # dates below python min (1-1-1) or above max date (9999-12-31) are not supported
     for unsupported_date in [
-        '-0001-01-01',
-        '0000-01-01',
-        '10000-01-01',
-        '-4999999-01-01',  # Trino min date
-        '5000000-12-31',  # Trino max date
+        "-0001-01-01",
+        "0000-01-01",
+        "10000-01-01",
+        "-4999999-01-01",  # Trino min date
+        "5000000-12-31",  # Trino max date
     ]:
         with pytest.raises(trino.exceptions.TrinoDataError):
             cur.execute(f"SELECT DATE '{unsupported_date}'")
@@ -422,26 +425,26 @@ def test_supported_special_dates_query_param(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
     for params in (
-            # min python date
-            date(1, 1, 1),
-            # before julian->gregorian switch
-            date(1500, 1, 1),
-            # During julian->gregorian switch
-            date(1752, 9, 4),
-            # before epoch
-            date(1952, 4, 3),
-            date(1970, 1, 1),
-            date(1970, 2, 3),
-            # summer on northern hemisphere (possible DST)
-            date(2017, 7, 1),
-            # winter on northern hemisphere (possible DST on southern hemisphere)
-            date(2017, 1, 1),
-            # winter on southern hemisphere (possible DST on northern hemisphere)
-            date(2017, 12, 31),
-            date(1983, 4, 1),
-            date(1983, 10, 1),
-            # max python date
-            date(9999, 12, 31),
+        # min python date
+        date(1, 1, 1),
+        # before julian->gregorian switch
+        date(1500, 1, 1),
+        # During julian->gregorian switch
+        date(1752, 9, 4),
+        # before epoch
+        date(1952, 4, 3),
+        date(1970, 1, 1),
+        date(1970, 2, 3),
+        # summer on northern hemisphere (possible DST)
+        date(2017, 7, 1),
+        # winter on northern hemisphere (possible DST on southern hemisphere)
+        date(2017, 1, 1),
+        # winter on southern hemisphere (possible DST on northern hemisphere)
+        date(2017, 12, 31),
+        date(1983, 4, 1),
+        date(1983, 10, 1),
+        # max python date
+        date(9999, 12, 31),
     ):
         cur.execute("SELECT ?", params=(params,))
         rows = cur.fetchall()
@@ -464,7 +467,7 @@ def test_time_with_named_time_zone_query_param(trino_connection):
     with pytest.raises(trino.exceptions.NotSupportedError):
         cur = trino_connection.cursor()
 
-        params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('Asia/Shanghai'))
+        params = time(16, 43, 22, 320000, tzinfo=pytz.timezone("Asia/Shanghai"))
 
         cur.execute("SELECT ?", params=(params,))
 
@@ -599,7 +602,10 @@ def test_array_timestamp_query_param(trino_connection):
 def test_array_timestamp_with_timezone_query_param(trino_connection):
     cur = trino_connection.cursor(experimental_python_types=True)
 
-    params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]
+    params = [
+        datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
+        datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc),
+    ]
 
     cur.execute("SELECT ?", params=(params,))
     rows = cur.fetchall()
@@ -723,15 +729,18 @@ def test_int_query_param(trino_connection):
     assert cur.description[0][1] == "bigint"
 
 
-@pytest.mark.parametrize('params', [
-    'NOT A LIST OR TUPPLE',
-    {'invalid', 'params'},
-    object,
-])
+@pytest.mark.parametrize(
+    "params",
+    [
+        "NOT A LIST OR TUPPLE",
+        {"invalid", "params"},
+        object,
+    ],
+)
 def test_select_query_invalid_params(trino_connection, params):
     cur = trino_connection.cursor()
     with pytest.raises(AssertionError):
-        cur.execute('SELECT ?', params=params)
+        cur.execute("SELECT ?", params=params)
 
 
 def test_select_cursor_iteration(trino_connection):
@@ -877,8 +886,9 @@ def test_transaction_multiple(trino_connection_with_transaction):
     assert len(rows2) == 1000
 
 
-@pytest.mark.skipif(trino_version() == '351', reason="Autocommit behaves "
-                                                     "differently in older Trino versions")
+@pytest.mark.skipif(
+    trino_version() == "351", reason="Autocommit behaves " "differently in older Trino versions"
+)
 def test_transaction_autocommit(trino_connection_in_autocommit):
     with trino_connection_in_autocommit as connection:
         connection.start_transaction()
@@ -887,17 +897,18 @@ def test_transaction_autocommit(trino_connection_in_autocommit):
             """
             CREATE TABLE memory.default.nation
             AS SELECT * from tpch.tiny.nation
-            """)
+            """
+        )
 
         with pytest.raises(TrinoUserError) as transaction_error:
             cur.fetchall()
-        assert "Catalog only supports writes using autocommit: memory" \
-               in str(transaction_error.value)
+        assert "Catalog only supports writes using autocommit: memory" in str(
+            transaction_error.value
+        )
 
 
 def test_invalid_query_throws_correct_error(trino_connection):
-    """Tests that an invalid query raises the correct exception
-    """
+    """Tests that an invalid query raises the correct exception"""
     cur = trino_connection.cursor()
     with pytest.raises(TrinoQueryError):
         cur.execute(
@@ -910,14 +921,14 @@ def test_invalid_query_throws_correct_error(trino_connection):
 
 def test_eager_loading_cursor_description(trino_connection):
     description_expected = [
-        ('node_id', 'varchar', None, None, None, None, None),
-        ('http_uri', 'varchar', None, None, None, None, None),
-        ('node_version', 'varchar', None, None, None, None, None),
-        ('coordinator', 'boolean', None, None, None, None, None),
-        ('state', 'varchar', None, None, None, None, None),
+        ("node_id", "varchar", None, None, None, None, None),
+        ("http_uri", "varchar", None, None, None, None, None),
+        ("node_version", "varchar", None, None, None, None, None),
+        ("coordinator", "boolean", None, None, None, None, None),
+        ("state", "varchar", None, None, None, None, None),
     ]
     cur = trino_connection.cursor()
-    cur.execute('SELECT * FROM system.runtime.nodes')
+    cur.execute("SELECT * FROM system.runtime.nodes")
     description_before = cur.description
 
     assert description_before is not None
@@ -934,7 +945,7 @@ def test_eager_loading_cursor_description(trino_connection):
 def test_info_uri(trino_connection):
     cur = trino_connection.cursor()
     assert cur.info_uri is None
-    cur.execute('SELECT * FROM system.runtime.nodes')
+    cur.execute("SELECT * FROM system.runtime.nodes")
     assert cur.info_uri is not None
     assert cur._query.query_id in cur.info_uri
     cur.fetchall()
@@ -971,44 +982,51 @@ def retrieve_client_tags_from_query(run_trino, client_tags):
     )
 
     cur = trino_connection.cursor()
-    cur.execute('SELECT 1')
+    cur.execute("SELECT 1")
     cur.fetchall()
 
     api_url = "http://" + trino_connection.host + ":" + str(trino_connection.port)
-    query_info = requests.post(api_url + "/ui/login", data={
-        "username": "admin",
-        "password": "",
-        "redirectPath": api_url + '/ui/api/query/' + cur._query.query_id
-    }).json()
-
-    query_client_tags = query_info['session']['clientTags']
+    query_info = requests.post(
+        api_url + "/ui/login",
+        data={
+            "username": "admin",
+            "password": "",
+            "redirectPath": api_url + "/ui/api/query/" + cur._query.query_id,
+        },
+    ).json()
+
+    query_client_tags = query_info["session"]["clientTags"]
     return query_client_tags
 
 
-@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions")
+@pytest.mark.skipif(
+    trino_version() == "351", reason="current_catalog not supported in older Trino versions"
+)
 def test_use_catalog_schema(trino_connection):
     cur = trino_connection.cursor()
-    cur.execute('SELECT current_catalog, current_schema')
+    cur.execute("SELECT current_catalog, current_schema")
     result = cur.fetchall()
     assert result[0][0] is None
     assert result[0][1] is None
 
-    cur.execute('USE tpch.tiny')
+    cur.execute("USE tpch.tiny")
     cur.fetchall()
-    cur.execute('SELECT current_catalog, current_schema')
+    cur.execute("SELECT current_catalog, current_schema")
     result = cur.fetchall()
-    assert result[0][0] == 'tpch'
-    assert result[0][1] == 'tiny'
+    assert result[0][0] == "tpch"
+    assert result[0][1] == "tiny"
 
-    cur.execute('USE tpcds.sf1')
+    cur.execute("USE tpcds.sf1")
     cur.fetchall()
-    cur.execute('SELECT current_catalog, current_schema')
+    cur.execute("SELECT current_catalog, current_schema")
     result = cur.fetchall()
-    assert result[0][0] == 'tpcds'
-    assert result[0][1] == 'sf1'
+    assert result[0][0] == "tpcds"
+    assert result[0][1] == "sf1"
 
 
-@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions")
+@pytest.mark.skipif(
+    trino_version() == "351", reason="current_catalog not supported in older Trino versions"
+)
 def test_use_catalog(run_trino):
     _, host, port = run_trino
 
@@ -1016,35 +1034,33 @@ def test_use_catalog(run_trino):
         host=host, port=port, user="test", source="test", catalog="tpch", max_attempts=1
     )
     cur = trino_connection.cursor()
-    cur.execute('SELECT current_catalog, current_schema')
+    cur.execute("SELECT current_catalog, current_schema")
     result = cur.fetchall()
-    assert result[0][0] == 'tpch'
+    assert result[0][0] == "tpch"
     assert result[0][1] is None
 
-    cur.execute('USE tiny')
+    cur.execute("USE tiny")
     cur.fetchall()
-    cur.execute('SELECT current_catalog, current_schema')
+    cur.execute("SELECT current_catalog, current_schema")
     result = cur.fetchall()
-    assert result[0][0] == 'tpch'
-    assert result[0][1] == 'tiny'
+    assert result[0][0] == "tpch"
+    assert result[0][1] == "tiny"
 
-    cur.execute('USE sf1')
+    cur.execute("USE sf1")
     cur.fetchall()
-    cur.execute('SELECT current_catalog, current_schema')
+    cur.execute("SELECT current_catalog, current_schema")
     result = cur.fetchall()
-    assert result[0][0] == 'tpch'
-    assert result[0][1] == 'sf1'
+    assert result[0][0] == "tpch"
+    assert result[0][1] == "sf1"
 
 
-@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role")
+@pytest.mark.skipif(trino_version() == "351", reason="Newer Trino versions return the system role")
 def test_set_role_trino_higher_351(run_trino):
     _, host, port = run_trino
 
-    trino_connection = trino.dbapi.Connection(
-        host=host, port=port, user="test", catalog="tpch"
-    )
+    trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch")
     cur = trino_connection.cursor()
-    cur.execute('SHOW TABLES FROM information_schema')
+    cur.execute("SHOW TABLES FROM information_schema")
     cur.fetchall()
     assert cur._request._client_session.role is None
 
@@ -1053,15 +1069,15 @@ def test_set_role_trino_higher_351(run_trino):
     assert cur._request._client_session.role == "system=ALL"
 
 
-@pytest.mark.skipif(trino_version() != '351', reason="Trino 351 returns the role for the current catalog")
+@pytest.mark.skipif(
+    trino_version() != "351", reason="Trino 351 returns the role for the current catalog"
+)
 def test_set_role_trino_351(run_trino):
     _, host, port = run_trino
 
-    trino_connection = trino.dbapi.Connection(
-        host=host, port=port, user="test", catalog="tpch"
-    )
+    trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch")
     cur = trino_connection.cursor()
-    cur.execute('SHOW TABLES FROM information_schema')
+    cur.execute("SHOW TABLES FROM information_schema")
     cur.fetchall()
     assert cur._request._client_session.role is None
 
diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py
index 1dc8f05a..cdac58ab 100644
--- a/tests/integration/test_sqlalchemy_integration.py
+++ b/tests/integration/test_sqlalchemy_integration.py
@@ -11,22 +11,24 @@
 # limitations under the License
 import pytest
 import sqlalchemy as sqla
-from sqlalchemy.sql import and_, or_, not_
+from sqlalchemy.sql import and_, not_, or_
 
 
 @pytest.fixture
 def trino_connection(run_trino, request):
     _, host, port = run_trino
-    engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}",
-                                connect_args={"source": "test", "max_attempts": 1})
+    engine = sqla.create_engine(
+        f"trino://test@{host}:{port}/{request.param}",
+        connect_args={"source": "test", "max_attempts": 1},
+    )
     yield engine, engine.connect()
 
 
-@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True)
 def test_select_query(trino_connection):
     _, conn = trino_connection
     metadata = sqla.MetaData()
-    nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
+    nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn)
     assert_column(nations, "nationkey", sqla.sql.sqltypes.BigInteger)
     assert_column(nations, "name", sqla.sql.sqltypes.String)
     assert_column(nations, "regionkey", sqla.sql.sqltypes.BigInteger)
@@ -36,10 +38,10 @@ def test_select_query(trino_connection):
     rows = result.fetchall()
     assert len(rows) == 25
     for row in rows:
-        assert isinstance(row['nationkey'], int)
-        assert isinstance(row['name'], str)
-        assert isinstance(row['regionkey'], int)
-        assert isinstance(row['comment'], str)
+        assert isinstance(row["nationkey"], int)
+        assert isinstance(row["name"], str)
+        assert isinstance(row["regionkey"], int)
+        assert isinstance(row["comment"], str)
 
 
 def assert_column(table, column_name, column_type):
@@ -47,11 +49,11 @@ def assert_column(table, column_name, column_type):
     assert isinstance(getattr(table.c, column_name).type, column_type)
 
 
-@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["system"], indirect=True)
 def test_select_specific_columns(trino_connection):
     _, conn = trino_connection
     metadata = sqla.MetaData()
-    nodes = sqla.Table('nodes', metadata, schema='runtime', autoload_with=conn)
+    nodes = sqla.Table("nodes", metadata, schema="runtime", autoload_with=conn)
     assert_column(nodes, "node_id", sqla.sql.sqltypes.String)
     assert_column(nodes, "state", sqla.sql.sqltypes.String)
     query = sqla.select(nodes.c.node_id, nodes.c.state)
@@ -59,26 +61,28 @@ def test_select_specific_columns(trino_connection):
     rows = result.fetchall()
     assert len(rows) > 0
     for row in rows:
-        assert isinstance(row['node_id'], str)
-        assert isinstance(row['state'], str)
+        assert isinstance(row["node_id"], str)
+        assert isinstance(row["state"], str)
 
 
-@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True)
 def test_define_and_create_table(trino_connection):
     engine, conn = trino_connection
     if not engine.dialect.has_schema(engine, "test"):
         engine.execute(sqla.schema.CreateSchema("test"))
     metadata = sqla.MetaData()
     try:
-        sqla.Table('users',
-                   metadata,
-                   sqla.Column('id', sqla.Integer),
-                   sqla.Column('name', sqla.String),
-                   sqla.Column('fullname', sqla.String),
-                   schema="test")
+        sqla.Table(
+            "users",
+            metadata,
+            sqla.Column("id", sqla.Integer),
+            sqla.Column("name", sqla.String),
+            sqla.Column("fullname", sqla.String),
+            schema="test",
+        )
         metadata.create_all(engine)
-        assert sqla.inspect(engine).has_table('users', schema="test")
-        users = sqla.Table('users', metadata, schema='test', autoload_with=conn)
+        assert sqla.inspect(engine).has_table("users", schema="test")
+        users = sqla.Table("users", metadata, schema="test", autoload_with=conn)
         assert_column(users, "id", sqla.sql.sqltypes.Integer)
         assert_column(users, "name", sqla.sql.sqltypes.String)
         assert_column(users, "fullname", sqla.sql.sqltypes.String)
@@ -86,7 +90,7 @@ def test_define_and_create_table(trino_connection):
         metadata.drop_all(engine)
 
 
-@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True)
 def test_insert(trino_connection):
     engine, conn = trino_connection
 
@@ -94,12 +98,14 @@ def test_insert(trino_connection):
         engine.execute(sqla.schema.CreateSchema("test"))
     metadata = sqla.MetaData()
     try:
-        users = sqla.Table('users',
-                           metadata,
-                           sqla.Column('id', sqla.Integer),
-                           sqla.Column('name', sqla.String),
-                           sqla.Column('fullname', sqla.String),
-                           schema="test")
+        users = sqla.Table(
+            "users",
+            metadata,
+            sqla.Column("id", sqla.Integer),
+            sqla.Column("name", sqla.String),
+            sqla.Column("fullname", sqla.String),
+            schema="test",
+        )
         metadata.create_all(engine)
         ins = users.insert()
         conn.execute(ins, {"id": 2, "name": "wendy", "fullname": "Wendy Williams"})
@@ -112,72 +118,79 @@ def test_insert(trino_connection):
         metadata.drop_all(engine)
 
 
-@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True)
 def test_insert_multiple_statements(trino_connection):
     engine, conn = trino_connection
     if not engine.dialect.has_schema(engine, "test"):
         engine.execute(sqla.schema.CreateSchema("test"))
     metadata = sqla.MetaData()
-    users = sqla.Table('users',
-                       metadata,
-                       sqla.Column('id', sqla.Integer),
-                       sqla.Column('name', sqla.String),
-                       sqla.Column('fullname', sqla.String),
-                       schema="test")
+    users = sqla.Table(
+        "users",
+        metadata,
+        sqla.Column("id", sqla.Integer),
+        sqla.Column("name", sqla.String),
+        sqla.Column("fullname", sqla.String),
+        schema="test",
+    )
     metadata.create_all(engine)
     ins = users.insert()
-    conn.execute(ins, [
-        {"id": 2, "name": "wendy", "fullname": "Wendy Williams"},
-        {"id": 3, "name": "john", "fullname": "John Doe"},
-        {"id": 4, "name": "mary", "fullname": "Mary Hopkins"},
-    ])
+    conn.execute(
+        ins,
+        [
+            {"id": 2, "name": "wendy", "fullname": "Wendy Williams"},
+            {"id": 3, "name": "john", "fullname": "John Doe"},
+            {"id": 4, "name": "mary", "fullname": "Mary Hopkins"},
+        ],
+    )
     query = sqla.select(users)
     result = conn.execute(query)
     rows = result.fetchall()
     assert len(rows) == 3
-    assert frozenset(rows) == frozenset([
-        (2, "wendy", "Wendy Williams"),
-        (3, "john", "John Doe"),
-        (4, "mary", "Mary Hopkins"),
-    ])
+    assert frozenset(rows) == frozenset(
+        [
+            (2, "wendy", "Wendy Williams"),
+            (3, "john", "John Doe"),
+            (4, "mary", "Mary Hopkins"),
+        ]
+    )
     metadata.drop_all(engine)
 
 
-@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True)
 def test_operators(trino_connection):
     _, conn = trino_connection
     metadata = sqla.MetaData()
-    customers = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
+    customers = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn)
     query = sqla.select(customers).where(customers.c.nationkey == 2)
     result = conn.execute(query)
     rows = result.fetchall()
     assert len(rows) == 1
     for row in rows:
-        assert isinstance(row['nationkey'], int)
-        assert isinstance(row['name'], str)
-        assert isinstance(row['regionkey'], int)
-        assert isinstance(row['comment'], str)
+        assert isinstance(row["nationkey"], int)
+        assert isinstance(row["name"], str)
+        assert isinstance(row["regionkey"], int)
+        assert isinstance(row["comment"], str)
 
 
-@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True)
 def test_conjunctions(trino_connection):
     _, conn = trino_connection
     metadata = sqla.MetaData()
-    customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn)
-    query = sqla.select(customers).where(and_(
-        customers.c.name.like('%12%'),
-        customers.c.nationkey == 15,
-        or_(
-            customers.c.mktsegment == 'AUTOMOBILE',
-            customers.c.mktsegment == 'HOUSEHOLD'
-        ),
-        not_(customers.c.acctbal < 0)))
+    customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn)
+    query = sqla.select(customers).where(
+        and_(
+            customers.c.name.like("%12%"),
+            customers.c.nationkey == 15,
+            or_(customers.c.mktsegment == "AUTOMOBILE", customers.c.mktsegment == "HOUSEHOLD"),
+            not_(customers.c.acctbal < 0),
+        )
+    )
     result = conn.execute(query)
     rows = result.fetchall()
     assert len(rows) == 1
 
 
-@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True)
 def test_textual_sql(trino_connection):
     _, conn = trino_connection
     s = sqla.text("SELECT * from tiny.customer where nationkey = :e1 AND acctbal < :e2")
@@ -185,70 +198,79 @@ def test_textual_sql(trino_connection):
     rows = result.fetchall()
     assert len(rows) == 3
     for row in rows:
-        assert isinstance(row['custkey'], int)
-        assert isinstance(row['name'], str)
-        assert isinstance(row['address'], str)
-        assert isinstance(row['nationkey'], int)
-        assert isinstance(row['phone'], str)
-        assert isinstance(row['acctbal'], float)
-        assert isinstance(row['mktsegment'], str)
-        assert isinstance(row['comment'], str)
+        assert isinstance(row["custkey"], int)
+        assert isinstance(row["name"], str)
+        assert isinstance(row["address"], str)
+        assert isinstance(row["nationkey"], int)
+        assert isinstance(row["phone"], str)
+        assert isinstance(row["acctbal"], float)
+        assert isinstance(row["mktsegment"], str)
+        assert isinstance(row["comment"], str)
 
 
-@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True)
 def test_alias(trino_connection):
     _, conn = trino_connection
     metadata = sqla.MetaData()
-    nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
+    nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn)
     nations1 = nations.alias("o1")
     nations2 = nations.alias("o2")
-    s = sqla.select(nations1) \
-        .join(nations2, and_(
-            nations1.c.regionkey == nations2.c.regionkey,
-            nations1.c.nationkey != nations2.c.nationkey,
-            nations1.c.regionkey == 1
-        )) \
+    s = (
+        sqla.select(nations1)
+        .join(
+            nations2,
+            and_(
+                nations1.c.regionkey == nations2.c.regionkey,
+                nations1.c.nationkey != nations2.c.nationkey,
+                nations1.c.regionkey == 1,
+            ),
+        )
         .distinct()
+    )
     result = conn.execute(s)
     rows = result.fetchall()
     assert len(rows) == 5
 
 
-@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True)
 def test_subquery(trino_connection):
     _, conn = trino_connection
     metadata = sqla.MetaData()
-    nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
-    customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn)
+    nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn)
+    customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn)
     automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900)
     automobile_customers_subquery = automobile_customers.subquery()
-    s = sqla.select(nations.c.name).where(nations.c.nationkey.in_(sqla.select(automobile_customers_subquery)))
+    s = sqla.select(nations.c.name).where(
+        nations.c.nationkey.in_(sqla.select(automobile_customers_subquery))
+    )
     result = conn.execute(s)
     rows = result.fetchall()
     assert len(rows) == 15
 
 
-@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True)
 def test_joins(trino_connection):
     _, conn = trino_connection
     metadata = sqla.MetaData()
-    nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
-    customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn)
-    s = sqla.select(nations.c.name) \
-        .select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey)) \
-        .where(customers.c.acctbal < -900) \
+    nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn)
+    customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn)
+    s = (
+        sqla.select(nations.c.name)
+        .select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey))
+        .where(customers.c.acctbal < -900)
         .distinct()
+    )
     result = conn.execute(s)
     rows = result.fetchall()
     assert len(rows) == 15
 
 
-@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
+@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True)
 def test_cte(trino_connection):
     _, conn = trino_connection
     metadata = sqla.MetaData()
-    nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn)
-    customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn)
+    nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn)
+    customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn)
     automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900)
     automobile_customers_cte = automobile_customers.cte()
     s = sqla.select(nations).where(nations.c.nationkey.in_(sqla.select(automobile_customers_cte)))
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 9d968e07..59e312d4 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -10,9 +10,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import pytest
 from unittest.mock import MagicMock, patch
 
+import pytest
+
 
 @pytest.fixture(scope="session")
 def sample_post_response_data():
@@ -194,8 +195,7 @@ def sample_get_error_response_data():
             "errorType": "USER_ERROR",
             "failureInfo": {
                 "errorLocation": {"columnNumber": 15, "lineNumber": 1},
-                "message": "line 1:15: Schema must be specified "
-                "when session schema is not set",
+                "message": "line 1:15: Schema must be specified " "when session schema is not set",
                 "stack": [
                     "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)",
                     "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)",
diff --git a/tests/unit/oauth_test_utils.py b/tests/unit/oauth_test_utils.py
index 77b891ce..1467a489 100644
--- a/tests/unit/oauth_test_utils.py
+++ b/tests/unit/oauth_test_utils.py
@@ -45,9 +45,15 @@ def __call__(self, request, uri, response_headers):
         authorization = request.headers.get("Authorization")
         if authorization and authorization.replace("Bearer ", "") in self.tokens:
             return [200, response_headers, json.dumps(self.sample_post_response_data)]
-        return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
-                                          f'x_token_server="{self.token_server}"',
-                      'Basic realm': '"Trino"'}, ""]
+        return [
+            401,
+            {
+                "Www-Authenticate": f'Bearer x_redirect_server="{self.redirect_server}", '
+                f'x_token_server="{self.token_server}"',
+                "Basic realm": '"Trino"',
+            },
+            "",
+        ]
 
 
 class GetTokenCallback:
@@ -66,19 +72,25 @@ def __call__(self, request, uri, response_headers):
 
 
 def _get_token_requests(challenge_id):
-    return list(filter(
-        lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
-        httpretty.latest_requests()))
+    return list(
+        filter(
+            lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
+            httpretty.latest_requests(),
+        )
+    )
 
 
 def _post_statement_requests():
-    return list(filter(
-        lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
-        httpretty.latest_requests()))
+    return list(
+        filter(
+            lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
+            httpretty.latest_requests(),
+        )
+    )
 
 
 class MultithreadedTokenServer:
-    Challenge = namedtuple('Challenge', ['token', 'attempts'])
+    Challenge = namedtuple("Challenge", ["token", "attempts"])
 
     def __init__(self, sample_post_response_data, attempts=1):
         self.tokens = set()
@@ -90,13 +102,15 @@ def __init__(self, sample_post_response_data, attempts=1):
         httpretty.register_uri(
             method=httpretty.POST,
             uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
-            body=self.post_statement_callback)
+            body=self.post_statement_callback,
+        )
 
         # bind get token
         httpretty.register_uri(
             method=httpretty.GET,
             uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
-            body=self.get_token_callback)
+            body=self.get_token_callback,
+        )
 
     # noinspection PyUnusedLocal
     def post_statement_callback(self, request, uri, response_headers):
@@ -111,9 +125,15 @@ def post_statement_callback(self, request, uri, response_headers):
         self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts)
         redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
         token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
-        return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", '
-                                          f'x_token_server="{token_server}"',
-                      'Basic realm': '"Trino"'}, ""]
+        return [
+            401,
+            {
+                "Www-Authenticate": f'Bearer x_redirect_server="{redirect_server}", '
+                f'x_token_server="{token_server}"',
+                "Basic realm": '"Trino"',
+            },
+            "",
+        ]
 
     # noinspection PyUnusedLocal
     def get_token_callback(self, request, uri, response_headers):
diff --git a/tests/unit/sqlalchemy/conftest.py b/tests/unit/sqlalchemy/conftest.py
index e80f19b8..71d6f74d 100644
--- a/tests/unit/sqlalchemy/conftest.py
+++ b/tests/unit/sqlalchemy/conftest.py
@@ -12,7 +12,7 @@
 import pytest
 from sqlalchemy.sql.sqltypes import ARRAY
 
-from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME
+from trino.sqlalchemy.datatype import MAP, ROW, TIME, TIMESTAMP, SQLType
 
 
 @pytest.fixture(scope="session")
diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py
index 00bd686c..6139affc 100644
--- a/tests/unit/sqlalchemy/test_compiler.py
+++ b/tests/unit/sqlalchemy/test_compiler.py
@@ -10,32 +10,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import pytest
-from sqlalchemy import (
-    Column,
-    insert,
-    Integer,
-    MetaData,
-    select,
-    String,
-    Table,
-)
+from sqlalchemy import Column, Integer, MetaData, String, Table, insert, select
 from sqlalchemy.schema import CreateTable
 
 from trino.sqlalchemy.dialect import TrinoDialect
 
 metadata = MetaData()
 table = Table(
-    'table',
+    "table",
     metadata,
-    Column('id', Integer),
-    Column('name', String),
+    Column("id", Integer),
+    Column("name", String),
 )
 table_with_catalog = Table(
-    'table',
-    metadata,
-    Column('id', Integer),
-    schema='default',
-    trino_catalog='other'
+    "table", metadata, Column("id", Integer), schema="default", trino_catalog="other"
 )
 
 
@@ -47,7 +35,10 @@ def dialect():
 def test_limit_offset(dialect):
     statement = select(table).limit(10).offset(0)
     query = statement.compile(dialect=dialect)
-    assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1\nLIMIT :param_2'
+    assert (
+        str(query)
+        == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1\nLIMIT :param_2'
+    )
 
 
 def test_limit(dialect):
@@ -63,15 +54,16 @@ def test_offset(dialect):
 
 
 def test_cte_insert_order(dialect):
-    cte = select(table).cte('cte')
+    cte = select(table).cte("cte")
     statement = insert(table).from_select(table.columns, cte)
     query = statement.compile(dialect=dialect)
-    assert str(query) == \
-        'INSERT INTO "table" (id, name) WITH cte AS \n'\
-        '(SELECT "table".id AS id, "table".name AS name \n'\
-        'FROM "table")\n'\
-        ' SELECT cte.id, cte.name \n'\
-        'FROM cte'
+    assert (
+        str(query) == 'INSERT INTO "table" (id, name) WITH cte AS \n'
+        '(SELECT "table".id AS id, "table".name AS name \n'
+        'FROM "table")\n'
+        " SELECT cte.id, cte.name \n"
+        "FROM cte"
+    )
 
 
 def test_catalogs_argument(dialect):
@@ -83,9 +75,6 @@ def test_catalogs_argument(dialect):
 def test_catalogs_create_table(dialect):
     statement = CreateTable(table_with_catalog)
     query = statement.compile(dialect=dialect)
-    assert str(query) == \
-        '\n'\
-        'CREATE TABLE "other".default."table" (\n'\
-        '\tid INTEGER\n'\
-        ')\n'\
-        '\n'
+    assert (
+        str(query) == "\n" 'CREATE TABLE "other".default."table" (\n' "\tid INTEGER\n" ")\n" "\n"
+    )
diff --git a/tests/unit/sqlalchemy/test_datatype_parse.py b/tests/unit/sqlalchemy/test_datatype_parse.py
index 66a2f6b0..90e007b9 100644
--- a/tests/unit/sqlalchemy/test_datatype_parse.py
+++ b/tests/unit/sqlalchemy/test_datatype_parse.py
@@ -10,23 +10,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import pytest
-from sqlalchemy.sql.sqltypes import (
-    CHAR,
-    VARCHAR,
-    ARRAY,
-    INTEGER,
-    DECIMAL,
-    DATE
-)
+from sqlalchemy.sql.sqltypes import ARRAY, CHAR, DATE, DECIMAL, INTEGER, VARCHAR
 from sqlalchemy.sql.type_api import TypeEngine
 
 from trino.sqlalchemy import datatype
-from trino.sqlalchemy.datatype import (
-    MAP,
-    ROW,
-    TIME,
-    TIMESTAMP
-)
+from trino.sqlalchemy.datatype import MAP, ROW, TIME, TIMESTAMP
 
 
 @pytest.mark.parametrize(
@@ -68,7 +56,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
     "CHAR(10)": CHAR(10),
     "VARCHAR(10)": VARCHAR(10),
     "DECIMAL(20)": DECIMAL(20),
-    "DECIMAL(20, 3)": DECIMAL(20, 3)
+    "DECIMAL(20, 3)": DECIMAL(20, 3),
 }
 
 
@@ -108,7 +96,9 @@ def test_parse_array(type_str: str, sql_type: ARRAY, assert_sqltype):
     "map(varchar(10), decimal(20,3))": MAP(VARCHAR(10), DECIMAL(20, 3)),
     "map(char, array(varchar(10)))": MAP(CHAR(), ARRAY(VARCHAR(10))),
     "map(varchar(10), array(varchar(10)))": MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
-    "map(varchar(10), array(array(varchar(10))))": MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
+    "map(varchar(10), array(array(varchar(10))))": MAP(
+        VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)
+    ),
 }
 
 
@@ -187,7 +177,7 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
     "timestamp(3)": TIMESTAMP(3, timezone=False),
     "timestamp(6)": TIMESTAMP(6),
     "timestamp(12) with time zone": TIMESTAMP(12, timezone=True),
-    "timestamp with time zone": TIMESTAMP(timezone=True)
+    "timestamp with time zone": TIMESTAMP(timezone=True),
 }
 
 
diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py
index b17f8cfe..4cff4c27 100644
--- a/tests/unit/sqlalchemy/test_dialect.py
+++ b/tests/unit/sqlalchemy/test_dialect.py
@@ -7,7 +7,11 @@
 
 from trino.auth import BasicAuthentication
 from trino.dbapi import Connection
-from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect
+from trino.sqlalchemy.dialect import (
+    CertificateAuthentication,
+    JWTAuthentication,
+    TrinoDialect,
+)
 from trino.transaction import IsolationLevel
 
 
@@ -26,7 +30,13 @@ def setup(self):
             (
                 make_url("trino://user@localhost:8080"),
                 list(),
-                dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"),
+                dict(
+                    host="localhost",
+                    port=8080,
+                    catalog="system",
+                    user="user",
+                    source="trino-sqlalchemy",
+                ),
             ),
             (
                 make_url("trino://user:pass@localhost:8080?source=trino-rulez"),
@@ -38,17 +48,18 @@ def setup(self):
                     user="user",
                     auth=BasicAuthentication("user", "pass"),
                     http_scheme="https",
-                    source="trino-rulez"
+                    source="trino-rulez",
                 ),
             ),
             (
                 make_url(
-                    'trino://user@localhost:8080?'
+                    "trino://user@localhost:8080?"
                     'session_properties={"query_max_run_time": "1d"}'
                     '&http_headers={"trino": 1}'
                     '&extra_credential=[("a", "b"), ("c", "d")]'
                     '&client_tags=[1, "sql"]'
-                    '&experimental_python_types=true'),
+                    "&experimental_python_types=true"
+                ),
                 list(),
                 dict(
                     host="localhost",
@@ -65,7 +76,9 @@ def setup(self):
             ),
         ],
     )
-    def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]):
+    def test_create_connect_args(
+        self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]
+    ):
         actual_args, actual_kwargs = self.dialect.create_connect_args(url)
 
         assert actual_args == expected_args
@@ -73,7 +86,9 @@ def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_
 
     def test_create_connect_args_missing_user_when_specify_password(self):
         url = make_url("trino://:pass@localhost")
-        with pytest.raises(ValueError, match="Username is required when specify password in connection URL"):
+        with pytest.raises(
+            ValueError, match="Username is required when specify password in connection URL"
+        ):
             self.dialect.create_connect_args(url)
 
     def test_create_connect_args_wrong_db_format(self):
@@ -97,36 +112,36 @@ def test_isolation_level(self):
 
 def test_trino_connection_basic_auth():
     dialect = TrinoDialect()
-    username = 'trino-user'
-    password = 'trino-bunny'
-    url = make_url(f'trino://{username}:{password}@host')
+    username = "trino-user"
+    password = "trino-bunny"
+    url = make_url(f"trino://{username}:{password}@host")
     _, cparams = dialect.create_connect_args(url)
 
-    assert cparams['http_scheme'] == "https"
-    assert isinstance(cparams['auth'], BasicAuthentication)
-    assert cparams['auth']._username == username
-    assert cparams['auth']._password == password
+    assert cparams["http_scheme"] == "https"
+    assert isinstance(cparams["auth"], BasicAuthentication)
+    assert cparams["auth"]._username == username
+    assert cparams["auth"]._password == password
 
 
 def test_trino_connection_jwt_auth():
     dialect = TrinoDialect()
-    access_token = 'sample-token'
-    url = make_url(f'trino://host/?access_token={access_token}')
+    access_token = "sample-token"
+    url = make_url(f"trino://host/?access_token={access_token}")
     _, cparams = dialect.create_connect_args(url)
 
-    assert cparams['http_scheme'] == "https"
-    assert isinstance(cparams['auth'], JWTAuthentication)
-    assert cparams['auth'].token == access_token
+    assert cparams["http_scheme"] == "https"
+    assert isinstance(cparams["auth"], JWTAuthentication)
+    assert cparams["auth"].token == access_token
 
 
 def test_trino_connection_certificate_auth():
     dialect = TrinoDialect()
-    cert = '/path/to/cert.pem'
-    key = '/path/to/key.pem'
-    url = make_url(f'trino://host/?cert={cert}&key={key}')
+    cert = "/path/to/cert.pem"
+    key = "/path/to/key.pem"
+    url = make_url(f"trino://host/?cert={cert}&key={key}")
     _, cparams = dialect.create_connect_args(url)
 
-    assert cparams['http_scheme'] == "https"
-    assert isinstance(cparams['auth'], CertificateAuthentication)
-    assert cparams['auth']._cert == cert
-    assert cparams['auth']._key == key
+    assert cparams["http_scheme"] == "https"
+    assert isinstance(cparams["auth"], CertificateAuthentication)
+    assert cparams["auth"]._cert == cert
+    assert cparams["auth"]._key == key
diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py
index 12089305..ba30349b 100644
--- a/tests/unit/test_client.py
+++ b/tests/unit/test_client.py
@@ -23,13 +23,28 @@
 from requests_kerberos.exceptions import KerberosExchangeError
 
 import trino.exceptions
-from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \
-    MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \
-    SERVER_ADDRESS
+from tests.unit.oauth_test_utils import (
+    REDIRECT_RESOURCE,
+    SERVER_ADDRESS,
+    TOKEN_RESOURCE,
+    GetTokenCallback,
+    MultithreadedTokenServer,
+    PostStatementCallback,
+    RedirectHandler,
+    _get_token_requests,
+    _post_statement_requests,
+)
 from trino import constants
 from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
-from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession, _DelayExponential, _retry_with, \
-    _RetryWithExponentialBackoff
+from trino.client import (
+    ClientSession,
+    TrinoQuery,
+    TrinoRequest,
+    TrinoResult,
+    _DelayExponential,
+    _retry_with,
+    _RetryWithExponentialBackoff,
+)
 
 
 @mock.patch("trino.client.TrinoRequest.http")
@@ -80,7 +95,7 @@ def test_request_headers(mock_get_and_post):
             headers={
                 accept_encoding_header: accept_encoding_value,
                 client_info_header: client_info_value,
-            }
+            },
         ),
         http_scheme="http",
         redirect_handler=None,
@@ -113,13 +128,8 @@ def test_request_session_properties_headers(mock_get_and_post):
         host="coordinator",
         port=8080,
         client_session=ClientSession(
-            user="test_user",
-            properties={
-                "a": "1",
-                "b": "2",
-                "c": "more=v1,v2"
-            }
-        )
+            user="test_user", properties={"a": "1", "b": "2", "c": "more=v1,v2"}
+        ),
     )
 
     def assert_headers(headers):
@@ -154,10 +164,10 @@ def test_additional_request_post_headers(mock_get_and_post):
         http_scheme="http",
     )
 
-    sql = 'select 1'
+    sql = "select 1"
     additional_headers = {
-        'X-Trino-Fake-1': 'one',
-        'X-Trino-Fake-2': 'two',
+        "X-Trino-Fake-1": "one",
+        "X-Trino-Fake-2": "two",
     }
 
     combined_headers = req.http_headers
@@ -167,7 +177,7 @@ def test_additional_request_post_headers(mock_get_and_post):
 
     # Validate that the post call was performed including the addtional headers
     _, post_kwargs = post.call_args
-    assert post_kwargs['headers'] == combined_headers
+    assert post_kwargs["headers"] == combined_headers
 
 
 def test_request_invalid_http_headers():
@@ -189,10 +199,7 @@ def test_request_client_tags_headers(mock_get_and_post):
     req = TrinoRequest(
         host="coordinator",
         port=8080,
-        client_session=ClientSession(
-            user="test_user",
-            client_tags=["tag1", "tag2"]
-        ),
+        client_session=ClientSession(user="test_user", client_tags=["tag1", "tag2"]),
     )
 
     def assert_headers(headers):
@@ -215,7 +222,7 @@ def test_request_client_tags_headers_no_client_tags(mock_get_and_post):
         port=8080,
         client_session=ClientSession(
             user="test_user",
-        )
+        ),
     )
 
     def assert_headers(headers):
@@ -335,20 +342,20 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
     redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
     token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
 
-    post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
+    post_statement_callback = PostStatementCallback(
+        redirect_server, token_server, [token], sample_post_response_data
+    )
 
     # bind post statement
     httpretty.register_uri(
         method=httpretty.POST,
         uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
-        body=post_statement_callback)
+        body=post_statement_callback,
+    )
 
     # bind get token
     get_token_callback = GetTokenCallback(token_server, token, attempts)
-    httpretty.register_uri(
-        method=httpretty.GET,
-        uri=token_server,
-        body=get_token_callback)
+    httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)
 
     redirect_handler = RedirectHandler()
 
@@ -359,10 +366,11 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
             user="test",
         ),
         http_scheme=constants.HTTPS,
-        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler))
+        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
+    )
     response = request.post("select 1")
 
-    assert response.request.headers['Authorization'] == f"Bearer {token}"
+    assert response.request.headers["Authorization"] == f"Bearer {token}"
     assert redirect_handler.redirect_server == redirect_server
     assert get_token_callback.attempts == 0
     assert len(_post_statement_requests()) == 2
@@ -378,20 +386,22 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
     redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
     token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
 
-    post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
+    post_statement_callback = PostStatementCallback(
+        redirect_server, token_server, [token], sample_post_response_data
+    )
 
     # bind post statement
     httpretty.register_uri(
         method=httpretty.POST,
         uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
-        body=post_statement_callback)
+        body=post_statement_callback,
+    )
 
     # bind get token
     get_token_callback = GetTokenCallback(token_server, token, attempts)
     httpretty.register_uri(
-        method=httpretty.GET,
-        uri=f"{TOKEN_RESOURCE}/{challenge_id}",
-        body=get_token_callback)
+        method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", body=get_token_callback
+    )
 
     redirect_handler = RedirectHandler()
 
@@ -402,7 +412,8 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
             user="test",
         ),
         http_scheme=constants.HTTPS,
-        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler))
+        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
+    )
     with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
         request.post("select 1")
 
@@ -413,21 +424,34 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
     assert len(_get_token_requests(challenge_id)) == _OAuth2TokenBearer.MAX_OAUTH_ATTEMPTS
 
 
-@pytest.mark.parametrize("header,error", [
-    ("", "Error: header WWW-Authenticate not available in the response."),
-    ('Bearer"', 'Error: header info didn\'t have x_redirect_server'),
-    ('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'),  # noqa: E501
-    ('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'),
-    ('Bearer x_token_server="token_server"', 'Error: header info didn\'t have x_redirect_server'),
-])
+@pytest.mark.parametrize(
+    "header,error",
+    [
+        ("", "Error: header WWW-Authenticate not available in the response."),
+        ('Bearer"', "Error: header info didn't have x_redirect_server"),
+        (
+            'x_redirect_server="redirect_server", x_token_server="token_server"',
+            'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"',
+        ),  # noqa: E501
+        (
+            'Bearer x_redirect_server="redirect_server"',
+            "Error: header info didn't have x_token_server",
+        ),
+        (
+            'Bearer x_token_server="token_server"',
+            "Error: header info didn't have x_redirect_server",
+        ),
+    ],
+)
 @httprettified
 def test_oauth2_authentication_missing_headers(header, error):
     # bind post statement
     httpretty.register_uri(
         method=httpretty.POST,
         uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
-        adding_headers={'WWW-Authenticate': header},
-        status=401)
+        adding_headers={"WWW-Authenticate": header},
+        status=401,
+    )
 
     request = TrinoRequest(
         host="coordinator",
@@ -436,7 +460,8 @@ def test_oauth2_authentication_missing_headers(header, error):
             user="test",
         ),
         http_scheme=constants.HTTPS,
-        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=RedirectHandler()))
+        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=RedirectHandler()),
+    )
 
     with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
         request.post("select 1")
@@ -444,13 +469,16 @@ def test_oauth2_authentication_missing_headers(header, error):
     assert str(exp.value) == error
 
 
-@pytest.mark.parametrize("header", [
-    'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge',
-    'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge="value"',
-    'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"',
-    'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"',
-    'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"',
-])
+@pytest.mark.parametrize(
+    "header",
+    [
+        'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge',
+        'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge="value"',
+        'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"',
+        'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"',
+        'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"',
+    ],
+)
 @httprettified
 def test_oauth2_header_parsing(header, sample_post_response_data):
     token = str(uuid.uuid4())
@@ -464,21 +492,27 @@ def post_statement(request, uri, response_headers):
         authorization = request.headers.get("Authorization")
         if authorization and authorization.replace("Bearer ", "") in token:
             return [200, response_headers, json.dumps(sample_post_response_data)]
-        return [401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server),
-                      'Basic realm': '"Trino"'}, ""]
+        return [
+            401,
+            {
+                "Www-Authenticate": header.format(
+                    redirect_server=redirect_server, token_server=token_server
+                ),
+                "Basic realm": '"Trino"',
+            },
+            "",
+        ]
 
     # bind post statement
     httpretty.register_uri(
         method=httpretty.POST,
         uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
-        body=post_statement)
+        body=post_statement,
+    )
 
     # bind get token
     get_token_callback = GetTokenCallback(token_server, token)
-    httpretty.register_uri(
-        method=httpretty.GET,
-        uri=token_server,
-        body=get_token_callback)
+    httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)
 
     redirect_handler = RedirectHandler()
 
@@ -489,10 +523,10 @@ def post_statement(request, uri, response_headers):
             user="test",
         ),
         http_scheme=constants.HTTPS,
-        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)
+        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
     ).post("select 1")
 
-    assert response.request.headers['Authorization'] == f"Bearer {token}"
+    assert response.request.headers["Authorization"] == f"Bearer {token}"
     assert redirect_handler.redirect_server == redirect_server
     assert get_token_callback.attempts == 0
     assert len(_post_statement_requests()) == 2
@@ -508,19 +542,23 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
     redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
     token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
 
-    post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
+    post_statement_callback = PostStatementCallback(
+        redirect_server, token_server, [token], sample_post_response_data
+    )
 
     # bind post statement
     httpretty.register_uri(
         method=httpretty.POST,
         uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
-        body=post_statement_callback)
+        body=post_statement_callback,
+    )
 
     httpretty.register_uri(
         method=httpretty.GET,
         uri=f"{TOKEN_RESOURCE}/{challenge_id}",
         status=http_status,
-        body="error")
+        body="error",
+    )
 
     redirect_handler = RedirectHandler()
 
@@ -531,13 +569,17 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
             user="test",
         ),
         http_scheme=constants.HTTPS,
-        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler))
+        auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
+    )
 
     with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
         request.post("select 1")
 
     assert redirect_handler.redirect_server == redirect_server
-    assert str(exp.value) == f"Error while getting the token response status code: {http_status}, body: error"
+    assert (
+        str(exp.value)
+        == f"Error while getting the token response status code: {http_status}, body: error"
+    )
     assert len(_post_statement_requests()) == 1
     assert len(_get_token_requests(challenge_id)) == 1
 
@@ -564,7 +606,8 @@ def run(self) -> None:
                     user="test",
                 ),
                 http_scheme=constants.HTTPS,
-                auth=auth)
+                auth=auth,
+            )
             for i in range(10):
                 # apparently HTTPretty in the current version is not thread-safe
                 # https://github.com/gabrielfalcao/HTTPretty/issues/209
@@ -650,10 +693,7 @@ def test_trino_fetch_error(mock_requests, sample_get_error_response_data):
     assert "stack" in error.failure_info
     assert len(error.failure_info["stack"]) == 36
     assert "suppressed" in error.failure_info
-    assert (
-        error.message
-        == "line 1:15: Schema must be specified when session schema is not set"
-    )
+    assert error.message == "line 1:15: Schema must be specified when session schema is not set"
     assert error.error_location == (1, 15)
     assert error.query_id == "20210817_140827_00000_arvdv"
 
@@ -803,11 +843,14 @@ def test_authentication_fail_retry(monkeypatch):
     assert post_retry.retry_count == attempts
 
 
-@pytest.mark.parametrize("status_code, attempts", [
-    (502, 3),
-    (503, 3),
-    (504, 3),
-])
+@pytest.mark.parametrize(
+    "status_code, attempts",
+    [
+        (502, 3),
+        (503, 3),
+        (504, 3),
+    ],
+)
 def test_5XX_error_retry(status_code, attempts, monkeypatch):
     http_resp = TrinoRequest.http.Response()
     http_resp.status_code = status_code
@@ -824,7 +867,7 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch):
         client_session=ClientSession(
             user="test",
         ),
-        max_attempts=attempts
+        max_attempts=attempts,
     )
 
     req.post("URL")
@@ -834,9 +877,7 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch):
     assert post_retry.retry_count == attempts
 
 
-@pytest.mark.parametrize("status_code", [
-    501
-])
+@pytest.mark.parametrize("status_code", [501])
 def test_error_no_retry(status_code, monkeypatch):
     http_resp = TrinoRequest.http.Response()
     http_resp.status_code = status_code
@@ -887,10 +928,12 @@ def test_trino_result_response_headers():
     headers associated to the TrinoQuery instance provided to the `TrinoResult`
     class.
     """
-    mock_trino_query = mock.Mock(respone_headers={
-        'X-Trino-Fake-1': 'one',
-        'X-Trino-Fake-2': 'two',
-    })
+    mock_trino_query = mock.Mock(
+        respone_headers={
+            "X-Trino-Fake-1": "one",
+            "X-Trino-Fake-2": "two",
+        }
+    )
 
     result = TrinoResult(
         query=mock_trino_query,
@@ -910,8 +953,8 @@ class MockResponse(mock.Mock):
         @property
         def headers(self):
             return {
-                'X-Trino-Fake-1': 'one',
-                'X-Trino-Fake-2': 'two',
+                "X-Trino-Fake-1": "one",
+                "X-Trino-Fake-2": "two",
             }
 
         def json(self):
@@ -930,18 +973,15 @@ def json(self):
         http_scheme="http",
     )
 
-    sql = 'execute my_stament using 1, 2, 3'
+    sql = "execute my_stament using 1, 2, 3"
     additional_headers = {
-        constants.HEADER_PREPARED_STATEMENT: 'my_statement=added_prepare_statement_header'
+        constants.HEADER_PREPARED_STATEMENT: "my_statement=added_prepare_statement_header"
     }
 
     # Patch the post function to avoid making the requests, as well as to
     # validate that the function was called with the right arguments.
-    with mock.patch.object(req, 'post', return_value=MockResponse()) as mock_post:
-        query = TrinoQuery(
-            request=req,
-            sql=sql
-        )
+    with mock.patch.object(req, "post", return_value=MockResponse()) as mock_post:
+        query = TrinoQuery(request=req, sql=sql)
         result = query.execute(additional_http_headers=additional_headers)
 
         # Validate the the post function was called with the right argguments
diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py
index 7b1c72c2..873c75ba 100644
--- a/tests/unit/test_dbapi.py
+++ b/tests/unit/test_dbapi.py
@@ -17,8 +17,16 @@
 from httpretty import httprettified
 from requests import Session
 
-from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \
-    GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS
+from tests.unit.oauth_test_utils import (
+    REDIRECT_RESOURCE,
+    SERVER_ADDRESS,
+    TOKEN_RESOURCE,
+    GetTokenCallback,
+    PostStatementCallback,
+    RedirectHandler,
+    _get_token_requests,
+    _post_statement_requests,
+)
 from trino import constants
 from trino.auth import OAuth2Authentication
 from trino.dbapi import connect
@@ -58,28 +66,28 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data):
     redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
     token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
 
-    post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
+    post_statement_callback = PostStatementCallback(
+        redirect_server, token_server, [token], sample_post_response_data
+    )
 
     # bind post statement
     httpretty.register_uri(
         method=httpretty.POST,
         uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
-        body=post_statement_callback)
+        body=post_statement_callback,
+    )
 
     # bind get token
     get_token_callback = GetTokenCallback(token_server, token)
-    httpretty.register_uri(
-        method=httpretty.GET,
-        uri=token_server,
-        body=get_token_callback)
+    httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)
 
     redirect_handler = RedirectHandler()
 
     with connect(
-            "coordinator",
-            user="test",
-            auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
-            http_scheme=constants.HTTPS
+        "coordinator",
+        user="test",
+        auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
+        http_scheme=constants.HTTPS,
     ) as conn:
         conn.cursor().execute("SELECT 1")
         conn.cursor().execute("SELECT 2")
@@ -87,18 +95,15 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data):
 
     # bind get token
     get_token_callback = GetTokenCallback(token_server, token)
-    httpretty.register_uri(
-        method=httpretty.GET,
-        uri=token_server,
-        body=get_token_callback)
+    httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)
 
     redirect_handler = RedirectHandler()
 
     with connect(
-            "coordinator",
-            user="test",
-            auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
-            http_scheme=constants.HTTPS
+        "coordinator",
+        user="test",
+        auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
+        http_scheme=constants.HTTPS,
     ) as conn2:
         conn2.cursor().execute("SELECT 1")
         conn2.cursor().execute("SELECT 2")
@@ -115,30 +120,27 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post
     redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
     token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
 
-    post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
+    post_statement_callback = PostStatementCallback(
+        redirect_server, token_server, [token], sample_post_response_data
+    )
 
     # bind post statement
     httpretty.register_uri(
         method=httpretty.POST,
         uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
-        body=post_statement_callback)
+        body=post_statement_callback,
+    )
 
     # bind get token
     get_token_callback = GetTokenCallback(token_server, token)
-    httpretty.register_uri(
-        method=httpretty.GET,
-        uri=token_server,
-        body=get_token_callback)
+    httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)
 
     redirect_handler = RedirectHandler()
 
     authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler)
 
     with connect(
-            "coordinator",
-            user="test",
-            auth=authentication,
-            http_scheme=constants.HTTPS
+        "coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS
     ) as conn:
         conn.cursor().execute("SELECT 1")
         conn.cursor().execute("SELECT 2")
@@ -146,16 +148,10 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post
 
     # bind get token
     get_token_callback = GetTokenCallback(token_server, token)
-    httpretty.register_uri(
-        method=httpretty.GET,
-        uri=token_server,
-        body=get_token_callback)
+    httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)
 
     with connect(
-            "coordinator",
-            user="test",
-            auth=authentication,
-            http_scheme=constants.HTTPS
+        "coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS
     ) as conn2:
         conn2.cursor().execute("SELECT 1")
         conn2.cursor().execute("SELECT 2")
@@ -173,31 +169,26 @@ def test_token_retrieved_once_when_multithreaded(sample_post_response_data):
     redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
     token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
 
-    post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
+    post_statement_callback = PostStatementCallback(
+        redirect_server, token_server, [token], sample_post_response_data
+    )
 
     # bind post statement
     httpretty.register_uri(
         method=httpretty.POST,
         uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
-        body=post_statement_callback)
+        body=post_statement_callback,
+    )
 
     # bind get token
     get_token_callback = GetTokenCallback(token_server, token)
-    httpretty.register_uri(
-        method=httpretty.GET,
-        uri=token_server,
-        body=get_token_callback)
+    httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)
 
     redirect_handler = RedirectHandler()
 
     authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler)
 
-    conn = connect(
-        "coordinator",
-        user="test",
-        auth=authentication,
-        http_scheme=constants.HTTPS
-    )
+    conn = connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS)
 
     class RunningThread(threading.Thread):
         lock = threading.Lock()
@@ -209,11 +200,7 @@ def run(self) -> None:
             with RunningThread.lock:
                 conn.cursor().execute("SELECT 1")
 
-    threads = [
-        RunningThread(),
-        RunningThread(),
-        RunningThread()
-    ]
+    threads = [RunningThread(), RunningThread(), RunningThread()]
 
     # run and join all threads
     for thread in threads:
diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py
index fd3c5e2d..9753bab5 100644
--- a/tests/unit/test_http.py
+++ b/tests/unit/test_http.py
@@ -10,8 +10,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from trino.client import get_header_values, get_session_property_values
 from trino import constants
+from trino.client import get_header_values, get_session_property_values
 
 
 def test_get_header_values():
diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py
index e97903cf..02adeb9d 100644
--- a/tests/unit/test_transaction.py
+++ b/tests/unit/test_transaction.py
@@ -10,9 +10,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from trino.transaction import IsolationLevel
 import pytest
 
+from trino.transaction import IsolationLevel
+
 
 def test_isolation_level_levels() -> None:
     levels = {
@@ -27,9 +28,7 @@ def test_isolation_level_levels() -> None:
 
 
 def test_isolation_level_values() -> None:
-    values = {
-        0, 1, 2, 3, 4
-    }
+    values = {0, 1, 2, 3, 4}
 
     assert IsolationLevel.values() == values
 
diff --git a/trino/__init__.py b/trino/__init__.py
index 4ff3e55b..42db0646 100644
--- a/trino/__init__.py
+++ b/trino/__init__.py
@@ -10,13 +10,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from . import auth
-from . import dbapi
-from . import client
-from . import constants
-from . import exceptions
-from . import logging
+from . import auth, client, constants, dbapi, exceptions, logging
 
-__all__ = ['auth', 'dbapi', 'client', 'constants', 'exceptions', 'logging']
+__all__ = ["auth", "dbapi", "client", "constants", "exceptions", "logging"]
 
 __version__ = "0.315.0"
diff --git a/trino/auth.py b/trino/auth.py
index e6b4f04c..e2c5d3fa 100644
--- a/trino/auth.py
+++ b/trino/auth.py
@@ -11,18 +11,18 @@
 # limitations under the License.
 
 import abc
+import importlib
 import json
 import os
 import re
 import threading
 import webbrowser
-from typing import Optional, List, Callable
+from typing import Callable, List, Optional
 from urllib.parse import urlparse
 
 from requests import Request
 from requests.auth import AuthBase, extract_cookies_to_jar
 from requests.utils import parse_dict_header
-import importlib
 
 import trino.logging
 from trino.client import exceptions
@@ -95,15 +95,17 @@ def get_exceptions(self):
     def __eq__(self, other):
         if not isinstance(other, KerberosAuthentication):
             return False
-        return (self._config == other._config
-                and self._service_name == other._service_name
-                and self._mutual_authentication == other._mutual_authentication
-                and self._force_preemptive == other._force_preemptive
-                and self._hostname_override == other._hostname_override
-                and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response
-                and self._principal == other._principal
-                and self._delegate == other._delegate
-                and self._ca_bundle == other._ca_bundle)
+        return (
+            self._config == other._config
+            and self._service_name == other._service_name
+            and self._mutual_authentication == other._mutual_authentication
+            and self._force_preemptive == other._force_preemptive
+            and self._hostname_override == other._hostname_override
+            and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response
+            and self._principal == other._principal
+            and self._delegate == other._delegate
+            and self._ca_bundle == other._ca_bundle
+        )
 
 
 class BasicAuthentication(Authentication):
@@ -143,7 +145,6 @@ def __call__(self, r):
 
 
 class JWTAuthentication(Authentication):
-
     def __init__(self, token):
         self.token = token
 
@@ -251,31 +252,38 @@ def get_token_from_cache(self, host: str) -> Optional[str]:
         try:
             return self._keyring.get_password(host, "token")
         except self._keyring.errors.NoKeyringError as e:
-            raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
-                                                     "detected, check https://pypi.org/project/keyring/ for more "
-                                                     "information.") from e
+            raise trino.exceptions.NotSupportedError(
+                "Although keyring module is installed no backend has been "
+                "detected, check https://pypi.org/project/keyring/ for more "
+                "information."
+            ) from e
 
     def store_token_to_cache(self, host: str, token: str) -> None:
         try:
             # keyring is installed, so we can store the token for reuse within multiple threads
             self._keyring.set_password(host, "token", token)
         except self._keyring.errors.NoKeyringError as e:
-            raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
-                                                     "detected, check https://pypi.org/project/keyring/ for more "
-                                                     "information.") from e
+            raise trino.exceptions.NotSupportedError(
+                "Although keyring module is installed no backend has been "
+                "detected, check https://pypi.org/project/keyring/ for more "
+                "information."
+            ) from e
 
 
 class _OAuth2TokenBearer(AuthBase):
     """
     Custom implementation of Trino Oauth2 based authorization to get the token
     """
+
     MAX_OAUTH_ATTEMPTS = 5
     _BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)
 
     def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
         self._redirect_auth_url = redirect_auth_url_handler
         keyring_cache = _OAuth2KeyRingTokenCache()
-        self._token_cache = keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache()
+        self._token_cache = (
+            keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache()
+        )
         self._token_lock = threading.Lock()
         self._inside_oauth_attempt_lock = threading.Lock()
         self._inside_oauth_attempt_blocker = threading.Event()
@@ -285,9 +293,9 @@ def __call__(self, r):
         token = self._get_token_from_cache(host)
 
         if token is not None:
-            r.headers['Authorization'] = "Bearer " + token
+            r.headers["Authorization"] = "Bearer " + token
 
-        r.register_hook('response', self._authenticate)
+        r.register_hook("response", self._authenticate)
 
         return r
 
@@ -312,20 +320,24 @@ def _authenticate(self, response, **kwargs):
 
     def _attempt_oauth(self, response, **kwargs):
         # we have to handle the authentication, may be token the token expired, or it wasn't there at all
-        auth_info = response.headers.get('WWW-Authenticate')
+        auth_info = response.headers.get("WWW-Authenticate")
         if not auth_info:
-            raise exceptions.TrinoAuthError("Error: header WWW-Authenticate not available in the response.")
+            raise exceptions.TrinoAuthError(
+                "Error: header WWW-Authenticate not available in the response."
+            )
 
         if not _OAuth2TokenBearer._BEARER_PREFIX.search(auth_info):
             raise exceptions.TrinoAuthError(f"Error: header info didn't match {auth_info}")
 
-        auth_info_headers = parse_dict_header(_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1))
+        auth_info_headers = parse_dict_header(
+            _OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)
+        )
 
-        auth_server = auth_info_headers.get('x_redirect_server')
+        auth_server = auth_info_headers.get("x_redirect_server")
         if auth_server is None:
             raise exceptions.TrinoAuthError("Error: header info didn't have x_redirect_server")
 
-        token_server = auth_info_headers.get('x_token_server')
+        token_server = auth_info_headers.get("x_token_server")
         if token_server is None:
             raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")
 
@@ -349,7 +361,7 @@ def _retry_request(self, response, **kwargs):
         request.prepare_cookies(request._cookies)
 
         host = self._determine_host(response.request.url)
-        request.headers['Authorization'] = "Bearer " + self._get_token_from_cache(host)
+        request.headers["Authorization"] = "Bearer " + self._get_token_from_cache(host)
         retry_response = response.connection.send(request, **kwargs)
         retry_response.history.append(response)
         retry_response.request = request
@@ -359,23 +371,26 @@ def _get_token(self, token_server, response, **kwargs):
         attempts = 0
         while attempts < self.MAX_OAUTH_ATTEMPTS:
             attempts += 1
-            with response.connection.send(Request(method='GET', url=token_server).prepare(), **kwargs) as response:
+            with response.connection.send(
+                Request(method="GET", url=token_server).prepare(), **kwargs
+            ) as response:
                 if response.status_code == 200:
                     token_response = json.loads(response.text)
-                    token = token_response.get('token')
+                    token = token_response.get("token")
                     if token:
                         return token
-                    error = token_response.get('error')
+                    error = token_response.get("error")
                     if error:
                         raise exceptions.TrinoAuthError(f"Error while getting the token: {error}")
                     else:
-                        token_server = token_response.get('nextUri')
+                        token_server = token_response.get("nextUri")
                         logger.debug(f"nextURi auth token server: {token_server}")
                 else:
                     raise exceptions.TrinoAuthError(
                         f"Error while getting the token response "
                         f"status code: {response.status_code}, "
-                        f"body: {response.text}")
+                        f"body: {response.text}"
+                    )
 
         raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token")
 
@@ -393,10 +408,12 @@ def _determine_host(url) -> Optional[str]:
 
 
 class OAuth2Authentication(Authentication):
-    def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([
-        WebBrowserRedirectHandler(),
-        ConsoleRedirectHandler()
-    ])):
+    def __init__(
+        self,
+        redirect_auth_url_handler=CompositeRedirectHandler(
+            [WebBrowserRedirectHandler(), ConsoleRedirectHandler()]
+        ),
+    ):
         self._redirect_auth_url = redirect_auth_url_handler
         self._bearer = _OAuth2TokenBearer(self._redirect_auth_url)
 
diff --git a/trino/client.py b/trino/client.py
index 211ce9c0..ba7aca34 100644
--- a/trino/client.py
+++ b/trino/client.py
@@ -62,7 +62,7 @@
 else:
     PROXIES = {}
 
-_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')
+_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$")
 
 INF = float("inf")
 NEGATIVE_INF = float("-inf")
@@ -214,8 +214,7 @@ def get_header_values(headers, header):
 def get_session_property_values(headers, header):
     kvs = get_header_values(headers, header)
     return [
-        (k.strip(), urllib.parse.unquote(v.strip()))
-        for k, v in (kv.split("=", 1) for kv in kvs)
+        (k.strip(), urllib.parse.unquote(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs)
     ]
 
 
@@ -245,16 +244,14 @@ def __repr__(self):
 
 
 class _DelayExponential(object):
-    def __init__(
-            self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600  # 100ms  # 2 hours
-    ):
+    def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600):  # 100ms  # 2 hours
         self._base = base
         self._exponent = exponent
         self._jitter = jitter
         self._max_delay = max_delay
 
     def __call__(self, attempt):
-        delay = float(self._base) * (self._exponent ** attempt)
+        delay = float(self._base) * (self._exponent**attempt)
         if self._jitter:
             delay *= random.random()
         delay = min(float(self._max_delay), delay)
@@ -262,9 +259,7 @@ def __call__(self, attempt):
 
 
 class _RetryWithExponentialBackoff(object):
-    def __init__(
-            self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600  # 100ms  # 2 hours
-    ):
+    def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600):  # 100ms  # 2 hours
         self._get_delay = _DelayExponential(base, exponent, jitter, max_delay)
 
     def retry(self, func, args, kwargs, err, attempt):
@@ -383,7 +378,10 @@ def http_headers(self) -> Dict[str, str]:
         headers[constants.HEADER_SOURCE] = self._client_session.source
         headers[constants.HEADER_USER] = self._client_session.user
         headers[constants.HEADER_ROLE] = self._client_session.role
-        if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
+        if (
+            self._client_session.client_tags is not None
+            and len(self._client_session.client_tags) > 0
+        ):
             headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)
 
         headers[constants.HEADER_SESSION] = ",".join(
@@ -401,8 +399,10 @@ def http_headers(self) -> Dict[str, str]:
         transaction_id = self._client_session.transaction_id
         headers[constants.HEADER_TRANSACTION] = transaction_id
 
-        if self._client_session.extra_credential is not None and \
-                len(self._client_session.extra_credential) > 0:
+        if (
+            self._client_session.extra_credential is not None
+            and len(self._client_session.extra_credential) > 0
+        ):
 
             for tup in self._client_session.extra_credential:
                 self._verify_extra_credential(tup)
@@ -410,9 +410,12 @@ def http_headers(self) -> Dict[str, str]:
             # HTTP 1.1 section 4.2 combine multiple extra credentials into a
             # comma-separated value
             # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
-            headers[constants.HEADER_EXTRA_CREDENTIAL] = \
-                ", ".join(
-                    [f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential])
+            headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join(
+                [
+                    f"{tup[0]}={urllib.parse.quote_plus(tup[1])}"
+                    for tup in self._client_session.extra_credential
+                ]
+            )
 
         return headers
 
@@ -536,9 +539,7 @@ def process(self, http_response) -> TrinoStatus:
             raise self._process_error(response["error"], response.get("id"))
 
         if constants.HEADER_CLEAR_SESSION in http_response.headers:
-            for prop in get_header_values(
-                http_response.headers, constants.HEADER_CLEAR_SESSION
-            ):
+            for prop in get_header_values(http_response.headers, constants.HEADER_CLEAR_SESSION):
                 self._client_session.properties.pop(prop, None)
 
         if constants.HEADER_SET_SESSION in http_response.headers:
@@ -579,7 +580,7 @@ def _verify_extra_credential(self, header):
             raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'")
 
         try:
-            key.encode().decode('ascii')
+            key.encode().decode("ascii")
         except UnicodeDecodeError:
             raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")
 
@@ -643,9 +644,13 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
                     raw_type = {
                         "typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
                     }
-                    return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
+                    return [
+                        cls._map_to_python_type((array_item, raw_type)) for array_item in value
+                    ]
                 if raw_type == "row":
-                    raw_types = map(lambda arg: arg["value"], data_type["typeSignature"]["arguments"])
+                    raw_types = map(
+                        lambda arg: arg["value"], data_type["typeSignature"]["arguments"]
+                    )
                     return tuple(
                         cls._map_to_python_type((array_item, raw_type))
                         for (array_item, raw_type) in zip(value, raw_types)
@@ -659,44 +664,53 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
                     "typeSignature": data_type["typeSignature"]["arguments"][1]["value"]
                 }
                 return {
-                    cls._map_to_python_type((key, raw_key_type)):
-                        cls._map_to_python_type((value[key], raw_value_type))
+                    cls._map_to_python_type((key, raw_key_type)): cls._map_to_python_type(
+                        (value[key], raw_value_type)
+                    )
                     for key in value
                 }
             elif "decimal" in raw_type:
                 return Decimal(value)
             elif raw_type == "double":
-                if value == 'Infinity':
+                if value == "Infinity":
                     return INF
-                elif value == '-Infinity':
+                elif value == "-Infinity":
                     return NEGATIVE_INF
-                elif value == 'NaN':
+                elif value == "NaN":
                     return NAN
                 return value
             elif raw_type == "date":
                 return datetime.strptime(value, "%Y-%m-%d").date()
             elif raw_type == "timestamp with time zone":
-                dt, tz = value.rsplit(' ', 1)
-                if tz.startswith('+') or tz.startswith('-'):
+                dt, tz = value.rsplit(" ", 1)
+                if tz.startswith("+") or tz.startswith("-"):
                     return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f %z")
-                return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=pytz.timezone(tz))
+                return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(
+                    tzinfo=pytz.timezone(tz)
+                )
             elif "timestamp" in raw_type:
                 return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f")
             elif "time with time zone" in raw_type:
-                matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value)
+                matches = re.match(r"^(.*)([\+\-])(\d{2}):(\d{2})$", value)
                 assert matches is not None
                 assert len(matches.groups()) == 4
-                if matches.group(2) == '-':
+                if matches.group(2) == "-":
                     tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
                 else:
                     tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
-                return datetime.strptime(matches.group(1), "%H:%M:%S.%f").time().replace(tzinfo=timezone(tz))
+                return (
+                    datetime.strptime(matches.group(1), "%H:%M:%S.%f")
+                    .time()
+                    .replace(tzinfo=timezone(tz))
+                )
             elif "time" in raw_type:
                 return datetime.strptime(value, "%H:%M:%S.%f").time()
             else:
                 return value
         except ValueError as e:
-            error_str = f"Could not convert '{value}' into the associated python type for '{raw_type}'"
+            error_str = (
+                f"Could not convert '{value}' into the associated python type for '{raw_type}'"
+            )
             raise trino.exceptions.TrinoDataError(error_str) from e
 
     def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]:
@@ -707,10 +721,10 @@ class TrinoQuery(object):
     """Represent the execution of a SQL statement by Trino."""
 
     def __init__(
-            self,
-            request: TrinoRequest,
-            sql: str,
-            experimental_python_types: bool = False,
+        self,
+        request: TrinoRequest,
+        sql: str,
+        experimental_python_types: bool = False,
     ) -> None:
         self.query_id: Optional[str] = None
 
@@ -814,6 +828,7 @@ def cancel(self) -> None:
 
     def is_finished(self) -> bool:
         import warnings
+
         warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning)
         return self.finished
 
diff --git a/trino/constants.py b/trino/constants.py
index 30046908..41104fdb 100644
--- a/trino/constants.py
+++ b/trino/constants.py
@@ -12,7 +12,6 @@
 
 from typing import Any, Optional
 
-
 DEFAULT_PORT = 8080
 DEFAULT_TLS_PORT = 443
 DEFAULT_SOURCE = "trino-python-client"
@@ -45,9 +44,9 @@
 HEADER_STARTED_TRANSACTION = "X-Trino-Started-Transaction-Id"
 HEADER_TRANSACTION = "X-Trino-Transaction-Id"
 
-HEADER_PREPARED_STATEMENT = 'X-Trino-Prepared-Statement'
-HEADER_ADDED_PREPARE = 'X-Trino-Added-Prepare'
-HEADER_DEALLOCATED_PREPARE = 'X-Trino-Deallocated-Prepare'
+HEADER_PREPARED_STATEMENT = "X-Trino-Prepared-Statement"
+HEADER_ADDED_PREPARE = "X-Trino-Added-Prepare"
+HEADER_DEALLOCATED_PREPARE = "X-Trino-Deallocated-Prepare"
 
 HEADER_SET_SCHEMA = "X-Trino-Set-Schema"
 HEADER_SET_CATALOG = "X-Trino-Set-Catalog"
diff --git a/trino/dbapi.py b/trino/dbapi.py
index 44813168..6828455d 100644
--- a/trino/dbapi.py
+++ b/trino/dbapi.py
@@ -17,31 +17,30 @@
 Fetch methods returns rows as a list of lists on purpose to let the caller
 decide to convert then to a list of tuples.
 """
-from decimal import Decimal
-from typing import Any, List, Optional  # NOQA for mypy types
-
 import copy
-import uuid
 import datetime
 import math
+import uuid
+from decimal import Decimal
+from typing import Any, List, Optional  # NOQA for mypy types
 
-from trino import constants
-import trino.exceptions
 import trino.client
+import trino.exceptions
 import trino.logging
-from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION
+from trino import constants
 from trino.exceptions import (
-    Warning,
-    Error,
-    InterfaceError,
     DatabaseError,
     DataError,
-    OperationalError,
+    Error,
     IntegrityError,
+    InterfaceError,
     InternalError,
-    ProgrammingError,
     NotSupportedError,
+    OperationalError,
+    ProgrammingError,
+    Warning,
 )
+from trino.transaction import NO_TRANSACTION, IsolationLevel, Transaction
 
 __all__ = [
     # https://www.python.org/dev/peps/pep-0249/#globals
@@ -128,7 +127,7 @@ def __init__(
             headers=http_headers,
             transaction_id=NO_TRANSACTION,
             extra_credential=extra_credential,
-            client_tags=client_tags
+            client_tags=client_tags,
         )
         # mypy cannot follow module import
         if http_session is None:
@@ -217,7 +216,9 @@ def cursor(self, experimental_python_types: bool = None):
             self,
             request,
             # if experimental_python_types is not explicitly set in Cursor, take from Connection
-            experimental_python_types if experimental_python_types is not None else self.experimental_python_types
+            experimental_python_types
+            if experimental_python_types is not None
+            else self.experimental_python_types,
         )
 
 
@@ -231,9 +232,7 @@ class Cursor(object):
 
     def __init__(self, connection, request, experimental_python_types: bool = False):
         if not isinstance(connection, Connection):
-            raise ValueError(
-                "connection must be a Connection object: {}".format(type(connection))
-            )
+            raise ValueError("connection must be a Connection object: {}".format(type(connection)))
         self._connection = connection
         self._request = request
 
@@ -268,8 +267,7 @@ def description(self):
 
         # [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
         return [
-            (col["name"], col["type"], None, None, None, None, None)
-            for col in self._query.columns
+            (col["name"], col["type"], None, None, None, None, None) for col in self._query.columns
         ]
 
     @property
@@ -317,15 +315,17 @@ def _prepare_statement(self, operation, statement_name):
         :return: string representing the value of the 'X-Trino-Added-Prepare'
             header.
         """
-        sql = 'PREPARE {statement_name} FROM {operation}'.format(
-            statement_name=statement_name,
-            operation=operation
+        sql = "PREPARE {statement_name} FROM {operation}".format(
+            statement_name=statement_name, operation=operation
         )
 
         # Send prepare statement. Copy the _request object to avoid poluting the
         # one that is going to be used to execute the actual operation.
-        query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
-                                        experimental_python_types=self._experimental_pyton_types)
+        query = trino.client.TrinoQuery(
+            copy.deepcopy(self._request),
+            sql=sql,
+            experimental_python_types=self._experimental_pyton_types,
+        )
         result = query.execute()
 
         # Iterate until the 'X-Trino-Added-Prepare' header is found or
@@ -338,16 +338,19 @@ def _prepare_statement(self, operation, statement_name):
 
         raise trino.exceptions.FailedToObtainAddedPrepareHeader
 
-    def _get_added_prepare_statement_trino_query(
-        self,
-        statement_name,
-        params
-    ):
-        sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
+    def _get_added_prepare_statement_trino_query(self, statement_name, params):
+        sql = (
+            "EXECUTE "
+            + statement_name
+            + " USING "
+            + ",".join(map(self._format_prepared_param, params))
+        )
 
         # No need to deepcopy _request here because this is the actual request
         # operation
-        return trino.client.TrinoQuery(self._request, sql=sql, experimental_python_types=self._experimental_pyton_types)
+        return trino.client.TrinoQuery(
+            self._request, sql=sql, experimental_python_types=self._experimental_pyton_types
+        )
 
     def _format_prepared_param(self, param):
         """
@@ -374,7 +377,7 @@ def _format_prepared_param(self, param):
             return "DOUBLE '%s'" % param
 
         if isinstance(param, str):
-            return ("'%s'" % param.replace("'", "''"))
+            return "'%s'" % param.replace("'", "''")
 
         if isinstance(param, bytes):
             return "X'%s'" % param.hex()
@@ -386,7 +389,7 @@ def _format_prepared_param(self, param):
         if isinstance(param, datetime.datetime) and param.tzinfo is not None:
             datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
             # named timezones
-            if hasattr(param.tzinfo, 'zone'):
+            if hasattr(param.tzinfo, "zone"):
                 return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.zone)
             # offset-based timezones
             return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param))
@@ -401,17 +404,16 @@ def _format_prepared_param(self, param):
             return "DATE '%s'" % date_str
 
         if isinstance(param, list):
-            return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param))
+            return "ARRAY[%s]" % ",".join(map(self._format_prepared_param, param))
 
         if isinstance(param, tuple):
-            return "ROW(%s)" % ','.join(map(self._format_prepared_param, param))
+            return "ROW(%s)" % ",".join(map(self._format_prepared_param, param))
 
         if isinstance(param, dict):
             keys = list(param.keys())
             values = [param[key] for key in keys]
             return "MAP({}, {})".format(
-                self._format_prepared_param(keys),
-                self._format_prepared_param(values)
+                self._format_prepared_param(keys), self._format_prepared_param(values)
             )
 
         if isinstance(param, uuid.UUID):
@@ -420,19 +422,22 @@ def _format_prepared_param(self, param):
         if isinstance(param, Decimal):
             return "DECIMAL '%s'" % param
 
-        raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))
+        raise trino.exceptions.NotSupportedError(
+            "Query parameter of type '%s' is not supported." % type(param)
+        )
 
     def _deallocate_prepare_statement(self, added_prepare_header, statement_name):
-        sql = 'DEALLOCATE PREPARE ' + statement_name
+        sql = "DEALLOCATE PREPARE " + statement_name
 
         # Send deallocate statement. Copy the _request object to avoid poluting the
         # one that is going to be used to execute the actual operation.
-        query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
-                                        experimental_python_types=self._experimental_pyton_types)
+        query = trino.client.TrinoQuery(
+            copy.deepcopy(self._request),
+            sql=sql,
+            experimental_python_types=self._experimental_pyton_types,
+        )
         result = query.execute(
-            additional_http_headers={
-                constants.HEADER_PREPARED_STATEMENT: added_prepare_header
-            }
+            additional_http_headers={constants.HEADER_PREPARED_STATEMENT: added_prepare_header}
         )
 
         # Iterate until the 'X-Trino-Deallocated-Prepare' header is found or
@@ -446,27 +451,22 @@ def _deallocate_prepare_statement(self, added_prepare_header, statement_name):
         raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader
 
     def _generate_unique_statement_name(self):
-        return 'st_' + uuid.uuid4().hex.replace('-', '')
+        return "st_" + uuid.uuid4().hex.replace("-", "")
 
     def execute(self, operation, params=None):
         if params:
             assert isinstance(params, (list, tuple)), (
-                'params must be a list or tuple containing the query '
-                'parameter values'
+                "params must be a list or tuple containing the query " "parameter values"
             )
 
             statement_name = self._generate_unique_statement_name()
             # Send prepare statement
-            added_prepare_header = self._prepare_statement(
-                operation, statement_name
-            )
+            added_prepare_header = self._prepare_statement(operation, statement_name)
 
             try:
                 # Send execute statement and assign the return value to `results`
                 # as it will be returned by the function
-                self._query = self._get_added_prepare_statement_trino_query(
-                    statement_name, params
-                )
+                self._query = self._get_added_prepare_statement_trino_query(statement_name, params)
                 result = self._query.execute(
                     additional_http_headers={
                         constants.HEADER_PREPARED_STATEMENT: added_prepare_header
@@ -479,8 +479,11 @@ def execute(self, operation, params=None):
                 self._deallocate_prepare_statement(added_prepare_header, statement_name)
 
         else:
-            self._query = trino.client.TrinoQuery(self._request, sql=operation,
-                                                  experimental_python_types=self._experimental_pyton_types)
+            self._query = trino.client.TrinoQuery(
+                self._request,
+                sql=operation,
+                experimental_python_types=self._experimental_pyton_types,
+            )
             result = self._query.execute()
         self._iterator = iter(result)
         return result
@@ -570,9 +573,7 @@ def fetchall(self) -> List[List[Any]]:
 
     def cancel(self):
         if self._query is None:
-            raise trino.exceptions.OperationalError(
-                "Cancel query failed; no running query"
-            )
+            raise trino.exceptions.OperationalError("Cancel query failed; no running query")
         self._query.cancel()
 
     def close(self):
@@ -606,9 +607,7 @@ def __eq__(self, other):
 
 STRING = DBAPITypeObject("VARCHAR", "CHAR", "VARBINARY", "JSON", "IPADDRESS")
 
-BINARY = DBAPITypeObject(
-    "ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest"
-)
+BINARY = DBAPITypeObject("ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest")
 
 NUMBER = DBAPITypeObject(
     "BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL"
diff --git a/trino/exceptions.py b/trino/exceptions.py
index 86708fd0..f8eddba3 100644
--- a/trino/exceptions.py
+++ b/trino/exceptions.py
@@ -139,6 +139,7 @@ class FailedToObtainAddedPrepareHeader(Error):
     Raise this exception when unable to find the 'X-Trino-Added-Prepare'
     header in the response of a PREPARE statement request.
     """
+
     pass
 
 
@@ -147,6 +148,7 @@ class FailedToObtainDeallocatedPrepareHeader(Error):
     Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare'
     header in the response of a DEALLOCATED statement request.
     """
+
     pass
 
 
diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py
index a085fbf3..30141edf 100644
--- a/trino/sqlalchemy/compiler.py
+++ b/trino/sqlalchemy/compiler.py
@@ -10,15 +10,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from sqlalchemy.sql import compiler
+
 try:
-    from sqlalchemy.sql.expression import (
-        Alias,
-        CTE,
-        Subquery,
-    )
+    from sqlalchemy.sql.expression import CTE, Alias, Subquery
 except ImportError:
     # For SQLAlchemy versions < 1.4, the CTE and Subquery classes did not explicitly exist
     from sqlalchemy.sql.expression import Alias
+
     CTE = type(None)
     Subquery = type(None)
 
@@ -113,8 +111,16 @@ def limit_clause(self, select, **kw):
             text += "\nLIMIT " + self.process(select._limit_clause, **kw)
         return text
 
-    def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
-                    fromhints=None, use_schema=True, **kwargs):
+    def visit_table(
+        self,
+        table,
+        asfrom=False,
+        iscrud=False,
+        ashint=False,
+        fromhints=None,
+        use_schema=True,
+        **kwargs,
+    ):
         sql = super(TrinoSQLCompiler, self).visit_table(
             table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
         )
@@ -128,13 +134,10 @@ def add_catalog(sql, table):
         if isinstance(table, (Alias, CTE, Subquery)):
             return sql
 
-        if (
-                'trino' not in table.dialect_options
-                or 'catalog' not in table.dialect_options['trino']
-        ):
+        if "trino" not in table.dialect_options or "catalog" not in table.dialect_options["trino"]:
             return sql
 
-        catalog = table.dialect_options['trino']['catalog']
+        catalog = table.dialect_options["trino"]["catalog"]
         sql = f'"{catalog}".{sql}'
         return sql
 
diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py
index 8284ba9c..e54f98c3 100644
--- a/trino/sqlalchemy/datatype.py
+++ b/trino/sqlalchemy/datatype.py
@@ -10,7 +10,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
-from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
 
 from sqlalchemy import util
 from sqlalchemy.sql import sqltypes
@@ -167,7 +167,7 @@ def aware_split(
         elif character == close_bracket:
             parens -= 1
         elif character == quote:
-            if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote:
+            if quotes and string[j - len(escaped_quote) + 1 : j + 1] != escaped_quote:
                 quotes = False
             elif not quotes:
                 quotes = True
diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py
index 7c4409a0..d58fc387 100644
--- a/trino/sqlalchemy/dialect.py
+++ b/trino/sqlalchemy/dialect.py
@@ -19,7 +19,8 @@
 from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
 from sqlalchemy.engine.url import URL
 
-from trino import dbapi as trino_dbapi, logging
+from trino import dbapi as trino_dbapi
+from trino import logging
 from trino.auth import BasicAuthentication, CertificateAuthentication, JWTAuthentication
 from trino.dbapi import Cursor
 from trino.sqlalchemy import compiler, datatype, error
@@ -102,7 +103,7 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
 
         if "cert" and "key" in url.query:
             kwargs["http_scheme"] = "https"
-            kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key'])
+            kwargs["auth"] = CertificateAuthentication(url.query["cert"], url.query["key"])
 
         if "source" in url.query:
             kwargs["source"] = url.query["source"]
@@ -122,16 +123,22 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
             kwargs["client_tags"] = json.loads(url.query["client_tags"])
 
         if "experimental_python_types" in url.query:
-            kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"])
+            kwargs["experimental_python_types"] = json.loads(
+                url.query["experimental_python_types"]
+            )
 
         return args, kwargs
 
-    def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
+    def get_columns(
+        self, connection: Connection, table_name: str, schema: str = None, **kw
+    ) -> List[Dict[str, Any]]:
         if not self.has_table(connection, table_name, schema):
             raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")
         return self._get_columns(connection, table_name, schema, **kw)
 
-    def _get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
+    def _get_columns(
+        self, connection: Connection, table_name: str, schema: str = None, **kw
+    ) -> List[Dict[str, Any]]:
         schema = schema or self._get_default_schema_name(connection)
         query = dedent(
             """
@@ -158,11 +165,15 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No
             columns.append(column)
         return columns
 
-    def get_pk_constraint(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]:
+    def get_pk_constraint(
+        self, connection: Connection, table_name: str, schema: str = None, **kw
+    ) -> Dict[str, Any]:
         """Trino has no support for primary keys. Returns a dummy"""
         return dict(name=None, constrained_columns=[])
 
-    def get_primary_keys(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[str]:
+    def get_primary_keys(
+        self, connection: Connection, table_name: str, schema: str = None, **kw
+    ) -> List[str]:
         pk = self.get_pk_constraint(connection, table_name, schema)
         return pk.get("constrained_columns")  # type: ignore
 
@@ -218,7 +229,9 @@ def get_temp_view_names(self, connection: Connection, schema: str = None, **kw)
         """Trino has no support for temporary views. Returns an empty list."""
         return []
 
-    def get_view_definition(self, connection: Connection, view_name: str, schema: str = None, **kw) -> str:
+    def get_view_definition(
+        self, connection: Connection, view_name: str, schema: str = None, **kw
+    ) -> str:
         schema = schema or self._get_default_schema_name(connection)
         if schema is None:
             raise exc.NoSuchTableError("schema is required")
@@ -233,17 +246,21 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
         res = connection.execute(sql.text(query), schema=schema, view=view_name)
         return res.scalar()
 
-    def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
+    def get_indexes(
+        self, connection: Connection, table_name: str, schema: str = None, **kw
+    ) -> List[Dict[str, Any]]:
         if not self.has_table(connection, table_name, schema):
             raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")
 
-        partitioned_columns = self._get_columns(connection, f"{table_name}$partitions", schema, **kw)
+        partitioned_columns = self._get_columns(
+            connection, f"{table_name}$partitions", schema, **kw
+        )
         if not partitioned_columns:
             return []
         partition_index = dict(
             name="partition",
             column_names=[col["name"] for col in partitioned_columns],
-            unique=False
+            unique=False,
         )
         return [partition_index]
 
@@ -263,7 +280,9 @@ def get_check_constraints(
         """Trino has no support for check constraints. Returns an empty list."""
         return []
 
-    def get_table_comment(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]:
+    def get_table_comment(
+        self, connection: Connection, table_name: str, schema: str = None, **kw
+    ) -> Dict[str, Any]:
         catalog_name = self._get_default_catalog_name(connection)
         if catalog_name is None:
             raise exc.NoSuchTableError("catalog is required in connection")
@@ -282,13 +301,13 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
         try:
             res = connection.execute(
                 sql.text(query),
-                catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
+                catalog_name=catalog_name,
+                schema_name=schema_name,
+                table_name=table_name,
             )
             return dict(text=res.scalar())
         except error.TrinoQueryError as e:
-            if e.error_name in (
-                error.PERMISSION_DENIED,
-            ):
+            if e.error_name in (error.PERMISSION_DENIED,):
                 return dict(text=None)
             raise
 
@@ -318,7 +337,9 @@ def has_table(self, connection: Connection, table_name: str, schema: str = None,
         res = connection.execute(sql.text(query), schema=schema, table=table_name)
         return res.first() is not None
 
-    def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool:
+    def has_sequence(
+        self, connection: Connection, sequence_name: str, schema: str = None, **kw
+    ) -> bool:
         """Trino has no support for sequence. Returns False indicate that given sequence does not exists."""
         return False
 
@@ -341,7 +362,11 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
         return dbapi_connection.schema
 
     def do_execute(
-        self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None
+        self,
+        cursor: Cursor,
+        statement: str,
+        parameters: Tuple[Any, ...],
+        context: DefaultExecutionContext = None,
     ):
         cursor.execute(statement, parameters)
         if context and context.should_autocommit:
diff --git a/trino/transaction.py b/trino/transaction.py
index e6c85234..c6e6257d 100644
--- a/trino/transaction.py
+++ b/trino/transaction.py
@@ -12,11 +12,10 @@
 from enum import Enum, unique
 from typing import Iterable
 
-from trino import constants
 import trino.client
 import trino.exceptions
 import trino.logging
-
+from trino import constants
 
 logger = trino.logging.get_logger(__name__)