Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sqlglot): Address regressions introduced in #26476 #27217

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
from sqlglot.errors import ParseError
from sqlglot.errors import SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
Expand Down Expand Up @@ -287,7 +287,7 @@ def _extract_tables_from_sql(self) -> set[Table]:
"""
try:
statements = parse(self.stripped(), dialect=self._dialect)
except ParseError:
except SqlglotError:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing (for right or wrong) can throw either a ParseError or TokenError. It seems like both a derived from the the SqlglotError error type and thus it seemed prudent to include a more broader except.

logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
return set()

Expand Down Expand Up @@ -319,12 +319,17 @@ def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
if not (literal := statement.find(exp.Literal)):
return set()

pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
sources = pseudo_query.find_all(exp.Table)
try:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue surfaced when trying to parse an ADD JAR ... statement. which sqlglot believes contains literal expressions. The TL;DR is parse_one can throw and we should catch these errors (similar to when using the sqlglot.parse method).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have no idea what your use case is, but if you want more lenient sql parsing, you can try error_level=IGNORE

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha! I was wondering whether I should have raised this with you.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always free to chat. you can hit me up on slack (i'm in your slack or you can come to mine)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside, but I had a similar problem with the old sqlparse solution with giving CTEs names which are keywords in other engines; I'm curious if swapping to sqlglot has fixed that actually, since it's more engine-aware. Specifically I made the mistake of using the name ref for the CTE in a trino query, and REF is apparently a keyword in some engines (not trino as far as I'm aware). I'll give it a go with the new sqlglot approach sometime.

pseudo_query = parse_one(
f"SELECT {literal.this}",
dialect=self._dialect,
)
sources = pseudo_query.find_all(exp.Table)
except SqlglotError:
return set()
else:
sources = [
source
Expand Down
10 changes: 6 additions & 4 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def test_extract_tables_illdefined() -> None:
assert extract_tables("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}
assert extract_tables('SELECT * FROM "tbname') == set()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ill-formed SQL statement throws a TokenError as opposed to a ParseError in sqlglot.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tobymao ideally should illformed SQL throw a TokenError or ParseError?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it depends on if it's a token error or a parser error!

when you have unbalanced quotes, it's basically impossible to finish tokenization, so that's why we throw a token error.

most everything else is a parser error



def test_extract_tables_show_tables_from() -> None:
Expand Down Expand Up @@ -558,6 +559,10 @@ def test_extract_tables_multistatement() -> None:
Table("t1"),
Table("t2"),
}
assert extract_tables(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example throws as sqlglot thinks the ADD JAR is a valid literal.

"ADD JAR file:///hive.jar; SELECT * FROM t1;",
engine="hive",
) == {Table("t1")}


def test_extract_tables_complex() -> None:
Expand Down Expand Up @@ -1815,10 +1820,7 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
# test falling back to sqlparse
logger = mocker.patch("superset.sql_parse.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(
sql,
"trino",
) == {
assert extract_table_references(sql, "trino") == {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was picked up by Black as part of the pre-commit check.

Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
Expand Down
Loading