Skip to content

Commit 29a855f

Browse files
author
Sameer Mesiah
committed
Clean up CeleryExecutor docstrings, comments, variable names, and typing to align with the workload-based executor model.
1 parent c438e2a commit 29a855f

5 files changed

Lines changed: 198 additions & 191 deletions

File tree

providers/celery/src/airflow/providers/celery/executors/celery_executor.py

Lines changed: 72 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@
3232
from collections import Counter
3333
from concurrent.futures import ProcessPoolExecutor
3434
from multiprocessing import cpu_count
35-
from typing import TYPE_CHECKING, Any
35+
from typing import TYPE_CHECKING, Any, cast
3636

3737
from celery import states as celery_states
3838
from deprecated import deprecated
3939

4040
from airflow.exceptions import AirflowProviderDeprecationWarning
4141
from airflow.executors.base_executor import BaseExecutor
4242
from airflow.providers.celery.executors import (
43-
celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043
43+
celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043.
4444
)
4545
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
4646
from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats
@@ -49,18 +49,23 @@
4949
log = logging.getLogger(__name__)
5050

5151

52-
CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task"
52+
CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery workload"
5353

5454

5555
if TYPE_CHECKING:
5656
from collections.abc import Sequence
5757

58+
from celery.result import AsyncResult
59+
5860
from airflow.cli.cli_config import GroupCommand
5961
from airflow.executors import workloads
6062
from airflow.models.taskinstance import TaskInstance
6163
from airflow.models.taskinstancekey import TaskInstanceKey
6264
from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery
6365

66+
if AIRFLOW_V_3_2_PLUS:
67+
from airflow.executors.workloads.types import WorkloadKey
68+
6469

6570
# PEP562
6671
def __getattr__(name):
@@ -84,7 +89,7 @@ class CeleryExecutor(BaseExecutor):
8489
"""
8590
CeleryExecutor is recommended for production use of Airflow.
8691
87-
It allows distributing the execution of task instances to multiple worker nodes.
92+
It allows distributing the execution of workloads (task instances and callbacks) to multiple worker nodes.
8893
8994
Celery is a simple, flexible and reliable distributed system to process
9095
vast amounts of messages, while providing operations with the tools
@@ -102,7 +107,7 @@ class CeleryExecutor(BaseExecutor):
102107
if TYPE_CHECKING:
103108
if AIRFLOW_V_3_0_PLUS:
104109
# TODO: TaskSDK: move this type change into BaseExecutor
105-
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
110+
queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment]
106111

107112
def __init__(self, *args, **kwargs):
108113
super().__init__(*args, **kwargs)
@@ -127,7 +132,7 @@ def __init__(self, *args, **kwargs):
127132

128133
self.celery_app = create_celery_app(self.conf)
129134

130-
# Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters)
135+
# Celery doesn't support bulk sending the workloads (which can become a bottleneck on bigger clusters)
131136
# so we use a multiprocessing pool to speed this up.
132137
# How many worker processes are created for checking celery task state.
133138
self._sync_parallelism = self.conf.getint("celery", "SYNC_PARALLELISM", fallback=0)
@@ -136,149 +141,151 @@ def __init__(self, *args, **kwargs):
136141
from airflow.providers.celery.executors.celery_executor_utils import BulkStateFetcher
137142

138143
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism, celery_app=self.celery_app)
139-
self.tasks = {}
140-
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
141-
self.task_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3)
144+
self.workloads: dict[WorkloadKey, AsyncResult] = {}
145+
self.workload_publish_retries: Counter[WorkloadKey] = Counter()
146+
self.workload_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3)
142147

143148
def start(self) -> None:
144149
self.log.debug("Starting Celery Executor using %s processes for syncing", self._sync_parallelism)
145150

146-
def _num_tasks_per_send_process(self, to_send_count: int) -> int:
151+
def _num_workloads_per_send_process(self, to_send_count: int) -> int:
147152
"""
148-
How many Celery tasks should each worker process send.
153+
How many Celery workloads should each worker process send.
149154
150-
:return: Number of tasks that should be sent per process
155+
:return: Number of workloads that should be sent per process
151156
"""
152157
return max(1, math.ceil(to_send_count / self._sync_parallelism))
153158

154159
def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None:
155-
# Airflow V2 version
160+
# Airflow V2 compatibility path — converts task tuples into workload-compatible tuples.
156161

157162
task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for task_tuple in task_tuples]
158163

159-
self._send_tasks(task_tuples_to_send)
164+
self._send_workloads(task_tuples_to_send)
160165

161166
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
162-
# Airflow V3 version -- have to delay imports until we know we are on v3
167+
# Airflow V3 version -- have to delay imports until we know we are on v3.
163168
from airflow.executors.workloads import ExecuteTask
164169

165170
if AIRFLOW_V_3_2_PLUS:
166171
from airflow.executors.workloads import ExecuteCallback
167172

168-
tasks: list[WorkloadInCelery] = []
173+
workloads_to_be_sent: list[WorkloadInCelery] = []
169174
for workload in workloads:
170175
if isinstance(workload, ExecuteTask):
171-
tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name))
176+
workloads_to_be_sent.append((workload.ti.key, workload, workload.ti.queue, self.team_name))
172177
elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):
173-
# Use default queue for callbacks, or extract from callback data if available
178+
# Use default queue for callbacks, or extract from callback data if available.
174179
queue = "default"
175180
if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data:
176181
queue = workload.callback.data["queue"]
177-
tasks.append((workload.callback.key, workload, queue, self.team_name))
182+
workloads_to_be_sent.append((workload.callback.key, workload, queue, self.team_name))
178183
else:
179184
raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}")
180185

181-
self._send_tasks(tasks)
186+
self._send_workloads(workloads_to_be_sent)
182187

183-
def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
188+
def _send_workloads(self, workload_tuples_to_send: Sequence[WorkloadInCelery]):
184189
# Celery state queries will be stuck if we do not use one same backend
185-
# for all tasks.
190+
# for all workloads.
186191
cached_celery_backend = self.celery_app.backend
187192

188-
key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
189-
self.log.debug("Sent all tasks.")
193+
key_and_async_results = self._send_workloads_to_celery(workload_tuples_to_send)
194+
self.log.debug("Sent all workloads.")
190195
from airflow.providers.celery.executors.celery_executor_utils import ExceptionWithTraceback
191196

192197
for key, _, result in key_and_async_results:
193198
if isinstance(result, ExceptionWithTraceback) and isinstance(
194199
result.exception, AirflowTaskTimeout
195200
):
196-
retries = self.task_publish_retries[key]
197-
if retries < self.task_publish_max_retries:
201+
retries = self.workload_publish_retries[key]
202+
if retries < self.workload_publish_max_retries:
198203
Stats.incr("celery.task_timeout_error")
199204
self.log.info(
200-
"[Try %s of %s] Task Timeout Error for Task: (%s).",
201-
self.task_publish_retries[key] + 1,
202-
self.task_publish_max_retries,
205+
"[Try %s of %s] Task Timeout Error for Workload: (%s).",
206+
self.workload_publish_retries[key] + 1,
207+
self.workload_publish_max_retries,
203208
tuple(key),
204209
)
205-
self.task_publish_retries[key] = retries + 1
210+
self.workload_publish_retries[key] = retries + 1
206211
continue
207212
if key in self.queued_tasks:
208213
self.queued_tasks.pop(key)
209214
else:
210215
self.queued_callbacks.pop(key, None)
211-
self.task_publish_retries.pop(key, None)
216+
self.workload_publish_retries.pop(key, None)
212217
if isinstance(result, ExceptionWithTraceback):
213218
self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback)
214219
self.event_buffer[key] = (TaskInstanceState.FAILED, None)
215220
elif result is not None:
216221
result.backend = cached_celery_backend
217222
self.running.add(key)
218-
self.tasks[key] = result
223+
self.workloads[key] = result
219224

220-
# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
225+
# Store the Celery task_id (workload execution ID) in the event buffer. This will get "overwritten" if the task
221226
# has another event, but that is fine, because the only other events are success/failed at
222-
# which point we don't need the ID anymore anyway
227+
# which point we don't need the ID anymore anyway.
223228
self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id)
224229

225-
def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
226-
from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor
230+
def _send_workloads_to_celery(self, workload_tuples_to_send: Sequence[WorkloadInCelery]):
231+
from airflow.providers.celery.executors.celery_executor_utils import send_workload_to_executor
227232

228-
if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
233+
if len(workload_tuples_to_send) == 1 or self._sync_parallelism == 1:
229234
# One tuple, or max one process -> send it in the main thread.
230-
return list(map(send_task_to_executor, task_tuples_to_send))
235+
return list(map(send_workload_to_executor, workload_tuples_to_send))
231236

232237
# Use chunks instead of a work queue to reduce context switching
233-
# since tasks are roughly uniform in size
234-
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
235-
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
238+
# since workloads are roughly uniform in size.
239+
chunksize = self._num_workloads_per_send_process(len(workload_tuples_to_send))
240+
num_processes = min(len(workload_tuples_to_send), self._sync_parallelism)
236241

237-
# Use ProcessPoolExecutor with team_name instead of task objects to avoid pickling issues.
242+
# Use ProcessPoolExecutor with team_name instead of workload objects to avoid pickling issues.
238243
# Subprocesses reconstruct the team-specific Celery app from the team name and existing config.
239244
with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
240245
key_and_async_results = list(
241-
send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize)
246+
send_pool.map(send_workload_to_executor, workload_tuples_to_send, chunksize=chunksize)
242247
)
243248
return key_and_async_results
244249

245250
def sync(self) -> None:
246-
if not self.tasks:
247-
self.log.debug("No task to query celery, skipping sync")
251+
if not self.workloads:
252+
self.log.debug("No workload to query celery, skipping sync")
248253
return
249-
self.update_all_task_states()
254+
self.update_all_workload_states()
250255

251256
def debug_dump(self) -> None:
252257
"""Debug dump; called in response to SIGUSR2 by the scheduler."""
253258
super().debug_dump()
254259
self.log.info(
255-
"executor.tasks (%d)\n\t%s", len(self.tasks), "\n\t".join(map(repr, self.tasks.items()))
260+
"executor.workloads (%d)\n\t%s",
261+
len(self.workloads),
262+
"\n\t".join(map(repr, self.workloads.items())),
256263
)
257264

258-
def update_all_task_states(self) -> None:
259-
"""Update states of the tasks."""
260-
self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
261-
state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.tasks.values())
265+
def update_all_workload_states(self) -> None:
266+
"""Update states of the workloads."""
267+
self.log.debug("Inquiring about %s celery workload(s)", len(self.workloads))
268+
state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.workloads.values())
262269

263270
self.log.debug("Inquiries completed.")
264-
for key, async_result in list(self.tasks.items()):
271+
for key, async_result in list(self.workloads.items()):
265272
state, info = state_and_info_by_celery_task_id.get(async_result.task_id)
266273
if state:
267-
self.update_task_state(key, state, info)
274+
self.update_workload_state(key, state, info)
268275

269276
def change_state(
270277
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
271278
) -> None:
272279
super().change_state(key, state, info, remove_running=remove_running)
273-
self.tasks.pop(key, None)
280+
self.workloads.pop(key, None)
274281

275-
def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
276-
"""Update state of a single task."""
282+
def update_workload_state(self, key: WorkloadKey, state: str, info: Any) -> None:
283+
"""Update state of a single workload."""
277284
try:
278285
if state == celery_states.SUCCESS:
279-
self.success(key, info)
286+
self.success(cast("TaskInstanceKey", key), info)
280287
elif state in (celery_states.FAILURE, celery_states.REVOKED):
281-
self.fail(key, info)
288+
self.fail(cast("TaskInstanceKey", key), info)
282289
elif state in (celery_states.STARTED, celery_states.PENDING, celery_states.RETRY):
283290
pass
284291
else:
@@ -288,7 +295,9 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None
288295

289296
def end(self, synchronous: bool = False) -> None:
290297
if synchronous:
291-
while any(task.state not in celery_states.READY_STATES for task in self.tasks.values()):
298+
while any(
299+
workload.state not in celery_states.READY_STATES for workload in self.workloads.values()
300+
):
292301
time.sleep(5)
293302
self.sync()
294303

@@ -322,7 +331,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
322331
not_adopted_tis.append(ti)
323332

324333
if not celery_tasks:
325-
# Nothing to adopt
334+
# Nothing to adopt.
326335
return tis
327336

328337
states_by_celery_task_id = self.bulk_state_fetcher.get_many(
@@ -342,9 +351,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
342351

343352
# Set the correct elements of the state dicts, then update this
344353
# like we just queried it.
345-
self.tasks[ti.key] = result
354+
self.workloads[ti.key] = result
346355
self.running.add(ti.key)
347-
self.update_task_state(ti.key, state, info)
356+
self.update_workload_state(ti.key, state, info)
348357
adopted.append(f"{ti} in state {state}")
349358

350359
if adopted:
@@ -373,7 +382,7 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
373382
return reprs
374383

375384
def revoke_task(self, *, ti: TaskInstance):
376-
celery_async_result = self.tasks.pop(ti.key, None)
385+
celery_async_result = self.workloads.pop(ti.key, None)
377386
if celery_async_result:
378387
try:
379388
self.celery_app.control.revoke(celery_async_result.task_id)

0 commit comments

Comments
 (0)