From b34f61bed8130b74bd0fbb0656306c7a6ee02048 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 14 May 2018 14:44:05 -0500 Subject: [PATCH] [sql lab] a better approach at limiting queries (#4947) * [sql lab] a better approach at limiting queries Currently there are two mechanisms that we use to enforce the row limiting constraints, depending on the database engine: 1. use dbapi's `cursor.fetchmany()` 2. wrap the SQL into a limiting subquery Method 1 isn't great as it can result in the database server storing larger than required result sets in memory expecting another fetch command while we know we don't need that. Method 2 has a positive side of working with all database engines, whether they use LIMIT, ROWNUM, TOP or whatever else since sqlalchemy does the work as specified for the dialect. On the downside though the query optimizer might not be able to optimize this as much as an approach that doesn't use a subquery. Since most modern DBs use the LIMIT syntax, this adds a regex approach to modify the query and force a LIMIT clause without using a subquery for the database that support this syntax and uses method 2 for all others. * Fixing build * Fix lint * Added more tests * Fix tests --- superset/db_engine_specs.py | 33 ++++++++++- superset/models/core.py | 15 +---- superset/sql_lab.py | 6 +- superset/sql_parse.py | 24 +++----- tests/celery_tests.py | 32 ---------- tests/db_engine_specs_test.py | 106 +++++++++++++++++++++++++++++++++- 6 files changed, 145 insertions(+), 71 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index eea06688e8a7a..0e189b1259e74 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -35,6 +35,7 @@ from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url from sqlalchemy.sql import text +from sqlalchemy.sql.expression import TextAsFrom import sqlparse import unicodecsv from werkzeug.utils import secure_filename @@ -55,6 +56,7 @@ class LimitMethod(object): """Enum the ways that limits can be applied""" FETCH_MANY = 'fetch_many' WRAP_SQL = 'wrap_sql' + FORCE_LIMIT = 'force_limit' class BaseEngineSpec(object): @@ -65,7 +67,7 @@ class BaseEngineSpec(object): cursor_execute_kwargs = {} time_grains = tuple() time_groupby_inline = False - limit_method = LimitMethod.FETCH_MANY + limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False inner_joins = True @@ -88,6 +90,30 @@ def extra_table_metadata(cls, database, table_name, schema_name): """Returns engine-specific table metadata""" return {} + @classmethod + def apply_limit_to_sql(cls, sql, limit, database): + """Alters the SQL statement to apply a LIMIT clause""" + if cls.limit_method == LimitMethod.WRAP_SQL: + sql = sql.strip('\t\n ;') + qry = ( + select('*') + .select_from( + TextAsFrom(text(sql), ['*']).alias('inner_qry'), + ) + .limit(limit) + ) + return database.compile_sqla_query(qry) + elif LimitMethod.FORCE_LIMIT: + no_limit = re.sub(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+\d+ # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """, '', sql) + return '{no_limit} LIMIT {limit}'.format(**locals()) + return sql + @staticmethod def csv_to_df(**kwargs): kwargs['filepath_or_buffer'] = \ @@ -346,7 +372,6 @@ def get_table_names(cls, schema, inspector): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' - time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), "DATE_TRUNC('SECOND', {col})", 'PT1S'), @@ -374,6 +399,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec): class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -399,6 +425,7 @@ def normalize_column_name(column_name): class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), @@ -1123,6 +1150,7 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username): class MssqlEngineSpec(BaseEngineSpec): engine = 'mssql' epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')" + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -1327,7 +1355,6 @@ def get_schema_names(cls, inspector): class DruidEngineSpec(BaseEngineSpec): """Engine spec for Druid.io""" engine = 'druid' - limit_method = LimitMethod.FETCH_MANY inner_joins = False diff --git a/superset/models/core.py b/superset/models/core.py index 2ad20faca85c6..8448c7ba54e49 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -22,7 +22,7 @@ import sqlalchemy as sqla from sqlalchemy import ( Boolean, Column, create_engine, DateTime, ForeignKey, Integer, - MetaData, select, String, Table, Text, + MetaData, String, Table, Text, ) from sqlalchemy.engine import url from sqlalchemy.engine.url import make_url @@ -30,8 +30,6 @@ from sqlalchemy.orm.session import make_transient from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import text -from sqlalchemy.sql.expression import TextAsFrom from sqlalchemy_utils import EncryptedType from superset import app, db, db_engine_specs, security_manager, utils @@ -721,15 +719,8 @@ def select_star( self, table_name, schema=schema, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition, cols=cols) - def wrap_sql_limit(self, sql, limit=1000): - qry = ( - select('*') - .select_from( - TextAsFrom(text(sql), ['*']) - .alias('inner_qry'), - ).limit(limit) - ) - return self.compile_sqla_query(qry) + def apply_limit_to_sql(self, sql, limit=1000): + return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 98f612a76dab9..68e3f70feecb6 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -17,7 +17,6 @@ from sqlalchemy.pool import NullPool from superset import app, dataframe, db, results_backend, security_manager, utils -from superset.db_engine_specs import LimitMethod from superset.models.sql_lab import Query from superset.sql_parse import SupersetQuery from superset.utils import get_celery_app, QueryStatus @@ -186,9 +185,8 @@ def handle_error(msg): query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (query.limit and superset_query.is_select() and - db_engine_spec.limit_method == LimitMethod.WRAP_SQL): - executed_sql = database.wrap_sql_limit(executed_sql, query.limit) + elif (query.limit and superset_query.is_select()): + executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) query.limit_used = True # Hook to allow environment-specific mutation (usually comments) to the SQL diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 790371ae35706..ea1c9c38851c1 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -15,13 +15,13 @@ PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} -# TODO: some sql_lab logic here. class SupersetQuery(object): def __init__(self, sql_statement): self.sql = sql_statement self._table_names = set() self._alias_names = set() # TODO: multistatement support + logging.info('Parsing with sqlparse statement {}'.format(self.sql)) self._parsed = sqlparse.parse(self.sql) for statement in self._parsed: @@ -36,11 +36,7 @@ def is_select(self): return self._parsed[0].get_type() == 'SELECT' def stripped(self): - sql = self.sql - if sql: - while sql[-1] in (' ', ';', '\n', '\t'): - sql = sql[:-1] - return sql + return self.sql.strip(' \t\n;') @staticmethod def __precedes_table_name(token_value): @@ -65,13 +61,12 @@ def __is_result_operation(keyword): @staticmethod def __is_identifier(token): - return ( - isinstance(token, IdentifierList) or isinstance(token, Identifier)) + return isinstance(token, (IdentifierList, Identifier)) def __process_identifier(self, identifier): # exclude subselects if '(' not in '{}'.format(identifier): - self._table_names.add(SupersetQuery.__get_full_name(identifier)) + self._table_names.add(self.__get_full_name(identifier)) return # store aliases @@ -94,11 +89,6 @@ def as_create_table(self, table_name, overwrite=False): :param overwrite, boolean, table table_name will be dropped if true :return: string, create table as query """ - # TODO(bkyryliuk): enforce that all the columns have names. - # Presto requires it for the CTA operation. - # TODO(bkyryliuk): drop table if allowed, check the namespace and - # the permissions. - # TODO raise if multi-statement exec_sql = '' sql = self.stripped() if overwrite: @@ -117,7 +107,7 @@ def __extract_from_token(self, token): self.__extract_from_token(item) if item.ttype in Keyword: - if SupersetQuery.__precedes_table_name(item.value.upper()): + if self.__precedes_table_name(item.value.upper()): table_name_preceding_token = True continue @@ -125,7 +115,7 @@ def __extract_from_token(self, token): continue if item.ttype in Keyword: - if SupersetQuery.__is_result_operation(item.value): + if self.__is_result_operation(item.value): table_name_preceding_token = False continue # FROM clause is over @@ -136,5 +126,5 @@ def __extract_from_token(self, token): if isinstance(item, IdentifierList): for token in item.tokens: - if SupersetQuery.__is_identifier(token): + if self.__is_identifier(token): self.__process_identifier(token) diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 79b71e986e2aa..f6d1a2958fd1c 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -139,38 +139,6 @@ def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp', self.logout() return json.loads(resp.data.decode('utf-8')) - def test_add_limit_to_the_query(self): - main_db = self.get_main_database(db.session) - - select_query = 'SELECT * FROM outer_space;' - updated_select_query = main_db.wrap_sql_limit(select_query, 100) - # Different DB engines have their own spacing while compiling - # the queries, that's why ' '.join(query.split()) is used. - # In addition some of the engines do not include OFFSET 0. - self.assertTrue( - 'SELECT * FROM (SELECT * FROM outer_space;) AS inner_qry ' - 'LIMIT 100' in ' '.join(updated_select_query.split()), - ) - - select_query_no_semicolon = 'SELECT * FROM outer_space' - updated_select_query_no_semicolon = main_db.wrap_sql_limit( - select_query_no_semicolon, 100) - self.assertTrue( - 'SELECT * FROM (SELECT * FROM outer_space) AS inner_qry ' - 'LIMIT 100' in - ' '.join(updated_select_query_no_semicolon.split()), - ) - - multi_line_query = ( - "SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';" - ) - updated_multi_line_query = main_db.wrap_sql_limit(multi_line_query, 100) - self.assertTrue( - 'SELECT * FROM (SELECT * FROM planets WHERE ' - "Luke_Father = 'Darth Vader';) AS inner_qry LIMIT 100" in - ' '.join(updated_multi_line_query.split()), - ) - def test_run_sync_query_dont_exist(self): main_db = self.get_main_database(db.session) db_id = main_db.id diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 1a1282ad1a47f..c38e4f569023a 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,12 +4,15 @@ from __future__ import print_function from __future__ import unicode_literals -import unittest +import textwrap -from superset.db_engine_specs import HiveEngineSpec +from superset.db_engine_specs import ( + HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec) +from superset.models.core import Database +from .base_tests import SupersetTestCase -class DbEngineSpecsTestCase(unittest.TestCase): +class DbEngineSpecsTestCase(SupersetTestCase): def test_0_progress(self): log = """ 17/02/07 18:26:27 INFO log.PerfLogger: @@ -80,3 +83,100 @@ def test_job_2_launched_stage_2_stages_progress(self): 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) + + def get_generic_database(self): + return Database(sqlalchemy_uri='mysql://localhost') + + def sql_limit_regex( + self, sql, expected_sql, + engine_spec_class=MySQLEngineSpec, + limit=1000): + main = self.get_generic_database() + limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) + self.assertEquals(expected_sql, limited) + + def test_wrapped_query(self): + self.sql_limit_regex( + 'SELECT * FROM a', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_wrapped_semi(self): + self.sql_limit_regex( + 'SELECT * FROM a;', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_wrapped_semi_tabs(self): + self.sql_limit_regex( + 'SELECT * FROM a \t \n ; \t \n ', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_simple_limit_query(self): + self.sql_limit_regex( + 'SELECT * FROM a', + 'SELECT * FROM a LIMIT 1000', + ) + + def test_modify_limit_query(self): + self.sql_limit_regex( + 'SELECT * FROM a LIMIT 9999', + 'SELECT * FROM a LIMIT 1000', + ) + + def test_modify_newline_query(self): + self.sql_limit_regex( + 'SELECT * FROM a\nLIMIT 9999', + 'SELECT * FROM a LIMIT 1000', + ) + + def test_modify_lcase_limit_query(self): + self.sql_limit_regex( + 'SELECT * FROM a\tlimit 9999', + 'SELECT * FROM a LIMIT 1000', + ) + + def test_limit_query_with_limit_subquery(self): + self.sql_limit_regex( + 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999', + 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000', + ) + + def test_limit_with_expr(self): + self.sql_limit_regex( + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT + 99990"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), + ) + + def test_limit_expr_and_semicolon(self): + self.sql_limit_regex( + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990 ;"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), + )