From 7c7dfecffffe5e9a433397ae0690bdd6742485be Mon Sep 17 00:00:00 2001 From: Niko Oliveira Date: Tue, 23 Jun 2026 15:49:16 -0700 Subject: [PATCH 1/3] Standardize ECS TaskDoneTrigger on region_name and AWS hook parameters TaskDoneTrigger used a non-standard 'region' argument and ignored SSL verification and botocore configuration. It now uses region_name and accepts verify and botocore_config like the other Amazon provider triggers, and EcsRunTaskOperator forwards those values when deferring, completing the trigger portion of the ECS migration in apache/airflow#35278. --- .../providers/amazon/aws/operators/ecs.py | 4 +- .../providers/amazon/aws/triggers/ecs.py | 26 ++++++-- .../unit/amazon/aws/triggers/test_ecs.py | 66 +++++++++++++++++++ 3 files changed, 89 insertions(+), 7 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py index 05f0102152d4a..4b71424e42d8a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py @@ -556,9 +556,11 @@ def execute(self, context): waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, - region=self.region_name, + region_name=self.region_name, log_group=self.awslogs_group, log_stream=self._get_logs_stream_name(), + verify=self.verify, + botocore_config=self.botocore_config, ), method_name="execute_complete", # timeout is set to ensure that if a trigger dies, the timeout does not restart diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py index 3768af58a12fd..dc9f923abd4d1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py @@ -125,7 +125,9 @@ class TaskDoneTrigger(BaseTrigger): :param waiter_max_attempts: The number of times to ping for status. Will fail after that many unsuccessful attempts. :param aws_conn_id: The Airflow connection used for AWS credentials. - :param region: The AWS region where the cluster is located. + :param region_name: The AWS region where the cluster is located. Used to build the hook. + :param verify: Whether or not to verify SSL certificates. Used to build the hook. + :param botocore_config: Configuration dictionary for the botocore client. Used to build the hook. """ def __init__( @@ -135,9 +137,11 @@ def __init__( waiter_delay: int, waiter_max_attempts: int, aws_conn_id: str | None, - region: str | None, + region_name: str | None, log_group: str | None = None, log_stream: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, ): self.cluster = cluster self.task_arn = task_arn @@ -145,7 +149,9 @@ def __init__( self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id - self.region = region + self.region_name = region_name + self.verify = verify + self.botocore_config = botocore_config self.log_group = log_group self.log_stream = log_stream @@ -159,19 +165,27 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "waiter_delay": self.waiter_delay, "waiter_max_attempts": self.waiter_max_attempts, "aws_conn_id": self.aws_conn_id, - "region": self.region, + "region_name": self.region_name, "log_group": self.log_group, "log_stream": self.log_stream, + "verify": self.verify, + "botocore_config": self.botocore_config, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: async with ( await EcsHook( - aws_conn_id=self.aws_conn_id, region_name=self.region + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, ).get_async_conn() as ecs_client, await AwsLogsHook( - aws_conn_id=self.aws_conn_id, region_name=self.region + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, ).get_async_conn() as logs_client, ): waiter = ecs_client.get_waiter("tasks_stopped") diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py index 5a4ba278cfd14..ccf77e565a56d 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py @@ -35,6 +35,72 @@ class TestTaskDoneTrigger: + def test_serialize_includes_generic_hook_params(self): + trigger = TaskDoneTrigger( + cluster="cluster", + task_arn="task_arn", + waiter_delay=5, + waiter_max_attempts=10, + aws_conn_id="my_conn", + region_name="eu-west-1", + log_group="lg", + log_stream="ls", + verify=False, + botocore_config={"read_timeout": 7}, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.ecs.TaskDoneTrigger" + assert kwargs == { + "cluster": "cluster", + "task_arn": "task_arn", + "waiter_delay": 5, + "waiter_max_attempts": 10, + "aws_conn_id": "my_conn", + "region_name": "eu-west-1", + "log_group": "lg", + "log_stream": "ls", + "verify": False, + "botocore_config": {"read_timeout": 7}, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.ecs.AwsLogsHook") + @mock.patch("airflow.providers.amazon.aws.triggers.ecs.EcsHook") + async def test_run_builds_hooks_with_generic_params(self, ecs_hook_cls, logs_hook_cls): + def make_hook(client): + ctx = mock.MagicMock() + ctx.__aenter__ = AsyncMock(return_value=client) + ctx.__aexit__ = AsyncMock(return_value=False) + instance = mock.MagicMock() + instance.get_async_conn = AsyncMock(return_value=ctx) + return instance + + ecs_client = mock.MagicMock() + ecs_client.get_waiter().wait = AsyncMock() + ecs_hook_cls.return_value = make_hook(ecs_client) + logs_hook_cls.return_value = make_hook(mock.MagicMock()) + + trigger = TaskDoneTrigger( + cluster="cluster", + task_arn="task_arn", + waiter_delay=0, + waiter_max_attempts=10, + aws_conn_id="my_conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 7}, + ) + await trigger.run().asend(None) + + expected = { + "aws_conn_id": "my_conn", + "region_name": "eu-west-1", + "verify": False, + "config": {"read_timeout": 7}, + } + ecs_hook_cls.assert_called_once_with(**expected) + logs_hook_cls.assert_called_once_with(**expected) + @pytest.mark.asyncio @mock.patch.object(EcsHook, "get_async_conn") # this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step From ee14637697d79c404e93b8932e264c99aaa33ba6 Mon Sep 17 00:00:00 2001 From: Niko Oliveira Date: Tue, 23 Jun 2026 18:17:12 -0700 Subject: [PATCH 2/3] Keep region as a deprecated TaskDoneTrigger alias Preserve backward compatibility for ECS TaskDoneTrigger after standardizing on region_name: the old region argument is still accepted as a deprecated alias, emitting AirflowProviderDeprecationWarning. This keeps existing keyword callers working and lets deferred-task triggers serialized by older versions deserialize after upgrade. --- .../airflow/providers/amazon/aws/triggers/ecs.py | 14 +++++++++++++- .../tests/unit/amazon/aws/triggers/test_ecs.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py index dc9f923abd4d1..39896960cfabd 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py @@ -18,11 +18,13 @@ from __future__ import annotations import asyncio +import warnings from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Any from botocore.exceptions import ClientError, WaiterError +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger @@ -128,6 +130,7 @@ class TaskDoneTrigger(BaseTrigger): :param region_name: The AWS region where the cluster is located. Used to build the hook. :param verify: Whether or not to verify SSL certificates. Used to build the hook. :param botocore_config: Configuration dictionary for the botocore client. Used to build the hook. + :param region: (deprecated) use ``region_name`` instead. """ def __init__( @@ -137,12 +140,21 @@ def __init__( waiter_delay: int, waiter_max_attempts: int, aws_conn_id: str | None, - region_name: str | None, + region_name: str | None = None, log_group: str | None = None, log_stream: str | None = None, verify: bool | str | None = None, botocore_config: dict | None = None, + region: str | None = None, ): + if region is not None: + warnings.warn( + "`region` is deprecated and will be removed in a future release. " + "Please use `region_name` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + region_name = region self.cluster = cluster self.task_arn = task_arn diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py index ccf77e565a56d..a46fc053c3830 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py @@ -23,6 +23,7 @@ import pytest from botocore.exceptions import WaiterError +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.triggers.ecs import ( @@ -35,6 +36,21 @@ class TestTaskDoneTrigger: + def test_deprecated_region_alias(self): + with pytest.warns(AirflowProviderDeprecationWarning, match="region"): + trigger = TaskDoneTrigger( + cluster="cluster", + task_arn="task_arn", + waiter_delay=5, + waiter_max_attempts=10, + aws_conn_id="my_conn", + region="eu-west-1", + ) + assert trigger.region_name == "eu-west-1" + _, kwargs = trigger.serialize() + assert kwargs["region_name"] == "eu-west-1" + assert "region" not in kwargs + def test_serialize_includes_generic_hook_params(self): trigger = TaskDoneTrigger( cluster="cluster", From 2699fbb36fc965d0c00347cb6c1b7c9ab275a592 Mon Sep 17 00:00:00 2001 From: Niko Oliveira Date: Tue, 23 Jun 2026 18:18:41 -0700 Subject: [PATCH 3/3] Exclude TaskDoneTrigger from trigger serialize/init sync check The deprecated region alias is folded into region_name at construction time, so it intentionally does not round-trip through serialize() -- the same by-design pattern already used for GKEJobTrigger's pod_name alias. --- scripts/ci/prek/check_trigger_serialize_init.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/ci/prek/check_trigger_serialize_init.py b/scripts/ci/prek/check_trigger_serialize_init.py index f1050a9340ea5..b87eddc67fc5d 100755 --- a/scripts/ci/prek/check_trigger_serialize_init.py +++ b/scripts/ci/prek/check_trigger_serialize_init.py @@ -67,6 +67,9 @@ class defined in another file cannot be resolved statically and are skipped -- t # `pod_names` preserves the value and avoids re-triggering the deprecation path on restart. "google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py::GKEJobTrigger", "cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py::KubernetesJobTrigger", + # `region` is a deprecated alias folded into `region_name` in __init__; serializing only + # `region_name` preserves the value and avoids re-triggering the deprecation path on restart. + "amazon/src/airflow/providers/amazon/aws/triggers/ecs.py::TaskDoneTrigger", }