-
Notifications
You must be signed in to change notification settings - Fork 16.8k
feat(kubernetes): add executor callback support to KubernetesExecutor #63454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,6 +79,7 @@ class KubernetesExecutor(BaseExecutor): | |
|
|
||
| RUNNING_POD_LOG_LINES = 100 | ||
| supports_ad_hoc_ti_run: bool = True | ||
| supports_callbacks: bool = True | ||
| supports_multi_team: bool = True | ||
|
|
||
| if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: | ||
|
|
@@ -234,29 +235,39 @@ def execute_async( | |
| def queue_workload(self, workload: workloads.All, session: Session | None) -> None: | ||
| from airflow.executors import workloads | ||
|
|
||
| if not isinstance(workload, workloads.ExecuteTask): | ||
| if isinstance(workload, workloads.ExecuteTask): | ||
| self.queued_tasks[workload.ti.key] = workload | ||
| elif isinstance(workload, workloads.ExecuteCallback): | ||
| self.queued_callbacks[workload.callback.id] = workload | ||
| else: | ||
| raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") | ||
| ti = workload.ti | ||
| self.queued_tasks[ti.key] = workload | ||
|
|
||
| def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: | ||
| from airflow.executors.workloads import ExecuteTask | ||
| from airflow.executors.workloads import ExecuteCallback, ExecuteTask | ||
| from airflow.utils.state import CallbackState | ||
|
|
||
| # Airflow V3 version | ||
| for w in workloads: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this was existing code, but can you please rename |
||
| if not isinstance(w, ExecuteTask): | ||
| if isinstance(w, ExecuteTask): | ||
| # TODO: AIP-72 handle populating tokens once https://github.com/apache/airflow/issues/45107 is handled. | ||
| command = [w] | ||
| key = w.ti.key | ||
| queue = w.ti.queue | ||
| executor_config = w.ti.executor_config or {} | ||
|
|
||
| del self.queued_tasks[key] | ||
| self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) | ||
| self.running.add(key) | ||
| elif isinstance(w, ExecuteCallback): | ||
| callback_key = w.callback.key | ||
| del self.queued_callbacks[callback_key] | ||
| # Put on task_queue for pod creation (no executor_config for callbacks) | ||
| self.task_queue.put(KubernetesJob(callback_key, [w], None, None)) | ||
| self.event_buffer[callback_key] = (CallbackState.QUEUED, None) | ||
| self.running.add(callback_key) | ||
| else: | ||
| raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}") | ||
|
|
||
| # TODO: AIP-72 handle populating tokens once https://github.com/apache/airflow/issues/45107 is handled. | ||
| command = [w] | ||
| key = w.ti.key | ||
| queue = w.ti.queue | ||
| executor_config = w.ti.executor_config or {} | ||
|
|
||
| del self.queued_tasks[key] | ||
| self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) | ||
| self.running.add(key) | ||
|
|
||
| def sync(self) -> None: | ||
| """Synchronize task state.""" | ||
| if TYPE_CHECKING: | ||
|
|
@@ -312,13 +323,16 @@ def sync(self) -> None: | |
| try: | ||
| key = task.key | ||
| self.kube_scheduler.run_next(task) | ||
| self.task_publish_retries.pop(key, None) | ||
| if not isinstance(key, str): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should that be both here and below? or perhaps ? |
||
| self.task_publish_retries.pop(key, None) | ||
| except PodReconciliationError as e: | ||
| self.log.exception( | ||
| "Pod reconciliation failed, likely due to kubernetes library upgrade. " | ||
| "Try clearing the task to re-run.", | ||
| ) | ||
| self.fail(task[0], e) | ||
| task_key = task.key | ||
| if not isinstance(task_key, str): | ||
| self.fail(task_key, e) | ||
| except ApiException as e: | ||
| try: | ||
| if e.body: | ||
|
|
@@ -331,38 +345,43 @@ def sync(self) -> None: | |
| # Use the body directly as the message instead. | ||
| body = {"message": e.body} | ||
|
|
||
| retries = self.task_publish_retries[key] | ||
| # In case of exceeded quota or conflict errors, requeue the task as per the task_publish_max_retries | ||
| message = body.get("message", "") | ||
| if ( | ||
| (str(e.status) == "403" and "exceeded quota" in message) | ||
| or (str(e.status) == "409" and "object has been modified" in message) | ||
| or (str(e.status) == "410" and "too old resource version" in message) | ||
| or str(e.status) == "500" | ||
| ) and (self.task_publish_max_retries == -1 or retries < self.task_publish_max_retries): | ||
| self.log.warning( | ||
| "[Try %s of %s] Kube ApiException for Task: (%s). Reason: %r. Message: %s", | ||
| self.task_publish_retries[key] + 1, | ||
| self.task_publish_max_retries, | ||
| key, | ||
| e.reason, | ||
| message, | ||
| ) | ||
| self.task_queue.put(task) | ||
| self.task_publish_retries[key] = retries + 1 | ||
| if not isinstance(key, str): | ||
| retries = self.task_publish_retries[key] | ||
| # In case of exceeded quota or conflict errors, requeue the task as per the task_publish_max_retries | ||
| message = body.get("message", "") | ||
| if ( | ||
| (str(e.status) == "403" and "exceeded quota" in message) | ||
| or (str(e.status) == "409" and "object has been modified" in message) | ||
| or (str(e.status) == "410" and "too old resource version" in message) | ||
| or str(e.status) == "500" | ||
| ) and ( | ||
| self.task_publish_max_retries == -1 or retries < self.task_publish_max_retries | ||
| ): | ||
| self.log.warning( | ||
| "[Try %s of %s] Kube ApiException for Task: (%s). Reason: %r. Message: %s", | ||
| self.task_publish_retries[key] + 1, | ||
| self.task_publish_max_retries, | ||
| key, | ||
| e.reason, | ||
| message, | ||
| ) | ||
| self.task_queue.put(task) | ||
| self.task_publish_retries[key] = retries + 1 | ||
| else: | ||
| self.log.error("Pod creation failed with reason %r. Failing task", e.reason) | ||
| self.fail(key, e) | ||
| self.task_publish_retries.pop(key, None) | ||
| else: | ||
| self.log.error("Pod creation failed with reason %r. Failing task", e.reason) | ||
| key = task.key | ||
| self.fail(key, e) | ||
| self.task_publish_retries.pop(key, None) | ||
| self.log.error("Pod creation failed with reason %r.", e.reason) | ||
| except PodMutationHookException as e: | ||
| key = task.key | ||
| self.log.error( | ||
| "Pod Mutation Hook failed for the task %s. Failing task. Details: %s", | ||
| key, | ||
| e.__cause__, | ||
| ) | ||
| self.fail(key, e) | ||
| if not isinstance(key, str): | ||
| self.fail(key, e) | ||
| finally: | ||
| self.task_queue.task_done() | ||
|
|
||
|
|
@@ -372,11 +391,17 @@ def _change_state( | |
| results: KubernetesResults, | ||
| session: Session = NEW_SESSION, | ||
| ) -> None: | ||
| """Change state of the task based on KubernetesResults.""" | ||
| """Change state of the workload based on KubernetesResults.""" | ||
| if TYPE_CHECKING: | ||
| assert self.kube_scheduler | ||
|
|
||
| key = results.key | ||
|
|
||
| # Callback results have a string key (CallbackKey = str) | ||
| if isinstance(key, str): | ||
| self._change_callback_state(results) | ||
| return | ||
|
|
||
| state = results.state | ||
| pod_name = results.pod_name | ||
| namespace = results.namespace | ||
|
|
@@ -468,6 +493,50 @@ def _change_state( | |
|
|
||
| self.event_buffer[key] = state, termination_reason | ||
|
|
||
| def _change_callback_state(self, results: KubernetesResults) -> None: | ||
| """Change state of a callback based on KubernetesResults.""" | ||
| from airflow.utils.state import CallbackState | ||
|
|
||
| if TYPE_CHECKING: | ||
| assert self.kube_scheduler | ||
|
|
||
| key = results.key | ||
| state = results.state | ||
| pod_name = results.pod_name | ||
| namespace = results.namespace | ||
|
|
||
| if state == ADOPTED: | ||
| self.running.discard(key) | ||
| return | ||
|
|
||
| if state == TaskInstanceState.FAILED: | ||
| self.log.warning("Callback %s failed in pod %s/%s", key, namespace, pod_name) | ||
|
|
||
| # Clean up pod | ||
| if self.kube_config.delete_worker_pods: | ||
| if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure: | ||
| self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace) | ||
| self.log.info( | ||
| "Deleted pod for callback %s. Pod name: %s. Namespace: %s", | ||
| key, | ||
| pod_name, | ||
| namespace, | ||
| ) | ||
| else: | ||
| self.kube_scheduler.patch_pod_executor_done(pod_name=pod_name, namespace=namespace) | ||
|
|
||
| if key not in self.running: | ||
| self.log.debug("Callback key not in running: %s", key) | ||
| return | ||
| self.running.discard(key) | ||
|
|
||
| # Map pod state to CallbackState | ||
| if state == TaskInstanceState.FAILED: | ||
| self.event_buffer[key] = CallbackState.FAILED, None | ||
| else: | ||
| # Pod succeeded (state is None for successful pods in K8s executor) | ||
| self.event_buffer[key] = CallbackState.SUCCESS, None | ||
|
|
||
| def _get_pod_namespace(self, ti: TaskInstance): | ||
| pod_override = ti.executor_config.get("pod_override") | ||
| namespace = None | ||
|
|
@@ -629,6 +698,9 @@ def adopt_launched_task( | |
|
|
||
| self.log.info("attempting to adopt pod %s", pod.metadata.name) | ||
| ti_key = annotations_to_key(pod.metadata.annotations) | ||
| if not isinstance(ti_key, tuple): | ||
| self.log.debug("Skipping non-task pod in adopt_launched_task: %s", pod.metadata.name) | ||
| return | ||
| if ti_key not in tis_to_flush_by_key: | ||
| self.log.error("attempting to adopt taskinstance which was not specified by database: %s", ti_key) | ||
| return | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do me a favor and drop this comment while you are in here.