Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/67672.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The scheduler no longer issues the trigger-rule upstream task-instance count query once per downstream task. For tasks that share the same upstreams within a scheduling pass, the count is computed once and reused, cutting database round-trips for DAGs where a mapped upstream feeds many downstream tasks.
12 changes: 11 additions & 1 deletion airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,16 +1520,26 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
if new_tis is not None:
additional_tis.extend(new_tis)
expansion_happened = True
# Expansion changes a mapped task's instance count, which invalidates the
# trigger-rule upstream-count memo on this DepContext (a downstream evaluated
# later in this same pass must see the post-expansion count).
dep_context.upstream_task_id_counts.clear()
if new_tis is None and schedulable.state in SCHEDULEABLE_STATES:
# It's enough to revise map index once per task id,
# checking the map index for each mapped task significantly slows down scheduling
if schedulable.task.task_id not in revised_map_index_task_ids:
ready_tis.extend(
revised_tis = list(
self._revise_map_indexes_if_mapped(
schedulable.task, dag_version_id=schedulable.dag_version_id, session=session
)
)
ready_tis.extend(revised_tis)
revised_map_index_task_ids.add(schedulable.task.task_id)
if revised_tis:
Comment on lines -1527 to +1538

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if the size of revised_tis would be significant. We could do

prev_ti_size = len(ready_tis)
ready_tis.extend(self._revise_map_indexes_if_mapped(...)
if len(ready_tis) > prev_ti_size:
    ...

to avoid building a list.

# Revising a mapped task can add new instances, growing its instance count
# the same way expansion does. Drop the upstream-count memo so a downstream
# evaluated later in this pass recomputes it instead of reading a stale value.
dep_context.upstream_task_id_counts.clear()

# _revise_map_indexes_if_mapped might mark the current task as REMOVED
# after calculating mapped task length, so we need to re-check
Expand Down
13 changes: 13 additions & 0 deletions airflow-core/src/airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ class DepContext:
have_changed_ti_states: bool = False
"""Have any of the TIs state's been changed as a result of evaluating dependencies"""

upstream_task_id_counts: dict[tuple[str, str, frozenset[str]], list[tuple[str, int]]] = attr.ib(
factory=dict, init=False
)
"""
Per-pass memo of the trigger-rule upstream task-instance counts, keyed by
``(dag_id, run_id, frozenset of direct-upstream task_ids)``.

Shares the lifetime and snapshot semantics of ``finished_tis`` (one scheduling pass). Only
populated for the "simple" case where the count-query predicate is exactly
``task_id IN (upstream_ids)`` and is therefore identical for every downstream sharing the same
direct upstreams; the mapped-task-group case uses per-ti map-index predicates and is not cached.
"""

def ensure_finished_tis(self, dag_run: DagRun, session: Session) -> list[TaskInstance]:
"""
Ensure finished_tis is populated if it's currently None, which allows running tasks without dag_run.
Expand Down
39 changes: 29 additions & 10 deletions airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import collections.abc
import functools
from collections import Counter
from collections.abc import Iterator, KeysView, Mapping, Sequence
from collections.abc import Iterator, KeysView, Mapping
from typing import TYPE_CHECKING, NamedTuple

from sqlalchemy import and_, func, or_, select
Expand All @@ -31,7 +31,6 @@
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.engine import Row
from sqlalchemy.orm import Session
from sqlalchemy.sql import ColumnElement

Expand All @@ -40,7 +39,6 @@
from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
from airflow.typing_compat import Unpack


class _UpstreamTIStates(NamedTuple):
Expand Down Expand Up @@ -372,13 +370,34 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
upstream = len(upstream_tasks)
upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup)
else:
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
task_id_counts: Sequence[Row[Unpack[tuple[str, int]]]] = session.execute( # type: ignore[type-arg]
select(TaskInstance.task_id, func.count(TaskInstance.task_id))
.where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
.where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks)))
.group_by(TaskInstance.task_id)
).all()
# In the simple case, `_iter_upstream_conditions` emits exactly
# `task_id IN (upstream_task_ids)` (the matching `get_closest_mapped_task_group()
# is None` branch). That predicate, and therefore the resulting counts, are
# identical for every downstream that shares the same set of direct upstreams, so
# we memoize them on the DepContext and run the query once per pass instead of
# once per downstream. The mapped-task-group case uses per-ti map-index predicates
# and is left un-memoized. The cache shares finished_tis' per-pass snapshot
# semantics; it is cleared in DagRun._get_ready_tis when a mapped task expands and
# changes its instance count.
cache_key: tuple[str, str, frozenset[str]] | None = None
task_id_counts: list[tuple[str, int]] | None = None
if task.get_closest_mapped_task_group() is None:
cache_key = (ti.dag_id, ti.run_id, frozenset(upstream_tasks))
task_id_counts = dep_context.upstream_task_id_counts.get(cache_key)
if task_id_counts is None:
task_id_counts = [
(task_id, count)
for task_id, count in session.execute(
select(TaskInstance.task_id, func.count(TaskInstance.task_id))
.where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
.where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks)))
.group_by(TaskInstance.task_id)
)
]
if cache_key is not None:
dep_context.upstream_task_id_counts[cache_key] = task_id_counts
# `task_id_counts` only contains task_ids matched by `task_id IN (upstream_tasks)`,
# so every key is present in `upstream_tasks`; is_setup is re-derived locally.
upstream = sum(count for _, count in task_id_counts)
upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup)

Expand Down
204 changes: 204 additions & 0 deletions airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import Mock

import pytest
from sqlalchemy import event

import airflow.settings
from airflow.models.dag_version import DagVersion
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -1997,3 +2000,204 @@ def _test_trigger_rule(
else:
assert not dep_statuses
assert ti.state == expected_ti_state


@contextmanager
def _count_upstream_count_queries():
"""
Count only the trigger-rule upstream-count query.

That query is ``SELECT task_instance.task_id, count(task_instance.task_id) ...
GROUP BY task_instance.task_id``; the filter below matches it and nothing else emitted while
evaluating the trigger rule for a plain (non-mapped-task-group) downstream.
"""
counter = {"n": 0}

def _on_execute(conn, cursor, statement, parameters, context, executemany):
sql = statement.lower()
if "count(" in sql and "group by" in sql and "task_id" in sql and "task_instance" in sql:
counter["n"] += 1

event.listen(airflow.settings.engine, "after_cursor_execute", _on_execute)
try:
yield counter
finally:
event.remove(airflow.settings.engine, "after_cursor_execute", _on_execute)


def _expand_mapped_task(dr, dag, task_id, states, session):
"""
Materialise ``len(states)`` instances of a mapped ``task_id`` with the given states.

Handles both shapes: a single unexpanded ``map_index=-1`` placeholder (expand it), or a task
already pre-expanded at dagrun creation (just set states on the existing instances).
"""
tis = [ti for ti in dr.get_task_instances(session=session) if ti.task_id == task_id]
assert tis, f"no task instances found for {task_id!r}"
if len(tis) == 1 and tis[0].map_index == -1:
base = tis[0]
mapped_task = base.task
dag_version = DagVersion.get_latest_version(dag.dag_id)
if TYPE_CHECKING:
assert dag_version
base.map_index = 0
base.state = states[0]
session.merge(base)
for map_index in range(1, len(states)):
ti = TaskInstance(
mapped_task, run_id=dr.run_id, map_index=map_index, dag_version_id=dag_version.id
)
ti.state = states[map_index]
session.add(ti)
ti.dag_run = dr
else:
tis.sort(key=lambda ti: ti.map_index)
assert len(tis) == len(states), f"{task_id!r}: {len(tis)} instances but {len(states)} states given"
for ti, state in zip(tis, states):
ti.state = state
session.merge(ti)
session.flush()


class TestTriggerRuleUpstreamCountMemo:
"""The upstream-count query is memoized per scheduling pass (one DepContext) in the simple case."""

def _make_dag(
self, dag_maker, session, *, n_downstreams, src_states, trigger_rule=TriggerRule.ALL_SUCCESS
):
@task
def src(i):
return i

@task(trigger_rule=trigger_rule)
def plain():
return 1

with dag_maker(dag_id="trmemo_simple", session=session) as dag:
nums = src.expand(i=list(range(len(src_states))))
for k in range(n_downstreams):
nums >> plain.override(task_id=f"p{k}")()

dr = dag_maker.create_dagrun()
_expand_mapped_task(dr, dag, "src", src_states, session)
session.commit()
return dr

def test_memoized_across_downstreams_sharing_upstream(self, dag_maker, session):
"""N plain downstreams of the same mapped upstream issue the count query once per pass."""
dr = self._make_dag(dag_maker, session, n_downstreams=4, src_states=[SUCCESS, SUCCESS, SUCCESS])
dep_context = DepContext()
with _count_upstream_count_queries() as counter:
for k in range(4):
ti = dr.get_task_instance(f"p{k}", session=session)
statuses = list(
TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
)
# All three upstreams succeeded -> ALL_SUCCESS is met -> no failing status.
assert statuses == []
assert counter["n"] == 1

def test_memoized_count_value_is_correct(self, dag_maker, session):
"""
Guards that the cached value is the real count, not just "present".

Three upstream instances exist but only two are finished-success; ALL_SUCCESS must NOT be met
because ``upstream`` (3) > ``success`` (2). A wrongly-cached count of 2 would let it pass.
"""
dr = self._make_dag(
dag_maker,
session,
n_downstreams=2,
src_states=[SUCCESS, SUCCESS, TaskInstanceState.RUNNING],
)
dep_context = DepContext()
with _count_upstream_count_queries() as counter:
for k in range(2):
ti = dr.get_task_instance(f"p{k}", session=session)
statuses = list(
TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
)
assert len(statuses) == 1
assert not statuses[0].passed
assert counter["n"] == 1

def test_distinct_upstream_sets_are_not_collapsed(self, dag_maker, session):
"""Downstreams with different upstream sets get different cache keys -> one query each."""

@task
def src_a(i):
return i

@task
def src_b(i):
return i

@task
def plain():
return 1

with dag_maker(dag_id="trmemo_keys", session=session) as dag:
a = src_a.expand(i=[0, 1])
b = src_b.expand(i=[0, 1, 2])
a >> plain.override(task_id="pa")()
b >> plain.override(task_id="pb")()

dr = dag_maker.create_dagrun()
_expand_mapped_task(dr, dag, "src_a", [SUCCESS, SUCCESS], session)
_expand_mapped_task(dr, dag, "src_b", [SUCCESS, SUCCESS, SUCCESS], session)
session.commit()

dep_context = DepContext()
with _count_upstream_count_queries() as counter:
for task_id in ("pa", "pb"):
ti = dr.get_task_instance(task_id, session=session)
statuses = list(
TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
)
assert statuses == []
assert counter["n"] == 2

def test_revise_growing_a_mapped_upstream_clears_memo_within_pass(self, dag_maker, session):
"""
When a mapped upstream grows via ``_revise_map_indexes_if_mapped`` mid-pass, the upstream-count
memo must be dropped so a downstream evaluated later in the same pass recomputes the count
instead of reusing the pre-grow value.

Driving ``_get_ready_tis`` with a fixed order ``[d1, mapped-instance, d2]``: d1 populates the
memo over the pre-grow instances, the mapped instance is revised and grows, then d2 must
recompute, so the upstream-count query runs twice. Without the cache clear in
``_get_ready_tis`` d2 reads the stale value and the query runs only once.
"""

@task
def src(arg):
return arg

@task
def plain():
return 1

def _build(length):
with dag_maker(dag_id="trmemo_revise", session=session, serialized=True):
nums = src.expand(arg=list(range(length)))
nums >> plain.override(task_id="d1")()
nums >> plain.override(task_id="d2")()

_build(4)
dr = dag_maker.create_dagrun()
# Re-serialize the DAG with the mapped task one element longer; revise adds map_index 4.
_build(5)
dr.dag = dag_maker.serialized_dag
session.commit()

ser = dag_maker.serialized_dag
d1 = dr.get_task_instance("d1", session=session)
d2 = dr.get_task_instance("d2", session=session)
src0 = dr.get_task_instance("src", map_index=0, session=session)
d1.task = ser.get_task("d1")
d2.task = ser.get_task("d2")
src0.task = ser.get_task("src")

with _count_upstream_count_queries() as counter:
dr._get_ready_tis([d1, src0, d2], [], session)
assert counter["n"] == 2
Loading