From 029eb75137a3a9c800d0447e59276cbd1e1114db Mon Sep 17 00:00:00 2001 From: Kevin Yang Date: Thu, 12 Mar 2026 00:14:18 -0400 Subject: [PATCH 1/4] feat(kubernetes): add executor callback support to KubernetesExecutor Run synchronous executor callbacks (e.g. deadline alerts) as Kubernetes pods, using the same pod pipeline as task execution. Callbacks are dispatched via annotation-based key discrimination in the watcher, and their pod exit code maps to CallbackState.SUCCESS/FAILED. Also extends execute_workload.py (task-sdk) to handle ExecuteCallback workloads inside pods, making it the unified entrypoint for both tasks and callbacks in containerised executors. Co-Authored-By: Claude Sonnet 4.6 --- .../executors/kubernetes_executor.py | 93 +++++++-- .../executors/kubernetes_executor_types.py | 10 +- .../executors/kubernetes_executor_utils.py | 96 +++++++-- .../kubernetes/kubernetes_helper_functions.py | 13 +- .../cncf/kubernetes/pod_generator.py | 4 +- .../executors/test_kubernetes_executor.py | 194 ++++++++++++++++++ .../sdk/execution_time/execute_workload.py | 31 ++- .../execution_time/test_execute_workload.py | 74 +++++++ 8 files changed, 468 insertions(+), 47 deletions(-) create mode 100644 task-sdk/tests/task_sdk/execution_time/test_execute_workload.py diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index a55341a62a8dd..a9f5224cc553b 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -79,6 +79,7 @@ class KubernetesExecutor(BaseExecutor): RUNNING_POD_LOG_LINES = 100 supports_ad_hoc_ti_run: bool = True + supports_callbacks: bool = True supports_multi_team: bool = True if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: @@ -234,29 +235,39 @@ def execute_async( def queue_workload(self, workload: workloads.All, session: Session | None) -> None: from airflow.executors import workloads - if not isinstance(workload, workloads.ExecuteTask): + if isinstance(workload, workloads.ExecuteTask): + self.queued_tasks[workload.ti.key] = workload + elif isinstance(workload, workloads.ExecuteCallback): + self.queued_callbacks[workload.callback.id] = workload + else: raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - ti = workload.ti - self.queued_tasks[ti.key] = workload def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: - from airflow.executors.workloads import ExecuteTask + from airflow.executors.workloads import ExecuteCallback, ExecuteTask + from airflow.utils.state import CallbackState # Airflow V3 version for w in workloads: - if not isinstance(w, ExecuteTask): + if isinstance(w, ExecuteTask): + # TODO: AIP-72 handle populating tokens once https://github.com/apache/airflow/issues/45107 is handled. + command = [w] + key = w.ti.key + queue = w.ti.queue + executor_config = w.ti.executor_config or {} + + del self.queued_tasks[key] + self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) + self.running.add(key) + elif isinstance(w, ExecuteCallback): + callback_key = w.callback.key + del self.queued_callbacks[callback_key] + # Put on task_queue for pod creation (no executor_config for callbacks) + self.task_queue.put(KubernetesJob(callback_key, [w], None, None)) + self.event_buffer[callback_key] = (CallbackState.QUEUED, None) + self.running.add(callback_key) + else: raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}") - # TODO: AIP-72 handle populating tokens once https://github.com/apache/airflow/issues/45107 is handled. - command = [w] - key = w.ti.key - queue = w.ti.queue - executor_config = w.ti.executor_config or {} - - del self.queued_tasks[key] - self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) - self.running.add(key) - def sync(self) -> None: """Synchronize task state.""" if TYPE_CHECKING: @@ -372,11 +383,17 @@ def _change_state( results: KubernetesResults, session: Session = NEW_SESSION, ) -> None: - """Change state of the task based on KubernetesResults.""" + """Change state of the workload based on KubernetesResults.""" if TYPE_CHECKING: assert self.kube_scheduler key = results.key + + # Callback results have a string key (CallbackKey = str) + if isinstance(key, str): + self._change_callback_state(results) + return + state = results.state pod_name = results.pod_name namespace = results.namespace @@ -468,6 +485,50 @@ def _change_state( self.event_buffer[key] = state, termination_reason + def _change_callback_state(self, results: KubernetesResults) -> None: + """Change state of a callback based on KubernetesResults.""" + from airflow.utils.state import CallbackState + + if TYPE_CHECKING: + assert self.kube_scheduler + + key = results.key + state = results.state + pod_name = results.pod_name + namespace = results.namespace + + if state == ADOPTED: + self.running.discard(key) + return + + if state == TaskInstanceState.FAILED: + self.log.warning("Callback %s failed in pod %s/%s", key, namespace, pod_name) + + # Clean up pod + if self.kube_config.delete_worker_pods: + if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure: + self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace) + self.log.info( + "Deleted pod for callback %s. Pod name: %s. Namespace: %s", + key, + pod_name, + namespace, + ) + else: + self.kube_scheduler.patch_pod_executor_done(pod_name=pod_name, namespace=namespace) + + if key not in self.running: + self.log.debug("Callback key not in running: %s", key) + return + self.running.discard(key) + + # Map pod state to CallbackState + if state == TaskInstanceState.FAILED: + self.event_buffer[key] = CallbackState.FAILED, None + else: + # Pod succeeded (state is None for successful pods in K8s executor) + self.event_buffer[key] = CallbackState.SUCCESS, None + def _get_pod_namespace(self, ti: TaskInstance): pod_override = ti.executor_config.get("pod_override") namespace = None diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py index f8e03f1f04c93..269237bf772a2 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - from airflow.models.taskinstance import TaskInstanceKey + from airflow.executors.workloads.types import WorkloadKey from airflow.utils.state import TaskInstanceState @@ -43,9 +43,9 @@ class FailureDetails(TypedDict, total=False): class KubernetesResults(NamedTuple): - """Results from Kubernetes task execution.""" + """Results from Kubernetes workload execution (task or callback).""" - key: TaskInstanceKey + key: WorkloadKey state: TaskInstanceState | str | None pod_name: str namespace: str @@ -69,9 +69,9 @@ class KubernetesWatch(NamedTuple): class KubernetesJob(NamedTuple): - """Job definition for Kubernetes execution.""" + """Job definition for Kubernetes execution (task or callback).""" - key: TaskInstanceKey + key: WorkloadKey command: Sequence[str] kube_executor_config: Any pod_template_file: str | None diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index ed31b43bf2544..967f145c169e5 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -160,22 +160,27 @@ def _run( if event["type"] == "ERROR": return self.process_error(event) annotations = task.metadata.annotations - task_instance_related_annotations = { - "dag_id": annotations["dag_id"], - "task_id": annotations["task_id"], - logical_date_key: annotations.get(logical_date_key), - "run_id": annotations.get("run_id"), - "try_number": annotations["try_number"], - } - map_index = annotations.get("map_index") - if map_index is not None: - task_instance_related_annotations["map_index"] = map_index + + # Callback pods have a "callback_id" annotation instead of task annotations + if annotations.get("callback_id"): + relevant_annotations = {"callback_id": annotations["callback_id"]} + else: + relevant_annotations = { + "dag_id": annotations["dag_id"], + "task_id": annotations["task_id"], + logical_date_key: annotations.get(logical_date_key), + "run_id": annotations.get("run_id"), + "try_number": annotations["try_number"], + } + map_index = annotations.get("map_index") + if map_index is not None: + relevant_annotations["map_index"] = map_index self.process_status( pod_name=task.metadata.name, namespace=task.metadata.namespace, status=task.status.phase, - annotations=task_instance_related_annotations, + annotations=relevant_annotations, resource_version=task.metadata.resource_version, event=event, ) @@ -552,10 +557,11 @@ def run_next(self, next_job: KubernetesJob) -> None: kube_executor_config = next_job.kube_executor_config pod_template_file = next_job.pod_template_file - dag_id, task_id, run_id, try_number, map_index = key if len(command) == 1: - from airflow.executors.workloads import ExecuteTask + from airflow.executors.workloads import ExecuteCallback, ExecuteTask + if isinstance(command[0], ExecuteCallback): + return self._run_next_callback(next_job) if isinstance(command[0], ExecuteTask): workload = command[0] command = workload_to_command_args(workload) @@ -566,6 +572,8 @@ def run_next(self, next_job: KubernetesJob) -> None: elif command[0:3] != ["airflow", "tasks", "run"]: raise ValueError('The command must start with ["airflow", "tasks", "run"].') + dag_id, task_id, run_id, try_number, map_index = key + base_worker_pod = get_base_pod_from_template(pod_template_file, self.kube_config) if not base_worker_pod: @@ -604,6 +612,68 @@ def run_next(self, next_job: KubernetesJob) -> None: self.run_pod_async(pod, **self.kube_config.kube_client_request_args) self.log.debug("Kubernetes Job created!") + def _run_next_callback(self, next_job: KubernetesJob) -> None: + """Build and create a pod for an ExecuteCallback workload.""" + from kubernetes.client import models as k8s + + from airflow.providers.cncf.kubernetes.exceptions import PodMutationHookException + + callback_workload = next_job.command[0] + callback_id = callback_workload.callback.id + + command = workload_to_command_args(callback_workload) + base_worker_pod = get_base_pod_from_template(None, self.kube_config) + + if not base_worker_pod: + raise AirflowException( + f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}" + ) + + pod_id = create_unique_id("callback", callback_id[:8]) + + dynamic_pod = k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + namespace=self.namespace, + name=pod_id, + annotations={"callback_id": callback_id}, + labels=PodGenerator.build_labels_for_k8s_executor_pod( + dag_id="__callback__", + task_id=callback_id[:8], + try_number=1, + airflow_worker=self.scheduler_job_id, + ), + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name="base", + image=self.kube_config.kube_image, + args=list(command), + env=[k8s.V1EnvVar(name="AIRFLOW_IS_K8S_EXECUTOR_POD", value="True")], + ) + ], + ), + ) + + pod = PodGenerator.reconcile_pods(base_worker_pod, dynamic_pod) + + from airflow.settings import pod_mutation_hook + + try: + pod_mutation_hook(pod) + except Exception as e: + raise PodMutationHookException from e + + self.log.info( + "Creating callback pod %s for callback %s", + pod.metadata.name, + callback_id, + ) + self.log.debug("Kubernetes launching callback image %s", pod.spec.containers[0].image) + + self.run_pod_async(pod, **self.kube_config.kube_client_request_args) + self.log.debug("Kubernetes callback pod created!") + def delete_pod(self, pod_name: str, namespace: str) -> None: """Delete Pod from a namespace; does not raise if it does not exist.""" try: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py index 4dff47c4b1d78..da92538b81c4b 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py @@ -35,7 +35,7 @@ from airflow.providers.common.compat.sdk import AirflowException if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.executors.workloads.types import WorkloadKey log = logging.getLogger(__name__) @@ -155,9 +155,14 @@ def create_unique_id( return base_name -def annotations_to_key(annotations: dict[str, str]) -> TaskInstanceKey: - """Build a TaskInstanceKey based on pod annotations.""" - log.debug("Creating task key for annotations %s", annotations) +def annotations_to_key(annotations: dict[str, str]) -> WorkloadKey: + """Build a WorkloadKey (TaskInstanceKey or CallbackKey) based on pod annotations.""" + log.debug("Creating key for annotations %s", annotations) + + # Callback pods have a "callback_id" annotation instead of task annotations + if "callback_id" in annotations: + return annotations["callback_id"] + dag_id = annotations["dag_id"] task_id = annotations["task_id"] try_number = int(annotations["try_number"]) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py index 7e5dc728d775f..322686df6d3f7 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py @@ -62,11 +62,11 @@ MAX_LABEL_LEN = 63 -def workload_to_command_args(workload: workloads.ExecuteTask) -> list[str]: +def workload_to_command_args(workload: workloads.ExecuteTask | workloads.ExecuteCallback) -> list[str]: """ Convert a workload object to Task SDK command arguments. - :param workload: The ExecuteTask workload to convert + :param workload: The workload to convert (ExecuteTask or ExecuteCallback) :return: List of command arguments for the Task SDK """ ser_input = workload.model_dump_json() diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py index af4a503c66cab..6c35ca30075a9 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py @@ -36,6 +36,7 @@ ) from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import ( ADOPTED, + KubernetesJob, KubernetesResults, KubernetesWatch, ) @@ -2095,3 +2096,196 @@ def test_get_pod_namespace_uses_instance_conf(self, monkeypatch): assert namespace == "team-a-ns" finally: executor.end() + + +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Callbacks require Airflow 3.0+") +class TestKubernetesExecutorCallbackSupport: + """Tests for executor callback support in KubernetesExecutor.""" + + @staticmethod + def _make_callback_workload(callback_id="12345678-1234-5678-1234-567812345678"): + from airflow.executors import workloads + from airflow.executors.workloads.base import BundleInfo + from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod + + callback_data = CallbackDTO( + id=callback_id, + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.module.my_callback", "kwargs": {}}, + ) + return workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test_dag.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="executor_callbacks/test_dag/run_1/12345678", + ) + + def test_supports_callbacks_flag_is_true(self): + assert KubernetesExecutor.supports_callbacks is True + + def test_queue_workload_with_callback(self): + executor = KubernetesExecutor() + workload = self._make_callback_workload() + + executor.queue_workload(workload, session=None) + + assert workload.callback.id in executor.queued_callbacks + assert executor.queued_callbacks[workload.callback.id] is workload + + @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_process_callback_workload(self, mock_get_kube_client, mock_kubernetes_job_watcher): + executor = KubernetesExecutor() + executor.job_id = 5 + executor.start() + + try: + workload = self._make_callback_workload() + callback_key = workload.callback.key + + executor.queued_callbacks[callback_key] = workload + executor._process_workloads([workload]) + + # Callback should be removed from queued_callbacks + assert callback_key not in executor.queued_callbacks + # Callback should be added to running + assert callback_key in executor.running + # Should be on the task_queue for pod creation + assert not executor.task_queue.empty() + finally: + executor.end() + + @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + @mock.patch( + "airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.delete_pod" + ) + def test_change_state_callback_success( + self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher + ): + from airflow.utils.state import CallbackState + + executor = KubernetesExecutor() + executor.job_id = 5 + executor.start() + + try: + callback_key = "12345678-1234-5678-1234-567812345678" + executor.running = {callback_key} + + # Pod succeeded (state=None means success in K8s executor) + results = KubernetesResults( + callback_key, None, "callback-pod", "default", "resource_version", None + ) + executor._change_state(results) + + assert executor.event_buffer[callback_key] == (CallbackState.SUCCESS, None) + assert callback_key not in executor.running + mock_delete_pod.assert_called_once_with(pod_name="callback-pod", namespace="default") + finally: + executor.end() + + @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + @mock.patch( + "airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.delete_pod" + ) + def test_change_state_callback_failure( + self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher + ): + from airflow.utils.state import CallbackState + + executor = KubernetesExecutor() + executor.job_id = 5 + executor.start() + + try: + callback_key = "12345678-1234-5678-1234-567812345678" + executor.running = {callback_key} + + results = KubernetesResults( + callback_key, + TaskInstanceState.FAILED, + "callback-pod", + "default", + "resource_version", + None, + ) + executor._change_state(results) + + assert executor.event_buffer[callback_key] == (CallbackState.FAILED, None) + assert callback_key not in executor.running + finally: + executor.end() + + @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_change_state_callback_adopted(self, mock_get_kube_client, mock_kubernetes_job_watcher): + executor = KubernetesExecutor() + executor.job_id = 5 + executor.start() + + try: + callback_key = "12345678-1234-5678-1234-567812345678" + executor.running = {callback_key} + + results = KubernetesResults( + callback_key, ADOPTED, "callback-pod", "default", "resource_version", None + ) + executor._change_state(results) + + assert len(executor.event_buffer) == 0 + assert callback_key not in executor.running + finally: + executor.end() + + def test_annotations_to_key_for_callback(self): + """Test that annotations_to_key returns callback_id for callback pods.""" + callback_id = "12345678-1234-5678-1234-567812345678" + annotations = {"callback_id": callback_id} + + key = annotations_to_key(annotations) + + assert key == callback_id + assert isinstance(key, str) + + @mock.patch( + "airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils" + ".AirflowKubernetesScheduler.run_pod_async" + ) + @mock.patch("airflow.settings.pod_mutation_hook") + @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher") + @mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_run_next_callback( + self, mock_get_kube_client, mock_kubernetes_job_watcher, mock_mutation_hook, mock_run_pod_async + ): + executor = KubernetesExecutor() + executor.job_id = 5 + executor.start() + + try: + workload = self._make_callback_workload() + callback_key = workload.callback.key + job = KubernetesJob(callback_key, [workload], None, None) + + executor.kube_scheduler.run_next(job) + + # Pod should have been created + mock_run_pod_async.assert_called_once() + created_pod = mock_run_pod_async.call_args[0][0] + + # Verify callback-specific annotations + assert created_pod.metadata.annotations["callback_id"] == workload.callback.id + + # Verify pod has airflow-worker label for watcher discovery + assert "airflow-worker" in created_pod.metadata.labels + assert created_pod.metadata.labels["kubernetes_executor"] == "True" + + # Verify pod command runs execute_workload with callback JSON + container_args = created_pod.spec.containers[0].args + assert container_args[0] == "python" + assert "execute_workload" in container_args[2] + assert "ExecuteCallback" in container_args[4] + finally: + executor.end() diff --git a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py index 410c676eeb913..78022ef943f71 100644 --- a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py +++ b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py @@ -16,7 +16,7 @@ # under the License. """ -Module for executing an Airflow task using the workload json provided by a input file. +Module for executing an Airflow workload (task or callback) using the workload json provided by an input file. Usage: python execute_workload.py @@ -34,15 +34,13 @@ import structlog if TYPE_CHECKING: - from airflow.executors.workloads import ExecuteTask + from airflow.executors import workloads log = structlog.get_logger(logger_name=__name__) -def execute_workload(workload: ExecuteTask) -> None: +def execute_workload(workload: workloads.All) -> None: from airflow.executors import workloads - from airflow.sdk.configuration import conf - from airflow.sdk.execution_time.supervisor import supervise from airflow.sdk.log import configure_logging from airflow.settings import dispose_orm @@ -50,10 +48,19 @@ def execute_workload(workload: ExecuteTask) -> None: configure_logging(output=sys.stdout.buffer, json_output=True) - if not isinstance(workload, workloads.ExecuteTask): + if isinstance(workload, workloads.ExecuteTask): + _execute_task(workload) + elif isinstance(workload, workloads.ExecuteCallback): + _execute_callback(workload) + else: raise ValueError(f"Executor does not know how to handle {type(workload)}") - log.info("Executing workload", workload=workload) + +def _execute_task(workload: workloads.ExecuteTask) -> None: + from airflow.sdk.configuration import conf + from airflow.sdk.execution_time.supervisor import supervise + + log.info("Executing task workload", workload=workload) base_url = conf.get("api", "base_url", fallback="/") # If it's a relative URL, use localhost:8080 as the default @@ -78,6 +85,16 @@ def execute_workload(workload: ExecuteTask) -> None: ) +def _execute_callback(workload: workloads.ExecuteCallback) -> None: + from airflow.executors.workloads.callback import execute_callback_workload + + log.info("Executing callback workload", callback_id=workload.callback.id) + + success, error = execute_callback_workload(workload.callback, log) + if not success: + raise RuntimeError(f"Callback execution failed: {error}") + + def main(): parser = argparse.ArgumentParser( description="Execute a workload in a Containerised executor using the task SDK." diff --git a/task-sdk/tests/task_sdk/execution_time/test_execute_workload.py b/task-sdk/tests/task_sdk/execution_time/test_execute_workload.py new file mode 100644 index 0000000000000..1f95b54d9e21c --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_execute_workload.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.executors import workloads +from airflow.executors.workloads.base import BundleInfo +from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod + + +class TestExecuteWorkloadCallback: + """Tests for callback handling in execute_workload.""" + + @staticmethod + def _make_callback_workload(): + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.module.my_callback", "kwargs": {}}, + ) + return workloads.ExecuteCallback( + callback=callback_data, + dag_rel_path="test_dag.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="executor_callbacks/test_dag/run_1/12345678", + ) + + @mock.patch("airflow.executors.workloads.callback.execute_callback_workload") + @mock.patch("airflow.settings.dispose_orm") + def test_execute_workload_handles_callback(self, mock_dispose_orm, mock_execute_callback): + from airflow.sdk.execution_time.execute_workload import execute_workload + + mock_execute_callback.return_value = (True, None) + + workload = self._make_callback_workload() + execute_workload(workload) + + mock_execute_callback.assert_called_once_with(workload.callback, mock.ANY) + + @mock.patch("airflow.executors.workloads.callback.execute_callback_workload") + @mock.patch("airflow.settings.dispose_orm") + def test_execute_workload_callback_failure_raises(self, mock_dispose_orm, mock_execute_callback): + from airflow.sdk.execution_time.execute_workload import execute_workload + + mock_execute_callback.return_value = (False, "Something went wrong") + + workload = self._make_callback_workload() + with pytest.raises(RuntimeError, match="Callback execution failed"): + execute_workload(workload) + + @mock.patch("airflow.settings.dispose_orm") + def test_execute_workload_rejects_unknown_type(self, mock_dispose_orm): + from airflow.sdk.execution_time.execute_workload import execute_workload + + with pytest.raises(ValueError, match="does not know how to handle"): + execute_workload("not_a_workload") # type: ignore[arg-type] From 1666c03f40d117b12a74d6bb0b9aff0e31c0cbbd Mon Sep 17 00:00:00 2001 From: Kevin Yang Date: Thu, 12 Mar 2026 00:28:46 -0400 Subject: [PATCH 2/4] fix type issue --- .../cncf/kubernetes/executors/kubernetes_executor_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py index 269237bf772a2..6c98f6887b2dc 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py @@ -72,7 +72,7 @@ class KubernetesJob(NamedTuple): """Job definition for Kubernetes execution (task or callback).""" key: WorkloadKey - command: Sequence[str] + command: Sequence[Any] kube_executor_config: Any pod_template_file: str | None From d4f1d4478a527bac8a2c9b273e148500920ee878 Mon Sep 17 00:00:00 2001 From: Kevin Yang Date: Thu, 12 Mar 2026 09:20:40 -0400 Subject: [PATCH 3/4] fix mypy error --- .../providers/cncf/kubernetes/executors/kubernetes_executor.py | 3 +++ .../cncf/kubernetes/executors/kubernetes_executor_utils.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index a9f5224cc553b..662974f43b1c6 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -690,6 +690,9 @@ def adopt_launched_task( self.log.info("attempting to adopt pod %s", pod.metadata.name) ti_key = annotations_to_key(pod.metadata.annotations) + if not isinstance(ti_key, tuple): + self.log.debug("Skipping non-task pod in adopt_launched_task: %s", pod.metadata.name) + return if ti_key not in tis_to_flush_by_key: self.log.error("attempting to adopt taskinstance which was not specified by database: %s", ti_key) return diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index 967f145c169e5..c5295d4ea08e2 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -572,6 +572,8 @@ def run_next(self, next_job: KubernetesJob) -> None: elif command[0:3] != ["airflow", "tasks", "run"]: raise ValueError('The command must start with ["airflow", "tasks", "run"].') + if not isinstance(key, tuple): + raise ValueError(f"Expected a TaskInstanceKey for task workload, got: {type(key)}") dag_id, task_id, run_id, try_number, map_index = key base_worker_pod = get_base_pod_from_template(pod_template_file, self.kube_config) From 37e77225adf0bdb4bb4a8bfe5f2446c196d5095d Mon Sep 17 00:00:00 2001 From: Kevin Yang Date: Thu, 12 Mar 2026 10:57:50 -0400 Subject: [PATCH 4/4] fix mypy errors --- .../executors/kubernetes_executor.py | 60 +++++++++++-------- .../executors/kubernetes_executor_utils.py | 12 +++- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 662974f43b1c6..6d93963835f69 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -323,13 +323,16 @@ def sync(self) -> None: try: key = task.key self.kube_scheduler.run_next(task) - self.task_publish_retries.pop(key, None) + if not isinstance(key, str): + self.task_publish_retries.pop(key, None) except PodReconciliationError as e: self.log.exception( "Pod reconciliation failed, likely due to kubernetes library upgrade. " "Try clearing the task to re-run.", ) - self.fail(task[0], e) + task_key = task.key + if not isinstance(task_key, str): + self.fail(task_key, e) except ApiException as e: try: if e.body: @@ -342,30 +345,34 @@ def sync(self) -> None: # Use the body directly as the message instead. body = {"message": e.body} - retries = self.task_publish_retries[key] - # In case of exceeded quota or conflict errors, requeue the task as per the task_publish_max_retries - message = body.get("message", "") - if ( - (str(e.status) == "403" and "exceeded quota" in message) - or (str(e.status) == "409" and "object has been modified" in message) - or (str(e.status) == "410" and "too old resource version" in message) - or str(e.status) == "500" - ) and (self.task_publish_max_retries == -1 or retries < self.task_publish_max_retries): - self.log.warning( - "[Try %s of %s] Kube ApiException for Task: (%s). Reason: %r. Message: %s", - self.task_publish_retries[key] + 1, - self.task_publish_max_retries, - key, - e.reason, - message, - ) - self.task_queue.put(task) - self.task_publish_retries[key] = retries + 1 + if not isinstance(key, str): + retries = self.task_publish_retries[key] + # In case of exceeded quota or conflict errors, requeue the task as per the task_publish_max_retries + message = body.get("message", "") + if ( + (str(e.status) == "403" and "exceeded quota" in message) + or (str(e.status) == "409" and "object has been modified" in message) + or (str(e.status) == "410" and "too old resource version" in message) + or str(e.status) == "500" + ) and ( + self.task_publish_max_retries == -1 or retries < self.task_publish_max_retries + ): + self.log.warning( + "[Try %s of %s] Kube ApiException for Task: (%s). Reason: %r. Message: %s", + self.task_publish_retries[key] + 1, + self.task_publish_max_retries, + key, + e.reason, + message, + ) + self.task_queue.put(task) + self.task_publish_retries[key] = retries + 1 + else: + self.log.error("Pod creation failed with reason %r. Failing task", e.reason) + self.fail(key, e) + self.task_publish_retries.pop(key, None) else: - self.log.error("Pod creation failed with reason %r. Failing task", e.reason) - key = task.key - self.fail(key, e) - self.task_publish_retries.pop(key, None) + self.log.error("Pod creation failed with reason %r.", e.reason) except PodMutationHookException as e: key = task.key self.log.error( @@ -373,7 +380,8 @@ def sync(self) -> None: key, e.__cause__, ) - self.fail(key, e) + if not isinstance(key, str): + self.fail(key, e) finally: self.task_queue.task_done() diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index c5295d4ea08e2..652b7185ef336 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -230,7 +230,9 @@ def process_status( elif hasattr(pod.status, "reason") and pod.status.reason == "ProviderFailed": # Most likely this happens due to Kubernetes setup (virtual kubelet, virtual nodes, etc.) key = annotations_to_key(annotations=annotations) - task_key_str = f"{key.dag_id}.{key.task_id}.{key.try_number}" if key else "unknown" + task_key_str = ( + f"{key.dag_id}.{key.task_id}.{key.try_number}" if not isinstance(key, str) else "unknown" + ) self.log.warning( "Event: %s failed to start with reason ProviderFailed, task: %s, annotations: %s", pod_name, @@ -275,7 +277,9 @@ def process_status( continue key = annotations_to_key(annotations=annotations) task_key_str = ( - f"{key.dag_id}.{key.task_id}.{key.try_number}" if key else "unknown" + f"{key.dag_id}.{key.task_id}.{key.try_number}" + if not isinstance(key, str) + else "unknown" ) self.log.warning( "Event: %s has container %s with fatal reason %s, task: %s", @@ -309,7 +313,9 @@ def process_status( ) key = annotations_to_key(annotations=annotations) - task_key_str = f"{key.dag_id}.{key.task_id}.{key.try_number}" if key else "unknown" + task_key_str = ( + f"{key.dag_id}.{key.task_id}.{key.try_number}" if not isinstance(key, str) else "unknown" + ) self.log.warning( "Event: %s Failed, task: %s, annotations: %s", pod_name, task_key_str, annotations_string )