From b9dac020f385d5212d87b52776a91aa2e54c52ea Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Mon, 14 May 2018 21:43:13 +0300 Subject: [PATCH] Force lowercase column names for Snowflake and Oracle (#4994) * Force lowercase column names for Snowflake and Oracle * Force lowercase column names for Snowflake and Oracle * Remove lowercasing of DB2 columns * Remove DB2 lowercasing * Fix test cases --- superset/db_engine_specs.py | 17 +++++++++++++++++ superset/sql_lab.py | 8 +++----- tests/sqllab_tests.py | 6 +++--- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index a718a0d62ca85..733588a77ad91 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -281,6 +281,15 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username): """ return {} + @classmethod + def get_normalized_column_names(cls, cursor_description): + columns = cursor_description if cursor_description else [] + return [cls.normalize_column_name(col[0]) for col in columns] + + @staticmethod + def normalize_column_name(column_name): + return column_name + class PostgresBaseEngineSpec(BaseEngineSpec): """ Abstract class for Postgres 'like' databases """ @@ -350,6 +359,10 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): Grain('year', _('year'), "DATE_TRUNC('YEAR', {col})", 'P1Y'), ) + @staticmethod + def normalize_column_name(column_name): + return column_name.lower() + class VerticaEngineSpec(PostgresBaseEngineSpec): engine = 'vertica' @@ -379,6 +392,10 @@ def convert_dttm(cls, target_type, dttm): """TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""" ).format(dttm.isoformat()) + @staticmethod + def normalize_column_name(column_name): + return column_name.lower() + class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 856ea4880fb71..98f612a76dab9 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -97,10 +97,8 @@ def session_scope(nullpool): session.close() -def convert_results_to_df(cursor_description, data): +def convert_results_to_df(column_names, data): """Convert raw query results to a DataFrame.""" - column_names = ( - [col[0] for col in cursor_description] if cursor_description else []) column_names = dedup(column_names) # check whether the result set has any nested dict columns @@ -236,7 +234,7 @@ def handle_error(msg): return handle_error(db_engine_spec.extract_error_message(e)) logging.info('Fetching cursor description') - cursor_description = cursor.description + column_names = db_engine_spec.get_normalized_column_names(cursor.description) if conn is not None: conn.commit() @@ -245,7 +243,7 @@ def handle_error(msg): if query.status == utils.QueryStatus.STOPPED: return handle_error('The query has been stopped') - cdf = convert_results_to_df(cursor_description, data) + cdf = convert_results_to_df(column_names, data) query.rows = cdf.size query.progress = 100 diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 4626f53093559..49926f80def1b 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -203,7 +203,7 @@ def test_alias_duplicate(self): raise_on_error=True) def test_df_conversion_no_dict(self): - cols = [['string_col'], ['int_col'], ['float_col']] + cols = ['string_col', 'int_col', 'float_col'] data = [['a', 4, 4.0]] cdf = convert_results_to_df(cols, data) @@ -211,7 +211,7 @@ def test_df_conversion_no_dict(self): self.assertEquals(len(cols), len(cdf.columns)) def test_df_conversion_tuple(self): - cols = [['string_col'], ['int_col'], ['list_col'], ['float_col']] + cols = ['string_col', 'int_col', 'list_col', 'float_col'] data = [(u'Text', 111, [123], 1.0)] cdf = convert_results_to_df(cols, data) @@ -219,7 +219,7 @@ def test_df_conversion_tuple(self): self.assertEquals(len(cols), len(cdf.columns)) def test_df_conversion_dict(self): - cols = [['string_col'], ['dict_col'], ['int_col']] + cols = ['string_col', 'dict_col', 'int_col'] data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]] cdf = convert_results_to_df(cols, data)