diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index a718a0d62ca85..189fb713db36d 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -35,6 +35,7 @@ from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url from sqlalchemy.sql import text +from sqlalchemy.sql.expression import TextAsFrom import sqlparse import unicodecsv from werkzeug.utils import secure_filename @@ -55,6 +56,7 @@ class LimitMethod(object): """Enum the ways that limits can be applied""" FETCH_MANY = 'fetch_many' WRAP_SQL = 'wrap_sql' + FORCE_LIMIT = 'force_limit' class BaseEngineSpec(object): @@ -65,7 +67,7 @@ class BaseEngineSpec(object): cursor_execute_kwargs = {} time_grains = tuple() time_groupby_inline = False - limit_method = LimitMethod.FETCH_MANY + limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False inner_joins = True @@ -88,6 +90,23 @@ def extra_table_metadata(cls, database, table_name, schema_name): """Returns engine-specific table metadata""" return {} + @classmethod + def apply_limit_to_sql(cls, sql, limit, database): + """Alters the SQL statement to apply a LIMIT clause""" + if cls.limit_method == LimitMethod.WRAP_SQL: + qry = ( + select('*') + .select_from( + TextAsFrom(text(sql), ['*']).alias('inner_qry'), + ) + .limit(limit) + ) + return database.compile_sqla_query(qry) + elif LimitMethod.FORCE_LIMIT: + no_limit = re.sub(r"(?i)\s+LIMIT\s+\d+;?(\s|;)*$", '', sql) + return "{no_limit} LIMIT {limit}".format(**locals()) + return sql + @staticmethod def csv_to_df(**kwargs): kwargs['filepath_or_buffer'] = \ @@ -337,7 +356,6 @@ def get_table_names(cls, schema, inspector): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' - time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), "DATE_TRUNC('SECOND', {col})", 'PT1S'), @@ -361,6 +379,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec): class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -382,6 +401,7 @@ def convert_dttm(cls, target_type, dttm): class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), @@ -1106,6 +1126,7 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username): class MssqlEngineSpec(BaseEngineSpec): engine = 'mssql' epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')" + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -1310,7 +1331,6 @@ def get_schema_names(cls, inspector): class DruidEngineSpec(BaseEngineSpec): """Engine spec for Druid.io""" engine = 'druid' - limit_method = LimitMethod.FETCH_MANY inner_joins = False diff --git a/superset/models/core.py b/superset/models/core.py index 2ad20faca85c6..2187b0540bd6b 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -722,14 +722,7 @@ def select_star( indent=indent, latest_partition=latest_partition, cols=cols) def wrap_sql_limit(self, sql, limit=1000): - qry = ( - select('*') - .select_from( - TextAsFrom(text(sql), ['*']) - .alias('inner_qry'), - ).limit(limit) - ) - return self.compile_sqla_query(qry) + return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 856ea4880fb71..99ec957d80d1c 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -188,8 +188,7 @@ def handle_error(msg): query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (query.limit and superset_query.is_select() and - db_engine_spec.limit_method == LimitMethod.WRAP_SQL): + elif (query.limit and superset_query.is_select()): executed_sql = database.wrap_sql_limit(executed_sql, query.limit) query.limit_used = True diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 790371ae35706..ea1c9c38851c1 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -15,13 +15,13 @@ PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} -# TODO: some sql_lab logic here. class SupersetQuery(object): def __init__(self, sql_statement): self.sql = sql_statement self._table_names = set() self._alias_names = set() # TODO: multistatement support + logging.info('Parsing with sqlparse statement {}'.format(self.sql)) self._parsed = sqlparse.parse(self.sql) for statement in self._parsed: @@ -36,11 +36,7 @@ def is_select(self): return self._parsed[0].get_type() == 'SELECT' def stripped(self): - sql = self.sql - if sql: - while sql[-1] in (' ', ';', '\n', '\t'): - sql = sql[:-1] - return sql + return self.sql.strip(' \t\n;') @staticmethod def __precedes_table_name(token_value): @@ -65,13 +61,12 @@ def __is_result_operation(keyword): @staticmethod def __is_identifier(token): - return ( - isinstance(token, IdentifierList) or isinstance(token, Identifier)) + return isinstance(token, (IdentifierList, Identifier)) def __process_identifier(self, identifier): # exclude subselects if '(' not in '{}'.format(identifier): - self._table_names.add(SupersetQuery.__get_full_name(identifier)) + self._table_names.add(self.__get_full_name(identifier)) return # store aliases @@ -94,11 +89,6 @@ def as_create_table(self, table_name, overwrite=False): :param overwrite, boolean, table table_name will be dropped if true :return: string, create table as query """ - # TODO(bkyryliuk): enforce that all the columns have names. - # Presto requires it for the CTA operation. - # TODO(bkyryliuk): drop table if allowed, check the namespace and - # the permissions. - # TODO raise if multi-statement exec_sql = '' sql = self.stripped() if overwrite: @@ -117,7 +107,7 @@ def __extract_from_token(self, token): self.__extract_from_token(item) if item.ttype in Keyword: - if SupersetQuery.__precedes_table_name(item.value.upper()): + if self.__precedes_table_name(item.value.upper()): table_name_preceding_token = True continue @@ -125,7 +115,7 @@ def __extract_from_token(self, token): continue if item.ttype in Keyword: - if SupersetQuery.__is_result_operation(item.value): + if self.__is_result_operation(item.value): table_name_preceding_token = False continue # FROM clause is over @@ -136,5 +126,5 @@ def __extract_from_token(self, token): if isinstance(item, IdentifierList): for token in item.tokens: - if SupersetQuery.__is_identifier(token): + if self.__is_identifier(token): self.__process_identifier(token) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 1a1282ad1a47f..71e6bc493bb8b 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,12 +4,12 @@ from __future__ import print_function from __future__ import unicode_literals -import unittest +from superset.db_engine_specs import MssqlEngineSpec, HiveEngineSpec, MySQLEngineSpec +from superset.models.core import Database +from .base_tests import SupersetTestCase -from superset.db_engine_specs import HiveEngineSpec - -class DbEngineSpecsTestCase(unittest.TestCase): +class DbEngineSpecsTestCase(SupersetTestCase): def test_0_progress(self): log = """ 17/02/07 18:26:27 INFO log.PerfLogger: @@ -80,3 +80,45 @@ def test_job_2_launched_stage_2_stages_progress(self): 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) + + def test_wrapped_query(self): + sql = "SELECT * FROM a" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MssqlEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_simple_limit_query(self): + sql = "SELECT * FROM a" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_limit_query(self): + sql = "SELECT * FROM a LIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_newline_query(self): + sql = "SELECT * FROM a\nLIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_lcase_limit_query(self): + sql = "SELECT * FROM a\tlimit 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_limit_query_with_limit_subquery(self): + sql = "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000" + self.assertEquals(expected, limited)