Skip to content

Commit

Permalink
Refactor!!: bundle multiple WHEN [NOT] MATCHED into a exp.WhenSequence (
Browse files Browse the repository at this point in the history
#4495)

* Refactor!!: bundle multiple WHEN [NOT] MATCHED into a exp.WhenSequence

* Rename WhenSequence to Whens
  • Loading branch information
georgesittas authored Dec 10, 2024
1 parent 43975e4 commit 051c6f0
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 16 deletions.
2 changes: 1 addition & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
if alias:
targets.add(normalize(alias.this))

for when in expression.expressions:
for when in expression.args["whens"].expressions:
# only remove the target names from the THEN clause
# theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED
# ref: https://github.com/TobikoData/sqlmesh/issues/2934
Expand Down
21 changes: 15 additions & 6 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6682,16 +6682,22 @@ class Merge(DML):
"this": True,
"using": True,
"on": True,
"expressions": True,
"whens": True,
"with": False,
"returning": False,
}


class When(Func):
class When(Expression):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}


class Whens(Expression):
"""Wraps around one or more WHEN [NOT] MATCHED [...] clauses."""

arg_types = {"expressions": True}


# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html
# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16
class NextValueFor(Func):
Expand Down Expand Up @@ -7349,14 +7355,17 @@ def merge(
Returns:
Merge: The syntax tree for the MERGE statement.
"""
expressions = []
for when_expr in when_exprs:
expressions.extend(
maybe_parse(when_expr, dialect=dialect, copy=copy, into=Whens, **opts).expressions
)

merge = Merge(
this=maybe_parse(into, dialect=dialect, copy=copy, **opts),
using=maybe_parse(using, dialect=dialect, copy=copy, **opts),
on=maybe_parse(on, dialect=dialect, copy=copy, **opts),
expressions=[
maybe_parse(when_expr, dialect=dialect, copy=copy, into=When, **opts)
for when_expr in when_exprs
],
whens=Whens(expressions=expressions),
)
if returning:
merge = merge.returning(returning, dialect=dialect, copy=False, **opts)
Expand Down
10 changes: 7 additions & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3696,6 +3696,9 @@ def when_sql(self, expression: exp.When) -> str:
then = self.sql(then_expression)
return f"WHEN {matched}{source}{condition} THEN {then}"

def whens_sql(self, expression: exp.Whens) -> str:
return self.expressions(expression, sep=" ", indent=False)

def merge_sql(self, expression: exp.Merge) -> str:
table = expression.this
table_alias = ""
Expand All @@ -3708,16 +3711,17 @@ def merge_sql(self, expression: exp.Merge) -> str:
this = self.sql(table)
using = f"USING {self.sql(expression, 'using')}"
on = f"ON {self.sql(expression, 'on')}"
expressions = self.expressions(expression, sep=" ", indent=False)
whens = self.sql(expression, "whens")

returning = self.sql(expression, "returning")
if returning:
expressions = f"{expressions}{returning}"
whens = f"{whens}{returning}"

sep = self.sep()

return self.prepend_ctes(
expression,
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}",
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{whens}",
)

@unsupported_args("format")
Expand Down
8 changes: 4 additions & 4 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ class Parser(metaclass=_Parser):
exp.Table: lambda self: self._parse_table_parts(),
exp.TableAlias: lambda self: self._parse_table_alias(),
exp.Tuple: lambda self: self._parse_value(),
exp.When: lambda self: seq_get(self._parse_when_matched(), 0),
exp.Whens: lambda self: self._parse_when_matched(),
exp.Where: lambda self: self._parse_where(),
exp.Window: lambda self: self._parse_named_window(),
exp.With: lambda self: self._parse_with(),
Expand Down Expand Up @@ -7010,11 +7010,11 @@ def _parse_merge(self) -> exp.Merge:
this=target,
using=using,
on=on,
expressions=self._parse_when_matched(),
whens=self._parse_when_matched(),
returning=self._parse_returning(),
)

def _parse_when_matched(self) -> t.List[exp.When]:
def _parse_when_matched(self) -> exp.Whens:
whens = []

while self._match(TokenType.WHEN):
Expand Down Expand Up @@ -7063,7 +7063,7 @@ def _parse_when_matched(self) -> t.List[exp.When]:
then=then,
)
)
return whens
return self.expression(exp.Whens, expressions=whens)

def _parse_show(self) -> t.Optional[exp.Expression]:
parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def test_parse_into(self):
self.assertIsInstance(
parse_one(
"WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)",
into=exp.When,
into=exp.Whens,
),
exp.When,
exp.Whens,
)

with self.assertRaises(ParseError) as ctx:
Expand Down

0 comments on commit 051c6f0

Please sign in to comment.