Skip to content

Commit 2b0de13

Browse files
authored
Feat: Support model selection when constructing a plan (#1335)
1 parent 9be3daa commit 2b0de13

File tree

5 files changed

+299
-54
lines changed

5 files changed

+299
-54
lines changed

sqlmesh/cli/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,13 +275,20 @@ def diff(ctx: click.Context, environment: t.Optional[str] = None) -> None:
275275
help="Include unmodified models in the target environment.",
276276
default=None,
277277
)
278+
@click.option(
279+
"--select-model",
280+
type=str,
281+
multiple=True,
282+
help="Select specific model changes that should be included in the plan.",
283+
)
278284
@click.pass_context
279285
@error_handler
280286
def plan(ctx: click.Context, environment: t.Optional[str] = None, **kwargs: t.Any) -> None:
281287
"""Plan a migration of the current context's models with the given environment."""
282288
context = ctx.obj
283289
restate_models = kwargs.pop("restate_model", None)
284-
context.plan(environment, restate_models=restate_models, **kwargs)
290+
select_models = kwargs.pop("select_model", None)
291+
context.plan(environment, restate_models=restate_models, select_models=select_models, **kwargs)
285292

286293

287294
@cli.command("run")

sqlmesh/core/context.py

Lines changed: 71 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from sqlmesh.core.plan import Plan
7272
from sqlmesh.core.scheduler import Scheduler
7373
from sqlmesh.core.schema_loader import create_schema_file
74+
from sqlmesh.core.selector import Selector
7475
from sqlmesh.core.snapshot import (
7576
Snapshot,
7677
SnapshotEvaluator,
@@ -545,58 +546,7 @@ def snapshots(self) -> t.Dict[str, Snapshot]:
545546
If one of the snapshots has been previosly stored in the persisted state, the stored
546547
instance will be returned.
547548
"""
548-
prod = self.state_reader.get_environment(c.PROD)
549-
remote_snapshots = (
550-
{
551-
snapshot.name: snapshot
552-
for snapshot in self.state_reader.get_snapshots(prod.snapshots).values()
553-
}
554-
if prod
555-
else {}
556-
)
557-
558-
fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {}
559-
models = self._models.copy()
560-
audits = self._audits.copy()
561-
projects = {config.project for config in self.configs.values()}
562-
563-
for name, snapshot in remote_snapshots.items():
564-
if name not in models and snapshot.project not in projects:
565-
models[name] = snapshot.model
566-
567-
for audit in snapshot.audits:
568-
if name not in audits:
569-
audits[name] = audit
570-
571-
snapshots = {}
572-
573-
for model in models.values():
574-
if model.name not in self._models and model.name in remote_snapshots:
575-
snapshot = remote_snapshots[model.name]
576-
ttl = snapshot.ttl
577-
project = snapshot.project
578-
else:
579-
config = self.config_for_model(model)
580-
ttl = config.snapshot_ttl
581-
project = config.project
582-
583-
snapshot = Snapshot.from_model(
584-
model,
585-
nodes=models,
586-
audits=audits,
587-
cache=fingerprint_cache,
588-
ttl=ttl,
589-
project=project,
590-
)
591-
snapshots[model.name] = snapshot
592-
593-
stored_snapshots = self.state_reader.get_snapshots(snapshots.values())
594-
595-
for snapshot in stored_snapshots.values():
596-
# Keep the original model instance to preserve the query cache.
597-
snapshot.node = snapshots[snapshot.name].node
598-
599-
return {name: stored_snapshots.get(s.snapshot_id, s) for name, s in snapshots.items()}
549+
return self._snapshots()
600550

601551
def render(
602552
self,
@@ -728,6 +678,7 @@ def plan(
728678
no_auto_categorization: t.Optional[bool] = None,
729679
effective_from: t.Optional[TimeLike] = None,
730680
include_unmodified: t.Optional[bool] = None,
681+
select_models: t.Optional[t.Collection[str]] = None,
731682
) -> Plan:
732683
"""Interactively create a migration plan.
733684
@@ -762,6 +713,7 @@ def plan(
762713
option determines the behavior.
763714
effective_from: The effective date from which to apply forward-only changes on production.
764715
include_unmodified: Indicates whether to include unmodified models in the target development environment.
716+
model_selections: A list of model selection strings to filter the models that should be included into this plan.
765717
766718
Returns:
767719
The populated Plan object.
@@ -783,8 +735,18 @@ def plan(
783735
if include_unmodified is None:
784736
include_unmodified = self.config.include_unmodified
785737

738+
models_override: t.Optional[UniqueKeyDict[str, Model]] = None
739+
if select_models:
740+
models_override = Selector(self.state_reader, self._models, self.path).select_models(
741+
select_models, environment, fallback_env_name=create_from or c.PROD
742+
)
743+
786744
plan = Plan(
787-
context_diff=self._context_diff(environment or c.PROD, create_from=create_from),
745+
context_diff=self._context_diff(
746+
environment or c.PROD,
747+
snapshots=self._snapshots(models_override),
748+
create_from=create_from,
749+
),
788750
start=start,
789751
end=end,
790752
execution_time=execution_time,
@@ -1178,6 +1140,62 @@ def _model_tables(self) -> t.Dict[str, str]:
11781140
for name, snapshot in self.snapshots.items()
11791141
}
11801142

1143+
def _snapshots(
1144+
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
1145+
) -> t.Dict[str, Snapshot]:
1146+
prod = self.state_reader.get_environment(c.PROD)
1147+
remote_snapshots = (
1148+
{
1149+
snapshot.name: snapshot
1150+
for snapshot in self.state_reader.get_snapshots(prod.snapshots).values()
1151+
}
1152+
if prod
1153+
else {}
1154+
)
1155+
1156+
fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {}
1157+
models = (models_override or self._models).copy()
1158+
audits = self._audits.copy()
1159+
projects = {config.project for config in self.configs.values()}
1160+
1161+
for name, snapshot in remote_snapshots.items():
1162+
if name not in models and snapshot.project not in projects:
1163+
models[name] = snapshot.model
1164+
1165+
for audit in snapshot.audits:
1166+
if name not in audits:
1167+
audits[name] = audit
1168+
1169+
snapshots = {}
1170+
1171+
for model in models.values():
1172+
if model.name not in self._models and model.name in remote_snapshots:
1173+
snapshot = remote_snapshots[model.name]
1174+
ttl = snapshot.ttl
1175+
project = snapshot.project
1176+
else:
1177+
config = self.config_for_model(model)
1178+
ttl = config.snapshot_ttl
1179+
project = config.project
1180+
1181+
snapshot = Snapshot.from_model(
1182+
model,
1183+
nodes=models,
1184+
audits=audits,
1185+
cache=fingerprint_cache,
1186+
ttl=ttl,
1187+
project=project,
1188+
)
1189+
snapshots[model.name] = snapshot
1190+
1191+
stored_snapshots = self.state_reader.get_snapshots(snapshots.values())
1192+
1193+
for snapshot in stored_snapshots.values():
1194+
# Keep the original model instance to preserve the query cache.
1195+
snapshot.node = snapshots[snapshot.name].node
1196+
1197+
return {name: stored_snapshots.get(s.snapshot_id, s) for name, s in snapshots.items()}
1198+
11811199
def _context_diff(
11821200
self,
11831201
environment: str,

sqlmesh/core/selector.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
from pathlib import Path
5+
6+
from sqlmesh.core.environment import Environment
7+
from sqlmesh.core.loader import update_model_schemas
8+
from sqlmesh.core.model import Model
9+
from sqlmesh.core.state_sync import StateReader
10+
from sqlmesh.utils import UniqueKeyDict
11+
from sqlmesh.utils.dag import DAG
12+
from sqlmesh.utils.errors import SQLMeshError
13+
14+
15+
class Selector:
16+
def __init__(
17+
self,
18+
state_reader: StateReader,
19+
models: UniqueKeyDict[str, Model],
20+
context_path: Path = Path("."),
21+
):
22+
self._state_reader = state_reader
23+
self._models = models
24+
self._context_path = context_path
25+
26+
def select_models(
27+
self,
28+
model_selections: t.Iterable[str],
29+
target_env_name: str,
30+
fallback_env_name: t.Optional[str] = None,
31+
) -> UniqueKeyDict[str, Model]:
32+
"""Given a set of selections returns models from the current state with names matching the
33+
selection while sourcing the remaining models from the target environment.
34+
35+
Args:
36+
model_selections: A set of selections.
37+
target_env_name: The name of the target environment.
38+
fallback_env_name: The name of the fallback environment that will be used if the target
39+
environment doesn't exist.
40+
41+
Returns:
42+
A dictionary of models.
43+
"""
44+
target_env = self._state_reader.get_environment(Environment.normalize_name(target_env_name))
45+
if not target_env and fallback_env_name:
46+
target_env = self._state_reader.get_environment(
47+
Environment.normalize_name(fallback_env_name)
48+
)
49+
if not target_env:
50+
raise SQLMeshError(
51+
f"Either the '{target_env_name}' or the '{fallback_env_name}' environment must exist in order to apply model selection."
52+
)
53+
54+
env_models = {
55+
s.name: s.model
56+
for s in self._state_reader.get_snapshots(
57+
target_env.snapshots, hydrate_seeds=True
58+
).values()
59+
}
60+
61+
# TODO: Support selection expressions.
62+
all_selected_models = set(model_selections)
63+
64+
dag: DAG[str] = DAG()
65+
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
66+
all_model_names = set(self._models) | set(env_models)
67+
for name in all_model_names:
68+
model: t.Optional[Model] = None
69+
if name not in all_selected_models and name in env_models:
70+
# Unselected modified or added model.
71+
model = env_models[name]
72+
elif name in all_selected_models and name in self._models:
73+
# Selected modified or removed model.
74+
model = self._models[name]
75+
76+
if model:
77+
# model.copy() can't be used here due to a cached state that can be a part of a model instance.
78+
model = type(model).parse_obj(model.dict(exclude={"mapping_schema"}))
79+
models[name] = model
80+
dag.add(model.name, model.depends_on)
81+
82+
update_model_schemas(dag, models, self._context_path)
83+
84+
return models

sqlmesh/magics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,12 @@ def test(self, line: str, test_def_raw: t.Optional[str] = None) -> None:
306306
help="Include unmodified models in the target environment.",
307307
default=None,
308308
)
309+
@argument(
310+
"--select-model",
311+
type=str,
312+
nargs="*",
313+
help="Select specific model changes that should be included in the plan.",
314+
)
309315
@line_magic
310316
def plan(self, line: str) -> None:
311317
"""Goes through a set of prompts to both establish a plan and apply it"""
@@ -332,6 +338,7 @@ def plan(self, line: str) -> None:
332338
no_auto_categorization=args.no_auto_categorization,
333339
effective_from=args.effective_from,
334340
include_unmodified=args.include_unmodified,
341+
select_models=args.select_model,
335342
)
336343
self._context.console = console
337344

0 commit comments

Comments
 (0)