Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def execute(self, context: Context) -> None:
cluster_identifier=self.cluster_identifier,
target_status=self.target_status,
poke_interval=self.poke_interval,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import asyncio
from collections.abc import AsyncIterator
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
Expand Down Expand Up @@ -283,6 +284,9 @@ class RedshiftClusterTrigger(BaseTrigger):
:param cluster_identifier: unique identifier of a cluster
:param target_status: Reference to the status which needs to be checked
:param poke_interval: polling period in seconds to check for the status
:param region_name: The AWS region where the cluster is. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
"""

def __init__(
Expand All @@ -291,12 +295,18 @@ def __init__(
cluster_identifier: str,
target_status: str,
poke_interval: float,
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
super().__init__()
self.aws_conn_id = aws_conn_id
self.cluster_identifier = cluster_identifier
self.target_status = target_status
self.poke_interval = poke_interval
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize RedshiftClusterTrigger arguments and classpath."""
Expand All @@ -307,15 +317,26 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"cluster_identifier": self.cluster_identifier,
"target_status": self.target_status,
"poke_interval": self.poke_interval,
"region_name": self.region_name,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

@cached_property
def hook(self) -> RedshiftHook:
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Run async until the cluster status matches the target status."""
try:
hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
while True:
status = await hook.cluster_status_async(self.cluster_identifier)
status = await self.hook.cluster_status_async(self.cluster_identifier)
if status == self.target_status:
yield TriggerEvent({"status": "success", "message": "target state met"})
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,33 @@ def test_redshift_cluster_sensor_trigger_serialization(self):
"cluster_identifier": "mock_cluster_identifier",
"target_status": "available",
"poke_interval": POLLING_PERIOD_SECONDS,
"region_name": None,
"verify": None,
"botocore_config": None,
}

def test_redshift_cluster_trigger_serializes_generic_hook_params(self):
"""Asserts the generic AWS hook params are serialized and used to build the hook."""
trigger = RedshiftClusterTrigger(
aws_conn_id="test_redshift_conn_id",
cluster_identifier="mock_cluster_identifier",
target_status="available",
poke_interval=POLLING_PERIOD_SECONDS,
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)
_, kwargs = trigger.serialize()
assert kwargs["region_name"] == "eu-west-1"
assert kwargs["verify"] is False
assert kwargs["botocore_config"] == {"read_timeout": 42}

hook = trigger.hook
assert hook.aws_conn_id == "test_redshift_conn_id"
assert hook._region_name == "eu-west-1"
assert hook._verify is False
assert hook._config.read_timeout == 42

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status_async")
async def test_redshift_cluster_sensor_trigger_success(self, mock_cluster_status):
Expand Down
Loading