Skip to content

Commit c3b716b

Browse files
authored
fix!: allow expressions in scd type 2 model unique keys (#2164)
1 parent ec5d223 commit c3b716b

8 files changed

Lines changed: 50 additions & 44 deletions

File tree

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,7 @@ def scd_type_2_by_column(
12171217
self,
12181218
target_table: TableName,
12191219
source_table: QueryOrDF,
1220-
unique_key: t.Sequence[exp.Column],
1220+
unique_key: t.Sequence[exp.Expression],
12211221
valid_from_name: str,
12221222
valid_to_name: str,
12231223
execution_time: TimeLike,
@@ -1248,7 +1248,7 @@ def _scd_type_2(
12481248
self,
12491249
target_table: TableName,
12501250
source_table: QueryOrDF,
1251-
unique_key: t.Union[t.Sequence[exp.Expression], t.Sequence[exp.Column]],
1251+
unique_key: t.Sequence[exp.Expression],
12521252
valid_from_name: str,
12531253
valid_to_name: str,
12541254
execution_time: TimeLike,

sqlmesh/core/model/kind.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from sqlmesh.utils.pydantic import (
1616
PydanticModel,
1717
SQLGlotBool,
18-
SQLGlotListOfColumns,
1918
SQLGlotListOfColumnsOrStar,
2019
SQLGlotListOfFields,
2120
SQLGlotPositiveInt,
@@ -336,7 +335,7 @@ class FullKind(_ModelKind):
336335

337336

338337
class _SCDType2Kind(_ModelKind):
339-
unique_key: SQLGlotListOfColumns
338+
unique_key: SQLGlotListOfFields
340339
valid_from_name: SQLGlotString = "valid_from"
341340
valid_to_name: SQLGlotString = "valid_to"
342341
invalidate_hard_deletes: SQLGlotBool = False

sqlmesh/core/model/meta.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,9 @@ def time_column(self) -> t.Optional[TimeColumn]:
287287

288288
@property
289289
def unique_key(self) -> t.List[exp.Expression]:
290-
if isinstance(self.kind, IncrementalByUniqueKeyKind):
291-
return self.kind.unique_key
292-
return []
293-
294-
@property
295-
def unique_key_columns(self) -> t.List[exp.Column]:
296-
if self.kind.is_scd_type_2:
297-
assert isinstance(self.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind))
290+
if isinstance(
291+
self.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind, IncrementalByUniqueKeyKind)
292+
):
298293
return self.kind.unique_key
299294
return []
300295

sqlmesh/core/snapshot/evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,7 @@ def insert(
12801280
self.adapter.scd_type_2_by_time(
12811281
target_table=name,
12821282
source_table=query_or_df,
1283-
unique_key=model.unique_key_columns,
1283+
unique_key=model.unique_key,
12841284
valid_from_name=model.kind.valid_from_name,
12851285
valid_to_name=model.kind.valid_to_name,
12861286
updated_at_name=model.kind.updated_at_name,
@@ -1295,7 +1295,7 @@ def insert(
12951295
self.adapter.scd_type_2_by_column(
12961296
target_table=name,
12971297
source_table=query_or_df,
1298-
unique_key=model.unique_key_columns,
1298+
unique_key=model.unique_key,
12991299
valid_from_name=model.kind.valid_from_name,
13001300
valid_to_name=model.kind.valid_to_name,
13011301
check_columns=model.kind.columns,
@@ -1325,7 +1325,7 @@ def append(
13251325
self.adapter.scd_type_2_by_time(
13261326
target_table=table_name,
13271327
source_table=query_or_df,
1328-
unique_key=model.unique_key_columns,
1328+
unique_key=model.unique_key,
13291329
valid_from_name=model.kind.valid_from_name,
13301330
valid_to_name=model.kind.valid_to_name,
13311331
updated_at_name=model.kind.updated_at_name,
@@ -1340,7 +1340,7 @@ def append(
13401340
self.adapter.scd_type_2_by_column(
13411341
target_table=table_name,
13421342
source_table=query_or_df,
1343-
unique_key=model.unique_key_columns,
1343+
unique_key=model.unique_key,
13441344
valid_from_name=model.kind.valid_from_name,
13451345
valid_to_name=model.kind.valid_to_name,
13461346
check_columns=model.kind.columns,

sqlmesh/utils/pydantic.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from sqlmesh.core import dialect as d
1515
from sqlmesh.utils import str_to_bool
16-
from sqlmesh.utils.errors import SQLMeshError
1716

1817
if sys.version_info >= (3, 9):
1918
from typing import Annotated
@@ -324,14 +323,6 @@ def list_of_fields_validator(v: t.Any, values: t.Any) -> t.List[exp.Expression]:
324323
return _get_fields(v, values)
325324

326325

327-
def list_of_columns_validator(v: t.Any, values: t.Any) -> t.List[exp.Column]:
328-
expressions = _get_fields(v, values)
329-
for expression in expressions:
330-
if not isinstance(expression, exp.Column):
331-
raise SQLMeshError(f"Invalid column {expression}. Value must be a column")
332-
return t.cast(t.List[exp.Column], expressions)
333-
334-
335326
def list_of_columns_or_star_validator(
336327
v: t.Any, values: t.Any
337328
) -> t.Union[exp.Star, t.List[exp.Column]]:
@@ -347,7 +338,6 @@ def list_of_columns_or_star_validator(
347338
SQLGlotBool = bool
348339
SQLGlotPositiveInt = int
349340
SQLGlotListOfFields = t.List[exp.Expression]
350-
SQLGlotListOfColumns = t.List[exp.Column]
351341
SQLGlotListOfColumnsOrStar = t.Union[t.List[exp.Column], exp.Star]
352342
elif PYDANTIC_MAJOR_VERSION >= 2:
353343
from pydantic.functional_validators import BeforeValidator # type: ignore
@@ -359,7 +349,6 @@ def list_of_columns_or_star_validator(
359349
SQLGlotListOfFields = Annotated[
360350
t.List[exp.Expression], BeforeValidator(list_of_fields_validator)
361351
]
362-
SQLGlotListOfColumns = Annotated[t.List[exp.Column], BeforeValidator(list_of_columns_validator)]
363352
SQLGlotListOfColumnsOrStar = Annotated[
364353
t.Union[t.List[exp.Column], exp.Star], BeforeValidator(list_of_columns_or_star_validator)
365354
]
@@ -387,8 +376,5 @@ class SQLGlotPositiveInt(PydanticTypeProxy[int]):
387376
class SQLGlotListOfFields(PydanticTypeProxy[t.List[exp.Expression]]):
388377
validate = list_of_fields_validator
389378

390-
class SQLGlotListOfColumns(PydanticTypeProxy[t.List[exp.Column]]):
391-
validate = list_of_columns_validator
392-
393379
class SQLGlotListOfColumnsOrStar(PydanticTypeProxy[t.Union[exp.Star, t.List[exp.Column]]]):
394380
validate = list_of_columns_or_star_validator

tests/core/engine_adapter/test_base.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,10 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
978978
source_table=t.cast(
979979
exp.Select, parse_one("SELECT id, name, price, test_updated_at FROM source")
980980
),
981-
unique_key=[exp.func("COALESCE", "id", "''")],
981+
unique_key=[
982+
parse_one("""COALESCE("id", '') || '|' || COALESCE("name", '')"""),
983+
parse_one("""COALESCE("name", '')"""),
984+
],
982985
valid_from_name="test_valid_from",
983986
valid_to_name="test_valid_to",
984987
updated_at_name="test_updated_at",
@@ -999,7 +1002,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
9991002
"""
10001003
CREATE OR REPLACE TABLE "target" AS
10011004
WITH "source" AS (
1002-
SELECT DISTINCT ON (COALESCE("id", ''))
1005+
SELECT DISTINCT ON (COALESCE("id", '') || '|' || COALESCE("name", ''), COALESCE("name", ''))
10031006
TRUE AS "_exists",
10041007
"id",
10051008
"name",
@@ -1045,17 +1048,24 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
10451048
"static"."test_valid_to"
10461049
FROM "static"
10471050
LEFT JOIN "latest"
1048-
ON COALESCE("static"."id", '') = COALESCE("latest"."id", '')
1051+
ON (
1052+
COALESCE("static"."id", '') || '|' || COALESCE("static"."name", '')
1053+
) = (
1054+
COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '')
1055+
)
1056+
AND COALESCE("static"."name", '') = COALESCE("latest"."name", '')
10491057
WHERE
10501058
"latest"."test_valid_to" IS NULL
10511059
), "latest_deleted" AS (
10521060
SELECT
10531061
TRUE AS "_exists",
1054-
COALESCE("id", '') AS "_key0",
1062+
COALESCE("id", '') || '|' || COALESCE("name", '') AS "_key0",
1063+
COALESCE("name", '') AS "_key1",
10551064
MAX("test_valid_to") AS "test_valid_to"
10561065
FROM "deleted"
10571066
GROUP BY
1058-
COALESCE("id", '')
1067+
COALESCE("id", '') || '|' || COALESCE("name", ''),
1068+
COALESCE("name", '')
10591069
), "joined" AS (
10601070
SELECT
10611071
"source"."_exists",
@@ -1071,7 +1081,12 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
10711081
"source"."test_updated_at" AS "test_updated_at"
10721082
FROM "latest"
10731083
LEFT JOIN "source"
1074-
ON COALESCE("latest"."id", '') = COALESCE("source"."id", '')
1084+
ON (
1085+
COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '')
1086+
) = (
1087+
COALESCE("source"."id", '') || '|' || COALESCE("source"."name", '')
1088+
)
1089+
AND COALESCE("latest"."name", '') = COALESCE("source"."name", '')
10751090
UNION
10761091
SELECT
10771092
"source"."_exists",
@@ -1087,7 +1102,12 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
10871102
"source"."test_updated_at" AS "test_updated_at"
10881103
FROM "latest"
10891104
RIGHT JOIN "source"
1090-
ON COALESCE("latest"."id", '') = COALESCE("source"."id", '')
1105+
ON (
1106+
COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '')
1107+
) = (
1108+
COALESCE("source"."id", '') || '|' || COALESCE("source"."name", '')
1109+
)
1110+
AND COALESCE("latest"."name", '') = COALESCE("source"."name", '')
10911111
), "updated_rows" AS (
10921112
SELECT
10931113
COALESCE("joined"."t_id", "joined"."id") AS "id",
@@ -1114,7 +1134,10 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
11141134
END AS "test_valid_to"
11151135
FROM "joined"
11161136
LEFT JOIN "latest_deleted"
1117-
ON COALESCE("joined"."id", '') = "latest_deleted"."_key0"
1137+
ON (
1138+
COALESCE("joined"."id", '') || '|' || COALESCE("joined"."name", '')
1139+
) = "latest_deleted"."_key0"
1140+
AND COALESCE("joined"."name", '') = "latest_deleted"."_key1"
11181141
), "inserted_rows" AS (
11191142
SELECT
11201143
"id",

tests/core/engine_adapter/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ def test_scd_type_2_by_time(ctx: TestContext):
11791179
ctx.engine_adapter.scd_type_2_by_time(
11801180
table,
11811181
ctx.input_data(input_data, input_schema),
1182-
unique_key=[exp.to_identifier("id")],
1182+
unique_key=[parse_one("COALESCE(id, -1)")],
11831183
valid_from_name="valid_from",
11841184
valid_to_name="valid_to",
11851185
updated_at_name="updated_at",

tests/core/test_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,7 +2460,7 @@ def test_scd_type_2_by_time_defaults():
24602460
MODEL (
24612461
name db.table,
24622462
kind SCD_TYPE_2 (
2463-
unique_key "ID",
2463+
unique_key (COALESCE("ID", '') || '|' || COALESCE("ds", ''), COALESCE("ds", '')),
24642464
),
24652465
);
24662466
SELECT
@@ -2473,7 +2473,10 @@ def test_scd_type_2_by_time_defaults():
24732473
"""
24742474
)
24752475
scd_type_2_model = load_sql_based_model(model_def)
2476-
assert scd_type_2_model.unique_key_columns == [exp.to_column("ID", quoted=True)]
2476+
assert scd_type_2_model.unique_key == [
2477+
parse_one("""COALESCE("ID", '') || '|' || COALESCE("ds", '')"""),
2478+
parse_one("""COALESCE("ds", '')"""),
2479+
]
24772480
assert scd_type_2_model.columns_to_types == {
24782481
"ID": exp.DataType.build("int"),
24792482
"ds": exp.DataType.build("varchar"),
@@ -2525,7 +2528,7 @@ def test_scd_type_2_by_time_overrides():
25252528
"""
25262529
)
25272530
scd_type_2_model = load_sql_based_model(model_def)
2528-
assert scd_type_2_model.unique_key_columns == [
2531+
assert scd_type_2_model.unique_key == [
25292532
exp.column("iD", quoted=True),
25302533
exp.column("ds", quoted=False),
25312534
]
@@ -2566,7 +2569,7 @@ def test_scd_type_2_by_column_defaults():
25662569
"""
25672570
)
25682571
scd_type_2_model = load_sql_based_model(model_def)
2569-
assert scd_type_2_model.unique_key_columns == [exp.to_column("ID", quoted=True)]
2572+
assert scd_type_2_model.unique_key == [exp.to_column("ID", quoted=True)]
25702573
assert scd_type_2_model.kind.columns == [exp.to_column("value_to_track", quoted=True)]
25712574
assert scd_type_2_model.columns_to_types == {
25722575
"ID": exp.DataType.build("int"),
@@ -2614,7 +2617,7 @@ def test_scd_type_2_by_column_overrides():
26142617
"""
26152618
)
26162619
scd_type_2_model = load_sql_based_model(model_def)
2617-
assert scd_type_2_model.unique_key_columns == [
2620+
assert scd_type_2_model.unique_key == [
26182621
exp.column("iD", quoted=True),
26192622
exp.column("ds", quoted=False),
26202623
]

0 commit comments

Comments
 (0)