Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force quoted column aliases for Oracle-like databases #5686

Merged
merged 29 commits into from
Sep 4, 2018
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d5242d0
Replace dataframe label override logic with table column override
villebro Aug 22, 2018
ead1b55
Add mutation to any_date_col
villebro Aug 22, 2018
4540d11
Linting
villebro Aug 22, 2018
027f794
Add mutation to oracle and redshift
villebro Aug 22, 2018
f8f851a
Fine tune how and which labels are mutated
villebro Aug 23, 2018
ce3617b
Implement alias quoting logic for oracle-like databases
villebro Aug 25, 2018
5f67d59
Fix and align column and metric sqla_col methods
villebro Aug 25, 2018
8b6e52d
Clean up typos and redundant logic
villebro Aug 25, 2018
62163fa
Move new attribute to old location
villebro Aug 25, 2018
e8dd7da
Linting
villebro Aug 25, 2018
6c6436f
Replace old sqla_col property references with function calls
villebro Aug 25, 2018
4b0e593
Remove redundant calls to mutate_column_label
villebro Aug 26, 2018
76959dd
Move duplicated logic to common function
villebro Aug 26, 2018
6c7fccc
Add db_engine_specs to all sqla_col calls
villebro Aug 26, 2018
c187104
Add missing mydb
villebro Aug 26, 2018
53b77a2
Add note about snowflake-sqlalchemy regression
villebro Aug 26, 2018
e792d15
Make db_engine_spec mandatory in sqla_col
villebro Aug 26, 2018
44eac83
Small refactoring and cleanup
villebro Aug 26, 2018
a161efc
Remove db_engine_spec from get_from_clause call
villebro Aug 26, 2018
901b635
Make db_engine_spec mandatory in adhoc_metric_to_sa
villebro Aug 26, 2018
7e7a2a8
Remove redundant mutate_expression_label call
villebro Aug 26, 2018
d60ed34
Add missing db_engine_specs to adhoc_metric_to_sa
villebro Aug 26, 2018
fc2c4f2
Rename arg label_name to label in get_column_label()
villebro Aug 27, 2018
3f2d874
Rename label function and add docstring
villebro Aug 27, 2018
95c99d3
Remove redundant db_engine_spec args
villebro Aug 27, 2018
04d3b5f
Rename col_label to label
villebro Aug 27, 2018
efb6313
Remove get_column_name wrapper and make direct calls to db_engine_spec
villebro Aug 27, 2018
3a6fe4b
Remove unneeded db_engine_specs
villebro Aug 27, 2018
0d18163
Rename sa_ vars to sqla_
villebro Aug 27, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ Make sure the user has privileges to access and use all required
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
not test for user rights during engine creation.

*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
use version 1.1.0 or try a newer version.

See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.

Caching
Expand Down
110 changes: 63 additions & 47 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
config = app.config


def get_column_label(db_engine_spec, column_name, label):
label = label if label is not None else column_name
return db_engine_spec.get_column_label(label)


class AnnotationDatasource(BaseDatasource):
""" Dummy object so we can query annotations using 'Viz' objects just like
regular datasources.
Expand Down Expand Up @@ -99,21 +104,20 @@ class TableColumn(Model, BaseColumn):
s for s in export_fields if s not in ('table_id',)]
export_parent = 'table'

@property
def sqla_col(self):
name = self.column_name
def sqla_col(self, db_engine_spec, label=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think db_engine_spec is available through self.table.database.db_engine_spec, so no need to receive it as a method argument.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops; db_engine_spec removed from all methods, as it is already available in all classes. I also renamed the method to get_sqla_col from just sqla_col, as is was previously a property.

col_label = get_column_label(db_engine_spec, self.column_name, label)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had a moment thinking about whether verbose_name should ever make its way in here and concluded no at least for now.

if not self.expression:
col = column(self.column_name).label(name)
col = column(self.column_name).label(col_label)
else:
col = literal_column(self.expression).label(name)
col = literal_column(self.expression).label(col_label)
return col

@property
def datasource(self):
return self.table

def get_time_filter(self, start_dttm, end_dttm):
col = self.sqla_col.label('__time')
def get_time_filter(self, db_engine_spec, start_dttm, end_dttm):
col = self.sqla_col(db_engine_spec, '__time')
l = [] # noqa: E741
if start_dttm:
l.append(col >= text(self.dttm_sql_literal(start_dttm)))
Expand Down Expand Up @@ -231,10 +235,9 @@ class SqlMetric(Model, BaseMetric):
s for s in export_fields if s not in ('table_id', )])
export_parent = 'table'

@property
def sqla_col(self):
name = self.metric_name
return literal_column(self.expression).label(name)
def sqla_col(self, db_engine_spec, label=None):
col_label = get_column_label(db_engine_spec, self.metric_name, label)
return literal_column(self.expression).label(col_label)

@property
def perm(self):
Expand Down Expand Up @@ -424,8 +427,8 @@ def values_for_column(self, column_name, limit=10000):
db_engine_spec = self.database.db_engine_spec

qry = (
select([target_col.sqla_col])
.select_from(self.get_from_clause(tp, db_engine_spec))
select([target_col.sqla_col(db_engine_spec)])
.select_from(self.get_from_clause(tp))
.distinct()
)
if limit:
Expand Down Expand Up @@ -474,7 +477,7 @@ def get_sqla_table(self):
tbl.schema = self.schema
return tbl

def get_from_clause(self, template_processor=None, db_engine_spec=None):
def get_from_clause(self, template_processor=None):
# Supporting arbitrary SQL statements in place of tables
if self.sql:
from_sql = self.sql
Expand All @@ -484,30 +487,34 @@ def get_from_clause(self, template_processor=None, db_engine_spec=None):
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table()

def adhoc_metric_to_sa(self, metric, cols):
def adhoc_metric_to_sa(self, metric, cols, db_engine_spec):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think db_engine_spec is also in scope here through self.table.database.db_engine_spec

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope for this PR, but I just caught the ..._to_sa should be ..._to_sqla to be consistent with aliasing sqla.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed both method and variables from sa to sqla.

"""
Turn an adhoc metric into a sqlalchemy column.

:param dict metric: Adhoc metric definition
:param dict cols: Columns for the current table
:param BaseEngineSpec db_engine_spec: Db engine spec for
database specific handling of column labels
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
expressionType = metric.get('expressionType')
if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
expression_type = metric.get('expressionType')
label = db_engine_spec.get_column_label(metric.get('label'))

if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
sa_column = column(column_name)
table_column = cols.get(column_name)

if table_column:
sa_column = table_column.sqla_col
sa_column = table_column.sqla_col(db_engine_spec)

sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
sa_metric = sa_metric.label(metric.get('label'))
sa_metric = sa_metric.label(label)
return sa_metric
elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sa_metric = literal_column(metric.get('sqlExpression'))
sa_metric = sa_metric.label(metric.get('label'))
sa_metric = sa_metric.label(label)
return sa_metric
else:
return None
Expand Down Expand Up @@ -566,15 +573,16 @@ def get_sqla_query( # sqla
metrics_exprs = []
for m in metrics:
if utils.is_adhoc_metric(m):
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols, db_engine_spec))
elif m in metrics_dict:
metrics_exprs.append(metrics_dict.get(m).sqla_col)
metrics_exprs.append(metrics_dict.get(m).sqla_col(db_engine_spec))
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
main_metric_expr = literal_column('COUNT(*)').label('ccount')
main_metric_expr = literal_column('COUNT(*)').label(
db_engine_spec.get_column_label('count'))

select_exprs = []
groupby_exprs = []
Expand All @@ -585,16 +593,16 @@ def get_sqla_query( # sqla
inner_groupby_exprs = []
for s in groupby:
col = cols[s]
outer = col.sqla_col
inner = col.sqla_col.label(col.column_name + '__')
outer = col.sqla_col(db_engine_spec)
inner = col.sqla_col(db_engine_spec, col.column_name + '__')

groupby_exprs.append(outer)
select_exprs.append(outer)
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
elif columns:
for s in columns:
select_exprs.append(cols[s].sqla_col)
select_exprs.append(cols[s].sqla_col(db_engine_spec))
metrics_exprs = []

if granularity:
Expand All @@ -612,13 +620,14 @@ def get_sqla_query( # sqla
self.main_dttm_col in self.dttm_cols and \
self.main_dttm_col != dttm_col.column_name:
time_filters.append(cols[self.main_dttm_col].
get_time_filter(from_dttm, to_dttm))
time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
get_time_filter(db_engine_spec, from_dttm, to_dttm))
time_filters.append(dttm_col.get_time_filter(db_engine_spec,
from_dttm, to_dttm))

select_exprs += metrics_exprs
qry = sa.select(select_exprs)

tbl = self.get_from_clause(template_processor, db_engine_spec)
tbl = self.get_from_clause(template_processor)

if not columns:
qry = qry.group_by(*groupby_exprs)
Expand All @@ -638,33 +647,35 @@ def get_sqla_query( # sqla
target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target)
if op in ('in', 'not in'):
cond = col_obj.sqla_col.in_(eq)
cond = col_obj.sqla_col(db_engine_spec).in_(eq)
if '<NULL>' in eq:
cond = or_(cond, col_obj.sqla_col == None) # noqa
cond = or_(cond, col_obj.sqla_col(db_engine_spec) == None) # noqa
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_num:
eq = utils.string_to_num(flt['val'])
if op == '==':
where_clause_and.append(col_obj.sqla_col == eq)
where_clause_and.append(col_obj.sqla_col(db_engine_spec) == eq)
elif op == '!=':
where_clause_and.append(col_obj.sqla_col != eq)
where_clause_and.append(col_obj.sqla_col(db_engine_spec) != eq)
elif op == '>':
where_clause_and.append(col_obj.sqla_col > eq)
where_clause_and.append(col_obj.sqla_col(db_engine_spec) > eq)
elif op == '<':
where_clause_and.append(col_obj.sqla_col < eq)
where_clause_and.append(col_obj.sqla_col(db_engine_spec) < eq)
elif op == '>=':
where_clause_and.append(col_obj.sqla_col >= eq)
where_clause_and.append(col_obj.sqla_col(db_engine_spec) >= eq)
elif op == '<=':
where_clause_and.append(col_obj.sqla_col <= eq)
where_clause_and.append(col_obj.sqla_col(db_engine_spec) <= eq)
elif op == 'LIKE':
where_clause_and.append(col_obj.sqla_col.like(eq))
where_clause_and.append(col_obj.sqla_col(db_engine_spec).like(eq))
elif op == 'IS NULL':
where_clause_and.append(col_obj.sqla_col == None) # noqa
where_clause_and.append(
col_obj.sqla_col(db_engine_spec) == None) # noqa
elif op == 'IS NOT NULL':
where_clause_and.append(col_obj.sqla_col != None) # noqa
where_clause_and.append(
col_obj.sqla_col(db_engine_spec) != None) # noqa
if extras:
where = extras.get('where')
if where:
Expand All @@ -686,7 +697,7 @@ def get_sqla_query( # sqla
for col, ascending in orderby:
direction = asc if ascending else desc
if utils.is_adhoc_metric(col):
col = self.adhoc_metric_to_sa(col, cols)
col = self.adhoc_metric_to_sa(col, cols, db_engine_spec)
qry = qry.order_by(direction(col))

if row_limit:
Expand All @@ -712,12 +723,13 @@ def get_sqla_query( # sqla
ob = inner_main_metric_expr
if timeseries_limit_metric:
if utils.is_adhoc_metric(timeseries_limit_metric):
ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols)
ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols,
db_engine_spec)
elif timeseries_limit_metric in metrics_dict:
timeseries_limit_metric = metrics_dict.get(
timeseries_limit_metric,
)
ob = timeseries_limit_metric.sqla_col
ob = timeseries_limit_metric.sqla_col(db_engine_spec)
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
direction = desc if order_desc else asc
Expand Down Expand Up @@ -750,19 +762,19 @@ def get_sqla_query( # sqla
}
result = self.query(subquery_obj)
dimensions = [c for c in result.df.columns if c not in metrics]
top_groups = self._get_top_groups(result.df, dimensions)
top_groups = self._get_top_groups(result.df, dimensions, db_engine_spec)
qry = qry.where(top_groups)

return qry.select_from(tbl)

def _get_top_groups(self, df, dimensions):
def _get_top_groups(self, df, dimensions, db_engine_spec):
cols = {col.column_name: col for col in self.columns}
groups = []
for unused, row in df.iterrows():
group = []
for dimension in dimensions:
col_obj = cols.get(dimension)
group.append(col_obj.sqla_col == row[dimension])
group.append(col_obj.sqla_col(db_engine_spec) == row[dimension])
groups.append(and_(*group))

return or_(*groups)
Expand Down Expand Up @@ -816,6 +828,7 @@ def fetch_metadata(self):
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
db_engine_spec = self.database.db_engine_spec

for col in table.columns:
try:
Expand Down Expand Up @@ -848,6 +861,9 @@ def fetch_metadata(self):
))
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
for metric in metrics:
metric.metric_name = db_engine_spec.mutate_expression_label(
metric.metric_name)
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()
Expand Down
4 changes: 1 addition & 3 deletions superset/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def __init__(self, data, cursor_description, db_engine_spec):
if cursor_description:
column_names = [col[0] for col in cursor_description]

case_sensitive = db_engine_spec.consistent_case_sensitivity
self.column_names = dedup(column_names,
case_sensitive=case_sensitive)
self.column_names = dedup(column_names)

data = data or []
self.df = (
Expand Down
Loading