From 21e794a66ffeedd7775748bcef286a6069729c9f Mon Sep 17 00:00:00 2001 From: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:31:10 -0300 Subject: [PATCH] fix(database import): Gracefully handle error to get catalog schemas (#31437) --- superset/commands/database/create.py | 47 +----------- .../commands/database/importers/v1/utils.py | 38 +--------- superset/commands/database/utils.py | 67 ++++++++++++++++ .../commands/importers/v1/import_test.py | 28 ------- .../databases/commands/utils_test.py | 76 +++++++++++++++++++ 5 files changed, 147 insertions(+), 109 deletions(-) create mode 100644 superset/commands/database/utils.py create mode 100644 tests/unit_tests/databases/commands/utils_test.py diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 0eb8d43b37c99..9425dc5ea945c 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -39,11 +39,11 @@ SSHTunnelInvalidError, ) from superset.commands.database.test_connection import TestConnectionDatabaseCommand +from superset.commands.database.utils import add_permissions from superset.daos.database import DatabaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel -from superset.db_engine_specs.base import GenericDBException from superset.exceptions import OAuth2RedirectError, SupersetErrorsException -from superset.extensions import event_logger, security_manager +from superset.extensions import event_logger from superset.models.core import Database from superset.utils.decorators import on_error, transaction @@ -99,28 +99,7 @@ def run(self) -> Model: ).run() # add catalog/schema permissions - if database.db_engine_spec.supports_catalog: - catalogs = database.get_all_catalog_names( - cache=False, - ssh_tunnel=ssh_tunnel, - ) - for catalog in catalogs: - security_manager.add_permission_view_menu( - "catalog_access", - security_manager.get_catalog_perm( - database.database_name, catalog - ), - ) - else: - # add a dummy catalog for DBs that don't support them - catalogs = [None] - - for catalog in catalogs: - try: - self.add_schema_permissions(database, catalog, ssh_tunnel) - except GenericDBException: # pylint: disable=broad-except - logger.warning("Error processing catalog '%s'", catalog) - continue + add_permissions(database, ssh_tunnel) except ( SSHTunnelInvalidError, SSHTunnelCreateFailedError, @@ -148,26 +127,6 @@ def run(self) -> Model: return database - def add_schema_permissions( - self, - database: Database, - catalog: str, - ssh_tunnel: Optional[SSHTunnel], - ) -> None: - for schema in database.get_all_schema_names( - catalog=catalog, - cache=False, - ssh_tunnel=ssh_tunnel, - ): - security_manager.add_permission_view_menu( - "schema_access", - security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ), - ) - def validate(self) -> None: exceptions: list[ValidationError] = [] sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri") diff --git a/superset/commands/database/importers/v1/utils.py b/superset/commands/database/importers/v1/utils.py index c9927a033dad8..0098bfa26d456 100644 --- a/superset/commands/database/importers/v1/utils.py +++ b/superset/commands/database/importers/v1/utils.py @@ -19,6 +19,7 @@ from typing import Any from superset import app, db, security_manager +from superset.commands.database.utils import add_permissions from superset.commands.exceptions import ImportFailedError from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe @@ -86,40 +87,3 @@ def import_database( logger.warning(ex.message) return database - - -def add_permissions(database: Database, ssh_tunnel: SSHTunnel) -> None: - """ - Add DAR for catalogs and schemas. - """ - if database.db_engine_spec.supports_catalog: - catalogs = database.get_all_catalog_names( - cache=False, - ssh_tunnel=ssh_tunnel, - ) - - for catalog in catalogs: - security_manager.add_permission_view_menu( - "catalog_access", - security_manager.get_catalog_perm( - database.database_name, - catalog, - ), - ) - else: - catalogs = [None] - - for catalog in catalogs: - for schema in database.get_all_schema_names( - catalog=catalog, - cache=False, - ssh_tunnel=ssh_tunnel, - ): - security_manager.add_permission_view_menu( - "schema_access", - security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ), - ) diff --git a/superset/commands/database/utils.py b/superset/commands/database/utils.py new file mode 100644 index 0000000000000..ea0ce1a27e274 --- /dev/null +++ b/superset/commands/database/utils.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +from superset import security_manager +from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.db_engine_specs.base import GenericDBException +from superset.models.core import Database + +logger = logging.getLogger(__name__) + + +def add_permissions(database: Database, ssh_tunnel: SSHTunnel | None) -> None: + """ + Add DAR for catalogs and schemas. + """ + if database.db_engine_spec.supports_catalog: + catalogs = database.get_all_catalog_names( + cache=False, + ssh_tunnel=ssh_tunnel, + ) + + for catalog in catalogs: + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm( + database.database_name, + catalog, + ), + ) + else: + catalogs = [None] + + for catalog in catalogs: + try: + for schema in database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ): + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) + except GenericDBException: # pylint: disable=broad-except + logger.warning("Error processing catalog '%s'", catalog) + continue diff --git a/tests/unit_tests/databases/commands/importers/v1/import_test.py b/tests/unit_tests/databases/commands/importers/v1/import_test.py index f101216bbf0de..06be5bc16f1f0 100644 --- a/tests/unit_tests/databases/commands/importers/v1/import_test.py +++ b/tests/unit_tests/databases/commands/importers/v1/import_test.py @@ -23,7 +23,6 @@ from sqlalchemy.orm.session import Session from superset import db -from superset.commands.database.importers.v1.utils import add_permissions from superset.commands.exceptions import ImportFailedError from superset.utils import json @@ -216,30 +215,3 @@ def test_import_database_with_user_impersonation( database = import_database(config) assert database.impersonate_user is True - - -def test_add_permissions(mocker: MockerFixture) -> None: - """ - Test adding permissions to a database when it's imported. - """ - database = mocker.MagicMock() - database.database_name = "my_db" - database.db_engine_spec.supports_catalog = True - database.get_all_catalog_names.return_value = ["catalog1", "catalog2"] - database.get_all_schema_names.side_effect = [["schema1"], ["schema2"]] - ssh_tunnel = mocker.MagicMock() - add_permission_view_menu = mocker.patch( - "superset.commands.database.importers.v1.utils.security_manager." - "add_permission_view_menu" - ) - - add_permissions(database, ssh_tunnel) - - add_permission_view_menu.assert_has_calls( - [ - mocker.call("catalog_access", "[my_db].[catalog1]"), - mocker.call("catalog_access", "[my_db].[catalog2]"), - mocker.call("schema_access", "[my_db].[catalog1].[schema1]"), - mocker.call("schema_access", "[my_db].[catalog2].[schema2]"), - ] - ) diff --git a/tests/unit_tests/databases/commands/utils_test.py b/tests/unit_tests/databases/commands/utils_test.py new file mode 100644 index 0000000000000..e8f27d041602c --- /dev/null +++ b/tests/unit_tests/databases/commands/utils_test.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pytest_mock import MockerFixture + +from superset.commands.database.utils import add_permissions + + +def test_add_permissions(mocker: MockerFixture) -> None: + """ + Test adding permissions to a database when it's created. + """ + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.supports_catalog = True + database.get_all_catalog_names.return_value = ["catalog1", "catalog2"] + database.get_all_schema_names.side_effect = [["schema1"], ["schema2"]] + ssh_tunnel = mocker.MagicMock() + add_permission_view_menu = mocker.patch( + "superset.commands.database.importers.v1.utils.security_manager." + "add_permission_view_menu" + ) + + add_permissions(database, ssh_tunnel) + + add_permission_view_menu.assert_has_calls( + [ + mocker.call("catalog_access", "[my_db].[catalog1]"), + mocker.call("catalog_access", "[my_db].[catalog2]"), + mocker.call("schema_access", "[my_db].[catalog1].[schema1]"), + mocker.call("schema_access", "[my_db].[catalog2].[schema2]"), + ] + ) + + +def test_add_permissions_handle_failures(mocker: MockerFixture) -> None: + """ + Test adding permissions to a database when it's created in case + the request to get all schemas for one fo the catalogs fail. + """ + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.supports_catalog = True + database.get_all_catalog_names.return_value = ["catalog1", "catalog2", "catalog3"] + database.get_all_schema_names.side_effect = [["schema1"], Exception, ["schema3"]] + ssh_tunnel = mocker.MagicMock() + add_permission_view_menu = mocker.patch( + "superset.commands.database.importers.v1.utils.security_manager." + "add_permission_view_menu" + ) + + add_permissions(database, ssh_tunnel) + + add_permission_view_menu.assert_has_calls( + [ + mocker.call("catalog_access", "[my_db].[catalog1]"), + mocker.call("catalog_access", "[my_db].[catalog2]"), + mocker.call("catalog_access", "[my_db].[catalog3]"), + mocker.call("schema_access", "[my_db].[catalog1].[schema1]"), + mocker.call("schema_access", "[my_db].[catalog3].[schema3]"), + ] + )