Skip to content

Commit

Permalink
use sqlparse
Browse files Browse the repository at this point in the history
  • Loading branch information
timifasubaa committed Jul 13, 2018
1 parent de31886 commit dca0bd0
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 77 deletions.
41 changes: 10 additions & 31 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tableschema import Table
from werkzeug.utils import secure_filename

from superset import app, cache_util, conf, db, utils
from superset import app, cache_util, conf, db, sql_parse, utils
from superset.exceptions import SupersetTemplateException
from superset.utils import QueryStatus

Expand Down Expand Up @@ -110,40 +110,19 @@ def apply_limit_to_sql(cls, sql, limit, database):
)
return database.compile_sqla_query(qry)
elif LimitMethod.FORCE_LIMIT:
sql_before_limit, sql_after_limit = cls.get_query_without_limit(sql)
return '{sql_before_limit} LIMIT {limit}{sql_after_limit}'.format(**locals())
parsed_query = sql_parse.SupersetQuery(sql)
sql = parsed_query.get_query_with_new_limit(limit)
return sql

@classmethod
def get_limit_from_sql(cls, sql):
limit_pattern = re.compile(r"""
(?ix) # case insensitive, verbose
\s+ # whitespace
LIMIT\s+(\d+) # LIMIT $ROWS
.*$ # everything else
""")
matches = limit_pattern.findall(sql)
if matches:
return matches[0]

@classmethod
def get_query_without_limit(cls, sql):
before_limit = re.sub(r"""
(?ix) # case insensitive, verbose
\s+ # whitespace
LIMIT\s+(\d+) # LIMIT $ROWS
(.*$)
""", '', sql)

after_limit_pattern = re.compile(r"""
(?ix) # case insensitive, verbose
\s+ # whitespace
LIMIT\s+\d+ # LIMIT $ROWS
(.*$)
""")
after_limit = after_limit_pattern.findall(sql)
after_limit = after_limit[0] if after_limit else ''
return before_limit, after_limit
parsed_query = sql_parse.SupersetQuery(sql)
return parsed_query.limit

@classmethod
def get_query_with_new_limit(cls, sql, limit):
parsed_query = sql_parse.SupersetQuery(sql)
return parsed_query.get_query_with_new_limit(limit)

@staticmethod
def csv_to_df(**kwargs):
Expand Down
44 changes: 44 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@ def __init__(self, sql_statement):
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
self._limit = None
# TODO: multistatement support

logging.info('Parsing with sqlparse statement {}'.format(self.sql))
self._parsed = sqlparse.parse(self.sql)
for statement in self._parsed:
self.__extract_from_token(statement)
self._limit = self._extract_limit_from_query(statement)
self._table_names = self._table_names - self._alias_names

@property
def tables(self):
return self._table_names

@property
def limit(self):
return self._limit

def is_select(self):
return self._parsed[0].get_type() == 'SELECT'

Expand Down Expand Up @@ -128,3 +134,41 @@ def __extract_from_token(self, token):
for token in item.tokens:
if self.__is_identifier(token):
self.__process_identifier(token)

def _get_limit_from_token(self, token):
if token.ttype == sqlparse.tokens.Literal.Number.Integer:
return int(token.value)
elif token.is_group:
return int(token.get_token_at_offset(1).value)

def _extract_limit_from_query(self, statement):
limit_token = None
for pos, item in enumerate(statement.tokens):
if item.ttype in Keyword and item.value.lower() == 'limit':
limit_token = statement.tokens[pos + 2]
return self._get_limit_from_token(limit_token)

def get_query_with_new_limit(self, new_limit):
"""returns the query with the specified limit"""
"""does not change the underlying query"""
if not self._limit:
return self.sql + ' LIMIT ' + str(new_limit)
limit_pos = None
tokens = self._parsed[0].tokens
# Add all items to before_str until there is a limit
for pos, item in enumerate(tokens):
if item.ttype in Keyword and item.value.lower() == 'limit':
limit_pos = pos
break
limit = tokens[limit_pos + 2]
if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
tokens[limit_pos + 2].value = new_limit
elif limit.is_group:
tokens[limit_pos + 2].value = (
'{}, {}'.format(next(limit.get_identifiers()), new_limit)
)

str_res = ''
for i in tokens:
str_res += str(i.value)
return str_res
91 changes: 45 additions & 46 deletions tests/db_engine_specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from __future__ import print_function
from __future__ import unicode_literals

import textwrap

from superset.db_engine_specs import (
BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
MySQLEngineSpec, PrestoEngineSpec,
Expand Down Expand Up @@ -143,18 +141,6 @@ def test_modify_limit_query(self):
'SELECT * FROM a LIMIT 1000',
)

def test_modify_newline_query(self):
self.sql_limit_regex(
'SELECT * FROM a\nLIMIT 9999',
'SELECT * FROM a LIMIT 1000',
)

def test_modify_lcase_limit_query(self):
self.sql_limit_regex(
'SELECT * FROM a\tlimit 9999',
'SELECT * FROM a LIMIT 1000',
)

def test_limit_query_with_limit_subquery(self):
self.sql_limit_regex(
'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999',
Expand All @@ -163,37 +149,38 @@ def test_limit_query_with_limit_subquery(self):

def test_limit_with_expr(self):
self.sql_limit_regex(
textwrap.dedent("""\
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT
99990"""),
textwrap.dedent("""\
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990""",
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table LIMIT 1000"""),
table
LIMIT 1000""",
)

def test_limit_expr_and_semicolon(self):
self.sql_limit_regex(
textwrap.dedent("""\
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990 ;"""),
textwrap.dedent("""\
LIMIT 99990 ;""",
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table LIMIT 1000"""),
table
LIMIT 1000 ;""",
)

def test_get_datatype(self):
Expand All @@ -204,36 +191,48 @@ def test_get_datatype(self):

def test_limit_with_implicit_offset(self):
self.sql_limit_regex(
textwrap.dedent("""\
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT
99990, 999999"""),
textwrap.dedent("""\
SELECT
'LIMIT 777' AS a
, b
FROM
table LIMIT 1000, 999999"""),
LIMIT 99990, 999999""",
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 1000""",
)

def test_limit_with_explicit_offset(self):
self.sql_limit_regex(
textwrap.dedent("""\
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT
99990 OFFSET 999999"""),
textwrap.dedent("""\
SELECT
'LIMIT 777' AS a
, b
FROM
table LIMIT 1000 OFFSET 999999"""),
LIMIT 99990
OFFSET 999999""",
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000
OFFSET 999999""",
)

def test_limit_with_non_token_limit(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777'""",
"""
SELECT
'LIMIT 777' LIMIT 1000""",
)

0 comments on commit dca0bd0

Please sign in to comment.