Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.cli.cli_config import GroupCommand
from airflow.executors.executor_utils import ExecutorName
from airflow.executors.workloads import ExecutorWorkload
from airflow.executors.workloads.types import WorkloadKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -219,7 +220,7 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
"""Add an event to the log table."""
self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra))

def queue_workload(self, workload: workloads.All, session: Session) -> None:
def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None:
if isinstance(workload, workloads.ExecuteTask):
ti = workload.ti
self.queued_tasks[ti.key] = workload
Expand All @@ -237,7 +238,7 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None:
f"Workload must be one of: ExecuteTask, ExecuteCallback."
)

def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]:
def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, ExecutorWorkload]]:
"""
Select and return the next batch of workloads to schedule, respecting priority policy.
Expand All @@ -246,7 +247,7 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey,
:param open_slots: Number of available execution slots
"""
workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = []
workloads_to_schedule: list[tuple[WorkloadKey, ExecutorWorkload]] = []

if self.queued_callbacks:
for key, workload in self.queued_callbacks.items():
Expand All @@ -262,7 +263,7 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey,

return workloads_to_schedule

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
def _process_workloads(self, workloads: Sequence[ExecutorWorkload]) -> None:
"""
Process the given workloads.
Expand Down
112 changes: 44 additions & 68 deletions airflow-core/src/airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@

import structlog

from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.workloads.callback import execute_callback_workload
from airflow.utils.state import CallbackState, TaskInstanceState

# add logger to parameter of setproctitle to support logging
if sys.platform == "darwin":
Expand All @@ -51,9 +48,23 @@
if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger

from airflow.executors.workloads import ExecutorWorkload
from airflow.executors.workloads.types import WorkloadResultType


def _get_execution_api_server_url(team_conf) -> str:
"""
Resolve the execution API server URL from team-specific configuration.

:param team_conf: Team-specific executor configuration (ExecutorConf or AirflowConfigParser)
"""
base_url = team_conf.get("api", "base_url", fallback="/")
if base_url.startswith("/"):
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"
return team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server)


def _get_executor_process_title_prefix(team_name: str | None) -> str:
"""
Build the process title prefix for LocalExecutor workers.
Expand All @@ -66,7 +77,7 @@ def _get_executor_process_title_prefix(team_name: str | None) -> str:

def _run_worker(
logger_name: str,
input: SimpleQueue[workloads.All | None],
input: SimpleQueue[ExecutorWorkload | None],
output: Queue[WorkloadResultType],
unread_messages: multiprocessing.sharedctypes.Synchronized[int],
team_conf,
Expand Down Expand Up @@ -99,73 +110,35 @@ def _run_worker(
with unread_messages:
unread_messages.value -= 1

# Handle different workload types
if isinstance(workload, workloads.ExecuteTask):
try:
_execute_work(log, workload, team_conf)
output.put((workload.ti.key, TaskInstanceState.SUCCESS, None))
except Exception as e:
log.exception("Task execution failed.")
output.put((workload.ti.key, TaskInstanceState.FAILED, e))

elif isinstance(workload, workloads.ExecuteCallback):
output.put((workload.callback.id, CallbackState.RUNNING, None))
try:
_execute_callback(log, workload, team_conf)
output.put((workload.callback.id, CallbackState.SUCCESS, None))
except Exception as e:
log.exception("Callback execution failed")
output.put((workload.callback.id, CallbackState.FAILED, e))

else:
raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}")


def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None:
"""
Execute command received and stores result state in queue.

:param log: Logger instance
:param workload: The workload to execute
:param team_conf: Team-specific executor configuration
"""
from airflow.sdk.execution_time.supervisor import supervise
if workload.running_state is not None:
output.put((workload.key, workload.running_state, None))

setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.ti.id}", log)

base_url = team_conf.get("api", "base_url", fallback="/")
# If it's a relative URL, use localhost:8080 as the default
if base_url.startswith("/"):
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"

# This will return the exit code of the task process, but we don't care about that, just if the
# _supervisor_ had an error reporting the state back (which will result in an exception.)
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
log_path=workload.log_path,
)
try:
_execute_workload(log, workload, team_conf)
output.put((workload.key, workload.success_state, None))
except Exception as e:
log.exception("Workload execution failed.", workload_type=type(workload).__name__)
output.put((workload.key, workload.failure_state, e))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using .pop(workload.key, None) on both queued_tasks and queued_callbacks for every workload silently swallows missing keys. The old code used del which would raise KeyError if a workload was never queued or got dequeued twice -- surfacing logic bugs rather than hiding them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. I can see where you are going. I don't think it's quite that simple though. with the two pops how I have it, one of them will pop and the other is expected to be a no-op. a task will pop from queued_tasks and the queued_callbacks pop will just not do anything, and vice versa. A del would have a problem with the no-op side of that.

But you are right that I missed the case where it's not in either. How do you feel about this:

for workload in workload_list:
    self.activity_queue.put(workload)
    # A valid workload will exist in exactly one of these dicts.
    # One will succeed, the other will fail gracefully and return None.
    removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop(workload.key, None)
    if not removed:
        raise KeyError(f"Workload {workload.key} was not found in any queue.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented that in e558b31, if you have another idea I can revert it



def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_conf) -> None:
def _execute_workload(log: Logger, workload: ExecutorWorkload, team_conf) -> None:
"""
Execute a callback workload.
Execute any workload type in a supervised subprocess.

All workload types are run in a supervised child process, providing process isolation,
stdout/stderr capture, signal handling, and crash detection.

:param log: Logger instance
:param workload: The ExecuteCallback workload to execute
:param workload: The workload to execute (ExecuteTask or ExecuteCallback)
:param team_conf: Team-specific executor configuration
"""
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log)

success, error_msg = execute_callback_workload(workload.callback, log)
from airflow.sdk.execution_time.supervisor import supervise_workload

if not success:
raise RuntimeError(error_msg or "Callback execution failed")
supervise_workload(
workload,
server=_get_execution_api_server_url(team_conf),
proctitle=f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.display_name}",
)


class LocalExecutor(BaseExecutor):
Expand All @@ -184,7 +157,7 @@ class LocalExecutor(BaseExecutor):
serve_logs: bool = True
supports_callbacks: bool = True

activity_queue: SimpleQueue[workloads.All | None]
activity_queue: SimpleQueue[ExecutorWorkload | None]
result_queue: SimpleQueue[WorkloadResultType]
workers: dict[int, multiprocessing.Process]
_unread_messages: multiprocessing.sharedctypes.Synchronized[int]
Expand Down Expand Up @@ -213,6 +186,7 @@ def start(self) -> None:

# Mypy sees this value as `SynchronizedBase[c_uint]`, but that isn't the right runtime type behaviour
# (it looks like an int to python)

self._unread_messages = multiprocessing.Value(ctypes.c_uint)

if self.is_mp_using_fork:
Expand Down Expand Up @@ -331,11 +305,13 @@ def terminate(self):
def _process_workloads(self, workload_list):
for workload in workload_list:
self.activity_queue.put(workload)
# Remove from appropriate queue based on workload type
if isinstance(workload, workloads.ExecuteTask):
del self.queued_tasks[workload.ti.key]
elif isinstance(workload, workloads.ExecuteCallback):
del self.queued_callbacks[workload.callback.id]
# A valid workload will exist in exactly one of these dicts.
# One pop will succeed, the other will return None gracefully.
removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop(
workload.key, None
)
if not removed:
raise KeyError(f"Workload {workload.key} was not found in any queue")
with self._unread_messages:
self._unread_messages.value += len(workload_list)
self._check_workers()
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/executors/workloads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@

TaskInstance = TaskInstanceDTO

ExecutorWorkload = Annotated[
ExecuteTask | ExecuteCallback,
Field(discriminator="type"),
]
"""Workload types that can be sent to executors (excludes RunTrigger, which is handled by the triggerer)."""

__all__ = [
"All",
"BaseWorkload",
"BundleInfo",
"CallbackFetchMethod",
"ExecuteCallback",
"ExecuteTask",
"ExecutorWorkload",
"TaskInstance",
"TaskInstanceDTO",
]
72 changes: 71 additions & 1 deletion airflow-core/src/airflow/executors/workloads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from __future__ import annotations

import os
from abc import ABC
from abc import ABC, abstractmethod
from collections.abc import Hashable
from typing import TYPE_CHECKING

from pydantic import BaseModel, ConfigDict, Field

if TYPE_CHECKING:
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.executors.workloads.types import WorkloadState


class BaseWorkload:
Expand Down Expand Up @@ -83,3 +85,71 @@ class BaseDagBundleWorkload(BaseWorkloadSchema, ABC):
dag_rel_path: os.PathLike[str] # Filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`)
bundle_info: BundleInfo
log_path: str | None # Rendered relative log filename template the task logs should be written to.

@property
@abstractmethod
def key(self) -> Hashable:
"""
Return the unique key identifying this workload instance.
Used by executors for tracking queued/running workloads and reporting results.
Must be a hashable value suitable for use in sets and as dict keys.
Must be implemented by subclasses.
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement key")

@property
@abstractmethod
def display_name(self) -> str:
"""
Return a human-readable name for this workload, suitable for logging and process titles.
Used by executors to set worker process titles and log messages.
Must be implemented by subclasses.
Example::
# For a task workload:
return str(self.ti.id) # "4d828a62-a417-4936-a7a6-2b3fabacecab"
# For a callback workload:
return str(self.callback.id) # "12345678-1234-5678-1234-567812345678"
# Results in process titles like:
# "airflow worker -- LocalExecutor: 4d828a62-a417-4936-a7a6-2b3fabacecab"
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement display_name")

@property
@abstractmethod
def success_state(self) -> WorkloadState:
"""
Return the state value representing successful completion of this workload type.
Must be implemented by subclasses.
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement success_state")

@property
@abstractmethod
def failure_state(self) -> WorkloadState:
"""
Return the state value representing failed completion of this workload type.
Must be implemented by subclasses.
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement failure_state")

@property
def running_state(self) -> WorkloadState | None:
"""
Return the state value representing that this workload is actively running.
Called by the executor worker *before* execution begins. Subclasses may override
this to emit an intermediate state transition (e.g. callbacks need
QUEUED → RUNNING → SUCCESS/FAILED). Returns ``None`` by default, meaning
no intermediate state is emitted.
"""
return None
Loading
Loading