Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(databases): test connection api #10723

Merged
merged 10 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def init_views(self) -> None:
AlertLogModelView,
AlertModelView,
AlertObservationModelView,
ValidatorInlineView,
SQLObserverInlineView,
ValidatorInlineView,
)
from superset.views.annotations import (
AnnotationLayerModelView,
Expand Down
99 changes: 95 additions & 4 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@
from flask import g, request, Response
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as _
from marshmallow import ValidationError
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
NoSuchModuleError,
NoSuchTableError,
OperationalError,
SQLAlchemyError,
)

from superset import event_logger
from superset.constants import RouteMethod
Expand All @@ -33,8 +40,10 @@
DatabaseDeleteFailedError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseSecurityUnsafeError,
DatabaseUpdateFailedError,
)
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.commands.update import UpdateDatabaseCommand
from superset.databases.dao import DatabaseDAO
from superset.databases.decorators import check_datasource_access
Expand All @@ -44,6 +53,7 @@
DatabasePostSchema,
DatabasePutSchema,
DatabaseRelatedObjectsResponse,
DatabaseTestConnectionSchema,
SchemasResponseSchema,
SelectStarResponseSchema,
TableMetadataResponseSchema,
Expand All @@ -65,6 +75,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"table_metadata",
"select_star",
"schemas",
"test_connection",
"related_objects",
}
class_permission_name = "DatabaseView"
Expand Down Expand Up @@ -343,7 +354,7 @@ def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ
@rison(database_schemas_query_schema)
@statsd_metrics
def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
""" Get all schemas from a database
"""Get all schemas from a database
---
get:
description: Get all schemas from a database
Expand Down Expand Up @@ -400,7 +411,7 @@ def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
def table_metadata(
self, database: Database, table_name: str, schema_name: str
) -> FlaskResponse:
""" Table schema info
"""Table schema info
---
get:
description: Get database table metadata
Expand Down Expand Up @@ -457,7 +468,7 @@ def table_metadata(
def select_star(
self, database: Database, table_name: str, schema_name: Optional[str] = None
) -> FlaskResponse:
""" Table schema info
"""Table schema info
---
get:
description: Get database select star for table
Expand Down Expand Up @@ -506,6 +517,86 @@ def select_star(
self.incr_stats("success", self.select_star.__name__)
return self.response(200, result=result)

@expose("/test_connection", methods=["POST"])
@protect()
@safe
@event_logger.log_this
@statsd_metrics
def test_connection( # pylint: disable=too-many-return-statements
self,
) -> FlaskResponse:
"""Tests a database connection
---
post:
description: >-
Tests a database connection
requestBody:
description: Database schema
required: true
content:
application/json:
schema:
type: object
properties:
encrypted_extra:
type: object
extras:
type: object
name:
type: string
server_cert:
type: string
responses:
200:
description: Database Test Connection
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
item = DatabaseTestConnectionSchema().load(request.json)
# This validates custom Schema with custom validations
except ValidationError as error:
return self.response_400(message=error.messages)
try:
TestConnectionDatabaseCommand(g.user, item).run()
return self.response(200, message="OK")
except (NoSuchModuleError, ModuleNotFoundError):
logger.info("Invalid driver")
driver_name = make_url(item.get("sqlalchemy_uri")).drivername
return self.response(
400,
message=_(f"Could not load database driver: {driver_name}"),
driver_name=driver_name,
)
except DatabaseSecurityUnsafeError as ex:
return self.response_422(message=ex)
except OperationalError:
logger.warning("Connection failed")
return self.response(
500,
message=_("Connection failed, please check your connection settings"),
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
return self.response_400(
message=_(
"Unexpected error occurred, please check your logs for details"
)
)

@expose("/<int:pk>/related_objects/", methods=["GET"])
@protect()
@safe
Expand Down
5 changes: 5 additions & 0 deletions superset/databases/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
DeleteFailedError,
UpdateFailedError,
)
from superset.security.analytics_db_safety import DBSecurityException


class DatabaseInvalidError(CommandInvalidError):
Expand Down Expand Up @@ -109,3 +110,7 @@ class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):

class DatabaseDeleteFailedError(DeleteFailedError):
message = _("Database could not be deleted.")


class DatabaseSecurityUnsafeError(DBSecurityException):
message = _("Stopped an unsafe database connection")
67 changes: 67 additions & 0 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from contextlib import closing
from typing import Any, Dict, Optional

import simplejson as json
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import select

from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError
from superset.databases.dao import DatabaseDAO
from superset.models.core import Database
from superset.security.analytics_db_safety import DBSecurityException

logger = logging.getLogger(__name__)


class TestConnectionDatabaseCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
self._model: Optional[Database] = None

def run(self) -> None:
self.validate()
try:
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted

database = DatabaseDAO.build_db_for_connection_test(
server_cert=self._properties.get("server_cert", ""),
extra=json.dumps(self._properties.get("extra", {})),
impersonate_user=self._properties.get("impersonate_user", False),
encrypted_extra=json.dumps(self._properties.get("encrypted_extra", {})),
)
if database is not None:
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)
with closing(engine.connect()) as conn:
conn.scalar(select([1]))
except DBSecurityException as ex:
logger.warning(ex)
raise DatabaseSecurityUnsafeError()

def validate(self) -> None:
database_name = self._properties.get("database_name")
if database_name is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)
21 changes: 20 additions & 1 deletion superset/databases/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional

from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
Expand Down Expand Up @@ -45,6 +45,25 @@ def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
)
return not db.session.query(database_query.exists()).scalar()

@staticmethod
def get_database_by_name(database_name: str) -> Optional[Database]:
return (
db.session.query(Database)
.filter(Database.database_name == database_name)
.one_or_none()
)

@staticmethod
def build_db_for_connection_test(
server_cert: str, extra: str, impersonate_user: bool, encrypted_extra: str
) -> Optional[Database]:
return Database(
server_cert=server_cert,
extra=extra,
impersonate_user=impersonate_user,
encrypted_extra=encrypted_extra,
)

@classmethod
def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
datasets = cls.find_by_id(database_id).tables
Expand Down
23 changes: 21 additions & 2 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import inspect
import json

from flask import current_app
from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema
from marshmallow.validate import Length, ValidationError
from sqlalchemy import MetaData
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError

from superset import app
from superset.exceptions import CertificateException
from superset.utils.core import markdown, parse_ssl_cert

Expand Down Expand Up @@ -142,7 +142,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
)
]
)
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value:
if current_app.config.get("PREVENT_UNSAFE_DB_CONNECTIONS", True) and value:
if value.startswith("sqlite"):
raise ValidationError(
[
Expand Down Expand Up @@ -291,6 +291,25 @@ class DatabasePutSchema(Schema):
)


class DatabaseTestConnectionSchema(Schema):
database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250),
)
impersonate_user = fields.Boolean(description=impersonate_user_description)
extra = fields.String(description=extra_description, validate=extra_validator)
encrypted_extra = fields.String(
description=encrypted_extra_description, validate=encrypted_extra_validator
)
server_cert = fields.String(
description=server_cert_description, validate=server_cert_validator
)
sqlalchemy_uri = fields.String(
description=sqlalchemy_uri_description,
required=True,
validate=[Length(1, 1024), sqlalchemy_uri_validator],
)


class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool()
initially = fields.Bool()
Expand Down
2 changes: 1 addition & 1 deletion superset/tasks/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@

if TYPE_CHECKING:
# pylint: disable=unused-import
from werkzeug.datastructures import TypeConversionDict
from flask_appbuilder.security.sqla.models import User
from werkzeug.datastructures import TypeConversionDict

# Globals
config = app.config
Expand Down
25 changes: 13 additions & 12 deletions superset/views/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,25 @@ class BaseSupersetModelRestApi(ModelRestApi):

csrf_exempt = False
method_permission_name = {
"get_list": "list",
"get": "show",
"bulk_delete": "delete",
"data": "list",
"delete": "delete",
"distinct": "list",
"export": "mulexport",
"get": "show",
"get_list": "list",
"info": "list",
"post": "add",
"put": "edit",
"delete": "delete",
"bulk_delete": "delete",
"info": "list",
"related": "list",
"distinct": "list",
"thumbnail": "list",
"refresh": "edit",
"data": "list",
"viz_types": "list",
"related": "list",
"related_objects": "list",
"table_metadata": "list",
"select_star": "list",
"schemas": "list",
"select_star": "list",
"table_metadata": "list",
"test_connection": "post",
"thumbnail": "list",
"viz_types": "list",
}

order_rel_fields: Dict[str, Tuple[str, str]] = {}
Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def testconn( # pylint: disable=too-many-return-statements,no-self-use
logger.warning("Stopped an unsafe database connection")
return json_error_response(_(str(ex)), 400)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
logger.warning("Unexpected error %s", type(ex).__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

return json_error_response(
_("Unexpected error occurred, please check your logs for details"), 400
)
Expand Down
Loading