3737 SnapshotCreationFailedError ,
3838 SnapshotNameVersion ,
3939)
40+ from sqlmesh .core .snapshot .definition import SnapshotEvaluationTriggers
4041from sqlmesh .utils import to_snake_case
4142from sqlmesh .core .state_sync import StateSync
4243from sqlmesh .utils import CorrelationId
@@ -83,6 +84,7 @@ def __init__(
8384 self .default_catalog = default_catalog
8485 self .console = console or get_console ()
8586 self ._circuit_breaker : t .Optional [t .Callable [[], bool ]] = None
87+ self ._restatement_triggers : t .Dict [SnapshotId , t .List [SnapshotId ]] = {}
8688
8789 def evaluate (
8890 self ,
@@ -234,6 +236,27 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
234236 self .console .log_success ("SKIP: No model batches to execute" )
235237 return
236238
239+ directly_modified_triggers : t .Dict [SnapshotId , t .List [SnapshotId ]] = {}
240+ for parent , children in plan .indirectly_modified_snapshots .items ():
241+ parent_id = stage .all_snapshots [parent ].snapshot_id
242+ directly_modified_triggers [parent_id ] = directly_modified_triggers .get (
243+ parent_id , []
244+ ) + [parent_id ]
245+ for child in children :
246+ directly_modified_triggers [child ] = directly_modified_triggers .get (child , []) + [
247+ parent_id
248+ ]
249+ directly_modified_triggers = {
250+ k : list (dict .fromkeys (v )) for k , v in directly_modified_triggers .items ()
251+ }
252+ snapshot_evaluation_triggers = {
253+ s_id : SnapshotEvaluationTriggers (
254+ directly_modified_triggers = directly_modified_triggers .get (s_id , []),
255+ restatement_triggers = self ._restatement_triggers .get (s_id , []),
256+ )
257+ for s_id in [s .snapshot_id for s in stage .all_snapshots .values ()]
258+ }
259+
237260 scheduler = self .create_scheduler (stage .all_snapshots .values (), self .snapshot_evaluator )
238261 # Convert model name restatements to snapshot ID restatements
239262 restatements_by_snapshot_id = {
@@ -249,6 +272,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
249272 start = plan .start ,
250273 end = plan .end ,
251274 restatements = restatements_by_snapshot_id ,
275+ snapshot_evaluation_triggers = snapshot_evaluation_triggers ,
252276 )
253277 if errors :
254278 raise PlanError ("Plan application failed." )
@@ -286,13 +310,14 @@ def visit_restatement_stage(
286310 # by forcing dev environments to re-run intervals that changed in prod
287311 #
288312 # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
289- snapshot_intervals_to_restate . update (
313+ restatement_intervals_all_environments , self . _restatement_triggers = (
290314 self ._restatement_intervals_across_all_environments (
291315 prod_restatements = plan .restatements ,
292316 disable_restatement_models = plan .disabled_restatement_models ,
293317 loaded_snapshots = {s .snapshot_id : s for s in stage .all_snapshots .values ()},
294318 )
295319 )
320+ snapshot_intervals_to_restate .update (restatement_intervals_all_environments )
296321
297322 self .state_sync .remove_intervals (
298323 snapshot_intervals = list (snapshot_intervals_to_restate ),
@@ -415,7 +440,9 @@ def _restatement_intervals_across_all_environments(
415440 prod_restatements : t .Dict [str , Interval ],
416441 disable_restatement_models : t .Set [str ],
417442 loaded_snapshots : t .Dict [SnapshotId , Snapshot ],
418- ) -> t .Set [t .Tuple [SnapshotTableInfo , Interval ]]:
443+ ) -> t .Tuple [
444+ t .Set [t .Tuple [SnapshotTableInfo , Interval ]], t .Dict [SnapshotId , t .List [SnapshotId ]]
445+ ]:
419446 """
420447 Given a map of snapshot names + intervals to restate in prod:
421448 - Look up matching snapshots across all environments (match based on name - regardless of version)
@@ -426,14 +453,14 @@ def _restatement_intervals_across_all_environments(
426453 run in those environments causes the intervals to be repopulated
427454 """
428455 if not prod_restatements :
429- return set ()
456+ return set (), {}
430457
431458 prod_name_versions : t .Set [SnapshotNameVersion ] = {
432459 s .name_version for s in loaded_snapshots .values ()
433460 }
434461
435462 snapshots_to_restate : t .Dict [SnapshotId , t .Tuple [SnapshotTableInfo , Interval ]] = {}
436-
463+ restatement_downstream_ids : t . Dict [ SnapshotId , t . List [ SnapshotId ]] = {}
437464 for env_summary in self .state_sync .get_environments_summary ():
438465 # Fetch the full environment object one at a time to avoid loading all environments into memory at once
439466 env = self .state_sync .get_environment (env_summary .name )
@@ -450,10 +477,17 @@ def _restatement_intervals_across_all_environments(
450477 for restatement , intervals in prod_restatements .items ():
451478 if restatement not in keyed_snapshots :
452479 continue
480+
481+ downstream = env_dag .downstream (restatement )
482+ if not env .is_dev and restatement not in disable_restatement_models :
483+ restatement_downstream_ids [keyed_snapshots [restatement ].snapshot_id ] = [
484+ keyed_snapshots [name ].snapshot_id
485+ for name in downstream
486+ if name not in disable_restatement_models
487+ ]
488+
453489 affected_snapshot_names = [
454- x
455- for x in ([restatement ] + env_dag .downstream (restatement ))
456- if x not in disable_restatement_models
490+ x for x in ([restatement ] + downstream ) if x not in disable_restatement_models
457491 ]
458492 snapshots_to_restate .update (
459493 {
@@ -464,6 +498,14 @@ def _restatement_intervals_across_all_environments(
464498 }
465499 )
466500
501+ restatement_triggers : t .Dict [SnapshotId , t .List [SnapshotId ]] = {
502+ id : [id ] for id in restatement_downstream_ids
503+ }
504+ for parent , children in restatement_downstream_ids .items ():
505+ for child in children :
506+ restatement_triggers [child ] = restatement_triggers .get (child , []) + [parent ]
507+ restatement_triggers = {k : list (dict .fromkeys (v )) for k , v in restatement_triggers .items ()}
508+
467509 # for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
468510 # include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
469511 # so we only do it if necessary
@@ -499,7 +541,7 @@ def _restatement_intervals_across_all_environments(
499541 )
500542 snapshots_to_restate [full_snapshot_id ] = (full_snapshot .table_info , new_intervals )
501543
502- return set (snapshots_to_restate .values ())
544+ return set (snapshots_to_restate .values ()), restatement_triggers
503545
504546 def _update_intervals_for_new_snapshots (self , snapshots : t .Collection [Snapshot ]) -> None :
505547 snapshots_intervals : t .List [SnapshotIntervals ] = []
0 commit comments