Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support setting roles in dbapi and sqlalchemy #212

Merged
merged 4 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ engine = create_engine(
"session_properties": {'query_max_run_time': '1d'},
"client_tags": ["tag1", "tag2"],
"experimental_python_types": True,
"roles": {"catalog1": "role1"},
}
)

Expand All @@ -115,7 +116,8 @@ engine = create_engine(
'trino://user@localhost:8080/system?'
'session_properties={"query_max_run_time": "1d"}'
'&client_tags=["tag1", "tag2"]'
'&experimental_python_types=true',
'&experimental_python_types=true'
'&roles={"catalog1": "role1"}'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no test for sqlalchemy integration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
```

Expand Down Expand Up @@ -328,6 +330,20 @@ cur.execute('SELECT * FROM system.runtime.nodes')
rows = cur.fetchall()
```

## Roles

Authorization roles to use for catalogs, specified as a dict with key-value pairs for the catalog and role. For example, `{"catalog1": "roleA", "catalog2": "roleB"}` sets `roleA` for `catalog1` and `roleB` for `catalog2`. See Trino docs.

```python
import trino
conn = trino.dbapi.connect(
host='localhost',
port=443,
user='the-user',
roles={"catalog1": "roleA", "catalog2": "roleB"},
)
```

## SSL

### SSL verification
Expand Down
26 changes: 22 additions & 4 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import trino
from tests.integration.conftest import trino_version
from trino import constants
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Refactor role into roles -> Add support for system and connector roles?

Does it make sense (and is it accurate?)?

Copy link
Contributor Author

@mdesmet mdesmet Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hashhar: It is really just a rename of role into roles. There is no new feature being introduced in that commit actually. The role field actually already contained multiple roles. I thought this is more consistent with the roles being passed on from dbapi.connect, introduced in the next commit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the reason why the first commit by itself fails tests. See

======================================================================================================================================= FAILURES ========================================================================================================================================
_____________________________________________________________________________________________________________________ test_set_role_in_connection_trino_higher_351 ______________________________________________________________________________________________________________________

run_trino = (<Popen: returncode: None args: ['docker', 'run', '--rm', '-p', '56257:8080',...>, 'localhost', 56257)

    @pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role")
    def test_set_role_in_connection_trino_higher_351(run_trino):
        _, host, port = run_trino

>       trino_connection = trino.dbapi.Connection(
            host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"}
        )
E       TypeError: Connection.__init__() got an unexpected keyword argument 'roles'

tests/integration/test_dbapi_integration.py:1077: TypeError
------------------------------------------------------------------------------------------------------------------------------- Captured stdout teardown --------------------------------------------------------------------------------------------------------------------------------
trino-python-client-tests-637e3b1
____________________________________________________________________________________________________________________________ test_role_is_set_when_specified ____________________________________________________________________________________________________________________________

mock_client = <MagicMock name='client' id='4532039360'>

    @patch("trino.dbapi.trino.client")
    def test_role_is_set_when_specified(mock_client):
        roles = {"system": "finance"}
>       with connect("sample_trino_cluster:443", roles=roles) as conn:

tests/unit/test_dbapi.py:262:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = ('sample_trino_cluster:443',), kwargs = {'roles': {'system': 'finance'}}

    def connect(*args, **kwargs):
        """Constructor for creating a connection to the database.

        See class :py:class:`Connection` for arguments.

        :returns: a :py:class:`Connection` object.
        """
>       return Connection(*args, **kwargs)
E       TypeError: Connection.__init__() got an unexpected keyword argument 'roles'

trino/dbapi.py:82: TypeError
================================================================================================================================ short test summary info ================================================================================================================================
FAILED tests/integration/test_dbapi_integration.py::test_set_role_in_connection_trino_higher_351 - TypeError: Connection.__init__() got an unexpected keyword argument 'roles'
FAILED tests/unit/test_dbapi.py::test_role_is_set_when_specified - TypeError: Connection.__init__() got an unexpected keyword argument 'roles'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably messed up something in my fixup comments locally. All good now.

from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError
from trino.transaction import IsolationLevel

Expand Down Expand Up @@ -1045,11 +1046,11 @@ def test_set_role_trino_higher_351(run_trino):
cur = trino_connection.cursor()
cur.execute('SHOW TABLES FROM information_schema')
cur.fetchall()
assert cur._request._client_session.role is None
assert cur._request._client_session.roles == {}

cur.execute("SET ROLE ALL")
cur.fetchall()
assert cur._request._client_session.role == "system=ALL"
assert_role_headers(cur, "system=ALL")


@pytest.mark.skipif(trino_version() != '351', reason="Trino 351 returns the role for the current catalog")
Expand All @@ -1062,11 +1063,28 @@ def test_set_role_trino_351(run_trino):
cur = trino_connection.cursor()
cur.execute('SHOW TABLES FROM information_schema')
cur.fetchall()
assert cur._request._client_session.role is None
assert cur._request._client_session.roles == {}

cur.execute("SET ROLE ALL")
cur.fetchall()
assert cur._request._client_session.role == "tpch=ALL"
assert_role_headers(cur, "tpch=ALL")


@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role")
def test_set_role_in_connection_trino_higher_351(run_trino):
_, host, port = run_trino

trino_connection = trino.dbapi.Connection(
host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"}
)
cur = trino_connection.cursor()
cur.execute('SHOW TABLES FROM information_schema')
cur.fetchall()
assert_role_headers(cur, "system=ALL")


def assert_role_headers(cursor, expected_header):
assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header


def test_prepared_statements(run_trino):
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def setup(self):
experimental_python_types=True,
),
),
(
make_url('trino://user@localhost:8080?roles={"hive":"finance","system":"analyst"}'),
list(),
dict(host="localhost",
port=8080,
catalog="system",
user="user",
roles={"hive": "finance", "system": "analyst"},
source="trino-sqlalchemy"),
),
],
)
def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]):
Expand Down
30 changes: 16 additions & 14 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import threading
import time
import uuid
from typing import Optional, Dict
from unittest import mock
from urllib.parse import urlparse

Expand Down Expand Up @@ -92,10 +93,9 @@ def assert_headers(headers):
assert headers[constants.HEADER_SOURCE] == source
assert headers[constants.HEADER_USER] == user
assert headers[constants.HEADER_SESSION] == ""
assert headers[constants.HEADER_ROLE] is None
assert headers[accept_encoding_header] == accept_encoding_value
assert headers[client_info_header] == client_info_value
assert len(headers.keys()) == 9
assert len(headers.keys()) == 8

req.post("URL")
_, post_kwargs = post.call_args
Expand Down Expand Up @@ -988,10 +988,12 @@ def __call__(self):
with_retry(FailerUntil(3).__call__)()


def assert_headers_with_role(headers, role):
def assert_headers_with_roles(headers: Dict[str, str], roles: Optional[str]):
if roles is None:
assert constants.HEADER_ROLE not in headers
else:
assert headers[constants.HEADER_ROLE] == roles
assert headers[constants.HEADER_USER] == "test_user"
assert headers[constants.HEADER_ROLE] == role
assert len(headers.keys()) == 7


def test_request_headers_role_hive_all(mock_get_and_post):
Expand All @@ -1001,17 +1003,17 @@ def test_request_headers_role_hive_all(mock_get_and_post):
port=8080,
client_session=ClientSession(
user="test_user",
role="hive=ALL",
roles={"hive": "ALL"}
),
)

req.post("URL")
_, post_kwargs = post.call_args
assert_headers_with_role(post_kwargs["headers"], "hive=ALL")
assert_headers_with_roles(post_kwargs["headers"], "hive=ALL")

req.get("URL")
_, get_kwargs = get.call_args
assert_headers_with_role(post_kwargs["headers"], "hive=ALL")
assert_headers_with_roles(post_kwargs["headers"], "hive=ALL")


def test_request_headers_role_admin(mock_get_and_post):
Expand All @@ -1022,17 +1024,17 @@ def test_request_headers_role_admin(mock_get_and_post):
port=8080,
client_session=ClientSession(
user="test_user",
role="admin",
roles={"system": "admin"}
),
)

req.post("URL")
_, post_kwargs = post.call_args
assert_headers_with_role(post_kwargs["headers"], "admin")
assert_headers_with_roles(post_kwargs["headers"], "system=admin")

req.get("URL")
_, get_kwargs = get.call_args
assert_headers_with_role(post_kwargs["headers"], "admin")
assert_headers_with_roles(post_kwargs["headers"], "system=admin")


def test_request_headers_role_empty(mock_get_and_post):
Expand All @@ -1043,14 +1045,14 @@ def test_request_headers_role_empty(mock_get_and_post):
port=8080,
client_session=ClientSession(
user="test_user",
role="",
roles=None,
),
)

req.post("URL")
_, post_kwargs = post.call_args
assert_headers_with_role(post_kwargs["headers"], "")
assert_headers_with_roles(post_kwargs["headers"], None)

req.get("URL")
_, get_kwargs = get.call_args
assert_headers_with_role(post_kwargs["headers"], "")
assert_headers_with_roles(post_kwargs["headers"], None)
12 changes: 10 additions & 2 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,19 @@ def run(self) -> None:

@patch("trino.dbapi.trino.client")
def test_tags_are_set_when_specified(mock_client):
# WHEN
client_tags = ["TAG1", "TAG2"]
with connect("sample_trino_cluster:443", client_tags=client_tags) as conn:
conn.cursor().execute("SOME FAKE QUERY")

# THEN
_, passed_client_tags = mock_client.ClientSession.call_args
assert passed_client_tags["client_tags"] == client_tags


@patch("trino.dbapi.trino.client")
def test_role_is_set_when_specified(mock_client):
roles = {"system": "finance"}
with connect("sample_trino_cluster:443", roles=roles) as conn:
conn.cursor().execute("SOME FAKE QUERY")

_, passed_role = mock_client.ClientSession.call_args
assert passed_role["roles"] == roles
56 changes: 36 additions & 20 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class ClientSession(object):
:param extra_credential: extra credentials. as list of ``(key, value)``
tuples.
:param client_tags: Client tags as list of strings.
:param role: role for the current session. Some connectors do not
:param roles: roles for the current session. Some connectors do not
support role management. See connector documentation for more details.
"""

Expand All @@ -113,7 +113,7 @@ def __init__(
transaction_id: str = None,
extra_credential: List[Tuple[str, str]] = None,
client_tags: List[str] = None,
role: str = None,
roles: Dict[str, str] = None,
):
self._user = user
self._catalog = catalog
Expand All @@ -124,7 +124,7 @@ def __init__(
self._transaction_id = transaction_id
self._extra_credential = extra_credential
self._client_tags = client_tags
self._role = role
self._roles = roles or {}
self._prepared_statements: Dict[str, str] = {}
self._object_lock = threading.Lock()

Expand Down Expand Up @@ -188,24 +188,15 @@ def extra_credential(self):
def client_tags(self):
return self._client_tags

def __getstate__(self):
state = self.__dict__.copy()
del state["_object_lock"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._object_lock = threading.Lock()

@property
def role(self):
def roles(self):
with self._object_lock:
return self._role
return self._roles

@role.setter
def role(self, role):
@roles.setter
def roles(self, roles):
with self._object_lock:
self._role = role
self._roles = roles

@property
def prepared_statements(self):
Expand All @@ -216,6 +207,15 @@ def prepared_statements(self, prepared_statements):
with self._object_lock:
self._prepared_statements = prepared_statements

def __getstate__(self):
state = self.__dict__.copy()
del state["_object_lock"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._object_lock = threading.Lock()


def get_header_values(headers, header):
return [val.strip() for val in headers[header].split(",")]
Expand All @@ -224,7 +224,7 @@ def get_header_values(headers, header):
def get_session_property_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote(v.strip()))
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs)
]

Expand All @@ -237,6 +237,14 @@ def get_prepared_statement_values(headers, header):
]


def get_roles_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs)
]


class TrinoStatus(object):
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
self.id = id
Expand Down Expand Up @@ -400,7 +408,12 @@ def http_headers(self) -> Dict[str, str]:
headers[constants.HEADER_SCHEMA] = self._client_session.schema
headers[constants.HEADER_SOURCE] = self._client_session.source
headers[constants.HEADER_USER] = self._client_session.user
headers[constants.HEADER_ROLE] = self._client_session.role
if len(self._client_session.roles.values()):
headers[constants.HEADER_ROLE] = ",".join(
# ``name`` must not contain ``=``
"{}={}".format(catalog, urllib.parse.quote(str(role)))
for catalog, role in self._client_session.roles.items()
)
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)

Expand Down Expand Up @@ -579,7 +592,10 @@ def process(self, http_response) -> TrinoStatus:
self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA]

if constants.HEADER_SET_ROLE in http_response.headers:
self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE]
for key, value in get_roles_values(
http_response.headers, constants.HEADER_SET_ROLE
):
self._client_session.roles[key] = value

if constants.HEADER_ADDED_PREPARE in http_response.headers:
for name, statement in get_prepared_statement_values(
Expand Down
4 changes: 3 additions & 1 deletion trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
http_session=None,
client_tags=None,
experimental_python_types=False,
roles=None,
):
self.host = host
self.port = port
Expand All @@ -127,7 +128,8 @@ def __init__(
headers=http_headers,
transaction_id=NO_TRANSACTION,
extra_credential=extra_credential,
client_tags=client_tags
client_tags=client_tags,
roles=roles,
)
# mypy cannot follow module import
if http_session is None:
Expand Down
3 changes: 3 additions & 0 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
if "experimental_python_types" in url.query:
kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"])

if "roles" in url.query:
kwargs["roles"] = json.loads(url.query["roles"])

return args, kwargs

def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
Expand Down