From d5242d066efa2894714b235c1375110325a43dd8 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 22 Aug 2018 18:19:40 +0300 Subject: [PATCH 01/29] Replace dataframe label override logic with table column override --- superset/connectors/sqla/models.py | 8 +++- superset/dataframe.py | 4 +- superset/db_engine_specs.py | 61 +++++------------------------- superset/models/core.py | 3 +- superset/views/core.py | 8 ++-- superset/viz.py | 4 -- 6 files changed, 25 insertions(+), 63 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 44a2cfb1ca9d5..d8e052b7f76c8 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -816,6 +816,7 @@ def fetch_metadata(self): .filter(or_(TableColumn.column_name == col.name for col in table.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} + db_engine_spec = self.database.db_engine_spec for col in table.columns: try: @@ -827,7 +828,10 @@ def fetch_metadata(self): logging.exception(e) dbcol = dbcols.get(col.name, None) if not dbcol: - dbcol = TableColumn(column_name=col.name, type=datatype) + dbcol = TableColumn( + column_name=db_engine_spec.mutate_column_label(col.name), + type=datatype + ) dbcol.groupby = dbcol.is_string dbcol.filterable = dbcol.is_string dbcol.sum = dbcol.is_num @@ -848,6 +852,8 @@ def fetch_metadata(self): )) if not self.main_dttm_col: self.main_dttm_col = any_date_col + for metric in metrics: + metric.metric_name = db_engine_spec.mutate_column_label(metric.metric_name) self.add_missing_metrics(metrics) db.session.merge(self) db.session.commit() diff --git a/superset/dataframe.py b/superset/dataframe.py index 834f11804743b..1678dd97f7f6d 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -73,9 +73,7 @@ def __init__(self, data, cursor_description, db_engine_spec): if cursor_description: column_names = [col[0] for col in cursor_description] - case_sensitive = db_engine_spec.consistent_case_sensitivity - self.column_names = dedup(column_names, - case_sensitive=case_sensitive) + self.column_names = dedup(column_names) data = data or [] self.df = ( diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 13eb69502bf51..2c8e561ddff6a 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -101,7 +101,6 @@ class BaseEngineSpec(object): time_secondary_columns = False inner_joins = True allows_subquery = True - consistent_case_sensitivity = True # do results have same case as qry for col names? arraysize = None @classmethod @@ -375,56 +374,9 @@ def execute(cls, cursor, query, async=False): cursor.arraysize = cls.arraysize cursor.execute(query) - @classmethod - def adjust_df_column_names(cls, df, fd): - """Based of fields in form_data, return dataframe with new column names - - Usually sqla engines return column names whose case matches that of the - original query. For example: - SELECT 1 as col1, 2 as COL2, 3 as Col_3 - will usually result in the following df.columns: - ['col1', 'COL2', 'Col_3']. - For these engines there is no need to adjust the dataframe column names - (default behavior). However, some engines (at least Snowflake, Oracle and - Redshift) return column names with different case than in the original query, - usually all uppercase. For these the column names need to be adjusted to - correspond to the case of the fields specified in the form data for Viz - to work properly. This adjustment can be done here. - """ - if cls.consistent_case_sensitivity: - return df - else: - return cls.align_df_col_names_with_form_data(df, fd) - @staticmethod - def align_df_col_names_with_form_data(df, fd): - """Helper function to rename columns that have changed case during query. - - Returns a dataframe where column names have been adjusted to correspond with - column names in form data (case insensitive). Examples: - dataframe: 'col1', form_data: 'col1' -> no change - dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1' - dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1' - """ - - columns = set() - lowercase_mapping = {} - - metrics = utils.get_metric_names(fd.get('metrics', [])) - groupby = fd.get('groupby', []) - other_cols = [utils.DTTM_ALIAS] - for col in metrics + groupby + other_cols: - columns.add(col) - lowercase_mapping[col.lower()] = col - - rename_cols = {} - for col in df.columns: - if col not in columns: - orig_col = lowercase_mapping.get(col.lower()) - if orig_col: - rename_cols[col] = orig_col - - return df.rename(index=str, columns=rename_cols) + def mutate_column_label(label): + return label @staticmethod def mutate_expression_label(label): @@ -478,7 +430,6 @@ def get_table_names(cls, schema, inspector): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' - consistent_case_sensitivity = False time_grain_functions = { None: '{col}', 'PT1S': "DATE_TRUNC('SECOND', {col})", @@ -508,6 +459,14 @@ def adjust_database_uri(cls, uri, selected_schema=None): uri.database = database + '/' + selected_schema return uri + @staticmethod + def mutate_column_label(label): + return label.upper() + + @staticmethod + def mutate_expression_label(label): + return label.upper() + class VerticaEngineSpec(PostgresBaseEngineSpec): engine = 'vertica' diff --git a/superset/models/core.py b/superset/models/core.py index 9d9674c19560a..f0acc502401f3 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -795,7 +795,8 @@ def needs_conversion(df_series): self.db_engine_spec.execute(cursor, sqls[-1]) if cursor.description is not None: - columns = [col_desc[0] for col_desc in cursor.description] + columns = [self.db_engine_spec.mutate_column_label(col_desc[0]) + for col_desc in cursor.description] else: columns = [] diff --git a/superset/views/core.py b/superset/views/core.py index 2e4a8e8bc56f5..68c2814c827c2 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2191,6 +2191,7 @@ def sqllab_viz(self): q = SupersetQuery(data.get('sql')) table.sql = q.stripped() db.session.add(table) + mydb = db.session.query(models.Database).filter_by(id=table.database_id).first() cols = [] for config in data.get('columns'): column_name = config.get('name') @@ -2198,7 +2199,7 @@ def sqllab_viz(self): TableColumn = SqlaTable.column_class SqlMetric = SqlaTable.metric_class col = TableColumn( - column_name=column_name, + column_name=mydb.db_engine_spec.mutate_column_label(column_name), filterable=True, groupby=True, is_dttm=config.get('is_date', False), @@ -2208,7 +2209,8 @@ def sqllab_viz(self): table.columns = cols table.metrics = [ - SqlMetric(metric_name='count', expression='count(*)'), + SqlMetric(metric_name=mydb.db_engine_spec.mutate_expression_label('count'), + expression='count(*)'), ] db.session.commit() return self.json_response(json.dumps({ @@ -2254,7 +2256,7 @@ def table(self, database_id, table_name, schema): dtype = col['type'].__class__.__name__ pass payload_columns.append({ - 'name': col['name'], + 'name': mydb.db_engine_spec.mutate_column_label(col['name']), 'type': dtype.split('(')[0] if '(' in dtype else dtype, 'longType': dtype, 'keys': [ diff --git a/superset/viz.py b/superset/viz.py index 5f4cea84984ae..df2bdaf0a5282 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -386,10 +386,6 @@ def get_df_payload(self, query_obj=None): if query_obj and not is_loaded: try: df = self.get_df(query_obj) - if hasattr(self.datasource, 'database') and \ - hasattr(self.datasource.database, 'db_engine_spec'): - db_engine_spec = self.datasource.database.db_engine_spec - df = db_engine_spec.adjust_df_column_names(df, self.form_data) if self.status != utils.QueryStatus.FAILED: stats_logger.incr('loaded_from_source') is_loaded = True From ead1b558295974e16f435b5479cc6b2904d3af26 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 22 Aug 2018 19:46:54 +0300 Subject: [PATCH 02/29] Add mutation to any_date_col --- superset/connectors/sqla/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index d8e052b7f76c8..c70c05f0421d7 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -841,7 +841,7 @@ def fetch_metadata(self): dbcol.type = datatype self.columns.append(dbcol) if not any_date_col and dbcol.is_time: - any_date_col = col.name + any_date_col = db_engine_spec.mutate_column_label(col.name) metrics += dbcol.get_metrics().values() metrics.append(M( From 4540d11ac740c7ff56acdc4e882ac2bd4aa5af2f Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 22 Aug 2018 19:47:20 +0300 Subject: [PATCH 03/29] Linting --- superset/connectors/sqla/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c70c05f0421d7..1a766309a5b59 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -830,7 +830,7 @@ def fetch_metadata(self): if not dbcol: dbcol = TableColumn( column_name=db_engine_spec.mutate_column_label(col.name), - type=datatype + type=datatype, ) dbcol.groupby = dbcol.is_string dbcol.filterable = dbcol.is_string From 027f794f9400c200c1730423489c31d89a5ccd91 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 22 Aug 2018 19:48:28 +0300 Subject: [PATCH 04/29] Add mutation to oracle and redshift --- superset/db_engine_specs.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 2c8e561ddff6a..eb829781fe3f0 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -476,6 +476,14 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec): engine = 'redshift' consistent_case_sensitivity = False + @staticmethod + def mutate_column_label(label): + return label.upper() + + @staticmethod + def mutate_expression_label(label): + return label.upper() + class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' @@ -500,6 +508,14 @@ def convert_dttm(cls, target_type, dttm): """TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""" ).format(dttm.isoformat()) + @staticmethod + def mutate_column_label(label): + return label.upper() + + @staticmethod + def mutate_expression_label(label): + return label.upper() + class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' From f8f851a37e14a257bd880c3aae7be8e3727cd8d7 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 23 Aug 2018 08:20:29 +0300 Subject: [PATCH 05/29] Fine tune how and which labels are mutated --- superset/connectors/sqla/models.py | 3 ++- superset/db_engine_specs.py | 12 ------------ superset/models/core.py | 3 +-- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 1a766309a5b59..8e06735f9e939 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -853,7 +853,8 @@ def fetch_metadata(self): if not self.main_dttm_col: self.main_dttm_col = any_date_col for metric in metrics: - metric.metric_name = db_engine_spec.mutate_column_label(metric.metric_name) + metric.metric_name = db_engine_spec.mutate_expression_label( + metric.metric_name) self.add_missing_metrics(metrics) db.session.merge(self) db.session.commit() diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index eb829781fe3f0..0052b16fe1460 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -463,10 +463,6 @@ def adjust_database_uri(cls, uri, selected_schema=None): def mutate_column_label(label): return label.upper() - @staticmethod - def mutate_expression_label(label): - return label.upper() - class VerticaEngineSpec(PostgresBaseEngineSpec): engine = 'vertica' @@ -480,10 +476,6 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec): def mutate_column_label(label): return label.upper() - @staticmethod - def mutate_expression_label(label): - return label.upper() - class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' @@ -512,10 +504,6 @@ def convert_dttm(cls, target_type, dttm): def mutate_column_label(label): return label.upper() - @staticmethod - def mutate_expression_label(label): - return label.upper() - class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' diff --git a/superset/models/core.py b/superset/models/core.py index f0acc502401f3..9d9674c19560a 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -795,8 +795,7 @@ def needs_conversion(df_series): self.db_engine_spec.execute(cursor, sqls[-1]) if cursor.description is not None: - columns = [self.db_engine_spec.mutate_column_label(col_desc[0]) - for col_desc in cursor.description] + columns = [col_desc[0] for col_desc in cursor.description] else: columns = [] From ce3617b2324e46befd2a8dfeef91789ee1d4f3a7 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 00:10:19 +0300 Subject: [PATCH 06/29] Implement alias quoting logic for oracle-like databases --- superset/connectors/sqla/models.py | 64 ++++++++++++++++++------------ superset/db_engine_specs.py | 30 ++++++-------- 2 files changed, 51 insertions(+), 43 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8e06735f9e939..1f0a53edb37c8 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -99,21 +99,22 @@ class TableColumn(Model, BaseColumn): s for s in export_fields if s not in ('table_id',)] export_parent = 'table' - @property - def sqla_col(self): - name = self.column_name + def sqla_col(self, db_engine_spec=None, label=None): + if db_engine_spec is None or label is None: + label = self.column_name + col_label = db_engine_spec.get_column_label(label) if not self.expression: - col = column(self.column_name).label(name) + col = column(self.column_name).label(col_label) else: - col = literal_column(self.expression).label(name) + col = literal_column(self.expression).label(col_label) return col @property def datasource(self): return self.table - def get_time_filter(self, start_dttm, end_dttm): - col = self.sqla_col.label('__time') + def get_time_filter(self, start_dttm, end_dttm, db_engine_spec): + col = self.sqla_col(db_engine_spec, '__time') l = [] # noqa: E741 if start_dttm: l.append(col >= text(self.dttm_sql_literal(start_dttm))) @@ -231,10 +232,14 @@ class SqlMetric(Model, BaseMetric): s for s in export_fields if s not in ('table_id', )]) export_parent = 'table' - @property - def sqla_col(self): - name = self.metric_name - return literal_column(self.expression).label(name) + def sqla_col(self, db_engine_spec=None, label=None): + if label is None: + label = self.metric_name + if db_engine_spec: + col_label = db_engine_spec.get_column_label(label) + else: + col_label = self.metric_name + return literal_column(self.expression).label(col_label) @property def perm(self): @@ -424,7 +429,7 @@ def values_for_column(self, column_name, limit=10000): db_engine_spec = self.database.db_engine_spec qry = ( - select([target_col.sqla_col]) + select([target_col.sqla_col(db_engine_spec)]) .select_from(self.get_from_clause(tp, db_engine_spec)) .distinct() ) @@ -437,7 +442,7 @@ def values_for_column(self, column_name, limit=10000): engine = self.database.get_sqla_engine() sql = '{}'.format( - qry.compile(engine, compile_kwargs={'literal_binds': True}), + qry.compile(postgresql.dialect, compile_kwargs={'literal_binds': True}), ) sql = self.mutate_query_from_config(sql) @@ -484,30 +489,37 @@ def get_from_clause(self, template_processor=None, db_engine_spec=None): return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() - def adhoc_metric_to_sa(self, metric, cols): + def adhoc_metric_to_sa(self, metric, cols, db_engine_spec=None): """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict cols: Columns for the current table + :param BaseEngineSpec db_engine_spec: Db engine specs for + database specific handling of column labels :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ expressionType = metric.get('expressionType') + if db_engine_spec and 'label' in metric: + label = db_engine_spec.get_column_label(metric['label']) + else: + label = metric.get('label') + if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: column_name = metric.get('column').get('column_name') sa_column = column(column_name) table_column = cols.get(column_name) if table_column: - sa_column = table_column.sqla_col + sa_column = table_column.sqla_col(db_engine_spec) sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column) - sa_metric = sa_metric.label(metric.get('label')) + sa_metric = sa_metric.label(label) return sa_metric elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: sa_metric = literal_column(metric.get('sqlExpression')) - sa_metric = sa_metric.label(metric.get('label')) + sa_metric = sa_metric.label(label) return sa_metric else: return None @@ -566,15 +578,16 @@ def get_sqla_query( # sqla metrics_exprs = [] for m in metrics: if utils.is_adhoc_metric(m): - metrics_exprs.append(self.adhoc_metric_to_sa(m, cols)) + metrics_exprs.append(self.adhoc_metric_to_sa(m, cols, db_engine_spec)) elif m in metrics_dict: - metrics_exprs.append(metrics_dict.get(m).sqla_col) + metrics_exprs.append(metrics_dict.get(m).sqla_col(db_engine_spec)) else: raise Exception(_("Metric '{}' is not valid".format(m))) if metrics_exprs: main_metric_expr = metrics_exprs[0] else: - main_metric_expr = literal_column('COUNT(*)').label('ccount') + main_metric_expr = literal_column('COUNT(*)').label( + db_engine_spec.get_column_label('count')) select_exprs = [] groupby_exprs = [] @@ -585,8 +598,8 @@ def get_sqla_query( # sqla inner_groupby_exprs = [] for s in groupby: col = cols[s] - outer = col.sqla_col - inner = col.sqla_col.label(col.column_name + '__') + outer = col.sqla_col(db_engine_spec) + inner = col.sqla_col(db_engine_spec, col.column_name + '__') groupby_exprs.append(outer) select_exprs.append(outer) @@ -594,7 +607,7 @@ def get_sqla_query( # sqla inner_select_exprs.append(inner) elif columns: for s in columns: - select_exprs.append(cols[s].sqla_col) + select_exprs.append(cols[s].sqla_col(db_engine_spec)) metrics_exprs = [] if granularity: @@ -612,8 +625,9 @@ def get_sqla_query( # sqla self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col]. - get_time_filter(from_dttm, to_dttm)) - time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) + get_time_filter(from_dttm, to_dttm, db_engine_spec)) + time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm, + db_engine_spec)) select_exprs += metrics_exprs qry = sa.select(select_exprs) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 0052b16fe1460..2912995153b5c 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -35,7 +35,7 @@ from sqlalchemy import select from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url -from sqlalchemy.sql import text +from sqlalchemy.sql import text, quoted_name from sqlalchemy.sql.expression import TextAsFrom import sqlparse from tableschema import Table @@ -102,6 +102,7 @@ class BaseEngineSpec(object): inner_joins = True allows_subquery = True arraysize = None + force_column_alias_quotes = False @classmethod def get_time_grains(cls): @@ -374,9 +375,11 @@ def execute(cls, cursor, query, async=False): cursor.arraysize = cls.arraysize cursor.execute(query) - @staticmethod - def mutate_column_label(label): - return label + @classmethod + def get_column_label(cls, label_name): + if cls.force_column_alias_quotes is True: + return quoted_name(label_name, True) + return label_name @staticmethod def mutate_expression_label(label): @@ -430,6 +433,8 @@ def get_table_names(cls, schema, inspector): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' + force_column_alias_quotes = True + time_grain_functions = { None: '{col}', 'PT1S': "DATE_TRUNC('SECOND', {col})", @@ -459,10 +464,6 @@ def adjust_database_uri(cls, uri, selected_schema=None): uri.database = database + '/' + selected_schema return uri - @staticmethod - def mutate_column_label(label): - return label.upper() - class VerticaEngineSpec(PostgresBaseEngineSpec): engine = 'vertica' @@ -470,17 +471,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec): class RedshiftEngineSpec(PostgresBaseEngineSpec): engine = 'redshift' - consistent_case_sensitivity = False - - @staticmethod - def mutate_column_label(label): - return label.upper() + force_column_alias_quotes = True class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' limit_method = LimitMethod.WRAP_SQL - consistent_case_sensitivity = False + force_column_alias_quotes = True time_grain_functions = { None: '{col}', @@ -500,14 +497,11 @@ def convert_dttm(cls, target_type, dttm): """TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""" ).format(dttm.isoformat()) - @staticmethod - def mutate_column_label(label): - return label.upper() - class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' limit_method = LimitMethod.WRAP_SQL + force_column_alias_quotes = True time_grain_functions = { None: '{col}', 'PT1S': 'CAST({col} as TIMESTAMP)' From 5f67d59ac9ce4f5b084ed0616c25f7f5cd028ee5 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 00:16:41 +0300 Subject: [PATCH 07/29] Fix and align column and metric sqla_col methods --- superset/connectors/sqla/models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 1f0a53edb37c8..e16b3d336ac15 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -100,9 +100,14 @@ class TableColumn(Model, BaseColumn): export_parent = 'table' def sqla_col(self, db_engine_spec=None, label=None): - if db_engine_spec is None or label is None: + if label is None: label = self.column_name - col_label = db_engine_spec.get_column_label(label) + + if db_engine_spec: + col_label = db_engine_spec.get_column_label(label) + else: + col_label = label + if not self.expression: col = column(self.column_name).label(col_label) else: @@ -235,10 +240,12 @@ class SqlMetric(Model, BaseMetric): def sqla_col(self, db_engine_spec=None, label=None): if label is None: label = self.metric_name + if db_engine_spec: col_label = db_engine_spec.get_column_label(label) else: col_label = self.metric_name + return literal_column(self.expression).label(col_label) @property From 8b6e52d48ceec90a2a7b07343a0eed37fc4c406b Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 00:21:51 +0300 Subject: [PATCH 08/29] Clean up typos and redundant logic --- superset/connectors/sqla/models.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e16b3d336ac15..b67a90ba7d293 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -449,7 +449,7 @@ def values_for_column(self, column_name, limit=10000): engine = self.database.get_sqla_engine() sql = '{}'.format( - qry.compile(postgresql.dialect, compile_kwargs={'literal_binds': True}), + qry.compile(engine, compile_kwargs={'literal_binds': True}), ) sql = self.mutate_query_from_config(sql) @@ -502,7 +502,7 @@ def adhoc_metric_to_sa(self, metric, cols, db_engine_spec=None): :param dict metric: Adhoc metric definition :param dict cols: Columns for the current table - :param BaseEngineSpec db_engine_spec: Db engine specs for + :param BaseEngineSpec db_engine_spec: Db engine spec for database specific handling of column labels :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column @@ -850,8 +850,7 @@ def fetch_metadata(self): dbcol = dbcols.get(col.name, None) if not dbcol: dbcol = TableColumn( - column_name=db_engine_spec.mutate_column_label(col.name), - type=datatype, + column_name=col.name, type=datatype, ) dbcol.groupby = dbcol.is_string dbcol.filterable = dbcol.is_string @@ -862,7 +861,7 @@ def fetch_metadata(self): dbcol.type = datatype self.columns.append(dbcol) if not any_date_col and dbcol.is_time: - any_date_col = db_engine_spec.mutate_column_label(col.name) + any_date_col = col.name metrics += dbcol.get_metrics().values() metrics.append(M( From 62163fa8cc0be4fb83cc4ec5d8a9ea4cc33c22f0 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 00:24:05 +0300 Subject: [PATCH 09/29] Move new attribute to old location --- superset/db_engine_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 2912995153b5c..40fe05d0d589a 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -101,8 +101,8 @@ class BaseEngineSpec(object): time_secondary_columns = False inner_joins = True allows_subquery = True - arraysize = None force_column_alias_quotes = False + arraysize = None @classmethod def get_time_grains(cls): From e8dd7da2233b8cece256cca294a8de345c160f8a Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 00:25:37 +0300 Subject: [PATCH 10/29] Linting --- superset/db_engine_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 40fe05d0d589a..365cd609eb00b 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -35,7 +35,7 @@ from sqlalchemy import select from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url -from sqlalchemy.sql import text, quoted_name +from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import TextAsFrom import sqlparse from tableschema import Table From 6c6436f24ff4c24c54afc4c6a85e2d8865e05f56 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 01:01:23 +0300 Subject: [PATCH 11/29] Replace old sqla_col property references with function calls --- superset/connectors/sqla/models.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index b67a90ba7d293..54e9ccab31c43 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -659,9 +659,9 @@ def get_sqla_query( # sqla target_column_is_numeric=col_obj.is_num, is_list_target=is_list_target) if op in ('in', 'not in'): - cond = col_obj.sqla_col.in_(eq) + cond = col_obj.sqla_col().in_(eq) if '' in eq: - cond = or_(cond, col_obj.sqla_col == None) # noqa + cond = or_(cond, col_obj.sqla_col() == None) # noqa if op == 'not in': cond = ~cond where_clause_and.append(cond) @@ -669,23 +669,23 @@ def get_sqla_query( # sqla if col_obj.is_num: eq = utils.string_to_num(flt['val']) if op == '==': - where_clause_and.append(col_obj.sqla_col == eq) + where_clause_and.append(col_obj.sqla_col() == eq) elif op == '!=': - where_clause_and.append(col_obj.sqla_col != eq) + where_clause_and.append(col_obj.sqla_col() != eq) elif op == '>': - where_clause_and.append(col_obj.sqla_col > eq) + where_clause_and.append(col_obj.sqla_col() > eq) elif op == '<': - where_clause_and.append(col_obj.sqla_col < eq) + where_clause_and.append(col_obj.sqla_col() < eq) elif op == '>=': - where_clause_and.append(col_obj.sqla_col >= eq) + where_clause_and.append(col_obj.sqla_col() >= eq) elif op == '<=': - where_clause_and.append(col_obj.sqla_col <= eq) + where_clause_and.append(col_obj.sqla_col() <= eq) elif op == 'LIKE': - where_clause_and.append(col_obj.sqla_col.like(eq)) + where_clause_and.append(col_obj.sqla_col().like(eq)) elif op == 'IS NULL': - where_clause_and.append(col_obj.sqla_col == None) # noqa + where_clause_and.append(col_obj.sqla_col() == None) # noqa elif op == 'IS NOT NULL': - where_clause_and.append(col_obj.sqla_col != None) # noqa + where_clause_and.append(col_obj.sqla_col() != None) # noqa if extras: where = extras.get('where') if where: @@ -738,7 +738,7 @@ def get_sqla_query( # sqla timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric, ) - ob = timeseries_limit_metric.sqla_col + ob = timeseries_limit_metric.sqla_col() else: raise Exception(_("Metric '{}' is not valid".format(m))) direction = desc if order_desc else asc @@ -783,7 +783,7 @@ def _get_top_groups(self, df, dimensions): group = [] for dimension in dimensions: col_obj = cols.get(dimension) - group.append(col_obj.sqla_col == row[dimension]) + group.append(col_obj.sqla_col() == row[dimension]) groups.append(and_(*group)) return or_(*groups) From 4b0e5938c41e88eb7c2c660d4cfdccf6e101de45 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 07:57:44 +0300 Subject: [PATCH 12/29] Remove redundant calls to mutate_column_label --- superset/views/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/superset/views/core.py b/superset/views/core.py index 68c2814c827c2..64545e0ce6234 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2191,7 +2191,6 @@ def sqllab_viz(self): q = SupersetQuery(data.get('sql')) table.sql = q.stripped() db.session.add(table) - mydb = db.session.query(models.Database).filter_by(id=table.database_id).first() cols = [] for config in data.get('columns'): column_name = config.get('name') @@ -2199,7 +2198,7 @@ def sqllab_viz(self): TableColumn = SqlaTable.column_class SqlMetric = SqlaTable.metric_class col = TableColumn( - column_name=mydb.db_engine_spec.mutate_column_label(column_name), + column_name=column_name, filterable=True, groupby=True, is_dttm=config.get('is_date', False), @@ -2256,7 +2255,7 @@ def table(self, database_id, table_name, schema): dtype = col['type'].__class__.__name__ pass payload_columns.append({ - 'name': mydb.db_engine_spec.mutate_column_label(col['name']), + 'name': col['name'], 'type': dtype.split('(')[0] if '(' in dtype else dtype, 'longType': dtype, 'keys': [ From 76959dd906c0817f2e0616cbf0eb27928e6cd37a Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 08:02:01 +0300 Subject: [PATCH 13/29] Move duplicated logic to common function --- superset/connectors/sqla/models.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 54e9ccab31c43..ceb244d264923 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -34,6 +34,16 @@ config = app.config +def get_column_label(db_engine_spec, name, label): + if label is None: + label = name + + if db_engine_spec: + return db_engine_spec.get_column_label(label) + + return label + + class AnnotationDatasource(BaseDatasource): """ Dummy object so we can query annotations using 'Viz' objects just like regular datasources. @@ -100,14 +110,7 @@ class TableColumn(Model, BaseColumn): export_parent = 'table' def sqla_col(self, db_engine_spec=None, label=None): - if label is None: - label = self.column_name - - if db_engine_spec: - col_label = db_engine_spec.get_column_label(label) - else: - col_label = label - + col_label = get_column_label(db_engine_spec, self.column_name, label) if not self.expression: col = column(self.column_name).label(col_label) else: @@ -238,14 +241,7 @@ class SqlMetric(Model, BaseMetric): export_parent = 'table' def sqla_col(self, db_engine_spec=None, label=None): - if label is None: - label = self.metric_name - - if db_engine_spec: - col_label = db_engine_spec.get_column_label(label) - else: - col_label = self.metric_name - + col_label = get_column_label(db_engine_spec, self.metric_name, label) return literal_column(self.expression).label(col_label) @property From 6c7fccc8e4af1d94f53832437167831a1a3a484c Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 08:21:41 +0300 Subject: [PATCH 14/29] Add db_engine_specs to all sqla_col calls --- superset/connectors/sqla/models.py | 32 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ceb244d264923..0d5d1f744c8e0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -655,9 +655,9 @@ def get_sqla_query( # sqla target_column_is_numeric=col_obj.is_num, is_list_target=is_list_target) if op in ('in', 'not in'): - cond = col_obj.sqla_col().in_(eq) + cond = col_obj.sqla_col(db_engine_spec).in_(eq) if '' in eq: - cond = or_(cond, col_obj.sqla_col() == None) # noqa + cond = or_(cond, col_obj.sqla_col(db_engine_spec) == None) # noqa if op == 'not in': cond = ~cond where_clause_and.append(cond) @@ -665,23 +665,25 @@ def get_sqla_query( # sqla if col_obj.is_num: eq = utils.string_to_num(flt['val']) if op == '==': - where_clause_and.append(col_obj.sqla_col() == eq) + where_clause_and.append(col_obj.sqla_col(db_engine_spec) == eq) elif op == '!=': - where_clause_and.append(col_obj.sqla_col() != eq) + where_clause_and.append(col_obj.sqla_col(db_engine_spec) != eq) elif op == '>': - where_clause_and.append(col_obj.sqla_col() > eq) + where_clause_and.append(col_obj.sqla_col(db_engine_spec) > eq) elif op == '<': - where_clause_and.append(col_obj.sqla_col() < eq) + where_clause_and.append(col_obj.sqla_col(db_engine_spec) < eq) elif op == '>=': - where_clause_and.append(col_obj.sqla_col() >= eq) + where_clause_and.append(col_obj.sqla_col(db_engine_spec) >= eq) elif op == '<=': - where_clause_and.append(col_obj.sqla_col() <= eq) + where_clause_and.append(col_obj.sqla_col(db_engine_spec) <= eq) elif op == 'LIKE': - where_clause_and.append(col_obj.sqla_col().like(eq)) + where_clause_and.append(col_obj.sqla_col(db_engine_spec).like(eq)) elif op == 'IS NULL': - where_clause_and.append(col_obj.sqla_col() == None) # noqa + where_clause_and.append( + col_obj.sqla_col(db_engine_spec) == None) # noqa elif op == 'IS NOT NULL': - where_clause_and.append(col_obj.sqla_col() != None) # noqa + where_clause_and.append( + col_obj.sqla_col(db_engine_spec) != None) # noqa if extras: where = extras.get('where') if where: @@ -734,7 +736,7 @@ def get_sqla_query( # sqla timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric, ) - ob = timeseries_limit_metric.sqla_col() + ob = timeseries_limit_metric.sqla_col(db_engine_spec) else: raise Exception(_("Metric '{}' is not valid".format(m))) direction = desc if order_desc else asc @@ -767,19 +769,19 @@ def get_sqla_query( # sqla } result = self.query(subquery_obj) dimensions = [c for c in result.df.columns if c not in metrics] - top_groups = self._get_top_groups(result.df, dimensions) + top_groups = self._get_top_groups(result.df, dimensions, db_engine_spec) qry = qry.where(top_groups) return qry.select_from(tbl) - def _get_top_groups(self, df, dimensions): + def _get_top_groups(self, df, dimensions, db_engine_spec=None): cols = {col.column_name: col for col in self.columns} groups = [] for unused, row in df.iterrows(): group = [] for dimension in dimensions: col_obj = cols.get(dimension) - group.append(col_obj.sqla_col() == row[dimension]) + group.append(col_obj.sqla_col(db_engine_spec) == row[dimension]) groups.append(and_(*group)) return or_(*groups) From c187104a999ccb2968fed6d1b81eb6043dd8433f Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 08:36:18 +0300 Subject: [PATCH 15/29] Add missing mydb --- superset/views/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/superset/views/core.py b/superset/views/core.py index 64545e0ce6234..22804700fd329 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2191,6 +2191,7 @@ def sqllab_viz(self): q = SupersetQuery(data.get('sql')) table.sql = q.stripped() db.session.add(table) + mydb = db.session.query(models.Database).filter_by(id=table.database_id).first() cols = [] for config in data.get('columns'): column_name = config.get('name') From 53b77a2c30f8dbfc9288e6c9af7879519ab95a8d Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 08:40:52 +0300 Subject: [PATCH 16/29] Add note about snowflake-sqlalchemy regression --- docs/installation.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/installation.rst b/docs/installation.rst index 008a2648f1a1b..d1d1fd5e4008a 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -389,6 +389,10 @@ Make sure the user has privileges to access and use all required databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does not test for user rights during engine creation. +*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of +snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to +use version 1.1.0 or try a newer version. + See `Snowflake SQLAlchemy `_. Caching From e792d15317724991f54df01db0b81afc7f43df77 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 11:17:11 +0300 Subject: [PATCH 17/29] Make db_engine_spec mandatory in sqla_col --- superset/connectors/sqla/models.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 0d5d1f744c8e0..4f692d9c886e0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -34,14 +34,9 @@ config = app.config -def get_column_label(db_engine_spec, name, label): - if label is None: - label = name - - if db_engine_spec: - return db_engine_spec.get_column_label(label) - - return label +def get_column_label(db_engine_spec, column_name, label): + label = label if label is not None else column_name + return db_engine_spec.get_column_label(label) class AnnotationDatasource(BaseDatasource): @@ -109,7 +104,7 @@ class TableColumn(Model, BaseColumn): s for s in export_fields if s not in ('table_id',)] export_parent = 'table' - def sqla_col(self, db_engine_spec=None, label=None): + def sqla_col(self, db_engine_spec, label=None): col_label = get_column_label(db_engine_spec, self.column_name, label) if not self.expression: col = column(self.column_name).label(col_label) @@ -121,7 +116,7 @@ def sqla_col(self, db_engine_spec=None, label=None): def datasource(self): return self.table - def get_time_filter(self, start_dttm, end_dttm, db_engine_spec): + def get_time_filter(self, db_engine_spec, start_dttm, end_dttm): col = self.sqla_col(db_engine_spec, '__time') l = [] # noqa: E741 if start_dttm: @@ -240,7 +235,7 @@ class SqlMetric(Model, BaseMetric): s for s in export_fields if s not in ('table_id', )]) export_parent = 'table' - def sqla_col(self, db_engine_spec=None, label=None): + def sqla_col(self, db_engine_spec, label=None): col_label = get_column_label(db_engine_spec, self.metric_name, label) return literal_column(self.expression).label(col_label) @@ -482,7 +477,7 @@ def get_sqla_table(self): tbl.schema = self.schema return tbl - def get_from_clause(self, template_processor=None, db_engine_spec=None): + def get_from_clause(self, template_processor=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql @@ -628,14 +623,14 @@ def get_sqla_query( # sqla self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col]. - get_time_filter(from_dttm, to_dttm, db_engine_spec)) - time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm, - db_engine_spec)) + get_time_filter(db_engine_spec, from_dttm, to_dttm)) + time_filters.append(dttm_col.get_time_filter(db_engine_spec, + from_dttm, to_dttm)) select_exprs += metrics_exprs qry = sa.select(select_exprs) - tbl = self.get_from_clause(template_processor, db_engine_spec) + tbl = self.get_from_clause(template_processor) if not columns: qry = qry.group_by(*groupby_exprs) From 44eac831866e0ac872cbbafa3128ea8617a41087 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 11:25:33 +0300 Subject: [PATCH 18/29] Small refactoring and cleanup --- superset/connectors/sqla/models.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 4f692d9c886e0..3394e41a775a8 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -769,7 +769,7 @@ def get_sqla_query( # sqla return qry.select_from(tbl) - def _get_top_groups(self, df, dimensions, db_engine_spec=None): + def _get_top_groups(self, df, dimensions, db_engine_spec): cols = {col.column_name: col for col in self.columns} groups = [] for unused, row in df.iterrows(): @@ -842,9 +842,7 @@ def fetch_metadata(self): logging.exception(e) dbcol = dbcols.get(col.name, None) if not dbcol: - dbcol = TableColumn( - column_name=col.name, type=datatype, - ) + dbcol = TableColumn(column_name=col.name, type=datatype) dbcol.groupby = dbcol.is_string dbcol.filterable = dbcol.is_string dbcol.sum = dbcol.is_num From a161efc688d917ab6499585687d1dafd51eca258 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 11:36:54 +0300 Subject: [PATCH 19/29] Remove db_engine_spec from get_from_clause call --- superset/connectors/sqla/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 3394e41a775a8..d60c82c92be07 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -428,7 +428,7 @@ def values_for_column(self, column_name, limit=10000): qry = ( select([target_col.sqla_col(db_engine_spec)]) - .select_from(self.get_from_clause(tp, db_engine_spec)) + .select_from(self.get_from_clause(tp)) .distinct() ) if limit: From 901b6355bf0a6446a089c6d321a8f826ac2e9a09 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 13:37:16 +0300 Subject: [PATCH 20/29] Make db_engine_spec mandatory in adhoc_metric_to_sa --- superset/connectors/sqla/models.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index d60c82c92be07..c520477e7eea9 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -487,7 +487,7 @@ def get_from_clause(self, template_processor=None): return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() - def adhoc_metric_to_sa(self, metric, cols, db_engine_spec=None): + def adhoc_metric_to_sa(self, metric, cols, db_engine_spec): """ Turn an adhoc metric into a sqlalchemy column. @@ -498,13 +498,10 @@ def adhoc_metric_to_sa(self, metric, cols, db_engine_spec=None): :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ - expressionType = metric.get('expressionType') - if db_engine_spec and 'label' in metric: - label = db_engine_spec.get_column_label(metric['label']) - else: - label = metric.get('label') + expression_type = metric.get('expressionType') + label = db_engine_spec.get_column_label(metric.get('label')) - if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: + if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: column_name = metric.get('column').get('column_name') sa_column = column(column_name) table_column = cols.get(column_name) @@ -515,7 +512,7 @@ def adhoc_metric_to_sa(self, metric, cols, db_engine_spec=None): sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column) sa_metric = sa_metric.label(label) return sa_metric - elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: + elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: sa_metric = literal_column(metric.get('sqlExpression')) sa_metric = sa_metric.label(label) return sa_metric From 7e7a2a886013701e014c4a334d1781bc7564ff22 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 13:40:57 +0300 Subject: [PATCH 21/29] Remove redundant mutate_expression_label call --- superset/views/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/superset/views/core.py b/superset/views/core.py index 22804700fd329..2e4a8e8bc56f5 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2191,7 +2191,6 @@ def sqllab_viz(self): q = SupersetQuery(data.get('sql')) table.sql = q.stripped() db.session.add(table) - mydb = db.session.query(models.Database).filter_by(id=table.database_id).first() cols = [] for config in data.get('columns'): column_name = config.get('name') @@ -2209,8 +2208,7 @@ def sqllab_viz(self): table.columns = cols table.metrics = [ - SqlMetric(metric_name=mydb.db_engine_spec.mutate_expression_label('count'), - expression='count(*)'), + SqlMetric(metric_name='count', expression='count(*)'), ] db.session.commit() return self.json_response(json.dumps({ From d60ed34fe5855fae811fcb7e0b629fa8fa034d33 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 26 Aug 2018 14:08:50 +0300 Subject: [PATCH 22/29] Add missing db_engine_specs to adhoc_metric_to_sa --- superset/connectors/sqla/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c520477e7eea9..b907ea8b04628 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -697,7 +697,7 @@ def get_sqla_query( # sqla for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): - col = self.adhoc_metric_to_sa(col, cols) + col = self.adhoc_metric_to_sa(col, cols, db_engine_spec) qry = qry.order_by(direction(col)) if row_limit: @@ -723,7 +723,8 @@ def get_sqla_query( # sqla ob = inner_main_metric_expr if timeseries_limit_metric: if utils.is_adhoc_metric(timeseries_limit_metric): - ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols) + ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols, + db_engine_spec) elif timeseries_limit_metric in metrics_dict: timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric, From fc2c4f2607aeed03ef4f58396d8b8efe5039c03d Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 27 Aug 2018 10:54:44 +0300 Subject: [PATCH 23/29] Rename arg label_name to label in get_column_label() --- superset/db_engine_specs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 365cd609eb00b..7b5f3b6c0ccae 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -376,10 +376,10 @@ def execute(cls, cursor, query, async=False): cursor.execute(query) @classmethod - def get_column_label(cls, label_name): + def get_column_label(cls, label): if cls.force_column_alias_quotes is True: - return quoted_name(label_name, True) - return label_name + return quoted_name(label, True) + return label @staticmethod def mutate_expression_label(label): From 3f2d874da307a83fcff106259cadb08c557456f1 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 27 Aug 2018 13:48:27 +0300 Subject: [PATCH 24/29] Rename label function and add docstring --- superset/connectors/sqla/models.py | 75 +++++++++++++++--------------- superset/db_engine_specs.py | 7 ++- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index b907ea8b04628..68fe2464d665f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -36,7 +36,7 @@ def get_column_label(db_engine_spec, column_name, label): label = label if label is not None else column_name - return db_engine_spec.get_column_label(label) + return db_engine_spec.make_label_compatible(label) class AnnotationDatasource(BaseDatasource): @@ -104,7 +104,8 @@ class TableColumn(Model, BaseColumn): s for s in export_fields if s not in ('table_id',)] export_parent = 'table' - def sqla_col(self, db_engine_spec, label=None): + def get_sqla_col(self, label=None): + db_engine_spec = self.table.database.db_engine_spec col_label = get_column_label(db_engine_spec, self.column_name, label) if not self.expression: col = column(self.column_name).label(col_label) @@ -116,8 +117,9 @@ def sqla_col(self, db_engine_spec, label=None): def datasource(self): return self.table - def get_time_filter(self, db_engine_spec, start_dttm, end_dttm): - col = self.sqla_col(db_engine_spec, '__time') + def get_time_filter(self, start_dttm, end_dttm): + db_engine_spec = self.table.database.db_engine_spec + col = self.get_sqla_col('__time') l = [] # noqa: E741 if start_dttm: l.append(col >= text(self.dttm_sql_literal(start_dttm))) @@ -235,7 +237,8 @@ class SqlMetric(Model, BaseMetric): s for s in export_fields if s not in ('table_id', )]) export_parent = 'table' - def sqla_col(self, db_engine_spec, label=None): + def get_sqla_col(self, label=None): + db_engine_spec = self.table.database.db_engine_spec col_label = get_column_label(db_engine_spec, self.metric_name, label) return literal_column(self.expression).label(col_label) @@ -427,7 +430,7 @@ def values_for_column(self, column_name, limit=10000): db_engine_spec = self.database.db_engine_spec qry = ( - select([target_col.sqla_col(db_engine_spec)]) + select([target_col.get_sqla_col()]) .select_from(self.get_from_clause(tp)) .distinct() ) @@ -487,7 +490,7 @@ def get_from_clause(self, template_processor=None): return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() - def adhoc_metric_to_sa(self, metric, cols, db_engine_spec): + def adhoc_metric_to_sqla(self, metric, cols, db_engine_spec): """ Turn an adhoc metric into a sqlalchemy column. @@ -499,7 +502,7 @@ def adhoc_metric_to_sa(self, metric, cols, db_engine_spec): :rtype: sqlalchemy.sql.column """ expression_type = metric.get('expressionType') - label = db_engine_spec.get_column_label(metric.get('label')) + label = db_engine_spec.make_label_compatible(metric.get('label')) if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: column_name = metric.get('column').get('column_name') @@ -507,7 +510,7 @@ def adhoc_metric_to_sa(self, metric, cols, db_engine_spec): table_column = cols.get(column_name) if table_column: - sa_column = table_column.sqla_col(db_engine_spec) + sa_column = table_column.get_sqla_col() sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column) sa_metric = sa_metric.label(label) @@ -573,16 +576,16 @@ def get_sqla_query( # sqla metrics_exprs = [] for m in metrics: if utils.is_adhoc_metric(m): - metrics_exprs.append(self.adhoc_metric_to_sa(m, cols, db_engine_spec)) + metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols, db_engine_spec)) elif m in metrics_dict: - metrics_exprs.append(metrics_dict.get(m).sqla_col(db_engine_spec)) + metrics_exprs.append(metrics_dict.get(m).get_sqla_col()) else: raise Exception(_("Metric '{}' is not valid".format(m))) if metrics_exprs: main_metric_expr = metrics_exprs[0] else: main_metric_expr = literal_column('COUNT(*)').label( - db_engine_spec.get_column_label('count')) + db_engine_spec.make_label_compatible('count')) select_exprs = [] groupby_exprs = [] @@ -593,8 +596,8 @@ def get_sqla_query( # sqla inner_groupby_exprs = [] for s in groupby: col = cols[s] - outer = col.sqla_col(db_engine_spec) - inner = col.sqla_col(db_engine_spec, col.column_name + '__') + outer = col.get_sqla_col() + inner = col.get_sqla_col(col.column_name + '__') groupby_exprs.append(outer) select_exprs.append(outer) @@ -602,7 +605,7 @@ def get_sqla_query( # sqla inner_select_exprs.append(inner) elif columns: for s in columns: - select_exprs.append(cols[s].sqla_col(db_engine_spec)) + select_exprs.append(cols[s].get_sqla_col()) metrics_exprs = [] if granularity: @@ -620,9 +623,8 @@ def get_sqla_query( # sqla self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col]. - get_time_filter(db_engine_spec, from_dttm, to_dttm)) - time_filters.append(dttm_col.get_time_filter(db_engine_spec, - from_dttm, to_dttm)) + get_time_filter(from_dttm, to_dttm)) + time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) select_exprs += metrics_exprs qry = sa.select(select_exprs) @@ -647,9 +649,9 @@ def get_sqla_query( # sqla target_column_is_numeric=col_obj.is_num, is_list_target=is_list_target) if op in ('in', 'not in'): - cond = col_obj.sqla_col(db_engine_spec).in_(eq) + cond = col_obj.get_sqla_col().in_(eq) if '' in eq: - cond = or_(cond, col_obj.sqla_col(db_engine_spec) == None) # noqa + cond = or_(cond, col_obj.get_sqla_col() == None) # noqa if op == 'not in': cond = ~cond where_clause_and.append(cond) @@ -657,25 +659,24 @@ def get_sqla_query( # sqla if col_obj.is_num: eq = utils.string_to_num(flt['val']) if op == '==': - where_clause_and.append(col_obj.sqla_col(db_engine_spec) == eq) + where_clause_and.append(col_obj.get_sqla_col() == eq) elif op == '!=': - where_clause_and.append(col_obj.sqla_col(db_engine_spec) != eq) + where_clause_and.append(col_obj.get_sqla_col() != eq) elif op == '>': - where_clause_and.append(col_obj.sqla_col(db_engine_spec) > eq) + where_clause_and.append(col_obj.get_sqla_col() > eq) elif op == '<': - where_clause_and.append(col_obj.sqla_col(db_engine_spec) < eq) + where_clause_and.append(col_obj.get_sqla_col() < eq) elif op == '>=': - where_clause_and.append(col_obj.sqla_col(db_engine_spec) >= eq) + where_clause_and.append(col_obj.get_sqla_col() >= eq) elif op == '<=': - where_clause_and.append(col_obj.sqla_col(db_engine_spec) <= eq) + where_clause_and.append(col_obj.get_sqla_col() <= eq) elif op == 'LIKE': - where_clause_and.append(col_obj.sqla_col(db_engine_spec).like(eq)) + where_clause_and.append(col_obj.get_sqla_col().like(eq)) elif op == 'IS NULL': - where_clause_and.append( - col_obj.sqla_col(db_engine_spec) == None) # noqa + where_clause_and.append(col_obj.get_sqla_col() == None) # noqa elif op == 'IS NOT NULL': where_clause_and.append( - col_obj.sqla_col(db_engine_spec) != None) # noqa + col_obj.get_sqla_col() != None) # noqa if extras: where = extras.get('where') if where: @@ -697,7 +698,7 @@ def get_sqla_query( # sqla for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): - col = self.adhoc_metric_to_sa(col, cols, db_engine_spec) + col = self.adhoc_metric_to_sqla(col, cols, db_engine_spec) qry = qry.order_by(direction(col)) if row_limit: @@ -723,13 +724,13 @@ def get_sqla_query( # sqla ob = inner_main_metric_expr if timeseries_limit_metric: if utils.is_adhoc_metric(timeseries_limit_metric): - ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols, - db_engine_spec) + ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols, + db_engine_spec) elif timeseries_limit_metric in metrics_dict: timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric, ) - ob = timeseries_limit_metric.sqla_col(db_engine_spec) + ob = timeseries_limit_metric.get_sqla_col() else: raise Exception(_("Metric '{}' is not valid".format(m))) direction = desc if order_desc else asc @@ -762,19 +763,19 @@ def get_sqla_query( # sqla } result = self.query(subquery_obj) dimensions = [c for c in result.df.columns if c not in metrics] - top_groups = self._get_top_groups(result.df, dimensions, db_engine_spec) + top_groups = self._get_top_groups(result.df, dimensions) qry = qry.where(top_groups) return qry.select_from(tbl) - def _get_top_groups(self, df, dimensions, db_engine_spec): + def _get_top_groups(self, df, dimensions): cols = {col.column_name: col for col in self.columns} groups = [] for unused, row in df.iterrows(): group = [] for dimension in dimensions: col_obj = cols.get(dimension) - group.append(col_obj.sqla_col(db_engine_spec) == row[dimension]) + group.append(col_obj.get_sqla_col() == row[dimension]) groups.append(and_(*group)) return or_(*groups) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 7b5f3b6c0ccae..a6ae8ce60353e 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -376,7 +376,12 @@ def execute(cls, cursor, query, async=False): cursor.execute(query) @classmethod - def get_column_label(cls, label): + def make_label_compatible(cls, label): + """ + Return a sqlalchemy.sql.elements.quoted_name if the engine requires + quoting of aliases to ensure that select query and query results + have same case. + """ if cls.force_column_alias_quotes is True: return quoted_name(label, True) return label From 95c99d3e78779da2a215758cb66dd646cb828fec Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 27 Aug 2018 13:53:52 +0300 Subject: [PATCH 25/29] Remove redundant db_engine_spec args --- superset/connectors/sqla/models.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 68fe2464d665f..4cacd5370eb94 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -490,18 +490,17 @@ def get_from_clause(self, template_processor=None): return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() - def adhoc_metric_to_sqla(self, metric, cols, db_engine_spec): + def adhoc_metric_to_sqla(self, metric, cols): """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict cols: Columns for the current table - :param BaseEngineSpec db_engine_spec: Db engine spec for - database specific handling of column labels :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ expression_type = metric.get('expressionType') + db_engine_spec = self.database.db_engine_spec label = db_engine_spec.make_label_compatible(metric.get('label')) if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: @@ -576,7 +575,7 @@ def get_sqla_query( # sqla metrics_exprs = [] for m in metrics: if utils.is_adhoc_metric(m): - metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols, db_engine_spec)) + metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols)) elif m in metrics_dict: metrics_exprs.append(metrics_dict.get(m).get_sqla_col()) else: @@ -698,7 +697,7 @@ def get_sqla_query( # sqla for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): - col = self.adhoc_metric_to_sqla(col, cols, db_engine_spec) + col = self.adhoc_metric_to_sqla(col, cols) qry = qry.order_by(direction(col)) if row_limit: @@ -724,8 +723,7 @@ def get_sqla_query( # sqla ob = inner_main_metric_expr if timeseries_limit_metric: if utils.is_adhoc_metric(timeseries_limit_metric): - ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols, - db_engine_spec) + ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) elif timeseries_limit_metric in metrics_dict: timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric, From 04d3b5fc16a14f67a0caaee35cab2f30790c1c77 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 27 Aug 2018 14:00:24 +0300 Subject: [PATCH 26/29] Rename col_label to label --- superset/connectors/sqla/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 4cacd5370eb94..046b23eb4c764 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -106,11 +106,11 @@ class TableColumn(Model, BaseColumn): def get_sqla_col(self, label=None): db_engine_spec = self.table.database.db_engine_spec - col_label = get_column_label(db_engine_spec, self.column_name, label) + label = get_column_label(db_engine_spec, self.column_name, label) if not self.expression: - col = column(self.column_name).label(col_label) + col = column(self.column_name).label(label) else: - col = literal_column(self.expression).label(col_label) + col = literal_column(self.expression).label(label) return col @property @@ -239,8 +239,8 @@ class SqlMetric(Model, BaseMetric): def get_sqla_col(self, label=None): db_engine_spec = self.table.database.db_engine_spec - col_label = get_column_label(db_engine_spec, self.metric_name, label) - return literal_column(self.expression).label(col_label) + label = get_column_label(db_engine_spec, self.metric_name, label) + return literal_column(self.expression).label(label) @property def perm(self): From efb6313079f163803c2a82c5c9d85c04d7f4b9fa Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 27 Aug 2018 14:10:49 +0300 Subject: [PATCH 27/29] Remove get_column_name wrapper and make direct calls to db_engine_spec --- superset/connectors/sqla/models.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 046b23eb4c764..3fa3d642b4af7 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -34,11 +34,6 @@ config = app.config -def get_column_label(db_engine_spec, column_name, label): - label = label if label is not None else column_name - return db_engine_spec.make_label_compatible(label) - - class AnnotationDatasource(BaseDatasource): """ Dummy object so we can query annotations using 'Viz' objects just like regular datasources. @@ -106,7 +101,7 @@ class TableColumn(Model, BaseColumn): def get_sqla_col(self, label=None): db_engine_spec = self.table.database.db_engine_spec - label = get_column_label(db_engine_spec, self.column_name, label) + label = db_engine_spec.make_label_compatible(label if label else self.column_name) if not self.expression: col = column(self.column_name).label(label) else: @@ -239,7 +234,7 @@ class SqlMetric(Model, BaseMetric): def get_sqla_col(self, label=None): db_engine_spec = self.table.database.db_engine_spec - label = get_column_label(db_engine_spec, self.metric_name, label) + label = db_engine_spec.make_label_compatible(label if label else self.metric_name) return literal_column(self.expression).label(label) @property From 3a6fe4b09ad840f22459d0b2e4f7a05ab08c8a31 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 27 Aug 2018 14:13:26 +0300 Subject: [PATCH 28/29] Remove unneeded db_engine_specs --- superset/connectors/sqla/models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 3fa3d642b4af7..6c4d2d66776f0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -113,7 +113,6 @@ def datasource(self): return self.table def get_time_filter(self, start_dttm, end_dttm): - db_engine_spec = self.table.database.db_engine_spec col = self.get_sqla_col('__time') l = [] # noqa: E741 if start_dttm: @@ -422,7 +421,6 @@ def values_for_column(self, column_name, limit=10000): cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() - db_engine_spec = self.database.db_engine_spec qry = ( select([target_col.get_sqla_col()]) From 0d18163c7c29664e44eebe52e6fcea15e8cb91af Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 27 Aug 2018 14:19:58 +0300 Subject: [PATCH 29/29] Rename sa_ vars to sqla_ --- superset/connectors/sqla/models.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 6c4d2d66776f0..fcfe9e0b3fd81 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -113,7 +113,7 @@ def datasource(self): return self.table def get_time_filter(self, start_dttm, end_dttm): - col = self.get_sqla_col('__time') + col = self.get_sqla_col(label='__time') l = [] # noqa: E741 if start_dttm: l.append(col >= text(self.dttm_sql_literal(start_dttm))) @@ -498,19 +498,19 @@ def adhoc_metric_to_sqla(self, metric, cols): if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: column_name = metric.get('column').get('column_name') - sa_column = column(column_name) + sqla_column = column(column_name) table_column = cols.get(column_name) if table_column: - sa_column = table_column.get_sqla_col() + sqla_column = table_column.get_sqla_col() - sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column) - sa_metric = sa_metric.label(label) - return sa_metric + sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column) + sqla_metric = sqla_metric.label(label) + return sqla_metric elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: - sa_metric = literal_column(metric.get('sqlExpression')) - sa_metric = sa_metric.label(label) - return sa_metric + sqla_metric = literal_column(metric.get('sqlExpression')) + sqla_metric = sqla_metric.label(label) + return sqla_metric else: return None