From 2f6a8e1dd800b87a2924cb49c25e8165d3fb9b96 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 27 Feb 2026 21:20:47 -0800 Subject: [PATCH 01/19] Move ExecutorCallback execution into a supervised process and make the entire flow more generic to account for future workload types --- .../src/airflow/executors/base_executor.py | 9 +- .../src/airflow/executors/local_executor.py | 124 ++++---- .../airflow/executors/workloads/__init__.py | 9 +- .../src/airflow/executors/workloads/base.py | 60 +++- .../airflow/executors/workloads/callback.py | 19 ++ .../src/airflow/executors/workloads/task.py | 19 ++ .../unit/executors/test_base_executor.py | 40 +-- .../unit/executors/test_local_executor.py | 78 ++++- .../celery/executors/celery_executor_utils.py | 13 +- .../src/airflow/providers/edge3/cli/worker.py | 2 +- .../sdk/execution_time/callback_supervisor.py | 270 ++++++++++++++++++ .../airflow/sdk/execution_time/supervisor.py | 5 +- .../test_callback_supervisor.py | 101 +++++++ .../execution_time/test_supervisor.py | 5 +- 14 files changed, 624 insertions(+), 130 deletions(-) create mode 100644 task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py create mode 100644 task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index d67c25c7bafaa..385964d42b156 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -51,6 +51,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand from airflow.executors.executor_utils import ExecutorName + from airflow.executors.workloads import ExecutorWorkload from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -219,7 +220,7 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey): """Add an event to the log table.""" self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra)) - def queue_workload(self, workload: workloads.All, session: Session) -> None: + def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None: if isinstance(workload, workloads.ExecuteTask): ti = workload.ti self.queued_tasks[ti.key] = workload @@ -237,7 +238,7 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None: f"Workload must be one of: ExecuteTask, ExecuteCallback." ) - def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]: + def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, ExecutorWorkload]]: """ Select and return the next batch of workloads to schedule, respecting priority policy. @@ -246,7 +247,7 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, :param open_slots: Number of available execution slots """ - workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = [] + workloads_to_schedule: list[tuple[WorkloadKey, ExecutorWorkload]] = [] if self.queued_callbacks: for key, workload in self.queued_callbacks.items(): @@ -262,7 +263,7 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, return workloads_to_schedule - def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: + def _process_workloads(self, workloads: Sequence[ExecutorWorkload]) -> None: """ Process the given workloads. diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 9b5939a0bd2e7..5da04009df104 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -37,8 +37,6 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor -from airflow.executors.workloads.callback import execute_callback_workload -from airflow.utils.state import CallbackState, TaskInstanceState # add logger to parameter of setproctitle to support logging if sys.platform == "darwin": @@ -51,6 +49,7 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger + from airflow.executors.workloads import ExecutorWorkload from airflow.executors.workloads.types import WorkloadResultType @@ -66,7 +65,7 @@ def _get_executor_process_title_prefix(team_name: str | None) -> str: def _run_worker( logger_name: str, - input: SimpleQueue[workloads.All | None], + input: SimpleQueue[ExecutorWorkload | None], output: Queue[WorkloadResultType], unread_messages: multiprocessing.sharedctypes.Synchronized[int], team_conf, @@ -99,73 +98,60 @@ def _run_worker( with unread_messages: unread_messages.value -= 1 - # Handle different workload types - if isinstance(workload, workloads.ExecuteTask): - try: - _execute_work(log, workload, team_conf) - output.put((workload.ti.key, TaskInstanceState.SUCCESS, None)) - except Exception as e: - log.exception("Task execution failed.") - output.put((workload.ti.key, TaskInstanceState.FAILED, e)) - - elif isinstance(workload, workloads.ExecuteCallback): - output.put((workload.callback.id, CallbackState.RUNNING, None)) - try: - _execute_callback(log, workload, team_conf) - output.put((workload.callback.id, CallbackState.SUCCESS, None)) - except Exception as e: - log.exception("Callback execution failed") - output.put((workload.callback.id, CallbackState.FAILED, e)) - - else: - raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") - - -def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None: - """ - Execute command received and stores result state in queue. + try: + _execute_workload(log, workload, team_conf) + output.put((workload.key, workload.success_state, None)) + except Exception as e: + log.exception("Workload execution failed.", workload_type=type(workload).__name__) + output.put((workload.key, workload.failure_state, e)) - :param log: Logger instance - :param workload: The workload to execute - :param team_conf: Team-specific executor configuration - """ - from airflow.sdk.execution_time.supervisor import supervise - - setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.ti.id}", log) - - base_url = team_conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default - if base_url.startswith("/"): - base_url = f"http://localhost:8080{base_url}" - default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - - # This will return the exit code of the task process, but we don't care about that, just if the - # _supervisor_ had an error reporting the state back (which will result in an exception.) - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), - log_path=workload.log_path, - ) - - -def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_conf) -> None: + +def _execute_workload(log: Logger, workload: ExecutorWorkload, team_conf) -> None: """ - Execute a callback workload. + Execute any workload type in a supervised subprocess. + + All workload types are run in a supervised child process, providing process isolation, + stdout/stderr capture, signal handling, and crash detection. :param log: Logger instance - :param workload: The ExecuteCallback workload to execute + :param workload: The workload to execute (ExecuteTask or ExecuteCallback) :param team_conf: Team-specific executor configuration """ - setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log) - - success, error_msg = execute_callback_workload(workload.callback, log) - - if not success: - raise RuntimeError(error_msg or "Callback execution failed") + setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.display_name}", log) + + if isinstance(workload, workloads.ExecuteTask): + from airflow.sdk.execution_time.supervisor import supervise + + base_url = team_conf.get("api", "base_url", fallback="/") + # If it's a relative URL, use localhost:8080 as the default + if base_url.startswith("/"): + base_url = f"http://localhost:8080{base_url}" + default_execution_api_server = f"{base_url.rstrip('/')}/execution/" + + # This will return the exit code of the task process, but we don't care about that, just if the + # _supervisor_ had an error reporting the state back (which will result in an exception.) + supervise( + # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, + token=workload.token, + server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), + log_path=workload.log_path, + ) + elif isinstance(workload, workloads.ExecuteCallback): + from airflow.sdk.execution_time.callback_supervisor import supervise_callback + + exit_code = supervise_callback( + id=workload.callback.id, + callback_path=workload.callback.data.get("path", ""), + callback_kwargs=workload.callback.data.get("kwargs", {}), + log_path=workload.log_path, + ) + if exit_code != 0: + raise RuntimeError(f"Callback subprocess exited with code {exit_code}") + else: + raise ValueError(f"LocalExecutor does not know how to execute {type(workload).__name__!r}") class LocalExecutor(BaseExecutor): @@ -184,7 +170,7 @@ class LocalExecutor(BaseExecutor): serve_logs: bool = True supports_callbacks: bool = True - activity_queue: SimpleQueue[workloads.All | None] + activity_queue: SimpleQueue[ExecutorWorkload | None] result_queue: SimpleQueue[WorkloadResultType] workers: dict[int, multiprocessing.Process] _unread_messages: multiprocessing.sharedctypes.Synchronized[int] @@ -331,11 +317,9 @@ def terminate(self): def _process_workloads(self, workload_list): for workload in workload_list: self.activity_queue.put(workload) - # Remove from appropriate queue based on workload type - if isinstance(workload, workloads.ExecuteTask): - del self.queued_tasks[workload.ti.key] - elif isinstance(workload, workloads.ExecuteCallback): - del self.queued_callbacks[workload.callback.id] + # Remove from the appropriate queue using the workload's key. + self.queued_tasks.pop(workload.key, None) + self.queued_callbacks.pop(workload.key, None) with self._unread_messages: self._unread_messages.value += len(workload_list) self._check_workers() diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py index 462e38ad0aaac..bc555326b5edf 100644 --- a/airflow-core/src/airflow/executors/workloads/__init__.py +++ b/airflow-core/src/airflow/executors/workloads/__init__.py @@ -34,6 +34,12 @@ TaskInstance = TaskInstanceDTO +ExecutorWorkload = Annotated[ + ExecuteTask | ExecuteCallback, + Field(discriminator="type"), +] +"""Workload types that can be sent to executors (excludes RunTrigger, which is handled by the triggerer).""" + __all__ = [ "All", "BaseWorkload", @@ -41,6 +47,7 @@ "CallbackFetchMethod", "ExecuteCallback", "ExecuteTask", + "ExecutorWorkload", "TaskInstance", "TaskInstanceDTO", -] +] \ No newline at end of file diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 97cf16ebaf64d..bd122aa5ac5d6 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -19,8 +19,8 @@ from __future__ import annotations import os -from abc import ABC -from typing import TYPE_CHECKING +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field @@ -83,3 +83,59 @@ class BaseDagBundleWorkload(BaseWorkloadSchema, ABC): dag_rel_path: os.PathLike[str] # Filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`) bundle_info: BundleInfo log_path: str | None # Rendered relative log filename template the task logs should be written to. + + @property + @abstractmethod + def key(self) -> Any: + """ + Return the unique key identifying this workload instance. + + Used by executors for tracking queued/running workloads and reporting results. + Must be a hashable value suitable for use in sets and as dict keys. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement key") + + @property + @abstractmethod + def display_name(self) -> str: + """ + Return a human-readable name for this workload, suitable for logging and process titles. + + Used by executors to set worker process titles and log messages. + + Must be implemented by subclasses. + + Example:: + + # For a task workload: + return str(self.ti.id) # "4d828a62-a417-4936-a7a6-2b3fabacecab" + + # For a callback workload: + return str(self.callback.id) # "12345678-1234-5678-1234-567812345678" + + # Results in process titles like: + # "airflow worker -- LocalExecutor: 4d828a62-a417-4936-a7a6-2b3fabacecab" + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement display_name") + + @property + @abstractmethod + def success_state(self) -> Any: + """ + Return the state value representing successful completion of this workload type. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement success_state") + + @property + @abstractmethod + def failure_state(self) -> Any: + """ + Return the state value representing failed completion of this workload type. + + Must be implemented by subclasses. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement failure_state") diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 273c55953675b..b0496d9d05940 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -28,6 +28,7 @@ from pydantic import BaseModel, Field, field_validator from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo +from airflow.utils.state import CallbackState if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -75,6 +76,24 @@ class ExecuteCallback(BaseDagBundleWorkload): type: Literal["ExecuteCallback"] = Field(init=False, default="ExecuteCallback") + @property + def key(self) -> CallbackKey: + """Return the callback key for this workload.""" + return self.callback.key + + @property + def display_name(self) -> str: + """Return the callback ID as a display name.""" + return str(self.callback.id) + + @property + def success_state(self) -> CallbackState: + return CallbackState.SUCCESS + + @property + def failure_state(self) -> CallbackState: + return CallbackState.FAILED + @classmethod def make( cls, diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index 4ca8c310fb5c2..3a95aab831061 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -25,6 +25,7 @@ from pydantic import BaseModel, Field from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator @@ -73,6 +74,24 @@ class ExecuteTask(BaseDagBundleWorkload): type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") + @property + def key(self) -> TaskInstanceKey: + """Return the TaskInstanceKey for this workload.""" + return self.ti.key + + @property + def display_name(self) -> str: + """Return the task instance ID as a display name.""" + return str(self.ti.id) + + @property + def success_state(self) -> TaskInstanceState: + return TaskInstanceState.SUCCESS + + @property + def failure_state(self) -> TaskInstanceState: + return TaskInstanceState.FAILED + @classmethod def make( cls, diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index fa0f311d018fe..dd90f786cb4c0 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -36,10 +36,11 @@ from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType from airflow.executors.local_executor import LocalExecutor from airflow.executors.workloads.base import BundleInfo -from airflow.executors.workloads.callback import CallbackDTO, execute_callback_workload +from airflow.executors.workloads.callback import CallbackDTO from airflow.models.callback import CallbackFetchMethod from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.sdk import BaseOperator +from airflow.sdk.execution_time.callback_supervisor import execute_callback from airflow.serialization.definitions.baseoperator import SerializedBaseOperator from airflow.utils.state import State, TaskInstanceState @@ -662,63 +663,34 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): class TestExecuteCallbackWorkload: def test_execute_function_callback_success(self): - callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", - fetch_method=CallbackFetchMethod.IMPORT_PATH, - data={ - "path": "builtins.dict", - "kwargs": {"a": 1, "b": 2, "c": 3}, - }, - ) log = structlog.get_logger() - success, error = execute_callback_workload(callback_data, log) + success, error = execute_callback("builtins.dict", {"a": 1, "b": 2, "c": 3}, log) assert success is True assert error is None def test_execute_callback_missing_path(self): - callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", - fetch_method=CallbackFetchMethod.IMPORT_PATH, - data={"kwargs": {}}, # Missing 'path' - ) log = structlog.get_logger() - success, error = execute_callback_workload(callback_data, log) + success, error = execute_callback("", {}, log) assert success is False assert "Callback path not found" in error def test_execute_callback_import_error(self): - callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", - fetch_method=CallbackFetchMethod.IMPORT_PATH, - data={ - "path": "nonexistent.module.function", - "kwargs": {}, - }, - ) log = structlog.get_logger() - success, error = execute_callback_workload(callback_data, log) + success, error = execute_callback("nonexistent.module.function", {}, log) assert success is False assert "ModuleNotFoundError" in error def test_execute_callback_execution_error(self): # Use a function that will raise an error; len() requires an argument - callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", - fetch_method=CallbackFetchMethod.IMPORT_PATH, - data={ - "path": "builtins.len", - "kwargs": {}, - }, - ) log = structlog.get_logger() - success, error = execute_callback_workload(callback_data, log) + success, error = execute_callback("builtins.len", {}, log) assert success is False assert "TypeError" in error diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 59afffe6833fe..f20d111b1ac0f 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -28,7 +28,8 @@ from airflow._shared.timezones import timezone from airflow.executors import workloads -from airflow.executors.local_executor import LocalExecutor, _execute_work +from airflow.executors.base_executor import ExecutorConf +from airflow.executors.local_executor import LocalExecutor, _execute_workload from airflow.executors.workloads.base import BundleInfo from airflow.executors.workloads.callback import CallbackDTO from airflow.executors.workloads.task import TaskInstanceDTO @@ -54,6 +55,17 @@ ) +def _make_mock_task_workload(): + """Create a MagicMock that passes isinstance checks for ExecuteTask and has required attributes.""" + task_workload = mock.MagicMock(spec=workloads.ExecuteTask) + task_workload.ti = mock.MagicMock() + task_workload.dag_rel_path = "some/path" + task_workload.bundle_info = mock.MagicMock() + task_workload.token = "test_token" + task_workload.log_path = None + return task_workload + + class TestLocalExecutor: """ When the executor is started, end() must be called before the test finishes. @@ -268,7 +280,7 @@ def test_execution_api_server_url_config(self, mock_supervise, conf_values, expe with conf_vars(conf_values): team_conf = ExecutorConf(team_name=None) - _execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=team_conf) + _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=team_conf) mock_supervise.assert_called_with( ti=mock.ANY, @@ -303,7 +315,7 @@ def test_team_and_global_config_isolation(self, mock_supervise): with conf_vars(config_overrides): # Test team-specific config team_conf = ExecutorConf(team_name=team_name) - _execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=team_conf) + _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=team_conf) # Verify team-specific server URL was used assert mock_supervise.call_count == 1 @@ -314,7 +326,7 @@ def test_team_and_global_config_isolation(self, mock_supervise): # Test global config (no team) global_conf = ExecutorConf(team_name=None) - _execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=global_conf) + _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=global_conf) # Verify default server URL was used assert mock_supervise.call_count == 1 @@ -377,18 +389,18 @@ def test_global_executor_without_team_name(self): class TestLocalExecutorCallbackSupport: + CALLBACK_UUID = "12345678-1234-5678-1234-567812345678" + def test_supports_callbacks_flag_is_true(self): executor = LocalExecutor() assert executor.supports_callbacks is True - @skip_non_fork_mp_start - @mock.patch("airflow.executors.workloads.callback.execute_callback_workload") - def test_process_callback_workload(self, mock_execute_callback): - mock_execute_callback.return_value = (True, None) - + @skip_spawn_mp_start + def test_process_callback_workload_queue_management(self): + """Test that _process_workloads correctly removes callbacks from queued_callbacks.""" executor = LocalExecutor(parallelism=1) callback_data = CallbackDTO( - id="12345678-1234-5678-1234-567812345678", + id=self.CALLBACK_UUID, fetch_method=CallbackFetchMethod.IMPORT_PATH, data={"path": "test.func", "kwargs": {}}, ) @@ -411,3 +423,49 @@ def test_process_callback_workload(self, mock_execute_callback): finally: executor.end() + + @mock.patch("airflow.sdk.execution_time.callback_supervisor.supervise_callback", return_value=0) + def test_execute_workload_calls_supervise_callback(self, mock_supervise_callback): + + callback_data = CallbackDTO( + id=self.CALLBACK_UUID, + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.module.my_callback", "kwargs": {"arg1": "val1"}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + _execute_workload(log=mock.ANY, workload=callback_workload, team_conf=ExecutorConf(team_name=None)) + + mock_supervise_callback.assert_called_once_with( + id=self.CALLBACK_UUID, + callback_path="test.module.my_callback", + callback_kwargs={"arg1": "val1"}, + log_path="test.log", + ) + + @mock.patch("airflow.sdk.execution_time.callback_supervisor.supervise_callback", return_value=1) + def test_execute_workload_raises_on_callback_failure(self, mock_supervise_callback): + + callback_data = CallbackDTO( + id=self.CALLBACK_UUID, + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.module.my_callback", "kwargs": {}}, + ) + callback_workload = workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + with pytest.raises(RuntimeError, match="Callback subprocess exited with code 1"): + _execute_workload( + log=mock.ANY, workload=callback_workload, team_conf=ExecutorConf(team_name=None) + ) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 768f2d9dbf646..639285dc9b064 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -54,7 +54,7 @@ from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager # type:ignore[no-redef] if AIRFLOW_V_3_2_PLUS: - from airflow.executors.workloads.callback import execute_callback_workload + from airflow.sdk.execution_time.callback_supervisor import supervise_callback log = logging.getLogger(__name__) @@ -219,9 +219,14 @@ def execute_workload(input: str) -> None: log_path=workload.log_path, ) elif isinstance(workload, workloads.ExecuteCallback): - success, error_msg = execute_callback_workload(workload.callback, log) - if not success: - raise RuntimeError(error_msg or "Callback execution failed") + exit_code = supervise_callback( + id=workload.callback.id, + callback_path=workload.callback.data.get("path", ""), + callback_kwargs=workload.callback.data.get("kwargs", {}), + log_path=workload.log_path, + ) + if exit_code != 0: + raise RuntimeError(f"Callback subprocess exited with code {exit_code}") else: raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py b/providers/edge3/src/airflow/providers/edge3/cli/worker.py index f0cffa96c05ea..4f7628151eb02 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py @@ -212,7 +212,7 @@ def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) - try: supervise( # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - # Same like in airflow/executors/local_executor.py:_execute_work() + # Same like in airflow/executors/local_executor.py:_execute_workload() ti=ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py new file mode 100644 index 0000000000000..a7b080c301fef --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -0,0 +1,270 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervised execution of callback workloads.""" + +from __future__ import annotations + +import os +import time +from importlib import import_module +from typing import TYPE_CHECKING, BinaryIO, ClassVar + +import attrs +import structlog +from pydantic import TypeAdapter + +from airflow.sdk.execution_time.supervisor import WatchedSubprocess + +if TYPE_CHECKING: + from collections.abc import Callable + + from structlog.typing import FilteringBoundLogger + from typing_extensions import Self + +__all__ = ["CallbackSubprocess", "execute_callback", "supervise_callback"] + +log: FilteringBoundLogger = structlog.get_logger(logger_name="callback_supervisor") + + +def execute_callback( + callback_path: str, + callback_kwargs: dict, + log, +) -> tuple[bool, str | None]: + """ + Execute a callback function by importing and calling it, returning the success state. + + Supports two patterns: + 1. Functions - called directly with kwargs + 2. Classes that return callable instances (like BaseNotifier) - instantiated then called with context + + Example: + # Function callback + execute_callback("my_module.alert_func", {"msg": "Alert!", "context": {...}}, log) + + # Notifier callback + execute_callback("airflow.providers.slack...SlackWebhookNotifier", {"text": "Alert!"}, log) + + :param callback_path: Dot-separated import path to the callback function or class. + :param callback_kwargs: Keyword arguments to pass to the callback. + :param log: Logger instance for recording execution. + :return: Tuple of (success: bool, error_message: str | None) + """ + if not callback_path: + return False, "Callback path not found." + + try: + # Import the callback callable + # Expected format: "module.path.to.function_or_class" + module_path, function_name = callback_path.rsplit(".", 1) + module = import_module(module_path) + callback_callable = getattr(module, function_name) + + log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) + + # If the callback is a callable, call it. If it is a class, instantiate it. + result = callback_callable(**callback_kwargs) + + # If the callback is a class then it is now instantiated and callable, call it. + if callable(result): + context = callback_kwargs.get("context", {}) + log.debug("Calling result with context for %s", callback_path) + result = result(context) + + log.info("Callback %s executed successfully.", callback_path) + return True, None + + except Exception as e: + error_msg = f"Callback execution failed: {type(e).__name__}: {str(e)}" + log.exception("Callback %s(%s) execution failed: %s", callback_path, callback_kwargs, error_msg) + return False, error_msg + + +def _callback_subprocess_main(): + """ + Entry point for the callback subprocess, runs after fork. + + Reads the callback path and kwargs from environment variables, + executes the callback, and exits with an appropriate code. + """ + import json + import sys + + log = structlog.get_logger(logger_name="callback_runner") + + callback_path = os.environ.get("_AIRFLOW_CALLBACK_PATH", "") + callback_kwargs_json = os.environ.get("_AIRFLOW_CALLBACK_KWARGS", "{}") + + if not callback_path: + print("No callback path found in environment", file=sys.stderr) + sys.exit(1) + + try: + callback_kwargs = json.loads(callback_kwargs_json) + except Exception: + log.exception("Failed to deserialize callback kwargs") + sys.exit(1) + + success, error_msg = execute_callback(callback_path, callback_kwargs, log) + if not success: + log.error("Callback failed", error=error_msg) + sys.exit(1) + + +# An empty message set — callbacks don't send requests back to the supervisor (yet). +_EmptyMessage: TypeAdapter[None] = TypeAdapter(None) + + +@attrs.define(kw_only=True) +class CallbackSubprocess(WatchedSubprocess): + """ + Supervised subprocess for executing callbacks. + + Uses the WatchedSubprocess infrastructure for fork/monitor/signal handling + while keeping a simple lifecycle: start, run callback, exit. + """ + + decoder: ClassVar[TypeAdapter] = _EmptyMessage + + @classmethod + def start( # type: ignore[override] + cls, + *, + id: str, + callback_path: str, + callback_kwargs: dict, + target: Callable[[], None] = _callback_subprocess_main, + logger: FilteringBoundLogger | None = None, + **kwargs, + ) -> Self: + """Fork and start a new subprocess to execute the given callback.""" + import json + from datetime import date, datetime + from uuid import UUID + + class _ExtendedEncoder(json.JSONEncoder): + """Handle types that stdlib json can't serialize (UUID, datetime, etc.).""" + + def default(self, o): + if isinstance(o, UUID): + return str(o) + if isinstance(o, datetime): + return o.isoformat() + if isinstance(o, date): + return o.isoformat() + if hasattr(o, "__str__"): + return str(o) + return super().default(o) + + # Pass the callback data to the child process via environment variables. + # These are set before fork so the child inherits them, and cleaned up in the parent after. + os.environ["_AIRFLOW_CALLBACK_PATH"] = callback_path + os.environ["_AIRFLOW_CALLBACK_KWARGS"] = json.dumps(callback_kwargs, cls=_ExtendedEncoder) + try: + proc: Self = super().start( + id=id, + target=target, + logger=logger, + **kwargs, + ) + finally: + # Clean up the env vars in the parent process + os.environ.pop("_AIRFLOW_CALLBACK_PATH", None) + os.environ.pop("_AIRFLOW_CALLBACK_KWARGS", None) + return proc + + def wait(self) -> int: + """ + Wait for the callback subprocess to complete. + + A simplified monitor loop compared to ActivitySubprocess — no heartbeating, + no task API state management. Just monitors process output and waits for exit. + """ + if self._exit_code is not None: + return self._exit_code + + try: + while self._exit_code is None or self._open_sockets: + self._service_subprocess(max_wait_time=5.0) + finally: + self.selector.close() + + self._exit_code = self._exit_code if self._exit_code is not None else 1 + return self._exit_code + + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: + """Handle incoming requests from the callback subprocess (currently none expected).""" + log.warning("Unexpected request from callback subprocess", msg=msg) + + +def _configure_logging(log_path: str) -> tuple[FilteringBoundLogger, BinaryIO]: + """Configure file-based logging for the callback subprocess.""" + from airflow.sdk.log import init_log_file, logging_processors + + log_file = init_log_file(log_path) + log_file_descriptor: BinaryIO = log_file.open("ab") + underlying_logger = structlog.BytesLogger(log_file_descriptor) + processors = logging_processors(json_output=True) + logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="callback").bind() + + return logger, log_file_descriptor + + +def supervise_callback( + *, + id: str, + callback_path: str, + callback_kwargs: dict, + log_path: str | None = None, +) -> int: + """ + Run a single callback execution to completion in a supervised subprocess. + + :param id: Unique identifier for this callback execution. + :param callback_path: Dot-separated import path to the callback function or class. + :param callback_kwargs: Keyword arguments to pass to the callback. + :param log_path: Path to write logs, if required. + :return: Exit code of the subprocess (0 = success). + """ + start = time.monotonic() + + logger: FilteringBoundLogger | None = None + log_file_descriptor: BinaryIO | None = None + if log_path: + logger, log_file_descriptor = _configure_logging(log_path) + + try: + process = CallbackSubprocess.start( + id=id, + callback_path=callback_path, + callback_kwargs=callback_kwargs, + logger=logger, + ) + + exit_code = process.wait() + end = time.monotonic() + log.info( + "Workload finished", + workload_type="ExecutorCallback", + workload_id=id, + exit_code=exit_code, + duration=end - start, + ) + return exit_code + finally: + if log_path and log_file_descriptor: + log_file_descriptor.close() diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 1dfefee54047c..222bba34adf38 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -2107,8 +2107,9 @@ def supervise( exit_code = process.wait() end = time.monotonic() log.info( - "Task finished", - task_instance_id=str(ti.id), + "Workload finished", + workload_type="ExecuteTask", + workload_id=str(ti.id), exit_code=exit_code, duration=end - start, final_state=process.final_state, diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py new file mode 100644 index 0000000000000..65457dbfaa28b --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for the callback supervisor module.""" + +from __future__ import annotations + +import pytest +import structlog + +from airflow.sdk.execution_time.callback_supervisor import execute_callback + + +def callback_no_args(): + """A simple callback that takes no arguments.""" + return "ok" + + +def callback_with_kwargs(arg1, arg2): + """A callback that accepts keyword arguments.""" + return f"{arg1}-{arg2}" + + +def callback_that_raises(): + """A callback that always raises.""" + raise ValueError("something went wrong") + + +class CallableClass: + """A class that returns a callable instance (like BaseNotifier).""" + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __call__(self, context): + return "notified" + + +@pytest.fixture +def log(): + return structlog.get_logger() + + +class TestExecuteCallback: + def test_successful_callback_no_args(self, log): + success, error = execute_callback(f"{__name__}.callback_no_args", {}, log) + + assert success is True + assert error is None + + def test_successful_callback_with_kwargs(self, log): + success, error = execute_callback( + f"{__name__}.callback_with_kwargs", {"arg1": "hello", "arg2": "world"}, log + ) + + assert success is True + assert error is None + + def test_empty_path_returns_failure(self, log): + success, error = execute_callback("", {}, log) + + assert success is False + assert "Callback path not found" in error + + def test_import_error_returns_failure(self, log): + success, error = execute_callback("nonexistent.module.function", {}, log) + + assert success is False + assert "ModuleNotFoundError" in error + + def test_execution_error_returns_failure(self, log): + success, error = execute_callback(f"{__name__}.callback_that_raises", {}, log) + + assert success is False + assert "ValueError" in error + + def test_callable_class_pattern(self, log): + """Test the class-that-returns-callable pattern (like BaseNotifier).""" + success, error = execute_callback(f"{__name__}.CallableClass", {"msg": "alert"}, log) + + assert success is True + assert error is None + + def test_attribute_error_for_nonexistent_function(self, log): + success, error = execute_callback(f"{__name__}.nonexistent_function_xyz", {}, log) + + assert success is False + assert "AttributeError" in error diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index b486ce7776611..bcea78269f3fb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -723,12 +723,13 @@ def mock_monotonic(): "exit_code": 0, "duration": 0.0, "final_state": "deferred", - "event": "Task finished", + "event": "Workload finished", + "workload_type": "ExecuteTask", + "workload_id": str(ti.id), "timestamp": mocker.ANY, "level": "info", "logger": "supervisor", "loc": mocker.ANY, - "task_instance_id": str(ti.id), } in captured_logs @pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"]) From 0f671a1157eb8213d9d3b2cff482cd08089d46c1 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 2 Mar 2026 14:13:11 -0800 Subject: [PATCH 02/19] merge conflict whitespace oops --- airflow-core/src/airflow/executors/workloads/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py index bc555326b5edf..e0af7df2922eb 100644 --- a/airflow-core/src/airflow/executors/workloads/__init__.py +++ b/airflow-core/src/airflow/executors/workloads/__init__.py @@ -50,4 +50,4 @@ "ExecutorWorkload", "TaskInstance", "TaskInstanceDTO", -] \ No newline at end of file +] From 64fa0f5f5527d4ace647a9f860b7fac243ba7c9b Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 2 Mar 2026 16:34:12 -0800 Subject: [PATCH 03/19] moved runtime error into supervise_callback() --- airflow-core/src/airflow/executors/local_executor.py | 4 +--- airflow-core/tests/unit/executors/test_local_executor.py | 6 ++++-- .../providers/celery/executors/celery_executor_utils.py | 4 +--- .../src/airflow/sdk/execution_time/callback_supervisor.py | 2 ++ 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 5da04009df104..dd27580b0df78 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -142,14 +142,12 @@ def _execute_workload(log: Logger, workload: ExecutorWorkload, team_conf) -> Non elif isinstance(workload, workloads.ExecuteCallback): from airflow.sdk.execution_time.callback_supervisor import supervise_callback - exit_code = supervise_callback( + supervise_callback( id=workload.callback.id, callback_path=workload.callback.data.get("path", ""), callback_kwargs=workload.callback.data.get("kwargs", {}), log_path=workload.log_path, ) - if exit_code != 0: - raise RuntimeError(f"Callback subprocess exited with code {exit_code}") else: raise ValueError(f"LocalExecutor does not know how to execute {type(workload).__name__!r}") diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index f20d111b1ac0f..acf700ba5d217 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -449,9 +449,11 @@ def test_execute_workload_calls_supervise_callback(self, mock_supervise_callback log_path="test.log", ) - @mock.patch("airflow.sdk.execution_time.callback_supervisor.supervise_callback", return_value=1) + @mock.patch( + "airflow.sdk.execution_time.callback_supervisor.supervise_callback", + side_effect=RuntimeError("Callback subprocess exited with code 1"), + ) def test_execute_workload_raises_on_callback_failure(self, mock_supervise_callback): - callback_data = CallbackDTO( id=self.CALLBACK_UUID, fetch_method=CallbackFetchMethod.IMPORT_PATH, diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 639285dc9b064..82fcf46e152a1 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -219,14 +219,12 @@ def execute_workload(input: str) -> None: log_path=workload.log_path, ) elif isinstance(workload, workloads.ExecuteCallback): - exit_code = supervise_callback( + supervise_callback( id=workload.callback.id, callback_path=workload.callback.data.get("path", ""), callback_kwargs=workload.callback.data.get("kwargs", {}), log_path=workload.log_path, ) - if exit_code != 0: - raise RuntimeError(f"Callback subprocess exited with code {exit_code}") else: raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index a7b080c301fef..9142d4bb36b2c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -264,6 +264,8 @@ def supervise_callback( exit_code=exit_code, duration=end - start, ) + if exit_code != 0: + raise RuntimeError(f"Callback subprocess exited with code {exit_code}") return exit_code finally: if log_path and log_file_descriptor: From 3b82e1ef608ba3dfe385f093c352b148e91881ac Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 2 Mar 2026 16:41:33 -0800 Subject: [PATCH 04/19] add some context tot eh comment about not supporting message types --- .../src/airflow/sdk/execution_time/callback_supervisor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 9142d4bb36b2c..95c4c38805a47 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -125,7 +125,11 @@ def _callback_subprocess_main(): sys.exit(1) -# An empty message set — callbacks don't send requests back to the supervisor (yet). +# An empty message set; the callback subprocess doesn't currently communicate back to the +# supervisor. This means callback code cannot access runtime services like Connection.get() +# or Variable.get() which require the supervisor to pass requests to the API server. +# To enable this, add the needed message types here and implement _handle_request accordingly. +# See ActivitySubprocess.decoder in supervisor.py for the full task message set and examples. _EmptyMessage: TypeAdapter[None] = TypeAdapter(None) From 2596cb9a01e2524ffb69ccdb031c456938678495 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 2 Mar 2026 16:52:08 -0800 Subject: [PATCH 05/19] Copilot review fixes --- airflow-core/tests/unit/executors/test_local_executor.py | 4 ++-- .../src/airflow/sdk/execution_time/callback_supervisor.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index acf700ba5d217..ae801124e6f6a 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -58,9 +58,9 @@ def _make_mock_task_workload(): """Create a MagicMock that passes isinstance checks for ExecuteTask and has required attributes.""" task_workload = mock.MagicMock(spec=workloads.ExecuteTask) - task_workload.ti = mock.MagicMock() + task_workload.ti = mock.MagicMock(spec=TaskInstanceDTO) task_workload.dag_rel_path = "some/path" - task_workload.bundle_info = mock.MagicMock() + task_workload.bundle_info = mock.MagicMock(spec=BundleInfo) task_workload.token = "test_token" task_workload.log_path = None return task_workload diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 95c4c38805a47..643359af27b2c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -246,10 +246,14 @@ def supervise_callback( """ start = time.monotonic() - logger: FilteringBoundLogger | None = None + logger: FilteringBoundLogger log_file_descriptor: BinaryIO | None = None if log_path: logger, log_file_descriptor = _configure_logging(log_path) + else: + # When no log file is requested, still use a callback-specific logger + # so logs are clearly separated from task logs. + logger = structlog.get_logger(logger_name="callback").bind() try: process = CallbackSubprocess.start( From 170348f04d9046a199ed591a75ae67d5e6a9f1ea Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 2 Mar 2026 17:26:02 -0800 Subject: [PATCH 06/19] Improve callback display name --- airflow-core/src/airflow/executors/workloads/callback.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index b0496d9d05940..6f25ebb812cab 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -83,7 +83,11 @@ def key(self) -> CallbackKey: @property def display_name(self) -> str: - """Return the callback ID as a display name.""" + """Return a human-readable name for logging and process titles.""" + if path := self.callback.data.get("path", ""): + # Use just the function/class name for brevity in process titles. + # The full path and UUID are available in log messages if needed. + return path.rsplit(".", 1)[-1] return str(self.callback.id) @property From 9965f46548967ec69a690eb7cb5a32a415f5d45b Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 3 Mar 2026 13:35:52 -0800 Subject: [PATCH 07/19] remove execute_callback from export --- task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 643359af27b2c..c47026ad9b84f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -35,7 +35,7 @@ from structlog.typing import FilteringBoundLogger from typing_extensions import Self -__all__ = ["CallbackSubprocess", "execute_callback", "supervise_callback"] +__all__ = ["CallbackSubprocess", "supervise_callback"] log: FilteringBoundLogger = structlog.get_logger(logger_name="callback_supervisor") From 428549a9bd0e31bb3f026fd4b455408ceb4f18f1 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 16 Mar 2026 15:10:41 -0700 Subject: [PATCH 08/19] Shivam fixes --- .../src/airflow/executors/local_executor.py | 2 + .../src/airflow/executors/workloads/base.py | 11 ++- .../airflow/executors/workloads/callback.py | 64 ---------------- .../unit/executors/test_local_executor.py | 2 +- .../celery/executors/celery_executor_utils.py | 10 ++- .../sdk/execution_time/callback_supervisor.py | 73 +++++++++++++++++-- 6 files changed, 82 insertions(+), 80 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index dd27580b0df78..ca76c9675e5f2 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -147,6 +147,7 @@ def _execute_workload(log: Logger, workload: ExecutorWorkload, team_conf) -> Non callback_path=workload.callback.data.get("path", ""), callback_kwargs=workload.callback.data.get("kwargs", {}), log_path=workload.log_path, + bundle_info=workload.bundle_info, ) else: raise ValueError(f"LocalExecutor does not know how to execute {type(workload).__name__!r}") @@ -197,6 +198,7 @@ def start(self) -> None: # Mypy sees this value as `SynchronizedBase[c_uint]`, but that isn't the right runtime type behaviour # (it looks like an int to python) + self._unread_messages = multiprocessing.Value(ctypes.c_uint) if self.is_mp_using_fork: diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index bd122aa5ac5d6..8eb14097cda52 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -20,12 +20,14 @@ import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from collections.abc import Hashable +from typing import TYPE_CHECKING from pydantic import BaseModel, ConfigDict, Field if TYPE_CHECKING: from airflow.api_fastapi.auth.tokens import JWTGenerator + from airflow.executors.workloads.types import WorkloadState class BaseWorkload: @@ -86,7 +88,7 @@ class BaseDagBundleWorkload(BaseWorkloadSchema, ABC): @property @abstractmethod - def key(self) -> Any: + def key(self) -> Hashable: """ Return the unique key identifying this workload instance. @@ -122,7 +124,7 @@ def display_name(self) -> str: @property @abstractmethod - def success_state(self) -> Any: + def success_state(self) -> WorkloadState: """ Return the state value representing successful completion of this workload type. @@ -132,10 +134,11 @@ def success_state(self) -> Any: @property @abstractmethod - def failure_state(self) -> Any: + def failure_state(self) -> WorkloadState: """ Return the state value representing failed completion of this workload type. Must be implemented by subclasses. """ raise NotImplementedError(f"{self.__class__.__name__} must implement failure_state") + diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 6f25ebb812cab..2d66e570b7a9b 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -19,7 +19,6 @@ from __future__ import annotations from enum import Enum -from importlib import import_module from pathlib import Path from typing import TYPE_CHECKING, Literal from uuid import UUID @@ -122,66 +121,3 @@ def make( log_path=fname, bundle_info=bundle_info, ) - - -def execute_callback_workload( - callback: CallbackDTO, - log, -) -> tuple[bool, str | None]: - """ - Execute a callback function by importing and calling it, returning the success state. - - Supports two patterns: - 1. Functions - called directly with kwargs - 2. Classes that return callable instances (like BaseNotifier) - instantiated then called with context - - Example: - # Function callback - callback.data = {"path": "my_module.alert_func", "kwargs": {"msg": "Alert!", "context": {...}}} - execute_callback_workload(callback, log) # Calls alert_func(msg="Alert!", context={...}) - - # Notifier callback - callback.data = {"path": "airflow.providers.slack...SlackWebhookNotifier", "kwargs": {"text": "Alert!", "context": {...}}} - execute_callback_workload(callback, log) # SlackWebhookNotifier(text=..., context=...) then calls instance(context) - - :param callback: The Callback schema containing path and kwargs - :param log: Logger instance for recording execution - :return: Tuple of (success: bool, error_message: str | None) - """ - from airflow.models.callback import _accepts_context # circular import - - callback_path = callback.data.get("path") - callback_kwargs = callback.data.get("kwargs", {}) - - if not callback_path: - return False, "Callback path not found in data." - - try: - # Import the callback callable - # Expected format: "module.path.to.function_or_class" - module_path, function_name = callback_path.rsplit(".", 1) - module = import_module(module_path) - callback_callable = getattr(module, function_name) - context = callback_kwargs.pop("context", None) - - log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) - - # If the callback is a callable, call it. If it is a class, instantiate it. - # Rather than forcing all custom callbacks to accept context, conditionally provide it only if supported. - if _accepts_context(callback_callable) and context is not None: - result = callback_callable(**callback_kwargs, context=context) - else: - result = callback_callable(**callback_kwargs) - - # If the callback is a class then it is now instantiated and callable, call it. - if callable(result): - log.debug("Calling result with context for %s", callback_path) - result = result(context) - - log.info("Callback %s executed successfully.", callback_path) - return True, None - - except Exception as e: - error_msg = f"Callback execution failed: {type(e).__name__}: {str(e)}" - log.exception("Callback %s(%s) execution failed: %s", callback_path, callback_kwargs, error_msg) - return False, error_msg diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index ae801124e6f6a..f75d786ddda48 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -426,7 +426,6 @@ def test_process_callback_workload_queue_management(self): @mock.patch("airflow.sdk.execution_time.callback_supervisor.supervise_callback", return_value=0) def test_execute_workload_calls_supervise_callback(self, mock_supervise_callback): - callback_data = CallbackDTO( id=self.CALLBACK_UUID, fetch_method=CallbackFetchMethod.IMPORT_PATH, @@ -447,6 +446,7 @@ def test_execute_workload_calls_supervise_callback(self, mock_supervise_callback callback_path="test.module.my_callback", callback_kwargs={"arg1": "val1"}, log_path="test.log", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), ) @mock.patch( diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 82fcf46e152a1..c91fd45cfaf38 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -53,8 +53,6 @@ except ImportError: from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager # type:ignore[no-redef] -if AIRFLOW_V_3_2_PLUS: - from airflow.sdk.execution_time.callback_supervisor import supervise_callback log = logging.getLogger(__name__) @@ -193,7 +191,6 @@ def execute_workload(input: str) -> None: from airflow.executors import workloads from airflow.providers.common.compat.sdk import conf - from airflow.sdk.execution_time.supervisor import supervise decoder = TypeAdapter[workloads.All](workloads.All) workload = decoder.validate_json(input) @@ -209,6 +206,8 @@ def execute_workload(input: str) -> None: default_execution_api_server = f"{base_url.rstrip('/')}/execution/" if isinstance(workload, workloads.ExecuteTask): + from airflow.sdk.execution_time.supervisor import supervise + supervise( # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. ti=workload.ti, # type: ignore[arg-type] @@ -218,12 +217,15 @@ def execute_workload(input: str) -> None: server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), log_path=workload.log_path, ) - elif isinstance(workload, workloads.ExecuteCallback): + elif isinstance(workload, workloads.ExecuteCallback) and AIRFLOW_V_3_2_PLUS: + from airflow.sdk.execution_time.callback_supervisor import supervise_callback + supervise_callback( id=workload.callback.id, callback_path=workload.callback.data.get("path", ""), callback_kwargs=workload.callback.data.get("kwargs", {}), log_path=workload.log_path, + bundle_info=workload.bundle_info, ) else: raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index c47026ad9b84f..9e4f881111204 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -27,7 +27,11 @@ import structlog from pydantic import TypeAdapter -from airflow.sdk.execution_time.supervisor import WatchedSubprocess +from airflow.sdk.execution_time.supervisor import ( + MIN_HEARTBEAT_INTERVAL, + SOCKET_CLEANUP_TIMEOUT, + WatchedSubprocess, +) if TYPE_CHECKING: from collections.abc import Callable @@ -35,6 +39,8 @@ from structlog.typing import FilteringBoundLogger from typing_extensions import Self + from airflow.executors.workloads.base import BundleInfo + __all__ = ["CallbackSubprocess", "supervise_callback"] log: FilteringBoundLogger = structlog.get_logger(logger_name="callback_supervisor") @@ -64,6 +70,8 @@ def execute_callback( :param log: Logger instance for recording execution. :return: Tuple of (success: bool, error_message: str | None) """ + from airflow.models.callback import _accepts_context # lazy import to avoid circular deps + if not callback_path: return False, "Callback path not found." @@ -81,7 +89,7 @@ def execute_callback( # If the callback is a class then it is now instantiated and callable, call it. if callable(result): - context = callback_kwargs.get("context", {}) + context = callback_kwargs.get("context", {}) if _accepts_context(result) else {} log.debug("Calling result with context for %s", callback_path) result = result(context) @@ -180,7 +188,7 @@ def default(self, o): os.environ["_AIRFLOW_CALLBACK_KWARGS"] = json.dumps(callback_kwargs, cls=_ExtendedEncoder) try: proc: Self = super().start( - id=id, + id=UUID(id) if not isinstance(id, UUID) else id, target=target, logger=logger, **kwargs, @@ -195,21 +203,46 @@ def wait(self) -> int: """ Wait for the callback subprocess to complete. - A simplified monitor loop compared to ActivitySubprocess — no heartbeating, - no task API state management. Just monitors process output and waits for exit. + Mirrors the structure of ActivitySubprocess.wait() but without heartbeating, + task API state management, or log uploading. """ if self._exit_code is not None: return self._exit_code try: - while self._exit_code is None or self._open_sockets: - self._service_subprocess(max_wait_time=5.0) + self._monitor_subprocess() finally: self.selector.close() self._exit_code = self._exit_code if self._exit_code is not None else 1 return self._exit_code + def _monitor_subprocess(self): + """ + Monitor the subprocess until it exits. + + A simplified version of ActivitySubprocess._monitor_subprocess() without heartbeating + or timeout handling, just process output monitoring and stuck-socket cleanup. + """ + while self._exit_code is None or self._open_sockets: + self._service_subprocess(max_wait_time=MIN_HEARTBEAT_INTERVAL) + + # If the process has exited but sockets remain open, apply a timeout + # to prevent hanging indefinitely on stuck sockets. + if self._exit_code is not None and self._open_sockets: + if ( + self._process_exit_monotonic + and time.monotonic() - self._process_exit_monotonic > SOCKET_CLEANUP_TIMEOUT + ): + log.warning( + "Process exited with open sockets; cleaning up after timeout", + pid=self.pid, + exit_code=self._exit_code, + socket_types=list(self._open_sockets.values()), + timeout_seconds=SOCKET_CLEANUP_TIMEOUT, + ) + self._cleanup_open_sockets() + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: """Handle incoming requests from the callback subprocess (currently none expected).""" log.warning("Unexpected request from callback subprocess", msg=msg) @@ -234,6 +267,7 @@ def supervise_callback( callback_path: str, callback_kwargs: dict, log_path: str | None = None, + bundle_info: BundleInfo | None = None, ) -> int: """ Run a single callback execution to completion in a supervised subprocess. @@ -242,10 +276,35 @@ def supervise_callback( :param callback_path: Dot-separated import path to the callback function or class. :param callback_kwargs: Keyword arguments to pass to the callback. :param log_path: Path to write logs, if required. + :param bundle_info: When provided, the bundle's path is added to sys.path so callbacks in Dag Bundles are importable. :return: Exit code of the subprocess (0 = success). """ + import sys + start = time.monotonic() + # If bundle info is provided, initialize the bundle and ensure its path is importable. + # This is needed for user-defined callbacks that live inside a DAG bundle rather than + # in an installed package or the plugins directory. + if bundle_info and bundle_info.name: + try: + from airflow.dag_processing.bundles.manager import DagBundlesManager + + bundle = DagBundlesManager().get_bundle( + name=bundle_info.name, + version=bundle_info.version, + ) + bundle.initialize() + if (bundle_path := str(bundle.path)) not in sys.path: + sys.path.append(bundle_path) + log.debug("Added bundle path to sys.path", bundle_name=bundle_info.name, path=bundle_path) + except Exception: + log.warning( + "Failed to initialize DAG bundle for callback", + bundle_name=bundle_info.name, + exc_info=True, + ) + logger: FilteringBoundLogger log_file_descriptor: BinaryIO | None = None if log_path: From b5afbe4a3e902d65f1b0034043afe2ee0dd56890 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 16 Mar 2026 21:31:02 -0700 Subject: [PATCH 09/19] static checks --- airflow-core/src/airflow/executors/workloads/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 8eb14097cda52..8fff22f055f00 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -141,4 +141,3 @@ def failure_state(self) -> WorkloadState: Must be implemented by subclasses. """ raise NotImplementedError(f"{self.__class__.__name__} must implement failure_state") - From 43d814ac833e3e80da05a8efa65d3727b02d486d Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 17 Mar 2026 14:41:44 -0700 Subject: [PATCH 10/19] fix Core/SDK barrier issues --- airflow-core/src/airflow/models/callback.py | 14 ++------------ airflow-core/src/airflow/triggers/callback.py | 6 +++--- airflow-core/tests/unit/models/test_callback.py | 14 +++++++------- .../src/airflow_shared/module_loading/__init__.py | 11 +++++++++++ task-sdk/.pre-commit-config.yaml | 1 + .../sdk/execution_time/callback_supervisor.py | 15 ++++++++++----- 6 files changed, 34 insertions(+), 27 deletions(-) diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index e08d58fa4da5c..c43c1fb509bea 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import inspect -from collections.abc import Callable from datetime import datetime from enum import Enum from importlib import import_module @@ -29,6 +27,8 @@ from sqlalchemy import ForeignKey, Integer, String, Text, Uuid from sqlalchemy.orm import Mapped, mapped_column, relationship +# Re-exporting as _accepts_context for backward compatibility +from airflow._shared.module_loading import accepts_context as _accepts_context # noqa: F401 from airflow._shared.observability.metrics.stats import Stats from airflow._shared.timezones import timezone from airflow.executors.workloads import BaseWorkload @@ -53,16 +53,6 @@ TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) -def _accepts_context(callback: Callable) -> bool: - """Check if callback accepts a 'context' parameter or **kwargs.""" - try: - sig = inspect.signature(callback) - except (ValueError, TypeError): - return True - params = sig.parameters - return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) - - class CallbackType(str, Enum): """ Types of Callbacks. diff --git a/airflow-core/src/airflow/triggers/callback.py b/airflow-core/src/airflow/triggers/callback.py index 9c2470c77eae6..b27920fed8614 100644 --- a/airflow-core/src/airflow/triggers/callback.py +++ b/airflow-core/src/airflow/triggers/callback.py @@ -22,8 +22,8 @@ from collections.abc import AsyncIterator from typing import Any -from airflow._shared.module_loading import import_string, qualname -from airflow.models.callback import CallbackState, _accepts_context +from airflow._shared.module_loading import accepts_context, import_string, qualname +from airflow.models.callback import CallbackState from airflow.triggers.base import BaseTrigger, TriggerEvent log = logging.getLogger(__name__) @@ -55,7 +55,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs` context = self.callback_kwargs.pop("context", None) - if _accepts_context(callback) and context is not None: + if accepts_context(callback) and context is not None: result = await callback(**self.callback_kwargs, context=context) else: result = await callback(**self.callback_kwargs) diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index 5d4940b77adf8..bb3102decf8a8 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -21,6 +21,7 @@ import pytest from sqlalchemy import select +from airflow._shared.module_loading import accepts_context from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.models import Trigger from airflow.models.callback import ( @@ -30,7 +31,6 @@ DagProcessorCallback, ExecutorCallback, TriggererCallback, - _accepts_context, ) from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback from airflow.triggers.base import TriggerEvent @@ -239,26 +239,26 @@ def test_true_when_var_keyword_present(self): def func_with_var_keyword(**kwargs): pass - assert _accepts_context(func_with_var_keyword) is True + assert accepts_context(func_with_var_keyword) is True def test_true_when_context_param_present(self): def func_with_context(context, alert_type): pass - assert _accepts_context(func_with_context) is True + assert accepts_context(func_with_context) is True def test_false_when_no_context_or_var_keyword(self): def func_without_context(a, b): pass - assert _accepts_context(func_without_context) is False + assert accepts_context(func_without_context) is False def test_false_when_no_params(self): def func_no_params(): pass - assert _accepts_context(func_no_params) is False + assert accepts_context(func_no_params) is False def test_true_for_uninspectable_callable(self): - with patch("airflow.models.callback.inspect.signature", side_effect=ValueError): - assert _accepts_context(lambda: None) is True + with patch("airflow._shared.module_loading.inspect.signature", side_effect=ValueError): + assert accepts_context(lambda: None) is True diff --git a/shared/module_loading/src/airflow_shared/module_loading/__init__.py b/shared/module_loading/src/airflow_shared/module_loading/__init__.py index 5b2b7c9496ec1..2406db56a24f5 100644 --- a/shared/module_loading/src/airflow_shared/module_loading/__init__.py +++ b/shared/module_loading/src/airflow_shared/module_loading/__init__.py @@ -18,6 +18,7 @@ from __future__ import annotations import functools +import inspect import logging import pkgutil import sys @@ -43,6 +44,16 @@ from types import ModuleType +def accepts_context(callback: Callable) -> bool: + """Check if callback accepts a 'context' parameter or **kwargs.""" + try: + sig = inspect.signature(callback) + except (ValueError, TypeError): + return True + params = sig.parameters + return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + def import_string(dotted_path: str): """ Import a dotted module path and return the attribute/class designated by the last name in the path. diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml index 315e0ea8a133f..bbcc6cefaca3c 100644 --- a/task-sdk/.pre-commit-config.yaml +++ b/task-sdk/.pre-commit-config.yaml @@ -45,6 +45,7 @@ repos: ^src/airflow/sdk/definitions/_internal/types\.py$| ^src/airflow/sdk/execution_time/execute_workload\.py$| ^src/airflow/sdk/execution_time/secrets_masker\.py$| + ^src/airflow/sdk/execution_time/callback_supervisor\.py$| ^src/airflow/sdk/execution_time/supervisor\.py$| ^src/airflow/sdk/execution_time/task_runner\.py$| ^src/airflow/sdk/serde/serializers/kubernetes\.py$| diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 9e4f881111204..6b408cd721a05 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -21,7 +21,7 @@ import os import time from importlib import import_module -from typing import TYPE_CHECKING, BinaryIO, ClassVar +from typing import TYPE_CHECKING, BinaryIO, ClassVar, Protocol import attrs import structlog @@ -39,7 +39,12 @@ from structlog.typing import FilteringBoundLogger from typing_extensions import Self - from airflow.executors.workloads.base import BundleInfo + # Core (airflow.executors.workloads.base.BundleInfo) and SDK (airflow.sdk.api.datamodels._generated.BundleInfo) + # are structurally identical, but MyPy treats them as different types. This Protocol makes MyPy happy. + class _BundleInfoLike(Protocol): + name: str + version: str | None + __all__ = ["CallbackSubprocess", "supervise_callback"] @@ -70,7 +75,7 @@ def execute_callback( :param log: Logger instance for recording execution. :return: Tuple of (success: bool, error_message: str | None) """ - from airflow.models.callback import _accepts_context # lazy import to avoid circular deps + from airflow.sdk._shared.module_loading import accepts_context if not callback_path: return False, "Callback path not found." @@ -89,7 +94,7 @@ def execute_callback( # If the callback is a class then it is now instantiated and callable, call it. if callable(result): - context = callback_kwargs.get("context", {}) if _accepts_context(result) else {} + context = callback_kwargs.get("context", {}) if accepts_context(result) else {} log.debug("Calling result with context for %s", callback_path) result = result(context) @@ -267,7 +272,7 @@ def supervise_callback( callback_path: str, callback_kwargs: dict, log_path: str | None = None, - bundle_info: BundleInfo | None = None, + bundle_info: _BundleInfoLike | None = None, ) -> int: """ Run a single callback execution to completion in a supervised subprocess. From 51e1b6b886b3dd471a0caedeaf0016c01b29a2ce Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 19 Mar 2026 14:45:24 -0700 Subject: [PATCH 11/19] Ash fixes - create supervise_workload and move setproctitle in there as well --- .../src/airflow/executors/local_executor.py | 55 ++++------ .../celery/executors/celery_executor_utils.py | 33 +----- .../sdk/execution_time/execute_workload.py | 26 +---- .../airflow/sdk/execution_time/supervisor.py | 101 +++++++++++++++++- .../execution_time/test_supervisor.py | 8 +- 5 files changed, 127 insertions(+), 96 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index ca76c9675e5f2..f60cc28351a43 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -35,7 +35,6 @@ import structlog -from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor # add logger to parameter of setproctitle to support logging @@ -53,6 +52,19 @@ from airflow.executors.workloads.types import WorkloadResultType +def _get_execution_api_server_url(team_conf) -> str: + """ + Resolve the execution API server URL from team-specific configuration. + + :param team_conf: Team-specific executor configuration (ExecutorConf or AirflowConfigParser) + """ + base_url = team_conf.get("api", "base_url", fallback="/") + if base_url.startswith("/"): + base_url = f"http://localhost:8080{base_url}" + default_execution_api_server = f"{base_url.rstrip('/')}/execution/" + return team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server) + + def _get_executor_process_title_prefix(team_name: str | None) -> str: """ Build the process title prefix for LocalExecutor workers. @@ -117,40 +129,13 @@ def _execute_workload(log: Logger, workload: ExecutorWorkload, team_conf) -> Non :param workload: The workload to execute (ExecuteTask or ExecuteCallback) :param team_conf: Team-specific executor configuration """ - setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.display_name}", log) - - if isinstance(workload, workloads.ExecuteTask): - from airflow.sdk.execution_time.supervisor import supervise - - base_url = team_conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default - if base_url.startswith("/"): - base_url = f"http://localhost:8080{base_url}" - default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - - # This will return the exit code of the task process, but we don't care about that, just if the - # _supervisor_ had an error reporting the state back (which will result in an exception.) - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), - log_path=workload.log_path, - ) - elif isinstance(workload, workloads.ExecuteCallback): - from airflow.sdk.execution_time.callback_supervisor import supervise_callback - - supervise_callback( - id=workload.callback.id, - callback_path=workload.callback.data.get("path", ""), - callback_kwargs=workload.callback.data.get("kwargs", {}), - log_path=workload.log_path, - bundle_info=workload.bundle_info, - ) - else: - raise ValueError(f"LocalExecutor does not know how to execute {type(workload).__name__!r}") + from airflow.sdk.execution_time.supervisor import supervise_workload + + supervise_workload( + workload, + server=_get_execution_api_server_url(team_conf), + proctitle=f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.display_name}", + ) class LocalExecutor(BaseExecutor): diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index c91fd45cfaf38..e729ab5207b0c 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -199,36 +199,9 @@ def execute_workload(input: str) -> None: log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload) - base_url = conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default - if base_url.startswith("/"): - base_url = f"http://localhost:8080{base_url}" - default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - - if isinstance(workload, workloads.ExecuteTask): - from airflow.sdk.execution_time.supervisor import supervise - - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), - log_path=workload.log_path, - ) - elif isinstance(workload, workloads.ExecuteCallback) and AIRFLOW_V_3_2_PLUS: - from airflow.sdk.execution_time.callback_supervisor import supervise_callback - - supervise_callback( - id=workload.callback.id, - callback_path=workload.callback.data.get("path", ""), - callback_kwargs=workload.callback.data.get("kwargs", {}), - log_path=workload.log_path, - bundle_info=workload.bundle_info, - ) - else: - raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") + from airflow.sdk.execution_time.supervisor import supervise_workload + + supervise_workload(workload) if not AIRFLOW_V_3_0_PLUS: diff --git a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py index 410c676eeb913..0c36d131fc796 100644 --- a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py +++ b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py @@ -40,9 +40,7 @@ def execute_workload(workload: ExecuteTask) -> None: - from airflow.executors import workloads - from airflow.sdk.configuration import conf - from airflow.sdk.execution_time.supervisor import supervise + from airflow.sdk.execution_time.supervisor import supervise_workload from airflow.sdk.log import configure_logging from airflow.settings import dispose_orm @@ -50,28 +48,10 @@ def execute_workload(workload: ExecuteTask) -> None: configure_logging(output=sys.stdout.buffer, json_output=True) - if not isinstance(workload, workloads.ExecuteTask): - raise ValueError(f"Executor does not know how to handle {type(workload)}") - log.info("Executing workload", workload=workload) - base_url = conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default - if base_url.startswith("/"): - base_url = f"http://localhost:8080{base_url}" - default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - server = conf.get("core", "execution_api_server_url", fallback=default_execution_api_server) - log.info("Connecting to server:", server=server) - - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=server, - log_path=workload.log_path, - sentry_integration=workload.sentry_integration, + supervise_workload( + workload, # Include the output of the task to stdout too, so that in process logs can be read from via the # kubeapi as pod logs. subprocess_logs_to_stdout=True, diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 222bba34adf38..d19d87e45e5d0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -133,13 +133,12 @@ from structlog.typing import FilteringBoundLogger, WrappedLogger from typing_extensions import Self - from airflow.executors.workloads import BundleInfo + from airflow.executors.workloads import BundleInfo, ExecutorWorkload from airflow.sdk.bases.secrets_backend import BaseSecretsBackend from airflow.sdk.definitions.connection import Connection from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI - -__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise"] +__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise", "supervise_task", "supervise_workload"] log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor") @@ -2000,7 +1999,7 @@ def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLog return logger, log_file_descriptor -def supervise( +def supervise_task( *, ti: TaskInstance, bundle_info: BundleInfo, @@ -2121,3 +2120,97 @@ def supervise( if close_client and client: with suppress(Exception): client.close() + + +def supervise_workload( + workload: ExecutorWorkload, + *, + server: str | None = None, + dry_run: bool = False, + client: Client | None = None, + subprocess_logs_to_stdout: bool = False, + proctitle: str | None = None, +) -> int: + """ + Run any workload type to completion in a supervised subprocess. + + Dispatch to the appropriate supervisor based on workload type. Workload-specific + attributes (log_path, sentry_integration, bundle_info, etc.) are read from the + workload object itself. + + :param workload: The ``ExecutorWorkload`` to execute. + :param server: Base URL of the API server (used by task workloads). + :param dry_run: If True, execute without actual task execution (simulate run). + :param client: Optional preconfigured client for communication with the server. + :param subprocess_logs_to_stdout: Should task logs also be sent to stdout via the main logger. + :param proctitle: Process title to set for this workload. If not provided, defaults to + ``"airflow supervisor: "``. Executors may pass a custom title + that includes executor-specific context (e.g. team name). + :return: Exit code of the process. + """ + # Imports deferred to avoid an SDK/core dependency at module load time. + from airflow.executors.workloads.callback import ExecuteCallback + from airflow.executors.workloads.task import ExecuteTask + + try: + from setproctitle import setproctitle + + setproctitle(proctitle or f"airflow supervisor: {workload.display_name}") + except ImportError: + pass + + # Resolve server URL from config when not explicitly provided. + # For example, team-specific executors may wish to pass their own server URL. + if server is None: + base_url = conf.get("api", "base_url", fallback="/") + if base_url.startswith("/"): + base_url = f"http://localhost:8080{base_url}" + server = conf.get( + "core", + "execution_api_server_url", + fallback=f"{base_url.rstrip('/')}/execution/", + ) + + if isinstance(workload, ExecuteTask): + return supervise_task( + ti=workload.ti, + bundle_info=workload.bundle_info, + dag_rel_path=workload.dag_rel_path, + token=workload.token, + server=server, + dry_run=dry_run, + log_path=workload.log_path, + subprocess_logs_to_stdout=subprocess_logs_to_stdout, + client=client, + sentry_integration=getattr(workload, "sentry_integration", ""), + ) + if isinstance(workload, ExecuteCallback): + from airflow.sdk.execution_time.callback_supervisor import supervise_callback + + return supervise_callback( + id=workload.callback.id, + callback_path=workload.callback.data.get("path", ""), + callback_kwargs=workload.callback.data.get("kwargs", {}), + log_path=workload.log_path, + bundle_info=workload.bundle_info, + ) + raise ValueError(f"Unknown workload type: {type(workload).__name__}") + + +def supervise(**kwargs) -> int: + """ + Call ``supervise_task()`` with a deprecation warning. + + This wrapper exists for backward compatibility with provider packages that may import ``supervise`` directly. + + .. deprecated:: + Use :func:`supervise_task` instead. + """ + import warnings + + warnings.warn( + "supervise() is deprecated, use supervise_task() instead.", + DeprecationWarning, + stacklevel=2, + ) + return supervise_task(**kwargs) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index bcea78269f3fb..e091101f84190 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -143,7 +143,7 @@ _remote_logging_conn, process_log_messages_from_subprocess, set_supervisor_comms, - supervise, + supervise_task, ) from airflow.sdk.execution_time.task_runner import run @@ -230,7 +230,7 @@ def test_supervise( with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): with expectation: - supervise(**kw) + supervise_task(**kw) @pytest.mark.usefixtures("disable_capturing") @@ -624,7 +624,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker bundle_info = BundleInfo(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): - exit_code = supervise( + exit_code = supervise_task( ti=ti, dag_rel_path=dagfile_path, token="", @@ -679,7 +679,7 @@ def mock_monotonic(): patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)), patch("airflow.sdk.execution_time.supervisor.time.monotonic", side_effect=mock_monotonic), ): - exit_code = supervise( + exit_code = supervise_task( ti=ti, dag_rel_path="super_basic_deferred_run.py", token="", From 79f92937dd20fa42036d723edf0cd6e645b639dc Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 19 Mar 2026 17:30:37 -0700 Subject: [PATCH 12/19] merge conflict fixes adn mypy --- airflow-core/tests/unit/executors/test_local_executor.py | 2 +- .../providers/celery/executors/celery_executor_utils.py | 4 ++-- task-sdk/src/airflow/sdk/execution_time/supervisor.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index f75d786ddda48..da1d65d6595e5 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -395,7 +395,7 @@ def test_supports_callbacks_flag_is_true(self): executor = LocalExecutor() assert executor.supports_callbacks is True - @skip_spawn_mp_start + @skip_non_fork_mp_start def test_process_callback_workload_queue_management(self): """Test that _process_workloads correctly removes callbacks from queued_callbacks.""" executor = LocalExecutor(parallelism=1) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index e729ab5207b0c..0b622799a6a21 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -189,10 +189,10 @@ def on_celery_worker_ready(*args, **kwargs): def execute_workload(input: str) -> None: from pydantic import TypeAdapter - from airflow.executors import workloads + from airflow.executors.workloads import ExecutorWorkload from airflow.providers.common.compat.sdk import conf - decoder = TypeAdapter[workloads.All](workloads.All) + decoder = TypeAdapter[ExecutorWorkload](ExecutorWorkload) workload = decoder.validate_json(input) celery_task_id = app.current_task.request.id diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index d19d87e45e5d0..0029a7f46d6b1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -2173,7 +2173,9 @@ def supervise_workload( if isinstance(workload, ExecuteTask): return supervise_task( - ti=workload.ti, + # workload.ti is a TaskInstanceDTO which duck-types as TaskInstance. + # TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] bundle_info=workload.bundle_info, dag_rel_path=workload.dag_rel_path, token=workload.token, From 937618f23d4fb6c5c067d63d6d138bb0e94e39ab Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 20 Mar 2026 09:15:57 -0700 Subject: [PATCH 13/19] fix unit tests --- .../unit/executors/test_local_executor.py | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index da1d65d6595e5..7fd8410dc9752 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -127,8 +127,8 @@ def test_executor_lazy_worker_spawning(self, mock_freeze, mock_unfreeze): executor.end() @skip_non_fork_mp_start - @mock.patch("airflow.sdk.execution_time.supervisor.supervise") - def test_execution(self, mock_supervise): + @mock.patch("airflow.sdk.execution_time.supervisor.supervise_workload") + def test_execution(self, mock_supervise_workload): success_tis = [ TaskInstanceDTO( id=uuid7(), @@ -151,14 +151,14 @@ def test_execution(self, mock_supervise): # We just mock both styles here, only one will be hit though has_failed_once = False - def fake_supervise(ti, **kwargs): + def fake_supervise_workload(workload, **kwargs): nonlocal has_failed_once - if ti.id == fail_ti.id and not has_failed_once: + if workload.ti.id == fail_ti.id and not has_failed_once: has_failed_once = True raise RuntimeError("fake failure") return 0 - mock_supervise.side_effect = fake_supervise + mock_supervise_workload.side_effect = fake_supervise_workload executor = LocalExecutor(parallelism=2) @@ -273,8 +273,8 @@ def test_clean_stop_on_signal(self): "relative_base_url", ], ) - @mock.patch("airflow.sdk.execution_time.supervisor.supervise") - def test_execution_api_server_url_config(self, mock_supervise, conf_values, expected_server): + @mock.patch("airflow.sdk.execution_time.supervisor.supervise_workload") + def test_execution_api_server_url_config(self, mock_supervise_workload, conf_values, expected_server): """Test that execution_api_server_url is correctly configured with fallback""" from airflow.executors.base_executor import ExecutorConf @@ -282,17 +282,11 @@ def test_execution_api_server_url_config(self, mock_supervise, conf_values, expe team_conf = ExecutorConf(team_name=None) _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=team_conf) - mock_supervise.assert_called_with( - ti=mock.ANY, - dag_rel_path=mock.ANY, - bundle_info=mock.ANY, - token=mock.ANY, - server=expected_server, - log_path=mock.ANY, - ) + mock_supervise_workload.assert_called_once() + assert mock_supervise_workload.call_args.kwargs["server"] == expected_server - @mock.patch("airflow.sdk.execution_time.supervisor.supervise") - def test_team_and_global_config_isolation(self, mock_supervise): + @mock.patch("airflow.sdk.execution_time.supervisor.supervise_workload") + def test_team_and_global_config_isolation(self, mock_supervise_workload): """Test that team-specific and global executors use correct configurations side-by-side""" from airflow.executors.base_executor import ExecutorConf @@ -318,20 +312,18 @@ def test_team_and_global_config_isolation(self, mock_supervise): _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=team_conf) # Verify team-specific server URL was used - assert mock_supervise.call_count == 1 - call_kwargs = mock_supervise.call_args[1] - assert call_kwargs["server"] == team_server + assert mock_supervise_workload.call_count == 1 + assert mock_supervise_workload.call_args.kwargs["server"] == team_server - mock_supervise.reset_mock() + mock_supervise_workload.reset_mock() # Test global config (no team) global_conf = ExecutorConf(team_name=None) _execute_workload(log=mock.ANY, workload=_make_mock_task_workload(), team_conf=global_conf) # Verify default server URL was used - assert mock_supervise.call_count == 1 - call_kwargs = mock_supervise.call_args[1] - assert call_kwargs["server"] == default_server + assert mock_supervise_workload.call_count == 1 + assert mock_supervise_workload.call_args.kwargs["server"] == default_server def test_multiple_team_executors_isolation(self): """Test that multiple team executors can coexist with isolated resources""" From a2f3252bf1294805e268b64ba6b73af16cf47b79 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 24 Mar 2026 10:53:45 -0700 Subject: [PATCH 14/19] add running state back --- airflow-core/src/airflow/executors/local_executor.py | 3 +++ airflow-core/src/airflow/executors/workloads/base.py | 12 ++++++++++++ .../src/airflow/executors/workloads/callback.py | 4 ++++ .../celery/executors/celery_executor_utils.py | 1 - 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index f60cc28351a43..59ddb54f440f4 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -110,6 +110,9 @@ def _run_worker( with unread_messages: unread_messages.value -= 1 + if workload.running_state is not None: + output.put((workload.key, workload.running_state, None)) + try: _execute_workload(log, workload, team_conf) output.put((workload.key, workload.success_state, None)) diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 8fff22f055f00..6404e991e0cad 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -141,3 +141,15 @@ def failure_state(self) -> WorkloadState: Must be implemented by subclasses. """ raise NotImplementedError(f"{self.__class__.__name__} must implement failure_state") + + @property + def running_state(self) -> WorkloadState | None: + """ + Return the state value representing that this workload is actively running. + + Called by the executor worker *before* execution begins. Subclasses may override + this to emit an intermediate state transition (e.g. callbacks need + QUEUED → RUNNING → SUCCESS/FAILED). Returns ``None`` by default, meaning + no intermediate state is emitted. + """ + return None diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 2d66e570b7a9b..a78dbab43a594 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -97,6 +97,10 @@ def success_state(self) -> CallbackState: def failure_state(self) -> CallbackState: return CallbackState.FAILED + @property + def running_state(self) -> CallbackState: + return CallbackState.RUNNING + @classmethod def make( cls, diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 0b622799a6a21..d8508791d20cb 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -190,7 +190,6 @@ def execute_workload(input: str) -> None: from pydantic import TypeAdapter from airflow.executors.workloads import ExecutorWorkload - from airflow.providers.common.compat.sdk import conf decoder = TypeAdapter[ExecutorWorkload](ExecutorWorkload) workload = decoder.validate_json(input) From 6df881fdf1612cf8797cbdcbf9c28e690b2d4c2d Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 24 Mar 2026 15:35:25 -0700 Subject: [PATCH 15/19] account for *args in accepts_context --- .../src/airflow_shared/module_loading/__init__.py | 6 ++++-- .../airflow/sdk/execution_time/callback_supervisor.py | 10 +++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/shared/module_loading/src/airflow_shared/module_loading/__init__.py b/shared/module_loading/src/airflow_shared/module_loading/__init__.py index 2406db56a24f5..2b6c1b388d37d 100644 --- a/shared/module_loading/src/airflow_shared/module_loading/__init__.py +++ b/shared/module_loading/src/airflow_shared/module_loading/__init__.py @@ -45,13 +45,15 @@ def accepts_context(callback: Callable) -> bool: - """Check if callback accepts a 'context' parameter or **kwargs.""" + """Check if callback accepts a 'context' parameter, *args, or **kwargs.""" try: sig = inspect.signature(callback) except (ValueError, TypeError): return True params = sig.parameters - return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + return "context" in params or any( + p.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) for p in params.values() + ) def import_string(dotted_path: str): diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 6b408cd721a05..13f1ac57681ab 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -94,9 +94,13 @@ def execute_callback( # If the callback is a class then it is now instantiated and callable, call it. if callable(result): - context = callback_kwargs.get("context", {}) if accepts_context(result) else {} - log.debug("Calling result with context for %s", callback_path) - result = result(context) + if accepts_context(result): + context = callback_kwargs.get("context", {}) + log.debug("Calling result with context for %s", callback_path) + result = result(context) + else: + log.debug("Calling result without context for %s", callback_path) + result = result() log.info("Callback %s executed successfully.", callback_path) return True, None From 21a255f3e58168345c369f576e5a2c48faf27652 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 24 Mar 2026 15:36:14 -0700 Subject: [PATCH 16/19] fix encoder --- .../src/airflow/sdk/execution_time/callback_supervisor.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 13f1ac57681ab..e10017ee78078 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -183,12 +183,8 @@ class _ExtendedEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, UUID): return str(o) - if isinstance(o, datetime): + if isinstance(o, (datetime, date)): return o.isoformat() - if isinstance(o, date): - return o.isoformat() - if hasattr(o, "__str__"): - return str(o) return super().default(o) # Pass the callback data to the child process via environment variables. From 845db573b7e68ba581cc9217123ae204b6ed9c63 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 25 Mar 2026 17:02:52 -0700 Subject: [PATCH 17/19] fix context, use closure for subprocess, and plumb callback logs through to the scheduler --- .../sdk/execution_time/callback_supervisor.py | 100 +++++------------- 1 file changed, 29 insertions(+), 71 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index e10017ee78078..53f423bde7691 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -18,7 +18,6 @@ from __future__ import annotations -import os import time from importlib import import_module from typing import TYPE_CHECKING, BinaryIO, ClassVar, Protocol @@ -34,8 +33,6 @@ ) if TYPE_CHECKING: - from collections.abc import Callable - from structlog.typing import FilteringBoundLogger from typing_extensions import Self @@ -89,18 +86,20 @@ def execute_callback( log.debug("Executing callback %s(%s)...", callback_path, callback_kwargs) - # If the callback is a callable, call it. If it is a class, instantiate it. - result = callback_callable(**callback_kwargs) + kwargs_without_context = {k: v for k, v in callback_kwargs.items() if k != "context"} + + # Call the callable with all kwargs if it accepts context, otherwise strip context. + if accepts_context(callback_callable): + result = callback_callable(**callback_kwargs) + else: + result = callback_callable(**kwargs_without_context) - # If the callback is a class then it is now instantiated and callable, call it. + # If the callback was a class then it is now instantiated and callable, call it. if callable(result): if accepts_context(result): - context = callback_kwargs.get("context", {}) - log.debug("Calling result with context for %s", callback_path) - result = result(context) + result = result(**callback_kwargs) else: - log.debug("Calling result without context for %s", callback_path) - result = result() + result = result(**kwargs_without_context) log.info("Callback %s executed successfully.", callback_path) return True, None @@ -111,37 +110,6 @@ def execute_callback( return False, error_msg -def _callback_subprocess_main(): - """ - Entry point for the callback subprocess, runs after fork. - - Reads the callback path and kwargs from environment variables, - executes the callback, and exits with an appropriate code. - """ - import json - import sys - - log = structlog.get_logger(logger_name="callback_runner") - - callback_path = os.environ.get("_AIRFLOW_CALLBACK_PATH", "") - callback_kwargs_json = os.environ.get("_AIRFLOW_CALLBACK_KWARGS", "{}") - - if not callback_path: - print("No callback path found in environment", file=sys.stderr) - sys.exit(1) - - try: - callback_kwargs = json.loads(callback_kwargs_json) - except Exception: - log.exception("Failed to deserialize callback kwargs") - sys.exit(1) - - success, error_msg = execute_callback(callback_path, callback_kwargs, log) - if not success: - log.error("Callback failed", error=error_msg) - sys.exit(1) - - # An empty message set; the callback subprocess doesn't currently communicate back to the # supervisor. This means callback code cannot access runtime services like Connection.get() # or Variable.get() which require the supervisor to pass requests to the API server. @@ -168,41 +136,30 @@ def start( # type: ignore[override] id: str, callback_path: str, callback_kwargs: dict, - target: Callable[[], None] = _callback_subprocess_main, logger: FilteringBoundLogger | None = None, **kwargs, ) -> Self: """Fork and start a new subprocess to execute the given callback.""" - import json - from datetime import date, datetime from uuid import UUID - class _ExtendedEncoder(json.JSONEncoder): - """Handle types that stdlib json can't serialize (UUID, datetime, etc.).""" - - def default(self, o): - if isinstance(o, UUID): - return str(o) - if isinstance(o, (datetime, date)): - return o.isoformat() - return super().default(o) - - # Pass the callback data to the child process via environment variables. - # These are set before fork so the child inherits them, and cleaned up in the parent after. - os.environ["_AIRFLOW_CALLBACK_PATH"] = callback_path - os.environ["_AIRFLOW_CALLBACK_KWARGS"] = json.dumps(callback_kwargs, cls=_ExtendedEncoder) - try: - proc: Self = super().start( - id=UUID(id) if not isinstance(id, UUID) else id, - target=target, - logger=logger, - **kwargs, - ) - finally: - # Clean up the env vars in the parent process - os.environ.pop("_AIRFLOW_CALLBACK_PATH", None) - os.environ.pop("_AIRFLOW_CALLBACK_KWARGS", None) - return proc + # Use a closure to pass callback data to the child process. Note that this + # ONLY works because WatchedSubprocess.start() uses os.fork(), so the child + # inherits the parent's memory space and the variables are available directly. + def _target(): + import sys + + _log = structlog.get_logger(logger_name="callback_runner") + success, error_msg = execute_callback(callback_path, callback_kwargs, _log) + if not success: + _log.error("Callback failed", error=error_msg) + sys.exit(1) + + return super().start( + id=UUID(id) if not isinstance(id, UUID) else id, + target=_target, + logger=logger, + **kwargs, + ) def wait(self) -> int: """ @@ -325,6 +282,7 @@ def supervise_callback( callback_path=callback_path, callback_kwargs=callback_kwargs, logger=logger, + subprocess_logs_to_stdout=True, ) exit_code = process.wait() From e558b315c1b06956ae4d9d2fb4144da68a43ae8a Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 25 Mar 2026 17:08:27 -0700 Subject: [PATCH 18/19] fix dequeue cleanup --- airflow-core/src/airflow/executors/local_executor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 59ddb54f440f4..c81d69089442c 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -305,9 +305,13 @@ def terminate(self): def _process_workloads(self, workload_list): for workload in workload_list: self.activity_queue.put(workload) - # Remove from the appropriate queue using the workload's key. - self.queued_tasks.pop(workload.key, None) - self.queued_callbacks.pop(workload.key, None) + # A valid workload will exist in exactly one of these dicts. + # One pop will succeed, the other will return None gracefully. + removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop( + workload.key, None + ) + if not removed: + raise KeyError(f"Workload {workload.key} was not found in any queue") with self._unread_messages: self._unread_messages.value += len(workload_list) self._check_workers() From b8052b2f2944c6d252fec3ee1470d487b44336e5 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 25 Mar 2026 21:27:03 -0700 Subject: [PATCH 19/19] Kaxil fixes rnd 2 - move bundle loading into start() - add _make_process_nondumpable to supervise_callback - fix context passing to callback_callable - revert adding *args check to accepts_context - parameterize tests - top-leveled some imports --- .../unit/executors/test_base_executor.py | 46 +++---- .../airflow_shared/module_loading/__init__.py | 6 +- .../sdk/execution_time/callback_supervisor.py | 70 ++++++----- .../test_callback_supervisor.py | 112 ++++++++++-------- 4 files changed, 121 insertions(+), 113 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index dd90f786cb4c0..530c401227966 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -662,35 +662,21 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): class TestExecuteCallbackWorkload: - def test_execute_function_callback_success(self): + @pytest.mark.parametrize( + ("path", "kwargs", "expect_success", "error_contains"), + [ + pytest.param("builtins.dict", {"a": 1, "b": 2, "c": 3}, True, None, id="function_success"), + pytest.param("", {}, False, "Callback path not found", id="missing_path"), + pytest.param("nonexistent.module.function", {}, False, "ModuleNotFoundError", id="import_error"), + pytest.param("builtins.len", {}, False, "TypeError", id="execution_error"), + ], + ) + def test_execute_callback(self, path, kwargs, expect_success, error_contains): log = structlog.get_logger() + success, error = execute_callback(path, kwargs, log) - success, error = execute_callback("builtins.dict", {"a": 1, "b": 2, "c": 3}, log) - - assert success is True - assert error is None - - def test_execute_callback_missing_path(self): - log = structlog.get_logger() - - success, error = execute_callback("", {}, log) - - assert success is False - assert "Callback path not found" in error - - def test_execute_callback_import_error(self): - log = structlog.get_logger() - - success, error = execute_callback("nonexistent.module.function", {}, log) - - assert success is False - assert "ModuleNotFoundError" in error - - def test_execute_callback_execution_error(self): - # Use a function that will raise an error; len() requires an argument - log = structlog.get_logger() - - success, error = execute_callback("builtins.len", {}, log) - - assert success is False - assert "TypeError" in error + assert success is expect_success + if error_contains: + assert error_contains in error + else: + assert error is None diff --git a/shared/module_loading/src/airflow_shared/module_loading/__init__.py b/shared/module_loading/src/airflow_shared/module_loading/__init__.py index 2b6c1b388d37d..2406db56a24f5 100644 --- a/shared/module_loading/src/airflow_shared/module_loading/__init__.py +++ b/shared/module_loading/src/airflow_shared/module_loading/__init__.py @@ -45,15 +45,13 @@ def accepts_context(callback: Callable) -> bool: - """Check if callback accepts a 'context' parameter, *args, or **kwargs.""" + """Check if callback accepts a 'context' parameter or **kwargs.""" try: sig = inspect.signature(callback) except (ValueError, TypeError): return True params = sig.parameters - return "context" in params or any( - p.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) for p in params.values() - ) + return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) def import_string(dotted_path: str): diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 53f423bde7691..f9747f6e82573 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -18,9 +18,11 @@ from __future__ import annotations +import sys import time from importlib import import_module from typing import TYPE_CHECKING, BinaryIO, ClassVar, Protocol +from uuid import UUID import attrs import structlog @@ -30,6 +32,7 @@ MIN_HEARTBEAT_INTERVAL, SOCKET_CLEANUP_TIMEOUT, WatchedSubprocess, + _make_process_nondumpable, ) if TYPE_CHECKING: @@ -95,11 +98,16 @@ def execute_callback( result = callback_callable(**kwargs_without_context) # If the callback was a class then it is now instantiated and callable, call it. + # Try keyword args first. If the callable only accepts positional args (like + # BaseNotifier.__call__(self, *args)), fall back to passing context positionally. if callable(result): - if accepts_context(result): - result = result(**callback_kwargs) - else: - result = result(**kwargs_without_context) + try: + if accepts_context(result): + result = result(**callback_kwargs) + else: + result = result(**kwargs_without_context) + except TypeError: + result = result(callback_kwargs.get("context", {})) log.info("Callback %s executed successfully.", callback_path) return True, None @@ -136,19 +144,42 @@ def start( # type: ignore[override] id: str, callback_path: str, callback_kwargs: dict, + bundle_info: _BundleInfoLike | None = None, logger: FilteringBoundLogger | None = None, **kwargs, ) -> Self: """Fork and start a new subprocess to execute the given callback.""" - from uuid import UUID # Use a closure to pass callback data to the child process. Note that this # ONLY works because WatchedSubprocess.start() uses os.fork(), so the child # inherits the parent's memory space and the variables are available directly. def _target(): - import sys - _log = structlog.get_logger(logger_name="callback_runner") + + # If bundle info is provided, initialize the bundle and ensure its path is importable. + # This is needed for user-defined callbacks that live inside a DAG bundle rather than + # in an installed package or the plugins directory. + if bundle_info and bundle_info.name: + try: + from airflow.dag_processing.bundles.manager import DagBundlesManager + + bundle = DagBundlesManager().get_bundle( + name=bundle_info.name, + version=bundle_info.version, + ) + bundle.initialize() + if (bundle_path := str(bundle.path)) not in sys.path: + sys.path.append(bundle_path) + log.debug( + "Added bundle path to sys.path", bundle_name=bundle_info.name, path=bundle_path + ) + except Exception: + log.warning( + "Failed to initialize DAG bundle for callback", + bundle_name=bundle_info.name, + exc_info=True, + ) + success, error_msg = execute_callback(callback_path, callback_kwargs, _log) if not success: _log.error("Callback failed", error=error_msg) @@ -241,32 +272,10 @@ def supervise_callback( :param bundle_info: When provided, the bundle's path is added to sys.path so callbacks in Dag Bundles are importable. :return: Exit code of the subprocess (0 = success). """ - import sys + _make_process_nondumpable() start = time.monotonic() - # If bundle info is provided, initialize the bundle and ensure its path is importable. - # This is needed for user-defined callbacks that live inside a DAG bundle rather than - # in an installed package or the plugins directory. - if bundle_info and bundle_info.name: - try: - from airflow.dag_processing.bundles.manager import DagBundlesManager - - bundle = DagBundlesManager().get_bundle( - name=bundle_info.name, - version=bundle_info.version, - ) - bundle.initialize() - if (bundle_path := str(bundle.path)) not in sys.path: - sys.path.append(bundle_path) - log.debug("Added bundle path to sys.path", bundle_name=bundle_info.name, path=bundle_path) - except Exception: - log.warning( - "Failed to initialize DAG bundle for callback", - bundle_name=bundle_info.name, - exc_info=True, - ) - logger: FilteringBoundLogger log_file_descriptor: BinaryIO | None = None if log_path: @@ -281,6 +290,7 @@ def supervise_callback( id=id, callback_path=callback_path, callback_kwargs=callback_kwargs, + bundle_info=bundle_info, logger=logger, subprocess_logs_to_stdout=True, ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py index 65457dbfaa28b..93b459d200e87 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -49,53 +49,67 @@ def __call__(self, context): return "notified" -@pytest.fixture -def log(): - return structlog.get_logger() - - class TestExecuteCallback: - def test_successful_callback_no_args(self, log): - success, error = execute_callback(f"{__name__}.callback_no_args", {}, log) - - assert success is True - assert error is None - - def test_successful_callback_with_kwargs(self, log): - success, error = execute_callback( - f"{__name__}.callback_with_kwargs", {"arg1": "hello", "arg2": "world"}, log - ) - - assert success is True - assert error is None - - def test_empty_path_returns_failure(self, log): - success, error = execute_callback("", {}, log) - - assert success is False - assert "Callback path not found" in error - - def test_import_error_returns_failure(self, log): - success, error = execute_callback("nonexistent.module.function", {}, log) - - assert success is False - assert "ModuleNotFoundError" in error - - def test_execution_error_returns_failure(self, log): - success, error = execute_callback(f"{__name__}.callback_that_raises", {}, log) - - assert success is False - assert "ValueError" in error - - def test_callable_class_pattern(self, log): - """Test the class-that-returns-callable pattern (like BaseNotifier).""" - success, error = execute_callback(f"{__name__}.CallableClass", {"msg": "alert"}, log) - - assert success is True - assert error is None - - def test_attribute_error_for_nonexistent_function(self, log): - success, error = execute_callback(f"{__name__}.nonexistent_function_xyz", {}, log) - - assert success is False - assert "AttributeError" in error + @pytest.mark.parametrize( + ("path", "kwargs", "expect_success", "error_contains"), + [ + pytest.param( + f"{__name__}.callback_no_args", + {}, + True, + None, + id="successful_no_args", + ), + pytest.param( + f"{__name__}.callback_with_kwargs", + {"arg1": "hello", "arg2": "world"}, + True, + None, + id="successful_with_kwargs", + ), + pytest.param( + f"{__name__}.CallableClass", + {"msg": "alert"}, + True, + None, + id="callable_class_pattern", + ), + pytest.param( + "", + {}, + False, + "Callback path not found", + id="empty_path", + ), + pytest.param( + "nonexistent.module.function", + {}, + False, + "ModuleNotFoundError", + id="import_error", + ), + pytest.param( + f"{__name__}.callback_that_raises", + {}, + False, + "ValueError", + id="execution_error", + ), + pytest.param( + f"{__name__}.nonexistent_function_xyz", + {}, + False, + "AttributeError", + id="attribute_error", + ), + ], + ) + def test_execute_callback(self, path, kwargs, expect_success, error_contains): + log = structlog.get_logger() + success, error = execute_callback(path, kwargs, log) + + assert success is expect_success + if error_contains: + assert error_contains in error + else: + assert error is None