Skip to content

Commit a1c72e7

Browse files
authored
Feat: allow UNION macros as model queries (#2361)
1 parent b8f9d1e commit a1c72e7

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

sqlmesh/core/model/definition.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,8 +1477,12 @@ def load_sql_based_model(
14771477
**meta_fields,
14781478
)
14791479

1480-
if query_or_seed_insert is not None and isinstance(
1481-
query_or_seed_insert, (exp.Query, d.JinjaQuery)
1480+
if query_or_seed_insert is not None and (
1481+
isinstance(query_or_seed_insert, (exp.Query, d.JinjaQuery))
1482+
or (
1483+
isinstance(query_or_seed_insert, d.MacroFunc)
1484+
and query_or_seed_insert.this.name.lower() == "union"
1485+
)
14821486
):
14831487
jinja_macro_references.update(
14841488
extract_macro_references(query_or_seed_insert.sql(dialect=dialect))
@@ -1792,6 +1796,7 @@ def _split_sql_model_statements(
17921796
if (
17931797
isinstance(expression, (exp.Query, d.JinjaQuery))
17941798
or expression == INSERT_SEED_MACRO_CALL
1799+
or (isinstance(expression, d.MacroFunc) and expression.this.name.lower() == "union")
17951800
):
17961801
query_positions.append((expression, idx))
17971802

tests/core/test_model.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def test_model_validation(query, error):
190190
assert error in str(ex.value)
191191

192192

193-
def test_model_union_query():
193+
def test_model_union_query(sushi_context, assert_exp_eq):
194194
expressions = d.parse(
195195
"""
196196
MODEL (
@@ -204,6 +204,37 @@ def test_model_union_query():
204204

205205
load_sql_based_model(expressions)
206206

207+
expressions = d.parse(
208+
"""
209+
MODEL (
210+
name sushi.test,
211+
kind FULL,
212+
);
213+
214+
@union('all', sushi.marketing, sushi.marketing)
215+
"""
216+
)
217+
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
218+
assert_exp_eq(
219+
sushi_context.get_model("sushi.test").render_query(),
220+
"""SELECT
221+
CAST("marketing"."customer_id" AS INT) AS "customer_id",
222+
CAST("marketing"."status" AS TEXT) AS "status",
223+
CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
224+
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
225+
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
226+
FROM "memory"."sushi"."marketing" AS "marketing"
227+
UNION ALL
228+
SELECT
229+
CAST("marketing"."customer_id" AS INT) AS "customer_id",
230+
CAST("marketing"."status" AS TEXT) AS "status",
231+
CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
232+
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
233+
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
234+
FROM "memory"."sushi"."marketing" AS "marketing"
235+
""",
236+
)
237+
207238

208239
def test_model_validation_union_query():
209240
expressions = d.parse(

0 commit comments

Comments
 (0)