From 14ceff72ced9101b80baccef082c71ce1e8430e9 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Thu, 26 Apr 2018 21:13:52 -0700 Subject: [PATCH] [bugfix] temporal columns with expression fail (#4890) * [bugfix] temporal columns with expression fail error msg: "local variable 'literal' referenced before assignment" Error occurs [only] when using temporal column defined as a SQL expression. Also noticed that examples were using `granularity` instead of using `granularity_sqla` as they should. Fixed that here. * Add tests --- superset/connectors/base/models.py | 5 +++ superset/connectors/sqla/models.py | 28 +++++++------- superset/data/__init__.py | 22 +++++------ tests/model_tests.py | 61 ++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 24 deletions(-) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 8e4a2a22459d4..9f9522daa45b0 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -241,6 +241,11 @@ def values_for_column(self, column_name, limit=10000): def default_query(qry): return qry + def get_column(self, column_name): + for col in self.columns: + if col.column_name == column_name: + return col + class BaseColumn(AuditMixinNullable, ImportMixin): """Interface for column""" diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c65df0209a82b..56a1751243995 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -117,22 +117,24 @@ def get_time_filter(self, start_dttm, end_dttm): def get_timestamp_expression(self, time_grain): """Getting the time component of the query""" + pdf = self.python_date_format + is_epoch = pdf in ('epoch_s', 'epoch_ms') + if not self.expression and not time_grain and not is_epoch: + return column(self.column_name, type_=DateTime).label(DTTM_ALIAS) + expr = self.expression or self.column_name - if not self.expression and not time_grain: - return column(expr, type_=DateTime).label(DTTM_ALIAS) + if is_epoch: + # if epoch, translate to DATE using db specific conf + db_spec = self.table.database.db_engine_spec + if pdf == 'epoch_s': + expr = db_spec.epoch_to_dttm().format(col=expr) + elif pdf == 'epoch_ms': + expr = db_spec.epoch_ms_to_dttm().format(col=expr) if time_grain: - pdf = self.python_date_format - if pdf in ('epoch_s', 'epoch_ms'): - # if epoch, translate to DATE using db specific conf - db_spec = self.table.database.db_engine_spec - if pdf == 'epoch_s': - expr = db_spec.epoch_to_dttm().format(col=expr) - elif pdf == 'epoch_ms': - expr = db_spec.epoch_ms_to_dttm().format(col=expr) grain = self.table.database.grains_dict().get(time_grain) - literal = grain.function if grain else '{col}' - literal = expr.format(col=expr) - return literal_column(literal, type_=DateTime).label(DTTM_ALIAS) + if grain: + expr = grain.function.format(col=expr) + return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) @classmethod def import_obj(cls, i_column): diff --git a/superset/data/__init__.py b/superset/data/__init__.py index 160ed647f9216..8ad6c11688a11 100644 --- a/superset/data/__init__.py +++ b/superset/data/__init__.py @@ -188,7 +188,7 @@ def load_world_bank_health_n_pop(): "compare_lag": "10", "compare_suffix": "o10Y", "limit": "25", - "granularity": "year", + "granularity_sqla": "year", "groupby": [], "metric": 'sum__SP_POP_TOTL', "metrics": ["sum__SP_POP_TOTL"], @@ -593,7 +593,7 @@ def load_birth_names(): "compare_lag": "10", "compare_suffix": "o10Y", "limit": "25", - "granularity": "ds", + "granularity_sqla": "ds", "groupby": [], "metric": 'sum__num', "metrics": ["sum__num"], @@ -642,7 +642,7 @@ def load_birth_names(): datasource_id=tbl.id, params=get_slice_json( defaults, - viz_type="big_number", granularity="ds", + viz_type="big_number", granularity_sqla="ds", compare_lag="5", compare_suffix="over 5Y")), Slice( slice_name="Genders", @@ -675,7 +675,7 @@ def load_birth_names(): params=get_slice_json( defaults, viz_type="line", groupby=['name'], - granularity='ds', rich_tooltip=True, show_legend=True)), + granularity_sqla='ds', rich_tooltip=True, show_legend=True)), Slice( slice_name="Average and Sum Trends", viz_type='dual_line', @@ -684,7 +684,7 @@ def load_birth_names(): params=get_slice_json( defaults, viz_type="dual_line", metric='avg__num', metric_2='sum__num', - granularity='ds')), + granularity_sqla='ds')), Slice( slice_name="Title", viz_type='markup', @@ -729,7 +729,7 @@ def load_birth_names(): datasource_id=tbl.id, params=get_slice_json( defaults, - viz_type="big_number_total", granularity="ds", + viz_type="big_number_total", granularity_sqla="ds", filters=[{ 'col': 'gender', 'op': 'in', @@ -876,7 +876,7 @@ def load_unicode_test_data(): tbl = obj slice_data = { - "granularity": "dttm", + "granularity_sqla": "dttm", "groupby": [], "metric": 'sum__value', "row_limit": config.get("ROW_LIMIT"), @@ -954,7 +954,7 @@ def load_random_time_series_data(): tbl = obj slice_data = { - "granularity": "day", + "granularity_sqla": "day", "row_limit": config.get("ROW_LIMIT"), "since": "1 year ago", "until": "now", @@ -1017,7 +1017,7 @@ def load_country_map_data(): tbl = obj slice_data = { - "granularity": "", + "granularity_sqla": "", "since": "", "until": "", "where": "", @@ -1092,7 +1092,7 @@ def load_long_lat_data(): tbl = obj slice_data = { - "granularity": "day", + "granularity_sqla": "day", "since": "2014-01-01", "until": "now", "where": "", @@ -1172,7 +1172,7 @@ def load_multiformat_time_series_data(): slice_data = { "metric": 'count', "granularity_sqla": col.column_name, - "granularity": "day", + "granularity_sqla": "day", "row_limit": config.get("ROW_LIMIT"), "since": "1 year ago", "until": "now", diff --git a/tests/model_tests.py b/tests/model_tests.py index cdd4c830fbfe9..8af104f57c9d0 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -105,3 +105,64 @@ def test_grains_dict(self): self.assertEquals(d.get('day').function, 'DATE({col})') self.assertEquals(d.get('P1D').function, 'DATE({col})') self.assertEquals(d.get('Time Column').function, '{col}') + + +class SqlaTableModelTestCase(SupersetTestCase): + + def test_get_timestamp_expression(self): + tbl = self.get_table_by_name('birth_names') + ds_col = tbl.get_column('ds') + sqla_literal = ds_col.get_timestamp_expression(None) + self.assertEquals(str(sqla_literal.compile()), 'ds') + + sqla_literal = ds_col.get_timestamp_expression('P1D') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(ds)') + + ds_col.expression = 'DATE_ADD(ds, 1)' + sqla_literal = ds_col.get_timestamp_expression('P1D') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))') + + def test_get_timestamp_expression_epoch(self): + tbl = self.get_table_by_name('birth_names') + ds_col = tbl.get_column('ds') + + ds_col.expression = None + ds_col.python_date_format = 'epoch_s' + sqla_literal = ds_col.get_timestamp_expression(None) + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'from_unixtime(ds)') + + ds_col.python_date_format = 'epoch_s' + sqla_literal = ds_col.get_timestamp_expression('P1D') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(from_unixtime(ds))') + + ds_col.expression = 'DATE_ADD(ds, 1)' + sqla_literal = ds_col.get_timestamp_expression('P1D') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))') + + def test_get_timestamp_expression_backward(self): + tbl = self.get_table_by_name('birth_names') + ds_col = tbl.get_column('ds') + + ds_col.expression = None + ds_col.python_date_format = None + sqla_literal = ds_col.get_timestamp_expression('day') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(ds)') + + ds_col.expression = None + ds_col.python_date_format = None + sqla_literal = ds_col.get_timestamp_expression('Time Column') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'ds')