Skip to content

Commit bfb9ef6

Browse files
committed
add when_matched support to custom time-range materialization
Parse and validate optional materialization_properties.when_matched and pass it through to adapter.merge so update clauses can preserve immutable columns like _inserted_at.
1 parent afe6b3a commit bfb9ef6

2 files changed

Lines changed: 99 additions & 1 deletion

File tree

sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
from sqlmesh import CustomMaterialization
44
from sqlmesh.core.model import Model
55
from sqlmesh.core.model.kind import TimeColumn
6+
import sqlmesh.core.dialect as d
67
from sqlglot import exp
78
from sqlmesh.utils.date import make_inclusive
89
from sqlmesh.utils.errors import ConfigError, SQLMeshError
910
from pydantic import model_validator
10-
from sqlmesh.utils.pydantic import list_of_fields_validator, bool_validator
11+
from sqlmesh.utils.pydantic import (
12+
bool_validator,
13+
list_of_fields_validator,
14+
validate_expression,
15+
)
1116
from sqlmesh.utils.date import TimeLike
1217
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
1318
from sqlmesh import CustomKind
@@ -22,6 +27,26 @@ class NonIdempotentIncrementalByTimeRangeKind(CustomKind):
2227
_primary_key: t.List[exp.Expression]
2328

2429
_partition_by_time_column: bool
30+
_when_matched: t.Optional[exp.Whens]
31+
32+
def _parse_when_matched(self, value: t.Any) -> t.Optional[exp.Whens]:
33+
if value is None:
34+
return None
35+
36+
if isinstance(value, list):
37+
value = " ".join(value)
38+
39+
if isinstance(value, str):
40+
value = value.strip()
41+
if value.startswith("("):
42+
value = value[1:-1]
43+
value = t.cast(exp.Whens, d.parse_one(value, into=exp.Whens, dialect=self.dialect))
44+
45+
value = validate_expression(value, dialect=self.dialect)
46+
return t.cast(
47+
exp.Whens,
48+
value.transform(d.replace_merge_table_aliases, dialect=self.dialect),
49+
)
2550

2651
@model_validator(mode="after")
2752
def _validate_model(self):
@@ -49,6 +74,10 @@ def _validate_model(self):
4974
self.materialization_properties.get("partition_by_time_column", True)
5075
)
5176

77+
self._when_matched = self._parse_when_matched(
78+
self.materialization_properties.get("when_matched")
79+
)
80+
5281
return self
5382

5483
@property
@@ -63,6 +92,10 @@ def primary_key(self) -> t.List[exp.Expression]:
6392
def partition_by_time_column(self) -> bool:
6493
return self._partition_by_time_column
6594

95+
@property
96+
def when_matched(self) -> t.Optional[exp.Whens]:
97+
return self._when_matched
98+
6699

67100
class NonIdempotentIncrementalByTimeRangeMaterialization(
68101
CustomMaterialization[NonIdempotentIncrementalByTimeRangeKind]
@@ -130,6 +163,7 @@ def _inject_alias(node: exp.Expression, alias: str) -> exp.Expression:
130163
source_table=query_or_df,
131164
target_columns_to_types=columns_to_types,
132165
unique_key=model.kind.primary_key,
166+
when_matched=model.kind.when_matched,
133167
merge_filter=exp.and_(*betweens),
134168
source_columns=source_columns,
135169
)

tests/materializations/test_non_idempotent_incremental_by_time_range.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ def test_kind(make_model: ModelMaker):
5959
exp.to_column("id", quoted=True),
6060
exp.to_column("ds", quoted=True),
6161
]
62+
assert model.kind.when_matched is None
63+
64+
model = make_model(
65+
[
66+
"time_column = ds",
67+
"primary_key = (id, ds)",
68+
"when_matched = 'when matched then update set target.name = source.name'",
69+
]
70+
)
71+
assert model.kind.when_matched is not None
6272

6373
# required fields
6474
with pytest.raises(ConfigError, match=r"Invalid time_column"):
@@ -165,6 +175,60 @@ def test_append(make_model: ModelMaker, make_mocked_engine_adapter: MockedEngine
165175
]
166176

167177

178+
def test_insert_with_when_matched(
179+
make_model: ModelMaker, make_mocked_engine_adapter: MockedEngineAdapterMaker
180+
):
181+
model: Model = make_model(
182+
[
183+
"time_column = ds",
184+
"primary_key = name",
185+
"when_matched = 'when matched then update set target.name = source.name'",
186+
],
187+
dialect="trino",
188+
)
189+
adapter = make_mocked_engine_adapter(TrinoEngineAdapter)
190+
strategy = NonIdempotentIncrementalByTimeRangeMaterialization(adapter)
191+
192+
start = to_timestamp("2020-01-01")
193+
end = to_timestamp("2020-01-03")
194+
195+
strategy.insert(
196+
"test.snapshot_table",
197+
query_or_df=model.render_query(
198+
start=start, end=end, execution_time=now(), runtime_stage=RuntimeStage.EVALUATING
199+
),
200+
model=model,
201+
is_first_insert=False,
202+
start=start,
203+
end=end,
204+
render_kwargs={},
205+
)
206+
207+
assert to_sql_calls(adapter) == [
208+
parse_one(
209+
"""
210+
MERGE INTO "test"."snapshot_table" AS "__merge_target__"
211+
USING (
212+
SELECT
213+
CAST("name" AS VARCHAR) AS "name",
214+
CAST("ds" AS TIMESTAMP) AS "ds"
215+
FROM "upstream"."table" AS "table"
216+
WHERE
217+
"ds" BETWEEN '2020-01-01 00:00:00' AND '2020-01-02 23:59:59.999999'
218+
) AS "__MERGE_SOURCE__"
219+
ON (
220+
"__MERGE_SOURCE__"."ds" BETWEEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) AND CAST('2020-01-02 23:59:59.999999' AS TIMESTAMP)
221+
AND "__MERGE_TARGET__"."ds" BETWEEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) AND CAST('2020-01-02 23:59:59.999999' AS TIMESTAMP)
222+
)
223+
AND "__MERGE_TARGET__"."name" = "__MERGE_SOURCE__"."name"
224+
WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."name" = "__MERGE_SOURCE__"."name"
225+
WHEN NOT MATCHED THEN INSERT ("name", "ds") VALUES ("__MERGE_SOURCE__"."name", "__MERGE_SOURCE__"."ds")
226+
""",
227+
dialect=adapter.dialect,
228+
).sql(dialect=adapter.dialect),
229+
]
230+
231+
168232
def test_partition_by_time_column_opt_out(make_model: ModelMaker):
169233
model = make_model(
170234
["time_column = ds", "primary_key = name", "partition_by_time_column = false"]

0 commit comments

Comments
 (0)