From 3cd733b81776456f70b470a0b6aa3ebf985d3f4f Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Sun, 29 Dec 2024 10:59:45 +0800 Subject: [PATCH 01/16] add sqlglot parsing --- marimo/_ast/sql_visitor.py | 131 ++++++--------------------- marimo/_dependencies/dependencies.py | 1 + pyproject.toml | 20 ++-- tests/_ast/test_sql_visitor.py | 30 ++++-- 4 files changed, 67 insertions(+), 115 deletions(-) diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index 9ad3a62e3b8..61bfacc6790 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -301,8 +301,6 @@ def find_sql_defs(sql_statement: str) -> SQLDefs: ) -# TODO(akshayka): there are other kinds of refs to find; this should be -# find_sql_refs def find_sql_refs( sql_statement: str, ) -> list[str]: @@ -315,109 +313,38 @@ def find_sql_refs( Returns: A list of table and schema names referenced in the statement. """ - if not DependencyManager.duckdb.has(): + + # Use sqlglot to parse ast (https://github.com/tobymao/sqlglot/blob/main/posts/ast_primer.md) + if not DependencyManager.sqlglot.has(): return [] - import duckdb + from sqlglot import exp, parse + from sqlglot.optimizer.scope import build_scope - tokens = duckdb.tokenize(sql_statement) - token_extractor = TokenExtractor( - sql_statement=sql_statement, tokens=tokens - ) refs: list[str] = [] - cte_names: set[str] = set() - i = 0 - - # First pass - collect CTE names - while i < len(tokens): - if token_extractor.is_keyword(i, "with"): - i += 1 - # Handle optional parenthesis after WITH - if token_extractor.token_str(i) == "(": - i += 1 - while i < len(tokens): - if token_extractor.is_keyword(i, "select"): - break - if ( - token_extractor.token_str(i) == "," - or token_extractor.token_str(i) == "(" - ): - i += 1 - continue - cte_name = token_extractor.strip_quotes( - token_extractor.token_str(i) - ) - if not token_extractor.is_keyword(i, "as"): - cte_names.add(cte_name) - i += 1 - if token_extractor.is_keyword(i, "as"): - break - i += 1 - - # Second pass - collect references excluding CTEs - i = 0 - while i < len(tokens): - if token_extractor.is_keyword(i, "from") or token_extractor.is_keyword( - i, "join" - ): - i += 1 - if i < len(tokens): - # Skip over opening parenthesis for subqueries - if token_extractor.token_str(i) == "(": - continue - - # Get table name parts, this could be: - # - catalog.schema.table - # - catalog.table (this is shorthand for catalog.main.table) - # - table - - parts: List[str] = [] - while i < len(tokens): - part = token_extractor.strip_quotes( - token_extractor.token_str(i) - ) - parts.append(part) - # next token is a dot, so we continue getting parts - if ( - i + 1 < len(tokens) - and token_extractor.token_str(i + 1) == "." - ): - i += 2 - continue - break - - if len(parts) == 3: - # If its the default in-memory catalog, - # only add the table name - if parts[0] == "memory": - refs.append(parts[2]) - else: - # Just add the catalog and table, skip schema - refs.extend([parts[0], parts[2]]) - elif len(parts) == 2: - # If its the default in-memory catalog, only add the table - if parts[0] == "memory": - refs.append(parts[1]) + asts = parse(sql_statement) + for sql_ast in asts: + root = build_scope(sql_ast) + if not root: # likely not a query + return [] + + for scope in root.traverse(): + for _alias, (_node, source) in scope.selected_sources.items(): + if isinstance(source, exp.Table): + if source.catalog == "memory": + # Default in-memory catalog, only include table name + refs.append(source.name) else: - # It's a catalog and table, add both - refs.extend(parts) - elif len(parts) == 1: - # It's a table, make sure it's not a CTE - if parts[0] not in cte_names: - refs.append(parts[0]) - else: - LOGGER.warning( - "Unexpected number of parts in SQL reference: %s", - parts, - ) - - i -= 1 # Compensate for outer loop increment - i += 1 - - # Re-use find_sql_defs to find referenced schemas and catalogs during creation. - defs = find_sql_defs(sql_statement) - refs.extend(defs.reffed_schemas) - refs.extend(defs.reffed_catalogs) - - # Remove duplicates while preserving order + # We skip schema if there is a catalog + # Because it may be called "public" or "main" across all catalogs + # and they aren't referenced in the code + if source.catalog: + refs.append(source.catalog) + elif source.db: + refs.append(source.db) # schema + + if source.name: + refs.append(source.name) # table name + + # removes duplicates while preserving order return list(dict.fromkeys(refs)) diff --git a/marimo/_dependencies/dependencies.py b/marimo/_dependencies/dependencies.py index 808a7ded49a..cf2eef406c5 100644 --- a/marimo/_dependencies/dependencies.py +++ b/marimo/_dependencies/dependencies.py @@ -147,6 +147,7 @@ class DependencyManager: numpy = Dependency("numpy") altair = Dependency("altair", min_version="5.3.0", max_version="6.0.0") duckdb = Dependency("duckdb") + sqlglot = Dependency("sqlglot") pillow = Dependency("PIL") plotly = Dependency("plotly") bokeh = Dependency("bokeh") diff --git a/pyproject.toml b/pyproject.toml index 74bec2c56ba..bca60321fe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,16 +75,22 @@ marimo = "marimo._cli.cli:main" homepage = "https://github.com/marimo-team/marimo" [project.optional-dependencies] -sql = ["duckdb >= 1.0.0", "polars[pyarrow] >= 1.9.0"] +sql = [ + "duckdb >= 1.1.0", + "polars[pyarrow] >= 1.9.0", + "sqlglot >= 26.0.1" +] + # List of deps that are recommended for most users # in order to unlock all features in marimo recommended = [ - "duckdb>=1.1.0", # SQL cells - "altair>=5.4.0", # Plotting in datasource viewer - "polars>=1.9.0", # SQL output back in Python - "openai>=1.41.1", # AI features - "ruff", # Formatting - "nbformat>=5.7.0", # Export as IPYNB + "duckdb>=1.1.0", # SQL cells + "altair>=5.4.0", # Plotting in datasource viewer + "polars[pyarrow]>=1.9.0", # SQL output back in Python + "sqlglot>=26.0.1", # SQL cells parsing + "openai>=1.41.1", # AI features + "ruff", # Formatting + "nbformat>=5.7.0", # Export as IPYNB ] dev = [ diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index aa888892e3b..ac3c0b20d0a 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -14,6 +14,7 @@ from marimo._dependencies.dependencies import DependencyManager HAS_DUCKDB = DependencyManager.duckdb.has() +HAS_SQLGLOT = DependencyManager.sqlglot.has() def test_execute_with_string_literal() -> None: @@ -461,6 +462,8 @@ def test_find_sql_defs_with_if_not_exists() -> None: tables=["my_table"], ) + # add sql with table & view in the same query + @pytest.mark.skipif( HAS_DUCKDB, reason="Test requires DuckDB to be unavailable" @@ -469,7 +472,7 @@ def test_find_sql_defs_duckdb_not_available() -> None: assert find_sql_defs("CREATE TABLE test (id INT);") == SQLDefs() -@pytest.mark.skipif(not HAS_DUCKDB, reason="Missing DuckDB") +@pytest.mark.skipif(not HAS_SQLGLOT, reason="Missing sqlglot") class TestFindSQLRefs: @staticmethod def test_find_sql_refs_simple() -> None: @@ -509,8 +512,6 @@ def test_find_sql_refs_with_schema() -> None: @staticmethod def test_find_sql_refs_with_catalog() -> None: # Skip the schema if it's coming from a catalog - # Why? Because it may be called "public" or "main" across all catalogs - # and they aren't referenced in the code sql = "SELECT * FROM my_catalog.my_schema.my_table;" assert find_sql_refs(sql) == ["my_catalog", "my_table"] @@ -566,7 +567,6 @@ def test_find_sql_refs_with_quoted_names() -> None: assert find_sql_refs(sql) == ["My Table", "Weird.Name"] @staticmethod - @pytest.mark.xfail(reason="Multiple CTEs are not supported") def test_find_sql_refs_with_multiple_ctes() -> None: sql = """ WITH @@ -578,7 +578,6 @@ def test_find_sql_refs_with_multiple_ctes() -> None: assert find_sql_refs(sql) == ["table1", "table2"] @staticmethod - @pytest.mark.xfail(reason="Nested joins are not supported") def test_find_sql_refs_with_nested_joins() -> None: sql = """ SELECT * FROM t1 @@ -593,7 +592,7 @@ def test_find_sql_refs_with_lateral_join() -> None: SELECT * FROM employees, LATERAL (SELECT * FROM departments WHERE departments.id = employees.dept_id) dept; """ - assert find_sql_refs(sql) == ["employees", "departments"] + assert find_sql_refs(sql) == ["departments", "employees"] @staticmethod def test_find_sql_refs_with_schema_switching() -> None: @@ -614,3 +613,22 @@ def test_find_sql_refs_with_complex_subqueries() -> None: ) t2; """ assert find_sql_refs(sql) == ["deeply", "table", "another_table"] + + @staticmethod + def test_find_sql_refs_with_alias() -> None: + sql = "SELECT * FROM employees AS e;" + assert find_sql_refs(sql) == ["employees"] + + @staticmethod + def test_find_sql_refs_invalid() -> None: + sql = "CREATE TABLE t1 (id int);" + assert find_sql_refs(sql) == [] + + @staticmethod + def test_find_sql_refs_comment() -> None: + sql = """ + -- comment + SELECT * FROM table1; + -- comment + """ + assert find_sql_refs(sql) == ["table1"] From e3043f507505bc729ac0fbf2ffa38461ae1b7039 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Sun, 29 Dec 2024 11:14:05 +0800 Subject: [PATCH 02/16] fix for comments and ddl's --- marimo/_ast/sql_visitor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index 61bfacc6790..b7d4c906774 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -325,8 +325,8 @@ def find_sql_refs( asts = parse(sql_statement) for sql_ast in asts: root = build_scope(sql_ast) - if not root: # likely not a query - return [] + if root is None: # Skip ddl's and comments + continue for scope in root.traverse(): for _alias, (_node, source) in scope.selected_sources.items(): From 2880dea7589986716db7a5839ee3101a61680e40 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 15:43:51 +0800 Subject: [PATCH 03/16] add support for dmls --- marimo/_ast/sql_visitor.py | 51 +++++++++------ tests/_ast/test_sql_visitor.py | 112 ++++++++++++++++++++++++++++++--- 2 files changed, 138 insertions(+), 25 deletions(-) diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index b7d4c906774..05d5ec2a1a6 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -322,29 +322,44 @@ def find_sql_refs( from sqlglot.optimizer.scope import build_scope refs: list[str] = [] - asts = parse(sql_statement) - for sql_ast in asts: - root = build_scope(sql_ast) - if root is None: # Skip ddl's and comments + + def append_refs_from_table(table_expr: exp.Table) -> None: + if table_expr.catalog == "memory": + # Default in-memory catalog, only include table name + refs.append(table_expr.name) + else: + # We skip schema if there is a catalog + # Because it may be called "public" or "main" across all catalogs + # and they aren't referenced in the code + if table_expr.catalog: + refs.append(table_expr.catalog) + elif table_expr.db: + refs.append(table_expr.db) # schema + + if table_expr.name: + refs.append(table_expr.name) # table name + + expression_list = parse(sql_statement) + for expression in expression_list: + dml_expression = False + if expression.find(exp.Update, exp.Insert, exp.Delete): + dml_expression = True + for table_expr in expression.find_all(exp.Table): + append_refs_from_table(table_expr) + + # this traversal is only available for select statements + root = build_scope(expression) + if root is None: continue + if dml_expression: + LOGGER.warning( + "Scopes should not exist for dml's, may need rework" + ) for scope in root.traverse(): for _alias, (_node, source) in scope.selected_sources.items(): if isinstance(source, exp.Table): - if source.catalog == "memory": - # Default in-memory catalog, only include table name - refs.append(source.name) - else: - # We skip schema if there is a catalog - # Because it may be called "public" or "main" across all catalogs - # and they aren't referenced in the code - if source.catalog: - refs.append(source.catalog) - elif source.db: - refs.append(source.db) # schema - - if source.name: - refs.append(source.name) # table name + append_refs_from_table(source) # removes duplicates while preserving order return list(dict.fromkeys(refs)) diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index ac3c0b20d0a..6b7dcb2a5e5 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -394,6 +394,17 @@ def test_find_sql_defs_create_or_replace_with_catalog() -> None: reffed_catalogs=["my_db"], ) + @staticmethod + def test_find_sql_defs_with_create_as() -> None: + sql = """ + CREATE TABLE t2 AS + WITH t3 AS ( + SELECT * from t1 + ) + SELECT * FROM t3; + """ + assert find_sql_defs(sql) == SQLDefs(tables=["t2"]) + @staticmethod def test_find_sql_defs_with_catalog_and_schema() -> None: sql = """ @@ -462,8 +473,6 @@ def test_find_sql_defs_with_if_not_exists() -> None: tables=["my_table"], ) - # add sql with table & view in the same query - @pytest.mark.skipif( HAS_DUCKDB, reason="Test requires DuckDB to be unavailable" @@ -614,16 +623,25 @@ def test_find_sql_refs_with_complex_subqueries() -> None: """ assert find_sql_refs(sql) == ["deeply", "table", "another_table"] + @staticmethod + def test_find_sql_refs_nested_intersect() -> None: + sql = """ + SELECT * FROM table1 + WHERE id IN ( + SELECT id FROM table2 + UNION + SELECT id FROM table3 + INTERSECT + SELECT id FROM table4 + ); + """ + assert find_sql_refs(sql) == ["table2", "table3", "table4", "table1"] + @staticmethod def test_find_sql_refs_with_alias() -> None: sql = "SELECT * FROM employees AS e;" assert find_sql_refs(sql) == ["employees"] - @staticmethod - def test_find_sql_refs_invalid() -> None: - sql = "CREATE TABLE t1 (id int);" - assert find_sql_refs(sql) == [] - @staticmethod def test_find_sql_refs_comment() -> None: sql = """ @@ -632,3 +650,83 @@ def test_find_sql_refs_comment() -> None: -- comment """ assert find_sql_refs(sql) == ["table1"] + + @staticmethod + def test_find_sql_refs_ddl() -> None: + # we are not referencing any table hence no refs + sql = "CREATE TABLE t1 (id int);" + assert find_sql_refs(sql) == [] + + @staticmethod + def test_find_sql_refs_ddl_with_reference() -> None: + sql = """ + CREATE TABLE table2 AS + WITH x AS ( + SELECT * from my_catalog.my_schema.table1 + ) + SELECT * FROM x; + """ + assert find_sql_refs(sql) == ["my_catalog", "table1"] + + @staticmethod + def test_find_sql_refs_update() -> None: + sql = "UPDATE my_schema.table1 SET id = 1" + assert find_sql_refs(sql) == ["my_schema", "table1"] + + @staticmethod + def test_find_sql_refs_update_with_catalog() -> None: + sql = "UPDATE my_catalog.my_schema.table1 SET id = 1" + assert find_sql_refs(sql) == ["my_catalog", "table1"] + + @staticmethod + def test_find_sql_refs_insert() -> None: + sql = "INSERT INTO my_schema.table1 (id INT) VALUES (1,2);" + assert find_sql_refs(sql) == ["my_schema", "table1"] + + @staticmethod + def test_find_sql_refs_delete() -> None: + sql = "DELETE FROM my_schema.table1 WHERE true;" + assert find_sql_refs(sql) == ["my_schema", "table1"] + + @staticmethod + def test_find_sql_refs_multi_dml() -> None: + sql = """ + INSERT INTO table1 (id INT) VALUES (1,2); + DELETE FROM table2 WHERE true; + UPDATE table3 SET id = 1; + """ + assert find_sql_refs(sql) == ["table1", "table2", "table3"] + + @staticmethod + def test_find_sql_refs_dml_with_query() -> None: + sql = """ + INSERT INTO table1 (id INT) VALUES (1); + SELECT * FROM table2; + """ + assert find_sql_refs(sql) == ["table1", "table2"] + + @staticmethod + def test_find_sql_refs_multi_selects_in_update() -> None: + sql = """ + UPDATE schema1.table1 + SET table1.column1 = ( + SELECT table2.column2 FROM schema2.table2 + ), + table1.column3 = ( + SELECT table3.column3 FROM table3 + ) + WHERE EXISTS ( + SELECT 1 FROM table2 + ) + AND table1.column4 IN ( + SELECT table4.column4 FROM table4 + ); + """ + assert find_sql_refs(sql) == [ + "schema1", + "table1", + "schema2", + "table2", + "table3", + "table4", + ] From 1141567d6edb6b32aea8dbf5ee809033fe5a04d5 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 15:54:43 +0800 Subject: [PATCH 04/16] modify comments and var name --- marimo/_ast/sql_visitor.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index 05d5ec2a1a6..ee8b8da5f57 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -323,29 +323,29 @@ def find_sql_refs( refs: list[str] = [] - def append_refs_from_table(table_expr: exp.Table) -> None: - if table_expr.catalog == "memory": + def append_refs_from_table(table: exp.Table) -> None: + if table.catalog == "memory": # Default in-memory catalog, only include table name - refs.append(table_expr.name) + refs.append(table.name) else: # We skip schema if there is a catalog # Because it may be called "public" or "main" across all catalogs # and they aren't referenced in the code - if table_expr.catalog: - refs.append(table_expr.catalog) - elif table_expr.db: - refs.append(table_expr.db) # schema + if table.catalog: + refs.append(table.catalog) + elif table.db: + refs.append(table.db) # schema - if table_expr.name: - refs.append(table_expr.name) # table name + if table.name: + refs.append(table.name) expression_list = parse(sql_statement) for expression in expression_list: dml_expression = False if expression.find(exp.Update, exp.Insert, exp.Delete): dml_expression = True - for table_expr in expression.find_all(exp.Table): - append_refs_from_table(table_expr) + for table in expression.find_all(exp.Table): + append_refs_from_table(table) # this traversal is only available for select statements root = build_scope(expression) @@ -353,7 +353,7 @@ def append_refs_from_table(table_expr: exp.Table) -> None: continue if dml_expression: LOGGER.warning( - "Scopes should not exist for dml's, may need rework" + "Scopes should not exist for dml's, may need rework if this occurs" ) for scope in root.traverse(): @@ -361,5 +361,5 @@ def append_refs_from_table(table_expr: exp.Table) -> None: if isinstance(source, exp.Table): append_refs_from_table(source) - # removes duplicates while preserving order + # remove duplicates while preserving order return list(dict.fromkeys(refs)) From f6f40f06fcc7a5ba2d5658696e78f30c7fd7e809 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 16:04:03 +0800 Subject: [PATCH 05/16] small refactor --- marimo/_ast/sql_visitor.py | 6 ++---- pyproject.toml | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index ee8b8da5f57..4e725bd3416 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -341,9 +341,7 @@ def append_refs_from_table(table: exp.Table) -> None: expression_list = parse(sql_statement) for expression in expression_list: - dml_expression = False - if expression.find(exp.Update, exp.Insert, exp.Delete): - dml_expression = True + if is_dml := bool(expression.find(exp.Update, exp.Insert, exp.Delete)): for table in expression.find_all(exp.Table): append_refs_from_table(table) @@ -351,7 +349,7 @@ def append_refs_from_table(table: exp.Table) -> None: root = build_scope(expression) if root is None: continue - if dml_expression: + if is_dml: LOGGER.warning( "Scopes should not exist for dml's, may need rework if this occurs" ) diff --git a/pyproject.toml b/pyproject.toml index bca60321fe1..3c74abd89a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,8 +76,8 @@ homepage = "https://github.com/marimo-team/marimo" [project.optional-dependencies] sql = [ - "duckdb >= 1.1.0", - "polars[pyarrow] >= 1.9.0", + "duckdb >= 1.1.0", + "polars[pyarrow] >= 1.9.0", "sqlglot >= 26.0.1" ] From 7748a2a7bf284be367f12b3959cd11ea56537d2c Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 17:56:28 +0800 Subject: [PATCH 06/16] adding sqlglot to dev dep --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3c74abd89a2..f87f319a08f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,8 @@ dev = [ "opentelemetry-api~=1.26.0", "opentelemetry-sdk~=1.26.0", # For SQL - "duckdb>=1.0.0", + "duckdb>=1.1.0", + "sqlglot>=26.0.1", # For linting "ruff~=0.6.1", # For AI From b1aff25e822b8a5afa601fd2d035bd46f768f259 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 18:11:23 +0800 Subject: [PATCH 07/16] downgrade sqlglot min req --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f87f319a08f..a79bc6a4f40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ homepage = "https://github.com/marimo-team/marimo" sql = [ "duckdb >= 1.1.0", "polars[pyarrow] >= 1.9.0", - "sqlglot >= 26.0.1" + "sqlglot >= 23.4" ] # List of deps that are recommended for most users @@ -87,7 +87,7 @@ recommended = [ "duckdb>=1.1.0", # SQL cells "altair>=5.4.0", # Plotting in datasource viewer "polars[pyarrow]>=1.9.0", # SQL output back in Python - "sqlglot>=26.0.1", # SQL cells parsing + "sqlglot>=23.4", # SQL cells parsing "openai>=1.41.1", # AI features "ruff", # Formatting "nbformat>=5.7.0", # Export as IPYNB @@ -101,7 +101,7 @@ dev = [ "opentelemetry-sdk~=1.26.0", # For SQL "duckdb>=1.1.0", - "sqlglot>=26.0.1", + "sqlglot>=23.4", # For linting "ruff~=0.6.1", # For AI From e091ba5e45c3a93bce6a0d47e3b1ea676712d71a Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 19:04:38 +0800 Subject: [PATCH 08/16] update sqlglot version to satisfiable --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a79bc6a4f40..cf08e3a08f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ homepage = "https://github.com/marimo-team/marimo" sql = [ "duckdb >= 1.1.0", "polars[pyarrow] >= 1.9.0", - "sqlglot >= 23.4" + "sqlglot >= 25.20" ] # List of deps that are recommended for most users @@ -87,7 +87,7 @@ recommended = [ "duckdb>=1.1.0", # SQL cells "altair>=5.4.0", # Plotting in datasource viewer "polars[pyarrow]>=1.9.0", # SQL output back in Python - "sqlglot>=23.4", # SQL cells parsing + "sqlglot>=25.20", # SQL cells parsing "openai>=1.41.1", # AI features "ruff", # Formatting "nbformat>=5.7.0", # Export as IPYNB @@ -101,7 +101,7 @@ dev = [ "opentelemetry-sdk~=1.26.0", # For SQL "duckdb>=1.1.0", - "sqlglot>=23.4", + "sqlglot>=25.20", # For linting "ruff~=0.6.1", # For AI From 29b70ad109105b3ef48fb16190caaffc73a26777 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 21:57:42 +0800 Subject: [PATCH 09/16] fix types and reduce sqlglot to min version --- marimo/_ast/sql_visitor.py | 5 ++++- pyproject.toml | 10 +++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index 4e725bd3416..c6b595311bf 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -341,6 +341,9 @@ def append_refs_from_table(table: exp.Table) -> None: expression_list = parse(sql_statement) for expression in expression_list: + if expression is None: + continue + if is_dml := bool(expression.find(exp.Update, exp.Insert, exp.Delete)): for table in expression.find_all(exp.Table): append_refs_from_table(table) @@ -354,7 +357,7 @@ def append_refs_from_table(table: exp.Table) -> None: "Scopes should not exist for dml's, may need rework if this occurs" ) - for scope in root.traverse(): + for scope in root.traverse(): # type: ignore for _alias, (_node, source) in scope.selected_sources.items(): if isinstance(source, exp.Table): append_refs_from_table(source) diff --git a/pyproject.toml b/pyproject.toml index cf08e3a08f3..440d2295a95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,9 @@ homepage = "https://github.com/marimo-team/marimo" [project.optional-dependencies] sql = [ - "duckdb >= 1.1.0", - "polars[pyarrow] >= 1.9.0", - "sqlglot >= 25.20" + "duckdb>=1.1.0", + "polars[pyarrow]>=1.9.0", + "sqlglot>=23.4" ] # List of deps that are recommended for most users @@ -87,7 +87,7 @@ recommended = [ "duckdb>=1.1.0", # SQL cells "altair>=5.4.0", # Plotting in datasource viewer "polars[pyarrow]>=1.9.0", # SQL output back in Python - "sqlglot>=25.20", # SQL cells parsing + "sqlglot>=23.4", # SQL cells parsing "openai>=1.41.1", # AI features "ruff", # Formatting "nbformat>=5.7.0", # Export as IPYNB @@ -101,7 +101,7 @@ dev = [ "opentelemetry-sdk~=1.26.0", # For SQL "duckdb>=1.1.0", - "sqlglot>=25.20", + "sqlglot>=23.4", # For linting "ruff~=0.6.1", # For AI From 40f984e75493352cd3c5d0830776ae419c6520cc Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 22:07:57 +0800 Subject: [PATCH 10/16] remove unecc test --- tests/_ast/test_sql_visitor.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index 6b7dcb2a5e5..e3d16e8ddc2 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -394,17 +394,6 @@ def test_find_sql_defs_create_or_replace_with_catalog() -> None: reffed_catalogs=["my_db"], ) - @staticmethod - def test_find_sql_defs_with_create_as() -> None: - sql = """ - CREATE TABLE t2 AS - WITH t3 AS ( - SELECT * from t1 - ) - SELECT * FROM t3; - """ - assert find_sql_defs(sql) == SQLDefs(tables=["t2"]) - @staticmethod def test_find_sql_defs_with_catalog_and_schema() -> None: sql = """ From a38ac25e78bc668d2f87dd6aed556e1f280759a3 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Mon, 30 Dec 2024 22:30:34 +0800 Subject: [PATCH 11/16] refactor flow --- marimo/_ast/sql_visitor.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index c6b595311bf..2458033992b 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -348,19 +348,17 @@ def append_refs_from_table(table: exp.Table) -> None: for table in expression.find_all(exp.Table): append_refs_from_table(table) - # this traversal is only available for select statements - root = build_scope(expression) - if root is None: - continue - if is_dml: - LOGGER.warning( - "Scopes should not exist for dml's, may need rework if this occurs" - ) + # build_scope only works for select statements + if root := build_scope(expression): + if is_dml: + LOGGER.warning( + "Scopes should not exist for dml's, may need rework if this occurs" + ) - for scope in root.traverse(): # type: ignore - for _alias, (_node, source) in scope.selected_sources.items(): - if isinstance(source, exp.Table): - append_refs_from_table(source) + for scope in root.traverse(): # type: ignore + for _alias, (_node, source) in scope.selected_sources.items(): + if isinstance(source, exp.Table): + append_refs_from_table(source) # remove duplicates while preserving order return list(dict.fromkeys(refs)) From 6c5b390a72dfaf67b919181fe39d812b70007342 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Tue, 31 Dec 2024 09:40:48 +0800 Subject: [PATCH 12/16] set duckdb to original --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 440d2295a95..2555db23d29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ homepage = "https://github.com/marimo-team/marimo" [project.optional-dependencies] sql = [ - "duckdb>=1.1.0", + "duckdb>=1.0.0", "polars[pyarrow]>=1.9.0", "sqlglot>=23.4" ] @@ -84,7 +84,7 @@ sql = [ # List of deps that are recommended for most users # in order to unlock all features in marimo recommended = [ - "duckdb>=1.1.0", # SQL cells + "duckdb>=1.0.0", # SQL cells "altair>=5.4.0", # Plotting in datasource viewer "polars[pyarrow]>=1.9.0", # SQL output back in Python "sqlglot>=23.4", # SQL cells parsing @@ -100,7 +100,7 @@ dev = [ "opentelemetry-api~=1.26.0", "opentelemetry-sdk~=1.26.0", # For SQL - "duckdb>=1.1.0", + "duckdb>=1.0.0", "sqlglot>=23.4", # For linting "ruff~=0.6.1", From c4e7594d6dcd12e2c5fbcedcc323ed9ec73b2d1b Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Tue, 31 Dec 2024 09:56:39 +0800 Subject: [PATCH 13/16] refactor test --- tests/_ast/test_sql_visitor.py | 39 ++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index e3d16e8ddc2..7a105ef1a52 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -651,22 +651,17 @@ def test_find_sql_refs_ddl_with_reference() -> None: sql = """ CREATE TABLE table2 AS WITH x AS ( - SELECT * from my_catalog.my_schema.table1 + SELECT * from table1 ) SELECT * FROM x; """ - assert find_sql_refs(sql) == ["my_catalog", "table1"] + assert find_sql_refs(sql) == ["table1"] @staticmethod def test_find_sql_refs_update() -> None: sql = "UPDATE my_schema.table1 SET id = 1" assert find_sql_refs(sql) == ["my_schema", "table1"] - @staticmethod - def test_find_sql_refs_update_with_catalog() -> None: - sql = "UPDATE my_catalog.my_schema.table1 SET id = 1" - assert find_sql_refs(sql) == ["my_catalog", "table1"] - @staticmethod def test_find_sql_refs_insert() -> None: sql = "INSERT INTO my_schema.table1 (id INT) VALUES (1,2);" @@ -687,15 +682,7 @@ def test_find_sql_refs_multi_dml() -> None: assert find_sql_refs(sql) == ["table1", "table2", "table3"] @staticmethod - def test_find_sql_refs_dml_with_query() -> None: - sql = """ - INSERT INTO table1 (id INT) VALUES (1); - SELECT * FROM table2; - """ - assert find_sql_refs(sql) == ["table1", "table2"] - - @staticmethod - def test_find_sql_refs_multi_selects_in_update() -> None: + def test_find_sql_refs_multiple_selects_in_update() -> None: sql = """ UPDATE schema1.table1 SET table1.column1 = ( @@ -719,3 +706,23 @@ def test_find_sql_refs_multi_selects_in_update() -> None: "table3", "table4", ] + + @staticmethod + def test_find_sql_refs_select_in_insert() -> None: + sql = """ + INSERT INTO table1 (column1, column2) + SELECT column1, column2 FROM table2 + WHERE column3 = 'value'; + """ + assert find_sql_refs(sql) == ["table1", "table2"] + + @staticmethod + def test_find_sql_refs_select_in_delete() -> None: + sql = """ + DELETE FROM table1 + WHERE column1 IN ( + SELECT column1 FROM table2 + WHERE column2 = 'value' + ); + """ + assert find_sql_refs(sql) == ["table1", "table2"] From a8482672ae8aa6103e1a86ce908816a0ee92f9c7 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Tue, 31 Dec 2024 09:59:01 +0800 Subject: [PATCH 14/16] add small test --- tests/_ast/test_sql_visitor.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index 7a105ef1a52..407f14f72b6 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -626,6 +626,18 @@ def test_find_sql_refs_nested_intersect() -> None: """ assert find_sql_refs(sql) == ["table2", "table3", "table4", "table1"] + @staticmethod + def test_find_sql_refs_with_recursive_cte() -> None: + sql = """ + WITH RECURSIVE cte AS ( + SELECT 1 AS n FROM table1 + UNION ALL + SELECT n + 1 FROM cte WHERE n < 10 + ) + SELECT * FROM cte; + """ + assert find_sql_refs(sql) == ["table1"] + @staticmethod def test_find_sql_refs_with_alias() -> None: sql = "SELECT * FROM employees AS e;" From 40fdea9007ed916b581514fea060cb0965e15188 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Tue, 31 Dec 2024 10:00:10 +0800 Subject: [PATCH 15/16] remove test --- tests/_ast/test_sql_visitor.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index 407f14f72b6..7a105ef1a52 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -626,18 +626,6 @@ def test_find_sql_refs_nested_intersect() -> None: """ assert find_sql_refs(sql) == ["table2", "table3", "table4", "table1"] - @staticmethod - def test_find_sql_refs_with_recursive_cte() -> None: - sql = """ - WITH RECURSIVE cte AS ( - SELECT 1 AS n FROM table1 - UNION ALL - SELECT n + 1 FROM cte WHERE n < 10 - ) - SELECT * FROM cte; - """ - assert find_sql_refs(sql) == ["table1"] - @staticmethod def test_find_sql_refs_with_alias() -> None: sql = "SELECT * FROM employees AS e;" From 436c0fb26182c27b751d123daf1bd7f9f9011379 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Tue, 31 Dec 2024 11:54:33 +0800 Subject: [PATCH 16/16] add duckdb dialect and error handling for parse --- marimo/_ast/sql_visitor.py | 8 +++++++- tests/_ast/test_sql_visitor.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/marimo/_ast/sql_visitor.py b/marimo/_ast/sql_visitor.py index 2458033992b..837e012dc81 100644 --- a/marimo/_ast/sql_visitor.py +++ b/marimo/_ast/sql_visitor.py @@ -319,6 +319,7 @@ def find_sql_refs( return [] from sqlglot import exp, parse + from sqlglot.errors import ParseError from sqlglot.optimizer.scope import build_scope refs: list[str] = [] @@ -339,7 +340,12 @@ def append_refs_from_table(table: exp.Table) -> None: if table.name: refs.append(table.name) - expression_list = parse(sql_statement) + try: + expression_list = parse(sql_statement, dialect="duckdb") + except ParseError as e: + LOGGER.error(f"Unable to parse SQL. Error: {e}") + return [] + for expression in expression_list: if expression is None: continue diff --git a/tests/_ast/test_sql_visitor.py b/tests/_ast/test_sql_visitor.py index 7a105ef1a52..d1aa2910694 100644 --- a/tests/_ast/test_sql_visitor.py +++ b/tests/_ast/test_sql_visitor.py @@ -726,3 +726,8 @@ def test_find_sql_refs_select_in_delete() -> None: ); """ assert find_sql_refs(sql) == ["table1", "table2"] + + @staticmethod + def test_find_sql_refs_invalid_sql() -> None: + sql = "SELECT * FROM" + assert find_sql_refs(sql) == []