From 7e819135de285886ee345b72d9e4da753142427b Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Sun, 13 May 2018 13:31:40 -0500 Subject: [PATCH] Added more tests --- superset/db_engine_specs.py | 8 +++ superset/models/core.py | 2 +- superset/sql_lab.py | 2 +- tests/db_engine_specs_test.py | 117 +++++++++++++++++++++++++--------- 4 files changed, 97 insertions(+), 32 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 957da21194fee..88c51eadbcd38 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -94,6 +94,7 @@ def extra_table_metadata(cls, database, table_name, schema_name): def apply_limit_to_sql(cls, sql, limit, database): """Alters the SQL statement to apply a LIMIT clause""" if cls.limit_method == LimitMethod.WRAP_SQL: + sql = sql.strip('\t\n ;') qry = ( select('*') .select_from( @@ -104,6 +105,13 @@ def apply_limit_to_sql(cls, sql, limit, database): return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: no_limit = re.sub(r'(?i)\s+LIMIT\s+\d+;?(\s|;)*$', '', sql) + no_limit = re.sub(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+\d+ # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """, '', sql) return '{no_limit} LIMIT {limit}'.format(**locals()) return sql diff --git a/superset/models/core.py b/superset/models/core.py index ed33461f23c5e..8448c7ba54e49 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -719,7 +719,7 @@ def select_star( self, table_name, schema=schema, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition, cols=cols) - def wrap_sql_limit(self, sql, limit=1000): + def apply_limit_to_sql(self, sql, limit=1000): return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 881271a2281f9..75e88146b9eee 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -188,7 +188,7 @@ def handle_error(msg): executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True elif (query.limit and superset_query.is_select()): - executed_sql = database.wrap_sql_limit(executed_sql, query.limit) + executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) query.limit_used = True # Hook to allow environment-specific mutation (usually comments) to the SQL diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 0726a68fbff0e..71c93ee328f7f 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,6 +4,9 @@ from __future__ import print_function from __future__ import unicode_literals +import textwrap + +from superset import db from superset.db_engine_specs import ( HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec) from superset.models.core import Database @@ -82,44 +85,98 @@ def test_job_2_launched_stage_2_stages_progress(self): """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) + def sql_limit_regex( + self, sql, expected_sql, + engine_spec_class=MySQLEngineSpec, + limit=1000): + main = self.get_main_database(db.session) + limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) + self.assertEquals(expected_sql, limited) + def test_wrapped_query(self): - sql = 'SELECT * FROM a' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MssqlEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_wrapped_semi(self): + self.sql_limit_regex( + 'SELECT * FROM a;', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_wrapped_semi_tabs(self): + self.sql_limit_regex( + 'SELECT * FROM a \t \n ; \t \n ', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) def test_simple_limit_query(self): - sql = 'SELECT * FROM a' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM a LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a', + 'SELECT * FROM a LIMIT 1000', + ) def test_modify_limit_query(self): - sql = 'SELECT * FROM a LIMIT 9999' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM a LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a LIMIT 9999', + 'SELECT * FROM a LIMIT 1000', + ) def test_modify_newline_query(self): - sql = 'SELECT * FROM a\nLIMIT 9999' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM a LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a\nLIMIT 9999', + 'SELECT * FROM a LIMIT 1000', + ) def test_modify_lcase_limit_query(self): - sql = 'SELECT * FROM a\tlimit 9999' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM a LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a\tlimit 9999', + 'SELECT * FROM a LIMIT 1000', + ) def test_limit_query_with_limit_subquery(self): - sql = 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999', + 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000', + ) + + def test_limit_with_expr(self): + self.sql_limit_regex( + textwrap.dedent( + """\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT + 99990"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), + ) + + def test_limit_expr_and_semicolon(self): + self.sql_limit_regex( + textwrap.dedent( + """\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990 ;"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), + )