Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

"""AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch Job."""
"""AWS Batch Executor. Each Airflow workload gets delegated out to an AWS Batch Job."""

from __future__ import annotations

Expand All @@ -33,7 +33,7 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
from airflow.utils.helpers import merge_dicts

Expand Down Expand Up @@ -88,6 +88,7 @@ class AwsBatchExecutor(BaseExecutor):
"""

supports_multi_team: bool = True
supports_callbacks: bool = True

# AWS only allows a maximum number of JOBs in the describe_jobs function
DESCRIBE_JOBS_BATCH_SIZE = 99
Expand Down Expand Up @@ -127,26 +128,45 @@ def __init__(self, *args, **kwargs):
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind can you narrow down the type for workload here? I think there is a new type that @ferruzzi made basically anywhere you need task | callback as a type.

Copy link
Contributor Author

@dondaum dondaum Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review.

Are you referring to types ExecuteTask and ExecuteCallback? https://github.com/apache/airflow/blob/main/airflow-core/src/airflow/executors/workloads/__init__.py

I think we can only use it if we pin the provider's Airflow core dependency to when this change was implemented. This happened here #61153.

If we want to maintain backwards compatibility, I can't use the new types.. Currently the Amazon provider is apache-airflow>=2.11.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even for just typing? @ferruzzi thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@onikolas is likely thinking of SchedulerWorkload from workloads/types.py, but @dondaum may be right, that isn't going to be in 2.11 so we can't use that until we bunmp the min_ver.

I think we could do a conditional import though:

workload_type_hint = workloads.All

if airflow version  > 3.2:
    from airflow.executors.workloads.types import SchedulerWorkload 

    workload_type_hint = SchedulerWorkload

And that will force it to get cleaned up later when we pin the versions up??

Copy link
Contributor Author

@dondaum dondaum Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say we make this conditional input work here. We are still limited by the BaseExecutor type having this

def queue_workload(self, workload: workloads.All, session: Session) -> None:
    ...

from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
if AIRFLOW_V_3_2_PLUS and isinstance(workload, workloads.ExecuteCallback):
self.queued_callbacks[workload.callback.id] = workload
elif isinstance(workload, workloads.ExecuteTask):
ti = workload.ti
self.queued_tasks[ti.key] = workload
else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask
from airflow.executors import workloads as wl

# Airflow V3 version
for w in workloads:
if not isinstance(w, ExecuteTask):
if isinstance(w, wl.ExecuteTask):
task_command = [w]
task_key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[task_key]
self.execute_async(
key=task_key,
command=task_command, # type: ignore[arg-type]
queue=queue,
executor_config=executor_config,
)
self.running.add(task_key)
elif AIRFLOW_V_3_2_PLUS and isinstance(w, wl.ExecuteCallback):
callback_command = [w]
callback_key = w.callback.id
queue = None
if isinstance(w.callback.data, dict) and "queue" in w.callback.data:
queue = w.callback.data["queue"]

del self.queued_callbacks[callback_key]
self.execute_async(key=callback_key, command=callback_command, queue=queue) # type: ignore[arg-type]
self.running.add(callback_key)
else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)

def check_health(self):
"""Make a test API call to check the health of the Batch Executor."""
Expand Down Expand Up @@ -235,7 +255,7 @@ def sync(self):
def sync_running_jobs(self):
all_job_ids = self.active_workers.get_all_jobs()
if not all_job_ids:
self.log.debug("No active Airflow tasks, skipping sync")
self.log.debug("No active Airflow workloads, skipping sync")
return
describe_job_response = self._describe_jobs(all_job_ids)

Expand All @@ -245,8 +265,8 @@ def sync_running_jobs(self):
if job.get_job_state() == State.FAILED:
self._handle_failed_job(job)
elif job.get_job_state() == State.SUCCESS:
task_key = self.active_workers.pop_by_id(job.job_id)
self.success(task_key)
workload_key = self.active_workers.pop_by_id(job.job_id)
self.success(workload_key)

def _handle_failed_job(self, job):
"""
Expand All @@ -263,15 +283,15 @@ def _handle_failed_job(self, job):
# responsibility for ensuring the process started. Failures in the DAG will be caught by
# Airflow, which will be handled separately.
job_info = self.active_workers.id_to_job_info[job.job_id]
task_key = self.active_workers.id_to_key[job.job_id]
task_cmd = job_info.cmd
workload_key = self.active_workers.id_to_key[job.job_id]
workload_cmd = job_info.cmd
queue = job_info.queue
exec_info = job_info.config
failure_count = self.active_workers.failure_count_by_id(job_id=job.job_id)
if int(failure_count) < int(self.max_submit_job_attempts):
self.log.warning(
"Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
task_key,
"Airflow workload %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
workload_key,
job.status_reason,
failure_count,
self.max_submit_job_attempts,
Expand All @@ -281,8 +301,8 @@ def _handle_failed_job(self, job):
self.active_workers.pop_by_id(job.job_id)
self.pending_jobs.append(
BatchQueuedJob(
task_key,
task_cmd,
workload_key,
workload_cmd,
queue,
exec_info,
failure_count + 1,
Expand All @@ -291,12 +311,12 @@ def _handle_failed_job(self, job):
)
else:
self.log.error(
"Airflow task %s has failed a maximum of %s times. Marking as failed",
task_key,
"Airflow workload %s has failed a maximum of %s times. Marking as failed",
workload_key,
failure_count,
)
self.active_workers.pop_by_id(job.job_id)
self.fail(task_key)
self.fail(workload_key)

def attempt_submit_jobs(self):
"""
Expand All @@ -309,8 +329,8 @@ def attempt_submit_jobs(self):
"""
for _ in range(len(self.pending_jobs)):
batch_job = self.pending_jobs.popleft()
key = batch_job.key
cmd = batch_job.command
workload_key = batch_job.key
workload_cmd = batch_job.command
queue = batch_job.queue
exec_config = batch_job.executor_config
attempt_number = batch_job.attempt_number
Expand All @@ -319,7 +339,7 @@ def attempt_submit_jobs(self):
self.pending_jobs.append(batch_job)
continue
try:
submit_job_response = self._submit_job(key, cmd, queue, exec_config or {})
submit_job_response = self._submit_job(workload_key, workload_cmd, queue, exec_config or {})
except NoCredentialsError:
self.pending_jobs.append(batch_job)
raise
Expand All @@ -337,18 +357,18 @@ def attempt_submit_jobs(self):
self.log.error(
(
"This job has been unsuccessfully attempted too many times (%s). "
"Dropping the task. Reason: %s"
"Dropping the workload. Reason: %s"
),
attempt_number,
failure_reason,
)
self.log_task_event(
event="batch job submit failure",
extra=f"This job has been unsuccessfully attempted too many times ({attempt_number}). "
f"Dropping the task. Reason: {failure_reason}",
ti_key=key,
f"Dropping the workload. Reason: {failure_reason}",
ti_key=workload_key,
)
self.fail(key=key)
self.fail(key=workload_key)
Comment on lines 365 to +371
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a callback workload exceeds max submit attempts, log_task_event is called with ti_key=workload_key. For callbacks, this key is a string UUID, not a TaskInstanceKey named tuple, which will cause errors since Log(task_instance=...) expects a TaskInstanceKey.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find. I followed the relevant code. The scheduler uses this log queue to write a Log() entry. The log itself can be initialized without a task instance, but the Executor method expects a task instance key def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey).

So I'm wondering whether we should either adjust the def log_task_event() to accept both keys, which would also perhaps require a change to a different executor, or whether we should remove the callback from this log queue.

@ferruzzi any thoughts on it ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same point he made in another PR. #63035 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to change Log() since it can already be initiated without a Task instance.

https://github.com/apache/airflow/blob/main/airflow-core/src/airflow/models/log.py#L78C9-L78C22

What we could do is:

# BaseExecutor
def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey | None):
    ...
# BatchExecutor
self.log_task_event(
    event="batch job submit failure",
    extra=f"This job has been unsuccessfully attempted too many times ({attempt_number}). "
    f"Dropping the workload. Reason: {failure_reason}"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this Log() entry for? I am wondering if callbacks need the entry or are they just for tasks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
batch_job.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
attempt_number
Expand All @@ -360,35 +380,39 @@ def attempt_submit_jobs(self):
job_id = submit_job_response["job_id"]
self.active_workers.add_job(
job_id=job_id,
airflow_task_key=key,
airflow_cmd=cmd,
airflow_workload_key=workload_key,
airflow_cmd=workload_cmd,
queue=queue,
exec_config=exec_config,
attempt_number=attempt_number,
)
self.running_state(key, job_id)
self.running_state(workload_key, job_id)

def _describe_jobs(self, job_ids) -> list[BatchJob]:
all_jobs = []
for i in range(0, len(job_ids), self.__class__.DESCRIBE_JOBS_BATCH_SIZE):
batched_job_ids = job_ids[i : i + self.__class__.DESCRIBE_JOBS_BATCH_SIZE]
if not batched_job_ids:
continue
boto_describe_tasks = self.batch.describe_jobs(jobs=batched_job_ids)
boto_describe_workloads = self.batch.describe_jobs(jobs=batched_job_ids)

describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
all_jobs.extend(describe_tasks_response["jobs"])
describe_workloads_response = BatchDescribeJobsResponseSchema().load(boto_describe_workloads)
all_jobs.extend(describe_workloads_response["jobs"])
return all_jobs

def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
"""Save the task to be executed in the next sync using Boto3's RunTask API."""
def execute_async(
self, key: TaskInstanceKey | str, command: CommandType, queue=None, executor_config=None
):
"""Save the workload to be executed in the next sync using Boto3's RunTask API."""
if executor_config and "command" in executor_config:
raise ValueError('Executor Config should never override "command"')

if len(command) == 1:
from airflow.executors.workloads import ExecuteTask
from airflow.executors import workloads

if isinstance(command[0], ExecuteTask):
if isinstance(command[0], workloads.ExecuteTask) or (
AIRFLOW_V_3_2_PLUS and isinstance(command[0], workloads.ExecuteCallback)
):
workload = command[0]
ser_input = workload.model_dump_json()
command = [
Expand Down Expand Up @@ -433,7 +457,7 @@ def _submit_job_kwargs(
self, key: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
) -> dict:
"""
Override the Airflow command to update the container overrides so kwargs are specific to this task.
Override the Airflow command to update the container overrides so kwargs are specific to this workload.

One last chance to modify Boto3's "submit_job" kwarg params before it gets passed into the Boto3
client. For the latest kwarg parameters:
Expand All @@ -450,7 +474,7 @@ def _submit_job_kwargs(
return submit_job_api

def end(self, heartbeat_interval=10):
"""Wait for all currently running tasks to end and prevent any new jobs from running."""
"""Wait for all currently running workloads to end and prevent any new jobs from running."""
try:
while True:
self.sync()
Expand Down Expand Up @@ -500,7 +524,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
ti = next(ti for ti in tis if ti.external_executor_id == batch_job.job_id)
self.active_workers.add_job(
job_id=batch_job.job_id,
airflow_task_key=ti.key,
airflow_workload_key=ti.key,
airflow_cmd=ti.command_as_list(),
queue=ti.queue,
exec_config=ti.executor_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
class BatchQueuedJob:
"""Represents a Batch job that is queued. The job will be run in the next heartbeat."""

key: TaskInstanceKey
key: TaskInstanceKey | str
command: CommandType
queue: str
executor_config: ExecutorConfigType
Expand Down Expand Up @@ -91,33 +91,33 @@ class BatchJobCollection:
"""A collection to manage running Batch Jobs."""

def __init__(self):
self.key_to_id: dict[TaskInstanceKey, str] = {}
self.id_to_key: dict[str, TaskInstanceKey] = {}
self.key_to_id: dict[TaskInstanceKey | str, str] = {}
self.id_to_key: dict[str, TaskInstanceKey | str] = {}
self.id_to_failure_counts: dict[str, int] = defaultdict(int)
self.id_to_job_info: dict[str, BatchJobInfo] = {}

def add_job(
self,
job_id: str,
airflow_task_key: TaskInstanceKey,
airflow_workload_key: TaskInstanceKey | str,
airflow_cmd: CommandType,
queue: str,
exec_config: ExecutorConfigType,
attempt_number: int,
):
"""Add a job to the collection."""
self.key_to_id[airflow_task_key] = job_id
self.id_to_key[job_id] = airflow_task_key
self.key_to_id[airflow_workload_key] = job_id
self.id_to_key[job_id] = airflow_workload_key
self.id_to_failure_counts[job_id] = attempt_number
self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd, queue=queue, config=exec_config)

def pop_by_id(self, job_id: str) -> TaskInstanceKey:
def pop_by_id(self, job_id: str) -> TaskInstanceKey | str:
"""Delete job from collection based off of Batch Job ID."""
task_key = self.id_to_key[job_id]
del self.key_to_id[task_key]
workload_key = self.id_to_key[job_id]
del self.key_to_id[workload_key]
del self.id_to_key[job_id]
del self.id_to_failure_counts[job_id]
return task_key
return workload_key

def failure_count_by_id(self, job_id: str) -> int:
"""Get the number of times a job has failed given a Batch Job Id."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_1_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 1)
AIRFLOW_V_3_1_8_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 8)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)

try:
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
Expand All @@ -58,6 +59,7 @@ def is_arg_set(value): # type: ignore[misc,no-redef]
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_1_1_PLUS",
"AIRFLOW_V_3_1_8_PLUS",
"AIRFLOW_V_3_2_PLUS",
"NOTSET",
"ArgNotSet",
"is_arg_set",
Expand Down
Loading
Loading