From e1f7f369b0e82628bef4262fcc3b0d98ee7a4ec8 Mon Sep 17 00:00:00 2001 From: Antonio Rivero Date: Mon, 20 Feb 2023 20:02:19 -0300 Subject: [PATCH] SSH Tunnel: - Add more robust checks around SSH Tunnel credentials passed in the file we want to import - Add extra exceptions so we can handle new errors and let users know - Update and add tests for new scenarios --- superset/databases/schemas.py | 93 ++++----- .../ssh_tunnel/commands/exceptions.py | 8 + .../integration_tests/databases/api_tests.py | 178 +++++++++++++++++- .../databases/commands_tests.py | 65 ++++++- .../fixtures/importexport.py | 64 +++++++ 5 files changed, 362 insertions(+), 46 deletions(-) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 9daac38b13374..b82969a19a061 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -19,7 +19,7 @@ import inspect import json -from typing import Any, Dict +from typing import Any, Dict, List from flask import current_app from flask_babel import lazy_gettext as _ @@ -33,6 +33,8 @@ from superset.databases.commands.exceptions import DatabaseInvalidError from superset.databases.ssh_tunnel.commands.exceptions import ( SSHTunnelingNotEnabledError, + SSHTunnelInvalidCredentials, + SSHTunnelMissingCredentials, ) from superset.databases.utils import make_url_safe from superset.db_engine_specs import get_engine_spec @@ -724,27 +726,10 @@ def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None: raise ValidationError("Must provide a password for the database") @validates_schema - def validate_ssh_tunnel_password(self, data: Dict[str, Any], **kwargs: Any) -> None: - """If ssh_tunnel has a masked password, password is required""" - uuid = data["uuid"] - existing = db.session.query(Database).filter_by(uuid=uuid).first() - if existing: - return - - # Our DB has a ssh_tunnel in it - if ssh_tunnel := data.get("ssh_tunnel"): - if not is_feature_enabled("SSH_TUNNELING"): - raise SSHTunnelingNotEnabledError() - password = ssh_tunnel.get("password") - if password == PASSWORD_MASK: - raise ValidationError("Must provide a password for the ssh tunnel") - return - - @validates_schema - def validate_ssh_tunnel_private_key( + def validate_ssh_tunnel_credentials( self, data: Dict[str, Any], **kwargs: Any ) -> None: - """If ssh_tunnel has a masked private key, private key is required""" + """If ssh_tunnel has a masked credentials, credentials are required""" uuid = data["uuid"] existing = db.session.query(Database).filter_by(uuid=uuid).first() if existing: @@ -752,34 +737,54 @@ def validate_ssh_tunnel_private_key( # Our DB has a ssh_tunnel in it if ssh_tunnel := data.get("ssh_tunnel"): + # Login methods are (only one from these options): + # 1. password + # 2. private_key + private_key_password + # Based on the data passed we determine what info is required. + # You cannot mix the credentials from both methods. if not is_feature_enabled("SSH_TUNNELING"): + # You are trying to import a Database with SSH Tunnel + # But the Feature Flag is not enabled. raise SSHTunnelingNotEnabledError() + password = ssh_tunnel.get("password") private_key = ssh_tunnel.get("private_key") - if private_key == PASSWORD_MASK: - raise ValidationError("Must provide a private key for the ssh tunnel") - return - - @validates_schema - def validate_ssh_tunnel_pkey_pass( - self, data: Dict[str, Any], **kwargs: Any - ) -> None: - """ - If ssh_tunnel has a masked private key password, private key password is required - """ - uuid = data["uuid"] - existing = db.session.query(Database).filter_by(uuid=uuid).first() - if existing: - return - - # Our DB has a ssh_tunnel in it - if ssh_tunnel := data.get("ssh_tunnel"): - if not is_feature_enabled("SSH_TUNNELING"): - raise SSHTunnelingNotEnabledError() private_key_password = ssh_tunnel.get("private_key_password") - if private_key_password == PASSWORD_MASK: - raise ValidationError( - "Must provide a private key password for the ssh tunnel" - ) + if password is not None: + # Login method #1 (Password) + if private_key is not None or private_key_password is not None: + # You cannot have a mix of login methods + raise SSHTunnelInvalidCredentials() + if password == PASSWORD_MASK: + raise ValidationError("Must provide a password for the ssh tunnel") + if password is None: + # If the SSH Tunnel we're importing has no password then it must + # have a private_key + private_key_password combination + if private_key is None and private_key_password is None: + # We have found nothing related to other credentials + raise SSHTunnelMissingCredentials() + # We need to ask for the missing properties of our method # 2 + # Some times the property is just missing + # or there're times where it's masked. + # If both are masked, we need to return a list of errors + # so the UI ask for both fields at the same time if needed + exception_messages: List[str] = [] + if private_key is None or private_key == PASSWORD_MASK: + # If we get here we need to ask for the private key + exception_messages.append( + "Must provide a private key for the ssh tunnel" + ) + if ( + private_key_password is None + or private_key_password == PASSWORD_MASK + ): + # If we get here we need to ask for the private key password + exception_messages.append( + "Must provide a private key password for the ssh tunnel" + ) + if exception_messages: + # We can ask for just one field or both if masked, if both + # are empty, SSHTunnelMissingCredentials was already raised + raise ValidationError(exception_messages) return diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py index 2495961c369a2..0e3f91cae691d 100644 --- a/superset/databases/ssh_tunnel/commands/exceptions.py +++ b/superset/databases/ssh_tunnel/commands/exceptions.py @@ -57,3 +57,11 @@ def __init__(self, field_name: str) -> None: [_("Field is required")], field_name=field_name, ) + + +class SSHTunnelMissingCredentials(CommandInvalidError): + message = _("Must provide credentials for the SSH Tunnel") + + +class SSHTunnelInvalidCredentials(CommandInvalidError): + message = _("Cannot have multiple credentials for the SSH Tunnel") diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 1d5d298d4ffd5..3859c0be51be9 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -68,6 +68,9 @@ dataset_metadata_config, database_with_ssh_tunnel_config_password, database_with_ssh_tunnel_config_private_key, + database_with_ssh_tunnel_config_mix_credentials, + database_with_ssh_tunnel_config_no_credentials, + database_with_ssh_tunnel_config_private_pass_only, ) from tests.integration_tests.fixtures.unicode_dashboard import ( load_unicode_dashboard_with_position, @@ -2512,8 +2515,8 @@ def test_import_database_masked_ssh_tunnel_private_key_and_password( "extra": { "databases/imported_database.yaml": { "_schema": [ - "Must provide a private key password for the ssh tunnel", "Must provide a private key for the ssh tunnel", + "Must provide a private key password for the ssh tunnel", ] }, "issue_codes": [ @@ -2633,6 +2636,179 @@ def test_import_database_masked_ssh_tunnel_feature_flag_disabled(self): ] } + @mock.patch("superset.databases.schemas.is_feature_enabled") + def test_import_database_masked_ssh_tunnel_feature_no_credentials( + self, mock_schema_is_feature_enabled + ): + """ + Database API: Test import database with ssh_tunnel that has no credentials + """ + self.login(username="admin") + uri = "api/v1/database/import/" + mock_schema_is_feature_enabled.return_value = True + + masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy() + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("database_export/metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_metadata_config).encode()) + with bundle.open( + "database_export/databases/imported_database.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(masked_database_config).encode()) + with bundle.open( + "database_export/datasets/imported_dataset.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + + form_data = { + "formData": (buf, "database_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "errors": [ + { + "message": "Must provide credentials for the SSH Tunnel", + "error_type": "GENERIC_COMMAND_ERROR", + "level": "warning", + "extra": { + "issue_codes": [ + { + "code": 1010, + "message": ( + "Issue 1010 - Superset encountered an " + "error while running a command." + ), + } + ], + }, + } + ] + } + + @mock.patch("superset.databases.schemas.is_feature_enabled") + def test_import_database_masked_ssh_tunnel_feature_mix_credentials( + self, mock_schema_is_feature_enabled + ): + """ + Database API: Test import database with ssh_tunnel that has no credentials + """ + self.login(username="admin") + uri = "api/v1/database/import/" + mock_schema_is_feature_enabled.return_value = True + + masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy() + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("database_export/metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_metadata_config).encode()) + with bundle.open( + "database_export/databases/imported_database.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(masked_database_config).encode()) + with bundle.open( + "database_export/datasets/imported_dataset.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + + form_data = { + "formData": (buf, "database_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "errors": [ + { + "message": "Cannot have multiple credentials for the SSH Tunnel", + "error_type": "GENERIC_COMMAND_ERROR", + "level": "warning", + "extra": { + "issue_codes": [ + { + "code": 1010, + "message": ( + "Issue 1010 - Superset encountered an " + "error while running a command." + ), + } + ], + }, + } + ] + } + + @mock.patch("superset.databases.schemas.is_feature_enabled") + def test_import_database_masked_ssh_tunnel_feature_only_pk_passwd( + self, mock_schema_is_feature_enabled + ): + """ + Database API: Test import database with ssh_tunnel that has no credentials + """ + self.login(username="admin") + uri = "api/v1/database/import/" + mock_schema_is_feature_enabled.return_value = True + + masked_database_config = ( + database_with_ssh_tunnel_config_private_pass_only.copy() + ) + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("database_export/metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_metadata_config).encode()) + with bundle.open( + "database_export/databases/imported_database.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(masked_database_config).encode()) + with bundle.open( + "database_export/datasets/imported_dataset.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + + form_data = { + "formData": (buf, "database_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "errors": [ + { + "message": "Error importing database", + "error_type": "GENERIC_COMMAND_ERROR", + "level": "warning", + "extra": { + "databases/imported_database.yaml": { + "_schema": [ + "Must provide a private key for the ssh tunnel", + "Must provide a private key password for the ssh tunnel", + ] + }, + "issue_codes": [ + { + "code": 1010, + "message": ( + "Issue 1010 - Superset encountered an " + "error while running a command." + ), + } + ], + }, + } + ] + } + @mock.patch( "superset.db_engine_specs.base.BaseEngineSpec.get_function_names", ) diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 531b8641eacc9..6a5a9c58bc129 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -64,8 +64,11 @@ from tests.integration_tests.fixtures.importexport import ( database_config, database_metadata_config, + database_with_ssh_tunnel_config_mix_credentials, + database_with_ssh_tunnel_config_no_credentials, database_with_ssh_tunnel_config_password, database_with_ssh_tunnel_config_private_key, + database_with_ssh_tunnel_config_private_pass_only, dataset_config, dataset_metadata_config, ) @@ -665,8 +668,8 @@ def test_import_v1_database_masked_ssh_tunnel_private_key_and_password( assert excinfo.value.normalized_messages() == { "databases/imported_database.yaml": { "_schema": [ - "Must provide a private key password for the ssh tunnel", "Must provide a private key for the ssh tunnel", + "Must provide a private key password for the ssh tunnel", ] } } @@ -751,6 +754,66 @@ def test_import_v1_database_with_ssh_tunnel_private_key_and_password( db.session.delete(database) db.session.commit() + @mock.patch("superset.databases.schemas.is_feature_enabled") + def test_import_v1_database_masked_ssh_tunnel_no_credentials( + self, mock_schema_is_feature_enabled + ): + """Test that databases with ssh_tunnels that have no credentials are rejected""" + mock_schema_is_feature_enabled.return_value = True + masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy() + contents = { + "metadata.yaml": yaml.safe_dump(database_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(masked_database_config), + } + command = ImportDatabasesCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Must provide credentials for the SSH Tunnel" + + @mock.patch("superset.databases.schemas.is_feature_enabled") + def test_import_v1_database_masked_ssh_tunnel_multiple_credentials( + self, mock_schema_is_feature_enabled + ): + """Test that databases with ssh_tunnels that have multiple credentials are rejected""" + mock_schema_is_feature_enabled.return_value = True + masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy() + contents = { + "metadata.yaml": yaml.safe_dump(database_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(masked_database_config), + } + command = ImportDatabasesCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert ( + str(excinfo.value) == "Cannot have multiple credentials for the SSH Tunnel" + ) + + @mock.patch("superset.databases.schemas.is_feature_enabled") + def test_import_v1_database_masked_ssh_tunnel_only_priv_key_psswd( + self, mock_schema_is_feature_enabled + ): + """Test that databases with ssh_tunnels that have multiple credentials are rejected""" + mock_schema_is_feature_enabled.return_value = True + masked_database_config = ( + database_with_ssh_tunnel_config_private_pass_only.copy() + ) + contents = { + "metadata.yaml": yaml.safe_dump(database_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(masked_database_config), + } + command = ImportDatabasesCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing database" + assert excinfo.value.normalized_messages() == { + "databases/imported_database.yaml": { + "_schema": [ + "Must provide a private key for the ssh tunnel", + "Must provide a private key password for the ssh tunnel", + ] + } + } + @patch("superset.databases.commands.importers.v1.import_dataset") def test_import_v1_rollback(self, mock_import_dataset): """Test than on an exception everything is rolled back""" diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py index 1020ca08d7f99..d5c898eba2c9a 100644 --- a/tests/integration_tests/fixtures/importexport.py +++ b/tests/integration_tests/fixtures/importexport.py @@ -404,6 +404,70 @@ "version": "1.0.0", } +database_with_ssh_tunnel_config_no_credentials: Dict[str, Any] = { + "allow_csv_upload": True, + "allow_ctas": True, + "allow_cvas": True, + "allow_dml": True, + "allow_run_async": False, + "cache_timeout": None, + "database_name": "imported_database", + "expose_in_sqllab": True, + "extra": {}, + "sqlalchemy_uri": "sqlite:///test.db", + "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", + "ssh_tunnel": { + "server_address": "localhost", + "server_port": 22, + "username": "Test", + }, + "version": "1.0.0", +} + +database_with_ssh_tunnel_config_mix_credentials: Dict[str, Any] = { + "allow_csv_upload": True, + "allow_ctas": True, + "allow_cvas": True, + "allow_dml": True, + "allow_run_async": False, + "cache_timeout": None, + "database_name": "imported_database", + "expose_in_sqllab": True, + "extra": {}, + "sqlalchemy_uri": "sqlite:///test.db", + "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", + "ssh_tunnel": { + "server_address": "localhost", + "server_port": 22, + "username": "Test", + "password": "XXXXXXXXXX", + "private_key": "XXXXXXXXXX", + }, + "version": "1.0.0", +} + +database_with_ssh_tunnel_config_private_pass_only: Dict[str, Any] = { + "allow_csv_upload": True, + "allow_ctas": True, + "allow_cvas": True, + "allow_dml": True, + "allow_run_async": False, + "cache_timeout": None, + "database_name": "imported_database", + "expose_in_sqllab": True, + "extra": {}, + "sqlalchemy_uri": "sqlite:///test.db", + "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", + "ssh_tunnel": { + "server_address": "localhost", + "server_port": 22, + "username": "Test", + "private_key_password": "XXXXXXXXXX", + }, + "version": "1.0.0", +} + + dataset_config: Dict[str, Any] = { "table_name": "imported_dataset", "main_dttm_col": None,