From e386bc426c3080a48d9be1b5fafd8e6fbd84df63 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 20 Jan 2023 17:25:56 +0200 Subject: [PATCH 01/29] re patch sqlatable into exploremixin --- superset/connectors/sqla/models.py | 1308 +++++++++-------- superset/models/helpers.py | 349 +++-- superset/models/sql_lab.py | 3 +- superset/result_set.py | 4 +- .../utils/pandas_postprocessing/boxplot.py | 8 +- .../utils/pandas_postprocessing/flatten.py | 2 +- 6 files changed, 854 insertions(+), 820 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c5fd025f4ee2f..b363188b872df 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -105,7 +105,12 @@ ) from superset.models.annotations import Annotation from superset.models.core import Database -from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult +from superset.models.helpers import ( + AuditMixinNullable, + CertificationMixin, + ExploreMixin, + QueryResult, +) from superset.sql_parse import ParsedQuery, sanitize_clause from superset.superset_typing import ( AdhocColumn, @@ -149,12 +154,13 @@ class SqlaQuery(NamedTuple): prequeries: List[str] sqla_query: Select +from superset.models.helpers import QueryStringExtended -class QueryStringExtended(NamedTuple): - applied_template_filters: Optional[List[str]] - labels_expected: List[str] - prequeries: List[str] - sql: str +# class QueryStringExtended(NamedTuple): +# applied_template_filters: Optional[List[str]] +# labels_expected: List[str] +# prequeries: List[str] +# sql: str @dataclass @@ -534,7 +540,7 @@ def _process_sql_expression( return expression -class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods +class SqlaTable(Model, BaseDatasource, ExploreMixin): # pylint: disable=too-many-public-methods """An ORM object for SqlAlchemy table references""" type = "table" @@ -980,7 +986,7 @@ def adhoc_metric_to_sqla( return self.make_sqla_column_compatible(sqla_metric, label) - def adhoc_column_to_sqla( + def adhoc_column_to_sqla( # type: ignore self, col: AdhocColumn, template_processor: Optional[BaseTemplateProcessor] = None, @@ -1118,649 +1124,649 @@ def get_sqla_row_level_filters( def text(self, clause: str) -> TextClause: return self.db_engine_spec.get_text_clause(clause) - def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - self, - apply_fetch_values_predicate: bool = False, - columns: Optional[List[ColumnTyping]] = None, - extras: Optional[Dict[str, Any]] = None, - filter: Optional[ # pylint: disable=redefined-builtin - List[QueryObjectFilterClause] - ] = None, - from_dttm: Optional[datetime] = None, - granularity: Optional[str] = None, - groupby: Optional[List[Column]] = None, - inner_from_dttm: Optional[datetime] = None, - inner_to_dttm: Optional[datetime] = None, - is_rowcount: bool = False, - is_timeseries: bool = True, - metrics: Optional[List[Metric]] = None, - orderby: Optional[List[OrderBy]] = None, - order_desc: bool = True, - to_dttm: Optional[datetime] = None, - series_columns: Optional[List[Column]] = None, - series_limit: Optional[int] = None, - series_limit_metric: Optional[Metric] = None, - row_limit: Optional[int] = None, - row_offset: Optional[int] = None, - timeseries_limit: Optional[int] = None, - timeseries_limit_metric: Optional[Metric] = None, - time_shift: Optional[str] = None, - ) -> SqlaQuery: - """Querying any sqla table from this common interface""" - if granularity not in self.dttm_cols and granularity is not None: - granularity = self.main_dttm_col - - extras = extras or {} - time_grain = extras.get("time_grain_sqla") - - template_kwargs = { - "columns": columns, - "from_dttm": from_dttm.isoformat() if from_dttm else None, - "groupby": groupby, - "metrics": metrics, - "row_limit": row_limit, - "row_offset": row_offset, - "time_column": granularity, - "time_grain": time_grain, - "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": [col.column_name for col in self.columns], - "filter": filter, - } - columns = columns or [] - groupby = groupby or [] - series_column_names = utils.get_column_names(series_columns or []) - # deprecated, to be removed in 2.0 - if is_timeseries and timeseries_limit: - series_limit = timeseries_limit - series_limit_metric = series_limit_metric or timeseries_limit_metric - template_kwargs.update(self.template_params_dict) - extra_cache_keys: List[Any] = [] - template_kwargs["extra_cache_keys"] = extra_cache_keys - removed_filters: List[str] = [] - applied_template_filters: List[str] = [] - template_kwargs["removed_filters"] = removed_filters - template_kwargs["applied_filters"] = applied_template_filters - template_processor = self.get_template_processor(**template_kwargs) - db_engine_spec = self.db_engine_spec - prequeries: List[str] = [] - orderby = orderby or [] - need_groupby = bool(metrics is not None or groupby) - metrics = metrics or [] - - # For backward compatibility - if granularity not in self.dttm_cols and granularity is not None: - granularity = self.main_dttm_col - - columns_by_name: Dict[str, TableColumn] = { - col.column_name: col for col in self.columns - } - - metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} - - if not granularity and is_timeseries: - raise QueryObjectValidationError( - _( - "Datetime column not provided as part table configuration " - "and is required by this type of chart" - ) - ) - if not metrics and not columns and not groupby: - raise QueryObjectValidationError(_("Empty query?")) - - metrics_exprs: List[ColumnElement] = [] - for metric in metrics: - if utils.is_adhoc_metric(metric): - assert isinstance(metric, dict) - metrics_exprs.append( - self.adhoc_metric_to_sqla( - metric=metric, - columns_by_name=columns_by_name, - template_processor=template_processor, - ) - ) - elif isinstance(metric, str) and metric in metrics_by_name: - metrics_exprs.append( - metrics_by_name[metric].get_sqla_col( - template_processor=template_processor - ) - ) - else: - raise QueryObjectValidationError( - _("Metric '%(metric)s' does not exist", metric=metric) - ) - - if metrics_exprs: - main_metric_expr = metrics_exprs[0] - else: - main_metric_expr, label = literal_column("COUNT(*)"), "ccount" - main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) - - # To ensure correct handling of the ORDER BY labeling we need to reference the - # metric instance if defined in the SELECT clause. - # use the key of the ColumnClause for the expected label - metrics_exprs_by_label = {m.key: m for m in metrics_exprs} - metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} - - # Since orderby may use adhoc metrics, too; we need to process them first - orderby_exprs: List[ColumnElement] = [] - for orig_col, ascending in orderby: - col: Union[AdhocMetric, ColumnElement] = orig_col - if isinstance(col, dict): - col = cast(AdhocMetric, col) - if col.get("sqlExpression"): - col["sqlExpression"] = _process_sql_expression( - expression=col["sqlExpression"], - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) - if utils.is_adhoc_metric(col): - # add adhoc sort by column to columns_by_name if not exists - col = self.adhoc_metric_to_sqla(col, columns_by_name) - # if the adhoc metric has been defined before - # use the existing instance. - col = metrics_exprs_by_expr.get(str(col), col) - need_groupby = True - elif col in columns_by_name: - col = columns_by_name[col].get_sqla_col( - template_processor=template_processor - ) - elif col in metrics_exprs_by_label: - col = metrics_exprs_by_label[col] - need_groupby = True - elif col in metrics_by_name: - col = metrics_by_name[col].get_sqla_col( - template_processor=template_processor - ) - need_groupby = True - - if isinstance(col, ColumnElement): - orderby_exprs.append(col) - else: - # Could not convert a column reference to valid ColumnElement - raise QueryObjectValidationError( - _("Unknown column used in orderby: %(col)s", col=orig_col) - ) - - select_exprs: List[Union[Column, Label]] = [] - groupby_all_columns = {} - groupby_series_columns = {} - - # filter out the pseudo column __timestamp from columns - columns = [col for col in columns if col != utils.DTTM_ALIAS] - dttm_col = columns_by_name.get(granularity) if granularity else None - - if need_groupby: - # dedup columns while preserving order - columns = groupby or columns - for selected in columns: - if isinstance(selected, str): - # if groupby field/expr equals granularity field/expr - if selected == granularity: - table_col = columns_by_name[selected] - outer = table_col.get_timestamp_expression( - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) - # if groupby field equals a selected column - elif selected in columns_by_name: - outer = columns_by_name[selected].get_sqla_col( - template_processor=template_processor - ) - else: - selected = validate_adhoc_subquery( - selected, - self.database_id, - self.schema, - ) - outer = literal_column(f"({selected})") - outer = self.make_sqla_column_compatible(outer, selected) - else: - outer = self.adhoc_column_to_sqla( - col=selected, template_processor=template_processor - ) - groupby_all_columns[outer.name] = outer - if ( - is_timeseries and not series_column_names - ) or outer.name in series_column_names: - groupby_series_columns[outer.name] = outer - select_exprs.append(outer) - elif columns: - for selected in columns: - if is_adhoc_column(selected): - _sql = selected["sqlExpression"] - _column_label = selected["label"] - elif isinstance(selected, str): - _sql = selected - _column_label = selected - - selected = validate_adhoc_subquery( - _sql, - self.database_id, - self.schema, - ) - select_exprs.append( - columns_by_name[selected].get_sqla_col( - template_processor=template_processor - ) - if isinstance(selected, str) and selected in columns_by_name - else self.make_sqla_column_compatible( - literal_column(selected), _column_label - ) - ) - metrics_exprs = [] - - if granularity: - if granularity not in columns_by_name or not dttm_col: - raise QueryObjectValidationError( - _( - 'Time column "%(col)s" does not exist in dataset', - col=granularity, - ) - ) - time_filters = [] - - if is_timeseries: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) - # always put timestamp as the first column - select_exprs.insert(0, timestamp) - groupby_all_columns[timestamp.name] = timestamp - - # Use main dttm column to support index with secondary dttm columns. - if ( - db_engine_spec.time_secondary_columns - and self.main_dttm_col in self.dttm_cols - and self.main_dttm_col != dttm_col.column_name - ): - time_filters.append( - columns_by_name[self.main_dttm_col].get_time_filter( - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor, - ) - ) - time_filters.append( - dttm_col.get_time_filter( - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor, - ) - ) - - # Always remove duplicates by column name, as sometimes `metrics_exprs` - # can have the same name as a groupby column (e.g. when users use - # raw columns as custom SQL adhoc metric). - select_exprs = remove_duplicates( - select_exprs + metrics_exprs, key=lambda x: x.name - ) - - # Expected output columns - labels_expected = [c.key for c in select_exprs] - - # Order by columns are "hidden" columns, some databases require them - # always be present in SELECT if an aggregation function is used - if not db_engine_spec.allows_hidden_ordeby_agg: - select_exprs = remove_duplicates(select_exprs + orderby_exprs) - - qry = sa.select(select_exprs) - - tbl, cte = self.get_from_clause(template_processor) - - if groupby_all_columns: - qry = qry.group_by(*groupby_all_columns.values()) - - where_clause_and = [] - having_clause_and = [] - - for flt in filter: # type: ignore - if not all(flt.get(s) for s in ["col", "op"]): - continue - flt_col = flt["col"] - val = flt.get("val") - op = flt["op"].upper() - col_obj: Optional[TableColumn] = None - sqla_col: Optional[Column] = None - if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: - col_obj = dttm_col - elif is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(flt_col) - else: - col_obj = columns_by_name.get(flt_col) - filter_grain = flt.get("grain") - - if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - if get_column_name(flt_col) in removed_filters: - # Skip generating SQLA filter when the jinja template handles it. - continue - - if col_obj or sqla_col is not None: - if sqla_col is not None: - pass - elif col_obj and filter_grain: - sqla_col = col_obj.get_timestamp_expression( - time_grain=filter_grain, template_processor=template_processor - ) - elif col_obj: - sqla_col = col_obj.get_sqla_col( - template_processor=template_processor - ) - col_type = col_obj.type if col_obj else None - col_spec = db_engine_spec.get_column_spec( - native_type=col_type, - db_extra=self.database.get_extra(), - ) - is_list_target = op in ( - utils.FilterOperator.IN.value, - utils.FilterOperator.NOT_IN.value, - ) - - col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" - - if col_spec and not col_advanced_data_type: - target_generic_type = col_spec.generic_type - else: - target_generic_type = GenericDataType.STRING - eq = self.filter_values_handler( - values=val, - operator=op, - target_generic_type=target_generic_type, - target_native_type=col_type, - is_list_target=is_list_target, - db_engine_spec=db_engine_spec, - db_extra=self.database.get_extra(), - ) - if ( - col_advanced_data_type != "" - and feature_flag_manager.is_feature_enabled( - "ENABLE_ADVANCED_DATA_TYPES" - ) - and col_advanced_data_type in ADVANCED_DATA_TYPES - ): - values = eq if is_list_target else [eq] # type: ignore - bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ - col_advanced_data_type - ].translate_type( - { - "type": col_advanced_data_type, - "values": values, - } - ) - if bus_resp["error_message"]: - raise AdvancedDataTypeResponseError( - _(bus_resp["error_message"]) - ) - - where_clause_and.append( - ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( - sqla_col, op, bus_resp["values"] - ) - ) - elif is_list_target: - assert isinstance(eq, (tuple, list)) - if len(eq) == 0: - raise QueryObjectValidationError( - _("Filter value list cannot be empty") - ) - if len(eq) > len( - eq_without_none := [x for x in eq if x is not None] - ): - is_null_cond = sqla_col.is_(None) - if eq: - cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) - else: - cond = is_null_cond - else: - cond = sqla_col.in_(eq) - if op == utils.FilterOperator.NOT_IN.value: - cond = ~cond - where_clause_and.append(cond) - elif op == utils.FilterOperator.IS_NULL.value: - where_clause_and.append(sqla_col.is_(None)) - elif op == utils.FilterOperator.IS_NOT_NULL.value: - where_clause_and.append(sqla_col.isnot(None)) - elif op == utils.FilterOperator.IS_TRUE.value: - where_clause_and.append(sqla_col.is_(True)) - elif op == utils.FilterOperator.IS_FALSE.value: - where_clause_and.append(sqla_col.is_(False)) - else: - if ( - op - not in { - utils.FilterOperator.EQUALS.value, - utils.FilterOperator.NOT_EQUALS.value, - } - and eq is None - ): - raise QueryObjectValidationError( - _( - "Must specify a value for filters " - "with comparison operators" - ) - ) - if op == utils.FilterOperator.EQUALS.value: - where_clause_and.append(sqla_col == eq) - elif op == utils.FilterOperator.NOT_EQUALS.value: - where_clause_and.append(sqla_col != eq) - elif op == utils.FilterOperator.GREATER_THAN.value: - where_clause_and.append(sqla_col > eq) - elif op == utils.FilterOperator.LESS_THAN.value: - where_clause_and.append(sqla_col < eq) - elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: - where_clause_and.append(sqla_col >= eq) - elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: - where_clause_and.append(sqla_col <= eq) - elif op == utils.FilterOperator.LIKE.value: - where_clause_and.append(sqla_col.like(eq)) - elif op == utils.FilterOperator.ILIKE.value: - where_clause_and.append(sqla_col.ilike(eq)) - elif ( - op == utils.FilterOperator.TEMPORAL_RANGE.value - and isinstance(eq, str) - and col_obj is not None - ): - _since, _until = get_since_until_from_time_range( - time_range=eq, - time_shift=time_shift, - extras=extras, - ) - where_clause_and.append( - col_obj.get_time_filter( - start_dttm=_since, - end_dttm=_until, - label=sqla_col.key, - template_processor=template_processor, - ) - ) - else: - raise QueryObjectValidationError( - _("Invalid filter operation type: %(op)s", op=op) - ) - where_clause_and += self.get_sqla_row_level_filters(template_processor) - if extras: - where = extras.get("where") - if where: - try: - where = template_processor.process_template(f"({where})") - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error in jinja expression in WHERE clause: %(msg)s", - msg=ex.message, - ) - ) from ex - where = _process_sql_expression( - expression=where, - database_id=self.database_id, - schema=self.schema, - ) - where_clause_and += [self.text(where)] - having = extras.get("having") - if having: - try: - having = template_processor.process_template(f"({having})") - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error in jinja expression in HAVING clause: %(msg)s", - msg=ex.message, - ) - ) from ex - having = _process_sql_expression( - expression=having, - database_id=self.database_id, - schema=self.schema, - ) - having_clause_and += [self.text(having)] - - if apply_fetch_values_predicate and self.fetch_values_predicate: - qry = qry.where( - self.get_fetch_values_predicate(template_processor=template_processor) - ) - if granularity: - qry = qry.where(and_(*(time_filters + where_clause_and))) - else: - qry = qry.where(and_(*where_clause_and)) - qry = qry.having(and_(*having_clause_and)) - - self.make_orderby_compatible(select_exprs, orderby_exprs) - - for col, (orig_col, ascending) in zip(orderby_exprs, orderby): - if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): - # if engine does not allow using SELECT alias in ORDER BY - # revert to the underlying column - col = col.element - - if ( - db_engine_spec.allows_alias_in_select - and db_engine_spec.allows_hidden_cc_in_orderby - and col.name in [select_col.name for select_col in select_exprs] - ): - col = literal_column(col.name) - direction = asc if ascending else desc - qry = qry.order_by(direction(col)) - - if row_limit: - qry = qry.limit(row_limit) - if row_offset: - qry = qry.offset(row_offset) - - if series_limit and groupby_series_columns: - if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries: - # some sql dialects require for order by expressions - # to also be in the select clause -- others, e.g. vertica, - # require a unique inner alias - inner_main_metric_expr = self.make_sqla_column_compatible( - main_metric_expr, "mme_inner__" - ) - inner_groupby_exprs = [] - inner_select_exprs = [] - for gby_name, gby_obj in groupby_series_columns.items(): - label = get_column_name(gby_name) - inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") - inner_groupby_exprs.append(inner) - inner_select_exprs.append(inner) - - inner_select_exprs += [inner_main_metric_expr] - subq = select(inner_select_exprs).select_from(tbl) - inner_time_filter = [] - - if dttm_col and not db_engine_spec.time_groupby_inline: - inner_time_filter = [ - dttm_col.get_time_filter( - start_dttm=inner_from_dttm or from_dttm, - end_dttm=inner_to_dttm or to_dttm, - template_processor=template_processor, - ) - ] - subq = subq.where(and_(*(where_clause_and + inner_time_filter))) - subq = subq.group_by(*inner_groupby_exprs) - - ob = inner_main_metric_expr - if series_limit_metric: - ob = self._get_series_orderby( - series_limit_metric=series_limit_metric, - metrics_by_name=metrics_by_name, - columns_by_name=columns_by_name, - template_processor=template_processor, - ) - direction = desc if order_desc else asc - subq = subq.order_by(direction(ob)) - subq = subq.limit(series_limit) - - on_clause = [] - for gby_name, gby_obj in groupby_series_columns.items(): - # in this case the column name, not the alias, needs to be - # conditionally mutated, as it refers to the column alias in - # the inner query - col_name = db_engine_spec.make_label_compatible(gby_name + "__") - on_clause.append(gby_obj == column(col_name)) - - tbl = tbl.join(subq.alias(), and_(*on_clause)) - else: - if series_limit_metric: - orderby = [ - ( - self._get_series_orderby( - series_limit_metric=series_limit_metric, - metrics_by_name=metrics_by_name, - columns_by_name=columns_by_name, - template_processor=template_processor, - ), - not order_desc, - ) - ] - - # run prequery to get top groups - prequery_obj = { - "is_timeseries": False, - "row_limit": series_limit, - "metrics": metrics, - "granularity": granularity, - "groupby": groupby, - "from_dttm": inner_from_dttm or from_dttm, - "to_dttm": inner_to_dttm or to_dttm, - "filter": filter, - "orderby": orderby, - "extras": extras, - "columns": columns, - "order_desc": True, - } - - result = self.query(prequery_obj) - prequeries.append(result.query) - dimensions = [ - c - for c in result.df.columns - if c not in metrics and c in groupby_series_columns - ] - top_groups = self._get_top_groups( - result.df, dimensions, groupby_series_columns, columns_by_name - ) - qry = qry.where(top_groups) - - qry = qry.select_from(tbl) - - if is_rowcount: - if not db_engine_spec.allows_subqueries: - raise QueryObjectValidationError( - _("Database does not support subqueries") - ) - label = "rowcount" - col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) - qry = select([col]).select_from(qry.alias("rowcount_qry")) - labels_expected = [label] - - return SqlaQuery( - applied_template_filters=applied_template_filters, - cte=cte, - extra_cache_keys=extra_cache_keys, - labels_expected=labels_expected, - sqla_query=qry, - prequeries=prequeries, - ) + # def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + # self, + # apply_fetch_values_predicate: bool = False, + # columns: Optional[List[ColumnTyping]] = None, + # extras: Optional[Dict[str, Any]] = None, + # filter: Optional[ # pylint: disable=redefined-builtin + # List[QueryObjectFilterClause] + # ] = None, + # from_dttm: Optional[datetime] = None, + # granularity: Optional[str] = None, + # groupby: Optional[List[Column]] = None, + # inner_from_dttm: Optional[datetime] = None, + # inner_to_dttm: Optional[datetime] = None, + # is_rowcount: bool = False, + # is_timeseries: bool = True, + # metrics: Optional[List[Metric]] = None, + # orderby: Optional[List[OrderBy]] = None, + # order_desc: bool = True, + # to_dttm: Optional[datetime] = None, + # series_columns: Optional[List[Column]] = None, + # series_limit: Optional[int] = None, + # series_limit_metric: Optional[Metric] = None, + # row_limit: Optional[int] = None, + # row_offset: Optional[int] = None, + # timeseries_limit: Optional[int] = None, + # timeseries_limit_metric: Optional[Metric] = None, + # time_shift: Optional[str] = None, + # ) -> SqlaQuery: + # """Querying any sqla table from this common interface""" + # if granularity not in self.dttm_cols and granularity is not None: + # granularity = self.main_dttm_col + + # extras = extras or {} + # time_grain = extras.get("time_grain_sqla") + + # template_kwargs = { + # "columns": columns, + # "from_dttm": from_dttm.isoformat() if from_dttm else None, + # "groupby": groupby, + # "metrics": metrics, + # "row_limit": row_limit, + # "row_offset": row_offset, + # "time_column": granularity, + # "time_grain": time_grain, + # "to_dttm": to_dttm.isoformat() if to_dttm else None, + # "table_columns": [col.column_name for col in self.columns], + # "filter": filter, + # } + # columns = columns or [] + # groupby = groupby or [] + # series_column_names = utils.get_column_names(series_columns or []) + # # deprecated, to be removed in 2.0 + # if is_timeseries and timeseries_limit: + # series_limit = timeseries_limit + # series_limit_metric = series_limit_metric or timeseries_limit_metric + # template_kwargs.update(self.template_params_dict) + # extra_cache_keys: List[Any] = [] + # template_kwargs["extra_cache_keys"] = extra_cache_keys + # removed_filters: List[str] = [] + # applied_template_filters: List[str] = [] + # template_kwargs["removed_filters"] = removed_filters + # template_kwargs["applied_filters"] = applied_template_filters + # template_processor = self.get_template_processor(**template_kwargs) + # db_engine_spec = self.db_engine_spec + # prequeries: List[str] = [] + # orderby = orderby or [] + # need_groupby = bool(metrics is not None or groupby) + # metrics = metrics or [] + + # # For backward compatibility + # if granularity not in self.dttm_cols and granularity is not None: + # granularity = self.main_dttm_col + + # columns_by_name: Dict[str, TableColumn] = { + # col.column_name: col for col in self.columns + # } + + # metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} + + # if not granularity and is_timeseries: + # raise QueryObjectValidationError( + # _( + # "Datetime column not provided as part table configuration " + # "and is required by this type of chart" + # ) + # ) + # if not metrics and not columns and not groupby: + # raise QueryObjectValidationError(_("Empty query?")) + + # metrics_exprs: List[ColumnElement] = [] + # for metric in metrics: + # if utils.is_adhoc_metric(metric): + # assert isinstance(metric, dict) + # metrics_exprs.append( + # self.adhoc_metric_to_sqla( + # metric=metric, + # columns_by_name=columns_by_name, + # template_processor=template_processor, + # ) + # ) + # elif isinstance(metric, str) and metric in metrics_by_name: + # metrics_exprs.append( + # metrics_by_name[metric].get_sqla_col( + # template_processor=template_processor + # ) + # ) + # else: + # raise QueryObjectValidationError( + # _("Metric '%(metric)s' does not exist", metric=metric) + # ) + + # if metrics_exprs: + # main_metric_expr = metrics_exprs[0] + # else: + # main_metric_expr, label = literal_column("COUNT(*)"), "ccount" + # main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) + + # # To ensure correct handling of the ORDER BY labeling we need to reference the + # # metric instance if defined in the SELECT clause. + # # use the key of the ColumnClause for the expected label + # metrics_exprs_by_label = {m.key: m for m in metrics_exprs} + # metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} + + # # Since orderby may use adhoc metrics, too; we need to process them first + # orderby_exprs: List[ColumnElement] = [] + # for orig_col, ascending in orderby: + # col: Union[AdhocMetric, ColumnElement] = orig_col + # if isinstance(col, dict): + # col = cast(AdhocMetric, col) + # if col.get("sqlExpression"): + # col["sqlExpression"] = _process_sql_expression( + # expression=col["sqlExpression"], + # database_id=self.database_id, + # schema=self.schema, + # template_processor=template_processor, + # ) + # if utils.is_adhoc_metric(col): + # # add adhoc sort by column to columns_by_name if not exists + # col = self.adhoc_metric_to_sqla(col, columns_by_name) + # # if the adhoc metric has been defined before + # # use the existing instance. + # col = metrics_exprs_by_expr.get(str(col), col) + # need_groupby = True + # elif col in columns_by_name: + # col = columns_by_name[col].get_sqla_col( + # template_processor=template_processor + # ) + # elif col in metrics_exprs_by_label: + # col = metrics_exprs_by_label[col] + # need_groupby = True + # elif col in metrics_by_name: + # col = metrics_by_name[col].get_sqla_col( + # template_processor=template_processor + # ) + # need_groupby = True + + # if isinstance(col, ColumnElement): + # orderby_exprs.append(col) + # else: + # # Could not convert a column reference to valid ColumnElement + # raise QueryObjectValidationError( + # _("Unknown column used in orderby: %(col)s", col=orig_col) + # ) + + # select_exprs: List[Union[Column, Label]] = [] + # groupby_all_columns = {} + # groupby_series_columns = {} + + # # filter out the pseudo column __timestamp from columns + # columns = [col for col in columns if col != utils.DTTM_ALIAS] + # dttm_col = columns_by_name.get(granularity) if granularity else None + + # if need_groupby: + # # dedup columns while preserving order + # columns = groupby or columns + # for selected in columns: + # if isinstance(selected, str): + # # if groupby field/expr equals granularity field/expr + # if selected == granularity: + # table_col = columns_by_name[selected] + # outer = table_col.get_timestamp_expression( + # time_grain=time_grain, + # label=selected, + # template_processor=template_processor, + # ) + # # if groupby field equals a selected column + # elif selected in columns_by_name: + # outer = columns_by_name[selected].get_sqla_col( + # template_processor=template_processor + # ) + # else: + # selected = validate_adhoc_subquery( + # selected, + # self.database_id, + # self.schema, + # ) + # outer = literal_column(f"({selected})") + # outer = self.make_sqla_column_compatible(outer, selected) + # else: + # outer = self.adhoc_column_to_sqla( + # col=selected, template_processor=template_processor + # ) + # groupby_all_columns[outer.name] = outer + # if ( + # is_timeseries and not series_column_names + # ) or outer.name in series_column_names: + # groupby_series_columns[outer.name] = outer + # select_exprs.append(outer) + # elif columns: + # for selected in columns: + # if is_adhoc_column(selected): + # _sql = selected["sqlExpression"] + # _column_label = selected["label"] + # elif isinstance(selected, str): + # _sql = selected + # _column_label = selected + + # selected = validate_adhoc_subquery( + # _sql, + # self.database_id, + # self.schema, + # ) + # select_exprs.append( + # columns_by_name[selected].get_sqla_col( + # template_processor=template_processor + # ) + # if isinstance(selected, str) and selected in columns_by_name + # else self.make_sqla_column_compatible( + # literal_column(selected), _column_label + # ) + # ) + # metrics_exprs = [] + + # if granularity: + # if granularity not in columns_by_name or not dttm_col: + # raise QueryObjectValidationError( + # _( + # 'Time column "%(col)s" does not exist in dataset', + # col=granularity, + # ) + # ) + # time_filters = [] + + # if is_timeseries: + # timestamp = dttm_col.get_timestamp_expression( + # time_grain=time_grain, template_processor=template_processor + # ) + # # always put timestamp as the first column + # select_exprs.insert(0, timestamp) + # groupby_all_columns[timestamp.name] = timestamp + + # # Use main dttm column to support index with secondary dttm columns. + # if ( + # db_engine_spec.time_secondary_columns + # and self.main_dttm_col in self.dttm_cols + # and self.main_dttm_col != dttm_col.column_name + # ): + # time_filters.append( + # columns_by_name[self.main_dttm_col].get_time_filter( + # start_dttm=from_dttm, + # end_dttm=to_dttm, + # template_processor=template_processor, + # ) + # ) + # time_filters.append( + # dttm_col.get_time_filter( + # start_dttm=from_dttm, + # end_dttm=to_dttm, + # template_processor=template_processor, + # ) + # ) + + # # Always remove duplicates by column name, as sometimes `metrics_exprs` + # # can have the same name as a groupby column (e.g. when users use + # # raw columns as custom SQL adhoc metric). + # select_exprs = remove_duplicates( + # select_exprs + metrics_exprs, key=lambda x: x.name + # ) + + # # Expected output columns + # labels_expected = [c.key for c in select_exprs] + + # # Order by columns are "hidden" columns, some databases require them + # # always be present in SELECT if an aggregation function is used + # if not db_engine_spec.allows_hidden_ordeby_agg: + # select_exprs = remove_duplicates(select_exprs + orderby_exprs) + + # qry = sa.select(select_exprs) + + # tbl, cte = self.get_from_clause(template_processor) + + # if groupby_all_columns: + # qry = qry.group_by(*groupby_all_columns.values()) + + # where_clause_and = [] + # having_clause_and = [] + + # for flt in filter: # type: ignore + # if not all(flt.get(s) for s in ["col", "op"]): + # continue + # flt_col = flt["col"] + # val = flt.get("val") + # op = flt["op"].upper() + # col_obj: Optional[TableColumn] = None + # sqla_col: Optional[Column] = None + # if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: + # col_obj = dttm_col + # elif is_adhoc_column(flt_col): + # sqla_col = self.adhoc_column_to_sqla(flt_col) + # else: + # col_obj = columns_by_name.get(flt_col) + # filter_grain = flt.get("grain") + + # if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): + # if get_column_name(flt_col) in removed_filters: + # # Skip generating SQLA filter when the jinja template handles it. + # continue + + # if col_obj or sqla_col is not None: + # if sqla_col is not None: + # pass + # elif col_obj and filter_grain: + # sqla_col = col_obj.get_timestamp_expression( + # time_grain=filter_grain, template_processor=template_processor + # ) + # elif col_obj: + # sqla_col = col_obj.get_sqla_col( + # template_processor=template_processor + # ) + # col_type = col_obj.type if col_obj else None + # col_spec = db_engine_spec.get_column_spec( + # native_type=col_type, + # db_extra=self.database.get_extra(), + # ) + # is_list_target = op in ( + # utils.FilterOperator.IN.value, + # utils.FilterOperator.NOT_IN.value, + # ) + + # col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" + + # if col_spec and not col_advanced_data_type: + # target_generic_type = col_spec.generic_type + # else: + # target_generic_type = GenericDataType.STRING + # eq = self.filter_values_handler( + # values=val, + # operator=op, + # target_generic_type=target_generic_type, + # target_native_type=col_type, + # is_list_target=is_list_target, + # db_engine_spec=db_engine_spec, + # db_extra=self.database.get_extra(), + # ) + # if ( + # col_advanced_data_type != "" + # and feature_flag_manager.is_feature_enabled( + # "ENABLE_ADVANCED_DATA_TYPES" + # ) + # and col_advanced_data_type in ADVANCED_DATA_TYPES + # ): + # values = eq if is_list_target else [eq] # type: ignore + # bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ + # col_advanced_data_type + # ].translate_type( + # { + # "type": col_advanced_data_type, + # "values": values, + # } + # ) + # if bus_resp["error_message"]: + # raise AdvancedDataTypeResponseError( + # _(bus_resp["error_message"]) + # ) + + # where_clause_and.append( + # ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( + # sqla_col, op, bus_resp["values"] + # ) + # ) + # elif is_list_target: + # assert isinstance(eq, (tuple, list)) + # if len(eq) == 0: + # raise QueryObjectValidationError( + # _("Filter value list cannot be empty") + # ) + # if len(eq) > len( + # eq_without_none := [x for x in eq if x is not None] + # ): + # is_null_cond = sqla_col.is_(None) + # if eq: + # cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) + # else: + # cond = is_null_cond + # else: + # cond = sqla_col.in_(eq) + # if op == utils.FilterOperator.NOT_IN.value: + # cond = ~cond + # where_clause_and.append(cond) + # elif op == utils.FilterOperator.IS_NULL.value: + # where_clause_and.append(sqla_col.is_(None)) + # elif op == utils.FilterOperator.IS_NOT_NULL.value: + # where_clause_and.append(sqla_col.isnot(None)) + # elif op == utils.FilterOperator.IS_TRUE.value: + # where_clause_and.append(sqla_col.is_(True)) + # elif op == utils.FilterOperator.IS_FALSE.value: + # where_clause_and.append(sqla_col.is_(False)) + # else: + # if ( + # op + # not in { + # utils.FilterOperator.EQUALS.value, + # utils.FilterOperator.NOT_EQUALS.value, + # } + # and eq is None + # ): + # raise QueryObjectValidationError( + # _( + # "Must specify a value for filters " + # "with comparison operators" + # ) + # ) + # if op == utils.FilterOperator.EQUALS.value: + # where_clause_and.append(sqla_col == eq) + # elif op == utils.FilterOperator.NOT_EQUALS.value: + # where_clause_and.append(sqla_col != eq) + # elif op == utils.FilterOperator.GREATER_THAN.value: + # where_clause_and.append(sqla_col > eq) + # elif op == utils.FilterOperator.LESS_THAN.value: + # where_clause_and.append(sqla_col < eq) + # elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: + # where_clause_and.append(sqla_col >= eq) + # elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: + # where_clause_and.append(sqla_col <= eq) + # elif op == utils.FilterOperator.LIKE.value: + # where_clause_and.append(sqla_col.like(eq)) + # elif op == utils.FilterOperator.ILIKE.value: + # where_clause_and.append(sqla_col.ilike(eq)) + # elif ( + # op == utils.FilterOperator.TEMPORAL_RANGE.value + # and isinstance(eq, str) + # and col_obj is not None + # ): + # _since, _until = get_since_until_from_time_range( + # time_range=eq, + # time_shift=time_shift, + # extras=extras, + # ) + # where_clause_and.append( + # col_obj.get_time_filter( + # start_dttm=_since, + # end_dttm=_until, + # label=sqla_col.key, + # template_processor=template_processor, + # ) + # ) + # else: + # raise QueryObjectValidationError( + # _("Invalid filter operation type: %(op)s", op=op) + # ) + # where_clause_and += self.get_sqla_row_level_filters(template_processor) + # if extras: + # where = extras.get("where") + # if where: + # try: + # where = template_processor.process_template(f"({where})") + # except TemplateError as ex: + # raise QueryObjectValidationError( + # _( + # "Error in jinja expression in WHERE clause: %(msg)s", + # msg=ex.message, + # ) + # ) from ex + # where = _process_sql_expression( + # expression=where, + # database_id=self.database_id, + # schema=self.schema, + # ) + # where_clause_and += [self.text(where)] + # having = extras.get("having") + # if having: + # try: + # having = template_processor.process_template(f"({having})") + # except TemplateError as ex: + # raise QueryObjectValidationError( + # _( + # "Error in jinja expression in HAVING clause: %(msg)s", + # msg=ex.message, + # ) + # ) from ex + # having = _process_sql_expression( + # expression=having, + # database_id=self.database_id, + # schema=self.schema, + # ) + # having_clause_and += [self.text(having)] + + # if apply_fetch_values_predicate and self.fetch_values_predicate: + # qry = qry.where( + # self.get_fetch_values_predicate(template_processor=template_processor) + # ) + # if granularity: + # qry = qry.where(and_(*(time_filters + where_clause_and))) + # else: + # qry = qry.where(and_(*where_clause_and)) + # qry = qry.having(and_(*having_clause_and)) + + # self.make_orderby_compatible(select_exprs, orderby_exprs) + + # for col, (orig_col, ascending) in zip(orderby_exprs, orderby): + # if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): + # # if engine does not allow using SELECT alias in ORDER BY + # # revert to the underlying column + # col = col.element + + # if ( + # db_engine_spec.allows_alias_in_select + # and db_engine_spec.allows_hidden_cc_in_orderby + # and col.name in [select_col.name for select_col in select_exprs] + # ): + # col = literal_column(col.name) + # direction = asc if ascending else desc + # qry = qry.order_by(direction(col)) + + # if row_limit: + # qry = qry.limit(row_limit) + # if row_offset: + # qry = qry.offset(row_offset) + + # if series_limit and groupby_series_columns: + # if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries: + # # some sql dialects require for order by expressions + # # to also be in the select clause -- others, e.g. vertica, + # # require a unique inner alias + # inner_main_metric_expr = self.make_sqla_column_compatible( + # main_metric_expr, "mme_inner__" + # ) + # inner_groupby_exprs = [] + # inner_select_exprs = [] + # for gby_name, gby_obj in groupby_series_columns.items(): + # label = get_column_name(gby_name) + # inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") + # inner_groupby_exprs.append(inner) + # inner_select_exprs.append(inner) + + # inner_select_exprs += [inner_main_metric_expr] + # subq = select(inner_select_exprs).select_from(tbl) + # inner_time_filter = [] + + # if dttm_col and not db_engine_spec.time_groupby_inline: + # inner_time_filter = [ + # dttm_col.get_time_filter( + # start_dttm=inner_from_dttm or from_dttm, + # end_dttm=inner_to_dttm or to_dttm, + # template_processor=template_processor, + # ) + # ] + # subq = subq.where(and_(*(where_clause_and + inner_time_filter))) + # subq = subq.group_by(*inner_groupby_exprs) + + # ob = inner_main_metric_expr + # if series_limit_metric: + # ob = self._get_series_orderby( + # series_limit_metric=series_limit_metric, + # metrics_by_name=metrics_by_name, + # columns_by_name=columns_by_name, + # template_processor=template_processor, + # ) + # direction = desc if order_desc else asc + # subq = subq.order_by(direction(ob)) + # subq = subq.limit(series_limit) + + # on_clause = [] + # for gby_name, gby_obj in groupby_series_columns.items(): + # # in this case the column name, not the alias, needs to be + # # conditionally mutated, as it refers to the column alias in + # # the inner query + # col_name = db_engine_spec.make_label_compatible(gby_name + "__") + # on_clause.append(gby_obj == column(col_name)) + + # tbl = tbl.join(subq.alias(), and_(*on_clause)) + # else: + # if series_limit_metric: + # orderby = [ + # ( + # self._get_series_orderby( + # series_limit_metric=series_limit_metric, + # metrics_by_name=metrics_by_name, + # columns_by_name=columns_by_name, + # template_processor=template_processor, + # ), + # not order_desc, + # ) + # ] + + # # run prequery to get top groups + # prequery_obj = { + # "is_timeseries": False, + # "row_limit": series_limit, + # "metrics": metrics, + # "granularity": granularity, + # "groupby": groupby, + # "from_dttm": inner_from_dttm or from_dttm, + # "to_dttm": inner_to_dttm or to_dttm, + # "filter": filter, + # "orderby": orderby, + # "extras": extras, + # "columns": columns, + # "order_desc": True, + # } + + # result = self.query(prequery_obj) + # prequeries.append(result.query) + # dimensions = [ + # c + # for c in result.df.columns + # if c not in metrics and c in groupby_series_columns + # ] + # top_groups = self._get_top_groups( + # result.df, dimensions, groupby_series_columns, columns_by_name + # ) + # qry = qry.where(top_groups) + + # qry = qry.select_from(tbl) + + # if is_rowcount: + # if not db_engine_spec.allows_subqueries: + # raise QueryObjectValidationError( + # _("Database does not support subqueries") + # ) + # label = "rowcount" + # col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) + # qry = select([col]).select_from(qry.alias("rowcount_qry")) + # labels_expected = [label] + + # return SqlaQuery( + # applied_template_filters=applied_template_filters, + # cte=cte, + # extra_cache_keys=extra_cache_keys, + # labels_expected=labels_expected, + # sqla_query=qry, + # prequeries=prequeries, + # ) def _get_series_orderby( self, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index fd0a1eff5ca7e..26b07c6e54330 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -26,8 +26,8 @@ Any, cast, Dict, + Hashable, List, - Mapping, NamedTuple, Optional, Set, @@ -87,7 +87,13 @@ QueryObjectDict, ) from superset.utils import core as utils -from superset.utils.core import get_user_id +from superset.utils.core import ( + GenericDataType, + get_column_name, + get_user_id, + is_adhoc_column, + remove_duplicates, +) if TYPE_CHECKING: from superset.connectors.sqla.models import SqlMetric, TableColumn @@ -680,7 +686,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods } @property - def query(self) -> str: + def fetch_value_predicate(self) -> str: + return "fix this!" + + def query(self, query_obj: QueryObjectDict) -> QueryResult: raise NotImplementedError() @property @@ -747,13 +756,18 @@ def columns(self) -> List[Any]: def get_fetch_values_predicate(self) -> List[Any]: raise NotImplementedError() - @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: + def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: raise NotImplementedError() def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: raise NotImplementedError() + def get_sqla_row_level_filters( + self, + template_processor: BaseTemplateProcessor, + ) -> List[TextClause]: + raise NotImplementedError() + def _process_sql_expression( # pylint: disable=no-self-use self, expression: Optional[str], @@ -1156,13 +1170,14 @@ def get_query_str(self, query_obj: QueryObjectDict) -> str: def _get_series_orderby( self, series_limit_metric: Metric, - metrics_by_name: Mapping[str, "SqlMetric"], - columns_by_name: Mapping[str, "TableColumn"], + metrics_by_name: Dict[str, "SqlMetric"], + columns_by_name: Dict[str, "TableColumn"], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) ob = self.adhoc_metric_to_sqla( - series_limit_metric, columns_by_name # type: ignore + series_limit_metric, columns_by_name ) elif ( isinstance(series_limit_metric, str) @@ -1180,23 +1195,24 @@ def adhoc_column_to_sqla( col: Type["AdhocColumn"], # type: ignore template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - """ - Turn an adhoc column into a sqlalchemy column. - - :param col: Adhoc column definition - :param template_processor: template_processor instance - :returns: The metric defined as a sqlalchemy column - :rtype: sqlalchemy.sql.column - """ - label = utils.get_column_name(col) # type: ignore - expression = self._process_sql_expression( - expression=col["sqlExpression"], - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) - sqla_column = literal_column(expression) - return self.make_sqla_column_compatible(sqla_column, label) + raise NotImplementedError() + # """ + # Turn an adhoc column into a sqlalchemy column. + + # :param col: Adhoc column definition + # :param template_processor: template_processor instance + # :returns: The metric defined as a sqlalchemy column + # :rtype: sqlalchemy.sql.column + # """ + # label = utils.get_column_name(col) # type: ignore + # expression = self._process_sql_expression( + # expression=col["sqlExpression"], + # database_id=self.database_id, + # schema=self.schema, + # template_processor=template_processor, + # ) + # sqla_column = literal_column(expression) + # return self.make_sqla_column_compatible(sqla_column, label) def _get_top_groups( self, @@ -1371,7 +1387,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma "time_column": granularity, "time_grain": time_grain, "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": [col.get("column_name") for col in self.columns], + "table_columns": [col.column_name for col in self.columns], "filter": filter, } columns = columns or [] @@ -1399,11 +1415,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col - columns_by_name: Dict[str, "TableColumn"] = { - col.get("column_name"): col - for col in self.columns # col.column_name: col for col in self.columns + columns_by_name: Dict[str, TableColumn] = { + col.column_name: col for col in self.columns } + metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} + if not granularity and is_timeseries: raise QueryObjectValidationError( _( @@ -1425,6 +1442,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma template_processor=template_processor, ) ) + elif isinstance(metric, str) and metric in metrics_by_name: + metrics_exprs.append( + metrics_by_name[metric].get_sqla_col( + template_processor=template_processor + ) + ) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=metric) @@ -1463,14 +1486,17 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: - gb_column_obj = columns_by_name[col] - if isinstance(gb_column_obj, dict): - col = self.get_sqla_col(gb_column_obj) - else: - col = gb_column_obj.get_sqla_col() + col = columns_by_name[col].get_sqla_col( + template_processor=template_processor + ) elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True + elif col in metrics_by_name: + col = metrics_by_name[col].get_sqla_col( + template_processor=template_processor + ) + need_groupby = True if isinstance(col, ColumnElement): orderby_exprs.append(col) @@ -1496,33 +1522,23 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma # if groupby field/expr equals granularity field/expr if selected == granularity: table_col = columns_by_name[selected] - if isinstance(table_col, dict): - outer = self.get_timestamp_expression( - column=table_col, - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) - else: - outer = table_col.get_timestamp_expression( - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) + outer = table_col.get_timestamp_expression( + time_grain=time_grain, + label=selected, + template_processor=template_processor, + ) # if groupby field equals a selected column elif selected in columns_by_name: - if isinstance(columns_by_name[selected], dict): - outer = sa.column(f"{selected}") - outer = self.make_sqla_column_compatible(outer, selected) - else: - outer = columns_by_name[selected].get_sqla_col() + outer = columns_by_name[selected].get_sqla_col( + template_processor=template_processor + ) else: - selected = self.validate_adhoc_subquery( + selected = validate_adhoc_subquery( selected, self.database_id, self.schema, ) - outer = sa.column(f"{selected}") + outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: outer = self.adhoc_column_to_sqla( @@ -1536,19 +1552,27 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma select_exprs.append(outer) elif columns: for selected in columns: - selected = self.validate_adhoc_subquery( - selected, + if is_adhoc_column(selected): + _sql = selected["sqlExpression"] + _column_label = selected["label"] + elif isinstance(selected, str): + _sql = selected + _column_label = selected + + selected = validate_adhoc_subquery( + _sql, self.database_id, self.schema, ) - if isinstance(columns_by_name[selected], dict): - select_exprs.append(sa.column(f"{selected}")) - else: - select_exprs.append( - columns_by_name[selected].get_sqla_col() - if selected in columns_by_name - else self.make_sqla_column_compatible(literal_column(selected)) + select_exprs.append( + columns_by_name[selected].get_sqla_col( + template_processor=template_processor ) + if isinstance(selected, str) and selected in columns_by_name + else self.make_sqla_column_compatible( + literal_column(selected), _column_label + ) + ) metrics_exprs = [] if granularity: @@ -1559,57 +1583,41 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col=granularity, ) ) - time_filters: List[Any] = [] + time_filters = [] if is_timeseries: - if isinstance(dttm_col, dict): - timestamp = self.get_timestamp_expression( - dttm_col, time_grain, template_processor=template_processor - ) - else: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) + timestamp = dttm_col.get_timestamp_expression( + time_grain=time_grain, template_processor=template_processor + ) # always put timestamp as the first column select_exprs.insert(0, timestamp) groupby_all_columns[timestamp.name] = timestamp # Use main dttm column to support index with secondary dttm columns. - if db_engine_spec.time_secondary_columns: - if isinstance(dttm_col, dict): - dttm_col_name = dttm_col.get("column_name") - else: - dttm_col_name = dttm_col.column_name - - if ( - self.main_dttm_col in self.dttm_cols - and self.main_dttm_col != dttm_col_name - ): - if isinstance(self.main_dttm_col, dict): - time_filters.append( - self.get_time_filter( - self.main_dttm_col, - from_dttm, - to_dttm, - ) - ) - else: - time_filters.append( - columns_by_name[self.main_dttm_col].get_time_filter( - from_dttm, - to_dttm, - ) - ) - - if isinstance(dttm_col, dict): - time_filters.append(self.get_time_filter(dttm_col, from_dttm, to_dttm)) - else: - time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) + if ( + db_engine_spec.time_secondary_columns + and self.main_dttm_col in self.dttm_cols + and self.main_dttm_col != dttm_col.column_name + ): + time_filters.append( + columns_by_name[self.main_dttm_col].get_time_filter( + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + ) + time_filters.append( + dttm_col.get_time_filter( + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + ) # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use # raw columns as custom SQL adhoc metric). - select_exprs = utils.remove_duplicates( + select_exprs = remove_duplicates( select_exprs + metrics_exprs, key=lambda x: x.name ) @@ -1619,7 +1627,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma # Order by columns are "hidden" columns, some databases require them # always be present in SELECT if an aggregation function is used if not db_engine_spec.allows_hidden_ordeby_agg: - select_exprs = utils.remove_duplicates(select_exprs + orderby_exprs) + select_exprs = remove_duplicates(select_exprs + orderby_exprs) qry = sa.select(select_exprs) @@ -1637,18 +1645,18 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma flt_col = flt["col"] val = flt.get("val") op = flt["op"].upper() - col_obj: Optional["TableColumn"] = None + col_obj: Optional[TableColumn] = None sqla_col: Optional[Column] = None if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col - elif utils.is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore + elif is_adhoc_column(flt_col): + sqla_col = self.adhoc_column_to_sqla(col=flt_col, template_processor=template_processor) # type: ignore else: col_obj = columns_by_name.get(flt_col) filter_grain = flt.get("grain") if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - if utils.get_column_name(flt_col) in removed_filters: + if get_column_name(flt_col) in removed_filters: # Skip generating SQLA filter when the jinja template handles it. continue @@ -1656,44 +1664,29 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if sqla_col is not None: pass elif col_obj and filter_grain: - if isinstance(col_obj, dict): - sqla_col = self.get_timestamp_expression( - col_obj, time_grain, template_processor=template_processor - ) - else: - sqla_col = col_obj.get_timestamp_expression( - time_grain=filter_grain, - template_processor=template_processor, - ) - elif col_obj and isinstance(col_obj, dict): - sqla_col = sa.column(col_obj.get("column_name")) + sqla_col = col_obj.get_timestamp_expression( + time_grain=filter_grain, template_processor=template_processor + ) elif col_obj: - sqla_col = col_obj.get_sqla_col() - - if col_obj and isinstance(col_obj, dict): - col_type = col_obj.get("type") - else: - col_type = col_obj.type if col_obj else None + sqla_col = col_obj.get_sqla_col( + template_processor=template_processor + ) + col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, - db_extra=self.database.get_extra(), # type: ignore +# db_extra=self.database.get_extra(), ) is_list_target = op in ( utils.FilterOperator.IN.value, utils.FilterOperator.NOT_IN.value, ) - if col_obj and isinstance(col_obj, dict): - col_advanced_data_type = "" - else: - col_advanced_data_type = ( - col_obj.advanced_data_type if col_obj else "" - ) + col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" if col_spec and not col_advanced_data_type: target_generic_type = col_spec.generic_type else: - target_generic_type = utils.GenericDataType.STRING + target_generic_type = GenericDataType.STRING eq = self.filter_values_handler( values=val, operator=op, @@ -1701,7 +1694,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma target_native_type=col_type, is_list_target=is_list_target, db_engine_spec=db_engine_spec, - db_extra=self.database.get_extra(), # type: ignore +# db_extra=self.database.get_extra(), ) if ( col_advanced_data_type != "" @@ -1757,7 +1750,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma elif op == utils.FilterOperator.IS_FALSE.value: where_clause_and.append(sqla_col.is_(False)) else: - if eq is None: + if ( + op + not in { + utils.FilterOperator.EQUALS.value, + utils.FilterOperator.NOT_EQUALS.value, + } + and eq is None + ): raise QueryObjectValidationError( _( "Must specify a value for filters " @@ -1791,23 +1791,23 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma extras=extras, ) where_clause_and.append( - self.get_time_filter( - time_col=col_obj, + col_obj.get_time_filter( start_dttm=_since, end_dttm=_until, + label=sqla_col.key, + template_processor=template_processor, ) ) else: raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) ) - # todo(hugh): fix this w/ template_processor - # where_clause_and += self.get_sqla_row_level_filters(template_processor) + where_clause_and += self.get_sqla_row_level_filters(template_processor) if extras: where = extras.get("where") if where: try: - where = template_processor.process_template(f"{where}") + where = template_processor.process_template(f"({where})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1815,11 +1815,17 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex + where = self._process_sql_expression( + expression=where, + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor + ) where_clause_and += [self.text(where)] having = extras.get("having") if having: try: - having = template_processor.process_template(f"{having}") + having = template_processor.process_template(f"({having})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1827,9 +1833,18 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex + having = self._process_sql_expression( + expression=having, + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) having_clause_and += [self.text(having)] - if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore - qry = qry.where(self.get_fetch_values_predicate()) # type: ignore + + if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore + qry = qry.where( + self.get_fetch_values_predicate(template_processor=template_processor) # type: ignore + ) if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: @@ -1869,7 +1884,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma inner_groupby_exprs = [] inner_select_exprs = [] for gby_name, gby_obj in groupby_series_columns.items(): - label = utils.get_column_name(gby_name) + label = get_column_name(gby_name) inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) @@ -1879,26 +1894,24 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma inner_time_filter = [] if dttm_col and not db_engine_spec.time_groupby_inline: - if isinstance(dttm_col, dict): - inner_time_filter = [ - self.get_time_filter( - dttm_col, - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - ] - else: - inner_time_filter = [ - dttm_col.get_time_filter( - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - ] - + inner_time_filter = [ + dttm_col.get_time_filter( + start_dttm=inner_from_dttm or from_dttm, + end_dttm=inner_to_dttm or to_dttm, + template_processor=template_processor, + ) + ] subq = subq.where(and_(*(where_clause_and + inner_time_filter))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr + if series_limit_metric: + ob = self._get_series_orderby( + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, + ) direction = sa.desc if order_desc else sa.asc subq = subq.order_by(direction(ob)) subq = subq.limit(series_limit) @@ -1912,6 +1925,19 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma on_clause.append(gby_obj == sa.column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) + else: + if series_limit_metric: + orderby = [ + ( + self._get_series_orderby( + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, + ), + not order_desc, + ) + ] # run prequery to get top groups prequery_obj = { @@ -1928,7 +1954,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma "columns": columns, "order_desc": True, } - result = self.exc_query(prequery_obj) + + result = self.query(prequery_obj) prequeries.append(result.query) dimensions = [ c diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index babea35baf39b..5ccba99975bf8 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -33,6 +33,7 @@ DateTime, Enum, ForeignKey, + Hashable, Integer, Numeric, String, @@ -307,7 +308,7 @@ def default_endpoint(self) -> str: return "" @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: + def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]: return [] @property diff --git a/superset/result_set.py b/superset/result_set.py index 3d29673b9fcb9..63d48b1e4bcab 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -70,9 +70,9 @@ def stringify_values(array: NDArray[Any]) -> NDArray[Any]: for obj in it: if na_obj := pd.isna(obj): # pandas type cannot be converted to string - obj[na_obj] = None # type: ignore + obj[na_obj] = None else: - obj[...] = stringify(obj) # type: ignore + obj[...] = stringify(obj) return result diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py index 673c39ebf3836..e2706345b1ea9 100644 --- a/superset/utils/pandas_postprocessing/boxplot.py +++ b/superset/utils/pandas_postprocessing/boxplot.py @@ -57,10 +57,10 @@ def boxplot( """ def quartile1(series: Series) -> float: - return np.nanpercentile(series, 25, interpolation="midpoint") # type: ignore + return np.nanpercentile(series, 25, interpolation="midpoint") def quartile3(series: Series) -> float: - return np.nanpercentile(series, 75, interpolation="midpoint") # type: ignore + return np.nanpercentile(series, 75, interpolation="midpoint") if whisker_type == PostProcessingBoxplotWhiskerType.TUKEY: @@ -99,8 +99,8 @@ def whisker_low(series: Series) -> float: return np.nanpercentile(series, low) else: - whisker_high = np.max # type: ignore - whisker_low = np.min # type: ignore + whisker_high = np.max + whisker_low = np.min def outliers(series: Series) -> Set[float]: above = series[series > whisker_high(series)] diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py index 1026164e454ee..db783c4bed264 100644 --- a/superset/utils/pandas_postprocessing/flatten.py +++ b/superset/utils/pandas_postprocessing/flatten.py @@ -85,7 +85,7 @@ def flatten( _columns = [] for series in df.columns.to_flat_index(): _cells = [] - for cell in series if is_sequence(series) else [series]: # type: ignore + for cell in series if is_sequence(series) else [series]: if pd.notnull(cell): # every cell should be converted to string and escape comma _cells.append(escape_separator(str(cell))) From 4528addf1172ce34bb2510123a32fb034bf3088e Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 24 Jan 2023 18:51:16 +0200 Subject: [PATCH 02/29] mk1: working explore with sqlatable model with exploremixin --- superset/connectors/sqla/models.py | 646 ----------------------------- 1 file changed, 646 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index bfdec8b996c20..350c03a793179 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -138,8 +138,6 @@ remove_duplicates, ) -from superset.models.helpers import ExploreMixin - config = app.config metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) @@ -1143,650 +1141,6 @@ def get_sqla_row_level_filters( def text(self, clause: str) -> TextClause: return self.db_engine_spec.get_text_clause(clause) - # def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - # self, - # apply_fetch_values_predicate: bool = False, - # columns: Optional[List[ColumnTyping]] = None, - # extras: Optional[Dict[str, Any]] = None, - # filter: Optional[ # pylint: disable=redefined-builtin - # List[QueryObjectFilterClause] - # ] = None, - # from_dttm: Optional[datetime] = None, - # granularity: Optional[str] = None, - # groupby: Optional[List[Column]] = None, - # inner_from_dttm: Optional[datetime] = None, - # inner_to_dttm: Optional[datetime] = None, - # is_rowcount: bool = False, - # is_timeseries: bool = True, - # metrics: Optional[List[Metric]] = None, - # orderby: Optional[List[OrderBy]] = None, - # order_desc: bool = True, - # to_dttm: Optional[datetime] = None, - # series_columns: Optional[List[Column]] = None, - # series_limit: Optional[int] = None, - # series_limit_metric: Optional[Metric] = None, - # row_limit: Optional[int] = None, - # row_offset: Optional[int] = None, - # timeseries_limit: Optional[int] = None, - # timeseries_limit_metric: Optional[Metric] = None, - # time_shift: Optional[str] = None, - # ) -> SqlaQuery: - # """Querying any sqla table from this common interface""" - # if granularity not in self.dttm_cols and granularity is not None: - # granularity = self.main_dttm_col - - # extras = extras or {} - # time_grain = extras.get("time_grain_sqla") - - # template_kwargs = { - # "columns": columns, - # "from_dttm": from_dttm.isoformat() if from_dttm else None, - # "groupby": groupby, - # "metrics": metrics, - # "row_limit": row_limit, - # "row_offset": row_offset, - # "time_column": granularity, - # "time_grain": time_grain, - # "to_dttm": to_dttm.isoformat() if to_dttm else None, - # "table_columns": [col.column_name for col in self.columns], - # "filter": filter, - # } - # columns = columns or [] - # groupby = groupby or [] - # series_column_names = utils.get_column_names(series_columns or []) - # # deprecated, to be removed in 2.0 - # if is_timeseries and timeseries_limit: - # series_limit = timeseries_limit - # series_limit_metric = series_limit_metric or timeseries_limit_metric - # template_kwargs.update(self.template_params_dict) - # extra_cache_keys: List[Any] = [] - # template_kwargs["extra_cache_keys"] = extra_cache_keys - # removed_filters: List[str] = [] - # applied_template_filters: List[str] = [] - # template_kwargs["removed_filters"] = removed_filters - # template_kwargs["applied_filters"] = applied_template_filters - # template_processor = self.get_template_processor(**template_kwargs) - # db_engine_spec = self.db_engine_spec - # prequeries: List[str] = [] - # orderby = orderby or [] - # need_groupby = bool(metrics is not None or groupby) - # metrics = metrics or [] - - # # For backward compatibility - # if granularity not in self.dttm_cols and granularity is not None: - # granularity = self.main_dttm_col - - # columns_by_name: Dict[str, TableColumn] = { - # col.column_name: col for col in self.columns - # } - - # metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} - - # if not granularity and is_timeseries: - # raise QueryObjectValidationError( - # _( - # "Datetime column not provided as part table configuration " - # "and is required by this type of chart" - # ) - # ) - # if not metrics and not columns and not groupby: - # raise QueryObjectValidationError(_("Empty query?")) - - # metrics_exprs: List[ColumnElement] = [] - # for metric in metrics: - # if utils.is_adhoc_metric(metric): - # assert isinstance(metric, dict) - # metrics_exprs.append( - # self.adhoc_metric_to_sqla( - # metric=metric, - # columns_by_name=columns_by_name, - # template_processor=template_processor, - # ) - # ) - # elif isinstance(metric, str) and metric in metrics_by_name: - # metrics_exprs.append( - # metrics_by_name[metric].get_sqla_col( - # template_processor=template_processor - # ) - # ) - # else: - # raise QueryObjectValidationError( - # _("Metric '%(metric)s' does not exist", metric=metric) - # ) - - # if metrics_exprs: - # main_metric_expr = metrics_exprs[0] - # else: - # main_metric_expr, label = literal_column("COUNT(*)"), "ccount" - # main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) - - # # To ensure correct handling of the ORDER BY labeling we need to reference the - # # metric instance if defined in the SELECT clause. - # # use the key of the ColumnClause for the expected label - # metrics_exprs_by_label = {m.key: m for m in metrics_exprs} - # metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} - - # # Since orderby may use adhoc metrics, too; we need to process them first - # orderby_exprs: List[ColumnElement] = [] - # for orig_col, ascending in orderby: - # col: Union[AdhocMetric, ColumnElement] = orig_col - # if isinstance(col, dict): - # col = cast(AdhocMetric, col) - # if col.get("sqlExpression"): - # col["sqlExpression"] = _process_sql_expression( - # expression=col["sqlExpression"], - # database_id=self.database_id, - # schema=self.schema, - # template_processor=template_processor, - # ) - # if utils.is_adhoc_metric(col): - # # add adhoc sort by column to columns_by_name if not exists - # col = self.adhoc_metric_to_sqla(col, columns_by_name) - # # if the adhoc metric has been defined before - # # use the existing instance. - # col = metrics_exprs_by_expr.get(str(col), col) - # need_groupby = True - # elif col in columns_by_name: - # col = columns_by_name[col].get_sqla_col( - # template_processor=template_processor - # ) - # elif col in metrics_exprs_by_label: - # col = metrics_exprs_by_label[col] - # need_groupby = True - # elif col in metrics_by_name: - # col = metrics_by_name[col].get_sqla_col( - # template_processor=template_processor - # ) - # need_groupby = True - - # if isinstance(col, ColumnElement): - # orderby_exprs.append(col) - # else: - # # Could not convert a column reference to valid ColumnElement - # raise QueryObjectValidationError( - # _("Unknown column used in orderby: %(col)s", col=orig_col) - # ) - - # select_exprs: List[Union[Column, Label]] = [] - # groupby_all_columns = {} - # groupby_series_columns = {} - - # # filter out the pseudo column __timestamp from columns - # columns = [col for col in columns if col != utils.DTTM_ALIAS] - # dttm_col = columns_by_name.get(granularity) if granularity else None - - # if need_groupby: - # # dedup columns while preserving order - # columns = groupby or columns - # for selected in columns: - # if isinstance(selected, str): - # # if groupby field/expr equals granularity field/expr - # if selected == granularity: - # table_col = columns_by_name[selected] - # outer = table_col.get_timestamp_expression( - # time_grain=time_grain, - # label=selected, - # template_processor=template_processor, - # ) - # # if groupby field equals a selected column - # elif selected in columns_by_name: - # outer = columns_by_name[selected].get_sqla_col( - # template_processor=template_processor - # ) - # else: - # selected = validate_adhoc_subquery( - # selected, - # self.database_id, - # self.schema, - # ) - # outer = literal_column(f"({selected})") - # outer = self.make_sqla_column_compatible(outer, selected) - # else: - # outer = self.adhoc_column_to_sqla( - # col=selected, template_processor=template_processor - # ) - # groupby_all_columns[outer.name] = outer - # if ( - # is_timeseries and not series_column_names - # ) or outer.name in series_column_names: - # groupby_series_columns[outer.name] = outer - # select_exprs.append(outer) - # elif columns: - # for selected in columns: - # if is_adhoc_column(selected): - # _sql = selected["sqlExpression"] - # _column_label = selected["label"] - # elif isinstance(selected, str): - # _sql = selected - # _column_label = selected - - # selected = validate_adhoc_subquery( - # _sql, - # self.database_id, - # self.schema, - # ) - # select_exprs.append( - # columns_by_name[selected].get_sqla_col( - # template_processor=template_processor - # ) - # if isinstance(selected, str) and selected in columns_by_name - # else self.make_sqla_column_compatible( - # literal_column(selected), _column_label - # ) - # ) - # metrics_exprs = [] - - # if granularity: - # if granularity not in columns_by_name or not dttm_col: - # raise QueryObjectValidationError( - # _( - # 'Time column "%(col)s" does not exist in dataset', - # col=granularity, - # ) - # ) - # time_filters = [] - - # if is_timeseries: - # timestamp = dttm_col.get_timestamp_expression( - # time_grain=time_grain, template_processor=template_processor - # ) - # # always put timestamp as the first column - # select_exprs.insert(0, timestamp) - # groupby_all_columns[timestamp.name] = timestamp - - # # Use main dttm column to support index with secondary dttm columns. - # if ( - # db_engine_spec.time_secondary_columns - # and self.main_dttm_col in self.dttm_cols - # and self.main_dttm_col != dttm_col.column_name - # ): - # time_filters.append( - # columns_by_name[self.main_dttm_col].get_time_filter( - # start_dttm=from_dttm, - # end_dttm=to_dttm, - # template_processor=template_processor, - # ) - # ) - # time_filters.append( - # dttm_col.get_time_filter( - # start_dttm=from_dttm, - # end_dttm=to_dttm, - # template_processor=template_processor, - # ) - # ) - - # # Always remove duplicates by column name, as sometimes `metrics_exprs` - # # can have the same name as a groupby column (e.g. when users use - # # raw columns as custom SQL adhoc metric). - # select_exprs = remove_duplicates( - # select_exprs + metrics_exprs, key=lambda x: x.name - # ) - - # # Expected output columns - # labels_expected = [c.key for c in select_exprs] - - # # Order by columns are "hidden" columns, some databases require them - # # always be present in SELECT if an aggregation function is used - # if not db_engine_spec.allows_hidden_ordeby_agg: - # select_exprs = remove_duplicates(select_exprs + orderby_exprs) - - # qry = sa.select(select_exprs) - - # tbl, cte = self.get_from_clause(template_processor) - - # if groupby_all_columns: - # qry = qry.group_by(*groupby_all_columns.values()) - - # where_clause_and = [] - # having_clause_and = [] - - # for flt in filter: # type: ignore - # if not all(flt.get(s) for s in ["col", "op"]): - # continue - # flt_col = flt["col"] - # val = flt.get("val") - # op = flt["op"].upper() - # col_obj: Optional[TableColumn] = None - # sqla_col: Optional[Column] = None - # if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: - # col_obj = dttm_col - # elif is_adhoc_column(flt_col): - # sqla_col = self.adhoc_column_to_sqla(flt_col) - # else: - # col_obj = columns_by_name.get(flt_col) - # filter_grain = flt.get("grain") - - # if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - # if get_column_name(flt_col) in removed_filters: - # # Skip generating SQLA filter when the jinja template handles it. - # continue - - # if col_obj or sqla_col is not None: - # if sqla_col is not None: - # pass - # elif col_obj and filter_grain: - # sqla_col = col_obj.get_timestamp_expression( - # time_grain=filter_grain, template_processor=template_processor - # ) - # elif col_obj: - # sqla_col = col_obj.get_sqla_col( - # template_processor=template_processor - # ) - # col_type = col_obj.type if col_obj else None - # col_spec = db_engine_spec.get_column_spec( - # native_type=col_type, - # db_extra=self.database.get_extra(), - # ) - # is_list_target = op in ( - # utils.FilterOperator.IN.value, - # utils.FilterOperator.NOT_IN.value, - # ) - - # col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" - - # if col_spec and not col_advanced_data_type: - # target_generic_type = col_spec.generic_type - # else: - # target_generic_type = GenericDataType.STRING - # eq = self.filter_values_handler( - # values=val, - # operator=op, - # target_generic_type=target_generic_type, - # target_native_type=col_type, - # is_list_target=is_list_target, - # db_engine_spec=db_engine_spec, - # db_extra=self.database.get_extra(), - # ) - # if ( - # col_advanced_data_type != "" - # and feature_flag_manager.is_feature_enabled( - # "ENABLE_ADVANCED_DATA_TYPES" - # ) - # and col_advanced_data_type in ADVANCED_DATA_TYPES - # ): - # values = eq if is_list_target else [eq] # type: ignore - # bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ - # col_advanced_data_type - # ].translate_type( - # { - # "type": col_advanced_data_type, - # "values": values, - # } - # ) - # if bus_resp["error_message"]: - # raise AdvancedDataTypeResponseError( - # _(bus_resp["error_message"]) - # ) - - # where_clause_and.append( - # ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( - # sqla_col, op, bus_resp["values"] - # ) - # ) - # elif is_list_target: - # assert isinstance(eq, (tuple, list)) - # if len(eq) == 0: - # raise QueryObjectValidationError( - # _("Filter value list cannot be empty") - # ) - # if len(eq) > len( - # eq_without_none := [x for x in eq if x is not None] - # ): - # is_null_cond = sqla_col.is_(None) - # if eq: - # cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) - # else: - # cond = is_null_cond - # else: - # cond = sqla_col.in_(eq) - # if op == utils.FilterOperator.NOT_IN.value: - # cond = ~cond - # where_clause_and.append(cond) - # elif op == utils.FilterOperator.IS_NULL.value: - # where_clause_and.append(sqla_col.is_(None)) - # elif op == utils.FilterOperator.IS_NOT_NULL.value: - # where_clause_and.append(sqla_col.isnot(None)) - # elif op == utils.FilterOperator.IS_TRUE.value: - # where_clause_and.append(sqla_col.is_(True)) - # elif op == utils.FilterOperator.IS_FALSE.value: - # where_clause_and.append(sqla_col.is_(False)) - # else: - # if ( - # op - # not in { - # utils.FilterOperator.EQUALS.value, - # utils.FilterOperator.NOT_EQUALS.value, - # } - # and eq is None - # ): - # raise QueryObjectValidationError( - # _( - # "Must specify a value for filters " - # "with comparison operators" - # ) - # ) - # if op == utils.FilterOperator.EQUALS.value: - # where_clause_and.append(sqla_col == eq) - # elif op == utils.FilterOperator.NOT_EQUALS.value: - # where_clause_and.append(sqla_col != eq) - # elif op == utils.FilterOperator.GREATER_THAN.value: - # where_clause_and.append(sqla_col > eq) - # elif op == utils.FilterOperator.LESS_THAN.value: - # where_clause_and.append(sqla_col < eq) - # elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: - # where_clause_and.append(sqla_col >= eq) - # elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: - # where_clause_and.append(sqla_col <= eq) - # elif op == utils.FilterOperator.LIKE.value: - # where_clause_and.append(sqla_col.like(eq)) - # elif op == utils.FilterOperator.ILIKE.value: - # where_clause_and.append(sqla_col.ilike(eq)) - # elif ( - # op == utils.FilterOperator.TEMPORAL_RANGE.value - # and isinstance(eq, str) - # and col_obj is not None - # ): - # _since, _until = get_since_until_from_time_range( - # time_range=eq, - # time_shift=time_shift, - # extras=extras, - # ) - # where_clause_and.append( - # col_obj.get_time_filter( - # start_dttm=_since, - # end_dttm=_until, - # label=sqla_col.key, - # template_processor=template_processor, - # ) - # ) - # else: - # raise QueryObjectValidationError( - # _("Invalid filter operation type: %(op)s", op=op) - # ) - # where_clause_and += self.get_sqla_row_level_filters(template_processor) - # if extras: - # where = extras.get("where") - # if where: - # try: - # where = template_processor.process_template(f"({where})") - # except TemplateError as ex: - # raise QueryObjectValidationError( - # _( - # "Error in jinja expression in WHERE clause: %(msg)s", - # msg=ex.message, - # ) - # ) from ex - # where = _process_sql_expression( - # expression=where, - # database_id=self.database_id, - # schema=self.schema, - # ) - # where_clause_and += [self.text(where)] - # having = extras.get("having") - # if having: - # try: - # having = template_processor.process_template(f"({having})") - # except TemplateError as ex: - # raise QueryObjectValidationError( - # _( - # "Error in jinja expression in HAVING clause: %(msg)s", - # msg=ex.message, - # ) - # ) from ex - # having = _process_sql_expression( - # expression=having, - # database_id=self.database_id, - # schema=self.schema, - # ) - # having_clause_and += [self.text(having)] - - # if apply_fetch_values_predicate and self.fetch_values_predicate: - # qry = qry.where( - # self.get_fetch_values_predicate(template_processor=template_processor) - # ) - # if granularity: - # qry = qry.where(and_(*(time_filters + where_clause_and))) - # else: - # qry = qry.where(and_(*where_clause_and)) - # qry = qry.having(and_(*having_clause_and)) - - # self.make_orderby_compatible(select_exprs, orderby_exprs) - - # for col, (orig_col, ascending) in zip(orderby_exprs, orderby): - # if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): - # # if engine does not allow using SELECT alias in ORDER BY - # # revert to the underlying column - # col = col.element - - # if ( - # db_engine_spec.allows_alias_in_select - # and db_engine_spec.allows_hidden_cc_in_orderby - # and col.name in [select_col.name for select_col in select_exprs] - # ): - # col = literal_column(col.name) - # direction = asc if ascending else desc - # qry = qry.order_by(direction(col)) - - # if row_limit: - # qry = qry.limit(row_limit) - # if row_offset: - # qry = qry.offset(row_offset) - - # if series_limit and groupby_series_columns: - # if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries: - # # some sql dialects require for order by expressions - # # to also be in the select clause -- others, e.g. vertica, - # # require a unique inner alias - # inner_main_metric_expr = self.make_sqla_column_compatible( - # main_metric_expr, "mme_inner__" - # ) - # inner_groupby_exprs = [] - # inner_select_exprs = [] - # for gby_name, gby_obj in groupby_series_columns.items(): - # label = get_column_name(gby_name) - # inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") - # inner_groupby_exprs.append(inner) - # inner_select_exprs.append(inner) - - # inner_select_exprs += [inner_main_metric_expr] - # subq = select(inner_select_exprs).select_from(tbl) - # inner_time_filter = [] - - # if dttm_col and not db_engine_spec.time_groupby_inline: - # inner_time_filter = [ - # dttm_col.get_time_filter( - # start_dttm=inner_from_dttm or from_dttm, - # end_dttm=inner_to_dttm or to_dttm, - # template_processor=template_processor, - # ) - # ] - # subq = subq.where(and_(*(where_clause_and + inner_time_filter))) - # subq = subq.group_by(*inner_groupby_exprs) - - # ob = inner_main_metric_expr - # if series_limit_metric: - # ob = self._get_series_orderby( - # series_limit_metric=series_limit_metric, - # metrics_by_name=metrics_by_name, - # columns_by_name=columns_by_name, - # template_processor=template_processor, - # ) - # direction = desc if order_desc else asc - # subq = subq.order_by(direction(ob)) - # subq = subq.limit(series_limit) - - # on_clause = [] - # for gby_name, gby_obj in groupby_series_columns.items(): - # # in this case the column name, not the alias, needs to be - # # conditionally mutated, as it refers to the column alias in - # # the inner query - # col_name = db_engine_spec.make_label_compatible(gby_name + "__") - # on_clause.append(gby_obj == column(col_name)) - - # tbl = tbl.join(subq.alias(), and_(*on_clause)) - # else: - # if series_limit_metric: - # orderby = [ - # ( - # self._get_series_orderby( - # series_limit_metric=series_limit_metric, - # metrics_by_name=metrics_by_name, - # columns_by_name=columns_by_name, - # template_processor=template_processor, - # ), - # not order_desc, - # ) - # ] - - # # run prequery to get top groups - # prequery_obj = { - # "is_timeseries": False, - # "row_limit": series_limit, - # "metrics": metrics, - # "granularity": granularity, - # "groupby": groupby, - # "from_dttm": inner_from_dttm or from_dttm, - # "to_dttm": inner_to_dttm or to_dttm, - # "filter": filter, - # "orderby": orderby, - # "extras": extras, - # "columns": columns, - # "order_desc": True, - # } - - # result = self.query(prequery_obj) - # prequeries.append(result.query) - # dimensions = [ - # c - # for c in result.df.columns - # if c not in metrics and c in groupby_series_columns - # ] - # top_groups = self._get_top_groups( - # result.df, dimensions, groupby_series_columns, columns_by_name - # ) - # qry = qry.where(top_groups) - - # qry = qry.select_from(tbl) - - # if is_rowcount: - # if not db_engine_spec.allows_subqueries: - # raise QueryObjectValidationError( - # _("Database does not support subqueries") - # ) - # label = "rowcount" - # col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) - # qry = select([col]).select_from(qry.alias("rowcount_qry")) - # labels_expected = [label] - - # return SqlaQuery( - # applied_template_filters=applied_template_filters, - # cte=cte, - # extra_cache_keys=extra_cache_keys, - # labels_expected=labels_expected, - # sqla_query=qry, - # prequeries=prequeries, - # ) - def _get_series_orderby( self, series_limit_metric: Metric, From 516dde4908b174ac18bb3f339ee1fddcec7d1e44 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 25 Jan 2023 01:21:03 +0200 Subject: [PATCH 03/29] working sqlatable + query with uncommented code that needs to be cleaned --- superset/connectors/sqla/models.py | 4 + superset/models/helpers.py | 156 +++++++++++++++++++++++------ superset/models/sql_lab.py | 30 ++++-- superset/utils/core.py | 13 +-- 4 files changed, 156 insertions(+), 47 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 350c03a793179..2e3f2a3c02018 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -630,6 +630,10 @@ class SqlaTable(Model, BaseDatasource, ExploreMixin): # pylint: disable=too-man def __repr__(self) -> str: # pylint: disable=invalid-repr-returned return self.name + @property + def db_extra(self) -> Dict[str, Any]: + return self.database.get_extra() + @staticmethod def _apply_cte(sql: str, cte: Optional[str]) -> str: """ diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 4ac78cd47aee7..17f8090bef317 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -15,11 +15,12 @@ # specific language governing permissions and limitations # under the License. """a collection of model-related helper classes and functions""" -# pylint: disable=too-many-lines import json import logging import re import uuid +# pylint: disable=too-many-lines +from collections import defaultdict from datetime import datetime, timedelta from json.decoder import JSONDecodeError from typing import ( @@ -705,7 +706,7 @@ def owners_data(self) -> List[Any]: @property def metrics(self) -> List[Any]: - raise NotImplementedError() + return [] @property def uid(self) -> str: @@ -769,7 +770,43 @@ def get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor, ) -> List[TextClause]: - raise NotImplementedError() + """ + Return the appropriate row level security filters for this table and the + current user. A custom username can be passed when the user is not present in the + Flask global namespace. + + :param template_processor: The template processor to apply to the filters. + :returns: A list of SQL clauses to be ANDed together. + """ + all_filters: List[TextClause] = [] + filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) + try: + for filter_ in security_manager.get_rls_filters(self): + clause = self.text( + f"({template_processor.process_template(filter_.clause)})" + ) + if filter_.group_key: + filter_groups[filter_.group_key].append(clause) + else: + all_filters.append(clause) + + if is_feature_enabled("EMBEDDED_SUPERSET"): + for rule in security_manager.get_guest_rls_filters(self): + clause = self.text( + f"({template_processor.process_template(rule['clause'])})" + ) + all_filters.append(clause) + + grouped_filters = [or_(*clauses) for clauses in filter_groups.values()] + all_filters.extend(grouped_filters) + return all_filters + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in RLS filters: %(msg)s", + msg=ex.message, + ) + ) from ex def _process_sql_expression( # pylint: disable=no-self-use self, @@ -1255,27 +1292,30 @@ def dttm_sql_literal(self, dttm: sa.DateTime, col_type: Optional[str]) -> str: def get_time_filter( self, - time_col: Dict[str, Any], + time_col: "TableColumn", start_dttm: Optional[sa.DateTime], end_dttm: Optional[sa.DateTime], + label: Optional[str] = "__time", + template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - label = "__time" - col = time_col.get("column_name") - sqla_col = literal_column(col) - my_col = self.make_sqla_column_compatible(sqla_col, label) + col = self.transform_tbl_column_to_sqla_column( + time_col, + label=label, + template_processor=template_processor + ) l = [] if start_dttm: l.append( - my_col + col >= self.db_engine_spec.get_text_clause( - self.dttm_sql_literal(start_dttm, time_col.get("type")) + self.dttm_sql_literal(start_dttm, time_col.type) ) ) if end_dttm: l.append( - my_col + col < self.db_engine_spec.get_text_clause( - self.dttm_sql_literal(end_dttm, time_col.get("type")) + self.dttm_sql_literal(end_dttm, time_col.type) ) ) return and_(*l) @@ -1339,11 +1379,24 @@ def get_timestamp_expression( time_expr = self.db_engine_spec.get_timestamp_expr(col, None, time_grain) return self.make_sqla_column_compatible(time_expr, label) - def get_sqla_col(self, col: Dict[str, Any]) -> Column: - label = col.get("column_name") - col_type = col.get("type") - col = sa.column(label, type_=col_type) - return self.make_sqla_column_compatible(col, label) + def transform_tbl_column_to_sqla_column( + self, + tbl_column: "TableColumn", + label: Optional[str] = None, + template_processor: Optional[BaseTemplateProcessor] = None + ) -> Column: + label = label or tbl_column.column_name + db_engine_spec = self.db_engine_spec + column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra) + type_ = column_spec.sqla_type if column_spec else None + if expression := tbl_column.expression: + if template_processor: + expression = template_processor.process_template(expression) + col = literal_column(expression, type_=type_) + else: + col = sa.column(tbl_column.column_name, type_=type_) + col = self.make_sqla_column_compatible(col, label) + return col def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements self, @@ -1380,6 +1433,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma extras = extras or {} time_grain = extras.get("time_grain_sqla") + # breakpoint() template_kwargs = { "columns": columns, "from_dttm": from_dttm.isoformat() if from_dttm else None, @@ -1489,6 +1543,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: + # col = self.transform_tbl_column_to_sqla_column( + # columns_by_name[col], template_processor=template_processor + # ) col = columns_by_name[col].get_sqla_col( template_processor=template_processor ) @@ -1532,6 +1589,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma ) # if groupby field equals a selected column elif selected in columns_by_name: + # outer = self.transform_tbl_column_to_sqla_column( + # columns_by_name[selected], template_processor=template_processor + # ) outer = columns_by_name[selected].get_sqla_col( template_processor=template_processor ) @@ -1567,11 +1627,15 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma self.database_id, self.schema, ) + + data = self.transform_tbl_column_to_sqla_column( + columns_by_name[selected], template_processor=template_processor + ) + # data = columns_by_name[selected].get_sqla_col( + # template_processor=template_processor + # ) select_exprs.append( - columns_by_name[selected].get_sqla_col( - template_processor=template_processor - ) - if isinstance(selected, str) and selected in columns_by_name + data if isinstance(selected, str) and selected in columns_by_name else self.make_sqla_column_compatible( literal_column(selected), _column_label ) @@ -1603,19 +1667,34 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma and self.main_dttm_col != dttm_col.column_name ): time_filters.append( - columns_by_name[self.main_dttm_col].get_time_filter( + # columns_by_name[self.main_dttm_col].get_time_filter( + # start_dttm=from_dttm, + # end_dttm=to_dttm, + # template_processor=template_processor, + # ) + self.get_time_filter( + time_col=self.main_dttm_col, start_dttm=from_dttm, end_dttm=to_dttm, - template_processor=template_processor, + template_processor=template_processor ) ) - time_filters.append( - dttm_col.get_time_filter( + + time_filter_column = self.get_time_filter( + time_col=dttm_col, start_dttm=from_dttm, end_dttm=to_dttm, - template_processor=template_processor, - ) + template_processor=template_processor ) + time_filters.append(time_filter_column) + + # time_filters.append( + # dttm_col.get_time_filter( + # start_dttm=from_dttm, + # end_dttm=to_dttm, + # template_processor=template_processor, + # ) + # ) # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use @@ -1671,9 +1750,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma time_grain=filter_grain, template_processor=template_processor ) elif col_obj: - sqla_col = col_obj.get_sqla_col( + sqla_col = self.transform_tbl_column_to_sqla_column( + tbl_column=col_obj, template_processor=template_processor ) + # sqla_col = col_obj.get_sqla_col( + # template_processor=template_processor + # ) col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, @@ -1794,12 +1877,19 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma extras=extras, ) where_clause_and.append( - col_obj.get_time_filter( + self.get_time_filter( + time_col=col_obj, start_dttm=_since, end_dttm=_until, label=sqla_col.key, template_processor=template_processor, ) + # col_obj.get_time_filter( + # start_dttm=_since, + # end_dttm=_until, + # label=sqla_col.key, + # template_processor=template_processor, + # ) ) else: raise QueryObjectValidationError( @@ -1898,11 +1988,17 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if dttm_col and not db_engine_spec.time_groupby_inline: inner_time_filter = [ - dttm_col.get_time_filter( + self.get_time_filter( + time_col=dttm_col, start_dttm=inner_from_dttm or from_dttm, end_dttm=inner_to_dttm or to_dttm, template_processor=template_processor, ) + # dttm_col.get_time_filter( + # start_dttm=inner_from_dttm or from_dttm, + # end_dttm=inner_to_dttm or to_dttm, + # template_processor=template_processor, + # ) ] subq = subq.where(and_(*(where_clause_and + inner_time_filter))) subq = subq.group_by(*inner_groupby_exprs) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index aecbb0340f225..9ffe28e2774b7 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -200,6 +200,9 @@ def columns(self) -> List[Dict[str, Any]]: str_types = ("VARCHAR", "STRING", "CHAR") columns = [] col_type = "" + + from superset.connectors.sqla.models import TableColumn + for col in self.extra.get("columns", []): computed_column = {**col} col_type = col.get("type") @@ -213,16 +216,29 @@ def columns(self) -> List[Dict[str, Any]]: if col_type and any(map(lambda t: t in col_type.upper(), date_types)): computed_column["type_generic"] = GenericDataType.TEMPORAL - computed_column["column_name"] = col.get("name") + column_name = col.get("name") + del computed_column["name"] + del computed_column["type_generic"] computed_column["groupby"] = True - columns.append(computed_column) + columns.append( + TableColumn( + column_name=column_name, + type=col_type, + is_dttm=computed_column["is_dttm"], + groupby=True, + ) + ) return columns + @property + def db_extra(self) -> None: + return None + @property def data(self) -> Dict[str, Any]: order_by_choices = [] for col in self.columns: - column_name = str(col.get("column_name") or "") + column_name = str(col.column_name or "") order_by_choices.append( (json.dumps([column_name, True]), column_name + " [asc]") ) @@ -236,7 +252,7 @@ def data(self) -> Dict[str, Any]: ], "filter_select": True, "name": self.tab_name, - "columns": self.columns, + "columns": [o.data for o in self.columns], "metrics": [], "id": self.id, "type": self.type, @@ -277,7 +293,7 @@ def cache_timeout(self) -> int: @property def column_names(self) -> List[Any]: - return [col.get("column_name") for col in self.columns] + return [col.column_name for col in self.columns] @property def offset(self) -> int: @@ -292,7 +308,7 @@ def main_dttm_col(self) -> Optional[str]: @property def dttm_cols(self) -> List[Any]: - return [col.get("column_name") for col in self.columns if col.get("is_dttm")] + return [col.column_name for col in self.columns if col.is_dttm] @property def schema_perm(self) -> str: @@ -335,7 +351,7 @@ def get_column(self, column_name: Optional[str]) -> Optional[Dict[str, Any]]: if not column_name: return None for col in self.columns: - if col.get("column_name") == column_name: + if col.column_name == column_name: return col return None diff --git a/superset/utils/core.py b/superset/utils/core.py index 0ab3a685a39c3..2e6014332bb2e 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1803,16 +1803,9 @@ def get_time_filter_status( datasource: "BaseDatasource", applied_time_extras: Dict[str, str], ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: - - temporal_columns: Set[Any] - if datasource.type == "query": - temporal_columns = { - col.get("column_name") for col in datasource.columns if col.get("is_dttm") - } - else: - temporal_columns = { - col.column_name for col in datasource.columns if col.is_dttm - } + temporal_columns: Set[Any] = { + col.column_name for col in datasource.columns if col.is_dttm + } applied: List[Dict[str, str]] = [] rejected: List[Dict[str, str]] = [] time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL) From 47c2f834eb93b274c6352f8afafbcc519da38c66 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 25 Jan 2023 20:46:38 +0200 Subject: [PATCH 04/29] fix linting --- superset/connectors/sqla/models.py | 40 ++----- superset/models/helpers.py | 109 ++++++++++-------- superset/models/sql_lab.py | 11 +- superset/result_set.py | 4 +- .../utils/pandas_postprocessing/boxplot.py | 8 +- .../utils/pandas_postprocessing/flatten.py | 2 +- .../charts/data/api_tests.py | 4 +- 7 files changed, 82 insertions(+), 96 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 2e3f2a3c02018..a1a1a8a1ae042 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -50,11 +50,9 @@ from jinja2.exceptions import TemplateError from sqlalchemy import ( and_, - asc, Boolean, Column, DateTime, - desc, Enum, ForeignKey, inspect, @@ -84,9 +82,7 @@ from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager -from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus -from superset.common.utils.time_range_utils import get_since_until_from_time_range from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.sqla.utils import ( find_cached_objects_in_session, @@ -98,13 +94,11 @@ from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression from superset.exceptions import ( - AdvancedDataTypeResponseError, DatasetInvalidPermissionEvaluationException, QueryClauseValidationException, QueryObjectValidationError, SupersetSecurityException, ) -from superset.extensions import feature_flag_manager from superset.jinja_context import ( BaseTemplateProcessor, ExtraCache, @@ -117,26 +111,12 @@ CertificationMixin, ExploreMixin, QueryResult, + QueryStringExtended, ) from superset.sql_parse import ParsedQuery, sanitize_clause -from superset.superset_typing import ( - AdhocColumn, - AdhocMetric, - Column as ColumnTyping, - Metric, - OrderBy, - QueryObjectDict, -) +from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict from superset.utils import core as utils -from superset.utils.core import ( - GenericDataType, - get_column_name, - get_username, - is_adhoc_column, - MediumText, - QueryObjectFilterClause, - remove_duplicates, -) +from superset.utils.core import GenericDataType, get_username, MediumText config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -161,14 +141,6 @@ class SqlaQuery(NamedTuple): prequeries: List[str] sqla_query: Select -from superset.models.helpers import QueryStringExtended - -# class QueryStringExtended(NamedTuple): -# applied_template_filters: Optional[List[str]] -# labels_expected: List[str] -# prequeries: List[str] -# sql: str - @dataclass class MetadataResult: @@ -547,7 +519,9 @@ def _process_sql_expression( return expression -class SqlaTable(Model, BaseDatasource, ExploreMixin): # pylint: disable=too-many-public-methods +class SqlaTable( + Model, BaseDatasource, ExploreMixin +): # pylint: disable=too-many-public-methods """An ORM object for SqlAlchemy table references""" type = "table" @@ -1007,7 +981,7 @@ def adhoc_metric_to_sqla( return self.make_sqla_column_compatible(sqla_metric, label) - def adhoc_column_to_sqla( # type: ignore + def adhoc_column_to_sqla( self, col: AdhocColumn, template_processor: Optional[BaseTemplateProcessor] = None, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 17f8090bef317..61502dc7ec1eb 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -19,6 +19,7 @@ import logging import re import uuid + # pylint: disable=too-many-lines from collections import defaultdict from datetime import datetime, timedelta @@ -693,6 +694,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def fetch_value_predicate(self) -> str: return "fix this!" + @property + def type(self) -> str: + raise NotImplementedError() + + @property + def db_extra(self) -> Optional[Dict[str, Any]]: + raise NotImplementedError() + def query(self, query_obj: QueryObjectDict) -> QueryResult: raise NotImplementedError() @@ -756,8 +765,9 @@ def sql(self) -> str: def columns(self) -> List[Any]: raise NotImplementedError() - @property - def get_fetch_values_predicate(self) -> List[Any]: + def get_fetch_values_predicate( + self, template_processor: Optional[BaseTemplateProcessor] = None + ) -> TextClause: raise NotImplementedError() def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: @@ -1216,14 +1226,14 @@ def _get_series_orderby( ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) - ob = self.adhoc_metric_to_sqla( - series_limit_metric, columns_by_name - ) + ob = self.adhoc_metric_to_sqla(series_limit_metric, columns_by_name) elif ( isinstance(series_limit_metric, str) and series_limit_metric in metrics_by_name ): - ob = metrics_by_name[series_limit_metric].get_sqla_col() + ob = metrics_by_name[series_limit_metric].get_sqla_col( + template_processor=template_processor + ) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=series_limit_metric) @@ -1232,7 +1242,7 @@ def _get_series_orderby( def adhoc_column_to_sqla( self, - col: Type["AdhocColumn"], # type: ignore + col: "AdhocColumn", # type: ignore template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: raise NotImplementedError() @@ -1290,7 +1300,7 @@ def dttm_sql_literal(self, dttm: sa.DateTime, col_type: Optional[str]) -> str: return f'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}' - def get_time_filter( + def get_time_filter( # pylint: disable=too-many-arguments self, time_col: "TableColumn", start_dttm: Optional[sa.DateTime], @@ -1298,10 +1308,8 @@ def get_time_filter( label: Optional[str] = "__time", template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - col = self.transform_tbl_column_to_sqla_column( - time_col, - label=label, - template_processor=template_processor + col = self.convert_tbl_column_to_sqla_col( + time_col, label=label, template_processor=template_processor ) l = [] if start_dttm: @@ -1379,12 +1387,12 @@ def get_timestamp_expression( time_expr = self.db_engine_spec.get_timestamp_expr(col, None, time_grain) return self.make_sqla_column_compatible(time_expr, label) - def transform_tbl_column_to_sqla_column( + def convert_tbl_column_to_sqla_col( self, tbl_column: "TableColumn", label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None - ) -> Column: + template_processor: Optional[BaseTemplateProcessor] = None, + ) -> Column: label = label or tbl_column.column_name db_engine_spec = self.db_engine_spec column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra) @@ -1472,11 +1480,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col - columns_by_name: Dict[str, TableColumn] = { + columns_by_name: Dict[str, "TableColumn"] = { col.column_name: col for col in self.columns } - metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} + metrics_by_name: Dict[str, "SqlMetric"] = { + m.metric_name: m for m in self.metrics + } if not granularity and is_timeseries: raise QueryObjectValidationError( @@ -1543,12 +1553,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: - # col = self.transform_tbl_column_to_sqla_column( - # columns_by_name[col], template_processor=template_processor - # ) - col = columns_by_name[col].get_sqla_col( - template_processor=template_processor + col = self.convert_tbl_column_to_sqla_col( + columns_by_name[col], template_processor=template_processor ) + # col = columns_by_name[col].get_sqla_col( + # template_processor=template_processor + # ) elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True @@ -1589,12 +1599,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma ) # if groupby field equals a selected column elif selected in columns_by_name: - # outer = self.transform_tbl_column_to_sqla_column( - # columns_by_name[selected], template_processor=template_processor - # ) - outer = columns_by_name[selected].get_sqla_col( - template_processor=template_processor + outer = self.convert_tbl_column_to_sqla_col( + columns_by_name[selected], + template_processor=template_processor, ) + # outer = columns_by_name[selected].get_sqla_col( + # template_processor=template_processor + # ) else: selected = validate_adhoc_subquery( selected, @@ -1628,14 +1639,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma self.schema, ) - data = self.transform_tbl_column_to_sqla_column( - columns_by_name[selected], template_processor=template_processor - ) - # data = columns_by_name[selected].get_sqla_col( - # template_processor=template_processor - # ) select_exprs.append( - data if isinstance(selected, str) and selected in columns_by_name + self.convert_tbl_column_to_sqla_col( + columns_by_name[selected], template_processor=template_processor + ) + if isinstance(selected, str) and selected in columns_by_name else self.make_sqla_column_compatible( literal_column(selected), _column_label ) @@ -1673,18 +1681,18 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma # template_processor=template_processor, # ) self.get_time_filter( - time_col=self.main_dttm_col, + time_col=columns_by_name[self.main_dttm_col], start_dttm=from_dttm, end_dttm=to_dttm, - template_processor=template_processor + template_processor=template_processor, ) ) time_filter_column = self.get_time_filter( - time_col=dttm_col, - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor + time_col=dttm_col, + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, ) time_filters.append(time_filter_column) @@ -1727,12 +1735,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma flt_col = flt["col"] val = flt.get("val") op = flt["op"].upper() - col_obj: Optional[TableColumn] = None + col_obj: Optional["TableColumn"] = None sqla_col: Optional[Column] = None if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col elif is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(col=flt_col, template_processor=template_processor) # type: ignore + sqla_col = self.adhoc_column_to_sqla( + col=flt_col, template_processor=template_processor + ) else: col_obj = columns_by_name.get(flt_col) filter_grain = flt.get("grain") @@ -1750,9 +1760,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma time_grain=filter_grain, template_processor=template_processor ) elif col_obj: - sqla_col = self.transform_tbl_column_to_sqla_column( - tbl_column=col_obj, - template_processor=template_processor + sqla_col = self.convert_tbl_column_to_sqla_col( + tbl_column=col_obj, template_processor=template_processor ) # sqla_col = col_obj.get_sqla_col( # template_processor=template_processor @@ -1760,7 +1769,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, -# db_extra=self.database.get_extra(), + # db_extra=self.database.get_extra(), ) is_list_target = op in ( utils.FilterOperator.IN.value, @@ -1780,7 +1789,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma target_native_type=col_type, is_list_target=is_list_target, db_engine_spec=db_engine_spec, -# db_extra=self.database.get_extra(), + # db_extra=self.database.get_extra(), ) if ( col_advanced_data_type != "" @@ -1912,7 +1921,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma expression=where, database_id=self.database_id, schema=self.schema, - template_processor=template_processor + template_processor=template_processor, ) where_clause_and += [self.text(where)] having = extras.get("having") @@ -1934,9 +1943,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma ) having_clause_and += [self.text(having)] - if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore + if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore qry = qry.where( - self.get_fetch_values_predicate(template_processor=template_processor) # type: ignore + self.get_fetch_values_predicate(template_processor=template_processor) ) if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 9ffe28e2774b7..5efc949fafcf1 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -54,6 +54,7 @@ from superset.utils.core import GenericDataType, QueryStatus, user_label if TYPE_CHECKING: + from superset.connectors.sqla.models import TableColumn from superset.db_engine_specs import BaseEngineSpec @@ -182,7 +183,11 @@ def sql_tables(self) -> List[Table]: return list(ParsedQuery(self.sql).tables) @property - def columns(self) -> List[Dict[str, Any]]: + def columns(self) -> List["TableColumn"]: + from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel + TableColumn, + ) + bool_types = ("BOOL",) num_types = ( "DOUBLE", @@ -201,8 +206,6 @@ def columns(self) -> List[Dict[str, Any]]: columns = [] col_type = "" - from superset.connectors.sqla.models import TableColumn - for col in self.extra.get("columns", []): computed_column = {**col} col_type = col.get("type") @@ -231,7 +234,7 @@ def columns(self) -> List[Dict[str, Any]]: return columns @property - def db_extra(self) -> None: + def db_extra(self) -> Optional[Dict[str, Any]]: return None @property diff --git a/superset/result_set.py b/superset/result_set.py index 63d48b1e4bcab..3d29673b9fcb9 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -70,9 +70,9 @@ def stringify_values(array: NDArray[Any]) -> NDArray[Any]: for obj in it: if na_obj := pd.isna(obj): # pandas type cannot be converted to string - obj[na_obj] = None + obj[na_obj] = None # type: ignore else: - obj[...] = stringify(obj) + obj[...] = stringify(obj) # type: ignore return result diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py index e2706345b1ea9..673c39ebf3836 100644 --- a/superset/utils/pandas_postprocessing/boxplot.py +++ b/superset/utils/pandas_postprocessing/boxplot.py @@ -57,10 +57,10 @@ def boxplot( """ def quartile1(series: Series) -> float: - return np.nanpercentile(series, 25, interpolation="midpoint") + return np.nanpercentile(series, 25, interpolation="midpoint") # type: ignore def quartile3(series: Series) -> float: - return np.nanpercentile(series, 75, interpolation="midpoint") + return np.nanpercentile(series, 75, interpolation="midpoint") # type: ignore if whisker_type == PostProcessingBoxplotWhiskerType.TUKEY: @@ -99,8 +99,8 @@ def whisker_low(series: Series) -> float: return np.nanpercentile(series, low) else: - whisker_high = np.max - whisker_low = np.min + whisker_high = np.max # type: ignore + whisker_low = np.min # type: ignore def outliers(series: Series) -> Set[float]: above = series[series > whisker_high(series)] diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py index db783c4bed264..1026164e454ee 100644 --- a/superset/utils/pandas_postprocessing/flatten.py +++ b/superset/utils/pandas_postprocessing/flatten.py @@ -85,7 +85,7 @@ def flatten( _columns = [] for series in df.columns.to_flat_index(): _cells = [] - for cell in series if is_sequence(series) else [series]: + for cell in series if is_sequence(series) else [series]: # type: ignore if pd.notnull(cell): # every cell should be converted to string and escape comma _cells.append(escape_separator(str(cell))) diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index d83cb8286b529..c8cba3ff43ae8 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -1155,8 +1155,8 @@ def test_chart_cache_timeout_chart_not_found( [ (200, {"where": "1 = 1"}), (200, {"having": "count(*) > 0"}), - (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}), - (400, {"having": "count(*) > (select count(*) from physical_dataset)"}), + (403, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (403, {"having": "count(*) > (select count(*) from physical_dataset)"}), ], ) @with_feature_flags(ALLOW_ADHOC_SUBQUERY=False) From 247f223fc7ed9e178dff44506db275c50a7476e1 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 7 Feb 2023 08:32:29 -0800 Subject: [PATCH 05/29] fix logos for columns --- superset/connectors/sqla/models.py | 29 ++++++++++++++++++++++ superset/models/sql_lab.py | 39 +++--------------------------- 2 files changed, 32 insertions(+), 36 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a1a1a8a1ae042..5d385984e9079 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -283,6 +283,35 @@ def db_extra(self) -> Dict[str, Any]: def type_generic(self) -> Optional[utils.GenericDataType]: if self.is_dttm: return GenericDataType.TEMPORAL + + bool_types = ("BOOL",) + num_types = ( + "DOUBLE", + "FLOAT", + "INT", + "BIGINT", + "NUMBER", + "LONG", + "REAL", + "NUMERIC", + "DECIMAL", + "MONEY", + ) + date_types = ("DATE", "TIME") + str_types = ("VARCHAR", "STRING", "CHAR") + + if self.table is None: + # Query.TableColumns don't have a reference to a table.db_engine_spec + # reference so this logic will manage rendering types + if self.type and any(map(lambda t: t in self.type.upper(), str_types)): + return GenericDataType.STRING + if self.type and any(map(lambda t: t in self.type.upper(), bool_types)): + return GenericDataType.BOOLEAN + if self.type and any(map(lambda t: t in self.type.upper(), num_types)): + return GenericDataType.NUMERIC + if self.type and any(map(lambda t: t in self.type.upper(), date_types)): + return GenericDataType.TEMPORAL + column_spec = self.db_engine_spec.get_column_spec( self.type, db_extra=self.db_extra ) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 71acce50373d1..98225000775d4 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -189,46 +189,13 @@ def columns(self) -> List["TableColumn"]: TableColumn, ) - bool_types = ("BOOL",) - num_types = ( - "DOUBLE", - "FLOAT", - "INT", - "BIGINT", - "NUMBER", - "LONG", - "REAL", - "NUMERIC", - "DECIMAL", - "MONEY", - ) - date_types = ("DATE", "TIME") - str_types = ("VARCHAR", "STRING", "CHAR") columns = [] - col_type = "" - for col in self.extra.get("columns", []): - computed_column = {**col} - col_type = col.get("type") - - if col_type and any(map(lambda t: t in col_type.upper(), str_types)): - computed_column["type_generic"] = GenericDataType.STRING - if col_type and any(map(lambda t: t in col_type.upper(), bool_types)): - computed_column["type_generic"] = GenericDataType.BOOLEAN - if col_type and any(map(lambda t: t in col_type.upper(), num_types)): - computed_column["type_generic"] = GenericDataType.NUMERIC - if col_type and any(map(lambda t: t in col_type.upper(), date_types)): - computed_column["type_generic"] = GenericDataType.TEMPORAL - - column_name = col.get("name") - del computed_column["name"] - del computed_column["type_generic"] - computed_column["groupby"] = True columns.append( TableColumn( - column_name=column_name, - type=col_type, - is_dttm=computed_column["is_dttm"], + column_name=col["name"], + type=col["type"], + is_dttm=col["is_dttm"], groupby=True, ) ) From d4e48696e6e8eea78c898274c6e05ba0103f57e9 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 14 Feb 2023 20:59:36 -0500 Subject: [PATCH 06/29] add filterable true to allow columns to have metrics --- superset/models/sql_lab.py | 1 + 1 file changed, 1 insertion(+) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 98225000775d4..b1d405cea45b7 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -197,6 +197,7 @@ def columns(self) -> List["TableColumn"]: type=col["type"], is_dttm=col["is_dttm"], groupby=True, + filterable=True, ) ) return columns From ce25c866cfcc12c015dcda61a2e4f724e3d440c2 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 20 Feb 2023 17:18:44 -0500 Subject: [PATCH 07/29] only stop propagation if e is defined --- .../components/controls/MetricControl/AdhocMetricOption.jsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx index 80cf879f7f256..f3fac06cc4131 100644 --- a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx +++ b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx @@ -48,7 +48,7 @@ class AdhocMetricOption extends React.PureComponent { } onRemoveMetric(e) { - e.stopPropagation(); + if (e !== undefined) e.stopPropagation(); this.props.onRemoveMetric(this.props.index); } @@ -67,7 +67,7 @@ class AdhocMetricOption extends React.PureComponent { multi, datasourceWarningMessage, } = this.props; - + console.log('hello'); return ( Date: Tue, 21 Feb 2023 22:05:43 -0500 Subject: [PATCH 08/29] fix save --- .../src/SqlLab/components/SaveDatasetModal/index.tsx | 4 ++-- superset/views/core.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index 949323b9aa75c..d1f3461906b04 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -61,7 +61,7 @@ export type ExploreQuery = QueryResponse & { }; export interface ISimpleColumn { - name?: string | null; + column_name?: string | null; type?: string | null; is_dttm?: boolean | null; } @@ -301,7 +301,7 @@ export const SaveDatasetModal = ({ ...formDataWithDefaults, datasource: `${data.table_id}__table`, ...(defaultVizType === 'table' && { - all_columns: selectedColumns.map(column => column.name), + all_columns: selectedColumns.map(column => column.column_name), }), }), ) diff --git a/superset/views/core.py b/superset/views/core.py index 8d632dcde21bf..2572d18993d7e 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2009,7 +2009,7 @@ def sqllab_viz(self) -> FlaskResponse: # pylint: disable=no-self-use db.session.add(table) cols = [] for config_ in data.get("columns"): - column_name = config_.get("name") + column_name = config_.get("column_name") col = TableColumn( column_name=column_name, filterable=True, From b6297600089c23c333a85491656fdcacb98f7731 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 21 Feb 2023 22:17:29 -0500 Subject: [PATCH 09/29] remove log --- .../components/controls/MetricControl/AdhocMetricOption.jsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx index f3fac06cc4131..11406f0ebe798 100644 --- a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx +++ b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx @@ -67,7 +67,7 @@ class AdhocMetricOption extends React.PureComponent { multi, datasourceWarningMessage, } = this.props; - console.log('hello'); + return ( Date: Tue, 21 Feb 2023 22:21:35 -0500 Subject: [PATCH 10/29] fix linting --- superset/models/sql_lab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index b1d405cea45b7..768ed809cb136 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -52,7 +52,7 @@ ) from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sqllab.limiting_factor import LimitingFactor -from superset.utils.core import GenericDataType, QueryStatus, user_label +from superset.utils.core import QueryStatus, user_label if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn From d4b6d0473b4f1a5078b84458ddc8e73226702bed Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 22 Feb 2023 12:59:59 -0500 Subject: [PATCH 11/29] change column name references --- tests/integration_tests/sqllab_tests.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 19e397e8f6961..b7ffe26a0752b 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -492,8 +492,16 @@ def test_sqllab_viz(self): "datasourceName": f"test_viz_flow_table_{random()}", "schema": "superset", "columns": [ - {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"}, - {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"}, + { + "is_dttm": False, + "type": "STRING", + "column_name": f"viz_type_{random()}", + }, + { + "is_dttm": False, + "type": "OBJECT", + "column_name": f"ccount_{random()}", + }, ], "sql": """\ SELECT * @@ -522,8 +530,16 @@ def test_sqllab_viz_bad_payload(self): "chartType": "dist_bar", "schema": "superset", "columns": [ - {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"}, - {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"}, + { + "is_dttm": False, + "type": "STRING", + "column_name": f"viz_type_{random()}", + }, + { + "is_dttm": False, + "type": "OBJECT", + "column_name": f"ccount_{random()}", + }, ], "sql": """\ SELECT * From e1ab0ac0da6317a9979da006aa18b27f9fe6864c Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 22 Feb 2023 13:14:40 -0500 Subject: [PATCH 12/29] fix lint --- .../src/SqlLab/components/SaveDatasetModal/index.tsx | 2 +- superset-frontend/src/SqlLab/fixtures.ts | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index d1f3461906b04..402e26462e041 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -216,7 +216,7 @@ export const SaveDatasetModal = ({ ...formDataWithDefaults, datasource: `${datasetToOverwrite.datasetid}__table`, ...(defaultVizType === 'table' && { - all_columns: datasource?.columns?.map(column => column.name), + all_columns: datasource?.columns?.map(column => column.column_name), }), }), ]); diff --git a/superset-frontend/src/SqlLab/fixtures.ts b/superset-frontend/src/SqlLab/fixtures.ts index 456a83a3faf1e..50739268ff01c 100644 --- a/superset-frontend/src/SqlLab/fixtures.ts +++ b/superset-frontend/src/SqlLab/fixtures.ts @@ -699,17 +699,17 @@ export const testQuery: ISaveableDatasource = { sql: 'SELECT *', columns: [ { - name: 'Column 1', + column_name: 'Column 1', type: DatasourceType.Query, is_dttm: false, }, { - name: 'Column 3', + column_name: 'Column 3', type: DatasourceType.Query, is_dttm: false, }, { - name: 'Column 2', + column_name: 'Column 2', type: DatasourceType.Query, is_dttm: true, }, From bdd07273e845dcebb3536261be59d3ab697e688b Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 22 Feb 2023 14:21:27 -0500 Subject: [PATCH 13/29] pre-commit --- superset/models/helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 66835b32e1701..2ead2f09c6432 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -19,6 +19,7 @@ import logging import re import uuid + # pylint: disable=too-many-lines from collections import defaultdict from datetime import datetime, timedelta From f2b512f0dc8514509953e67fb70faacf92868a92 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Wed, 22 Feb 2023 14:32:57 -0500 Subject: [PATCH 14/29] lint --- superset/models/helpers.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 2ead2f09c6432..25fda9e812e9b 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1674,11 +1674,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma and self.main_dttm_col != dttm_col.column_name ): time_filters.append( - # columns_by_name[self.main_dttm_col].get_time_filter( - # start_dttm=from_dttm, - # end_dttm=to_dttm, - # template_processor=template_processor, - # ) self.get_time_filter( time_col=columns_by_name[self.main_dttm_col], start_dttm=from_dttm, @@ -1695,14 +1690,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma ) time_filters.append(time_filter_column) - # time_filters.append( - # dttm_col.get_time_filter( - # start_dttm=from_dttm, - # end_dttm=to_dttm, - # template_processor=template_processor, - # ) - # ) - # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use # raw columns as custom SQL adhoc metric). From 8937044f364c8c2972ce4dab3960e23f29bb09d1 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Tue, 28 Feb 2023 11:20:37 -0500 Subject: [PATCH 15/29] address concerns --- .../components/controls/MetricControl/AdhocMetricOption.jsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx index 11406f0ebe798..c74212f0baf6b 100644 --- a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx +++ b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx @@ -48,7 +48,7 @@ class AdhocMetricOption extends React.PureComponent { } onRemoveMetric(e) { - if (e !== undefined) e.stopPropagation(); + e?.stopPropagation(); this.props.onRemoveMetric(this.props.index); } From 1d443b79b38182101cd5237135176cb6d3f64e91 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 20 Mar 2023 14:02:27 -0700 Subject: [PATCH 16/29] fix saving columns --- superset/views/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/views/core.py b/superset/views/core.py index 777c0f9cafebb..0a624ecf62c08 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2018,7 +2018,7 @@ def sqllab_viz(self) -> FlaskResponse: # pylint: disable=no-self-use db.session.add(table) cols = [] for config_ in data.get("columns"): - column_name = config_.get("column_name") + column_name = config_.get("column_name") or config_.get("name") col = TableColumn( column_name=column_name, filterable=True, From 303a92e67bb139f589903c48c07077ba240a18ce Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 20 Mar 2023 14:14:23 -0700 Subject: [PATCH 17/29] add fix for name vs column_name --- .../src/SqlLab/components/SaveDatasetModal/index.tsx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index 402e26462e041..51a4fb13c41ce 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -301,7 +301,9 @@ export const SaveDatasetModal = ({ ...formDataWithDefaults, datasource: `${data.table_id}__table`, ...(defaultVizType === 'table' && { - all_columns: selectedColumns.map(column => column.column_name), + all_columns: selectedColumns.map( + column => column.column_name || column.name, + ), }), }), ) From 0c002538c8074f4c730671ec45876abe70074de4 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 3 Apr 2023 18:13:03 -0400 Subject: [PATCH 18/29] add QueryStringExtended --- superset/connectors/sqla/models.py | 21 +++++++-------------- superset/models/helpers.py | 8 ++++++-- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 776a29165c582..d5cc7545ed7b5 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -93,8 +93,14 @@ ) from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression +from superset.superset_typing import ( + AdhocColumn, + AdhocMetric, + Column as ColumnTyping, + Metric, + QueryObjectDict, +) from superset.exceptions import ( - AdvancedDataTypeResponseError, ColumnNotFoundException, DatasetInvalidPermissionEvaluationException, QueryClauseValidationException, @@ -117,7 +123,6 @@ QueryStringExtended, ) from superset.sql_parse import ParsedQuery, sanitize_clause -from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict from superset.utils import core as utils from superset.utils.core import GenericDataType, get_username, MediumText @@ -135,18 +140,6 @@ } ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES} - -class SqlaQuery(NamedTuple): - applied_template_filters: List[str] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] - cte: Optional[str] - extra_cache_keys: List[Any] - labels_expected: List[str] - prequeries: List[str] - sqla_query: Select - - @dataclass class MetadataResult: added: List[str] = field(default_factory=list) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 1fdb5b35c7654..b9efc6ef07f9f 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -97,6 +97,7 @@ is_adhoc_column, remove_duplicates, ) +from superset.utils.dates import datetime_to_epoch if TYPE_CHECKING: from superset.connectors.sqla.models import SqlMetric, TableColumn @@ -675,20 +676,22 @@ def clone_model( # todo(hugh): centralize where this code lives class QueryStringExtended(NamedTuple): applied_template_filters: Optional[List[str]] + applied_filter_columns: List[ColumnTyping] + rejected_filter_columns: List[ColumnTyping] labels_expected: List[str] prequeries: List[str] sql: str - class SqlaQuery(NamedTuple): applied_template_filters: List[str] + applied_filter_columns: List[ColumnTyping] + rejected_filter_columns: List[ColumnTyping] cte: Optional[str] extra_cache_keys: List[Any] labels_expected: List[str] prequeries: List[str] sqla_query: Select - class ExploreMixin: # pylint: disable=too-many-public-methods """ Allows any flask_appbuilder.Model (Query, Table, etc.) @@ -1257,6 +1260,7 @@ def _get_series_orderby( def adhoc_column_to_sqla( self, col: "AdhocColumn", # type: ignore + force_type_check: bool = False, template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: raise NotImplementedError() From d178acaba22026ce77fd747c0fe371cefe3f5f8f Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 3 Apr 2023 18:54:55 -0400 Subject: [PATCH 19/29] bring in master --- superset/models/helpers.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index b9efc6ef07f9f..70aa9c2f248d5 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1475,6 +1475,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma } columns = columns or [] groupby = groupby or [] + rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] series_column_names = utils.get_column_names(series_columns or []) # deprecated, to be removed in 2.0 if is_timeseries and timeseries_limit: @@ -2091,6 +2092,24 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) labels_expected = [label] + + filter_columns = [flt.get("col") for flt in filter] if filter else [] + rejected_filter_columns = [ + col + for col in filter_columns + if col + and not is_adhoc_column(col) + and col not in self.column_names + and col not in applied_template_filters + ] + rejected_adhoc_filters_columns + + applied_filter_columns = [ + col + for col in filter_columns + if col + and not is_adhoc_column(col) + and (col in self.column_names or col in applied_template_filters) + ] + applied_adhoc_filters_columns return SqlaQuery( applied_template_filters=applied_template_filters, From 9d5dd929f3df0cfdf89941d66f83b8f2ae90512d Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 3 Apr 2023 20:54:13 -0400 Subject: [PATCH 20/29] ok --- superset/connectors/sqla/models.py | 15 ++++++++------- superset/models/helpers.py | 11 +++++++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index d5cc7545ed7b5..af3240ec7d921 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -93,13 +93,6 @@ ) from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression -from superset.superset_typing import ( - AdhocColumn, - AdhocMetric, - Column as ColumnTyping, - Metric, - QueryObjectDict, -) from superset.exceptions import ( ColumnNotFoundException, DatasetInvalidPermissionEvaluationException, @@ -123,6 +116,13 @@ QueryStringExtended, ) from superset.sql_parse import ParsedQuery, sanitize_clause +from superset.superset_typing import ( + AdhocColumn, + AdhocMetric, + Column as ColumnTyping, + Metric, + QueryObjectDict, +) from superset.utils import core as utils from superset.utils.core import GenericDataType, get_username, MediumText @@ -140,6 +140,7 @@ } ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES} + @dataclass class MetadataResult: added: List[str] = field(default_factory=list) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 70aa9c2f248d5..e6835a6919245 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -682,6 +682,7 @@ class QueryStringExtended(NamedTuple): prequeries: List[str] sql: str + class SqlaQuery(NamedTuple): applied_template_filters: List[str] applied_filter_columns: List[ColumnTyping] @@ -692,6 +693,7 @@ class SqlaQuery(NamedTuple): prequeries: List[str] sqla_query: Select + class ExploreMixin: # pylint: disable=too-many-public-methods """ Allows any flask_appbuilder.Model (Query, Table, etc.) @@ -941,6 +943,8 @@ def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExten sql = self.mutate_query_from_config(sql) return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, + applied_filter_columns=sqlaq.applied_filter_columns, + rejected_filter_columns=sqlaq.rejected_filter_columns, labels_expected=sqlaq.labels_expected, prequeries=sqlaq.prequeries, sql=sql, @@ -1476,6 +1480,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma columns = columns or [] groupby = groupby or [] rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] + applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] series_column_names = utils.get_column_names(series_columns or []) # deprecated, to be removed in 2.0 if is_timeseries and timeseries_limit: @@ -2092,7 +2097,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) labels_expected = [label] - + filter_columns = [flt.get("col") for flt in filter] if filter else [] rejected_filter_columns = [ col @@ -2102,7 +2107,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma and col not in self.column_names and col not in applied_template_filters ] + rejected_adhoc_filters_columns - + applied_filter_columns = [ col for col in filter_columns @@ -2114,6 +2119,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma return SqlaQuery( applied_template_filters=applied_template_filters, cte=cte, + applied_filter_columns=applied_filter_columns, + rejected_filter_columns=rejected_filter_columns, extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, sqla_query=qry, From 29dfe09322c2ad581fa0626eac2276c81063acf7 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 3 Apr 2023 21:25:05 -0400 Subject: [PATCH 21/29] ok --- superset/models/helpers.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index e6835a6919245..c042d61c2bbe6 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. """a collection of model-related helper classes and functions""" +import dataclasses import json import logging import re import uuid - # pylint: disable=too-many-lines from collections import defaultdict from datetime import datetime, timedelta @@ -1060,18 +1060,23 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: return df try: - df = self.database.get_df( - sql, self.schema, mutator=assign_column_label # type: ignore - ) + df = self.database.get_df(sql, self.schema, mutator=assign_column_label) except Exception as ex: # pylint: disable=broad-except df = pd.DataFrame() status = QueryStatus.FAILED logger.warning( "Query %s on schema %s failed", sql, self.schema, exc_info=True ) + db_engine_spec = self.db_engine_spec + errors = [ + dataclasses.asdict(error) for error in db_engine_spec.extract_errors(ex) + ] error_message = utils.error_msg_from_exception(ex) return QueryResult( + applied_template_filters=query_str_ext.applied_template_filters, + applied_filter_columns=query_str_ext.applied_filter_columns, + rejected_filter_columns=query_str_ext.rejected_filter_columns, status=status, df=df, duration=datetime.now() - qry_start_dttm, @@ -1463,7 +1468,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma extras = extras or {} time_grain = extras.get("time_grain_sqla") - # breakpoint() template_kwargs = { "columns": columns, "from_dttm": from_dttm.isoformat() if from_dttm else None, From dccf089c2f672b5e044728685852751f120b5cf0 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 3 Apr 2023 21:26:37 -0400 Subject: [PATCH 22/29] lit --- superset/models/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index c042d61c2bbe6..077d7d1154d47 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -20,7 +20,6 @@ import logging import re import uuid -# pylint: disable=too-many-lines from collections import defaultdict from datetime import datetime, timedelta from json.decoder import JSONDecodeError From 924514ad79ed1a8f3b437f16adb763d9960150e4 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Mon, 3 Apr 2023 22:33:38 -0400 Subject: [PATCH 23/29] lint --- superset/models/helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 077d7d1154d47..cb841eeffd99b 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1059,7 +1059,9 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: return df try: - df = self.database.get_df(sql, self.schema, mutator=assign_column_label) + df = self.database.get_df( + sql, self.schema, mutator=assign_column_label # type: ignore + ) except Exception as ex: # pylint: disable=broad-except df = pd.DataFrame() status = QueryStatus.FAILED From 1dd895b9b0ebc2efcd2193b72013b4f9f0183564 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Thu, 6 Apr 2023 15:11:21 -0400 Subject: [PATCH 24/29] maybe --- superset/models/helpers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index cb841eeffd99b..c5f1789fb1ce0 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -72,6 +72,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( AdvancedDataTypeResponseError, + ColumnNotFoundException, QueryClauseValidationException, QueryObjectValidationError, SupersetSecurityException, @@ -1756,9 +1757,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col elif is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla( - col=flt_col, template_processor=template_processor - ) + try: + sqla_col = self.adhoc_column_to_sqla(flt_col) + applied_adhoc_filters_columns.append(flt_col) + except ColumnNotFoundException: + rejected_adhoc_filters_columns.append(flt_col) + continue else: col_obj = columns_by_name.get(cast(str, flt_col)) filter_grain = flt.get("grain") From 245d0ad337d9660c8dafa36039f0366ab9e3eeff Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 7 Apr 2023 11:26:21 -0400 Subject: [PATCH 25/29] ask for check --- superset/models/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index c5f1789fb1ce0..b0f93171de19c 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1758,7 +1758,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col_obj = dttm_col elif is_adhoc_column(flt_col): try: - sqla_col = self.adhoc_column_to_sqla(flt_col) + sqla_col = self.adhoc_column_to_sqla(flt_col, force_type_check=True) applied_adhoc_filters_columns.append(flt_col) except ColumnNotFoundException: rejected_adhoc_filters_columns.append(flt_col) From f32b549bffbea09b4e73385d0323d5bd663a7991 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 7 Apr 2023 11:45:01 -0400 Subject: [PATCH 26/29] linting --- superset/utils/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/utils/core.py b/superset/utils/core.py index 457928c8484b8..8cf1076f5eef3 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1732,13 +1732,13 @@ def extract_dataframe_dtypes( if datasource: for column in datasource.columns: if isinstance(column, dict): - columns_by_name[column.get("column_name")] = column # type: ignore + columns_by_name[column.get("column_name")] = column else: columns_by_name[column.column_name] = column generic_types: List[GenericDataType] = [] for column in df.columns: - column_object = columns_by_name.get(column) # type: ignore + column_object = columns_by_name.get(column) series = df[column] inferred_type = infer_dtype(series) if isinstance(column_object, dict): From df47fb025e2750ba6fa302d4ad8d285fd245624a Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 7 Apr 2023 12:17:37 -0400 Subject: [PATCH 27/29] linting --- superset/connectors/sqla/models.py | 11 ++--------- superset/models/helpers.py | 5 ++++- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index af3240ec7d921..6180a546e7500 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -31,7 +31,6 @@ Dict, Hashable, List, - NamedTuple, Optional, Set, Tuple, @@ -78,7 +77,7 @@ from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table from sqlalchemy.sql.elements import ColumnClause, TextClause -from sqlalchemy.sql.expression import Label, Select, TextAsFrom +from sqlalchemy.sql.expression import Label, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager @@ -116,13 +115,7 @@ QueryStringExtended, ) from superset.sql_parse import ParsedQuery, sanitize_clause -from superset.superset_typing import ( - AdhocColumn, - AdhocMetric, - Column as ColumnTyping, - Metric, - QueryObjectDict, -) +from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict from superset.utils import core as utils from superset.utils.core import GenericDataType, get_username, MediumText diff --git a/superset/models/helpers.py b/superset/models/helpers.py index b0f93171de19c..b08c8c1477124 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=too-many-lines """a collection of model-related helper classes and functions""" import dataclasses import json @@ -935,7 +936,9 @@ def validate_adhoc_subquery( return ";\n".join(str(statement) for statement in statements) - def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: + def get_query_str_extended( + self, query_obj: QueryObjectDict, mutate: bool = True + ) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore sql = self._apply_cte(sql, sqlaq.cte) From 266e2244ec559e3d0b1a6bd56283b864ac83600a Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 7 Apr 2023 12:47:15 -0400 Subject: [PATCH 28/29] lint 1 more --- superset/models/helpers.py | 42 +++----------------------------------- 1 file changed, 3 insertions(+), 39 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index b08c8c1477124..cc3b34ae627cb 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -943,7 +943,8 @@ def get_query_str_extended( sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) - sql = self.mutate_query_from_config(sql) + if mutate: + sql = self.mutate_query_from_config(sql) return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, applied_filter_columns=sqlaq.applied_filter_columns, @@ -1151,7 +1152,7 @@ def get_from_clause( def adhoc_metric_to_sqla( self, metric: AdhocMetric, - columns_by_name: Dict[str, "TableColumn"], # # pylint: disable=unused-argument + columns_by_name: Dict[str, "TableColumn"], # pylint: disable=unused-argument template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ @@ -1278,23 +1279,6 @@ def adhoc_column_to_sqla( template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: raise NotImplementedError() - # """ - # Turn an adhoc column into a sqlalchemy column. - - # :param col: Adhoc column definition - # :param template_processor: template_processor instance - # :returns: The metric defined as a sqlalchemy column - # :rtype: sqlalchemy.sql.column - # """ - # label = utils.get_column_name(col) # type: ignore - # expression = self._process_sql_expression( - # expression=col["sqlExpression"], - # database_id=self.database_id, - # schema=self.schema, - # template_processor=template_processor, - # ) - # sqla_column = literal_column(expression) - # return self.make_sqla_column_compatible(sqla_column, label) def _get_top_groups( self, @@ -1589,9 +1573,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = self.convert_tbl_column_to_sqla_col( columns_by_name[col], template_processor=template_processor ) - # col = columns_by_name[col].get_sqla_col( - # template_processor=template_processor - # ) elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True @@ -1636,9 +1617,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma columns_by_name[selected], template_processor=template_processor, ) - # outer = columns_by_name[selected].get_sqla_col( - # template_processor=template_processor - # ) else: selected = validate_adhoc_subquery( selected, @@ -1786,9 +1764,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma sqla_col = self.convert_tbl_column_to_sqla_col( tbl_column=col_obj, template_processor=template_processor ) - # sqla_col = col_obj.get_sqla_col( - # template_processor=template_processor - # ) col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, @@ -1916,12 +1891,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma label=sqla_col.key, template_processor=template_processor, ) - # col_obj.get_time_filter( - # start_dttm=_since, - # end_dttm=_until, - # label=sqla_col.key, - # template_processor=template_processor, - # ) ) else: raise QueryObjectValidationError( @@ -2026,11 +1995,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma end_dttm=inner_to_dttm or to_dttm, template_processor=template_processor, ) - # dttm_col.get_time_filter( - # start_dttm=inner_from_dttm or from_dttm, - # end_dttm=inner_to_dttm or to_dttm, - # template_processor=template_processor, - # ) ] subq = subq.where(and_(*(where_clause_and + inner_time_filter))) subq = subq.group_by(*inner_groupby_exprs) From dfe7191b8c7470c31bf5a64df6808eeb38fb6107 Mon Sep 17 00:00:00 2001 From: hughhhh Date: Fri, 7 Apr 2023 13:34:12 -0400 Subject: [PATCH 29/29] fe build fix --- .../src/SqlLab/components/SaveDatasetModal/index.tsx | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index 51a4fb13c41ce..402e26462e041 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -301,9 +301,7 @@ export const SaveDatasetModal = ({ ...formDataWithDefaults, datasource: `${data.table_id}__table`, ...(defaultVizType === 'table' && { - all_columns: selectedColumns.map( - column => column.column_name || column.name, - ), + all_columns: selectedColumns.map(column => column.column_name), }), }), )