diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py index d5d2929d10cfa..404b40e59e444 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/ec2.py @@ -75,6 +75,8 @@ def execute(self, context: Context) -> Any: aws_conn_id=self.aws_conn_id, region_name=self.region_name, poll_interval=int(self.poke_interval), + verify=self.verify, + botocore_config=self.botocore_config, ), method_name="execute_complete", ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ec2.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ec2.py index 9346e86caf7c6..8b2d4d98af841 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ec2.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ec2.py @@ -37,6 +37,8 @@ class EC2StateSensorTrigger(BaseTrigger): maintained on each worker node). :param region_name: (optional) aws region name associated with the client :param poll_interval: number of seconds to wait before attempting the next poll + :param verify: Whether or not to verify SSL certificates. Used to build the hook. + :param botocore_config: Configuration dictionary for the botocore client. Used to build the hook. """ def __init__( @@ -46,12 +48,16 @@ def __init__( aws_conn_id: str | None = "aws_default", region_name: str | None = None, poll_interval: int = 60, + verify: bool | str | None = None, + botocore_config: dict | None = None, ): self.instance_id = instance_id self.target_state = target_state self.aws_conn_id = aws_conn_id self.region_name = region_name self.poll_interval = poll_interval + self.verify = verify + self.botocore_config = botocore_config def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -62,12 +68,20 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "aws_conn_id": self.aws_conn_id, "region_name": self.region_name, "poll_interval": self.poll_interval, + "verify": self.verify, + "botocore_config": self.botocore_config, }, ) @cached_property def hook(self): - return EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type") + return EC2Hook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + api_type="client_type", + ) async def run(self): while True: diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py index c91b8ecf62e7d..cf87fbdf2d586 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ec2.py @@ -48,6 +48,28 @@ def test_ec2_state_sensor_trigger_serialize(self): assert args["aws_conn_id"] == TEST_CONN_ID assert args["region_name"] == TEST_REGION_NAME assert args["poll_interval"] == TEST_POLL_INTERVAL + assert args["verify"] is None + assert args["botocore_config"] is None + + def test_ec2_state_sensor_trigger_serializes_generic_hook_params(self): + test_ec2_state_sensor = EC2StateSensorTrigger( + instance_id=TEST_INSTANCE_ID, + target_state=TEST_TARGET_STATE, + aws_conn_id=TEST_CONN_ID, + region_name=TEST_REGION_NAME, + poll_interval=TEST_POLL_INTERVAL, + verify=False, + botocore_config={"read_timeout": 99}, + ) + _, args = test_ec2_state_sensor.serialize() + assert args["verify"] is False + assert args["botocore_config"] == {"read_timeout": 99} + + hook = test_ec2_state_sensor.hook + assert hook.aws_conn_id == TEST_CONN_ID + assert hook._region_name == TEST_REGION_NAME + assert hook._verify is False + assert hook._config.read_timeout == 99 @pytest.mark.asyncio @mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_instance_state_async")