-
Notifications
You must be signed in to change notification settings - Fork 17.3k
Refactor _executable_task_instances_to_queued to make logic more readable #66878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ashb
wants to merge
1
commit into
apache:main
Choose a base branch
from
astronomer:scheduler-enqueue-readability
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+277
−171
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -110,12 +110,13 @@ | |||||||
| from sqlalchemy.engine import CursorResult | ||||||||
| from sqlalchemy.orm import Session | ||||||||
| from sqlalchemy.orm.interfaces import LoaderOption | ||||||||
| from sqlalchemy.sql.selectable import Subquery | ||||||||
| from sqlalchemy.sql.selectable import Select, Subquery | ||||||||
|
|
||||||||
| from airflow._shared.logging.types import Logger | ||||||||
| from airflow.executors.base_executor import BaseExecutor | ||||||||
| from airflow.executors.executor_utils import ExecutorName | ||||||||
| from airflow.executors.workloads.types import SchedulerWorkload | ||||||||
| from airflow.models.pool import PoolStats | ||||||||
| from airflow.serialization.definitions.dag import SerializedDAG | ||||||||
| from airflow.utils.sqlalchemy import CommitProhibitorGuard | ||||||||
|
|
||||||||
|
|
@@ -476,26 +477,22 @@ def _debug_dump(self, signum: int, frame: FrameType | None) -> None: | |||||||
| self.log.info("\n\t".join(map(repr, callstack))) | ||||||||
| self.log.info("-" * 80) | ||||||||
|
|
||||||||
| def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]: | ||||||||
| def _acquire_pool_capacity( | ||||||||
| self, max_tis: int, session: Session | ||||||||
| ) -> tuple[dict[str, PoolStats], int, set[str]]: | ||||||||
| """ | ||||||||
| Find TIs that are ready for execution based on conditions. | ||||||||
|
|
||||||||
| Conditions include: | ||||||||
| - pool limits | ||||||||
| - DAG max_active_tasks | ||||||||
| - executor state | ||||||||
| - priority | ||||||||
| - max active tis per DAG | ||||||||
| - max active tis per DAG run | ||||||||
|
|
||||||||
| :param max_tis: Maximum number of TIs to queue in this loop. | ||||||||
| :return: list[airflow.models.TaskInstance] | ||||||||
| Acquire the scheduler critical-section lock and read current pool utilisation. | ||||||||
|
|
||||||||
| On PostgreSQL a transactional advisory lock is taken first so that only one | ||||||||
| scheduler at a time enters the critical section; pool rows are then locked via | ||||||||
| ``SELECT … FOR UPDATE`` (or ``NOWAIT`` where supported). | ||||||||
|
|
||||||||
| Returns a ``(pools, effective_max_tis, starved_pools)`` tuple. ``effective_max_tis`` | ||||||||
| is zero when all pools are already full; callers should short-circuit in that case. | ||||||||
| """ | ||||||||
| from airflow.models.pool import Pool | ||||||||
| from airflow.utils.db import DBLocks | ||||||||
|
|
||||||||
| executable_tis: list[TI] = [] | ||||||||
|
|
||||||||
| if get_dialect_name(session) == "postgresql": | ||||||||
| # Optimization: to avoid littering the DB errors of "ERROR: canceling statement due to lock | ||||||||
| # timeout", try to take out a transactional advisory lock (unlocks automatically on | ||||||||
|
|
@@ -523,11 +520,33 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - | |||||||
|
|
||||||||
| if pool_slots_free == 0: | ||||||||
| self.log.debug("All pools are full!") | ||||||||
| return [] | ||||||||
|
|
||||||||
| max_tis = int(min(max_tis, pool_slots_free)) | ||||||||
| return pools, 0, set() | ||||||||
|
|
||||||||
| effective_max_tis = int(min(max_tis, pool_slots_free)) | ||||||||
| starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0} | ||||||||
| return pools, effective_max_tis, starved_pools | ||||||||
|
|
||||||||
| def _select_task_instances_to_queue( | ||||||||
| self, | ||||||||
| max_tis: int, | ||||||||
| pools: dict[str, PoolStats], | ||||||||
| starved_pools: set[str], | ||||||||
| session: Session, | ||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| ) -> list[TI]: | ||||||||
| """ | ||||||||
| Select SCHEDULED TIs that can run given pool and concurrency constraints, and mark them QUEUED. | ||||||||
|
|
||||||||
| ``pools`` and ``starved_pools`` must come from a prior ``_acquire_pool_capacity`` call (or an | ||||||||
| equivalent pre-built dict in tests). The pool stats are updated in-place as slots are | ||||||||
| virtually allocated to each selected TI. | ||||||||
|
|
||||||||
| :param max_tis: Upper bound on TIs to select this cycle. | ||||||||
| :param pools: Current pool utilisation as returned by ``Pool.slots_stats``. | ||||||||
| :param starved_pools: Pools that are already at capacity; TIs in these pools are skipped. | ||||||||
| :param session: SQLAlchemy session (must remain open until the caller commits). | ||||||||
| :return: TIs that were moved to QUEUED state. | ||||||||
| """ | ||||||||
| executable_tis: list[TI] = [] | ||||||||
|
|
||||||||
| # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks. | ||||||||
| concurrency_map = ConcurrencyMap() | ||||||||
|
|
@@ -549,96 +568,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - | |||||||
| num_starved_tasks = len(starved_tasks) | ||||||||
| num_starved_tasks_task_dagrun_concurrency = len(starved_tasks_task_dagrun_concurrency) | ||||||||
|
|
||||||||
| # This behaves the same as 'concurrency_map.load()' with the difference that | ||||||||
| # 'load()' executes immediately while '_get_current_dr_task_concurrency' creates a | ||||||||
| # subquery object that is then executed along with main query. | ||||||||
| # The results of 'load()' aren't used again here because by the time the main query | ||||||||
| # executes, there could be a change that will be ignored. | ||||||||
| dr_task_concurrency_subquery = _get_current_dr_task_concurrency(states=EXECUTION_STATES) | ||||||||
|
|
||||||||
| query = ( | ||||||||
| select(TI) | ||||||||
| .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") | ||||||||
| .join(TI.dag_run) | ||||||||
| .where(DR.state == DagRunState.RUNNING) | ||||||||
| .join(TI.dag_model) | ||||||||
| .where(~DM.is_paused) | ||||||||
| .where(TI.state == TaskInstanceState.SCHEDULED) | ||||||||
| .where(DM.bundle_name.is_not(None)) | ||||||||
| .join( | ||||||||
| dr_task_concurrency_subquery, | ||||||||
| and_( | ||||||||
| TI.dag_id == dr_task_concurrency_subquery.c.dag_id, | ||||||||
| TI.run_id == dr_task_concurrency_subquery.c.run_id, | ||||||||
| ), | ||||||||
| isouter=True, | ||||||||
| ) | ||||||||
| .where( | ||||||||
| func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < DM.max_active_tasks | ||||||||
| ) | ||||||||
| .order_by(-TI.priority_weight, DR.logical_date, TI.map_index) | ||||||||
| query = self._build_schedulable_tis_query( | ||||||||
| starved_pools, | ||||||||
| starved_dags, | ||||||||
| starved_tasks, | ||||||||
| starved_tasks_task_dagrun_concurrency, | ||||||||
| max_tis, | ||||||||
| ) | ||||||||
|
|
||||||||
| # Starvation filters should be applied before computing the row_num based on the | ||||||||
| # max_active_tasks limit. That way, starved dags and tasks that shouldn't run, | ||||||||
| # won't occupy a slot. | ||||||||
| if starved_pools: | ||||||||
| query = query.where(TI.pool.not_in(starved_pools)) | ||||||||
|
|
||||||||
| if starved_dags: | ||||||||
| query = query.where(TI.dag_id.not_in(starved_dags)) | ||||||||
|
|
||||||||
| if starved_tasks: | ||||||||
| query = query.where(tuple_(TI.dag_id, TI.task_id).not_in(starved_tasks)) | ||||||||
|
|
||||||||
| if starved_tasks_task_dagrun_concurrency: | ||||||||
| query = query.where( | ||||||||
| tuple_(TI.dag_id, TI.run_id, TI.task_id).not_in(starved_tasks_task_dagrun_concurrency) | ||||||||
| ) | ||||||||
|
|
||||||||
| # Create a subquery with row numbers partitioned by dag_id and run_id. | ||||||||
| # Different dags can have the same run_id but | ||||||||
| # the dag_id combined with the run_id uniquely identify a run. | ||||||||
| ranked_query = ( | ||||||||
| query.add_columns( | ||||||||
| func.row_number() | ||||||||
| .over( | ||||||||
| partition_by=[TI.dag_id, TI.run_id], | ||||||||
| order_by=[-TI.priority_weight, DR.logical_date, TI.map_index], | ||||||||
| ) | ||||||||
| .label("row_num"), | ||||||||
| DM.max_active_tasks.label("dr_max_active_tasks"), | ||||||||
| # Create columns for the order_by checks here for sqlite. | ||||||||
| TI.priority_weight.label("priority_weight_for_ordering"), | ||||||||
| DR.logical_date.label("logical_date_for_ordering"), | ||||||||
| TI.map_index.label("map_index_for_ordering"), | ||||||||
| ) | ||||||||
| ).subquery() | ||||||||
|
|
||||||||
| # Select only rows where row_number <= max_active_tasks. | ||||||||
| query = ( | ||||||||
| select(TI) | ||||||||
| .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") | ||||||||
| .select_from(ranked_query) | ||||||||
| .join( | ||||||||
| TI, | ||||||||
| (TI.dag_id == ranked_query.c.dag_id) | ||||||||
| & (TI.task_id == ranked_query.c.task_id) | ||||||||
| & (TI.run_id == ranked_query.c.run_id) | ||||||||
| & (TI.map_index == ranked_query.c.map_index), | ||||||||
| ) | ||||||||
| .where(ranked_query.c.row_num <= ranked_query.c.dr_max_active_tasks) | ||||||||
| # Add the order_by columns from the ranked query for sqlite. | ||||||||
| .order_by( | ||||||||
| -ranked_query.c.priority_weight_for_ordering, | ||||||||
| ranked_query.c.logical_date_for_ordering, | ||||||||
| ranked_query.c.map_index_for_ordering, | ||||||||
| ) | ||||||||
| .options(selectinload(TI.dag_model)) | ||||||||
| ) | ||||||||
|
|
||||||||
| query = query.limit(max_tis) | ||||||||
|
|
||||||||
| timer = stats.timer("scheduler.critical_section_query_duration") | ||||||||
| timer.start() | ||||||||
|
|
||||||||
|
|
@@ -899,6 +836,126 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - | |||||||
| stats.gauge("scheduler.tasks.starving", num_starving_tasks_total) | ||||||||
| stats.gauge("scheduler.tasks.executable", len(executable_tis)) | ||||||||
|
|
||||||||
| return self._mark_task_instances_queued(executable_tis, session) | ||||||||
|
|
||||||||
| def _build_schedulable_tis_query( | ||||||||
| self, | ||||||||
| starved_pools: set[str], | ||||||||
| starved_dags: set[str], | ||||||||
| starved_tasks: set[tuple[str, str]], | ||||||||
| starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]], | ||||||||
| max_tis: int, | ||||||||
| ) -> Select[tuple[TI]]: | ||||||||
| """ | ||||||||
| Build a query that fetches SCHEDULED TIs eligible for execution this cycle. | ||||||||
|
|
||||||||
| Applies current starvation exclusions so that saturated pools, DAGs, or tasks | ||||||||
| don't re-appear in the candidate set. Row-number windowing enforces | ||||||||
| ``max_active_tasks`` per DagRun. The returned query is ready to be wrapped | ||||||||
| with ``with_row_locks`` and executed by the caller; no session is required here. | ||||||||
|
|
||||||||
| This behaves the same as calling ``concurrency_map.load()`` followed by | ||||||||
| ``_get_current_dr_task_concurrency``, with the difference that the subquery | ||||||||
| object is built here and executed as part of the main query, so any state | ||||||||
| changes between construction and execution are naturally ignored. | ||||||||
| """ | ||||||||
| dr_task_concurrency_subquery = _get_current_dr_task_concurrency(states=EXECUTION_STATES) | ||||||||
|
|
||||||||
| query = ( | ||||||||
| select(TI) | ||||||||
| .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") | ||||||||
| .join(TI.dag_run) | ||||||||
| .where(DR.state == DagRunState.RUNNING) | ||||||||
| .join(TI.dag_model) | ||||||||
| .where(~DM.is_paused) | ||||||||
| .where(TI.state == TaskInstanceState.SCHEDULED) | ||||||||
| .where(DM.bundle_name.is_not(None)) | ||||||||
| .join( | ||||||||
| dr_task_concurrency_subquery, | ||||||||
| and_( | ||||||||
| TI.dag_id == dr_task_concurrency_subquery.c.dag_id, | ||||||||
| TI.run_id == dr_task_concurrency_subquery.c.run_id, | ||||||||
| ), | ||||||||
| isouter=True, | ||||||||
| ) | ||||||||
| .where(func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < DM.max_active_tasks) | ||||||||
| .order_by(-TI.priority_weight, DR.logical_date, TI.map_index) | ||||||||
| ) | ||||||||
|
|
||||||||
| # Starvation filters should be applied before computing the row_num based on the | ||||||||
| # max_active_tasks limit. That way, starved dags and tasks that shouldn't run, | ||||||||
| # won't occupy a slot. | ||||||||
| if starved_pools: | ||||||||
| query = query.where(TI.pool.not_in(starved_pools)) | ||||||||
|
|
||||||||
| if starved_dags: | ||||||||
| query = query.where(TI.dag_id.not_in(starved_dags)) | ||||||||
|
|
||||||||
| if starved_tasks: | ||||||||
| query = query.where(tuple_(TI.dag_id, TI.task_id).not_in(starved_tasks)) | ||||||||
|
|
||||||||
| if starved_tasks_task_dagrun_concurrency: | ||||||||
| query = query.where( | ||||||||
| tuple_(TI.dag_id, TI.run_id, TI.task_id).not_in(starved_tasks_task_dagrun_concurrency) | ||||||||
| ) | ||||||||
|
|
||||||||
| # Create a subquery with row numbers partitioned by dag_id and run_id. | ||||||||
| # Different dags can have the same run_id but | ||||||||
| # the dag_id combined with the run_id uniquely identify a run. | ||||||||
| ranked_query = ( | ||||||||
| query.add_columns( | ||||||||
| func.row_number() | ||||||||
| .over( | ||||||||
| partition_by=[TI.dag_id, TI.run_id], | ||||||||
| order_by=[-TI.priority_weight, DR.logical_date, TI.map_index], | ||||||||
| ) | ||||||||
| .label("row_num"), | ||||||||
| DM.max_active_tasks.label("dr_max_active_tasks"), | ||||||||
| # Create columns for the order_by checks here for sqlite. | ||||||||
| TI.priority_weight.label("priority_weight_for_ordering"), | ||||||||
| DR.logical_date.label("logical_date_for_ordering"), | ||||||||
| TI.map_index.label("map_index_for_ordering"), | ||||||||
| ) | ||||||||
| ).subquery() | ||||||||
|
|
||||||||
| # Select only rows where row_number <= max_active_tasks. | ||||||||
| return ( | ||||||||
| select(TI) | ||||||||
| .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") | ||||||||
| .select_from(ranked_query) | ||||||||
| .join( | ||||||||
| TI, | ||||||||
| (TI.dag_id == ranked_query.c.dag_id) | ||||||||
| & (TI.task_id == ranked_query.c.task_id) | ||||||||
| & (TI.run_id == ranked_query.c.run_id) | ||||||||
| & (TI.map_index == ranked_query.c.map_index), | ||||||||
| ) | ||||||||
| .where(ranked_query.c.row_num <= ranked_query.c.dr_max_active_tasks) | ||||||||
| # Add the order_by columns from the ranked query for sqlite. | ||||||||
| .order_by( | ||||||||
| -ranked_query.c.priority_weight_for_ordering, | ||||||||
| ranked_query.c.logical_date_for_ordering, | ||||||||
| ranked_query.c.map_index_for_ordering, | ||||||||
| ) | ||||||||
| .options(selectinload(TI.dag_model)) | ||||||||
| .limit(max_tis) | ||||||||
| ) | ||||||||
|
|
||||||||
| def _mark_task_instances_queued(self, executable_tis: list[TI], session: Session) -> list[TI]: | ||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| """ | ||||||||
| Bulk-update ``executable_tis`` to QUEUED state and detach them from the session. | ||||||||
|
|
||||||||
| Handles ``external_executor_id`` pre-assignment for executors that opt in via | ||||||||
| ``pre_assigns_external_executor_id``, using a CASE expression in mixed-executor | ||||||||
| deployments. UUIDs are read back via RETURNING on PostgreSQL and a follow-up | ||||||||
| SELECT on other databases. | ||||||||
|
|
||||||||
| After this call the TIs are transient (detached from the ORM session) and carry | ||||||||
| their final ``external_executor_id`` values in memory. | ||||||||
|
|
||||||||
| :return: ``executable_tis`` (same list, post-transient) or ``[]`` if the filter | ||||||||
| could not be built (should not happen in practice). | ||||||||
| """ | ||||||||
| if executable_tis: | ||||||||
| task_instance_str = "\n".join( | ||||||||
| f"\t{x!r} (id={x.id}, try_number={x.try_number})" for x in executable_tis | ||||||||
|
|
@@ -1072,7 +1129,10 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int: | |||||||
| self.log.debug("max_tis query size is less than or equal to zero. No query will be performed!") | ||||||||
| return 0 | ||||||||
|
|
||||||||
| queued_tis = self._executable_task_instances_to_queued(max_tis, session=session) | ||||||||
| pools, max_tis, starved_pools = self._acquire_pool_capacity(max_tis, session=session) | ||||||||
| if max_tis == 0: | ||||||||
| return 0 | ||||||||
| queued_tis = self._select_task_instances_to_queue(max_tis, pools, starved_pools, session=session) | ||||||||
|
|
||||||||
| # Sort queued TIs to their respective executor | ||||||||
| executor_to_queued_tis = self._executor_to_workloads(queued_tis, session) | ||||||||
|
|
||||||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or, convert it all to be kwarg only. Same with the others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we had the rule with session as kwarg mainly for the
@provide_sessiondecorator which is not used here. So a session needs to be provided by caller anyway. So it is rather a nit.