diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index a718a0d62ca85..b91d93f393aca 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -35,6 +35,7 @@ from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url from sqlalchemy.sql import text +from sqlalchemy.sql.expression import TextAsFrom import sqlparse import unicodecsv from werkzeug.utils import secure_filename @@ -55,6 +56,7 @@ class LimitMethod(object): """Enum the ways that limits can be applied""" FETCH_MANY = 'fetch_many' WRAP_SQL = 'wrap_sql' + FORCE_LIMIT = 'force_limit' class BaseEngineSpec(object): @@ -65,7 +67,7 @@ class BaseEngineSpec(object): cursor_execute_kwargs = {} time_grains = tuple() time_groupby_inline = False - limit_method = LimitMethod.FETCH_MANY + limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False inner_joins = True @@ -88,6 +90,30 @@ def extra_table_metadata(cls, database, table_name, schema_name): """Returns engine-specific table metadata""" return {} + @classmethod + def apply_limit_to_sql(cls, sql, limit, database): + """Alters the SQL statement to apply a LIMIT clause""" + if cls.limit_method == LimitMethod.WRAP_SQL: + sql = sql.strip('\t\n ;') + qry = ( + select('*') + .select_from( + TextAsFrom(text(sql), ['*']).alias('inner_qry'), + ) + .limit(limit) + ) + return database.compile_sqla_query(qry) + elif LimitMethod.FORCE_LIMIT: + no_limit = re.sub(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+\d+ # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """, '', sql) + return '{no_limit} LIMIT {limit}'.format(**locals()) + return sql + @staticmethod def csv_to_df(**kwargs): kwargs['filepath_or_buffer'] = \ @@ -337,7 +363,6 @@ def get_table_names(cls, schema, inspector): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' - time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), "DATE_TRUNC('SECOND', {col})", 'PT1S'), @@ -361,6 +386,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec): class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -382,6 +408,7 @@ def convert_dttm(cls, target_type, dttm): class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), @@ -1106,6 +1133,7 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username): class MssqlEngineSpec(BaseEngineSpec): engine = 'mssql' epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')" + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -1310,7 +1338,6 @@ def get_schema_names(cls, inspector): class DruidEngineSpec(BaseEngineSpec): """Engine spec for Druid.io""" engine = 'druid' - limit_method = LimitMethod.FETCH_MANY inner_joins = False diff --git a/superset/models/core.py b/superset/models/core.py index 2ad20faca85c6..8448c7ba54e49 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -22,7 +22,7 @@ import sqlalchemy as sqla from sqlalchemy import ( Boolean, Column, create_engine, DateTime, ForeignKey, Integer, - MetaData, select, String, Table, Text, + MetaData, String, Table, Text, ) from sqlalchemy.engine import url from sqlalchemy.engine.url import make_url @@ -30,8 +30,6 @@ from sqlalchemy.orm.session import make_transient from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import text -from sqlalchemy.sql.expression import TextAsFrom from sqlalchemy_utils import EncryptedType from superset import app, db, db_engine_specs, security_manager, utils @@ -721,15 +719,8 @@ def select_star( self, table_name, schema=schema, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition, cols=cols) - def wrap_sql_limit(self, sql, limit=1000): - qry = ( - select('*') - .select_from( - TextAsFrom(text(sql), ['*']) - .alias('inner_qry'), - ).limit(limit) - ) - return self.compile_sqla_query(qry) + def apply_limit_to_sql(self, sql, limit=1000): + return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 856ea4880fb71..75e88146b9eee 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -17,7 +17,6 @@ from sqlalchemy.pool import NullPool from superset import app, dataframe, db, results_backend, security_manager, utils -from superset.db_engine_specs import LimitMethod from superset.models.sql_lab import Query from superset.sql_parse import SupersetQuery from superset.utils import get_celery_app, QueryStatus @@ -188,9 +187,8 @@ def handle_error(msg): query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (query.limit and superset_query.is_select() and - db_engine_spec.limit_method == LimitMethod.WRAP_SQL): - executed_sql = database.wrap_sql_limit(executed_sql, query.limit) + elif (query.limit and superset_query.is_select()): + executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) query.limit_used = True # Hook to allow environment-specific mutation (usually comments) to the SQL diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 790371ae35706..ea1c9c38851c1 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -15,13 +15,13 @@ PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} -# TODO: some sql_lab logic here. class SupersetQuery(object): def __init__(self, sql_statement): self.sql = sql_statement self._table_names = set() self._alias_names = set() # TODO: multistatement support + logging.info('Parsing with sqlparse statement {}'.format(self.sql)) self._parsed = sqlparse.parse(self.sql) for statement in self._parsed: @@ -36,11 +36,7 @@ def is_select(self): return self._parsed[0].get_type() == 'SELECT' def stripped(self): - sql = self.sql - if sql: - while sql[-1] in (' ', ';', '\n', '\t'): - sql = sql[:-1] - return sql + return self.sql.strip(' \t\n;') @staticmethod def __precedes_table_name(token_value): @@ -65,13 +61,12 @@ def __is_result_operation(keyword): @staticmethod def __is_identifier(token): - return ( - isinstance(token, IdentifierList) or isinstance(token, Identifier)) + return isinstance(token, (IdentifierList, Identifier)) def __process_identifier(self, identifier): # exclude subselects if '(' not in '{}'.format(identifier): - self._table_names.add(SupersetQuery.__get_full_name(identifier)) + self._table_names.add(self.__get_full_name(identifier)) return # store aliases @@ -94,11 +89,6 @@ def as_create_table(self, table_name, overwrite=False): :param overwrite, boolean, table table_name will be dropped if true :return: string, create table as query """ - # TODO(bkyryliuk): enforce that all the columns have names. - # Presto requires it for the CTA operation. - # TODO(bkyryliuk): drop table if allowed, check the namespace and - # the permissions. - # TODO raise if multi-statement exec_sql = '' sql = self.stripped() if overwrite: @@ -117,7 +107,7 @@ def __extract_from_token(self, token): self.__extract_from_token(item) if item.ttype in Keyword: - if SupersetQuery.__precedes_table_name(item.value.upper()): + if self.__precedes_table_name(item.value.upper()): table_name_preceding_token = True continue @@ -125,7 +115,7 @@ def __extract_from_token(self, token): continue if item.ttype in Keyword: - if SupersetQuery.__is_result_operation(item.value): + if self.__is_result_operation(item.value): table_name_preceding_token = False continue # FROM clause is over @@ -136,5 +126,5 @@ def __extract_from_token(self, token): if isinstance(item, IdentifierList): for token in item.tokens: - if SupersetQuery.__is_identifier(token): + if self.__is_identifier(token): self.__process_identifier(token) diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 79b71e986e2aa..f6d1a2958fd1c 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -139,38 +139,6 @@ def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp', self.logout() return json.loads(resp.data.decode('utf-8')) - def test_add_limit_to_the_query(self): - main_db = self.get_main_database(db.session) - - select_query = 'SELECT * FROM outer_space;' - updated_select_query = main_db.wrap_sql_limit(select_query, 100) - # Different DB engines have their own spacing while compiling - # the queries, that's why ' '.join(query.split()) is used. - # In addition some of the engines do not include OFFSET 0. - self.assertTrue( - 'SELECT * FROM (SELECT * FROM outer_space;) AS inner_qry ' - 'LIMIT 100' in ' '.join(updated_select_query.split()), - ) - - select_query_no_semicolon = 'SELECT * FROM outer_space' - updated_select_query_no_semicolon = main_db.wrap_sql_limit( - select_query_no_semicolon, 100) - self.assertTrue( - 'SELECT * FROM (SELECT * FROM outer_space) AS inner_qry ' - 'LIMIT 100' in - ' '.join(updated_select_query_no_semicolon.split()), - ) - - multi_line_query = ( - "SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';" - ) - updated_multi_line_query = main_db.wrap_sql_limit(multi_line_query, 100) - self.assertTrue( - 'SELECT * FROM (SELECT * FROM planets WHERE ' - "Luke_Father = 'Darth Vader';) AS inner_qry LIMIT 100" in - ' '.join(updated_multi_line_query.split()), - ) - def test_run_sync_query_dont_exist(self): main_db = self.get_main_database(db.session) db_id = main_db.id diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 1a1282ad1a47f..c38e4f569023a 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,12 +4,15 @@ from __future__ import print_function from __future__ import unicode_literals -import unittest +import textwrap -from superset.db_engine_specs import HiveEngineSpec +from superset.db_engine_specs import ( + HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec) +from superset.models.core import Database +from .base_tests import SupersetTestCase -class DbEngineSpecsTestCase(unittest.TestCase): +class DbEngineSpecsTestCase(SupersetTestCase): def test_0_progress(self): log = """ 17/02/07 18:26:27 INFO log.PerfLogger: @@ -80,3 +83,100 @@ def test_job_2_launched_stage_2_stages_progress(self): 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) + + def get_generic_database(self): + return Database(sqlalchemy_uri='mysql://localhost') + + def sql_limit_regex( + self, sql, expected_sql, + engine_spec_class=MySQLEngineSpec, + limit=1000): + main = self.get_generic_database() + limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) + self.assertEquals(expected_sql, limited) + + def test_wrapped_query(self): + self.sql_limit_regex( + 'SELECT * FROM a', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_wrapped_semi(self): + self.sql_limit_regex( + 'SELECT * FROM a;', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_wrapped_semi_tabs(self): + self.sql_limit_regex( + 'SELECT * FROM a \t \n ; \t \n ', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_simple_limit_query(self): + self.sql_limit_regex( + 'SELECT * FROM a', + 'SELECT * FROM a LIMIT 1000', + ) + + def test_modify_limit_query(self): + self.sql_limit_regex( + 'SELECT * FROM a LIMIT 9999', + 'SELECT * FROM a LIMIT 1000', + ) + + def test_modify_newline_query(self): + self.sql_limit_regex( + 'SELECT * FROM a\nLIMIT 9999', + 'SELECT * FROM a LIMIT 1000', + ) + + def test_modify_lcase_limit_query(self): + self.sql_limit_regex( + 'SELECT * FROM a\tlimit 9999', + 'SELECT * FROM a LIMIT 1000', + ) + + def test_limit_query_with_limit_subquery(self): + self.sql_limit_regex( + 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999', + 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000', + ) + + def test_limit_with_expr(self): + self.sql_limit_regex( + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT + 99990"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), + ) + + def test_limit_expr_and_semicolon(self): + self.sql_limit_regex( + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990 ;"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), + )