Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch committed Apr 27, 2018
1 parent 6a7b25f commit 3ac0ba5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
28 changes: 14 additions & 14 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,24 @@ def get_time_filter(self, start_dttm, end_dttm):

def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
pdf = self.python_date_format
is_epoch = pdf in ('epoch_s', 'epoch_ms')
if not self.expression and not time_grain and not is_epoch:
return column(self.column_name, type_=DateTime).label(DTTM_ALIAS)

expr = self.expression or self.column_name
if not self.expression and not time_grain:
return column(expr, type_=DateTime).label(DTTM_ALIAS)
literal = '{col}'
if is_epoch:
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
if time_grain:
pdf = self.python_date_format
if pdf in ('epoch_s', 'epoch_ms'):
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
grain = self.table.database.grains_dict().get(time_grain)
if grain:
literal = grain.function
literal = expr.format(col=expr)
return literal_column(literal, type_=DateTime).label(DTTM_ALIAS)
expr = grain.function.format(col=expr)
return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)

@classmethod
def import_obj(cls, i_column):
Expand Down
55 changes: 55 additions & 0 deletions tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,63 @@ def test_grains_dict(self):
self.assertEquals(d.get('P1D').function, 'DATE({col})')
self.assertEquals(d.get('Time Column').function, '{col}')


class SqlaTableModelTestCase(SupersetTestCase):

def test_get_timestamp_expression(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')
sqla_literal = ds_col.get_timestamp_expression(None)
self.assertEquals(str(sqla_literal.compile()), 'ds')

sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')

ds_col.expression = 'DATE_ADD(ds, 1)'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))')

def test_get_timestamp_expression_epoch(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')

ds_col.expression = None
ds_col.python_date_format = 'epoch_s'
sqla_literal = ds_col.get_timestamp_expression(None)
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'from_unixtime(ds)')

ds_col.python_date_format = 'epoch_s'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(from_unixtime(ds))')

ds_col.expression = 'DATE_ADD(ds, 1)'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')

def test_get_timestamp_expression_backward(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')

ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('day')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')

ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('Time Column')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'ds')

0 comments on commit 3ac0ba5

Please sign in to comment.