Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class KubernetesExecutor(BaseExecutor):

RUNNING_POD_LOG_LINES = 100
supports_ad_hoc_ti_run: bool = True
supports_callbacks: bool = True
supports_multi_team: bool = True

if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
Expand Down Expand Up @@ -234,29 +235,39 @@ def execute_async(
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
if isinstance(workload, workloads.ExecuteTask):
self.queued_tasks[workload.ti.key] = workload
elif isinstance(workload, workloads.ExecuteCallback):
self.queued_callbacks[workload.callback.id] = workload
else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask
from airflow.executors.workloads import ExecuteCallback, ExecuteTask
from airflow.utils.state import CallbackState

# Airflow V3 version
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do me a favor and drop this comment while you are in here.

Suggested change
# Airflow V3 version

for w in workloads:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this was existing code, but can you please rename w to workload for me, we should be avoiding one-letter variable names.

if not isinstance(w, ExecuteTask):
if isinstance(w, ExecuteTask):
# TODO: AIP-72 handle populating tokens once https://github.com/apache/airflow/issues/45107 is handled.
command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)
elif isinstance(w, ExecuteCallback):
callback_key = w.callback.key
del self.queued_callbacks[callback_key]
# Put on task_queue for pod creation (no executor_config for callbacks)
self.task_queue.put(KubernetesJob(callback_key, [w], None, None))
self.event_buffer[callback_key] = (CallbackState.QUEUED, None)
self.running.add(callback_key)
else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")

# TODO: AIP-72 handle populating tokens once https://github.com/apache/airflow/issues/45107 is handled.
command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)

def sync(self) -> None:
"""Synchronize task state."""
if TYPE_CHECKING:
Expand Down Expand Up @@ -312,13 +323,16 @@ def sync(self) -> None:
try:
key = task.key
self.kube_scheduler.run_next(task)
self.task_publish_retries.pop(key, None)
if not isinstance(key, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not isinstance(key, str):

Should that be

if isinstance(key, TaskInstanceKey):

both here and below? or perhaps

if not isinstance(key, WorkloadKey):

?

self.task_publish_retries.pop(key, None)
except PodReconciliationError as e:
self.log.exception(
"Pod reconciliation failed, likely due to kubernetes library upgrade. "
"Try clearing the task to re-run.",
)
self.fail(task[0], e)
task_key = task.key
if not isinstance(task_key, str):
self.fail(task_key, e)
except ApiException as e:
try:
if e.body:
Expand All @@ -331,38 +345,43 @@ def sync(self) -> None:
# Use the body directly as the message instead.
body = {"message": e.body}

retries = self.task_publish_retries[key]
# In case of exceeded quota or conflict errors, requeue the task as per the task_publish_max_retries
message = body.get("message", "")
if (
(str(e.status) == "403" and "exceeded quota" in message)
or (str(e.status) == "409" and "object has been modified" in message)
or (str(e.status) == "410" and "too old resource version" in message)
or str(e.status) == "500"
) and (self.task_publish_max_retries == -1 or retries < self.task_publish_max_retries):
self.log.warning(
"[Try %s of %s] Kube ApiException for Task: (%s). Reason: %r. Message: %s",
self.task_publish_retries[key] + 1,
self.task_publish_max_retries,
key,
e.reason,
message,
)
self.task_queue.put(task)
self.task_publish_retries[key] = retries + 1
if not isinstance(key, str):
retries = self.task_publish_retries[key]
# In case of exceeded quota or conflict errors, requeue the task as per the task_publish_max_retries
message = body.get("message", "")
if (
(str(e.status) == "403" and "exceeded quota" in message)
or (str(e.status) == "409" and "object has been modified" in message)
or (str(e.status) == "410" and "too old resource version" in message)
or str(e.status) == "500"
) and (
self.task_publish_max_retries == -1 or retries < self.task_publish_max_retries
):
self.log.warning(
"[Try %s of %s] Kube ApiException for Task: (%s). Reason: %r. Message: %s",
self.task_publish_retries[key] + 1,
self.task_publish_max_retries,
key,
e.reason,
message,
)
self.task_queue.put(task)
self.task_publish_retries[key] = retries + 1
else:
self.log.error("Pod creation failed with reason %r. Failing task", e.reason)
self.fail(key, e)
self.task_publish_retries.pop(key, None)
else:
self.log.error("Pod creation failed with reason %r. Failing task", e.reason)
key = task.key
self.fail(key, e)
self.task_publish_retries.pop(key, None)
self.log.error("Pod creation failed with reason %r.", e.reason)
except PodMutationHookException as e:
key = task.key
self.log.error(
"Pod Mutation Hook failed for the task %s. Failing task. Details: %s",
key,
e.__cause__,
)
self.fail(key, e)
if not isinstance(key, str):
self.fail(key, e)
finally:
self.task_queue.task_done()

Expand All @@ -372,11 +391,17 @@ def _change_state(
results: KubernetesResults,
session: Session = NEW_SESSION,
) -> None:
"""Change state of the task based on KubernetesResults."""
"""Change state of the workload based on KubernetesResults."""
if TYPE_CHECKING:
assert self.kube_scheduler

key = results.key

# Callback results have a string key (CallbackKey = str)
if isinstance(key, str):
self._change_callback_state(results)
return

state = results.state
pod_name = results.pod_name
namespace = results.namespace
Expand Down Expand Up @@ -468,6 +493,50 @@ def _change_state(

self.event_buffer[key] = state, termination_reason

def _change_callback_state(self, results: KubernetesResults) -> None:
"""Change state of a callback based on KubernetesResults."""
from airflow.utils.state import CallbackState

if TYPE_CHECKING:
assert self.kube_scheduler

key = results.key
state = results.state
pod_name = results.pod_name
namespace = results.namespace

if state == ADOPTED:
self.running.discard(key)
return

if state == TaskInstanceState.FAILED:
self.log.warning("Callback %s failed in pod %s/%s", key, namespace, pod_name)

# Clean up pod
if self.kube_config.delete_worker_pods:
if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace)
self.log.info(
"Deleted pod for callback %s. Pod name: %s. Namespace: %s",
key,
pod_name,
namespace,
)
else:
self.kube_scheduler.patch_pod_executor_done(pod_name=pod_name, namespace=namespace)

if key not in self.running:
self.log.debug("Callback key not in running: %s", key)
return
self.running.discard(key)

# Map pod state to CallbackState
if state == TaskInstanceState.FAILED:
self.event_buffer[key] = CallbackState.FAILED, None
else:
# Pod succeeded (state is None for successful pods in K8s executor)
self.event_buffer[key] = CallbackState.SUCCESS, None

def _get_pod_namespace(self, ti: TaskInstance):
pod_override = ti.executor_config.get("pod_override")
namespace = None
Expand Down Expand Up @@ -629,6 +698,9 @@ def adopt_launched_task(

self.log.info("attempting to adopt pod %s", pod.metadata.name)
ti_key = annotations_to_key(pod.metadata.annotations)
if not isinstance(ti_key, tuple):
self.log.debug("Skipping non-task pod in adopt_launched_task: %s", pod.metadata.name)
return
if ti_key not in tis_to_flush_by_key:
self.log.error("attempting to adopt taskinstance which was not specified by database: %s", ti_key)
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from airflow.models.taskinstance import TaskInstanceKey
from airflow.executors.workloads.types import WorkloadKey
from airflow.utils.state import TaskInstanceState


Expand All @@ -43,9 +43,9 @@ class FailureDetails(TypedDict, total=False):


class KubernetesResults(NamedTuple):
"""Results from Kubernetes task execution."""
"""Results from Kubernetes workload execution (task or callback)."""

key: TaskInstanceKey
key: WorkloadKey
state: TaskInstanceState | str | None
pod_name: str
namespace: str
Expand All @@ -69,10 +69,10 @@ class KubernetesWatch(NamedTuple):


class KubernetesJob(NamedTuple):
"""Job definition for Kubernetes execution."""
"""Job definition for Kubernetes execution (task or callback)."""

key: TaskInstanceKey
command: Sequence[str]
key: WorkloadKey
command: Sequence[Any]
kube_executor_config: Any
pod_template_file: str | None

Expand Down
Loading
Loading