3232from collections import Counter
3333from concurrent .futures import ProcessPoolExecutor
3434from multiprocessing import cpu_count
35- from typing import TYPE_CHECKING , Any
35+ from typing import TYPE_CHECKING , Any , cast
3636
3737from celery import states as celery_states
3838from deprecated import deprecated
3939
4040from airflow .exceptions import AirflowProviderDeprecationWarning
4141from airflow .executors .base_executor import BaseExecutor
4242from airflow .providers .celery .executors import (
43- celery_executor_utils as _celery_executor_utils , # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043
43+ celery_executor_utils as _celery_executor_utils , # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043.
4444)
4545from airflow .providers .celery .version_compat import AIRFLOW_V_3_0_PLUS , AIRFLOW_V_3_2_PLUS
4646from airflow .providers .common .compat .sdk import AirflowTaskTimeout , Stats
4949log = logging .getLogger (__name__ )
5050
5151
52- CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task "
52+ CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery workload "
5353
5454
5555if TYPE_CHECKING :
5656 from collections .abc import Sequence
5757
58+ from celery .result import AsyncResult
59+
5860 from airflow .cli .cli_config import GroupCommand
5961 from airflow .executors import workloads
6062 from airflow .models .taskinstance import TaskInstance
6163 from airflow .models .taskinstancekey import TaskInstanceKey
6264 from airflow .providers .celery .executors .celery_executor_utils import TaskTuple , WorkloadInCelery
6365
66+ if AIRFLOW_V_3_2_PLUS :
67+ from airflow .executors .workloads .types import WorkloadKey
68+
6469
6570# PEP562
6671def __getattr__ (name ):
@@ -84,7 +89,7 @@ class CeleryExecutor(BaseExecutor):
8489 """
8590 CeleryExecutor is recommended for production use of Airflow.
8691
87- It allows distributing the execution of task instances to multiple worker nodes.
92+ It allows distributing the execution of workloads ( task instances and callbacks) to multiple worker nodes.
8893
8994 Celery is a simple, flexible and reliable distributed system to process
9095 vast amounts of messages, while providing operations with the tools
@@ -102,7 +107,7 @@ class CeleryExecutor(BaseExecutor):
102107 if TYPE_CHECKING :
103108 if AIRFLOW_V_3_0_PLUS :
104109 # TODO: TaskSDK: move this type change into BaseExecutor
105- queued_tasks : dict [TaskInstanceKey , workloads .All ] # type: ignore[assignment]
110+ queued_tasks : dict [WorkloadKey , workloads .All ] # type: ignore[assignment]
106111
107112 def __init__ (self , * args , ** kwargs ):
108113 super ().__init__ (* args , ** kwargs )
@@ -127,7 +132,7 @@ def __init__(self, *args, **kwargs):
127132
128133 self .celery_app = create_celery_app (self .conf )
129134
130- # Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters)
135+ # Celery doesn't support bulk sending the workloads (which can become a bottleneck on bigger clusters)
131136 # so we use a multiprocessing pool to speed this up.
132137 # How many worker processes are created for checking celery task state.
133138 self ._sync_parallelism = self .conf .getint ("celery" , "SYNC_PARALLELISM" , fallback = 0 )
@@ -136,149 +141,151 @@ def __init__(self, *args, **kwargs):
136141 from airflow .providers .celery .executors .celery_executor_utils import BulkStateFetcher
137142
138143 self .bulk_state_fetcher = BulkStateFetcher (self ._sync_parallelism , celery_app = self .celery_app )
139- self .tasks = {}
140- self .task_publish_retries : Counter [TaskInstanceKey ] = Counter ()
141- self .task_publish_max_retries = self .conf .getint ("celery" , "task_publish_max_retries" , fallback = 3 )
144+ self .workloads : dict [ WorkloadKey , AsyncResult ] = {}
145+ self .workload_publish_retries : Counter [WorkloadKey ] = Counter ()
146+ self .workload_publish_max_retries = self .conf .getint ("celery" , "task_publish_max_retries" , fallback = 3 )
142147
143148 def start (self ) -> None :
144149 self .log .debug ("Starting Celery Executor using %s processes for syncing" , self ._sync_parallelism )
145150
146- def _num_tasks_per_send_process (self , to_send_count : int ) -> int :
151+ def _num_workloads_per_send_process (self , to_send_count : int ) -> int :
147152 """
148- How many Celery tasks should each worker process send.
153+ How many Celery workloads should each worker process send.
149154
150- :return: Number of tasks that should be sent per process
155+ :return: Number of workloads that should be sent per process
151156 """
152157 return max (1 , math .ceil (to_send_count / self ._sync_parallelism ))
153158
154159 def _process_tasks (self , task_tuples : Sequence [TaskTuple ]) -> None :
155- # Airflow V2 version
160+ # Airflow V2 compatibility path — converts task tuples into workload-compatible tuples.
156161
157162 task_tuples_to_send = [task_tuple [:3 ] + (self .team_name ,) for task_tuple in task_tuples ]
158163
159- self ._send_tasks (task_tuples_to_send )
164+ self ._send_workloads (task_tuples_to_send )
160165
161166 def _process_workloads (self , workloads : Sequence [workloads .All ]) -> None :
162- # Airflow V3 version -- have to delay imports until we know we are on v3
167+ # Airflow V3 version -- have to delay imports until we know we are on v3.
163168 from airflow .executors .workloads import ExecuteTask
164169
165170 if AIRFLOW_V_3_2_PLUS :
166171 from airflow .executors .workloads import ExecuteCallback
167172
168- tasks : list [WorkloadInCelery ] = []
173+ workloads_to_be_sent : list [WorkloadInCelery ] = []
169174 for workload in workloads :
170175 if isinstance (workload , ExecuteTask ):
171- tasks .append ((workload .ti .key , workload , workload .ti .queue , self .team_name ))
176+ workloads_to_be_sent .append ((workload .ti .key , workload , workload .ti .queue , self .team_name ))
172177 elif AIRFLOW_V_3_2_PLUS and isinstance (workload , ExecuteCallback ):
173- # Use default queue for callbacks, or extract from callback data if available
178+ # Use default queue for callbacks, or extract from callback data if available.
174179 queue = "default"
175180 if isinstance (workload .callback .data , dict ) and "queue" in workload .callback .data :
176181 queue = workload .callback .data ["queue" ]
177- tasks .append ((workload .callback .key , workload , queue , self .team_name ))
182+ workloads_to_be_sent .append ((workload .callback .key , workload , queue , self .team_name ))
178183 else :
179184 raise ValueError (f"{ type (self )} ._process_workloads cannot handle { type (workload )} " )
180185
181- self ._send_tasks ( tasks )
186+ self ._send_workloads ( workloads_to_be_sent )
182187
183- def _send_tasks (self , task_tuples_to_send : Sequence [WorkloadInCelery ]):
188+ def _send_workloads (self , workload_tuples_to_send : Sequence [WorkloadInCelery ]):
184189 # Celery state queries will be stuck if we do not use one same backend
185- # for all tasks .
190+ # for all workloads .
186191 cached_celery_backend = self .celery_app .backend
187192
188- key_and_async_results = self ._send_tasks_to_celery ( task_tuples_to_send )
189- self .log .debug ("Sent all tasks ." )
193+ key_and_async_results = self ._send_workloads_to_celery ( workload_tuples_to_send )
194+ self .log .debug ("Sent all workloads ." )
190195 from airflow .providers .celery .executors .celery_executor_utils import ExceptionWithTraceback
191196
192197 for key , _ , result in key_and_async_results :
193198 if isinstance (result , ExceptionWithTraceback ) and isinstance (
194199 result .exception , AirflowTaskTimeout
195200 ):
196- retries = self .task_publish_retries [key ]
197- if retries < self .task_publish_max_retries :
201+ retries = self .workload_publish_retries [key ]
202+ if retries < self .workload_publish_max_retries :
198203 Stats .incr ("celery.task_timeout_error" )
199204 self .log .info (
200- "[Try %s of %s] Task Timeout Error for Task : (%s)." ,
201- self .task_publish_retries [key ] + 1 ,
202- self .task_publish_max_retries ,
205+ "[Try %s of %s] Task Timeout Error for Workload : (%s)." ,
206+ self .workload_publish_retries [key ] + 1 ,
207+ self .workload_publish_max_retries ,
203208 tuple (key ),
204209 )
205- self .task_publish_retries [key ] = retries + 1
210+ self .workload_publish_retries [key ] = retries + 1
206211 continue
207212 if key in self .queued_tasks :
208213 self .queued_tasks .pop (key )
209214 else :
210215 self .queued_callbacks .pop (key , None )
211- self .task_publish_retries .pop (key , None )
216+ self .workload_publish_retries .pop (key , None )
212217 if isinstance (result , ExceptionWithTraceback ):
213218 self .log .error ("%s: %s\n %s\n " , CELERY_SEND_ERR_MSG_HEADER , result .exception , result .traceback )
214219 self .event_buffer [key ] = (TaskInstanceState .FAILED , None )
215220 elif result is not None :
216221 result .backend = cached_celery_backend
217222 self .running .add (key )
218- self .tasks [key ] = result
223+ self .workloads [key ] = result
219224
220- # Store the Celery task_id in the event buffer. This will get "overwritten" if the task
225+ # Store the Celery task_id (workload execution ID) in the event buffer. This will get "overwritten" if the task
221226 # has another event, but that is fine, because the only other events are success/failed at
222- # which point we don't need the ID anymore anyway
227+ # which point we don't need the ID anymore anyway.
223228 self .event_buffer [key ] = (TaskInstanceState .QUEUED , result .task_id )
224229
225- def _send_tasks_to_celery (self , task_tuples_to_send : Sequence [WorkloadInCelery ]):
226- from airflow .providers .celery .executors .celery_executor_utils import send_task_to_executor
230+ def _send_workloads_to_celery (self , workload_tuples_to_send : Sequence [WorkloadInCelery ]):
231+ from airflow .providers .celery .executors .celery_executor_utils import send_workload_to_executor
227232
228- if len (task_tuples_to_send ) == 1 or self ._sync_parallelism == 1 :
233+ if len (workload_tuples_to_send ) == 1 or self ._sync_parallelism == 1 :
229234 # One tuple, or max one process -> send it in the main thread.
230- return list (map (send_task_to_executor , task_tuples_to_send ))
235+ return list (map (send_workload_to_executor , workload_tuples_to_send ))
231236
232237 # Use chunks instead of a work queue to reduce context switching
233- # since tasks are roughly uniform in size
234- chunksize = self ._num_tasks_per_send_process (len (task_tuples_to_send ))
235- num_processes = min (len (task_tuples_to_send ), self ._sync_parallelism )
238+ # since workloads are roughly uniform in size.
239+ chunksize = self ._num_workloads_per_send_process (len (workload_tuples_to_send ))
240+ num_processes = min (len (workload_tuples_to_send ), self ._sync_parallelism )
236241
237- # Use ProcessPoolExecutor with team_name instead of task objects to avoid pickling issues.
242+ # Use ProcessPoolExecutor with team_name instead of workload objects to avoid pickling issues.
238243 # Subprocesses reconstruct the team-specific Celery app from the team name and existing config.
239244 with ProcessPoolExecutor (max_workers = num_processes ) as send_pool :
240245 key_and_async_results = list (
241- send_pool .map (send_task_to_executor , task_tuples_to_send , chunksize = chunksize )
246+ send_pool .map (send_workload_to_executor , workload_tuples_to_send , chunksize = chunksize )
242247 )
243248 return key_and_async_results
244249
245250 def sync (self ) -> None :
246- if not self .tasks :
247- self .log .debug ("No task to query celery, skipping sync" )
251+ if not self .workloads :
252+ self .log .debug ("No workload to query celery, skipping sync" )
248253 return
249- self .update_all_task_states ()
254+ self .update_all_workload_states ()
250255
251256 def debug_dump (self ) -> None :
252257 """Debug dump; called in response to SIGUSR2 by the scheduler."""
253258 super ().debug_dump ()
254259 self .log .info (
255- "executor.tasks (%d)\n \t %s" , len (self .tasks ), "\n \t " .join (map (repr , self .tasks .items ()))
260+ "executor.workloads (%d)\n \t %s" ,
261+ len (self .workloads ),
262+ "\n \t " .join (map (repr , self .workloads .items ())),
256263 )
257264
258- def update_all_task_states (self ) -> None :
259- """Update states of the tasks ."""
260- self .log .debug ("Inquiring about %s celery task (s)" , len (self .tasks ))
261- state_and_info_by_celery_task_id = self .bulk_state_fetcher .get_many (self .tasks .values ())
265+ def update_all_workload_states (self ) -> None :
266+ """Update states of the workloads ."""
267+ self .log .debug ("Inquiring about %s celery workload (s)" , len (self .workloads ))
268+ state_and_info_by_celery_task_id = self .bulk_state_fetcher .get_many (self .workloads .values ())
262269
263270 self .log .debug ("Inquiries completed." )
264- for key , async_result in list (self .tasks .items ()):
271+ for key , async_result in list (self .workloads .items ()):
265272 state , info = state_and_info_by_celery_task_id .get (async_result .task_id )
266273 if state :
267- self .update_task_state (key , state , info )
274+ self .update_workload_state (key , state , info )
268275
269276 def change_state (
270277 self , key : TaskInstanceKey , state : TaskInstanceState , info = None , remove_running = True
271278 ) -> None :
272279 super ().change_state (key , state , info , remove_running = remove_running )
273- self .tasks .pop (key , None )
280+ self .workloads .pop (key , None )
274281
275- def update_task_state (self , key : TaskInstanceKey , state : str , info : Any ) -> None :
276- """Update state of a single task ."""
282+ def update_workload_state (self , key : WorkloadKey , state : str , info : Any ) -> None :
283+ """Update state of a single workload ."""
277284 try :
278285 if state == celery_states .SUCCESS :
279- self .success (key , info )
286+ self .success (cast ( "TaskInstanceKey" , key ) , info )
280287 elif state in (celery_states .FAILURE , celery_states .REVOKED ):
281- self .fail (key , info )
288+ self .fail (cast ( "TaskInstanceKey" , key ) , info )
282289 elif state in (celery_states .STARTED , celery_states .PENDING , celery_states .RETRY ):
283290 pass
284291 else :
@@ -288,7 +295,9 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None
288295
289296 def end (self , synchronous : bool = False ) -> None :
290297 if synchronous :
291- while any (task .state not in celery_states .READY_STATES for task in self .tasks .values ()):
298+ while any (
299+ workload .state not in celery_states .READY_STATES for workload in self .workloads .values ()
300+ ):
292301 time .sleep (5 )
293302 self .sync ()
294303
@@ -322,7 +331,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
322331 not_adopted_tis .append (ti )
323332
324333 if not celery_tasks :
325- # Nothing to adopt
334+ # Nothing to adopt.
326335 return tis
327336
328337 states_by_celery_task_id = self .bulk_state_fetcher .get_many (
@@ -342,9 +351,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
342351
343352 # Set the correct elements of the state dicts, then update this
344353 # like we just queried it.
345- self .tasks [ti .key ] = result
354+ self .workloads [ti .key ] = result
346355 self .running .add (ti .key )
347- self .update_task_state (ti .key , state , info )
356+ self .update_workload_state (ti .key , state , info )
348357 adopted .append (f"{ ti } in state { state } " )
349358
350359 if adopted :
@@ -373,7 +382,7 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
373382 return reprs
374383
375384 def revoke_task (self , * , ti : TaskInstance ):
376- celery_async_result = self .tasks .pop (ti .key , None )
385+ celery_async_result = self .workloads .pop (ti .key , None )
377386 if celery_async_result :
378387 try :
379388 self .celery_app .control .revoke (celery_async_result .task_id )
0 commit comments