-
Notifications
You must be signed in to change notification settings - Fork 16.8k
Clean up CeleryExecutor to use workload terminology and typing #63888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
SameerMesiah97
wants to merge
1
commit into
apache:main
Choose a base branch
from
SameerMesiah97:CeleryExecutor-Cleanup
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+198
−191
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,15 +32,15 @@ | |
| from collections import Counter | ||
| from concurrent.futures import ProcessPoolExecutor | ||
| from multiprocessing import cpu_count | ||
| from typing import TYPE_CHECKING, Any | ||
| from typing import TYPE_CHECKING, Any, cast | ||
|
|
||
| from celery import states as celery_states | ||
| from deprecated import deprecated | ||
|
|
||
| from airflow.exceptions import AirflowProviderDeprecationWarning | ||
| from airflow.executors.base_executor import BaseExecutor | ||
| from airflow.providers.celery.executors import ( | ||
| celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043 | ||
| celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043. | ||
| ) | ||
| from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS | ||
| from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats | ||
|
|
@@ -49,18 +49,23 @@ | |
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task" | ||
| CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery workload" | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Sequence | ||
|
|
||
| from celery.result import AsyncResult | ||
|
|
||
| from airflow.cli.cli_config import GroupCommand | ||
| from airflow.executors import workloads | ||
| from airflow.models.taskinstance import TaskInstance | ||
| from airflow.models.taskinstancekey import TaskInstanceKey | ||
| from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery | ||
|
|
||
| if AIRFLOW_V_3_2_PLUS: | ||
| from airflow.executors.workloads.types import WorkloadKey | ||
|
|
||
|
|
||
| # PEP562 | ||
| def __getattr__(name): | ||
|
|
@@ -84,7 +89,7 @@ class CeleryExecutor(BaseExecutor): | |
| """ | ||
| CeleryExecutor is recommended for production use of Airflow. | ||
| It allows distributing the execution of task instances to multiple worker nodes. | ||
| It allows distributing the execution of workloads (task instances and callbacks) to multiple worker nodes. | ||
| Celery is a simple, flexible and reliable distributed system to process | ||
| vast amounts of messages, while providing operations with the tools | ||
|
|
@@ -102,7 +107,7 @@ class CeleryExecutor(BaseExecutor): | |
| if TYPE_CHECKING: | ||
| if AIRFLOW_V_3_0_PLUS: | ||
| # TODO: TaskSDK: move this type change into BaseExecutor | ||
| queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] | ||
| queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment] | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
@@ -127,7 +132,7 @@ def __init__(self, *args, **kwargs): | |
|
|
||
| self.celery_app = create_celery_app(self.conf) | ||
|
|
||
| # Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters) | ||
| # Celery doesn't support bulk sending the workloads (which can become a bottleneck on bigger clusters) | ||
| # so we use a multiprocessing pool to speed this up. | ||
| # How many worker processes are created for checking celery task state. | ||
| self._sync_parallelism = self.conf.getint("celery", "SYNC_PARALLELISM", fallback=0) | ||
|
|
@@ -136,149 +141,151 @@ def __init__(self, *args, **kwargs): | |
| from airflow.providers.celery.executors.celery_executor_utils import BulkStateFetcher | ||
|
|
||
| self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism, celery_app=self.celery_app) | ||
| self.tasks = {} | ||
| self.task_publish_retries: Counter[TaskInstanceKey] = Counter() | ||
| self.task_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3) | ||
| self.workloads: dict[WorkloadKey, AsyncResult] = {} | ||
| self.workload_publish_retries: Counter[WorkloadKey] = Counter() | ||
| self.workload_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3) | ||
|
|
||
| def start(self) -> None: | ||
| self.log.debug("Starting Celery Executor using %s processes for syncing", self._sync_parallelism) | ||
|
|
||
| def _num_tasks_per_send_process(self, to_send_count: int) -> int: | ||
| def _num_workloads_per_send_process(self, to_send_count: int) -> int: | ||
| """ | ||
| How many Celery tasks should each worker process send. | ||
| How many Celery workloads should each worker process send. | ||
| :return: Number of tasks that should be sent per process | ||
| :return: Number of workloads that should be sent per process | ||
| """ | ||
| return max(1, math.ceil(to_send_count / self._sync_parallelism)) | ||
|
|
||
| def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None: | ||
| # Airflow V2 version | ||
| # Airflow V2 compatibility path — converts task tuples into workload-compatible tuples. | ||
|
|
||
| task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for task_tuple in task_tuples] | ||
|
|
||
| self._send_tasks(task_tuples_to_send) | ||
| self._send_workloads(task_tuples_to_send) | ||
|
|
||
| def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: | ||
| # Airflow V3 version -- have to delay imports until we know we are on v3 | ||
| # Airflow V3 version -- have to delay imports until we know we are on v3. | ||
| from airflow.executors.workloads import ExecuteTask | ||
|
|
||
| if AIRFLOW_V_3_2_PLUS: | ||
| from airflow.executors.workloads import ExecuteCallback | ||
|
|
||
| tasks: list[WorkloadInCelery] = [] | ||
| workloads_to_be_sent: list[WorkloadInCelery] = [] | ||
| for workload in workloads: | ||
| if isinstance(workload, ExecuteTask): | ||
| tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) | ||
| workloads_to_be_sent.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) | ||
| elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback): | ||
| # Use default queue for callbacks, or extract from callback data if available | ||
| # Use default queue for callbacks, or extract from callback data if available. | ||
| queue = "default" | ||
| if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: | ||
| queue = workload.callback.data["queue"] | ||
| tasks.append((workload.callback.key, workload, queue, self.team_name)) | ||
| workloads_to_be_sent.append((workload.callback.key, workload, queue, self.team_name)) | ||
| else: | ||
| raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") | ||
|
|
||
| self._send_tasks(tasks) | ||
| self._send_workloads(workloads_to_be_sent) | ||
|
|
||
| def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]): | ||
| def _send_workloads(self, workload_tuples_to_send: Sequence[WorkloadInCelery]): | ||
| # Celery state queries will be stuck if we do not use one same backend | ||
| # for all tasks. | ||
| # for all workloads. | ||
| cached_celery_backend = self.celery_app.backend | ||
|
|
||
| key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send) | ||
| self.log.debug("Sent all tasks.") | ||
| key_and_async_results = self._send_workloads_to_celery(workload_tuples_to_send) | ||
| self.log.debug("Sent all workloads.") | ||
| from airflow.providers.celery.executors.celery_executor_utils import ExceptionWithTraceback | ||
|
|
||
| for key, _, result in key_and_async_results: | ||
| if isinstance(result, ExceptionWithTraceback) and isinstance( | ||
| result.exception, AirflowTaskTimeout | ||
| ): | ||
| retries = self.task_publish_retries[key] | ||
| if retries < self.task_publish_max_retries: | ||
| retries = self.workload_publish_retries[key] | ||
| if retries < self.workload_publish_max_retries: | ||
| Stats.incr("celery.task_timeout_error") | ||
| self.log.info( | ||
| "[Try %s of %s] Task Timeout Error for Task: (%s).", | ||
| self.task_publish_retries[key] + 1, | ||
| self.task_publish_max_retries, | ||
| "[Try %s of %s] Task Timeout Error for Workload: (%s).", | ||
| self.workload_publish_retries[key] + 1, | ||
| self.workload_publish_max_retries, | ||
| tuple(key), | ||
| ) | ||
| self.task_publish_retries[key] = retries + 1 | ||
| self.workload_publish_retries[key] = retries + 1 | ||
| continue | ||
| if key in self.queued_tasks: | ||
| self.queued_tasks.pop(key) | ||
| else: | ||
| self.queued_callbacks.pop(key, None) | ||
| self.task_publish_retries.pop(key, None) | ||
| self.workload_publish_retries.pop(key, None) | ||
| if isinstance(result, ExceptionWithTraceback): | ||
| self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback) | ||
| self.event_buffer[key] = (TaskInstanceState.FAILED, None) | ||
| elif result is not None: | ||
| result.backend = cached_celery_backend | ||
| self.running.add(key) | ||
| self.tasks[key] = result | ||
| self.workloads[key] = result | ||
|
|
||
| # Store the Celery task_id in the event buffer. This will get "overwritten" if the task | ||
| # Store the Celery task_id (workload execution ID) in the event buffer. This will get "overwritten" if the task | ||
| # has another event, but that is fine, because the only other events are success/failed at | ||
| # which point we don't need the ID anymore anyway | ||
| # which point we don't need the ID anymore anyway. | ||
| self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id) | ||
|
|
||
| def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[WorkloadInCelery]): | ||
| from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor | ||
| def _send_workloads_to_celery(self, workload_tuples_to_send: Sequence[WorkloadInCelery]): | ||
| from airflow.providers.celery.executors.celery_executor_utils import send_workload_to_executor | ||
|
|
||
| if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1: | ||
| if len(workload_tuples_to_send) == 1 or self._sync_parallelism == 1: | ||
| # One tuple, or max one process -> send it in the main thread. | ||
| return list(map(send_task_to_executor, task_tuples_to_send)) | ||
| return list(map(send_workload_to_executor, workload_tuples_to_send)) | ||
|
|
||
| # Use chunks instead of a work queue to reduce context switching | ||
| # since tasks are roughly uniform in size | ||
| chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send)) | ||
| num_processes = min(len(task_tuples_to_send), self._sync_parallelism) | ||
| # since workloads are roughly uniform in size. | ||
| chunksize = self._num_workloads_per_send_process(len(workload_tuples_to_send)) | ||
| num_processes = min(len(workload_tuples_to_send), self._sync_parallelism) | ||
|
|
||
| # Use ProcessPoolExecutor with team_name instead of task objects to avoid pickling issues. | ||
| # Use ProcessPoolExecutor with team_name instead of workload objects to avoid pickling issues. | ||
| # Subprocesses reconstruct the team-specific Celery app from the team name and existing config. | ||
| with ProcessPoolExecutor(max_workers=num_processes) as send_pool: | ||
| key_and_async_results = list( | ||
| send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize) | ||
| send_pool.map(send_workload_to_executor, workload_tuples_to_send, chunksize=chunksize) | ||
| ) | ||
| return key_and_async_results | ||
|
|
||
| def sync(self) -> None: | ||
| if not self.tasks: | ||
| self.log.debug("No task to query celery, skipping sync") | ||
| if not self.workloads: | ||
| self.log.debug("No workload to query celery, skipping sync") | ||
| return | ||
| self.update_all_task_states() | ||
| self.update_all_workload_states() | ||
|
|
||
| def debug_dump(self) -> None: | ||
| """Debug dump; called in response to SIGUSR2 by the scheduler.""" | ||
| super().debug_dump() | ||
| self.log.info( | ||
| "executor.tasks (%d)\n\t%s", len(self.tasks), "\n\t".join(map(repr, self.tasks.items())) | ||
| "executor.workloads (%d)\n\t%s", | ||
| len(self.workloads), | ||
| "\n\t".join(map(repr, self.workloads.items())), | ||
| ) | ||
|
|
||
| def update_all_task_states(self) -> None: | ||
| """Update states of the tasks.""" | ||
| self.log.debug("Inquiring about %s celery task(s)", len(self.tasks)) | ||
| state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.tasks.values()) | ||
| def update_all_workload_states(self) -> None: | ||
| """Update states of the workloads.""" | ||
| self.log.debug("Inquiring about %s celery workload(s)", len(self.workloads)) | ||
| state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.workloads.values()) | ||
|
|
||
| self.log.debug("Inquiries completed.") | ||
| for key, async_result in list(self.tasks.items()): | ||
| for key, async_result in list(self.workloads.items()): | ||
| state, info = state_and_info_by_celery_task_id.get(async_result.task_id) | ||
| if state: | ||
| self.update_task_state(key, state, info) | ||
| self.update_workload_state(key, state, info) | ||
|
|
||
| def change_state( | ||
| self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True | ||
| ) -> None: | ||
| super().change_state(key, state, info, remove_running=remove_running) | ||
| self.tasks.pop(key, None) | ||
| self.workloads.pop(key, None) | ||
|
|
||
| def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None: | ||
| """Update state of a single task.""" | ||
| def update_workload_state(self, key: WorkloadKey, state: str, info: Any) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double check me, but I don't believe callbacks support this part (yet?), do they? If not, then maybe we should leave this one for now? With this change it makes the method look like it properly handles both types when it doesn't. |
||
| """Update state of a single workload.""" | ||
| try: | ||
| if state == celery_states.SUCCESS: | ||
| self.success(key, info) | ||
| self.success(cast("TaskInstanceKey", key), info) | ||
| elif state in (celery_states.FAILURE, celery_states.REVOKED): | ||
| self.fail(key, info) | ||
| self.fail(cast("TaskInstanceKey", key), info) | ||
| elif state in (celery_states.STARTED, celery_states.PENDING, celery_states.RETRY): | ||
| pass | ||
| else: | ||
|
|
@@ -288,7 +295,9 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None | |
|
|
||
| def end(self, synchronous: bool = False) -> None: | ||
| if synchronous: | ||
| while any(task.state not in celery_states.READY_STATES for task in self.tasks.values()): | ||
| while any( | ||
| workload.state not in celery_states.READY_STATES for workload in self.workloads.values() | ||
| ): | ||
| time.sleep(5) | ||
| self.sync() | ||
|
|
||
|
|
@@ -322,7 +331,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task | |
| not_adopted_tis.append(ti) | ||
|
|
||
| if not celery_tasks: | ||
| # Nothing to adopt | ||
| # Nothing to adopt. | ||
| return tis | ||
|
|
||
| states_by_celery_task_id = self.bulk_state_fetcher.get_many( | ||
|
|
@@ -342,9 +351,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task | |
|
|
||
| # Set the correct elements of the state dicts, then update this | ||
| # like we just queried it. | ||
| self.tasks[ti.key] = result | ||
| self.workloads[ti.key] = result | ||
| self.running.add(ti.key) | ||
| self.update_task_state(ti.key, state, info) | ||
| self.update_workload_state(ti.key, state, info) | ||
| adopted.append(f"{ti} in state {state}") | ||
|
|
||
| if adopted: | ||
|
|
@@ -373,7 +382,7 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: | |
| return reprs | ||
|
|
||
| def revoke_task(self, *, ti: TaskInstance): | ||
| celery_async_result = self.tasks.pop(ti.key, None) | ||
| celery_async_result = self.workloads.pop(ti.key, None) | ||
| if celery_async_result: | ||
| try: | ||
| self.celery_app.control.revoke(celery_async_result.task_id) | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's a miss here. You are importing WorkloadKey if version is over 3.2 up above, but using it here if airflow version is 3.0. What about using this as the import block?
I know there's some community debate over using try/catch on imports, but I think this feels like the right time to use one.
((I think the same comment goes for celery_executor_utils.py as well))