diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 40aea66301214..ab8abe0edca21 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -1513,13 +1513,13 @@ export function createCtasDatasource(vizOptions) { return dispatch => { dispatch(createDatasourceStarted()); return SupersetClient.post({ - endpoint: '/superset/get_or_create_table/', - postPayload: { data: vizOptions }, + endpoint: '/api/v1/dataset/get_or_create/', + jsonPayload: vizOptions, }) .then(({ json }) => { - dispatch(createDatasourceSuccess(json)); + dispatch(createDatasourceSuccess(json.result)); - return json; + return json.result; }) .catch(() => { const errorMsg = t('An error occurred while creating the data source'); diff --git a/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx b/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx index 2fe1e14a07853..a4c71139c0d3c 100644 --- a/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx +++ b/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx @@ -48,10 +48,10 @@ const ExploreCtasResultsButton = ({ const dispatch = useDispatch<(dispatch: any) => Promise>(); const buildVizOptions = { - datasourceName: table, + table_name: table, schema, - dbId, - templateParams, + database_id: dbId, + template_params: templateParams, }; const visualize = () => { diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index c502f527acd6f..86cb08bb8690d 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -35,7 +35,6 @@ from superset.superset_typing import FlaskResponse from superset.utils import core as utils from superset.views.base import ( - create_table_permissions, DatasourceFilter, DeleteMixin, ListWidgetWithCheckboxes, @@ -511,7 +510,6 @@ def post_add( # pylint: disable=arguments-differ ) -> None: if fetch_metadata: item.fetch_metadata() - create_table_permissions(item) if flash_message: flash( _( diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 925c3c7cb8c71..d58a1dd3f6152 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -61,6 +61,7 @@ DatasetRelatedObjectsResponse, get_delete_ids_schema, get_export_ids_schema, + GetOrCreateDatasetSchema, ) from superset.utils.core import parse_boolean_string from superset.views.base import DatasourceFilter, generate_download_headers @@ -93,6 +94,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): "refresh", "related_objects", "duplicate", + "get_or_create_dataset", } list_columns = [ "id", @@ -240,6 +242,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): openapi_spec_component_schemas = ( DatasetRelatedObjectsResponse, DatasetDuplicateSchema, + GetOrCreateDatasetSchema, ) list_outer_default_load = True @@ -877,3 +880,70 @@ def import_(self) -> Response: ) command.run() return self.response(200, message="OK") + + @expose("/get_or_create/", methods=["POST"]) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".get_or_create_dataset", + log_to_statsd=False, + ) + def get_or_create_dataset(self) -> Response: + """Retrieve a dataset by name, or create it if it does not exist + --- + post: + summary: Retrieve a table by name, or create it if it does not exist + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/GetOrCreateDatasetSchema' + responses: + 200: + description: The ID of the table + content: + application/json: + schema: + type: object + properties: + result: + type: object + properties: + table_id: + type: integer + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + body = GetOrCreateDatasetSchema().load(request.json) + except ValidationError as ex: + return self.response(400, message=ex.messages) + table_name = body["table_name"] + database_id = body["database_id"] + table = DatasetDAO.get_table_by_name(database_id, table_name) + if table: + return self.response(200, result={"table_id": table.id}) + + body["database"] = database_id + try: + tbl = CreateDatasetCommand(body).run() + return self.response(200, result={"table_id": tbl.id}) + except DatasetInvalidError as ex: + return self.response_422(message=ex.normalized_messages()) + except DatasetCreateFailedError as ex: + logger.error( + "Error creating model %s: %s", + self.__class__.__name__, + str(ex), + exc_info=True, + ) + return self.response_422(message=ex.message) diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index 1c55723b47a23..b158fce1fefe8 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -388,6 +388,14 @@ def bulk_delete(models: Optional[List[SqlaTable]], commit: bool = True) -> None: db.session.rollback() raise ex + @staticmethod + def get_table_by_name(database_id: int, table_name: str) -> Optional[SqlaTable]: + return ( + db.session.query(SqlaTable) + .filter_by(database_id=database_id, table_name=table_name) + .one_or_none() + ) + class DatasetColumnDAO(BaseDAO): model_cls = TableColumn diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 223324da3aa9b..103359a2c3f03 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -228,6 +228,17 @@ def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: external_url = fields.String(allow_none=True) +class GetOrCreateDatasetSchema(Schema): + table_name = fields.String(required=True, description="Name of table") + database_id = fields.Integer( + required=True, description="ID of database table belongs to" + ) + schema = fields.String( + description="The schema the table belongs to", allow_none=True + ) + template_params = fields.String(description="Template params for the table") + + class DatasetSchema(SQLAlchemyAutoSchema): """ Schema for the ``Dataset`` model. diff --git a/superset/views/base.py b/superset/views/base.py index 0d69f1482f70f..f6651e5c74528 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -299,12 +299,6 @@ def validate_sqlatable(table: models.SqlaTable) -> None: ) from ex -def create_table_permissions(table: models.SqlaTable) -> None: - security_manager.add_permission_view_menu("datasource_access", table.get_perm()) - if table.schema: - security_manager.add_permission_view_menu("schema_access", table.schema_perm) - - class BaseSupersetView(BaseView): @staticmethod def json_response(obj: Any, status: int = 200) -> FlaskResponse: diff --git a/superset/views/core.py b/superset/views/core.py index f1603837bb660..fb371c209ee90 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -142,7 +142,6 @@ api, BaseSupersetView, common_bootstrap_payload, - create_table_permissions, CsvResponse, data_payload_response, deprecated, @@ -1927,6 +1926,7 @@ def log(self) -> FlaskResponse: # pylint: disable=no-self-use @has_access @expose("/get_or_create_table/", methods=["POST"]) @event_logger.log_this + @deprecated() def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use """Gets or creates a table object with attributes passed to the API. @@ -1956,11 +1956,11 @@ def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use table.schema = data.get("schema") table.template_params = data.get("templateParams") # needed for the table validation. + # fn can be deleted when this endpoint is removed validate_sqlatable(table) db.session.add(table) table.fetch_metadata() - create_table_permissions(table) db.session.commit() return json_success(json.dumps({"table_id": table.id})) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 95236af09041e..6e0551bd9f826 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -34,6 +34,7 @@ DAODeleteFailedError, DAOUpdateFailedError, ) +from superset.datasets.commands.exceptions import DatasetCreateFailedError from superset.datasets.models import Dataset from superset.extensions import db, security_manager from superset.models.core import Database @@ -474,6 +475,7 @@ def test_info_security_dataset(self): "can_write", "can_export", "can_duplicate", + "can_get_or_create_dataset", } def test_create_dataset_item(self): @@ -2302,3 +2304,90 @@ def test_duplicate_invalid_dataset(self): } rv = self.post_assert_metric(uri, table_data, "duplicate") assert rv.status_code == 422 + + @pytest.mark.usefixtures("app_context", "virtual_dataset") + def test_get_or_create_dataset_already_exists(self): + """ + Dataset API: Test get or create endpoint when table already exists + """ + self.login(username="admin") + rv = self.client.post( + "api/v1/dataset/get_or_create/", + json={ + "table_name": "virtual_dataset", + "database_id": get_example_database().id, + }, + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + dataset = ( + db.session.query(SqlaTable) + .filter(SqlaTable.table_name == "virtual_dataset") + .one() + ) + self.assertEqual(response["result"], {"table_id": dataset.id}) + + def test_get_or_create_dataset_database_not_found(self): + """ + Dataset API: Test get or create endpoint when database doesn't exist + """ + self.login(username="admin") + rv = self.client.post( + "api/v1/dataset/get_or_create/", + json={"table_name": "virtual_dataset", "database_id": 999}, + ) + self.assertEqual(rv.status_code, 422) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["message"], {"database": ["Database does not exist"]}) + + @patch("superset.datasets.commands.create.CreateDatasetCommand.run") + def test_get_or_create_dataset_create_fails(self, command_run_mock): + """ + Dataset API: Test get or create endpoint when create fails + """ + command_run_mock.side_effect = DatasetCreateFailedError + self.login(username="admin") + rv = self.client.post( + "api/v1/dataset/get_or_create/", + json={ + "table_name": "virtual_dataset", + "database_id": get_example_database().id, + }, + ) + self.assertEqual(rv.status_code, 422) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["message"], "Dataset could not be created.") + + def test_get_or_create_dataset_creates_table(self): + """ + Dataset API: Test get or create endpoint when table is created + """ + self.login(username="admin") + + examples_db = get_example_database() + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api") + engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT 2 as col") + + rv = self.client.post( + "api/v1/dataset/get_or_create/", + json={ + "table_name": "test_create_sqla_table_api", + "database_id": examples_db.id, + "template_params": '{"param": 1}', + }, + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + table = ( + db.session.query(SqlaTable) + .filter_by(table_name="test_create_sqla_table_api") + .one() + ) + self.assertEqual(response["result"], {"table_id": table.id}) + self.assertEqual(table.template_params, '{"param": 1}') + + db.session.delete(table) + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE test_create_sqla_table_api") + db.session.commit() diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 5cc5c85beab37..0ce98477a0b2d 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -20,13 +20,18 @@ import pytest import yaml +from sqlalchemy.exc import SQLAlchemyError from superset import db, security_manager from superset.commands.exceptions import CommandInvalidError from superset.commands.importers.exceptions import IncorrectVersionError from superset.connectors.sqla.models import SqlaTable from superset.databases.commands.importers.v1 import ImportDatabasesCommand -from superset.datasets.commands.exceptions import DatasetNotFoundError +from superset.datasets.commands.create import CreateDatasetCommand +from superset.datasets.commands.exceptions import ( + DatasetInvalidError, + DatasetNotFoundError, +) from superset.datasets.commands.export import ExportDatasetsCommand from superset.datasets.commands.importers import v0, v1 from superset.models.core import Database @@ -519,3 +524,47 @@ def _get_table_from_list_by_name(name: str, tables: List[Any]): if table.table_name == name: return table raise ValueError(f"Table {name} does not exists in database") + + +class TestCreateDatasetCommand(SupersetTestCase): + def test_database_not_found(self): + self.login(username="admin") + with self.assertRaises(DatasetInvalidError): + CreateDatasetCommand({"table_name": "table", "database": 9999}).run() + + @patch("superset.models.core.Database.get_table") + def test_get_table_from_database_error(self, get_table_mock): + self.login(username="admin") + get_table_mock.side_effect = SQLAlchemyError + with self.assertRaises(DatasetInvalidError): + CreateDatasetCommand( + {"table_name": "table", "database": get_example_database().id} + ).run() + + @patch("superset.security.manager.g") + @patch("superset.commands.utils.g") + def test_create_dataset_command(self, mock_g, mock_g2): + mock_g.user = security_manager.find_user("admin") + mock_g2.user = mock_g.user + examples_db = get_example_database() + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS test_create_dataset_command") + engine.execute( + "CREATE TABLE test_create_dataset_command AS SELECT 2 as col" + ) + + table = CreateDatasetCommand( + {"table_name": "test_create_dataset_command", "database": examples_db.id} + ).run() + fetched_table = ( + db.session.query(SqlaTable) + .filter_by(table_name="test_create_dataset_command") + .one() + ) + self.assertEqual(table, fetched_table) + self.assertEqual([owner.username for owner in table.owners], ["admin"]) + + db.session.delete(table) + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE test_create_dataset_command") + db.session.commit()