diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 76dd6087be58a..b8854010faa2e 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -41,6 +41,7 @@ from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.daos.database import DatabaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.db_engine_specs.base import GenericDBException from superset.exceptions import SupersetErrorsException from superset.extensions import event_logger, security_manager from superset.models.core import Database @@ -118,7 +119,7 @@ def run(self) -> Model: for catalog in catalogs: try: self.add_schema_permissions(database, catalog, ssh_tunnel) - except Exception: # pylint: disable=broad-except + except GenericDBException: # pylint: disable=broad-except logger.warning("Error processing catalog '%s'", catalog) continue except ( diff --git a/superset/commands/database/importers/v1/utils.py b/superset/commands/database/importers/v1/utils.py index 56d31b03e1998..bf65225d8d872 100644 --- a/superset/commands/database/importers/v1/utils.py +++ b/superset/commands/database/importers/v1/utils.py @@ -62,14 +62,55 @@ def import_database( config["extra"] = json.dumps(config["extra"]) # Before it gets removed in import_from_dict - ssh_tunnel = config.pop("ssh_tunnel", None) + ssh_tunnel_config = config.pop("ssh_tunnel", None) database = Database.import_from_dict(config, recursive=False) if database.id is None: db.session.flush() - if ssh_tunnel: - ssh_tunnel["database_id"] = database.id - SSHTunnel.import_from_dict(ssh_tunnel, recursive=False) + if ssh_tunnel_config: + ssh_tunnel_config["database_id"] = database.id + ssh_tunnel = SSHTunnel.import_from_dict(ssh_tunnel_config, recursive=False) + else: + ssh_tunnel = None + + # TODO (betodealmeida): we should use the `CreateDatabaseCommand` for imports + add_permissions(database, ssh_tunnel) return database + + +def add_permissions(database: Database, ssh_tunnel: SSHTunnel) -> None: + """ + Add DAR for catalogs and schemas. + """ + if database.db_engine_spec.supports_catalog: + catalogs = database.get_all_catalog_names( + cache=False, + ssh_tunnel=ssh_tunnel, + ) + for catalog in catalogs: + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm( + database.database_name, + catalog, + ), + ) + else: + catalogs = [None] + + for catalog in catalogs: + for schema in database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ): + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 75aebf322be9d..0fc31c096a063 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -41,6 +41,7 @@ from superset.daos.database import DatabaseDAO from superset.daos.dataset import DatasetDAO from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.db_engine_specs.base import GenericDBException from superset.models.core import Database from superset.utils.decorators import on_error, transaction @@ -80,6 +81,7 @@ def run(self) -> Model: database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) self._refresh_catalogs(database, original_database_name, ssh_tunnel) + return database def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None: @@ -115,17 +117,13 @@ def _get_catalog_names( ) -> set[str]: """ Helper method to load catalogs. - - This method captures a generic exception, since errors could potentially come - from any of the 50+ database drivers we support. """ - try: return database.get_all_catalog_names( force=True, ssh_tunnel=ssh_tunnel, ) - except Exception as ex: + except GenericDBException as ex: raise DatabaseConnectionFailedError() from ex def _get_schema_names( @@ -136,18 +134,14 @@ def _get_schema_names( ) -> set[str]: """ Helper method to load schemas. - - This method captures a generic exception, since errors could potentially come - from any of the 50+ database drivers we support. """ - try: return database.get_all_schema_names( force=True, catalog=catalog, ssh_tunnel=ssh_tunnel, ) - except Exception as ex: + except GenericDBException as ex: raise DatabaseConnectionFailedError() from ex def _refresh_catalogs( @@ -255,7 +249,7 @@ def _rename_database_in_permissions( catalog: str | None, schemas: set[str], ) -> None: - new_name = security_manager.get_catalog_perm( + new_catalog_perm_name = security_manager.get_catalog_perm( database.database_name, catalog, ) @@ -271,10 +265,10 @@ def _rename_database_in_permissions( perm, ) if existing_pvm: - existing_pvm.view_menu.name = new_name + existing_pvm.view_menu.name = new_catalog_perm_name for schema in schemas: - new_name = security_manager.get_schema_perm( + new_schema_perm_name = security_manager.get_schema_perm( database.database_name, catalog, schema, @@ -291,7 +285,7 @@ def _rename_database_in_permissions( perm, ) if existing_pvm: - existing_pvm.view_menu.name = new_name + existing_pvm.view_menu.name = new_schema_perm_name # rename permissions on datasets and charts for dataset in DatabaseDAO.get_datasets( @@ -299,9 +293,11 @@ def _rename_database_in_permissions( catalog=catalog, schema=schema, ): - dataset.schema_perm = new_name + dataset.catalog_perm = new_catalog_perm_name + dataset.schema_perm = new_schema_perm_name for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]: - chart.schema_perm = new_name + chart.catalog_perm = new_catalog_perm_name + chart.schema_perm = new_schema_perm_name def validate(self) -> None: if database_name := self._properties.get("database_name"): diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 91860bf5b352d..4f66e2fdc73fa 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -434,7 +434,25 @@ def get_default_catalog( cls, database: Database, ) -> str | None: + """ + Return the default catalog. + + The default behavior for Databricks is confusing. When Unity Catalog is not + enabled we have (the DB engine spec hasn't been tested with it enabled): + + > SHOW CATALOGS; + spark_catalog + > SELECT current_catalog(); + hive_metastore + + To handle permissions correctly we use the result of `SHOW CATALOGS` when a + single catalog is returned. + """ with database.get_sqla_engine() as engine: + catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")} + if len(catalogs) == 1: + return catalogs.pop() + return engine.execute("SELECT current_catalog()").scalar() @classmethod diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 0f5948ad7b238..a8bae9b64d4a7 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -18,6 +18,7 @@ from io import BytesIO from unittest import mock +from unittest.mock import patch from zipfile import is_zipfile, ZipFile import prison @@ -1768,7 +1769,8 @@ def test_export_chart_gamma(self): assert rv.status_code == 404 - def test_import_chart(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_chart(self, mock_add_permissions): """ Chart API: Test import chart """ @@ -1805,7 +1807,8 @@ def test_import_chart(self): db.session.delete(database) db.session.commit() - def test_import_chart_overwrite(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_chart_overwrite(self, mock_add_permissions): """ Chart API: Test import existing chart """ @@ -1876,7 +1879,8 @@ def test_import_chart_overwrite(self): db.session.delete(database) db.session.commit() - def test_import_chart_invalid(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_chart_invalid(self, mock_add_permissions): """ Chart API: Test import invalid chart """ diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py index 1ee4658b88ffc..d8522473d09f2 100644 --- a/tests/integration_tests/charts/commands_tests.py +++ b/tests/integration_tests/charts/commands_tests.py @@ -173,7 +173,8 @@ def test_export_chart_command_no_related(self, mock_g): class TestImportChartsCommand(SupersetTestCase): @patch("superset.utils.core.g") @patch("superset.security.manager.g") - def test_import_v1_chart(self, sm_g, utils_g) -> None: + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_chart(self, mock_add_permissions, sm_g, utils_g) -> None: """Test that we can import a chart""" admin = sm_g.user = utils_g.user = security_manager.find_user("admin") contents = { @@ -246,7 +247,8 @@ def test_import_v1_chart(self, sm_g, utils_g) -> None: db.session.commit() @patch("superset.security.manager.g") - def test_import_v1_chart_multiple(self, sm_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_chart_multiple(self, mock_add_permissions, sm_g): """Test that a chart can be imported multiple times""" sm_g.user = security_manager.find_user("admin") contents = { @@ -272,7 +274,8 @@ def test_import_v1_chart_multiple(self, sm_g): db.session.delete(database) db.session.commit() - def test_import_v1_chart_validation(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_chart_validation(self, mock_add_permissions): """Test different validations applied when importing a chart""" # metadata.yaml must be present contents = { diff --git a/tests/integration_tests/commands_test.py b/tests/integration_tests/commands_test.py index 83409fd02280f..77dd32b8169b3 100644 --- a/tests/integration_tests/commands_test.py +++ b/tests/integration_tests/commands_test.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import copy +from unittest.mock import patch import yaml from flask import g @@ -63,8 +64,10 @@ def setUp(self): self.user = user setattr(g, "user", user) - def test_import_assets(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_assets(self, mock_add_permissions): """Test that we can import multiple assets""" + contents = { "metadata.yaml": yaml.safe_dump(metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), @@ -144,13 +147,16 @@ def test_import_assets(self): assert dashboard.owners == [self.user] + mock_add_permissions.assert_called_with(database, None) + db.session.delete(dashboard) db.session.delete(chart) db.session.delete(dataset) db.session.delete(database) db.session.commit() - def test_import_v1_dashboard_overwrite(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_dashboard_overwrite(self, mock_add_permissions): """Test that assets can be overwritten""" contents = { "metadata.yaml": yaml.safe_dump(metadata_config), @@ -185,6 +191,9 @@ def test_import_v1_dashboard_overwrite(self): chart = dashboard.slices[0] dataset = chart.table database = dataset.database + + mock_add_permissions.assert_called_with(database, None) + db.session.delete(dashboard) db.session.delete(chart) db.session.delete(dataset) diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 259b9485fbe78..86a855a289073 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -2111,7 +2111,8 @@ def test_export_bundle_not_allowed(self): db.session.delete(dashboard) db.session.commit() - def test_import_dashboard(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_dashboard(self, mock_add_permissions): """ Dashboard API: Test import dashboard """ @@ -2215,7 +2216,8 @@ def test_import_dashboard_v0_export(self): db.session.delete(dataset) db.session.commit() - def test_import_dashboard_overwrite(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_dashboard_overwrite(self, mock_add_permissions): """ Dashboard API: Test import existing dashboard """ diff --git a/tests/integration_tests/dashboards/commands_tests.py b/tests/integration_tests/dashboards/commands_tests.py index 334e0425cf1f3..4564f39989722 100644 --- a/tests/integration_tests/dashboards/commands_tests.py +++ b/tests/integration_tests/dashboards/commands_tests.py @@ -488,7 +488,8 @@ def test_import_v0_dashboard_cli_export(self): @patch("superset.utils.core.g") @patch("superset.security.manager.g") - def test_import_v1_dashboard(self, sm_g, utils_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_dashboard(self, mock_add_permissions, sm_g, utils_g): """Test that we can import a dashboard""" admin = sm_g.user = utils_g.user = security_manager.find_user("admin") contents = { @@ -577,7 +578,8 @@ def test_import_v1_dashboard(self, sm_g, utils_g): db.session.commit() @patch("superset.security.manager.g") - def test_import_v1_dashboard_multiple(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_dashboard_multiple(self, mock_add_permissions, mock_g): """Test that a dashboard can be imported multiple times""" mock_g.user = security_manager.find_user("admin") diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 8d0cd0810f8b1..5dd526025ed92 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2331,7 +2331,8 @@ def test_export_database_non_existing(self): rv = self.get_assert_metric(uri, "export") assert rv.status_code == 404 - def test_import_database(self): + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_database(self, mock_add_permissions): """ Database API: Test import database """ @@ -2363,7 +2364,8 @@ def test_import_database(self): db.session.delete(database) db.session.commit() - def test_import_database_overwrite(self): + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_database_overwrite(self, mock_add_permissions): """ Database API: Test import existing database """ @@ -2433,7 +2435,8 @@ def test_import_database_overwrite(self): db.session.delete(database) db.session.commit() - def test_import_database_invalid(self): + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_database_invalid(self, mock_add_permissions): """ Database API: Test import invalid database """ @@ -2483,7 +2486,8 @@ def test_import_database_invalid(self): ] } - def test_import_database_masked_password(self): + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_database_masked_password(self, mock_add_permissions): """ Database API: Test import database with masked password """ @@ -2540,7 +2544,8 @@ def test_import_database_masked_password(self): ] } - def test_import_database_masked_password_provided(self): + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_database_masked_password_provided(self, mock_add_permissions): """ Database API: Test import database with masked password provided """ @@ -2586,8 +2591,11 @@ def test_import_database_masked_password_provided(self): db.session.commit() @mock.patch("superset.databases.schemas.is_feature_enabled") + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_database_masked_ssh_tunnel_password( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """ Database API: Test import database with masked password @@ -2644,8 +2652,11 @@ def test_import_database_masked_ssh_tunnel_password( } @mock.patch("superset.databases.schemas.is_feature_enabled") + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_database_masked_ssh_tunnel_password_provided( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """ Database API: Test import database with masked password provided @@ -2692,8 +2703,11 @@ def test_import_database_masked_ssh_tunnel_password_provided( db.session.commit() @mock.patch("superset.databases.schemas.is_feature_enabled") + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_database_masked_ssh_tunnel_private_key_and_password( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """ Database API: Test import database with masked private_key @@ -2753,8 +2767,11 @@ def test_import_database_masked_ssh_tunnel_private_key_and_password( } @mock.patch("superset.databases.schemas.is_feature_enabled") + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_database_masked_ssh_tunnel_private_key_and_password_provided( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """ Database API: Test import database with masked password provided @@ -2804,7 +2821,11 @@ def test_import_database_masked_ssh_tunnel_private_key_and_password_provided( db.session.delete(database) db.session.commit() - def test_import_database_masked_ssh_tunnel_feature_flag_disabled(self): + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_database_masked_ssh_tunnel_feature_flag_disabled( + self, + mock_add_permissions, + ): """ Database API: Test import database with ssh_tunnel and feature flag disabled """ @@ -2856,8 +2877,11 @@ def test_import_database_masked_ssh_tunnel_feature_flag_disabled(self): } @mock.patch("superset.databases.schemas.is_feature_enabled") + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_database_masked_ssh_tunnel_feature_no_credentials( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """ Database API: Test import database with ssh_tunnel that has no credentials @@ -2911,8 +2935,11 @@ def test_import_database_masked_ssh_tunnel_feature_no_credentials( } @mock.patch("superset.databases.schemas.is_feature_enabled") + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_database_masked_ssh_tunnel_feature_mix_credentials( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """ Database API: Test import database with ssh_tunnel that has no credentials @@ -2966,8 +2993,11 @@ def test_import_database_masked_ssh_tunnel_feature_mix_credentials( } @mock.patch("superset.databases.schemas.is_feature_enabled") + @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_database_masked_ssh_tunnel_feature_only_pk_passwd( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """ Database API: Test import database with ssh_tunnel that has no credentials @@ -3802,7 +3832,7 @@ def test_get_related_objects(self): assert "dashboards" in rv.json assert "sqllab_tab_states" in rv.json - @patch.dict( + @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", SQL_VALIDATORS_BY_ENGINE, clear=True, @@ -3828,7 +3858,7 @@ def test_validate_sql(self): self.assertEqual(rv.status_code, 200) self.assertEqual(response["result"], []) - @patch.dict( + @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", SQL_VALIDATORS_BY_ENGINE, clear=True, @@ -3864,7 +3894,7 @@ def test_validate_sql_errors(self): ], ) - @patch.dict( + @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", SQL_VALIDATORS_BY_ENGINE, clear=True, @@ -3885,7 +3915,7 @@ def test_validate_sql_not_found(self): rv = self.client.post(uri, json=request_payload) self.assertEqual(rv.status_code, 404) - @patch.dict( + @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", SQL_VALIDATORS_BY_ENGINE, clear=True, @@ -3908,7 +3938,7 @@ def test_validate_sql_validation_fails(self): self.assertEqual(rv.status_code, 400) self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}}) - @patch.dict( + @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", {}, clear=True, @@ -3953,8 +3983,8 @@ def test_validate_sql_endpoint_noconfig(self): }, ) - @patch("superset.commands.database.validate_sql.get_validator_by_name") - @patch.dict( + @mock.patch("superset.commands.database.validate_sql.get_validator_by_name") + @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", PRESTO_SQL_VALIDATORS_BY_ENGINE, clear=True, diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 8979b91c47241..bbf9b82592448 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -218,9 +218,9 @@ def test_export_database_command(self, mock_g): "is_active": True, "is_dttm": False, "python_date_format": None, - "type": "STRING" - if example_db.backend == "hive" - else "VARCHAR(255)", + "type": ( + "STRING" if example_db.backend == "hive" else "VARCHAR(255)" + ), "advanced_data_type": None, "verbose_name": None, }, @@ -397,7 +397,8 @@ def test_export_database_command_no_related(self, mock_g): class TestImportDatabasesCommand(SupersetTestCase): @patch("superset.security.manager.g") - def test_import_v1_database(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_database(self, mock_add_permissions, mock_g): """Test that a database can be imported""" mock_g.user = security_manager.find_user("admin") @@ -420,13 +421,14 @@ def test_import_v1_database(self, mock_g): assert database.database_name == "imported_database" assert database.expose_in_sqllab assert database.extra == "{}" - assert database.sqlalchemy_uri == "someengine://user:pass@host1" + assert database.sqlalchemy_uri == "postgresql://user:pass@host1" db.session.delete(database) db.session.commit() @patch("superset.security.manager.g") - def test_import_v1_database_broken_csv_fields(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_database_broken_csv_fields(self, mock_add_permissions, mock_g): """ Test that a database can be imported with broken schema. @@ -459,13 +461,14 @@ def test_import_v1_database_broken_csv_fields(self, mock_g): assert database.database_name == "imported_database" assert database.expose_in_sqllab assert database.extra == '{"schemas_allowed_for_file_upload": ["upload"]}' - assert database.sqlalchemy_uri == "someengine://user:pass@host1" + assert database.sqlalchemy_uri == "postgresql://user:pass@host1" db.session.delete(database) db.session.commit() @patch("superset.security.manager.g") - def test_import_v1_database_multiple(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_database_multiple(self, mock_add_permissions, mock_g): """Test that a database can be imported multiple times""" mock_g.user = security_manager.find_user("admin") @@ -509,7 +512,8 @@ def test_import_v1_database_multiple(self, mock_g): db.session.commit() @patch("superset.security.manager.g") - def test_import_v1_database_with_dataset(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_database_with_dataset(self, mock_add_permissions, mock_g): """Test that a database can be imported with datasets""" mock_g.user = security_manager.find_user("admin") @@ -532,7 +536,10 @@ def test_import_v1_database_with_dataset(self, mock_g): db.session.commit() @patch("superset.security.manager.g") - def test_import_v1_database_with_dataset_multiple(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_database_with_dataset_multiple( + self, mock_add_permissions, mock_g + ): """Test that a database can be imported multiple times w/o changing datasets""" mock_g.user = security_manager.find_user("admin") @@ -570,7 +577,8 @@ def test_import_v1_database_with_dataset_multiple(self, mock_g): db.session.delete(dataset.database) db.session.commit() - def test_import_v1_database_validation(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_database_validation(self, mock_add_permissions): """Test different validations applied when importing a database""" # metadata.yaml must be present contents = { @@ -619,7 +627,8 @@ def test_import_v1_database_validation(self): } } - def test_import_v1_database_masked_password(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_database_masked_password(self, mock_add_permissions): """Test that database imports with masked passwords are rejected""" masked_database_config = database_config.copy() masked_database_config["sqlalchemy_uri"] = ( @@ -640,8 +649,11 @@ def test_import_v1_database_masked_password(self): } @patch("superset.databases.schemas.is_feature_enabled") + @patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_v1_database_masked_ssh_tunnel_password( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """Test that database imports with masked ssh_tunnel passwords are rejected""" mock_schema_is_feature_enabled.return_value = True @@ -661,8 +673,11 @@ def test_import_v1_database_masked_ssh_tunnel_password( } @patch("superset.databases.schemas.is_feature_enabled") + @patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_v1_database_masked_ssh_tunnel_private_key_and_password( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """Test that database imports with masked ssh_tunnel private_key and private_key_password are rejected""" mock_schema_is_feature_enabled.return_value = True @@ -686,8 +701,10 @@ def test_import_v1_database_masked_ssh_tunnel_private_key_and_password( @patch("superset.databases.schemas.is_feature_enabled") @patch("superset.security.manager.g") + @patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_v1_database_with_ssh_tunnel_password( self, + mock_add_permissions, mock_g, mock_schema_is_feature_enabled, ): @@ -715,7 +732,7 @@ def test_import_v1_database_with_ssh_tunnel_password( assert database.database_name == "imported_database" assert database.expose_in_sqllab assert database.extra == "{}" - assert database.sqlalchemy_uri == "someengine://user:pass@host1" + assert database.sqlalchemy_uri == "postgresql://user:pass@host1" model_ssh_tunnel = ( db.session.query(SSHTunnel) @@ -729,8 +746,10 @@ def test_import_v1_database_with_ssh_tunnel_password( @patch("superset.databases.schemas.is_feature_enabled") @patch("superset.security.manager.g") + @patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_v1_database_with_ssh_tunnel_private_key_and_password( self, + mock_add_permissions, mock_g, mock_schema_is_feature_enabled, ): @@ -760,7 +779,7 @@ def test_import_v1_database_with_ssh_tunnel_private_key_and_password( assert database.database_name == "imported_database" assert database.expose_in_sqllab assert database.extra == "{}" - assert database.sqlalchemy_uri == "someengine://user:pass@host1" + assert database.sqlalchemy_uri == "postgresql://user:pass@host1" model_ssh_tunnel = ( db.session.query(SSHTunnel) @@ -774,8 +793,11 @@ def test_import_v1_database_with_ssh_tunnel_private_key_and_password( db.session.commit() @patch("superset.databases.schemas.is_feature_enabled") + @patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_v1_database_masked_ssh_tunnel_no_credentials( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """Test that databases with ssh_tunnels that have no credentials are rejected""" mock_schema_is_feature_enabled.return_value = True @@ -790,8 +812,11 @@ def test_import_v1_database_masked_ssh_tunnel_no_credentials( assert str(excinfo.value) == "Must provide credentials for the SSH Tunnel" @patch("superset.databases.schemas.is_feature_enabled") + @patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_v1_database_masked_ssh_tunnel_multiple_credentials( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """Test that databases with ssh_tunnels that have multiple credentials are rejected""" mock_schema_is_feature_enabled.return_value = True @@ -808,8 +833,11 @@ def test_import_v1_database_masked_ssh_tunnel_multiple_credentials( ) @patch("superset.databases.schemas.is_feature_enabled") + @patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_v1_database_masked_ssh_tunnel_only_priv_key_psswd( - self, mock_schema_is_feature_enabled + self, + mock_add_permissions, + mock_schema_is_feature_enabled, ): """Test that databases with ssh_tunnels that have multiple credentials are rejected""" mock_schema_is_feature_enabled.return_value = True @@ -834,7 +862,8 @@ def test_import_v1_database_masked_ssh_tunnel_only_priv_key_psswd( } @patch("superset.commands.database.importers.v1.import_dataset") - def test_import_v1_rollback(self, mock_import_dataset): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_rollback(self, mock_add_permissions, mock_import_dataset): """Test than on an exception everything is rolled back""" num_databases = db.session.query(Database).count() diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 37de6e87c27ad..2bc9c721f6fae 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -2039,7 +2039,8 @@ def test_get_datasets_custom_filter_sql(self): for table_name in self.fixture_tables_names: assert table_name in [ds["table_name"] for ds in data["result"]] - def test_import_dataset(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_dataset(self, mock_add_permissions): """ Dataset API: Test import dataset """ @@ -2102,7 +2103,8 @@ def test_import_dataset_v0_export(self): db.session.delete(dataset) db.session.commit() - def test_import_dataset_overwrite(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_dataset_overwrite(self, mock_add_permissions): """ Dataset API: Test import existing dataset """ diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 53bd7fa051aa6..66a15e2e61d52 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -343,8 +343,9 @@ def test_import_v0_dataset_ui_export(self): @patch("superset.utils.core.g") @patch("superset.security.manager.g") + @patch("superset.commands.database.importers.v1.utils.add_permissions") @pytest.mark.usefixtures("load_energy_table_with_slice") - def test_import_v1_dataset(self, sm_g, utils_g): + def test_import_v1_dataset(self, mock_add_permissions, sm_g, utils_g): """Test that we can import a dataset""" admin = sm_g.user = utils_g.user = security_manager.find_user("admin") contents = { @@ -411,7 +412,8 @@ def test_import_v1_dataset(self, sm_g, utils_g): db.session.commit() @patch("superset.security.manager.g") - def test_import_v1_dataset_multiple(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_dataset_multiple(self, mock_add_permissions, mock_g): """Test that a dataset can be imported multiple times""" mock_g.user = security_manager.find_user("admin") @@ -452,7 +454,8 @@ def test_import_v1_dataset_multiple(self, mock_g): db.session.delete(dataset.database) db.session.commit() - def test_import_v1_dataset_validation(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_dataset_validation(self, mock_add_permissions): """Test different validations applied when importing a dataset""" # metadata.yaml must be present contents = { @@ -502,7 +505,8 @@ def test_import_v1_dataset_validation(self): } @patch("superset.security.manager.g") - def test_import_v1_dataset_existing_database(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_dataset_existing_database(self, mock_add_permissions, mock_g): """Test that a dataset can be imported when the database already exists""" mock_g.user = security_manager.find_user("admin") diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py index cccf4fa7701a4..1ca3e303c623a 100644 --- a/tests/integration_tests/fixtures/importexport.py +++ b/tests/integration_tests/fixtures/importexport.py @@ -374,7 +374,7 @@ "database_name": "imported_database", "expose_in_sqllab": True, "extra": {}, - "sqlalchemy_uri": "someengine://user:pass@host1", + "sqlalchemy_uri": "postgresql://user:pass@host1", "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", "version": "1.0.0", } @@ -389,7 +389,7 @@ "database_name": "imported_database", "expose_in_sqllab": True, "extra": {}, - "sqlalchemy_uri": "someengine://user:pass@host1", + "sqlalchemy_uri": "postgresql://user:pass@host1", "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", "ssh_tunnel": { "server_address": "localhost", @@ -411,7 +411,7 @@ "database_name": "imported_database", "expose_in_sqllab": True, "extra": {}, - "sqlalchemy_uri": "someengine://user:pass@host1", + "sqlalchemy_uri": "postgresql://user:pass@host1", "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", "ssh_tunnel": { "server_address": "localhost", diff --git a/tests/integration_tests/queries/saved_queries/api_tests.py b/tests/integration_tests/queries/saved_queries/api_tests.py index 9b1184b1f73f8..4ce0a79dac9a4 100644 --- a/tests/integration_tests/queries/saved_queries/api_tests.py +++ b/tests/integration_tests/queries/saved_queries/api_tests.py @@ -20,6 +20,7 @@ from datetime import datetime from io import BytesIO from typing import Optional +from unittest.mock import patch from zipfile import is_zipfile, ZipFile import yaml @@ -898,7 +899,8 @@ def create_saved_query_import(self): buf.seek(0) return buf - def test_import_saved_queries(self): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_saved_queries(self, mock_add_permissions): """ Saved Query API: Test import """ diff --git a/tests/integration_tests/queries/saved_queries/commands_tests.py b/tests/integration_tests/queries/saved_queries/commands_tests.py index 8babd7efb9dba..4ce816622f7d3 100644 --- a/tests/integration_tests/queries/saved_queries/commands_tests.py +++ b/tests/integration_tests/queries/saved_queries/commands_tests.py @@ -148,7 +148,8 @@ def test_export_query_command_key_order(self, mock_g): class TestImportSavedQueriesCommand(SupersetTestCase): @patch("superset.security.manager.g") - def test_import_v1_saved_queries(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_saved_queries(self, mock_add_permissions, mock_g): """Test that we can import a saved query""" mock_g.user = security_manager.find_user("admin") @@ -178,7 +179,8 @@ def test_import_v1_saved_queries(self, mock_g): db.session.commit() @patch("superset.security.manager.g") - def test_import_v1_saved_queries_multiple(self, mock_g): + @patch("superset.commands.database.importers.v1.utils.add_permissions") + def test_import_v1_saved_queries_multiple(self, mock_add_permissions, mock_g): """Test that a saved query can be imported multiple times""" mock_g.user = security_manager.find_user("admin") diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index 37500d521420a..b1b5e6843f0c2 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -178,7 +178,12 @@ def test_rename_with_catalog( DatabaseDAO.find_by_id.return_value = original_database database_with_catalog.database_name = "my_other_db" DatabaseDAO.update.return_value = database_with_catalog - DatabaseDAO.get_datasets.return_value = [] + + dataset = mocker.MagicMock() + chart = mocker.MagicMock() + DatabaseDAO.get_datasets.return_value = [dataset] + DatasetDAO = mocker.patch("superset.commands.database.update.DatasetDAO") + DatasetDAO.get_related_objects.return_value = {"charts": [chart]} find_permission_view_menu = mocker.patch.object( security_manager, @@ -218,6 +223,11 @@ def test_rename_with_catalog( assert catalog2_pvm.view_menu.name == "[my_other_db].[catalog2]" assert catalog2_schema3_pvm.view_menu.name == "[my_other_db].[catalog2].[schema3]" + assert dataset.catalog_perm == "[my_other_db].[catalog2]" + assert dataset.schema_perm == "[my_other_db].[catalog2].[schema4]" + assert chart.catalog_perm == "[my_other_db].[catalog2]" + assert chart.schema_perm == "[my_other_db].[catalog2].[schema4]" + def test_rename_without_catalog( mocker: MockerFixture, diff --git a/tests/unit_tests/databases/commands/importers/v1/import_test.py b/tests/unit_tests/databases/commands/importers/v1/import_test.py index bfb472b3e2a0a..86d848ff2192d 100644 --- a/tests/unit_tests/databases/commands/importers/v1/import_test.py +++ b/tests/unit_tests/databases/commands/importers/v1/import_test.py @@ -23,6 +23,7 @@ from sqlalchemy.orm.session import Session from superset import db +from superset.commands.database.importers.v1.utils import add_permissions from superset.commands.exceptions import ImportFailedError from superset.utils import json @@ -37,6 +38,7 @@ def test_import_database(mocker: MockerFixture, session: Session) -> None: from tests.integration_tests.fixtures.importexport import database_config mocker.patch.object(security_manager, "can_access", return_value=True) + mocker.patch("superset.commands.database.importers.v1.utils.add_permissions") engine = db.session.get_bind() Database.metadata.create_all(engine) # pylint: disable=no-member @@ -44,7 +46,7 @@ def test_import_database(mocker: MockerFixture, session: Session) -> None: config = copy.deepcopy(database_config) database = import_database(config) assert database.database_name == "imported_database" - assert database.sqlalchemy_uri == "someengine://user:pass@host1" + assert database.sqlalchemy_uri == "postgresql://user:pass@host1" assert database.cache_timeout is None assert database.expose_in_sqllab is True assert database.allow_run_async is False @@ -108,6 +110,7 @@ def test_import_database_managed_externally( from tests.integration_tests.fixtures.importexport import database_config mocker.patch.object(security_manager, "can_access", return_value=True) + mocker.patch("superset.commands.database.importers.v1.utils.add_permissions") engine = db.session.get_bind() Database.metadata.create_all(engine) # pylint: disable=no-member @@ -158,6 +161,7 @@ def test_import_database_with_version(mocker: MockerFixture, session: Session) - from tests.integration_tests.fixtures.importexport import database_config mocker.patch.object(security_manager, "can_access", return_value=True) + mocker.patch("superset.commands.database.importers.v1.utils.add_permissions") engine = db.session.get_bind() Database.metadata.create_all(engine) # pylint: disable=no-member @@ -166,3 +170,30 @@ def test_import_database_with_version(mocker: MockerFixture, session: Session) - config["extra"]["version"] = "1.1.1" database = import_database(config) assert json.loads(database.extra)["version"] == "1.1.1" + + +def test_add_permissions(mocker: MockerFixture) -> None: + """ + Test adding permissions to a database when it's imported. + """ + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.supports_catalog = True + database.get_all_catalog_names.return_value = ["catalog1", "catalog2"] + database.get_all_schema_names.side_effect = [["schema1"], ["schema2"]] + ssh_tunnel = mocker.MagicMock() + add_permission_view_menu = mocker.patch( + "superset.commands.database.importers.v1.utils.security_manager." + "add_permission_view_menu" + ) + + add_permissions(database, ssh_tunnel) + + add_permission_view_menu.assert_has_calls( + [ + mocker.call("catalog_access", "[my_db].[catalog1]"), + mocker.call("catalog_access", "[my_db].[catalog2]"), + mocker.call("schema_access", "[my_db].[catalog1].[schema1]"), + mocker.call("schema_access", "[my_db].[catalog2].[schema2]"), + ] + )