diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py index c99228a3e425b..ebafeb850657f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -350,8 +350,8 @@ def execute(self, context: Context) -> dict: trigger=SageMakerTrigger( job_name=self.config["ProcessingJobName"], job_type="Processing", - poke_interval=self.check_interval, - max_attempts=self.max_attempts, + waiter_delay=self.check_interval, + waiter_max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -366,7 +366,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if validated_event["status"] != "success": raise AirflowException(f"Error while running job: {validated_event}") - self.log.info(validated_event["message"]) + self.log.info("SageMaker job %s completed.", validated_event["job_name"]) self.serialized_job = serialize(self.hook.describe_processing_job(validated_event["job_name"])) self.log.info("%s completed successfully.", self.task_id) return {"Processing": self.serialized_job} @@ -602,7 +602,7 @@ def execute(self, context: Context) -> dict: trigger=SageMakerTrigger( job_name=endpoint_info["EndpointName"], job_type="endpoint", - poke_interval=self.check_interval, + waiter_delay=self.check_interval, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -829,8 +829,8 @@ def execute(self, context: Context) -> dict: trigger=SageMakerTrigger( job_name=transform_config["TransformJobName"], job_type="Transform", - poke_interval=self.check_interval, - max_attempts=self.max_attempts, + waiter_delay=self.check_interval, + waiter_max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -852,7 +852,7 @@ def _check_if_model_exists(self, model_name: str, describe_func: Callable[[str], def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]: validated_event = validate_execute_complete_event(event) - self.log.info(validated_event["message"]) + self.log.info("SageMaker job %s completed.", validated_event["job_name"]) return self.serialize_result(validated_event["job_name"]) def serialize_result(self, job_name: str) -> dict[str, dict]: @@ -1003,7 +1003,7 @@ def execute(self, context: Context) -> dict: trigger=SageMakerTrigger( job_name=self.config["HyperParameterTuningJobName"], job_type="tuning", - poke_interval=self.check_interval, + waiter_delay=self.check_interval, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -1234,8 +1234,8 @@ def execute(self, context: Context) -> dict: trigger=SageMakerTrigger( job_name=self.config["TrainingJobName"], job_type="Training", - poke_interval=self.check_interval, - max_attempts=self.max_attempts, + waiter_delay=self.check_interval, + waiter_max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", @@ -1249,7 +1249,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if validated_event["status"] != "success": raise AirflowException(f"Error while running job: {validated_event}") - self.log.info(validated_event["message"]) + self.log.info("SageMaker job %s completed.", validated_event["job_name"]) return self.serialize_result(validated_event["job_name"]) def serialize_result(self, job_name: str) -> dict[str, dict]: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py index 57fbbe5871317..271fe8b51fd54 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py @@ -18,63 +18,95 @@ from __future__ import annotations import asyncio +import warnings from collections import Counter from collections.abc import AsyncIterator from enum import IntEnum -from functools import cached_property -from typing import Any +from typing import TYPE_CHECKING from botocore.exceptions import WaiterError +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook -from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger from airflow.providers.common.compat.sdk import AirflowException -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import TriggerEvent +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook -class SageMakerTrigger(BaseTrigger): + +class SageMakerTrigger(AwsBaseWaiterTrigger): """ SageMakerTrigger is fired as deferred class with params to run the task in triggerer. :param job_name: name of the job to check status :param job_type: Type of the sagemaker job whether it is Transform or Training - :param poke_interval: polling period in seconds to check for the status - :param max_attempts: Number of times to poll for query state before returning the current state, - defaults to None. + :param waiter_delay: polling period in seconds to check for the status + :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: AWS connection ID for sagemaker + :param region_name: The AWS region where the job is running. Used to build the hook. + :param verify: Whether or not to verify SSL certificates. Used to build the hook. + :param botocore_config: Configuration dictionary for the botocore client. Used to build the hook. + :param poke_interval: (deprecated) use ``waiter_delay`` instead. + :param max_attempts: (deprecated) use ``waiter_max_attempts`` instead. """ def __init__( self, job_name: str, job_type: str, - poke_interval: int = 30, - max_attempts: int = 480, + waiter_delay: int = 30, + waiter_max_attempts: int = 480, aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, + poke_interval: int | None = None, + max_attempts: int | None = None, ): - super().__init__() + if poke_interval is not None: + warnings.warn( + "`poke_interval` is deprecated and will be removed in a future release. " + "Please use `waiter_delay` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = poke_interval + if max_attempts is not None: + warnings.warn( + "`max_attempts` is deprecated and will be removed in a future release. " + "Please use `waiter_max_attempts` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_max_attempts = max_attempts self.job_name = job_name self.job_type = job_type - self.poke_interval = poke_interval - self.max_attempts = max_attempts - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - """Serialize SagemakerTrigger arguments and classpath.""" - return ( - "airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger", - { - "job_name": self.job_name, - "job_type": self.job_type, - "poke_interval": self.poke_interval, - "max_attempts": self.max_attempts, - "aws_conn_id": self.aws_conn_id, - }, + super().__init__( + serialized_fields={"job_name": job_name, "job_type": job_type}, + waiter_name=self._get_job_type_waiter(job_type), + waiter_args={self._get_waiter_arg_name(job_type): job_name}, + failure_message=f"Error while waiting for {job_type} job", + status_message=f"{job_type} job not done yet", + status_queries=[self._get_response_status_key(job_type)], + return_key="job_name", + return_value=job_name, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, + verify=verify, + botocore_config=botocore_config, ) - @cached_property - def hook(self) -> SageMakerHook: - return SageMakerHook(aws_conn_id=self.aws_conn_id) + def hook(self) -> AwsGenericHook: + return SageMakerHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) @staticmethod def _get_job_type_waiter(job_type: str) -> str: @@ -106,26 +138,20 @@ def _get_response_status_key(job_type: str) -> str: "endpoint": "EndpointStatus", }[job_type.lower()] - async def run(self): - self.log.info("job name is %s and job type is %s", self.job_name, self.job_type) - async with await self.hook.get_async_conn() as client: - waiter = self.hook.get_waiter( - self._get_job_type_waiter(self.job_type), deferrable=True, client=client - ) - await async_wait( - waiter=waiter, - waiter_delay=self.poke_interval, - waiter_max_attempts=self.max_attempts, - args={self._get_waiter_arg_name(self.job_type): self.job_name}, - failure_message=f"Error while waiting for {self.job_type} job", - status_message=f"{self.job_type} job not done yet", - status_args=[self._get_response_status_key(self.job_type)], - ) - yield TriggerEvent({"status": "success", "message": "Job completed.", "job_name": self.job_name}) - -class SageMakerPipelineTrigger(BaseTrigger): - """Trigger to wait for a sagemaker pipeline execution to finish.""" +class SageMakerPipelineTrigger(AwsBaseWaiterTrigger): + """ + Trigger to wait for a sagemaker pipeline execution to finish. + + :param waiter_type: Type of waiter to use, see ``Type`` enum. + :param pipeline_execution_arn: ARN of the pipeline execution to wait for. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: The AWS region where the pipeline runs. Used to build the hook. + :param verify: Whether or not to verify SSL certificates. Used to build the hook. + :param botocore_config: Configuration dictionary for the botocore client. Used to build the hook. + """ class Type(IntEnum): """Type of waiter to use.""" @@ -133,42 +159,59 @@ class Type(IntEnum): COMPLETE = 1 STOPPED = 2 + _waiter_name = { + Type.COMPLETE: "PipelineExecutionComplete", + Type.STOPPED: "PipelineExecutionStopped", + } + def __init__( self, - waiter_type: Type, + waiter_type: Type | int, pipeline_execution_arn: str, waiter_delay: int, waiter_max_attempts: int, aws_conn_id: str | None, + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, ): - self.waiter_type = waiter_type + # waiter_type arrives as an int when deserialized from a serialized trigger. + self.waiter_type = self.Type(waiter_type) self.pipeline_execution_arn = pipeline_execution_arn - self.waiter_delay = waiter_delay - self.waiter_max_attempts = waiter_max_attempts - self.aws_conn_id = aws_conn_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - self.__class__.__module__ + "." + self.__class__.__qualname__, - { + super().__init__( + serialized_fields={ "waiter_type": self.waiter_type.value, # saving the int value here - "pipeline_execution_arn": self.pipeline_execution_arn, - "waiter_delay": self.waiter_delay, - "waiter_max_attempts": self.waiter_max_attempts, - "aws_conn_id": self.aws_conn_id, + "pipeline_execution_arn": pipeline_execution_arn, }, + waiter_name=self._waiter_name[self.waiter_type], + waiter_args={"PipelineExecutionArn": pipeline_execution_arn}, + failure_message="Error while waiting for the pipeline execution to finish", + status_message="Pipeline execution not done yet", + status_queries=["PipelineExecutionStatus"], + return_value=pipeline_execution_arn, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, + verify=verify, + botocore_config=botocore_config, ) - _waiter_name = { - Type.COMPLETE: "PipelineExecutionComplete", - Type.STOPPED: "PipelineExecutionStopped", - } + def hook(self) -> AwsGenericHook: + return SageMakerHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) async def run(self) -> AsyncIterator[TriggerEvent]: - hook = SageMakerHook(aws_conn_id=self.aws_conn_id) + # Custom polling loop (instead of the base waiter loop) so we can surface + # per-step pipeline progress in the logs between attempts. + hook = self.hook() async with await hook.get_async_conn() as conn: - waiter = hook.get_waiter(self._waiter_name[self.waiter_type], deferrable=True, client=conn) - for _ in range(self.waiter_max_attempts): + waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=conn) + for _ in range(self.attempts): try: await waiter.wait( PipelineExecutionArn=self.pipeline_execution_arn, WaiterConfig={"MaxAttempts": 1} diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py index 0f1851df048b7..68cf554fbb1df 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py @@ -20,15 +20,20 @@ from unittest.mock import AsyncMock import pytest +from botocore.exceptions import WaiterError -from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerPipelineTrigger, SageMakerTrigger from airflow.triggers.base import TriggerEvent JOB_NAME = "job_name" -JOB_TYPE = "job_type" +JOB_TYPE = "training" AWS_CONN_ID = "aws_sagemaker_conn" -POKE_INTERVAL = 30 -MAX_ATTEMPTS = 60 +WAITER_DELAY = 30 +WAITER_MAX_ATTEMPTS = 60 +REGION_NAME = "us-west-2" +PIPELINE_ARN = "arn:aws:sagemaker:us-west-2:123456789012:pipeline/my-pipeline/execution/abc" class TestSagemakerTrigger: @@ -36,17 +41,54 @@ def test_sagemaker_trigger_serialize(self): sagemaker_trigger = SageMakerTrigger( job_name=JOB_NAME, job_type=JOB_TYPE, - poke_interval=POKE_INTERVAL, - max_attempts=MAX_ATTEMPTS, + waiter_delay=WAITER_DELAY, + waiter_max_attempts=WAITER_MAX_ATTEMPTS, aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME, ) class_path, args = sagemaker_trigger.serialize() assert class_path == "airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger" assert args["job_name"] == JOB_NAME assert args["job_type"] == JOB_TYPE - assert args["poke_interval"] == POKE_INTERVAL - assert args["max_attempts"] == MAX_ATTEMPTS + assert args["waiter_delay"] == WAITER_DELAY + assert args["waiter_max_attempts"] == WAITER_MAX_ATTEMPTS assert args["aws_conn_id"] == AWS_CONN_ID + assert args["region_name"] == REGION_NAME + + @pytest.mark.parametrize( + ("deprecated_kwarg", "canonical_attr", "value"), + [ + ("poke_interval", "waiter_delay", 17), + ("max_attempts", "attempts", 21), + ], + ) + def test_sagemaker_trigger_deprecated_params(self, deprecated_kwarg, canonical_attr, value): + with pytest.warns(AirflowProviderDeprecationWarning, match=deprecated_kwarg): + trigger = SageMakerTrigger( + job_name=JOB_NAME, + job_type=JOB_TYPE, + aws_conn_id=AWS_CONN_ID, + **{deprecated_kwarg: value}, + ) + assert getattr(trigger, canonical_attr) == value + + def test_sagemaker_trigger_hook_uses_generic_params(self): + sagemaker_trigger = SageMakerTrigger( + job_name=JOB_NAME, + job_type=JOB_TYPE, + waiter_delay=WAITER_DELAY, + waiter_max_attempts=WAITER_MAX_ATTEMPTS, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME, + verify=False, + botocore_config={"read_timeout": 10}, + ) + hook = sagemaker_trigger.hook() + assert isinstance(hook, SageMakerHook) + assert hook.aws_conn_id == AWS_CONN_ID + assert hook._region_name == REGION_NAME + assert hook._verify is False + assert hook._config.read_timeout == 10 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -69,14 +111,121 @@ async def test_sagemaker_trigger_run_all_job_types(self, mock_async_conn, mock_g sagemaker_trigger = SageMakerTrigger( job_name=JOB_NAME, job_type=job_type, - poke_interval=POKE_INTERVAL, - max_attempts=MAX_ATTEMPTS, + waiter_delay=WAITER_DELAY, + waiter_max_attempts=WAITER_MAX_ATTEMPTS, aws_conn_id=AWS_CONN_ID, ) generator = sagemaker_trigger.run() response = await generator.asend(None) - assert response == TriggerEvent( - {"status": "success", "message": "Job completed.", "job_name": JOB_NAME} + assert response == TriggerEvent({"status": "success", "job_name": JOB_NAME}) + + +class TestSagemakerPipelineTrigger: + def test_serialize(self): + trigger = SageMakerPipelineTrigger( + waiter_type=SageMakerPipelineTrigger.Type.COMPLETE, + pipeline_execution_arn=PIPELINE_ARN, + waiter_delay=WAITER_DELAY, + waiter_max_attempts=WAITER_MAX_ATTEMPTS, + aws_conn_id=AWS_CONN_ID, ) + class_path, args = trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.sagemaker.SageMakerPipelineTrigger" + assert args["waiter_type"] == SageMakerPipelineTrigger.Type.COMPLETE.value + assert args["pipeline_execution_arn"] == PIPELINE_ARN + assert args["waiter_delay"] == WAITER_DELAY + assert args["waiter_max_attempts"] == WAITER_MAX_ATTEMPTS + assert args["aws_conn_id"] == AWS_CONN_ID + + def test_deserialize_accepts_int_waiter_type(self): + # On deserialization the waiter_type is passed back as the stored int value. + trigger = SageMakerPipelineTrigger( + waiter_type=SageMakerPipelineTrigger.Type.STOPPED.value, + pipeline_execution_arn=PIPELINE_ARN, + waiter_delay=WAITER_DELAY, + waiter_max_attempts=WAITER_MAX_ATTEMPTS, + aws_conn_id=AWS_CONN_ID, + ) + assert trigger.waiter_type == SageMakerPipelineTrigger.Type.STOPPED + assert trigger.waiter_name == "PipelineExecutionStopped" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + + trigger = SageMakerPipelineTrigger( + waiter_type=SageMakerPipelineTrigger.Type.COMPLETE, + pipeline_execution_arn=PIPELINE_ARN, + waiter_delay=WAITER_DELAY, + waiter_max_attempts=WAITER_MAX_ATTEMPTS, + aws_conn_id=AWS_CONN_ID, + ) + + response = await trigger.run().asend(None) + + assert response == TriggerEvent({"status": "success", "value": PIPELINE_ARN}) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.sagemaker.asyncio.sleep") + @mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_async_conn") + async def test_run_logs_steps_then_succeeds(self, mock_async_conn, mock_get_waiter, mock_sleep): + conn = mock.MagicMock() + conn.list_pipeline_execution_steps = AsyncMock( + return_value={ + "PipelineExecutionSteps": [ + {"StepName": "step-1", "StepStatus": "Executing"}, + {"StepName": "step-2", "StepStatus": "Succeeded"}, + ] + } + ) + mock_async_conn.return_value.__aenter__.return_value = conn + mock_sleep.return_value = None + + non_terminal_error = WaiterError( + name="PipelineExecutionComplete", + reason="not done yet", + last_response={"PipelineExecutionStatus": "Executing"}, + ) + mock_get_waiter().wait = AsyncMock(side_effect=[non_terminal_error, None]) + + trigger = SageMakerPipelineTrigger( + waiter_type=SageMakerPipelineTrigger.Type.COMPLETE, + pipeline_execution_arn=PIPELINE_ARN, + waiter_delay=WAITER_DELAY, + waiter_max_attempts=WAITER_MAX_ATTEMPTS, + aws_conn_id=AWS_CONN_ID, + ) + + response = await trigger.run().asend(None) + + assert response == TriggerEvent({"status": "success", "value": PIPELINE_ARN}) + conn.list_pipeline_execution_steps.assert_awaited_once_with(PipelineExecutionArn=PIPELINE_ARN) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter") + @mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_async_conn") + async def test_run_raises_on_terminal_failure(self, mock_async_conn, mock_get_waiter): + mock_async_conn.return_value.__aenter__.return_value = mock.MagicMock() + terminal_error = WaiterError( + name="PipelineExecutionComplete", + reason="terminal failure", + last_response={"PipelineExecutionStatus": "Failed"}, + ) + mock_get_waiter().wait = AsyncMock(side_effect=terminal_error) + + trigger = SageMakerPipelineTrigger( + waiter_type=SageMakerPipelineTrigger.Type.COMPLETE, + pipeline_execution_arn=PIPELINE_ARN, + waiter_delay=WAITER_DELAY, + waiter_max_attempts=WAITER_MAX_ATTEMPTS, + aws_conn_id=AWS_CONN_ID, + ) + + with pytest.raises(WaiterError, match="terminal failure"): + await trigger.run().asend(None)