Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,11 @@ def execute(self, context):
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region_name,
region_name=self.region_name,
log_group=self.awslogs_group,
log_stream=self._get_logs_stream_name(),
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from __future__ import annotations

import asyncio
import warnings
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any

from botocore.exceptions import ClientError, WaiterError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
Expand Down Expand Up @@ -125,7 +127,10 @@ class TaskDoneTrigger(BaseTrigger):
:param waiter_max_attempts: The number of times to ping for status.
Will fail after that many unsuccessful attempts.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region: The AWS region where the cluster is located.
:param region_name: The AWS region where the cluster is located. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
:param region: (deprecated) use ``region_name`` instead.
"""

def __init__(
Expand All @@ -135,17 +140,30 @@ def __init__(
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str | None,
region: str | None,
region_name: str | None = None,
log_group: str | None = None,
log_stream: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
region: str | None = None,
):
if region is not None:
warnings.warn(
"`region` is deprecated and will be removed in a future release. "
"Please use `region_name` instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
region_name = region
self.cluster = cluster
self.task_arn = task_arn

self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
self.region = region
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config

self.log_group = log_group
self.log_stream = log_stream
Expand All @@ -159,19 +177,27 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"waiter_delay": self.waiter_delay,
"waiter_max_attempts": self.waiter_max_attempts,
"aws_conn_id": self.aws_conn_id,
"region": self.region,
"region_name": self.region_name,
"log_group": self.log_group,
"log_stream": self.log_stream,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
async with (
await EcsHook(
aws_conn_id=self.aws_conn_id, region_name=self.region
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
).get_async_conn() as ecs_client,
await AwsLogsHook(
aws_conn_id=self.aws_conn_id, region_name=self.region
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
).get_async_conn() as logs_client,
):
waiter = ecs_client.get_waiter("tasks_stopped")
Expand Down
82 changes: 82 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/triggers/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
from botocore.exceptions import WaiterError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import (
Expand All @@ -35,6 +36,87 @@


class TestTaskDoneTrigger:
def test_deprecated_region_alias(self):
with pytest.warns(AirflowProviderDeprecationWarning, match="region"):
trigger = TaskDoneTrigger(
cluster="cluster",
task_arn="task_arn",
waiter_delay=5,
waiter_max_attempts=10,
aws_conn_id="my_conn",
region="eu-west-1",
)
assert trigger.region_name == "eu-west-1"
_, kwargs = trigger.serialize()
assert kwargs["region_name"] == "eu-west-1"
assert "region" not in kwargs

def test_serialize_includes_generic_hook_params(self):
trigger = TaskDoneTrigger(
cluster="cluster",
task_arn="task_arn",
waiter_delay=5,
waiter_max_attempts=10,
aws_conn_id="my_conn",
region_name="eu-west-1",
log_group="lg",
log_stream="ls",
verify=False,
botocore_config={"read_timeout": 7},
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.amazon.aws.triggers.ecs.TaskDoneTrigger"
assert kwargs == {
"cluster": "cluster",
"task_arn": "task_arn",
"waiter_delay": 5,
"waiter_max_attempts": 10,
"aws_conn_id": "my_conn",
"region_name": "eu-west-1",
"log_group": "lg",
"log_stream": "ls",
"verify": False,
"botocore_config": {"read_timeout": 7},
}

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.triggers.ecs.AwsLogsHook")
@mock.patch("airflow.providers.amazon.aws.triggers.ecs.EcsHook")
async def test_run_builds_hooks_with_generic_params(self, ecs_hook_cls, logs_hook_cls):
def make_hook(client):
ctx = mock.MagicMock()
ctx.__aenter__ = AsyncMock(return_value=client)
ctx.__aexit__ = AsyncMock(return_value=False)
instance = mock.MagicMock()
instance.get_async_conn = AsyncMock(return_value=ctx)
return instance

ecs_client = mock.MagicMock()
ecs_client.get_waiter().wait = AsyncMock()
ecs_hook_cls.return_value = make_hook(ecs_client)
logs_hook_cls.return_value = make_hook(mock.MagicMock())

trigger = TaskDoneTrigger(
cluster="cluster",
task_arn="task_arn",
waiter_delay=0,
waiter_max_attempts=10,
aws_conn_id="my_conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 7},
)
await trigger.run().asend(None)

expected = {
"aws_conn_id": "my_conn",
"region_name": "eu-west-1",
"verify": False,
"config": {"read_timeout": 7},
}
ecs_hook_cls.assert_called_once_with(**expected)
logs_hook_cls.assert_called_once_with(**expected)

@pytest.mark.asyncio
@mock.patch.object(EcsHook, "get_async_conn")
# this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step
Expand Down
3 changes: 3 additions & 0 deletions scripts/ci/prek/check_trigger_serialize_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class defined in another file cannot be resolved statically and are skipped -- t
# `pod_names` preserves the value and avoids re-triggering the deprecation path on restart.
"google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py::GKEJobTrigger",
"cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py::KubernetesJobTrigger",
# `region` is a deprecated alias folded into `region_name` in __init__; serializing only
# `region_name` preserves the value and avoids re-triggering the deprecation path on restart.
"amazon/src/airflow/providers/amazon/aws/triggers/ecs.py::TaskDoneTrigger",
}


Expand Down
Loading