diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a6b5194570b4b..56a1751243995 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -117,24 +117,24 @@ def get_time_filter(self, start_dttm, end_dttm): def get_timestamp_expression(self, time_grain): """Getting the time component of the query""" + pdf = self.python_date_format + is_epoch = pdf in ('epoch_s', 'epoch_ms') + if not self.expression and not time_grain and not is_epoch: + return column(self.column_name, type_=DateTime).label(DTTM_ALIAS) + expr = self.expression or self.column_name - if not self.expression and not time_grain: - return column(expr, type_=DateTime).label(DTTM_ALIAS) - literal = '{col}' + if is_epoch: + # if epoch, translate to DATE using db specific conf + db_spec = self.table.database.db_engine_spec + if pdf == 'epoch_s': + expr = db_spec.epoch_to_dttm().format(col=expr) + elif pdf == 'epoch_ms': + expr = db_spec.epoch_ms_to_dttm().format(col=expr) if time_grain: - pdf = self.python_date_format - if pdf in ('epoch_s', 'epoch_ms'): - # if epoch, translate to DATE using db specific conf - db_spec = self.table.database.db_engine_spec - if pdf == 'epoch_s': - expr = db_spec.epoch_to_dttm().format(col=expr) - elif pdf == 'epoch_ms': - expr = db_spec.epoch_ms_to_dttm().format(col=expr) grain = self.table.database.grains_dict().get(time_grain) if grain: - literal = grain.function - literal = expr.format(col=expr) - return literal_column(literal, type_=DateTime).label(DTTM_ALIAS) + expr = grain.function.format(col=expr) + return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) @classmethod def import_obj(cls, i_column): diff --git a/tests/model_tests.py b/tests/model_tests.py index 11fa9e8ca8a3f..8af104f57c9d0 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -106,8 +106,63 @@ def test_grains_dict(self): self.assertEquals(d.get('P1D').function, 'DATE({col})') self.assertEquals(d.get('Time Column').function, '{col}') + +class SqlaTableModelTestCase(SupersetTestCase): + def test_get_timestamp_expression(self): tbl = self.get_table_by_name('birth_names') ds_col = tbl.get_column('ds') sqla_literal = ds_col.get_timestamp_expression(None) self.assertEquals(str(sqla_literal.compile()), 'ds') + + sqla_literal = ds_col.get_timestamp_expression('P1D') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(ds)') + + ds_col.expression = 'DATE_ADD(ds, 1)' + sqla_literal = ds_col.get_timestamp_expression('P1D') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))') + + def test_get_timestamp_expression_epoch(self): + tbl = self.get_table_by_name('birth_names') + ds_col = tbl.get_column('ds') + + ds_col.expression = None + ds_col.python_date_format = 'epoch_s' + sqla_literal = ds_col.get_timestamp_expression(None) + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'from_unixtime(ds)') + + ds_col.python_date_format = 'epoch_s' + sqla_literal = ds_col.get_timestamp_expression('P1D') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(from_unixtime(ds))') + + ds_col.expression = 'DATE_ADD(ds, 1)' + sqla_literal = ds_col.get_timestamp_expression('P1D') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))') + + def test_get_timestamp_expression_backward(self): + tbl = self.get_table_by_name('birth_names') + ds_col = tbl.get_column('ds') + + ds_col.expression = None + ds_col.python_date_format = None + sqla_literal = ds_col.get_timestamp_expression('day') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'DATE(ds)') + + ds_col.expression = None + ds_col.python_date_format = None + sqla_literal = ds_col.get_timestamp_expression('Time Column') + compiled = '{}'.format(sqla_literal.compile()) + if tbl.database.backend == 'mysql': + self.assertEquals(compiled, 'ds')