diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index 949323b9aa75c..402e26462e041 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -61,7 +61,7 @@ export type ExploreQuery = QueryResponse & { }; export interface ISimpleColumn { - name?: string | null; + column_name?: string | null; type?: string | null; is_dttm?: boolean | null; } @@ -216,7 +216,7 @@ export const SaveDatasetModal = ({ ...formDataWithDefaults, datasource: `${datasetToOverwrite.datasetid}__table`, ...(defaultVizType === 'table' && { - all_columns: datasource?.columns?.map(column => column.name), + all_columns: datasource?.columns?.map(column => column.column_name), }), }), ]); @@ -301,7 +301,7 @@ export const SaveDatasetModal = ({ ...formDataWithDefaults, datasource: `${data.table_id}__table`, ...(defaultVizType === 'table' && { - all_columns: selectedColumns.map(column => column.name), + all_columns: selectedColumns.map(column => column.column_name), }), }), ) diff --git a/superset-frontend/src/SqlLab/fixtures.ts b/superset-frontend/src/SqlLab/fixtures.ts index fcb0fff8e3d70..ba88a41b0accc 100644 --- a/superset-frontend/src/SqlLab/fixtures.ts +++ b/superset-frontend/src/SqlLab/fixtures.ts @@ -692,17 +692,17 @@ export const testQuery: ISaveableDatasource = { sql: 'SELECT *', columns: [ { - name: 'Column 1', + column_name: 'Column 1', type: DatasourceType.Query, is_dttm: false, }, { - name: 'Column 3', + column_name: 'Column 3', type: DatasourceType.Query, is_dttm: false, }, { - name: 'Column 2', + column_name: 'Column 2', type: DatasourceType.Query, is_dttm: true, }, diff --git a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx index 80cf879f7f256..c74212f0baf6b 100644 --- a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx +++ b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx @@ -48,7 +48,7 @@ class AdhocMetricOption extends React.PureComponent { } onRemoveMetric(e) { - e.stopPropagation(); + e?.stopPropagation(); this.props.onRemoveMetric(this.props.index); } diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 593c5f853b935..6180a546e7500 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -31,7 +31,6 @@ Dict, Hashable, List, - NamedTuple, Optional, Set, Tuple, @@ -50,11 +49,9 @@ from jinja2.exceptions import TemplateError from sqlalchemy import ( and_, - asc, Boolean, Column, DateTime, - desc, Enum, ForeignKey, inspect, @@ -80,13 +77,11 @@ from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table from sqlalchemy.sql.elements import ColumnClause, TextClause -from sqlalchemy.sql.expression import Label, Select, TextAsFrom +from sqlalchemy.sql.expression import Label, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager -from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus -from superset.common.utils.time_range_utils import get_since_until_from_time_range from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.sqla.utils import ( find_cached_objects_in_session, @@ -98,7 +93,6 @@ from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression from superset.exceptions import ( - AdvancedDataTypeResponseError, ColumnNotFoundException, DatasetInvalidPermissionEvaluationException, QueryClauseValidationException, @@ -106,7 +100,6 @@ SupersetGenericDBErrorException, SupersetSecurityException, ) -from superset.extensions import feature_flag_manager from superset.jinja_context import ( BaseTemplateProcessor, ExtraCache, @@ -114,26 +107,17 @@ ) from superset.models.annotations import Annotation from superset.models.core import Database -from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult -from superset.sql_parse import ParsedQuery, sanitize_clause -from superset.superset_typing import ( - AdhocColumn, - AdhocMetric, - Column as ColumnTyping, - Metric, - OrderBy, - QueryObjectDict, +from superset.models.helpers import ( + AuditMixinNullable, + CertificationMixin, + ExploreMixin, + QueryResult, + QueryStringExtended, ) +from superset.sql_parse import ParsedQuery, sanitize_clause +from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict from superset.utils import core as utils -from superset.utils.core import ( - GenericDataType, - get_column_name, - get_username, - is_adhoc_column, - MediumText, - QueryObjectFilterClause, - remove_duplicates, -) +from superset.utils.core import GenericDataType, get_username, MediumText config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -150,26 +134,6 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES} -class SqlaQuery(NamedTuple): - applied_template_filters: List[str] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] - cte: Optional[str] - extra_cache_keys: List[Any] - labels_expected: List[str] - prequeries: List[str] - sqla_query: Select - - -class QueryStringExtended(NamedTuple): - applied_template_filters: Optional[List[str]] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] - labels_expected: List[str] - prequeries: List[str] - sql: str - - @dataclass class MetadataResult: added: List[str] = field(default_factory=list) @@ -310,6 +274,35 @@ def db_extra(self) -> Dict[str, Any]: def type_generic(self) -> Optional[utils.GenericDataType]: if self.is_dttm: return GenericDataType.TEMPORAL + + bool_types = ("BOOL",) + num_types = ( + "DOUBLE", + "FLOAT", + "INT", + "BIGINT", + "NUMBER", + "LONG", + "REAL", + "NUMERIC", + "DECIMAL", + "MONEY", + ) + date_types = ("DATE", "TIME") + str_types = ("VARCHAR", "STRING", "CHAR") + + if self.table is None: + # Query.TableColumns don't have a reference to a table.db_engine_spec + # reference so this logic will manage rendering types + if self.type and any(map(lambda t: t in self.type.upper(), str_types)): + return GenericDataType.STRING + if self.type and any(map(lambda t: t in self.type.upper(), bool_types)): + return GenericDataType.BOOLEAN + if self.type and any(map(lambda t: t in self.type.upper(), num_types)): + return GenericDataType.NUMERIC + if self.type and any(map(lambda t: t in self.type.upper(), date_types)): + return GenericDataType.TEMPORAL + column_spec = self.db_engine_spec.get_column_spec( self.type, db_extra=self.db_extra ) @@ -545,8 +538,10 @@ def _process_sql_expression( return expression -class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods - """An ORM object for SqlAlchemy table references.""" +class SqlaTable( + Model, BaseDatasource, ExploreMixin +): # pylint: disable=too-many-public-methods + """An ORM object for SqlAlchemy table references""" type = "table" query_language = "sql" @@ -626,6 +621,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def __repr__(self) -> str: # pylint: disable=invalid-repr-returned return self.name + @property + def db_extra(self) -> Dict[str, Any]: + return self.database.get_extra() + @staticmethod def _apply_cte(sql: str, cte: Optional[str]) -> str: """ @@ -1151,680 +1150,6 @@ def get_sqla_row_level_filters( def text(self, clause: str) -> TextClause: return self.db_engine_spec.get_text_clause(clause) - def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - self, - apply_fetch_values_predicate: bool = False, - columns: Optional[List[ColumnTyping]] = None, - extras: Optional[Dict[str, Any]] = None, - filter: Optional[ # pylint: disable=redefined-builtin - List[QueryObjectFilterClause] - ] = None, - from_dttm: Optional[datetime] = None, - granularity: Optional[str] = None, - groupby: Optional[List[Column]] = None, - inner_from_dttm: Optional[datetime] = None, - inner_to_dttm: Optional[datetime] = None, - is_rowcount: bool = False, - is_timeseries: bool = True, - metrics: Optional[List[Metric]] = None, - orderby: Optional[List[OrderBy]] = None, - order_desc: bool = True, - to_dttm: Optional[datetime] = None, - series_columns: Optional[List[Column]] = None, - series_limit: Optional[int] = None, - series_limit_metric: Optional[Metric] = None, - row_limit: Optional[int] = None, - row_offset: Optional[int] = None, - timeseries_limit: Optional[int] = None, - timeseries_limit_metric: Optional[Metric] = None, - time_shift: Optional[str] = None, - ) -> SqlaQuery: - """Querying any sqla table from this common interface""" - if granularity not in self.dttm_cols and granularity is not None: - granularity = self.main_dttm_col - - extras = extras or {} - time_grain = extras.get("time_grain_sqla") - - template_kwargs = { - "columns": columns, - "from_dttm": from_dttm.isoformat() if from_dttm else None, - "groupby": groupby, - "metrics": metrics, - "row_limit": row_limit, - "row_offset": row_offset, - "time_column": granularity, - "time_grain": time_grain, - "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": [col.column_name for col in self.columns], - "filter": filter, - } - columns = columns or [] - groupby = groupby or [] - rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] - applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] - series_column_names = utils.get_column_names(series_columns or []) - # deprecated, to be removed in 2.0 - if is_timeseries and timeseries_limit: - series_limit = timeseries_limit - series_limit_metric = series_limit_metric or timeseries_limit_metric - template_kwargs.update(self.template_params_dict) - extra_cache_keys: List[Any] = [] - template_kwargs["extra_cache_keys"] = extra_cache_keys - removed_filters: List[str] = [] - applied_template_filters: List[str] = [] - template_kwargs["removed_filters"] = removed_filters - template_kwargs["applied_filters"] = applied_template_filters - template_processor = self.get_template_processor(**template_kwargs) - db_engine_spec = self.db_engine_spec - prequeries: List[str] = [] - orderby = orderby or [] - need_groupby = bool(metrics is not None or groupby) - metrics = metrics or [] - - # For backward compatibility - if granularity not in self.dttm_cols and granularity is not None: - granularity = self.main_dttm_col - - columns_by_name: Dict[str, TableColumn] = { - col.column_name: col for col in self.columns - } - - metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} - - if not granularity and is_timeseries: - raise QueryObjectValidationError( - _( - "Datetime column not provided as part table configuration " - "and is required by this type of chart" - ) - ) - if not metrics and not columns and not groupby: - raise QueryObjectValidationError(_("Empty query?")) - - metrics_exprs: List[ColumnElement] = [] - for metric in metrics: - if utils.is_adhoc_metric(metric): - assert isinstance(metric, dict) - metrics_exprs.append( - self.adhoc_metric_to_sqla( - metric=metric, - columns_by_name=columns_by_name, - template_processor=template_processor, - ) - ) - elif isinstance(metric, str) and metric in metrics_by_name: - metrics_exprs.append( - metrics_by_name[metric].get_sqla_col( - template_processor=template_processor - ) - ) - else: - raise QueryObjectValidationError( - _("Metric '%(metric)s' does not exist", metric=metric) - ) - - if metrics_exprs: - main_metric_expr = metrics_exprs[0] - else: - main_metric_expr, label = literal_column("COUNT(*)"), "count" - main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) - - # To ensure correct handling of the ORDER BY labeling we need to reference the - # metric instance if defined in the SELECT clause. - # use the key of the ColumnClause for the expected label - metrics_exprs_by_label = {m.key: m for m in metrics_exprs} - metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} - - # Since orderby may use adhoc metrics, too; we need to process them first - orderby_exprs: List[ColumnElement] = [] - for orig_col, ascending in orderby: - col: Union[AdhocMetric, ColumnElement] = orig_col - if isinstance(col, dict): - col = cast(AdhocMetric, col) - if col.get("sqlExpression"): - col["sqlExpression"] = _process_sql_expression( - expression=col["sqlExpression"], - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) - if utils.is_adhoc_metric(col): - # add adhoc sort by column to columns_by_name if not exists - col = self.adhoc_metric_to_sqla(col, columns_by_name) - # if the adhoc metric has been defined before - # use the existing instance. - col = metrics_exprs_by_expr.get(str(col), col) - need_groupby = True - elif col in columns_by_name: - col = columns_by_name[col].get_sqla_col( - template_processor=template_processor - ) - elif col in metrics_exprs_by_label: - col = metrics_exprs_by_label[col] - need_groupby = True - elif col in metrics_by_name: - col = metrics_by_name[col].get_sqla_col( - template_processor=template_processor - ) - need_groupby = True - - if isinstance(col, ColumnElement): - orderby_exprs.append(col) - else: - # Could not convert a column reference to valid ColumnElement - raise QueryObjectValidationError( - _("Unknown column used in orderby: %(col)s", col=orig_col) - ) - - select_exprs: List[Union[Column, Label]] = [] - groupby_all_columns = {} - groupby_series_columns = {} - - # filter out the pseudo column __timestamp from columns - columns = [col for col in columns if col != utils.DTTM_ALIAS] - dttm_col = columns_by_name.get(granularity) if granularity else None - - if need_groupby: - # dedup columns while preserving order - columns = groupby or columns - for selected in columns: - if isinstance(selected, str): - # if groupby field/expr equals granularity field/expr - if selected == granularity: - table_col = columns_by_name[selected] - outer = table_col.get_timestamp_expression( - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) - # if groupby field equals a selected column - elif selected in columns_by_name: - outer = columns_by_name[selected].get_sqla_col( - template_processor=template_processor - ) - else: - selected = validate_adhoc_subquery( - selected, - self.database_id, - self.schema, - ) - outer = literal_column(f"({selected})") - outer = self.make_sqla_column_compatible(outer, selected) - else: - outer = self.adhoc_column_to_sqla( - col=selected, template_processor=template_processor - ) - groupby_all_columns[outer.name] = outer - if ( - is_timeseries and not series_column_names - ) or outer.name in series_column_names: - groupby_series_columns[outer.name] = outer - select_exprs.append(outer) - elif columns: - for selected in columns: - if is_adhoc_column(selected): - _sql = selected["sqlExpression"] - _column_label = selected["label"] - elif isinstance(selected, str): - _sql = selected - _column_label = selected - - selected = validate_adhoc_subquery( - _sql, - self.database_id, - self.schema, - ) - select_exprs.append( - columns_by_name[selected].get_sqla_col( - template_processor=template_processor - ) - if isinstance(selected, str) and selected in columns_by_name - else self.make_sqla_column_compatible( - literal_column(selected), _column_label - ) - ) - metrics_exprs = [] - - if granularity: - if granularity not in columns_by_name or not dttm_col: - raise QueryObjectValidationError( - _( - 'Time column "%(col)s" does not exist in dataset', - col=granularity, - ) - ) - time_filters = [] - - if is_timeseries: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) - # always put timestamp as the first column - select_exprs.insert(0, timestamp) - groupby_all_columns[timestamp.name] = timestamp - - # Use main dttm column to support index with secondary dttm columns. - if ( - db_engine_spec.time_secondary_columns - and self.main_dttm_col in self.dttm_cols - and self.main_dttm_col != dttm_col.column_name - ): - time_filters.append( - columns_by_name[self.main_dttm_col].get_time_filter( - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor, - ) - ) - time_filters.append( - dttm_col.get_time_filter( - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor, - ) - ) - - # Always remove duplicates by column name, as sometimes `metrics_exprs` - # can have the same name as a groupby column (e.g. when users use - # raw columns as custom SQL adhoc metric). - select_exprs = remove_duplicates( - select_exprs + metrics_exprs, key=lambda x: x.name - ) - - # Expected output columns - labels_expected = [c.key for c in select_exprs] - - # Order by columns are "hidden" columns, some databases require them - # always be present in SELECT if an aggregation function is used - if not db_engine_spec.allows_hidden_orderby_agg: - select_exprs = remove_duplicates(select_exprs + orderby_exprs) - - qry = sa.select(select_exprs) - - tbl, cte = self.get_from_clause(template_processor) - - if groupby_all_columns: - qry = qry.group_by(*groupby_all_columns.values()) - - where_clause_and = [] - having_clause_and = [] - - for flt in filter: # type: ignore - if not all(flt.get(s) for s in ["col", "op"]): - continue - flt_col = flt["col"] - val = flt.get("val") - op = flt["op"].upper() - col_obj: Optional[TableColumn] = None - sqla_col: Optional[Column] = None - if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: - col_obj = dttm_col - elif is_adhoc_column(flt_col): - try: - sqla_col = self.adhoc_column_to_sqla( - col=flt_col, - force_type_check=True, - template_processor=template_processor, - ) - applied_adhoc_filters_columns.append(flt_col) - except ColumnNotFoundException: - rejected_adhoc_filters_columns.append(flt_col) - continue - else: - col_obj = columns_by_name.get(cast(str, flt_col)) - filter_grain = flt.get("grain") - - if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - if get_column_name(flt_col) in removed_filters: - # Skip generating SQLA filter when the jinja template handles it. - continue - - if col_obj or sqla_col is not None: - if sqla_col is not None: - pass - elif col_obj and filter_grain: - sqla_col = col_obj.get_timestamp_expression( - time_grain=filter_grain, template_processor=template_processor - ) - elif col_obj: - sqla_col = col_obj.get_sqla_col( - template_processor=template_processor - ) - col_type = col_obj.type if col_obj else None - col_spec = db_engine_spec.get_column_spec( - native_type=col_type, - db_extra=self.database.get_extra(), - ) - is_list_target = op in ( - utils.FilterOperator.IN.value, - utils.FilterOperator.NOT_IN.value, - ) - - col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" - - if col_spec and not col_advanced_data_type: - target_generic_type = col_spec.generic_type - else: - target_generic_type = GenericDataType.STRING - eq = self.filter_values_handler( - values=val, - operator=op, - target_generic_type=target_generic_type, - target_native_type=col_type, - is_list_target=is_list_target, - db_engine_spec=db_engine_spec, - db_extra=self.database.get_extra(), - ) - if ( - col_advanced_data_type != "" - and feature_flag_manager.is_feature_enabled( - "ENABLE_ADVANCED_DATA_TYPES" - ) - and col_advanced_data_type in ADVANCED_DATA_TYPES - ): - values = eq if is_list_target else [eq] # type: ignore - bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ - col_advanced_data_type - ].translate_type( - { - "type": col_advanced_data_type, - "values": values, - } - ) - if bus_resp["error_message"]: - raise AdvancedDataTypeResponseError( - _(bus_resp["error_message"]) - ) - - where_clause_and.append( - ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( - sqla_col, op, bus_resp["values"] - ) - ) - elif is_list_target: - assert isinstance(eq, (tuple, list)) - if len(eq) == 0: - raise QueryObjectValidationError( - _("Filter value list cannot be empty") - ) - if len(eq) > len( - eq_without_none := [x for x in eq if x is not None] - ): - is_null_cond = sqla_col.is_(None) - if eq: - cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) - else: - cond = is_null_cond - else: - cond = sqla_col.in_(eq) - if op == utils.FilterOperator.NOT_IN.value: - cond = ~cond - where_clause_and.append(cond) - elif op == utils.FilterOperator.IS_NULL.value: - where_clause_and.append(sqla_col.is_(None)) - elif op == utils.FilterOperator.IS_NOT_NULL.value: - where_clause_and.append(sqla_col.isnot(None)) - elif op == utils.FilterOperator.IS_TRUE.value: - where_clause_and.append(sqla_col.is_(True)) - elif op == utils.FilterOperator.IS_FALSE.value: - where_clause_and.append(sqla_col.is_(False)) - else: - if ( - op - not in { - utils.FilterOperator.EQUALS.value, - utils.FilterOperator.NOT_EQUALS.value, - } - and eq is None - ): - raise QueryObjectValidationError( - _( - "Must specify a value for filters " - "with comparison operators" - ) - ) - if op == utils.FilterOperator.EQUALS.value: - where_clause_and.append(sqla_col == eq) - elif op == utils.FilterOperator.NOT_EQUALS.value: - where_clause_and.append(sqla_col != eq) - elif op == utils.FilterOperator.GREATER_THAN.value: - where_clause_and.append(sqla_col > eq) - elif op == utils.FilterOperator.LESS_THAN.value: - where_clause_and.append(sqla_col < eq) - elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: - where_clause_and.append(sqla_col >= eq) - elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: - where_clause_and.append(sqla_col <= eq) - elif op == utils.FilterOperator.LIKE.value: - where_clause_and.append(sqla_col.like(eq)) - elif op == utils.FilterOperator.ILIKE.value: - where_clause_and.append(sqla_col.ilike(eq)) - elif ( - op == utils.FilterOperator.TEMPORAL_RANGE.value - and isinstance(eq, str) - and col_obj is not None - ): - _since, _until = get_since_until_from_time_range( - time_range=eq, - time_shift=time_shift, - extras=extras, - ) - where_clause_and.append( - col_obj.get_time_filter( - start_dttm=_since, - end_dttm=_until, - label=sqla_col.key, - template_processor=template_processor, - ) - ) - else: - raise QueryObjectValidationError( - _("Invalid filter operation type: %(op)s", op=op) - ) - where_clause_and += self.get_sqla_row_level_filters(template_processor) - if extras: - where = extras.get("where") - if where: - try: - where = template_processor.process_template(f"({where})") - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error in jinja expression in WHERE clause: %(msg)s", - msg=ex.message, - ) - ) from ex - where = _process_sql_expression( - expression=where, - database_id=self.database_id, - schema=self.schema, - ) - where_clause_and += [self.text(where)] - having = extras.get("having") - if having: - try: - having = template_processor.process_template(f"({having})") - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error in jinja expression in HAVING clause: %(msg)s", - msg=ex.message, - ) - ) from ex - having = _process_sql_expression( - expression=having, - database_id=self.database_id, - schema=self.schema, - ) - having_clause_and += [self.text(having)] - - if apply_fetch_values_predicate and self.fetch_values_predicate: - qry = qry.where( - self.get_fetch_values_predicate(template_processor=template_processor) - ) - if granularity: - qry = qry.where(and_(*(time_filters + where_clause_and))) - else: - qry = qry.where(and_(*where_clause_and)) - qry = qry.having(and_(*having_clause_and)) - - self.make_orderby_compatible(select_exprs, orderby_exprs) - - for col, (orig_col, ascending) in zip(orderby_exprs, orderby): - if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): - # if engine does not allow using SELECT alias in ORDER BY - # revert to the underlying column - col = col.element - - if ( - db_engine_spec.allows_alias_in_select - and db_engine_spec.allows_hidden_cc_in_orderby - and col.name in [select_col.name for select_col in select_exprs] - ): - col = literal_column(col.name) - direction = asc if ascending else desc - qry = qry.order_by(direction(col)) - - if row_limit: - qry = qry.limit(row_limit) - if row_offset: - qry = qry.offset(row_offset) - - if series_limit and groupby_series_columns: - if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries: - # some sql dialects require for order by expressions - # to also be in the select clause -- others, e.g. vertica, - # require a unique inner alias - inner_main_metric_expr = self.make_sqla_column_compatible( - main_metric_expr, "mme_inner__" - ) - inner_groupby_exprs = [] - inner_select_exprs = [] - for gby_name, gby_obj in groupby_series_columns.items(): - label = get_column_name(gby_name) - inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") - inner_groupby_exprs.append(inner) - inner_select_exprs.append(inner) - - inner_select_exprs += [inner_main_metric_expr] - subq = select(inner_select_exprs).select_from(tbl) - inner_time_filter = [] - - if dttm_col and not db_engine_spec.time_groupby_inline: - inner_time_filter = [ - dttm_col.get_time_filter( - start_dttm=inner_from_dttm or from_dttm, - end_dttm=inner_to_dttm or to_dttm, - template_processor=template_processor, - ) - ] - subq = subq.where(and_(*(where_clause_and + inner_time_filter))) - subq = subq.group_by(*inner_groupby_exprs) - - ob = inner_main_metric_expr - if series_limit_metric: - ob = self._get_series_orderby( - series_limit_metric=series_limit_metric, - metrics_by_name=metrics_by_name, - columns_by_name=columns_by_name, - template_processor=template_processor, - ) - direction = desc if order_desc else asc - subq = subq.order_by(direction(ob)) - subq = subq.limit(series_limit) - - on_clause = [] - for gby_name, gby_obj in groupby_series_columns.items(): - # in this case the column name, not the alias, needs to be - # conditionally mutated, as it refers to the column alias in - # the inner query - col_name = db_engine_spec.make_label_compatible(gby_name + "__") - on_clause.append(gby_obj == column(col_name)) - - tbl = tbl.join(subq.alias(), and_(*on_clause)) - else: - if series_limit_metric: - orderby = [ - ( - self._get_series_orderby( - series_limit_metric=series_limit_metric, - metrics_by_name=metrics_by_name, - columns_by_name=columns_by_name, - template_processor=template_processor, - ), - not order_desc, - ) - ] - - # run prequery to get top groups - prequery_obj = { - "is_timeseries": False, - "row_limit": series_limit, - "metrics": metrics, - "granularity": granularity, - "groupby": groupby, - "from_dttm": inner_from_dttm or from_dttm, - "to_dttm": inner_to_dttm or to_dttm, - "filter": filter, - "orderby": orderby, - "extras": extras, - "columns": columns, - "order_desc": True, - } - - result = self.query(prequery_obj) - prequeries.append(result.query) - dimensions = [ - c - for c in result.df.columns - if c not in metrics and c in groupby_series_columns - ] - top_groups = self._get_top_groups( - result.df, dimensions, groupby_series_columns, columns_by_name - ) - qry = qry.where(top_groups) - - qry = qry.select_from(tbl) - - if is_rowcount: - if not db_engine_spec.allows_subqueries: - raise QueryObjectValidationError( - _("Database does not support subqueries") - ) - label = "rowcount" - col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) - qry = select([col]).select_from(qry.alias("rowcount_qry")) - labels_expected = [label] - - filter_columns = [flt.get("col") for flt in filter] if filter else [] - rejected_filter_columns = [ - col - for col in filter_columns - if col - and not is_adhoc_column(col) - and col not in self.column_names - and col not in applied_template_filters - ] + rejected_adhoc_filters_columns - applied_filter_columns = [ - col - for col in filter_columns - if col - and not is_adhoc_column(col) - and (col in self.column_names or col in applied_template_filters) - ] + applied_adhoc_filters_columns - - return SqlaQuery( - applied_template_filters=applied_template_filters, - rejected_filter_columns=rejected_filter_columns, - applied_filter_columns=applied_filter_columns, - cte=cte, - extra_cache_keys=extra_cache_keys, - labels_expected=labels_expected, - sqla_query=qry, - prequeries=prequeries, - ) - def _get_series_orderby( self, series_limit_metric: Metric, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 0790e3709abd6..cc3b34ae627cb 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -14,20 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""a collection of model-related helper classes and functions""" # pylint: disable=too-many-lines +"""a collection of model-related helper classes and functions""" +import dataclasses import json import logging import re import uuid +from collections import defaultdict from datetime import datetime, timedelta from json.decoder import JSONDecodeError from typing import ( Any, cast, Dict, + Hashable, List, - Mapping, NamedTuple, Optional, Set, @@ -71,6 +73,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( AdvancedDataTypeResponseError, + ColumnNotFoundException, QueryClauseValidationException, QueryObjectValidationError, SupersetSecurityException, @@ -88,7 +91,13 @@ QueryObjectDict, ) from superset.utils import core as utils -from superset.utils.core import get_user_id +from superset.utils.core import ( + GenericDataType, + get_column_name, + get_user_id, + is_adhoc_column, + remove_duplicates, +) from superset.utils.dates import datetime_to_epoch if TYPE_CHECKING: @@ -668,6 +677,8 @@ def clone_model( # todo(hugh): centralize where this code lives class QueryStringExtended(NamedTuple): applied_template_filters: Optional[List[str]] + applied_filter_columns: List[ColumnTyping] + rejected_filter_columns: List[ColumnTyping] labels_expected: List[str] prequeries: List[str] sql: str @@ -675,6 +686,8 @@ class QueryStringExtended(NamedTuple): class SqlaQuery(NamedTuple): applied_template_filters: List[str] + applied_filter_columns: List[ColumnTyping] + rejected_filter_columns: List[ColumnTyping] cte: Optional[str] extra_cache_keys: List[Any] labels_expected: List[str] @@ -698,7 +711,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods } @property - def query(self) -> str: + def fetch_value_predicate(self) -> str: + return "fix this!" + + @property + def type(self) -> str: + raise NotImplementedError() + + @property + def db_extra(self) -> Optional[Dict[str, Any]]: + raise NotImplementedError() + + def query(self, query_obj: QueryObjectDict) -> QueryResult: raise NotImplementedError() @property @@ -711,7 +735,7 @@ def owners_data(self) -> List[Any]: @property def metrics(self) -> List[Any]: - raise NotImplementedError() + return [] @property def uid(self) -> str: @@ -761,17 +785,59 @@ def sql(self) -> str: def columns(self) -> List[Any]: raise NotImplementedError() - @property - def get_fetch_values_predicate(self) -> List[Any]: + def get_fetch_values_predicate( + self, template_processor: Optional[BaseTemplateProcessor] = None + ) -> TextClause: raise NotImplementedError() - @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: + def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: raise NotImplementedError() def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: raise NotImplementedError() + def get_sqla_row_level_filters( + self, + template_processor: BaseTemplateProcessor, + ) -> List[TextClause]: + """ + Return the appropriate row level security filters for this table and the + current user. A custom username can be passed when the user is not present in the + Flask global namespace. + + :param template_processor: The template processor to apply to the filters. + :returns: A list of SQL clauses to be ANDed together. + """ + all_filters: List[TextClause] = [] + filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) + try: + for filter_ in security_manager.get_rls_filters(self): + clause = self.text( + f"({template_processor.process_template(filter_.clause)})" + ) + if filter_.group_key: + filter_groups[filter_.group_key].append(clause) + else: + all_filters.append(clause) + + if is_feature_enabled("EMBEDDED_SUPERSET"): + for rule in security_manager.get_guest_rls_filters(self): + clause = self.text( + f"({template_processor.process_template(rule['clause'])})" + ) + all_filters.append(clause) + + grouped_filters = [or_(*clauses) for clauses in filter_groups.values()] + all_filters.extend(grouped_filters) + return all_filters + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in RLS filters: %(msg)s", + msg=ex.message, + ) + ) from ex + def _process_sql_expression( # pylint: disable=no-self-use self, expression: Optional[str], @@ -870,14 +936,19 @@ def validate_adhoc_subquery( return ";\n".join(str(statement) for statement in statements) - def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: + def get_query_str_extended( + self, query_obj: QueryObjectDict, mutate: bool = True + ) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) - sql = self.mutate_query_from_config(sql) + if mutate: + sql = self.mutate_query_from_config(sql) return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, + applied_filter_columns=sqlaq.applied_filter_columns, + rejected_filter_columns=sqlaq.rejected_filter_columns, labels_expected=sqlaq.labels_expected, prequeries=sqlaq.prequeries, sql=sql, @@ -1002,9 +1073,16 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: logger.warning( "Query %s on schema %s failed", sql, self.schema, exc_info=True ) + db_engine_spec = self.db_engine_spec + errors = [ + dataclasses.asdict(error) for error in db_engine_spec.extract_errors(ex) + ] error_message = utils.error_msg_from_exception(ex) return QueryResult( + applied_template_filters=query_str_ext.applied_template_filters, + applied_filter_columns=query_str_ext.applied_filter_columns, + rejected_filter_columns=query_str_ext.rejected_filter_columns, status=status, df=df, duration=datetime.now() - qry_start_dttm, @@ -1074,7 +1152,7 @@ def get_from_clause( def adhoc_metric_to_sqla( self, metric: AdhocMetric, - columns_by_name: Dict[str, "TableColumn"], # # pylint: disable=unused-argument + columns_by_name: Dict[str, "TableColumn"], # pylint: disable=unused-argument template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ @@ -1174,19 +1252,20 @@ def get_query_str(self, query_obj: QueryObjectDict) -> str: def _get_series_orderby( self, series_limit_metric: Metric, - metrics_by_name: Mapping[str, "SqlMetric"], - columns_by_name: Mapping[str, "TableColumn"], + metrics_by_name: Dict[str, "SqlMetric"], + columns_by_name: Dict[str, "TableColumn"], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) - ob = self.adhoc_metric_to_sqla( - series_limit_metric, columns_by_name # type: ignore - ) + ob = self.adhoc_metric_to_sqla(series_limit_metric, columns_by_name) elif ( isinstance(series_limit_metric, str) and series_limit_metric in metrics_by_name ): - ob = metrics_by_name[series_limit_metric].get_sqla_col() + ob = metrics_by_name[series_limit_metric].get_sqla_col( + template_processor=template_processor + ) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=series_limit_metric) @@ -1195,26 +1274,11 @@ def _get_series_orderby( def adhoc_column_to_sqla( self, - col: Type["AdhocColumn"], # type: ignore + col: "AdhocColumn", # type: ignore + force_type_check: bool = False, template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - """ - Turn an adhoc column into a sqlalchemy column. - - :param col: Adhoc column definition - :param template_processor: template_processor instance - :returns: The metric defined as a sqlalchemy column - :rtype: sqlalchemy.sql.column - """ - label = utils.get_column_name(col) # type: ignore - expression = self._process_sql_expression( - expression=col["sqlExpression"], - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) - sqla_column = literal_column(expression) - return self.make_sqla_column_compatible(sqla_column, label) + raise NotImplementedError() def _get_top_groups( self, @@ -1252,29 +1316,30 @@ def dttm_sql_literal(self, dttm: sa.DateTime, col_type: Optional[str]) -> str: return f'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}' - def get_time_filter( + def get_time_filter( # pylint: disable=too-many-arguments self, - time_col: Dict[str, Any], + time_col: "TableColumn", start_dttm: Optional[sa.DateTime], end_dttm: Optional[sa.DateTime], + label: Optional[str] = "__time", + template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - label = "__time" - col = time_col.get("column_name") - sqla_col = literal_column(col) - my_col = self.make_sqla_column_compatible(sqla_col, label) + col = self.convert_tbl_column_to_sqla_col( + time_col, label=label, template_processor=template_processor + ) l = [] if start_dttm: l.append( - my_col + col >= self.db_engine_spec.get_text_clause( - self.dttm_sql_literal(start_dttm, time_col.get("type")) + self.dttm_sql_literal(start_dttm, time_col.type) ) ) if end_dttm: l.append( - my_col + col < self.db_engine_spec.get_text_clause( - self.dttm_sql_literal(end_dttm, time_col.get("type")) + self.dttm_sql_literal(end_dttm, time_col.type) ) ) return and_(*l) @@ -1338,11 +1403,24 @@ def get_timestamp_expression( time_expr = self.db_engine_spec.get_timestamp_expr(col, None, time_grain) return self.make_sqla_column_compatible(time_expr, label) - def get_sqla_col(self, col: Dict[str, Any]) -> Column: - label = col.get("column_name") - col_type = col.get("type") - col = sa.column(label, type_=col_type) - return self.make_sqla_column_compatible(col, label) + def convert_tbl_column_to_sqla_col( + self, + tbl_column: "TableColumn", + label: Optional[str] = None, + template_processor: Optional[BaseTemplateProcessor] = None, + ) -> Column: + label = label or tbl_column.column_name + db_engine_spec = self.db_engine_spec + column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra) + type_ = column_spec.sqla_type if column_spec else None + if expression := tbl_column.expression: + if template_processor: + expression = template_processor.process_template(expression) + col = literal_column(expression, type_=type_) + else: + col = sa.column(tbl_column.column_name, type_=type_) + col = self.make_sqla_column_compatible(col, label) + return col def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements self, @@ -1389,11 +1467,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma "time_column": granularity, "time_grain": time_grain, "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": [col.get("column_name") for col in self.columns], + "table_columns": [col.column_name for col in self.columns], "filter": filter, } columns = columns or [] groupby = groupby or [] + rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] + applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] series_column_names = utils.get_column_names(series_columns or []) # deprecated, to be removed in 2.0 if is_timeseries and timeseries_limit: @@ -1418,8 +1498,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma granularity = self.main_dttm_col columns_by_name: Dict[str, "TableColumn"] = { - col.get("column_name"): col - for col in self.columns # col.column_name: col for col in self.columns + col.column_name: col for col in self.columns + } + + metrics_by_name: Dict[str, "SqlMetric"] = { + m.metric_name: m for m in self.metrics } if not granularity and is_timeseries: @@ -1443,6 +1526,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma template_processor=template_processor, ) ) + elif isinstance(metric, str) and metric in metrics_by_name: + metrics_exprs.append( + metrics_by_name[metric].get_sqla_col( + template_processor=template_processor + ) + ) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=metric) @@ -1481,14 +1570,17 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: - gb_column_obj = columns_by_name[col] - if isinstance(gb_column_obj, dict): - col = self.get_sqla_col(gb_column_obj) - else: - col = gb_column_obj.get_sqla_col() + col = self.convert_tbl_column_to_sqla_col( + columns_by_name[col], template_processor=template_processor + ) elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True + elif col in metrics_by_name: + col = metrics_by_name[col].get_sqla_col( + template_processor=template_processor + ) + need_groupby = True if isinstance(col, ColumnElement): orderby_exprs.append(col) @@ -1514,33 +1606,24 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma # if groupby field/expr equals granularity field/expr if selected == granularity: table_col = columns_by_name[selected] - if isinstance(table_col, dict): - outer = self.get_timestamp_expression( - column=table_col, - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) - else: - outer = table_col.get_timestamp_expression( - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) + outer = table_col.get_timestamp_expression( + time_grain=time_grain, + label=selected, + template_processor=template_processor, + ) # if groupby field equals a selected column elif selected in columns_by_name: - if isinstance(columns_by_name[selected], dict): - outer = sa.column(f"{selected}") - outer = self.make_sqla_column_compatible(outer, selected) - else: - outer = columns_by_name[selected].get_sqla_col() + outer = self.convert_tbl_column_to_sqla_col( + columns_by_name[selected], + template_processor=template_processor, + ) else: - selected = self.validate_adhoc_subquery( + selected = validate_adhoc_subquery( selected, self.database_id, self.schema, ) - outer = sa.column(f"{selected}") + outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: outer = self.adhoc_column_to_sqla( @@ -1554,19 +1637,28 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma select_exprs.append(outer) elif columns: for selected in columns: - selected = self.validate_adhoc_subquery( - selected, + if is_adhoc_column(selected): + _sql = selected["sqlExpression"] + _column_label = selected["label"] + elif isinstance(selected, str): + _sql = selected + _column_label = selected + + selected = validate_adhoc_subquery( + _sql, self.database_id, self.schema, ) - if isinstance(columns_by_name[selected], dict): - select_exprs.append(sa.column(f"{selected}")) - else: - select_exprs.append( - columns_by_name[selected].get_sqla_col() - if selected in columns_by_name - else self.make_sqla_column_compatible(literal_column(selected)) + + select_exprs.append( + self.convert_tbl_column_to_sqla_col( + columns_by_name[selected], template_processor=template_processor ) + if isinstance(selected, str) and selected in columns_by_name + else self.make_sqla_column_compatible( + literal_column(selected), _column_label + ) + ) metrics_exprs = [] if granularity: @@ -1577,57 +1669,43 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col=granularity, ) ) - time_filters: List[Any] = [] + time_filters = [] if is_timeseries: - if isinstance(dttm_col, dict): - timestamp = self.get_timestamp_expression( - dttm_col, time_grain, template_processor=template_processor - ) - else: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) + timestamp = dttm_col.get_timestamp_expression( + time_grain=time_grain, template_processor=template_processor + ) # always put timestamp as the first column select_exprs.insert(0, timestamp) groupby_all_columns[timestamp.name] = timestamp # Use main dttm column to support index with secondary dttm columns. - if db_engine_spec.time_secondary_columns: - if isinstance(dttm_col, dict): - dttm_col_name = dttm_col.get("column_name") - else: - dttm_col_name = dttm_col.column_name - - if ( - self.main_dttm_col in self.dttm_cols - and self.main_dttm_col != dttm_col_name - ): - if isinstance(self.main_dttm_col, dict): - time_filters.append( - self.get_time_filter( - self.main_dttm_col, - from_dttm, - to_dttm, - ) - ) - else: - time_filters.append( - columns_by_name[self.main_dttm_col].get_time_filter( - from_dttm, - to_dttm, - ) - ) + if ( + db_engine_spec.time_secondary_columns + and self.main_dttm_col in self.dttm_cols + and self.main_dttm_col != dttm_col.column_name + ): + time_filters.append( + self.get_time_filter( + time_col=columns_by_name[self.main_dttm_col], + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + ) - if isinstance(dttm_col, dict): - time_filters.append(self.get_time_filter(dttm_col, from_dttm, to_dttm)) - else: - time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) + time_filter_column = self.get_time_filter( + time_col=dttm_col, + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + time_filters.append(time_filter_column) # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use # raw columns as custom SQL adhoc metric). - select_exprs = utils.remove_duplicates( + select_exprs = remove_duplicates( select_exprs + metrics_exprs, key=lambda x: x.name ) @@ -1637,7 +1715,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma # Order by columns are "hidden" columns, some databases require them # always be present in SELECT if an aggregation function is used if not db_engine_spec.allows_hidden_orderby_agg: - select_exprs = utils.remove_duplicates(select_exprs + orderby_exprs) + select_exprs = remove_duplicates(select_exprs + orderby_exprs) qry = sa.select(select_exprs) @@ -1659,14 +1737,19 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma sqla_col: Optional[Column] = None if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col - elif utils.is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore + elif is_adhoc_column(flt_col): + try: + sqla_col = self.adhoc_column_to_sqla(flt_col, force_type_check=True) + applied_adhoc_filters_columns.append(flt_col) + except ColumnNotFoundException: + rejected_adhoc_filters_columns.append(flt_col) + continue else: col_obj = columns_by_name.get(cast(str, flt_col)) filter_grain = flt.get("grain") if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - if utils.get_column_name(flt_col) in removed_filters: + if get_column_name(flt_col) in removed_filters: # Skip generating SQLA filter when the jinja template handles it. continue @@ -1674,44 +1757,29 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if sqla_col is not None: pass elif col_obj and filter_grain: - if isinstance(col_obj, dict): - sqla_col = self.get_timestamp_expression( - col_obj, time_grain, template_processor=template_processor - ) - else: - sqla_col = col_obj.get_timestamp_expression( - time_grain=filter_grain, - template_processor=template_processor, - ) - elif col_obj and isinstance(col_obj, dict): - sqla_col = sa.column(col_obj.get("column_name")) + sqla_col = col_obj.get_timestamp_expression( + time_grain=filter_grain, template_processor=template_processor + ) elif col_obj: - sqla_col = col_obj.get_sqla_col() - - if col_obj and isinstance(col_obj, dict): - col_type = col_obj.get("type") - else: - col_type = col_obj.type if col_obj else None + sqla_col = self.convert_tbl_column_to_sqla_col( + tbl_column=col_obj, template_processor=template_processor + ) + col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, - db_extra=self.database.get_extra(), # type: ignore + # db_extra=self.database.get_extra(), ) is_list_target = op in ( utils.FilterOperator.IN.value, utils.FilterOperator.NOT_IN.value, ) - if col_obj and isinstance(col_obj, dict): - col_advanced_data_type = "" - else: - col_advanced_data_type = ( - col_obj.advanced_data_type if col_obj else "" - ) + col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" if col_spec and not col_advanced_data_type: target_generic_type = col_spec.generic_type else: - target_generic_type = utils.GenericDataType.STRING + target_generic_type = GenericDataType.STRING eq = self.filter_values_handler( values=val, operator=op, @@ -1719,7 +1787,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma target_native_type=col_type, is_list_target=is_list_target, db_engine_spec=db_engine_spec, - db_extra=self.database.get_extra(), # type: ignore + # db_extra=self.database.get_extra(), ) if ( col_advanced_data_type != "" @@ -1775,7 +1843,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma elif op == utils.FilterOperator.IS_FALSE.value: where_clause_and.append(sqla_col.is_(False)) else: - if eq is None: + if ( + op + not in { + utils.FilterOperator.EQUALS.value, + utils.FilterOperator.NOT_EQUALS.value, + } + and eq is None + ): raise QueryObjectValidationError( _( "Must specify a value for filters " @@ -1813,19 +1888,20 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma time_col=col_obj, start_dttm=_since, end_dttm=_until, + label=sqla_col.key, + template_processor=template_processor, ) ) else: raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) ) - # todo(hugh): fix this w/ template_processor - # where_clause_and += self.get_sqla_row_level_filters(template_processor) + where_clause_and += self.get_sqla_row_level_filters(template_processor) if extras: where = extras.get("where") if where: try: - where = template_processor.process_template(f"{where}") + where = template_processor.process_template(f"({where})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1833,11 +1909,17 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex + where = self._process_sql_expression( + expression=where, + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) where_clause_and += [self.text(where)] having = extras.get("having") if having: try: - having = template_processor.process_template(f"{having}") + having = template_processor.process_template(f"({having})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1845,9 +1927,18 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex + having = self._process_sql_expression( + expression=having, + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) having_clause_and += [self.text(having)] + if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore - qry = qry.where(self.get_fetch_values_predicate()) # type: ignore + qry = qry.where( + self.get_fetch_values_predicate(template_processor=template_processor) + ) if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: @@ -1887,7 +1978,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma inner_groupby_exprs = [] inner_select_exprs = [] for gby_name, gby_obj in groupby_series_columns.items(): - label = utils.get_column_name(gby_name) + label = get_column_name(gby_name) inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) @@ -1897,26 +1988,25 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma inner_time_filter = [] if dttm_col and not db_engine_spec.time_groupby_inline: - if isinstance(dttm_col, dict): - inner_time_filter = [ - self.get_time_filter( - dttm_col, - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - ] - else: - inner_time_filter = [ - dttm_col.get_time_filter( - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - ] - + inner_time_filter = [ + self.get_time_filter( + time_col=dttm_col, + start_dttm=inner_from_dttm or from_dttm, + end_dttm=inner_to_dttm or to_dttm, + template_processor=template_processor, + ) + ] subq = subq.where(and_(*(where_clause_and + inner_time_filter))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr + if series_limit_metric: + ob = self._get_series_orderby( + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, + ) direction = sa.desc if order_desc else sa.asc subq = subq.order_by(direction(ob)) subq = subq.limit(series_limit) @@ -1930,6 +2020,19 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma on_clause.append(gby_obj == sa.column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) + else: + if series_limit_metric: + orderby = [ + ( + self._get_series_orderby( + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, + ), + not order_desc, + ) + ] # run prequery to get top groups prequery_obj = { @@ -1946,7 +2049,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma "columns": columns, "order_desc": True, } - result = self.exc_query(prequery_obj) + + result = self.query(prequery_obj) prequeries.append(result.query) dimensions = [ c @@ -1970,9 +2074,29 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) labels_expected = [label] + filter_columns = [flt.get("col") for flt in filter] if filter else [] + rejected_filter_columns = [ + col + for col in filter_columns + if col + and not is_adhoc_column(col) + and col not in self.column_names + and col not in applied_template_filters + ] + rejected_adhoc_filters_columns + + applied_filter_columns = [ + col + for col in filter_columns + if col + and not is_adhoc_column(col) + and (col in self.column_names or col in applied_template_filters) + ] + applied_adhoc_filters_columns + return SqlaQuery( applied_template_filters=applied_template_filters, cte=cte, + applied_filter_columns=applied_filter_columns, + rejected_filter_columns=rejected_filter_columns, extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, sqla_query=qry, diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 976ee177f94e5..87dcba2a81c3f 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -19,7 +19,7 @@ import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Dict, Hashable, List, Optional, Type, TYPE_CHECKING import simplejson as json import sqlalchemy as sqla @@ -52,9 +52,10 @@ ) from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sqllab.limiting_factor import LimitingFactor -from superset.utils.core import GenericDataType, QueryStatus, user_label +from superset.utils.core import QueryStatus, user_label if TYPE_CHECKING: + from superset.connectors.sqla.models import TableColumn from superset.db_engine_specs import BaseEngineSpec @@ -183,47 +184,33 @@ def sql_tables(self) -> List[Table]: return list(ParsedQuery(self.sql).tables) @property - def columns(self) -> List[Dict[str, Any]]: - bool_types = ("BOOL",) - num_types = ( - "DOUBLE", - "FLOAT", - "INT", - "BIGINT", - "NUMBER", - "LONG", - "REAL", - "NUMERIC", - "DECIMAL", - "MONEY", + def columns(self) -> List["TableColumn"]: + from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel + TableColumn, ) - date_types = ("DATE", "TIME") - str_types = ("VARCHAR", "STRING", "CHAR") + columns = [] - col_type = "" for col in self.extra.get("columns", []): - computed_column = {**col} - col_type = col.get("type") - - if col_type and any(map(lambda t: t in col_type.upper(), str_types)): - computed_column["type_generic"] = GenericDataType.STRING - if col_type and any(map(lambda t: t in col_type.upper(), bool_types)): - computed_column["type_generic"] = GenericDataType.BOOLEAN - if col_type and any(map(lambda t: t in col_type.upper(), num_types)): - computed_column["type_generic"] = GenericDataType.NUMERIC - if col_type and any(map(lambda t: t in col_type.upper(), date_types)): - computed_column["type_generic"] = GenericDataType.TEMPORAL - - computed_column["column_name"] = col.get("name") - computed_column["groupby"] = True - columns.append(computed_column) + columns.append( + TableColumn( + column_name=col["name"], + type=col["type"], + is_dttm=col["is_dttm"], + groupby=True, + filterable=True, + ) + ) return columns + @property + def db_extra(self) -> Optional[Dict[str, Any]]: + return None + @property def data(self) -> Dict[str, Any]: order_by_choices = [] for col in self.columns: - column_name = str(col.get("column_name") or "") + column_name = str(col.column_name or "") order_by_choices.append( (json.dumps([column_name, True]), f"{column_name} " + __("[asc]")) ) @@ -237,7 +224,7 @@ def data(self) -> Dict[str, Any]: ], "filter_select": True, "name": self.tab_name, - "columns": self.columns, + "columns": [o.data for o in self.columns], "metrics": [], "id": self.id, "type": self.type, @@ -280,7 +267,7 @@ def cache_timeout(self) -> int: @property def column_names(self) -> List[Any]: - return [col.get("column_name") for col in self.columns] + return [col.column_name for col in self.columns] @property def offset(self) -> int: @@ -295,7 +282,7 @@ def main_dttm_col(self) -> Optional[str]: @property def dttm_cols(self) -> List[Any]: - return [col.get("column_name") for col in self.columns if col.get("is_dttm")] + return [col.column_name for col in self.columns if col.is_dttm] @property def schema_perm(self) -> str: @@ -310,7 +297,7 @@ def default_endpoint(self) -> str: return "" @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: + def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]: return [] @property @@ -338,7 +325,7 @@ def get_column(self, column_name: Optional[str]) -> Optional[Dict[str, Any]]: if not column_name: return None for col in self.columns: - if col.get("column_name") == column_name: + if col.column_name == column_name: return col return None diff --git a/superset/utils/core.py b/superset/utils/core.py index 460c17b949dac..8cf1076f5eef3 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1732,13 +1732,13 @@ def extract_dataframe_dtypes( if datasource: for column in datasource.columns: if isinstance(column, dict): - columns_by_name[column.get("column_name")] = column # type: ignore + columns_by_name[column.get("column_name")] = column else: columns_by_name[column.column_name] = column generic_types: List[GenericDataType] = [] for column in df.columns: - column_object = columns_by_name.get(column) # type: ignore + column_object = columns_by_name.get(column) series = df[column] inferred_type = infer_dtype(series) if isinstance(column_object, dict): @@ -1786,15 +1786,9 @@ def get_time_filter_status( datasource: "BaseDatasource", applied_time_extras: Dict[str, str], ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: - temporal_columns: Set[Any] - if datasource.type == "query": - temporal_columns = { - col.get("column_name") for col in datasource.columns if col.get("is_dttm") - } - else: - temporal_columns = { - col.column_name for col in datasource.columns if col.is_dttm - } + temporal_columns: Set[Any] = { + col.column_name for col in datasource.columns if col.is_dttm + } applied: List[Dict[str, str]] = [] rejected: List[Dict[str, str]] = [] time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL) diff --git a/superset/views/core.py b/superset/views/core.py index 6559db125422e..8df0d4f044ac8 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2023,7 +2023,7 @@ def sqllab_viz(self) -> FlaskResponse: # pylint: disable=no-self-use db.session.add(table) cols = [] for config_ in data.get("columns"): - column_name = config_.get("name") + column_name = config_.get("column_name") or config_.get("name") col = TableColumn( column_name=column_name, filterable=True, diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 821c80ec42083..db81488c3f9d2 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -1242,8 +1242,8 @@ def test_chart_cache_timeout_chart_not_found( [ (200, {"where": "1 = 1"}), (200, {"having": "count(*) > 0"}), - (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}), - (400, {"having": "count(*) > (select count(*) from physical_dataset)"}), + (403, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (403, {"having": "count(*) > (select count(*) from physical_dataset)"}), ], ) @with_feature_flags(ALLOW_ADHOC_SUBQUERY=False) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index d9f26239d1394..27ccdde96be29 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -493,8 +493,16 @@ def test_sqllab_viz(self): "datasourceName": f"test_viz_flow_table_{random()}", "schema": "superset", "columns": [ - {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"}, - {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"}, + { + "is_dttm": False, + "type": "STRING", + "column_name": f"viz_type_{random()}", + }, + { + "is_dttm": False, + "type": "OBJECT", + "column_name": f"ccount_{random()}", + }, ], "sql": """\ SELECT * @@ -523,8 +531,16 @@ def test_sqllab_viz_bad_payload(self): "chartType": "dist_bar", "schema": "superset", "columns": [ - {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"}, - {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"}, + { + "is_dttm": False, + "type": "STRING", + "column_name": f"viz_type_{random()}", + }, + { + "is_dttm": False, + "type": "OBJECT", + "column_name": f"ccount_{random()}", + }, ], "sql": """\ SELECT *