Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
08fbfa0
first pass implementation of executor support for sync callbacks
seanghaeli Nov 4, 2025
a2186c4
Synchronous callback support for BaseExecutor, LocalExecutor, and Cel…
ferruzzi Jan 27, 2026
7e17159
Create a unified WorkflowState and WorkflowKey; added mypy type hint …
ferruzzi Jan 28, 2026
e41698c
First batch Niko fixes
ferruzzi Jan 29, 2026
6de20c8
deovercommentification of hand-off notes
ferruzzi Jan 29, 2026
4722833
CI fixes and minor type-related cleanup
ferruzzi Jan 29, 2026
9719607
Niko suggestions round 2 and Ci/prek fixes
ferruzzi Jan 30, 2026
8aee703
CI and MyPy fixes
ferruzzi Jan 30, 2026
587499a
CI and MyPy fixes
ferruzzi Jan 30, 2026
47e1d76
refactor workload file locations
ferruzzi Feb 4, 2026
09679f6
generalize _executor_to_tis and reuse it for all workload types
ferruzzi Feb 5, 2026
e637ccd
celery fixes
ferruzzi Feb 6, 2026
6abf97a
fix bad merge
ferruzzi Feb 6, 2026
a5c7533
mypy and pydantic typing issues
ferruzzi Feb 7, 2026
b33e30c
rename BaseWorkloadSchema.token to BaseWorkloadSchema.identity_token
ferruzzi Feb 11, 2026
9959dd3
use correct state type in callbacks
ferruzzi Feb 11, 2026
1500562
revert changes to the deprecated executors
ferruzzi Feb 11, 2026
8e59873
revert dropped TODO
ferruzzi Feb 11, 2026
1fdc461
missed some identity_token renames in a test module
ferruzzi Feb 11, 2026
a298fe0
make better use of the CallbackKey type alias
ferruzzi Feb 11, 2026
5edaa1d
pr fixes
ferruzzi Feb 11, 2026
d0af94d
fix the callback state lifecycle to match that of tasks
ferruzzi Feb 11, 2026
d901a84
static checks
ferruzzi Feb 11, 2026
279c254
post-rebase fixes
ferruzzi Feb 18, 2026
82cf839
Can't make mypy happy without modifying the hybrid executors at least…
ferruzzi Feb 18, 2026
ac4f31a
pydantic fixes
ferruzzi Feb 19, 2026
fb3927a
typo
ferruzzi Feb 20, 2026
f8f5dc9
rebase and docstring tweaks
ferruzzi Feb 20, 2026
02f47f2
fix celery callback queue
ferruzzi Feb 20, 2026
0673ea5
revert renaming token to identity_token
ferruzzi Feb 23, 2026
57c83ad
whitespace fix
ferruzzi Feb 24, 2026
e7622f1
user-facing phrasing for Niko
ferruzzi Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 78 additions & 29 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from airflow.configuration import conf
from airflow.executors import workloads
from airflow.executors.executor_loader import ExecutorLoader
from airflow.executors.workloads.task import TaskInstanceDTO
from airflow.models import Log
from airflow.models.callback import CallbackKey
from airflow.observability.metrics import stats_utils
from airflow.observability.trace import Trace
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -52,6 +54,7 @@
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.cli.cli_config import GroupCommand
from airflow.executors.executor_utils import ExecutorName
from airflow.executors.workloads.types import WorkloadKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey

Expand Down Expand Up @@ -143,6 +146,7 @@ class BaseExecutor(LoggingMixin):
active_spans = ThreadSafeDict()

supports_ad_hoc_ti_run: bool = False
supports_callbacks: bool = False
supports_multi_team: bool = False
sentry_integration: str = ""

Expand Down Expand Up @@ -186,8 +190,9 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
self.parallelism: int = parallelism
self.team_name: str | None = team_name
self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {}
self.running: set[TaskInstanceKey] = set()
self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {}
self.running: set[WorkloadKey] = set()
self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {}
self._task_event_logs: deque[Log] = deque()
self.conf = ExecutorConf(team_name)

Expand All @@ -203,7 +208,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
:meta private:
"""

self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType)
self.attempts: dict[WorkloadKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType)

def __repr__(self):
_repr = f"{self.__class__.__name__}(parallelism={self.parallelism}"
Expand All @@ -224,10 +229,47 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra))

def queue_workload(self, workload: workloads.All, session: Session) -> None:
if not isinstance(workload, workloads.ExecuteTask):
raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}")
ti = workload.ti
self.queued_tasks[ti.key] = workload
if isinstance(workload, workloads.ExecuteTask):
ti = workload.ti
self.queued_tasks[ti.key] = workload
elif isinstance(workload, workloads.ExecuteCallback):
if not self.supports_callbacks:
raise NotImplementedError(
f"{type(self).__name__} does not support ExecuteCallback workloads. "
f"Set supports_callbacks = True and implement callback handling in _process_workloads(). "
f"See LocalExecutor or CeleryExecutor for reference implementation."
)
self.queued_callbacks[workload.callback.id] = workload
else:
raise ValueError(
f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. "
f"Workload must be one of: ExecuteTask, ExecuteCallback."
)

def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]:
"""
Select and return the next batch of workloads to schedule, respecting priority policy.

Priority Policy: Callbacks are scheduled before tasks (callbacks complete existing work).
Callbacks are processed in FIFO order. Tasks are sorted by priority_weight (higher priority first).

:param open_slots: Number of available execution slots
"""
workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = []

if self.queued_callbacks:
for key, workload in self.queued_callbacks.items():
if len(workloads_to_schedule) >= open_slots:
break
workloads_to_schedule.append((key, workload))

if open_slots > len(workloads_to_schedule) and self.queued_tasks:
for task_key, task_workload in self.order_queued_tasks_by_priority():
if len(workloads_to_schedule) >= open_slots:
break
workloads_to_schedule.append((task_key, task_workload))

return workloads_to_schedule

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
"""
Expand Down Expand Up @@ -266,10 +308,10 @@ def heartbeat(self) -> None:
"""Heartbeat sent to trigger new jobs."""
open_slots = self.parallelism - len(self.running)

num_running_tasks = len(self.running)
num_queued_tasks = len(self.queued_tasks)
num_running_workloads = len(self.running)
num_queued_workloads = len(self.queued_tasks) + len(self.queued_callbacks)

self._emit_metrics(open_slots, num_running_tasks, num_queued_tasks)
self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads)
self.trigger_tasks(open_slots)

# Calling child class sync method
Expand Down Expand Up @@ -350,16 +392,16 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, workload

def trigger_tasks(self, open_slots: int) -> None:
"""
Initiate async execution of the queued tasks, up to the number of available slots.
Initiate async execution of queued workloads (tasks and callbacks), up to the number of available slots.

Callbacks are prioritized over tasks to complete existing work before starting new work.

:param open_slots: Number of open slots
"""
sorted_queue = self.order_queued_tasks_by_priority()
workloads_to_schedule = self._get_workloads_to_schedule(open_slots)
workload_list = []

for _ in range(min((open_slots, len(self.queued_tasks)))):
key, item = sorted_queue.pop()

for key, workload in workloads_to_schedule:
# If a task makes it here but is still understood by the executor
# to be running, it generally means that the task has been killed
# externally and not yet been marked as failed.
Expand All @@ -373,12 +415,12 @@ def trigger_tasks(self, open_slots: int) -> None:
if key in self.attempts:
del self.attempts[key]

if isinstance(item, workloads.ExecuteTask) and hasattr(item, "ti"):
ti = item.ti
if isinstance(workload, workloads.ExecuteTask) and hasattr(workload, "ti"):
ti = workload.ti

# If it's None, then the span for the current id hasn't been started.
if self.active_spans is not None and self.active_spans.get("ti:" + str(ti.id)) is None:
if isinstance(ti, workloads.TaskInstance):
if isinstance(ti, TaskInstanceDTO):
parent_context = Trace.extract(ti.parent_context_carrier)
else:
parent_context = Trace.extract(ti.dag_run.context_carrier)
Expand All @@ -397,7 +439,8 @@ def trigger_tasks(self, open_slots: int) -> None:
carrier = Trace.inject()
ti.context_carrier = carrier

workload_list.append(item)
workload_list.append(workload)

if workload_list:
self._process_workloads(workload_list)

Expand Down Expand Up @@ -459,24 +502,25 @@ def running_state(self, key: TaskInstanceKey, info=None) -> None:
"""
self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False)

def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueType]:
"""
Return and flush the event buffer.

In case dag_ids is specified it will only return and flush events
for the given dag_ids. Otherwise, it returns and flushes all events.
Note: Callback events (with string keys) are always included regardless of dag_ids filter.

:param dag_ids: the dag_ids to return events for; returns all if given ``None``.
:return: a dict of events
"""
cleared_events: dict[TaskInstanceKey, EventBufferValueType] = {}
cleared_events: dict[WorkloadKey, EventBufferValueType] = {}
if dag_ids is None:
cleared_events = self.event_buffer
self.event_buffer = {}
else:
for ti_key in list(self.event_buffer.keys()):
if ti_key.dag_id in dag_ids:
cleared_events[ti_key] = self.event_buffer.pop(ti_key)
for key in list(self.event_buffer.keys()):
if isinstance(key, CallbackKey) or key.dag_id in dag_ids:
cleared_events[key] = self.event_buffer.pop(key)

return cleared_events

Expand Down Expand Up @@ -529,21 +573,26 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

@property
def slots_available(self):
"""Number of new tasks this executor instance can accept."""
return self.parallelism - len(self.running) - len(self.queued_tasks)
"""Number of new workloads (tasks and callbacks) this executor instance can accept."""
return self.parallelism - len(self.running) - len(self.queued_tasks) - len(self.queued_callbacks)

@property
def slots_occupied(self):
"""Number of tasks this executor instance is currently managing."""
return len(self.running) + len(self.queued_tasks)
"""Number of workloads (tasks and callbacks) this executor instance is currently managing."""
return len(self.running) + len(self.queued_tasks) + len(self.queued_callbacks)

def debug_dump(self):
"""Get called in response to SIGUSR2 by the scheduler."""
self.log.info(
"executor.queued (%d)\n\t%s",
"executor.queued_tasks (%d)\n\t%s",
len(self.queued_tasks),
"\n\t".join(map(repr, self.queued_tasks.items())),
)
self.log.info(
"executor.queued_callbacks (%d)\n\t%s",
len(self.queued_callbacks),
"\n\t".join(map(repr, self.queued_callbacks.items())),
)
self.log.info("executor.running (%d)\n\t%s", len(self.running), "\n\t".join(map(repr, self.running)))
self.log.info(
"executor.event_buffer (%d)\n\t%s",
Expand Down
90 changes: 61 additions & 29 deletions airflow-core/src/airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import TaskInstanceState
from airflow.executors.workloads.callback import execute_callback_workload
from airflow.utils.state import CallbackState, TaskInstanceState

# add logger to parameter of setproctitle to support logging
if sys.platform == "darwin":
Expand All @@ -50,13 +51,23 @@
if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger

TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Exception | None]
from airflow.executors.workloads.types import WorkloadResultType


def _get_executor_process_title_prefix(team_name: str | None) -> str:
"""
Build the process title prefix for LocalExecutor workers.

:param team_name: Team name from executor configuration
"""
team_suffix = f" [{team_name}]" if team_name else ""
return f"airflow worker -- LocalExecutor{team_suffix}:"


def _run_worker(
logger_name: str,
input: SimpleQueue[workloads.All | None],
output: Queue[TaskInstanceStateType],
output: Queue[WorkloadResultType],
unread_messages: multiprocessing.sharedctypes.Synchronized[int],
team_conf,
):
Expand All @@ -68,11 +79,8 @@ def _run_worker(
log = structlog.get_logger(logger_name)
log.info("Worker starting up pid=%d", os.getpid())

# Create team suffix for process title
team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else ""

while True:
setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: <idle>", log)
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} <idle>", log)
try:
workload = input.get()
except EOFError:
Expand All @@ -87,25 +95,30 @@ def _run_worker(
# Received poison pill, no more tasks to run
return

if not isinstance(workload, workloads.ExecuteTask):
raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}")

# Decrement this as soon as we pick up a message off the queue
with unread_messages:
unread_messages.value -= 1
key = None
if ti := getattr(workload, "ti", None):
key = ti.key
else:
raise TypeError(f"Don't know how to get ti key from {type(workload).__name__}")

try:
_execute_work(log, workload, team_conf)
# Handle different workload types
if isinstance(workload, workloads.ExecuteTask):
try:
_execute_work(log, workload, team_conf)
output.put((workload.ti.key, TaskInstanceState.SUCCESS, None))
except Exception as e:
log.exception("Task execution failed.")
output.put((workload.ti.key, TaskInstanceState.FAILED, e))

elif isinstance(workload, workloads.ExecuteCallback):
output.put((workload.callback.id, CallbackState.RUNNING, None))
try:
_execute_callback(log, workload, team_conf)
output.put((workload.callback.id, CallbackState.SUCCESS, None))
except Exception as e:
log.exception("Callback execution failed")
output.put((workload.callback.id, CallbackState.FAILED, e))

output.put((key, TaskInstanceState.SUCCESS, None))
except Exception as e:
log.exception("uhoh")
output.put((key, TaskInstanceState.FAILED, e))
else:
raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}")


def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None:
Expand All @@ -118,9 +131,7 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No
"""
from airflow.sdk.execution_time.supervisor import supervise

# Create team suffix for process title
team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else ""
setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: {workload.ti.id}", log)
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.ti.id}", log)

base_url = team_conf.get("api", "base_url", fallback="/")
# If it's a relative URL, use localhost:8080 as the default
Expand All @@ -141,6 +152,22 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No
)


def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_conf) -> None:
"""
Execute a callback workload.

:param log: Logger instance
:param workload: The ExecuteCallback workload to execute
:param team_conf: Team-specific executor configuration
"""
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log)

success, error_msg = execute_callback_workload(workload.callback, log)

if not success:
raise RuntimeError(error_msg or "Callback execution failed")


class LocalExecutor(BaseExecutor):
"""
LocalExecutor executes tasks locally in parallel.
Expand All @@ -155,9 +182,10 @@ class LocalExecutor(BaseExecutor):

supports_multi_team: bool = True
serve_logs: bool = True
supports_callbacks: bool = True

activity_queue: SimpleQueue[workloads.All | None]
result_queue: SimpleQueue[TaskInstanceStateType]
result_queue: SimpleQueue[WorkloadResultType]
workers: dict[int, multiprocessing.Process]
_unread_messages: multiprocessing.sharedctypes.Synchronized[int]

Expand Down Expand Up @@ -300,10 +328,14 @@ def end(self) -> None:
def terminate(self):
"""Terminate the executor is not doing anything."""

def _process_workloads(self, workloads):
for workload in workloads:
def _process_workloads(self, workload_list):
for workload in workload_list:
self.activity_queue.put(workload)
del self.queued_tasks[workload.ti.key]
# Remove from appropriate queue based on workload type
if isinstance(workload, workloads.ExecuteTask):
del self.queued_tasks[workload.ti.key]
elif isinstance(workload, workloads.ExecuteCallback):
del self.queued_callbacks[workload.callback.id]
with self._unread_messages:
self._unread_messages.value += len(workloads)
self._unread_messages.value += len(workload_list)
self._check_workers()
Loading
Loading