From 530c26fada7207b659b84ac38f95f597d2549e1a Mon Sep 17 00:00:00 2001 From: Timi Fasubaa Date: Wed, 16 May 2018 19:08:23 -0700 Subject: [PATCH 1/3] force limit only when there is no existing limit --- superset/sql_lab.py | 4 +++- superset/utils.py | 14 ++++++++++++++ superset/views/core.py | 2 +- tests/celery_tests.py | 2 -- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 68e3f70feecb6..52b08273d4cc2 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -171,6 +171,7 @@ def handle_error(msg): # Limit enforced only for retrieving the data, not for the CTA queries. superset_query = SupersetQuery(rendered_query) executed_sql = superset_query.stripped() + SQL_MAX_ROWS = int(app.config.get('SQL_MAX_ROW', None)) if not superset_query.is_select() and not database.allow_dml: return handle_error( 'Only `SELECT` statements are allowed against this database') @@ -185,7 +186,8 @@ def handle_error(msg): query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (query.limit and superset_query.is_select()): + elif (not query.limit and superset_query.is_select() and SQL_MAX_ROWS): + query.limit = SQL_MAX_ROWS executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) query.limit_used = True diff --git a/superset/utils.py b/superset/utils.py index 08ce0d2f385af..47626aa8e687e 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -882,3 +882,17 @@ def split_adhoc_filters_into_base_filters(fd): fd['having_filters'] = simple_having_filters fd['filters'] = simple_where_filters del fd['adhoc_filters'] + + +def get_limit_from_sql(sql): + sql = sql.lower() + limit = None + tokens = sql.split() + try: + if 'limit' in tokens: + limit_pos = tokens.index('limit') + 1 + limit = int(tokens[limit_pos]) + except Exception as e: + # fail quietly so we can get the more intelligible error from the database. + logging.error('Non-numeric limit added.\n{}'.format(e)) + return limit diff --git a/superset/views/core.py b/superset/views/core.py index 40d24b268abb4..84e305889bad4 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2394,7 +2394,7 @@ def sql_json(self): query = Query( database_id=int(database_id), - limit=int(app.config.get('SQL_MAX_ROW', None)), + limit=utils.get_limit_from_sql(sql), sql=sql, schema=schema, select_as_cta=request.form.get('select_as_cta') == 'true', diff --git a/tests/celery_tests.py b/tests/celery_tests.py index f6d1a2958fd1c..c785702f515ec 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -196,13 +196,11 @@ def test_run_async_query(self): self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records')) self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertTrue('FROM tmp_async_1' in query.select_sql) - self.assertTrue('LIMIT 666' in query.select_sql) self.assertEqual( 'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role ' "WHERE name='Admin'", query.executed_sql) self.assertEqual(sql_where, query.sql) self.assertEqual(0, query.rows) - self.assertEqual(666, query.limit) self.assertEqual(False, query.limit_used) self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) From 767399cd1e2cf39b9c312b245709addb5573a23f Mon Sep 17 00:00:00 2001 From: Timi Fasubaa Date: Fri, 18 May 2018 17:57:21 -0700 Subject: [PATCH 2/3] reuse_regex_logic --- superset/db_engine_specs.py | 10 ++-------- superset/sql_lab.py | 5 +++-- superset/utils.py | 35 ++++++++++++++++++++++++----------- tests/celery_tests.py | 2 ++ 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 0e189b1259e74..5e706c50a4a7d 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -104,14 +104,8 @@ def apply_limit_to_sql(cls, sql, limit, database): ) return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: - no_limit = re.sub(r""" - (?ix) # case insensitive, verbose - \s+ # whitespace - LIMIT\s+\d+ # LIMIT $ROWS - ;? # optional semi-colon - (\s|;)*$ # remove trailing spaces tabs or semicolons - """, '', sql) - return '{no_limit} LIMIT {limit}'.format(**locals()) + sql_without_limit = utils.get_query_without_limit(sql) + return '{sql_without_limit} LIMIT {limit}'.format(**locals()) return sql @staticmethod diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 52b08273d4cc2..7aa5d03efb917 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -171,7 +171,7 @@ def handle_error(msg): # Limit enforced only for retrieving the data, not for the CTA queries. superset_query = SupersetQuery(rendered_query) executed_sql = superset_query.stripped() - SQL_MAX_ROWS = int(app.config.get('SQL_MAX_ROW', None)) + SQL_MAX_ROWS = app.config.get('SQL_MAX_ROW') if not superset_query.is_select() and not database.allow_dml: return handle_error( 'Only `SELECT` statements are allowed against this database') @@ -186,7 +186,8 @@ def handle_error(msg): query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (not query.limit and superset_query.is_select() and SQL_MAX_ROWS): + elif (superset_query.is_select() and SQL_MAX_ROWS and + (not query.limit or query.limit > SQL_MAX_ROWS)): query.limit = SQL_MAX_ROWS executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) query.limit_used = True diff --git a/superset/utils.py b/superset/utils.py index 47626aa8e687e..09131f6d9650d 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -18,6 +18,7 @@ import json import logging import os +import re import signal import smtplib import sys @@ -884,15 +885,27 @@ def split_adhoc_filters_into_base_filters(fd): del fd['adhoc_filters'] +def get_query_without_limit(sql): + return re.sub(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+\d+ # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """, '', sql) + + def get_limit_from_sql(sql): - sql = sql.lower() - limit = None - tokens = sql.split() - try: - if 'limit' in tokens: - limit_pos = tokens.index('limit') + 1 - limit = int(tokens[limit_pos]) - except Exception as e: - # fail quietly so we can get the more intelligible error from the database. - logging.error('Non-numeric limit added.\n{}'.format(e)) - return limit + # returns the limit of the quest or None if it has no limit. + + limit_pattern = re.compile(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+(\d+) # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """) + matches = limit_pattern.findall(sql) + + if matches: + return int(matches[0]) diff --git a/tests/celery_tests.py b/tests/celery_tests.py index c785702f515ec..f6d1a2958fd1c 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -196,11 +196,13 @@ def test_run_async_query(self): self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records')) self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertTrue('FROM tmp_async_1' in query.select_sql) + self.assertTrue('LIMIT 666' in query.select_sql) self.assertEqual( 'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role ' "WHERE name='Admin'", query.executed_sql) self.assertEqual(sql_where, query.sql) self.assertEqual(0, query.rows) + self.assertEqual(666, query.limit) self.assertEqual(False, query.limit_used) self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) From c61381ce96b267b87f1058674fa39ff6fc883113 Mon Sep 17 00:00:00 2001 From: Timi Fasubaa Date: Fri, 25 May 2018 15:45:11 -0700 Subject: [PATCH 3/3] add tests --- superset/db_engine_specs.py | 25 ++++++++++++++++++++++++- superset/sql_lab.py | 3 +-- superset/utils.py | 27 --------------------------- superset/views/core.py | 2 +- tests/celery_tests.py | 30 ++++++++++++++++++++++++++++-- tests/db_engine_specs_test.py | 13 +++++++++++++ 6 files changed, 67 insertions(+), 33 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 5e706c50a4a7d..cd5ab87226e54 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -104,10 +104,33 @@ def apply_limit_to_sql(cls, sql, limit, database): ) return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: - sql_without_limit = utils.get_query_without_limit(sql) + sql_without_limit = cls.get_query_without_limit(sql) return '{sql_without_limit} LIMIT {limit}'.format(**locals()) return sql + @classmethod + def get_limit_from_sql(cls, sql): + limit_pattern = re.compile(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+(\d+) # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """) + matches = limit_pattern.findall(sql) + if matches: + return int(matches[0][0]) + + @classmethod + def get_query_without_limit(cls, sql): + return re.sub(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+\d+ # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """, '', sql) + @staticmethod def csv_to_df(**kwargs): kwargs['filepath_or_buffer'] = \ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 7aa5d03efb917..c9f07ae906c7b 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -186,11 +186,10 @@ def handle_error(msg): query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (superset_query.is_select() and SQL_MAX_ROWS and + if (superset_query.is_select() and SQL_MAX_ROWS and (not query.limit or query.limit > SQL_MAX_ROWS)): query.limit = SQL_MAX_ROWS executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) - query.limit_used = True # Hook to allow environment-specific mutation (usually comments) to the SQL SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR') diff --git a/superset/utils.py b/superset/utils.py index 09131f6d9650d..08ce0d2f385af 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -18,7 +18,6 @@ import json import logging import os -import re import signal import smtplib import sys @@ -883,29 +882,3 @@ def split_adhoc_filters_into_base_filters(fd): fd['having_filters'] = simple_having_filters fd['filters'] = simple_where_filters del fd['adhoc_filters'] - - -def get_query_without_limit(sql): - return re.sub(r""" - (?ix) # case insensitive, verbose - \s+ # whitespace - LIMIT\s+\d+ # LIMIT $ROWS - ;? # optional semi-colon - (\s|;)*$ # remove trailing spaces tabs or semicolons - """, '', sql) - - -def get_limit_from_sql(sql): - # returns the limit of the quest or None if it has no limit. - - limit_pattern = re.compile(r""" - (?ix) # case insensitive, verbose - \s+ # whitespace - LIMIT\s+(\d+) # LIMIT $ROWS - ;? # optional semi-colon - (\s|;)*$ # remove trailing spaces tabs or semicolons - """) - matches = limit_pattern.findall(sql) - - if matches: - return int(matches[0]) diff --git a/superset/views/core.py b/superset/views/core.py index 84e305889bad4..b4a1689a9165b 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2394,7 +2394,7 @@ def sql_json(self): query = Query( database_id=int(database_id), - limit=utils.get_limit_from_sql(sql), + limit=mydb.db_engine_spec.get_limit_from_sql(sql), sql=sql, schema=schema, select_as_cta=request.form.get('select_as_cta') == 'true', diff --git a/tests/celery_tests.py b/tests/celery_tests.py index f6d1a2958fd1c..39b7749ae88f7 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -196,10 +196,9 @@ def test_run_async_query(self): self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records')) self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertTrue('FROM tmp_async_1' in query.select_sql) - self.assertTrue('LIMIT 666' in query.select_sql) self.assertEqual( 'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role ' - "WHERE name='Admin'", query.executed_sql) + "WHERE name='Admin' LIMIT 666", query.executed_sql) self.assertEqual(sql_where, query.sql) self.assertEqual(0, query.rows) self.assertEqual(666, query.limit) @@ -207,6 +206,33 @@ def test_run_async_query(self): self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) + def test_run_async_query_with_lower_limit(self): + main_db = self.get_main_database(db.session) + eng = main_db.get_sqla_engine() + sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1" + result = self.run_sql( + main_db.id, sql_where, '5', async='true', tmp_table='tmp_async_2', + cta='true') + assert result['query']['state'] in ( + QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) + + time.sleep(1) + + query = self.get_query_by_id(result['query']['serverId']) + df = pd.read_sql_query(query.select_sql, con=eng) + self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertEqual([{'name': 'Alpha'}], df.to_dict(orient='records')) + self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertTrue('FROM tmp_async_2' in query.select_sql) + self.assertEqual( + 'CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role ' + "WHERE name='Alpha' LIMIT 1", query.executed_sql) + self.assertEqual(sql_where, query.sql) + self.assertEqual(0, query.rows) + self.assertEqual(1, query.limit) + self.assertEqual(True, query.select_as_cta) + self.assertEqual(True, query.select_as_cta_used) + @staticmethod def de_unicode_dict(d): def str_if_basestring(o): diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index c38e4f569023a..bdce0b060d020 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -95,6 +95,19 @@ def sql_limit_regex( limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) self.assertEquals(expected_sql, limited) + def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec): + q0 = 'select * from table' + q1 = 'select * from mytable limit 10' + q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20' + q3 = 'select * from (select * from my_subquery limit 10);' + q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;' + + self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None) + self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10) + self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20) + self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None) + self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20) + def test_wrapped_query(self): self.sql_limit_regex( 'SELECT * FROM a',