Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force quoted column aliases for Oracle-like databases #5686

Merged
merged 29 commits into from
Sep 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d5242d0
Replace dataframe label override logic with table column override
villebro Aug 22, 2018
ead1b55
Add mutation to any_date_col
villebro Aug 22, 2018
4540d11
Linting
villebro Aug 22, 2018
027f794
Add mutation to oracle and redshift
villebro Aug 22, 2018
f8f851a
Fine tune how and which labels are mutated
villebro Aug 23, 2018
ce3617b
Implement alias quoting logic for oracle-like databases
villebro Aug 25, 2018
5f67d59
Fix and align column and metric sqla_col methods
villebro Aug 25, 2018
8b6e52d
Clean up typos and redundant logic
villebro Aug 25, 2018
62163fa
Move new attribute to old location
villebro Aug 25, 2018
e8dd7da
Linting
villebro Aug 25, 2018
6c6436f
Replace old sqla_col property references with function calls
villebro Aug 25, 2018
4b0e593
Remove redundant calls to mutate_column_label
villebro Aug 26, 2018
76959dd
Move duplicated logic to common function
villebro Aug 26, 2018
6c7fccc
Add db_engine_specs to all sqla_col calls
villebro Aug 26, 2018
c187104
Add missing mydb
villebro Aug 26, 2018
53b77a2
Add note about snowflake-sqlalchemy regression
villebro Aug 26, 2018
e792d15
Make db_engine_spec mandatory in sqla_col
villebro Aug 26, 2018
44eac83
Small refactoring and cleanup
villebro Aug 26, 2018
a161efc
Remove db_engine_spec from get_from_clause call
villebro Aug 26, 2018
901b635
Make db_engine_spec mandatory in adhoc_metric_to_sa
villebro Aug 26, 2018
7e7a2a8
Remove redundant mutate_expression_label call
villebro Aug 26, 2018
d60ed34
Add missing db_engine_specs to adhoc_metric_to_sa
villebro Aug 26, 2018
fc2c4f2
Rename arg label_name to label in get_column_label()
villebro Aug 27, 2018
3f2d874
Rename label function and add docstring
villebro Aug 27, 2018
95c99d3
Remove redundant db_engine_spec args
villebro Aug 27, 2018
04d3b5f
Rename col_label to label
villebro Aug 27, 2018
efb6313
Remove get_column_name wrapper and make direct calls to db_engine_spec
villebro Aug 27, 2018
3a6fe4b
Remove unneeded db_engine_specs
villebro Aug 27, 2018
0d18163
Rename sa_ vars to sqla_
villebro Aug 27, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ Make sure the user has privileges to access and use all required
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
not test for user rights during engine creation.

*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
use version 1.1.0 or try a newer version.

See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.

Caching
Expand Down
106 changes: 57 additions & 49 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,21 @@ class TableColumn(Model, BaseColumn):
s for s in export_fields if s not in ('table_id',)]
export_parent = 'table'

@property
def sqla_col(self):
name = self.column_name
def get_sqla_col(self, label=None):
db_engine_spec = self.table.database.db_engine_spec
label = db_engine_spec.make_label_compatible(label if label else self.column_name)
if not self.expression:
col = column(self.column_name).label(name)
col = column(self.column_name).label(label)
else:
col = literal_column(self.expression).label(name)
col = literal_column(self.expression).label(label)
return col

@property
def datasource(self):
return self.table

def get_time_filter(self, start_dttm, end_dttm):
col = self.sqla_col.label('__time')
col = self.get_sqla_col(label='__time')
l = [] # noqa: E741
if start_dttm:
l.append(col >= text(self.dttm_sql_literal(start_dttm)))
Expand Down Expand Up @@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric):
s for s in export_fields if s not in ('table_id', )])
export_parent = 'table'

@property
def sqla_col(self):
name = self.metric_name
return literal_column(self.expression).label(name)
def get_sqla_col(self, label=None):
db_engine_spec = self.table.database.db_engine_spec
label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
return literal_column(self.expression).label(label)

@property
def perm(self):
Expand Down Expand Up @@ -421,11 +421,10 @@ def values_for_column(self, column_name, limit=10000):
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
db_engine_spec = self.database.db_engine_spec

qry = (
select([target_col.sqla_col])
.select_from(self.get_from_clause(tp, db_engine_spec))
select([target_col.get_sqla_col()])
.select_from(self.get_from_clause(tp))
.distinct()
)
if limit:
Expand Down Expand Up @@ -474,7 +473,7 @@ def get_sqla_table(self):
tbl.schema = self.schema
return tbl

def get_from_clause(self, template_processor=None, db_engine_spec=None):
def get_from_clause(self, template_processor=None):
# Supporting arbitrary SQL statements in place of tables
if self.sql:
from_sql = self.sql
Expand All @@ -484,7 +483,7 @@ def get_from_clause(self, template_processor=None, db_engine_spec=None):
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table()

def adhoc_metric_to_sa(self, metric, cols):
def adhoc_metric_to_sqla(self, metric, cols):
"""
Turn an adhoc metric into a sqlalchemy column.

Expand All @@ -493,22 +492,25 @@ def adhoc_metric_to_sa(self, metric, cols):
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
expressionType = metric.get('expressionType')
if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
expression_type = metric.get('expressionType')
db_engine_spec = self.database.db_engine_spec
label = db_engine_spec.make_label_compatible(metric.get('label'))

if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
sa_column = column(column_name)
sqla_column = column(column_name)
table_column = cols.get(column_name)

if table_column:
sa_column = table_column.sqla_col

sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
sa_metric = sa_metric.label(metric.get('label'))
return sa_metric
elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sa_metric = literal_column(metric.get('sqlExpression'))
sa_metric = sa_metric.label(metric.get('label'))
return sa_metric
sqla_column = table_column.get_sqla_col()

sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
sqla_metric = sqla_metric.label(label)
return sqla_metric
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sqla_metric = literal_column(metric.get('sqlExpression'))
sqla_metric = sqla_metric.label(label)
return sqla_metric
else:
return None

Expand Down Expand Up @@ -566,15 +568,16 @@ def get_sqla_query( # sqla
metrics_exprs = []
for m in metrics:
if utils.is_adhoc_metric(m):
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict:
metrics_exprs.append(metrics_dict.get(m).sqla_col)
metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
main_metric_expr = literal_column('COUNT(*)').label('ccount')
main_metric_expr = literal_column('COUNT(*)').label(
db_engine_spec.make_label_compatible('count'))

select_exprs = []
groupby_exprs = []
Expand All @@ -585,16 +588,16 @@ def get_sqla_query( # sqla
inner_groupby_exprs = []
for s in groupby:
col = cols[s]
outer = col.sqla_col
inner = col.sqla_col.label(col.column_name + '__')
outer = col.get_sqla_col()
inner = col.get_sqla_col(col.column_name + '__')

groupby_exprs.append(outer)
select_exprs.append(outer)
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
elif columns:
for s in columns:
select_exprs.append(cols[s].sqla_col)
select_exprs.append(cols[s].get_sqla_col())
metrics_exprs = []

if granularity:
Expand All @@ -618,7 +621,7 @@ def get_sqla_query( # sqla
select_exprs += metrics_exprs
qry = sa.select(select_exprs)

tbl = self.get_from_clause(template_processor, db_engine_spec)
tbl = self.get_from_clause(template_processor)

if not columns:
qry = qry.group_by(*groupby_exprs)
Expand All @@ -638,33 +641,34 @@ def get_sqla_query( # sqla
target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target)
if op in ('in', 'not in'):
cond = col_obj.sqla_col.in_(eq)
cond = col_obj.get_sqla_col().in_(eq)
if '<NULL>' in eq:
cond = or_(cond, col_obj.sqla_col == None) # noqa
cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_num:
eq = utils.string_to_num(flt['val'])
if op == '==':
where_clause_and.append(col_obj.sqla_col == eq)
where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == '!=':
where_clause_and.append(col_obj.sqla_col != eq)
where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == '>':
where_clause_and.append(col_obj.sqla_col > eq)
where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == '<':
where_clause_and.append(col_obj.sqla_col < eq)
where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == '>=':
where_clause_and.append(col_obj.sqla_col >= eq)
where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == '<=':
where_clause_and.append(col_obj.sqla_col <= eq)
where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == 'LIKE':
where_clause_and.append(col_obj.sqla_col.like(eq))
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == 'IS NULL':
where_clause_and.append(col_obj.sqla_col == None) # noqa
where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
elif op == 'IS NOT NULL':
where_clause_and.append(col_obj.sqla_col != None) # noqa
where_clause_and.append(
col_obj.get_sqla_col() != None) # noqa
if extras:
where = extras.get('where')
if where:
Expand All @@ -686,7 +690,7 @@ def get_sqla_query( # sqla
for col, ascending in orderby:
direction = asc if ascending else desc
if utils.is_adhoc_metric(col):
col = self.adhoc_metric_to_sa(col, cols)
col = self.adhoc_metric_to_sqla(col, cols)
qry = qry.order_by(direction(col))

if row_limit:
Expand All @@ -712,12 +716,12 @@ def get_sqla_query( # sqla
ob = inner_main_metric_expr
if timeseries_limit_metric:
if utils.is_adhoc_metric(timeseries_limit_metric):
ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols)
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
elif timeseries_limit_metric in metrics_dict:
timeseries_limit_metric = metrics_dict.get(
timeseries_limit_metric,
)
ob = timeseries_limit_metric.sqla_col
ob = timeseries_limit_metric.get_sqla_col()
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
direction = desc if order_desc else asc
Expand Down Expand Up @@ -762,7 +766,7 @@ def _get_top_groups(self, df, dimensions):
group = []
for dimension in dimensions:
col_obj = cols.get(dimension)
group.append(col_obj.sqla_col == row[dimension])
group.append(col_obj.get_sqla_col() == row[dimension])
groups.append(and_(*group))

return or_(*groups)
Expand Down Expand Up @@ -816,6 +820,7 @@ def fetch_metadata(self):
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
db_engine_spec = self.database.db_engine_spec

for col in table.columns:
try:
Expand Down Expand Up @@ -848,6 +853,9 @@ def fetch_metadata(self):
))
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
for metric in metrics:
metric.metric_name = db_engine_spec.mutate_expression_label(
metric.metric_name)
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()
Expand Down
4 changes: 1 addition & 3 deletions superset/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def __init__(self, data, cursor_description, db_engine_spec):
if cursor_description:
column_names = [col[0] for col in cursor_description]

case_sensitive = db_engine_spec.consistent_case_sensitivity
self.column_names = dedup(column_names,
case_sensitive=case_sensitive)
self.column_names = dedup(column_names)

data = data or []
self.df = (
Expand Down
66 changes: 14 additions & 52 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sqlalchemy import select
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.sql import text
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import TextAsFrom
import sqlparse
from tableschema import Table
Expand Down Expand Up @@ -101,7 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
consistent_case_sensitivity = True # do results have same case as qry for col names?
force_column_alias_quotes = False
arraysize = None

@classmethod
Expand Down Expand Up @@ -376,55 +376,15 @@ def execute(cls, cursor, query, async=False):
cursor.execute(query)

@classmethod
def adjust_df_column_names(cls, df, fd):
"""Based of fields in form_data, return dataframe with new column names

Usually sqla engines return column names whose case matches that of the
original query. For example:
SELECT 1 as col1, 2 as COL2, 3 as Col_3
will usually result in the following df.columns:
['col1', 'COL2', 'Col_3'].
For these engines there is no need to adjust the dataframe column names
(default behavior). However, some engines (at least Snowflake, Oracle and
Redshift) return column names with different case than in the original query,
usually all uppercase. For these the column names need to be adjusted to
correspond to the case of the fields specified in the form data for Viz
to work properly. This adjustment can be done here.
def make_label_compatible(cls, label):
"""
if cls.consistent_case_sensitivity:
return df
else:
return cls.align_df_col_names_with_form_data(df, fd)

@staticmethod
def align_df_col_names_with_form_data(df, fd):
"""Helper function to rename columns that have changed case during query.

Returns a dataframe where column names have been adjusted to correspond with
column names in form data (case insensitive). Examples:
dataframe: 'col1', form_data: 'col1' -> no change
dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
Return a sqlalchemy.sql.elements.quoted_name if the engine requires
quoting of aliases to ensure that select query and query results
have same case.
"""

columns = set()
lowercase_mapping = {}

metrics = utils.get_metric_names(fd.get('metrics', []))
groupby = fd.get('groupby', [])
other_cols = [utils.DTTM_ALIAS]
for col in metrics + groupby + other_cols:
columns.add(col)
lowercase_mapping[col.lower()] = col

rename_cols = {}
for col in df.columns:
if col not in columns:
orig_col = lowercase_mapping.get(col.lower())
if orig_col:
rename_cols[col] = orig_col

return df.rename(index=str, columns=rename_cols)
if cls.force_column_alias_quotes is True:
return quoted_name(label, True)
return label

@staticmethod
def mutate_expression_label(label):
Expand Down Expand Up @@ -478,7 +438,8 @@ def get_table_names(cls, schema, inspector):

class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
consistent_case_sensitivity = False
force_column_alias_quotes = True

time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
Expand Down Expand Up @@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):

class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
consistent_case_sensitivity = False
force_column_alias_quotes = True


class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
consistent_case_sensitivity = False
force_column_alias_quotes = True

time_grain_functions = {
None: '{col}',
Expand All @@ -545,6 +506,7 @@ def convert_dttm(cls, target_type, dttm):
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)'
Expand Down
Loading