diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py index 05f0102152d4a..4b71424e42d8a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py @@ -556,9 +556,11 @@ def execute(self, context): waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, - region=self.region_name, + region_name=self.region_name, log_group=self.awslogs_group, log_stream=self._get_logs_stream_name(), + verify=self.verify, + botocore_config=self.botocore_config, ), method_name="execute_complete", # timeout is set to ensure that if a trigger dies, the timeout does not restart diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py index 3768af58a12fd..39896960cfabd 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ecs.py @@ -18,11 +18,13 @@ from __future__ import annotations import asyncio +import warnings from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Any from botocore.exceptions import ClientError, WaiterError +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger @@ -125,7 +127,10 @@ class TaskDoneTrigger(BaseTrigger): :param waiter_max_attempts: The number of times to ping for status. Will fail after that many unsuccessful attempts. :param aws_conn_id: The Airflow connection used for AWS credentials. - :param region: The AWS region where the cluster is located. + :param region_name: The AWS region where the cluster is located. Used to build the hook. + :param verify: Whether or not to verify SSL certificates. Used to build the hook. + :param botocore_config: Configuration dictionary for the botocore client. Used to build the hook. + :param region: (deprecated) use ``region_name`` instead. """ def __init__( @@ -135,17 +140,30 @@ def __init__( waiter_delay: int, waiter_max_attempts: int, aws_conn_id: str | None, - region: str | None, + region_name: str | None = None, log_group: str | None = None, log_stream: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, + region: str | None = None, ): + if region is not None: + warnings.warn( + "`region` is deprecated and will be removed in a future release. " + "Please use `region_name` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + region_name = region self.cluster = cluster self.task_arn = task_arn self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id - self.region = region + self.region_name = region_name + self.verify = verify + self.botocore_config = botocore_config self.log_group = log_group self.log_stream = log_stream @@ -159,19 +177,27 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "waiter_delay": self.waiter_delay, "waiter_max_attempts": self.waiter_max_attempts, "aws_conn_id": self.aws_conn_id, - "region": self.region, + "region_name": self.region_name, "log_group": self.log_group, "log_stream": self.log_stream, + "verify": self.verify, + "botocore_config": self.botocore_config, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: async with ( await EcsHook( - aws_conn_id=self.aws_conn_id, region_name=self.region + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, ).get_async_conn() as ecs_client, await AwsLogsHook( - aws_conn_id=self.aws_conn_id, region_name=self.region + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, ).get_async_conn() as logs_client, ): waiter = ecs_client.get_waiter("tasks_stopped") diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py index 5a4ba278cfd14..a46fc053c3830 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py @@ -23,6 +23,7 @@ import pytest from botocore.exceptions import WaiterError +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.triggers.ecs import ( @@ -35,6 +36,87 @@ class TestTaskDoneTrigger: + def test_deprecated_region_alias(self): + with pytest.warns(AirflowProviderDeprecationWarning, match="region"): + trigger = TaskDoneTrigger( + cluster="cluster", + task_arn="task_arn", + waiter_delay=5, + waiter_max_attempts=10, + aws_conn_id="my_conn", + region="eu-west-1", + ) + assert trigger.region_name == "eu-west-1" + _, kwargs = trigger.serialize() + assert kwargs["region_name"] == "eu-west-1" + assert "region" not in kwargs + + def test_serialize_includes_generic_hook_params(self): + trigger = TaskDoneTrigger( + cluster="cluster", + task_arn="task_arn", + waiter_delay=5, + waiter_max_attempts=10, + aws_conn_id="my_conn", + region_name="eu-west-1", + log_group="lg", + log_stream="ls", + verify=False, + botocore_config={"read_timeout": 7}, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.ecs.TaskDoneTrigger" + assert kwargs == { + "cluster": "cluster", + "task_arn": "task_arn", + "waiter_delay": 5, + "waiter_max_attempts": 10, + "aws_conn_id": "my_conn", + "region_name": "eu-west-1", + "log_group": "lg", + "log_stream": "ls", + "verify": False, + "botocore_config": {"read_timeout": 7}, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.ecs.AwsLogsHook") + @mock.patch("airflow.providers.amazon.aws.triggers.ecs.EcsHook") + async def test_run_builds_hooks_with_generic_params(self, ecs_hook_cls, logs_hook_cls): + def make_hook(client): + ctx = mock.MagicMock() + ctx.__aenter__ = AsyncMock(return_value=client) + ctx.__aexit__ = AsyncMock(return_value=False) + instance = mock.MagicMock() + instance.get_async_conn = AsyncMock(return_value=ctx) + return instance + + ecs_client = mock.MagicMock() + ecs_client.get_waiter().wait = AsyncMock() + ecs_hook_cls.return_value = make_hook(ecs_client) + logs_hook_cls.return_value = make_hook(mock.MagicMock()) + + trigger = TaskDoneTrigger( + cluster="cluster", + task_arn="task_arn", + waiter_delay=0, + waiter_max_attempts=10, + aws_conn_id="my_conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 7}, + ) + await trigger.run().asend(None) + + expected = { + "aws_conn_id": "my_conn", + "region_name": "eu-west-1", + "verify": False, + "config": {"read_timeout": 7}, + } + ecs_hook_cls.assert_called_once_with(**expected) + logs_hook_cls.assert_called_once_with(**expected) + @pytest.mark.asyncio @mock.patch.object(EcsHook, "get_async_conn") # this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step diff --git a/scripts/ci/prek/check_trigger_serialize_init.py b/scripts/ci/prek/check_trigger_serialize_init.py index f1050a9340ea5..b87eddc67fc5d 100755 --- a/scripts/ci/prek/check_trigger_serialize_init.py +++ b/scripts/ci/prek/check_trigger_serialize_init.py @@ -67,6 +67,9 @@ class defined in another file cannot be resolved statically and are skipped -- t # `pod_names` preserves the value and avoids re-triggering the deprecation path on restart. "google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py::GKEJobTrigger", "cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py::KubernetesJobTrigger", + # `region` is a deprecated alias folded into `region_name` in __init__; serializing only + # `region_name` preserves the value and avoids re-triggering the deprecation path on restart. + "amazon/src/airflow/providers/amazon/aws/triggers/ecs.py::TaskDoneTrigger", }