Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: create permissions on DB import #29802

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion superset/commands/database/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import SupersetErrorsException
from superset.extensions import event_logger, security_manager
from superset.models.core import Database
Expand Down Expand Up @@ -118,7 +119,7 @@ def run(self) -> Model:
for catalog in catalogs:
try:
self.add_schema_permissions(database, catalog, ssh_tunnel)
except Exception: # pylint: disable=broad-except
except GenericDBException: # pylint: disable=broad-except
logger.warning("Error processing catalog '%s'", catalog)
continue
except (
Expand Down
49 changes: 45 additions & 4 deletions superset/commands/database/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,55 @@ def import_database(
config["extra"] = json.dumps(config["extra"])

# Before it gets removed in import_from_dict
ssh_tunnel = config.pop("ssh_tunnel", None)
ssh_tunnel_config = config.pop("ssh_tunnel", None)

database = Database.import_from_dict(config, recursive=False)
if database.id is None:
db.session.flush()

if ssh_tunnel:
ssh_tunnel["database_id"] = database.id
SSHTunnel.import_from_dict(ssh_tunnel, recursive=False)
if ssh_tunnel_config:
ssh_tunnel_config["database_id"] = database.id
ssh_tunnel = SSHTunnel.import_from_dict(ssh_tunnel_config, recursive=False)
else:
ssh_tunnel = None

# TODO (betodealmeida): we should use the `CreateDatabaseCommand` for imports
add_permissions(database, ssh_tunnel)

return database


def add_permissions(database: Database, ssh_tunnel: SSHTunnel) -> None:
"""
Add DAR for catalogs and schemas.
"""
if database.db_engine_spec.supports_catalog:
catalogs = database.get_all_catalog_names(
cache=False,
ssh_tunnel=ssh_tunnel,
)
for catalog in catalogs:
security_manager.add_permission_view_menu(
"catalog_access",
security_manager.get_catalog_perm(
database.database_name,
catalog,
),
)
else:
catalogs = [None]

for catalog in catalogs:
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
),
)
28 changes: 12 additions & 16 deletions superset/commands/database/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction

Expand Down Expand Up @@ -80,6 +81,7 @@ def run(self) -> Model:
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)

return database

def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
Expand Down Expand Up @@ -115,17 +117,13 @@ def _get_catalog_names(
) -> set[str]:
"""
Helper method to load catalogs.

This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""

try:
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex

def _get_schema_names(
Expand All @@ -136,18 +134,14 @@ def _get_schema_names(
) -> set[str]:
"""
Helper method to load schemas.

This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""

try:
return database.get_all_schema_names(
force=True,
catalog=catalog,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex

def _refresh_catalogs(
Expand Down Expand Up @@ -255,7 +249,7 @@ def _rename_database_in_permissions(
catalog: str | None,
schemas: set[str],
) -> None:
new_name = security_manager.get_catalog_perm(
new_catalog_perm_name = security_manager.get_catalog_perm(
database.database_name,
catalog,
)
Expand All @@ -271,10 +265,10 @@ def _rename_database_in_permissions(
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_name
existing_pvm.view_menu.name = new_catalog_perm_name

for schema in schemas:
new_name = security_manager.get_schema_perm(
new_schema_perm_name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
Expand All @@ -291,17 +285,19 @@ def _rename_database_in_permissions(
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_name
existing_pvm.view_menu.name = new_schema_perm_name

# rename permissions on datasets and charts
for dataset in DatabaseDAO.get_datasets(
database.id,
catalog=catalog,
schema=schema,
):
dataset.schema_perm = new_name
dataset.catalog_perm = new_catalog_perm_name
dataset.schema_perm = new_schema_perm_name
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
chart.schema_perm = new_name
chart.catalog_perm = new_catalog_perm_name
chart.schema_perm = new_schema_perm_name

def validate(self) -> None:
if database_name := self._properties.get("database_name"):
Expand Down
18 changes: 18 additions & 0 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,25 @@ def get_default_catalog(
cls,
database: Database,
) -> str | None:
"""
Return the default catalog.

The default behavior for Databricks is confusing. When Unity Catalog is not
enabled we have (the DB engine spec hasn't been tested with it enabled):

> SHOW CATALOGS;
spark_catalog
> SELECT current_catalog();
hive_metastore

To handle permissions correctly we use the result of `SHOW CATALOGS` when a
single catalog is returned.
"""
with database.get_sqla_engine() as engine:
catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")}
if len(catalogs) == 1:
return catalogs.pop()
Comment on lines +452 to +454
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if SHOW CATALOGS returns only spark_catalog, it seems we're returning it. Do we want to return spark_catalog, or do we need to return hive_metastore instead (when there's only spark_catalog)? I'm asking it because I think there are issues with other metadata calls if we try to use spark_catalog when the catalog selected in the form is hive_metastore (or if UC is not enabled).

Also, if multi_catalog is disabled, do we want to rely on SHOW CATALOGS? Should we validate if multi_catalog is disabled first, and if so return the value of current_catalog() directly (which should be the catalog from the form)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we need to create the permissions regardless of the state of multi_catalog. For Databricks, the migration process will create permissions for:

  1. Whatever get_default_catalog() returns.
  2. All the catalogs from get_catalog_names different from the default catalog from (1).

For Databricks, this means we're creating 2 permissions when Unity Catalog is not enabled:

  1. [Databricks].[hive_metastore]
  2. [Databricks].[spark_catalog]

This in itself is already a problem for the admin — which would should they assign roles to? Either one is problematic, depending on the status of multi_catalog:

With multi_catalog disabled the catalog sent by API calls is null, so it gets replaced with the default catalog, hive_metastore. If the admin assign roles to [Databricks].[spark_catalog] people won't be able to access datasets, because of the name mismatch.

With multi_catalog enabled the only option in the dropdown is spark_catalog. Which means if the admin assigned roles to [Databricks].[hive_metastore] people won't be able to access datasets.

With the change in this PR, everything is consistent, because the default catalog is part of the list of catalogs.


return engine.execute("SELECT current_catalog()").scalar()

@classmethod
Expand Down
10 changes: 7 additions & 3 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from io import BytesIO
from unittest import mock
from unittest.mock import patch
from zipfile import is_zipfile, ZipFile

import prison
Expand Down Expand Up @@ -1768,7 +1769,8 @@ def test_export_chart_gamma(self):

assert rv.status_code == 404

def test_import_chart(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_chart(self, mock_add_permissions):
"""
Chart API: Test import chart
"""
Expand Down Expand Up @@ -1805,7 +1807,8 @@ def test_import_chart(self):
db.session.delete(database)
db.session.commit()

def test_import_chart_overwrite(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_chart_overwrite(self, mock_add_permissions):
"""
Chart API: Test import existing chart
"""
Expand Down Expand Up @@ -1876,7 +1879,8 @@ def test_import_chart_overwrite(self):
db.session.delete(database)
db.session.commit()

def test_import_chart_invalid(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_chart_invalid(self, mock_add_permissions):
"""
Chart API: Test import invalid chart
"""
Expand Down
9 changes: 6 additions & 3 deletions tests/integration_tests/charts/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def test_export_chart_command_no_related(self, mock_g):
class TestImportChartsCommand(SupersetTestCase):
@patch("superset.utils.core.g")
@patch("superset.security.manager.g")
def test_import_v1_chart(self, sm_g, utils_g) -> None:
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_chart(self, mock_add_permissions, sm_g, utils_g) -> None:
"""Test that we can import a chart"""
admin = sm_g.user = utils_g.user = security_manager.find_user("admin")
contents = {
Expand Down Expand Up @@ -246,7 +247,8 @@ def test_import_v1_chart(self, sm_g, utils_g) -> None:
db.session.commit()

@patch("superset.security.manager.g")
def test_import_v1_chart_multiple(self, sm_g):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_chart_multiple(self, mock_add_permissions, sm_g):
"""Test that a chart can be imported multiple times"""
sm_g.user = security_manager.find_user("admin")
contents = {
Expand All @@ -272,7 +274,8 @@ def test_import_v1_chart_multiple(self, sm_g):
db.session.delete(database)
db.session.commit()

def test_import_v1_chart_validation(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_chart_validation(self, mock_add_permissions):
"""Test different validations applied when importing a chart"""
# metadata.yaml must be present
contents = {
Expand Down
13 changes: 11 additions & 2 deletions tests/integration_tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import copy
from unittest.mock import patch

import yaml
from flask import g
Expand Down Expand Up @@ -63,8 +64,10 @@ def setUp(self):
self.user = user
setattr(g, "user", user)

def test_import_assets(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_assets(self, mock_add_permissions):
"""Test that we can import multiple assets"""

contents = {
"metadata.yaml": yaml.safe_dump(metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
Expand Down Expand Up @@ -144,13 +147,16 @@ def test_import_assets(self):

assert dashboard.owners == [self.user]

mock_add_permissions.assert_called_with(database, None)

db.session.delete(dashboard)
db.session.delete(chart)
db.session.delete(dataset)
db.session.delete(database)
db.session.commit()

def test_import_v1_dashboard_overwrite(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_dashboard_overwrite(self, mock_add_permissions):
"""Test that assets can be overwritten"""
contents = {
"metadata.yaml": yaml.safe_dump(metadata_config),
Expand Down Expand Up @@ -185,6 +191,9 @@ def test_import_v1_dashboard_overwrite(self):
chart = dashboard.slices[0]
dataset = chart.table
database = dataset.database

mock_add_permissions.assert_called_with(database, None)

db.session.delete(dashboard)
db.session.delete(chart)
db.session.delete(dataset)
Expand Down
6 changes: 4 additions & 2 deletions tests/integration_tests/dashboards/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2111,7 +2111,8 @@ def test_export_bundle_not_allowed(self):
db.session.delete(dashboard)
db.session.commit()

def test_import_dashboard(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_dashboard(self, mock_add_permissions):
"""
Dashboard API: Test import dashboard
"""
Expand Down Expand Up @@ -2215,7 +2216,8 @@ def test_import_dashboard_v0_export(self):
db.session.delete(dataset)
db.session.commit()

def test_import_dashboard_overwrite(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_dashboard_overwrite(self, mock_add_permissions):
"""
Dashboard API: Test import existing dashboard
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/integration_tests/dashboards/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def test_import_v0_dashboard_cli_export(self):

@patch("superset.utils.core.g")
@patch("superset.security.manager.g")
def test_import_v1_dashboard(self, sm_g, utils_g):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_dashboard(self, mock_add_permissions, sm_g, utils_g):
"""Test that we can import a dashboard"""
admin = sm_g.user = utils_g.user = security_manager.find_user("admin")
contents = {
Expand Down Expand Up @@ -577,7 +578,8 @@ def test_import_v1_dashboard(self, sm_g, utils_g):
db.session.commit()

@patch("superset.security.manager.g")
def test_import_v1_dashboard_multiple(self, mock_g):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_dashboard_multiple(self, mock_add_permissions, mock_g):
"""Test that a dashboard can be imported multiple times"""
mock_g.user = security_manager.find_user("admin")

Expand Down
Loading
Loading