Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 168 additions & 108 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from sqlalchemy.orm.interfaces import LoaderOption
from sqlalchemy.sql.selectable import Subquery
from sqlalchemy.sql.selectable import Select, Subquery

from airflow._shared.logging.types import Logger
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.executor_utils import ExecutorName
from airflow.executors.workloads.types import SchedulerWorkload
from airflow.models.pool import PoolStats
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.utils.sqlalchemy import CommitProhibitorGuard

Expand Down Expand Up @@ -476,26 +477,22 @@ def _debug_dump(self, signum: int, frame: FrameType | None) -> None:
self.log.info("\n\t".join(map(repr, callstack)))
self.log.info("-" * 80)

def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]:
def _acquire_pool_capacity(
self, max_tis: int, session: Session

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self, max_tis: int, session: Session
self, max_tis: int, *, session: Session

Or, convert it all to be kwarg only. Same with the others.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had the rule with session as kwarg mainly for the @provide_session decorator which is not used here. So a session needs to be provided by caller anyway. So it is rather a nit.

) -> tuple[dict[str, PoolStats], int, set[str]]:
"""
Find TIs that are ready for execution based on conditions.

Conditions include:
- pool limits
- DAG max_active_tasks
- executor state
- priority
- max active tis per DAG
- max active tis per DAG run

:param max_tis: Maximum number of TIs to queue in this loop.
:return: list[airflow.models.TaskInstance]
Acquire the scheduler critical-section lock and read current pool utilisation.

On PostgreSQL a transactional advisory lock is taken first so that only one
scheduler at a time enters the critical section; pool rows are then locked via
``SELECT … FOR UPDATE`` (or ``NOWAIT`` where supported).

Returns a ``(pools, effective_max_tis, starved_pools)`` tuple. ``effective_max_tis``
is zero when all pools are already full; callers should short-circuit in that case.
"""
from airflow.models.pool import Pool
from airflow.utils.db import DBLocks

executable_tis: list[TI] = []

if get_dialect_name(session) == "postgresql":
# Optimization: to avoid littering the DB errors of "ERROR: canceling statement due to lock
# timeout", try to take out a transactional advisory lock (unlocks automatically on
Expand Down Expand Up @@ -523,11 +520,33 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -

if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []

max_tis = int(min(max_tis, pool_slots_free))
return pools, 0, set()

effective_max_tis = int(min(max_tis, pool_slots_free))
starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0}
return pools, effective_max_tis, starved_pools

def _select_task_instances_to_queue(
self,
max_tis: int,
pools: dict[str, PoolStats],
starved_pools: set[str],
session: Session,

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
session: Session,
*,
session: Session,

) -> list[TI]:
"""
Select SCHEDULED TIs that can run given pool and concurrency constraints, and mark them QUEUED.

``pools`` and ``starved_pools`` must come from a prior ``_acquire_pool_capacity`` call (or an
equivalent pre-built dict in tests). The pool stats are updated in-place as slots are
virtually allocated to each selected TI.

:param max_tis: Upper bound on TIs to select this cycle.
:param pools: Current pool utilisation as returned by ``Pool.slots_stats``.
:param starved_pools: Pools that are already at capacity; TIs in these pools are skipped.
:param session: SQLAlchemy session (must remain open until the caller commits).
:return: TIs that were moved to QUEUED state.
"""
executable_tis: list[TI] = []

# dag_id to # of running tasks and (dag_id, task_id) to # of running tasks.
concurrency_map = ConcurrencyMap()
Expand All @@ -549,96 +568,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
num_starved_tasks = len(starved_tasks)
num_starved_tasks_task_dagrun_concurrency = len(starved_tasks_task_dagrun_concurrency)

# This behaves the same as 'concurrency_map.load()' with the difference that
# 'load()' executes immediately while '_get_current_dr_task_concurrency' creates a
# subquery object that is then executed along with main query.
# The results of 'load()' aren't used again here because by the time the main query
# executes, there could be a change that will be ignored.
dr_task_concurrency_subquery = _get_current_dr_task_concurrency(states=EXECUTION_STATES)

query = (
select(TI)
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.join(TI.dag_run)
.where(DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.where(~DM.is_paused)
.where(TI.state == TaskInstanceState.SCHEDULED)
.where(DM.bundle_name.is_not(None))
.join(
dr_task_concurrency_subquery,
and_(
TI.dag_id == dr_task_concurrency_subquery.c.dag_id,
TI.run_id == dr_task_concurrency_subquery.c.run_id,
),
isouter=True,
)
.where(
func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < DM.max_active_tasks
)
.order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
query = self._build_schedulable_tis_query(
starved_pools,
starved_dags,
starved_tasks,
starved_tasks_task_dagrun_concurrency,
max_tis,
)

# Starvation filters should be applied before computing the row_num based on the
# max_active_tasks limit. That way, starved dags and tasks that shouldn't run,
# won't occupy a slot.
if starved_pools:
query = query.where(TI.pool.not_in(starved_pools))

if starved_dags:
query = query.where(TI.dag_id.not_in(starved_dags))

if starved_tasks:
query = query.where(tuple_(TI.dag_id, TI.task_id).not_in(starved_tasks))

if starved_tasks_task_dagrun_concurrency:
query = query.where(
tuple_(TI.dag_id, TI.run_id, TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
)

# Create a subquery with row numbers partitioned by dag_id and run_id.
# Different dags can have the same run_id but
# the dag_id combined with the run_id uniquely identify a run.
ranked_query = (
query.add_columns(
func.row_number()
.over(
partition_by=[TI.dag_id, TI.run_id],
order_by=[-TI.priority_weight, DR.logical_date, TI.map_index],
)
.label("row_num"),
DM.max_active_tasks.label("dr_max_active_tasks"),
# Create columns for the order_by checks here for sqlite.
TI.priority_weight.label("priority_weight_for_ordering"),
DR.logical_date.label("logical_date_for_ordering"),
TI.map_index.label("map_index_for_ordering"),
)
).subquery()

# Select only rows where row_number <= max_active_tasks.
query = (
select(TI)
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.select_from(ranked_query)
.join(
TI,
(TI.dag_id == ranked_query.c.dag_id)
& (TI.task_id == ranked_query.c.task_id)
& (TI.run_id == ranked_query.c.run_id)
& (TI.map_index == ranked_query.c.map_index),
)
.where(ranked_query.c.row_num <= ranked_query.c.dr_max_active_tasks)
# Add the order_by columns from the ranked query for sqlite.
.order_by(
-ranked_query.c.priority_weight_for_ordering,
ranked_query.c.logical_date_for_ordering,
ranked_query.c.map_index_for_ordering,
)
.options(selectinload(TI.dag_model))
)

query = query.limit(max_tis)

timer = stats.timer("scheduler.critical_section_query_duration")
timer.start()

Expand Down Expand Up @@ -899,6 +836,126 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
stats.gauge("scheduler.tasks.starving", num_starving_tasks_total)
stats.gauge("scheduler.tasks.executable", len(executable_tis))

return self._mark_task_instances_queued(executable_tis, session)

def _build_schedulable_tis_query(
self,
starved_pools: set[str],
starved_dags: set[str],
starved_tasks: set[tuple[str, str]],
starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]],
max_tis: int,
) -> Select[tuple[TI]]:
"""
Build a query that fetches SCHEDULED TIs eligible for execution this cycle.

Applies current starvation exclusions so that saturated pools, DAGs, or tasks
don't re-appear in the candidate set. Row-number windowing enforces
``max_active_tasks`` per DagRun. The returned query is ready to be wrapped
with ``with_row_locks`` and executed by the caller; no session is required here.

This behaves the same as calling ``concurrency_map.load()`` followed by
``_get_current_dr_task_concurrency``, with the difference that the subquery
object is built here and executed as part of the main query, so any state
changes between construction and execution are naturally ignored.
"""
dr_task_concurrency_subquery = _get_current_dr_task_concurrency(states=EXECUTION_STATES)

query = (
select(TI)
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.join(TI.dag_run)
.where(DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.where(~DM.is_paused)
.where(TI.state == TaskInstanceState.SCHEDULED)
.where(DM.bundle_name.is_not(None))
.join(
dr_task_concurrency_subquery,
and_(
TI.dag_id == dr_task_concurrency_subquery.c.dag_id,
TI.run_id == dr_task_concurrency_subquery.c.run_id,
),
isouter=True,
)
.where(func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < DM.max_active_tasks)
.order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
)

# Starvation filters should be applied before computing the row_num based on the
# max_active_tasks limit. That way, starved dags and tasks that shouldn't run,
# won't occupy a slot.
if starved_pools:
query = query.where(TI.pool.not_in(starved_pools))

if starved_dags:
query = query.where(TI.dag_id.not_in(starved_dags))

if starved_tasks:
query = query.where(tuple_(TI.dag_id, TI.task_id).not_in(starved_tasks))

if starved_tasks_task_dagrun_concurrency:
query = query.where(
tuple_(TI.dag_id, TI.run_id, TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
)

# Create a subquery with row numbers partitioned by dag_id and run_id.
# Different dags can have the same run_id but
# the dag_id combined with the run_id uniquely identify a run.
ranked_query = (
query.add_columns(
func.row_number()
.over(
partition_by=[TI.dag_id, TI.run_id],
order_by=[-TI.priority_weight, DR.logical_date, TI.map_index],
)
.label("row_num"),
DM.max_active_tasks.label("dr_max_active_tasks"),
# Create columns for the order_by checks here for sqlite.
TI.priority_weight.label("priority_weight_for_ordering"),
DR.logical_date.label("logical_date_for_ordering"),
TI.map_index.label("map_index_for_ordering"),
)
).subquery()

# Select only rows where row_number <= max_active_tasks.
return (
select(TI)
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.select_from(ranked_query)
.join(
TI,
(TI.dag_id == ranked_query.c.dag_id)
& (TI.task_id == ranked_query.c.task_id)
& (TI.run_id == ranked_query.c.run_id)
& (TI.map_index == ranked_query.c.map_index),
)
.where(ranked_query.c.row_num <= ranked_query.c.dr_max_active_tasks)
# Add the order_by columns from the ranked query for sqlite.
.order_by(
-ranked_query.c.priority_weight_for_ordering,
ranked_query.c.logical_date_for_ordering,
ranked_query.c.map_index_for_ordering,
)
.options(selectinload(TI.dag_model))
.limit(max_tis)
)

def _mark_task_instances_queued(self, executable_tis: list[TI], session: Session) -> list[TI]:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _mark_task_instances_queued(self, executable_tis: list[TI], session: Session) -> list[TI]:
def _mark_task_instances_queued(self, executable_tis: list[TI], *, session: Session) -> list[TI]:

"""
Bulk-update ``executable_tis`` to QUEUED state and detach them from the session.

Handles ``external_executor_id`` pre-assignment for executors that opt in via
``pre_assigns_external_executor_id``, using a CASE expression in mixed-executor
deployments. UUIDs are read back via RETURNING on PostgreSQL and a follow-up
SELECT on other databases.

After this call the TIs are transient (detached from the ORM session) and carry
their final ``external_executor_id`` values in memory.

:return: ``executable_tis`` (same list, post-transient) or ``[]`` if the filter
could not be built (should not happen in practice).
"""
if executable_tis:
task_instance_str = "\n".join(
f"\t{x!r} (id={x.id}, try_number={x.try_number})" for x in executable_tis
Expand Down Expand Up @@ -1072,7 +1129,10 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int:
self.log.debug("max_tis query size is less than or equal to zero. No query will be performed!")
return 0

queued_tis = self._executable_task_instances_to_queued(max_tis, session=session)
pools, max_tis, starved_pools = self._acquire_pool_capacity(max_tis, session=session)
if max_tis == 0:
return 0
queued_tis = self._select_task_instances_to_queue(max_tis, pools, starved_pools, session=session)

# Sort queued TIs to their respective executor
executor_to_queued_tis = self._executor_to_workloads(queued_tis, session)
Expand Down
Loading
Loading