diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 1a3f55b7f6f3d..b5725ad4ba7dd 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -110,12 +110,13 @@ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session from sqlalchemy.orm.interfaces import LoaderOption - from sqlalchemy.sql.selectable import Subquery + from sqlalchemy.sql.selectable import Select, Subquery from airflow._shared.logging.types import Logger from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.executors.workloads.types import SchedulerWorkload + from airflow.models.pool import PoolStats from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard @@ -476,26 +477,22 @@ def _debug_dump(self, signum: int, frame: FrameType | None) -> None: self.log.info("\n\t".join(map(repr, callstack))) self.log.info("-" * 80) - def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]: + def _acquire_pool_capacity( + self, max_tis: int, session: Session + ) -> tuple[dict[str, PoolStats], int, set[str]]: """ - Find TIs that are ready for execution based on conditions. - - Conditions include: - - pool limits - - DAG max_active_tasks - - executor state - - priority - - max active tis per DAG - - max active tis per DAG run - - :param max_tis: Maximum number of TIs to queue in this loop. - :return: list[airflow.models.TaskInstance] + Acquire the scheduler critical-section lock and read current pool utilisation. + + On PostgreSQL a transactional advisory lock is taken first so that only one + scheduler at a time enters the critical section; pool rows are then locked via + ``SELECT … FOR UPDATE`` (or ``NOWAIT`` where supported). + + Returns a ``(pools, effective_max_tis, starved_pools)`` tuple. ``effective_max_tis`` + is zero when all pools are already full; callers should short-circuit in that case. """ from airflow.models.pool import Pool from airflow.utils.db import DBLocks - executable_tis: list[TI] = [] - if get_dialect_name(session) == "postgresql": # Optimization: to avoid littering the DB errors of "ERROR: canceling statement due to lock # timeout", try to take out a transactional advisory lock (unlocks automatically on @@ -523,11 +520,33 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - if pool_slots_free == 0: self.log.debug("All pools are full!") - return [] - - max_tis = int(min(max_tis, pool_slots_free)) + return pools, 0, set() + effective_max_tis = int(min(max_tis, pool_slots_free)) starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0} + return pools, effective_max_tis, starved_pools + + def _select_task_instances_to_queue( + self, + max_tis: int, + pools: dict[str, PoolStats], + starved_pools: set[str], + session: Session, + ) -> list[TI]: + """ + Select SCHEDULED TIs that can run given pool and concurrency constraints, and mark them QUEUED. + + ``pools`` and ``starved_pools`` must come from a prior ``_acquire_pool_capacity`` call (or an + equivalent pre-built dict in tests). The pool stats are updated in-place as slots are + virtually allocated to each selected TI. + + :param max_tis: Upper bound on TIs to select this cycle. + :param pools: Current pool utilisation as returned by ``Pool.slots_stats``. + :param starved_pools: Pools that are already at capacity; TIs in these pools are skipped. + :param session: SQLAlchemy session (must remain open until the caller commits). + :return: TIs that were moved to QUEUED state. + """ + executable_tis: list[TI] = [] # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks. concurrency_map = ConcurrencyMap() @@ -549,96 +568,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - num_starved_tasks = len(starved_tasks) num_starved_tasks_task_dagrun_concurrency = len(starved_tasks_task_dagrun_concurrency) - # This behaves the same as 'concurrency_map.load()' with the difference that - # 'load()' executes immediately while '_get_current_dr_task_concurrency' creates a - # subquery object that is then executed along with main query. - # The results of 'load()' aren't used again here because by the time the main query - # executes, there could be a change that will be ignored. - dr_task_concurrency_subquery = _get_current_dr_task_concurrency(states=EXECUTION_STATES) - - query = ( - select(TI) - .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") - .join(TI.dag_run) - .where(DR.state == DagRunState.RUNNING) - .join(TI.dag_model) - .where(~DM.is_paused) - .where(TI.state == TaskInstanceState.SCHEDULED) - .where(DM.bundle_name.is_not(None)) - .join( - dr_task_concurrency_subquery, - and_( - TI.dag_id == dr_task_concurrency_subquery.c.dag_id, - TI.run_id == dr_task_concurrency_subquery.c.run_id, - ), - isouter=True, - ) - .where( - func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < DM.max_active_tasks - ) - .order_by(-TI.priority_weight, DR.logical_date, TI.map_index) + query = self._build_schedulable_tis_query( + starved_pools, + starved_dags, + starved_tasks, + starved_tasks_task_dagrun_concurrency, + max_tis, ) - # Starvation filters should be applied before computing the row_num based on the - # max_active_tasks limit. That way, starved dags and tasks that shouldn't run, - # won't occupy a slot. - if starved_pools: - query = query.where(TI.pool.not_in(starved_pools)) - - if starved_dags: - query = query.where(TI.dag_id.not_in(starved_dags)) - - if starved_tasks: - query = query.where(tuple_(TI.dag_id, TI.task_id).not_in(starved_tasks)) - - if starved_tasks_task_dagrun_concurrency: - query = query.where( - tuple_(TI.dag_id, TI.run_id, TI.task_id).not_in(starved_tasks_task_dagrun_concurrency) - ) - - # Create a subquery with row numbers partitioned by dag_id and run_id. - # Different dags can have the same run_id but - # the dag_id combined with the run_id uniquely identify a run. - ranked_query = ( - query.add_columns( - func.row_number() - .over( - partition_by=[TI.dag_id, TI.run_id], - order_by=[-TI.priority_weight, DR.logical_date, TI.map_index], - ) - .label("row_num"), - DM.max_active_tasks.label("dr_max_active_tasks"), - # Create columns for the order_by checks here for sqlite. - TI.priority_weight.label("priority_weight_for_ordering"), - DR.logical_date.label("logical_date_for_ordering"), - TI.map_index.label("map_index_for_ordering"), - ) - ).subquery() - - # Select only rows where row_number <= max_active_tasks. - query = ( - select(TI) - .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") - .select_from(ranked_query) - .join( - TI, - (TI.dag_id == ranked_query.c.dag_id) - & (TI.task_id == ranked_query.c.task_id) - & (TI.run_id == ranked_query.c.run_id) - & (TI.map_index == ranked_query.c.map_index), - ) - .where(ranked_query.c.row_num <= ranked_query.c.dr_max_active_tasks) - # Add the order_by columns from the ranked query for sqlite. - .order_by( - -ranked_query.c.priority_weight_for_ordering, - ranked_query.c.logical_date_for_ordering, - ranked_query.c.map_index_for_ordering, - ) - .options(selectinload(TI.dag_model)) - ) - - query = query.limit(max_tis) - timer = stats.timer("scheduler.critical_section_query_duration") timer.start() @@ -899,6 +836,126 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - stats.gauge("scheduler.tasks.starving", num_starving_tasks_total) stats.gauge("scheduler.tasks.executable", len(executable_tis)) + return self._mark_task_instances_queued(executable_tis, session) + + def _build_schedulable_tis_query( + self, + starved_pools: set[str], + starved_dags: set[str], + starved_tasks: set[tuple[str, str]], + starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]], + max_tis: int, + ) -> Select[tuple[TI]]: + """ + Build a query that fetches SCHEDULED TIs eligible for execution this cycle. + + Applies current starvation exclusions so that saturated pools, DAGs, or tasks + don't re-appear in the candidate set. Row-number windowing enforces + ``max_active_tasks`` per DagRun. The returned query is ready to be wrapped + with ``with_row_locks`` and executed by the caller; no session is required here. + + This behaves the same as calling ``concurrency_map.load()`` followed by + ``_get_current_dr_task_concurrency``, with the difference that the subquery + object is built here and executed as part of the main query, so any state + changes between construction and execution are naturally ignored. + """ + dr_task_concurrency_subquery = _get_current_dr_task_concurrency(states=EXECUTION_STATES) + + query = ( + select(TI) + .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") + .join(TI.dag_run) + .where(DR.state == DagRunState.RUNNING) + .join(TI.dag_model) + .where(~DM.is_paused) + .where(TI.state == TaskInstanceState.SCHEDULED) + .where(DM.bundle_name.is_not(None)) + .join( + dr_task_concurrency_subquery, + and_( + TI.dag_id == dr_task_concurrency_subquery.c.dag_id, + TI.run_id == dr_task_concurrency_subquery.c.run_id, + ), + isouter=True, + ) + .where(func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < DM.max_active_tasks) + .order_by(-TI.priority_weight, DR.logical_date, TI.map_index) + ) + + # Starvation filters should be applied before computing the row_num based on the + # max_active_tasks limit. That way, starved dags and tasks that shouldn't run, + # won't occupy a slot. + if starved_pools: + query = query.where(TI.pool.not_in(starved_pools)) + + if starved_dags: + query = query.where(TI.dag_id.not_in(starved_dags)) + + if starved_tasks: + query = query.where(tuple_(TI.dag_id, TI.task_id).not_in(starved_tasks)) + + if starved_tasks_task_dagrun_concurrency: + query = query.where( + tuple_(TI.dag_id, TI.run_id, TI.task_id).not_in(starved_tasks_task_dagrun_concurrency) + ) + + # Create a subquery with row numbers partitioned by dag_id and run_id. + # Different dags can have the same run_id but + # the dag_id combined with the run_id uniquely identify a run. + ranked_query = ( + query.add_columns( + func.row_number() + .over( + partition_by=[TI.dag_id, TI.run_id], + order_by=[-TI.priority_weight, DR.logical_date, TI.map_index], + ) + .label("row_num"), + DM.max_active_tasks.label("dr_max_active_tasks"), + # Create columns for the order_by checks here for sqlite. + TI.priority_weight.label("priority_weight_for_ordering"), + DR.logical_date.label("logical_date_for_ordering"), + TI.map_index.label("map_index_for_ordering"), + ) + ).subquery() + + # Select only rows where row_number <= max_active_tasks. + return ( + select(TI) + .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") + .select_from(ranked_query) + .join( + TI, + (TI.dag_id == ranked_query.c.dag_id) + & (TI.task_id == ranked_query.c.task_id) + & (TI.run_id == ranked_query.c.run_id) + & (TI.map_index == ranked_query.c.map_index), + ) + .where(ranked_query.c.row_num <= ranked_query.c.dr_max_active_tasks) + # Add the order_by columns from the ranked query for sqlite. + .order_by( + -ranked_query.c.priority_weight_for_ordering, + ranked_query.c.logical_date_for_ordering, + ranked_query.c.map_index_for_ordering, + ) + .options(selectinload(TI.dag_model)) + .limit(max_tis) + ) + + def _mark_task_instances_queued(self, executable_tis: list[TI], session: Session) -> list[TI]: + """ + Bulk-update ``executable_tis`` to QUEUED state and detach them from the session. + + Handles ``external_executor_id`` pre-assignment for executors that opt in via + ``pre_assigns_external_executor_id``, using a CASE expression in mixed-executor + deployments. UUIDs are read back via RETURNING on PostgreSQL and a follow-up + SELECT on other databases. + + After this call the TIs are transient (detached from the ORM session) and carry + their final ``external_executor_id`` values in memory. + + :return: ``executable_tis`` (same list, post-transient) or ``[]`` if the filter + could not be built (should not happen in practice). + """ if executable_tis: task_instance_str = "\n".join( f"\t{x!r} (id={x.id}, try_number={x.try_number})" for x in executable_tis @@ -1072,7 +1129,10 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int: self.log.debug("max_tis query size is less than or equal to zero. No query will be performed!") return 0 - queued_tis = self._executable_task_instances_to_queued(max_tis, session=session) + pools, max_tis, starved_pools = self._acquire_pool_capacity(max_tis, session=session) + if max_tis == 0: + return 0 + queued_tis = self._select_task_instances_to_queue(max_tis, pools, starved_pools, session=session) # Sort queued TIs to their respective executor executor_to_queued_tis = self._executor_to_workloads(queued_tis, session) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index eb763635ddf32..883ef1e6b751d 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -77,7 +77,7 @@ from airflow.models.deadline import Deadline from airflow.models.deadline_alert import DeadlineAlert from airflow.models.log import Log -from airflow.models.pool import Pool +from airflow.models.pool import Pool, PoolStats from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.models.team import Team @@ -277,6 +277,26 @@ def _clean_db(): clear_db_triggers() +def make_pool_stats( + pool: str = "default_pool", + total: int | float = 128, + running: int = 0, + queued: int = 0, + deferred: int = 0, + scheduled: int = 0, +) -> dict[str, PoolStats]: + return { + pool: PoolStats( + total=total, + running=running, + queued=queued, + deferred=deferred, + scheduled=scheduled, + open=total - running - queued, + ) + } + + @patch.dict( ExecutorLoader.executors, {MOCK_EXECUTOR: f"{MockExecutor.__module__}.{MockExecutor.__qualname__}"} ) @@ -1243,7 +1263,7 @@ def test_find_executable_task_instances_backfill(self, dag_maker): session.merge(ti_non_backfill) session.flush() - queued_tis = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + queued_tis = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) assert len(queued_tis) == 2 assert {x.key for x in queued_tis} == {ti_non_backfill.key, ti_backfill.key} session.rollback() @@ -1279,7 +1299,8 @@ def test_find_executable_task_instances_pool(self, dag_maker): session.add(pool2) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity(32, session) + res = self.job_runner._select_task_instances_to_queue(max_tis, pools, starved_pools, session) session.flush() assert len(res) == 3 res_keys = [] @@ -1321,7 +1342,7 @@ def test_find_executable_task_instances_only_running_dagruns( ti.state = State.SCHEDULED session.merge(ti) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) session.flush() assert total_executed_ti == len(res) @@ -1354,7 +1375,7 @@ def test_find_executable_task_instances_order_logical_date(self, dag_maker): session.merge(ti) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._select_task_instances_to_queue(1, make_pool_stats(), set(), session) session.flush() assert [ti.key for ti in res] == [tis[1].key] session.rollback() @@ -1383,7 +1404,7 @@ def test_find_executable_task_instances_order_priority(self, dag_maker): session.merge(ti) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._select_task_instances_to_queue(1, make_pool_stats(), set(), session) session.flush() assert [ti.key for ti in res] == [tis[1].key] session.rollback() @@ -1419,7 +1440,7 @@ def test_find_executable_task_instances_executor(self, dag_maker, mock_executors session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) assert len(res) == 5 res_ti_keys = [res_ti.key for res_ti in res] @@ -1481,7 +1502,7 @@ def test_find_executable_task_instances_executor_with_teams(self, dag_maker, moc scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) # All tasks should be queued since they have valid executor mappings assert len(res) == 5 @@ -1605,10 +1626,11 @@ def test_max_active_tasks_per_dr_limit_applied_in_task_query(self, dag_maker, mo queued_tis = None while count < task_num: - # Use `_executable_task_instances_to_queued` because it returns a list of TIs - # while `_critical_section_enqueue_task_instances` just returns the number of the TIs. - queued_tis = self.job_runner._executable_task_instances_to_queued( - max_tis=self.job_runner.executor.slots_available, session=session + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity( + self.job_runner.executor.slots_available, session + ) + queued_tis = self.job_runner._select_task_instances_to_queue( + max_tis, pools, starved_pools, session ) count += len(queued_tis) iterations += 1 @@ -1665,9 +1687,10 @@ def test_max_active_tasks_per_dr_limit_partial_capacity(self, dag_maker, mock_ex run_id="run1", ) - queued_tis = self.job_runner._executable_task_instances_to_queued( - max_tis=self.job_runner.executor.slots_available, session=session + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity( + self.job_runner.executor.slots_available, session ) + queued_tis = self.job_runner._select_task_instances_to_queue(max_tis, pools, starved_pools, session) assert queued_tis is not None # Only 2 tasks should have been queued. @@ -1715,9 +1738,10 @@ def test_max_active_tasks_per_dr_limit_starvation_filter_ordering(self, dag_make run_id="run1", ) - queued_tis = self.job_runner._executable_task_instances_to_queued( - max_tis=self.job_runner.executor.slots_available, session=session + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity( + self.job_runner.executor.slots_available, session ) + queued_tis = self.job_runner._select_task_instances_to_queue(max_tis, pools, starved_pools, session) assert queued_tis is not None # 4 tasks should have been queued. @@ -1771,7 +1795,8 @@ def test_find_executable_task_instances_order_priority_with_pools(self, dag_make session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity(32, session) + res = self.job_runner._select_task_instances_to_queue(max_tis, pools, starved_pools, session) assert len(res) == 2 assert ti3.key == res[0].key @@ -1802,7 +1827,7 @@ def test_find_executable_task_instances_order_logical_date_and_priority(self, da session.merge(ti) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._select_task_instances_to_queue(1, make_pool_stats(), set(), session) session.flush() assert [ti.key for ti in res] == [tis[1].key] session.rollback() @@ -1830,14 +1855,16 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker, mock_ex session.flush() # Two tasks w/o pool up for execution and our default pool size is 1 - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(total=1), set(), session) assert len(res) == 1 ti2.state = State.RUNNING session.flush() # One task w/o pool up for execution and one task running - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue( + 32, make_pool_stats(total=1, running=1), set(), session + ) assert len(res) == 0 session.rollback() @@ -1866,7 +1893,7 @@ def test_queued_task_instances_fails_with_missing_dag(self, dag_maker, session): ti.state = State.SCHEDULED session.merge(ti) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) session.flush() assert len(res) == 0 tis = dr.get_task_instances(session=session) @@ -1889,7 +1916,7 @@ def test_nonexistent_pool(self, dag_maker): session.merge(ti) session.commit() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) session.flush() assert len(res) == 0 session.rollback() @@ -1916,7 +1943,8 @@ def test_infinite_pool(self, dag_maker): session.add(infinite_pool) session.commit() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity(32, session) + res = self.job_runner._select_task_instances_to_queue(max_tis, pools, starved_pools, session) session.flush() assert len(res) == 1 session.rollback() @@ -1942,7 +1970,9 @@ def test_not_enough_pool_slots(self, caplog, dag_maker): session.commit() cannot_run_ti_id = next(t for t in dr.task_instances if t.task_id == "cannot_run").id with caplog.at_level(logging.WARNING): - self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + self.job_runner._select_task_instances_to_queue( + 32, make_pool_stats("some_pool", total=2), set(), session + ) assert ( f"Not executing room for 1 more (limit is 2) - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) assert len(res) == 1 session.rollback() @@ -2316,7 +2348,7 @@ def test_find_executable_task_instances_deferred_does_not_block_different_task(s session.merge(ti_b1) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) queued_task_ids = [ti.task_id for ti in res] # task_b should be queued, task_a should be blocked assert "task_b" in queued_task_ids @@ -2349,7 +2381,7 @@ def test_find_executable_task_instances_deferred_to_success_unblocks(self, dag_m session.merge(ti2) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) assert len(res) == 0 # Step 2: ti1 completes -> ti2 should be unblocked @@ -2357,7 +2389,7 @@ def test_find_executable_task_instances_deferred_to_success_unblocks(self, dag_m session.merge(ti1) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) assert len(res) == 1 assert res[0].key == ti2.key session.rollback() @@ -2390,7 +2422,7 @@ def test_find_executable_task_instances_max_active_tis_per_dagrun_deferred(self, session.merge(ti_a1) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) queued_task_ids = [(ti.task_id, ti.map_index) for ti in res] # ti_a1 should be blocked, task_b may be queued assert ("task_a", 1) not in queued_task_ids @@ -2426,7 +2458,7 @@ def test_find_executable_task_instances_deferred_does_not_affect_max_active_task session.merge(t3) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) # Deferred doesn't count toward max_active_tasks=2, so both scheduled can run assert len(res) == 2 session.rollback() @@ -2457,7 +2489,7 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_ session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=100, session=session) + res = self.job_runner._select_task_instances_to_queue(100, make_pool_stats(), set(), session) assert len(res) == 0 session.rollback() @@ -2484,7 +2516,7 @@ def test_find_executable_task_instances_not_enough_pool_slots_for_first(self, da # Schedule ti with lower priority, # because the one with higher priority is limited by a concurrency limit - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(total=1), set(), session) assert len(res) == 1 assert res[0].key == ti2.key @@ -2521,7 +2553,7 @@ def test_find_executable_task_instances_not_enough_dag_concurrency_for_first(sel # Schedule ti with lower priority, # because the one with higher priority is limited by a concurrency limit - res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._select_task_instances_to_queue(1, make_pool_stats(), set(), session) assert len(res) == 1 assert res[0].key == ti2.key @@ -2550,7 +2582,7 @@ def test_find_executable_task_instances_not_enough_task_concurrency_for_first(se # Schedule ti with lower priority, # because the one with higher priority is limited by a concurrency limit - res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._select_task_instances_to_queue(1, make_pool_stats(), set(), session) assert len(res) == 1 assert res[0].key == ti1b.key @@ -2579,7 +2611,7 @@ def test_find_executable_task_instances_task_concurrency_per_dagrun_for_first(se # Schedule ti with higher priority, # because it's running in a different DAG run with 0 active tis - res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._select_task_instances_to_queue(1, make_pool_stats(), set(), session) assert len(res) == 1 assert res[0].key == ti2a.key @@ -2612,7 +2644,7 @@ def test_find_executable_task_instances_not_enough_task_concurrency_per_dagrun_f # Schedule ti with lower priority, # because the one with higher priority is limited by a concurrency limit - res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._select_task_instances_to_queue(1, make_pool_stats(), set(), session) assert len(res) == 1 assert res[0].key == ti1b.key @@ -2649,7 +2681,16 @@ def test_find_executable_task_instances_negative_open_pool_slots(self, dag_maker ti2.state = State.RUNNING session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=1, session=session) + res = self.job_runner._select_task_instances_to_queue( + 1, + { + **make_pool_stats(total=0), + **make_pool_stats("pool1", total=1), + **make_pool_stats("pool2", total=1, running=2), + }, + set(), + session, + ) assert len(res) == 1 assert res[0].key == ti1.key @@ -2675,7 +2716,7 @@ def test_emit_pool_starving_tasks_metrics(self, mock_get_backend, dag_maker): set_default_pool_slots(1) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(total=1), set(), session) assert len(res) == 0 mock_stats.gauge.assert_has_calls( @@ -2691,7 +2732,7 @@ def test_emit_pool_starving_tasks_metrics(self, mock_get_backend, dag_maker): set_default_pool_slots(2) session.flush() - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(total=2), set(), session) assert len(res) == 1 mock_stats.gauge.assert_has_calls( @@ -2727,7 +2768,7 @@ def test_enqueue_task_instances_with_queued_state(self, dag_maker, session): assert mock_queue_workload.called session.rollback() - def test_executable_task_instances_to_queued_sets_external_executor_id(self, dag_maker, session): + def test_select_task_instances_to_queue_sets_external_executor_id(self, dag_maker, session): """external_executor_id is written to the DB in the same UPDATE that sets state=QUEUED.""" dag_id = "SchedulerJobTest.test_executable_sets_external_executor_id" session = settings.Session() @@ -2746,7 +2787,7 @@ class PreAssigningExecutor(MockExecutor): session.merge(ti) session.flush() - returned_tis = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + returned_tis = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) assert len(returned_tis) == 1 # In-memory object (post make_transient) should carry the UUID @@ -4065,7 +4106,7 @@ def test_do_not_schedule_removed_task(self, dag_maker, session): self.job_runner = SchedulerJobRunner(job=scheduler_job) # Try to find executable task instances - should not find any for the removed task - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) # Should be empty because the task no longer exists in the DAG assert res == [] @@ -4325,8 +4366,9 @@ def test_scheduler_verify_pool_full(self, dag_maker, mock_executor): dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.RUNNING) self.job_runner._schedule_dag_run(dr, session) session.flush() - task_instances_list = self.job_runner._executable_task_instances_to_queued( - max_tis=32, session=session + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity(32, session) + task_instances_list = self.job_runner._select_task_instances_to_queue( + max_tis, pools, starved_pools, session ) assert len(task_instances_list) == 1 @@ -4370,8 +4412,9 @@ def _create_dagruns(): for dr in _create_dagruns(): self.job_runner._schedule_dag_run(dr, session) - task_instances_list = self.job_runner._executable_task_instances_to_queued( - max_tis=32, session=session + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity(32, session) + task_instances_list = self.job_runner._select_task_instances_to_queue( + max_tis, pools, starved_pools, session ) # As tasks require 2 slots, only 3 can fit into 6 available @@ -4437,9 +4480,11 @@ def _create_dagruns(dag: SerializedDAG): for dr in _create_dagruns(dag_d2): self.job_runner._schedule_dag_run(dr, session) - self.job_runner._executable_task_instances_to_queued(max_tis=2, session=session) - task_instances_list2 = self.job_runner._executable_task_instances_to_queued( - max_tis=2, session=session + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity(2, session) + self.job_runner._select_task_instances_to_queue(max_tis, pools, starved_pools, session) + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity(2, session) + task_instances_list2 = self.job_runner._select_task_instances_to_queue( + max_tis, pools, starved_pools, session ) # Make sure we get TIs from a non-full pool in the 2nd list @@ -4496,8 +4541,9 @@ def test_scheduler_verify_priority_and_slots(self, dag_maker, mock_executor): session.merge(ti) session.flush() - task_instances_list = self.job_runner._executable_task_instances_to_queued( - max_tis=32, session=session + pools, max_tis, starved_pools = self.job_runner._acquire_pool_capacity(32, session) + task_instances_list = self.job_runner._select_task_instances_to_queue( + max_tis, pools, starved_pools, session ) # Only second and third @@ -8924,7 +8970,7 @@ def test_multi_team_scheduling_loop_batch_optimization(self, dag_maker, mock_exe with mock.patch.object(self.job_runner, "_get_team_names_for_dag_ids") as mock_batch: mock_batch.return_value = {"dag_a": "team_a", "dag_b": "team_b"} - res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + res = self.job_runner._select_task_instances_to_queue(32, make_pool_stats(), set(), session) # Verify batch method was called with unique DAG IDs mock_batch.assert_called_once_with({"dag_a", "dag_b"}, session)