From 4730951ad18c0c455669c98de7ab44ed74b4f2fe Mon Sep 17 00:00:00 2001 From: Wes Eklund Date: Tue, 27 Jan 2026 14:29:26 -0500 Subject: [PATCH] fix: model deploy pipeline fails when PRE_PROD/PROD env vars are empty This fixes two bugs in the model deployment seed code: 1. Pipeline fails with "Unable to parse environment specification aws:///" when PRE_PROD_ACCOUNT_ID, PRE_PROD_REGION, PROD_ACCOUNT_ID, or PROD_REGION are not provided. Now these values fall back to DEV environment values, allowing single-account deployments. 2. EventBridge rule name exceeds 64 character limit for long project names. Now truncates project name to fit within the limit. Also adds unit tests for the constants module to prevent regression. --- .../seed_code/deploy_app/config/constants.py | 54 ++++- .../deploy_app/deploy_app/pipeline_stack.py | 7 +- .../seed_code/deploy_app/tests/__init__.py | 2 + .../deploy_app/tests/test_constants.py | 214 ++++++++++++++++++ 4 files changed, 265 insertions(+), 12 deletions(-) create mode 100644 modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/tests/__init__.py create mode 100644 modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/tests/test_constants.py diff --git a/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/config/constants.py b/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/config/constants.py index f31c08d7..ac440f05 100644 --- a/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/config/constants.py +++ b/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/config/constants.py @@ -17,7 +17,7 @@ import json import os -from typing import Union +from typing import Any, List, Union def get_bool_env(name: str, default: Union[bool, str] = True) -> bool: @@ -26,6 +26,33 @@ def get_bool_env(name: str, default: Union[bool, str] = True) -> bool: return os.getenv(name, default_str).lower() == "true" +def get_env_with_fallback(key: str, fallback: Any, is_json_list: bool = False, warn: bool = False) -> Any: + """Get environment variable with fallback to provided value. + + Args: + key: Environment variable name + fallback: Value to use if env var is not set or empty + is_json_list: If True, parse as JSON list and fall back if empty list + warn: If True, print a warning when fallback is used + + Returns: + The environment variable value, or fallback if not set/empty + """ + raw = os.environ.get(key) + if is_json_list: + value: List[str] = json.loads(raw) if raw else [] + if not value: + if warn: + print(f"INFO: {key} not provided, using fallback") + return fallback + return value + if not raw: + if warn: + print(f"INFO: {key} not provided, using fallback") + return fallback + return raw + + MAX_NAME_LENGTH = 63 REPOSITORY_TYPE = os.getenv("REPOSITORY_TYPE", "CodeCommit") # Default to CODECOMMIT if not set CODE_CONNECTION_ARN = os.getenv("CODE_CONNECTION_ARN", "") @@ -33,23 +60,28 @@ def get_bool_env(name: str, default: Union[bool, str] = True) -> bool: MODEL_BUCKET_ARN = os.environ["MODEL_BUCKET_ARN"] MODEL_PACKAGE_GROUP_NAME = os.getenv("MODEL_PACKAGE_GROUP_NAME", "") +# DEV environment configuration (required) DEV_ACCOUNT_ID = os.environ["DEV_ACCOUNT_ID"] DEV_REGION = os.environ["DEV_REGION"] DEV_VPC_ID = os.environ["DEV_VPC_ID"] DEV_SUBNET_IDS = json.loads(os.environ["DEV_SUBNET_IDS"]) DEV_SECURITY_GROUP_IDS = json.loads(os.environ["DEV_SECURITY_GROUP_IDS"]) -PRE_PROD_ACCOUNT_ID = os.environ["PRE_PROD_ACCOUNT_ID"] -PRE_PROD_REGION = os.environ["PRE_PROD_REGION"] -PRE_PROD_VPC_ID = os.environ["PRE_PROD_VPC_ID"] -PRE_PROD_SUBNET_IDS = json.loads(os.environ["PRE_PROD_SUBNET_IDS"]) -PRE_PROD_SECURITY_GROUP_IDS = json.loads(os.environ["PRE_PROD_SECURITY_GROUP_IDS"]) +# PRE_PROD environment configuration (falls back to DEV if not provided) +PRE_PROD_ACCOUNT_ID = get_env_with_fallback("PRE_PROD_ACCOUNT_ID", DEV_ACCOUNT_ID, warn=True) +PRE_PROD_REGION = get_env_with_fallback("PRE_PROD_REGION", DEV_REGION, warn=True) +PRE_PROD_VPC_ID = get_env_with_fallback("PRE_PROD_VPC_ID", DEV_VPC_ID) +PRE_PROD_SUBNET_IDS = get_env_with_fallback("PRE_PROD_SUBNET_IDS", DEV_SUBNET_IDS, is_json_list=True) +PRE_PROD_SECURITY_GROUP_IDS = get_env_with_fallback( + "PRE_PROD_SECURITY_GROUP_IDS", DEV_SECURITY_GROUP_IDS, is_json_list=True +) -PROD_ACCOUNT_ID = os.environ["PROD_ACCOUNT_ID"] -PROD_REGION = os.environ["PROD_REGION"] -PROD_VPC_ID = os.environ["PROD_VPC_ID"] -PROD_SUBNET_IDS = json.loads(os.environ["PROD_SUBNET_IDS"]) -PROD_SECURITY_GROUP_IDS = json.loads(os.environ["PROD_SECURITY_GROUP_IDS"]) +# PROD environment configuration (falls back to DEV if not provided) +PROD_ACCOUNT_ID = get_env_with_fallback("PROD_ACCOUNT_ID", DEV_ACCOUNT_ID, warn=True) +PROD_REGION = get_env_with_fallback("PROD_REGION", DEV_REGION, warn=True) +PROD_VPC_ID = get_env_with_fallback("PROD_VPC_ID", DEV_VPC_ID) +PROD_SUBNET_IDS = get_env_with_fallback("PROD_SUBNET_IDS", DEV_SUBNET_IDS, is_json_list=True) +PROD_SECURITY_GROUP_IDS = get_env_with_fallback("PROD_SECURITY_GROUP_IDS", DEV_SECURITY_GROUP_IDS, is_json_list=True) PROJECT_NAME = os.getenv("PROJECT_NAME", "") PROJECT_ID = os.getenv("PROJECT_ID", "") diff --git a/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/deploy_app/pipeline_stack.py b/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/deploy_app/pipeline_stack.py index 0abbe89f..f00477ab 100644 --- a/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/deploy_app/pipeline_stack.py +++ b/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/deploy_app/pipeline_stack.py @@ -223,10 +223,15 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: Any) -> None: # Add EventBridge rule to trigger pipeline when model is approved in Model Registry if constants.ENABLE_EVENTBRIDGE_TRIGGER: + # Truncate project name to fit within 64 char limit for EventBridge rule names + rule_suffix = "-model-approval-trigger" + max_prefix_len = constants.MAX_NAME_LENGTH - len(rule_suffix) + truncated_project_name = constants.PROJECT_NAME[:max_prefix_len] + events.Rule( self, "ModelApprovalEventRule", - rule_name=f"{constants.PROJECT_NAME}-model-approval-trigger", + rule_name=f"{truncated_project_name}{rule_suffix}", event_pattern=events.EventPattern( source=["aws.sagemaker"], detail_type=["SageMaker Model Package State Change"], diff --git a/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/tests/__init__.py b/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/tests/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/tests/test_constants.py b/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/tests/test_constants.py new file mode 100644 index 00000000..709c97b6 --- /dev/null +++ b/modules/sagemaker/sagemaker-templates/templates/model_deploy/seed_code/deploy_app/tests/test_constants.py @@ -0,0 +1,214 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +from unittest import mock + +import pytest + + +@pytest.fixture(scope="function") +def clean_constants_module(): + """Remove cached constants module to ensure fresh imports with mocked env vars.""" + modules_to_remove = [k for k in sys.modules if k.startswith("config")] + for mod in modules_to_remove: + del sys.modules[mod] + + yield + + modules_to_remove = [k for k in sys.modules if k.startswith("config")] + for mod in modules_to_remove: + del sys.modules[mod] + + +@pytest.fixture(scope="function") +def base_env_vars(): + """Base environment variables required for constants module to load.""" + return { + "DEV_ACCOUNT_ID": "111111111111", + "DEV_REGION": "us-east-1", + "DEV_VPC_ID": "vpc-dev", + "DEV_SUBNET_IDS": '["subnet-dev"]', + "DEV_SECURITY_GROUP_IDS": '["sg-dev"]', + "MODEL_BUCKET_ARN": "arn:aws:s3:::test-bucket", + } + + +# Tests for get_env_with_fallback helper function + + +def test_get_env_with_fallback_returns_value_when_set(clean_constants_module, base_env_vars): + """Should return the environment variable value when it's set.""" + with mock.patch.dict(os.environ, {**base_env_vars, "TEST_VAR": "test_value"}, clear=True): + from config.constants import get_env_with_fallback + + result = get_env_with_fallback("TEST_VAR", "fallback") + assert result == "test_value" + + +def test_get_env_with_fallback_returns_fallback_when_not_set(clean_constants_module, base_env_vars): + """Should return fallback when environment variable is not set.""" + with mock.patch.dict(os.environ, base_env_vars, clear=True): + from config.constants import get_env_with_fallback + + result = get_env_with_fallback("TEST_VAR", "fallback") + assert result == "fallback" + + +def test_get_env_with_fallback_returns_fallback_when_empty_string(clean_constants_module, base_env_vars): + """Should return fallback when environment variable is empty string.""" + with mock.patch.dict(os.environ, {**base_env_vars, "TEST_VAR": ""}, clear=True): + from config.constants import get_env_with_fallback + + result = get_env_with_fallback("TEST_VAR", "fallback") + assert result == "fallback" + + +def test_get_env_with_fallback_parses_json_list(clean_constants_module, base_env_vars): + """Should parse JSON list when is_json_list=True.""" + with mock.patch.dict(os.environ, {**base_env_vars, "TEST_LIST": '["a", "b", "c"]'}, clear=True): + from config.constants import get_env_with_fallback + + result = get_env_with_fallback("TEST_LIST", ["default"], is_json_list=True) + assert result == ["a", "b", "c"] + + +def test_get_env_with_fallback_json_list_fallback_when_empty(clean_constants_module, base_env_vars): + """Should return fallback when JSON list is empty.""" + with mock.patch.dict(os.environ, {**base_env_vars, "TEST_LIST": "[]"}, clear=True): + from config.constants import get_env_with_fallback + + result = get_env_with_fallback("TEST_LIST", ["default"], is_json_list=True) + assert result == ["default"] + + +def test_get_env_with_fallback_warns_when_enabled(clean_constants_module, base_env_vars, capsys): + """Should print warning when warn=True and fallback is used.""" + with mock.patch.dict(os.environ, base_env_vars, clear=True): + from config.constants import get_env_with_fallback + + get_env_with_fallback("TEST_VAR", "fallback", warn=True) + captured = capsys.readouterr() + assert "INFO: TEST_VAR not provided, using fallback" in captured.out + + +def test_get_env_with_fallback_no_warning_by_default(clean_constants_module, base_env_vars, capsys): + """Should not print warning when warn=False (default).""" + with mock.patch.dict(os.environ, base_env_vars, clear=True): + from config.constants import get_env_with_fallback + + # Clear any output from module load (PRE_PROD/PROD fallback warnings) + capsys.readouterr() + + get_env_with_fallback("TEST_VAR", "fallback") + captured = capsys.readouterr() + assert "TEST_VAR" not in captured.out + + +# Tests for PRE_PROD/PROD fallback behavior + + +def test_preprod_falls_back_to_dev_when_not_set(clean_constants_module, base_env_vars): + """PRE_PROD should use DEV values when not provided.""" + with mock.patch.dict(os.environ, base_env_vars, clear=True): + import config.constants as constants + + assert constants.PRE_PROD_ACCOUNT_ID == "111111111111" + assert constants.PRE_PROD_REGION == "us-east-1" + assert constants.PRE_PROD_VPC_ID == "vpc-dev" + assert constants.PRE_PROD_SUBNET_IDS == ["subnet-dev"] + assert constants.PRE_PROD_SECURITY_GROUP_IDS == ["sg-dev"] + + +def test_prod_falls_back_to_dev_when_not_set(clean_constants_module, base_env_vars): + """PROD should use DEV values when not provided.""" + with mock.patch.dict(os.environ, base_env_vars, clear=True): + import config.constants as constants + + assert constants.PROD_ACCOUNT_ID == "111111111111" + assert constants.PROD_REGION == "us-east-1" + assert constants.PROD_VPC_ID == "vpc-dev" + assert constants.PROD_SUBNET_IDS == ["subnet-dev"] + assert constants.PROD_SECURITY_GROUP_IDS == ["sg-dev"] + + +def test_preprod_uses_own_values_when_set(clean_constants_module, base_env_vars): + """PRE_PROD should use its own values when provided.""" + env = { + **base_env_vars, + "PRE_PROD_ACCOUNT_ID": "222222222222", + "PRE_PROD_REGION": "us-west-2", + "PRE_PROD_VPC_ID": "vpc-preprod", + "PRE_PROD_SUBNET_IDS": '["subnet-preprod"]', + "PRE_PROD_SECURITY_GROUP_IDS": '["sg-preprod"]', + } + with mock.patch.dict(os.environ, env, clear=True): + import config.constants as constants + + assert constants.PRE_PROD_ACCOUNT_ID == "222222222222" + assert constants.PRE_PROD_REGION == "us-west-2" + assert constants.PRE_PROD_VPC_ID == "vpc-preprod" + assert constants.PRE_PROD_SUBNET_IDS == ["subnet-preprod"] + assert constants.PRE_PROD_SECURITY_GROUP_IDS == ["sg-preprod"] + + +def test_prod_uses_own_values_when_set(clean_constants_module, base_env_vars): + """PROD should use its own values when provided.""" + env = { + **base_env_vars, + "PROD_ACCOUNT_ID": "333333333333", + "PROD_REGION": "eu-west-1", + "PROD_VPC_ID": "vpc-prod", + "PROD_SUBNET_IDS": '["subnet-prod"]', + "PROD_SECURITY_GROUP_IDS": '["sg-prod"]', + } + with mock.patch.dict(os.environ, env, clear=True): + import config.constants as constants + + assert constants.PROD_ACCOUNT_ID == "333333333333" + assert constants.PROD_REGION == "eu-west-1" + assert constants.PROD_VPC_ID == "vpc-prod" + assert constants.PROD_SUBNET_IDS == ["subnet-prod"] + assert constants.PROD_SECURITY_GROUP_IDS == ["sg-prod"] + + +def test_empty_string_treated_as_not_set(clean_constants_module, base_env_vars): + """Empty string values should fall back to DEV.""" + env = { + **base_env_vars, + "PRE_PROD_ACCOUNT_ID": "", + "PRE_PROD_REGION": "", + "PROD_ACCOUNT_ID": "", + "PROD_REGION": "", + } + with mock.patch.dict(os.environ, env, clear=True): + import config.constants as constants + + assert constants.PRE_PROD_ACCOUNT_ID == "111111111111" + assert constants.PRE_PROD_REGION == "us-east-1" + assert constants.PROD_ACCOUNT_ID == "111111111111" + assert constants.PROD_REGION == "us-east-1" + + +def test_warns_for_account_and_region_fallback(clean_constants_module, base_env_vars, capsys): + """Should print warnings when ACCOUNT_ID and REGION fall back.""" + with mock.patch.dict(os.environ, base_env_vars, clear=True): + import config.constants # noqa: F401 + + captured = capsys.readouterr() + assert "PRE_PROD_ACCOUNT_ID not provided" in captured.out + assert "PRE_PROD_REGION not provided" in captured.out + assert "PROD_ACCOUNT_ID not provided" in captured.out + assert "PROD_REGION not provided" in captured.out + + +def test_no_warning_for_vpc_subnet_sg_fallback(clean_constants_module, base_env_vars, capsys): + """Should NOT print warnings when VPC/subnet/SG fall back.""" + with mock.patch.dict(os.environ, base_env_vars, clear=True): + import config.constants # noqa: F401 + + captured = capsys.readouterr() + assert "VPC_ID not provided" not in captured.out + assert "SUBNET_IDS not provided" not in captured.out + assert "SECURITY_GROUP_IDS not provided" not in captured.out