Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
)
)
elif trigger_rule == TR.ALL_SUCCESS:
num_failures = upstream - success
if ti.map_index > -1:
num_failures -= removed
num_failures = upstream - success - removed
if num_failures > 0:
yield self._failing_status(
reason=(
Expand All @@ -504,9 +502,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
)
)
elif trigger_rule == TR.ALL_FAILED:
num_success = upstream - failed - upstream_failed
if ti.map_index > -1:
num_success -= removed
num_success = upstream - failed - upstream_failed - removed
if num_success > 0:
yield self._failing_status(
reason=(
Expand All @@ -527,9 +523,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
)
)
elif trigger_rule == TR.NONE_FAILED or trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
num_failures = upstream - success - skipped
if ti.map_index > -1:
num_failures -= removed
num_failures = upstream - success - skipped - removed
if num_failures > 0:
yield self._failing_status(
reason=(
Expand Down Expand Up @@ -581,11 +575,8 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
)
elif trigger_rule == TR.ALL_DONE_MIN_ONE_SUCCESS:
# For this trigger rule, skipped tasks are not considered "done"
non_skipped_done = success + failed + upstream_failed + removed
non_skipped_upstream = upstream - skipped
if ti.map_index > -1:
non_skipped_upstream -= removed
non_skipped_done -= removed
non_skipped_done = success + failed + upstream_failed
non_skipped_upstream = upstream - skipped - removed

if skipped > 0:
yield self._failing_status(
Expand Down
121 changes: 121 additions & 0 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,127 @@ def mapped_print_value(arg):
assert len(success_tis) == rerun_length


def test_mapped_task_length_reduction_rerun_downstream_not_deadlocked(session, dag_maker):
@task
def producer():
context = get_current_context()
if context["ti"].try_number == 0:
return [i for i in range(3)]
return [i for i in range(2)]

@task
def work(arg):
return arg

@task
def finish(data):
return sum(data)

def _task_ids(tis):
return [(ti.task_id, ti.map_index) for ti in tis]

with dag_maker(session=session):
produced = producer()
mapped = work.expand(arg=produced)
done = finish(produced)
mapped >> done

dr: DagRun = dag_maker.create_dagrun()

# First run with 3 mapped task instances.
dag_maker.run_ti("producer", dr)
decision = dr.task_instance_scheduling_decisions(session=session)
assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1), ("work", 2)]

for ti in decision.schedulable_tis:
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
decision = dr.task_instance_scheduling_decisions(session=session)
assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
dag_maker.run_ti("finish", dr)

# Clear and rerun with one fewer mapped task instance.
clear_task_instances(dr.get_task_instances(session=session), session=session)
ti = dr.get_task_instance(task_id="producer", session=session)
ti.try_number += 1
session.merge(ti)

dag_maker.run_ti("producer", dr)
decision = dr.task_instance_scheduling_decisions(session=session)
assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1)]

mapped_states = session.execute(
select(TI.map_index, TI.state)
.where(TI.task_id == "work", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
).all()
assert mapped_states == [
(0, State.NONE),
(1, State.NONE),
(2, TaskInstanceState.REMOVED),
]

for ti in decision.schedulable_tis:
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
decision = dr.task_instance_scheduling_decisions(session=session)
assert _task_ids(decision.schedulable_tis) == [("finish", -1)]

dag_maker.run_ti("finish", dr)
finish_ti = dr.get_task_instance(task_id="finish", map_index=-1, session=session)
assert finish_ti
assert finish_ti.state == TaskInstanceState.SUCCESS


def test_rerun_with_upstream_task_removed(session, dag_maker):
def _task_ids(tis):
return [(ti.task_id, ti.map_index) for ti in tis]

with dag_maker("test", session=session):
upstream_1 = EmptyOperator(task_id="upstream_1")
upstream_2 = EmptyOperator(task_id="upstream_2")
downstream = EmptyOperator(task_id="downstream")
[upstream_1, upstream_2] >> downstream

dr: DagRun = dag_maker.create_dagrun()

dag_maker.run_ti("upstream_1", dr)
dag_maker.run_ti("upstream_2", dr)
decision = dr.task_instance_scheduling_decisions(session=session)
assert _task_ids(decision.schedulable_tis) == [("downstream", -1)]

dag_maker.run_ti("downstream", dr)
dr.update_state(session=session)
assert dr.state == DagRunState.SUCCESS

# Rerun with upstream_1 removed
with dag_maker("test", session=session, serialized=True) as dag:
upstream_2 = EmptyOperator(task_id="upstream_2")
downstream = EmptyOperator(task_id="downstream")
upstream_2 >> downstream

latest_version = DagVersion.get_latest_version(dag.dag_id)
assert latest_version.version_number == 2

clear_task_instances(
dr.get_task_instances(session=session),
session=session,
run_on_latest_version=True,
)

upstream_1 = dr.get_task_instance(task_id="upstream_1", map_index=-1, session=session)
assert upstream_1.state == TaskInstanceState.REMOVED

decision = dr.task_instance_scheduling_decisions(session=session)
assert _task_ids(decision.schedulable_tis) == [("upstream_2", -1)]

dag_maker.run_ti("upstream_2", dr)
decision = dr.task_instance_scheduling_decisions(session=session)
assert _task_ids(decision.schedulable_tis) == [("downstream", -1)]

dag_maker.run_ti("downstream", dr)
dr.update_state(session=session)
assert dr.state == DagRunState.SUCCESS


def test_operator_mapped_task_group_receives_value(dag_maker, session):
with dag_maker(session=session):

Expand Down
90 changes: 90 additions & 0 deletions airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,96 @@ def test_mapped_task_upstream_removed_with_none_failed_trigger_rules(

_test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed)

@pytest.mark.parametrize("flag_upstream_failed", [True, False])
@pytest.mark.parametrize(
("trigger_rule", "upstream_states"),
[
(
TriggerRule.ALL_SUCCESS,
_UpstreamTIStates(
success=3,
skipped=0,
failed=0,
upstream_failed=0,
removed=2,
done=5,
skipped_setup=0,
success_setup=0,
),
),
(
TriggerRule.ALL_FAILED,
_UpstreamTIStates(
success=0,
skipped=0,
failed=3,
upstream_failed=0,
removed=2,
done=5,
skipped_setup=0,
success_setup=0,
),
),
(
TriggerRule.NONE_FAILED,
_UpstreamTIStates(
success=3,
skipped=0,
failed=0,
upstream_failed=0,
removed=2,
done=5,
skipped_setup=0,
success_setup=0,
),
),
(
TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
_UpstreamTIStates(
success=3,
skipped=0,
failed=0,
upstream_failed=0,
removed=2,
done=5,
skipped_setup=0,
success_setup=0,
),
),
(
TriggerRule.ALL_DONE_MIN_ONE_SUCCESS,
_UpstreamTIStates(
success=3,
skipped=0,
failed=0,
upstream_failed=0,
removed=2,
done=5,
skipped_setup=0,
success_setup=0,
),
),
],
)
def test_non_mapped_task_ignores_removed_upstream_tis(
self,
monkeypatch,
session,
get_task_instance,
flag_upstream_failed,
trigger_rule,
upstream_states,
):
"""
Non-mapped trigger-rule checks should exclude removed upstream task instances.
"""
ti = get_task_instance(
trigger_rule,
normal_tasks=["upstream_1", "upstream_2", "upstream_3", "upstream_4", "upstream_5"],
)
monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states)
_test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed)


def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
from airflow.sdk import task, task_group
Expand Down
Loading