-
Notifications
You must be signed in to change notification settings - Fork 14.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Force quoted column aliases for Oracle-like databases #5686
Changes from 22 commits
d5242d0
ead1b55
4540d11
027f794
f8f851a
ce3617b
5f67d59
8b6e52d
62163fa
e8dd7da
6c6436f
4b0e593
76959dd
6c7fccc
c187104
53b77a2
e792d15
44eac83
a161efc
901b635
7e7a2a8
d60ed34
fc2c4f2
3f2d874
95c99d3
04d3b5f
efb6313
3a6fe4b
0d18163
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,11 @@ | |
config = app.config | ||
|
||
|
||
def get_column_label(db_engine_spec, column_name, label): | ||
label = label if label is not None else column_name | ||
return db_engine_spec.get_column_label(label) | ||
|
||
|
||
class AnnotationDatasource(BaseDatasource): | ||
""" Dummy object so we can query annotations using 'Viz' objects just like | ||
regular datasources. | ||
|
@@ -99,21 +104,20 @@ class TableColumn(Model, BaseColumn): | |
s for s in export_fields if s not in ('table_id',)] | ||
export_parent = 'table' | ||
|
||
@property | ||
def sqla_col(self): | ||
name = self.column_name | ||
def sqla_col(self, db_engine_spec, label=None): | ||
col_label = get_column_label(db_engine_spec, self.column_name, label) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just had a moment thinking about whether |
||
if not self.expression: | ||
col = column(self.column_name).label(name) | ||
col = column(self.column_name).label(col_label) | ||
else: | ||
col = literal_column(self.expression).label(name) | ||
col = literal_column(self.expression).label(col_label) | ||
return col | ||
|
||
@property | ||
def datasource(self): | ||
return self.table | ||
|
||
def get_time_filter(self, start_dttm, end_dttm): | ||
col = self.sqla_col.label('__time') | ||
def get_time_filter(self, db_engine_spec, start_dttm, end_dttm): | ||
col = self.sqla_col(db_engine_spec, '__time') | ||
l = [] # noqa: E741 | ||
if start_dttm: | ||
l.append(col >= text(self.dttm_sql_literal(start_dttm))) | ||
|
@@ -231,10 +235,9 @@ class SqlMetric(Model, BaseMetric): | |
s for s in export_fields if s not in ('table_id', )]) | ||
export_parent = 'table' | ||
|
||
@property | ||
def sqla_col(self): | ||
name = self.metric_name | ||
return literal_column(self.expression).label(name) | ||
def sqla_col(self, db_engine_spec, label=None): | ||
col_label = get_column_label(db_engine_spec, self.metric_name, label) | ||
return literal_column(self.expression).label(col_label) | ||
|
||
@property | ||
def perm(self): | ||
|
@@ -424,8 +427,8 @@ def values_for_column(self, column_name, limit=10000): | |
db_engine_spec = self.database.db_engine_spec | ||
|
||
qry = ( | ||
select([target_col.sqla_col]) | ||
.select_from(self.get_from_clause(tp, db_engine_spec)) | ||
select([target_col.sqla_col(db_engine_spec)]) | ||
.select_from(self.get_from_clause(tp)) | ||
.distinct() | ||
) | ||
if limit: | ||
|
@@ -474,7 +477,7 @@ def get_sqla_table(self): | |
tbl.schema = self.schema | ||
return tbl | ||
|
||
def get_from_clause(self, template_processor=None, db_engine_spec=None): | ||
def get_from_clause(self, template_processor=None): | ||
# Supporting arbitrary SQL statements in place of tables | ||
if self.sql: | ||
from_sql = self.sql | ||
|
@@ -484,30 +487,34 @@ def get_from_clause(self, template_processor=None, db_engine_spec=None): | |
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') | ||
return self.get_sqla_table() | ||
|
||
def adhoc_metric_to_sa(self, metric, cols): | ||
def adhoc_metric_to_sa(self, metric, cols, db_engine_spec): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of scope for this PR, but I just caught the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed both method and variables from |
||
""" | ||
Turn an adhoc metric into a sqlalchemy column. | ||
|
||
:param dict metric: Adhoc metric definition | ||
:param dict cols: Columns for the current table | ||
:param BaseEngineSpec db_engine_spec: Db engine spec for | ||
database specific handling of column labels | ||
:returns: The metric defined as a sqlalchemy column | ||
:rtype: sqlalchemy.sql.column | ||
""" | ||
expressionType = metric.get('expressionType') | ||
if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: | ||
expression_type = metric.get('expressionType') | ||
label = db_engine_spec.get_column_label(metric.get('label')) | ||
|
||
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: | ||
column_name = metric.get('column').get('column_name') | ||
sa_column = column(column_name) | ||
table_column = cols.get(column_name) | ||
|
||
if table_column: | ||
sa_column = table_column.sqla_col | ||
sa_column = table_column.sqla_col(db_engine_spec) | ||
|
||
sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column) | ||
sa_metric = sa_metric.label(metric.get('label')) | ||
sa_metric = sa_metric.label(label) | ||
return sa_metric | ||
elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: | ||
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: | ||
sa_metric = literal_column(metric.get('sqlExpression')) | ||
sa_metric = sa_metric.label(metric.get('label')) | ||
sa_metric = sa_metric.label(label) | ||
return sa_metric | ||
else: | ||
return None | ||
|
@@ -566,15 +573,16 @@ def get_sqla_query( # sqla | |
metrics_exprs = [] | ||
for m in metrics: | ||
if utils.is_adhoc_metric(m): | ||
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols)) | ||
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols, db_engine_spec)) | ||
elif m in metrics_dict: | ||
metrics_exprs.append(metrics_dict.get(m).sqla_col) | ||
metrics_exprs.append(metrics_dict.get(m).sqla_col(db_engine_spec)) | ||
else: | ||
raise Exception(_("Metric '{}' is not valid".format(m))) | ||
if metrics_exprs: | ||
main_metric_expr = metrics_exprs[0] | ||
else: | ||
main_metric_expr = literal_column('COUNT(*)').label('ccount') | ||
main_metric_expr = literal_column('COUNT(*)').label( | ||
db_engine_spec.get_column_label('count')) | ||
|
||
select_exprs = [] | ||
groupby_exprs = [] | ||
|
@@ -585,16 +593,16 @@ def get_sqla_query( # sqla | |
inner_groupby_exprs = [] | ||
for s in groupby: | ||
col = cols[s] | ||
outer = col.sqla_col | ||
inner = col.sqla_col.label(col.column_name + '__') | ||
outer = col.sqla_col(db_engine_spec) | ||
inner = col.sqla_col(db_engine_spec, col.column_name + '__') | ||
|
||
groupby_exprs.append(outer) | ||
select_exprs.append(outer) | ||
inner_groupby_exprs.append(inner) | ||
inner_select_exprs.append(inner) | ||
elif columns: | ||
for s in columns: | ||
select_exprs.append(cols[s].sqla_col) | ||
select_exprs.append(cols[s].sqla_col(db_engine_spec)) | ||
metrics_exprs = [] | ||
|
||
if granularity: | ||
|
@@ -612,13 +620,14 @@ def get_sqla_query( # sqla | |
self.main_dttm_col in self.dttm_cols and \ | ||
self.main_dttm_col != dttm_col.column_name: | ||
time_filters.append(cols[self.main_dttm_col]. | ||
get_time_filter(from_dttm, to_dttm)) | ||
time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) | ||
get_time_filter(db_engine_spec, from_dttm, to_dttm)) | ||
time_filters.append(dttm_col.get_time_filter(db_engine_spec, | ||
from_dttm, to_dttm)) | ||
|
||
select_exprs += metrics_exprs | ||
qry = sa.select(select_exprs) | ||
|
||
tbl = self.get_from_clause(template_processor, db_engine_spec) | ||
tbl = self.get_from_clause(template_processor) | ||
|
||
if not columns: | ||
qry = qry.group_by(*groupby_exprs) | ||
|
@@ -638,33 +647,35 @@ def get_sqla_query( # sqla | |
target_column_is_numeric=col_obj.is_num, | ||
is_list_target=is_list_target) | ||
if op in ('in', 'not in'): | ||
cond = col_obj.sqla_col.in_(eq) | ||
cond = col_obj.sqla_col(db_engine_spec).in_(eq) | ||
if '<NULL>' in eq: | ||
cond = or_(cond, col_obj.sqla_col == None) # noqa | ||
cond = or_(cond, col_obj.sqla_col(db_engine_spec) == None) # noqa | ||
if op == 'not in': | ||
cond = ~cond | ||
where_clause_and.append(cond) | ||
else: | ||
if col_obj.is_num: | ||
eq = utils.string_to_num(flt['val']) | ||
if op == '==': | ||
where_clause_and.append(col_obj.sqla_col == eq) | ||
where_clause_and.append(col_obj.sqla_col(db_engine_spec) == eq) | ||
elif op == '!=': | ||
where_clause_and.append(col_obj.sqla_col != eq) | ||
where_clause_and.append(col_obj.sqla_col(db_engine_spec) != eq) | ||
elif op == '>': | ||
where_clause_and.append(col_obj.sqla_col > eq) | ||
where_clause_and.append(col_obj.sqla_col(db_engine_spec) > eq) | ||
elif op == '<': | ||
where_clause_and.append(col_obj.sqla_col < eq) | ||
where_clause_and.append(col_obj.sqla_col(db_engine_spec) < eq) | ||
elif op == '>=': | ||
where_clause_and.append(col_obj.sqla_col >= eq) | ||
where_clause_and.append(col_obj.sqla_col(db_engine_spec) >= eq) | ||
elif op == '<=': | ||
where_clause_and.append(col_obj.sqla_col <= eq) | ||
where_clause_and.append(col_obj.sqla_col(db_engine_spec) <= eq) | ||
elif op == 'LIKE': | ||
where_clause_and.append(col_obj.sqla_col.like(eq)) | ||
where_clause_and.append(col_obj.sqla_col(db_engine_spec).like(eq)) | ||
elif op == 'IS NULL': | ||
where_clause_and.append(col_obj.sqla_col == None) # noqa | ||
where_clause_and.append( | ||
col_obj.sqla_col(db_engine_spec) == None) # noqa | ||
elif op == 'IS NOT NULL': | ||
where_clause_and.append(col_obj.sqla_col != None) # noqa | ||
where_clause_and.append( | ||
col_obj.sqla_col(db_engine_spec) != None) # noqa | ||
if extras: | ||
where = extras.get('where') | ||
if where: | ||
|
@@ -686,7 +697,7 @@ def get_sqla_query( # sqla | |
for col, ascending in orderby: | ||
direction = asc if ascending else desc | ||
if utils.is_adhoc_metric(col): | ||
col = self.adhoc_metric_to_sa(col, cols) | ||
col = self.adhoc_metric_to_sa(col, cols, db_engine_spec) | ||
qry = qry.order_by(direction(col)) | ||
|
||
if row_limit: | ||
|
@@ -712,12 +723,13 @@ def get_sqla_query( # sqla | |
ob = inner_main_metric_expr | ||
if timeseries_limit_metric: | ||
if utils.is_adhoc_metric(timeseries_limit_metric): | ||
ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols) | ||
ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols, | ||
db_engine_spec) | ||
elif timeseries_limit_metric in metrics_dict: | ||
timeseries_limit_metric = metrics_dict.get( | ||
timeseries_limit_metric, | ||
) | ||
ob = timeseries_limit_metric.sqla_col | ||
ob = timeseries_limit_metric.sqla_col(db_engine_spec) | ||
else: | ||
raise Exception(_("Metric '{}' is not valid".format(m))) | ||
direction = desc if order_desc else asc | ||
|
@@ -750,19 +762,19 @@ def get_sqla_query( # sqla | |
} | ||
result = self.query(subquery_obj) | ||
dimensions = [c for c in result.df.columns if c not in metrics] | ||
top_groups = self._get_top_groups(result.df, dimensions) | ||
top_groups = self._get_top_groups(result.df, dimensions, db_engine_spec) | ||
qry = qry.where(top_groups) | ||
|
||
return qry.select_from(tbl) | ||
|
||
def _get_top_groups(self, df, dimensions): | ||
def _get_top_groups(self, df, dimensions, db_engine_spec): | ||
cols = {col.column_name: col for col in self.columns} | ||
groups = [] | ||
for unused, row in df.iterrows(): | ||
group = [] | ||
for dimension in dimensions: | ||
col_obj = cols.get(dimension) | ||
group.append(col_obj.sqla_col == row[dimension]) | ||
group.append(col_obj.sqla_col(db_engine_spec) == row[dimension]) | ||
groups.append(and_(*group)) | ||
|
||
return or_(*groups) | ||
|
@@ -816,6 +828,7 @@ def fetch_metadata(self): | |
.filter(or_(TableColumn.column_name == col.name | ||
for col in table.columns))) | ||
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} | ||
db_engine_spec = self.database.db_engine_spec | ||
|
||
for col in table.columns: | ||
try: | ||
|
@@ -848,6 +861,9 @@ def fetch_metadata(self): | |
)) | ||
if not self.main_dttm_col: | ||
self.main_dttm_col = any_date_col | ||
for metric in metrics: | ||
metric.metric_name = db_engine_spec.mutate_expression_label( | ||
metric.metric_name) | ||
self.add_missing_metrics(metrics) | ||
db.session.merge(self) | ||
db.session.commit() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
db_engine_spec
is available throughself.table.database.db_engine_spec
, so no need to receive it as a method argument.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops;
db_engine_spec
removed from all methods, as it is already available in all classes. I also renamed the method toget_sqla_col
from justsqla_col
, as is was previously a property.