diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.test.js b/superset-frontend/src/SqlLab/actions/sqlLab.test.js
index ecf2c4d7e299c..7a7f2f72a4c71 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.test.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.test.js
@@ -508,10 +508,10 @@ describe('async actions', () => {
fetchMock.delete(updateTableSchemaEndpoint, {});
fetchMock.post(updateTableSchemaEndpoint, JSON.stringify({ id: 1 }));
- const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table/*/*/';
+ const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*';
fetchMock.get(getTableMetadataEndpoint, {});
const getExtraTableMetadataEndpoint =
- 'glob:**/api/v1/database/*/table_metadata/extra/';
+ 'glob:**/api/v1/database/*/table_metadata/extra/*';
fetchMock.get(getExtraTableMetadataEndpoint, {});
let isFeatureEnabledMock;
diff --git a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx
index f8c94468bf7f3..b5003b16f7b47 100644
--- a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx
+++ b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx
@@ -61,13 +61,13 @@ beforeEach(() => {
},
],
});
- fetchMock.get('glob:*/api/v1/database/*/table/*/*', {
+ fetchMock.get('glob:*/api/v1/database/*/table_metadata/*', {
status: 200,
body: {
columns: table.columns,
},
});
- fetchMock.get('glob:*/api/v1/database/*/table_metadata/extra/', {
+ fetchMock.get('glob:*/api/v1/database/*/table_metadata/extra/*', {
status: 200,
body: {},
});
diff --git a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
index a2fe88020aa45..7b95b6d0f3492 100644
--- a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
+++ b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
@@ -47,7 +47,7 @@ jest.mock(
{column.name}
),
);
-const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table/*/*/';
+const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*';
const getExtraTableMetadataEndpoint =
'glob:**/api/v1/database/*/table_metadata/extra/*';
const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded';
diff --git a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
index ef5797fb309c2..c964fc32faaf0 100644
--- a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
+++ b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
@@ -74,7 +74,9 @@ const DatasetPanelWrapper = ({
const { dbId, tableName, schema } = props;
setLoading(true);
setHasColumns?.(false);
- const path = `/api/v1/database/${dbId}/table/${tableName}/${schema}/`;
+ const path = schema
+ ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}/`
+ : `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`;
try {
const response = await SupersetClient.get({
endpoint: path,
diff --git a/superset-frontend/src/hooks/apiResources/tables.ts b/superset-frontend/src/hooks/apiResources/tables.ts
index 164fe0f0ab19c..41be4c167c9c8 100644
--- a/superset-frontend/src/hooks/apiResources/tables.ts
+++ b/superset-frontend/src/hooks/apiResources/tables.ts
@@ -114,9 +114,9 @@ const tableApi = api.injectEndpoints({
}),
tableMetadata: builder.query({
query: ({ dbId, schema, table }) => ({
- endpoint: `/api/v1/database/${dbId}/table/${encodeURIComponent(
- table,
- )}/${encodeURIComponent(schema)}/`,
+ endpoint: schema
+ ? `/api/v1/database/${dbId}/table_metadata/?name=${table}&schema=${schema}`
+ : `/api/v1/database/${dbId}/table_metadata/?name=${table}`,
transformResponse: ({ json }: TableMetadataReponse) => json,
}),
}),
diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py
index 16b87a567a5f0..d4b14445687cd 100644
--- a/superset/commands/dataset/create.py
+++ b/superset/commands/dataset/create.py
@@ -34,6 +34,7 @@
from superset.daos.exceptions import DAOCreateFailedError
from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager
+from superset.sql_parse import Table
logger = logging.getLogger(__name__)
@@ -80,7 +81,10 @@ def validate(self) -> None:
if (
database
and not sql
- and not DatasetDAO.validate_table_exists(database, table_name, schema)
+ and not DatasetDAO.validate_table_exists(
+ database,
+ Table(table_name, schema),
+ )
):
exceptions.append(TableNotFoundValidationError(table_name))
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index afd2791c9c601..363d64aa14880 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -49,7 +49,7 @@
Integer,
or_,
String,
- Table,
+ Table as DBTable,
Text,
update,
)
@@ -108,7 +108,7 @@
validate_adhoc_subquery,
)
from superset.models.slice import Slice
-from superset.sql_parse import ParsedQuery, sanitize_clause
+from superset.sql_parse import ParsedQuery, sanitize_clause, Table
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
@@ -1068,7 +1068,7 @@ def data(self) -> dict[str, Any]:
return {s: getattr(self, s) for s in attrs}
-sqlatable_user = Table(
+sqlatable_user = DBTable(
"sqlatable_user",
metadata,
Column("id", Integer, primary_key=True),
@@ -1146,6 +1146,7 @@ class SqlaTable(
foreign_keys=[database_id],
)
schema = Column(String(255))
+ catalog = Column(String(256), nullable=True, default=None)
sql = Column(MediumText())
is_sqllab_view = Column(Boolean, default=False)
template_params = Column(Text)
@@ -1322,8 +1323,7 @@ def external_metadata(self) -> list[ResultSetColumnType]:
return get_virtual_table_metadata(dataset=self)
return get_physical_table_metadata(
database=self.database,
- table_name=self.table_name,
- schema_name=self.schema,
+ table=Table(self.table_name, self.schema, self.catalog),
normalize_columns=self.normalize_columns,
)
@@ -1339,7 +1339,9 @@ def select_star(self) -> str | None:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
- self.table_name, schema=self.schema, show_cols=False, latest_partition=False
+ Table(self.table_name, self.schema, self.catalog),
+ show_cols=False,
+ latest_partition=False,
)
@property
@@ -1779,7 +1781,13 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
)
def get_sqla_table_object(self) -> Table:
- return self.database.get_table(self.table_name, schema=self.schema)
+ return self.database.get_table(
+ Table(
+ self.table_name,
+ self.schema,
+ self.catalog,
+ )
+ )
def fetch_metadata(self, commit: bool = True) -> MetadataResult:
"""
@@ -1791,7 +1799,13 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
new_columns = self.external_metadata()
metrics = [
SqlMetric(**metric)
- for metric in self.database.get_metrics(self.table_name, self.schema)
+ for metric in self.database.get_metrics(
+ Table(
+ self.table_name,
+ self.schema,
+ self.catalog,
+ )
+ )
]
any_date_col = None
db_engine_spec = self.db_engine_spec
@@ -2038,7 +2052,7 @@ def load_database(self: SqlaTable) -> None:
sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column)
sa.event.listen(TableColumn, "after_update", SqlaTable.update_column)
-RLSFilterRoles = Table(
+RLSFilterRoles = DBTable(
"rls_filter_roles",
metadata,
Column("id", Integer, primary_key=True),
@@ -2046,7 +2060,7 @@ def load_database(self: SqlaTable) -> None:
Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")),
)
-RLSFilterTables = Table(
+RLSFilterTables = DBTable(
"rls_filter_tables",
metadata,
Column("id", Integer, primary_key=True),
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 4bc11aee42d80..f547e238dfcc8 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -38,7 +38,7 @@
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
-from superset.sql_parse import ParsedQuery
+from superset.sql_parse import ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
@@ -47,24 +47,18 @@
def get_physical_table_metadata(
database: Database,
- table_name: str,
+ table: Table,
normalize_columns: bool,
- schema_name: str | None = None,
) -> list[ResultSetColumnType]:
"""Use SQLAlchemy inspector to get table metadata"""
db_engine_spec = database.db_engine_spec
db_dialect = database.get_dialect()
- # ensure empty schema
- _schema_name = schema_name if schema_name else None
- # Table does not exist or is not visible to a connection.
- if not (
- database.has_table_by_name(table_name=table_name, schema=_schema_name)
- or database.has_view_by_name(view_name=table_name, schema=_schema_name)
- ):
- raise NoSuchTableError
+ # Table does not exist or is not visible to a connection.
+ if not (database.has_table(table) or database.has_view(table)):
+ raise NoSuchTableError(table)
- cols = database.get_columns(table_name, schema=_schema_name)
+ cols = database.get_columns(table)
for col in cols:
try:
if isinstance(col["type"], TypeEngine):
diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py
index 4647e02ce68af..c9c16beaa69d0 100644
--- a/superset/daos/dataset.py
+++ b/superset/daos/dataset.py
@@ -30,6 +30,7 @@
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from superset.views.base import DatasourceFilter
@@ -72,13 +73,14 @@ def get_related_objects(database_id: int) -> dict[str, Any]:
@staticmethod
def validate_table_exists(
- database: Database, table_name: str, schema: str | None
+ database: Database,
+ table: Table,
) -> bool:
try:
- database.get_table(table_name, schema=schema)
+ database.get_table(table)
return True
except SQLAlchemyError as ex: # pragma: no cover
- logger.warning("Got an error %s validating table: %s", str(ex), table_name)
+ logger.warning("Got an error %s validating table: %s", str(ex), table)
return False
@staticmethod
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 4537fc8bc512f..1b5f515ea69bb 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -133,6 +133,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
RouteMethod.RELATED,
"tables",
"table_metadata",
+ "table_metadata_deprecated",
"table_extra_metadata",
"table_extra_metadata_deprecated",
"select_star",
@@ -717,10 +718,10 @@ def tables(self, pk: int, **kwargs: Any) -> FlaskResponse:
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
- f".table_metadata",
+ f".table_metadata_deprecated",
log_to_statsd=False,
)
- def table_metadata(
+ def table_metadata_deprecated(
self, database: Database, table_name: str, schema_name: str
) -> FlaskResponse:
"""Get database table metadata.
@@ -761,16 +762,16 @@ def table_metadata(
500:
$ref: '#/components/responses/500'
"""
- self.incr_stats("init", self.table_metadata.__name__)
+ self.incr_stats("init", self.table_metadata_deprecated.__name__)
try:
- table_info = get_table_metadata(database, table_name, schema_name)
+ table_info = get_table_metadata(database, Table(table_name, schema_name))
except SQLAlchemyError as ex:
- self.incr_stats("error", self.table_metadata.__name__)
+ self.incr_stats("error", self.table_metadata_deprecated.__name__)
return self.response_422(error_msg_from_exception(ex))
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
- self.incr_stats("success", self.table_metadata.__name__)
+ self.incr_stats("success", self.table_metadata_deprecated.__name__)
return self.response(200, **table_info)
@expose("//table_extra///", methods=("GET",))
@@ -839,7 +840,86 @@ def table_extra_metadata_deprecated(
payload = database.db_engine_spec.get_extra_table_metadata(database, table)
return self.response(200, **payload)
- @expose("//table_metadata/extra/", methods=("GET",))
+ @expose("//table_metadata/", methods=["GET"])
+ @protect()
+ @statsd_metrics
+ @event_logger.log_this_with_context(
+ action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+ f".table_metadata",
+ log_to_statsd=False,
+ )
+ def table_metadata(self, pk: int) -> FlaskResponse:
+ """
+ Get metadata for a given table.
+
+ Optionally, a schema and a catalog can be passed, if different from the default
+ ones.
+ ---
+ get:
+ summary: Get table metadata
+ description: >-
+ Metadata associated with the table (columns, indexes, etc.)
+ parameters:
+ - in: path
+ schema:
+ type: integer
+ name: pk
+ description: The database id
+ - in: query
+ schema:
+ type: string
+ name: table
+ required: true
+ description: Table name
+ - in: query
+ schema:
+ type: string
+ name: schema
+ description: >-
+ Optional table schema, if not passed default schema will be used
+ - in: query
+ schema:
+ type: string
+ name: catalog
+ description: >-
+ Optional table catalog, if not passed default catalog will be used
+ responses:
+ 200:
+ description: Table metadata information
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/TableExtraMetadataResponseSchema"
+ 401:
+ $ref: '#/components/responses/401'
+ 404:
+ $ref: '#/components/responses/404'
+ 500:
+ $ref: '#/components/responses/500'
+ """
+ self.incr_stats("init", self.table_metadata.__name__)
+
+ database = DatabaseDAO.find_by_id(pk)
+ if database is None:
+ raise DatabaseNotFoundException("No such database")
+
+ try:
+ parameters = QualifiedTableSchema().load(request.args)
+ except ValidationError as ex:
+ raise InvalidPayloadSchemaError(ex) from ex
+
+ table = Table(parameters["name"], parameters["schema"], parameters["catalog"])
+ try:
+ security_manager.raise_for_access(database=database, table=table)
+ except SupersetSecurityException as ex:
+ # instead of raising 403, raise 404 to hide table existence
+ raise TableNotFoundException("No such table") from ex
+
+ payload = database.db_engine_spec.get_table_metadata(database, table)
+
+ return self.response(200, **payload)
+
+ @expose("//table_metadata/extra/", methods=["GET"])
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
@@ -973,7 +1053,8 @@ def select_star(
self.incr_stats("init", self.select_star.__name__)
try:
result = database.select_star(
- table_name, schema_name, latest_partition=True
+ Table(table_name, schema_name),
+ latest_partition=True,
)
except NoSuchTableError:
self.incr_stats("error", self.select_star.__name__)
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index dccab980bcd18..47f4bf15fac0f 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -17,11 +17,13 @@
# pylint: disable=unused-argument, too-many-lines
+from __future__ import annotations
+
import inspect
import json
import os
import re
-from typing import Any
+from typing import Any, TypedDict
from flask import current_app
from flask_babel import lazy_gettext as _
@@ -581,6 +583,49 @@ class DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema):
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
+class TableMetadataOptionsResponse(TypedDict):
+ deferrable: bool
+ initially: bool
+ match: bool
+ ondelete: bool
+ onupdate: bool
+
+
+class TableMetadataColumnsResponse(TypedDict, total=False):
+ keys: list[str]
+ longType: str
+ name: str
+ type: str
+ duplicates_constraint: str | None
+ comment: str | None
+
+
+class TableMetadataForeignKeysIndexesResponse(TypedDict):
+ column_names: list[str]
+ name: str
+ options: TableMetadataOptionsResponse
+ referred_columns: list[str]
+ referred_schema: str
+ referred_table: str
+ type: str
+
+
+class TableMetadataPrimaryKeyResponse(TypedDict):
+ column_names: list[str]
+ name: str
+ type: str
+
+
+class TableMetadataResponse(TypedDict):
+ name: str
+ columns: list[TableMetadataColumnsResponse]
+ foreignKeys: list[TableMetadataForeignKeysIndexesResponse]
+ indexes: list[TableMetadataForeignKeysIndexesResponse]
+ primaryKey: TableMetadataPrimaryKeyResponse
+ selectStar: str
+ comment: str | None
+
+
class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool()
initially = fields.Bool()
diff --git a/superset/databases/utils.py b/superset/databases/utils.py
index 8de4bb6f2353d..dfd75eb2233f4 100644
--- a/superset/databases/utils.py
+++ b/superset/databases/utils.py
@@ -14,19 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Optional, Union
+
+from __future__ import annotations
+
+from typing import Any, TYPE_CHECKING
from sqlalchemy.engine.url import make_url, URL
from superset.commands.database.exceptions import DatabaseInvalidError
+from superset.sql_parse import Table
+
+if TYPE_CHECKING:
+ from superset.databases.schemas import (
+ TableMetadataColumnsResponse,
+ TableMetadataForeignKeysIndexesResponse,
+ TableMetadataResponse,
+ )
def get_foreign_keys_metadata(
database: Any,
- table_name: str,
- schema_name: Optional[str],
-) -> list[dict[str, Any]]:
- foreign_keys = database.get_foreign_keys(table_name, schema_name)
+ table: Table,
+) -> list[TableMetadataForeignKeysIndexesResponse]:
+ foreign_keys = database.get_foreign_keys(table)
for fk in foreign_keys:
fk["column_names"] = fk.pop("constrained_columns")
fk["type"] = "fk"
@@ -34,9 +44,10 @@ def get_foreign_keys_metadata(
def get_indexes_metadata(
- database: Any, table_name: str, schema_name: Optional[str]
-) -> list[dict[str, Any]]:
- indexes = database.get_indexes(table_name, schema_name)
+ database: Any,
+ table: Table,
+) -> list[TableMetadataForeignKeysIndexesResponse]:
+ indexes = database.get_indexes(table)
for idx in indexes:
idx["type"] = "index"
return indexes
@@ -51,30 +62,27 @@ def get_col_type(col: dict[Any, Any]) -> str:
return dtype
-def get_table_metadata(
- database: Any, table_name: str, schema_name: Optional[str]
-) -> dict[str, Any]:
+def get_table_metadata(database: Any, table: Table) -> TableMetadataResponse:
"""
Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found.
:param database: The database model
- :param table_name: Table name
- :param schema_name: schema name
+ :param table: Table instance
:return: Dict table metadata ready for API response
"""
keys = []
- columns = database.get_columns(table_name, schema_name)
- primary_key = database.get_pk_constraint(table_name, schema_name)
+ columns = database.get_columns(table)
+ primary_key = database.get_pk_constraint(table)
if primary_key and primary_key.get("constrained_columns"):
primary_key["column_names"] = primary_key.pop("constrained_columns")
primary_key["type"] = "pk"
keys += [primary_key]
- foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
- indexes = get_indexes_metadata(database, table_name, schema_name)
+ foreign_keys = get_foreign_keys_metadata(database, table)
+ indexes = get_indexes_metadata(database, table)
keys += foreign_keys + indexes
- payload_columns: list[dict[str, Any]] = []
- table_comment = database.get_table_comment(table_name, schema_name)
+ payload_columns: list[TableMetadataColumnsResponse] = []
+ table_comment = database.get_table_comment(table)
for col in columns:
dtype = get_col_type(col)
payload_columns.append(
@@ -87,11 +95,10 @@ def get_table_metadata(
}
)
return {
- "name": table_name,
+ "name": table.table,
"columns": payload_columns,
"selectStar": database.select_star(
- table_name,
- schema=schema_name,
+ table,
indent=True,
cols=columns,
latest_partition=True,
@@ -103,7 +110,7 @@ def get_table_metadata(
}
-def make_url_safe(raw_url: Union[str, URL]) -> URL:
+def make_url_safe(raw_url: str | URL) -> URL:
"""
Wrapper for SQLAlchemy's make_url(), which tends to raise too detailed of
errors, which inevitably find their way into server logs. ArgumentErrors
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 451e96927d9fc..63756a221f958 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -61,7 +61,7 @@
from superset import security_manager, sql_parse
from superset.constants import TimeGrain as TimeGrainConstants
-from superset.databases.utils import make_url_safe
+from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.sql_parse import ParsedQuery, SQLScript, Table
@@ -80,6 +80,7 @@
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
+ from superset.databases.schemas import TableMetadataResponse
from superset.models.core import Database
from superset.models.sql_lab import Query
@@ -1034,6 +1035,21 @@ def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]
"""
return indexes
+ @classmethod
+ def get_table_metadata( # pylint: disable=unused-argument
+ cls,
+ database: Database,
+ table: Table,
+ ) -> TableMetadataResponse:
+ """
+ Returns basic table metadata
+
+ :param database: Database instance
+ :param table: A Table instance
+ :return: Basic table metadata
+ """
+ return get_table_metadata(database, table)
+
@classmethod
def get_extra_table_metadata( # pylint: disable=unused-argument
cls,
@@ -1472,36 +1488,34 @@ def get_indexes(
cls,
database: Database, # pylint: disable=unused-argument
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
- :param table_name: The table to inspect
- :param schema: The schema to inspect
+ :param table: The table instance to inspect
:returns: The indexes
"""
- return inspector.get_indexes(table_name, schema)
+ return inspector.get_indexes(table.table, table.schema)
@classmethod
def get_table_comment(
- cls, inspector: Inspector, table_name: str, schema: str | None
+ cls,
+ inspector: Inspector,
+ table: Table,
) -> str | None:
"""
Get comment of table from a given schema and table
-
:param inspector: SqlAlchemy Inspector instance
- :param table_name: Table name
- :param schema: Schema name. If omitted, uses default schema for database
+ :param table: Table instance
:return: comment of table
"""
comment = None
try:
- comment = inspector.get_table_comment(table_name, schema)
+ comment = inspector.get_table_comment(table.table, table.schema)
comment = comment.get("text") if isinstance(comment, dict) else None
except NotImplementedError:
# It's expected that some dialects don't implement the comment method
@@ -1515,22 +1529,25 @@ def get_table_comment(
def get_columns( # pylint: disable=unused-argument
cls,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
- Get all columns from a given schema and table
+ Get all columns from a given schema and table.
+
+ The inspector will be bound to a catalog, if one was specified.
:param inspector: SqlAlchemy Inspector instance
- :param table_name: Table name
- :param schema: Schema name. If omitted, uses default schema for database
+ :param table: Table instance
:param options: Extra options to customise the display of columns in
some databases
:return: All columns in table
"""
return convert_inspector_columns(
- cast(list[SQLAColumnType], inspector.get_columns(table_name, schema))
+ cast(
+ list[SQLAColumnType],
+ inspector.get_columns(table.table, table.schema),
+ )
)
@classmethod
@@ -1538,8 +1555,7 @@ def get_metrics( # pylint: disable=unused-argument
cls,
database: Database,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
) -> list[MetricType]:
"""
Get all metrics from a given schema and table.
@@ -1556,17 +1572,15 @@ def get_metrics( # pylint: disable=unused-argument
@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument
cls,
- table_name: str,
- schema: str | None,
database: Database,
+ table: Table,
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
"""
Add a where clause to a query to reference only the most recent partition
- :param table_name: Table name
- :param schema: Schema name
+ :param table: Table instance
:param database: Database instance
:param query: SqlAlchemy query
:param columns: List of TableColumns
@@ -1589,9 +1603,8 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
def select_star( # pylint: disable=too-many-arguments,too-many-locals
cls,
database: Database,
- table_name: str,
+ table: Table,
engine: Engine,
- schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -1604,9 +1617,8 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
WARNING: expects only unquoted table and schema names.
:param database: Database instance
- :param table_name: Table name, unquoted
+ :param table: Table instance
:param engine: SqlAlchemy Engine instance
- :param schema: Schema, unquoted
:param limit: limit to impose on query
:param show_cols: Show columns in query; otherwise use "*"
:param indent: Add indentation to query
@@ -1618,16 +1630,18 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
fields: str | list[Any] = "*"
cols = cols or []
if (show_cols or latest_partition) and not cols:
- cols = database.get_columns(table_name, schema)
+ cols = database.get_columns(table)
if show_cols:
fields = cls._get_fields(cols)
+
quote = engine.dialect.identifier_preparer.quote
quote_schema = engine.dialect.identifier_preparer.quote_schema
- if schema:
- full_table_name = quote_schema(schema) + "." + quote(table_name)
- else:
- full_table_name = quote(table_name)
+ full_table_name = (
+ quote_schema(table.schema) + "." + quote(table.table)
+ if table.schema
+ else quote(table.table)
+ )
qry = select(fields).select_from(text(full_table_name))
@@ -1635,7 +1649,10 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
qry = qry.limit(limit)
if latest_partition:
partition_query = cls.where_latest_partition(
- table_name, schema, database, qry, columns=cols
+ database,
+ table,
+ qry,
+ columns=cols,
)
if partition_query is not None:
qry = partition_query
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 20eae4f9336ce..876babf43d290 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -304,20 +304,18 @@ def get_indexes(
cls,
database: "Database",
inspector: Inspector,
- table_name: str,
- schema: Optional[str],
+ table: Table,
) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
- :param table_name: The table to inspect
- :param schema: The schema to inspect
+ :param table: The table instance to inspect
:returns: The indexes
"""
- return cls.normalize_indexes(inspector.get_indexes(table_name, schema))
+ return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema))
@classmethod
def get_extra_table_metadata(
@@ -325,7 +323,7 @@ def get_extra_table_metadata(
database: "Database",
table: Table,
) -> dict[str, Any]:
- indexes = database.get_indexes(table.table, table.schema)
+ indexes = database.get_indexes(table)
if not indexes:
return {}
partitions_columns = [
@@ -629,9 +627,8 @@ def parameters_json_schema(cls) -> Any:
def select_star( # pylint: disable=too-many-arguments
cls,
database: "Database",
- table_name: str,
+ table: Table,
engine: Engine,
- schema: Optional[str] = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -690,9 +687,8 @@ def select_star( # pylint: disable=too-many-arguments
return super().select_star(
database,
- table_name,
+ table,
engine,
- schema,
limit,
show_cols,
indent,
diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py
index db2e500b53d8f..8a04ee5d3b0f2 100644
--- a/superset/db_engine_specs/db2.py
+++ b/superset/db_engine_specs/db2.py
@@ -21,6 +21,7 @@
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
+from superset.sql_parse import Table
logger = logging.getLogger(__name__)
@@ -64,7 +65,9 @@ def epoch_to_dttm(cls) -> str:
@classmethod
def get_table_comment(
- cls, inspector: Inspector, table_name: str, schema: Union[str, None]
+ cls,
+ inspector: Inspector,
+ table: Table,
) -> Optional[str]:
"""
Get comment of table from a given schema
@@ -78,7 +81,7 @@ def get_table_comment(
"""
comment = None
try:
- table_comment = inspector.get_table_comment(table_name, schema)
+ table_comment = inspector.get_table_comment(table.table, table.schema)
comment = table_comment.get("text")
return comment[0]
except IndexError:
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 2655ed6c9af6e..9e6639337482f 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -412,24 +412,24 @@ def handle_cursor( # pylint: disable=too-many-locals
def get_columns(
cls,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
- return BaseEngineSpec.get_columns(inspector, table_name, schema, options)
+ return BaseEngineSpec.get_columns(inspector, table, options)
@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments
cls,
- table_name: str,
- schema: str | None,
database: Database,
+ table: Table,
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
try:
col_names, values = cls.latest_partition(
- table_name, schema, database, show_first=True
+ database,
+ table,
+ show_first=True,
)
except Exception: # pylint: disable=broad-except
# table is not partitioned
@@ -449,7 +449,10 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]:
@classmethod
def latest_sub_partition( # type: ignore
- cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
+ cls,
+ database: Database,
+ table: Table,
+ **kwargs: Any,
) -> str:
# TODO(bogdan): implement`
pass
@@ -467,24 +470,24 @@ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None:
@classmethod
def _partition_query( # pylint: disable=too-many-arguments
cls,
- table_name: str,
- schema: str | None,
+ table: Table,
indexes: list[dict[str, Any]],
database: Database,
limit: int = 0,
order_by: list[tuple[str, bool]] | None = None,
filters: dict[Any, Any] | None = None,
) -> str:
- full_table_name = f"{schema}.{table_name}" if schema else table_name
+ full_table_name = (
+ f"{table.schema}.{table.table}" if table.schema else table.table
+ )
return f"SHOW PARTITIONS {full_table_name}"
@classmethod
def select_star( # pylint: disable=too-many-arguments
cls,
database: Database,
- table_name: str,
+ table: Table,
engine: Engine,
- schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -493,9 +496,8 @@ def select_star( # pylint: disable=too-many-arguments
) -> str:
return super(PrestoEngineSpec, cls).select_star(
database,
- table_name,
+ table,
engine,
- schema,
limit,
show_cols,
indent,
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index a8143c87edbb6..b46d8132f5836 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -420,8 +420,7 @@ def get_function_names(cls, database: Database) -> list[str]:
@classmethod
def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument
cls,
- table_name: str,
- schema: str | None,
+ table: Table,
indexes: list[dict[str, Any]],
database: Database,
limit: int = 0,
@@ -434,8 +433,7 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus
Note the unused arguments are exposed for sub-classing purposes where custom
integrations may require the schema, indexes, etc. to build the partition query.
- :param table_name: the name of the table to get partitions from
- :param schema: the schema name
+ :param table: the table instance
:param indexes: the indexes associated with the table
:param database: the database the query will be run against
:param limit: the number of partitions to be returned
@@ -464,12 +462,16 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus
presto_version = database.get_extra().get("version")
if presto_version and Version(presto_version) < Version("0.199"):
- full_table_name = f"{schema}.{table_name}" if schema else table_name
+ full_table_name = (
+ f"{table.schema}.{table.table}" if table.schema else table.table
+ )
partition_select_clause = f"SHOW PARTITIONS FROM {full_table_name}"
else:
- system_table_name = f'"{table_name}$partitions"'
+ system_table_name = f'"{table.table}$partitions"'
full_table_name = (
- f"{schema}.{system_table_name}" if schema else system_table_name
+ f"{table.schema}.{system_table_name}"
+ if table.schema
+ else system_table_name
)
partition_select_clause = f"SELECT * FROM {full_table_name}"
@@ -486,16 +488,13 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus
@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments
cls,
- table_name: str,
- schema: str | None,
database: Database,
+ table: Table,
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
try:
- col_names, values = cls.latest_partition(
- table_name, schema, database, show_first=True
- )
+ col_names, values = cls.latest_partition(database, table, show_first=True)
except Exception: # pylint: disable=broad-except
# table is not partitioned
return None
@@ -529,16 +528,14 @@ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None:
@cache_manager.data_cache.memoize(timeout=60)
def latest_partition( # pylint: disable=too-many-arguments
cls,
- table_name: str,
- schema: str | None,
database: Database,
+ table: Table,
show_first: bool = False,
indexes: list[dict[str, Any]] | None = None,
) -> tuple[list[str], list[str] | None]:
"""Returns col name and the latest (max) partition value for a table
- :param table_name: the name of the table
- :param schema: schema / database / namespace
+ :param table: the table instance
:param database: database query will be run against
:type database: models.Database
:param show_first: displays the value for the first partitioning key
@@ -550,11 +547,11 @@ def latest_partition( # pylint: disable=too-many-arguments
(['ds'], ('2018-01-01',))
"""
if indexes is None:
- indexes = database.get_indexes(table_name, schema)
+ indexes = database.get_indexes(table)
if not indexes:
raise SupersetTemplateException(
- f"Error getting partition for {schema}.{table_name}. "
+ f"Error getting partition for {table}. "
"Verify that this table has a partition."
)
@@ -575,20 +572,23 @@ def latest_partition( # pylint: disable=too-many-arguments
return column_names, cls._latest_partition_from_df(
df=database.get_df(
sql=cls._partition_query(
- table_name,
- schema,
+ table,
indexes,
database,
limit=1,
order_by=[(column_name, True) for column_name in column_names],
),
- schema=schema,
+ catalog=table.catalog,
+ schema=table.schema,
)
)
@classmethod
def latest_sub_partition(
- cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
+ cls,
+ database: Database,
+ table: Table,
+ **kwargs: Any,
) -> Any:
"""Returns the latest (max) partition value for a table
@@ -601,12 +601,9 @@ def latest_sub_partition(
``latest_sub_partition('my_table',
event_category='page', event_type='click')``
- :param table_name: the name of the table, can be just the table
- name or a fully qualified table name as ``schema_name.table_name``
- :type table_name: str
- :param schema: schema / database / namespace
- :type schema: str
:param database: database query will be run against
+ :param table: the table instance
+ :type table: Table
:type database: models.Database
:param kwargs: keyword arguments define the filtering criteria
@@ -615,7 +612,7 @@ def latest_sub_partition(
>>> latest_sub_partition('sub_partition_table', event_type='click')
'2018-01-01'
"""
- indexes = database.get_indexes(table_name, schema)
+ indexes = database.get_indexes(table)
part_fields = indexes[0]["column_names"]
for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary
if k not in k in part_fields: # pylint: disable=comparison-with-itself
@@ -633,15 +630,14 @@ def latest_sub_partition(
field_to_return = field
sql = cls._partition_query(
- table_name,
- schema,
+ table,
indexes,
database,
limit=1,
order_by=[(field_to_return, True)],
filters=kwargs,
)
- df = database.get_df(sql, schema)
+ df = database.get_df(sql, table.catalog, table.schema)
if df.empty:
return ""
return df.to_dict()[field_to_return][0]
@@ -966,7 +962,9 @@ def _parse_structural_column( # pylint: disable=too-many-locals
@classmethod
def _show_columns(
- cls, inspector: Inspector, table_name: str, schema: str | None
+ cls,
+ inspector: Inspector,
+ table: Table,
) -> list[ResultRow]:
"""
Show presto column names
@@ -976,17 +974,16 @@ def _show_columns(
:return: list of column objects
"""
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
- full_table = quote(table_name)
- if schema:
- full_table = f"{quote(schema)}.{full_table}"
+ full_table = quote(table.table)
+ if table.schema:
+ full_table = f"{quote(table.schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def get_columns(
cls,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
@@ -999,7 +996,7 @@ def get_columns(
:return: a list of results that contain column info
(i.e. column name and data type)
"""
- columns = cls._show_columns(inspector, table_name, schema)
+ columns = cls._show_columns(inspector, table)
result: list[ResultSetColumnType] = []
for column in columns:
# parse column if it is a row or array
@@ -1077,9 +1074,8 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]:
def select_star( # pylint: disable=too-many-arguments
cls,
database: Database,
- table_name: str,
+ table: Table,
engine: Engine,
- schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -1102,9 +1098,8 @@ def select_star( # pylint: disable=too-many-arguments
]
return super().select_star(
database,
- table_name,
+ table,
engine,
- schema,
limit,
show_cols,
indent,
@@ -1232,11 +1227,10 @@ def get_extra_table_metadata(
) -> dict[str, Any]:
metadata = {}
- if indexes := database.get_indexes(table.table, table.schema):
+ if indexes := database.get_indexes(table):
col_names, latest_parts = cls.latest_partition(
- table.table,
- table.schema,
database,
+ table,
show_first=True,
indexes=indexes,
)
@@ -1248,8 +1242,7 @@ def get_extra_table_metadata(
"cols": sorted(indexes[0].get("column_names", [])),
"latest": dict(zip(col_names, latest_parts)),
"partitionQuery": cls._partition_query(
- table_name=table.table,
- schema=table.schema,
+ table=table,
indexes=indexes,
database=database,
),
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 32c119ec1f645..aa151129a078e 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -66,7 +66,7 @@ def get_extra_table_metadata(
) -> dict[str, Any]:
metadata = {}
- if indexes := database.get_indexes(table.table, table.schema):
+ if indexes := database.get_indexes(table):
col_names, latest_parts = cls.latest_partition(
table.table,
table.schema,
@@ -91,8 +91,7 @@ def get_extra_table_metadata(
),
"latest": dict(zip(col_names, latest_parts)),
"partitionQuery": cls._partition_query(
- table_name=table.table,
- schema=table.schema,
+ table=table,
indexes=indexes,
database=database,
),
@@ -414,8 +413,7 @@ def _expand_columns(cls, col: ResultSetColumnType) -> list[ResultSetColumnType]:
def get_columns(
cls,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
@@ -423,7 +421,7 @@ def get_columns(
"schema_options", expand the schema definition out to show all
subfields of nested ROWs as their appropriate dotted paths.
"""
- base_cols = super().get_columns(inspector, table_name, schema, options)
+ base_cols = super().get_columns(inspector, table, options)
if not (options or {}).get("expand_rows"):
return base_cols
@@ -434,8 +432,7 @@ def get_indexes(
cls,
database: Database,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
@@ -444,11 +441,10 @@ def get_indexes(
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
- :param table_name: The table to inspect
- :param schema: The schema to inspect
+ :param table: The table instance to inspect
:returns: The indexes
"""
try:
- return super().get_indexes(database, inspector, table_name, schema)
+ return super().get_indexes(database, inspector, table)
except NoSuchTableError:
return []
diff --git a/superset/migrations/versions/2024-04-17_18-09_c15655337636_add_catalog_column.py b/superset/migrations/versions/2024-04-17_18-09_c15655337636_add_catalog_column.py
new file mode 100644
index 0000000000000..31c7137398717
--- /dev/null
+++ b/superset/migrations/versions/2024-04-17_18-09_c15655337636_add_catalog_column.py
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add catalog column
+
+Revision ID: c15655337636
+Revises: 0dc386701747
+Create Date: 2024-04-17 18:09:36.795529
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = "c15655337636"
+down_revision = "0dc386701747"
+
+import sqlalchemy as sa
+from alembic import op
+
+
+def upgrade():
+ op.add_column("tables", sa.Column("catalog", sa.String(length=256), nullable=True))
+
+
+def downgrade():
+ op.drop_column("tables", "catalog")
diff --git a/superset/models/core.py b/superset/models/core.py
index bfd4c39593392..939d3a80b2722 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -45,7 +45,7 @@
Integer,
MetaData,
String,
- Table,
+ Table as SqlaTable,
Text,
)
from sqlalchemy.engine import Connection, Dialect, Engine
@@ -71,6 +71,7 @@
)
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.result_set import SupersetResultSet
+from superset.sql_parse import Table
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
from superset.utils import cache as cache_util, core as utils
from superset.utils.backports import StrEnum
@@ -382,13 +383,22 @@ def get_effective_user(self, object_url: URL) -> str | None:
)
@contextmanager
- def get_sqla_engine(
+ def get_sqla_engine( # pylint: disable=too-many-arguments
self,
+ catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
override_ssh_tunnel: SSHTunnel | None = None,
) -> Engine:
+ """
+ Context manager for a SQLAlchemy engine.
+
+ This method will return a context manager for a SQLAlchemy engine. Using the
+ context manager (as opposed to the engine directly) is important because we need
+ to potentially establish SSH tunnels before the connection is created, and clean
+ them up once the engine is no longer used.
+ """
from superset.daos.database import ( # pylint: disable=import-outside-toplevel
DatabaseDAO,
)
@@ -403,7 +413,7 @@ def get_sqla_engine(
# if ssh_tunnel is available build engine with information
engine_context = ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=ssh_tunnel,
- sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted,
+ sqlalchemy_database_uri=sqlalchemy_uri,
)
with engine_context as server_context:
@@ -415,22 +425,21 @@ def get_sqla_engine(
server_context.local_bind_address,
)
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
- sqlalchemy_uri, server_context
+ sqlalchemy_uri,
+ server_context,
)
+
yield self._get_sqla_engine(
+ catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
- # The `get_sqla_engine_with_context` was renamed to `get_sqla_engine`, but we kept a
- # reference to the old method to prevent breaking third-party applications.
- # TODO (betodealmeida): Remove in 5.0
- get_sqla_engine_with_context = get_sqla_engine
-
def _get_sqla_engine(
self,
+ catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
@@ -447,26 +456,10 @@ def _get_sqla_engine(
params["poolclass"] = NullPool
connect_args = params.get("connect_args", {})
- # The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and
- # had its signature changed in order to support more DB engine specs. Since DB
- # engine specs can be released as 3rd party modules we want to make sure the old
- # method is still supported so we don't introduce a breaking change.
- if hasattr(self.db_engine_spec, "adjust_database_uri"):
- sqlalchemy_url = self.db_engine_spec.adjust_database_uri(
- sqlalchemy_url,
- schema,
- )
- logger.warning(
- "DB engine spec %s implements the method `adjust_database_uri`, which is "
- "deprecated and will be removed in version 3.0. Please update it to "
- "implement `adjust_engine_params` instead.",
- self.db_engine_spec,
- )
-
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
uri=sqlalchemy_url,
connect_args=connect_args,
- catalog=None,
+ catalog=catalog,
schema=schema,
)
@@ -532,12 +525,16 @@ def _get_sqla_engine(
@contextmanager
def get_raw_connection(
self,
+ catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
) -> Connection:
with self.get_sqla_engine(
- schema=schema, nullpool=nullpool, source=source
+ catalog=catalog,
+ schema=schema,
+ nullpool=nullpool,
+ source=source,
) as engine:
with closing(engine.raw_connection()) as conn:
# pre-session queries are used to set the selected schema and, in the
@@ -575,11 +572,12 @@ def get_reserved_words(self) -> set[str]:
def get_df( # pylint: disable=too-many-locals
self,
sql: str,
+ catalog: str | None = None,
schema: str | None = None,
mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
- with self.get_sqla_engine(schema) as engine:
+ with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
engine_url = engine.url
mutate_after_split = config["MUTATE_AFTER_SPLIT"]
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
@@ -640,8 +638,13 @@ def _log_query(sql: str) -> None:
return df
- def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
- with self.get_sqla_engine(schema) as engine:
+ def compile_sqla_query(
+ self,
+ qry: Select,
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> str:
+ with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
# pylint: disable=protected-access
@@ -652,8 +655,7 @@ def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
def select_star( # pylint: disable=too-many-arguments
self,
- table_name: str,
- schema: str | None = None,
+ table: Table,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -661,11 +663,10 @@ def select_star( # pylint: disable=too-many-arguments
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
- with self.get_sqla_engine(schema) as engine:
+ with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
return self.db_engine_spec.select_star(
self,
- table_name,
- schema=schema,
+ table,
engine=engine,
limit=limit,
show_cols=show_cols,
@@ -690,6 +691,7 @@ def safe_sqlalchemy_uri(self) -> str:
)
def get_all_table_names_in_schema( # pylint: disable=unused-argument
self,
+ catalog: str | None,
schema: str,
cache: bool = False,
cache_timeout: int | None = None,
@@ -707,7 +709,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
:return: The table/schema pairs
"""
try:
- with self.get_inspector_with_context() as inspector:
+ with self.get_inspector(catalog=catalog, schema=schema) as inspector:
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
@@ -725,6 +727,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
)
def get_all_view_names_in_schema( # pylint: disable=unused-argument
self,
+ catalog: str | None,
schema: str,
cache: bool = False,
cache_timeout: int | None = None,
@@ -742,7 +745,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
:return: set of views
"""
try:
- with self.get_inspector_with_context() as inspector:
+ with self.get_inspector(catalog=catalog, schema=schema) as inspector:
return {
(view, schema)
for view in self.db_engine_spec.get_view_names(
@@ -755,10 +758,17 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
@contextmanager
- def get_inspector_with_context(
- self, ssh_tunnel: SSHTunnel | None = None
+ def get_inspector(
+ self,
+ catalog: str | None = None,
+ schema: str | None = None,
+ ssh_tunnel: SSHTunnel | None = None,
) -> Inspector:
- with self.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine:
+ with self.get_sqla_engine(
+ catalog=catalog,
+ schema=schema,
+ override_ssh_tunnel=ssh_tunnel,
+ ) as engine:
yield sqla.inspect(engine)
@cache_util.memoized_func(
@@ -767,6 +777,7 @@ def get_inspector_with_context(
)
def get_all_schema_names( # pylint: disable=unused-argument
self,
+ catalog: str | None = None,
cache: bool = False,
cache_timeout: int | None = None,
force: bool = False,
@@ -783,7 +794,10 @@ def get_all_schema_names( # pylint: disable=unused-argument
:return: schema list
"""
try:
- with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector:
+ with self.get_inspector(
+ catalog=catalog,
+ ssh_tunnel=ssh_tunnel,
+ ) as inspector:
return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@@ -835,51 +849,57 @@ def get_encrypted_extra(self) -> dict[str, Any]:
def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None:
self.db_engine_spec.update_params_from_encrypted_extra(self, params)
- def get_table(self, table_name: str, schema: str | None = None) -> Table:
+ def get_table(self, table: Table) -> SqlaTable:
extra = self.get_extra()
meta = MetaData(**extra.get("metadata_params", {}))
- with self.get_sqla_engine() as engine:
- return Table(
- table_name,
+ with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
+ return SqlaTable(
+ table.table,
meta,
- schema=schema or None,
+ schema=table.schema or None,
autoload=True,
autoload_with=engine,
)
- def get_table_comment(
- self, table_name: str, schema: str | None = None
- ) -> str | None:
- with self.get_inspector_with_context() as inspector:
- return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
-
- def get_columns(
- self, table_name: str, schema: str | None = None
- ) -> list[ResultSetColumnType]:
- with self.get_inspector_with_context() as inspector:
+ def get_table_comment(self, table: Table) -> str | None:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ return self.db_engine_spec.get_table_comment(inspector, table)
+
+ def get_columns(self, table: Table) -> list[ResultSetColumnType]:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
return self.db_engine_spec.get_columns(
- inspector, table_name, schema, self.schema_options
+ inspector, table, self.schema_options
)
def get_metrics(
self,
- table_name: str,
- schema: str | None = None,
+ table: Table,
) -> list[MetricType]:
- with self.get_inspector_with_context() as inspector:
- return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
-
- def get_indexes(
- self, table_name: str, schema: str | None = None
- ) -> list[dict[str, Any]]:
- with self.get_inspector_with_context() as inspector:
- return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
-
- def get_pk_constraint(
- self, table_name: str, schema: str | None = None
- ) -> dict[str, Any]:
- with self.get_inspector_with_context() as inspector:
- pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ return self.db_engine_spec.get_metrics(self, inspector, table)
+
+ def get_indexes(self, table: Table) -> list[dict[str, Any]]:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ return self.db_engine_spec.get_indexes(self, inspector, table)
+
+ def get_pk_constraint(self, table: Table) -> dict[str, Any]:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ pk_constraint = inspector.get_pk_constraint(table.table, table.schema) or {}
def _convert(value: Any) -> Any:
try:
@@ -889,11 +909,12 @@ def _convert(value: Any) -> Any:
return {key: _convert(value) for key, value in pk_constraint.items()}
- def get_foreign_keys(
- self, table_name: str, schema: str | None = None
- ) -> list[dict[str, Any]]:
- with self.get_inspector_with_context() as inspector:
- return inspector.get_foreign_keys(table_name, schema)
+ def get_foreign_keys(self, table: Table) -> list[dict[str, Any]]:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ return inspector.get_foreign_keys(table.table, table.schema)
def get_schema_access_for_file_upload( # pylint: disable=invalid-name
self,
@@ -942,36 +963,22 @@ def get_perm(self) -> str:
return self.perm # type: ignore
def has_table(self, table: Table) -> bool:
- with self.get_sqla_engine() as engine:
- return engine.has_table(table.table_name, table.schema or None)
+ with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
+ return engine.has_table(table.table, table.schema)
- def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool:
- with self.get_sqla_engine() as engine:
- return engine.has_table(table_name, schema)
-
- @classmethod
- def _has_view(
- cls,
- conn: Connection,
- dialect: Dialect,
- view_name: str,
- schema: str | None = None,
- ) -> bool:
- view_names: list[str] = []
- try:
- view_names = dialect.get_view_names(connection=conn, schema=schema)
- except Exception: # pylint: disable=broad-except
- logger.warning("Has view failed", exc_info=True)
- return view_name in view_names
-
- def has_view(self, view_name: str, schema: str | None = None) -> bool:
- with self.get_sqla_engine(schema) as engine:
- return engine.run_callable(
- self._has_view, engine.dialect, view_name, schema
- )
+ def has_view(self, table: Table) -> bool:
+ with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
+ connection = engine.connect()
+ try:
+ views = engine.dialect.get_view_names(
+ connection=connection,
+ schema=table.schema,
+ )
+ except Exception: # pylint: disable=broad-except
+ logger.warning("Has view failed", exc_info=True)
+ views = []
- def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
- return self.has_view(view_name=view_name, schema=schema)
+ return table.table in views
def get_dialect(self) -> Dialect:
sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted)
diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py
index eba3acf36edd4..7f81081777538 100644
--- a/superset/views/datasource/views.py
+++ b/superset/views/datasource/views.py
@@ -37,6 +37,7 @@
from superset.daos.datasource import DatasourceDAO
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.core import Database
+from superset.sql_parse import Table
from superset.superset_typing import FlaskResponse
from superset.utils.core import DatasourceType
from superset.views.base import (
@@ -180,8 +181,7 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse:
)
external_metadata = get_physical_table_metadata(
database=database,
- table_name=params["table_name"],
- schema_name=params["schema_name"],
+ table=Table(params["table_name"], params["schema_name"]),
normalize_columns=params.get("normalize_columns") or False,
)
except (NoResultFound, NoSuchTableError) as ex:
diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py
index 5e004992de28c..8de9f7483bf26 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -1175,6 +1175,169 @@ def test_csv_upload_file_extension_valid(
assert response.status_code == 200
+def test_table_metadata_happy_path(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint.
+ """
+ database = mocker.MagicMock()
+ database.db_engine_spec.get_table_metadata.return_value = {"hello": "world"}
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database)
+ mocker.patch("superset.databases.api.security_manager.raise_for_access")
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t")
+ assert response.json == {"hello": "world"}
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("t"),
+ )
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t&schema=s")
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("t", "s"),
+ )
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t&catalog=c")
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("t", None, "c"),
+ )
+
+ response = client.get(
+ "/api/v1/database/1/table_metadata/?name=t&schema=s&catalog=c"
+ )
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("t", "s", "c"),
+ )
+
+
+def test_table_metadata_no_table(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint when no table name is passed.
+ """
+ database = mocker.MagicMock()
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database)
+
+ response = client.get("/api/v1/database/1/table_metadata/?schema=s&catalog=c")
+ assert response.status_code == 422
+ assert response.json == {
+ "errors": [
+ {
+ "message": "An error happened when validating the request",
+ "error_type": "INVALID_PAYLOAD_SCHEMA_ERROR",
+ "level": "error",
+ "extra": {
+ "messages": {"name": ["Missing data for required field."]},
+ "issue_codes": [
+ {
+ "code": 1020,
+ "message": "Issue 1020 - The submitted payload has the incorrect schema.",
+ }
+ ],
+ },
+ }
+ ]
+ }
+
+
+def test_table_metadata_slashes(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint with names that have slashes.
+ """
+ database = mocker.MagicMock()
+ database.db_engine_spec.get_table_metadata.return_value = {"hello": "world"}
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database)
+ mocker.patch("superset.databases.api.security_manager.raise_for_access")
+
+ client.get("/api/v1/database/1/table_metadata/?name=foo/bar")
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("foo/bar"),
+ )
+
+
+def test_table_metadata_invalid_database(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint when the database is invalid.
+ """
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=None)
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t")
+ assert response.status_code == 404
+ assert response.json == {
+ "errors": [
+ {
+ "message": "No such database",
+ "error_type": "DATABASE_NOT_FOUND_ERROR",
+ "level": "error",
+ "extra": {
+ "issue_codes": [
+ {
+ "code": 1011,
+ "message": "Issue 1011 - Superset encountered an unexpected error.",
+ },
+ {
+ "code": 1036,
+ "message": "Issue 1036 - The database was deleted.",
+ },
+ ]
+ },
+ }
+ ]
+ }
+
+
+def test_table_metadata_unauthorized(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint when the user is unauthorized.
+ """
+ database = mocker.MagicMock()
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database)
+ mocker.patch(
+ "superset.databases.api.security_manager.raise_for_access",
+ side_effect=SupersetSecurityException(
+ SupersetError(
+ error_type=SupersetErrorType.TABLE_SECURITY_ACCESS_ERROR,
+ message="You don't have access to the table",
+ level=ErrorLevel.ERROR,
+ )
+ ),
+ )
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t")
+ assert response.status_code == 404
+ assert response.json == {
+ "errors": [
+ {
+ "message": "No such table",
+ "error_type": "TABLE_NOT_FOUND_ERROR",
+ "level": "error",
+ "extra": None,
+ }
+ ]
+ }
+
def test_table_extra_metadata_happy_path(
mocker: MockFixture,
client: Any,
diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py
index e17e0d2833db8..3bc05ee20eec0 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -232,9 +232,8 @@ class NoLimitDBEngineSpec(BaseEngineSpec):
sql = BaseEngineSpec.select_star(
database=database,
- table_name="my_table",
+ table=Table("my_table"),
engine=engine,
- schema=None,
limit=100,
show_cols=True,
indent=True,
@@ -252,9 +251,8 @@ class NoLimitDBEngineSpec(BaseEngineSpec):
sql = NoLimitDBEngineSpec.select_star(
database=database,
- table_name="my_table",
+ table=Table("my_table"),
engine=engine,
- schema=None,
limit=100,
show_cols=True,
indent=True,
diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py
index 3870297db8b70..18af750bf54a7 100644
--- a/tests/unit_tests/db_engine_specs/test_bigquery.py
+++ b/tests/unit_tests/db_engine_specs/test_bigquery.py
@@ -27,6 +27,7 @@
from sqlalchemy.sql import sqltypes
from sqlalchemy_bigquery import BigQueryDialect
+from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@@ -156,9 +157,8 @@ def test_select_star(mocker: MockFixture) -> None:
sql = BigQueryEngineSpec.select_star(
database=database,
- table_name="my_table",
+ table=Table("my_table"),
engine=engine,
- schema=None,
limit=100,
show_cols=True,
indent=True,
diff --git a/tests/unit_tests/db_engine_specs/test_db2.py b/tests/unit_tests/db_engine_specs/test_db2.py
index d7dd19ad5923c..7e4ea68aa8013 100644
--- a/tests/unit_tests/db_engine_specs/test_db2.py
+++ b/tests/unit_tests/db_engine_specs/test_db2.py
@@ -18,6 +18,8 @@
import pytest
from pytest_mock import MockerFixture
+from superset.sql_parse import Table
+
def test_epoch_to_dttm() -> None:
"""
@@ -43,7 +45,7 @@ def test_get_table_comment(mocker: MockerFixture):
}
assert (
- Db2EngineSpec.get_table_comment(mock_inspector, "my_table", "my_schema")
+ Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema"))
== "This is a table comment"
)
@@ -59,7 +61,8 @@ def test_get_table_comment_empty(mocker: MockerFixture):
mock_inspector.get_table_comment.return_value = {}
assert (
- Db2EngineSpec.get_table_comment(mock_inspector, "my_table", "my_schema") == None
+ Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema"))
+ is None
)
diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py
index 8d57d4ed1a8c3..638b377c82709 100644
--- a/tests/unit_tests/db_engine_specs/test_presto.py
+++ b/tests/unit_tests/db_engine_specs/test_presto.py
@@ -24,6 +24,7 @@
from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url
+from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
@@ -143,7 +144,10 @@ def test_where_latest_partition(
expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
result = spec.where_latest_partition(
- "table", mock.MagicMock(), mock.MagicMock(), query, columns
+ mock.MagicMock(),
+ Table("table"),
+ query,
+ columns,
)
assert result is not None
actual = result.compile(
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py
index 7ce9199c5bd1d..a0a81980bf28a 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -442,7 +442,7 @@ def test_get_columns(mocker: MockerFixture):
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns
- actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema")
+ actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema"))
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
@@ -475,7 +475,9 @@ def test_get_columns_expand_rows(mocker: MockerFixture):
mock_inspector.get_columns.return_value = sqla_columns
actual = TrinoEngineSpec.get_columns(
- mock_inspector, "table", "schema", {"expand_rows": True}
+ mock_inspector,
+ Table("table", "schema"),
+ {"expand_rows": True},
)
expected = [
ResultSetColumnType(
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index beefd3ea3cc5a..ce3ad1822271f 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -18,7 +18,6 @@
# pylint: disable=import-outside-toplevel
import json
from datetime import datetime
-from typing import Optional
import pytest
from pytest_mock import MockFixture
@@ -26,6 +25,7 @@
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.models.core import Database
+from superset.sql_parse import Table
def test_get_metrics(mocker: MockFixture) -> None:
@@ -37,7 +37,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
from superset.models.core import Database
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- assert database.get_metrics("table") == [
+ assert database.get_metrics(Table("table")) == [
{
"expression": "COUNT(*)",
"metric_name": "count",
@@ -52,8 +52,7 @@ def get_metrics(
cls,
database: Database,
inspector: Inspector,
- table_name: str,
- schema: Optional[str],
+ table: Table,
) -> list[MetricType]:
return [
{
@@ -65,7 +64,7 @@ def get_metrics(
]
database.get_db_engine_spec = mocker.MagicMock(return_value=CustomSqliteEngineSpec)
- assert database.get_metrics("table") == [
+ assert database.get_metrics(Table("table")) == [
{
"expression": "COUNT(DISTINCT user_id)",
"metric_name": "count_distinct_user_id",