diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py index 01bed883ddc92..fc61680121ce5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch Job.""" +"""AWS Batch Executor. Each Airflow workload gets delegated out to an AWS Batch Job.""" from __future__ import annotations @@ -33,7 +33,7 @@ exponential_backoff_retry, ) from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone from airflow.utils.helpers import merge_dicts @@ -88,6 +88,7 @@ class AwsBatchExecutor(BaseExecutor): """ supports_multi_team: bool = True + supports_callbacks: bool = True # AWS only allows a maximum number of JOBs in the describe_jobs function DESCRIBE_JOBS_BATCH_SIZE = 99 @@ -127,26 +128,45 @@ def __init__(self, *args, **kwargs): def queue_workload(self, workload: workloads.All, session: Session | None) -> None: from airflow.executors import workloads - if not isinstance(workload, workloads.ExecuteTask): + if AIRFLOW_V_3_2_PLUS and isinstance(workload, workloads.ExecuteCallback): + self.queued_callbacks[workload.callback.id] = workload + elif isinstance(workload, workloads.ExecuteTask): + ti = workload.ti + self.queued_tasks[ti.key] = workload + else: raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - ti = workload.ti - self.queued_tasks[ti.key] = workload def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: - from airflow.executors.workloads import ExecuteTask + from airflow.executors import workloads as wl # Airflow V3 version for w in workloads: - if not isinstance(w, ExecuteTask): + if isinstance(w, wl.ExecuteTask): + task_command = [w] + task_key = w.ti.key + queue = w.ti.queue + executor_config = w.ti.executor_config or {} + + del self.queued_tasks[task_key] + self.execute_async( + key=task_key, + command=task_command, # type: ignore[arg-type] + queue=queue, + executor_config=executor_config, + ) + self.running.add(task_key) + elif AIRFLOW_V_3_2_PLUS and isinstance(w, wl.ExecuteCallback): + callback_command = [w] + callback_key = w.callback.id + queue = None + if isinstance(w.callback.data, dict) and "queue" in w.callback.data: + queue = w.callback.data["queue"] + + del self.queued_callbacks[callback_key] + self.execute_async(key=callback_key, command=callback_command, queue=queue) # type: ignore[arg-type] + self.running.add(callback_key) + else: raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}") - command = [w] - key = w.ti.key - queue = w.ti.queue - executor_config = w.ti.executor_config or {} - - del self.queued_tasks[key] - self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type] - self.running.add(key) def check_health(self): """Make a test API call to check the health of the Batch Executor.""" @@ -235,7 +255,7 @@ def sync(self): def sync_running_jobs(self): all_job_ids = self.active_workers.get_all_jobs() if not all_job_ids: - self.log.debug("No active Airflow tasks, skipping sync") + self.log.debug("No active Airflow workloads, skipping sync") return describe_job_response = self._describe_jobs(all_job_ids) @@ -245,8 +265,8 @@ def sync_running_jobs(self): if job.get_job_state() == State.FAILED: self._handle_failed_job(job) elif job.get_job_state() == State.SUCCESS: - task_key = self.active_workers.pop_by_id(job.job_id) - self.success(task_key) + workload_key = self.active_workers.pop_by_id(job.job_id) + self.success(workload_key) def _handle_failed_job(self, job): """ @@ -263,15 +283,15 @@ def _handle_failed_job(self, job): # responsibility for ensuring the process started. Failures in the DAG will be caught by # Airflow, which will be handled separately. job_info = self.active_workers.id_to_job_info[job.job_id] - task_key = self.active_workers.id_to_key[job.job_id] - task_cmd = job_info.cmd + workload_key = self.active_workers.id_to_key[job.job_id] + workload_cmd = job_info.cmd queue = job_info.queue exec_info = job_info.config failure_count = self.active_workers.failure_count_by_id(job_id=job.job_id) if int(failure_count) < int(self.max_submit_job_attempts): self.log.warning( - "Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.", - task_key, + "Airflow workload %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.", + workload_key, job.status_reason, failure_count, self.max_submit_job_attempts, @@ -281,8 +301,8 @@ def _handle_failed_job(self, job): self.active_workers.pop_by_id(job.job_id) self.pending_jobs.append( BatchQueuedJob( - task_key, - task_cmd, + workload_key, + workload_cmd, queue, exec_info, failure_count + 1, @@ -291,12 +311,12 @@ def _handle_failed_job(self, job): ) else: self.log.error( - "Airflow task %s has failed a maximum of %s times. Marking as failed", - task_key, + "Airflow workload %s has failed a maximum of %s times. Marking as failed", + workload_key, failure_count, ) self.active_workers.pop_by_id(job.job_id) - self.fail(task_key) + self.fail(workload_key) def attempt_submit_jobs(self): """ @@ -309,8 +329,8 @@ def attempt_submit_jobs(self): """ for _ in range(len(self.pending_jobs)): batch_job = self.pending_jobs.popleft() - key = batch_job.key - cmd = batch_job.command + workload_key = batch_job.key + workload_cmd = batch_job.command queue = batch_job.queue exec_config = batch_job.executor_config attempt_number = batch_job.attempt_number @@ -319,7 +339,7 @@ def attempt_submit_jobs(self): self.pending_jobs.append(batch_job) continue try: - submit_job_response = self._submit_job(key, cmd, queue, exec_config or {}) + submit_job_response = self._submit_job(workload_key, workload_cmd, queue, exec_config or {}) except NoCredentialsError: self.pending_jobs.append(batch_job) raise @@ -337,7 +357,7 @@ def attempt_submit_jobs(self): self.log.error( ( "This job has been unsuccessfully attempted too many times (%s). " - "Dropping the task. Reason: %s" + "Dropping the workload. Reason: %s" ), attempt_number, failure_reason, @@ -345,10 +365,10 @@ def attempt_submit_jobs(self): self.log_task_event( event="batch job submit failure", extra=f"This job has been unsuccessfully attempted too many times ({attempt_number}). " - f"Dropping the task. Reason: {failure_reason}", - ti_key=key, + f"Dropping the workload. Reason: {failure_reason}", + ti_key=workload_key, ) - self.fail(key=key) + self.fail(key=workload_key) else: batch_job.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( attempt_number @@ -360,13 +380,13 @@ def attempt_submit_jobs(self): job_id = submit_job_response["job_id"] self.active_workers.add_job( job_id=job_id, - airflow_task_key=key, - airflow_cmd=cmd, + airflow_workload_key=workload_key, + airflow_cmd=workload_cmd, queue=queue, exec_config=exec_config, attempt_number=attempt_number, ) - self.running_state(key, job_id) + self.running_state(workload_key, job_id) def _describe_jobs(self, job_ids) -> list[BatchJob]: all_jobs = [] @@ -374,21 +394,25 @@ def _describe_jobs(self, job_ids) -> list[BatchJob]: batched_job_ids = job_ids[i : i + self.__class__.DESCRIBE_JOBS_BATCH_SIZE] if not batched_job_ids: continue - boto_describe_tasks = self.batch.describe_jobs(jobs=batched_job_ids) + boto_describe_workloads = self.batch.describe_jobs(jobs=batched_job_ids) - describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks) - all_jobs.extend(describe_tasks_response["jobs"]) + describe_workloads_response = BatchDescribeJobsResponseSchema().load(boto_describe_workloads) + all_jobs.extend(describe_workloads_response["jobs"]) return all_jobs - def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None): - """Save the task to be executed in the next sync using Boto3's RunTask API.""" + def execute_async( + self, key: TaskInstanceKey | str, command: CommandType, queue=None, executor_config=None + ): + """Save the workload to be executed in the next sync using Boto3's RunTask API.""" if executor_config and "command" in executor_config: raise ValueError('Executor Config should never override "command"') if len(command) == 1: - from airflow.executors.workloads import ExecuteTask + from airflow.executors import workloads - if isinstance(command[0], ExecuteTask): + if isinstance(command[0], workloads.ExecuteTask) or ( + AIRFLOW_V_3_2_PLUS and isinstance(command[0], workloads.ExecuteCallback) + ): workload = command[0] ser_input = workload.model_dump_json() command = [ @@ -433,7 +457,7 @@ def _submit_job_kwargs( self, key: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType ) -> dict: """ - Override the Airflow command to update the container overrides so kwargs are specific to this task. + Override the Airflow command to update the container overrides so kwargs are specific to this workload. One last chance to modify Boto3's "submit_job" kwarg params before it gets passed into the Boto3 client. For the latest kwarg parameters: @@ -450,7 +474,7 @@ def _submit_job_kwargs( return submit_job_api def end(self, heartbeat_interval=10): - """Wait for all currently running tasks to end and prevent any new jobs from running.""" + """Wait for all currently running workloads to end and prevent any new jobs from running.""" try: while True: self.sync() @@ -500,7 +524,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task ti = next(ti for ti in tis if ti.external_executor_id == batch_job.job_id) self.active_workers.add_job( job_id=batch_job.job_id, - airflow_task_key=ti.key, + airflow_workload_key=ti.key, airflow_cmd=ti.command_as_list(), queue=ti.queue, exec_config=ti.executor_config, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py index abcc16ec321fb..16ba84fb2f94f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/utils.py @@ -43,7 +43,7 @@ class BatchQueuedJob: """Represents a Batch job that is queued. The job will be run in the next heartbeat.""" - key: TaskInstanceKey + key: TaskInstanceKey | str command: CommandType queue: str executor_config: ExecutorConfigType @@ -91,33 +91,33 @@ class BatchJobCollection: """A collection to manage running Batch Jobs.""" def __init__(self): - self.key_to_id: dict[TaskInstanceKey, str] = {} - self.id_to_key: dict[str, TaskInstanceKey] = {} + self.key_to_id: dict[TaskInstanceKey | str, str] = {} + self.id_to_key: dict[str, TaskInstanceKey | str] = {} self.id_to_failure_counts: dict[str, int] = defaultdict(int) self.id_to_job_info: dict[str, BatchJobInfo] = {} def add_job( self, job_id: str, - airflow_task_key: TaskInstanceKey, + airflow_workload_key: TaskInstanceKey | str, airflow_cmd: CommandType, queue: str, exec_config: ExecutorConfigType, attempt_number: int, ): """Add a job to the collection.""" - self.key_to_id[airflow_task_key] = job_id - self.id_to_key[job_id] = airflow_task_key + self.key_to_id[airflow_workload_key] = job_id + self.id_to_key[job_id] = airflow_workload_key self.id_to_failure_counts[job_id] = attempt_number self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd, queue=queue, config=exec_config) - def pop_by_id(self, job_id: str) -> TaskInstanceKey: + def pop_by_id(self, job_id: str) -> TaskInstanceKey | str: """Delete job from collection based off of Batch Job ID.""" - task_key = self.id_to_key[job_id] - del self.key_to_id[task_key] + workload_key = self.id_to_key[job_id] + del self.key_to_id[workload_key] del self.id_to_key[job_id] del self.id_to_failure_counts[job_id] - return task_key + return workload_key def failure_count_by_id(self, job_id: str) -> int: """Get the number of times a job has failed given a Batch Job Id.""" diff --git a/providers/amazon/src/airflow/providers/amazon/version_compat.py b/providers/amazon/src/airflow/providers/amazon/version_compat.py index d56eb0ea33995..d02a776b42193 100644 --- a/providers/amazon/src/airflow/providers/amazon/version_compat.py +++ b/providers/amazon/src/airflow/providers/amazon/version_compat.py @@ -40,6 +40,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) AIRFLOW_V_3_1_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 1) AIRFLOW_V_3_1_8_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 8) +AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) try: from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet @@ -58,6 +59,7 @@ def is_arg_set(value): # type: ignore[misc,no-redef] "AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_1_1_PLUS", "AIRFLOW_V_3_1_8_PLUS", + "AIRFLOW_V_3_2_PLUS", "NOTSET", "ArgNotSet", "is_arg_set", diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py index ed767fea187dd..0099b967ce85a 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py @@ -50,7 +50,7 @@ from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3])) ARN1 = "arn1" @@ -112,7 +112,7 @@ def _setup_method(self): self.first_airflow_key = mock.Mock(spec=tuple) self.collection.add_job( job_id=self.first_job_id, - airflow_task_key=self.first_airflow_key, + airflow_workload_key=self.first_airflow_key, airflow_cmd="command1", queue="queue1", exec_config={}, @@ -123,7 +123,7 @@ def _setup_method(self): self.second_airflow_key = mock.Mock(spec=tuple) self.collection.add_job( job_id=self.second_job_id, - airflow_task_key=self.second_airflow_key, + airflow_workload_key=self.second_airflow_key, airflow_cmd="command2", queue="queue2", exec_config={}, @@ -211,6 +211,7 @@ def test_task_sdk(self, running_state_mock, mock_airflow_key, mock_executor, moc workload = mock.Mock(spec=ExecuteTask) workload.ti = mock.Mock(spec=TaskInstance) workload.ti.key = mock_airflow_key() + workload.ti.queue = "some-job-queue" tags_exec_config = [{"key": "FOO", "value": "BAR"}] workload.ti.executor_config = {"tags": tags_exec_config} ser_workload = json.dumps({"test_key": "test_value"}) @@ -270,6 +271,98 @@ def test_task_sdk(self, running_state_mock, mock_airflow_key, mock_executor, moc assert job_id == ARN1 running_state_mock.assert_called_once_with(workload.ti.key, ARN1) + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="Test requires Airflow 3.2+") + @mock.patch("airflow.providers.amazon.aws.executors.batch.batch_executor.AwsBatchExecutor.running_state") + def test_task_sdk_callback(self, running_state_mock, mock_airflow_key, mock_executor, mock_cmd): + """Test task sdk execution for callbacks from end-to-end.""" + from airflow.executors.workloads import ExecuteCallback + + workload = mock.Mock(spec=ExecuteCallback) + workload.callback = mock.Mock() + workload.callback.id = mock_airflow_key() + ser_workload = json.dumps({"test_key": "test_value"}) + workload.model_dump_json.return_value = ser_workload + + mock_executor.queue_workload(workload, mock.Mock()) + + mock_executor.batch.submit_job.return_value = {"jobId": ARN1, "jobName": "some-job-name"} + + assert mock_executor.queued_callbacks[workload.callback.id] == workload + assert len(mock_executor.pending_jobs) == 0 + assert len(mock_executor.running) == 0 + mock_executor._process_workloads([workload]) + assert len(mock_executor.queued_callbacks) == 0 + assert len(mock_executor.running) == 1 + assert workload.callback.id in mock_executor.running + assert len(mock_executor.pending_jobs) == 1 + assert mock_executor.pending_jobs[0].command == [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + '{"test_key": "test_value"}', + ] + + mock_executor.attempt_submit_jobs() + mock_executor.batch.submit_job.assert_called_once() + assert len(mock_executor.pending_jobs) == 0 + mock_executor.batch.submit_job.assert_called_once_with( + jobDefinition="some-job-def", + jobName="some-job-name", + jobQueue="some-job-queue", + containerOverrides={ + "command": [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + ser_workload, + ], + "environment": [ + { + "name": "AIRFLOW_IS_EXECUTOR_CONTAINER", + "value": "true", + }, + ], + }, + ) + + # Task is stored in active worker. + assert len(mock_executor.active_workers) == 1 + # Get the job_id for this task key + job_id = next( + job_id + for job_id, key in mock_executor.active_workers.id_to_key.items() + if key == workload.callback.id + ) + assert job_id == ARN1 + running_state_mock.assert_called_once_with(workload.callback.id, ARN1) + + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="Test requires Airflow 3.2+") + @mock.patch("airflow.providers.amazon.aws.executors.batch.batch_executor.AwsBatchExecutor.running_state") + def test_task_sdk_callback_with_queue(self, mock_airflow_key, mock_executor): + """Test task sdk execution for callbacks with queue from end-to-end.""" + from airflow.executors.workloads import ExecuteCallback + + workload = mock.Mock(spec=ExecuteCallback) + workload.callback = mock.Mock() + workload.callback.id = mock_airflow_key() + workload.callback.data = {"queue": "fast-queue"} + + mock_executor.queue_workload(workload, mock.Mock()) + + mock_executor.batch.submit_job.return_value = {"jobId": ARN1, "jobName": "some-job-name"} + + assert mock_executor.queued_callbacks[workload.callback.id] == workload + assert len(mock_executor.pending_jobs) == 0 + assert len(mock_executor.running) == 0 + mock_executor._process_workloads([workload]) + assert len(mock_executor.queued_callbacks) == 0 + assert len(mock_executor.running) == 1 + assert workload.callback.id in mock_executor.running + assert len(mock_executor.pending_jobs) == 1 + assert mock_executor.pending_jobs[0].queue == "fast-queue" + @mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor): """ @@ -446,7 +539,7 @@ def test_task_retry_on_api_failure(self, _, mock_executor, caplog): mock_executor.sync_running_jobs() for i in range(2): assert ( - f"Airflow task {airflow_keys[i]} failed due to {jobs[i]['statusReason']}. Failure 1 out of {mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. Rescheduling." + f"Airflow workload {airflow_keys[i]} failed due to {jobs[i]['statusReason']}. Failure 1 out of {mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. Rescheduling." in caplog.messages[i] ) @@ -455,7 +548,7 @@ def test_task_retry_on_api_failure(self, _, mock_executor, caplog): mock_executor.sync_running_jobs() for i in range(2): assert ( - f"Airflow task {airflow_keys[i]} failed due to {jobs[i]['statusReason']}. Failure 2 out of {mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. Rescheduling." + f"Airflow workload {airflow_keys[i]} failed due to {jobs[i]['statusReason']}. Failure 2 out of {mock_executor.max_submit_job_attempts} occurred on {jobs[i]['jobId']}. Rescheduling." in caplog.messages[i] ) @@ -464,7 +557,7 @@ def test_task_retry_on_api_failure(self, _, mock_executor, caplog): mock_executor.sync_running_jobs() for i in range(2): assert ( - f"Airflow task {airflow_keys[i]} has failed a maximum of {mock_executor.max_submit_job_attempts} times. Marking as failed" + f"Airflow workload {airflow_keys[i]} has failed a maximum of {mock_executor.max_submit_job_attempts} times. Marking as failed" in caplog.text ) @@ -480,7 +573,7 @@ def test_sync_running_jobs_no_jobs(self, mock_executor, caplog): caplog.set_level("DEBUG") assert len(mock_executor.active_workers.get_all_jobs()) == 0 mock_executor.sync_running_jobs() - assert "No active Airflow tasks, skipping sync" in caplog.messages[0] + assert "No active Airflow workloads, skipping sync" in caplog.messages[0] def test_sync_client_error(self, mock_executor, caplog): mock_executor.execute_async("airflow_key", "airflow_cmd") @@ -499,7 +592,7 @@ def test_sync_client_error(self, mock_executor, caplog): def test_sync_exception(self, mock_executor, caplog): mock_executor.active_workers.add_job( job_id="job_id", - airflow_task_key="airflow_key", + airflow_workload_key="airflow_key", airflow_cmd="command", queue="queue", exec_config={}, @@ -617,7 +710,7 @@ def test_terminate(self, mock_airflow_key, mock_executor): def test_terminate_failure(self, mock_executor, caplog): mock_executor.active_workers.add_job( job_id="job_id", - airflow_task_key="airflow_key", + airflow_workload_key="airflow_key", airflow_cmd="command", queue="queue", exec_config={}, @@ -663,7 +756,7 @@ def _mock_sync( """ executor.active_workers.add_job( job_id=job_id, - airflow_task_key=airflow_key, + airflow_workload_key=airflow_key, airflow_cmd="airflow_cmd", queue="queue", exec_config={}, @@ -966,7 +1059,9 @@ def test_submit_job_kwargs_exec_config_overrides( executor = AwsBatchExecutor() - final_run_task_kwargs = executor._submit_job_kwargs(mock_ti_key, command, "queue", exec_config) + final_run_task_kwargs = executor._submit_job_kwargs( + mock_ti_key, command, expected_result["jobQueue"], exec_config + ) assert final_run_task_kwargs == expected_result diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py index c0b914bdb3b6f..3ee8dca26283a 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_utils.py @@ -152,7 +152,7 @@ def test_add_job(self): """Test adding a job to the collection.""" self.collection.add_job( job_id=self.job_id1, - airflow_task_key=self.key1, + airflow_workload_key=self.key1, airflow_cmd=self.cmd1, queue=self.queue1, exec_config=self.config1, @@ -170,7 +170,7 @@ def test_add_multiple_jobs(self): """Test adding multiple jobs to the collection.""" self.collection.add_job( job_id=self.job_id1, - airflow_task_key=self.key1, + airflow_workload_key=self.key1, airflow_cmd=self.cmd1, queue=self.queue1, exec_config=self.config1, @@ -178,7 +178,7 @@ def test_add_multiple_jobs(self): ) self.collection.add_job( job_id=self.job_id2, - airflow_task_key=self.key2, + airflow_workload_key=self.key2, airflow_cmd=self.cmd2, queue=self.queue2, exec_config=self.config2, @@ -194,7 +194,7 @@ def test_pop_by_id(self): """Test removing a job from the collection by its ID.""" self.collection.add_job( job_id=self.job_id1, - airflow_task_key=self.key1, + airflow_workload_key=self.key1, airflow_cmd=self.cmd1, queue=self.queue1, exec_config=self.config1, @@ -222,7 +222,7 @@ def test_failure_count_by_id(self): attempt_number = 5 self.collection.add_job( job_id=self.job_id1, - airflow_task_key=self.key1, + airflow_workload_key=self.key1, airflow_cmd=self.cmd1, queue=self.queue1, exec_config=self.config1, @@ -240,7 +240,7 @@ def test_increment_failure_count(self): initial_attempt = 1 self.collection.add_job( job_id=self.job_id1, - airflow_task_key=self.key1, + airflow_workload_key=self.key1, airflow_cmd=self.cmd1, queue=self.queue1, exec_config=self.config1, @@ -264,7 +264,7 @@ def test_get_all_jobs_with_jobs(self): """Test getting all job IDs from a collection with jobs.""" self.collection.add_job( job_id=self.job_id1, - airflow_task_key=self.key1, + airflow_workload_key=self.key1, airflow_cmd=self.cmd1, queue=self.queue1, exec_config=self.config1, @@ -272,7 +272,7 @@ def test_get_all_jobs_with_jobs(self): ) self.collection.add_job( job_id=self.job_id2, - airflow_task_key=self.key2, + airflow_workload_key=self.key2, airflow_cmd=self.cmd2, queue=self.queue2, exec_config=self.config2, @@ -288,7 +288,7 @@ def test_len_method(self): assert len(self.collection) == 0 self.collection.add_job( job_id=self.job_id1, - airflow_task_key=self.key1, + airflow_workload_key=self.key1, airflow_cmd=self.cmd1, queue=self.queue1, exec_config=self.config1, @@ -297,7 +297,7 @@ def test_len_method(self): assert len(self.collection) == 1 self.collection.add_job( job_id=self.job_id2, - airflow_task_key=self.key2, + airflow_workload_key=self.key2, airflow_cmd=self.cmd2, queue=self.queue2, exec_config=self.config2,