Skip to content

Commit

Permalink
SSH Tunnel:
Browse files Browse the repository at this point in the history
- Add more robust checks around SSH Tunnel credentials passed in the file we want to import
- Add extra exceptions so we can handle new errors and let users know
- Update and add tests for new scenarios
  • Loading branch information
Antonio-RiveroMartnez committed Feb 20, 2023
1 parent 0b8be63 commit e1f7f36
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 46 deletions.
93 changes: 49 additions & 44 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import inspect
import json
from typing import Any, Dict
from typing import Any, Dict, List

from flask import current_app
from flask_babel import lazy_gettext as _
Expand All @@ -33,6 +33,8 @@
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelingNotEnabledError,
SSHTunnelInvalidCredentials,
SSHTunnelMissingCredentials,
)
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
Expand Down Expand Up @@ -724,62 +726,65 @@ def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
raise ValidationError("Must provide a password for the database")

@validates_schema
def validate_ssh_tunnel_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
"""If ssh_tunnel has a masked password, password is required"""
uuid = data["uuid"]
existing = db.session.query(Database).filter_by(uuid=uuid).first()
if existing:
return

# Our DB has a ssh_tunnel in it
if ssh_tunnel := data.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
password = ssh_tunnel.get("password")
if password == PASSWORD_MASK:
raise ValidationError("Must provide a password for the ssh tunnel")
return

@validates_schema
def validate_ssh_tunnel_private_key(
def validate_ssh_tunnel_credentials(
self, data: Dict[str, Any], **kwargs: Any
) -> None:
"""If ssh_tunnel has a masked private key, private key is required"""
"""If ssh_tunnel has a masked credentials, credentials are required"""
uuid = data["uuid"]
existing = db.session.query(Database).filter_by(uuid=uuid).first()
if existing:
return

# Our DB has a ssh_tunnel in it
if ssh_tunnel := data.get("ssh_tunnel"):
# Login methods are (only one from these options):
# 1. password
# 2. private_key + private_key_password
# Based on the data passed we determine what info is required.
# You cannot mix the credentials from both methods.
if not is_feature_enabled("SSH_TUNNELING"):
# You are trying to import a Database with SSH Tunnel
# But the Feature Flag is not enabled.
raise SSHTunnelingNotEnabledError()
password = ssh_tunnel.get("password")
private_key = ssh_tunnel.get("private_key")
if private_key == PASSWORD_MASK:
raise ValidationError("Must provide a private key for the ssh tunnel")
return

@validates_schema
def validate_ssh_tunnel_pkey_pass(
self, data: Dict[str, Any], **kwargs: Any
) -> None:
"""
If ssh_tunnel has a masked private key password, private key password is required
"""
uuid = data["uuid"]
existing = db.session.query(Database).filter_by(uuid=uuid).first()
if existing:
return

# Our DB has a ssh_tunnel in it
if ssh_tunnel := data.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
private_key_password = ssh_tunnel.get("private_key_password")
if private_key_password == PASSWORD_MASK:
raise ValidationError(
"Must provide a private key password for the ssh tunnel"
)
if password is not None:
# Login method #1 (Password)
if private_key is not None or private_key_password is not None:
# You cannot have a mix of login methods
raise SSHTunnelInvalidCredentials()
if password == PASSWORD_MASK:
raise ValidationError("Must provide a password for the ssh tunnel")
if password is None:
# If the SSH Tunnel we're importing has no password then it must
# have a private_key + private_key_password combination
if private_key is None and private_key_password is None:
# We have found nothing related to other credentials
raise SSHTunnelMissingCredentials()
# We need to ask for the missing properties of our method # 2
# Some times the property is just missing
# or there're times where it's masked.
# If both are masked, we need to return a list of errors
# so the UI ask for both fields at the same time if needed
exception_messages: List[str] = []
if private_key is None or private_key == PASSWORD_MASK:
# If we get here we need to ask for the private key
exception_messages.append(
"Must provide a private key for the ssh tunnel"
)
if (
private_key_password is None
or private_key_password == PASSWORD_MASK
):
# If we get here we need to ask for the private key password
exception_messages.append(
"Must provide a private key password for the ssh tunnel"
)
if exception_messages:
# We can ask for just one field or both if masked, if both
# are empty, SSHTunnelMissingCredentials was already raised
raise ValidationError(exception_messages)
return


Expand Down
8 changes: 8 additions & 0 deletions superset/databases/ssh_tunnel/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,11 @@ def __init__(self, field_name: str) -> None:
[_("Field is required")],
field_name=field_name,
)


class SSHTunnelMissingCredentials(CommandInvalidError):
message = _("Must provide credentials for the SSH Tunnel")


class SSHTunnelInvalidCredentials(CommandInvalidError):
message = _("Cannot have multiple credentials for the SSH Tunnel")
178 changes: 177 additions & 1 deletion tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
dataset_metadata_config,
database_with_ssh_tunnel_config_password,
database_with_ssh_tunnel_config_private_key,
database_with_ssh_tunnel_config_mix_credentials,
database_with_ssh_tunnel_config_no_credentials,
database_with_ssh_tunnel_config_private_pass_only,
)
from tests.integration_tests.fixtures.unicode_dashboard import (
load_unicode_dashboard_with_position,
Expand Down Expand Up @@ -2512,8 +2515,8 @@ def test_import_database_masked_ssh_tunnel_private_key_and_password(
"extra": {
"databases/imported_database.yaml": {
"_schema": [
"Must provide a private key password for the ssh tunnel",
"Must provide a private key for the ssh tunnel",
"Must provide a private key password for the ssh tunnel",
]
},
"issue_codes": [
Expand Down Expand Up @@ -2633,6 +2636,179 @@ def test_import_database_masked_ssh_tunnel_feature_flag_disabled(self):
]
}

@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_feature_no_credentials(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True

masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy()

buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))

assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Must provide credentials for the SSH Tunnel",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}

@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_feature_mix_credentials(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True

masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy()

buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))

assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Cannot have multiple credentials for the SSH Tunnel",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}

@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_feature_only_pk_passwd(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True

masked_database_config = (
database_with_ssh_tunnel_config_private_pass_only.copy()
)

buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))

assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/imported_database.yaml": {
"_schema": [
"Must provide a private key for the ssh tunnel",
"Must provide a private key password for the ssh tunnel",
]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}

@mock.patch(
"superset.db_engine_specs.base.BaseEngineSpec.get_function_names",
)
Expand Down
Loading

0 comments on commit e1f7f36

Please sign in to comment.