diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index 5d2b6955d75ba..31804333ba91e 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -491,9 +491,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: ) ) elif trigger_rule == TR.ALL_SUCCESS: - num_failures = upstream - success - if ti.map_index > -1: - num_failures -= removed + num_failures = upstream - success - removed if num_failures > 0: yield self._failing_status( reason=( @@ -504,9 +502,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: ) ) elif trigger_rule == TR.ALL_FAILED: - num_success = upstream - failed - upstream_failed - if ti.map_index > -1: - num_success -= removed + num_success = upstream - failed - upstream_failed - removed if num_success > 0: yield self._failing_status( reason=( @@ -527,9 +523,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: ) ) elif trigger_rule == TR.NONE_FAILED or trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: - num_failures = upstream - success - skipped - if ti.map_index > -1: - num_failures -= removed + num_failures = upstream - success - skipped - removed if num_failures > 0: yield self._failing_status( reason=( @@ -581,11 +575,8 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: ) elif trigger_rule == TR.ALL_DONE_MIN_ONE_SUCCESS: # For this trigger rule, skipped tasks are not considered "done" - non_skipped_done = success + failed + upstream_failed + removed - non_skipped_upstream = upstream - skipped - if ti.map_index > -1: - non_skipped_upstream -= removed - non_skipped_done -= removed + non_skipped_done = success + failed + upstream_failed + non_skipped_upstream = upstream - skipped - removed if skipped > 0: yield self._failing_status( diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 14722f83b0cce..87a0b40741e61 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -2452,6 +2452,127 @@ def mapped_print_value(arg): assert len(success_tis) == rerun_length +def test_mapped_task_length_reduction_rerun_downstream_not_deadlocked(session, dag_maker): + @task + def producer(): + context = get_current_context() + if context["ti"].try_number == 0: + return [i for i in range(3)] + return [i for i in range(2)] + + @task + def work(arg): + return arg + + @task + def finish(data): + return sum(data) + + def _task_ids(tis): + return [(ti.task_id, ti.map_index) for ti in tis] + + with dag_maker(session=session): + produced = producer() + mapped = work.expand(arg=produced) + done = finish(produced) + mapped >> done + + dr: DagRun = dag_maker.create_dagrun() + + # First run with 3 mapped task instances. + dag_maker.run_ti("producer", dr) + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1), ("work", 2)] + + for ti in decision.schedulable_tis: + dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index) + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("finish", -1)] + dag_maker.run_ti("finish", dr) + + # Clear and rerun with one fewer mapped task instance. + clear_task_instances(dr.get_task_instances(session=session), session=session) + ti = dr.get_task_instance(task_id="producer", session=session) + ti.try_number += 1 + session.merge(ti) + + dag_maker.run_ti("producer", dr) + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1)] + + mapped_states = session.execute( + select(TI.map_index, TI.state) + .where(TI.task_id == "work", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) + .order_by(TI.map_index) + ).all() + assert mapped_states == [ + (0, State.NONE), + (1, State.NONE), + (2, TaskInstanceState.REMOVED), + ] + + for ti in decision.schedulable_tis: + dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index) + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("finish", -1)] + + dag_maker.run_ti("finish", dr) + finish_ti = dr.get_task_instance(task_id="finish", map_index=-1, session=session) + assert finish_ti + assert finish_ti.state == TaskInstanceState.SUCCESS + + +def test_rerun_with_upstream_task_removed(session, dag_maker): + def _task_ids(tis): + return [(ti.task_id, ti.map_index) for ti in tis] + + with dag_maker("test", session=session): + upstream_1 = EmptyOperator(task_id="upstream_1") + upstream_2 = EmptyOperator(task_id="upstream_2") + downstream = EmptyOperator(task_id="downstream") + [upstream_1, upstream_2] >> downstream + + dr: DagRun = dag_maker.create_dagrun() + + dag_maker.run_ti("upstream_1", dr) + dag_maker.run_ti("upstream_2", dr) + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("downstream", -1)] + + dag_maker.run_ti("downstream", dr) + dr.update_state(session=session) + assert dr.state == DagRunState.SUCCESS + + # Rerun with upstream_1 removed + with dag_maker("test", session=session, serialized=True) as dag: + upstream_2 = EmptyOperator(task_id="upstream_2") + downstream = EmptyOperator(task_id="downstream") + upstream_2 >> downstream + + latest_version = DagVersion.get_latest_version(dag.dag_id) + assert latest_version.version_number == 2 + + clear_task_instances( + dr.get_task_instances(session=session), + session=session, + run_on_latest_version=True, + ) + + upstream_1 = dr.get_task_instance(task_id="upstream_1", map_index=-1, session=session) + assert upstream_1.state == TaskInstanceState.REMOVED + + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("upstream_2", -1)] + + dag_maker.run_ti("upstream_2", dr) + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("downstream", -1)] + + dag_maker.run_ti("downstream", dr) + dr.update_state(session=session) + assert dr.state == DagRunState.SUCCESS + + def test_operator_mapped_task_group_receives_value(dag_maker, session): with dag_maker(session=session): diff --git a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py index 31b240e4429f7..3dc2978481b78 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py @@ -1322,6 +1322,96 @@ def test_mapped_task_upstream_removed_with_none_failed_trigger_rules( _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + @pytest.mark.parametrize( + ("trigger_rule", "upstream_states"), + [ + ( + TriggerRule.ALL_SUCCESS, + _UpstreamTIStates( + success=3, + skipped=0, + failed=0, + upstream_failed=0, + removed=2, + done=5, + skipped_setup=0, + success_setup=0, + ), + ), + ( + TriggerRule.ALL_FAILED, + _UpstreamTIStates( + success=0, + skipped=0, + failed=3, + upstream_failed=0, + removed=2, + done=5, + skipped_setup=0, + success_setup=0, + ), + ), + ( + TriggerRule.NONE_FAILED, + _UpstreamTIStates( + success=3, + skipped=0, + failed=0, + upstream_failed=0, + removed=2, + done=5, + skipped_setup=0, + success_setup=0, + ), + ), + ( + TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, + _UpstreamTIStates( + success=3, + skipped=0, + failed=0, + upstream_failed=0, + removed=2, + done=5, + skipped_setup=0, + success_setup=0, + ), + ), + ( + TriggerRule.ALL_DONE_MIN_ONE_SUCCESS, + _UpstreamTIStates( + success=3, + skipped=0, + failed=0, + upstream_failed=0, + removed=2, + done=5, + skipped_setup=0, + success_setup=0, + ), + ), + ], + ) + def test_non_mapped_task_ignores_removed_upstream_tis( + self, + monkeypatch, + session, + get_task_instance, + flag_upstream_failed, + trigger_rule, + upstream_states, + ): + """ + Non-mapped trigger-rule checks should exclude removed upstream task instances. + """ + ti = get_task_instance( + trigger_rule, + normal_tasks=["upstream_1", "upstream_2", "upstream_3", "upstream_4", "upstream_5"], + ) + monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) + def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session): from airflow.sdk import task, task_group