Skip to content

Commit

Permalink
[bugfix] temporal columns with expression fail (apache#4890)
Browse files Browse the repository at this point in the history
* [bugfix] temporal columns with expression fail

error msg: "local variable 'literal' referenced before assignment"

Error occurs [only] when using temporal column defined as a SQL
expression.

Also noticed that examples were using `granularity` instead of using
`granularity_sqla` as they should. Fixed that here.

* Add tests
  • Loading branch information
mistercrunch authored Apr 27, 2018
1 parent ca5547f commit 14ceff7
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 24 deletions.
5 changes: 5 additions & 0 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ def values_for_column(self, column_name, limit=10000):
def default_query(qry):
return qry

def get_column(self, column_name):
for col in self.columns:
if col.column_name == column_name:
return col


class BaseColumn(AuditMixinNullable, ImportMixin):
"""Interface for column"""
Expand Down
28 changes: 15 additions & 13 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,24 @@ def get_time_filter(self, start_dttm, end_dttm):

def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
pdf = self.python_date_format
is_epoch = pdf in ('epoch_s', 'epoch_ms')
if not self.expression and not time_grain and not is_epoch:
return column(self.column_name, type_=DateTime).label(DTTM_ALIAS)

expr = self.expression or self.column_name
if not self.expression and not time_grain:
return column(expr, type_=DateTime).label(DTTM_ALIAS)
if is_epoch:
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
if time_grain:
pdf = self.python_date_format
if pdf in ('epoch_s', 'epoch_ms'):
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
grain = self.table.database.grains_dict().get(time_grain)
literal = grain.function if grain else '{col}'
literal = expr.format(col=expr)
return literal_column(literal, type_=DateTime).label(DTTM_ALIAS)
if grain:
expr = grain.function.format(col=expr)
return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)

@classmethod
def import_obj(cls, i_column):
Expand Down
22 changes: 11 additions & 11 deletions superset/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def load_world_bank_health_n_pop():
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity": "year",
"granularity_sqla": "year",
"groupby": [],
"metric": 'sum__SP_POP_TOTL',
"metrics": ["sum__SP_POP_TOTL"],
Expand Down Expand Up @@ -593,7 +593,7 @@ def load_birth_names():
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity": "ds",
"granularity_sqla": "ds",
"groupby": [],
"metric": 'sum__num',
"metrics": ["sum__num"],
Expand Down Expand Up @@ -642,7 +642,7 @@ def load_birth_names():
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="big_number", granularity="ds",
viz_type="big_number", granularity_sqla="ds",
compare_lag="5", compare_suffix="over 5Y")),
Slice(
slice_name="Genders",
Expand Down Expand Up @@ -675,7 +675,7 @@ def load_birth_names():
params=get_slice_json(
defaults,
viz_type="line", groupby=['name'],
granularity='ds', rich_tooltip=True, show_legend=True)),
granularity_sqla='ds', rich_tooltip=True, show_legend=True)),
Slice(
slice_name="Average and Sum Trends",
viz_type='dual_line',
Expand All @@ -684,7 +684,7 @@ def load_birth_names():
params=get_slice_json(
defaults,
viz_type="dual_line", metric='avg__num', metric_2='sum__num',
granularity='ds')),
granularity_sqla='ds')),
Slice(
slice_name="Title",
viz_type='markup',
Expand Down Expand Up @@ -729,7 +729,7 @@ def load_birth_names():
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type="big_number_total", granularity="ds",
viz_type="big_number_total", granularity_sqla="ds",
filters=[{
'col': 'gender',
'op': 'in',
Expand Down Expand Up @@ -876,7 +876,7 @@ def load_unicode_test_data():
tbl = obj

slice_data = {
"granularity": "dttm",
"granularity_sqla": "dttm",
"groupby": [],
"metric": 'sum__value',
"row_limit": config.get("ROW_LIMIT"),
Expand Down Expand Up @@ -954,7 +954,7 @@ def load_random_time_series_data():
tbl = obj

slice_data = {
"granularity": "day",
"granularity_sqla": "day",
"row_limit": config.get("ROW_LIMIT"),
"since": "1 year ago",
"until": "now",
Expand Down Expand Up @@ -1017,7 +1017,7 @@ def load_country_map_data():
tbl = obj

slice_data = {
"granularity": "",
"granularity_sqla": "",
"since": "",
"until": "",
"where": "",
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def load_long_lat_data():
tbl = obj

slice_data = {
"granularity": "day",
"granularity_sqla": "day",
"since": "2014-01-01",
"until": "now",
"where": "",
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def load_multiformat_time_series_data():
slice_data = {
"metric": 'count',
"granularity_sqla": col.column_name,
"granularity": "day",
"granularity_sqla": "day",
"row_limit": config.get("ROW_LIMIT"),
"since": "1 year ago",
"until": "now",
Expand Down
61 changes: 61 additions & 0 deletions tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,64 @@ def test_grains_dict(self):
self.assertEquals(d.get('day').function, 'DATE({col})')
self.assertEquals(d.get('P1D').function, 'DATE({col})')
self.assertEquals(d.get('Time Column').function, '{col}')


class SqlaTableModelTestCase(SupersetTestCase):

def test_get_timestamp_expression(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')
sqla_literal = ds_col.get_timestamp_expression(None)
self.assertEquals(str(sqla_literal.compile()), 'ds')

sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')

ds_col.expression = 'DATE_ADD(ds, 1)'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))')

def test_get_timestamp_expression_epoch(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')

ds_col.expression = None
ds_col.python_date_format = 'epoch_s'
sqla_literal = ds_col.get_timestamp_expression(None)
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'from_unixtime(ds)')

ds_col.python_date_format = 'epoch_s'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(from_unixtime(ds))')

ds_col.expression = 'DATE_ADD(ds, 1)'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')

def test_get_timestamp_expression_backward(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')

ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('day')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')

ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('Time Column')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'ds')

0 comments on commit 14ceff7

Please sign in to comment.