From 2f68010729453bdf29e31b7de29731d812e1668c Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 6 Sep 2023 11:54:25 -0700 Subject: [PATCH] fix: `is_select` (#25189) --- superset/sql_parse.py | 74 ++++++++++++++++------------- tests/unit_tests/sql_parse_tests.py | 7 +++ 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 2a283b81f0fc9..34fc354730536 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -244,46 +244,52 @@ def is_select(self) -> bool: # make sure we strip comments; prevents a bug with comments in the CTE parsed = sqlparse.parse(self.strip_comments()) - # Check if this is a CTE - if parsed[0].is_group and parsed[0][0].ttype == Keyword.CTE: - if sqloxide_parse is not None: - try: - if not self._check_cte_is_select( - sqloxide_parse(self.strip_comments(), dialect="ansi") - ): - return False - except ValueError: - # sqloxide was not able to parse the query, so let's continue with - # sqlparse - pass - inner_cte = self.get_inner_cte_expression(parsed[0].tokens) or [] - # Check if the inner CTE is a not a SELECT - if any(token.ttype == DDL for token in inner_cte) or any( + for statement in parsed: + # Check if this is a CTE + if statement.is_group and statement[0].ttype == Keyword.CTE: + if sqloxide_parse is not None: + try: + if not self._check_cte_is_select( + sqloxide_parse(self.strip_comments(), dialect="ansi") + ): + return False + except ValueError: + # sqloxide was not able to parse the query, so let's continue with + # sqlparse + pass + inner_cte = self.get_inner_cte_expression(statement.tokens) or [] + # Check if the inner CTE is a not a SELECT + if any(token.ttype == DDL for token in inner_cte) or any( + token.ttype == DML and token.normalized != "SELECT" + for token in inner_cte + ): + return False + + if statement.get_type() == "SELECT": + continue + + if statement.get_type() != "UNKNOWN": + return False + + # for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed, + # and no DDL is allowed + if any(token.ttype == DDL for token in statement) or any( token.ttype == DML and token.normalized != "SELECT" - for token in inner_cte + for token in statement ): return False - if parsed[0].get_type() == "SELECT": - return True - - if parsed[0].get_type() != "UNKNOWN": - return False - - # for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed, - # and no DDL is allowed - if any(token.ttype == DDL for token in parsed[0]) or any( - token.ttype == DML and token.normalized != "SELECT" for token in parsed[0] - ): - return False + # return false on `EXPLAIN`, `SET`, `SHOW`, etc. + if statement[0].ttype == Keyword: + return False - # return false on `EXPLAIN`, `SET`, `SHOW`, etc. - if parsed[0][0].ttype == Keyword: - return False + if not any( + token.ttype == DML and token.normalized == "SELECT" + for token in statement + ): + return False - return any( - token.ttype == DML and token.normalized == "SELECT" for token in parsed[0] - ) + return True def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]: for token in tokens: diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 7d8839198c430..73074d3df64bf 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1616,3 +1616,10 @@ def test_extract_table_references(mocker: MockerFixture) -> None: Table(table="other_table", schema=None, catalog=None) } logger.warning.assert_not_called() + + +def test_is_select() -> None: + """ + Test `is_select`. + """ + assert not ParsedQuery("SELECT 1; DROP DATABASE superset").is_select()