diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.test.js b/superset-frontend/src/SqlLab/actions/sqlLab.test.js index ecf2c4d7e299c..7a7f2f72a4c71 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.test.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.test.js @@ -508,10 +508,10 @@ describe('async actions', () => { fetchMock.delete(updateTableSchemaEndpoint, {}); fetchMock.post(updateTableSchemaEndpoint, JSON.stringify({ id: 1 })); - const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table/*/*/'; + const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*'; fetchMock.get(getTableMetadataEndpoint, {}); const getExtraTableMetadataEndpoint = - 'glob:**/api/v1/database/*/table_metadata/extra/'; + 'glob:**/api/v1/database/*/table_metadata/extra/*'; fetchMock.get(getExtraTableMetadataEndpoint, {}); let isFeatureEnabledMock; diff --git a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx index f8c94468bf7f3..b5003b16f7b47 100644 --- a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx +++ b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx @@ -61,13 +61,13 @@ beforeEach(() => { }, ], }); - fetchMock.get('glob:*/api/v1/database/*/table/*/*', { + fetchMock.get('glob:*/api/v1/database/*/table_metadata/*', { status: 200, body: { columns: table.columns, }, }); - fetchMock.get('glob:*/api/v1/database/*/table_metadata/extra/', { + fetchMock.get('glob:*/api/v1/database/*/table_metadata/extra/*', { status: 200, body: {}, }); diff --git a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx index a2fe88020aa45..7b95b6d0f3492 100644 --- a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx +++ b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx @@ -47,7 +47,7 @@ jest.mock(
{column.name}
), ); -const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table/*/*/'; +const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*'; const getExtraTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/extra/*'; const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded'; diff --git a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx index ef5797fb309c2..c964fc32faaf0 100644 --- a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx +++ b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx @@ -74,7 +74,9 @@ const DatasetPanelWrapper = ({ const { dbId, tableName, schema } = props; setLoading(true); setHasColumns?.(false); - const path = `/api/v1/database/${dbId}/table/${tableName}/${schema}/`; + const path = schema + ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}/` + : `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`; try { const response = await SupersetClient.get({ endpoint: path, diff --git a/superset-frontend/src/hooks/apiResources/tables.ts b/superset-frontend/src/hooks/apiResources/tables.ts index 164fe0f0ab19c..41be4c167c9c8 100644 --- a/superset-frontend/src/hooks/apiResources/tables.ts +++ b/superset-frontend/src/hooks/apiResources/tables.ts @@ -114,9 +114,9 @@ const tableApi = api.injectEndpoints({ }), tableMetadata: builder.query({ query: ({ dbId, schema, table }) => ({ - endpoint: `/api/v1/database/${dbId}/table/${encodeURIComponent( - table, - )}/${encodeURIComponent(schema)}/`, + endpoint: schema + ? `/api/v1/database/${dbId}/table_metadata/?name=${table}&schema=${schema}` + : `/api/v1/database/${dbId}/table_metadata/?name=${table}`, transformResponse: ({ json }: TableMetadataReponse) => json, }), }), diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index 16b87a567a5f0..d4b14445687cd 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -34,6 +34,7 @@ from superset.daos.exceptions import DAOCreateFailedError from superset.exceptions import SupersetSecurityException from superset.extensions import db, security_manager +from superset.sql_parse import Table logger = logging.getLogger(__name__) @@ -80,7 +81,10 @@ def validate(self) -> None: if ( database and not sql - and not DatasetDAO.validate_table_exists(database, table_name, schema) + and not DatasetDAO.validate_table_exists( + database, + Table(table_name, schema), + ) ): exceptions.append(TableNotFoundValidationError(table_name)) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index afd2791c9c601..363d64aa14880 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -49,7 +49,7 @@ Integer, or_, String, - Table, + Table as DBTable, Text, update, ) @@ -108,7 +108,7 @@ validate_adhoc_subquery, ) from superset.models.slice import Slice -from superset.sql_parse import ParsedQuery, sanitize_clause +from superset.sql_parse import ParsedQuery, sanitize_clause, Table from superset.superset_typing import ( AdhocColumn, AdhocMetric, @@ -1068,7 +1068,7 @@ def data(self) -> dict[str, Any]: return {s: getattr(self, s) for s in attrs} -sqlatable_user = Table( +sqlatable_user = DBTable( "sqlatable_user", metadata, Column("id", Integer, primary_key=True), @@ -1146,6 +1146,7 @@ class SqlaTable( foreign_keys=[database_id], ) schema = Column(String(255)) + catalog = Column(String(256), nullable=True, default=None) sql = Column(MediumText()) is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) @@ -1322,8 +1323,7 @@ def external_metadata(self) -> list[ResultSetColumnType]: return get_virtual_table_metadata(dataset=self) return get_physical_table_metadata( database=self.database, - table_name=self.table_name, - schema_name=self.schema, + table=Table(self.table_name, self.schema, self.catalog), normalize_columns=self.normalize_columns, ) @@ -1339,7 +1339,9 @@ def select_star(self) -> str | None: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star( - self.table_name, schema=self.schema, show_cols=False, latest_partition=False + Table(self.table_name, self.schema, self.catalog), + show_cols=False, + latest_partition=False, ) @property @@ -1779,7 +1781,13 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None: ) def get_sqla_table_object(self) -> Table: - return self.database.get_table(self.table_name, schema=self.schema) + return self.database.get_table( + Table( + self.table_name, + self.schema, + self.catalog, + ) + ) def fetch_metadata(self, commit: bool = True) -> MetadataResult: """ @@ -1791,7 +1799,13 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: new_columns = self.external_metadata() metrics = [ SqlMetric(**metric) - for metric in self.database.get_metrics(self.table_name, self.schema) + for metric in self.database.get_metrics( + Table( + self.table_name, + self.schema, + self.catalog, + ) + ) ] any_date_col = None db_engine_spec = self.db_engine_spec @@ -2038,7 +2052,7 @@ def load_database(self: SqlaTable) -> None: sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column) sa.event.listen(TableColumn, "after_update", SqlaTable.update_column) -RLSFilterRoles = Table( +RLSFilterRoles = DBTable( "rls_filter_roles", metadata, Column("id", Integer, primary_key=True), @@ -2046,7 +2060,7 @@ def load_database(self: SqlaTable) -> None: Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")), ) -RLSFilterTables = Table( +RLSFilterTables = DBTable( "rls_filter_tables", metadata, Column("id", Integer, primary_key=True), diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 4bc11aee42d80..f547e238dfcc8 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -38,7 +38,7 @@ ) from superset.models.core import Database from superset.result_set import SupersetResultSet -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, Table from superset.superset_typing import ResultSetColumnType if TYPE_CHECKING: @@ -47,24 +47,18 @@ def get_physical_table_metadata( database: Database, - table_name: str, + table: Table, normalize_columns: bool, - schema_name: str | None = None, ) -> list[ResultSetColumnType]: """Use SQLAlchemy inspector to get table metadata""" db_engine_spec = database.db_engine_spec db_dialect = database.get_dialect() - # ensure empty schema - _schema_name = schema_name if schema_name else None - # Table does not exist or is not visible to a connection. - if not ( - database.has_table_by_name(table_name=table_name, schema=_schema_name) - or database.has_view_by_name(view_name=table_name, schema=_schema_name) - ): - raise NoSuchTableError + # Table does not exist or is not visible to a connection. + if not (database.has_table(table) or database.has_view(table)): + raise NoSuchTableError(table) - cols = database.get_columns(table_name, schema=_schema_name) + cols = database.get_columns(table) for col in cols: try: if isinstance(col["type"], TypeEngine): diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 4647e02ce68af..c9c16beaa69d0 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -30,6 +30,7 @@ from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from superset.views.base import DatasourceFilter @@ -72,13 +73,14 @@ def get_related_objects(database_id: int) -> dict[str, Any]: @staticmethod def validate_table_exists( - database: Database, table_name: str, schema: str | None + database: Database, + table: Table, ) -> bool: try: - database.get_table(table_name, schema=schema) + database.get_table(table) return True except SQLAlchemyError as ex: # pragma: no cover - logger.warning("Got an error %s validating table: %s", str(ex), table_name) + logger.warning("Got an error %s validating table: %s", str(ex), table) return False @staticmethod diff --git a/superset/databases/api.py b/superset/databases/api.py index 4537fc8bc512f..1b5f515ea69bb 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -133,6 +133,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): RouteMethod.RELATED, "tables", "table_metadata", + "table_metadata_deprecated", "table_extra_metadata", "table_extra_metadata_deprecated", "select_star", @@ -717,10 +718,10 @@ def tables(self, pk: int, **kwargs: Any) -> FlaskResponse: @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" - f".table_metadata", + f".table_metadata_deprecated", log_to_statsd=False, ) - def table_metadata( + def table_metadata_deprecated( self, database: Database, table_name: str, schema_name: str ) -> FlaskResponse: """Get database table metadata. @@ -761,16 +762,16 @@ def table_metadata( 500: $ref: '#/components/responses/500' """ - self.incr_stats("init", self.table_metadata.__name__) + self.incr_stats("init", self.table_metadata_deprecated.__name__) try: - table_info = get_table_metadata(database, table_name, schema_name) + table_info = get_table_metadata(database, Table(table_name, schema_name)) except SQLAlchemyError as ex: - self.incr_stats("error", self.table_metadata.__name__) + self.incr_stats("error", self.table_metadata_deprecated.__name__) return self.response_422(error_msg_from_exception(ex)) except SupersetException as ex: return self.response(ex.status, message=ex.message) - self.incr_stats("success", self.table_metadata.__name__) + self.incr_stats("success", self.table_metadata_deprecated.__name__) return self.response(200, **table_info) @expose("//table_extra///", methods=("GET",)) @@ -839,7 +840,86 @@ def table_extra_metadata_deprecated( payload = database.db_engine_spec.get_extra_table_metadata(database, table) return self.response(200, **payload) - @expose("//table_metadata/extra/", methods=("GET",)) + @expose("//table_metadata/", methods=["GET"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".table_metadata", + log_to_statsd=False, + ) + def table_metadata(self, pk: int) -> FlaskResponse: + """ + Get metadata for a given table. + + Optionally, a schema and a catalog can be passed, if different from the default + ones. + --- + get: + summary: Get table metadata + description: >- + Metadata associated with the table (columns, indexes, etc.) + parameters: + - in: path + schema: + type: integer + name: pk + description: The database id + - in: query + schema: + type: string + name: table + required: true + description: Table name + - in: query + schema: + type: string + name: schema + description: >- + Optional table schema, if not passed default schema will be used + - in: query + schema: + type: string + name: catalog + description: >- + Optional table catalog, if not passed default catalog will be used + responses: + 200: + description: Table metadata information + content: + application/json: + schema: + $ref: "#/components/schemas/TableExtraMetadataResponseSchema" + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + self.incr_stats("init", self.table_metadata.__name__) + + database = DatabaseDAO.find_by_id(pk) + if database is None: + raise DatabaseNotFoundException("No such database") + + try: + parameters = QualifiedTableSchema().load(request.args) + except ValidationError as ex: + raise InvalidPayloadSchemaError(ex) from ex + + table = Table(parameters["name"], parameters["schema"], parameters["catalog"]) + try: + security_manager.raise_for_access(database=database, table=table) + except SupersetSecurityException as ex: + # instead of raising 403, raise 404 to hide table existence + raise TableNotFoundException("No such table") from ex + + payload = database.db_engine_spec.get_table_metadata(database, table) + + return self.response(200, **payload) + + @expose("//table_metadata/extra/", methods=["GET"]) @protect() @statsd_metrics @event_logger.log_this_with_context( @@ -973,7 +1053,8 @@ def select_star( self.incr_stats("init", self.select_star.__name__) try: result = database.select_star( - table_name, schema_name, latest_partition=True + Table(table_name, schema_name), + latest_partition=True, ) except NoSuchTableError: self.incr_stats("error", self.select_star.__name__) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index dccab980bcd18..47f4bf15fac0f 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -17,11 +17,13 @@ # pylint: disable=unused-argument, too-many-lines +from __future__ import annotations + import inspect import json import os import re -from typing import Any +from typing import Any, TypedDict from flask import current_app from flask_babel import lazy_gettext as _ @@ -581,6 +583,49 @@ class DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema): ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) +class TableMetadataOptionsResponse(TypedDict): + deferrable: bool + initially: bool + match: bool + ondelete: bool + onupdate: bool + + +class TableMetadataColumnsResponse(TypedDict, total=False): + keys: list[str] + longType: str + name: str + type: str + duplicates_constraint: str | None + comment: str | None + + +class TableMetadataForeignKeysIndexesResponse(TypedDict): + column_names: list[str] + name: str + options: TableMetadataOptionsResponse + referred_columns: list[str] + referred_schema: str + referred_table: str + type: str + + +class TableMetadataPrimaryKeyResponse(TypedDict): + column_names: list[str] + name: str + type: str + + +class TableMetadataResponse(TypedDict): + name: str + columns: list[TableMetadataColumnsResponse] + foreignKeys: list[TableMetadataForeignKeysIndexesResponse] + indexes: list[TableMetadataForeignKeysIndexesResponse] + primaryKey: TableMetadataPrimaryKeyResponse + selectStar: str + comment: str | None + + class TableMetadataOptionsResponseSchema(Schema): deferrable = fields.Bool() initially = fields.Bool() diff --git a/superset/databases/utils.py b/superset/databases/utils.py index 8de4bb6f2353d..dfd75eb2233f4 100644 --- a/superset/databases/utils.py +++ b/superset/databases/utils.py @@ -14,19 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Optional, Union + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING from sqlalchemy.engine.url import make_url, URL from superset.commands.database.exceptions import DatabaseInvalidError +from superset.sql_parse import Table + +if TYPE_CHECKING: + from superset.databases.schemas import ( + TableMetadataColumnsResponse, + TableMetadataForeignKeysIndexesResponse, + TableMetadataResponse, + ) def get_foreign_keys_metadata( database: Any, - table_name: str, - schema_name: Optional[str], -) -> list[dict[str, Any]]: - foreign_keys = database.get_foreign_keys(table_name, schema_name) + table: Table, +) -> list[TableMetadataForeignKeysIndexesResponse]: + foreign_keys = database.get_foreign_keys(table) for fk in foreign_keys: fk["column_names"] = fk.pop("constrained_columns") fk["type"] = "fk" @@ -34,9 +44,10 @@ def get_foreign_keys_metadata( def get_indexes_metadata( - database: Any, table_name: str, schema_name: Optional[str] -) -> list[dict[str, Any]]: - indexes = database.get_indexes(table_name, schema_name) + database: Any, + table: Table, +) -> list[TableMetadataForeignKeysIndexesResponse]: + indexes = database.get_indexes(table) for idx in indexes: idx["type"] = "index" return indexes @@ -51,30 +62,27 @@ def get_col_type(col: dict[Any, Any]) -> str: return dtype -def get_table_metadata( - database: Any, table_name: str, schema_name: Optional[str] -) -> dict[str, Any]: +def get_table_metadata(database: Any, table: Table) -> TableMetadataResponse: """ Get table metadata information, including type, pk, fks. This function raises SQLAlchemyError when a schema is not found. :param database: The database model - :param table_name: Table name - :param schema_name: schema name + :param table: Table instance :return: Dict table metadata ready for API response """ keys = [] - columns = database.get_columns(table_name, schema_name) - primary_key = database.get_pk_constraint(table_name, schema_name) + columns = database.get_columns(table) + primary_key = database.get_pk_constraint(table) if primary_key and primary_key.get("constrained_columns"): primary_key["column_names"] = primary_key.pop("constrained_columns") primary_key["type"] = "pk" keys += [primary_key] - foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) - indexes = get_indexes_metadata(database, table_name, schema_name) + foreign_keys = get_foreign_keys_metadata(database, table) + indexes = get_indexes_metadata(database, table) keys += foreign_keys + indexes - payload_columns: list[dict[str, Any]] = [] - table_comment = database.get_table_comment(table_name, schema_name) + payload_columns: list[TableMetadataColumnsResponse] = [] + table_comment = database.get_table_comment(table) for col in columns: dtype = get_col_type(col) payload_columns.append( @@ -87,11 +95,10 @@ def get_table_metadata( } ) return { - "name": table_name, + "name": table.table, "columns": payload_columns, "selectStar": database.select_star( - table_name, - schema=schema_name, + table, indent=True, cols=columns, latest_partition=True, @@ -103,7 +110,7 @@ def get_table_metadata( } -def make_url_safe(raw_url: Union[str, URL]) -> URL: +def make_url_safe(raw_url: str | URL) -> URL: """ Wrapper for SQLAlchemy's make_url(), which tends to raise too detailed of errors, which inevitably find their way into server logs. ArgumentErrors diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 451e96927d9fc..63756a221f958 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -61,7 +61,7 @@ from superset import security_manager, sql_parse from superset.constants import TimeGrain as TimeGrainConstants -from superset.databases.utils import make_url_safe +from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.sql_parse import ParsedQuery, SQLScript, Table @@ -80,6 +80,7 @@ if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn + from superset.databases.schemas import TableMetadataResponse from superset.models.core import Database from superset.models.sql_lab import Query @@ -1034,6 +1035,21 @@ def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any] """ return indexes + @classmethod + def get_table_metadata( # pylint: disable=unused-argument + cls, + database: Database, + table: Table, + ) -> TableMetadataResponse: + """ + Returns basic table metadata + + :param database: Database instance + :param table: A Table instance + :return: Basic table metadata + """ + return get_table_metadata(database, table) + @classmethod def get_extra_table_metadata( # pylint: disable=unused-argument cls, @@ -1472,36 +1488,34 @@ def get_indexes( cls, database: Database, # pylint: disable=unused-argument inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. :param database: The database to inspect :param inspector: The SQLAlchemy inspector - :param table_name: The table to inspect - :param schema: The schema to inspect + :param table: The table instance to inspect :returns: The indexes """ - return inspector.get_indexes(table_name, schema) + return inspector.get_indexes(table.table, table.schema) @classmethod def get_table_comment( - cls, inspector: Inspector, table_name: str, schema: str | None + cls, + inspector: Inspector, + table: Table, ) -> str | None: """ Get comment of table from a given schema and table - :param inspector: SqlAlchemy Inspector instance - :param table_name: Table name - :param schema: Schema name. If omitted, uses default schema for database + :param table: Table instance :return: comment of table """ comment = None try: - comment = inspector.get_table_comment(table_name, schema) + comment = inspector.get_table_comment(table.table, table.schema) comment = comment.get("text") if isinstance(comment, dict) else None except NotImplementedError: # It's expected that some dialects don't implement the comment method @@ -1515,22 +1529,25 @@ def get_table_comment( def get_columns( # pylint: disable=unused-argument cls, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ - Get all columns from a given schema and table + Get all columns from a given schema and table. + + The inspector will be bound to a catalog, if one was specified. :param inspector: SqlAlchemy Inspector instance - :param table_name: Table name - :param schema: Schema name. If omitted, uses default schema for database + :param table: Table instance :param options: Extra options to customise the display of columns in some databases :return: All columns in table """ return convert_inspector_columns( - cast(list[SQLAColumnType], inspector.get_columns(table_name, schema)) + cast( + list[SQLAColumnType], + inspector.get_columns(table.table, table.schema), + ) ) @classmethod @@ -1538,8 +1555,7 @@ def get_metrics( # pylint: disable=unused-argument cls, database: Database, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, ) -> list[MetricType]: """ Get all metrics from a given schema and table. @@ -1556,17 +1572,15 @@ def get_metrics( # pylint: disable=unused-argument @classmethod def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument cls, - table_name: str, - schema: str | None, database: Database, + table: Table, query: Select, columns: list[ResultSetColumnType] | None = None, ) -> Select | None: """ Add a where clause to a query to reference only the most recent partition - :param table_name: Table name - :param schema: Schema name + :param table: Table instance :param database: Database instance :param query: SqlAlchemy query :param columns: List of TableColumns @@ -1589,9 +1603,8 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: def select_star( # pylint: disable=too-many-arguments,too-many-locals cls, database: Database, - table_name: str, + table: Table, engine: Engine, - schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -1604,9 +1617,8 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals WARNING: expects only unquoted table and schema names. :param database: Database instance - :param table_name: Table name, unquoted + :param table: Table instance :param engine: SqlAlchemy Engine instance - :param schema: Schema, unquoted :param limit: limit to impose on query :param show_cols: Show columns in query; otherwise use "*" :param indent: Add indentation to query @@ -1618,16 +1630,18 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals fields: str | list[Any] = "*" cols = cols or [] if (show_cols or latest_partition) and not cols: - cols = database.get_columns(table_name, schema) + cols = database.get_columns(table) if show_cols: fields = cls._get_fields(cols) + quote = engine.dialect.identifier_preparer.quote quote_schema = engine.dialect.identifier_preparer.quote_schema - if schema: - full_table_name = quote_schema(schema) + "." + quote(table_name) - else: - full_table_name = quote(table_name) + full_table_name = ( + quote_schema(table.schema) + "." + quote(table.table) + if table.schema + else quote(table.table) + ) qry = select(fields).select_from(text(full_table_name)) @@ -1635,7 +1649,10 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals qry = qry.limit(limit) if latest_partition: partition_query = cls.where_latest_partition( - table_name, schema, database, qry, columns=cols + database, + table, + qry, + columns=cols, ) if partition_query is not None: qry = partition_query diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 20eae4f9336ce..876babf43d290 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -304,20 +304,18 @@ def get_indexes( cls, database: "Database", inspector: Inspector, - table_name: str, - schema: Optional[str], + table: Table, ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. :param database: The database to inspect :param inspector: The SQLAlchemy inspector - :param table_name: The table to inspect - :param schema: The schema to inspect + :param table: The table instance to inspect :returns: The indexes """ - return cls.normalize_indexes(inspector.get_indexes(table_name, schema)) + return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema)) @classmethod def get_extra_table_metadata( @@ -325,7 +323,7 @@ def get_extra_table_metadata( database: "Database", table: Table, ) -> dict[str, Any]: - indexes = database.get_indexes(table.table, table.schema) + indexes = database.get_indexes(table) if not indexes: return {} partitions_columns = [ @@ -629,9 +627,8 @@ def parameters_json_schema(cls) -> Any: def select_star( # pylint: disable=too-many-arguments cls, database: "Database", - table_name: str, + table: Table, engine: Engine, - schema: Optional[str] = None, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -690,9 +687,8 @@ def select_star( # pylint: disable=too-many-arguments return super().select_star( database, - table_name, + table, engine, - schema, limit, show_cols, indent, diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py index db2e500b53d8f..8a04ee5d3b0f2 100644 --- a/superset/db_engine_specs/db2.py +++ b/superset/db_engine_specs/db2.py @@ -21,6 +21,7 @@ from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.sql_parse import Table logger = logging.getLogger(__name__) @@ -64,7 +65,9 @@ def epoch_to_dttm(cls) -> str: @classmethod def get_table_comment( - cls, inspector: Inspector, table_name: str, schema: Union[str, None] + cls, + inspector: Inspector, + table: Table, ) -> Optional[str]: """ Get comment of table from a given schema @@ -78,7 +81,7 @@ def get_table_comment( """ comment = None try: - table_comment = inspector.get_table_comment(table_name, schema) + table_comment = inspector.get_table_comment(table.table, table.schema) comment = table_comment.get("text") return comment[0] except IndexError: diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 2655ed6c9af6e..9e6639337482f 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -412,24 +412,24 @@ def handle_cursor( # pylint: disable=too-many-locals def get_columns( cls, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: - return BaseEngineSpec.get_columns(inspector, table_name, schema, options) + return BaseEngineSpec.get_columns(inspector, table, options) @classmethod def where_latest_partition( # pylint: disable=too-many-arguments cls, - table_name: str, - schema: str | None, database: Database, + table: Table, query: Select, columns: list[ResultSetColumnType] | None = None, ) -> Select | None: try: col_names, values = cls.latest_partition( - table_name, schema, database, show_first=True + database, + table, + show_first=True, ) except Exception: # pylint: disable=broad-except # table is not partitioned @@ -449,7 +449,10 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]: @classmethod def latest_sub_partition( # type: ignore - cls, table_name: str, schema: str | None, database: Database, **kwargs: Any + cls, + database: Database, + table: Table, + **kwargs: Any, ) -> str: # TODO(bogdan): implement` pass @@ -467,24 +470,24 @@ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None: @classmethod def _partition_query( # pylint: disable=too-many-arguments cls, - table_name: str, - schema: str | None, + table: Table, indexes: list[dict[str, Any]], database: Database, limit: int = 0, order_by: list[tuple[str, bool]] | None = None, filters: dict[Any, Any] | None = None, ) -> str: - full_table_name = f"{schema}.{table_name}" if schema else table_name + full_table_name = ( + f"{table.schema}.{table.table}" if table.schema else table.table + ) return f"SHOW PARTITIONS {full_table_name}" @classmethod def select_star( # pylint: disable=too-many-arguments cls, database: Database, - table_name: str, + table: Table, engine: Engine, - schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -493,9 +496,8 @@ def select_star( # pylint: disable=too-many-arguments ) -> str: return super(PrestoEngineSpec, cls).select_star( database, - table_name, + table, engine, - schema, limit, show_cols, indent, diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index a8143c87edbb6..b46d8132f5836 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -420,8 +420,7 @@ def get_function_names(cls, database: Database) -> list[str]: @classmethod def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument cls, - table_name: str, - schema: str | None, + table: Table, indexes: list[dict[str, Any]], database: Database, limit: int = 0, @@ -434,8 +433,7 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus Note the unused arguments are exposed for sub-classing purposes where custom integrations may require the schema, indexes, etc. to build the partition query. - :param table_name: the name of the table to get partitions from - :param schema: the schema name + :param table: the table instance :param indexes: the indexes associated with the table :param database: the database the query will be run against :param limit: the number of partitions to be returned @@ -464,12 +462,16 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus presto_version = database.get_extra().get("version") if presto_version and Version(presto_version) < Version("0.199"): - full_table_name = f"{schema}.{table_name}" if schema else table_name + full_table_name = ( + f"{table.schema}.{table.table}" if table.schema else table.table + ) partition_select_clause = f"SHOW PARTITIONS FROM {full_table_name}" else: - system_table_name = f'"{table_name}$partitions"' + system_table_name = f'"{table.table}$partitions"' full_table_name = ( - f"{schema}.{system_table_name}" if schema else system_table_name + f"{table.schema}.{system_table_name}" + if table.schema + else system_table_name ) partition_select_clause = f"SELECT * FROM {full_table_name}" @@ -486,16 +488,13 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus @classmethod def where_latest_partition( # pylint: disable=too-many-arguments cls, - table_name: str, - schema: str | None, database: Database, + table: Table, query: Select, columns: list[ResultSetColumnType] | None = None, ) -> Select | None: try: - col_names, values = cls.latest_partition( - table_name, schema, database, show_first=True - ) + col_names, values = cls.latest_partition(database, table, show_first=True) except Exception: # pylint: disable=broad-except # table is not partitioned return None @@ -529,16 +528,14 @@ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None: @cache_manager.data_cache.memoize(timeout=60) def latest_partition( # pylint: disable=too-many-arguments cls, - table_name: str, - schema: str | None, database: Database, + table: Table, show_first: bool = False, indexes: list[dict[str, Any]] | None = None, ) -> tuple[list[str], list[str] | None]: """Returns col name and the latest (max) partition value for a table - :param table_name: the name of the table - :param schema: schema / database / namespace + :param table: the table instance :param database: database query will be run against :type database: models.Database :param show_first: displays the value for the first partitioning key @@ -550,11 +547,11 @@ def latest_partition( # pylint: disable=too-many-arguments (['ds'], ('2018-01-01',)) """ if indexes is None: - indexes = database.get_indexes(table_name, schema) + indexes = database.get_indexes(table) if not indexes: raise SupersetTemplateException( - f"Error getting partition for {schema}.{table_name}. " + f"Error getting partition for {table}. " "Verify that this table has a partition." ) @@ -575,20 +572,23 @@ def latest_partition( # pylint: disable=too-many-arguments return column_names, cls._latest_partition_from_df( df=database.get_df( sql=cls._partition_query( - table_name, - schema, + table, indexes, database, limit=1, order_by=[(column_name, True) for column_name in column_names], ), - schema=schema, + catalog=table.catalog, + schema=table.schema, ) ) @classmethod def latest_sub_partition( - cls, table_name: str, schema: str | None, database: Database, **kwargs: Any + cls, + database: Database, + table: Table, + **kwargs: Any, ) -> Any: """Returns the latest (max) partition value for a table @@ -601,12 +601,9 @@ def latest_sub_partition( ``latest_sub_partition('my_table', event_category='page', event_type='click')`` - :param table_name: the name of the table, can be just the table - name or a fully qualified table name as ``schema_name.table_name`` - :type table_name: str - :param schema: schema / database / namespace - :type schema: str :param database: database query will be run against + :param table: the table instance + :type table: Table :type database: models.Database :param kwargs: keyword arguments define the filtering criteria @@ -615,7 +612,7 @@ def latest_sub_partition( >>> latest_sub_partition('sub_partition_table', event_type='click') '2018-01-01' """ - indexes = database.get_indexes(table_name, schema) + indexes = database.get_indexes(table) part_fields = indexes[0]["column_names"] for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary if k not in k in part_fields: # pylint: disable=comparison-with-itself @@ -633,15 +630,14 @@ def latest_sub_partition( field_to_return = field sql = cls._partition_query( - table_name, - schema, + table, indexes, database, limit=1, order_by=[(field_to_return, True)], filters=kwargs, ) - df = database.get_df(sql, schema) + df = database.get_df(sql, table.catalog, table.schema) if df.empty: return "" return df.to_dict()[field_to_return][0] @@ -966,7 +962,9 @@ def _parse_structural_column( # pylint: disable=too-many-locals @classmethod def _show_columns( - cls, inspector: Inspector, table_name: str, schema: str | None + cls, + inspector: Inspector, + table: Table, ) -> list[ResultRow]: """ Show presto column names @@ -976,17 +974,16 @@ def _show_columns( :return: list of column objects """ quote = inspector.engine.dialect.identifier_preparer.quote_identifier - full_table = quote(table_name) - if schema: - full_table = f"{quote(schema)}.{full_table}" + full_table = quote(table.table) + if table.schema: + full_table = f"{quote(table.schema)}.{full_table}" return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() @classmethod def get_columns( cls, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ @@ -999,7 +996,7 @@ def get_columns( :return: a list of results that contain column info (i.e. column name and data type) """ - columns = cls._show_columns(inspector, table_name, schema) + columns = cls._show_columns(inspector, table) result: list[ResultSetColumnType] = [] for column in columns: # parse column if it is a row or array @@ -1077,9 +1074,8 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]: def select_star( # pylint: disable=too-many-arguments cls, database: Database, - table_name: str, + table: Table, engine: Engine, - schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -1102,9 +1098,8 @@ def select_star( # pylint: disable=too-many-arguments ] return super().select_star( database, - table_name, + table, engine, - schema, limit, show_cols, indent, @@ -1232,11 +1227,10 @@ def get_extra_table_metadata( ) -> dict[str, Any]: metadata = {} - if indexes := database.get_indexes(table.table, table.schema): + if indexes := database.get_indexes(table): col_names, latest_parts = cls.latest_partition( - table.table, - table.schema, database, + table, show_first=True, indexes=indexes, ) @@ -1248,8 +1242,7 @@ def get_extra_table_metadata( "cols": sorted(indexes[0].get("column_names", [])), "latest": dict(zip(col_names, latest_parts)), "partitionQuery": cls._partition_query( - table_name=table.table, - schema=table.schema, + table=table, indexes=indexes, database=database, ), diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 32c119ec1f645..aa151129a078e 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -66,7 +66,7 @@ def get_extra_table_metadata( ) -> dict[str, Any]: metadata = {} - if indexes := database.get_indexes(table.table, table.schema): + if indexes := database.get_indexes(table): col_names, latest_parts = cls.latest_partition( table.table, table.schema, @@ -91,8 +91,7 @@ def get_extra_table_metadata( ), "latest": dict(zip(col_names, latest_parts)), "partitionQuery": cls._partition_query( - table_name=table.table, - schema=table.schema, + table=table, indexes=indexes, database=database, ), @@ -414,8 +413,7 @@ def _expand_columns(cls, col: ResultSetColumnType) -> list[ResultSetColumnType]: def get_columns( cls, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ @@ -423,7 +421,7 @@ def get_columns( "schema_options", expand the schema definition out to show all subfields of nested ROWs as their appropriate dotted paths. """ - base_cols = super().get_columns(inspector, table_name, schema, options) + base_cols = super().get_columns(inspector, table, options) if not (options or {}).get("expand_rows"): return base_cols @@ -434,8 +432,7 @@ def get_indexes( cls, database: Database, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. @@ -444,11 +441,10 @@ def get_indexes( :param database: The database to inspect :param inspector: The SQLAlchemy inspector - :param table_name: The table to inspect - :param schema: The schema to inspect + :param table: The table instance to inspect :returns: The indexes """ try: - return super().get_indexes(database, inspector, table_name, schema) + return super().get_indexes(database, inspector, table) except NoSuchTableError: return [] diff --git a/superset/migrations/versions/2024-04-17_18-09_c15655337636_add_catalog_column.py b/superset/migrations/versions/2024-04-17_18-09_c15655337636_add_catalog_column.py new file mode 100644 index 0000000000000..31c7137398717 --- /dev/null +++ b/superset/migrations/versions/2024-04-17_18-09_c15655337636_add_catalog_column.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add catalog column + +Revision ID: c15655337636 +Revises: 0dc386701747 +Create Date: 2024-04-17 18:09:36.795529 + +""" + +# revision identifiers, used by Alembic. +revision = "c15655337636" +down_revision = "0dc386701747" + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + op.add_column("tables", sa.Column("catalog", sa.String(length=256), nullable=True)) + + +def downgrade(): + op.drop_column("tables", "catalog") diff --git a/superset/models/core.py b/superset/models/core.py index bfd4c39593392..939d3a80b2722 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -45,7 +45,7 @@ Integer, MetaData, String, - Table, + Table as SqlaTable, Text, ) from sqlalchemy.engine import Connection, Dialect, Engine @@ -71,6 +71,7 @@ ) from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.result_set import SupersetResultSet +from superset.sql_parse import Table from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType from superset.utils import cache as cache_util, core as utils from superset.utils.backports import StrEnum @@ -382,13 +383,22 @@ def get_effective_user(self, object_url: URL) -> str | None: ) @contextmanager - def get_sqla_engine( + def get_sqla_engine( # pylint: disable=too-many-arguments self, + catalog: str | None = None, schema: str | None = None, nullpool: bool = True, source: utils.QuerySource | None = None, override_ssh_tunnel: SSHTunnel | None = None, ) -> Engine: + """ + Context manager for a SQLAlchemy engine. + + This method will return a context manager for a SQLAlchemy engine. Using the + context manager (as opposed to the engine directly) is important because we need + to potentially establish SSH tunnels before the connection is created, and clean + them up once the engine is no longer used. + """ from superset.daos.database import ( # pylint: disable=import-outside-toplevel DatabaseDAO, ) @@ -403,7 +413,7 @@ def get_sqla_engine( # if ssh_tunnel is available build engine with information engine_context = ssh_manager_factory.instance.create_tunnel( ssh_tunnel=ssh_tunnel, - sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted, + sqlalchemy_database_uri=sqlalchemy_uri, ) with engine_context as server_context: @@ -415,22 +425,21 @@ def get_sqla_engine( server_context.local_bind_address, ) sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url( - sqlalchemy_uri, server_context + sqlalchemy_uri, + server_context, ) + yield self._get_sqla_engine( + catalog=catalog, schema=schema, nullpool=nullpool, source=source, sqlalchemy_uri=sqlalchemy_uri, ) - # The `get_sqla_engine_with_context` was renamed to `get_sqla_engine`, but we kept a - # reference to the old method to prevent breaking third-party applications. - # TODO (betodealmeida): Remove in 5.0 - get_sqla_engine_with_context = get_sqla_engine - def _get_sqla_engine( self, + catalog: str | None = None, schema: str | None = None, nullpool: bool = True, source: utils.QuerySource | None = None, @@ -447,26 +456,10 @@ def _get_sqla_engine( params["poolclass"] = NullPool connect_args = params.get("connect_args", {}) - # The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and - # had its signature changed in order to support more DB engine specs. Since DB - # engine specs can be released as 3rd party modules we want to make sure the old - # method is still supported so we don't introduce a breaking change. - if hasattr(self.db_engine_spec, "adjust_database_uri"): - sqlalchemy_url = self.db_engine_spec.adjust_database_uri( - sqlalchemy_url, - schema, - ) - logger.warning( - "DB engine spec %s implements the method `adjust_database_uri`, which is " - "deprecated and will be removed in version 3.0. Please update it to " - "implement `adjust_engine_params` instead.", - self.db_engine_spec, - ) - sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params( uri=sqlalchemy_url, connect_args=connect_args, - catalog=None, + catalog=catalog, schema=schema, ) @@ -532,12 +525,16 @@ def _get_sqla_engine( @contextmanager def get_raw_connection( self, + catalog: str | None = None, schema: str | None = None, nullpool: bool = True, source: utils.QuerySource | None = None, ) -> Connection: with self.get_sqla_engine( - schema=schema, nullpool=nullpool, source=source + catalog=catalog, + schema=schema, + nullpool=nullpool, + source=source, ) as engine: with closing(engine.raw_connection()) as conn: # pre-session queries are used to set the selected schema and, in the @@ -575,11 +572,12 @@ def get_reserved_words(self) -> set[str]: def get_df( # pylint: disable=too-many-locals self, sql: str, + catalog: str | None = None, schema: str | None = None, mutator: Callable[[pd.DataFrame], None] | None = None, ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) - with self.get_sqla_engine(schema) as engine: + with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: engine_url = engine.url mutate_after_split = config["MUTATE_AFTER_SPLIT"] sql_query_mutator = config["SQL_QUERY_MUTATOR"] @@ -640,8 +638,13 @@ def _log_query(sql: str) -> None: return df - def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str: - with self.get_sqla_engine(schema) as engine: + def compile_sqla_query( + self, + qry: Select, + catalog: str | None = None, + schema: str | None = None, + ) -> str: + with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) # pylint: disable=protected-access @@ -652,8 +655,7 @@ def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str: def select_star( # pylint: disable=too-many-arguments self, - table_name: str, - schema: str | None = None, + table: Table, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -661,11 +663,10 @@ def select_star( # pylint: disable=too-many-arguments cols: list[ResultSetColumnType] | None = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" - with self.get_sqla_engine(schema) as engine: + with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: return self.db_engine_spec.select_star( self, - table_name, - schema=schema, + table, engine=engine, limit=limit, show_cols=show_cols, @@ -690,6 +691,7 @@ def safe_sqlalchemy_uri(self) -> str: ) def get_all_table_names_in_schema( # pylint: disable=unused-argument self, + catalog: str | None, schema: str, cache: bool = False, cache_timeout: int | None = None, @@ -707,7 +709,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument :return: The table/schema pairs """ try: - with self.get_inspector_with_context() as inspector: + with self.get_inspector(catalog=catalog, schema=schema) as inspector: return { (table, schema) for table in self.db_engine_spec.get_table_names( @@ -725,6 +727,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument ) def get_all_view_names_in_schema( # pylint: disable=unused-argument self, + catalog: str | None, schema: str, cache: bool = False, cache_timeout: int | None = None, @@ -742,7 +745,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument :return: set of views """ try: - with self.get_inspector_with_context() as inspector: + with self.get_inspector(catalog=catalog, schema=schema) as inspector: return { (view, schema) for view in self.db_engine_spec.get_view_names( @@ -755,10 +758,17 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @contextmanager - def get_inspector_with_context( - self, ssh_tunnel: SSHTunnel | None = None + def get_inspector( + self, + catalog: str | None = None, + schema: str | None = None, + ssh_tunnel: SSHTunnel | None = None, ) -> Inspector: - with self.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine: + with self.get_sqla_engine( + catalog=catalog, + schema=schema, + override_ssh_tunnel=ssh_tunnel, + ) as engine: yield sqla.inspect(engine) @cache_util.memoized_func( @@ -767,6 +777,7 @@ def get_inspector_with_context( ) def get_all_schema_names( # pylint: disable=unused-argument self, + catalog: str | None = None, cache: bool = False, cache_timeout: int | None = None, force: bool = False, @@ -783,7 +794,10 @@ def get_all_schema_names( # pylint: disable=unused-argument :return: schema list """ try: - with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector: + with self.get_inspector( + catalog=catalog, + ssh_tunnel=ssh_tunnel, + ) as inspector: return self.db_engine_spec.get_schema_names(inspector) except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @@ -835,51 +849,57 @@ def get_encrypted_extra(self) -> dict[str, Any]: def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None: self.db_engine_spec.update_params_from_encrypted_extra(self, params) - def get_table(self, table_name: str, schema: str | None = None) -> Table: + def get_table(self, table: Table) -> SqlaTable: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) - with self.get_sqla_engine() as engine: - return Table( - table_name, + with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: + return SqlaTable( + table.table, meta, - schema=schema or None, + schema=table.schema or None, autoload=True, autoload_with=engine, ) - def get_table_comment( - self, table_name: str, schema: str | None = None - ) -> str | None: - with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_table_comment(inspector, table_name, schema) - - def get_columns( - self, table_name: str, schema: str | None = None - ) -> list[ResultSetColumnType]: - with self.get_inspector_with_context() as inspector: + def get_table_comment(self, table: Table) -> str | None: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return self.db_engine_spec.get_table_comment(inspector, table) + + def get_columns(self, table: Table) -> list[ResultSetColumnType]: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: return self.db_engine_spec.get_columns( - inspector, table_name, schema, self.schema_options + inspector, table, self.schema_options ) def get_metrics( self, - table_name: str, - schema: str | None = None, + table: Table, ) -> list[MetricType]: - with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_metrics(self, inspector, table_name, schema) - - def get_indexes( - self, table_name: str, schema: str | None = None - ) -> list[dict[str, Any]]: - with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_indexes(self, inspector, table_name, schema) - - def get_pk_constraint( - self, table_name: str, schema: str | None = None - ) -> dict[str, Any]: - with self.get_inspector_with_context() as inspector: - pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return self.db_engine_spec.get_metrics(self, inspector, table) + + def get_indexes(self, table: Table) -> list[dict[str, Any]]: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return self.db_engine_spec.get_indexes(self, inspector, table) + + def get_pk_constraint(self, table: Table) -> dict[str, Any]: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + pk_constraint = inspector.get_pk_constraint(table.table, table.schema) or {} def _convert(value: Any) -> Any: try: @@ -889,11 +909,12 @@ def _convert(value: Any) -> Any: return {key: _convert(value) for key, value in pk_constraint.items()} - def get_foreign_keys( - self, table_name: str, schema: str | None = None - ) -> list[dict[str, Any]]: - with self.get_inspector_with_context() as inspector: - return inspector.get_foreign_keys(table_name, schema) + def get_foreign_keys(self, table: Table) -> list[dict[str, Any]]: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return inspector.get_foreign_keys(table.table, table.schema) def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, @@ -942,36 +963,22 @@ def get_perm(self) -> str: return self.perm # type: ignore def has_table(self, table: Table) -> bool: - with self.get_sqla_engine() as engine: - return engine.has_table(table.table_name, table.schema or None) + with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: + return engine.has_table(table.table, table.schema) - def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool: - with self.get_sqla_engine() as engine: - return engine.has_table(table_name, schema) - - @classmethod - def _has_view( - cls, - conn: Connection, - dialect: Dialect, - view_name: str, - schema: str | None = None, - ) -> bool: - view_names: list[str] = [] - try: - view_names = dialect.get_view_names(connection=conn, schema=schema) - except Exception: # pylint: disable=broad-except - logger.warning("Has view failed", exc_info=True) - return view_name in view_names - - def has_view(self, view_name: str, schema: str | None = None) -> bool: - with self.get_sqla_engine(schema) as engine: - return engine.run_callable( - self._has_view, engine.dialect, view_name, schema - ) + def has_view(self, table: Table) -> bool: + with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: + connection = engine.connect() + try: + views = engine.dialect.get_view_names( + connection=connection, + schema=table.schema, + ) + except Exception: # pylint: disable=broad-except + logger.warning("Has view failed", exc_info=True) + views = [] - def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool: - return self.has_view(view_name=view_name, schema=schema) + return table.table in views def get_dialect(self) -> Dialect: sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted) diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index eba3acf36edd4..7f81081777538 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -37,6 +37,7 @@ from superset.daos.datasource import DatasourceDAO from superset.exceptions import SupersetException, SupersetSecurityException from superset.models.core import Database +from superset.sql_parse import Table from superset.superset_typing import FlaskResponse from superset.utils.core import DatasourceType from superset.views.base import ( @@ -180,8 +181,7 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse: ) external_metadata = get_physical_table_metadata( database=database, - table_name=params["table_name"], - schema_name=params["schema_name"], + table=Table(params["table_name"], params["schema_name"]), normalize_columns=params.get("normalize_columns") or False, ) except (NoResultFound, NoSuchTableError) as ex: diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 5e004992de28c..8de9f7483bf26 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -1175,6 +1175,169 @@ def test_csv_upload_file_extension_valid( assert response.status_code == 200 +def test_table_metadata_happy_path( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint. + """ + database = mocker.MagicMock() + database.db_engine_spec.get_table_metadata.return_value = {"hello": "world"} + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database) + mocker.patch("superset.databases.api.security_manager.raise_for_access") + + response = client.get("/api/v1/database/1/table_metadata/?name=t") + assert response.json == {"hello": "world"} + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("t"), + ) + + response = client.get("/api/v1/database/1/table_metadata/?name=t&schema=s") + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("t", "s"), + ) + + response = client.get("/api/v1/database/1/table_metadata/?name=t&catalog=c") + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("t", None, "c"), + ) + + response = client.get( + "/api/v1/database/1/table_metadata/?name=t&schema=s&catalog=c" + ) + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("t", "s", "c"), + ) + + +def test_table_metadata_no_table( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint when no table name is passed. + """ + database = mocker.MagicMock() + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database) + + response = client.get("/api/v1/database/1/table_metadata/?schema=s&catalog=c") + assert response.status_code == 422 + assert response.json == { + "errors": [ + { + "message": "An error happened when validating the request", + "error_type": "INVALID_PAYLOAD_SCHEMA_ERROR", + "level": "error", + "extra": { + "messages": {"name": ["Missing data for required field."]}, + "issue_codes": [ + { + "code": 1020, + "message": "Issue 1020 - The submitted payload has the incorrect schema.", + } + ], + }, + } + ] + } + + +def test_table_metadata_slashes( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint with names that have slashes. + """ + database = mocker.MagicMock() + database.db_engine_spec.get_table_metadata.return_value = {"hello": "world"} + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database) + mocker.patch("superset.databases.api.security_manager.raise_for_access") + + client.get("/api/v1/database/1/table_metadata/?name=foo/bar") + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("foo/bar"), + ) + + +def test_table_metadata_invalid_database( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint when the database is invalid. + """ + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=None) + + response = client.get("/api/v1/database/1/table_metadata/?name=t") + assert response.status_code == 404 + assert response.json == { + "errors": [ + { + "message": "No such database", + "error_type": "DATABASE_NOT_FOUND_ERROR", + "level": "error", + "extra": { + "issue_codes": [ + { + "code": 1011, + "message": "Issue 1011 - Superset encountered an unexpected error.", + }, + { + "code": 1036, + "message": "Issue 1036 - The database was deleted.", + }, + ] + }, + } + ] + } + + +def test_table_metadata_unauthorized( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint when the user is unauthorized. + """ + database = mocker.MagicMock() + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database) + mocker.patch( + "superset.databases.api.security_manager.raise_for_access", + side_effect=SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.TABLE_SECURITY_ACCESS_ERROR, + message="You don't have access to the table", + level=ErrorLevel.ERROR, + ) + ), + ) + + response = client.get("/api/v1/database/1/table_metadata/?name=t") + assert response.status_code == 404 + assert response.json == { + "errors": [ + { + "message": "No such table", + "error_type": "TABLE_NOT_FOUND_ERROR", + "level": "error", + "extra": None, + } + ] + } + def test_table_extra_metadata_happy_path( mocker: MockFixture, client: Any, diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index e17e0d2833db8..3bc05ee20eec0 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -232,9 +232,8 @@ class NoLimitDBEngineSpec(BaseEngineSpec): sql = BaseEngineSpec.select_star( database=database, - table_name="my_table", + table=Table("my_table"), engine=engine, - schema=None, limit=100, show_cols=True, indent=True, @@ -252,9 +251,8 @@ class NoLimitDBEngineSpec(BaseEngineSpec): sql = NoLimitDBEngineSpec.select_star( database=database, - table_name="my_table", + table=Table("my_table"), engine=engine, - schema=None, limit=100, show_cols=True, indent=True, diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py index 3870297db8b70..18af750bf54a7 100644 --- a/tests/unit_tests/db_engine_specs/test_bigquery.py +++ b/tests/unit_tests/db_engine_specs/test_bigquery.py @@ -27,6 +27,7 @@ from sqlalchemy.sql import sqltypes from sqlalchemy_bigquery import BigQueryDialect +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm @@ -156,9 +157,8 @@ def test_select_star(mocker: MockFixture) -> None: sql = BigQueryEngineSpec.select_star( database=database, - table_name="my_table", + table=Table("my_table"), engine=engine, - schema=None, limit=100, show_cols=True, indent=True, diff --git a/tests/unit_tests/db_engine_specs/test_db2.py b/tests/unit_tests/db_engine_specs/test_db2.py index d7dd19ad5923c..7e4ea68aa8013 100644 --- a/tests/unit_tests/db_engine_specs/test_db2.py +++ b/tests/unit_tests/db_engine_specs/test_db2.py @@ -18,6 +18,8 @@ import pytest from pytest_mock import MockerFixture +from superset.sql_parse import Table + def test_epoch_to_dttm() -> None: """ @@ -43,7 +45,7 @@ def test_get_table_comment(mocker: MockerFixture): } assert ( - Db2EngineSpec.get_table_comment(mock_inspector, "my_table", "my_schema") + Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema")) == "This is a table comment" ) @@ -59,7 +61,8 @@ def test_get_table_comment_empty(mocker: MockerFixture): mock_inspector.get_table_comment.return_value = {} assert ( - Db2EngineSpec.get_table_comment(mock_inspector, "my_table", "my_schema") == None + Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema")) + is None ) diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 8d57d4ed1a8c3..638b377c82709 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -24,6 +24,7 @@ from sqlalchemy import sql, text, types from sqlalchemy.engine.url import make_url +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -143,7 +144,10 @@ def test_where_latest_partition( expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}""" result = spec.where_latest_partition( - "table", mock.MagicMock(), mock.MagicMock(), query, columns + mock.MagicMock(), + Table("table"), + query, + columns, ) assert result is not None actual = result.compile( diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 7ce9199c5bd1d..a0a81980bf28a 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -442,7 +442,7 @@ def test_get_columns(mocker: MockerFixture): mock_inspector = mocker.MagicMock() mock_inspector.get_columns.return_value = sqla_columns - actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema") + actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema")) expected = [ ResultSetColumnType( name="field1", column_name="field1", type=field1_type, is_dttm=False @@ -475,7 +475,9 @@ def test_get_columns_expand_rows(mocker: MockerFixture): mock_inspector.get_columns.return_value = sqla_columns actual = TrinoEngineSpec.get_columns( - mock_inspector, "table", "schema", {"expand_rows": True} + mock_inspector, + Table("table", "schema"), + {"expand_rows": True}, ) expected = [ ResultSetColumnType( diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index beefd3ea3cc5a..ce3ad1822271f 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -18,7 +18,6 @@ # pylint: disable=import-outside-toplevel import json from datetime import datetime -from typing import Optional import pytest from pytest_mock import MockFixture @@ -26,6 +25,7 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.models.core import Database +from superset.sql_parse import Table def test_get_metrics(mocker: MockFixture) -> None: @@ -37,7 +37,7 @@ def test_get_metrics(mocker: MockFixture) -> None: from superset.models.core import Database database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - assert database.get_metrics("table") == [ + assert database.get_metrics(Table("table")) == [ { "expression": "COUNT(*)", "metric_name": "count", @@ -52,8 +52,7 @@ def get_metrics( cls, database: Database, inspector: Inspector, - table_name: str, - schema: Optional[str], + table: Table, ) -> list[MetricType]: return [ { @@ -65,7 +64,7 @@ def get_metrics( ] database.get_db_engine_spec = mocker.MagicMock(return_value=CustomSqliteEngineSpec) - assert database.get_metrics("table") == [ + assert database.get_metrics(Table("table")) == [ { "expression": "COUNT(DISTINCT user_id)", "metric_name": "count_distinct_user_id",