Skip to content

Commit cab40ae

Browse files
authored
Feat: Support simple model selection expressions (#1338)
1 parent 2b0de13 commit cab40ae

File tree

3 files changed

+69
-5
lines changed

3 files changed

+69
-5
lines changed

sqlmesh/core/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -737,9 +737,9 @@ def plan(
737737

738738
models_override: t.Optional[UniqueKeyDict[str, Model]] = None
739739
if select_models:
740-
models_override = Selector(self.state_reader, self._models, self.path).select_models(
741-
select_models, environment, fallback_env_name=create_from or c.PROD
742-
)
740+
models_override = Selector(
741+
self.state_reader, self._models, self.path, dag=self.dag
742+
).select_models(select_models, environment, fallback_env_name=create_from or c.PROD)
743743

744744
plan = Plan(
745745
context_diff=self._context_diff(

sqlmesh/core/selector.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import fnmatch
34
import typing as t
45
from pathlib import Path
56

@@ -18,11 +19,19 @@ def __init__(
1819
state_reader: StateReader,
1920
models: UniqueKeyDict[str, Model],
2021
context_path: Path = Path("."),
22+
dag: t.Optional[DAG[str]] = None,
2123
):
2224
self._state_reader = state_reader
2325
self._models = models
2426
self._context_path = context_path
2527

28+
if dag is None:
29+
self._dag: DAG[str] = DAG()
30+
for model in models.values():
31+
self._dag.add(model.name, model.depends_on)
32+
else:
33+
self._dag = dag
34+
2635
def select_models(
2736
self,
2837
model_selections: t.Iterable[str],
@@ -58,8 +67,7 @@ def select_models(
5867
).values()
5968
}
6069

61-
# TODO: Support selection expressions.
62-
all_selected_models = set(model_selections)
70+
all_selected_models = self._expand_model_selections(model_selections)
6371

6472
dag: DAG[str] = DAG()
6573
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
@@ -82,3 +90,35 @@ def select_models(
8290
update_model_schemas(dag, models, self._context_path)
8391

8492
return models
93+
94+
def _expand_model_selections(self, model_selections: t.Iterable[str]) -> t.Set[str]:
95+
result: t.Set[str] = set()
96+
97+
def _add_model(model_name: str, include_upstream: bool, include_downstream: bool) -> None:
98+
result.add(model_name)
99+
if include_upstream:
100+
result.update(self._dag.upstream(model_name))
101+
if include_downstream:
102+
result.update(self._dag.downstream(model_name))
103+
104+
for selection in model_selections:
105+
if not selection:
106+
continue
107+
108+
include_upstream = False
109+
include_downstream = False
110+
if selection[0] == "+":
111+
selection = selection[1:]
112+
include_upstream = True
113+
if selection[-1] == "+":
114+
selection = selection[:-1]
115+
include_downstream = True
116+
117+
if "*" in selection:
118+
for name in self._models:
119+
if fnmatch.fnmatch(name, selection):
120+
_add_model(name, include_upstream, include_downstream)
121+
else:
122+
_add_model(selection, include_upstream, include_downstream)
123+
124+
return result

tests/core/test_selector.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,34 @@ def test_select_models(mocker: MockerFixture, make_snapshot):
9090
removed_model.name: removed_model.copy(update={"mapping_schema": added_model_schema}),
9191
},
9292
)
93+
_assert_models_equal(
94+
selector.select_models(["+modified_model"], env_name),
95+
{
96+
added_model.name: added_model,
97+
modified_model_v2.name: modified_model_v2.copy(
98+
update={"mapping_schema": added_model_schema}
99+
),
100+
removed_model.name: removed_model.copy(update={"mapping_schema": added_model_schema}),
101+
},
102+
)
103+
_assert_models_equal(
104+
selector.select_models(["added_model+"], env_name),
105+
{
106+
added_model.name: added_model,
107+
modified_model_v2.name: modified_model_v2.copy(
108+
update={"mapping_schema": added_model_schema}
109+
),
110+
removed_model.name: removed_model.copy(update={"mapping_schema": added_model_schema}),
111+
},
112+
)
93113
_assert_models_equal(
94114
selector.select_models(["added_model", "modified_model", "removed_model"], env_name),
95115
local_models,
96116
)
117+
_assert_models_equal(
118+
selector.select_models(["*_model", "removed_model"], env_name),
119+
local_models,
120+
)
97121

98122

99123
def test_select_models_missing_env(mocker: MockerFixture, make_snapshot):

0 commit comments

Comments
 (0)