From 051c6f085e426dbcdda3fbc324a180ada268355f Mon Sep 17 00:00:00 2001 From: Jo <46752250+georgesittas@users.noreply.github.com> Date: Tue, 10 Dec 2024 19:41:46 +0200 Subject: [PATCH] Refactor!!: bundle multiple WHEN [NOT] MATCHED into a exp.WhenSequence (#4495) * Refactor!!: bundle multiple WHEN [NOT] MATCHED into a exp.WhenSequence * Rename WhenSequence to Whens --- sqlglot/dialects/dialect.py | 2 +- sqlglot/expressions.py | 21 +++++++++++++++------ sqlglot/generator.py | 10 +++++++--- sqlglot/parser.py | 8 ++++---- tests/test_parser.py | 4 ++-- 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index a363c6e69e..4ed0290a40 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1547,7 +1547,7 @@ def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: if alias: targets.add(normalize(alias.this)) - for when in expression.expressions: + for when in expression.args["whens"].expressions: # only remove the target names from the THEN clause # theyre still valid in the part of WHEN MATCHED / WHEN NOT MATCHED # ref: https://github.com/TobikoData/sqlmesh/issues/2934 diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index c000f6faaf..39975d7d57 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6682,16 +6682,22 @@ class Merge(DML): "this": True, "using": True, "on": True, - "expressions": True, + "whens": True, "with": False, "returning": False, } -class When(Func): +class When(Expression): arg_types = {"matched": True, "source": False, "condition": False, "then": True} +class Whens(Expression): + """Wraps around one or more WHEN [NOT] MATCHED [...] clauses.""" + + arg_types = {"expressions": True} + + # https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html # https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16 class NextValueFor(Func): @@ -7349,14 +7355,17 @@ def merge( Returns: Merge: The syntax tree for the MERGE statement. """ + expressions = [] + for when_expr in when_exprs: + expressions.extend( + maybe_parse(when_expr, dialect=dialect, copy=copy, into=Whens, **opts).expressions + ) + merge = Merge( this=maybe_parse(into, dialect=dialect, copy=copy, **opts), using=maybe_parse(using, dialect=dialect, copy=copy, **opts), on=maybe_parse(on, dialect=dialect, copy=copy, **opts), - expressions=[ - maybe_parse(when_expr, dialect=dialect, copy=copy, into=When, **opts) - for when_expr in when_exprs - ], + whens=Whens(expressions=expressions), ) if returning: merge = merge.returning(returning, dialect=dialect, copy=False, **opts) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index e0965c772a..635a5f5fa7 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -3696,6 +3696,9 @@ def when_sql(self, expression: exp.When) -> str: then = self.sql(then_expression) return f"WHEN {matched}{source}{condition} THEN {then}" + def whens_sql(self, expression: exp.Whens) -> str: + return self.expressions(expression, sep=" ", indent=False) + def merge_sql(self, expression: exp.Merge) -> str: table = expression.this table_alias = "" @@ -3708,16 +3711,17 @@ def merge_sql(self, expression: exp.Merge) -> str: this = self.sql(table) using = f"USING {self.sql(expression, 'using')}" on = f"ON {self.sql(expression, 'on')}" - expressions = self.expressions(expression, sep=" ", indent=False) + whens = self.sql(expression, "whens") + returning = self.sql(expression, "returning") if returning: - expressions = f"{expressions}{returning}" + whens = f"{whens}{returning}" sep = self.sep() return self.prepend_ctes( expression, - f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}", + f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{whens}", ) @unsupported_args("format") diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c99b3ae629..4cd3d30560 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -778,7 +778,7 @@ class Parser(metaclass=_Parser): exp.Table: lambda self: self._parse_table_parts(), exp.TableAlias: lambda self: self._parse_table_alias(), exp.Tuple: lambda self: self._parse_value(), - exp.When: lambda self: seq_get(self._parse_when_matched(), 0), + exp.Whens: lambda self: self._parse_when_matched(), exp.Where: lambda self: self._parse_where(), exp.Window: lambda self: self._parse_named_window(), exp.With: lambda self: self._parse_with(), @@ -7010,11 +7010,11 @@ def _parse_merge(self) -> exp.Merge: this=target, using=using, on=on, - expressions=self._parse_when_matched(), + whens=self._parse_when_matched(), returning=self._parse_returning(), ) - def _parse_when_matched(self) -> t.List[exp.When]: + def _parse_when_matched(self) -> exp.Whens: whens = [] while self._match(TokenType.WHEN): @@ -7063,7 +7063,7 @@ def _parse_when_matched(self) -> t.List[exp.When]: then=then, ) ) - return whens + return self.expression(exp.Whens, expressions=whens) def _parse_show(self) -> t.Optional[exp.Expression]: parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE) diff --git a/tests/test_parser.py b/tests/test_parser.py index cb194462b8..1bf951a65b 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -27,9 +27,9 @@ def test_parse_into(self): self.assertIsInstance( parse_one( "WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)", - into=exp.When, + into=exp.Whens, ), - exp.When, + exp.Whens, ) with self.assertRaises(ParseError) as ctx: