diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 53e7f06948286..bf991e388e7ac 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -32,7 +32,7 @@ from collections import Counter from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from celery import states as celery_states from deprecated import deprecated @@ -40,7 +40,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor from airflow.providers.celery.executors import ( - celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043 + celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043. ) from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats @@ -49,18 +49,23 @@ log = logging.getLogger(__name__) -CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task" +CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery workload" if TYPE_CHECKING: from collections.abc import Sequence + from celery.result import AsyncResult + from airflow.cli.cli_config import GroupCommand from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery + if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads.types import WorkloadKey + # PEP562 def __getattr__(name): @@ -84,7 +89,7 @@ class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. - It allows distributing the execution of task instances to multiple worker nodes. + It allows distributing the execution of workloads (task instances and callbacks) to multiple worker nodes. Celery is a simple, flexible and reliable distributed system to process vast amounts of messages, while providing operations with the tools @@ -102,7 +107,7 @@ class CeleryExecutor(BaseExecutor): if TYPE_CHECKING: if AIRFLOW_V_3_0_PLUS: # TODO: TaskSDK: move this type change into BaseExecutor - queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] + queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -127,7 +132,7 @@ def __init__(self, *args, **kwargs): self.celery_app = create_celery_app(self.conf) - # Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters) + # Celery doesn't support bulk sending the workloads (which can become a bottleneck on bigger clusters) # so we use a multiprocessing pool to speed this up. # How many worker processes are created for checking celery task state. self._sync_parallelism = self.conf.getint("celery", "SYNC_PARALLELISM", fallback=0) @@ -136,149 +141,151 @@ def __init__(self, *args, **kwargs): from airflow.providers.celery.executors.celery_executor_utils import BulkStateFetcher self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism, celery_app=self.celery_app) - self.tasks = {} - self.task_publish_retries: Counter[TaskInstanceKey] = Counter() - self.task_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3) + self.workloads: dict[WorkloadKey, AsyncResult] = {} + self.workload_publish_retries: Counter[WorkloadKey] = Counter() + self.workload_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3) def start(self) -> None: self.log.debug("Starting Celery Executor using %s processes for syncing", self._sync_parallelism) - def _num_tasks_per_send_process(self, to_send_count: int) -> int: + def _num_workloads_per_send_process(self, to_send_count: int) -> int: """ - How many Celery tasks should each worker process send. + How many Celery workloads should each worker process send. - :return: Number of tasks that should be sent per process + :return: Number of workloads that should be sent per process """ return max(1, math.ceil(to_send_count / self._sync_parallelism)) def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None: - # Airflow V2 version + # Airflow V2 compatibility path — converts task tuples into workload-compatible tuples. task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for task_tuple in task_tuples] - self._send_tasks(task_tuples_to_send) + self._send_workloads(task_tuples_to_send) def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: - # Airflow V3 version -- have to delay imports until we know we are on v3 + # Airflow V3 version -- have to delay imports until we know we are on v3. from airflow.executors.workloads import ExecuteTask if AIRFLOW_V_3_2_PLUS: from airflow.executors.workloads import ExecuteCallback - tasks: list[WorkloadInCelery] = [] + workloads_to_be_sent: list[WorkloadInCelery] = [] for workload in workloads: if isinstance(workload, ExecuteTask): - tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) + workloads_to_be_sent.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback): - # Use default queue for callbacks, or extract from callback data if available + # Use default queue for callbacks, or extract from callback data if available. queue = "default" if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: queue = workload.callback.data["queue"] - tasks.append((workload.callback.key, workload, queue, self.team_name)) + workloads_to_be_sent.append((workload.callback.key, workload, queue, self.team_name)) else: raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") - self._send_tasks(tasks) + self._send_workloads(workloads_to_be_sent) - def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]): + def _send_workloads(self, workload_tuples_to_send: Sequence[WorkloadInCelery]): # Celery state queries will be stuck if we do not use one same backend - # for all tasks. + # for all workloads. cached_celery_backend = self.celery_app.backend - key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send) - self.log.debug("Sent all tasks.") + key_and_async_results = self._send_workloads_to_celery(workload_tuples_to_send) + self.log.debug("Sent all workloads.") from airflow.providers.celery.executors.celery_executor_utils import ExceptionWithTraceback for key, _, result in key_and_async_results: if isinstance(result, ExceptionWithTraceback) and isinstance( result.exception, AirflowTaskTimeout ): - retries = self.task_publish_retries[key] - if retries < self.task_publish_max_retries: + retries = self.workload_publish_retries[key] + if retries < self.workload_publish_max_retries: Stats.incr("celery.task_timeout_error") self.log.info( - "[Try %s of %s] Task Timeout Error for Task: (%s).", - self.task_publish_retries[key] + 1, - self.task_publish_max_retries, + "[Try %s of %s] Task Timeout Error for Workload: (%s).", + self.workload_publish_retries[key] + 1, + self.workload_publish_max_retries, tuple(key), ) - self.task_publish_retries[key] = retries + 1 + self.workload_publish_retries[key] = retries + 1 continue if key in self.queued_tasks: self.queued_tasks.pop(key) else: self.queued_callbacks.pop(key, None) - self.task_publish_retries.pop(key, None) + self.workload_publish_retries.pop(key, None) if isinstance(result, ExceptionWithTraceback): self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback) self.event_buffer[key] = (TaskInstanceState.FAILED, None) elif result is not None: result.backend = cached_celery_backend self.running.add(key) - self.tasks[key] = result + self.workloads[key] = result - # Store the Celery task_id in the event buffer. This will get "overwritten" if the task + # Store the Celery task_id (workload execution ID) in the event buffer. This will get "overwritten" if the task # has another event, but that is fine, because the only other events are success/failed at - # which point we don't need the ID anymore anyway + # which point we don't need the ID anymore anyway. self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id) - def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[WorkloadInCelery]): - from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor + def _send_workloads_to_celery(self, workload_tuples_to_send: Sequence[WorkloadInCelery]): + from airflow.providers.celery.executors.celery_executor_utils import send_workload_to_executor - if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1: + if len(workload_tuples_to_send) == 1 or self._sync_parallelism == 1: # One tuple, or max one process -> send it in the main thread. - return list(map(send_task_to_executor, task_tuples_to_send)) + return list(map(send_workload_to_executor, workload_tuples_to_send)) # Use chunks instead of a work queue to reduce context switching - # since tasks are roughly uniform in size - chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send)) - num_processes = min(len(task_tuples_to_send), self._sync_parallelism) + # since workloads are roughly uniform in size. + chunksize = self._num_workloads_per_send_process(len(workload_tuples_to_send)) + num_processes = min(len(workload_tuples_to_send), self._sync_parallelism) - # Use ProcessPoolExecutor with team_name instead of task objects to avoid pickling issues. + # Use ProcessPoolExecutor with team_name instead of workload objects to avoid pickling issues. # Subprocesses reconstruct the team-specific Celery app from the team name and existing config. with ProcessPoolExecutor(max_workers=num_processes) as send_pool: key_and_async_results = list( - send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize) + send_pool.map(send_workload_to_executor, workload_tuples_to_send, chunksize=chunksize) ) return key_and_async_results def sync(self) -> None: - if not self.tasks: - self.log.debug("No task to query celery, skipping sync") + if not self.workloads: + self.log.debug("No workload to query celery, skipping sync") return - self.update_all_task_states() + self.update_all_workload_states() def debug_dump(self) -> None: """Debug dump; called in response to SIGUSR2 by the scheduler.""" super().debug_dump() self.log.info( - "executor.tasks (%d)\n\t%s", len(self.tasks), "\n\t".join(map(repr, self.tasks.items())) + "executor.workloads (%d)\n\t%s", + len(self.workloads), + "\n\t".join(map(repr, self.workloads.items())), ) - def update_all_task_states(self) -> None: - """Update states of the tasks.""" - self.log.debug("Inquiring about %s celery task(s)", len(self.tasks)) - state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.tasks.values()) + def update_all_workload_states(self) -> None: + """Update states of the workloads.""" + self.log.debug("Inquiring about %s celery workload(s)", len(self.workloads)) + state_and_info_by_celery_task_id = self.bulk_state_fetcher.get_many(self.workloads.values()) self.log.debug("Inquiries completed.") - for key, async_result in list(self.tasks.items()): + for key, async_result in list(self.workloads.items()): state, info = state_and_info_by_celery_task_id.get(async_result.task_id) if state: - self.update_task_state(key, state, info) + self.update_workload_state(key, state, info) def change_state( self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True ) -> None: super().change_state(key, state, info, remove_running=remove_running) - self.tasks.pop(key, None) + self.workloads.pop(key, None) - def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None: - """Update state of a single task.""" + def update_workload_state(self, key: WorkloadKey, state: str, info: Any) -> None: + """Update state of a single workload.""" try: if state == celery_states.SUCCESS: - self.success(key, info) + self.success(cast("TaskInstanceKey", key), info) elif state in (celery_states.FAILURE, celery_states.REVOKED): - self.fail(key, info) + self.fail(cast("TaskInstanceKey", key), info) elif state in (celery_states.STARTED, celery_states.PENDING, celery_states.RETRY): pass else: @@ -288,7 +295,9 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None def end(self, synchronous: bool = False) -> None: if synchronous: - while any(task.state not in celery_states.READY_STATES for task in self.tasks.values()): + while any( + workload.state not in celery_states.READY_STATES for workload in self.workloads.values() + ): time.sleep(5) self.sync() @@ -322,7 +331,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task not_adopted_tis.append(ti) if not celery_tasks: - # Nothing to adopt + # Nothing to adopt. return tis states_by_celery_task_id = self.bulk_state_fetcher.get_many( @@ -342,9 +351,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task # Set the correct elements of the state dicts, then update this # like we just queried it. - self.tasks[ti.key] = result + self.workloads[ti.key] = result self.running.add(ti.key) - self.update_task_state(ti.key, state, info) + self.update_workload_state(ti.key, state, info) adopted.append(f"{ti} in state {state}") if adopted: @@ -373,7 +382,7 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: return reprs def revoke_task(self, *, ti: TaskInstance): - celery_async_result = self.tasks.pop(ti.key, None) + celery_async_result = self.workloads.pop(ti.key, None) if celery_async_result: try: self.celery_app.control.revoke(celery_async_result.task_id) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index e96d3800ccda4..9f385ae1ffee9 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -75,7 +75,7 @@ from airflow.models.taskinstance import TaskInstanceKey # We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so unfortunately we just have to define - # the type as the union of both kinds + # the type as the union of both kinds. CommandType = Sequence[str] WorkloadInCelery: TypeAlias = tuple[WorkloadKey, workloads.All | CommandType, str | None, str | None] @@ -83,7 +83,7 @@ WorkloadKey, CommandType, AsyncResult | "ExceptionWithTraceback" ] - # Deprecated alias for backward compatibility + # Deprecated alias for backward compatibility. TaskInstanceInCelery: TypeAlias = WorkloadInCelery TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None] @@ -124,10 +124,10 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery: celery_app_name = team_conf.get("celery", "CELERY_APP_NAME") - # Make app name unique per team to ensure proper broker isolation + # Make app name unique per team to ensure proper broker isolation. # Each team's executor needs a distinct Celery app name to prevent - # tasks from being routed to the wrong broker - # Only do this if team_conf is an ExecutorConf with team_name (not global conf) + # tasks from being routed to the wrong broker. + # Only do this if team_conf is an ExecutorConf with team_name (not global conf). team_name = getattr(team_conf, "team_name", None) if team_name: celery_app_name = f"{celery_app_name}_{team_name}" @@ -136,7 +136,7 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery: celery_app = Celery(celery_app_name, config_source=config) - # Register tasks with this app + # Register tasks with this app. celery_app.task(name="execute_workload")(execute_workload) if not AIRFLOW_V_3_0_PLUS: celery_app.task(name="execute_command")(execute_command) @@ -144,7 +144,7 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery: return celery_app -# Keep module-level app for backward compatibility +# Keep module-level app for backward compatibility. app = _get_celery_app() @@ -186,7 +186,7 @@ def on_celery_worker_ready(*args, **kwargs): # Once Celery 5.5 is out of beta, we can pass `pydantic=True` to the decorator and it will handle the validation -# and deserialization for us +# and deserialization for us. @app.task(name="execute_workload") def execute_workload(input: str) -> None: from pydantic import TypeAdapter @@ -203,7 +203,7 @@ def execute_workload(input: str) -> None: log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload) base_url = conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default + # If it's a relative URL, use localhost:8080 as the default. if base_url.startswith("/"): base_url = f"http://localhost:8080{base_url}" default_execution_api_server = f"{base_url.rstrip('/')}/execution/" @@ -254,7 +254,7 @@ def execute_command(command_to_exec: CommandType) -> None: def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None = None) -> None: pid = os.fork() if pid: - # In parent, wait for the child + # In parent, wait for the child. pid, ret = os.waitpid(pid, 0) if ret == 0: return @@ -269,7 +269,7 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None = from airflow.cli.cli_parser import get_parser parser = get_parser() - # [1:] - remove "airflow" from the start of the command + # [1:] - remove "airflow" from the start of the command. args = parser.parse_args(command_to_exec[1:]) args.shut_down_logging = False if celery_task_id: @@ -329,7 +329,7 @@ def send_workload_to_executor( workload_tuple: WorkloadInCelery, ) -> WorkloadInCeleryResult: """ - Send workload to executor. + Send workload to executor (serialized and executed as a Celery task). This function is called in ProcessPoolExecutor subprocesses. To avoid pickling issues with team-specific Celery apps, we pass the team_name and reconstruct the Celery app here. @@ -340,26 +340,26 @@ def send_workload_to_executor( # ExecutorConf wraps config access to automatically use team-specific config where present. if TYPE_CHECKING: _conf: ExecutorConf | AirflowConfigParser - # Check if Airflow version is greater than or equal to 3.2 to import ExecutorConf + # Check if Airflow version is greater than or equal to 3.2 to import ExecutorConf. if AIRFLOW_V_3_2_PLUS: from airflow.executors.base_executor import ExecutorConf _conf = ExecutorConf(team_name) else: - # Airflow <3.2 ExecutorConf doesn't exist (at least not with the required attributes), fall back to global conf + # Airflow <3.2 ExecutorConf doesn't exist (at least not with the required attributes), fall back to global conf. _conf = conf - # Create the Celery app with the correct configuration + # Create the Celery app with the correct configuration. celery_app = create_celery_app(_conf) if AIRFLOW_V_3_0_PLUS: - # Get the task from the app - task_to_run = celery_app.tasks["execute_workload"] + # Get the task from the app. + celery_task = celery_app.tasks["execute_workload"] if TYPE_CHECKING: assert isinstance(args, workloads.BaseWorkload) args = (args.model_dump_json(),) else: - # Get the task from the app - task_to_run = celery_app.tasks["execute_command"] + # Get the task from the app. + celery_task = celery_app.tasks["execute_command"] args = [args] # type: ignore[list-item] # Pre-import redis.client to avoid SIGALRM interrupting module initialization. @@ -369,27 +369,23 @@ def send_workload_to_executor( try: import redis.client # noqa: F401 except ImportError: - pass # Redis not installed or not using Redis backend + pass # Redis not installed or not using Redis backend. try: with timeout(seconds=OPERATION_TIMEOUT): - result = task_to_run.apply_async(args=args, queue=queue) + result = celery_task.apply_async(args=args, queue=queue) except (Exception, AirflowTaskTimeout) as e: exception_traceback = f"Celery Task ID: {key}\n{traceback.format_exc()}" result = ExceptionWithTraceback(e, exception_traceback) # The type is right for the version, but the type cannot be defined correctly for Airflow 2 and 3 - # concurrently; + # concurrently. return key, args, result -# Backward compatibility alias -send_task_to_executor = send_workload_to_executor - - def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | ExceptionWithTraceback, Any]: """ - Fetch and return the state of the given celery task. + Fetch and return the state of the given celery task (workload execution). The scope of this function is global so that it can be called by subprocesses in the pool. @@ -403,12 +399,12 @@ def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | Excep try: import redis.client # noqa: F401 except ImportError: - pass # Redis not installed or not using Redis backend + pass # Redis not installed or not using Redis backend. try: with timeout(seconds=OPERATION_TIMEOUT): - # Accessing state property of celery task will make actual network request - # to get the current state of the task + # Accessing state property of celery task (workload execution) triggers a network request + # to get the current state of the task. info = async_result.info if hasattr(async_result, "info") else None return async_result.task_id, async_result.state, info except Exception as e: @@ -428,7 +424,7 @@ class BulkStateFetcher(LoggingMixin): def __init__(self, sync_parallelism: int, celery_app: Celery | None = None): super().__init__() self._sync_parallelism = sync_parallelism - self.celery_app = celery_app or app # Use provided app or fall back to module-level app + self.celery_app = celery_app or app # Use provided app or fall back to module-level app. def _tasks_list_to_task_ids(self, async_tasks: Collection[AsyncResult]) -> set[str]: return {a.task_id for a in async_tasks} diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 9d55cda5376ae..4a2dd18cf3135 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -103,7 +103,7 @@ def _task_event_logs(self, value): @property def queued_tasks(self) -> dict[TaskInstanceKey, Any]: """Return queued tasks from celery and kubernetes executor.""" - return self.celery_executor.queued_tasks | self.kubernetes_executor.queued_tasks + return self.celery_executor.queued_tasks | self.kubernetes_executor.queued_tasks # type: ignore[return-value] @queued_tasks.setter def queued_tasks(self, value) -> None: diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py b/providers/celery/tests/integration/celery/test_celery_executor.py index da8ee15571b10..154e61f9924d0 100644 --- a/providers/celery/tests/integration/celery/test_celery_executor.py +++ b/providers/celery/tests/integration/celery/test_celery_executor.py @@ -27,7 +27,7 @@ from time import sleep from unittest import mock -# leave this it is used by the test worker +# Leave this it is used by the test worker. import celery.contrib.testing.tasks # noqa: F401 import pytest import uuid6 @@ -85,8 +85,8 @@ def _prepare_app(broker_url=None, execute=None): test_config = dict(celery_executor_utils.get_celery_configuration()) test_config.update({"broker_url": broker_url}) test_app = Celery(broker_url, config_source=test_config) - # Register the fake execute function with the test_app using the correct task name - # This ensures workers using test_app will execute the fake function + # Register the fake execute function with the test_app using the correct task name. + # This ensures workers using test_app will execute the fake function. test_execute = test_app.task(name=execute_name)(execute) patch_app = mock.patch.object(celery_executor_utils, "app", test_app) @@ -96,7 +96,7 @@ def _prepare_app(broker_url=None, execute=None): celery_executor_utils.execute_command.__wrapped__ = execute patch_execute = mock.patch.object(celery_executor_utils, execute_name, test_execute) - # Patch factory function so CeleryExecutor instances get the test app + # Patch factory function so CeleryExecutor instances get the test app. patch_factory = mock.patch.object(celery_executor_utils, "create_celery_app", return_value=test_app) backend = test_app.backend @@ -106,7 +106,7 @@ def _prepare_app(broker_url=None, execute=None): # race condition where it one of the subprocesses can die with "Table # already exists" error, because SQLA checks for which tables exist, # then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT - # EXISTS + # EXISTS. session = backend.ResultSession() session.close() @@ -130,7 +130,7 @@ def teardown_method(self) -> None: db.clear_db_jobs() -def setup_dagrun_with_success_and_fail_tasks(dag_maker): +def setup_dagrun_with_success_and_fail_workloads(dag_maker): date = timezone.utcnow() start_date = date - timedelta(days=2) @@ -176,19 +176,19 @@ def test_celery_integration(self, broker_url, executor_config): from airflow.providers.celery.executors import celery_executor, celery_executor_utils if AIRFLOW_V_3_0_PLUS: - # Airflow 3: execute_workload receives JSON string + # Airflow 3: execute_workload receives JSON string. def fake_execute(input: str) -> None: """Fake execute_workload that parses JSON and fails for tasks with 'fail' in task_id.""" import json workload_dict = json.loads(input) - # Check if this is a task that should fail (task_id contains "fail") + # Check if this is a workload that should fail (task_id contains "fail"). if "ti" in workload_dict and "task_id" in workload_dict["ti"]: if "fail" in workload_dict["ti"]["task_id"]: raise AirflowException("fail") else: - # Airflow 2: execute_command receives command list - def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 version + # Airflow 2: execute_command receives command list. + def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 version. if "fail" in input: raise AirflowException("fail") @@ -212,7 +212,7 @@ def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 ve TaskInstanceKey("id", "success", "abc", 0, -1), TaskInstanceKey("id", "fail", "abc", 0, -1), ] - dagrun = setup_dagrun_with_success_and_fail_tasks(dag_maker) + dagrun = setup_dagrun_with_success_and_fail_workloads(dag_maker) ti_success, ti_fail = dagrun.task_instances for w in ( workloads.ExecuteTask.make( @@ -246,7 +246,7 @@ def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 ve assert executor.queued_tasks == {} - def test_error_sending_task(self): + def test_error_sending_workload(self): from airflow.providers.celery.executors import celery_executor, celery_executor_utils with _prepare_app(): @@ -265,8 +265,8 @@ def test_error_sending_task(self): executor.queued_tasks[key] = workload executor.task_publish_retries[key] = 1 - # Mock send_task_to_executor to return an error result - # This simulates a failure when sending the task to Celery + # Mock send_workload_to_executor to return an error result. + # This simulates a failure when sending the workload to Celery. def mock_send_error(task_tuple): key_from_tuple = task_tuple[0] return ( @@ -279,14 +279,14 @@ def mock_send_error(task_tuple): ) with mock.patch.object( - celery_executor_utils, "send_task_to_executor", side_effect=mock_send_error + celery_executor_utils, "send_workload_to_executor", side_effect=mock_send_error ): executor.heartbeat() - assert len(executor.queued_tasks) == 0, "Task should no longer be queued" + assert len(executor.queued_tasks) == 0, "Workload should no longer be queued" assert executor.event_buffer[key][0] == State.FAILED - def test_retry_on_error_sending_task(self, caplog): - """Test that Airflow retries publishing tasks to Celery Broker at least 3 times""" + def test_retry_on_error_sending_workload(self, caplog): + """Test that Airflow retries publishing workloads to Celery Broker at least 3 times""" from airflow.providers.celery.executors import celery_executor, celery_executor_utils with ( @@ -300,8 +300,8 @@ def test_retry_on_error_sending_task(self, caplog): ), ): executor = celery_executor.CeleryExecutor() - assert executor.task_publish_retries == {} - assert executor.task_publish_max_retries == 3, "Assert Default Max Retries is 3" + assert executor.workload_publish_retries == {} + assert executor.workload_publish_max_retries == 3, "Assert Default Max Retries is 3" with DAG(dag_id="id"): task = BashOperator(task_id="test", bash_command="true", start_date=datetime.now()) @@ -316,27 +316,27 @@ def test_retry_on_error_sending_task(self, caplog): key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1) executor.queued_tasks[key] = workload - # Test that when heartbeat is called again, task is published again to Celery Queue + # Test that when heartbeat is called again, workload is published again to Celery Queue. executor.heartbeat() - assert dict(executor.task_publish_retries) == {key: 1} - assert len(executor.queued_tasks) == 1, "Task should remain in queue" + assert dict(executor.workload_publish_retries) == {key: 1} + assert len(executor.queued_tasks) == 1, "Workload should remain in queue" assert executor.event_buffer == {} - assert f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in caplog.text + assert f"[Try 1 of 3] Celery Task Timeout Error for Workload: ({key})." in caplog.text executor.heartbeat() - assert dict(executor.task_publish_retries) == {key: 2} - assert len(executor.queued_tasks) == 1, "Task should remain in queue" + assert dict(executor.workload_publish_retries) == {key: 2} + assert len(executor.queued_tasks) == 1, "Workload should remain in queue" assert executor.event_buffer == {} - assert f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in caplog.text + assert f"[Try 2 of 3] Celery Task Timeout Error for Workload: ({key})." in caplog.text executor.heartbeat() - assert dict(executor.task_publish_retries) == {key: 3} - assert len(executor.queued_tasks) == 1, "Task should remain in queue" + assert dict(executor.workload_publish_retries) == {key: 3} + assert len(executor.queued_tasks) == 1, "Workload should remain in queue" assert executor.event_buffer == {} - assert f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in caplog.text + assert f"[Try 3 of 3] Celery Task Timeout Error for Workload: ({key})." in caplog.text executor.heartbeat() - assert dict(executor.task_publish_retries) == {} + assert dict(executor.workload_publish_retries) == {} assert len(executor.queued_tasks) == 0, "Task should no longer be in queue" assert executor.event_buffer[key][0] == State.FAILED @@ -391,7 +391,7 @@ def test_should_support_kv_backend(self, mock_mget, caplog): ] ) - # Assert called - ignore order + # Assert called - ignore order. mget_args, _ = mock_mget.call_args assert set(mget_args[0]) == {b"celery-task-meta-456", b"celery-task-meta-123"} mock_mget.assert_called_once_with(mock.ANY) diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index f5d34fb29162d..f901f7eccbd05 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -25,7 +25,7 @@ from datetime import timedelta from unittest import mock -# leave this it is used by the test worker +# Leave this it is used by the test worker. import celery.contrib.testing.tasks # noqa: F401 import pytest import time_machine @@ -93,7 +93,7 @@ def _prepare_app(broker_url=None, execute=None): test_execute = test_app.task(execute) patch_app = mock.patch.object(celery_executor_utils, "app", test_app) patch_execute = mock.patch.object(celery_executor_utils, execute_name, test_execute) - # Patch factory function so CeleryExecutor instances get the test app + # Patch factory function so CeleryExecutor instances get the test app. patch_factory = mock.patch.object(celery_executor_utils, "create_celery_app", return_value=test_app) backend = test_app.backend @@ -103,7 +103,7 @@ def _prepare_app(broker_url=None, execute=None): # race condition where it one of the subprocesses can die with "Table # already exists" error, because SQLA checks for which tables exist, # then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT - # EXISTS + # EXISTS. session = backend.ResultSession() session.close() @@ -111,7 +111,7 @@ def _prepare_app(broker_url=None, execute=None): try: yield test_app finally: - # Clear event loop to tear down each celery instance + # Clear event loop to tear down each celery instance. set_event_loop(None) @@ -141,7 +141,7 @@ def test_celery_executor_init_with_args_kwargs(self): team_name = "test_team" if AIRFLOW_V_3_2_PLUS: - # Multi-team support with ExecutorConf requires Airflow 3.2+ + # Multi-team support with ExecutorConf requires Airflow 3.2+. executor = celery_executor.CeleryExecutor(parallelism=parallelism, team_name=team_name) else: executor = celery_executor.CeleryExecutor(parallelism) @@ -149,7 +149,7 @@ def test_celery_executor_init_with_args_kwargs(self): assert executor.parallelism == parallelism if AIRFLOW_V_3_2_PLUS: - # Multi-team support with ExecutorConf requires Airflow 3.2+ + # Multi-team support with ExecutorConf requires Airflow 3.2+. assert executor.team_name == team_name assert executor.conf.team_name == team_name @@ -160,8 +160,8 @@ def test_exception_propagation(self, caplog): ) with _prepare_app(): executor = celery_executor.CeleryExecutor() - executor.tasks = {"key": FakeCeleryResult()} - executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values()) + executor.workloads = {"key": FakeCeleryResult()} + executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.workloads.values()) assert celery_executor_utils.CELERY_FETCH_ERR_MSG_HEADER in caplog.text, caplog.record_tuples assert FAKE_EXCEPTION_MSG in caplog.text, caplog.record_tuples @@ -267,7 +267,7 @@ def test_try_adopt_task_instances(self, clean_dags_dagruns_and_dagbundles, testi executor = celery_executor.CeleryExecutor() assert executor.running == set() - assert executor.tasks == {} + assert executor.workloads == {} not_adopted_tis = executor.try_adopt_task_instances(tis) @@ -275,7 +275,7 @@ def test_try_adopt_task_instances(self, clean_dags_dagruns_and_dagbundles, testi key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, None, 0) assert executor.running == {key_1, key_2} - assert executor.tasks == {key_1: AsyncResult("231"), key_2: AsyncResult("232")} + assert executor.workloads == {key_1: AsyncResult("231"), key_2: AsyncResult("232")} assert not_adopted_tis == [] @pytest.fixture @@ -310,12 +310,12 @@ def test_cleanup_stuck_queued_tasks( executor = celery_executor.CeleryExecutor() executor.job_id = 1 executor.running = {ti.key} - executor.tasks = {ti.key: AsyncResult("231")} + executor.workloads = {ti.key: AsyncResult("231")} assert executor.has_task(ti) with pytest.warns(AirflowProviderDeprecationWarning, match="cleanup_stuck_queued_tasks"): executor.cleanup_stuck_queued_tasks(tis=tis) executor.sync() - assert executor.tasks == {} + assert executor.workloads == {} app.control.revoke.assert_called_once_with("231") mock_fail.assert_called() assert not executor.has_task(ti) @@ -344,13 +344,13 @@ def test_revoke_task(self, mock_fail, clean_dags_dagruns_and_dagbundles, testing executor = celery_executor.CeleryExecutor() executor.job_id = 1 executor.running = {ti.key} - executor.tasks = {ti.key: AsyncResult("231")} + executor.workloads = {ti.key: AsyncResult("231")} assert executor.has_task(ti) for ti in tis: executor.revoke_task(ti=ti) executor.sync() app.control.revoke.assert_called_once_with("231") - assert executor.tasks == {} + assert executor.workloads == {} assert not executor.has_task(ti) mock_fail.assert_not_called() @@ -358,18 +358,18 @@ def test_revoke_task(self, mock_fail, clean_dags_dagruns_and_dagbundles, testing def test_result_backend_sqlalchemy_engine_options(self): import importlib - # Scope the mock using context manager so we can clean up afterward + # Scope the mock using context manager so we can clean up afterward. with mock.patch("celery.Celery") as mock_celery: - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) - # reload celery_executor_utils to recreate the celery app with new config + # reload celery_executor_utils to recreate the celery app with new config. importlib.reload(celery_executor_utils) call_args = mock_celery.call_args.kwargs.get("config_source") assert "database_engine_options" in call_args assert call_args["database_engine_options"] == {"pool_recycle": 1800} - # Clean up: reload modules with real Celery to restore clean state for subsequent tests + # Clean up: reload modules with real Celery to restore clean state for subsequent tests. importlib.reload(default_celery) importlib.reload(celery_executor_utils) @@ -378,9 +378,9 @@ def test_operation_timeout_config(): assert celery_executor_utils.OPERATION_TIMEOUT == 1 -class MockTask: +class MockWorkload: """ - A picklable object used to mock tasks sent to Celery. Can't use the mock library + A picklable object used to mock workloads sent to Celery. Can't use the mock library here because it's not picklable. """ @@ -407,7 +407,7 @@ def register_signals(): yield - # Restore original signal handlers after test + # Restore original signal handlers after test. signal.signal(signal.SIGINT, orig_sigint) signal.signal(signal.SIGTERM, orig_sigterm) signal.signal(signal.SIGUSR2, orig_sigusr2) @@ -415,20 +415,20 @@ def register_signals(): @pytest.mark.execution_timeout(200) @pytest.mark.quarantined -def test_send_tasks_to_celery_hang(register_signals): +def test_send_workloads_to_celery_hang(register_signals): """ Test that celery_executor does not hang after many runs. """ executor = celery_executor.CeleryExecutor() - task = MockTask() - task_tuples_to_send = [(None, None, None, task) for _ in range(26)] + workload = MockWorkload() + workload_tuples_to_send = [(None, None, None, workload) for _ in range(26)] for _ in range(250): # This loop can hang on Linux if celery_executor does something wrong with # multiprocessing. - results = executor._send_tasks_to_celery(task_tuples_to_send) - assert results == [(None, None, 1) for _ in task_tuples_to_send] + results = executor._send_workloads_to_celery(workload_tuples_to_send) + assert results == [(None, None, 1) for _ in workload_tuples_to_send] @conf_vars({("celery", "result_backend"): "rediss://test_user:test_password@localhost:6379/0"}) @@ -438,7 +438,7 @@ def test_celery_executor_with_no_recommended_result_backend(caplog): from airflow.providers.celery.executors.default_celery import log with caplog.at_level(logging.WARNING, logger=log.name): - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) assert "test_password" not in caplog.text assert ( @@ -451,7 +451,7 @@ def test_celery_executor_with_no_recommended_result_backend(caplog): def test_sentinel_kwargs_loaded_from_string(): import importlib - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) assert default_celery.DEFAULT_CELERY_CONFIG["broker_transport_options"]["sentinel_kwargs"] == { "service_name": "mymaster" @@ -462,7 +462,7 @@ def test_sentinel_kwargs_loaded_from_string(): def test_celery_task_acks_late_loaded_from_string(): import importlib - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) assert default_celery.DEFAULT_CELERY_CONFIG["task_acks_late"] is False @@ -522,7 +522,7 @@ def test_visibility_timeout_not_set_for_unsupported_broker(caplog): def test_celery_extra_celery_config_loaded_from_string(): import importlib - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) assert default_celery.DEFAULT_CELERY_CONFIG["worker_max_tasks_per_child"] == 10 @@ -532,7 +532,7 @@ def test_result_backend_sentinel_kwargs_loaded_from_string(): """Test that sentinel_kwargs for result backend transport options is correctly parsed.""" import importlib - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) assert "result_backend_transport_options" in default_celery.DEFAULT_CELERY_CONFIG assert default_celery.DEFAULT_CELERY_CONFIG["result_backend_transport_options"]["sentinel_kwargs"] == { @@ -545,7 +545,7 @@ def test_result_backend_master_name_loaded(): """Test that master_name for result backend transport options is correctly loaded.""" import importlib - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) assert "result_backend_transport_options" in default_celery.DEFAULT_CELERY_CONFIG assert ( @@ -563,7 +563,7 @@ def test_result_backend_transport_options_with_multiple_options(): """Test that multiple result backend transport options are correctly loaded.""" import importlib - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) result_backend_opts = default_celery.DEFAULT_CELERY_CONFIG["result_backend_transport_options"] assert result_backend_opts["sentinel_kwargs"] == {"password": "redis_password"} @@ -607,7 +607,7 @@ def test_result_backend_sentinel_full_config(): """Test full Redis Sentinel configuration for result backend.""" import importlib - # reload celery conf to apply the new config + # Reload celery conf to apply the new config. importlib.reload(default_celery) assert default_celery.DEFAULT_CELERY_CONFIG["result_backend"] == ( @@ -636,13 +636,13 @@ def teardown_method(self) -> None: ("operators", "default_queue"): "global_queue", } ) - def test_multi_team_isolation_and_task_routing(self, monkeypatch): + def test_multi_team_isolation_and_workload_routing(self, monkeypatch): """ - Test multi-team executor isolation and correct task routing. + Test multi-team executor isolation and correct workload routing. Verifies: - Each executor has isolated Celery app and config - - Tasks are routed through team-specific apps (_process_tasks/_process_workloads) + - Workloads are routed through team-specific apps (_process_tasks/_process_workloads) - Backward compatibility with global executor """ # Set up team-specific config via environment variables @@ -651,49 +651,49 @@ def test_multi_team_isolation_and_task_routing(self, monkeypatch): monkeypatch.setenv("AIRFLOW__TEAM_B___CELERY__BROKER_URL", "redis://team-b:6379/0") monkeypatch.setenv("AIRFLOW__TEAM_B___OPERATORS__DEFAULT_QUEUE", "team_b_queue") - # Reload config to pick up environment variables + # Reload config to pick up environment variables. from airflow import configuration configuration.conf.read_dict({}, source="test") - # Create executors with different team configs + # Create executors with different team configs. team_a_executor = CeleryExecutor(parallelism=2, team_name="team_a") team_b_executor = CeleryExecutor(parallelism=3, team_name="team_b") global_executor = CeleryExecutor(parallelism=4) - # Each executor has its own Celery app (critical for isolation) + # Each executor has its own Celery app (critical for isolation). assert team_a_executor.celery_app is not team_b_executor.celery_app assert team_a_executor.celery_app is not global_executor.celery_app - # Team-specific broker URLs are used + # Team-specific broker URLs are used. assert "team-a" in team_a_executor.celery_app.conf.broker_url assert "team-b" in team_b_executor.celery_app.conf.broker_url assert "global" in global_executor.celery_app.conf.broker_url - # Team-specific queues are used + # Team-specific queues are used. assert team_a_executor.celery_app.conf.task_default_queue == "team_a_queue" assert team_b_executor.celery_app.conf.task_default_queue == "team_b_queue" assert global_executor.celery_app.conf.task_default_queue == "global_queue" - # Each executor has its own BulkStateFetcher with correct app + # Each executor has its own BulkStateFetcher with correct app. assert team_a_executor.bulk_state_fetcher.celery_app is team_a_executor.celery_app assert team_b_executor.bulk_state_fetcher.celery_app is team_b_executor.celery_app - # Executors have isolated internal state - assert team_a_executor.tasks is not team_b_executor.tasks + # Executors have isolated internal state. + assert team_a_executor.workloads is not team_b_executor.workloads assert team_a_executor.running is not team_b_executor.running assert team_a_executor.queued_tasks is not team_b_executor.queued_tasks @conf_vars({("celery", "broker_url"): "redis://global:6379/0"}) - @mock.patch("airflow.providers.celery.executors.celery_executor.CeleryExecutor._send_tasks") - def test_task_routing_through_team_specific_app(self, mock_send_tasks, monkeypatch): + @mock.patch("airflow.providers.celery.executors.celery_executor.CeleryExecutor._send_workloads") + def test_workload_routing_through_team_specific_app(self, mock_send_workloads, monkeypatch): """ - Test that _process_tasks and _process_workloads pass the correct team_name for task routing. + Test that _process_tasks (v2) and _process_workloads (v3) pass the correct team_name for task routing. With the ProcessPoolExecutor approach, we pass team_name instead of task objects to avoid pickling issues. The subprocess reconstructs the team-specific Celery app from the team_name. """ - # Set up team A config + # Set up team A config. monkeypatch.setenv("AIRFLOW__TEAM_A___CELERY__BROKER_URL", "redis://team-a:6379/0") team_a_executor = CeleryExecutor(parallelism=2, team_name="team_a") @@ -702,40 +702,42 @@ def test_task_routing_through_team_specific_app(self, mock_send_tasks, monkeypat from airflow.executors.workloads import ExecuteTask from airflow.models.taskinstancekey import TaskInstanceKey - # Create mock workload + # Create mock workload. mock_ti = mock.Mock() mock_ti.key = TaskInstanceKey("dag", "task", "run", 1) mock_ti.queue = "test_queue" mock_workload = mock.Mock(spec=ExecuteTask) mock_workload.ti = mock_ti - # Process workload through team A executor + # Process workload through team A executor. team_a_executor._process_workloads([mock_workload]) - # Verify _send_tasks received the correct team_name - assert mock_send_tasks.called - task_tuples = mock_send_tasks.call_args[0][0] - team_name_from_call = task_tuples[0][3] # 4th element is now team_name + # Verify _send_workloads received the correct team_name. + assert mock_send_workloads.called + workload_tuples = mock_send_workloads.call_args[0][0] + team_name_from_call = workload_tuples[0][ + 3 + ] # 4th element is team_name (used to reconstruct Celery app in subprocess). - # Critical: team_name is passed so subprocess can reconstruct the correct app + # Critical: team_name is passed so subprocess can reconstruct the correct app. assert team_name_from_call == "team_a" else: from airflow.models.taskinstancekey import TaskInstanceKey - # Test V2 path with execute_command + # Test V2 path with execute_command. mock_key = TaskInstanceKey("dag", "task", "run", 1) mock_command = ["airflow", "tasks", "run", "dag", "task"] mock_queue = "test_queue" - # Process task through team A executor + # Process task through team A executor. team_a_executor._process_tasks([(mock_key, mock_command, mock_queue, None)]) - # Verify _send_tasks received team A's execute_command task - assert mock_send_tasks.called - task_tuples = mock_send_tasks.call_args[0][0] - task_from_call = task_tuples[0][3] # 4th element is the task (V2 still uses task object) + # Verify _send_workloads received team A's execute_command workload (v2 compatibility path). + assert mock_send_workloads.called + task_tuples = mock_send_workloads.call_args[0][0] + task_from_call = task_tuples[0][3] # 4th element is the task (V2 still uses task object). - # Critical: task belongs to team A's app, not module-level app + # Critical: Celery task belongs to team A's app, not module-level app. assert task_from_call.app is team_a_executor.celery_app assert task_from_call.name == "execute_command" @@ -756,7 +758,7 @@ def test_celery_tasks_registered_on_import(): "execute_workload must be registered with the Celery app at import time. " "Workers need this to receive tasks without KeyError." ) - # TODO: remove this block when min supported Airflow version is >= 3.0 + # TODO: remove this block when min supported Airflow version is >= 3.0. if not AIRFLOW_V_3_0_PLUS: assert "execute_command" in registered_tasks, ( "execute_command must be registered for Airflow 2.x compatibility."