Skip to content

Commit

Permalink
Fix!: exp.Merge condition for Trino/Postgres (tobymao#4596)
Browse files Browse the repository at this point in the history
* Fix!: exp.Merge condition for Trino/Postgres

* address PR review comment

* Fixups

---------

Co-authored-by: Jo <[email protected]>
  • Loading branch information
MikeWallis42 and georgesittas authored Jan 13, 2025
1 parent 199508a commit b7ab3f1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
30 changes: 18 additions & 12 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,19 +1570,25 @@ def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
targets.add(normalize(alias.this))

for when in expression.args["whens"].expressions:
# only remove the target names from the THEN clause
# theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED
# ref: https://github.com/TobikoData/sqlmesh/issues/2934
then = when.args.get("then")
# only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED
# they are still valid in the <condition>, the right hand side of each UPDATE and the VALUES part
# (not the column list) of the INSERT
then: exp.Insert | exp.Update | None = when.args.get("then")
if then:
then.transform(
lambda node: (
exp.column(node.this)
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
else node
),
copy=False,
)
if isinstance(then, exp.Update):
for equals in then.find_all(exp.EQ):
equal_lhs = equals.this
if (
isinstance(equal_lhs, exp.Column)
and normalize(equal_lhs.args.get("table")) in targets
):
equal_lhs.replace(exp.column(equal_lhs.this))
if isinstance(then, exp.Insert):
column_list = then.this
if isinstance(column_list, exp.Tuple):
for column in column_list.expressions:
if normalize(column.args.get("table")) in targets:
column.replace(exp.column(column.this))

return self.merge_sql(expression)

Expand Down
11 changes: 11 additions & 0 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2336,6 +2336,17 @@ def test_merge(self):
},
)

# needs to preserve the target alias in then WHEN condition and function but not in the THEN clause
self.validate_all(
"""MERGE INTO foo AS target USING (SELECT a, b FROM tbl) AS src ON src.a = target.a
WHEN MATCHED THEN UPDATE SET target.b = COALESCE(src.b, target.b)
WHEN NOT MATCHED THEN INSERT (target.a, target.b) VALUES (src.a, src.b)""",
write={
"trino": """MERGE INTO foo AS target USING (SELECT a, b FROM tbl) AS src ON src.a = target.a WHEN MATCHED THEN UPDATE SET b = COALESCE(src.b, target.b) WHEN NOT MATCHED THEN INSERT (a, b) VALUES (src.a, src.b)""",
"postgres": """MERGE INTO foo AS target USING (SELECT a, b FROM tbl) AS src ON src.a = target.a WHEN MATCHED THEN UPDATE SET b = COALESCE(src.b, target.b) WHEN NOT MATCHED THEN INSERT (a, b) VALUES (src.a, src.b)""",
},
)

def test_substring(self):
self.validate_all(
"SUBSTR('123456', 2, 3)",
Expand Down

0 comments on commit b7ab3f1

Please sign in to comment.