Skip to content

Latest commit

 

History

History
2707 lines (2250 loc) · 85.3 KB

File metadata and controls

2707 lines (2250 loc) · 85.3 KB

Nova Forge SDK - API Specification

Table of Contents

  1. NovaModelCustomizer
  2. Runtime Managers
  3. Dataset Loaders
  4. Job Results
  5. Utility Functions
  6. Monitoring
  7. Enums and Configuration

NovaModelCustomizer

The main entrypoint class for customizing and training Nova models.

Constructor

__init__()

Initializes a NovaModelCustomizer instance.

Signature:

def __init__(
 self,
 model: Model,
 method: TrainingMethod,
 infra: RuntimeManager,
 data_s3_path: Optional[str] = None,
 output_s3_path: Optional[str] = None,
 model_path: Optional[str] = None,
 validation_config: Optional[Dict[str, bool]] = None,
 generated_recipe_dir: Optional[str] = None,
 mlflow_monitor: Optional[MLflowMonitor] = None,
 deployment_mode: DeploymentMode = DeploymentMode.FAIL_IF_EXISTS,
 data_mixing_enabled: bool = False,
 enable_job_caching: bool = False,
)

Parameters:

  • model (Model): The Nova model to be trained (e.g., Model.NOVA_MICRO, Model.NOVA_LITE, Model.NOVA_LITE_2, Model.NOVA_PRO)
  • method (TrainingMethod): The fine-tuning method (e.g., TrainingMethod.SFT_LORA, TrainingMethod.RFT)
  • infra (RuntimeManager): Runtime infrastructure manager (e.g., SMTJRuntimeManager, SMHPRuntimeManager, or BedrockRuntimeManager)
  • data_s3_path (Optional[str]): S3 path to the training dataset
  • output_s3_path (Optional[str]): S3 path for output artifacts. If not provided, will be auto-generated
  • model_path (Optional[str]): S3 path for model path
  • validation_config (Optional[Dict[str, bool]]): Optional dict to control validation. Defaults to {'iam': True, 'infra': True, 'recipe': True}.
    • iam (bool): Enable IAM permission validation (default: True)
    • infra (bool): Enable infrastructure validation (default: True)
    • recipe (bool): Enable recipe constraint validation (default: True)
  • generated_recipe_dir (Optional[str]): Optional local path to save the generated recipe
  • mlflow_monitor (Optional[MLflowMonitor]): Optional MLflow monitoring configuration for experiment tracking (SageMaker only, not supported on Bedrock)
  • deployment_mode (DeploymentMode): Behavior when deploying to existing endpoint name. Options: FAIL_IF_EXISTS (default), UPDATE_IF_EXISTS
  • data_mixing_enabled (bool): Enable data mixing feature for CPT and SFT training on SageMaker HyperPod. Default is False
    • Note: The data_mixing_enabled parameter must be set to True during initialization to use data mixing features.
    • Note: Datamixing is only supported for CPT, SFT_LORA, and SFT_FULL methods on SageMaker HyperPod (SMHP).
  • enable_job_caching (bool): Whether to enable job result caching. When enabled, completed job results are cached to job_cache_dir (default: .cached-nova-jobs/) and reused for identical job configurations. Default: False

Raises:

  • ValueError: If region is unsupported or model is invalid

Example:

from amzn_nova_forge import *

# SageMaker Training Jobs (SMTJ)
infra = SMTJRuntimeManager(instance_type="ml.p5.48xlarge", instance_count=2)

customizer = NovaModelCustomizer(
 model=Model.NOVA_MICRO,
 method=TrainingMethod.SFT_LORA,
 infra=infra,
 data_s3_path="s3://my-bucket/training-data/",
 output_s3_path="s3://my-bucket/output/"
)

# Amazon Bedrock (fully managed)
bedrock_infra = BedrockRuntimeManager(
 execution_role="arn:aws:iam::123456789012:role/BedrockRole",
 base_model_identifier="arn:aws:bedrock:us-east-1::custom-model/amazon.nova-2-lite-v1:0:256k:abcdefghijk"
)

bedrock_customizer = NovaModelCustomizer(
 model=Model.NOVA_MICRO,
 method=TrainingMethod.SFT_LORA,
 infra=bedrock_infra,
 data_s3_path="s3://my-bucket/training-data/",
 output_s3_path="s3://my-bucket/output/"
)

# With MLflow monitoring (SageMaker only)
mlflow_monitor = MLflowMonitor(
 tracking_uri="arn:aws:sagemaker:us-east-1:123456789012:mlflow-app/app-xxx",
 experiment_name="nova-customization",
 run_name="sft-run-1"
)

customizer_with_mlflow = NovaModelCustomizer(
 model=Model.NOVA_MICRO,
 method=TrainingMethod.SFT_LORA,
 infra=infra,
 data_s3_path="s3://my-bucket/training-data/",
 output_s3_path="s3://my-bucket/output/",
 mlflow_monitor=mlflow_monitor
)

Methods

get_data_mixing_config()

Get the current data mixing configuration.

Signature:

def get_data_mixing_config(
 self
) -> Dict[str, Any]

Returns:

  • Dict[str, Any]: Dictionary containing the data mixing configuration

Example:

config = customizer.get_data_mixing_config()
print(config)
# Output: {'customer_data_percent': 50, 'nova_code_percent': 30, 'nova_general_percent': 70}

set_data_mixing_config()

Set the data mixing configuration.

Signature:

def set_data_mixing_config(
 self,
 config: Dict[str, Any]
) -> None

Parameters:

  • config (Dict[str, Any]): Dictionary containing the data mixing configuration
    • customer_data_percent (int/float): Percentage of customer data (0-100)
    • nova_code_percent (int/float): Percentage of Nova code data (0-100)
    • nova_general_percent (int/float): Percentage of Nova general data (0-100)
    • Nova percentages must sum to 100%

Raises:

  • ValueError: If data mixing is not enabled or configuration is invalid

Example:

# Must initialize with data_mixing_enabled=True
customizer = NovaModelCustomizer(
    model=Model.NOVA_LITE_2,
    method=TrainingMethod.SFT_LORA,
    infra=SMHPRuntimeManager(...),
    data_s3_path="s3://bucket/data.jsonl",
    data_mixing_enabled=True
)

# Set data mixing configuration
customizer.set_data_mixing_config({
    "customer_data_percent": 50,
    "nova_code_percent": 30,
    "nova_general_percent": 70
})

train()

Generates the recipe YAML, configures runtime, and launches a training job.

Signature:

def train(
 self,
 job_name: str,
 recipe_path: Optional[str] = None,
 overrides: Optional[Dict[str, Any]] = None,
 rft_lambda_arn: Optional[str] = None,
 validation_data_s3_path: Optional[str] = None,
 dry_run: Optional[bool] = False
) -> TrainingResult

Parameters:

  • job_name (str): User-defined name for the training job
  • recipe_path (Optional[str]): Path for a YAML recipe file (both S3 and local paths are accepted)
  • overrides (Optional[Dict[str, Any]]): Dictionary of configuration overrides. Example overrides below:
    • max_epochs (int): Maximum number of training epochs
    • lr (float): Learning rate
    • warmup_steps (int): Number of warmup steps
    • loraplus_lr_ratio (float): LoRA+ learning rate ratio
    • global_batch_size (int): Global batch size
    • max_length (int): Maximum sequence length
    • A full list of available overrides can be found via the Nova Customization public documentation or by referencing the training recipes here.
  • rft_lambda_arn (Optional[str]): Rewards Lambda ARN (only used for RFT training methods). If passed, takes priority over rft_lambda_arn set on the RuntimeManager.
  • validation_data_s3_path (Optional[str]): Validation S3 path, only applicable for CPT (but is still optional for CPT)
  • dry_run (Optional[bool]): Actually starts a job if False, otherwise just performs validation.

Returns:

  • TrainingResult: Metadata object (either SMTJTrainingResult, SMHPTrainingResult, or BedrockTrainingResult) containing:
  • job_id (str): The training job identifier
  • method (TrainingMethod): The training method used
  • started_time (datetime): Job start timestamp
  • model_artifacts (ModelArtifacts): Paths to model checkpoints and outputs
    • checkpoint_s3_path (str, Optional): Path to the model checkpoint/trained model.
    • output_s3_path (str): Path to the metrics and output tar file.
  • model_type (Model): Model type of the model being trained

Raises:

  • Exception: If job execution fails
  • ValueError: If training method is not supported

Example:

result = customizer.train(
 job_name="my-training-job",
 overrides={
 'max_epochs': 10,
 'lr': 5e-6,
 'warmup_steps': 20,
 'global_batch_size': 128
 }
)
print(f"Training job started: {result.job_id}")
print(f"Checkpoint path: {result.model_artifacts.checkpoint_s3_path}")

evaluate()

Generates the recipe YAML, configures runtime, and launches an evaluation job.

Signature:

def evaluate(
 self,
 job_name: str,
 eval_task: EvaluationTask,
 model_path: Optional[str] = None,
 subtask: Optional[str] = None,
 data_s3_path: Optional[str] = None,
 recipe_path: Optional[str] = None,
 overrides: Optional[Dict[str, Any]] = None,
 processor: Optional[Dict[str, Any]] = None,
 rl_env: Optional[Dict[str, Any]] = None,
 dry_run: Optional[bool] = False,
 job_result: Optional[TrainingResult] = None
) -> EvaluationResult | None

Parameters:

  • job_name (str): User-defined name for the evaluation job
  • eval_task (EvaluationTask): The evaluation task to be performed (e.g., EvaluationTask.MMLU)
  • model_path (Optional[str]): S3 path for model to evaluate. If not provided, will attempt to extract from job_result or the customizer's most recent training job.
  • data_s3_path (Optional[str]): S3 URI for the dataset. Only required for BYOD (Bring Your Own Data) evaluation tasks.
  • subtask (Optional[str]): Subtask for evaluation (task-specific)
    • The list of available subtasks per task can be found here: Subtasks
  • recipe_path (Optional[str]): Path for a YAML recipe file (both S3 and local paths are accepted)
  • overrides (Optional[Dict[str, Any]]): Dictionary of inference configuration overrides
    • max_new_tokens (int): Maximum tokens to generate
    • top_k (int): Top-k sampling parameter
    • top_p (float): Top-p (nucleus) sampling parameter
    • temperature (float): Temperature for sampling
  • processor (Optional[Dict[str, Any]]): Optional, Bring Your Own Metrics/RFT lambda Configuration
  • rl_env (Optional[Dict[str, Any]]): Optional, Bring your own reinforcement learning environment config. For RFT_EVAL, if either processor or rl_env is explicitly passed, it takes priority over rft_lambda_arn set on the RuntimeManager.
  • dry_run (Optional[bool]): Actually starts a job if False, otherwise just performs validation.
  • job_result (Optional[TrainingResult]): Optional TrainingResult object to extract checkpoint path from. If provided and model_path is None, will automatically extract the checkpoint path from the training job's output and validate platform compatibility.

Returns:

  • EvaluationResult(BaseJobResult): Metadata object (either SMTJEvaluationResult, SMHPEvaluationResult, or BedrockEvaluationResult) containing:
    • job_id (str): The evaluation job identifier
    • started_time (datetime): Job start timestamp
    • eval_output_path (str): S3 path to evaluation results
    • eval_task (EvaluationTask): The Evaluation task
  • Returns None if dry_run=True

Example:

from amzn_nova_forge.recipe import *

# General eval task (with overrides)
eval_result = customizer.evaluate(
    job_name="my-eval-job",
    eval_task=EvaluationTask.MMLU,
    model_path="s3://my-bucket/checkpoints/my-model/",
    overrides={
        'max_new_tokens': 2048,
        'temperature': 0,
        'top_p': 1.0
    }
)
print(f"Evaluation job started: {eval_result.job_id}")

# BYOM eval task (by providing processor config)
byom_eval_result = customizer.evaluate(
    job_name='my-eval-test-byom',
    eval_task=EvaluationTask.GEN_QA,
    data_s3_path="s3://bucket/data",
    processor={
        "lambda_arn": "arn:aws:lambda:<region>:123456789012:function:byom-lambda"
    }
)

deploy()

Creates a custom model and deploys it to Amazon Bedrock or SageMaker.

Deployment behavior when endpoint already exists is controlled by the deployment_mode parameter set during NovaModelCustomizer initialization:

  • FAIL_IF_EXISTS: Raise error (default, safest)
  • UPDATE_IF_EXISTS: Try in-place update, fail if not supported (PT only)

Signature:

def deploy(
  self,
  model_artifact_path: Optional[str] = None,
  deploy_platform: DeployPlatform = DeployPlatform.BEDROCK_OD,
  unit_count: Optional[int] = None,
  endpoint_name: Optional[str] = None,
  job_result: Optional[TrainingResult] = None,
  execution_role_name: Optional[str] = None,
  sagemaker_instance_type: Optional[str] = "ml.p5.48xlarge",
  sagemaker_environment_variables: Optional[Dict[str, Any]] = None,
) -> DeploymentResult
  • Note: If DeployPlatform.BEDROCK_PT or DeployPlatform.SAGEMAKER is selected, you must include a value for unit_count.
  • Note: If model_artifact_path is provided, we will NOT attempt to resolve model_artifact_path from job_result or the enclosing NovaModelCustomizer object.

Parameters:

  • model_artifact_path (Optional[str]): S3 path to the trained model checkpoint. If not provided, will attempt to extract from job_result or the job_id field of the Customizer.
  • deploy_platform (DeployPlatform): Platform to deploy the model to
  • DeployPlatform.BEDROCK_OD: Bedrock On-Demand
  • DeployPlatform.BEDROCK_PT: Bedrock Provisioned Throughput
  • DeployPlatform.SAGEMAKER: SageMaker
  • unit_count (Optional[int]): Used in Bedrock Provisioned Throughput number of PT to purchase or SageMaker number of initial instances
  • endpoint_name (Optional[str]): Name of the deployed model's endpoint (auto-generated if not provided)
  • job_result (Optional[TrainingResult]): Training job result object to use for extracting checkpoint path and validating job completion. Also used to retrieve job_id if it's not provided.
  • execution_role_name: Optional IAM execution role name for Bedrock or SageMaker, defaults to BedrockDeployModelExecutionRole or SageMakerExecutionRoleName. If this role does not exist, it will be created.
  • sagemaker_instance_type: Optional EC2 instance type for SageMaker deployment, defaults to ml.p5.48xlarge
  • sagemaker_environment_variables: Optional environment variables for model configuration Returns:
  • DeploymentResult: Contains:
  • endpoint (EndpointInfo): Endpoint information
  • platform (DeployPlatform): Deployment platform
  • endpoint_name (str): Endpoint name
  • uri (str): Model ARN
  • model_artifact_path (str): S3 path to artifacts
  • created_at (datetime): Deployment creation timestamp

Raises:

  • Exception: When unable to successfully deploy the model
  • ValueError: If platform is not supported

Example:

from amzn_nova_forge.model import *

bedrock_deployment = customizer.deploy(
 model_artifact_path="s3://escrow-bucket/my-model-artifacts/",
 deploy_platform=DeployPlatform.BEDROCK_OD,
 endpoint_name="my-custom-nova-model-bedrock"
)
print(f"Model deployed: {bedrock_deployment.endpoint.uri}")
print(f"Endpoint: {bedrock_deployment.endpoint.endpoint_name}")
print(f"Status: {bedrock_deployment.status}")

sagemaker_deployment = customizer.deploy(
 model_artifact_path="s3://escrow-bucket/my-model-artifacts/",
 deploy_platform=DeployPlatform.SAGEMAKER,
 unit_count=1,
 endpoint_name="my-custom-nova-model-sagemaker",
 sagemaker_environment_variables={
   "CONTEXT_LENGTH": "12000",
   "MAX_CONCURRENCY": "16"
 }
)
print(f"Model deployed: {sagemaker_deployment.endpoint.uri}")
print(f"Endpoint: {sagemaker_deployment.endpoint.endpoint_name}")
print(f"Status: {sagemaker_deployment.status}")

Optionally, you can provide a Bedrock execution role name to be used in deployment. Otherwise, a default Bedrock execution role will be created on your behalf. You can also use the following method to create a Bedrock execution role with scoped down IAM permissions.

from amzn_nova_forge.util.bedrock import create_bedrock_execution_role
 
iam_client = boto3.client("iam")
 
create_bedrock_execution_role(
    iam_client=iam_client, 
    role_name="BedrockDeployModelExecutionRole",
    bedrock_resource="your-model-name", # Optional: Name of the bedrock resources that IAM role should have restricted create and get access to
    s3_resource="s3-bucket" # Optional: S3 resource that IAM role should have restricted read access to such as the training output bucket
)
 

invoke_inference()

Invokes a single inference on a trained model.

Signature:

def invoke_inference(
 self,
 request_body: Dict[str, Any], 
 endpoint_arn: Optional[str]
) -> InferenceResult

Parameters:

  • request_body (Dict[str, Any]): Inference request body
  • endpoint_arn (Optional[str]):Endpoint ARN to invoke inference. Optional if user wants to send request to an already deployed endpoint on customizer

Returns:

  • InferenceResult: Metadata object (SingleInferenceResult) containing:
  • job_id (str): Batch inference job identifier
  • started_time (datetime): Job start timestamp
  • inference_output_path (str): Empty string

Example:

inference_result = customizer.invoke_inference(
    request_body={
      "messages": [{"role": "user", "content": "Hello! How are you?"}],
      "max_tokens": 100,
      "stream": False,
    },
    endpoint_arn="arn:aws:sagemaker:us-east-1:123456789012:endpoint/endpoint",
)
inference_result.show()

batch_inference()

Launches a batch inference job on a trained model.

Signature:

def batch_inference(
 self,
 job_name: str,
 input_path: str,
 output_s3_path: str,
 model_path: Optional[str] = None,
 recipe_path: Optional[str] = None,
 overrides: Optional[Dict[str, Any]] = None,
 dry_run: Optional[bool] = False
) -> InferenceResult

Parameters:

  • job_name (str): Name for the batch inference job
  • input_path (str): S3 path to input data for inference
  • output_s3_path (str): S3 path for inference outputs
  • model_path (Optional[str]): S3 path to the model
  • recipe_path (Optional[str]): Path for a YAML recipe file
  • overrides (Optional[Dict[str, Any]]): Configuration overrides for inference
  • max_new_tokens (int): Maximum tokens to generate
  • top_k (int): Top-k sampling parameter
  • top_p (float): Top-p (nucleus) sampling parameter
  • temperature (float): Temperature for sampling
  • top_logprobs (int): Number of top log probabilities to return
  • dry_run (Optional[bool]): Actually starts a job if False, otherwise just performs validation.

Returns:

  • InferenceResult: Metadata object (SMTJBatchInferenceResult) containing:
  • job_id (str): Batch inference job identifier
  • started_time (datetime): Job start timestamp
  • inference_output_path (str): S3 path to inference results
  • Note: Batch inference is only supported on SageMaker platforms (SMTJ, SMHP)

Example:

inference_result = customizer.batch_inference(
 job_name="batch-inference-job",
 input_path="s3://my-bucket/inference-input/",
 output_s3_path="s3://my-bucket/inference-output/",
 model_path="s3://my-bucket/trained-model/"
)
print(f"Batch inference started: {inference_result.job_id}")

In a separate notebook cell, you can run the following commands to get the job status and download a formatted result file when the jobs completes.

inference_result.get_job_status() # Gets the job status.
inference_result.get("s3://my-bucket/save-location/file-name.jsonl") # Uploads a formatted inference_results.jsonl file to the given s3 location.

get_logs()

Retrieves and displays CloudWatch logs for the current job.

Signature:

def get_logs(
 self,
 limit: Optional[int] = None,
 start_from_head: bool = False,
 end_time: Optional[str] = None
)

Parameters:

  • limit (Optional[int]): Maximum number of log lines to retrieve
  • start_from_head (bool): If True, start from the beginning of logs; if False, start from the end
  • end_time (Optional[str]): Optionally specify an end time for searching a log time range

Returns:

  • None (prints logs to console)

Example:

# After starting a training job
customizer.train(job_name="my-job")
customizer.get_logs(limit=100, start_from_head=True)

Runtime Managers

Runtime managers handle the infrastructure for executing training and evaluation jobs, leveraging the JobConfig dataclass to do so:

@dataclass
class JobConfig:
    job_name: str
    image_uri: str
    recipe_path: str
    output_s3_path: Optional[str] = None
    data_s3_path: Optional[str] = None
    input_s3_data_type: Optional[str] = None
    mlflow_tracking_uri: Optional[str] = None
    mlflow_experiment_name: Optional[str] = None
    mlflow_run_name: Optional[str] = None
  • The specific instance types that can be used with the runtime managers (SMTJ, SMHP) can be found in docs/instance_type_spec.md.
  • This file also defines which instance types can be used with a specific model and method.
  • Bedrock is fully managed and does not require instance type configuration.

Shared RuntimeManager Methods

The following methods are available on all RuntimeManager subclasses.

Properties (shared)

  • rft_lambda (Optional[str]): Lambda ARN or local .py file path. Assigning a new value automatically updates rft_lambda_arn — if the value is an ARN it is resolved immediately; if it is a file path, rft_lambda_arn is cleared until deploy_lambda() is called.
  • rft_lambda_arn (Optional[str]): Resolved Lambda ARN. Set immediately when rft_lambda is assigned an ARN, or populated by deploy_lambda() when rft_lambda is a file path.

Example:

# Set an ARN directly — rft_lambda_arn is updated immediately
runtime.rft_lambda = 'arn:aws:lambda:us-east-1:123456789012:function:my-reward-fn'
print(runtime.rft_lambda_arn)  # 'arn:aws:lambda:us-east-1:123456789012:function:my-reward-fn'

# Set a file path — rft_lambda_arn is cleared until deploy_lambda() is called
runtime.rft_lambda = 'reward.py'
print(runtime.rft_lambda_arn)  # None
runtime.deploy_lambda(lambda_name='my-reward-fn')
print(runtime.rft_lambda_arn)
#'arn:aws:lambda:us-east-1:123456789012:function:my-reward-fn'

deploy_lambda()

Packages a local Python file into a zip and creates or updates a Lambda function. The source file is read from self.rft_lambda, which must be set to a local .py file path before calling this method.

Signature:

def deploy_lambda(
    self,
    lambda_name: Optional[str] = None,
    execution_role_arn: Optional[str] = None,
) -> str

Parameters:

  • lambda_name (Optional[str]): Name for the Lambda function. Defaults to the source filename stem (underscores replaced with hyphens).
  • execution_role_arn (Optional[str]): IAM role ARN for the Lambda. Falls back to the runtime manager's execution_role attribute if not provided.

Returns:

  • str: The deployed Lambda function ARN. Also sets self.rft_lambda_arn on the manager.

Raises:

  • ValueError: If rft_lambda is not set, is already an ARN (nothing to deploy), the source file is not found, or no execution role can be resolved.

Example:

runtime.rft_lambda = 'rft_training_reward.py'
lambda_arn = runtime.deploy_lambda(lambda_name='my-reward-fn')
# runtime.rft_lambda_arn is now set automatically

validate_lambda()

Validates the RFT reward lambda with sample data from S3. Reads the lambda to validate from self.rft_lambda / self.rft_lambda_arn:

  • If rft_lambda is an ARN (or rft_lambda_arn is set), invokes the deployed Lambda with samples from data_s3_path.
  • If rft_lambda is a local .py path, validates by executing lambda_handler directly without deploying.

Signature:

def validate_lambda(
    self,
    data_s3_path: str,
    validation_samples: int = 10,
) -> None

Parameters:

  • data_s3_path (str): S3 path to the training dataset for pulling sample data.
  • validation_samples (int): Number of samples to load from data_s3_path (default: 10).

Raises:

  • ValueError: If rft_lambda is not set, or if validation fails.

Example:

# Validate a local file without deploying
runtime.rft_lambda = 'rft_training_reward.py'
runtime.validate_lambda(data_s3_path='s3://bucket/data.jsonl')

# Validate a deployed lambda
runtime.rft_lambda = 'arn:aws:lambda:us-east-1:123456789012:function:my-reward-fn'
runtime.validate_lambda(data_s3_path='s3://bucket/data.jsonl', validation_samples=20)

SMTJRuntimeManager

Manages SageMaker Training Jobs.

Constructor

Signature:

def __init__(
    self,
    instance_type: str,
    instance_count: int,
    execution_role: Optional[str] = None,
    kms_key_id: Optional[str] = None,
    encrypt_inter_container_traffic: bool = False,
    subnets: Optional[list[str]] = None,
    security_group_ids: Optional[list[str]] = None,
    rft_lambda: Optional[str] = None,
)

Parameters:

  • instance_type (str): EC2 instance type (e.g., "ml.p5.48xlarge", "ml.p4d.24xlarge")
  • instance_count (int): Number of instances to use
  • execution_role (Optional[str]): The execution role for the training job
  • kms_key_id (Optional[str]): Optional KMS Key Id to use in S3 Bucket encryption, training jobs and deployments.
  • encrypt_inter_container_traffic (bool): Boolean that determines whether to encrypt inter-container traffic. Default value is False.
  • subnets (Optional[list[str]]): Optional list of strings representing subnets. Default value is None.
  • security_group_ids (Optional[list[str]]): Optional list of strings representing security group IDs. Default value is None.
  • rft_lambda (Optional[str]): Lambda ARN or local .py file path for RFT reward function. Can also be set or updated after construction.

Example:

from amzn_nova_forge.manager import *
infra = SMTJRuntimeManager(
 instance_type="ml.p5.48xlarge",
 instance_count=2
)

Properties

  • instance_type (str): Returns the instance type
  • instance_count (int): Returns the number of instances

Methods

execute()

Starts a SageMaker training job.

Signature:

def execute(
 self,
 job_config: JobConfig
) -> str

Returns:

  • str: Training job name/ID
cleanup()

Stops and cleans up a training job.

Signature:

def cleanup(
 self,
 job_name: str
) -> None

SMHPRuntimeManager

Manages SageMaker HyperPod jobs.

Constructor

Signature:

def __init__(
 self,
 instance_type: str,
 instance_count: int,
 cluster_name: str,
 namespace: str,
 kms_key_id: Optional[str] = None,
 rft_lambda: Optional[str] = None,
)

Parameters:

  • instance_type (str): EC2 instance type
  • instance_count (int): Number of instances
  • cluster_name (str): HyperPod cluster name
  • namespace (str): Kubernetes namespace
  • kms_key_id (Optional[str]): Optional KMS Key Id to use in S3 Bucket encryption
  • rft_lambda (Optional[str]): Lambda ARN or local .py file path for RFT reward function. Can also be set or updated after construction.

Example:

from amzn_nova_forge.manager import *
infra = SMHPRuntimeManager(
 instance_type="ml.p5.48xlarge",
 instance_count=4,
 cluster_name="my-hyperpod-cluster",
 namespace="default"
)

Properties

  • instance_type (str): Returns the instance type
  • instance_count (int): Returns the number of instances

Methods

execute()

Starts a SageMaker HyperPod job. Signature:

def execute(
 self,
 job_config=JobConfig
) -> str

Returns:

  • str: HyperPod job ID
cleanup()

Cancels and cleans up a HyperPod job.

Signature:

def cleanup(
 self,
 job_name: str
) -> None
scale_cluster()

Scale a HyperPod cluster instance group up or down. The scaling operation is asynchronous - the cluster status will change to 'Updating' while scaling, and 'InService' when ready.

Signature:

def scale_cluster(
 self,
 instance_group_name: str,
 target_instance_count: int,
) -> Dict[str, Any]

Parameters:

  • instance_group_name (str): Name of the instance group to scale (e.g., 'worker-group')
  • target_instance_count (int): Desired number of instances for the group (must be non-negative)

Returns:

  • Dict[str, Any]: Response containing:
    • ClusterArn (str): ARN of the updated cluster
    • InstanceGroupName (str): Name of the scaled instance group
    • InstanceType (str): Instance type being scaled
    • PreviousCount (int): Current instance count before scaling
    • TargetCount (int): Target instance count after scaling

Raises:

  • ValueError: If target_instance_count is negative or instance group name is invalid
  • ClientError: If scaling fails due to insufficient quota, capacity or other cluster issues.

Example:

from amzn_nova_forge.manager import *

# Create a runtime manager for your cluster
manager = SMHPRuntimeManager(
    instance_type="ml.p4d.24xlarge",
    instance_count=4,
    cluster_name="my-hyperpod-cluster",
    namespace="default"
)

# Scale up the worker group from 4 to 8 instances
result = manager.scale_cluster(
    instance_group_name="worker-group",
    target_instance_count=8
)

# Scale down to 2 instances
result = manager.scale_cluster(
    instance_group_name="worker-group",
    target_instance_count=2
)

Notes:

  • This method only works with Restricted Instance Groups (RIGs) in HyperPod clusters. The cluster must be in 'InService' state before scaling can be initiated.
  • This method can only scale up a SMHP cluster when there is sufficient Service Quota available. You will need to request a quota increase before scaling up a RIG in your HyperPod cluster. You can learn more here.
    • Specifically, you will need to request a service quota increase for "INSTANCE_TYPE for cluster usage".
get_instance_groups()

Gets the RIGs associated with the current cluster defined in the SMHPRuntimeManager. Prints the values to the terminal and returns it as a list of dictionary entries.

Signature:

def get_instance_groups(
 self
) -> List[Dict[str, Any]]

Returns:

  • List[Dict[str, Any]]: Response containing:
    • InstanceGroupName: Name of the instance group
    • InstanceType: EC2 instance type (e.g., 'ml.p5.48xlarge')
    • CurrentCount: Current number of instances in the group

Raises:

  • ClientError: If unable to describe the cluster

Example:

from amzn_nova_forge.manager import *

# Create a runtime manager for your cluster
manager = SMHPRuntimeManager(
    instance_type="ml.p4d.24xlarge",
    instance_count=4,
    cluster_name="my-hyperpod-cluster",
    namespace="default"
)

# Get the instance groups available on the current cluster.
instance_groups = manager.get_instance_groups()

BedrockRuntimeManager

Manages Amazon Bedrock model customization jobs.

Constructor

Signature:

def __init__(
 self,
 execution_role: str,
 base_model_identifier: Optional[str] = None,
 kms_key_id: Optional[str] = None,
 rft_lambda: Optional[str] = None,
)

Parameters:

  • execution_role (str): IAM role ARN for Bedrock job execution
  • base_model_identifier (Optional[str]): Base model ARN (e.g., "arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-2-lite-v1:0:256k")
  • kms_key_id (Optional[str]): Optional KMS Key Id for encryption
  • rft_lambda (Optional[str]): Lambda ARN or local .py file path for RFT reward function. Can also be set or updated after construction.

Example:

from amzn_nova_forge.manager import *
infra = BedrockRuntimeManager(
 execution_role="arn:aws:iam::123456789012:role/BedrockRole",
 base_model_identifier="arn:aws:bedrock:us-east-1::custom-model/amazon.nova-2-lite-v1:0:256k:abcdefghijk" # optional: your custom model ARN for iterative training
)

Methods

execute()

Starts a Bedrock model customization job. Signature:

def execute(
 self,
 job_config: JobConfig
) -> str

Returns:

  • str: Bedrock job ARN
cleanup()

Stops a Bedrock customization job.

Signature:

def cleanup(
 self,
 job_name: str
) -> None

SMTJServerlessRuntimeManager

Manages SageMaker Training Jobs.

Constructor

Signature:

def __init__(
    self,
    model_package_group_name: str,
    execution_role: Optional[str] = None,
    kms_key_id: Optional[str] = None,
    encrypt_inter_container_traffic: bool = False,
    subnets: Optional[list[str]] = None,
    security_group_ids: Optional[list[str]] = None,
    max_job_runtime: Optional[int] = 86400, 
    rft_lambda: Optional[str] = None,

)

Parameters:

  • model_package_group_name (str): Model package group name to use with SageMaker Model registry (required for SMTJ Serverless)
  • execution_role (Optional[str]): The execution role for the training job
  • kms_key_id (Optional[str]): Optional KMS Key Id to use in S3 Bucket encryption, training jobs and deployments.
  • encrypt_inter_container_traffic (bool): Boolean that determines whether to encrypt inter-container traffic. Default value is False.
  • subnets (Optional[list[str]]): Optional list of strings representing subnets. Default value is None.
  • security_group_ids (Optional[list[str]]): Optional list of strings representing security group IDs. Default value is None.
  • max_job_runtime (Optional[int]): Max Job Runtime in seconds (default: 1 day)
  • rft_lambda (Optional[str]): Lambda ARN or local .py file path for RFT reward function. Can also be set or updated after construction. Example:
from amzn_nova_forge.manager import *
infra = SMTJServerlessRuntimeManager(
  model_package_group_name="model-group",
)

Properties

  • model_package_group_name (str): Model Package Group name

Methods

execute()

Starts a SageMaker training job.

Signature:

def execute(
 self,
 job_config: JobConfig
) -> str

Returns:

  • str: Training job name/ID
cleanup()

Stops and cleans up a training job.

Signature:

def cleanup(
 self,
 job_name: str
) -> None

Dataset Loaders

Dataset loaders handle loading, transforming, and saving datasets in various formats.

Base Class: DatasetLoader

Abstract base class for all dataset loaders.

Constructor

Signature:

def __init__(
 self,
 **column_mappings
)

Parameters:

  • **column_mappings: Keyword arguments mapping standard column names to dataset column names
    • Example: question="input" where "question" is the standard name and "input" is your column name

Column Mappings

If you are transforming a plain JSON, JSONL, or CSV file from a generic format (e.g. 'input/output') to another format (e.g. Converse for SFT), you need to provide "column mappings" to connect your generic column/field name to the expected ones in the transformation function.

For example, if your plain dataset has "input" and "output" columns, and you want to transform it for SFT (which requrires "question" and "answer"), you would provide the following:

loader = JSONDatasetLoader(
    question="input",
    answer="output"
)

Below is a list of accepted column mapping parameters for transformations.

  • SFT: question, answer
    • Optional: system, [image/video required options]: image_format/video_format, s3_uri, bucket_owner
    • 2.0: reasoning_text, tools/toolsConfig*
  • RFT: question, reference_answer
    • Optional: system, id, tools*
  • Eval: query, response
    • Optional: images, metadata
  • CPT: text

Additional Notes:

  • If you're providing multimodal data in a generic format, you need to provide ALL three of the following fields:
    • image_format OR video_format + s3_uri, bucket_owner
  • *tools/toolsConfig (SFT 2.0) and tools (RFT) parameters can only be provided when transforming from OpenAI Messages format to Converse or OpenAI. A generic format cannot be provided for this transformation to work.

JSONLDatasetLoader

Loads datasets from JSONL (JSON Lines) files.

Methods

load()

Loads dataset from a JSONL file (local or S3).

Signature:

def load(
 self,
 path: str
) -> "DatasetLoader"

Parameters:

  • path (str): Path to JSONL file (local path or S3 URI)

Returns:

  • DatasetLoader: Self (for method chaining)

Example:

from amzn_nova_forge.dataset import *
loader = JSONLDatasetLoader()
loader.load("s3://my-bucket/data/training.jsonl")

JSONDatasetLoader

Loads datasets from JSON files.

Methods

load()

Loads dataset from a JSON file (local or S3).

Signature:

def load(
 self,
 path: str
) -> "DatasetLoader"

Parameters:

  • path (str): Path to JSON file (local path or S3 URI)

Returns:

  • DatasetLoader: Self (for method chaining)

Example:

from amzn_nova_forge.dataset import *
loader = JSONDatasetLoader()
loader.load("data/training.json")

CSVDatasetLoader

Loads datasets from CSV files.

Methods

load()

Loads dataset from a CSV file.

Signature:

def load(
 self,
 path: str
) -> "DatasetLoader"

Parameters:

  • path (str): Path to CSV file (local path or S3 URI)

Returns:

  • DatasetLoader: Self (for method chaining)

Example:

from amzn_nova_forge.dataset import *
loader = CSVDatasetLoader(question="user_query", answer="bot_response")
loader.load("data/conversations.csv")

Common DatasetLoader Methods

These methods are available on all DatasetLoader subclasses.

show()

Displays the first n rows of the dataset. Signature:

def show(
 self,
 n: int = 10
) -> None

Parameters:

  • n (int): Number of rows to display (default: 10)

Example:

loader.show(5) # Show first 5 rows

split_data()

Splits dataset into train, validation, and test sets.

Signature:

def split_data(
 self,
 train_ratio: float = 0.8,
 val_ratio: float = 0.1,
 test_ratio: float = 0.1,
 seed: int = 42,
) -> Tuple["DatasetLoader", "DatasetLoader", "DatasetLoader"]

Parameters:

  • train_ratio (float): Proportion of data for training (default: 0.8)
  • val_ratio (float): Proportion of data for validation (default: 0.1)
  • test_ratio (float): Proportion of data for testing (default: 0.1)
  • seed (int): Random seed for reproducibility (default: 42)

Returns:

  • Tuple[DatasetLoader, DatasetLoader, DatasetLoader]: Three DatasetLoader objects (train, val, test)

Raises:

  • DataPrepError: If ratios don't sum to 1.0 or dataset is empty

Example:

train_loader, val_loader, test_loader = loader.split_data(
 train_ratio=0.7,
 val_ratio=0.2,
 test_ratio=0.1
)

transform()

Transforms dataset to the required format for a specific training method and model. Currently the following transformations are supported:

  • Q/A-formatted CSV/JSON/JSONL to SFT 1.0, SFT 2.0 (without reasoningContent, Tools), RFT, Eval, CPT
  • OpenAI Messages format to SFT 1.0 and SFT 2.0 (with Tools)

Signature:

def transform(
 self,
 method: TrainingMethod,
 model: Model
) -> "DatasetLoader"

Parameters:

  • method (TrainingMethod): The training method (e.g., TrainingMethod.SFT_LORA)
  • model (Model): The Nova model version (e.g., Model.NOVA_LITE)

Returns:

  • DatasetLoader: Self (for method chaining)

Raises:

  • ValueError: If method/model combination is not supported
  • DataPrepError: If transformation fails

Example:

loader.transform(
 method=TrainingMethod.SFT_LORA,
 model=Model.NOVA_MICRO
)

validate()

Validates dataset when given the user's intended training method and model.

Signature:

def validate(
 self,
 method: TrainingMethod,
 model: Model,
 eval_task: EvaluationTask (Optional)
) -> None

Parameters:

  • method (TrainingMethod): The training method (e.g., TrainingMethod.SFT_LORA)
  • model (Model): The Nova model version (e.g., Model.NOVA_LITE)
  • eval_task (EvaluationTask): The evaluation task (e.g., EvaluationTask.GEN_QA)

Returns:

  • None

Raises:

  • ValueError: If method/model combination is not supported or validation is unsuccessful.

Example:

loader.validate(
 method=TrainingMethod.SFT_LORA,
 model=Model.NOVA_MICRO
)

If you're validating a BYOD Evaluation dataset, you need to provide another parameter, eval_task to the validate function. For example:

loader.validate(
    method=TrainingMethod.EVALUATION,
    model=Model.NOVA_LITE_2,
    eval_task=EvaluationTask.GEN_QA
)

>> Validation succeeded for 22 samples on an Evaluation BYOD dataset

save_data()

Saves the dataset to a local or S3 location. Signature:

def save_data(
 self,
 save_path: str
) -> str

Parameters:

  • save_path (str): Path where to save the file (local or S3, must end in .json or .jsonl)

Returns:

  • str: Path where the file was saved

Raises:

  • DataPrepError: If save fails or format is unsupported

Example:

# Save locally
loader.save_data("output/training_data.jsonl")
# Save to S3
loader.save_data("s3://my-bucket/data/training_data.jsonl")

Job Results

Job result classes provide methods to check status and retrieve results from training, evaluation, and inference jobs.

Base Classes

BaseJobResult

Abstract base class for all job results.

Attributes:

  • job_id (str): Job identifier
  • started_time (datetime): Job start timestamp

Methods:

get_job_status()

Gets the current status of the job.

Signature:

def get_job_status(
 self
) -> tuple[JobStatus, str]

Returns:

  • tuple[JobStatus, str]: A tuple of (status enum, raw status string)
  • JobStatus.IN_PROGRESS: Job is running
  • JobStatus.COMPLETED: Job completed successfully
  • JobStatus.FAILED: Job failed

Example:

status, raw_status = result.get_job_status()
if status == JobStatus.COMPLETED:
 print("Job finished!")
dump(file_path: Optional[str] = None, file_name: Optional[str] = None)

Save the job result to file_path path

Signature:

def dump(
 self,
 file_path: Optional[str] = None,
 file_name: Optional[str] = None
) -> Path

Parameters:

  • file_path (Optional[str]): Directory path to save the result. Saves to current directory if not provided
  • file_name (Optional[str]): The file name of the result. Default to <job_id>_<platform>.json if not provided

Returns:

  • Path: The full result file path

Example:

result.dump()
# Result will be saved to ./{job_id}_{platform}.json under current dir
result.dump(file_path='/customized/path', file_name='customized_name.json')
# Result will be saved to /customized/path/customized_name.json
load(file_path: str)

Load the job result from the file_path path

Signature:

@classmethod
def load(
 cls,
 file_path: str
) -> "BaseJobResult":

Returns:

  • JobResultObject. The instance of subclass of BaseJobResult such as SMTJEvaluationResult, SMHPEvaluationResult, BedrockEvaluationResult, SMTJTrainingResult, SMHPTrainingResult, or BedrockTrainingResult

Example:

job_result = BaseJobResult.load('./my_job_result.json')

enable_job_notifications()

Enable email notifications for when a job reaches a terminal state (Completed, Failed, or Stopped).

Signature:

def enable_job_notifications(
    self,
    emails: list[str],
    output_s3_path: Optional[str] = None,
    region: Optional[str] = "us-east-1",
    **platform_kwargs
) -> None

Parameters:

  • emails (list[str]): List of email addresses to notify
  • output_s3_path (Optional[str]): S3 path where job outputs are stored.
    • Only required if the SDK cannot automatically extract it from the job result's model_artifacts attribute.
    • For most training jobs, this parameter is automatically populated and does not need to be provided explicitly.
  • region (Optional[str]): AWS region for notification infrastructure (default: "us-east-1")
  • **platform_kwargs: Platform-specific parameters:
    • For SMTJ:
      • kms_key_id (Optional[str]): Customer KMS key ID (not full ARN) for SNS topic encryption
    • For SMHP:
      • namespace (str): Kubernetes namespace where the PyTorchJob runs (e.g., "kubeflow", "default") (Required)
      • kubectl_layer_arn (str): ARN of the lambda-kubectl layer (Required)
      • eks_cluster_arn (Optional[str]): EKS cluster ARN (auto-detected if not provided)
      • vpc_id (Optional[str]): VPC ID (auto-detected if not provided)
      • subnet_ids (Optional[list[str]]): List of subnet IDs for Lambda (auto-detected if not provided)
      • security_group_id (Optional[str]): Security group ID for Lambda (auto-detected if not provided)
      • polling_interval_minutes (Optional[int]): How often to check job status in minutes (default: 5)
      • kms_key_id (Optional[str]): Customer KMS key ID (not full ARN) for SNS topic encryption

Returns:

  • None

Raises:

  • ValueError: If required parameters are missing or invalid
  • NotificationManagerInfraError: If infrastructure setup fails

How It Works:

  1. Creates AWS infrastructure (CloudFormation stack) if it doesn't exist:
    • DynamoDB table to store job notification configurations
    • SNS topic for email notifications
    • Lambda function to monitor job status
    • EventBridge rule (SMTJ) or scheduled rule (SMHP) to trigger Lambda
    • (SMHP only) VPC endpoints for DynamoDB and S3 if needed
  2. Stores job configuration in DynamoDB (including namespace for SMHP)
  3. Subscribes email addresses to SNS topic (users must confirm subscription)
  4. Monitors job status and sends email when job completes, fails, is stopped, or becomes degraded (SMHP only)

Email Confirmation: Users will receive a confirmation email from AWS SNS and must click the confirmation link before receiving job notifications.

Examples:

SMTJ (SageMaker Training Jobs):

# Basic usage - output_s3_path is automatically extracted
result = customizer.train(job_name="my-job")
result.enable_job_notifications(
    emails=["user@example.com", "team@example.com"]
)

# With customer KMS encryption
result.enable_job_notifications(
    emails=["user@example.com"],
    kms_key_id="abc-123-def-456"  # Just the key ID, not full ARN
)

# With custom region
result.enable_job_notifications(
    emails=["user@example.com"],
    region="us-west-2"
)

SMHP (SageMaker HyperPod):

# Basic usage (with auto-detection)
result = customizer.train(job_name="my-job")
result.enable_job_notifications(
    emails=["user@example.com"],
    namespace="kubeflow",  # Required
    kubectl_layer_arn="arn:aws:lambda:<region>:123456789012:layer:kubectl:1"  # Required
)

# With custom polling interval
result.enable_job_notifications(
    emails=["user@example.com"],
    namespace="kubeflow",
    kubectl_layer_arn="arn:aws:lambda:<region>:123456789012:layer:kubectl:1",
    polling_interval_minutes=10  # Check every 10 minutes instead of default 5
)

# With explicit VPC configuration of the cluster where jobs are being monitored.
result.enable_job_notifications(
    emails=["user@example.com"],
    namespace="kubeflow",
    kubectl_layer_arn="arn:aws:lambda:<region>:123456789012:layer:kubectl:1",
    eks_cluster_arn="arn:aws:eks:<region>:123456789012:cluster/my-cluster",
    vpc_id="vpc-12345",
    subnet_ids=["subnet-1", "subnet-2"],
    security_group_id="sg-12345"
)

Important Notes:

  • For SMHP, requires deploying a kubectl Lambda layer from AWS Serverless Application Repository
  • For SMHP, the user will need to manually grant the Lambda function access to your EKS cluster (access-entry).
  • See docs/job_notifications.md for detailed setup instructions, troubleshooting, and advanced usage

EvaluationResult (ABC)

Result object for SageMaker Training Job evaluation tasks.

Attributes:

  • job_id (str): Job identifier
  • started_time (datetime): Job start timestamp
  • eval_task (EvaluationTask): Evaluation task performed
  • eval_output_path (str): S3 path to evaluation results

Subclasses

  • SMTJEvaluationResult
  • SMHPEvaluationResult
  • BedrockEvaluationResult

Methods

get()

Downloads and returns evaluation results as a dictionary.

Signature:

def get(
 self
) -> Dict

Returns:

  • Dict: Evaluation results (empty dict if job not completed)

Example:

eval_result = customizer.evaluate(...)
# Wait for job to complete
results = eval_result.get()
print(results)

show()

Prints evaluation results to console.

Signature:

def show(
 self
) -> None

Example:

eval_result.show()

upload_tensorboard_results()

Uploads TensorBoard results to S3.

Signature:

def upload_tensorboard_results(
 self,
 tensorboard_s3_path: Optional[str] = None
) -> None

Parameters:

  • tensorboard_s3_path (Optional[str]): Target S3 path (auto-generated if not provided)

Example:

eval_result.upload_tensorboard_results(
 tensorboard_s3_path="s3://my-bucket/tensorboard/"
)

clean()

Cleans up local cached results.

Signature:

def clean(
 self
) -> None

SMTJBatchInferenceResult

Result object for batch inference jobs. Attributes:

  • job_id (str): Job identifier
  • started_time (datetime): Job start timestamp
  • inference_output_path (str): S3 path to inference outputs

Methods

get()

Downloads and returns inference results, optionally saving to S3. Signature:

def get(
 self,
 s3_path: Optional[str] = None
) -> Dict

Parameters:

  • s3_path (Optional[str]): S3 path to save formatted results Returns:
  • Dict: Dictionary containing list of inference results
  • Each result has: system, query, gold_response, inference_response, metadata Example:
inference_result = customizer.batch_inference(...)
# Wait for job to complete
results = inference_result.get(s3_path="s3://my-bucket/formatted-results.jsonl")

show()

Prints inference results to console. Signature:

def show(
 self
) -> None

clean()

Cleans up local cached results. Signature:

def clean(
 self
) -> None

IAM Role Creation SDK

This SDK provides utility functions for creating IAM roles with specific permissions for AWS Bedrock and SageMaker services.

Methods

create_bedrock_execution_role()

Creates an IAM role with permissions for Bedrock model creation and deployment.

Signature:

def create_bedrock_execution_role(
    iam_client, 
    role_name: str, 
    bedrock_resource: str = "*", 
    s3_resource: str = "*"
) -> Dict

Parameters:

  • iam_client: Boto3 IAM client
  • role_name (str): Name of the IAM role to create
  • bedrock_resource (Optional[str]): Specific Bedrock resource to restrict access. Defaults to "*" (all resources)
  • s3_resource (Optional[str]): Specific S3 resource to restrict access. Defaults to "*" (all resources)

Returns:

  • Dict: IAM role details

Example:

import boto3
from amzn_nova_forge.iam.iam_role_creator import create_bedrock_execution_role

iam_client = boto3.client("iam")
create_bedrock_execution_role(iam_client, "role-name", "bedrock_resource", "s3_resource")

create_sagemaker_execution_role()

Creates an IAM role with permissions for SageMaker model creation and deployment.

Signature:

def create_sagemaker_execution_role(
    iam_client,
    role_name: str,
    s3_resource: str = "*",
    kms_resource: str = "*",
    ec2_condition: Optional[Dict[str, Any]] = None,
    cloudwatch_metric_condition: Optional[Dict[str, Any]] = None,
    cloudwatch_logstream_resource: str = "*",
    cloudwatch_loggroup_resource: str = "*"
) -> Dict

Parameters:

  • iam_client: Boto3 IAM client
  • role_name (str): Name of the IAM role to create
  • s3_resource (Optional[str]): Specific S3 resource to restrict access
  • kms_resource (Optional[str]): Specific KMS resource to restrict access
  • ec2_condition (Optional[Dict]): Conditional access for EC2 resources
  • cloudwatch_metric_condition (Optional[Dict]): Conditional access for CloudWatch metrics
  • cloudwatch_logstream_resource (Optional[str]): Specific CloudWatch log stream resource
  • cloudwatch_loggroup_resource (Optional[str]): Specific CloudWatch log group resource

Returns:

  • Dict: IAM role details

Example:

import boto3
from amzn_nova_forge.iam.iam_role_creator import create_sagemaker_execution_role

iam_client = boto3.client("iam")
create_sagemaker_execution_role(
        iam_client,
        role_name="role-name",
        s3_resource="example-bucket""",
        kms_resource="encryption-key",
        ec2_condition={
            "ArnLike": {
                "ec2:Vpc": "arn:aws:ec2:*:*:vpc/example"
            }
        },
        cloudwatch_metric_condition={
            "StringEquals": {
                "cloudwatch:namespace": ["example-namespace"]
            }
        },
        cloudwatch_loggroup_resource="example-loggroup",
        cloudwatch_logstream_resource="example-logstream"
    )

Utility Functions

verify_reward_function()

Verifies a reward function with sample data before using it in RFT training or evaluation. This utility helps you test your reward function implementation to ensure it works correctly and returns the expected format.

Signature:

def verify_reward_function(
    reward_function: str,
    sample_data: List[Dict[str, Any]],
    region: str = "us-east-1",
    validate_format: bool = True,
    platform: Optional[Platform] = None,
) -> Dict[str, Any]

Parameters:

  • reward_function (str): Either a Lambda ARN (string starting with 'arn:aws:lambda:') or a path to a local Python file containing the reward function.
  • sample_data (List[Dict[str, Any]]): List of conversation samples to test. Each sample should be a dict with 'id', 'messages', and optionally 'reference_answer' keys.
  • region (str): AWS region for Lambda invocation (default: "us-east-1").
  • validate_format (bool): If True, validates that sample_data matches RFT format and output matches expected format (default: True).
  • platform (Platform): Platform enum (Platform.SMHP or Platform.SMTJ). Required when using Lambda ARN. When set to Platform.SMHP, validates that Lambda ARN contains 'SageMaker' in the function name as required by SageMaker HyperPod. Optional for local files.

Returns:

  • Dict[str, Any]: Dictionary containing:
    • success (bool): Always True if no exception raised
    • results (list): List of individual test results
    • total_samples (int): Total number of samples tested
    • successful_samples (int): Number of successful tests
    • warnings (list): List of warning messages (e.g., missing reference_answer)

Raises:

  • ValueError: If any validation errors are encountered, with a detailed error message listing all issues found.

Example

from amzn_nova_forge import verify_reward_function
from amzn_nova_forge.model.model_enums import Platform

# Test with Lambda ARN (platform required for Lambda ARNs)
result = verify_reward_function(
    reward_function="arn:aws:lambda:us-east-1:123456789012:function:MySageMakerReward",
    sample_data=[
        {
            "id": "sample_1",
            "reference_answer": "correct answer",
            "messages": [
                {"role": "user", "content": "question"},
                {"role": "assistant", "content": "response"}
            ]
        }
    ],
    platform=Platform.SMHP  # Required for Lambda ARNs
)

print(f"Verification: {'PASSED' if result['success'] else 'FAILED'}")
print(f"Tested {result['total_samples']} samples, {result['successful_samples']} successful")

if result.get('warnings'):
    print(f"\nWarnings:")
    for warning in result['warnings']:
        print(f"  - {warning}")

# Test with local Python file (platform optional)
result = verify_reward_function(
    reward_function="./my_reward_function.py",
    sample_data=[
        {
            "id": "sample_1",
            "reference_answer": "correct answer",
            "messages": [
                {"role": "user", "content": "question"},
                {"role": "assistant", "content": "response"}
            ]
        }
    ]
)

Output Format Requirements from Lambda:

{
    "id": "sample_1",                   # Required: string
    "aggregate_reward_score": 0.75,     # Required: float or int
    "metrics_list": [                   # Optional: validated if present
        {
            "name": "accuracy",         # Required: string
            "value": 0.85,              # Required: float or int
            "type": "Metric"            # Required: "Metric" or "Reward"
        }
    ]
}

Common Validation Errors:

  • Missing required fields in input (messages field is required)
  • Missing required fields in output (id and aggregate_reward_score are required)
  • Invalid data types (e.g., aggregate_reward_score must be a number)
  • Missing platform parameter when using Lambda ARN
  • SMHP Lambda ARN doesn't contain 'SageMaker' in function name
  • Invalid metrics_list structure (must be list of dicts with name, value, type)
  • Invalid metric type (must be "Metric" or "Reward")

Warnings:

  • Missing reference_answer: While optional in RFT datasets, reference answers are recommended for meaningful reward calculations. Without ground truth, your reward function cannot compare model outputs against expected answers.

Note: The metrics_list field is optional. If provided, it will be validated for proper structure and logged during training/evaluation.


Monitoring

CloudWatchLogMonitor

Monitors CloudWatch logs and plots training metrics for Nova model training jobs. Supports both SageMaker Training Jobs (SMTJ) and SageMaker HyperPod (SMHP) platforms.

Factory Methods

from_job_id()

Creates a CloudWatchLogMonitor from a job ID.

Signature:

@classmethod
def from_job_id(
    cls,
    job_id: str,
    platform: Platform,
    started_time: Optional[datetime] = None,
    **kwargs,
) -> "CloudWatchLogMonitor"

Parameters:

  • job_id (str): The training job identifier
  • platform (Platform): Execution platform (Platform.SMTJ or Platform.SMHP)
  • started_time (Optional[datetime]): Job start time (used to filter logs)
  • **kwargs: Platform-specific parameters:
    • SMHP requires: cluster_name (str), optional namespace (str, defaults to "kubeflow")

Returns:

  • CloudWatchLogMonitor: Monitor instance

Example:

from amzn_nova_forge.monitor import CloudWatchLogMonitor
from amzn_nova_forge.model.model_enums import Platform

# SMTJ
monitor = CloudWatchLogMonitor.from_job_id(
    job_id="my-training-job",
    platform=Platform.SMTJ,
    started_time=datetime(2026, 1, 15, 12, 0, 0)
)

# SMHP
monitor = CloudWatchLogMonitor.from_job_id(
    job_id="my-hyperpod-job",
    platform=Platform.SMHP,
    cluster_name="my-cluster",
    namespace="kubeflow"
)

from_job_result()

Creates a CloudWatchLogMonitor from a training job result object.

Signature:

@classmethod
def from_job_result(
    cls,
    job_result: BaseJobResult,
    cloudwatch_logs_client=None
) -> "CloudWatchLogMonitor"

Parameters:

  • job_result (BaseJobResult): A training or evaluation result object (e.g., TrainingResult)
  • cloudwatch_logs_client (Optional): Boto3 CloudWatch Logs client (auto-created if not provided)

Returns:

  • CloudWatchLogMonitor: Monitor instance

Example:

result = customizer.train(job_name="my-job")
monitor = CloudWatchLogMonitor.from_job_result(job_result=result)

Methods

get_logs()

Retrieves CloudWatch log events for the job.

Signature:

def get_logs(
    self,
    limit: Optional[int] = None,
    start_from_head: bool = False,
    end_time: Optional[int] = None,
) -> List[Dict]

Parameters:

  • limit (Optional[int]): Maximum number of log events to retrieve
  • start_from_head (bool): If True, start from the beginning of logs; if False, start from the end
  • end_time (Optional[int]): End time in epoch milliseconds

Returns:

  • List[Dict]: List of log event dictionaries, each containing a "message" key

Example:

logs = monitor.get_logs(limit=100)

show_logs()

Prints CloudWatch log messages to the console.

Signature:

def show_logs(
    self,
    limit: Optional[int] = None,
    start_from_head: bool = False,
    end_time: Optional[int] = None,
) -> None

Parameters:

  • limit (Optional[int]): Maximum number of log events to display
  • start_from_head (bool): If True, start from the beginning of logs; if False, start from the end
  • end_time (Optional[int]): End time in epoch milliseconds

Example:

monitor.show_logs(limit=50, start_from_head=True)

plot_metrics()

Parses training metrics from CloudWatch logs and displays them as matplotlib plots. Automatically fetches the latest logs if the job is still in progress or logs have not been retrieved yet.

Signature:

def plot_metrics(
    self,
    training_method: TrainingMethod,
    metrics: Optional[List[str]] = None,
    starting_step: Optional[int] = None,
    ending_step: Optional[int] = None,
) -> None

Parameters:

  • training_method (TrainingMethod): The training method used for the job (e.g., TrainingMethod.SFT_LORA, TrainingMethod.CPT, TrainingMethod.RFT_LORA)
  • metrics (Optional[List[str]]): List of metric names to plot. Available metrics depend on training method:
    • CPT / SFT: "training_loss"
    • RFT: "reward_score"
  • starting_step (Optional[int]): Filter to only show metrics from this global step onward
  • ending_step (Optional[int]): Filter to only show metrics up to this global step

Raises:

  • ValueError: If starting_step > ending_step, or if no logs are found for the job
  • NotImplementedError: If an unsupported metric is requested for the given training method/platform

Example:

from amzn_nova_forge.monitor import CloudWatchLogMonitor
from amzn_nova_forge.model.model_enums import Platform, TrainingMethod

# Create monitor from a training result
monitor = CloudWatchLogMonitor.from_job_result(job_result=training_result)

# Plot training loss for an SFT job
monitor.plot_metrics(
    training_method=TrainingMethod.SFT_LORA,
    metrics=["training_loss"]
)

# Plot reward score for an RFT job, filtered to steps 50-200
monitor.plot_metrics(
    training_method=TrainingMethod.RFT_LORA,
    metrics=["reward_score"],
    starting_step=50,
    ending_step=200
)

MLflowMonitor

MLflow monitoring configuration for Nova model training. This class provides experiment tracking capabilities through MLflow integration.

Note: MLflow monitoring is only supported for SageMaker platforms (SMTJ, SMHP). It is not available for Bedrock platform.

MLflow Integration Features:

  • Automatic logging of training metrics
  • Model artifact and checkpoint tracking
  • Hyperparameter recording
  • Support for SageMaker MLflow tracking servers
  • Custom MLflow tracking server support (with proper network configuration)

Constructor

Signature:

def __init__(
 self,
 tracking_uri: Optional[str] = None,
 experiment_name: Optional[str] = None,
 run_name: Optional[str] = None,
)

Parameters:

  • tracking_uri (Optional[str]): MLflow tracking server URI or SageMaker MLflow app ARN. If not provided, attempts to use a default SageMaker MLflow tracking server if one exists
  • experiment_name (Optional[str]): Name of the MLflow experiment. If not provided, will use the job name
  • run_name (Optional[str]): Name of the MLflow run. If not provided, will be auto-generated

Raises:

  • ValueError: If MLflow configuration validation fails

Example:

from amzn_nova_forge.monitor import *

# With explicit tracking URI
monitor = MLflowMonitor(
    tracking_uri="arn:aws:sagemaker:us-east-1:123456789012:mlflow-app/app-xxx",
    experiment_name="nova-customization",
    run_name="sft-run-1"
)

# With default tracking URI (if available)
monitor = MLflowMonitor(
    experiment_name="nova-customization",
    run_name="sft-run-1"
)

# Use with NovaModelCustomizer
customizer = NovaModelCustomizer(
    model=Model.NOVA_LITE_2,
    method=TrainingMethod.SFT_LORA,
    infra=runtime_manager,
    data_s3_path="s3://bucket/data",
    mlflow_monitor=monitor
)

Methods

to_dict()

Converts MLflow configuration to dictionary format for use in recipe overrides.

Signature:

def to_dict(
 self
) -> dict

Returns:

  • dict: Dictionary with mlflow_* keys for recipe configuration. Returns empty dict if no tracking URI is available

Example:

monitor = MLflowMonitor(
 tracking_uri="arn:aws:sagemaker:us-east-1:123456789012:mlflow-app/app-xxx",
 experiment_name="nova-customization"
)

config_dict = monitor.to_dict()
# Returns: {
#   "mlflow_tracking_uri": "arn:aws:sagemaker:us-east-1:123456789012:mlflow-app/app-xxx",
#   "mlflow_experiment_name": "nova-customization"
# }
get_presigned_url()

Generates a presigned URL for accessing the MLflow tracking server UI directly without navigating through the AWS Console.

Signature:

def get_presigned_url(
 self,
 session_expiration_duration_in_seconds: int = 43200,
 expires_in_seconds: int = 300
) -> str

Parameters:

  • session_expiration_duration_in_seconds (int, optional): Duration in seconds for which the MLflow UI session is valid after accessing the presigned URL. Default is 43200 seconds (12 hours). Valid range: 1800-43200 seconds
  • expires_in_seconds (int, optional): Duration in seconds for which the presigned URL itself is valid. The URL must be accessed within this time. Default is 300 seconds (5 minutes). Valid range: 5-300 seconds

Returns:

  • str: Presigned URL for accessing the MLflow tracking server UI. This URL must be used within expires_in_seconds

Raises:

  • ValueError: If tracking_uri is not set
  • RuntimeError: If unable to generate presigned URL

Example:

monitor = MLflowMonitor(
    tracking_uri="arn:aws:sagemaker:us-east-1:123456789012:mlflow-app/app-xxx",
    experiment_name="nova-customization"
)

# Generate presigned URL with defaults
# URL expires in 5 minutes, but session lasts 12 hours once accessed
url = monitor.get_presigned_url()
print(f"Access MLflow UI at: {url}")

# Generate URL with custom expiration times
url = monitor.get_presigned_url(
    session_expiration_duration_in_seconds=3600,  # 1 hour session
    expires_in_seconds=60  # URL expires in 1 minute
)

MLflow Integration Notes

When MLflow monitoring is enabled:

  1. Training metrics will be automatically logged to the specified MLflow tracking server
  2. Model artifacts and checkpoints will be tracked in MLflow
  3. Hyperparameters and configuration will be recorded as MLflow parameters
  4. You can view experiment results in the MLflow UI

The MLflow integration supports:

  • SageMaker MLflow tracking servers
  • Custom MLflow tracking servers (with appropriate network configuration)
  • Automatic experiment and run creation
  • Metric logging during training
  • Artifact tracking

Enums and Configuration

Model Enum

Supported Nova models with their configurations. Values:

  • Model.NOVA_MICRO: Amazon Nova Micro (Version 1)
    • model_type: "amazon.nova-micro-v1:0:128k"
    • model_path: "nova-micro/prod"
    • version: Version.ONE
  • Model.NOVA_LITE: Amazon Nova Lite (Version 1)
    • model_type: "amazon.nova-lite-v1:0:300k"
    • model_path: "nova-lite/prod"
    • version: Version.ONE
  • Model.NOVA_LITE_2: Amazon Nova Lite (Version 2)
    • model_type: "amazon.nova-2-lite-v1:0:256k"
    • model_path: "nova-lite-2/prod"
    • version: Version.TWO
  • Model.NOVA_PRO: Amazon Nova Pro (Version 1)
    • model_type: "amazon.nova-pro-v1:0:300k"
    • model_path: "nova-pro/prod"
    • version: Version.ONE

Methods:

from_model_type()

Gets Model enum from model type string.

Signature:

@classmethod
def from_model_type(
 cls,
 model_type: str
) -> "Model"

Example:

model = Model.from_model_type("amazon.nova-micro-v1:0:128k")

TrainingMethod Enum

Supported training methods.

Values:

  • TrainingMethod.CPT: Continued Pre-Training
  • TrainingMethod.DPO_LORA: Direct Preference Optimization with LoRA
  • TrainingMethod.DPO_FULL: Direct Preference Optimization (full rank)
  • TrainingMethod.SFT_LORA: Supervised Fine-Tuning with LoRA
  • TrainingMethod.SFT_FULL: Supervised Fine-Tuning (full rank)
  • TrainingMethod.RFT_LORA: Reinforcement Fine-Tuning with LoRA
  • TrainingMethod.RFT_FULL: Full reinforcement Fine-Tuning
  • TrainingMethod.EVALUATION: Evaluation only

DeployPlatform Enum

Supported deployment platforms.

Values:

  • DeployPlatform.BEDROCK_OD: Amazon Bedrock On-Demand
  • DeployPlatform.BEDROCK_PT: Amazon Bedrock Provisioned Throughput
  • DeployPlatform.SAGEMAKER: Amazon SageMaker

DeploymentMode Enum

Deployment behavior when an endpoint with the same name already exists.

Values:

  • DeploymentMode.FAIL_IF_EXISTS: Raise an error if endpoint already exists (safest, default)
  • DeploymentMode.UPDATE_IF_EXISTS: Try in-place update only, fail if not supported (PT only)

Note: Only FAIL_IF_EXISTS and UPDATE_IF_EXISTS modes are currently supported. UPDATE_IF_EXISTS is only applicable for Bedrock Provisioned Throughput (PT) deployments.


EvaluationTask Enum

Supported evaluation tasks. Common values include:

  • EvaluationTask.MMLU: Massive Multitask Language Understanding
  • EvaluationTask.GPQA: General Physics Question Answering
  • EvaluationTask.MATH: Mathematical Problem Solving
  • EvaluationTask.GEN_QA: Custom Dataset Evaluation
  • The full list of available tasks can be found here: AWS Documentation

Platform Enum

Infrastructure platforms.

Values:

  • Platform.SMTJ: SageMaker Training Jobs
  • Platform.SMHP: SageMaker HyperPod
  • Platform.BEDROCK: Amazon Bedrock

JobStatus Enum

Job execution status.

Values:

  • JobStatus.IN_PROGRESS: Job is running
  • JobStatus.COMPLETED: Job completed successfully
  • JobStatus.FAILED: Job failed

RFT Multiturn Infrastructure

For RFT multiturn training and evaluation, you need to set up infrastructure to run reward functions.

Helper Functions

create_rft_execution_role

Creates an IAM role with required permissions for RFT multiturn infrastructure.

Function:

def create_rft_execution_role(
    region: str = "us-east-1",
    role_name: Optional[str] = None,
    custom_policy_path: Optional[str] = None
) -> str

Parameters:

  • region (str): AWS region. Default: "us-east-1"
  • role_name (Optional[str]): Custom role name. Default: "RFTExecutionRoleNovaSDK"
  • custom_policy_path (Optional[str]): Path to custom policy JSON file. If not provided, uses SDK default.

Returns:

  • str: ARN of the created/existing role

Example:

from amzn_nova_forge import create_rft_execution_role

# Create role with default name
role_arn = create_rft_execution_role(region="us-east-1")

# Create role with custom name
role_arn = create_rft_execution_role(region="us-east-1", role_name="my-custom-rft-role")

list_rft_stacks

Lists CloudFormation stacks in the region, optionally filtering for Nova SDK stacks.

Function:

def list_rft_stacks(
    region: str = "us-east-1",
    all_stacks: bool = False
) -> List[str]

Parameters:

  • region (str): AWS region. Default: "us-east-1"
  • all_stacks (bool): If True, list all stacks. If False, only list Nova SDK stacks (ending with "NovaForgeSDK"). Default: False

Returns:

  • List[str]: List of stack names

Example:

from amzn_nova_forge import list_rft_stacks

# List only Nova SDK stacks
nova_stacks = list_rft_stacks(region="us-east-1")

# List all CloudFormation stacks
all_stacks = list_rft_stacks(region="us-east-1", all_stacks=True)

RFTMultiturnInfrastructure

Manages infrastructure for RFT multiturn training (reward function workers).

Constructor:

def __init__(
    self,
    stack_name: str,
    region: str = "us-east-1",
    vf_env_id: Optional[VFEnvId] = None,
    custom_env: Optional[CustomEnvironment] = None,
    infrastructure_arn: Optional[str] = None,
    python_venv_name: Optional[str] = None,
    vpc_config: Optional[Dict[str, Any]] = None,
    cpu: Optional[str] = None,
    memory: Optional[str] = None,
    rft_role_name: Optional[str] = None,
)

Parameters:

  • stack_name (str): CloudFormation stack name
  • region (str): AWS region. Default: "us-east-1"
  • vf_env_id (Optional[VFEnvId]): Built-in environment ID (VFEnvId.WORDLE or VFEnvId.TERMINAL_BENCH)
  • custom_env (Optional[CustomEnvironment]): Custom environment (mutually exclusive with vf_env_id)
  • infrastructure_arn (Optional[str]): Platform ARN (EC2 instance ID, ECS cluster ARN, or None for LOCAL)
  • python_venv_name (Optional[str]): Python virtual environment name (required for LOCAL/EC2, optional for ECS)
  • vpc_config (Optional[Dict]): VPC configuration for ECS only. Dict with keys:
    • subnets: List[str] - Subnet IDs
    • security_groups: List[str] - Security group IDs
  • cpu (Optional[str]): CPU units for ECS tasks (e.g., "2048"). Ignored for LOCAL/EC2.
  • memory (Optional[str]): Memory in MB for ECS tasks (e.g., "4096"). Ignored for LOCAL/EC2.
  • rft_role_name (Optional[str]): IAM role name for RFT infrastructure. If not provided, uses default role or creates one.

Example:

from amzn_nova_forge import RFTMultiturnInfrastructure, CustomEnvironment, VFEnvId

# Option 1: LOCAL with built-in environment
rft_infra = RFTMultiturnInfrastructure(
    stack_name="my-rft-stack",
    region="us-east-1",
    python_venv_name="my_rft_venv",
    vf_env_id=VFEnvId.WORDLE
)

# Option 2: ECS with custom environment and VPC config
custom_env = CustomEnvironment(
    env_id="my-custom-env", 
    output_dir="~/custom_envs/", 
    env_type="single_turn"
).create(overwrite=True)

rft_infra = RFTMultiturnInfrastructure(
    stack_name="my-rft-stack",
    custom_env=custom_env,
    infrastructure_arn="arn:aws:ecs:us-east-1:123456789012:cluster/my-cluster",
    vpc_config={
        "subnets": ["subnet-12345", "subnet-67890"],
        "security_groups": ["sg-12345"]
    },
    cpu="4096",
    memory="8192"
)

# Deploy infrastructure
rft_infra.setup()

# Start training environment
rft_infra.start_training_environment()

# Use with NovaModelCustomizer
customizer = NovaModelCustomizer(
    model=Model.NOVA_LITE_2,
    method=TrainingMethod.RFT_MULTITURN_LORA,
    infra=runtime,
    data_s3_path="s3://bucket/data.jsonl"
)

training_result = customizer.train(
    job_name="rft-training",
    rft_multiturn_infra=rft_infra
)

CustomEnvironment

Create custom reward functions for RFT multiturn training.

Constructor:

def __init__(
    self,
    env_id: str,
    local_path: str = None,
    output_dir: str =  "~/custom_envs",
    env_type: str = "single_turn"
)

Methods:

  • create(overwrite: bool = False): Create environment structure
  • package_and_upload(bucket: Optional[str] = None): Upload to S3

Example:

custom_env = CustomEnvironment(
    env_id="my-custom-env",
    output_dir="~/custom_envs/",
    env_type="single_turn"
).create(overwrite=True)

custom_env.validate()
custom_env.package_and_upload()
print(f"Uploaded to: {custom_env.s3_uri}")

RFT Multiturn Methods

Infrastructure Management:

  • setup(): Deploy CloudFormation stack (Lambda, SQS, DynamoDB)
  • start_training_environment(vf_env_args: Dict = None): Start training workers
  • start_evaluation_environment(vf_env_args: Dict = None): Start evaluation workers
  • kill_task(env_type: EnvType): Stop workers
  • cleanup(delete_stack: bool = False, cleanup_environment: bool = False): Clean up resources
    • delete_stack: If True, delete CloudFormation stack
    • cleanup_environment: If True, clean up environment resources:
      • LOCAL/EC2: Delete virtual environment and starter kit directories
      • ECS: Deregister task definitions

Monitoring:

  • get_logs(env_type: EnvType, limit: int = 100, start_from_head: bool = False, log_stream_name: Optional[str] = None, tail: bool = False): View worker logs
    • tail: If True, continuously stream logs in real-time (blocks until Ctrl+C)
  • check_all_queues(): Check SQS queue status
  • flush_all_queues(): Clear all queues

Configuration:

  • get_configuration(): Get infrastructure config
  • get_recipe_overrides(): Get recipe overrides for training

Note: RFT multiturn only supports SageMaker HyperPod (SMHP) platform and Nova 2.0 models (NOVA_LITE_2).


Job Notifications

The Nova Forge SDK provides automated email notifications for training jobs when they reach terminal states (Completed, Failed, or Stopped). This feature helps you monitor long-running jobs without constantly checking their status.

Overview

Job notifications are managed through platform-specific notification managers that automatically set up and manage the required AWS infrastructure:

  • SMTJNotificationManager: For SageMaker Training Jobs (SMTJ)
  • SMHPNotificationManager: For SageMaker HyperPod (SMHP)

How It Works

When you enable notifications for a job, the SDK automatically:

  1. Creates AWS Infrastructure (if it doesn't exist):

    • CloudFormation stack with all required resources
    • DynamoDB table to store job notification configurations
    • SNS topic for email notifications
    • Lambda function to handle job state changes
    • EventBridge rule to monitor job status
    • IAM roles and policies with appropriate permissions
  2. Configures Job Monitoring:

    • Stores job configuration in DynamoDB
    • Subscribes email addresses to SNS topic
    • Monitors job status via EventBridge
  3. Sends Notifications:

    • Detects when job reaches terminal state
    • Validates output artifacts (for SMTJ, checks for manifest.json in output.tar.gz)
    • Sends email notification with job details and console link

Using Job Notifications

The simplest way to enable notifications is through the job result object:

from amzn_nova_forge_sdk import *

# Start a training job
customizer = NovaModelCustomizer(
    model=Model.NOVA_MICRO,
    method=TrainingMethod.SFT_LORA,
    infra=SMTJRuntimeManager(instance_type="ml.p5.48xlarge", instance_count=2),
    data_s3_path="s3://my-bucket/training-data/",
    output_s3_path="s3://my-bucket/output/"
)

result = customizer.train(job_name="my-training-job")

# Enable notifications
result.enable_job_notifications(
    emails=["user@example.com", "team@example.com"],
    region="us-west-2", # Optional
    kms_key_id="1234abcd-12ab-34cd-56ef-1234567890ab", # Optional customer KMS key
    output_s3_path="s3://my-bucket/custom-output-path/" # Optional output path
)

Note: Only provide output_s3_path if the 'JobResult' object doesn't have 'model_artifacts' (will be called out when you run the function).

Email Confirmation

When you enable notifications:

  1. Each email address receives a confirmation email from AWS SNS
  2. Users must click the confirmation link in the email
  3. After confirmation, they'll receive notifications for all jobs using that SNS topic
  4. Confirmation is only needed once per email address per region

Notification Content

Email notifications include:

  • Job ID and platform (SMTJ/SMHP)
  • Job status (Completed, Failed, or Stopped)
  • Timestamp
  • Link to AWS Console for the job
  • For completed jobs: Validation status of output artifacts
  • For failed jobs: Failure reason (if available)

Infrastructure Details (SMTJ)

CloudFormation Stack

The notification infrastructure is managed as a CloudFormation stack:

  • Stack Name: NovaForgeSDK-SMTJ-JobNotifications
  • Region: Specified when enabling notifications (default: us-east-1)
  • Resources: DynamoDB table, SNS topic, Lambda function, EventBridge rule, IAM roles

DynamoDB Table

Stores job notification configurations:

  • Table Name: NovaForgeSDK-SMTJ-JobNotifications
  • Primary Key: job_id (String)
  • Attributes: emails (String Set), output_s3_path (String), created_at (String), ttl (Number)
  • TTL: Automatically deletes entries after 30 days

SNS Topic

Manages email subscriptions:

  • Topic Name: NovaForgeSDK-SMTJ-Notifications
  • Encryption: Optional KMS encryption
  • Subscriptions: Email protocol with confirmation required

Lambda Function

Handles job state change events:

  • Function Name: NovaForgeSDK-SMTJ-NotificationHandler
  • Runtime: Python 3.12
  • Timeout: 180 seconds
  • Triggers: EventBridge rule for SageMaker Training Job state changes

EventBridge Rule

Monitors job status:

  • Rule Name: NovaForgeSDK-SMTJ-Job-State-Change
  • Event Pattern: SageMaker Training Job State Change events
  • States Monitored: Completed, Failed, Stopped

Infrastructure Details (SMHP)

CloudFormation Stack

The notification infrastructure is managed as a CloudFormation stack:

  • Stack Name: NovaForgeSDK-SMHP-JobNotifications-{ClusterName}
  • Region: Specified when enabling notifications (default: us-east-1)
  • Resources: DynamoDB table, SNS topic, Lambda function, EventBridge rule, IAM roles, VPC endpoints

DynamoDB Table

Stores job notification configurations:

  • Table Name: NovaForgeSDK-SMHP-JobNotifications-{ClusterName}
  • Primary Key: job_id (String)
  • Attributes: output_s3_path (String), namespace (String), ttl (Number)
  • TTL: Automatically deletes entries after 30 days
  • Point-in-Time Recovery: Enabled

SNS Topic

Manages email subscriptions:

  • Topic Name: NovaForgeSDK-SMHP-Notifications-{ClusterName}
  • Encryption: Optional KMS encryption
  • Subscriptions: Email protocol with confirmation required

Lambda Function

Handles job status polling:

  • Function Name: NovaForgeSDK-SMHP-NotificationHandler-{ClusterName}
  • Runtime: Python 3.12
  • Timeout: 300 seconds
  • Memory: 512 MB
  • VPC Configuration: Deployed in VPC with access to EKS cluster
  • Layers: kubectl layer for Kubernetes API access
  • Triggers: EventBridge scheduled rule (default: every 5 minutes)

EventBridge Rule

Periodically checks job status:

  • Rule Name: NovaForgeSDK-SMHP-Job-Check-{ClusterName}
  • Schedule: Rate-based (default: every 5 minutes, configurable)
  • Target: Lambda function for polling PyTorchJob status

VPC Endpoints

Enable private AWS service access for Lambda:

  • DynamoDB Gateway Endpoint: NovaForgeSDK-SMHP-DynamoDB-{ClusterName}
  • SNS Interface Endpoint: NovaForgeSDK-SMHP-SNS-{ClusterName}
  • S3 Gateway Endpoint: NovaForgeSDK-SMHP-S3-{ClusterName}

Limitations and Notes

  1. Email Confirmation: Users must confirm their email subscription before receiving notifications.

  2. Region-Specific: Notification infrastructure is created per region. Jobs in different regions require separate infrastructure.

  3. Stack Creation Restrictions: For SMTJ, one notification stack is created per region. For SMHP, one notification stack is created per cluster per region.

  4. KMS Key Requirements: If using KMS encryption:

    • Provide only the key ID, not the full ARN
    • The Lambda function automatically receives permissions to use the key
    • The key must be in the same region as the notification infrastructure
  5. Output Path Required: The output_s3_path is required for manifest validation. The SDK will attempt to extract it from model_artifacts if not provided explicitly.

  6. Hard-coded CloudFormation Stack Names: When the CF stack is created, it will have one of the following names: NovaForgeSDK-SMTJ-JobNotifications or NovaForgeSDK-SMHP-JobNotifications-{HP-Cluster}.

Troubleshooting

Notifications Not Received

  1. Check email confirmation: Ensure you clicked the confirmation link in the AWS SNS email
  2. Check spam folder: SNS emails may be filtered as spam
  3. Verify job status: Notifications only sent for terminal states (Completed, Failed, Stopped)
  4. Check CloudWatch Logs: View Lambda function logs for errors

Stack Creation Failures

If CloudFormation stack creation fails:

  1. Check IAM permissions for CloudFormation, DynamoDB, SNS, Lambda, EventBridge, and IAM
  2. Verify no resource name conflicts exist
  3. Check CloudFormation console for detailed error messages

API Reference

See the BaseJobResult.enable_job_notifications() method documentation for detailed parameter information.


Error Handling

All SDK functions may raise exceptions. It's recommended to wrap calls in try-except blocks:

try:
 result = customizer.train(job_name="my-job")
except ValueError as e:
 print(f"Configuration error: {e}")
except Exception as e:
 print(f"Training failed: {e}")

Common exceptions:

  • ValueError: Invalid parameters or configuration
  • DataPrepError: Dataset preparation errors
  • Exception: General job execution or AWS API errors

Best Practices

  1. Always validate your data using loader.show() before training
  2. Use overrides sparingly - start with defaults and tune as needed
  3. Monitor logs during training using get_logs()
  4. Check job status before calling .get() on results
  5. Clean up resources when done to avoid unnecessary costs
  6. Use descriptive job names to help track and organize your experiments
  7. Save results incrementally during long-running jobs
  8. Test with small datasets before scaling up to full training

Additional Resources

  • AWS Documentation: Amazon Bedrock
  • AWS Documentation: Amazon SageMaker
  • SDK GitHub Repository: Check for updates and examples
  • Support: Use AWS Support for technical assistance