From be9084ae3cf7084bbc530cbf7e50c1ff075db8a7 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 7 May 2018 18:18:31 -0700 Subject: [PATCH] [sql lab] a better approach at limiting queries Currently there are two mechanisms that we use to enforce the row limiting constraints, depending on the database engine: 1. use dbapi's `cursor.fetchmany()` 2. wrap the SQL into a limiting subquery Method 1 isn't great as it can result in the database server storing larger than required result sets in memory expecting another fetch command while we know we don't need that. Method 2 has a positive side of working with all database engines, whether they use LIMIT, ROWNUM, TOP or whatever else since sqlalchemy does the work as specified for the dialect. On the downside though the query optimizer might not be able to optimize this as much as an approach that doesn't use a subquery. Since most modern DBs use the LIMIT syntax, this adds a regex approach to modify the query and force a LIMIT clause without using a subquery for the database that support this syntax and uses method 2 for all others. --- superset/db_engine_specs.py | 26 +++++++++++++++--- superset/models/core.py | 9 +------ superset/sql_lab.py | 3 +-- superset/sql_parse.py | 24 +++++------------ tests/db_engine_specs_test.py | 50 ++++++++++++++++++++++++++++++++--- 5 files changed, 78 insertions(+), 34 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index a718a0d62ca85..189fb713db36d 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -35,6 +35,7 @@ from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url from sqlalchemy.sql import text +from sqlalchemy.sql.expression import TextAsFrom import sqlparse import unicodecsv from werkzeug.utils import secure_filename @@ -55,6 +56,7 @@ class LimitMethod(object): """Enum the ways that limits can be applied""" FETCH_MANY = 'fetch_many' WRAP_SQL = 'wrap_sql' + FORCE_LIMIT = 'force_limit' class BaseEngineSpec(object): @@ -65,7 +67,7 @@ class BaseEngineSpec(object): cursor_execute_kwargs = {} time_grains = tuple() time_groupby_inline = False - limit_method = LimitMethod.FETCH_MANY + limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False inner_joins = True @@ -88,6 +90,23 @@ def extra_table_metadata(cls, database, table_name, schema_name): """Returns engine-specific table metadata""" return {} + @classmethod + def apply_limit_to_sql(cls, sql, limit, database): + """Alters the SQL statement to apply a LIMIT clause""" + if cls.limit_method == LimitMethod.WRAP_SQL: + qry = ( + select('*') + .select_from( + TextAsFrom(text(sql), ['*']).alias('inner_qry'), + ) + .limit(limit) + ) + return database.compile_sqla_query(qry) + elif LimitMethod.FORCE_LIMIT: + no_limit = re.sub(r"(?i)\s+LIMIT\s+\d+;?(\s|;)*$", '', sql) + return "{no_limit} LIMIT {limit}".format(**locals()) + return sql + @staticmethod def csv_to_df(**kwargs): kwargs['filepath_or_buffer'] = \ @@ -337,7 +356,6 @@ def get_table_names(cls, schema, inspector): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' - time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), "DATE_TRUNC('SECOND', {col})", 'PT1S'), @@ -361,6 +379,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec): class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -382,6 +401,7 @@ def convert_dttm(cls, target_type, dttm): class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), @@ -1106,6 +1126,7 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username): class MssqlEngineSpec(BaseEngineSpec): engine = 'mssql' epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')" + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -1310,7 +1331,6 @@ def get_schema_names(cls, inspector): class DruidEngineSpec(BaseEngineSpec): """Engine spec for Druid.io""" engine = 'druid' - limit_method = LimitMethod.FETCH_MANY inner_joins = False diff --git a/superset/models/core.py b/superset/models/core.py index 2ad20faca85c6..2187b0540bd6b 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -722,14 +722,7 @@ def select_star( indent=indent, latest_partition=latest_partition, cols=cols) def wrap_sql_limit(self, sql, limit=1000): - qry = ( - select('*') - .select_from( - TextAsFrom(text(sql), ['*']) - .alias('inner_qry'), - ).limit(limit) - ) - return self.compile_sqla_query(qry) + return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 856ea4880fb71..99ec957d80d1c 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -188,8 +188,7 @@ def handle_error(msg): query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (query.limit and superset_query.is_select() and - db_engine_spec.limit_method == LimitMethod.WRAP_SQL): + elif (query.limit and superset_query.is_select()): executed_sql = database.wrap_sql_limit(executed_sql, query.limit) query.limit_used = True diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 790371ae35706..ea1c9c38851c1 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -15,13 +15,13 @@ PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} -# TODO: some sql_lab logic here. class SupersetQuery(object): def __init__(self, sql_statement): self.sql = sql_statement self._table_names = set() self._alias_names = set() # TODO: multistatement support + logging.info('Parsing with sqlparse statement {}'.format(self.sql)) self._parsed = sqlparse.parse(self.sql) for statement in self._parsed: @@ -36,11 +36,7 @@ def is_select(self): return self._parsed[0].get_type() == 'SELECT' def stripped(self): - sql = self.sql - if sql: - while sql[-1] in (' ', ';', '\n', '\t'): - sql = sql[:-1] - return sql + return self.sql.strip(' \t\n;') @staticmethod def __precedes_table_name(token_value): @@ -65,13 +61,12 @@ def __is_result_operation(keyword): @staticmethod def __is_identifier(token): - return ( - isinstance(token, IdentifierList) or isinstance(token, Identifier)) + return isinstance(token, (IdentifierList, Identifier)) def __process_identifier(self, identifier): # exclude subselects if '(' not in '{}'.format(identifier): - self._table_names.add(SupersetQuery.__get_full_name(identifier)) + self._table_names.add(self.__get_full_name(identifier)) return # store aliases @@ -94,11 +89,6 @@ def as_create_table(self, table_name, overwrite=False): :param overwrite, boolean, table table_name will be dropped if true :return: string, create table as query """ - # TODO(bkyryliuk): enforce that all the columns have names. - # Presto requires it for the CTA operation. - # TODO(bkyryliuk): drop table if allowed, check the namespace and - # the permissions. - # TODO raise if multi-statement exec_sql = '' sql = self.stripped() if overwrite: @@ -117,7 +107,7 @@ def __extract_from_token(self, token): self.__extract_from_token(item) if item.ttype in Keyword: - if SupersetQuery.__precedes_table_name(item.value.upper()): + if self.__precedes_table_name(item.value.upper()): table_name_preceding_token = True continue @@ -125,7 +115,7 @@ def __extract_from_token(self, token): continue if item.ttype in Keyword: - if SupersetQuery.__is_result_operation(item.value): + if self.__is_result_operation(item.value): table_name_preceding_token = False continue # FROM clause is over @@ -136,5 +126,5 @@ def __extract_from_token(self, token): if isinstance(item, IdentifierList): for token in item.tokens: - if SupersetQuery.__is_identifier(token): + if self.__is_identifier(token): self.__process_identifier(token) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 1a1282ad1a47f..71e6bc493bb8b 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,12 +4,12 @@ from __future__ import print_function from __future__ import unicode_literals -import unittest +from superset.db_engine_specs import MssqlEngineSpec, HiveEngineSpec, MySQLEngineSpec +from superset.models.core import Database +from .base_tests import SupersetTestCase -from superset.db_engine_specs import HiveEngineSpec - -class DbEngineSpecsTestCase(unittest.TestCase): +class DbEngineSpecsTestCase(SupersetTestCase): def test_0_progress(self): log = """ 17/02/07 18:26:27 INFO log.PerfLogger: @@ -80,3 +80,45 @@ def test_job_2_launched_stage_2_stages_progress(self): 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) + + def test_wrapped_query(self): + sql = "SELECT * FROM a" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MssqlEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_simple_limit_query(self): + sql = "SELECT * FROM a" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_limit_query(self): + sql = "SELECT * FROM a LIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_newline_query(self): + sql = "SELECT * FROM a\nLIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_lcase_limit_query(self): + sql = "SELECT * FROM a\tlimit 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_limit_query_with_limit_subquery(self): + sql = "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000" + self.assertEquals(expected, limited)