Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ def execute(self, context: Context) -> dict:
trigger=SageMakerTrigger(
job_name=self.config["ProcessingJobName"],
job_type="Processing",
poke_interval=self.check_interval,
max_attempts=self.max_attempts,
waiter_delay=self.check_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand All @@ -366,7 +366,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
if validated_event["status"] != "success":
raise AirflowException(f"Error while running job: {validated_event}")

self.log.info(validated_event["message"])
self.log.info("SageMaker job %s completed.", validated_event["job_name"])
self.serialized_job = serialize(self.hook.describe_processing_job(validated_event["job_name"]))
self.log.info("%s completed successfully.", self.task_id)
return {"Processing": self.serialized_job}
Expand Down Expand Up @@ -602,7 +602,7 @@ def execute(self, context: Context) -> dict:
trigger=SageMakerTrigger(
job_name=endpoint_info["EndpointName"],
job_type="endpoint",
poke_interval=self.check_interval,
waiter_delay=self.check_interval,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -829,8 +829,8 @@ def execute(self, context: Context) -> dict:
trigger=SageMakerTrigger(
job_name=transform_config["TransformJobName"],
job_type="Transform",
poke_interval=self.check_interval,
max_attempts=self.max_attempts,
waiter_delay=self.check_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand All @@ -852,7 +852,7 @@ def _check_if_model_exists(self, model_name: str, describe_func: Callable[[str],
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
validated_event = validate_execute_complete_event(event)

self.log.info(validated_event["message"])
self.log.info("SageMaker job %s completed.", validated_event["job_name"])
return self.serialize_result(validated_event["job_name"])

def serialize_result(self, job_name: str) -> dict[str, dict]:
Expand Down Expand Up @@ -1003,7 +1003,7 @@ def execute(self, context: Context) -> dict:
trigger=SageMakerTrigger(
job_name=self.config["HyperParameterTuningJobName"],
job_type="tuning",
poke_interval=self.check_interval,
waiter_delay=self.check_interval,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -1234,8 +1234,8 @@ def execute(self, context: Context) -> dict:
trigger=SageMakerTrigger(
job_name=self.config["TrainingJobName"],
job_type="Training",
poke_interval=self.check_interval,
max_attempts=self.max_attempts,
waiter_delay=self.check_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand All @@ -1249,7 +1249,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
if validated_event["status"] != "success":
raise AirflowException(f"Error while running job: {validated_event}")

self.log.info(validated_event["message"])
self.log.info("SageMaker job %s completed.", validated_event["job_name"])
return self.serialize_result(validated_event["job_name"])

def serialize_result(self, job_name: str) -> dict[str, dict]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,63 +18,95 @@
from __future__ import annotations

import asyncio
import warnings
from collections import Counter
from collections.abc import AsyncIterator
from enum import IntEnum
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING

from botocore.exceptions import WaiterError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.common.compat.sdk import AirflowException
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.triggers.base import TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

class SageMakerTrigger(BaseTrigger):

class SageMakerTrigger(AwsBaseWaiterTrigger):
"""
SageMakerTrigger is fired as deferred class with params to run the task in triggerer.

:param job_name: name of the job to check status
:param job_type: Type of the sagemaker job whether it is Transform or Training
:param poke_interval: polling period in seconds to check for the status
:param max_attempts: Number of times to poll for query state before returning the current state,
defaults to None.
:param waiter_delay: polling period in seconds to check for the status
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: AWS connection ID for sagemaker
:param region_name: The AWS region where the job is running. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
:param poke_interval: (deprecated) use ``waiter_delay`` instead.
:param max_attempts: (deprecated) use ``waiter_max_attempts`` instead.
"""

def __init__(
self,
job_name: str,
job_type: str,
poke_interval: int = 30,
max_attempts: int = 480,
waiter_delay: int = 30,
waiter_max_attempts: int = 480,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
poke_interval: int | None = None,
max_attempts: int | None = None,
):
super().__init__()
if poke_interval is not None:
warnings.warn(
"`poke_interval` is deprecated and will be removed in a future release. "
"Please use `waiter_delay` instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
waiter_delay = poke_interval
if max_attempts is not None:
warnings.warn(
"`max_attempts` is deprecated and will be removed in a future release. "
"Please use `waiter_max_attempts` instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
waiter_max_attempts = max_attempts
self.job_name = job_name
self.job_type = job_type
self.poke_interval = poke_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize SagemakerTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger",
{
"job_name": self.job_name,
"job_type": self.job_type,
"poke_interval": self.poke_interval,
"max_attempts": self.max_attempts,
"aws_conn_id": self.aws_conn_id,
},
super().__init__(
serialized_fields={"job_name": job_name, "job_type": job_type},
waiter_name=self._get_job_type_waiter(job_type),
waiter_args={self._get_waiter_arg_name(job_type): job_name},
failure_message=f"Error while waiting for {job_type} job",
status_message=f"{job_type} job not done yet",
status_queries=[self._get_response_status_key(job_type)],
return_key="job_name",
return_value=job_name,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
verify=verify,
botocore_config=botocore_config,
)

@cached_property
def hook(self) -> SageMakerHook:
return SageMakerHook(aws_conn_id=self.aws_conn_id)
def hook(self) -> AwsGenericHook:
return SageMakerHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

@staticmethod
def _get_job_type_waiter(job_type: str) -> str:
Expand Down Expand Up @@ -106,69 +138,80 @@ def _get_response_status_key(job_type: str) -> str:
"endpoint": "EndpointStatus",
}[job_type.lower()]

async def run(self):
self.log.info("job name is %s and job type is %s", self.job_name, self.job_type)
async with await self.hook.get_async_conn() as client:
waiter = self.hook.get_waiter(
self._get_job_type_waiter(self.job_type), deferrable=True, client=client
)
await async_wait(
waiter=waiter,
waiter_delay=self.poke_interval,
waiter_max_attempts=self.max_attempts,
args={self._get_waiter_arg_name(self.job_type): self.job_name},
failure_message=f"Error while waiting for {self.job_type} job",
status_message=f"{self.job_type} job not done yet",
status_args=[self._get_response_status_key(self.job_type)],
)
yield TriggerEvent({"status": "success", "message": "Job completed.", "job_name": self.job_name})


class SageMakerPipelineTrigger(BaseTrigger):
"""Trigger to wait for a sagemaker pipeline execution to finish."""
class SageMakerPipelineTrigger(AwsBaseWaiterTrigger):
"""
Trigger to wait for a sagemaker pipeline execution to finish.

:param waiter_type: Type of waiter to use, see ``Type`` enum.
:param pipeline_execution_arn: ARN of the pipeline execution to wait for.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: The AWS region where the pipeline runs. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
"""

class Type(IntEnum):
"""Type of waiter to use."""

COMPLETE = 1
STOPPED = 2

_waiter_name = {
Type.COMPLETE: "PipelineExecutionComplete",
Type.STOPPED: "PipelineExecutionStopped",
}

def __init__(
self,
waiter_type: Type,
waiter_type: Type | int,
pipeline_execution_arn: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
self.waiter_type = waiter_type
# waiter_type arrives as an int when deserialized from a serialized trigger.
self.waiter_type = self.Type(waiter_type)
self.pipeline_execution_arn = pipeline_execution_arn
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
super().__init__(
serialized_fields={
"waiter_type": self.waiter_type.value, # saving the int value here
"pipeline_execution_arn": self.pipeline_execution_arn,
"waiter_delay": self.waiter_delay,
"waiter_max_attempts": self.waiter_max_attempts,
"aws_conn_id": self.aws_conn_id,
"pipeline_execution_arn": pipeline_execution_arn,
},
waiter_name=self._waiter_name[self.waiter_type],
waiter_args={"PipelineExecutionArn": pipeline_execution_arn},
failure_message="Error while waiting for the pipeline execution to finish",
status_message="Pipeline execution not done yet",
status_queries=["PipelineExecutionStatus"],
return_value=pipeline_execution_arn,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
verify=verify,
botocore_config=botocore_config,
)

_waiter_name = {
Type.COMPLETE: "PipelineExecutionComplete",
Type.STOPPED: "PipelineExecutionStopped",
}
def hook(self) -> AwsGenericHook:
return SageMakerHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

async def run(self) -> AsyncIterator[TriggerEvent]:
hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
# Custom polling loop (instead of the base waiter loop) so we can surface
# per-step pipeline progress in the logs between attempts.
hook = self.hook()
async with await hook.get_async_conn() as conn:
waiter = hook.get_waiter(self._waiter_name[self.waiter_type], deferrable=True, client=conn)
for _ in range(self.waiter_max_attempts):
waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=conn)
for _ in range(self.attempts):
try:
await waiter.wait(
PipelineExecutionArn=self.pipeline_execution_arn, WaiterConfig={"MaxAttempts": 1}
Expand Down
Loading