Skip to content

Commit 236287f

Browse files
authored
Feat: allow parameterizing model kinds using macros (#3401)
1 parent 7a0b34e commit 236287f

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

sqlmesh/core/dialect.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -552,11 +552,15 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
552552
elif key == "columns":
553553
value = self._parse_schema()
554554
elif key == "kind":
555-
id_var = self._parse_id_var(any_token=True)
556-
if not id_var:
557-
value = None
555+
if self._match(TokenType.PARAMETER):
556+
field = _parse_macro(self)
558557
else:
559-
kind = ModelKindName[id_var.name.upper()]
558+
field = self._parse_id_var(any_token=True)
559+
560+
if not field or isinstance(field, (MacroVar, MacroFunc)):
561+
value = field
562+
else:
563+
kind = ModelKindName[field.name.upper()]
560564

561565
if kind in (
562566
ModelKindName.INCREMENTAL_BY_TIME_RANGE,
@@ -573,11 +577,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
573577
else:
574578
props = None
575579

576-
value = self.expression(
577-
ModelKind,
578-
this=kind.value,
579-
expressions=props,
580-
)
580+
value = self.expression(ModelKind, this=kind.value, expressions=props)
581581
elif key == "expression":
582582
value = self._parse_conjunction()
583583
else:

tests/core/test_model.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SqlModel,
3939
TimeColumn,
4040
ExternalKind,
41+
ViewKind,
4142
create_external_model,
4243
create_seed_model,
4344
create_sql_model,
@@ -6091,3 +6092,48 @@ def test_cluster_with_complex_expression():
60916092

60926093
model = load_sql_based_model(expressions)
60936094
assert [expr.sql("snowflake") for expr in model.clustered_by] == ['(TO_DATE("CLUSTER_COL"))']
6095+
6096+
6097+
def test_parametric_model_kind(tmp_path: Path):
6098+
init_example_project(tmp_path, dialect="duckdb")
6099+
6100+
test_sql_file = tmp_path / "models/test_model.sql"
6101+
test_sql_file.write_text(
6102+
"""
6103+
MODEL (
6104+
name test_schema.test_model,
6105+
kind @IF(@gateway = 'main', VIEW, FULL)
6106+
);
6107+
6108+
SELECT
6109+
1 AS c
6110+
"""
6111+
)
6112+
6113+
db_path = str(tmp_path / "db.db")
6114+
config = Config(
6115+
gateways={
6116+
"main": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path)),
6117+
"other": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path)),
6118+
},
6119+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
6120+
)
6121+
6122+
context = Context(paths=tmp_path, config=config)
6123+
plan = context.plan(no_prompts=True, auto_apply=True, no_diff=True)
6124+
6125+
assert len(plan.context_diff.new_snapshots) == 4
6126+
assert isinstance(context.get_model("test_schema.test_model").kind, ViewKind)
6127+
6128+
context = Context(paths=tmp_path, config=config, gateway="other")
6129+
plan = context.plan(no_prompts=True, auto_apply=True, no_diff=True)
6130+
diff = plan.context_diff
6131+
6132+
assert len(diff.new_snapshots) == 1
6133+
assert len(diff.modified_snapshots) == 1
6134+
6135+
new_snapshot, old_snapshot = next(iter(diff.modified_snapshots.values()))
6136+
assert isinstance(t.cast(SqlModel, new_snapshot.node).kind, FullKind)
6137+
assert isinstance(t.cast(SqlModel, old_snapshot.node).kind, ViewKind)
6138+
6139+
assert isinstance(context.get_model("test_schema.test_model").kind, FullKind)

0 commit comments

Comments
 (0)