Skip to content

Commit

Permalink
Feat!: add returning to merge expression builder (#4125)
Browse files Browse the repository at this point in the history
* Add returning to merge

* Fix

* Fix

* Fix test

* Fmt

* Quote self
  • Loading branch information
max-muoto authored Sep 16, 2024
1 parent 22c456d commit ba015dc
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
20 changes: 12 additions & 8 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""

from __future__ import annotations

import datetime
import math
import numbers
Expand All @@ -36,6 +35,7 @@
from sqlglot.tokens import Token, TokenError

if t.TYPE_CHECKING:
from typing_extensions import Self
from sqlglot._typing import E, Lit
from sqlglot.dialects.dialect import DialectType

Expand Down Expand Up @@ -1368,7 +1368,7 @@ def returning(
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> DML:
) -> "Self":
"""
Set the RETURNING expression. Not supported by all dialects.
Expand Down Expand Up @@ -6276,7 +6276,7 @@ class Use(Expression):
arg_types = {"this": True, "kind": False}


class Merge(Expression):
class Merge(DML):
arg_types = {
"this": True,
"using": True,
Expand Down Expand Up @@ -6840,9 +6840,7 @@ def delete(
if where:
delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts)
if returning:
delete_expr = t.cast(
Delete, delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
)
delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
return delete_expr


Expand Down Expand Up @@ -6885,7 +6883,7 @@ def insert(
insert = Insert(this=this, expression=expr, overwrite=overwrite)

if returning:
insert = t.cast(Insert, insert.returning(returning, dialect=dialect, copy=False, **opts))
insert = insert.returning(returning, dialect=dialect, copy=False, **opts)

return insert

Expand All @@ -6895,6 +6893,7 @@ def merge(
into: ExpOrStr,
using: ExpOrStr,
on: ExpOrStr,
returning: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
copy: bool = True,
**opts,
Expand All @@ -6915,14 +6914,15 @@ def merge(
into: The target table to merge data into.
using: The source table to merge data from.
on: The join condition for the merge.
returning: The columns to return from the merge.
dialect: The dialect used to parse the input expressions.
copy: Whether to copy the expression.
**opts: Other options to use to parse the input expressions.
Returns:
Merge: The syntax tree for the MERGE statement.
"""
return Merge(
merge = Merge(
this=maybe_parse(into, dialect=dialect, copy=copy, **opts),
using=maybe_parse(using, dialect=dialect, copy=copy, **opts),
on=maybe_parse(on, dialect=dialect, copy=copy, **opts),
Expand All @@ -6931,6 +6931,10 @@ def merge(
for when_expr in when_exprs
],
)
if returning:
merge = merge.returning(returning, dialect=dialect, copy=False, **opts)

return merge


def condition(
Expand Down
8 changes: 5 additions & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3625,13 +3625,15 @@ def merge_sql(self, expression: exp.Merge) -> str:
using = f"USING {self.sql(expression, 'using')}"
on = f"ON {self.sql(expression, 'on')}"
expressions = self.expressions(expression, sep=" ", indent=False)
returning = self.sql(expression, "returning")
if returning:
expressions = f"{expressions}{returning}"

sep = self.sep()
returning = self.expressions(expression, key="returning", indent=False)
returning = f"RETURNING {returning}" if returning else ""

return self.prepend_ctes(
expression,
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}{sep}{returning}",
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}",
)

@unsupported_args("format")
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6846,7 +6846,7 @@ def _parse_merge(self) -> exp.Merge:
using=using,
on=on,
expressions=self._parse_when_matched(),
returning=self._match(TokenType.RETURNING) and self._parse_csv(self._parse_bitwise),
returning=self._parse_returning(),
)

def _parse_when_matched(self) -> t.List[exp.When]:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,16 @@ def test_build(self):
),
"MERGE INTO target_table AS target USING source_table AS source ON target.id = source.id WHEN MATCHED THEN UPDATE SET target.name = source.name",
),
(
lambda: exp.merge(
"WHEN MATCHED THEN UPDATE SET target.name = source.name",
into=exp.table_("target_table").as_("target"),
using=exp.table_("source_table").as_("source"),
on="target.id = source.id",
returning="target.*",
),
"MERGE INTO target_table AS target USING source_table AS source ON target.id = source.id WHEN MATCHED THEN UPDATE SET target.name = source.name RETURNING target.*",
),
]:
with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)

0 comments on commit ba015dc

Please sign in to comment.