diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 459bb12eeaafd..5e985c9344287 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1202,18 +1202,6 @@ def recalculate(self) -> _UnfinishedStates: execute=execute_callbacks, ) - if dag.deadline: - # The dagrun has succeeded. If there were any Deadlines for it which were not breached, they are no longer needed. - deadline_alerts = [ - DeadlineAlertModel.get_by_id(alert_id, session) for alert_id in dag.deadline - ] - - if any( - deadline_alert.reference_class in SerializedReferenceModels.TYPES.DAGRUN - for deadline_alert in deadline_alerts - ): - Deadline.prune_deadlines(session=session, conditions={DagRun.id: self.id}) - # if *all tasks* are deadlocked, the run failed elif unfinished.should_schedule and not are_runnable_tasks: self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self) @@ -1267,6 +1255,20 @@ def recalculate(self) -> _UnfinishedStates: self.data_interval_start, self.data_interval_end, ) + + if dag.deadline: + # The dagrun has reached a terminal state. Prune any pending deadlines + # so they don't fire after the run is already finished. + deadline_alerts = [ + DeadlineAlertModel.get_by_id(alert_id, session) for alert_id in dag.deadline + ] + + if any( + deadline_alert.reference_class in SerializedReferenceModels.TYPES.DAGRUN + for deadline_alert in deadline_alerts + ): + Deadline.prune_deadlines(session=session, conditions={DagRun.id: self.id}) + session.flush() self._emit_dagrun_span(state=self.state) diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index dd34c2d10e7ea..ce3b5e1e8c85c 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -1360,6 +1360,36 @@ def test_dagrun_success_handles_empty_deadline_list(self, mock_prune, dag_maker, mock_prune.assert_not_called() assert dag_run.state == DagRunState.SUCCESS + @mock.patch.object(Deadline, "prune_deadlines") + @mock.patch.object(DeadlineAlertModel, "get_by_id") + def test_dagrun_failure_prunes_dagrun_deadlines( + self, mock_get_by_id, mock_prune, session, deadline_test_dag + ): + """Deadlines should be pruned when a DAG run fails, not just on success.""" + mock_deadline_alert = mock.MagicMock() + mock_deadline_alert.reference_class = SerializedReferenceModels.FixedDatetimeDeadline + mock_get_by_id.return_value = mock_deadline_alert + + scheduler_dag = deadline_test_dag() + + deadline_ids = ["deadline-uuid-1", "deadline-uuid-2"] + scheduler_dag.deadline = deadline_ids + + dag_run = self.create_dag_run( + dag=scheduler_dag, + task_states={"task_1": TaskInstanceState.SUCCESS, "task_2": TaskInstanceState.FAILED}, + session=session, + ) + dag_run.dag = scheduler_dag + + dag_run.update_state(session=session) + + assert mock_get_by_id.call_count == len(deadline_ids) + for deadline_id in deadline_ids: + mock_get_by_id.assert_any_call(deadline_id, session) + mock_prune.assert_called_once_with(session=session, conditions={DagRun.id: dag_run.id}) + assert dag_run.state == DagRunState.FAILED + @pytest.mark.parametrize( ("run_type", "expected_tis"),