From 08fbfa00596b0d3f1f2d8b8c7f932f4fbb83f7c9 Mon Sep 17 00:00:00 2001 From: Sean Ghaeli Date: Mon, 3 Nov 2025 17:34:42 -0800 Subject: [PATCH 01/32] first pass implementation of executor support for sync callbacks --- .../src/airflow/executors/workloads.py | 2 +- .../src/airflow/jobs/scheduler_job_runner.py | 104 +++++++++++++++++- airflow-core/src/airflow/models/deadline.py | 63 ++++++++++- .../celery/executors/celery_executor.py | 45 +++++--- .../celery/executors/celery_executor_utils.py | 59 ++++++++-- 5 files changed, 242 insertions(+), 31 deletions(-) diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index 7cf1aae60ff21..37d83b6af1593 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -205,6 +205,6 @@ class RunTrigger(BaseModel): All = Annotated[ - ExecuteTask | RunTrigger, + ExecuteTask | ExecuteCallback | RunTrigger, Field(discriminator="type"), ] diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 4117153d4e7cd..fc1ff887f5fce 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -85,8 +85,8 @@ TaskOutletAssetReference, ) from airflow.models.backfill import Backfill -from airflow.models.callback import Callback -from airflow.models.dag import DagModel +from airflow.models.callback import Callback, CallbackState, ExecutorCallback +from airflow.models.dag import DagModel, get_next_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag from airflow.models.dagbundle import DagBundleModel @@ -993,6 +993,103 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int: return len(queued_tis) + def _enqueue_executor_callbacks(self, session: Session) -> None: + """ + Enqueue ExecutorCallback workloads to executors. + + Similar to _enqueue_task_instances, but for callbacks that need to run on executors. + Queries for QUEUED ExecutorCallback instances and routes them to the appropriate executor. + + :param session: The database session + """ + # Query for QUEUED ExecutorCallback instances + from airflow.models.callback import CallbackType + + queued_callbacks = session.scalars( + select(ExecutorCallback) + .where(ExecutorCallback.type == CallbackType.EXECUTOR) + .where(ExecutorCallback.state == CallbackState.QUEUED) + .order_by(ExecutorCallback.priority_weight.desc()) + .limit(conf.getint("scheduler", "max_callback_workloads_per_loop", fallback=100)) + ).all() + + if not queued_callbacks: + return + + # Group callbacks by executor (based on callback executor attribute or default executor) + executor_to_callbacks: dict[BaseExecutor, list[ExecutorCallback]] = defaultdict(list) + + for callback in queued_callbacks: + # Get the executor name from callback data if specified + executor_name = None + if isinstance(callback.data, dict): + executor_name = callback.data.get("executor") + + # Find the appropriate executor + executor = None + if executor_name: + # Find executor by name - try multiple matching strategies + for exec in self.job.executors: + # Match by class name (e.g., "CeleryExecutor") + if exec.__class__.__name__ == executor_name: + executor = exec + break + # Match by executor name attribute if available + if hasattr(exec, "name") and exec.name and str(exec.name) == executor_name: + executor = exec + break + # Match by executor name attribute if available + if hasattr(exec, "executor_name") and exec.executor_name == executor_name: + executor = exec + break + + # Default to first executor if no specific executor found + if executor is None: + executor = self.job.executors[0] if self.job.executors else None + + if executor is None: + self.log.warning("No executor available for callback %s", callback.id) + continue + + executor_to_callbacks[executor].append(callback) + + # Enqueue callbacks for each executor + for executor, callbacks in executor_to_callbacks.items(): + for callback in callbacks: + # Get the associated DagRun for the callback + # For deadline callbacks, we stored dag_run_id in the callback data + dag_run = None + if isinstance(callback.data, dict) and "dag_run_id" in callback.data: + dag_run_id = callback.data["dag_run_id"] + dag_run = session.get(DagRun, dag_run_id) + elif isinstance(callback.data, dict) and "dag_id" in callback.data: + # Fallback: try to find the latest dag_run for the dag_id + dag_id = callback.data["dag_id"] + dag_run = session.scalars( + select(DagRun) + .where(DagRun.dag_id == dag_id) + .order_by(DagRun.execution_date.desc()) + .limit(1) + ).first() + + if dag_run is None: + self.log.warning("Could not find DagRun for callback %s", callback.id) + continue + + # Create ExecuteCallback workload + workload = workloads.ExecuteCallback.make( + callback=callback, + dag_run=dag_run, + generator=executor.jwt_generator, + ) + + # Queue the workload + executor.queue_workload(workload, session=session) + + # Update callback state to RUNNING + callback.state = CallbackState.RUNNING + session.add(callback) + @staticmethod def _process_task_event_logs(log_records: deque[Log], session: Session): objects = (log_records.popleft() for _ in range(len(log_records))) @@ -1657,6 +1754,9 @@ def _run_scheduler_loop(self) -> None: ): deadline.handle_miss(session) + # Route ExecutorCallback workloads to executors (similar to task routing) + self._enqueue_executor_callbacks(session) + # Heartbeat the scheduler periodically perform_heartbeat( job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index bbf24cc2842d2..2172386ea7d3a 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -21,6 +21,7 @@ from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime, timedelta +from enum import Enum from typing import TYPE_CHECKING, Any, cast from uuid import UUID @@ -31,8 +32,12 @@ from airflow._shared.observability.metrics.stats import Stats from airflow._shared.timezones import timezone +from airflow.models import Trigger from airflow.models.base import Base from airflow.models.callback import Callback, CallbackDefinitionProtocol +from airflow.observability.stats import Stats +from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback +from airflow.triggers.callback import CallbackTrigger from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name @@ -77,6 +82,19 @@ def __get__(self, instance, cls=None): return self.method(cls) +class DeadlineCallbackState(str, Enum): + """ + All possible states of deadline callbacks once the deadline is missed. + + `None` state implies that the deadline is pending (`deadline_time` hasn't passed yet). + """ + + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + class Deadline(Base): """A Deadline is a 'need-by' date which triggers a callback if the provided time has passed.""" @@ -224,9 +242,48 @@ def get_simple_context(): "deadline": {"id": self.id, "deadline_time": self.deadline_time}, } - self.callback.data["kwargs"] = self.callback.data["kwargs"] | {"context": get_simple_context()} - self.missed = True - self.callback.queue() + if isinstance(self.callback, AsyncCallback): + callback_trigger = CallbackTrigger( + callback_path=self.callback.path, + callback_kwargs=(self.callback.kwargs or {}) | {"context": get_simple_context()}, + ) + trigger_orm = Trigger.from_object(callback_trigger) + session.add(trigger_orm) + session.flush() + self.trigger = trigger_orm + + elif isinstance(self.callback, SyncCallback): + from airflow.models.callback import CallbackFetchMethod, ExecutorCallback + + # Create an ExecutorCallback for processing through the executor pipeline + # Store deadline and dag_run info in the callback data for later retrieval + callback_data = self.callback.serialize() + callback_data["deadline_id"] = str(self.id) + callback_data["dag_run_id"] = str(self.dagrun.id) + callback_data["dag_id"] = self.dagrun.dag_id + + # Add context to callback kwargs (similar to AsyncCallback) + if "kwargs" not in callback_data: + callback_data["kwargs"] = {} + callback_data["kwargs"] = (callback_data.get("kwargs") or {}) | {"context": get_simple_context()} + + # Create a modified callback_def with the additional data + class ModifiedCallback: + def serialize(self): + return callback_data + + executor_callback = ExecutorCallback( + callback_def=ModifiedCallback(), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + priority_weight=1, + ) + executor_callback.queue() + session.add(executor_callback) + session.flush() + + else: + raise TypeError("Unknown Callback type") + session.add(self) Stats.incr( "deadline_alerts.deadline_missed", diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 7144425b2c3d7..a131f7820a5db 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -39,7 +39,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor -from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats from airflow.utils.state import TaskInstanceState @@ -160,14 +160,21 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: # Airflow V3 version -- have to delay imports until we know we are on v3 from airflow.executors.workloads import ExecuteTask - tasks = [ - (workload.ti.key, workload, workload.ti.queue, self.team_name) - for workload in workloads - if isinstance(workload, ExecuteTask) - ] - if len(tasks) != len(workloads): - invalid = list(workload for workload in workloads if not isinstance(workload, ExecuteTask)) - raise ValueError(f"{type(self)}._process_workloads cannot handle {invalid}") + if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads import ExecuteCallback + + tasks = [] + for workload in workloads: + if isinstance(workload, ExecuteTask): + tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) + elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback): + # Use default queue for callbacks, or extract from callback data if available + queue = "default" + if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: + queue = workload.callback.data["queue"] + tasks.append((workload.callback.key, workload, queue, self.team_name)) + else: + raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") self._send_tasks(tasks) @@ -378,8 +385,20 @@ def get_cli_commands() -> list[GroupCommand]: def queue_workload(self, workload: workloads.All, session: Session | None) -> None: from airflow.executors import workloads - - if not isinstance(workload, workloads.ExecuteTask): + from airflow.models.taskinstancekey import TaskInstanceKey + + if isinstance(workload, workloads.ExecuteTask): + ti = workload.ti + self.queued_tasks[ti.key] = workload + elif isinstance(workload, workloads.ExecuteCallback): + # For callbacks, use a synthetic key based on callback ID + callback_key = TaskInstanceKey( + dag_id="callback", + task_id=str(workload.callback.id), + run_id="callback", + try_number=1, + map_index=-1, + ) + self.queued_tasks[callback_key] = workload + else: raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - ti = workload.ti - self.queued_tasks[ti.key] = workload diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index b09737701f3f7..ece85595bb444 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -182,9 +182,6 @@ def execute_workload(input: str) -> None: celery_task_id = app.current_task.request.id - if not isinstance(workload, workloads.ExecuteTask): - raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") - log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload) base_url = conf.get("api", "base_url", fallback="/") @@ -193,15 +190,53 @@ def execute_workload(input: str) -> None: base_url = f"http://localhost:8080{base_url}" default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), - log_path=workload.log_path, - ) + if isinstance(workload, workloads.ExecuteTask): + supervise( + # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, + token=workload.token, + server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), + log_path=workload.log_path, + ) + elif isinstance(workload, workloads.ExecuteCallback): + # Execute callback workload - import and execute the callback function + from airflow.models.callback import CallbackFetchMethod + + callback = workload.callback + if callback.fetch_type != CallbackFetchMethod.IMPORT_PATH: + raise ValueError( + f"CeleryExecutor only supports callbacks with fetch_type={CallbackFetchMethod.IMPORT_PATH}, " + f"got {callback.fetch_type}" + ) + + # Extract callback path and kwargs from data + if not isinstance(callback.data, dict): + raise ValueError(f"Callback data must be a dict, got {type(callback.data)}") + + callback_path = callback.data.get("path") + callback_kwargs = callback.data.get("kwargs") or {} + + if not callback_path: + raise ValueError("Callback path not found in callback data") + + # Import the callback function using Airflow's import_string utility + from airflow.utils.module_loading import import_string + + callback_func = import_string(callback_path) + + # Execute the callback + log.info("[%s] Executing callback: %s", celery_task_id, callback_path) + try: + result = callback_func(**callback_kwargs) + log.info("[%s] Callback executed successfully", celery_task_id) + return result + except Exception: + log.exception("[%s] Callback execution failed", celery_task_id) + raise + else: + raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") if not AIRFLOW_V_3_0_PLUS: From a2186c422e9de60f3873ec3d15dff1564b5d1f74 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 27 Jan 2026 11:20:08 -0800 Subject: [PATCH 02/32] Synchronous callback support for BaseExecutor, LocalExecutor, and CeleryExecutor Add support for the Callback workload to be run in the executors. Other executors will need to be updated before the can support the workload, but I tried to make it as non-invasive as I could. --- .../src/airflow/executors/base_executor.py | 98 +++++++++++--- .../src/airflow/executors/local_executor.py | 82 +++++++---- .../src/airflow/executors/workloads.py | 65 ++++++++- .../src/airflow/jobs/scheduler_job_runner.py | 55 +++++--- airflow-core/src/airflow/models/deadline.py | 72 +++++----- .../unit/executors/test_base_executor.py | 127 ++++++++++++++++++ .../unit/executors/test_local_executor.py | 40 ++++++ .../celery/executors/celery_executor.py | 1 + .../celery/executors/celery_executor_utils.py | 37 +---- .../src/airflow/sdk/definitions/deadline.py | 7 +- .../task_sdk/definitions/test_deadline.py | 23 +++- 11 files changed, 461 insertions(+), 146 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 3bb8a70fa2712..c21fada8e672d 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -143,6 +143,7 @@ class BaseExecutor(LoggingMixin): active_spans = ThreadSafeDict() supports_ad_hoc_ti_run: bool = False + supports_callbacks: bool = False supports_multi_team: bool = False sentry_integration: str = "" @@ -186,8 +187,9 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) self.parallelism: int = parallelism self.team_name: str | None = team_name self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {} - self.running: set[TaskInstanceKey] = set() - self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {} + self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {} + self.running: set[TaskInstanceKey | str] = set() + self.event_buffer: dict[TaskInstanceKey | str, EventBufferValueType] = {} self._task_event_logs: deque[Log] = deque() self.conf = ExecutorConf(team_name) @@ -224,10 +226,46 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey): self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra)) def queue_workload(self, workload: workloads.All, session: Session) -> None: - if not isinstance(workload, workloads.ExecuteTask): + if isinstance(workload, workloads.ExecuteTask): + ti = workload.ti + self.queued_tasks[ti.key] = workload + elif isinstance(workload, workloads.ExecuteCallback): + self.queued_callbacks[workload.callback.id] = workload + else: raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") - ti = workload.ti - self.queued_tasks[ti.key] = workload + + def _get_workloads_to_schedule( + self, open_slots: int + ) -> list[tuple[TaskInstanceKey | str, workloads.All]]: + """ + Select and return the next batch of workloads to schedule, respecting priority policy. + + Priority Policy: Callbacks are scheduled before tasks (callbacks complete existing work). + Callbacks are processed in FIFO order. Tasks are sorted by priority_weight (higher priority first). + + :param open_slots: Number of available execution slots + """ + workloads_to_schedule: list[tuple[TaskInstanceKey | str, workloads.All]] = [] + + if self.queued_callbacks: + for key, workload in self.queued_callbacks.items(): + if len(workloads_to_schedule) >= open_slots: + break + workloads_to_schedule.append((key, workload)) + + remaining_slots = open_slots - len(workloads_to_schedule) + if remaining_slots and self.queued_tasks: + sorted_tasks = sorted( + self.queued_tasks.items(), + key=lambda x: x[1].ti.priority_weight, + reverse=True, + ) + for key, workload in sorted_tasks: + if len(workloads_to_schedule) >= open_slots: + break + workloads_to_schedule.append((key, workload)) + + return workloads_to_schedule def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: """ @@ -266,10 +304,10 @@ def heartbeat(self) -> None: """Heartbeat sent to trigger new jobs.""" open_slots = self.parallelism - len(self.running) - num_running_tasks = len(self.running) - num_queued_tasks = len(self.queued_tasks) + num_running_workloads = len(self.running) + num_queued_workloads = len(self.queued_tasks) + len(self.queued_callbacks) - self._emit_metrics(open_slots, num_running_tasks, num_queued_tasks) + self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads) self.trigger_tasks(open_slots) # Calling child class sync method @@ -350,16 +388,16 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, workload def trigger_tasks(self, open_slots: int) -> None: """ - Initiate async execution of the queued tasks, up to the number of available slots. + Initiate async execution of queued workloads (tasks and callbacks), up to the number of available slots. + + Callbacks are prioritized over tasks to complete existing work before starting new work. :param open_slots: Number of open slots """ - sorted_queue = self.order_queued_tasks_by_priority() + workloads_to_schedule = self._get_workloads_to_schedule(open_slots) workload_list = [] - for _ in range(min((open_slots, len(self.queued_tasks)))): - key, item = sorted_queue.pop() - + for key, workload in workloads_to_schedule: # If a task makes it here but is still understood by the executor # to be running, it generally means that the task has been killed # externally and not yet been marked as failed. @@ -373,8 +411,8 @@ def trigger_tasks(self, open_slots: int) -> None: if key in self.attempts: del self.attempts[key] - if isinstance(item, workloads.ExecuteTask) and hasattr(item, "ti"): - ti = item.ti + if isinstance(workload, workloads.ExecuteTask) and hasattr(workload, "ti"): + ti = workload.ti # If it's None, then the span for the current id hasn't been started. if self.active_spans is not None and self.active_spans.get("ti:" + str(ti.id)) is None: @@ -397,9 +435,20 @@ def trigger_tasks(self, open_slots: int) -> None: carrier = Trace.inject() ti.context_carrier = carrier - workload_list.append(item) + workload_list.append(workload) + if workload_list: - self._process_workloads(workload_list) + try: + self._process_workloads(workload_list) + except AttributeError as e: + if any(isinstance(workload, workloads.ExecuteCallback) for workload in workload_list): + raise NotImplementedError( + f"{type(self).__name__} does not support ExecuteCallback workloads. " + f"This executor needs to be updated to handle both ExecuteTask and ExecuteCallback types. " + f"See any executor with supports_callbacks=True (LocalExecutor or CeleryExecutor, for example) for reference implementation." + ) from e + # Re-raise if it's a different AttributeError + raise # TODO: This should not be using `TaskInstanceState` here, this is just "did the process complete, or did # it die". It is possible for the task itself to finish with success, but the state of the task to be set @@ -529,21 +578,26 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task @property def slots_available(self): - """Number of new tasks this executor instance can accept.""" - return self.parallelism - len(self.running) - len(self.queued_tasks) + """Number of new workloads (tasks and callbacks) this executor instance can accept.""" + return self.parallelism - len(self.running) - len(self.queued_tasks) - len(self.queued_callbacks) @property def slots_occupied(self): - """Number of tasks this executor instance is currently managing.""" - return len(self.running) + len(self.queued_tasks) + """Number of workloads (tasks and callbacks) this executor instance is currently managing.""" + return len(self.running) + len(self.queued_tasks) + len(self.queued_callbacks) def debug_dump(self): """Get called in response to SIGUSR2 by the scheduler.""" self.log.info( - "executor.queued (%d)\n\t%s", + "executor.queued_tasks (%d)\n\t%s", len(self.queued_tasks), "\n\t".join(map(repr, self.queued_tasks.items())), ) + self.log.info( + "executor.queued_callbacks (%d)\n\t%s", + len(self.queued_callbacks), + "\n\t".join(map(repr, self.queued_callbacks.items())), + ) self.log.info("executor.running (%d)\n\t%s", len(self.running), "\n\t".join(map(repr, self.running))) self.log.info( "executor.event_buffer (%d)\n\t%s", diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 604de7c7f00f4..197f12bb4f7a6 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -53,6 +53,16 @@ TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Exception | None] +def _get_executor_process_title_prefix(team_name: str | None) -> str: + """ + Build the process title prefix for LocalExecutor workers. + + :param team_name: Team name from executor configuration + """ + team_suffix = f" [{team_name}]" if team_name else "" + return f"airflow worker -- LocalExecutor{team_suffix}:" + + def _run_worker( logger_name: str, input: SimpleQueue[workloads.All | None], @@ -68,11 +78,8 @@ def _run_worker( log = structlog.get_logger(logger_name) log.info("Worker starting up pid=%d", os.getpid()) - # Create team suffix for process title - team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else "" - while True: - setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: ", log) + setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} ", log) try: workload = input.get() except EOFError: @@ -87,25 +94,31 @@ def _run_worker( # Received poison pill, no more tasks to run return - if not isinstance(workload, workloads.ExecuteTask): - raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") - # Decrement this as soon as we pick up a message off the queue with unread_messages: unread_messages.value -= 1 - key = None - if ti := getattr(workload, "ti", None): - key = ti.key - else: - raise TypeError(f"Don't know how to get ti key from {type(workload).__name__}") - try: - _execute_work(log, workload, team_conf) + # Handle different workload types + if isinstance(workload, workloads.ExecuteTask): + key = workload.ti.key + try: + _execute_work(log, workload, team_conf) + output.put((key, TaskInstanceState.SUCCESS, None)) + except Exception as e: + log.exception("Task execution failed.") + output.put((key, TaskInstanceState.FAILED, e)) + + elif isinstance(workload, workloads.ExecuteCallback): + key = workload.callback.id + try: + _execute_callback(log, workload, team_conf) + output.put((key, TaskInstanceState.SUCCESS, None)) + except Exception as e: + log.exception("Callback execution failed") + output.put((key, TaskInstanceState.FAILED, e)) - output.put((key, TaskInstanceState.SUCCESS, None)) - except Exception as e: - log.exception("uhoh") - output.put((key, TaskInstanceState.FAILED, e)) + else: + raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None: @@ -118,9 +131,7 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No """ from airflow.sdk.execution_time.supervisor import supervise - # Create team suffix for process title - team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else "" - setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: {workload.ti.id}", log) + setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.ti.id}", log) base_url = team_conf.get("api", "base_url", fallback="/") # If it's a relative URL, use localhost:8080 as the default @@ -141,6 +152,22 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No ) +def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_conf) -> None: + """ + Execute a callback workload. + + :param log: Logger instance + :param workload: The ExecuteCallback workload to execute + :param team_conf: Team-specific executor configuration + """ + setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log) + + success, error_msg = workloads.execute_callback_workload(workload.callback, log) + + if not success: + raise RuntimeError(error_msg or "Callback execution failed") + + class LocalExecutor(BaseExecutor): """ LocalExecutor executes tasks locally in parallel. @@ -155,6 +182,7 @@ class LocalExecutor(BaseExecutor): supports_multi_team: bool = True serve_logs: bool = True + supports_callbacks: bool = True activity_queue: SimpleQueue[workloads.All | None] result_queue: SimpleQueue[TaskInstanceStateType] @@ -300,10 +328,14 @@ def end(self) -> None: def terminate(self): """Terminate the executor is not doing anything.""" - def _process_workloads(self, workloads): - for workload in workloads: + def _process_workloads(self, workload_list): + for workload in workload_list: self.activity_queue.put(workload) - del self.queued_tasks[workload.ti.key] + # Remove from appropriate queue based on workload type + if isinstance(workload, workloads.ExecuteTask): + del self.queued_tasks[workload.ti.key] + elif isinstance(workload, workloads.ExecuteCallback): + del self.queued_callbacks[workload.callback.id] with self._unread_messages: - self._unread_messages.value += len(workloads) + self._unread_messages.value += len(workload_list) self._check_workers() diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index 37d83b6af1593..e0096256f4bd6 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -20,16 +20,19 @@ import uuid from abc import ABC from datetime import datetime +from importlib import import_module from pathlib import Path from typing import TYPE_CHECKING, Annotated, Literal import structlog from pydantic import BaseModel, Field +from airflow.models.callback import CallbackFetchMethod + if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.models import DagRun - from airflow.models.callback import Callback as CallbackModel, CallbackFetchMethod + from airflow.models.callback import Callback as CallbackModel from airflow.models.taskinstance import TaskInstance as TIModel from airflow.models.taskinstancekey import TaskInstanceKey @@ -92,7 +95,7 @@ class Callback(BaseModel): """Schema for Callback with minimal required fields needed for Executors and Task SDK.""" id: uuid.UUID - fetch_type: CallbackFetchMethod + fetch_method: CallbackFetchMethod data: dict @@ -208,3 +211,61 @@ class RunTrigger(BaseModel): ExecuteTask | ExecuteCallback | RunTrigger, Field(discriminator="type"), ] + + +def execute_callback_workload( + callback: Callback, + log, +) -> tuple[bool, str | None]: + """ + Execute a callback function by importing and calling it, returning the success state. + + Supports two patterns: + 1. Functions - called directly with kwargs + 2. Classes that return callable instances (like BaseNotifier) - instantiated then called with context + + Example: + # Function callback + callback.data = {"path": "my_module.alert_func", "kwargs": {"msg": "Alert!", "context": {...}}} + execute_callback_workload(callback, log) # Calls alert_func(msg="Alert!", context={...}) + + # Notifier callback + callback.data = {"path": "airflow.providers.slack...SlackWebhookNotifier", "kwargs": {"text": "Alert!", "context": {...}}} + execute_callback_workload(callback, log) # SlackWebhookNotifier(text=..., context=...) then calls instance(context) + + :param callback: The Callback schema containing path and kwargs + :param log: Logger instance for recording execution + :return: Tuple of (success: bool, error_message: str | None) + """ + # Extract callback details from data + callback_path = callback.data.get("path") + callback_kwargs = callback.data.get("kwargs", {}) + + if not callback_path: + return False, "Callback path not found in data." + + try: + # Import the callback callable + # Expected format: "module.path.to.function_or_class" + module_path, function_name = callback_path.rsplit(".", 1) + module = import_module(module_path) + callback_callable = getattr(module, function_name) + + log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) + + # If the callback is a callabale, call it. If it is a class, instantiate it. + result = callback_callable(**callback_kwargs) + + # If the callback is a class then it is now instantiated and callable, call it. + if callable(result): + context = callback_kwargs.get("context", {}) + log.debug("Calling result with context for %s", callback_path) + result = result(context) + + log.info("Callback %s executed successfully.", callback_path) + return True, None + + except Exception as e: + error_msg = f"Callback execution failed: {type(e).__name__}: {str(e)}" + log.exception("Callback %s(%s) execution failed: %s", callback_path, callback_kwargs, error_msg) + return False, error_msg diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index fc1ff887f5fce..4c420db7f0145 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -95,6 +95,7 @@ from airflow.models.pool import normalize_pool_name_for_stats from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.team import Team from airflow.models.trigger import TRIGGER_FAIL_REPR, Trigger, TriggerFailureReason from airflow.observability.metrics import stats_utils @@ -130,7 +131,6 @@ from airflow._shared.logging.types import Logger from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName - from airflow.models.taskinstance import TaskInstanceKey from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard @@ -1152,21 +1152,44 @@ def process_executor_events( ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int] = {} event_buffer = executor.get_event_buffer() tis_with_right_state: list[TaskInstanceKey] = [] - - # Report execution - for ti_key, (state, _) in event_buffer.items(): - # We create map (dag_id, task_id, logical_date) -> in-memory try_number - ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number - - cls.logger().info("Received executor event with state %s for task instance %s", state, ti_key) - if state in ( - TaskInstanceState.FAILED, - TaskInstanceState.SUCCESS, - TaskInstanceState.QUEUED, - TaskInstanceState.RUNNING, - TaskInstanceState.RESTARTING, - ): - tis_with_right_state.append(ti_key) + callback_keys_with_events: list[str] = [] + + # Report execution - handle both task and callback events + for key, (state, _) in event_buffer.items(): + if isinstance(key, TaskInstanceKey): + ti_primary_key_to_try_number_map[key.primary] = key.try_number + cls.logger().info("Received executor event with state %s for task instance %s", state, key) + if state in ( + TaskInstanceState.FAILED, + TaskInstanceState.SUCCESS, + TaskInstanceState.QUEUED, + TaskInstanceState.RUNNING, + TaskInstanceState.RESTARTING, + ): + tis_with_right_state.append(key) + else: + # Callback event (key is string UUID) + cls.logger().info("Received executor event with state %s for callback %s", state, key) + if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS): + callback_keys_with_events.append(key) + + # Handle callback completion events + for callback_id in callback_keys_with_events: + state, info = event_buffer.pop(callback_id) + callback = session.get(Callback, callback_id) + if callback: + # Note: We receive TaskInstanceState from executor (SUCCESS/FAILED) but convert to CallbackState here. + # This is intentional - executor layer uses generic completion states, scheduler converts to proper types. + if state == TaskInstanceState.SUCCESS: + callback.state = CallbackState.SUCCESS + cls.logger().info("Callback %s completed successfully", callback_id) + elif state == TaskInstanceState.FAILED: + callback.state = CallbackState.FAILED + callback.output = str(info) if info else "Execution failed" + cls.logger().error("Callback %s failed: %s", callback_id, callback.output) + session.add(callback) + else: + cls.logger().warning("Callback %s not found in database", callback_id) # Return if no finished tasks if not tis_with_right_state: diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index 2172386ea7d3a..316d3eff56e36 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -32,12 +32,12 @@ from airflow._shared.observability.metrics.stats import Stats from airflow._shared.timezones import timezone -from airflow.models import Trigger from airflow.models.base import Base -from airflow.models.callback import Callback, CallbackDefinitionProtocol -from airflow.observability.stats import Stats -from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback -from airflow.triggers.callback import CallbackTrigger +from airflow.models.callback import ( + Callback, + ExecutorCallback, + TriggererCallback, +) from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name @@ -242,48 +242,36 @@ def get_simple_context(): "deadline": {"id": self.id, "deadline_time": self.deadline_time}, } - if isinstance(self.callback, AsyncCallback): - callback_trigger = CallbackTrigger( - callback_path=self.callback.path, - callback_kwargs=(self.callback.kwargs or {}) | {"context": get_simple_context()}, - ) - trigger_orm = Trigger.from_object(callback_trigger) - session.add(trigger_orm) + if isinstance(self.callback, TriggererCallback): + # Update the callback with context before queuing + if "kwargs" not in self.callback.data: + self.callback.data["kwargs"] = {} + self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { + "context": get_simple_context() + } + + self.callback.queue() + session.add(self.callback) session.flush() - self.trigger = trigger_orm - - elif isinstance(self.callback, SyncCallback): - from airflow.models.callback import CallbackFetchMethod, ExecutorCallback - - # Create an ExecutorCallback for processing through the executor pipeline - # Store deadline and dag_run info in the callback data for later retrieval - callback_data = self.callback.serialize() - callback_data["deadline_id"] = str(self.id) - callback_data["dag_run_id"] = str(self.dagrun.id) - callback_data["dag_id"] = self.dagrun.dag_id - - # Add context to callback kwargs (similar to AsyncCallback) - if "kwargs" not in callback_data: - callback_data["kwargs"] = {} - callback_data["kwargs"] = (callback_data.get("kwargs") or {}) | {"context": get_simple_context()} - - # Create a modified callback_def with the additional data - class ModifiedCallback: - def serialize(self): - return callback_data - - executor_callback = ExecutorCallback( - callback_def=ModifiedCallback(), - fetch_method=CallbackFetchMethod.IMPORT_PATH, - priority_weight=1, - ) - executor_callback.queue() - session.add(executor_callback) + + elif isinstance(self.callback, ExecutorCallback): + if "kwargs" not in self.callback.data: + self.callback.data["kwargs"] = {} + self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { + "context": get_simple_context() + } + self.callback.data["deadline_id"] = str(self.id) + self.callback.data["dag_run_id"] = str(self.dagrun.id) + self.callback.data["dag_id"] = self.dagrun.dag_id + + self.callback.queue() + session.add(self.callback) session.flush() else: - raise TypeError("Unknown Callback type") + raise TypeError(f"Unknown Callback type: {type(self.callback).__name__}") + self.missed = True session.add(self) Stats.incr( "deadline_alerts.deadline_missed", diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index 5c2a3d6d549df..c0c17e3946d72 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -25,6 +25,7 @@ import pendulum import pytest +import structlog import time_machine from airflow._shared.timezones import timezone @@ -34,6 +35,8 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType from airflow.executors.local_executor import LocalExecutor +from airflow.executors.workloads import Callback, execute_callback_workload +from airflow.models.callback import CallbackFetchMethod from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.sdk import BaseOperator from airflow.serialization.definitions.baseoperator import SerializedBaseOperator @@ -573,3 +576,127 @@ def test_executor_conf_get_mandatory_value(self): team_executor_conf = ExecutorConf(team_name="test_team") assert team_executor_conf.get_mandatory_value("celery", "broker_url") == "redis://team-redis" + + +class TestCallbackSupport: + def test_supports_callbacks_flag_default_false(self): + executor = BaseExecutor() + assert executor.supports_callbacks is False + + def test_local_executor_supports_callbacks_true(self): + """Test that LocalExecutor sets supports_callbacks to True.""" + executor = LocalExecutor() + assert executor.supports_callbacks is True + + @pytest.mark.db_test + def test_queue_workload_with_execute_callback(self, dag_maker, session): + executor = BaseExecutor() + callback_data = Callback( + id=UUID("12345678-1234-5678-1234-567812345678"), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.func", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=workloads.BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + executor.queue_workload(callback_workload, session) + + assert len(executor.queued_callbacks) == 1 + assert callback_data.id in executor.queued_callbacks + + @pytest.mark.db_test + def test_get_workloads_to_schedule_prioritizes_callbacks(self, dag_maker, session): + executor = BaseExecutor() + dagrun = setup_dagrun(dag_maker) + callback_data = Callback( + id=UUID("12345678-1234-5678-1234-567812345678"), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.func", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=workloads.BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + executor.queue_workload(callback_workload, session) + + for ti in dagrun.task_instances: + task_workload = workloads.ExecuteTask.make(ti) + executor.queue_workload(task_workload, session) + + workloads_to_schedule = executor._get_workloads_to_schedule(open_slots=10) + + assert len(workloads_to_schedule) == 4 # 1 callback + 3 tasks + _, first_workload = workloads_to_schedule[0] + assert isinstance(first_workload, workloads.ExecuteCallback) # Assert callback comes first + + +class TestExecuteCallbackWorkload: + def test_execute_function_callback_success(self): + callback_data = Callback( + id=UUID("12345678-1234-5678-1234-567812345678"), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "builtins.dict", + "kwargs": {"a": 1, "b": 2, "c": 3}, + }, + ) + log = structlog.get_logger() + + success, error = execute_callback_workload(callback_data, log) + + assert success is True + assert error is None + + def test_execute_callback_missing_path(self): + callback_data = Callback( + id=UUID("12345678-1234-5678-1234-567812345678"), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"kwargs": {}}, # Missing 'path' + ) + log = structlog.get_logger() + + success, error = execute_callback_workload(callback_data, log) + + assert success is False + assert "Callback path not found" in error + + def test_execute_callback_import_error(self): + callback_data = Callback( + id=UUID("12345678-1234-5678-1234-567812345678"), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "nonexistent.module.function", + "kwargs": {}, + }, + ) + log = structlog.get_logger() + + success, error = execute_callback_workload(callback_data, log) + + assert success is False + assert "ModuleNotFoundError" in error + + def test_execute_callback_execution_error(self): + # Use a function that will raise an error; len() requires an argument + callback_data = Callback( + id=UUID("12345678-1234-5678-1234-567812345678"), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={ + "path": "builtins.len", + "kwargs": {}, + }, + ) + log = structlog.get_logger() + + success, error = execute_callback_workload(callback_data, log) + + assert success is False + assert "TypeError" in error diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 5f216cca2e767..a152d9c2d1a87 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -21,6 +21,7 @@ import multiprocessing import os from unittest import mock +from uuid import UUID import pytest from kgb import spy_on @@ -29,6 +30,8 @@ from airflow._shared.timezones import timezone from airflow.executors import workloads from airflow.executors.local_executor import LocalExecutor, _execute_work +from airflow.executors.workloads import Callback +from airflow.models.callback import CallbackFetchMethod from airflow.settings import Session from airflow.utils.state import State @@ -327,3 +330,40 @@ def test_global_executor_without_team_name(self): assert len(executor.workers) == 2 executor.end() + + +class TestLocalExecutorCallbackSupport: + def test_supports_callbacks_flag_is_true(self): + executor = LocalExecutor() + assert executor.supports_callbacks is True + + @skip_spawn_mp_start + @mock.patch("airflow.executors.workloads.execute_callback_workload") + def test_process_callback_workload(self, mock_execute_callback): + mock_execute_callback.return_value = (True, None) + + executor = LocalExecutor(parallelism=1) + callback_data = Callback( + id=UUID("12345678-1234-5678-1234-567812345678"), + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.func", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=workloads.BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + executor.start() + + try: + executor.queued_callbacks[callback_data.id] = callback_workload + executor._process_workloads([callback_workload]) + assert len(executor.queued_callbacks) == 0 + # We can't easily verify worker execution without running the worker, + # but we can verify the helper is called via mock + + finally: + executor.end() diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index a131f7820a5db..e365d2992817f 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -91,6 +91,7 @@ class CeleryExecutor(BaseExecutor): """ supports_ad_hoc_ti_run: bool = True + supports_callbacks: bool = True sentry_integration: str = "sentry_sdk.integrations.celery.CeleryIntegration" # TODO: Remove this flag once providers depend on Airflow 3.2. diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index ece85595bb444..bfa07d155c85d 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -201,40 +201,9 @@ def execute_workload(input: str) -> None: log_path=workload.log_path, ) elif isinstance(workload, workloads.ExecuteCallback): - # Execute callback workload - import and execute the callback function - from airflow.models.callback import CallbackFetchMethod - - callback = workload.callback - if callback.fetch_type != CallbackFetchMethod.IMPORT_PATH: - raise ValueError( - f"CeleryExecutor only supports callbacks with fetch_type={CallbackFetchMethod.IMPORT_PATH}, " - f"got {callback.fetch_type}" - ) - - # Extract callback path and kwargs from data - if not isinstance(callback.data, dict): - raise ValueError(f"Callback data must be a dict, got {type(callback.data)}") - - callback_path = callback.data.get("path") - callback_kwargs = callback.data.get("kwargs") or {} - - if not callback_path: - raise ValueError("Callback path not found in callback data") - - # Import the callback function using Airflow's import_string utility - from airflow.utils.module_loading import import_string - - callback_func = import_string(callback_path) - - # Execute the callback - log.info("[%s] Executing callback: %s", celery_task_id, callback_path) - try: - result = callback_func(**callback_kwargs) - log.info("[%s] Callback executed successfully", celery_task_id) - return result - except Exception: - log.exception("[%s] Callback execution failed", celery_task_id) - raise + success, error_msg = workloads.execute_callback_workload(workload.callback, log) + if not success: + raise RuntimeError(error_msg or "Callback execution failed") else: raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py index aeb1ff89010a2..b591c538ad991 100644 --- a/task-sdk/src/airflow/sdk/definitions/deadline.py +++ b/task-sdk/src/airflow/sdk/definitions/deadline.py @@ -21,7 +21,10 @@ from typing import TYPE_CHECKING from airflow.models.deadline import DeadlineReferenceType, ReferenceModels -from airflow.sdk.definitions.callback import AsyncCallback, Callback +from airflow.sdk.definitions.callback import AsyncCallback, Callback, SyncCallback +from airflow.sdk.serde import deserialize, serialize +from airflow.serialization.definitions.deadline import DeadlineAlertFields +from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding if TYPE_CHECKING: from collections.abc import Callable @@ -44,7 +47,7 @@ def __init__( self.reference = reference self.interval = interval - if not isinstance(callback, AsyncCallback): + if not isinstance(callback, (AsyncCallback, SyncCallback)): raise ValueError(f"Callbacks of type {type(callback).__name__} are not currently supported") self.callback = callback diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py b/task-sdk/tests/task_sdk/definitions/test_deadline.py index 1025cfc27a3ac..8e9e816b30705 100644 --- a/task-sdk/tests/task_sdk/definitions/test_deadline.py +++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py @@ -138,10 +138,27 @@ def test_deadline_alert_in_set(self): alert_set = {alert1, alert2} assert len(alert_set) == 1 - def test_deadline_alert_unsupported_callback(self): - with pytest.raises(ValueError, match="Callbacks of type SyncCallback are not currently supported"): + @pytest.mark.parametrize( + ("callback_class"), + [ + pytest.param(AsyncCallback, id="async_callback"), + pytest.param(SyncCallback, id="sync_callback"), + ], + ) + def test_deadline_alert_accepts_all_callbacks(self, callback_class): + alert = DeadlineAlert( + reference=DeadlineReference.DAGRUN_QUEUED_AT, + interval=timedelta(hours=1), + callback=callback_class(TEST_CALLBACK_PATH), + ) + assert alert.callback is not None + assert isinstance(alert.callback, callback_class) + + def test_deadline_alert_rejects_invalid_callback(self): + """Test that DeadlineAlert rejects non-callback types.""" + with pytest.raises(ValueError, match="Callbacks of type str are not currently supported"): DeadlineAlert( reference=DeadlineReference.DAGRUN_QUEUED_AT, interval=timedelta(hours=1), - callback=SyncCallback(TEST_CALLBACK_PATH), + callback="not_a_callback", # type: ignore ) From 7e17159839e193a9f5ff01b62af1df3528f48051 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 28 Jan 2026 12:24:41 -0800 Subject: [PATCH 03/32] Create a unified WorkflowState and WorkflowKey; added mypy type hint fixes --- .../src/airflow/executors/base_executor.py | 25 +++++++++++-------- .../src/airflow/executors/local_executor.py | 14 ++++++----- .../src/airflow/executors/workloads.py | 10 ++++++-- .../src/airflow/jobs/scheduler_job_runner.py | 3 ++- airflow-core/src/airflow/models/callback.py | 16 +++--------- airflow-core/src/airflow/utils/state.py | 13 ++++++++++ .../celery/executors/celery_executor.py | 2 +- .../executors/celery_kubernetes_executor.py | 5 ++-- .../executors/local_kubernetes_executor.py | 5 ++-- 9 files changed, 55 insertions(+), 38 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index c21fada8e672d..040072c8dc5e2 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -52,6 +52,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand from airflow.executors.executor_utils import ExecutorName + from airflow.executors.workloads import WorkloadKey from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -188,8 +189,8 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) self.team_name: str | None = team_name self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {} self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {} - self.running: set[TaskInstanceKey | str] = set() - self.event_buffer: dict[TaskInstanceKey | str, EventBufferValueType] = {} + self.running: set[WorkloadKey] = set() + self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {} self._task_event_logs: deque[Log] = deque() self.conf = ExecutorConf(team_name) @@ -205,7 +206,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) :meta private: """ - self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType) + self.attempts: dict[WorkloadKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType) def __repr__(self): _repr = f"{self.__class__.__name__}(parallelism={self.parallelism}" @@ -230,13 +231,13 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None: ti = workload.ti self.queued_tasks[ti.key] = workload elif isinstance(workload, workloads.ExecuteCallback): - self.queued_callbacks[workload.callback.id] = workload + self.queued_callbacks[str(workload.callback.id)] = workload else: raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") def _get_workloads_to_schedule( self, open_slots: int - ) -> list[tuple[TaskInstanceKey | str, workloads.All]]: + ) -> list[tuple[WorkloadKey, workloads.All]]: """ Select and return the next batch of workloads to schedule, respecting priority policy. @@ -245,7 +246,7 @@ def _get_workloads_to_schedule( :param open_slots: Number of available execution slots """ - workloads_to_schedule: list[tuple[TaskInstanceKey | str, workloads.All]] = [] + workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = [] if self.queued_callbacks: for key, workload in self.queued_callbacks.items(): @@ -508,24 +509,26 @@ def running_state(self, key: TaskInstanceKey, info=None) -> None: """ self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False) - def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]: + def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueType]: """ Return and flush the event buffer. In case dag_ids is specified it will only return and flush events for the given dag_ids. Otherwise, it returns and flushes all events. + Note: Callback events (with string keys) are always included regardless of dag_ids filter. :param dag_ids: the dag_ids to return events for; returns all if given ``None``. :return: a dict of events """ - cleared_events: dict[TaskInstanceKey, EventBufferValueType] = {} + cleared_events: dict[WorkloadKey, EventBufferValueType] = {} if dag_ids is None: cleared_events = self.event_buffer self.event_buffer = {} else: - for ti_key in list(self.event_buffer.keys()): - if ti_key.dag_id in dag_ids: - cleared_events[ti_key] = self.event_buffer.pop(ti_key) + for key in list(self.event_buffer.keys()): + # Include if it's a callback (string key) or if it's a task in the specified dags + if isinstance(key, str) or key.dag_id in dag_ids: + cleared_events[key] = self.event_buffer.pop(key) return cleared_events diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 197f12bb4f7a6..b7ca7688ee0fc 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -37,7 +37,7 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor -from airflow.utils.state import TaskInstanceState +from airflow.utils.state import CallbackState, TaskInstanceState # add logger to parameter of setproctitle to support logging if sys.platform == "darwin": @@ -50,7 +50,9 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger - TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Exception | None] + from airflow.executors.workloads import WorkloadKey, WorkloadState + + WorkloadResultType = tuple[WorkloadKey, WorkloadState, Exception | None] def _get_executor_process_title_prefix(team_name: str | None) -> str: @@ -66,7 +68,7 @@ def _get_executor_process_title_prefix(team_name: str | None) -> str: def _run_worker( logger_name: str, input: SimpleQueue[workloads.All | None], - output: Queue[TaskInstanceStateType], + output: Queue[WorkloadResultType], unread_messages: multiprocessing.sharedctypes.Synchronized[int], team_conf, ): @@ -112,10 +114,10 @@ def _run_worker( key = workload.callback.id try: _execute_callback(log, workload, team_conf) - output.put((key, TaskInstanceState.SUCCESS, None)) + output.put((key, CallbackState.SUCCESS, None)) except Exception as e: log.exception("Callback execution failed") - output.put((key, TaskInstanceState.FAILED, e)) + output.put((key, CallbackState.FAILED, e)) else: raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") @@ -185,7 +187,7 @@ class LocalExecutor(BaseExecutor): supports_callbacks: bool = True activity_queue: SimpleQueue[workloads.All | None] - result_queue: SimpleQueue[TaskInstanceStateType] + result_queue: SimpleQueue[WorkloadResultType] workers: dict[int, multiprocessing.Process] _unread_messages: multiprocessing.sharedctypes.Synchronized[int] diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index e0096256f4bd6..428e35fb8983d 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -27,14 +27,20 @@ import structlog from pydantic import BaseModel, Field -from airflow.models.callback import CallbackFetchMethod +# NOTE: noqa because ruff wants this in TYPE_CHECKING, but if we do that then pydantic fails at runtime. +from airflow.models.callback import CallbackFetchMethod # noqa: TCH001 if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.models import DagRun - from airflow.models.callback import Callback as CallbackModel + from airflow.models.callback import Callback as CallbackModel, CallbackKey from airflow.models.taskinstance import TaskInstance as TIModel from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.utils.state import CallbackState, TaskInstanceState + + # TODO: I wonder if we can programatically assemble these unions somehow + WorkloadKey = TaskInstanceKey | CallbackKey + WorkloadState = TaskInstanceState | CallbackState __all__ = ["All", "ExecuteTask", "ExecuteCallback"] diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 4c420db7f0145..a04173f72de1b 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -85,7 +85,8 @@ TaskOutletAssetReference, ) from airflow.models.backfill import Backfill -from airflow.models.callback import Callback, CallbackState, ExecutorCallback +from airflow.models.callback import Callback, ExecutorCallback +from airflow.utils.state import CallbackState from airflow.models.dag import DagModel, get_next_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index ea45a10f1f1ae..5012f2d860673 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -31,6 +31,7 @@ from airflow._shared.timezones import timezone from airflow.models import Base from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime +from airflow.utils.state import CallbackState if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -38,20 +39,9 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.triggers.base import TriggerEvent -log = structlog.get_logger(__name__) - - -class CallbackState(str, Enum): - """All possible states of callbacks.""" + CallbackKey = str # Callback keys are str(UUID) - PENDING = "pending" - QUEUED = "queued" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - - def __str__(self) -> str: - return self.value +log = structlog.get_logger(__name__) ACTIVE_STATES = frozenset((CallbackState.QUEUED, CallbackState.RUNNING)) diff --git a/airflow-core/src/airflow/utils/state.py b/airflow-core/src/airflow/utils/state.py index b392a02352574..926003d86b0d1 100644 --- a/airflow-core/src/airflow/utils/state.py +++ b/airflow-core/src/airflow/utils/state.py @@ -20,6 +20,19 @@ from enum import Enum +class CallbackState(str, Enum): + """All possible states of callbacks.""" + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self) -> str: + return self.value + + class TerminalTIState(str, Enum): """States that a Task Instance can be in that indicate it has reached a terminal state.""" diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index e365d2992817f..42a3d8d4b6f3f 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -164,7 +164,7 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: if AIRFLOW_V_3_2_PLUS: from airflow.executors.workloads import ExecuteCallback - tasks = [] + tasks: list[TaskInstanceInCelery] = [] for workload in workloads: if isinstance(workload, ExecuteTask): tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 49ae5b35b6f52..6ec97e2480fe6 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -34,6 +34,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand from airflow.executors.base_executor import EventBufferValueType + from airflow.executors.workloads import WorkloadKey from airflow.models.taskinstance import ( # type: ignore[attr-defined] SimpleTaskInstance, TaskInstance, @@ -113,8 +114,8 @@ def queued_tasks(self, value) -> None: """Not implemented for hybrid executors.""" @property - def running(self) -> set[TaskInstanceKey]: - """Return running tasks from celery and kubernetes executor.""" + def running(self) -> set[WorkloadKey]: + """Combine running from both executors.""" return self.celery_executor.running.union(self.kubernetes_executor.running) @running.setter diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 114da7ec36fe8..763e977d85bb9 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -25,6 +25,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor +from airflow.executors.workloads import WorkloadKey from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -109,8 +110,8 @@ def queued_tasks(self, value) -> None: """Not implemented for hybrid executors.""" @property - def running(self) -> set[TaskInstanceKey]: - """Return running tasks from local and kubernetes executor.""" + def running(self) -> set[WorkloadKey]: + """Combine running from both executors.""" return self.local_executor.running.union(self.kubernetes_executor.running) @running.setter From e41698c31a8d3031d20b767a46ab24b347cd941d Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 29 Jan 2026 12:35:39 -0800 Subject: [PATCH 04/32] First batch Niko fixes --- .../src/airflow/executors/base_executor.py | 32 ++++++------------- .../src/airflow/executors/workloads.py | 1 - .../src/airflow/jobs/scheduler_job_runner.py | 3 +- .../executors/local_kubernetes_executor.py | 2 +- 4 files changed, 12 insertions(+), 26 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 040072c8dc5e2..9e890639a027e 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -231,13 +231,17 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None: ti = workload.ti self.queued_tasks[ti.key] = workload elif isinstance(workload, workloads.ExecuteCallback): + if not self.supports_callbacks: + raise NotImplementedError( + f"{type(self).__name__} does not support ExecuteCallback workloads. " + f"Set supports_callbacks = True and implement callback handling in _process_workloads(). " + f"See LocalExecutor or CeleryExecutor for reference implementation." + ) self.queued_callbacks[str(workload.callback.id)] = workload else: raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") - def _get_workloads_to_schedule( - self, open_slots: int - ) -> list[tuple[WorkloadKey, workloads.All]]: + def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]: """ Select and return the next batch of workloads to schedule, respecting priority policy. @@ -254,14 +258,8 @@ def _get_workloads_to_schedule( break workloads_to_schedule.append((key, workload)) - remaining_slots = open_slots - len(workloads_to_schedule) - if remaining_slots and self.queued_tasks: - sorted_tasks = sorted( - self.queued_tasks.items(), - key=lambda x: x[1].ti.priority_weight, - reverse=True, - ) - for key, workload in sorted_tasks: + if open_slots > len(workloads_to_schedule) and self.queued_tasks: + for key, workload in self.order_queued_tasks_by_priority(): if len(workloads_to_schedule) >= open_slots: break workloads_to_schedule.append((key, workload)) @@ -439,17 +437,7 @@ def trigger_tasks(self, open_slots: int) -> None: workload_list.append(workload) if workload_list: - try: - self._process_workloads(workload_list) - except AttributeError as e: - if any(isinstance(workload, workloads.ExecuteCallback) for workload in workload_list): - raise NotImplementedError( - f"{type(self).__name__} does not support ExecuteCallback workloads. " - f"This executor needs to be updated to handle both ExecuteTask and ExecuteCallback types. " - f"See any executor with supports_callbacks=True (LocalExecutor or CeleryExecutor, for example) for reference implementation." - ) from e - # Re-raise if it's a different AttributeError - raise + self._process_workloads(workload_list) # TODO: This should not be using `TaskInstanceState` here, this is just "did the process complete, or did # it die". It is possible for the task itself to finish with success, but the state of the task to be set diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index 428e35fb8983d..39f97ce05a54d 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -38,7 +38,6 @@ from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.state import CallbackState, TaskInstanceState - # TODO: I wonder if we can programatically assemble these unions somehow WorkloadKey = TaskInstanceKey | CallbackKey WorkloadState = TaskInstanceState | CallbackState diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index a04173f72de1b..3973330bc9d7b 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -86,7 +86,6 @@ ) from airflow.models.backfill import Backfill from airflow.models.callback import Callback, ExecutorCallback -from airflow.utils.state import CallbackState from airflow.models.dag import DagModel, get_next_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag @@ -117,7 +116,7 @@ prohibit_commit, with_row_locks, ) -from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.state import CallbackState, DagRunState, State, TaskInstanceState from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 763e977d85bb9..7e0c7d140845b 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -25,7 +25,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor -from airflow.executors.workloads import WorkloadKey from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -34,6 +33,7 @@ from airflow.cli.cli_config import GroupCommand from airflow.executors.base_executor import EventBufferValueType from airflow.executors.local_executor import LocalExecutor + from airflow.executors.workloads import WorkloadKey from airflow.models.taskinstance import ( # type: ignore[attr-defined] SimpleTaskInstance, TaskInstance, From 6de20c8a259280662ee4c4d2e031b8f701438643 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 29 Jan 2026 12:45:03 -0800 Subject: [PATCH 05/32] deovercommentification of hand-off notes --- .../src/airflow/jobs/scheduler_job_runner.py | 33 +++++-------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 3973330bc9d7b..b58fe8b7961d6 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -85,7 +85,7 @@ TaskOutletAssetReference, ) from airflow.models.backfill import Backfill -from airflow.models.callback import Callback, ExecutorCallback +from airflow.models.callback import Callback, CallbackType, ExecutorCallback from airflow.models.dag import DagModel, get_next_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag @@ -1002,9 +1002,6 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: :param session: The database session """ - # Query for QUEUED ExecutorCallback instances - from airflow.models.callback import CallbackType - queued_callbacks = session.scalars( select(ExecutorCallback) .where(ExecutorCallback.type == CallbackType.EXECUTOR) @@ -1016,33 +1013,25 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: if not queued_callbacks: return - # Group callbacks by executor (based on callback executor attribute or default executor) executor_to_callbacks: dict[BaseExecutor, list[ExecutorCallback]] = defaultdict(list) for callback in queued_callbacks: - # Get the executor name from callback data if specified executor_name = None if isinstance(callback.data, dict): executor_name = callback.data.get("executor") - # Find the appropriate executor executor = None if executor_name: - # Find executor by name - try multiple matching strategies - for exec in self.job.executors: - # Match by class name (e.g., "CeleryExecutor") - if exec.__class__.__name__ == executor_name: - executor = exec + for e in self.job.executors: + if e.__class__.__name__ == executor_name: + executor = e break - # Match by executor name attribute if available - if hasattr(exec, "name") and exec.name and str(exec.name) == executor_name: - executor = exec + if hasattr(e, "name") and e.name and str(e.name) == executor_name: + executor = e break - # Match by executor name attribute if available - if hasattr(exec, "executor_name") and exec.executor_name == executor_name: - executor = exec + if hasattr(e, "executor_name") and e.executor_name == executor_name: + executor = e break - # Default to first executor if no specific executor found if executor is None: executor = self.job.executors[0] if self.job.executors else None @@ -1056,8 +1045,6 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: # Enqueue callbacks for each executor for executor, callbacks in executor_to_callbacks.items(): for callback in callbacks: - # Get the associated DagRun for the callback - # For deadline callbacks, we stored dag_run_id in the callback data dag_run = None if isinstance(callback.data, dict) and "dag_run_id" in callback.data: dag_run_id = callback.data["dag_run_id"] @@ -1076,17 +1063,13 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: self.log.warning("Could not find DagRun for callback %s", callback.id) continue - # Create ExecuteCallback workload workload = workloads.ExecuteCallback.make( callback=callback, dag_run=dag_run, generator=executor.jwt_generator, ) - # Queue the workload executor.queue_workload(workload, session=session) - - # Update callback state to RUNNING callback.state = CallbackState.RUNNING session.add(callback) From 4722833f8014cf52440759227b98c1aaf7e33556 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 29 Jan 2026 15:05:46 -0800 Subject: [PATCH 06/32] CI fixes and minor type-related cleanup --- .../src/airflow/executors/base_executor.py | 6 +- .../src/airflow/executors/workloads.py | 4 +- .../unit/executors/test_base_executor.py | 35 ++++++++-- .../celery/executors/celery_executor.py | 66 ++++++++++++++++++- .../executors/celery_kubernetes_executor.py | 2 +- .../executors/local_kubernetes_executor.py | 2 +- 6 files changed, 100 insertions(+), 15 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 9e890639a027e..a1d027fc6736c 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -237,7 +237,7 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None: f"Set supports_callbacks = True and implement callback handling in _process_workloads(). " f"See LocalExecutor or CeleryExecutor for reference implementation." ) - self.queued_callbacks[str(workload.callback.id)] = workload + self.queued_callbacks[workload.callback.id] = workload else: raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") @@ -259,10 +259,10 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads_to_schedule.append((key, workload)) if open_slots > len(workloads_to_schedule) and self.queued_tasks: - for key, workload in self.order_queued_tasks_by_priority(): + for task_key, task_workload in self.order_queued_tasks_by_priority(): if len(workloads_to_schedule) >= open_slots: break - workloads_to_schedule.append((key, workload)) + workloads_to_schedule.append((task_key, task_workload)) return workloads_to_schedule diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index 39f97ce05a54d..97fffbd9dbb09 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -99,7 +99,7 @@ def key(self) -> TaskInstanceKey: class Callback(BaseModel): """Schema for Callback with minimal required fields needed for Executors and Task SDK.""" - id: uuid.UUID + id: str # A uuid.UUID stored as a string fetch_method: CallbackFetchMethod data: dict @@ -180,7 +180,7 @@ def make( return cls( callback=Callback.model_validate(callback, from_attributes=True), dag_rel_path=dag_rel_path or Path(dag_run.dag_model.relative_fileloc or ""), - token=cls.generate_token(str(callback.id), generator), + token=cls.generate_token(callback.id, generator), log_path=fname, bundle_info=bundle_info, ) diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index c0c17e3946d72..9360ffc58c56f 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -588,11 +588,31 @@ def test_local_executor_supports_callbacks_true(self): executor = LocalExecutor() assert executor.supports_callbacks is True + @pytest.mark.db_test + def test_queue_callback_without_support_raises_error(self, dag_maker, session): + executor = BaseExecutor() # supports_callbacks = False by default + callback_data = Callback( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.func", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=workloads.BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + with pytest.raises(NotImplementedError, match="does not support ExecuteCallback"): + executor.queue_workload(callback_workload, session) + @pytest.mark.db_test def test_queue_workload_with_execute_callback(self, dag_maker, session): executor = BaseExecutor() + executor.supports_callbacks = True # Enable for this test callback_data = Callback( - id=UUID("12345678-1234-5678-1234-567812345678"), + id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, ) @@ -610,11 +630,12 @@ def test_queue_workload_with_execute_callback(self, dag_maker, session): assert callback_data.id in executor.queued_callbacks @pytest.mark.db_test - def test_get_workloads_to_schedule_prioritizes_callbacks(self, dag_maker, session): + def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): executor = BaseExecutor() + executor.supports_callbacks = True # Enable for this test dagrun = setup_dagrun(dag_maker) callback_data = Callback( - id=UUID("12345678-1234-5678-1234-567812345678"), + id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, ) @@ -641,7 +662,7 @@ def test_get_workloads_to_schedule_prioritizes_callbacks(self, dag_maker, sessio class TestExecuteCallbackWorkload: def test_execute_function_callback_success(self): callback_data = Callback( - id=UUID("12345678-1234-5678-1234-567812345678"), + id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={ "path": "builtins.dict", @@ -657,7 +678,7 @@ def test_execute_function_callback_success(self): def test_execute_callback_missing_path(self): callback_data = Callback( - id=UUID("12345678-1234-5678-1234-567812345678"), + id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"kwargs": {}}, # Missing 'path' ) @@ -670,7 +691,7 @@ def test_execute_callback_missing_path(self): def test_execute_callback_import_error(self): callback_data = Callback( - id=UUID("12345678-1234-5678-1234-567812345678"), + id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={ "path": "nonexistent.module.function", @@ -687,7 +708,7 @@ def test_execute_callback_import_error(self): def test_execute_callback_execution_error(self): # Use a function that will raise an error; len() requires an argument callback_data = Callback( - id=UUID("12345678-1234-5678-1234-567812345678"), + id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={ "path": "builtins.len", diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 42a3d8d4b6f3f..c0e19043ea1ce 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -157,6 +157,37 @@ def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None: self._send_tasks(task_tuples_to_send) + from airflow.executors.workloads import ExecuteCallback, ExecuteTask + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.celery.executors.celery_executor_utils import execute_workload +>>>>>>> 5f63b2e16c (CI fixes and minor type-related cleanup) + + tasks: list[TaskInstanceInCelery] = [] + for workload in workloads: + if isinstance(workload, ExecuteTask): +<<<<<<< HEAD + tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) + elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback): + tasks.append((workload.ti.key, workload, workload.ti.queue, execute_workload)) + elif isinstance(workload, ExecuteCallback): + # For callbacks, use a synthetic key based on callback ID + callback_key = TaskInstanceKey( + dag_id="callback", + task_id=workload.callback.id, + run_id="callback", + try_number=1, + map_index=-1, + ) +>>>>>>> 5f63b2e16c (CI fixes and minor type-related cleanup) + # Use default queue for callbacks, or extract from callback data if available + queue = "default" + if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: + queue = workload.callback.data["queue"] + tasks.append((workload.callback.key, workload, queue, self.team_name)) + else: + raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") + + self._send_tasks(tasks) def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: # Airflow V3 version -- have to delay imports until we know we are on v3 from airflow.executors.workloads import ExecuteTask @@ -178,6 +209,39 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") self._send_tasks(tasks) +======= + from airflow.executors.workloads import ExecuteCallback, ExecuteTask + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.celery.executors.celery_executor_utils import execute_workload +>>>>>>> 5f63b2e16c (CI fixes and minor type-related cleanup) + + tasks: list[TaskInstanceInCelery] = [] + for workload in workloads: + if isinstance(workload, ExecuteTask): +<<<<<<< HEAD + tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) + elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback): +======= + tasks.append((workload.ti.key, workload, workload.ti.queue, execute_workload)) + elif isinstance(workload, ExecuteCallback): + # For callbacks, use a synthetic key based on callback ID + callback_key = TaskInstanceKey( + dag_id="callback", + task_id=workload.callback.id, + run_id="callback", + try_number=1, + map_index=-1, + ) +>>>>>>> 5f63b2e16c (CI fixes and minor type-related cleanup) + # Use default queue for callbacks, or extract from callback data if available + queue = "default" + if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: + queue = workload.callback.data["queue"] + tasks.append((workload.callback.key, workload, queue, self.team_name)) + else: + raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") + + self._send_tasks(tasks) def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): # Celery state queries will be stuck if we do not use one same backend @@ -395,7 +459,7 @@ def queue_workload(self, workload: workloads.All, session: Session | None) -> No # For callbacks, use a synthetic key based on callback ID callback_key = TaskInstanceKey( dag_id="callback", - task_id=str(workload.callback.id), + task_id=workload.callback.id, run_id="callback", try_number=1, map_index=-1, diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 6ec97e2480fe6..afa7a143f16cb 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -228,7 +228,7 @@ def heartbeat(self) -> None: def get_event_buffer( self, dag_ids: list[str] | None = None - ) -> dict[TaskInstanceKey, EventBufferValueType]: + ) -> dict[WorkloadKey, EventBufferValueType]: """ Return and flush the event buffer from celery and kubernetes executor. diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 7e0c7d140845b..fae50fb9b30a7 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -222,7 +222,7 @@ def heartbeat(self) -> None: def get_event_buffer( self, dag_ids: list[str] | None = None - ) -> dict[TaskInstanceKey, EventBufferValueType]: + ) -> dict[WorkloadKey, EventBufferValueType]: """ Return and flush the event buffer from local and kubernetes executor. From 9719607df7bd77313e3f34643398beb60d519e36 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 29 Jan 2026 16:34:20 -0800 Subject: [PATCH 07/32] Niko suggestions round 2 and Ci/prek fixes --- airflow-core/src/airflow/jobs/scheduler_job_runner.py | 9 ++++++++- .../celery/executors/celery_kubernetes_executor.py | 4 +--- .../kubernetes/executors/local_kubernetes_executor.py | 4 +--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index b58fe8b7961d6..44ee915621f44 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1002,12 +1002,19 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: :param session: The database session """ + num_occupied_slots = sum(executor.slots_occupied for executor in self.job.executors) + max_callbacks = conf.getint("core", "parallelism") - num_occupied_slots + + if max_callbacks <= 0: + self.log.debug("No available slots for callbacks; all executors at capacity") + return + queued_callbacks = session.scalars( select(ExecutorCallback) .where(ExecutorCallback.type == CallbackType.EXECUTOR) .where(ExecutorCallback.state == CallbackState.QUEUED) .order_by(ExecutorCallback.priority_weight.desc()) - .limit(conf.getint("scheduler", "max_callback_workloads_per_loop", fallback=100)) + .limit(max_callbacks) ).all() if not queued_callbacks: diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index afa7a143f16cb..f2203ad746ee1 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -226,9 +226,7 @@ def heartbeat(self) -> None: self.celery_executor.heartbeat() self.kubernetes_executor.heartbeat() - def get_event_buffer( - self, dag_ids: list[str] | None = None - ) -> dict[WorkloadKey, EventBufferValueType]: + def get_event_buffer(self, dag_ids: list[str] | None = None) -> dict[WorkloadKey, EventBufferValueType]: """ Return and flush the event buffer from celery and kubernetes executor. diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index fae50fb9b30a7..9209de0a7ac15 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -220,9 +220,7 @@ def heartbeat(self) -> None: self.local_executor.heartbeat() self.kubernetes_executor.heartbeat() - def get_event_buffer( - self, dag_ids: list[str] | None = None - ) -> dict[WorkloadKey, EventBufferValueType]: + def get_event_buffer(self, dag_ids: list[str] | None = None) -> dict[WorkloadKey, EventBufferValueType]: """ Return and flush the event buffer from local and kubernetes executor. From 8aee70322057b4e19933ac790d7fe89d9a864fa1 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 29 Jan 2026 18:51:11 -0800 Subject: [PATCH 08/32] CI and MyPy fixes --- airflow-core/src/airflow/executors/local_executor.py | 4 ++-- airflow-core/tests/unit/executors/test_local_executor.py | 3 +-- docs/spelling_wordlist.txt | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index b7ca7688ee0fc..2f2de402e1ce3 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -102,7 +102,7 @@ def _run_worker( # Handle different workload types if isinstance(workload, workloads.ExecuteTask): - key = workload.ti.key + key: WorkloadKey = workload.ti.key try: _execute_work(log, workload, team_conf) output.put((key, TaskInstanceState.SUCCESS, None)) @@ -111,7 +111,7 @@ def _run_worker( output.put((key, TaskInstanceState.FAILED, e)) elif isinstance(workload, workloads.ExecuteCallback): - key = workload.callback.id + key: WorkloadKey = workload.callback.id try: _execute_callback(log, workload, team_conf) output.put((key, CallbackState.SUCCESS, None)) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index a152d9c2d1a87..265aff6384492 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -21,7 +21,6 @@ import multiprocessing import os from unittest import mock -from uuid import UUID import pytest from kgb import spy_on @@ -344,7 +343,7 @@ def test_process_callback_workload(self, mock_execute_callback): executor = LocalExecutor(parallelism=1) callback_data = Callback( - id=UUID("12345678-1234-5678-1234-567812345678"), + id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, ) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index e7bef2859afcc..12f7e3ac427f4 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -2092,6 +2092,7 @@ winrm WIT workgroup workgroups +WorkloadKey workspaces writeable wsman From 587499a53e53c0f6716e33d6642898fc3881b464 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 30 Jan 2026 10:17:53 -0800 Subject: [PATCH 09/32] CI and MyPy fixes --- airflow-core/src/airflow/executors/local_executor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 2f2de402e1ce3..dd398c1273df7 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -102,22 +102,20 @@ def _run_worker( # Handle different workload types if isinstance(workload, workloads.ExecuteTask): - key: WorkloadKey = workload.ti.key try: _execute_work(log, workload, team_conf) - output.put((key, TaskInstanceState.SUCCESS, None)) + output.put((workload.ti.key, TaskInstanceState.SUCCESS, None)) except Exception as e: log.exception("Task execution failed.") - output.put((key, TaskInstanceState.FAILED, e)) + output.put((workload.ti.key, TaskInstanceState.FAILED, e)) elif isinstance(workload, workloads.ExecuteCallback): - key: WorkloadKey = workload.callback.id try: _execute_callback(log, workload, team_conf) - output.put((key, CallbackState.SUCCESS, None)) + output.put((workload.callback.id, CallbackState.SUCCESS, None)) except Exception as e: log.exception("Callback execution failed") - output.put((key, CallbackState.FAILED, e)) + output.put((workload.callback.id, CallbackState.FAILED, e)) else: raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") From 47e1d76e384c1a8dd72a620e5d1e94f779914a89 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 4 Feb 2026 14:02:10 -0800 Subject: [PATCH 10/32] refactor workload file locations --- .../src/airflow/executors/base_executor.py | 5 +- .../src/airflow/executors/local_executor.py | 4 +- .../airflow/executors/workloads/__init__.py | 35 ++++ .../src/airflow/executors/workloads/base.py | 82 +++++++++ .../{workloads.py => workloads/callback.py} | 166 ++---------------- .../src/airflow/executors/workloads/task.py | 104 +++++++++++ .../airflow/executors/workloads/trigger.py | 42 +++++ .../src/airflow/executors/workloads/types.py | 37 ++++ .../src/airflow/jobs/triggerer_job_runner.py | 5 +- airflow-core/src/airflow/models/callback.py | 23 +-- .../src/airflow/models/taskinstance.py | 11 +- .../unit/executors/test_base_executor.py | 23 +-- .../unit/executors/test_local_executor.py | 12 +- .../celery/executors/celery_executor.py | 64 ------- .../celery/test_celery_executor.py | 7 +- .../edge3/worker_api/v2-edge-generated.yaml | 6 +- 16 files changed, 371 insertions(+), 255 deletions(-) create mode 100644 airflow-core/src/airflow/executors/workloads/__init__.py create mode 100644 airflow-core/src/airflow/executors/workloads/base.py rename airflow-core/src/airflow/executors/{workloads.py => workloads/callback.py} (51%) create mode 100644 airflow-core/src/airflow/executors/workloads/task.py create mode 100644 airflow-core/src/airflow/executors/workloads/trigger.py create mode 100644 airflow-core/src/airflow/executors/workloads/types.py diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index a1d027fc6736c..f0226091a9b27 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -32,6 +32,7 @@ from airflow.configuration import conf from airflow.executors import workloads from airflow.executors.executor_loader import ExecutorLoader +from airflow.executors.workloads.task import TaskInstanceDTO from airflow.models import Log from airflow.observability.metrics import stats_utils from airflow.observability.trace import Trace @@ -52,7 +53,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand from airflow.executors.executor_utils import ExecutorName - from airflow.executors.workloads import WorkloadKey + from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -415,7 +416,7 @@ def trigger_tasks(self, open_slots: int) -> None: # If it's None, then the span for the current id hasn't been started. if self.active_spans is not None and self.active_spans.get("ti:" + str(ti.id)) is None: - if isinstance(ti, workloads.TaskInstance): + if isinstance(ti, TaskInstanceDTO): parent_context = Trace.extract(ti.parent_context_carrier) else: parent_context = Trace.extract(ti.dag_run.context_carrier) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index dd398c1273df7..0e91a522788e6 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -50,9 +50,7 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger - from airflow.executors.workloads import WorkloadKey, WorkloadState - - WorkloadResultType = tuple[WorkloadKey, WorkloadState, Exception | None] + from airflow.executors.workloads.types import WorkloadResultType def _get_executor_process_title_prefix(team_name: str | None) -> str: diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py new file mode 100644 index 0000000000000..dca4c991f637b --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/__init__.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workload schemas for executor communication.""" + +from __future__ import annotations + +from typing import Annotated + +from pydantic import Field + +from airflow.executors.workloads.base import BaseWorkload, BundleInfo +from airflow.executors.workloads.callback import CallbackFetchMethod, ExecuteCallback +from airflow.executors.workloads.task import ExecuteTask +from airflow.executors.workloads.trigger import RunTrigger + +All = Annotated[ + ExecuteTask | ExecuteCallback | RunTrigger, + Field(discriminator="type"), +] + +__all__ = ["All", "BaseWorkload", "BundleInfo", "CallbackFetchMethod", "ExecuteCallback", "ExecuteTask"] diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py new file mode 100644 index 0000000000000..7e6bffc56b35f --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""ORM models and Pydantic schemas for BaseWorkload.""" + +from __future__ import annotations + +import os +from abc import ABC +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +if TYPE_CHECKING: + from airflow.api_fastapi.auth.tokens import JWTGenerator + + +class BaseWorkload: + """ + Mixin for ORM models that can be scheduled as workloads. + + This mixin defines the interface that scheduler workloads (TaskInstance, + ExecutorCallback, etc.) must implement to provide routing information to the scheduler. + + Subclasses must override: + - get_dag_id() -> str | None + - get_executor_name() -> str | None + """ + + def get_dag_id(self) -> str | None: + """ + Return the DAG ID for scheduler routing. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement get_dag_id()") + + def get_executor_name(self) -> str | None: + """ + Return the executor name for scheduler routing. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement get_executor_name()") + + +class BundleInfo(BaseModel): + """Schema for telling task which bundle to run with.""" + + name: str + version: str | None = None + + +class BaseWorkloadSchema(BaseModel): + """Base Pydantic schema for executor workload DTOs.""" + + token: str # The identity token for this workload. + + @staticmethod + def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: + return generator.generate({"sub": sub_id}) if generator else "" + + +class BaseDagBundleWorkload(BaseWorkloadSchema, ABC): + """Base class for Workloads that are associated with a DAG bundle.""" + + dag_rel_path: os.PathLike[str] # Filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`) + bundle_info: BundleInfo + log_path: str | None # Rendered relative log filename template the task logs should be written to. diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads/callback.py similarity index 51% rename from airflow-core/src/airflow/executors/workloads.py rename to airflow-core/src/airflow/executors/workloads/callback.py index 97fffbd9dbb09..d8e33f08def57 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -14,89 +14,39 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Callback workload schemas for executor communication.""" + from __future__ import annotations -import os -import uuid -from abc import ABC -from datetime import datetime +from enum import Enum from importlib import import_module from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Literal +from typing import TYPE_CHECKING, Literal import structlog from pydantic import BaseModel, Field -# NOTE: noqa because ruff wants this in TYPE_CHECKING, but if we do that then pydantic fails at runtime. -from airflow.models.callback import CallbackFetchMethod # noqa: TCH001 +from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.models import DagRun - from airflow.models.callback import Callback as CallbackModel, CallbackKey - from airflow.models.taskinstance import TaskInstance as TIModel - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.utils.state import CallbackState, TaskInstanceState - - WorkloadKey = TaskInstanceKey | CallbackKey - WorkloadState = TaskInstanceState | CallbackState - - -__all__ = ["All", "ExecuteTask", "ExecuteCallback"] + from airflow.models.callback import Callback as CallbackModel log = structlog.get_logger(__name__) -class BaseWorkload(BaseModel): - token: str - """The identity token for this workload""" - - @staticmethod - def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: - return generator.generate({"sub": sub_id}) if generator else "" - - -class BundleInfo(BaseModel): - """Schema for telling task which bundle to run with.""" +class CallbackFetchMethod(str, Enum): + """Methods used to fetch callback at runtime.""" - name: str - version: str | None = None + # For future use once Dag Processor callbacks (on_success_callback/on_failure_callback) get moved to executors + DAG_ATTRIBUTE = "dag_attribute" - -class TaskInstance(BaseModel): - """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK.""" - - id: uuid.UUID - dag_version_id: uuid.UUID - task_id: str - dag_id: str - run_id: str - try_number: int - map_index: int = -1 - - pool_slots: int - queue: str - priority_weight: int - executor_config: dict | None = Field(default=None, exclude=True) - - parent_context_carrier: dict | None = None - context_carrier: dict | None = None - - # TODO: Task-SDK: Can we replace TastInstanceKey with just the uuid across the codebase? - @property - def key(self) -> TaskInstanceKey: - from airflow.models.taskinstancekey import TaskInstanceKey - - return TaskInstanceKey( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=self.run_id, - try_number=self.try_number, - map_index=self.map_index, - ) + # For deadline callbacks since they import callbacks through the import path + IMPORT_PATH = "import_path" -class Callback(BaseModel): +class CallbackDTO(BaseModel): """Schema for Callback with minimal required fields needed for Executors and Task SDK.""" id: str # A uuid.UUID stored as a string @@ -104,60 +54,10 @@ class Callback(BaseModel): data: dict -class BaseDagBundleWorkload(BaseWorkload, ABC): - """Base class for Workloads that are associated with a DAG bundle.""" - - dag_rel_path: os.PathLike[str] - """The filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`)""" - - bundle_info: BundleInfo - - log_path: str | None - """The rendered relative log filename template the task logs should be written to""" - - -class ExecuteTask(BaseDagBundleWorkload): - """Execute the given Task.""" - - ti: TaskInstance - sentry_integration: str = "" - - type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") - - @classmethod - def make( - cls, - ti: TIModel, - dag_rel_path: Path | None = None, - generator: JWTGenerator | None = None, - bundle_info: BundleInfo | None = None, - sentry_integration: str = "", - ) -> ExecuteTask: - from airflow.utils.helpers import log_filename_template_renderer - - ser_ti = TaskInstance.model_validate(ti, from_attributes=True) - ser_ti.parent_context_carrier = ti.dag_run.context_carrier - if not bundle_info: - bundle_info = BundleInfo( - name=ti.dag_model.bundle_name, - version=ti.dag_run.bundle_version, - ) - fname = log_filename_template_renderer()(ti=ti) - - return cls( - ti=ser_ti, - dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc or ""), - token=cls.generate_token(str(ti.id), generator), - log_path=fname, - bundle_info=bundle_info, - sentry_integration=sentry_integration, - ) - - class ExecuteCallback(BaseDagBundleWorkload): """Execute the given Callback.""" - callback: Callback + callback: CallbackDTO type: Literal["ExecuteCallback"] = Field(init=False, default="ExecuteCallback") @@ -170,6 +70,7 @@ def make( generator: JWTGenerator | None = None, bundle_info: BundleInfo | None = None, ) -> ExecuteCallback: + """Create an ExecuteCallback workload from a Callback ORM model.""" if not bundle_info: bundle_info = BundleInfo( name=dag_run.dag_model.bundle_name, @@ -178,7 +79,7 @@ def make( fname = f"executor_callbacks/{callback.id}" # TODO: better log file template return cls( - callback=Callback.model_validate(callback, from_attributes=True), + callback=CallbackDTO.model_validate(callback, from_attributes=True), dag_rel_path=dag_rel_path or Path(dag_run.dag_model.relative_fileloc or ""), token=cls.generate_token(callback.id, generator), log_path=fname, @@ -186,40 +87,8 @@ def make( ) -class RunTrigger(BaseModel): - """Execute an async "trigger" process that yields events.""" - - id: int - - ti: TaskInstance | None - """ - The task instance associated with this trigger. - - Could be none for asset-based triggers. - """ - - classpath: str - """ - Dot-separated name of the module+fn to import and run this workload. - - Consumers of this Workload must perform their own validation of this input. - """ - - encrypted_kwargs: str - - timeout_after: datetime | None = None - - type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger") - - -All = Annotated[ - ExecuteTask | ExecuteCallback | RunTrigger, - Field(discriminator="type"), -] - - def execute_callback_workload( - callback: Callback, + callback: CallbackDTO, log, ) -> tuple[bool, str | None]: """ @@ -242,7 +111,6 @@ def execute_callback_workload( :param log: Logger instance for recording execution :return: Tuple of (success: bool, error_message: str | None) """ - # Extract callback details from data callback_path = callback.data.get("path") callback_kwargs = callback.data.get("kwargs", {}) diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py new file mode 100644 index 0000000000000..3620c08dff13d --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task workload schemas for executor communication.""" + +from __future__ import annotations + +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +from pydantic import BaseModel, Field + +from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo + +if TYPE_CHECKING: + from airflow.api_fastapi.auth.tokens import JWTGenerator + from airflow.models.taskinstance import TaskInstance as TIModel + from airflow.models.taskinstancekey import TaskInstanceKey + + +class TaskInstanceDTO(BaseModel): + """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK.""" + + id: uuid.UUID + dag_version_id: uuid.UUID + task_id: str + dag_id: str + run_id: str + try_number: int + map_index: int = -1 + + pool_slots: int + queue: str + priority_weight: int + executor_config: dict | None = Field(default=None, exclude=True) + + parent_context_carrier: dict | None = None + context_carrier: dict | None = None + + @property + def key(self) -> TaskInstanceKey: + """Return the TaskInstanceKey for this task instance.""" + from airflow.models.taskinstancekey import TaskInstanceKey + + return TaskInstanceKey( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + try_number=self.try_number, + map_index=self.map_index, + ) + + +class ExecuteTask(BaseDagBundleWorkload): + """Execute the given Task.""" + + ti: TaskInstanceDTO + sentry_integration: str = "" + + type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") + + @classmethod + def make( + cls, + ti: TIModel, + dag_rel_path: Path | None = None, + generator: JWTGenerator | None = None, + bundle_info: BundleInfo | None = None, + sentry_integration: str = "", + ) -> ExecuteTask: + """Create an ExecuteTask workload from a TaskInstance ORM model.""" + from airflow.utils.helpers import log_filename_template_renderer + + ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True) + ser_ti.parent_context_carrier = ti.dag_run.context_carrier + if not bundle_info: + bundle_info = BundleInfo( + name=ti.dag_model.bundle_name, + version=ti.dag_run.bundle_version, + ) + fname = log_filename_template_renderer()(ti=ti) + + return cls( + ti=ser_ti, + dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc or ""), + token=cls.generate_token(str(ti.id), generator), + log_path=fname, + bundle_info=bundle_info, + sentry_integration=sentry_integration, + ) diff --git a/airflow-core/src/airflow/executors/workloads/trigger.py b/airflow-core/src/airflow/executors/workloads/trigger.py new file mode 100644 index 0000000000000..25bca9ce44b13 --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/trigger.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Trigger workload schemas for executor communication.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field + +# Using noqa because Ruff wants this in a TYPE_CHECKING block but Pydantic fails if it is. +from airflow.executors.workloads.task import TaskInstanceDTO # noqa: TCH001 + + +class RunTrigger(BaseModel): + """ + Execute an async "trigger" process that yields events. + + Consumers of this Workload must perform their own validation of the classpath input. + """ + + id: int + ti: TaskInstanceDTO | None # Could be none for asset-based triggers. + classpath: str # Dot-separated name of the module+fn to import and run this workload. + encrypted_kwargs: str + timeout_after: datetime | None = None + type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger") diff --git a/airflow-core/src/airflow/executors/workloads/types.py b/airflow-core/src/airflow/executors/workloads/types.py new file mode 100644 index 0000000000000..bc42444873a53 --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/types.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Type aliases for Workloads.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.models.callback import CallbackKey, ExecutorCallback + from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.utils.state import CallbackState, TaskInstanceState + + # Type aliases for workload keys and states (used by executor layer) + WorkloadKey = TaskInstanceKey | CallbackKey + WorkloadState = TaskInstanceState | CallbackState + + # Type alias for scheduler workloads (ORM models that can be routed to executors) + SchedulerWorkload = TaskInstance | ExecutorCallback + + # Type alias for executor workload results (used by executor implementations) + WorkloadResultType = tuple[WorkloadKey, WorkloadState, Exception | None] diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index be8213c423fd6..5567e4763ea1d 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -45,6 +45,7 @@ from airflow._shared.timezones import timezone from airflow.configuration import conf from airflow.executors import workloads +from airflow.executors.workloads.task import TaskInstanceDTO from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import perform_heartbeat from airflow.models.trigger import Trigger @@ -687,9 +688,7 @@ def update_triggers(self, requested_trigger_ids: set[int]): ti_id=new_trigger_orm.task_instance.id, ) continue - ser_ti = workloads.TaskInstance.model_validate( - new_trigger_orm.task_instance, from_attributes=True - ) + ser_ti = TaskInstanceDTO.model_validate(new_trigger_orm.task_instance, from_attributes=True) # When producing logs from TIs, include the job id producing the logs to disambiguate it. self.logger_cache[new_id] = TriggerLoggingFactory( log_path=f"{log_path}.trigger.{self.job.id}.log", diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index 5012f2d860673..b05d943347ebd 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations from datetime import datetime @@ -29,6 +30,8 @@ from airflow._shared.observability.metrics.stats import Stats from airflow._shared.timezones import timezone +from airflow.executors.workloads import BaseWorkload +from airflow.executors.workloads.callback import CallbackFetchMethod from airflow.models import Base from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime from airflow.utils.state import CallbackState @@ -60,16 +63,6 @@ class CallbackType(str, Enum): DAG_PROCESSOR = "dag_processor" -class CallbackFetchMethod(str, Enum): - """Methods used to fetch callback at runtime.""" - - # For future use once Dag Processor callbacks (on_success_callback/on_failure_callback) get moved to executors - DAG_ATTRIBUTE = "dag_attribute" - - # For deadline callbacks since they import callbacks through the import path - IMPORT_PATH = "import_path" - - class CallbackDefinitionProtocol(Protocol): """Protocol for TaskSDK Callback definition.""" @@ -93,7 +86,7 @@ class ImportPathExecutorCallbackDefProtocol(ImportPathCallbackDefProtocol, Proto executor: str | None -class Callback(Base): +class Callback(Base, BaseWorkload): """Base class for callbacks.""" __tablename__ = "callback" @@ -159,6 +152,14 @@ def get_metric_info(self, status: CallbackState, result: Any) -> dict: return {"stat": name, "tags": tags} + def get_dag_id(self) -> str | None: + """Return the DAG ID for scheduler routing.""" + return self.data.get("dag_id") + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.data.get("executor") + @staticmethod def create_from_sdk_def(callback_def: CallbackDefinitionProtocol, **kwargs) -> Callback: # Cannot check actual type using isinstance() because that would require SDK import diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 475cbd7ae68fb..8b12db483affa 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -70,6 +70,7 @@ from airflow._shared.timezones import timezone from airflow.assets.manager import asset_manager from airflow.configuration import conf +from airflow.executors.workloads import BaseWorkload from airflow.listeners.listener import get_listener_manager from airflow.models.asset import AssetModel from airflow.models.base import Base, StringID, TaskInstanceDependencies @@ -406,7 +407,7 @@ def uuid7() -> UUID: return uuid6.uuid7() -class TaskInstance(Base, LoggingMixin): +class TaskInstance(Base, LoggingMixin, BaseWorkload): """ Task instances store the state of a task instance. @@ -802,6 +803,14 @@ def key(self) -> TaskInstanceKey: """Returns a tuple that identifies the task instance uniquely.""" return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) + def get_dag_id(self) -> str: + """Return the DAG ID for scheduler routing.""" + return self.dag_id + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.executor + @provide_session def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool: """ diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index 9360ffc58c56f..fa0f311d018fe 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -35,7 +35,8 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType from airflow.executors.local_executor import LocalExecutor -from airflow.executors.workloads import Callback, execute_callback_workload +from airflow.executors.workloads.base import BundleInfo +from airflow.executors.workloads.callback import CallbackDTO, execute_callback_workload from airflow.models.callback import CallbackFetchMethod from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.sdk import BaseOperator @@ -591,7 +592,7 @@ def test_local_executor_supports_callbacks_true(self): @pytest.mark.db_test def test_queue_callback_without_support_raises_error(self, dag_maker, session): executor = BaseExecutor() # supports_callbacks = False by default - callback_data = Callback( + callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, @@ -599,7 +600,7 @@ def test_queue_callback_without_support_raises_error(self, dag_maker, session): callback_workload = workloads.ExecuteCallback( callback=callback_data, dag_rel_path="test.py", - bundle_info=workloads.BundleInfo(name="test_bundle", version="1.0"), + bundle_info=BundleInfo(name="test_bundle", version="1.0"), token="test_token", log_path="test.log", ) @@ -611,7 +612,7 @@ def test_queue_callback_without_support_raises_error(self, dag_maker, session): def test_queue_workload_with_execute_callback(self, dag_maker, session): executor = BaseExecutor() executor.supports_callbacks = True # Enable for this test - callback_data = Callback( + callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, @@ -619,7 +620,7 @@ def test_queue_workload_with_execute_callback(self, dag_maker, session): callback_workload = workloads.ExecuteCallback( callback=callback_data, dag_rel_path="test.py", - bundle_info=workloads.BundleInfo(name="test_bundle", version="1.0"), + bundle_info=BundleInfo(name="test_bundle", version="1.0"), token="test_token", log_path="test.log", ) @@ -634,7 +635,7 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): executor = BaseExecutor() executor.supports_callbacks = True # Enable for this test dagrun = setup_dagrun(dag_maker) - callback_data = Callback( + callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, @@ -642,7 +643,7 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): callback_workload = workloads.ExecuteCallback( callback=callback_data, dag_rel_path="test.py", - bundle_info=workloads.BundleInfo(name="test_bundle", version="1.0"), + bundle_info=BundleInfo(name="test_bundle", version="1.0"), token="test_token", log_path="test.log", ) @@ -661,7 +662,7 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): class TestExecuteCallbackWorkload: def test_execute_function_callback_success(self): - callback_data = Callback( + callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={ @@ -677,7 +678,7 @@ def test_execute_function_callback_success(self): assert error is None def test_execute_callback_missing_path(self): - callback_data = Callback( + callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"kwargs": {}}, # Missing 'path' @@ -690,7 +691,7 @@ def test_execute_callback_missing_path(self): assert "Callback path not found" in error def test_execute_callback_import_error(self): - callback_data = Callback( + callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={ @@ -707,7 +708,7 @@ def test_execute_callback_import_error(self): def test_execute_callback_execution_error(self): # Use a function that will raise an error; len() requires an argument - callback_data = Callback( + callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={ diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 265aff6384492..34e8f818aa94c 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -29,7 +29,9 @@ from airflow._shared.timezones import timezone from airflow.executors import workloads from airflow.executors.local_executor import LocalExecutor, _execute_work -from airflow.executors.workloads import Callback +from airflow.executors.workloads.base import BundleInfo +from airflow.executors.workloads.callback import CallbackDTO +from airflow.executors.workloads.task import TaskInstanceDTO from airflow.models.callback import CallbackFetchMethod from airflow.settings import Session from airflow.utils.state import State @@ -83,7 +85,7 @@ def test_executor_worker_spawned(self, mock_freeze, mock_unfreeze): @mock.patch("airflow.sdk.execution_time.supervisor.supervise") def test_execution(self, mock_supervise): success_tis = [ - workloads.TaskInstance( + TaskInstanceDTO( id=uuid7(), dag_version_id=uuid7(), task_id=f"success_{i}", @@ -337,12 +339,12 @@ def test_supports_callbacks_flag_is_true(self): assert executor.supports_callbacks is True @skip_spawn_mp_start - @mock.patch("airflow.executors.workloads.execute_callback_workload") + @mock.patch("airflow.executors.workloads.callback.execute_callback_workload") def test_process_callback_workload(self, mock_execute_callback): mock_execute_callback.return_value = (True, None) executor = LocalExecutor(parallelism=1) - callback_data = Callback( + callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, @@ -350,7 +352,7 @@ def test_process_callback_workload(self, mock_execute_callback): callback_workload = workloads.ExecuteCallback( callback=callback_data, dag_rel_path="test.py", - bundle_info=workloads.BundleInfo(name="test_bundle", version="1.0"), + bundle_info=BundleInfo(name="test_bundle", version="1.0"), token="test_token", log_path="test.log", ) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index c0e19043ea1ce..376c9ed96b204 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -157,37 +157,6 @@ def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None: self._send_tasks(task_tuples_to_send) - from airflow.executors.workloads import ExecuteCallback, ExecuteTask - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.celery.executors.celery_executor_utils import execute_workload ->>>>>>> 5f63b2e16c (CI fixes and minor type-related cleanup) - - tasks: list[TaskInstanceInCelery] = [] - for workload in workloads: - if isinstance(workload, ExecuteTask): -<<<<<<< HEAD - tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) - elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback): - tasks.append((workload.ti.key, workload, workload.ti.queue, execute_workload)) - elif isinstance(workload, ExecuteCallback): - # For callbacks, use a synthetic key based on callback ID - callback_key = TaskInstanceKey( - dag_id="callback", - task_id=workload.callback.id, - run_id="callback", - try_number=1, - map_index=-1, - ) ->>>>>>> 5f63b2e16c (CI fixes and minor type-related cleanup) - # Use default queue for callbacks, or extract from callback data if available - queue = "default" - if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: - queue = workload.callback.data["queue"] - tasks.append((workload.callback.key, workload, queue, self.team_name)) - else: - raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") - - self._send_tasks(tasks) def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: # Airflow V3 version -- have to delay imports until we know we are on v3 from airflow.executors.workloads import ExecuteTask @@ -209,39 +178,6 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") self._send_tasks(tasks) -======= - from airflow.executors.workloads import ExecuteCallback, ExecuteTask - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.celery.executors.celery_executor_utils import execute_workload ->>>>>>> 5f63b2e16c (CI fixes and minor type-related cleanup) - - tasks: list[TaskInstanceInCelery] = [] - for workload in workloads: - if isinstance(workload, ExecuteTask): -<<<<<<< HEAD - tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) - elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback): -======= - tasks.append((workload.ti.key, workload, workload.ti.queue, execute_workload)) - elif isinstance(workload, ExecuteCallback): - # For callbacks, use a synthetic key based on callback ID - callback_key = TaskInstanceKey( - dag_id="callback", - task_id=workload.callback.id, - run_id="callback", - try_number=1, - map_index=-1, - ) ->>>>>>> 5f63b2e16c (CI fixes and minor type-related cleanup) - # Use default queue for callbacks, or extract from callback data if available - queue = "default" - if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: - queue = workload.callback.data["queue"] - tasks.append((workload.callback.key, workload, queue, self.team_name)) - else: - raise ValueError(f"{type(self)}._process_workloads cannot handle {type(workload)}") - - self._send_tasks(tasks) def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): # Celery state queries will be stuck if we do not use one same backend diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py b/providers/celery/tests/integration/celery/test_celery_executor.py index 3b055dd0b4d66..da8ee15571b10 100644 --- a/providers/celery/tests/integration/celery/test_celery_executor.py +++ b/providers/celery/tests/integration/celery/test_celery_executor.py @@ -42,6 +42,7 @@ from airflow._shared.timezones import timezone from airflow.configuration import conf from airflow.executors import workloads +from airflow.executors.workloads.task import TaskInstanceDTO from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstance from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, TaskInstanceKey @@ -197,7 +198,7 @@ def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 ve executor.start() with start_worker(app=app, logfile=sys.stdout, loglevel="info"): - ti_success = workloads.TaskInstance.model_construct( + ti_success = TaskInstanceDTO.model_construct( id=uuid7(), task_id="success", dag_id="id", @@ -257,7 +258,7 @@ def test_error_sending_task(self): else: ti = TaskInstance(task=task, run_id="abc") workload = workloads.ExecuteTask.model_construct( - ti=workloads.TaskInstance.model_validate(ti, from_attributes=True), + ti=TaskInstanceDTO.model_validate(ti, from_attributes=True), ) key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1) @@ -309,7 +310,7 @@ def test_retry_on_error_sending_task(self, caplog): else: ti = TaskInstance(task=task, run_id="abc") workload = workloads.ExecuteTask.model_construct( - ti=workloads.TaskInstance.model_validate(ti, from_attributes=True), + ti=TaskInstanceDTO.model_validate(ti, from_attributes=True), ) key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 293d8700ec725..94feded9ff924 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1005,7 +1005,7 @@ components: - type: 'null' title: Log Path ti: - $ref: '#/components/schemas/TaskInstance' + $ref: '#/components/schemas/TaskInstanceDTO' sentry_integration: type: string title: Sentry Integration @@ -1151,7 +1151,7 @@ components: - log_chunk_data title: PushLogsBody description: Incremental new log content from worker. - TaskInstance: + TaskInstanceDTO: properties: id: type: string @@ -1209,7 +1209,7 @@ components: - pool_slots - queue - priority_weight - title: TaskInstance + title: TaskInstanceDTO description: Schema for TaskInstance with minimal required fields needed for Executors and Task SDK. TaskInstanceState: From 09679f61a9f5c28a9a6ff1f2ab88f867dd431eaf Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 4 Feb 2026 17:44:24 -0800 Subject: [PATCH 11/32] generalize _executor_to_tis and reuse it for all workload types --- .../src/airflow/executors/workloads/types.py | 21 ++-- .../src/airflow/jobs/scheduler_job_runner.py | 112 ++++++++---------- 2 files changed, 60 insertions(+), 73 deletions(-) diff --git a/airflow-core/src/airflow/executors/workloads/types.py b/airflow-core/src/airflow/executors/workloads/types.py index bc42444873a53..31cda7028466f 100644 --- a/airflow-core/src/airflow/executors/workloads/types.py +++ b/airflow-core/src/airflow/executors/workloads/types.py @@ -18,20 +18,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias + +from airflow.models.callback import ExecutorCallback +from airflow.models.taskinstance import TaskInstance if TYPE_CHECKING: - from airflow.models.callback import CallbackKey, ExecutorCallback - from airflow.models.taskinstance import TaskInstance + from airflow.models.callback import CallbackKey from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.state import CallbackState, TaskInstanceState # Type aliases for workload keys and states (used by executor layer) - WorkloadKey = TaskInstanceKey | CallbackKey - WorkloadState = TaskInstanceState | CallbackState - - # Type alias for scheduler workloads (ORM models that can be routed to executors) - SchedulerWorkload = TaskInstance | ExecutorCallback + WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey + WorkloadState: TypeAlias = TaskInstanceState | CallbackState # Type alias for executor workload results (used by executor implementations) - WorkloadResultType = tuple[WorkloadKey, WorkloadState, Exception | None] + WorkloadResultType: TypeAlias = tuple[WorkloadKey, WorkloadState, Exception | None] + +# Type alias for scheduler workloads (ORM models that can be routed to executors) +# Must be outside TYPE_CHECKING for use in function signatures +SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 44ee915621f44..b83a4357b86b5 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -131,6 +131,7 @@ from airflow._shared.logging.types import Logger from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName + from airflow.executors.workloads.types import SchedulerWorkload from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard @@ -981,7 +982,7 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int: queued_tis = self._executable_task_instances_to_queued(max_tis, session=session) # Sort queued TIs to their respective executor - executor_to_queued_tis = self._executor_to_tis(queued_tis, session) + executor_to_queued_tis = self._executor_to_workloads(queued_tis, session) for executor, queued_tis_per_executor in executor_to_queued_tis.items(): self.log.info( "Trying to enqueue tasks: %s for executor: %s", @@ -1020,38 +1021,15 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: if not queued_callbacks: return - executor_to_callbacks: dict[BaseExecutor, list[ExecutorCallback]] = defaultdict(list) - - for callback in queued_callbacks: - executor_name = None - if isinstance(callback.data, dict): - executor_name = callback.data.get("executor") - - executor = None - if executor_name: - for e in self.job.executors: - if e.__class__.__name__ == executor_name: - executor = e - break - if hasattr(e, "name") and e.name and str(e.name) == executor_name: - executor = e - break - if hasattr(e, "executor_name") and e.executor_name == executor_name: - executor = e - break - # Default to first executor if no specific executor found - if executor is None: - executor = self.job.executors[0] if self.job.executors else None - - if executor is None: - self.log.warning("No executor available for callback %s", callback.id) - continue - - executor_to_callbacks[executor].append(callback) + # Route callbacks to executors using the generalized routing method + executor_to_callbacks = self._executor_to_workloads(queued_callbacks, session) # Enqueue callbacks for each executor for executor, callbacks in executor_to_callbacks.items(): for callback in callbacks: + if not isinstance(callback, ExecutorCallback): + # Can't happen since we queried ExecutorCallback, but satisfies mypy. + continue dag_run = None if isinstance(callback.data, dict) and "dag_run_id" in callback.data: dag_run_id = callback.data["dag_run_id"] @@ -1150,11 +1128,11 @@ def process_executor_events( ti_primary_key_to_try_number_map[key.primary] = key.try_number cls.logger().info("Received executor event with state %s for task instance %s", state, key) if state in ( - TaskInstanceState.FAILED, - TaskInstanceState.SUCCESS, - TaskInstanceState.QUEUED, - TaskInstanceState.RUNNING, - TaskInstanceState.RESTARTING, + TaskInstanceState.FAILED, + TaskInstanceState.SUCCESS, + TaskInstanceState.QUEUED, + TaskInstanceState.RUNNING, + TaskInstanceState.RESTARTING, ): tis_with_right_state.append(key) else: @@ -2547,7 +2525,7 @@ def _handle_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) -> None: scheduled) up to 2 times before failing the task. """ tasks_stuck_in_queued = self._get_tis_stuck_in_queued(session) - for executor, stuck_tis in self._executor_to_tis(tasks_stuck_in_queued, session).items(): + for executor, stuck_tis in self._executor_to_workloads(tasks_stuck_in_queued, session).items(): try: for ti in stuck_tis: executor.revoke_task(ti=ti) @@ -2838,7 +2816,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: ) to_reset: list[TaskInstance] = [] - exec_to_tis = self._executor_to_tis(tis_to_adopt_or_reset, session) + exec_to_tis = self._executor_to_workloads(tis_to_adopt_or_reset, session) for executor, tis in exec_to_tis.items(): to_reset.extend(executor.try_adopt_task_instances(tis)) @@ -3187,50 +3165,54 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]: session.add(warning) existing_warned_dag_ids.add(warning.dag_id) - def _executor_to_tis( + def _executor_to_workloads( self, - tis: Iterable[TaskInstance], + workloads: Iterable[SchedulerWorkload], session, dag_id_to_team_name: dict[str, str | None] | None = None, - ) -> dict[BaseExecutor, list[TaskInstance]]: - """Organize TIs into lists per their respective executor.""" - tis_iter: Iterable[TaskInstance] + ) -> dict[BaseExecutor, list[SchedulerWorkload]]: + """Organize workloads into lists per their respective executor.""" + workloads_iter: Iterable[SchedulerWorkload] if conf.getboolean("core", "multi_team"): if dag_id_to_team_name is None: - if isinstance(tis, list): - tis_list = tis + if isinstance(workloads, list): + workloads_list = workloads else: - tis_list = list(tis) - if tis_list: + workloads_list = list(workloads) + if workloads_list: dag_id_to_team_name = self._get_team_names_for_dag_ids( - {ti.dag_id for ti in tis_list}, session + {dag_id for workload in workloads_list if + (dag_id := workload.get_dag_id()) is not None}, + session, ) else: dag_id_to_team_name = {} - tis_iter = tis_list + workloads_iter = workloads_list else: - tis_iter = tis + workloads_iter = workloads else: dag_id_to_team_name = {} - tis_iter = tis + workloads_iter = workloads - _executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] = defaultdict(list) - for ti in tis_iter: + _executor_to_workloads: defaultdict[BaseExecutor, list[SchedulerWorkload]] = defaultdict(list) + for workload in workloads_iter: if executor_obj := self._try_to_load_executor( - ti, session, team_name=dag_id_to_team_name.get(ti.dag_id, NOTSET) + workload, session, team_name=dag_id_to_team_name.get(workload.get_dag_id(), NOTSET) ): - _executor_to_tis[executor_obj].append(ti) + _executor_to_workloads[executor_obj].append(workload) - return _executor_to_tis + return _executor_to_workloads - def _try_to_load_executor(self, ti: TaskInstance, session, team_name=NOTSET) -> BaseExecutor | None: + def _try_to_load_executor( + self, workload: SchedulerWorkload, session, team_name=NOTSET + ) -> BaseExecutor | None: """ Try to load the given executor. In this context, we don't want to fail if the executor does not exist. Catch the exception and log to the user. - :param ti: TaskInstance to load executor for + :param workload: SchedulerWorkload (TaskInstance or ExecutorCallback) to load executor for :param session: Database session for queries :param team_name: Optional pre-resolved team name. If NOTSET and multi-team is enabled, will query the database to resolve team name. None indicates global team. @@ -3239,12 +3221,12 @@ def _try_to_load_executor(self, ti: TaskInstance, session, team_name=NOTSET) -> if conf.getboolean("core", "multi_team"): # Use provided team_name if available, otherwise query the database if team_name is NOTSET: - team_name = self._get_task_team_name(ti, session) + team_name = self._get_task_team_name(workload, session) else: team_name = None - # Firstly, check if there is no executor set on the TaskInstance, if not, we need to fetch the default + # Firstly, check if there is no executor set on the workload, if not, we need to fetch the default # (either globally or for the team) - if ti.executor is None: + if workload.executor is None: if not team_name: # No team is specified, so just use the global default executor executor = self.executor @@ -3259,22 +3241,24 @@ def _try_to_load_executor(self, ti: TaskInstance, session, team_name=NOTSET) -> # No executor found for that team, fall back to global default executor = self.executor else: - # An executor is specified on the TaskInstance (as a str), so we need to find it in the list of executors - for _executor in self.executors: - if _executor.name and ti.executor in (_executor.name.alias, _executor.name.module_path): + # An executor is specified on the workload (as a str), so we need to find it in the list of executors + for _executor in self.job.executors: + if _executor.name and workload.executor in (_executor.name.alias, _executor.name.module_path): # The executor must either match the team or be global (i.e. team_name is None) if team_name and _executor.team_name == team_name or _executor.team_name is None: executor = _executor if executor is not None: - self.log.debug("Found executor %s for task %s (team: %s)", executor.name, ti, team_name) + self.log.debug("Found executor %s for workload %s (team: %s)", executor.name, workload, team_name) else: # This case should not happen unless some (as of now unknown) edge case occurs or direct DB # modification, since the DAG parser will validate the tasks in the DAG and ensure the executor # they request is available and if not, disallow the DAG to be scheduled. # Keeping this exception handling because this is a critical issue if we do somehow find # ourselves here and the user should get some feedback about that. - self.log.warning("Executor, %s, was not found but a Task was configured to use it", ti.executor) + self.log.warning( + "Executor, %s, was not found but a workload was configured to use it", workload.executor + ) return executor From e637ccdd525420fc60bc8c4723080209f6a494f5 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 5 Feb 2026 23:32:36 -0800 Subject: [PATCH 12/32] celery fixes --- .../airflow/executors/workloads/callback.py | 7 +++- .../celery/executors/celery_executor.py | 33 ++++++++---------- .../celery/executors/celery_executor_utils.py | 34 ++++++++++++------- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index d8e33f08def57..2073e716b016f 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.models import DagRun - from airflow.models.callback import Callback as CallbackModel + from airflow.models.callback import Callback as CallbackModel, CallbackKey log = structlog.get_logger(__name__) @@ -53,6 +53,11 @@ class CallbackDTO(BaseModel): fetch_method: CallbackFetchMethod data: dict + @property + def key(self) -> CallbackKey: + """Return callback ID as key (CallbackKey = str).""" + return self.id + class ExecuteCallback(BaseDagBundleWorkload): """Execute the given Callback.""" diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 376c9ed96b204..60c844cd583f8 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -56,9 +56,10 @@ from airflow.cli.cli_config import GroupCommand from airflow.executors import workloads + from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.celery.executors.celery_executor_utils import TaskInstanceInCelery, TaskTuple + from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery # PEP562 @@ -98,10 +99,14 @@ class CeleryExecutor(BaseExecutor): supports_sentry: bool = True supports_multi_team: bool = True - if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: - # In the v3 path, we store workloads, not commands as strings. - # TODO: TaskSDK: move this type change into BaseExecutor - queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] + if TYPE_CHECKING: + if AIRFLOW_V_3_2_PLUS: + # In v3.2+, callbacks are supported, so keys can be TaskInstanceKey OR str (CallbackKey) + queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment] + elif AIRFLOW_V_3_0_PLUS: + # In v3.0-3.1, only tasks are supported (no callbacks yet) + # TODO: TaskSDK: move this type change into BaseExecutor + queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment,no-redef] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -164,7 +169,7 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: if AIRFLOW_V_3_2_PLUS: from airflow.executors.workloads import ExecuteCallback - tasks: list[TaskInstanceInCelery] = [] + tasks: list[WorkloadInCelery] = [] for workload in workloads: if isinstance(workload, ExecuteTask): tasks.append((workload.ti.key, workload, workload.ti.queue, self.team_name)) @@ -179,7 +184,7 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: self._send_tasks(tasks) - def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): + def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]): # Celery state queries will be stuck if we do not use one same backend # for all tasks. cached_celery_backend = self.celery_app.backend @@ -218,7 +223,7 @@ def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): # which point we don't need the ID anymore anyway self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id) - def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): + def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[WorkloadInCelery]): from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1: @@ -386,20 +391,12 @@ def get_cli_commands() -> list[GroupCommand]: def queue_workload(self, workload: workloads.All, session: Session | None) -> None: from airflow.executors import workloads - from airflow.models.taskinstancekey import TaskInstanceKey if isinstance(workload, workloads.ExecuteTask): ti = workload.ti self.queued_tasks[ti.key] = workload elif isinstance(workload, workloads.ExecuteCallback): - # For callbacks, use a synthetic key based on callback ID - callback_key = TaskInstanceKey( - dag_id="callback", - task_id=workload.callback.id, - run_id="callback", - try_number=1, - map_index=-1, - ) - self.queued_tasks[callback_key] = workload + # Use workload.callback.key (CallbackKey = str) + self.queued_tasks[workload.callback.key] = workload else: raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index bfa07d155c85d..c92ac306556a5 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -41,7 +41,7 @@ from sqlalchemy import select from airflow.configuration import AirflowConfigParser, conf -from airflow.executors.base_executor import BaseExecutor +from airflow.executors.base_executor import BaseExecutor, EventBufferValueType from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, Stats, timeout from airflow.utils.log.logging_mixin import LoggingMixin @@ -53,6 +53,9 @@ except ImportError: from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager +if AIRFLOW_V_3_2_PLUS: + from airflow.executors.workloads.callback import execute_callback_workload + log = logging.getLogger(__name__) if sys.platform == "darwin": @@ -66,17 +69,22 @@ from celery.result import AsyncResult from airflow.executors import workloads - from airflow.executors.base_executor import EventBufferValueType, ExecutorConf + from airflow.executors.base_executor import ExecutorConf + from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstanceKey # We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so unfortunately we just have to define # the type as the union of both kinds CommandType = Sequence[str] - TaskInstanceInCelery: TypeAlias = tuple[ - TaskInstanceKey, workloads.All | CommandType, str | None, str | None + WorkloadInCelery: TypeAlias = tuple[WorkloadKey, workloads.All | CommandType, str | None, str | None] + WorkloadInCeleryResult: TypeAlias = tuple[ + WorkloadKey, CommandType, AsyncResult | "ExceptionWithTraceback" ] + # Deprecated alias for backward compatibility + TaskInstanceInCelery: TypeAlias = WorkloadInCelery + TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None] OPERATION_TIMEOUT = conf.getfloat("celery", "operation_timeout") @@ -201,7 +209,7 @@ def execute_workload(input: str) -> None: log_path=workload.log_path, ) elif isinstance(workload, workloads.ExecuteCallback): - success, error_msg = workloads.execute_callback_workload(workload.callback, log) + success, error_msg = execute_callback_workload(workload.callback, log) if not success: raise RuntimeError(error_msg or "Callback execution failed") else: @@ -307,16 +315,16 @@ def __init__(self, exception: BaseException, exception_traceback: str): self.traceback = exception_traceback -def send_task_to_executor( - task_tuple: TaskInstanceInCelery, -) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]: +def send_workload_to_executor( + workload_tuple: WorkloadInCelery, +) -> WorkloadInCeleryResult: """ - Send task to executor. + Send workload to executor. This function is called in ProcessPoolExecutor subprocesses. To avoid pickling issues with team-specific Celery apps, we pass the team_name and reconstruct the Celery app here. """ - key, args, queue, team_name = task_tuple + key, args, queue, team_name = workload_tuple # Reconstruct the Celery app from configuration, which may or may not be team-specific. # ExecutorConf wraps config access to automatically use team-specific config where present. @@ -341,8 +349,6 @@ def send_task_to_executor( assert isinstance(args, workloads.BaseWorkload) args = (args.model_dump_json(),) else: - # Get the task from the app - task_to_run = celery_app.tasks["execute_command"] args = [args] # type: ignore[list-item] # Pre-import redis.client to avoid SIGALRM interrupting module initialization. @@ -366,6 +372,10 @@ def send_task_to_executor( return key, args, result +# Backward compatibility alias +send_task_to_executor = send_workload_to_executor + + def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | ExceptionWithTraceback, Any]: """ Fetch and return the state of the given celery task. From 6abf97a5de85d91b64ce6b5dd6735d75bf458e4e Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 6 Feb 2026 07:48:38 -0800 Subject: [PATCH 13/32] fix bad merge --- .../src/airflow/executors/local_executor.py | 3 ++- .../airflow/executors/workloads/callback.py | 13 +++++++++-- .../src/airflow/jobs/scheduler_job_runner.py | 2 +- .../tests/unit/jobs/test_scheduler_job.py | 22 +++++++++---------- .../celery/executors/celery_executor_utils.py | 8 ++++--- .../executors/celery_kubernetes_executor.py | 14 +++++++----- .../executors/local_kubernetes_executor.py | 2 +- .../src/airflow/sdk/definitions/deadline.py | 3 --- 8 files changed, 40 insertions(+), 27 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 0e91a522788e6..7acf24cb15805 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -37,6 +37,7 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor +from airflow.executors.workloads.callback import execute_callback_workload from airflow.utils.state import CallbackState, TaskInstanceState # add logger to parameter of setproctitle to support logging @@ -160,7 +161,7 @@ def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_con """ setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log) - success, error_msg = workloads.execute_callback_workload(workload.callback, log) + success, error_msg = execute_callback_workload(workload.callback, log) if not success: raise RuntimeError(error_msg or "Callback execution failed") diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 2073e716b016f..bac5b443178a8 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -22,9 +22,10 @@ from importlib import import_module from pathlib import Path from typing import TYPE_CHECKING, Literal +from uuid import UUID import structlog -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo @@ -53,6 +54,14 @@ class CallbackDTO(BaseModel): fetch_method: CallbackFetchMethod data: dict + @field_validator("id", mode="before") + @classmethod + def validate_id(cls, v): + """Convert UUID to str if needed.""" + if isinstance(v, UUID): + return str(v) + return v + @property def key(self) -> CallbackKey: """Return callback ID as key (CallbackKey = str).""" @@ -86,7 +95,7 @@ def make( return cls( callback=CallbackDTO.model_validate(callback, from_attributes=True), dag_rel_path=dag_rel_path or Path(dag_run.dag_model.relative_fileloc or ""), - token=cls.generate_token(callback.id, generator), + token=cls.generate_token(str(callback.id), generator), log_path=fname, bundle_info=bundle_info, ) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index b83a4357b86b5..b5071abea71b2 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -86,7 +86,7 @@ ) from airflow.models.backfill import Backfill from airflow.models.callback import Callback, CallbackType, ExecutorCallback -from airflow.models.dag import DagModel, get_next_data_interval +from airflow.models.dag import DagModel from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag from airflow.models.dagbundle import DagBundleModel diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index ea9c51220ca6b..f883cbcd2cb60 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -1311,7 +1311,7 @@ def test_find_executable_task_instances_executor_with_teams(self, dag_maker, moc assert len(res) == 5 # Verify that each task is routed to the correct executor - executor_to_tis = self.job_runner._executor_to_tis(res, session) + executor_to_tis = self.job_runner._executor_to_workloads(res, session) # Team pi tasks should go to mock_executors[0] (configured for team_pi) a_tis_in_executor = [ti for ti in executor_to_tis.get(mock_executors[0], []) if ti.dag_id == "dag_a"] @@ -7909,7 +7909,7 @@ def test_multi_team_get_team_names_for_dag_ids_database_error(self, mock_log, da assert result == {} mock_log.exception.assert_called_once() - def test_multi_team_get_task_team_name_success(self, dag_maker, session): + def test_multi_team_get_workload_team_name_success(self, dag_maker, session): """Test successful team name resolution for a single task.""" clear_db_teams() clear_db_dag_bundles() @@ -7932,10 +7932,10 @@ def test_multi_team_get_task_team_name_success(self, dag_maker, session): scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - result = self.job_runner._get_task_team_name(ti, session) + result = self.job_runner._get_workload_team_name(ti, session) assert result == "team_a" - def test_multi_team_get_task_team_name_no_team(self, dag_maker, session): + def test_multi_team_get_workload_team_name_no_team(self, dag_maker, session): """Test team resolution when no team is associated with the DAG.""" with dag_maker(dag_id="dag_no_team", session=session): task = EmptyOperator(task_id="task_no_team") @@ -7946,10 +7946,10 @@ def test_multi_team_get_task_team_name_no_team(self, dag_maker, session): scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - result = self.job_runner._get_task_team_name(ti, session) + result = self.job_runner._get_workload_team_name(ti, session) assert result is None - def test_multi_team_get_task_team_name_database_error(self, dag_maker, session): + def test_multi_team_get_workload_team_name_database_error(self, dag_maker, session): """Test graceful error handling when individual task team resolution fails. This code should _not_ fail the scheduler.""" with dag_maker(dag_id="dag_test", session=session): task = EmptyOperator(task_id="task_test") @@ -7962,7 +7962,7 @@ def test_multi_team_get_task_team_name_database_error(self, dag_maker, session): # Mock _get_team_names_for_dag_ids to return empty dict (simulates database error handling in that function) with mock.patch.object(self.job_runner, "_get_team_names_for_dag_ids", return_value={}) as mock_batch: - result = self.job_runner._get_task_team_name(ti, session) + result = self.job_runner._get_workload_team_name(ti, session) mock_batch.assert_called_once_with([ti.dag_id], session) # Should return None when batch function returns empty dict @@ -7980,7 +7980,7 @@ def test_multi_team_try_to_load_executor_multi_team_disabled(self, dag_maker, mo scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - with mock.patch.object(self.job_runner, "_get_task_team_name") as mock_team_resolve: + with mock.patch.object(self.job_runner, "_get_workload_team_name") as mock_team_resolve: result = self.job_runner._try_to_load_executor(ti, session) # Should not call team resolution when multi_team is disabled mock_team_resolve.assert_not_called() @@ -8177,7 +8177,7 @@ def test_multi_team_try_to_load_executor_explicit_executor_team_mismatch( # Should log a warning when no executor is found mock_log.warning.assert_called_once_with( - "Executor, %s, was not found but a Task was configured to use it", "secondary_exec" + "Executor, %s, was not found but a workload was configured to use it", "secondary_exec" ) # Should return None since we failed to resolve an executor due to the mismatch. In practice, this @@ -8229,7 +8229,7 @@ def test_multi_team_try_to_load_executor_team_name_pre_resolved(self, dag_maker, self.job_runner = SchedulerJobRunner(job=scheduler_job) # Call with pre-resolved team name (as done in the scheduling loop) - with mock.patch.object(self.job_runner, "_get_task_team_name") as mock_team_resolve: + with mock.patch.object(self.job_runner, "_get_workload_team_name") as mock_team_resolve: result = self.job_runner._try_to_load_executor(ti, session, team_name="team_a") mock_team_resolve.assert_not_called() # We don't query for the team if it is pre-resolved @@ -8342,7 +8342,7 @@ def test_multi_team_config_disabled_uses_legacy_behavior(self, dag_maker, mock_e scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) - with mock.patch.object(self.job_runner, "_get_task_team_name") as mock_team_resolve: + with mock.patch.object(self.job_runner, "_get_workload_team_name") as mock_team_resolve: result1 = self.job_runner._try_to_load_executor(ti1, session) result2 = self.job_runner._try_to_load_executor(ti2, session) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index c92ac306556a5..8f486ce9945b3 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -41,7 +41,7 @@ from sqlalchemy import select from airflow.configuration import AirflowConfigParser, conf -from airflow.executors.base_executor import BaseExecutor, EventBufferValueType +from airflow.executors.base_executor import BaseExecutor from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, Stats, timeout from airflow.utils.log.logging_mixin import LoggingMixin @@ -69,7 +69,7 @@ from celery.result import AsyncResult from airflow.executors import workloads - from airflow.executors.base_executor import ExecutorConf + from airflow.executors.base_executor import EventBufferValueType, ExecutorConf from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstanceKey @@ -331,6 +331,7 @@ def send_workload_to_executor( if TYPE_CHECKING: _conf: ExecutorConf | AirflowConfigParser # Check if Airflow version is greater than or equal to 3.2 to import ExecutorConf + # Check if Airflow version is greater than or equal to 3.2 to import ExecutorConf if AIRFLOW_V_3_2_PLUS: from airflow.executors.base_executor import ExecutorConf @@ -338,7 +339,6 @@ def send_workload_to_executor( else: # Airflow <3.2 ExecutorConf doesn't exist (at least not with the required attributes), fall back to global conf _conf = conf - # Create the Celery app with the correct configuration celery_app = create_celery_app(_conf) @@ -349,6 +349,8 @@ def send_workload_to_executor( assert isinstance(args, workloads.BaseWorkload) args = (args.model_dump_json(),) else: + # Get the task from the app + task_to_run = celery_app.tasks["execute_command"] args = [args] # type: ignore[list-item] # Pre-import redis.client to avoid SIGALRM interrupting module initialization. diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index f2203ad746ee1..a8a4b3eb75eaa 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -34,7 +34,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand from airflow.executors.base_executor import EventBufferValueType - from airflow.executors.workloads import WorkloadKey + from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import ( # type: ignore[attr-defined] SimpleTaskInstance, TaskInstance, @@ -101,11 +101,15 @@ def _task_event_logs(self): def _task_event_logs(self, value): """Not implemented for hybrid executors.""" - @property - def queued_tasks(self) -> dict[TaskInstanceKey, Any]: - """Return queued tasks from celery and kubernetes executor.""" + @property # type: ignore[override] + def queued_tasks(self) -> dict[TaskInstanceKey | str, Any]: + """ + Return queued tasks from celery and kubernetes executor. + + TODO: Union type used for compatibility. When AIRFLOW_V_3_0_PLUS is removed, change return type to WorkloadKey. + """ queued_tasks = self.celery_executor.queued_tasks.copy() - queued_tasks.update(self.kubernetes_executor.queued_tasks) + queued_tasks.update(self.kubernetes_executor.queued_tasks) # type: ignore[arg-type] return queued_tasks diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 9209de0a7ac15..7ff79f2fae4bc 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -33,7 +33,7 @@ from airflow.cli.cli_config import GroupCommand from airflow.executors.base_executor import EventBufferValueType from airflow.executors.local_executor import LocalExecutor - from airflow.executors.workloads import WorkloadKey + from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import ( # type: ignore[attr-defined] SimpleTaskInstance, TaskInstance, diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py index b591c538ad991..8c55e10d45c86 100644 --- a/task-sdk/src/airflow/sdk/definitions/deadline.py +++ b/task-sdk/src/airflow/sdk/definitions/deadline.py @@ -22,9 +22,6 @@ from airflow.models.deadline import DeadlineReferenceType, ReferenceModels from airflow.sdk.definitions.callback import AsyncCallback, Callback, SyncCallback -from airflow.sdk.serde import deserialize, serialize -from airflow.serialization.definitions.deadline import DeadlineAlertFields -from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding if TYPE_CHECKING: from collections.abc import Callable From a5c75337b02c7f695e32e40439c96a6469d265fd Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 6 Feb 2026 16:19:23 -0800 Subject: [PATCH 14/32] mypy and pydantic typing issues --- airflow-core/src/airflow/jobs/scheduler_job_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index b5071abea71b2..e2c09d58498b5 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -3226,7 +3226,7 @@ def _try_to_load_executor( team_name = None # Firstly, check if there is no executor set on the workload, if not, we need to fetch the default # (either globally or for the team) - if workload.executor is None: + if workload.get_executor_name() is None: if not team_name: # No team is specified, so just use the global default executor executor = self.executor @@ -3243,7 +3243,7 @@ def _try_to_load_executor( else: # An executor is specified on the workload (as a str), so we need to find it in the list of executors for _executor in self.job.executors: - if _executor.name and workload.executor in (_executor.name.alias, _executor.name.module_path): + if workload.get_executor_name() in (_executor.name.alias, _executor.name.module_path): # The executor must either match the team or be global (i.e. team_name is None) if team_name and _executor.team_name == team_name or _executor.team_name is None: executor = _executor @@ -3257,7 +3257,8 @@ def _try_to_load_executor( # Keeping this exception handling because this is a critical issue if we do somehow find # ourselves here and the user should get some feedback about that. self.log.warning( - "Executor, %s, was not found but a workload was configured to use it", workload.executor + "Executor, %s, was not found but a workload was configured to use it", + workload.get_executor_name(), ) return executor From b33e30cef3774e34bd5d344fcbd824266493e880 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 10 Feb 2026 18:45:50 -0800 Subject: [PATCH 15/32] rename BaseWorkloadSchema.token to BaseWorkloadSchema.identity_token --- airflow-core/src/airflow/executors/local_executor.py | 2 +- airflow-core/src/airflow/executors/workloads/base.py | 2 +- airflow-core/src/airflow/executors/workloads/callback.py | 2 +- airflow-core/src/airflow/executors/workloads/task.py | 2 +- airflow-core/tests/unit/executors/test_local_executor.py | 6 +++--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 7acf24cb15805..7dc1471a39be0 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -145,7 +145,7 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No ti=workload.ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, - token=workload.token, + token=workload.identity_token, server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), log_path=workload.log_path, ) diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 7e6bffc56b35f..98676adea19e1 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -67,7 +67,7 @@ class BundleInfo(BaseModel): class BaseWorkloadSchema(BaseModel): """Base Pydantic schema for executor workload DTOs.""" - token: str # The identity token for this workload. + identity_token: str @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index bac5b443178a8..1615808e40d17 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -95,7 +95,7 @@ def make( return cls( callback=CallbackDTO.model_validate(callback, from_attributes=True), dag_rel_path=dag_rel_path or Path(dag_run.dag_model.relative_fileloc or ""), - token=cls.generate_token(str(callback.id), generator), + identity_token=cls.generate_token(str(callback.id), generator), log_path=fname, bundle_info=bundle_info, ) diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index 3620c08dff13d..10ac15514358c 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -97,7 +97,7 @@ def make( return cls( ti=ser_ti, dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc or ""), - token=cls.generate_token(str(ti.id), generator), + identity_token=cls.generate_token(str(ti.id), generator), log_path=fname, bundle_info=bundle_info, sentry_integration=sentry_integration, diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 34e8f818aa94c..23e2feaa56634 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -125,7 +125,7 @@ def fake_supervise(ti, **kwargs): for ti in success_tis: executor.queue_workload( workloads.ExecuteTask( - token="", + identity_token="", ti=ti, dag_rel_path="some/path", log_path=None, @@ -136,7 +136,7 @@ def fake_supervise(ti, **kwargs): executor.queue_workload( workloads.ExecuteTask( - token="", + identity_token="", ti=fail_ti, dag_rel_path="some/path", log_path=None, @@ -353,7 +353,7 @@ def test_process_callback_workload(self, mock_execute_callback): callback=callback_data, dag_rel_path="test.py", bundle_info=BundleInfo(name="test_bundle", version="1.0"), - token="test_token", + identity_token="test_token", log_path="test.log", ) From 9959dd3be8f59bfb68afc8f8c8299648edae2f0a Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 10 Feb 2026 18:46:16 -0800 Subject: [PATCH 16/32] use correct state type in callbacks --- airflow-core/src/airflow/jobs/scheduler_job_runner.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index e2c09d58498b5..bac2de5a16cf4 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1138,7 +1138,7 @@ def process_executor_events( else: # Callback event (key is string UUID) cls.logger().info("Received executor event with state %s for callback %s", state, key) - if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS): + if state in (CallbackState.FAILED, CallbackState.SUCCESS): callback_keys_with_events.append(key) # Handle callback completion events @@ -1146,12 +1146,10 @@ def process_executor_events( state, info = event_buffer.pop(callback_id) callback = session.get(Callback, callback_id) if callback: - # Note: We receive TaskInstanceState from executor (SUCCESS/FAILED) but convert to CallbackState here. - # This is intentional - executor layer uses generic completion states, scheduler converts to proper types. - if state == TaskInstanceState.SUCCESS: + if state == CallbackState.SUCCESS: callback.state = CallbackState.SUCCESS cls.logger().info("Callback %s completed successfully", callback_id) - elif state == TaskInstanceState.FAILED: + elif state == CallbackState.FAILED: callback.state = CallbackState.FAILED callback.output = str(info) if info else "Execution failed" cls.logger().error("Callback %s failed: %s", callback_id, callback.output) From 15005629a93b3c764733381afc8f9cee77e53b80 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 10 Feb 2026 18:56:20 -0800 Subject: [PATCH 17/32] revert changes to the deprecated executors --- .../executors/celery_kubernetes_executor.py | 21 ++++++++----------- .../executors/local_kubernetes_executor.py | 9 ++++---- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index a8a4b3eb75eaa..49ae5b35b6f52 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -34,7 +34,6 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand from airflow.executors.base_executor import EventBufferValueType - from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import ( # type: ignore[attr-defined] SimpleTaskInstance, TaskInstance, @@ -101,15 +100,11 @@ def _task_event_logs(self): def _task_event_logs(self, value): """Not implemented for hybrid executors.""" - @property # type: ignore[override] - def queued_tasks(self) -> dict[TaskInstanceKey | str, Any]: - """ - Return queued tasks from celery and kubernetes executor. - - TODO: Union type used for compatibility. When AIRFLOW_V_3_0_PLUS is removed, change return type to WorkloadKey. - """ + @property + def queued_tasks(self) -> dict[TaskInstanceKey, Any]: + """Return queued tasks from celery and kubernetes executor.""" queued_tasks = self.celery_executor.queued_tasks.copy() - queued_tasks.update(self.kubernetes_executor.queued_tasks) # type: ignore[arg-type] + queued_tasks.update(self.kubernetes_executor.queued_tasks) return queued_tasks @@ -118,8 +113,8 @@ def queued_tasks(self, value) -> None: """Not implemented for hybrid executors.""" @property - def running(self) -> set[WorkloadKey]: - """Combine running from both executors.""" + def running(self) -> set[TaskInstanceKey]: + """Return running tasks from celery and kubernetes executor.""" return self.celery_executor.running.union(self.kubernetes_executor.running) @running.setter @@ -230,7 +225,9 @@ def heartbeat(self) -> None: self.celery_executor.heartbeat() self.kubernetes_executor.heartbeat() - def get_event_buffer(self, dag_ids: list[str] | None = None) -> dict[WorkloadKey, EventBufferValueType]: + def get_event_buffer( + self, dag_ids: list[str] | None = None + ) -> dict[TaskInstanceKey, EventBufferValueType]: """ Return and flush the event buffer from celery and kubernetes executor. diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 7ff79f2fae4bc..114da7ec36fe8 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -33,7 +33,6 @@ from airflow.cli.cli_config import GroupCommand from airflow.executors.base_executor import EventBufferValueType from airflow.executors.local_executor import LocalExecutor - from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import ( # type: ignore[attr-defined] SimpleTaskInstance, TaskInstance, @@ -110,8 +109,8 @@ def queued_tasks(self, value) -> None: """Not implemented for hybrid executors.""" @property - def running(self) -> set[WorkloadKey]: - """Combine running from both executors.""" + def running(self) -> set[TaskInstanceKey]: + """Return running tasks from local and kubernetes executor.""" return self.local_executor.running.union(self.kubernetes_executor.running) @running.setter @@ -220,7 +219,9 @@ def heartbeat(self) -> None: self.local_executor.heartbeat() self.kubernetes_executor.heartbeat() - def get_event_buffer(self, dag_ids: list[str] | None = None) -> dict[WorkloadKey, EventBufferValueType]: + def get_event_buffer( + self, dag_ids: list[str] | None = None + ) -> dict[TaskInstanceKey, EventBufferValueType]: """ Return and flush the event buffer from local and kubernetes executor. From 8e59873031ac3ff62068ccae9829bd675e0a8080 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 10 Feb 2026 19:25:36 -0800 Subject: [PATCH 18/32] revert dropped TODO --- airflow-core/src/airflow/executors/workloads/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index 10ac15514358c..fe20e5a1afa2f 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -51,9 +51,9 @@ class TaskInstanceDTO(BaseModel): parent_context_carrier: dict | None = None context_carrier: dict | None = None + # TODO: Task-SDK: Can we replace TaskInstanceKey with just the uuid across the codebase? @property def key(self) -> TaskInstanceKey: - """Return the TaskInstanceKey for this task instance.""" from airflow.models.taskinstancekey import TaskInstanceKey return TaskInstanceKey( From 1fdc461a41a2ac1e56df711b25d41f51caffb6ce Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 10 Feb 2026 19:49:54 -0800 Subject: [PATCH 19/32] missed some identity_token renames in a test module --- airflow-core/tests/unit/executors/test_base_executor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index fa0f311d018fe..31c6663f98eca 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -601,7 +601,7 @@ def test_queue_callback_without_support_raises_error(self, dag_maker, session): callback=callback_data, dag_rel_path="test.py", bundle_info=BundleInfo(name="test_bundle", version="1.0"), - token="test_token", + identity_token="test_token", log_path="test.log", ) @@ -621,7 +621,7 @@ def test_queue_workload_with_execute_callback(self, dag_maker, session): callback=callback_data, dag_rel_path="test.py", bundle_info=BundleInfo(name="test_bundle", version="1.0"), - token="test_token", + identity_token="test_token", log_path="test.log", ) @@ -644,7 +644,7 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): callback=callback_data, dag_rel_path="test.py", bundle_info=BundleInfo(name="test_bundle", version="1.0"), - token="test_token", + identity_token="test_token", log_path="test.log", ) executor.queue_workload(callback_workload, session) From a298fe07bac57f134bc8c10ef479dd6e55884b4b Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 10 Feb 2026 19:50:23 -0800 Subject: [PATCH 20/32] make better use of the CallbackKey type alias --- airflow-core/src/airflow/executors/base_executor.py | 5 +++-- airflow-core/src/airflow/models/callback.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index f0226091a9b27..b39e2e9910a5c 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -34,6 +34,7 @@ from airflow.executors.executor_loader import ExecutorLoader from airflow.executors.workloads.task import TaskInstanceDTO from airflow.models import Log +from airflow.models.callback import CallbackKey from airflow.observability.metrics import stats_utils from airflow.observability.trace import Trace from airflow.utils.log.logging_mixin import LoggingMixin @@ -515,8 +516,8 @@ def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueTy self.event_buffer = {} else: for key in list(self.event_buffer.keys()): - # Include if it's a callback (string key) or if it's a task in the specified dags - if isinstance(key, str) or key.dag_id in dag_ids: + # Include if it's a callback or if it's a task in the specified dags + if isinstance(key, CallbackKey) or key.dag_id in dag_ids: cleared_events[key] = self.event_buffer.pop(key) return cleared_events diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index b05d943347ebd..7e5ee8b63ee69 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -36,14 +36,14 @@ from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime from airflow.utils.state import CallbackState +CallbackKey = str # Callback keys are str(UUID) + if TYPE_CHECKING: from sqlalchemy.orm import Session from airflow.callbacks.callback_requests import CallbackRequest from airflow.triggers.base import TriggerEvent - CallbackKey = str # Callback keys are str(UUID) - log = structlog.get_logger(__name__) From 5edaa1dda90c8e775522b31de25925d78e1f568b Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 10 Feb 2026 21:05:33 -0800 Subject: [PATCH 21/32] pr fixes --- .../src/airflow/jobs/scheduler_job_runner.py | 59 +++++++++++-------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index bac2de5a16cf4..031b86224ad4c 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1030,22 +1030,28 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: if not isinstance(callback, ExecutorCallback): # Can't happen since we queried ExecutorCallback, but satisfies mypy. continue - dag_run = None - if isinstance(callback.data, dict) and "dag_run_id" in callback.data: - dag_run_id = callback.data["dag_run_id"] - dag_run = session.get(DagRun, dag_run_id) - elif isinstance(callback.data, dict) and "dag_id" in callback.data: - # Fallback: try to find the latest dag_run for the dag_id - dag_id = callback.data["dag_id"] - dag_run = session.scalars( - select(DagRun) - .where(DagRun.dag_id == dag_id) - .order_by(DagRun.execution_date.desc()) - .limit(1) - ).first() + + # TODO: Add dagrun_id as a proper ORM foreign key on the callback table instead of storing in data dict. + # This would eliminate this reconstruction step. For now, all ExecutorCallbacks + # are expected to have dag_run_id set in their data dict (e.g., by Deadline.handle_miss). + if not isinstance(callback.data, dict) or "dag_run_id" not in callback.data: + self.log.error( + "ExecutorCallback %s is missing required 'dag_run_id' in data dict. " + "This indicates a bug in callback creation. Skipping callback.", + callback.id + ) + continue + + dag_run_id = callback.data["dag_run_id"] + dag_run = session.get(DagRun, dag_run_id) if dag_run is None: - self.log.warning("Could not find DagRun for callback %s", callback.id) + self.log.warning( + "Could not find DagRun with id=%s for callback %s. " + "DagRun may have been deleted.", + dag_run_id, + callback.id + ) continue workload = workloads.ExecuteCallback.make( @@ -1145,17 +1151,20 @@ def process_executor_events( for callback_id in callback_keys_with_events: state, info = event_buffer.pop(callback_id) callback = session.get(Callback, callback_id) - if callback: - if state == CallbackState.SUCCESS: - callback.state = CallbackState.SUCCESS - cls.logger().info("Callback %s completed successfully", callback_id) - elif state == CallbackState.FAILED: - callback.state = CallbackState.FAILED - callback.output = str(info) if info else "Execution failed" - cls.logger().error("Callback %s failed: %s", callback_id, callback.output) - session.add(callback) - else: - cls.logger().warning("Callback %s not found in database", callback_id) + if not callback: + # This should not normally happen - we just received an event for this callback. + # Only possible if callback was deleted mid-execution (e.g., cascade delete from DagRun deletion). + cls.logger().warning("Callback %s not found in database (may have been cascade deleted)", callback_id) + continue + + if state == CallbackState.SUCCESS: + callback.state = CallbackState.SUCCESS + cls.logger().info("Callback %s completed successfully", callback_id) + elif state == CallbackState.FAILED: + callback.state = CallbackState.FAILED + callback.output = str(info) if info else "Execution failed" + cls.logger().error("Callback %s failed: %s", callback_id, callback.output) + session.add(callback) # Return if no finished tasks if not tis_with_right_state: From d0af94d2d141ad52f76e9c255da2b587a475c717 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 10 Feb 2026 23:01:50 -0800 Subject: [PATCH 22/32] fix the callback state lifecycle to match that of tasks --- .../src/airflow/executors/local_executor.py | 1 + .../src/airflow/jobs/scheduler_job_runner.py | 29 ++++++------ airflow-core/src/airflow/models/callback.py | 4 +- airflow-core/src/airflow/models/deadline.py | 3 +- airflow-core/src/airflow/utils/state.py | 1 + .../tests/unit/jobs/test_scheduler_job.py | 45 ++++++++++++++++++- .../tests/unit/models/test_callback.py | 8 ++-- 7 files changed, 70 insertions(+), 21 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 7dc1471a39be0..4102c0f3850a3 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -109,6 +109,7 @@ def _run_worker( output.put((workload.ti.key, TaskInstanceState.FAILED, e)) elif isinstance(workload, workloads.ExecuteCallback): + output.put((workload.callback.id, CallbackState.RUNNING, None)) try: _execute_callback(log, workload, team_conf) output.put((workload.callback.id, CallbackState.SUCCESS, None)) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 031b86224ad4c..7a5942e0fce9f 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1010,19 +1010,19 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: self.log.debug("No available slots for callbacks; all executors at capacity") return - queued_callbacks = session.scalars( + pending_callbacks = session.scalars( select(ExecutorCallback) .where(ExecutorCallback.type == CallbackType.EXECUTOR) - .where(ExecutorCallback.state == CallbackState.QUEUED) + .where(ExecutorCallback.state == CallbackState.PENDING) .order_by(ExecutorCallback.priority_weight.desc()) .limit(max_callbacks) ).all() - if not queued_callbacks: + if not pending_callbacks: return # Route callbacks to executors using the generalized routing method - executor_to_callbacks = self._executor_to_workloads(queued_callbacks, session) + executor_to_callbacks = self._executor_to_workloads(pending_callbacks, session) # Enqueue callbacks for each executor for executor, callbacks in executor_to_callbacks.items(): @@ -1030,7 +1030,7 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: if not isinstance(callback, ExecutorCallback): # Can't happen since we queried ExecutorCallback, but satisfies mypy. continue - + # TODO: Add dagrun_id as a proper ORM foreign key on the callback table instead of storing in data dict. # This would eliminate this reconstruction step. For now, all ExecutorCallbacks # are expected to have dag_run_id set in their data dict (e.g., by Deadline.handle_miss). @@ -1041,15 +1041,15 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: callback.id ) continue - + dag_run_id = callback.data["dag_run_id"] dag_run = session.get(DagRun, dag_run_id) - + if dag_run is None: self.log.warning( "Could not find DagRun with id=%s for callback %s. " - "DagRun may have been deleted.", - dag_run_id, + "DagRun may have been deleted.", + dag_run_id, callback.id ) continue @@ -1061,7 +1061,7 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: ) executor.queue_workload(workload, session=session) - callback.state = CallbackState.RUNNING + callback.state = CallbackState.QUEUED session.add(callback) @staticmethod @@ -1144,10 +1144,10 @@ def process_executor_events( else: # Callback event (key is string UUID) cls.logger().info("Received executor event with state %s for callback %s", state, key) - if state in (CallbackState.FAILED, CallbackState.SUCCESS): + if state in (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.SUCCESS): callback_keys_with_events.append(key) - # Handle callback completion events + # Handle callback state events for callback_id in callback_keys_with_events: state, info = event_buffer.pop(callback_id) callback = session.get(Callback, callback_id) @@ -1157,7 +1157,10 @@ def process_executor_events( cls.logger().warning("Callback %s not found in database (may have been cascade deleted)", callback_id) continue - if state == CallbackState.SUCCESS: + if state == CallbackState.RUNNING: + callback.state = CallbackState.RUNNING + cls.logger().info("Callback %s is currently running", callback_id) + elif state == CallbackState.SUCCESS: callback.state = CallbackState.SUCCESS cls.logger().info("Callback %s completed successfully", callback_id) elif state == CallbackState.FAILED: diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index 7e5ee8b63ee69..117ea9c0a2d31 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -47,7 +47,7 @@ log = structlog.get_logger(__name__) -ACTIVE_STATES = frozenset((CallbackState.QUEUED, CallbackState.RUNNING)) +ACTIVE_STATES = frozenset((CallbackState.PENDING, CallbackState.QUEUED, CallbackState.RUNNING)) TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) @@ -130,7 +130,7 @@ def __init__(self, priority_weight: int = 1, prefix: str = "", **kwargs): :param prefix: Optional prefix for metric names :param kwargs: Additional data emitted in metric tags """ - self.state = CallbackState.PENDING + self.state = CallbackState.SCHEDULED self.priority_weight = priority_weight self.data = kwargs # kwargs can be used to include additional info in metric tags if prefix: diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index 316d3eff56e36..723364aae32d3 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -41,6 +41,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name +from airflow.utils.state import CallbackState if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -264,7 +265,7 @@ def get_simple_context(): self.callback.data["dag_run_id"] = str(self.dagrun.id) self.callback.data["dag_id"] = self.dagrun.dag_id - self.callback.queue() + self.callback.state = CallbackState.PENDING session.add(self.callback) session.flush() diff --git a/airflow-core/src/airflow/utils/state.py b/airflow-core/src/airflow/utils/state.py index 926003d86b0d1..332efb105533d 100644 --- a/airflow-core/src/airflow/utils/state.py +++ b/airflow-core/src/airflow/utils/state.py @@ -23,6 +23,7 @@ class CallbackState(str, Enum): """All possible states of callbacks.""" + SCHEDULED = "scheduled" PENDING = "pending" QUEUED = "queued" RUNNING = "running" diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index f883cbcd2cb60..07c38aa84e44a 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -65,6 +65,7 @@ PartitionedAssetKeyLog, ) from airflow.models.backfill import Backfill, _create_backfill +from airflow.models.callback import ExecutorCallback from airflow.models.dag import DagModel, get_last_dagrun, infer_automated_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel @@ -92,7 +93,7 @@ from airflow.timetables.base import DagRunInfo, DataInterval from airflow.utils.session import create_session, provide_session from airflow.utils.span_status import SpanStatus -from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.state import CallbackState, DagRunState, State, TaskInstanceState from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -556,6 +557,48 @@ def test_process_executor_events_with_no_callback(self, mock_stats_incr, mock_ta any_order=True, ) + def test_enqueue_executor_callbacks_only_selects_pending_state(self, dag_maker, session): + def test_callback(): + pass + + def create_callback_in_state(state: CallbackState): + callback = Deadline( + deadline_time=timezone.utcnow(), + callback=SyncCallback(test_callback), + dagrun_id=dag_run.id, + deadline_alert_id=None, + ).callback + callback.state = state + callback.data["dag_run_id"] = dag_run.id + callback.data["dag_id"] = dag_run.dag_id + return callback + + with dag_maker(dag_id="test_callback_states"): + pass + dag_run = dag_maker.create_dagrun() + + scheduled_callback = create_callback_in_state(CallbackState.SCHEDULED) + pending_callback = create_callback_in_state(CallbackState.PENDING) + queued_callback = create_callback_in_state(CallbackState.QUEUED) + running_callback = create_callback_in_state(CallbackState.RUNNING) + session.add_all([scheduled_callback, pending_callback, queued_callback, running_callback]) + session.flush() + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + # Verify initial state before calling _enqueue_executor_callbacks + assert session.get(ExecutorCallback, pending_callback.id).state == CallbackState.PENDING + + self.job_runner._enqueue_executor_callbacks(session) + # PENDING should progress to QUEUED after _enqueue_executor_callbacks + assert session.get(ExecutorCallback, pending_callback.id).state == CallbackState.QUEUED + + # Other callbacks should remain in their original states + assert session.get(ExecutorCallback, scheduled_callback.id).state == CallbackState.SCHEDULED + assert session.get(ExecutorCallback, queued_callback.id).state == CallbackState.QUEUED + assert session.get(ExecutorCallback, running_callback.id).state == CallbackState.RUNNING + @mock.patch("airflow.jobs.scheduler_job_runner.TaskCallbackRequest") @mock.patch("airflow.jobs.scheduler_job_runner.Stats.incr") def test_process_executor_events_with_callback( diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index dfc19fc61a354..6ab6ad2d02df7 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -123,7 +123,7 @@ def test_polymorphic_serde(self, session): assert isinstance(retrieved, TriggererCallback) assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH assert retrieved.data == TEST_ASYNC_CALLBACK.serialize() - assert retrieved.state == CallbackState.PENDING.value + assert retrieved.state == CallbackState.SCHEDULED.value assert retrieved.output is None assert retrieved.priority_weight == 1 assert retrieved.created_at is not None @@ -131,7 +131,7 @@ def test_polymorphic_serde(self, session): def test_queue(self, session): callback = TriggererCallback(TEST_ASYNC_CALLBACK) - assert callback.state == CallbackState.PENDING + assert callback.state == CallbackState.SCHEDULED assert callback.trigger is None callback.queue() @@ -193,7 +193,7 @@ def test_polymorphic_serde(self, session): assert isinstance(retrieved, ExecutorCallback) assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH assert retrieved.data == TEST_SYNC_CALLBACK.serialize() - assert retrieved.state == CallbackState.PENDING.value + assert retrieved.state == CallbackState.SCHEDULED.value assert retrieved.output is None assert retrieved.priority_weight == 1 assert retrieved.created_at is not None @@ -201,7 +201,7 @@ def test_polymorphic_serde(self, session): def test_queue(self): callback = ExecutorCallback(TEST_SYNC_CALLBACK, fetch_method=CallbackFetchMethod.DAG_ATTRIBUTE) - assert callback.state == CallbackState.PENDING + assert callback.state == CallbackState.SCHEDULED callback.queue() assert callback.state == CallbackState.QUEUED From d901a8488121e92cc38142c6271883b68c6d2394 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 11 Feb 2026 00:31:42 -0800 Subject: [PATCH 23/32] static checks --- .../src/airflow/jobs/scheduler_job_runner.py | 21 ++++++++++--------- .../edge3/worker_api/v2-edge-generated.yaml | 6 +++--- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 7a5942e0fce9f..ea40e02e9af1f 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1030,7 +1030,7 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: if not isinstance(callback, ExecutorCallback): # Can't happen since we queried ExecutorCallback, but satisfies mypy. continue - + # TODO: Add dagrun_id as a proper ORM foreign key on the callback table instead of storing in data dict. # This would eliminate this reconstruction step. For now, all ExecutorCallbacks # are expected to have dag_run_id set in their data dict (e.g., by Deadline.handle_miss). @@ -1038,19 +1038,18 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: self.log.error( "ExecutorCallback %s is missing required 'dag_run_id' in data dict. " "This indicates a bug in callback creation. Skipping callback.", - callback.id + callback.id, ) continue - + dag_run_id = callback.data["dag_run_id"] dag_run = session.get(DagRun, dag_run_id) - + if dag_run is None: self.log.warning( - "Could not find DagRun with id=%s for callback %s. " - "DagRun may have been deleted.", - dag_run_id, - callback.id + "Could not find DagRun with id=%s for callback %s. DagRun may have been deleted.", + dag_run_id, + callback.id, ) continue @@ -1154,9 +1153,11 @@ def process_executor_events( if not callback: # This should not normally happen - we just received an event for this callback. # Only possible if callback was deleted mid-execution (e.g., cascade delete from DagRun deletion). - cls.logger().warning("Callback %s not found in database (may have been cascade deleted)", callback_id) + cls.logger().warning( + "Callback %s not found in database (may have been cascade deleted)", callback_id + ) continue - + if state == CallbackState.RUNNING: callback.state = CallbackState.RUNNING cls.logger().info("Callback %s is currently running", callback_id) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 94feded9ff924..0827a8819446f 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -990,9 +990,9 @@ components: description: Status of a Edge Worker instance. ExecuteTask: properties: - token: + identity_token: type: string - title: Token + title: Identity Token dag_rel_path: type: string format: path @@ -1017,7 +1017,7 @@ components: default: ExecuteTask type: object required: - - token + - identity_token - dag_rel_path - bundle_info - log_path From 279c2542f2809debf60daf6678a4e20980fda232 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 17 Feb 2026 19:04:32 -0800 Subject: [PATCH 24/32] post-rebase fixes --- .../src/airflow/executors/workloads/base.py | 9 ++- .../src/airflow/jobs/scheduler_job_runner.py | 61 +++++++++++-------- .../tests/unit/jobs/test_scheduler_job.py | 8 +-- .../celery/executors/celery_executor_utils.py | 2 +- .../sdk/execution_time/execute_workload.py | 2 +- 5 files changed, 48 insertions(+), 34 deletions(-) diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 98676adea19e1..a2755503233b2 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -22,7 +22,7 @@ from abc import ABC from typing import TYPE_CHECKING -from pydantic import BaseModel +from pydantic import BaseModel, Field if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -67,7 +67,12 @@ class BundleInfo(BaseModel): class BaseWorkloadSchema(BaseModel): """Base Pydantic schema for executor workload DTOs.""" - identity_token: str + identity_token: str = Field(validation_alias="token") + + @property + def token(self) -> str: + """Backward compat alias for identity_token.""" + return self.identity_token @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index ea40e02e9af1f..c46bbfe389cfa 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -360,32 +360,35 @@ def _get_team_names_for_dag_ids( # Return dict with all None values to ensure graceful degradation return {} - def _get_task_team_name(self, task_instance: TaskInstance, session: Session) -> str | None: + def _get_workload_team_name(self, workload: SchedulerWorkload, session: Session) -> str | None: """ - Resolve team name for a task instance using the DAG > Bundle > Team relationship chain. + Resolve team name for a workload using the DAG > Bundle > Team relationship chain. - TaskInstance > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team + Workload > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team - :param task_instance: The TaskInstance to resolve team name for + :param workload: The Workload to resolve team name for :param session: Database session for queries :return: Team name if found or None """ # Use the batch query function with a single DAG ID - dag_id_to_team_name = self._get_team_names_for_dag_ids([task_instance.dag_id], session) - team_name = dag_id_to_team_name.get(task_instance.dag_id) + if dag_id := workload.get_dag_id(): + dag_id_to_team_name = self._get_team_names_for_dag_ids([dag_id], session) + team_name = dag_id_to_team_name.get(dag_id) + else: + team_name = None # mypy didn't like the implicit defaulting to None if team_name: self.log.debug( - "Resolved team name '%s' for task %s (dag_id=%s)", + "Resolved team name '%s' for workload %s (dag_id=%s)", team_name, - task_instance.task_id, - task_instance.dag_id, + workload, + dag_id, ) else: self.log.debug( - "No team found for task %s (dag_id=%s) - DAG may not have bundle or team association", - task_instance.task_id, - task_instance.dag_id, + "No team found for workload %s (dag_id=%s) - DAG may not have bundle or team association", + workload, + dag_id, ) return team_name @@ -1003,7 +1006,7 @@ def _enqueue_executor_callbacks(self, session: Session) -> None: :param session: The database session """ - num_occupied_slots = sum(executor.slots_occupied for executor in self.job.executors) + num_occupied_slots = sum(executor.slots_occupied for executor in self.executors) max_callbacks = conf.getint("core", "parallelism") - num_occupied_slots if max_callbacks <= 0: @@ -1133,11 +1136,11 @@ def process_executor_events( ti_primary_key_to_try_number_map[key.primary] = key.try_number cls.logger().info("Received executor event with state %s for task instance %s", state, key) if state in ( - TaskInstanceState.FAILED, - TaskInstanceState.SUCCESS, - TaskInstanceState.QUEUED, - TaskInstanceState.RUNNING, - TaskInstanceState.RESTARTING, + TaskInstanceState.FAILED, + TaskInstanceState.SUCCESS, + TaskInstanceState.QUEUED, + TaskInstanceState.RUNNING, + TaskInstanceState.RESTARTING, ): tis_with_right_state.append(key) else: @@ -3192,8 +3195,11 @@ def _executor_to_workloads( workloads_list = list(workloads) if workloads_list: dag_id_to_team_name = self._get_team_names_for_dag_ids( - {dag_id for workload in workloads_list if - (dag_id := workload.get_dag_id()) is not None}, + { + dag_id + for workload in workloads_list + if (dag_id := workload.get_dag_id()) is not None + }, session, ) else: @@ -3207,9 +3213,9 @@ def _executor_to_workloads( _executor_to_workloads: defaultdict[BaseExecutor, list[SchedulerWorkload]] = defaultdict(list) for workload in workloads_iter: - if executor_obj := self._try_to_load_executor( - workload, session, team_name=dag_id_to_team_name.get(workload.get_dag_id(), NOTSET) - ): + _dag_id = workload.get_dag_id() + _team = dag_id_to_team_name.get(_dag_id, NOTSET) if _dag_id else NOTSET + if executor_obj := self._try_to_load_executor(workload, session, team_name=_team): _executor_to_workloads[executor_obj].append(workload) return _executor_to_workloads @@ -3232,7 +3238,7 @@ def _try_to_load_executor( if conf.getboolean("core", "multi_team"): # Use provided team_name if available, otherwise query the database if team_name is NOTSET: - team_name = self._get_task_team_name(workload, session) + team_name = self._get_workload_team_name(workload, session) else: team_name = None # Firstly, check if there is no executor set on the workload, if not, we need to fetch the default @@ -3253,8 +3259,11 @@ def _try_to_load_executor( executor = self.executor else: # An executor is specified on the workload (as a str), so we need to find it in the list of executors - for _executor in self.job.executors: - if workload.get_executor_name() in (_executor.name.alias, _executor.name.module_path): + for _executor in self.executors: + if _executor.name and workload.get_executor_name() in ( + _executor.name.alias, + _executor.name.module_path, + ): # The executor must either match the team or be global (i.e. team_name is None) if team_name and _executor.team_name == team_name or _executor.team_name is None: executor = _executor diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 07c38aa84e44a..93fa1cebd8757 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -8363,13 +8363,13 @@ def test_multi_team_executor_to_tis_batch_optimization(self, dag_maker, mock_exe with ( assert_queries_count(1, session=session), - mock.patch.object(self.job_runner, "_get_task_team_name") as mock_single, + mock.patch.object(self.job_runner, "_get_workload_team_name") as mock_single, ): - executor_to_tis = self.job_runner._executor_to_tis([ti1, ti2], session) + executor_to_workloads = self.job_runner._executor_to_workloads([ti1, ti2], session) mock_single.assert_not_called() - assert executor_to_tis[mock_executors[0]] == [ti1] - assert executor_to_tis[mock_executors[1]] == [ti2] + assert executor_to_workloads[mock_executors[0]] == [ti1] + assert executor_to_workloads[mock_executors[1]] == [ti2] @conf_vars({("core", "multi_team"): "false"}) def test_multi_team_config_disabled_uses_legacy_behavior(self, dag_maker, mock_executors, session): diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 8f486ce9945b3..a3631c49ad3ee 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -204,7 +204,7 @@ def execute_workload(input: str) -> None: ti=workload.ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, - token=workload.token, + token=workload.identity_token, server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), log_path=workload.log_path, ) diff --git a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py index 410c676eeb913..cd9849a29ccf8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py +++ b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py @@ -68,7 +68,7 @@ def execute_workload(workload: ExecuteTask) -> None: ti=workload.ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, - token=workload.token, + token=workload.identity_token, server=server, log_path=workload.log_path, sentry_integration=workload.sentry_integration, From 82cf839c70765e8cd746dfee2bf0f97b56ad4bcd Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 18 Feb 2026 14:44:21 -0800 Subject: [PATCH 25/32] Can't make mypy happy without modifying the hybrid executors at least a little --- .../celery/executors/celery_kubernetes_executor.py | 12 ++++++------ .../executors/local_kubernetes_executor.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 49ae5b35b6f52..34cfb27e86a81 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -104,18 +104,18 @@ def _task_event_logs(self, value): def queued_tasks(self) -> dict[TaskInstanceKey, Any]: """Return queued tasks from celery and kubernetes executor.""" queued_tasks = self.celery_executor.queued_tasks.copy() - queued_tasks.update(self.kubernetes_executor.queued_tasks) + queued_tasks.update(self.kubernetes_executor.queued_tasks) # type: ignore[arg-type] - return queued_tasks + return queued_tasks # type: ignore[return-value] @queued_tasks.setter def queued_tasks(self, value) -> None: """Not implemented for hybrid executors.""" - @property + @property # type: ignore[override] def running(self) -> set[TaskInstanceKey]: """Return running tasks from celery and kubernetes executor.""" - return self.celery_executor.running.union(self.kubernetes_executor.running) + return self.celery_executor.running.union(self.kubernetes_executor.running) # type: ignore[return-value, arg-type] @running.setter def running(self, value) -> None: @@ -225,7 +225,7 @@ def heartbeat(self) -> None: self.celery_executor.heartbeat() self.kubernetes_executor.heartbeat() - def get_event_buffer( + def get_event_buffer( # type: ignore[override] self, dag_ids: list[str] | None = None ) -> dict[TaskInstanceKey, EventBufferValueType]: """ @@ -237,7 +237,7 @@ def get_event_buffer( cleared_events_from_celery = self.celery_executor.get_event_buffer(dag_ids) cleared_events_from_kubernetes = self.kubernetes_executor.get_event_buffer(dag_ids) - return {**cleared_events_from_celery, **cleared_events_from_kubernetes} + return {**cleared_events_from_celery, **cleared_events_from_kubernetes} # type: ignore[dict-item] def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 114da7ec36fe8..f2eb64e46c588 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -108,10 +108,10 @@ def queued_tasks(self) -> dict[TaskInstanceKey, Any]: def queued_tasks(self, value) -> None: """Not implemented for hybrid executors.""" - @property + @property # type: ignore[override] def running(self) -> set[TaskInstanceKey]: """Return running tasks from local and kubernetes executor.""" - return self.local_executor.running.union(self.kubernetes_executor.running) + return self.local_executor.running.union(self.kubernetes_executor.running) # type: ignore[return-value, arg-type] @running.setter def running(self, value) -> None: @@ -219,7 +219,7 @@ def heartbeat(self) -> None: self.local_executor.heartbeat() self.kubernetes_executor.heartbeat() - def get_event_buffer( + def get_event_buffer( # type: ignore[override] self, dag_ids: list[str] | None = None ) -> dict[TaskInstanceKey, EventBufferValueType]: """ @@ -231,7 +231,7 @@ def get_event_buffer( cleared_events_from_local = self.local_executor.get_event_buffer(dag_ids) cleared_events_from_kubernetes = self.kubernetes_executor.get_event_buffer(dag_ids) - return {**cleared_events_from_local, **cleared_events_from_kubernetes} + return {**cleared_events_from_local, **cleared_events_from_kubernetes} # type: ignore[dict-item] def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ From ac4f31a0de2826967a7218e8202c15f726c1e72d Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 18 Feb 2026 17:21:46 -0800 Subject: [PATCH 26/32] pydantic fixes --- airflow-core/src/airflow/executors/workloads/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index a2755503233b2..b66c1ab72eb78 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -22,7 +22,7 @@ from abc import ABC from typing import TYPE_CHECKING -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -67,6 +67,8 @@ class BundleInfo(BaseModel): class BaseWorkloadSchema(BaseModel): """Base Pydantic schema for executor workload DTOs.""" + model_config = ConfigDict(populate_by_name=True) + identity_token: str = Field(validation_alias="token") @property From fb3927a5df96ba0def0c3c28ad91a7805d873f75 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 20 Feb 2026 09:00:03 -0800 Subject: [PATCH 27/32] typo --- airflow-core/src/airflow/executors/workloads/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 1615808e40d17..197f71090ede9 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -140,7 +140,7 @@ def execute_callback_workload( log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) - # If the callback is a callabale, call it. If it is a class, instantiate it. + # If the callback is a callable, call it. If it is a class, instantiate it. result = callback_callable(**callback_kwargs) # If the callback is a class then it is now instantiated and callable, call it. From f8f5dc9fe8dc52599951189d39c59da5fe11a79c Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 20 Feb 2026 10:50:06 -0800 Subject: [PATCH 28/32] rebase and docstring tweaks --- .../src/airflow/executors/base_executor.py | 1 - .../src/airflow/jobs/scheduler_job_runner.py | 7 +++---- airflow-core/src/airflow/models/deadline.py | 14 -------------- .../celery/executors/celery_executor_utils.py | 1 - 4 files changed, 3 insertions(+), 20 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index b39e2e9910a5c..f6ba8983d9324 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -516,7 +516,6 @@ def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueTy self.event_buffer = {} else: for key in list(self.event_buffer.keys()): - # Include if it's a callback or if it's a task in the specified dags if isinstance(key, CallbackKey) or key.dag_id in dag_ids: cleared_events[key] = self.event_buffer.pop(key) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index c46bbfe389cfa..19059866bface 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -3241,14 +3241,13 @@ def _try_to_load_executor( team_name = self._get_workload_team_name(workload, session) else: team_name = None - # Firstly, check if there is no executor set on the workload, if not, we need to fetch the default - # (either globally or for the team) + # If there is no executor set on the workload fetch the default (either globally or for the team) if workload.get_executor_name() is None: if not team_name: - # No team is specified, so just use the global default executor + # No team is specified, use the global default executor executor = self.executor else: - # We do have a team, so we need to find the default executor for that team + # We do have a team, use the default executor for that team for _executor in self.executors: # First executor that resolves should be the default for that team if _executor.team_name == team_name: diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index 723364aae32d3..debfe949b314c 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -21,7 +21,6 @@ from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime, timedelta -from enum import Enum from typing import TYPE_CHECKING, Any, cast from uuid import UUID @@ -83,19 +82,6 @@ def __get__(self, instance, cls=None): return self.method(cls) -class DeadlineCallbackState(str, Enum): - """ - All possible states of deadline callbacks once the deadline is missed. - - `None` state implies that the deadline is pending (`deadline_time` hasn't passed yet). - """ - - QUEUED = "queued" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - - class Deadline(Base): """A Deadline is a 'need-by' date which triggers a callback if the provided time has passed.""" diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index a3631c49ad3ee..10272552dbc23 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -331,7 +331,6 @@ def send_workload_to_executor( if TYPE_CHECKING: _conf: ExecutorConf | AirflowConfigParser # Check if Airflow version is greater than or equal to 3.2 to import ExecutorConf - # Check if Airflow version is greater than or equal to 3.2 to import ExecutorConf if AIRFLOW_V_3_2_PLUS: from airflow.executors.base_executor import ExecutorConf From 02f47f2a1854cac36194515093ed3d9b519340fe Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 20 Feb 2026 14:38:29 -0800 Subject: [PATCH 29/32] fix celery callback queue --- .../celery/executors/celery_executor.py | 28 ++++--------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 60c844cd583f8..cd142acc7a182 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -52,11 +52,8 @@ if TYPE_CHECKING: from collections.abc import Sequence - from sqlalchemy.orm import Session - from airflow.cli.cli_config import GroupCommand from airflow.executors import workloads - from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery @@ -100,13 +97,9 @@ class CeleryExecutor(BaseExecutor): supports_multi_team: bool = True if TYPE_CHECKING: - if AIRFLOW_V_3_2_PLUS: - # In v3.2+, callbacks are supported, so keys can be TaskInstanceKey OR str (CallbackKey) - queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment] - elif AIRFLOW_V_3_0_PLUS: - # In v3.0-3.1, only tasks are supported (no callbacks yet) + if AIRFLOW_V_3_0_PLUS: # TODO: TaskSDK: move this type change into BaseExecutor - queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment,no-redef] + queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -208,7 +201,10 @@ def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]): ) self.task_publish_retries[key] = retries + 1 continue - self.queued_tasks.pop(key) + if key in self.queued_tasks: + self.queued_tasks.pop(key) + else: + self.queued_callbacks.pop(key, None) self.task_publish_retries.pop(key, None) if isinstance(result, ExceptionWithTraceback): self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback) @@ -388,15 +384,3 @@ def get_cli_commands() -> list[GroupCommand]: from airflow.providers.celery.cli.definition import get_celery_cli_commands return get_celery_cli_commands() - - def queue_workload(self, workload: workloads.All, session: Session | None) -> None: - from airflow.executors import workloads - - if isinstance(workload, workloads.ExecuteTask): - ti = workload.ti - self.queued_tasks[ti.key] = workload - elif isinstance(workload, workloads.ExecuteCallback): - # Use workload.callback.key (CallbackKey = str) - self.queued_tasks[workload.callback.key] = workload - else: - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") From 0673ea566760c149a5c0c6e9c5a32af455d65aee Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 23 Feb 2026 13:12:07 -0800 Subject: [PATCH 30/32] revert renaming token to identity_token It was a much more invasive change than intended and may be considered breaking. We can do that later if we decide we want it. --- airflow-core/src/airflow/executors/local_executor.py | 2 +- airflow-core/src/airflow/executors/workloads/base.py | 10 +++------- .../src/airflow/executors/workloads/callback.py | 2 +- airflow-core/src/airflow/executors/workloads/task.py | 2 +- .../tests/unit/executors/test_base_executor.py | 6 +++--- .../tests/unit/executors/test_local_executor.py | 6 +++--- .../celery/executors/celery_executor_utils.py | 2 +- .../providers/edge3/worker_api/v2-edge-generated.yaml | 6 +++--- .../src/airflow/sdk/execution_time/execute_workload.py | 2 +- 9 files changed, 17 insertions(+), 21 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 4102c0f3850a3..9b5939a0bd2e7 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -146,7 +146,7 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No ti=workload.ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, - token=workload.identity_token, + token=workload.token, server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), log_path=workload.log_path, ) diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index b66c1ab72eb78..cf622209d67ba 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -22,7 +22,7 @@ from abc import ABC from typing import TYPE_CHECKING -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -69,12 +69,8 @@ class BaseWorkloadSchema(BaseModel): model_config = ConfigDict(populate_by_name=True) - identity_token: str = Field(validation_alias="token") - - @property - def token(self) -> str: - """Backward compat alias for identity_token.""" - return self.identity_token + token: str + """The identity token for this workload""" @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 197f71090ede9..c15bb33fba70e 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -95,7 +95,7 @@ def make( return cls( callback=CallbackDTO.model_validate(callback, from_attributes=True), dag_rel_path=dag_rel_path or Path(dag_run.dag_model.relative_fileloc or ""), - identity_token=cls.generate_token(str(callback.id), generator), + token=cls.generate_token(str(callback.id), generator), log_path=fname, bundle_info=bundle_info, ) diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index fe20e5a1afa2f..d691dcb6f0968 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -97,7 +97,7 @@ def make( return cls( ti=ser_ti, dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc or ""), - identity_token=cls.generate_token(str(ti.id), generator), + token=cls.generate_token(str(ti.id), generator), log_path=fname, bundle_info=bundle_info, sentry_integration=sentry_integration, diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index 31c6663f98eca..fa0f311d018fe 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -601,7 +601,7 @@ def test_queue_callback_without_support_raises_error(self, dag_maker, session): callback=callback_data, dag_rel_path="test.py", bundle_info=BundleInfo(name="test_bundle", version="1.0"), - identity_token="test_token", + token="test_token", log_path="test.log", ) @@ -621,7 +621,7 @@ def test_queue_workload_with_execute_callback(self, dag_maker, session): callback=callback_data, dag_rel_path="test.py", bundle_info=BundleInfo(name="test_bundle", version="1.0"), - identity_token="test_token", + token="test_token", log_path="test.log", ) @@ -644,7 +644,7 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): callback=callback_data, dag_rel_path="test.py", bundle_info=BundleInfo(name="test_bundle", version="1.0"), - identity_token="test_token", + token="test_token", log_path="test.log", ) executor.queue_workload(callback_workload, session) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 23e2feaa56634..34e8f818aa94c 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -125,7 +125,7 @@ def fake_supervise(ti, **kwargs): for ti in success_tis: executor.queue_workload( workloads.ExecuteTask( - identity_token="", + token="", ti=ti, dag_rel_path="some/path", log_path=None, @@ -136,7 +136,7 @@ def fake_supervise(ti, **kwargs): executor.queue_workload( workloads.ExecuteTask( - identity_token="", + token="", ti=fail_ti, dag_rel_path="some/path", log_path=None, @@ -353,7 +353,7 @@ def test_process_callback_workload(self, mock_execute_callback): callback=callback_data, dag_rel_path="test.py", bundle_info=BundleInfo(name="test_bundle", version="1.0"), - identity_token="test_token", + token="test_token", log_path="test.log", ) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 10272552dbc23..578d0a909acc1 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -204,7 +204,7 @@ def execute_workload(input: str) -> None: ti=workload.ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, - token=workload.identity_token, + token=workload.token, server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), log_path=workload.log_path, ) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 0827a8819446f..94feded9ff924 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -990,9 +990,9 @@ components: description: Status of a Edge Worker instance. ExecuteTask: properties: - identity_token: + token: type: string - title: Identity Token + title: Token dag_rel_path: type: string format: path @@ -1017,7 +1017,7 @@ components: default: ExecuteTask type: object required: - - identity_token + - token - dag_rel_path - bundle_info - log_path diff --git a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py index cd9849a29ccf8..410c676eeb913 100644 --- a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py +++ b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py @@ -68,7 +68,7 @@ def execute_workload(workload: ExecuteTask) -> None: ti=workload.ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, - token=workload.identity_token, + token=workload.token, server=server, log_path=workload.log_path, sentry_integration=workload.sentry_integration, From 57c83ad65431c331c2145cb4b5a10248fd92c268 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 23 Feb 2026 16:56:54 -0800 Subject: [PATCH 31/32] whitespace fix --- airflow-core/src/airflow/models/callback.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index 117ea9c0a2d31..ea482ab7ba8d5 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from __future__ import annotations from datetime import datetime From e7622f1337ac9893078a7ce73d1c585bb182f036 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 27 Feb 2026 14:34:03 -0800 Subject: [PATCH 32/32] user-facing phrasing for Niko --- airflow-core/src/airflow/executors/base_executor.py | 5 ++++- airflow-core/src/airflow/jobs/scheduler_job_runner.py | 10 ++++++---- airflow-core/tests/unit/jobs/test_scheduler_job.py | 3 ++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index f6ba8983d9324..2997d55d8bb3b 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -241,7 +241,10 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None: ) self.queued_callbacks[workload.callback.id] = workload else: - raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") + raise ValueError( + f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. " + f"Workload must be one of: ExecuteTask, ExecuteCallback." + ) def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]: """ diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 19059866bface..d192798004211 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -379,14 +379,14 @@ def _get_workload_team_name(self, workload: SchedulerWorkload, session: Session) if team_name: self.log.debug( - "Resolved team name '%s' for workload %s (dag_id=%s)", + "Resolved team name '%s' for task or callback %s (dag_id=%s)", team_name, workload, dag_id, ) else: self.log.debug( - "No team found for workload %s (dag_id=%s) - DAG may not have bundle or team association", + "No team found for task or callback %s (dag_id=%s) - DAG may not have bundle or team association", workload, dag_id, ) @@ -3268,7 +3268,9 @@ def _try_to_load_executor( executor = _executor if executor is not None: - self.log.debug("Found executor %s for workload %s (team: %s)", executor.name, workload, team_name) + self.log.debug( + "Found executor %s for task or callback %s (team: %s)", executor.name, workload, team_name + ) else: # This case should not happen unless some (as of now unknown) edge case occurs or direct DB # modification, since the DAG parser will validate the tasks in the DAG and ensure the executor @@ -3276,7 +3278,7 @@ def _try_to_load_executor( # Keeping this exception handling because this is a critical issue if we do somehow find # ourselves here and the user should get some feedback about that. self.log.warning( - "Executor, %s, was not found but a workload was configured to use it", + "Executor, %s, was not found but a Task or Callback was configured to use it", workload.get_executor_name(), ) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 93fa1cebd8757..abbd5f20067b9 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -8220,7 +8220,8 @@ def test_multi_team_try_to_load_executor_explicit_executor_team_mismatch( # Should log a warning when no executor is found mock_log.warning.assert_called_once_with( - "Executor, %s, was not found but a workload was configured to use it", "secondary_exec" + "Executor, %s, was not found but a Task or Callback was configured to use it", + "secondary_exec", ) # Should return None since we failed to resolve an executor due to the mismatch. In practice, this