diff --git a/superset/models/core.py b/superset/models/core.py index 6f32383ab8058..96a1953faeb3b 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -84,7 +84,11 @@ from superset.utils import cache as cache_util, core as utils, json from superset.utils.backports import StrEnum from superset.utils.core import get_username -from superset.utils.oauth2 import get_oauth2_access_token, OAuth2ClientConfigSchema +from superset.utils.oauth2 import ( + check_for_oauth2, + get_oauth2_access_token, + OAuth2ClientConfigSchema, +) config = app.config custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] @@ -451,13 +455,14 @@ def get_sqla_engine( # pylint: disable=too-many-arguments engine_context_manager = config["ENGINE_CONTEXT_MANAGER"] with engine_context_manager(self, catalog, schema): - yield self._get_sqla_engine( - catalog=catalog, - schema=schema, - nullpool=nullpool, - source=source, - sqlalchemy_uri=sqlalchemy_uri, - ) + with check_for_oauth2(self): + yield self._get_sqla_engine( + catalog=catalog, + schema=schema, + nullpool=nullpool, + source=source, + sqlalchemy_uri=sqlalchemy_uri, + ) def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901 self, @@ -583,10 +588,9 @@ def get_raw_connection( nullpool=nullpool, source=source, ) as engine: - try: + with check_for_oauth2(self): with closing(engine.raw_connection()) as conn: - # pre-session queries are used to set the selected schema and, in the # noqa: E501 - # future, the selected catalog + # pre-session queries are used to set the selected catalog/schema for prequery in self.db_engine_spec.get_prequeries( database=self, catalog=catalog, @@ -597,11 +601,6 @@ def get_raw_connection( yield conn - except Exception as ex: - if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): - self.db_engine_spec.start_oauth2_dance(self) - raise - def get_default_catalog(self) -> str | None: """ Return the default configured catalog for the database. diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 0918f0792a9a2..b93f89c870016 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -17,8 +17,9 @@ from __future__ import annotations +from contextlib import contextmanager from datetime import datetime, timedelta, timezone -from typing import Any, TYPE_CHECKING +from typing import Any, Iterator, TYPE_CHECKING import backoff import jwt @@ -32,7 +33,7 @@ if TYPE_CHECKING: from superset.db_engine_specs.base import BaseEngineSpec - from superset.models.core import DatabaseUserOAuth2Tokens + from superset.models.core import Database, DatabaseUserOAuth2Tokens JWT_EXPIRATION = timedelta(minutes=5) @@ -197,3 +198,16 @@ class OAuth2ClientConfigSchema(Schema): load_default=lambda: "json", validate=validate.OneOf(["json", "data"]), ) + + +@contextmanager +def check_for_oauth2(database: Database) -> Iterator[None]: + """ + Run code and check if OAuth2 is needed. + """ + try: + yield + except Exception as ex: + if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(ex): + database.db_engine_spec.start_oauth2_dance(database) + raise diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 5b269fc3ba935..50ccde6605568 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -558,16 +558,47 @@ def test_get_oauth2_config(app_context: None) -> None: } -def test_raw_connection_oauth(mocker: MockerFixture) -> None: +def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None: """ Test that we can start OAuth2 from `raw_connection()` errors. - Some databases that use OAuth2 need to trigger the flow when the connection is - created, rather than when the query runs. This happens when the SQLAlchemy engine - URI cannot be built without the user personal token. + With OAuth2, some databases will raise an exception when the engine is first created + (eg, BigQuery). Others, like, Snowflake, when the connection is created. And + finally, GSheets will raise an exception when the query is executed. - This test verifies that the exception is captured and raised correctly so that the - frontend can trigger the OAuth2 dance. + This tests verifies that when calling `raw_connection()` the OAuth2 flow is + triggered when the engine is created. + """ + g = mocker.patch("superset.db_engine_specs.base.g") + g.user = mocker.MagicMock() + g.user.id = 42 + + database = Database( + id=1, + database_name="my_db", + sqlalchemy_uri="sqlite://", + encrypted_extra=json.dumps(oauth2_client_info), + ) + database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore + _get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine") + _get_sqla_engine.side_effect = OAuth2Error("OAuth2 required") + + with pytest.raises(OAuth2RedirectError) as excinfo: + with database.get_raw_connection() as conn: + conn.cursor() + assert str(excinfo.value) == "You don't have permission to access the data." + + +def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None: + """ + Test that we can start OAuth2 from `raw_connection()` errors. + + With OAuth2, some databases will raise an exception when the engine is first created + (eg, BigQuery). Others, like, Snowflake, when the connection is created. And + finally, GSheets will raise an exception when the query is executed. + + This tests verifies that when calling `raw_connection()` the OAuth2 flow is + triggered when the connection is created. """ g = mocker.patch("superset.db_engine_specs.base.g") g.user = mocker.MagicMock() @@ -591,6 +622,40 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> None: assert str(excinfo.value) == "You don't have permission to access the data." +def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None: + """ + Test that we can start OAuth2 from `raw_connection()` errors. + + With OAuth2, some databases will raise an exception when the engine is first created + (eg, BigQuery). Others, like, Snowflake, when the connection is created. And + finally, GSheets will raise an exception when the query is executed. + + This tests verifies that when calling `raw_connection()` the OAuth2 flow is + triggered when the connection is created. + """ + g = mocker.patch("superset.db_engine_specs.base.g") + g.user = mocker.MagicMock() + g.user.id = 42 + + database = Database( + id=1, + database_name="my_db", + sqlalchemy_uri="sqlite://", + encrypted_extra=json.dumps(oauth2_client_info), + ) + database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore + get_sqla_engine = mocker.patch.object(database, "get_sqla_engine") + get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = ( + OAuth2Error("OAuth2 required") + ) + + with pytest.raises(OAuth2RedirectError) as excinfo: # noqa: PT012 + with database.get_raw_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + assert str(excinfo.value) == "You don't have permission to access the data." + + def test_get_schema_access_for_file_upload() -> None: """ Test the `get_schema_access_for_file_upload` method. @@ -638,6 +703,27 @@ def test_engine_context_manager(mocker: MockerFixture) -> None: ) +def test_engine_oauth2(mocker: MockerFixture) -> None: + """ + Test that we handle OAuth2 when `create_engine` fails. + """ + database = Database(database_name="my_db", sqlalchemy_uri="trino://") + mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception) + mocker.patch.object(database, "is_oauth2_enabled", return_value=True) + mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True) + start_oauth2_dance = mocker.patch.object( + database.db_engine_spec, + "start_oauth2_dance", + side_effect=OAuth2Error("OAuth2 required"), + ) + + with pytest.raises(OAuth2Error): + with database.get_sqla_engine("catalog", "schema"): + pass + + start_oauth2_dance.assert_called_with(database) + + def test_purge_oauth2_tokens(session: Session) -> None: """ Test the `purge_oauth2_tokens` method.