Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def execute(self, context: Context) -> Any:
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
poll_interval=int(self.poke_interval),
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class EC2StateSensorTrigger(BaseTrigger):
maintained on each worker node).
:param region_name: (optional) aws region name associated with the client
:param poll_interval: number of seconds to wait before attempting the next poll
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
"""

def __init__(
Expand All @@ -46,12 +48,16 @@ def __init__(
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
poll_interval: int = 60,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
self.instance_id = instance_id
self.target_state = target_state
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poll_interval = poll_interval
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
Expand All @@ -62,12 +68,20 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"poll_interval": self.poll_interval,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

@cached_property
def hook(self):
return EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
return EC2Hook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
api_type="client_type",
)

async def run(self):
while True:
Expand Down
22 changes: 22 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def test_ec2_state_sensor_trigger_serialize(self):
assert args["aws_conn_id"] == TEST_CONN_ID
assert args["region_name"] == TEST_REGION_NAME
assert args["poll_interval"] == TEST_POLL_INTERVAL
assert args["verify"] is None
assert args["botocore_config"] is None

def test_ec2_state_sensor_trigger_serializes_generic_hook_params(self):
test_ec2_state_sensor = EC2StateSensorTrigger(
instance_id=TEST_INSTANCE_ID,
target_state=TEST_TARGET_STATE,
aws_conn_id=TEST_CONN_ID,
region_name=TEST_REGION_NAME,
poll_interval=TEST_POLL_INTERVAL,
verify=False,
botocore_config={"read_timeout": 99},
)
_, args = test_ec2_state_sensor.serialize()
assert args["verify"] is False
assert args["botocore_config"] == {"read_timeout": 99}

hook = test_ec2_state_sensor.hook
assert hook.aws_conn_id == TEST_CONN_ID
assert hook._region_name == TEST_REGION_NAME
assert hook._verify is False
assert hook._config.read_timeout == 99

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_instance_state_async")
Expand Down
Loading