Skip to content

Commit bd38d91

Browse files
committed
chore: add unit tests
1 parent ccf3967 commit bd38d91

2 files changed

Lines changed: 160 additions & 0 deletions

File tree

airflow-core/tests/unit/models/test_dagrun.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2423,6 +2423,76 @@ def mapped_print_value(arg):
24232423
assert len(success_tis) == rerun_length
24242424

24252425

2426+
def test_mapped_task_length_reduction_rerun_downstream_not_deadlocked(session, dag_maker):
2427+
@task
2428+
def producer():
2429+
context = get_current_context()
2430+
if context["ti"].try_number == 0:
2431+
return [i for i in range(3)]
2432+
return [i for i in range(2)]
2433+
2434+
@task
2435+
def work(arg):
2436+
return arg
2437+
2438+
@task
2439+
def finish(data):
2440+
return sum(data)
2441+
2442+
def _task_ids(tis):
2443+
return [(ti.task_id, ti.map_index) for ti in tis]
2444+
2445+
with dag_maker(session=session):
2446+
produced = producer()
2447+
mapped = work.expand(arg=produced)
2448+
done = finish(produced)
2449+
mapped >> done
2450+
2451+
dr: DagRun = dag_maker.create_dagrun()
2452+
2453+
# First run with 3 mapped task instances.
2454+
dag_maker.run_ti("producer", dr)
2455+
decision = dr.task_instance_scheduling_decisions(session=session)
2456+
assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1), ("work", 2)]
2457+
2458+
for ti in decision.schedulable_tis:
2459+
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
2460+
decision = dr.task_instance_scheduling_decisions(session=session)
2461+
assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
2462+
dag_maker.run_ti("finish", dr)
2463+
2464+
# Clear and rerun with one fewer mapped task instance.
2465+
clear_task_instances(dr.get_task_instances(session=session), session=session)
2466+
ti = dr.get_task_instance(task_id="producer", session=session)
2467+
ti.try_number += 1
2468+
session.merge(ti)
2469+
2470+
dag_maker.run_ti("producer", dr)
2471+
decision = dr.task_instance_scheduling_decisions(session=session)
2472+
assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1)]
2473+
2474+
mapped_states = session.execute(
2475+
select(TI.map_index, TI.state)
2476+
.where(TI.task_id == "work", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
2477+
.order_by(TI.map_index)
2478+
).all()
2479+
assert mapped_states == [
2480+
(0, State.NONE),
2481+
(1, State.NONE),
2482+
(2, TaskInstanceState.REMOVED),
2483+
]
2484+
2485+
for ti in decision.schedulable_tis:
2486+
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
2487+
decision = dr.task_instance_scheduling_decisions(session=session)
2488+
assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
2489+
2490+
dag_maker.run_ti("finish", dr)
2491+
finish_ti = dr.get_task_instance(task_id="finish", map_index=-1, session=session)
2492+
assert finish_ti
2493+
assert finish_ti.state == TaskInstanceState.SUCCESS
2494+
2495+
24262496
def test_operator_mapped_task_group_receives_value(dag_maker, session):
24272497
with dag_maker(session=session):
24282498

airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,96 @@ def test_mapped_task_upstream_removed_with_none_failed_trigger_rules(
13221322

13231323
_test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed)
13241324

1325+
@pytest.mark.parametrize("flag_upstream_failed", [True, False])
1326+
@pytest.mark.parametrize(
1327+
("trigger_rule", "upstream_states"),
1328+
[
1329+
(
1330+
TriggerRule.ALL_SUCCESS,
1331+
_UpstreamTIStates(
1332+
success=3,
1333+
skipped=0,
1334+
failed=0,
1335+
upstream_failed=0,
1336+
removed=2,
1337+
done=5,
1338+
skipped_setup=0,
1339+
success_setup=0,
1340+
),
1341+
),
1342+
(
1343+
TriggerRule.ALL_FAILED,
1344+
_UpstreamTIStates(
1345+
success=0,
1346+
skipped=0,
1347+
failed=3,
1348+
upstream_failed=0,
1349+
removed=2,
1350+
done=5,
1351+
skipped_setup=0,
1352+
success_setup=0,
1353+
),
1354+
),
1355+
(
1356+
TriggerRule.NONE_FAILED,
1357+
_UpstreamTIStates(
1358+
success=3,
1359+
skipped=0,
1360+
failed=0,
1361+
upstream_failed=0,
1362+
removed=2,
1363+
done=5,
1364+
skipped_setup=0,
1365+
success_setup=0,
1366+
),
1367+
),
1368+
(
1369+
TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
1370+
_UpstreamTIStates(
1371+
success=3,
1372+
skipped=0,
1373+
failed=0,
1374+
upstream_failed=0,
1375+
removed=2,
1376+
done=5,
1377+
skipped_setup=0,
1378+
success_setup=0,
1379+
),
1380+
),
1381+
(
1382+
TriggerRule.ALL_DONE_MIN_ONE_SUCCESS,
1383+
_UpstreamTIStates(
1384+
success=3,
1385+
skipped=0,
1386+
failed=0,
1387+
upstream_failed=0,
1388+
removed=2,
1389+
done=5,
1390+
skipped_setup=0,
1391+
success_setup=0,
1392+
),
1393+
),
1394+
],
1395+
)
1396+
def test_non_mapped_task_ignores_removed_upstream_tis(
1397+
self,
1398+
monkeypatch,
1399+
session,
1400+
get_task_instance,
1401+
flag_upstream_failed,
1402+
trigger_rule,
1403+
upstream_states,
1404+
):
1405+
"""
1406+
Non-mapped trigger-rule checks should exclude removed upstream task instances.
1407+
"""
1408+
ti = get_task_instance(
1409+
trigger_rule,
1410+
normal_tasks=["upstream_1", "upstream_2", "upstream_3", "upstream_4", "upstream_5"],
1411+
)
1412+
monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states)
1413+
_test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed)
1414+
13251415

13261416
def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
13271417
from airflow.sdk import task, task_group

0 commit comments

Comments
 (0)