diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index d67c25c7bafaa..385964d42b156 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -51,6 +51,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand from airflow.executors.executor_utils import ExecutorName + from airflow.executors.workloads import ExecutorWorkload from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -219,7 +220,7 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey): """Add an event to the log table.""" self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra)) - def queue_workload(self, workload: workloads.All, session: Session) -> None: + def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None: if isinstance(workload, workloads.ExecuteTask): ti = workload.ti self.queued_tasks[ti.key] = workload @@ -237,7 +238,7 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None: f"Workload must be one of: ExecuteTask, ExecuteCallback." ) - def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]: + def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, ExecutorWorkload]]: """ Select and return the next batch of workloads to schedule, respecting priority policy. @@ -246,7 +247,7 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, :param open_slots: Number of available execution slots """ - workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = [] + workloads_to_schedule: list[tuple[WorkloadKey, ExecutorWorkload]] = [] if self.queued_callbacks: for key, workload in self.queued_callbacks.items(): @@ -262,7 +263,7 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, return workloads_to_schedule - def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: + def _process_workloads(self, workloads: Sequence[ExecutorWorkload]) -> None: """ Process the given workloads. diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 9b5939a0bd2e7..c81d69089442c 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -35,10 +35,7 @@ import structlog -from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor -from airflow.executors.workloads.callback import execute_callback_workload -from airflow.utils.state import CallbackState, TaskInstanceState # add logger to parameter of setproctitle to support logging if sys.platform == "darwin": @@ -51,9 +48,23 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger + from airflow.executors.workloads import ExecutorWorkload from airflow.executors.workloads.types import WorkloadResultType +def _get_execution_api_server_url(team_conf) -> str: + """ + Resolve the execution API server URL from team-specific configuration. + + :param team_conf: Team-specific executor configuration (ExecutorConf or AirflowConfigParser) + """ + base_url = team_conf.get("api", "base_url", fallback="/") + if base_url.startswith("/"): + base_url = f"http://localhost:8080{base_url}" + default_execution_api_server = f"{base_url.rstrip('/')}/execution/" + return team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server) + + def _get_executor_process_title_prefix(team_name: str | None) -> str: """ Build the process title prefix for LocalExecutor workers. @@ -66,7 +77,7 @@ def _get_executor_process_title_prefix(team_name: str | None) -> str: def _run_worker( logger_name: str, - input: SimpleQueue[workloads.All | None], + input: SimpleQueue[ExecutorWorkload | None], output: Queue[WorkloadResultType], unread_messages: multiprocessing.sharedctypes.Synchronized[int], team_conf, @@ -99,73 +110,35 @@ def _run_worker( with unread_messages: unread_messages.value -= 1 - # Handle different workload types - if isinstance(workload, workloads.ExecuteTask): - try: - _execute_work(log, workload, team_conf) - output.put((workload.ti.key, TaskInstanceState.SUCCESS, None)) - except Exception as e: - log.exception("Task execution failed.") - output.put((workload.ti.key, TaskInstanceState.FAILED, e)) - - elif isinstance(workload, workloads.ExecuteCallback): - output.put((workload.callback.id, CallbackState.RUNNING, None)) - try: - _execute_callback(log, workload, team_conf) - output.put((workload.callback.id, CallbackState.SUCCESS, None)) - except Exception as e: - log.exception("Callback execution failed") - output.put((workload.callback.id, CallbackState.FAILED, e)) - - else: - raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") - - -def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None: - """ - Execute command received and stores result state in queue. - - :param log: Logger instance - :param workload: The workload to execute - :param team_conf: Team-specific executor configuration - """ - from airflow.sdk.execution_time.supervisor import supervise + if workload.running_state is not None: + output.put((workload.key, workload.running_state, None)) - setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.ti.id}", log) - - base_url = team_conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default - if base_url.startswith("/"): - base_url = f"http://localhost:8080{base_url}" - default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - - # This will return the exit code of the task process, but we don't care about that, just if the - # _supervisor_ had an error reporting the state back (which will result in an exception.) - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), - log_path=workload.log_path, - ) + try: + _execute_workload(log, workload, team_conf) + output.put((workload.key, workload.success_state, None)) + except Exception as e: + log.exception("Workload execution failed.", workload_type=type(workload).__name__) + output.put((workload.key, workload.failure_state, e)) -def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_conf) -> None: +def _execute_workload(log: Logger, workload: ExecutorWorkload, team_conf) -> None: """ - Execute a callback workload. + Execute any workload type in a supervised subprocess. + + All workload types are run in a supervised child process, providing process isolation, + stdout/stderr capture, signal handling, and crash detection. :param log: Logger instance - :param workload: The ExecuteCallback workload to execute + :param workload: The workload to execute (ExecuteTask or ExecuteCallback) :param team_conf: Team-specific executor configuration """ - setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log) - - success, error_msg = execute_callback_workload(workload.callback, log) + from airflow.sdk.execution_time.supervisor import supervise_workload - if not success: - raise RuntimeError(error_msg or "Callback execution failed") + supervise_workload( + workload, + server=_get_execution_api_server_url(team_conf), + proctitle=f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.display_name}", + ) class LocalExecutor(BaseExecutor): @@ -184,7 +157,7 @@ class LocalExecutor(BaseExecutor): serve_logs: bool = True supports_callbacks: bool = True - activity_queue: SimpleQueue[workloads.All | None] + activity_queue: SimpleQueue[ExecutorWorkload | None] result_queue: SimpleQueue[WorkloadResultType] workers: dict[int, multiprocessing.Process] _unread_messages: multiprocessing.sharedctypes.Synchronized[int] @@ -213,6 +186,7 @@ def start(self) -> None: # Mypy sees this value as `SynchronizedBase[c_uint]`, but that isn't the right runtime type behaviour # (it looks like an int to python) + self._unread_messages = multiprocessing.Value(ctypes.c_uint) if self.is_mp_using_fork: @@ -331,11 +305,13 @@ def terminate(self): def _process_workloads(self, workload_list): for workload in workload_list: self.activity_queue.put(workload) - # Remove from appropriate queue based on workload type - if isinstance(workload, workloads.ExecuteTask): - del self.queued_tasks[workload.ti.key] - elif isinstance(workload, workloads.ExecuteCallback): - del self.queued_callbacks[workload.callback.id] + # A valid workload will exist in exactly one of these dicts. + # One pop will succeed, the other will return None gracefully. + removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop( + workload.key, None + ) + if not removed: + raise KeyError(f"Workload {workload.key} was not found in any queue") with self._unread_messages: self._unread_messages.value += len(workload_list) self._check_workers() diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py index 462e38ad0aaac..e0af7df2922eb 100644 --- a/airflow-core/src/airflow/executors/workloads/__init__.py +++ b/airflow-core/src/airflow/executors/workloads/__init__.py @@ -34,6 +34,12 @@ TaskInstance = TaskInstanceDTO +ExecutorWorkload = Annotated[ + ExecuteTask | ExecuteCallback, + Field(discriminator="type"), +] +"""Workload types that can be sent to executors (excludes RunTrigger, which is handled by the triggerer).""" + __all__ = [ "All", "BaseWorkload", @@ -41,6 +47,7 @@ "CallbackFetchMethod", "ExecuteCallback", "ExecuteTask", + "ExecutorWorkload", "TaskInstance", "TaskInstanceDTO", ] diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 97cf16ebaf64d..6404e991e0cad 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -19,13 +19,15 @@ from __future__ import annotations import os -from abc import ABC +from abc import ABC, abstractmethod +from collections.abc import Hashable from typing import TYPE_CHECKING from pydantic import BaseModel, ConfigDict, Field if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator + from airflow.executors.workloads.types import WorkloadState class BaseWorkload: @@ -83,3 +85,71 @@ class BaseDagBundleWorkload(BaseWorkloadSchema, ABC): dag_rel_path: os.PathLike[str] # Filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`) bundle_info: BundleInfo log_path: str | None # Rendered relative log filename template the task logs should be written to. + + @property + @abstractmethod + def key(self) -> Hashable: + """ + Return the unique key identifying this workload instance. + + Used by executors for tracking queued/running workloads and reporting results. + Must be a hashable value suitable for use in sets and as dict keys. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement key") + + @property + @abstractmethod + def display_name(self) -> str: + """ + Return a human-readable name for this workload, suitable for logging and process titles. + + Used by executors to set worker process titles and log messages. + + Must be implemented by subclasses. + + Example:: + + # For a task workload: + return str(self.ti.id) # "4d828a62-a417-4936-a7a6-2b3fabacecab" + + # For a callback workload: + return str(self.callback.id) # "12345678-1234-5678-1234-567812345678" + + # Results in process titles like: + # "airflow worker -- LocalExecutor: 4d828a62-a417-4936-a7a6-2b3fabacecab" + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement display_name") + + @property + @abstractmethod + def success_state(self) -> WorkloadState: + """ + Return the state value representing successful completion of this workload type. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement success_state") + + @property + @abstractmethod + def failure_state(self) -> WorkloadState: + """ + Return the state value representing failed completion of this workload type. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement failure_state") + + @property + def running_state(self) -> WorkloadState | None: + """ + Return the state value representing that this workload is actively running. + + Called by the executor worker *before* execution begins. Subclasses may override + this to emit an intermediate state transition (e.g. callbacks need + QUEUED → RUNNING → SUCCESS/FAILED). Returns ``None`` by default, meaning + no intermediate state is emitted. + """ + return None diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 273c55953675b..a78dbab43a594 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -19,7 +19,6 @@ from __future__ import annotations from enum import Enum -from importlib import import_module from pathlib import Path from typing import TYPE_CHECKING, Literal from uuid import UUID @@ -28,6 +27,7 @@ from pydantic import BaseModel, Field, field_validator from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo +from airflow.utils.state import CallbackState if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -75,6 +75,32 @@ class ExecuteCallback(BaseDagBundleWorkload): type: Literal["ExecuteCallback"] = Field(init=False, default="ExecuteCallback") + @property + def key(self) -> CallbackKey: + """Return the callback key for this workload.""" + return self.callback.key + + @property + def display_name(self) -> str: + """Return a human-readable name for logging and process titles.""" + if path := self.callback.data.get("path", ""): + # Use just the function/class name for brevity in process titles. + # The full path and UUID are available in log messages if needed. + return path.rsplit(".", 1)[-1] + return str(self.callback.id) + + @property + def success_state(self) -> CallbackState: + return CallbackState.SUCCESS + + @property + def failure_state(self) -> CallbackState: + return CallbackState.FAILED + + @property + def running_state(self) -> CallbackState: + return CallbackState.RUNNING + @classmethod def make( cls, @@ -99,66 +125,3 @@ def make( log_path=fname, bundle_info=bundle_info, ) - - -def execute_callback_workload( - callback: CallbackDTO, - log, -) -> tuple[bool, str | None]: - """ - Execute a callback function by importing and calling it, returning the success state. - - Supports two patterns: - 1. Functions - called directly with kwargs - 2. Classes that return callable instances (like BaseNotifier) - instantiated then called with context - - Example: - # Function callback - callback.data = {"path": "my_module.alert_func", "kwargs": {"msg": "Alert!", "context": {...}}} - execute_callback_workload(callback, log) # Calls alert_func(msg="Alert!", context={...}) - - # Notifier callback - callback.data = {"path": "airflow.providers.slack...SlackWebhookNotifier", "kwargs": {"text": "Alert!", "context": {...}}} - execute_callback_workload(callback, log) # SlackWebhookNotifier(text=..., context=...) then calls instance(context) - - :param callback: The Callback schema containing path and kwargs - :param log: Logger instance for recording execution - :return: Tuple of (success: bool, error_message: str | None) - """ - from airflow.models.callback import _accepts_context # circular import - - callback_path = callback.data.get("path") - callback_kwargs = callback.data.get("kwargs", {}) - - if not callback_path: - return False, "Callback path not found in data." - - try: - # Import the callback callable - # Expected format: "module.path.to.function_or_class" - module_path, function_name = callback_path.rsplit(".", 1) - module = import_module(module_path) - callback_callable = getattr(module, function_name) - context = callback_kwargs.pop("context", None) - - log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) - - # If the callback is a callable, call it. If it is a class, instantiate it. - # Rather than forcing all custom callbacks to accept context, conditionally provide it only if supported. - if _accepts_context(callback_callable) and context is not None: - result = callback_callable(**callback_kwargs, context=context) - else: - result = callback_callable(**callback_kwargs) - - # If the callback is a class then it is now instantiated and callable, call it. - if callable(result): - log.debug("Calling result with context for %s", callback_path) - result = result(context) - - log.info("Callback %s executed successfully.", callback_path) - return True, None - - except Exception as e: - error_msg = f"Callback execution failed: {type(e).__name__}: {str(e)}" - log.exception("Callback %s(%s) execution failed: %s", callback_path, callback_kwargs, error_msg) - return False, error_msg diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index 4ca8c310fb5c2..3a95aab831061 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -25,6 +25,7 @@ from pydantic import BaseModel, Field from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -73,6 +74,24 @@ class ExecuteTask(BaseDagBundleWorkload): type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") + @property + def key(self) -> TaskInstanceKey: + """Return the TaskInstanceKey for this workload.""" + return self.ti.key + + @property + def display_name(self) -> str: + """Return the task instance ID as a display name.""" + return str(self.ti.id) + + @property + def success_state(self) -> TaskInstanceState: + return TaskInstanceState.SUCCESS + + @property + def failure_state(self) -> TaskInstanceState: + return TaskInstanceState.FAILED + @classmethod def make( cls, diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index e08d58fa4da5c..c43c1fb509bea 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import inspect -from collections.abc import Callable from datetime import datetime from enum import Enum from importlib import import_module @@ -29,6 +27,8 @@ from sqlalchemy import ForeignKey, Integer, String, Text, Uuid from sqlalchemy.orm import Mapped, mapped_column, relationship +# Re-exporting as _accepts_context for backward compatibility +from airflow._shared.module_loading import accepts_context as _accepts_context # noqa: F401 from airflow._shared.observability.metrics.stats import Stats from airflow._shared.timezones import timezone from airflow.executors.workloads import BaseWorkload @@ -53,16 +53,6 @@ TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) -def _accepts_context(callback: Callable) -> bool: - """Check if callback accepts a 'context' parameter or **kwargs.""" - try: - sig = inspect.signature(callback) - except (ValueError, TypeError): - return True - params = sig.parameters - return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) - - class CallbackType(str, Enum): """ Types of Callbacks. diff --git a/airflow-core/src/airflow/triggers/callback.py b/airflow-core/src/airflow/triggers/callback.py index 9c2470c77eae6..b27920fed8614 100644 --- a/airflow-core/src/airflow/triggers/callback.py +++ b/airflow-core/src/airflow/triggers/callback.py @@ -22,8 +22,8 @@ from collections.abc import AsyncIterator from typing import Any -from airflow._shared.module_loading import import_string, qualname -from airflow.models.callback import CallbackState, _accepts_context +from airflow._shared.module_loading import accepts_context, import_string, qualname +from airflow.models.callback import CallbackState from airflow.triggers.base import BaseTrigger, TriggerEvent log = logging.getLogger(__name__) @@ -55,7 +55,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs` context = self.callback_kwargs.pop("context", None) - if _accepts_context(callback) and context is not None: + if accepts_context(callback) and context is not None: result = await callback(**self.callback_kwargs, context=context) else: result = await callback(**self.callback_kwargs) diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index fa0f311d018fe..530c401227966 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -36,10 +36,11 @@ from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType from airflow.executors.local_executor import LocalExecutor from airflow.executors.workloads.base import BundleInfo -from airflow.executors.workloads.callback import CallbackDTO, execute_callback_workload +from airflow.executors.workloads.callback import CallbackDTO from airflow.models.callback import CallbackFetchMethod from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.sdk import BaseOperator +from airflow.sdk.execution_time.callback_supervisor import execute_callback from airflow.serialization.definitions.baseoperator import SerializedBaseOperator from airflow.utils.state import State, TaskInstanceState @@ -661,64 +662,21 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): class TestExecuteCallbackWorkload: - def test_execute_function_callback_success(self): - callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", - fetch_method=CallbackFetchMethod.IMPORT_PATH, - data={ - "path": "builtins.dict", - "kwargs": {"a": 1, "b": 2, "c": 3}, - }, - ) + @pytest.mark.parametrize( + ("path", "kwargs", "expect_success", "error_contains"), + [ + pytest.param("builtins.dict", {"a": 1, "b": 2, "c": 3}, True, None, id="function_success"), + pytest.param("", {}, False, "Callback path not found", id="missing_path"), + pytest.param("nonexistent.module.function", {}, False, "ModuleNotFoundError", id="import_error"), + pytest.param("builtins.len", {}, False, "TypeError", id="execution_error"), + ], + ) + def test_execute_callback(self, path, kwargs, expect_success, error_contains): log = structlog.get_logger() + success, error = execute_callback(path, kwargs, log) - success, error = execute_callback_workload(callback_data, log) - - assert success is True - assert error is None - - def test_execute_callback_missing_path(self): - callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", - fetch_method=CallbackFetchMethod.IMPORT_PATH, - data={"kwargs": {}}, # Missing 'path' - ) - log = structlog.get_logger() - - success, error = execute_callback_workload(callback_data, log) - - assert success is False - assert "Callback path not found" in error - - def test_execute_callback_import_error(self): - callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", - fetch_method=CallbackFetchMethod.IMPORT_PATH, - data={ - "path": "nonexistent.module.function", - "kwargs": {}, - }, - ) - log = structlog.get_logger() - - success, error = execute_callback_workload(callback_data, log) - - assert success is False - assert "ModuleNotFoundError" in error - - def test_execute_callback_execution_error(self): - # Use a function that will raise an error; len() requires an argument - callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", - fetch_method=CallbackFetchMethod.IMPORT_PATH, - data={ - "path": "builtins.len", - "kwargs": {}, - }, - ) - log = structlog.get_logger() - - success, error = execute_callback_workload(callback_data, log) - - assert success is False - assert "TypeError" in error + assert success is expect_success + if error_contains: + assert error_contains in error + else: + assert error is None diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 59afffe6833fe..7fd8410dc9752 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -28,7 +28,8 @@ from airflow._shared.timezones import timezone from airflow.executors import workloads -from airflow.executors.local_executor import LocalExecutor, _execute_work +from airflow.executors.base_executor import ExecutorConf +from airflow.executors.local_executor import LocalExecutor, _execute_workload from airflow.executors.workloads.base import BundleInfo from airflow.executors.workloads.callback import CallbackDTO from airflow.executors.workloads.task import TaskInstanceDTO @@ -54,6 +55,17 @@ ) +def _make_mock_task_workload(): + """Create a MagicMock that passes isinstance checks for ExecuteTask and has required attributes.""" + task_workload = mock.MagicMock(spec=workloads.ExecuteTask) + task_workload.ti = mock.MagicMock(spec=TaskInstanceDTO) + task_workload.dag_rel_path = "some/path" + task_workload.bundle_info = mock.MagicMock(spec=BundleInfo) + task_workload.token = "test_token" + task_workload.log_path = None + return task_workload + + class TestLocalExecutor: """ When the executor is started, end() must be called before the test finishes. @@ -115,8 +127,8 @@ def test_executor_lazy_worker_spawning(self, mock_freeze, mock_unfreeze): executor.end() @skip_non_fork_mp_start - @mock.patch("airflow.sdk.execution_time.supervisor.supervise") - def test_execution(self, mock_supervise): + @mock.patch("airflow.sdk.execution_time.supervisor.supervise_workload") + def test_execution(self, mock_supervise_workload): success_tis = [ TaskInstanceDTO( id=uuid7(), @@ -139,14 +151,14 @@ def test_execution(self, mock_supervise): # We just mock both styles here, only one will be hit though has_failed_once = False - def fake_supervise(ti, **kwargs): + def fake_supervise_workload(workload, **kwargs): nonlocal has_failed_once - if ti.id == fail_ti.id and not has_failed_once: + if workload.ti.id == fail_ti.id and not has_failed_once: has_failed_once = True raise RuntimeError("fake failure") return 0 - mock_supervise.side_effect = fake_supervise + mock_supervise_workload.side_effect = fake_supervise_workload executor = LocalExecutor(parallelism=2) @@ -261,26 +273,20 @@ def test_clean_stop_on_signal(self): "relative_base_url", ], ) - @mock.patch("airflow.sdk.execution_time.supervisor.supervise") - def test_execution_api_server_url_config(self, mock_supervise, conf_values, expected_server): + @mock.patch("airflow.sdk.execution_time.supervisor.supervise_workload") + def test_execution_api_server_url_config(self, mock_supervise_workload, conf_values, expected_server): """Test that execution_api_server_url is correctly configured with fallback""" from airflow.executors.base_executor import ExecutorConf with conf_vars(conf_values): team_conf = ExecutorConf(team_name=None) - _execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=team_conf) - - mock_supervise.assert_called_with( - ti=mock.ANY, - dag_rel_path=mock.ANY, - bundle_info=mock.ANY, - token=mock.ANY, - server=expected_server, - log_path=mock.ANY, - ) + _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=team_conf) - @mock.patch("airflow.sdk.execution_time.supervisor.supervise") - def test_team_and_global_config_isolation(self, mock_supervise): + mock_supervise_workload.assert_called_once() + assert mock_supervise_workload.call_args.kwargs["server"] == expected_server + + @mock.patch("airflow.sdk.execution_time.supervisor.supervise_workload") + def test_team_and_global_config_isolation(self, mock_supervise_workload): """Test that team-specific and global executors use correct configurations side-by-side""" from airflow.executors.base_executor import ExecutorConf @@ -303,23 +309,21 @@ def test_team_and_global_config_isolation(self, mock_supervise): with conf_vars(config_overrides): # Test team-specific config team_conf = ExecutorConf(team_name=team_name) - _execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=team_conf) + _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=team_conf) # Verify team-specific server URL was used - assert mock_supervise.call_count == 1 - call_kwargs = mock_supervise.call_args[1] - assert call_kwargs["server"] == team_server + assert mock_supervise_workload.call_count == 1 + assert mock_supervise_workload.call_args.kwargs["server"] == team_server - mock_supervise.reset_mock() + mock_supervise_workload.reset_mock() # Test global config (no team) global_conf = ExecutorConf(team_name=None) - _execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=global_conf) + _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=global_conf) # Verify default server URL was used - assert mock_supervise.call_count == 1 - call_kwargs = mock_supervise.call_args[1] - assert call_kwargs["server"] == default_server + assert mock_supervise_workload.call_count == 1 + assert mock_supervise_workload.call_args.kwargs["server"] == default_server def test_multiple_team_executors_isolation(self): """Test that multiple team executors can coexist with isolated resources""" @@ -377,18 +381,18 @@ def test_global_executor_without_team_name(self): class TestLocalExecutorCallbackSupport: + CALLBACK_UUID = "12345678-1234-5678-1234-567812345678" + def test_supports_callbacks_flag_is_true(self): executor = LocalExecutor() assert executor.supports_callbacks is True @skip_non_fork_mp_start - @mock.patch("airflow.executors.workloads.callback.execute_callback_workload") - def test_process_callback_workload(self, mock_execute_callback): - mock_execute_callback.return_value = (True, None) - + def test_process_callback_workload_queue_management(self): + """Test that _process_workloads correctly removes callbacks from queued_callbacks.""" executor = LocalExecutor(parallelism=1) callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", + id=self.CALLBACK_UUID, fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, ) @@ -411,3 +415,51 @@ def test_process_callback_workload(self, mock_execute_callback): finally: executor.end() + + @mock.patch("airflow.sdk.execution_time.callback_supervisor.supervise_callback", return_value=0) + def test_execute_workload_calls_supervise_callback(self, mock_supervise_callback): + callback_data = CallbackDTO( + id=self.CALLBACK_UUID, + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.module.my_callback", "kwargs": {"arg1": "val1"}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + _execute_workload(log=mock.ANY, workload=callback_workload, team_conf=ExecutorConf(team_name=None)) + + mock_supervise_callback.assert_called_once_with( + id=self.CALLBACK_UUID, + callback_path="test.module.my_callback", + callback_kwargs={"arg1": "val1"}, + log_path="test.log", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + ) + + @mock.patch( + "airflow.sdk.execution_time.callback_supervisor.supervise_callback", + side_effect=RuntimeError("Callback subprocess exited with code 1"), + ) + def test_execute_workload_raises_on_callback_failure(self, mock_supervise_callback): + callback_data = CallbackDTO( + id=self.CALLBACK_UUID, + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.module.my_callback", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + with pytest.raises(RuntimeError, match="Callback subprocess exited with code 1"): + _execute_workload( + log=mock.ANY, workload=callback_workload, team_conf=ExecutorConf(team_name=None) + ) diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index 5d4940b77adf8..bb3102decf8a8 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -21,6 +21,7 @@ import pytest from sqlalchemy import select +from airflow._shared.module_loading import accepts_context from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.models import Trigger from airflow.models.callback import ( @@ -30,7 +31,6 @@ DagProcessorCallback, ExecutorCallback, TriggererCallback, - _accepts_context, ) from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback from airflow.triggers.base import TriggerEvent @@ -239,26 +239,26 @@ def test_true_when_var_keyword_present(self): def func_with_var_keyword(**kwargs): pass - assert _accepts_context(func_with_var_keyword) is True + assert accepts_context(func_with_var_keyword) is True def test_true_when_context_param_present(self): def func_with_context(context, alert_type): pass - assert _accepts_context(func_with_context) is True + assert accepts_context(func_with_context) is True def test_false_when_no_context_or_var_keyword(self): def func_without_context(a, b): pass - assert _accepts_context(func_without_context) is False + assert accepts_context(func_without_context) is False def test_false_when_no_params(self): def func_no_params(): pass - assert _accepts_context(func_no_params) is False + assert accepts_context(func_no_params) is False def test_true_for_uninspectable_callable(self): - with patch("airflow.models.callback.inspect.signature", side_effect=ValueError): - assert _accepts_context(lambda: None) is True + with patch("airflow._shared.module_loading.inspect.signature", side_effect=ValueError): + assert accepts_context(lambda: None) is True diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 768f2d9dbf646..d8508791d20cb 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -53,8 +53,6 @@ except ImportError: from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager # type:ignore[no-redef] -if AIRFLOW_V_3_2_PLUS: - from airflow.executors.workloads.callback import execute_callback_workload log = logging.getLogger(__name__) @@ -191,39 +189,18 @@ def on_celery_worker_ready(*args, **kwargs): def execute_workload(input: str) -> None: from pydantic import TypeAdapter - from airflow.executors import workloads - from airflow.providers.common.compat.sdk import conf - from airflow.sdk.execution_time.supervisor import supervise + from airflow.executors.workloads import ExecutorWorkload - decoder = TypeAdapter[workloads.All](workloads.All) + decoder = TypeAdapter[ExecutorWorkload](ExecutorWorkload) workload = decoder.validate_json(input) celery_task_id = app.current_task.request.id log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload) - base_url = conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default - if base_url.startswith("/"): - base_url = f"http://localhost:8080{base_url}" - default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - - if isinstance(workload, workloads.ExecuteTask): - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), - log_path=workload.log_path, - ) - elif isinstance(workload, workloads.ExecuteCallback): - success, error_msg = execute_callback_workload(workload.callback, log) - if not success: - raise RuntimeError(error_msg or "Callback execution failed") - else: - raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") + from airflow.sdk.execution_time.supervisor import supervise_workload + + supervise_workload(workload) if not AIRFLOW_V_3_0_PLUS: diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py b/providers/edge3/src/airflow/providers/edge3/cli/worker.py index f0cffa96c05ea..4f7628151eb02 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py @@ -212,7 +212,7 @@ def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) - try: supervise( # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - # Same like in airflow/executors/local_executor.py:_execute_work() + # Same like in airflow/executors/local_executor.py:_execute_workload() ti=ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, diff --git a/shared/module_loading/src/airflow_shared/module_loading/__init__.py b/shared/module_loading/src/airflow_shared/module_loading/__init__.py index 5b2b7c9496ec1..2406db56a24f5 100644 --- a/shared/module_loading/src/airflow_shared/module_loading/__init__.py +++ b/shared/module_loading/src/airflow_shared/module_loading/__init__.py @@ -18,6 +18,7 @@ from __future__ import annotations import functools +import inspect import logging import pkgutil import sys @@ -43,6 +44,16 @@ from types import ModuleType +def accepts_context(callback: Callable) -> bool: + """Check if callback accepts a 'context' parameter or **kwargs.""" + try: + sig = inspect.signature(callback) + except (ValueError, TypeError): + return True + params = sig.parameters + return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + def import_string(dotted_path: str): """ Import a dotted module path and return the attribute/class designated by the last name in the path. diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml index 315e0ea8a133f..bbcc6cefaca3c 100644 --- a/task-sdk/.pre-commit-config.yaml +++ b/task-sdk/.pre-commit-config.yaml @@ -45,6 +45,7 @@ repos: ^src/airflow/sdk/definitions/_internal/types\.py$| ^src/airflow/sdk/execution_time/execute_workload\.py$| ^src/airflow/sdk/execution_time/secrets_masker\.py$| + ^src/airflow/sdk/execution_time/callback_supervisor\.py$| ^src/airflow/sdk/execution_time/supervisor\.py$| ^src/airflow/sdk/execution_time/task_runner\.py$| ^src/airflow/sdk/serde/serializers/kubernetes\.py$| diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py new file mode 100644 index 0000000000000..f9747f6e82573 --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervised execution of callback workloads.""" + +from __future__ import annotations + +import sys +import time +from importlib import import_module +from typing import TYPE_CHECKING, BinaryIO, ClassVar, Protocol +from uuid import UUID + +import attrs +import structlog +from pydantic import TypeAdapter + +from airflow.sdk.execution_time.supervisor import ( + MIN_HEARTBEAT_INTERVAL, + SOCKET_CLEANUP_TIMEOUT, + WatchedSubprocess, + _make_process_nondumpable, +) + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger + from typing_extensions import Self + + # Core (airflow.executors.workloads.base.BundleInfo) and SDK (airflow.sdk.api.datamodels._generated.BundleInfo) + # are structurally identical, but MyPy treats them as different types. This Protocol makes MyPy happy. + class _BundleInfoLike(Protocol): + name: str + version: str | None + + +__all__ = ["CallbackSubprocess", "supervise_callback"] + +log: FilteringBoundLogger = structlog.get_logger(logger_name="callback_supervisor") + + +def execute_callback( + callback_path: str, + callback_kwargs: dict, + log, +) -> tuple[bool, str | None]: + """ + Execute a callback function by importing and calling it, returning the success state. + + Supports two patterns: + 1. Functions - called directly with kwargs + 2. Classes that return callable instances (like BaseNotifier) - instantiated then called with context + + Example: + # Function callback + execute_callback("my_module.alert_func", {"msg": "Alert!", "context": {...}}, log) + + # Notifier callback + execute_callback("airflow.providers.slack...SlackWebhookNotifier", {"text": "Alert!"}, log) + + :param callback_path: Dot-separated import path to the callback function or class. + :param callback_kwargs: Keyword arguments to pass to the callback. + :param log: Logger instance for recording execution. + :return: Tuple of (success: bool, error_message: str | None) + """ + from airflow.sdk._shared.module_loading import accepts_context + + if not callback_path: + return False, "Callback path not found." + + try: + # Import the callback callable + # Expected format: "module.path.to.function_or_class" + module_path, function_name = callback_path.rsplit(".", 1) + module = import_module(module_path) + callback_callable = getattr(module, function_name) + + log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) + + kwargs_without_context = {k: v for k, v in callback_kwargs.items() if k != "context"} + + # Call the callable with all kwargs if it accepts context, otherwise strip context. + if accepts_context(callback_callable): + result = callback_callable(**callback_kwargs) + else: + result = callback_callable(**kwargs_without_context) + + # If the callback was a class then it is now instantiated and callable, call it. + # Try keyword args first. If the callable only accepts positional args (like + # BaseNotifier.__call__(self, *args)), fall back to passing context positionally. + if callable(result): + try: + if accepts_context(result): + result = result(**callback_kwargs) + else: + result = result(**kwargs_without_context) + except TypeError: + result = result(callback_kwargs.get("context", {})) + + log.info("Callback %s executed successfully.", callback_path) + return True, None + + except Exception as e: + error_msg = f"Callback execution failed: {type(e).__name__}: {str(e)}" + log.exception("Callback %s(%s) execution failed: %s", callback_path, callback_kwargs, error_msg) + return False, error_msg + + +# An empty message set; the callback subprocess doesn't currently communicate back to the +# supervisor. This means callback code cannot access runtime services like Connection.get() +# or Variable.get() which require the supervisor to pass requests to the API server. +# To enable this, add the needed message types here and implement _handle_request accordingly. +# See ActivitySubprocess.decoder in supervisor.py for the full task message set and examples. +_EmptyMessage: TypeAdapter[None] = TypeAdapter(None) + + +@attrs.define(kw_only=True) +class CallbackSubprocess(WatchedSubprocess): + """ + Supervised subprocess for executing callbacks. + + Uses the WatchedSubprocess infrastructure for fork/monitor/signal handling + while keeping a simple lifecycle: start, run callback, exit. + """ + + decoder: ClassVar[TypeAdapter] = _EmptyMessage + + @classmethod + def start( # type: ignore[override] + cls, + *, + id: str, + callback_path: str, + callback_kwargs: dict, + bundle_info: _BundleInfoLike | None = None, + logger: FilteringBoundLogger | None = None, + **kwargs, + ) -> Self: + """Fork and start a new subprocess to execute the given callback.""" + + # Use a closure to pass callback data to the child process. Note that this + # ONLY works because WatchedSubprocess.start() uses os.fork(), so the child + # inherits the parent's memory space and the variables are available directly. + def _target(): + _log = structlog.get_logger(logger_name="callback_runner") + + # If bundle info is provided, initialize the bundle and ensure its path is importable. + # This is needed for user-defined callbacks that live inside a DAG bundle rather than + # in an installed package or the plugins directory. + if bundle_info and bundle_info.name: + try: + from airflow.dag_processing.bundles.manager import DagBundlesManager + + bundle = DagBundlesManager().get_bundle( + name=bundle_info.name, + version=bundle_info.version, + ) + bundle.initialize() + if (bundle_path := str(bundle.path)) not in sys.path: + sys.path.append(bundle_path) + log.debug( + "Added bundle path to sys.path", bundle_name=bundle_info.name, path=bundle_path + ) + except Exception: + log.warning( + "Failed to initialize DAG bundle for callback", + bundle_name=bundle_info.name, + exc_info=True, + ) + + success, error_msg = execute_callback(callback_path, callback_kwargs, _log) + if not success: + _log.error("Callback failed", error=error_msg) + sys.exit(1) + + return super().start( + id=UUID(id) if not isinstance(id, UUID) else id, + target=_target, + logger=logger, + **kwargs, + ) + + def wait(self) -> int: + """ + Wait for the callback subprocess to complete. + + Mirrors the structure of ActivitySubprocess.wait() but without heartbeating, + task API state management, or log uploading. + """ + if self._exit_code is not None: + return self._exit_code + + try: + self._monitor_subprocess() + finally: + self.selector.close() + + self._exit_code = self._exit_code if self._exit_code is not None else 1 + return self._exit_code + + def _monitor_subprocess(self): + """ + Monitor the subprocess until it exits. + + A simplified version of ActivitySubprocess._monitor_subprocess() without heartbeating + or timeout handling, just process output monitoring and stuck-socket cleanup. + """ + while self._exit_code is None or self._open_sockets: + self._service_subprocess(max_wait_time=MIN_HEARTBEAT_INTERVAL) + + # If the process has exited but sockets remain open, apply a timeout + # to prevent hanging indefinitely on stuck sockets. + if self._exit_code is not None and self._open_sockets: + if ( + self._process_exit_monotonic + and time.monotonic() - self._process_exit_monotonic > SOCKET_CLEANUP_TIMEOUT + ): + log.warning( + "Process exited with open sockets; cleaning up after timeout", + pid=self.pid, + exit_code=self._exit_code, + socket_types=list(self._open_sockets.values()), + timeout_seconds=SOCKET_CLEANUP_TIMEOUT, + ) + self._cleanup_open_sockets() + + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: + """Handle incoming requests from the callback subprocess (currently none expected).""" + log.warning("Unexpected request from callback subprocess", msg=msg) + + +def _configure_logging(log_path: str) -> tuple[FilteringBoundLogger, BinaryIO]: + """Configure file-based logging for the callback subprocess.""" + from airflow.sdk.log import init_log_file, logging_processors + + log_file = init_log_file(log_path) + log_file_descriptor: BinaryIO = log_file.open("ab") + underlying_logger = structlog.BytesLogger(log_file_descriptor) + processors = logging_processors(json_output=True) + logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="callback").bind() + + return logger, log_file_descriptor + + +def supervise_callback( + *, + id: str, + callback_path: str, + callback_kwargs: dict, + log_path: str | None = None, + bundle_info: _BundleInfoLike | None = None, +) -> int: + """ + Run a single callback execution to completion in a supervised subprocess. + + :param id: Unique identifier for this callback execution. + :param callback_path: Dot-separated import path to the callback function or class. + :param callback_kwargs: Keyword arguments to pass to the callback. + :param log_path: Path to write logs, if required. + :param bundle_info: When provided, the bundle's path is added to sys.path so callbacks in Dag Bundles are importable. + :return: Exit code of the subprocess (0 = success). + """ + _make_process_nondumpable() + + start = time.monotonic() + + logger: FilteringBoundLogger + log_file_descriptor: BinaryIO | None = None + if log_path: + logger, log_file_descriptor = _configure_logging(log_path) + else: + # When no log file is requested, still use a callback-specific logger + # so logs are clearly separated from task logs. + logger = structlog.get_logger(logger_name="callback").bind() + + try: + process = CallbackSubprocess.start( + id=id, + callback_path=callback_path, + callback_kwargs=callback_kwargs, + bundle_info=bundle_info, + logger=logger, + subprocess_logs_to_stdout=True, + ) + + exit_code = process.wait() + end = time.monotonic() + log.info( + "Workload finished", + workload_type="ExecutorCallback", + workload_id=id, + exit_code=exit_code, + duration=end - start, + ) + if exit_code != 0: + raise RuntimeError(f"Callback subprocess exited with code {exit_code}") + return exit_code + finally: + if log_path and log_file_descriptor: + log_file_descriptor.close() diff --git a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py index 410c676eeb913..0c36d131fc796 100644 --- a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py +++ b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py @@ -40,9 +40,7 @@ def execute_workload(workload: ExecuteTask) -> None: - from airflow.executors import workloads - from airflow.sdk.configuration import conf - from airflow.sdk.execution_time.supervisor import supervise + from airflow.sdk.execution_time.supervisor import supervise_workload from airflow.sdk.log import configure_logging from airflow.settings import dispose_orm @@ -50,28 +48,10 @@ def execute_workload(workload: ExecuteTask) -> None: configure_logging(output=sys.stdout.buffer, json_output=True) - if not isinstance(workload, workloads.ExecuteTask): - raise ValueError(f"Executor does not know how to handle {type(workload)}") - log.info("Executing workload", workload=workload) - base_url = conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default - if base_url.startswith("/"): - base_url = f"http://localhost:8080{base_url}" - default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - server = conf.get("core", "execution_api_server_url", fallback=default_execution_api_server) - log.info("Connecting to server:", server=server) - - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=server, - log_path=workload.log_path, - sentry_integration=workload.sentry_integration, + supervise_workload( + workload, # Include the output of the task to stdout too, so that in process logs can be read from via the # kubeapi as pod logs. subprocess_logs_to_stdout=True, diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 1dfefee54047c..0029a7f46d6b1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -133,13 +133,12 @@ from structlog.typing import FilteringBoundLogger, WrappedLogger from typing_extensions import Self - from airflow.executors.workloads import BundleInfo + from airflow.executors.workloads import BundleInfo, ExecutorWorkload from airflow.sdk.bases.secrets_backend import BaseSecretsBackend from airflow.sdk.definitions.connection import Connection from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI - -__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise"] +__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise", "supervise_task", "supervise_workload"] log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor") @@ -2000,7 +1999,7 @@ def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLog return logger, log_file_descriptor -def supervise( +def supervise_task( *, ti: TaskInstance, bundle_info: BundleInfo, @@ -2107,8 +2106,9 @@ def supervise( exit_code = process.wait() end = time.monotonic() log.info( - "Task finished", - task_instance_id=str(ti.id), + "Workload finished", + workload_type="ExecuteTask", + workload_id=str(ti.id), exit_code=exit_code, duration=end - start, final_state=process.final_state, @@ -2120,3 +2120,99 @@ def supervise( if close_client and client: with suppress(Exception): client.close() + + +def supervise_workload( + workload: ExecutorWorkload, + *, + server: str | None = None, + dry_run: bool = False, + client: Client | None = None, + subprocess_logs_to_stdout: bool = False, + proctitle: str | None = None, +) -> int: + """ + Run any workload type to completion in a supervised subprocess. + + Dispatch to the appropriate supervisor based on workload type. Workload-specific + attributes (log_path, sentry_integration, bundle_info, etc.) are read from the + workload object itself. + + :param workload: The ``ExecutorWorkload`` to execute. + :param server: Base URL of the API server (used by task workloads). + :param dry_run: If True, execute without actual task execution (simulate run). + :param client: Optional preconfigured client for communication with the server. + :param subprocess_logs_to_stdout: Should task logs also be sent to stdout via the main logger. + :param proctitle: Process title to set for this workload. If not provided, defaults to + ``"airflow supervisor: "``. Executors may pass a custom title + that includes executor-specific context (e.g. team name). + :return: Exit code of the process. + """ + # Imports deferred to avoid an SDK/core dependency at module load time. + from airflow.executors.workloads.callback import ExecuteCallback + from airflow.executors.workloads.task import ExecuteTask + + try: + from setproctitle import setproctitle + + setproctitle(proctitle or f"airflow supervisor: {workload.display_name}") + except ImportError: + pass + + # Resolve server URL from config when not explicitly provided. + # For example, team-specific executors may wish to pass their own server URL. + if server is None: + base_url = conf.get("api", "base_url", fallback="/") + if base_url.startswith("/"): + base_url = f"http://localhost:8080{base_url}" + server = conf.get( + "core", + "execution_api_server_url", + fallback=f"{base_url.rstrip('/')}/execution/", + ) + + if isinstance(workload, ExecuteTask): + return supervise_task( + # workload.ti is a TaskInstanceDTO which duck-types as TaskInstance. + # TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] + bundle_info=workload.bundle_info, + dag_rel_path=workload.dag_rel_path, + token=workload.token, + server=server, + dry_run=dry_run, + log_path=workload.log_path, + subprocess_logs_to_stdout=subprocess_logs_to_stdout, + client=client, + sentry_integration=getattr(workload, "sentry_integration", ""), + ) + if isinstance(workload, ExecuteCallback): + from airflow.sdk.execution_time.callback_supervisor import supervise_callback + + return supervise_callback( + id=workload.callback.id, + callback_path=workload.callback.data.get("path", ""), + callback_kwargs=workload.callback.data.get("kwargs", {}), + log_path=workload.log_path, + bundle_info=workload.bundle_info, + ) + raise ValueError(f"Unknown workload type: {type(workload).__name__}") + + +def supervise(**kwargs) -> int: + """ + Call ``supervise_task()`` with a deprecation warning. + + This wrapper exists for backward compatibility with provider packages that may import ``supervise`` directly. + + .. deprecated:: + Use :func:`supervise_task` instead. + """ + import warnings + + warnings.warn( + "supervise() is deprecated, use supervise_task() instead.", + DeprecationWarning, + stacklevel=2, + ) + return supervise_task(**kwargs) diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py new file mode 100644 index 0000000000000..93b459d200e87 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for the callback supervisor module.""" + +from __future__ import annotations + +import pytest +import structlog + +from airflow.sdk.execution_time.callback_supervisor import execute_callback + + +def callback_no_args(): + """A simple callback that takes no arguments.""" + return "ok" + + +def callback_with_kwargs(arg1, arg2): + """A callback that accepts keyword arguments.""" + return f"{arg1}-{arg2}" + + +def callback_that_raises(): + """A callback that always raises.""" + raise ValueError("something went wrong") + + +class CallableClass: + """A class that returns a callable instance (like BaseNotifier).""" + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __call__(self, context): + return "notified" + + +class TestExecuteCallback: + @pytest.mark.parametrize( + ("path", "kwargs", "expect_success", "error_contains"), + [ + pytest.param( + f"{__name__}.callback_no_args", + {}, + True, + None, + id="successful_no_args", + ), + pytest.param( + f"{__name__}.callback_with_kwargs", + {"arg1": "hello", "arg2": "world"}, + True, + None, + id="successful_with_kwargs", + ), + pytest.param( + f"{__name__}.CallableClass", + {"msg": "alert"}, + True, + None, + id="callable_class_pattern", + ), + pytest.param( + "", + {}, + False, + "Callback path not found", + id="empty_path", + ), + pytest.param( + "nonexistent.module.function", + {}, + False, + "ModuleNotFoundError", + id="import_error", + ), + pytest.param( + f"{__name__}.callback_that_raises", + {}, + False, + "ValueError", + id="execution_error", + ), + pytest.param( + f"{__name__}.nonexistent_function_xyz", + {}, + False, + "AttributeError", + id="attribute_error", + ), + ], + ) + def test_execute_callback(self, path, kwargs, expect_success, error_contains): + log = structlog.get_logger() + success, error = execute_callback(path, kwargs, log) + + assert success is expect_success + if error_contains: + assert error_contains in error + else: + assert error is None diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index b486ce7776611..e091101f84190 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -143,7 +143,7 @@ _remote_logging_conn, process_log_messages_from_subprocess, set_supervisor_comms, - supervise, + supervise_task, ) from airflow.sdk.execution_time.task_runner import run @@ -230,7 +230,7 @@ def test_supervise( with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): with expectation: - supervise(**kw) + supervise_task(**kw) @pytest.mark.usefixtures("disable_capturing") @@ -624,7 +624,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker bundle_info = BundleInfo(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): - exit_code = supervise( + exit_code = supervise_task( ti=ti, dag_rel_path=dagfile_path, token="", @@ -679,7 +679,7 @@ def mock_monotonic(): patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)), patch("airflow.sdk.execution_time.supervisor.time.monotonic", side_effect=mock_monotonic), ): - exit_code = supervise( + exit_code = supervise_task( ti=ti, dag_rel_path="super_basic_deferred_run.py", token="", @@ -723,12 +723,13 @@ def mock_monotonic(): "exit_code": 0, "duration": 0.0, "final_state": "deferred", - "event": "Task finished", + "event": "Workload finished", + "workload_type": "ExecuteTask", + "workload_id": str(ti.id), "timestamp": mocker.ANY, "level": "info", "logger": "supervisor", "loc": mocker.ANY, - "task_instance_id": str(ti.id), } in captured_logs @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"])