-
Notifications
You must be signed in to change notification settings - Fork 16.8k
feat: add callback support to aws batch executor #62984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| """AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch Job.""" | ||
| """AWS Batch Executor. Each Airflow workload gets delegated out to an AWS Batch Job.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
|
|
@@ -33,7 +33,7 @@ | |
| exponential_backoff_retry, | ||
| ) | ||
| from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook | ||
| from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS | ||
| from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS | ||
| from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone | ||
| from airflow.utils.helpers import merge_dicts | ||
|
|
||
|
|
@@ -88,6 +88,7 @@ class AwsBatchExecutor(BaseExecutor): | |
| """ | ||
|
|
||
| supports_multi_team: bool = True | ||
| supports_callbacks: bool = True | ||
|
|
||
| # AWS only allows a maximum number of JOBs in the describe_jobs function | ||
| DESCRIBE_JOBS_BATCH_SIZE = 99 | ||
|
|
@@ -127,26 +128,45 @@ def __init__(self, *args, **kwargs): | |
| def queue_workload(self, workload: workloads.All, session: Session | None) -> None: | ||
| from airflow.executors import workloads | ||
|
|
||
| if not isinstance(workload, workloads.ExecuteTask): | ||
| if AIRFLOW_V_3_2_PLUS and isinstance(workload, workloads.ExecuteCallback): | ||
| self.queued_callbacks[workload.callback.id] = workload | ||
| elif isinstance(workload, workloads.ExecuteTask): | ||
| ti = workload.ti | ||
| self.queued_tasks[ti.key] = workload | ||
| else: | ||
| raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") | ||
| ti = workload.ti | ||
| self.queued_tasks[ti.key] = workload | ||
|
|
||
| def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: | ||
| from airflow.executors.workloads import ExecuteTask | ||
| from airflow.executors import workloads as wl | ||
|
|
||
| # Airflow V3 version | ||
| for w in workloads: | ||
| if not isinstance(w, ExecuteTask): | ||
| if isinstance(w, wl.ExecuteTask): | ||
| task_command = [w] | ||
| task_key = w.ti.key | ||
| queue = w.ti.queue | ||
| executor_config = w.ti.executor_config or {} | ||
|
|
||
| del self.queued_tasks[task_key] | ||
| self.execute_async( | ||
| key=task_key, | ||
| command=task_command, # type: ignore[arg-type] | ||
| queue=queue, | ||
| executor_config=executor_config, | ||
| ) | ||
| self.running.add(task_key) | ||
| elif AIRFLOW_V_3_2_PLUS and isinstance(w, wl.ExecuteCallback): | ||
| callback_command = [w] | ||
| callback_key = w.callback.id | ||
| queue = None | ||
| if isinstance(w.callback.data, dict) and "queue" in w.callback.data: | ||
| queue = w.callback.data["queue"] | ||
|
|
||
| del self.queued_callbacks[callback_key] | ||
| self.execute_async(key=callback_key, command=callback_command, queue=queue) # type: ignore[arg-type] | ||
| self.running.add(callback_key) | ||
| else: | ||
| raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}") | ||
| command = [w] | ||
| key = w.ti.key | ||
| queue = w.ti.queue | ||
| executor_config = w.ti.executor_config or {} | ||
|
|
||
| del self.queued_tasks[key] | ||
| self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type] | ||
| self.running.add(key) | ||
|
|
||
| def check_health(self): | ||
| """Make a test API call to check the health of the Batch Executor.""" | ||
|
|
@@ -235,7 +255,7 @@ def sync(self): | |
| def sync_running_jobs(self): | ||
| all_job_ids = self.active_workers.get_all_jobs() | ||
| if not all_job_ids: | ||
| self.log.debug("No active Airflow tasks, skipping sync") | ||
| self.log.debug("No active Airflow workloads, skipping sync") | ||
| return | ||
| describe_job_response = self._describe_jobs(all_job_ids) | ||
|
|
||
|
|
@@ -245,8 +265,8 @@ def sync_running_jobs(self): | |
| if job.get_job_state() == State.FAILED: | ||
| self._handle_failed_job(job) | ||
| elif job.get_job_state() == State.SUCCESS: | ||
| task_key = self.active_workers.pop_by_id(job.job_id) | ||
| self.success(task_key) | ||
| workload_key = self.active_workers.pop_by_id(job.job_id) | ||
| self.success(workload_key) | ||
|
|
||
| def _handle_failed_job(self, job): | ||
| """ | ||
|
|
@@ -263,15 +283,15 @@ def _handle_failed_job(self, job): | |
| # responsibility for ensuring the process started. Failures in the DAG will be caught by | ||
| # Airflow, which will be handled separately. | ||
| job_info = self.active_workers.id_to_job_info[job.job_id] | ||
| task_key = self.active_workers.id_to_key[job.job_id] | ||
| task_cmd = job_info.cmd | ||
| workload_key = self.active_workers.id_to_key[job.job_id] | ||
| workload_cmd = job_info.cmd | ||
| queue = job_info.queue | ||
| exec_info = job_info.config | ||
| failure_count = self.active_workers.failure_count_by_id(job_id=job.job_id) | ||
| if int(failure_count) < int(self.max_submit_job_attempts): | ||
| self.log.warning( | ||
| "Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.", | ||
| task_key, | ||
| "Airflow workload %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.", | ||
| workload_key, | ||
| job.status_reason, | ||
| failure_count, | ||
| self.max_submit_job_attempts, | ||
|
|
@@ -281,8 +301,8 @@ def _handle_failed_job(self, job): | |
| self.active_workers.pop_by_id(job.job_id) | ||
| self.pending_jobs.append( | ||
| BatchQueuedJob( | ||
| task_key, | ||
| task_cmd, | ||
| workload_key, | ||
| workload_cmd, | ||
| queue, | ||
| exec_info, | ||
| failure_count + 1, | ||
|
|
@@ -291,12 +311,12 @@ def _handle_failed_job(self, job): | |
| ) | ||
| else: | ||
| self.log.error( | ||
| "Airflow task %s has failed a maximum of %s times. Marking as failed", | ||
| task_key, | ||
| "Airflow workload %s has failed a maximum of %s times. Marking as failed", | ||
| workload_key, | ||
| failure_count, | ||
| ) | ||
| self.active_workers.pop_by_id(job.job_id) | ||
| self.fail(task_key) | ||
| self.fail(workload_key) | ||
|
|
||
| def attempt_submit_jobs(self): | ||
| """ | ||
|
|
@@ -309,8 +329,8 @@ def attempt_submit_jobs(self): | |
| """ | ||
| for _ in range(len(self.pending_jobs)): | ||
| batch_job = self.pending_jobs.popleft() | ||
| key = batch_job.key | ||
| cmd = batch_job.command | ||
| workload_key = batch_job.key | ||
| workload_cmd = batch_job.command | ||
| queue = batch_job.queue | ||
| exec_config = batch_job.executor_config | ||
| attempt_number = batch_job.attempt_number | ||
|
|
@@ -319,7 +339,7 @@ def attempt_submit_jobs(self): | |
| self.pending_jobs.append(batch_job) | ||
| continue | ||
| try: | ||
| submit_job_response = self._submit_job(key, cmd, queue, exec_config or {}) | ||
| submit_job_response = self._submit_job(workload_key, workload_cmd, queue, exec_config or {}) | ||
| except NoCredentialsError: | ||
| self.pending_jobs.append(batch_job) | ||
| raise | ||
|
|
@@ -337,18 +357,18 @@ def attempt_submit_jobs(self): | |
| self.log.error( | ||
| ( | ||
| "This job has been unsuccessfully attempted too many times (%s). " | ||
| "Dropping the task. Reason: %s" | ||
| "Dropping the workload. Reason: %s" | ||
| ), | ||
| attempt_number, | ||
| failure_reason, | ||
| ) | ||
| self.log_task_event( | ||
| event="batch job submit failure", | ||
| extra=f"This job has been unsuccessfully attempted too many times ({attempt_number}). " | ||
| f"Dropping the task. Reason: {failure_reason}", | ||
| ti_key=key, | ||
| f"Dropping the workload. Reason: {failure_reason}", | ||
| ti_key=workload_key, | ||
| ) | ||
| self.fail(key=key) | ||
| self.fail(key=workload_key) | ||
|
Comment on lines
365
to
+371
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When a callback workload exceeds max submit attempts, log_task_event is called with ti_key=workload_key. For callbacks, this key is a string UUID, not a TaskInstanceKey named tuple, which will cause errors since Log(task_instance=...) expects a TaskInstanceKey.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great find. I followed the relevant code. The scheduler uses this log queue to write a So I'm wondering whether we should either adjust the @ferruzzi any thoughts on it ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same point he made in another PR. #63035 (comment)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to change https://github.com/apache/airflow/blob/main/airflow-core/src/airflow/models/log.py#L78C9-L78C22 What we could do is: # BaseExecutor
def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey | None):
...
# BatchExecutor
self.log_task_event(
event="batch job submit failure",
extra=f"This job has been unsuccessfully attempted too many times ({attempt_number}). "
f"Dropping the workload. Reason: {failure_reason}"
)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this Log() entry for? I am wondering if callbacks need the entry or are they just for tasks?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| else: | ||
| batch_job.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( | ||
| attempt_number | ||
|
|
@@ -360,35 +380,39 @@ def attempt_submit_jobs(self): | |
| job_id = submit_job_response["job_id"] | ||
| self.active_workers.add_job( | ||
| job_id=job_id, | ||
| airflow_task_key=key, | ||
| airflow_cmd=cmd, | ||
| airflow_workload_key=workload_key, | ||
| airflow_cmd=workload_cmd, | ||
| queue=queue, | ||
| exec_config=exec_config, | ||
| attempt_number=attempt_number, | ||
| ) | ||
| self.running_state(key, job_id) | ||
| self.running_state(workload_key, job_id) | ||
|
|
||
| def _describe_jobs(self, job_ids) -> list[BatchJob]: | ||
| all_jobs = [] | ||
| for i in range(0, len(job_ids), self.__class__.DESCRIBE_JOBS_BATCH_SIZE): | ||
| batched_job_ids = job_ids[i : i + self.__class__.DESCRIBE_JOBS_BATCH_SIZE] | ||
| if not batched_job_ids: | ||
| continue | ||
| boto_describe_tasks = self.batch.describe_jobs(jobs=batched_job_ids) | ||
| boto_describe_workloads = self.batch.describe_jobs(jobs=batched_job_ids) | ||
|
|
||
| describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks) | ||
| all_jobs.extend(describe_tasks_response["jobs"]) | ||
| describe_workloads_response = BatchDescribeJobsResponseSchema().load(boto_describe_workloads) | ||
| all_jobs.extend(describe_workloads_response["jobs"]) | ||
| return all_jobs | ||
|
|
||
| def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None): | ||
| """Save the task to be executed in the next sync using Boto3's RunTask API.""" | ||
| def execute_async( | ||
| self, key: TaskInstanceKey | str, command: CommandType, queue=None, executor_config=None | ||
| ): | ||
| """Save the workload to be executed in the next sync using Boto3's RunTask API.""" | ||
| if executor_config and "command" in executor_config: | ||
| raise ValueError('Executor Config should never override "command"') | ||
|
|
||
| if len(command) == 1: | ||
| from airflow.executors.workloads import ExecuteTask | ||
| from airflow.executors import workloads | ||
|
|
||
| if isinstance(command[0], ExecuteTask): | ||
| if isinstance(command[0], workloads.ExecuteTask) or ( | ||
| AIRFLOW_V_3_2_PLUS and isinstance(command[0], workloads.ExecuteCallback) | ||
| ): | ||
| workload = command[0] | ||
| ser_input = workload.model_dump_json() | ||
| command = [ | ||
|
|
@@ -433,7 +457,7 @@ def _submit_job_kwargs( | |
| self, key: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType | ||
| ) -> dict: | ||
| """ | ||
| Override the Airflow command to update the container overrides so kwargs are specific to this task. | ||
| Override the Airflow command to update the container overrides so kwargs are specific to this workload. | ||
|
|
||
| One last chance to modify Boto3's "submit_job" kwarg params before it gets passed into the Boto3 | ||
| client. For the latest kwarg parameters: | ||
|
|
@@ -450,7 +474,7 @@ def _submit_job_kwargs( | |
| return submit_job_api | ||
|
|
||
| def end(self, heartbeat_interval=10): | ||
| """Wait for all currently running tasks to end and prevent any new jobs from running.""" | ||
| """Wait for all currently running workloads to end and prevent any new jobs from running.""" | ||
| try: | ||
| while True: | ||
| self.sync() | ||
|
|
@@ -500,7 +524,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task | |
| ti = next(ti for ti in tis if ti.external_executor_id == batch_job.job_id) | ||
| self.active_workers.add_job( | ||
| job_id=batch_job.job_id, | ||
| airflow_task_key=ti.key, | ||
| airflow_workload_key=ti.key, | ||
| airflow_cmd=ti.command_as_list(), | ||
| queue=ti.queue, | ||
| exec_config=ti.executor_config, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't mind can you narrow down the type for workload here? I think there is a new type that @ferruzzi made basically anywhere you need
task | callbackas a type.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the review.
Are you referring to types
ExecuteTaskandExecuteCallback? https://github.com/apache/airflow/blob/main/airflow-core/src/airflow/executors/workloads/__init__.pyI think we can only use it if we pin the provider's Airflow core dependency to when this change was implemented. This happened here #61153.
If we want to maintain backwards compatibility, I can't use the new types.. Currently the Amazon provider is
apache-airflow>=2.11.0There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even for just typing? @ferruzzi thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@onikolas is likely thinking of
SchedulerWorkloadfrom workloads/types.py, but @dondaum may be right, that isn't going to be in 2.11 so we can't use that until we bunmp the min_ver.I think we could do a conditional import though:
And that will force it to get cleaned up later when we pin the versions up??
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's say we make this conditional input work here. We are still limited by the
BaseExecutortype having this