diff --git a/.envrc b/.envrc new file mode 100644 index 000000000..3550a30f2 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.github/workflows/code.release.branch.yml b/.github/workflows/code.release.branch.yml index c5cb3c4aa..6683dbbc7 100644 --- a/.github/workflows/code.release.branch.yml +++ b/.github/workflows/code.release.branch.yml @@ -58,6 +58,12 @@ jobs: echo ${RELEASE_TAG:1} > VERSION # Update package-lock.json to reflect new version npm install + # Regenerate CDK baselines from MockApp + echo "๐Ÿ“ Regenerating CDK baselines..." + rm -rf test/cdk/stacks/__baselines__ + mkdir -p test/cdk/stacks/__baselines__ + npm test -- test/cdk/stacks/snapshot.test.ts --testNamePattern="is compatible with baseline" + echo "โœ… CDK baselines regenerated" # Add the generated PR description to the top of CHANGELOG.md echo "๐Ÿ“ Adding release notes to CHANGELOG.md..." # Create a temporary file with the new content diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index 42107f79e..a2c57c880 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -75,6 +75,9 @@ jobs: fi pip install -e ./lisa-sdk - name: Run tests + env: + ACCOUNT_NUMBER: '012345678901' + REGION: us-east-1 run: | make test-coverage pre-commit: diff --git a/.gitignore b/.gitignore index 0097dc556..5c733858b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.direnv *.js !tailwind.config.js !postcss.config.js diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ef1b8204e..41ab94870 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,14 +16,14 @@ repos: hooks: - id: bandit args: [--recursive, -c=pyproject.toml] - additional_dependencies: ['bandit[toml]', 'pbr'] + additional_dependencies: ['bandit[toml]', 'pbr', 'PyYAML'] - repo: https://github.com/Yelp/detect-secrets rev: v1.5.0 hooks: - id: detect-secrets exclude: (?x)^( - .*.ipynb|config.yaml|.*.md|.*test.*.py + .*.ipynb|config.yaml|.*.md|.*test.*.py|test/cdk/stacks/__baselines__/.*\.json )$ - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/CHANGELOG.md b/CHANGELOG.md index 814ab618d..08bc39d92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,39 @@ +# v5.4.0 + +## Key Features + +### Bedrock Guardrails Integration +LISA Administrators can now Bedrock Guardrails to any models via the Model Management page or API +- **Comprehensive Protection**: Integrated with Bedrock Guardrails through LiteLLM's proxy support of the ApplyGuardrail API, enabling guardrails during prompt input, response generation, and prompt output. +- **Advanced Capabilities**: Supports topic denial, word filtering, sensitive information limitation, contextual grounding checks, and automated reasoning for factual accuracy. +- **Flexible Administration**: Administrators can apply Guardrails to any LISA model (self hosted or 3rd party) via the Model Management UI or API, with customizable permissions for different user groups. +- **Adaptive Policies**: Guardrails and group permissions can be updated anytime to evolve content moderation alongside organizational needs. + +### Offline/Air-gapped Deployment Support +Enhanced the platform to support offline and air-gapped deployments by enabling pre-caching of external dependencies for the REST API and MCP Workbench. +- **Nodeenv Pre-caching**: Added support for pre-caching the required nodeenv in the REST API container to enable offline deployments. +- **Offline Deployment**: Enabled configuration of pre-cached external dependencies for the MCP Workbench via to support offline and air-gapped deployments. + +### MCP Workbench Refactoring +Migrated the MCP Workbench deployment to use the shared LisaServe ECS cluster, improving modularity and enabling conditional deployment. +- **MCP Workbench Stack**: Created a dedicated stack that deploys the MCP Workbench as a separate ECS service on the shared cluster. +- **Conditional Deployment**: Introduced a configuration flag to control the optional deployment of the MCP Workbench. +- **Container Overrides**: Added support for overriding the MCP Workbench container image during deployment.. + +### MCP Workbench UX Improvements +Enhanced the user experience of the MCP Workbench with tool validation, error display, and theme support. +- **Validation**: Implemented tool validation to improve the user experience. +- **Theming**: Introduced theme support for the MCP Workbench UI. + +## Acknowledgements +* @batzela +* @bedanley +* @dustins +* @estohlmann +* @jmharold + +**Full Changelog**: https://github.com/awslabs/LISA/compare/v5.3.2..v5.4.0 + # v5.3.2 ## Key Features @@ -20,7 +56,7 @@ This release includes updates to our [documentation site](https://awslabs.github ## Acknowledgements * @bedanley -* @dustinps +* @dustins * @estohlmann * @jmharold diff --git a/Makefile b/Makefile index 2d5faacb1..16ce9c41f 100644 --- a/Makefile +++ b/Makefile @@ -13,33 +13,30 @@ PROJECT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) HEADLESS = false DOCKER_CMD ?= $(or $(CDK_DOCKER),docker) +# Function to read config with fallback to base config and default value +# Usage: VAR := $(call get_config,property,default_value) +define get_config +$(shell test -f $(PROJECT_DIR)/config-custom.yaml && yq -r $(1) $(PROJECT_DIR)/config-custom.yaml 2>/dev/null | grep -v '^null$$' || \ + (test -f $(PROJECT_DIR)/config-base.yaml && yq -r $(1) $(PROJECT_DIR)/config-base.yaml 2>/dev/null | grep -v '^null$$') || \ + echo "$(2)") +endef # PROFILE (optional argument) ifeq (${PROFILE},) -TEMP_PROFILE := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq .profile) -ifneq ($(TEMP_PROFILE), null) -PROFILE := ${TEMP_PROFILE} -else -$(warning profile is not set in the command line using PROFILE variable or config files, attempting deployment without this variable) +PROFILE := $(call get_config,.profile,) +ifeq ($(PROFILE),) +$(warning profile is not set in command line using PROFILE variable or config files, attempting deployment without this variable) endif endif # DEPLOYMENT_NAME ifeq (${DEPLOYMENT_NAME},) -DEPLOYMENT_NAME := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq .deploymentName) -endif - -ifeq (${DEPLOYMENT_NAME}, null) -DEPLOYMENT_NAME := $(shell cat $(PROJECT_DIR)/config-base.yaml | yq .deploymentName) -endif - -ifeq (${DEPLOYMENT_NAME}, null) -DEPLOYMENT_NAME := prod +DEPLOYMENT_NAME := $(call get_config,.deploymentName,prod) endif # ACCOUNT_NUMBER ifeq (${ACCOUNT_NUMBER},) -ACCOUNT_NUMBER := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq .accountNumber) +ACCOUNT_NUMBER := $(call get_config,.accountNumber,) endif ifeq (${ACCOUNT_NUMBER},) @@ -48,18 +45,16 @@ endif # REGION ifeq (${REGION},) -REGION := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq .region) +REGION := $(call get_config,.region,) endif ifeq (${REGION},) $(error region must be set in command line using REGION variable or config files) endif +# PARTITION ifeq (${PARTITION},) -PARTITION := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq .partition ) -endif -ifeq (${PARTITION}, null) -PARTITION := aws +PARTITION := $(call get_config,.partition,aws) endif # DOMAIN - used for the docker login @@ -76,29 +71,13 @@ endif # Arguments defined through config files # APP_NAME -APP_NAME := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq .appName) -ifeq (${APP_NAME}, null) -APP_NAME := $(shell cat $(PROJECT_DIR)/config-base.yaml | yq .appName) -endif - -ifeq (${APP_NAME}, null) -APP_NAME := lisa -endif +APP_NAME := $(call get_config,.appName,lisa) # DEPLOYMENT_STAGE -DEPLOYMENT_STAGE := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq .deploymentStage) -ifeq (${DEPLOYMENT_STAGE}, null) -DEPLOYMENT_STAGE := $(shell cat $(PROJECT_DIR)/config-base.yaml | yq .deploymentStage) -endif - -ifeq (${DEPLOYMENT_STAGE}, null) -DEPLOYMENT_STAGE := prod -endif +DEPLOYMENT_STAGE := $(call get_config,.deploymentStage,prod) # ACCOUNT_NUMBERS_ECR - AWS account numbers that need to be logged into with Docker CLI to use ECR -ifneq ($(shell cat $(PROJECT_DIR)/config-custom.yaml | yq '.accountNumbersEcr'), null) -ACCOUNT_NUMBERS_ECR := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq .accountNumbersEcr[]) -endif +ACCOUNT_NUMBERS_ECR := $(shell test -f $(PROJECT_DIR)/config-custom.yaml && yq '.accountNumbersEcr[]' $(PROJECT_DIR)/config-custom.yaml 2>/dev/null || echo "") # Append deployed account number to array for dockerLogin rule ACCOUNT_NUMBERS_ECR := $(ACCOUNT_NUMBERS_ECR) $(ACCOUNT_NUMBER) @@ -113,13 +92,18 @@ ifneq ($(findstring $(DEPLOYMENT_STAGE),$(STACK)),$(DEPLOYMENT_STAGE)) endif # MODEL_IDS - IDs of models to deploy -ifneq ($(shell cat $(PROJECT_DIR)/config-custom.yaml | yq '.ecsModels'), null) -MODEL_IDS := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq '.ecsModels[].modelName') -endif +MODEL_IDS := $(shell test -f $(PROJECT_DIR)/config-custom.yaml && yq '.ecsModels[].modelName' $(PROJECT_DIR)/config-custom.yaml 2>/dev/null || echo "") # MODEL_BUCKET - S3 bucket containing model artifacts -MODEL_BUCKET := $(shell cat $(PROJECT_DIR)/config-custom.yaml | yq '.s3BucketModels') +MODEL_BUCKET := $(call get_config,.s3BucketModels,) +# BASE_URL - Base URL for web UI assets based on domain name and deployment stage +DOMAIN_NAME := $(call get_config,.apiGatewayConfig.domainName,) +ifeq ($(DOMAIN_NAME),) +BASE_URL := /$(DEPLOYMENT_STAGE)/ +else +BASE_URL := / +endif ################################################################################# # COMMANDS # @@ -269,7 +253,7 @@ listStacks: @npx cdk list buildNpmModules: - npm run build + BASE_URL=$(BASE_URL) npm run build buildArchive: BUILD_ASSETS=true npm run build diff --git a/VERSION b/VERSION index 84197c894..8a30e8f94 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -5.3.2 +5.4.0 diff --git a/bin/build-images b/bin/build-images index ae9377f21..d5e1725f1 100755 --- a/bin/build-images +++ b/bin/build-images @@ -89,6 +89,13 @@ ecr_login() { fi } +# Function to check if a config parameter is enabled (defaults to true if not present) +should_build_image() { + local param="$1" + local value=$(yq ".${param}" "$ROOT/custom-config.yaml" 2>/dev/null) + [[ "$value" != "false" ]] +} + # Main function to build all images build_all_images() { echo "Starting Docker image builds..." @@ -117,6 +124,17 @@ build_all_images() { rsync -av --exclude='__pycache__' ./lambda/ "$BUILD_DIR/" build_image "Dockerfile" "lisa-batch-ingestion" "$LISA_VERSION" "$RAG_DIR" "NODE_ENV=production" + # lisa-mcp-workbench (conditional) + if should_build_image "deployMcpWorkbench"; then + MCP_DIR="./lib/serve/mcp-workbench" + build_image "Dockerfile" "lisa-mcp-workbench" "$LISA_VERSION" "$MCP_DIR" \ + "NODE_ENV=production" \ + "BASE_IMAGE=python:3.13.7-slim" + else + echo "deployMcpWorkbench is disabled, skipping lisa-mcp-workbench build" + echo "" + fi + # lisa-tei build_image "Dockerfile" "lisa-tei" "latest" "./lib/serve/ecs-model/embedding/tei" \ "NODE_ENV=production" \ diff --git a/ecs_model_deployer/src/lib/ecsCluster.ts b/ecs_model_deployer/src/lib/ecsCluster.ts index 280676e6a..1618d38dd 100644 --- a/ecs_model_deployer/src/lib/ecsCluster.ts +++ b/ecs_model_deployer/src/lib/ecsCluster.ts @@ -237,7 +237,8 @@ export class ECSCluster extends Construct { environment, logging: LogDriver.awsLogs({ streamPrefix: identifier }), gpuCount: Ec2Metadata.get(ecsConfig.instanceType).gpuCount, - memoryReservationMiB: Ec2Metadata.get(ecsConfig.instanceType).memory - ecsConfig.containerMemoryBuffer, + memoryReservationMiB: taskDefinition.containerConfig.memoryReservation ?? + (Ec2Metadata.get(ecsConfig.instanceType).memory - ecsConfig.containerMemoryBuffer), portMappings: [{ hostPort: 80, containerPort: 8080, protocol: Protocol.TCP }], healthCheck: containerHealthCheck, // Model containers need to run with privileged set to true diff --git a/ecs_model_deployer/src/lib/lisa_model_stack.ts b/ecs_model_deployer/src/lib/lisa_model_stack.ts index ab4eab304..402e2fdf8 100644 --- a/ecs_model_deployer/src/lib/lisa_model_stack.ts +++ b/ecs_model_deployer/src/lib/lisa_model_stack.ts @@ -57,7 +57,8 @@ export class LisaModelStack extends Stack { super(scope, id, props); const vpc = Vpc.fromLookup(this, `${id}-vpc`, { - vpcId: props.vpcId + vpcId: props.vpcId, + returnVpnGateways: false, }); let subnetSelection: SubnetSelection | undefined; diff --git a/eslint.config.js b/eslint.config.mjs similarity index 98% rename from eslint.config.js rename to eslint.config.mjs index 520d3abf2..a24b7ea89 100644 --- a/eslint.config.js +++ b/eslint.config.mjs @@ -142,30 +142,32 @@ export default [ }, { ignores: [ - 'dist/**', - 'node_modules/**', + '**/*.bundle.js', + '**/*.d.ts', + '**/*.min.js', + '**/.venv/**', + '**/build/**', + '**/coverage/**', + '**/dist/**', + '**/venv/**', + '*.bundle.js', + '*.min.js', + '.venv/**', 'build/**', 'coverage/**', + 'cypress/dist/**', + 'dist/**', + 'ecs_model_deployer/dist/**', 'htmlcov/**', + 'lib/docs/.vitepress/cache/**', + 'lib/docs/dist/**', 'lib/user-interface/react/dist/**', 'lib/user-interface/react/public/**', - 'lib/docs/dist/**', - 'lib/docs/.vitepress/cache/**', - 'ecs_model_deployer/dist/**', + 'node_modules/**', + 'pnpm-lock.yaml', + 'pnpm-workspace.yaml', 'vector_store_deployer/dist/**', - 'cypress/dist/**', - '.venv/**', 'venv/**', - '*.min.js', - '*.bundle.js', - '**/*.min.js', - '**/*.bundle.js', - '**/dist/**', - '**/build/**', - '**/coverage/**', - '**/.venv/**', - '**/venv/**', - '**/*.d.ts' ] } ]; diff --git a/flake.lock b/flake.lock new file mode 100644 index 000000000..5aec80c83 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1761468971, + "narHash": "sha256-vY2OLVg5ZTobdroQKQQSipSIkHlxOTrIF1fsMzPh8w8=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "78e34d1667d32d8a0ffc3eba4591ff256e80576e", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-25.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 000000000..a10634c69 --- /dev/null +++ b/flake.nix @@ -0,0 +1,110 @@ +{ + # LISA Development Environment Flake + # + # LISA (LLM Inference Solution for Amazon) is an open source infrastructure-as-code + # solution for deploying LLM inference capabilities into AWS accounts. This flake + # provides a complete development environment with all necessary tools and dependencies + # for developing, testing, and deploying LISA. + description = "Development environment for LISA - LLM Inference Solution for Amazon"; + + inputs = { + # Use the unstable channel for latest package versions + nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05"; + + # Utility functions for creating flakes that work across multiple systems + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + # Generate outputs for all default systems (x86_64-linux, aarch64-linux, x86_64-darwin, aarch64-darwin) + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + # Formatter for this flake (run with `nix fmt`) + formatter = pkgs.nixpkgs-fmt; + + # Default development shell (enter with `nix develop`) + devShells.default = pkgs.mkShell { + # Core development tools needed for LISA + packages = with pkgs; [ + awscli2 # AWS command-line interface for deployment and management + gnumake + jq # JSON processor for parsing AWS responses and configuration + pre-commit # Git hook framework for code quality checks + python311Full # Python runtime for LISA backend services + nodejs # Node.js runtime for CDK infrastructure and frontend tooling + nodePackages.aws-cdk # AWS CDK CLI, the command line tool for CDK apps + uv # Fast Python package installer and virtual environment manager + yq # YAML processor for configuration management + ]; + + # Script that runs when entering the development shell + shellHook = '' + echo "Welcome to the LISA development environment!" + echo "Python: $(python --version)" + echo "Node: $(node --version)" + echo "" + + # Set up Python virtual environment using uv + if [ ! -d .venv ]; then + echo "Creating Python virtual environment with uv..." + uv venv + else + echo "Using existing Python virtual environment." + fi + + # Ensure we start fresh if another venv is active + if [ -n "$VIRTUAL_ENV" ]; then + echo "Deactivating existing virtual environment..." + deactivate + fi + + # Activate the project virtual environment + source .venv/bin/activate + + # Ensure pip is up to date + uv pip install --upgrade pip + + # Initialize npm project if package.json doesn't exist + if [ ! -f package.json ]; then + echo "No package.json found. Running npm init..." + npm init -y + fi + + # Install Python development dependencies + echo "Installing Python development dependencies from requirements-dev.txt..." + + # Extract packages that must be installed as binary wheels (no source builds) + only_binary_packages=`grep "^--only-binary=" requirements-dev.txt | sed 's/^--only-binary=//' | tr ',' ' ' | tr -s ' ' | cut -d' ' -f1-` + echo "Extracted binary-only packages: $only_binary_packages" + + # Install requirements with --only-binary flags converted to command-line arguments + # This removes the --only-binary line from the file and passes it as CLI args instead + echo "Installing filtered requirements-dev.txt..." + uv pip install -r <(sed '/^--only-binary/d' requirements-dev.txt) `for p in "$$=only_binary_packages"; do echo "--only-binary=$$p"; done` + + # Install LISA SDK in editable mode with binary-only installation + echo "Installing lisa-sdk in editable mode..." + uv pip install --only-binary :all: -e lisa-sdk + + # Install Node.js dependencies + echo "Installing Node.js dependencies..." + npm install + + # Configure git hooks for pre-commit + # Unset any existing hooks path to ensure pre-commit can manage hooks + git config --unset-all core.hooksPath 2>/dev/null || true + pre-commit install + + echo "" + echo "Development environment ready!" + echo "Available commands:" + echo " uv pip - For faster package management" + echo " deploylisa - Clean build and deploy LISA in headless mode" + ''; + }; + } + ); +} diff --git a/lambda/mcp_workbench/lambda_functions.py b/lambda/mcp_workbench/lambda_functions.py index a1603d906..3ccc8699c 100644 --- a/lambda/mcp_workbench/lambda_functions.py +++ b/lambda/mcp_workbench/lambda_functions.py @@ -28,6 +28,8 @@ from utilities.common_functions import api_wrapper, retry_config from utilities.exceptions import HTTPException +from .syntax_validator import PythonSyntaxValidator + logger = logging.getLogger(__name__) # Initialize the S3 resource using environment variables @@ -255,3 +257,46 @@ def delete(event: dict, context: dict) -> Dict[str, str]: except Exception as e: logger.error("Unexpected error deleting tool: %s", e, exc_info=True) raise ValueError(f"Failed to delete tool: {e}") from e + + +@api_wrapper +def validate_syntax(event: dict, context: dict) -> Dict[str, Any]: + """Validate Python code syntax without execution.""" + if not is_admin(event): + raise ValueError("Only admin users can validate code syntax.") + + try: + body = json.loads(event["body"], parse_float=Decimal) + + # Ensure the required field is present + if "code" not in body: + raise ValueError("Missing required field: 'code' is required.") + + code = body["code"] + if not isinstance(code, str): + raise ValueError("Code must be a string.") + + logger.info("Validating Python code syntax") + + # Initialize the validator and validate the code + validator = PythonSyntaxValidator() + result = validator.validate_code(code) + + # Convert the dataclass to a dictionary for JSON serialization + response = { + "is_valid": result.is_valid, + "syntax_errors": result.syntax_errors, + "missing_required_imports": result.missing_required_imports, + "validation_timestamp": datetime.now().isoformat(), + } + + logger.info(f"Validation completed. Valid: {result.is_valid}, " f"Errors: {len(result.syntax_errors)}") + + return response + + except json.JSONDecodeError as e: + logger.error("Invalid JSON in request body: %s", e, exc_info=True) + raise ValueError(f"Invalid request body: {e}") from e + except Exception as e: + logger.error("Unexpected error validating syntax: %s", e, exc_info=True) + raise ValueError(f"Failed to validate syntax: {e}") from e diff --git a/lambda/mcp_workbench/mcp_mocks.py b/lambda/mcp_workbench/mcp_mocks.py new file mode 100644 index 000000000..a500d5756 --- /dev/null +++ b/lambda/mcp_workbench/mcp_mocks.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mock implementations of MCP Workbench core components for validation purposes. + +These mocks are used by the syntax validator to allow user code to import +and use MCP Workbench constructs without needing the full MCP Workbench +package installed. They provide just enough functionality to validate +the structure and usage of MCP tools. +""" + +from abc import ABC, abstractmethod +from functools import wraps +from typing import Any, Callable + + +class BaseTool(ABC): + """ + Mock BaseTool for validation purposes. + + This provides the same interface as the real BaseTool class, + allowing validation of class-based MCP tools without requiring + the full MCP Workbench package. + """ + + def __init__(self, name: str, description: str): + """ + Initialize the tool with required metadata. + + Args: + name: The name of the tool + description: A description of what the tool does + """ + self.name = name + self.description = description + + @abstractmethod + async def execute(self) -> Callable[..., Any]: + """ + Returns a function to be executed as the tool. + + Returns: + The function to be executed + """ + pass + + +def mcp_tool(name: str, description: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Mock mcp_tool decorator for validation purposes. + + This provides the same interface as the real mcp_tool decorator, + allowing validation of function-based MCP tools without requiring + the full MCP Workbench package. + + Args: + name: The name of the tool + description: A description of what the tool does + + Returns: + The decorated function with MCP tool metadata + """ + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + # Store metadata as function attributes + func._mcp_tool_name = name # type: ignore[attr-defined] + func._mcp_tool_description = description # type: ignore[attr-defined] + func._is_mcp_tool = True # type: ignore[attr-defined] + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + # If the function is already async, await it + if hasattr(func, "__code__") and func.__code__.co_flags & 0x80: # CO_COROUTINE + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + + # Copy metadata to wrapper + wrapper._mcp_tool_name = name # type: ignore[attr-defined] + wrapper._mcp_tool_description = description # type: ignore[attr-defined] + wrapper._is_mcp_tool = True # type: ignore[attr-defined] + wrapper._original_func = func # type: ignore[attr-defined] + + return wrapper + + return decorator diff --git a/lambda/mcp_workbench/syntax_validator.py b/lambda/mcp_workbench/syntax_validator.py new file mode 100644 index 000000000..50f6afa83 --- /dev/null +++ b/lambda/mcp_workbench/syntax_validator.py @@ -0,0 +1,301 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python syntax validation module for MCP Workbench.""" +import ast +import importlib.util +import logging +import os +import sys +from dataclasses import dataclass +from types import ModuleType +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class ValidationResult: + """Result of Python code validation.""" + + is_valid: bool + syntax_errors: List[Dict[str, Any]] + missing_required_imports: Optional[List[str]] = None + + def __post_init__(self) -> None: + """Initialize list fields if None.""" + if self.missing_required_imports is None: + self.missing_required_imports = [] + + +class PythonSyntaxValidator: + """Validates Python code syntax and imports without execution.""" + + # Required MCP Workbench imports + REQUIRED_MCP_IMPORTS = [("mcpworkbench.core.annotations", "mcp_tool"), ("mcpworkbench.core.base_tool", "BaseTool")] + + def __init__(self) -> None: + """Initialize the validator.""" + # Safety limits + self.max_code_size = 100_000 # 100KB + + def validate_code(self, code: str) -> ValidationResult: + """ + Validate Python code for syntax and required imports. + + Args: + code: Python code string to validate + + Returns: + ValidationResult with validation details + """ + syntax_errors = [] + + # Safety checks + if len(code) > self.max_code_size: + return ValidationResult( + is_valid=False, + syntax_errors=[ + { + "type": "CodeTooLarge", + "message": ( + f"Code size ({len(code)} bytes) exceeds maximum allowed ({self.max_code_size} bytes)" + ), + "line": 0, + "column": 0, + } + ], + missing_required_imports=[], + ) + + # Basic input validation + if not code or not code.strip(): + return ValidationResult( + is_valid=False, + syntax_errors=[{"type": "EmptyCode", "message": "Code cannot be empty", "line": 0, "column": 0}], + missing_required_imports=[], + ) + + # 1. AST-based syntax validation (fast check) + try: + tree = ast.parse(code) + logger.info("AST parsing successful") + except SyntaxError as e: + syntax_errors.append(self._format_syntax_error(e)) + logger.warning(f"Syntax error: {e}") + + # Return early if syntax is invalid + return ValidationResult(is_valid=False, syntax_errors=syntax_errors, missing_required_imports=[]) + except Exception as e: + syntax_errors.append( + {"type": "ParseError", "message": f"Failed to parse code: {str(e)}", "line": 0, "column": 0} + ) + logger.error(f"Parse error: {e}") + return ValidationResult(is_valid=False, syntax_errors=syntax_errors, missing_required_imports=[]) + + # 2. Module execution validation (comprehensive check) + execution_errors = self._validate_module_execution(code) + syntax_errors.extend(execution_errors) + + # 3. Check for required MCP imports + missing_required_imports = self._check_required_mcp_imports(tree) + + # Determine overall validity + is_valid = len(syntax_errors) == 0 and len(missing_required_imports) == 0 + + return ValidationResult( + is_valid=is_valid, syntax_errors=syntax_errors, missing_required_imports=missing_required_imports + ) + + def _validate_module_execution(self, code: str) -> List[Dict[str, Any]]: + """Validate code by attempting to execute it as a module.""" + errors = [] + + try: + # Set up the MCP environment FIRST (inject mocks into sys.modules) + # This must happen before exec() so imports can find the mocks + self._setup_mcp_environment(None) + + # Create a temporary module spec + spec = importlib.util.spec_from_loader("temp_validation_module", loader=None) + if spec is None: + errors.append( + { + "type": "ModuleError", + "message": "Failed to create module spec for validation", + "line": 0, + "column": 0, + } + ) + return errors + + module = importlib.util.module_from_spec(spec) + + # Execute the code in the module context + # The mocks are already in sys.modules so imports will work + exec(code, module.__dict__) # nosec B102 + logger.info("Module execution successful") + + except ImportError as e: + errors.append({"type": "ImportError", "message": str(e), "line": 0, "column": 0}) + logger.warning(f"Import error during execution: {e}") + except SyntaxError as e: + # Shouldn't happen since AST passed, but just in case + errors.append(self._format_syntax_error(e)) + logger.warning(f"Syntax error during execution: {e}") + except NameError as e: + errors.append({"type": "NameError", "message": str(e), "line": 0, "column": 0}) + logger.warning(f"Name error during execution: {e}") + except Exception as e: + errors.append( + {"type": "ExecutionError", "message": f"Error executing code: {str(e)}", "line": 0, "column": 0} + ) + logger.error(f"Execution error: {e}") + + return errors + + def _setup_mcp_environment(self, module: Any) -> None: + """Set up the module with required MCP imports available.""" + # Check if real MCP Workbench is available + if "mcpworkbench.core.base_tool" not in sys.modules: + # Real package not available, inject mocks into sys.modules + logger.info("Real MCP Workbench not found, setting up mocks") + mcp_tool_func: Any = None + base_tool_class: Any = None + + try: + # Try relative import first (when running as part of a package) + from .mcp_mocks import BaseTool as base_tool_class + from .mcp_mocks import mcp_tool as mcp_tool_func + + logger.info("Successfully imported mocks via relative import") + except ImportError as e: + logger.info(f"Relative import failed: {e}, trying absolute import") + try: + # Fall back to absolute import (when running standalone) + import mcp_mocks + + mcp_tool_func = mcp_mocks.mcp_tool + base_tool_class = mcp_mocks.BaseTool + logger.info("Successfully imported mocks via absolute import") + except ImportError as mock_error: + logger.error(f"CRITICAL: Failed to import MCP mocks via both methods: {mock_error}") + logger.error(f"Current directory: {os.getcwd() if 'os' in dir() else 'unknown'}") + logger.error(f"sys.path: {sys.path[:3]}") # Show first 3 paths + return + + # Create mock module hierarchy in sys.modules + # This allows user code to do: from mcpworkbench.core.base_tool import BaseTool + if "mcpworkbench" not in sys.modules: + sys.modules["mcpworkbench"] = ModuleType("mcpworkbench") + + if "mcpworkbench.core" not in sys.modules: + core_module = ModuleType("mcpworkbench.core") + sys.modules["mcpworkbench.core"] = core_module + sys.modules["mcpworkbench"].core = core_module # type: ignore[attr-defined] + + # Create and register the base_tool mock module + base_tool_module = ModuleType("mcpworkbench.core.base_tool") + base_tool_module.BaseTool = base_tool_class # type: ignore[attr-defined] + sys.modules["mcpworkbench.core.base_tool"] = base_tool_module + sys.modules["mcpworkbench.core"].base_tool = base_tool_module # type: ignore[attr-defined] + + # Create and register the annotations mock module + annotations_module = ModuleType("mcpworkbench.core.annotations") + annotations_module.mcp_tool = mcp_tool_func # type: ignore[attr-defined] + sys.modules["mcpworkbench.core.annotations"] = annotations_module + sys.modules["mcpworkbench.core"].annotations = annotations_module # type: ignore[attr-defined] + + logger.info("MCP mock modules successfully injected into sys.modules") + logger.info(f"Modules now in sys.modules: {[k for k in sys.modules.keys() if 'mcpworkbench' in k]}") + else: + logger.info("Real MCP Workbench package is already available in sys.modules") + + def _check_required_mcp_imports(self, tree: ast.AST) -> List[str]: + """Check if required MCP imports are present in the AST.""" + missing_required = [] + + # Collect all imports from the AST + imports = self._collect_imports(tree) + + # Check if at least one required import is present + has_required_import = False + + for module, name in self.REQUIRED_MCP_IMPORTS: + # Check if imported via 'from module import name' + if module in imports["from_imports"] and name in imports["from_imports"][module]: + has_required_import = True + break + + # Check if imported via star import + if module in imports["star_imports"]: + has_required_import = True + break + + if not has_required_import: + missing_required.append("At least one of the required MCP Workbench imports is missing") + + return missing_required + + def _collect_imports(self, tree: ast.AST) -> Dict[str, Any]: + """Collect all import statements from the AST.""" + imports: Dict[str, Any] = { + "modules": set(), # Direct module imports: import os + "from_imports": {}, # From imports: from os import path -> {'os': {'path'}} + "aliases": {}, # Import aliases: import numpy as np -> {'np': 'numpy'} + "star_imports": set(), # Star imports: from os import * + } + + class ImportVisitor(ast.NodeVisitor): + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + module_name = alias.name + alias_name = alias.asname or module_name + imports["modules"].add(module_name) + if alias.asname: + imports["aliases"][alias_name] = module_name + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module: + module_name = node.module + if node.names and len(node.names) == 1 and node.names[0].name == "*": + # Star import + imports["star_imports"].add(module_name) + else: + # Regular from import + if module_name not in imports["from_imports"]: + imports["from_imports"][module_name] = set() + + for alias in node.names: + name = alias.name + alias_name = alias.asname or name + imports["from_imports"][module_name].add(name) + if alias.asname: + imports["aliases"][alias_name] = f"{module_name}.{name}" + self.generic_visit(node) + + visitor = ImportVisitor() + visitor.visit(tree) + return imports + + def _format_syntax_error(self, syntax_error: SyntaxError) -> Dict[str, Any]: + """Format a SyntaxError into a standardized error dictionary.""" + return { + "type": "SyntaxError", + "message": str(syntax_error.msg) if syntax_error.msg else "Syntax error", + "line": syntax_error.lineno or 0, + "column": syntax_error.offset or 0, + "text": syntax_error.text.strip() if syntax_error.text else "", + } diff --git a/lambda/models/clients/litellm_client.py b/lambda/models/clients/litellm_client.py index 343dd12b9..11fc6546c 100644 --- a/lambda/models/clients/litellm_client.py +++ b/lambda/models/clients/litellm_client.py @@ -98,3 +98,100 @@ def get_model(self, identifier: str) -> Dict[str, Any]: if len(filtered_models) < 1: raise ModelNotFoundError("Specified model was not found.") return filtered_models[0] + + def create_guardrail(self, guardrail_config: Dict[str, Any]) -> Dict[str, Any]: + """ + Create a new guardrail configuration in LiteLLM. + + Args: + guardrail_config: Dictionary containing guardrail configuration including + guardrail_name, guardrail_identifier, guardrail_version, mode, etc. + + Returns: + Dictionary containing the created guardrail information including LiteLLM guardrail ID + """ + resp = requests.post( + self._base_uri + "/guardrails", + headers=self._headers, + json=guardrail_config, + timeout=self._timeout, + verify=self._verify, + ) + resp.raise_for_status() + return resp.json() # type: ignore [no-any-return] + + def update_guardrail(self, guardrail_id: str, guardrail_config: Dict[str, Any]) -> Dict[str, Any]: + """ + Update an existing guardrail configuration in LiteLLM. + + Args: + guardrail_id: The LiteLLM guardrail ID to update + guardrail_config: Dictionary containing updated guardrail configuration + + Returns: + Dictionary containing the updated guardrail information + """ + resp = requests.put( + self._base_uri + f"/guardrails/{guardrail_id}", + headers=self._headers, + json=guardrail_config, + timeout=self._timeout, + verify=self._verify, + ) + resp.raise_for_status() + return resp.json() # type: ignore [no-any-return] + + def delete_guardrail(self, guardrail_id: str) -> None: + """ + Delete a guardrail configuration from LiteLLM. + + Args: + guardrail_id: The LiteLLM guardrail ID to delete + """ + resp = requests.delete( + self._base_uri + f"/guardrails/{guardrail_id}", + headers=self._headers, + timeout=self._timeout, + verify=self._verify, + ) + resp.raise_for_status() + + def get_guardrail_info(self, guardrail_id: str) -> Dict[str, Any]: + """ + Get information about a specific guardrail. + + Args: + guardrail_id: The LiteLLM guardrail ID to retrieve + + Returns: + Dictionary containing guardrail information + """ + resp = requests.get( + self._base_uri + f"/guardrails/{guardrail_id}", + headers=self._headers, + timeout=self._timeout, + verify=self._verify, + ) + resp.raise_for_status() + return resp.json() # type: ignore [no-any-return] + + def apply_guardrail(self, guardrail_name: str, text: str) -> Dict[str, Any]: + """ + Apply a guardrail to text content for validation. + + Args: + guardrail_name: Name of the guardrail to apply + text: Text content to validate against the guardrail + + Returns: + Dictionary containing validation results + """ + resp = requests.post( + self._base_uri + "/guardrails/apply_guardrail", + headers=self._headers, + json={"guardrail_name": guardrail_name, "text": text}, + timeout=self._timeout, + verify=self._verify, + ) + resp.raise_for_status() + return resp.json() # type: ignore [no-any-return] diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index 195d11d0f..d2282d114 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -73,6 +73,65 @@ def __str__(self) -> str: EMBEDDING = "embedding" +class GuardrailMode(str, Enum): + """Defines supported guardrail execution modes.""" + + def __str__(self) -> str: + """Returns string representation of the enum value.""" + return str(self.value) + + PRE_CALL = "pre_call" + DURING_CALL = "during_call" + POST_CALL = "post_call" + + +class GuardrailConfig(BaseModel): + """Defines configuration for a single guardrail.""" + + guardrailName: str = Field(min_length=1) + guardrailIdentifier: str = Field(min_length=1) + guardrailVersion: str = Field(default="DRAFT") + mode: GuardrailMode = Field(default=GuardrailMode.PRE_CALL) + description: Optional[str] = None + allowedGroups: List[str] = Field(default_factory=list) + markedForDeletion: Optional[bool] = Field(default=False) + + +# Type alias for guardrails configuration - maps guardrail IDs to their configs +GuardrailsConfig: TypeAlias = Dict[str, GuardrailConfig] + + +class GuardrailRequest(BaseModel): + """Defines request structure for guardrails API operations.""" + + model_id: str = Field(min_length=1) + guardrails_config: GuardrailsConfig + + +class GuardrailResponse(BaseModel): + """Defines response structure for guardrails API operations.""" + + model_id: str + guardrails_config: GuardrailsConfig + success: bool + message: str + + +class GuardrailsTableEntry(BaseModel): + """Represents a guardrail entry in DynamoDB table.""" + + guardrailId: str # Partition key + modelId: str # Sort key + guardrailName: str + guardrailIdentifier: str + guardrailVersion: str + mode: str + description: Optional[str] + allowedGroups: List[str] + createdDate: int = Field(default_factory=lambda: int(time.time() * 1000)) + lastModifiedDate: int = Field(default_factory=lambda: int(time.time() * 1000)) + + class MetricConfig(BaseModel): """Defines metrics configuration for auto-scaling policies.""" @@ -230,6 +289,7 @@ class LISAModel(BaseModel): streaming: bool features: Optional[List[ModelFeature]] = None allowedGroups: Optional[List[str]] = None + guardrailsConfig: Optional[GuardrailsConfig] = None class ApiResponseBase(BaseModel): @@ -255,6 +315,7 @@ class CreateModelRequest(BaseModel): features: Optional[List[ModelFeature]] = None allowedGroups: Optional[List[str]] = None apiKey: Optional[str] = None + guardrailsConfig: Optional[GuardrailsConfig] = None @model_validator(mode="after") def validate_create_model_request(self) -> Self: @@ -308,6 +369,7 @@ class UpdateModelRequest(BaseModel): allowedGroups: Optional[List[str]] = None features: Optional[List[ModelFeature]] = None containerConfig: Optional[ContainerConfigUpdatable] = None + guardrailsConfig: Optional[GuardrailsConfig] = None @model_validator(mode="after") def validate_update_model_request(self) -> Self: @@ -321,11 +383,13 @@ def validate_update_model_request(self) -> Self: self.allowedGroups, self.features, self.containerConfig, + self.guardrailsConfig, ] if not validate_any_fields_defined(fields): raise ValueError( "At least one field out of autoScalingInstanceConfig, containerConfig, enabled, modelType, " - "modelDescription, streaming, allowedGroups, or features must be defined in request payload." + "modelDescription, streaming, allowedGroups, features, or guardrailsConfig must be " + "defined in request payload." ) if self.modelType == ModelType.EMBEDDING and self.streaming: diff --git a/lambda/models/handler/base_handler.py b/lambda/models/handler/base_handler.py index e165ab913..66ac1c30e 100644 --- a/lambda/models/handler/base_handler.py +++ b/lambda/models/handler/base_handler.py @@ -26,11 +26,13 @@ def __init__( autoscaling_client: Any, stepfunctions_client: Any, model_table_resource: Any, + guardrails_table_resource: Any, ): """Make all clients available for use in any handler class.""" self._autoscaling = autoscaling_client self._stepfunctions = stepfunctions_client self._model_table = model_table_resource + self._guardrails_table = guardrails_table_resource def __call__(self, *args: Any, **kwargs: Any) -> Any: """All handlers must implement the __call__ method.""" diff --git a/lambda/models/handler/get_model_handler.py b/lambda/models/handler/get_model_handler.py index 675381953..c6b82d520 100644 --- a/lambda/models/handler/get_model_handler.py +++ b/lambda/models/handler/get_model_handler.py @@ -21,7 +21,7 @@ from ..domain_objects import GetModelResponse from ..exception import ModelNotFoundError from .base_handler import BaseApiHandler -from .utils import to_lisa_model +from .utils import attach_guardrails_to_model, fetch_guardrails_for_model, to_lisa_model class GetModelHandler(BaseApiHandler): @@ -37,6 +37,10 @@ def __call__( model = to_lisa_model(ddb_item) + # Fetch and attach guardrails for this model + guardrail_items = fetch_guardrails_for_model(self._guardrails_table, model_id) + attach_guardrails_to_model(model, guardrail_items) + # Check if user has access to this model based on groups if not is_admin and user_groups is not None: if not user_has_group_access(user_groups, model.allowedGroups or []): diff --git a/lambda/models/handler/list_models_handler.py b/lambda/models/handler/list_models_handler.py index 56b29fd7b..c1f27ef90 100644 --- a/lambda/models/handler/list_models_handler.py +++ b/lambda/models/handler/list_models_handler.py @@ -20,7 +20,7 @@ from ..domain_objects import ListModelsResponse from .base_handler import BaseApiHandler -from .utils import to_lisa_model +from .utils import attach_guardrails_to_model, fetch_all_guardrails, group_guardrails_by_model, to_lisa_model class ListModelsHandler(BaseApiHandler): @@ -39,6 +39,15 @@ def __call__(self, user_groups: Optional[List[str]] = None, is_admin: bool = Fal models_list = [to_lisa_model(m) for m in ddb_models] + # Fetch all guardrails and group them by model ID + all_guardrails = fetch_all_guardrails(self._guardrails_table) + guardrails_by_model = group_guardrails_by_model(all_guardrails) + + # Attach guardrails to models + for model in models_list: + if model.modelId in guardrails_by_model: + attach_guardrails_to_model(model, guardrails_by_model[model.modelId]) + # Filter models based on user groups if not admin if not is_admin and user_groups is not None: models_list = [ diff --git a/lambda/models/handler/update_model_handler.py b/lambda/models/handler/update_model_handler.py index 6ae691a28..4d21d631f 100644 --- a/lambda/models/handler/update_model_handler.py +++ b/lambda/models/handler/update_model_handler.py @@ -20,7 +20,7 @@ from ..domain_objects import ModelStatus, UpdateModelRequest, UpdateModelResponse from ..exception import InvalidStateTransitionError, ModelNotFoundError from .base_handler import BaseApiHandler -from .utils import to_lisa_model +from .utils import attach_guardrails_to_model, fetch_guardrails_for_model, to_lisa_model class UpdateModelHandler(BaseApiHandler): @@ -109,4 +109,10 @@ def __call__(self, model_id: str, update_request: UpdateModelRequest) -> UpdateM stateMachineArn=os.environ["UPDATE_SFN_ARN"], input=json.dumps(state_machine_payload) ) - return UpdateModelResponse(model=to_lisa_model(ddb_item)) + model = to_lisa_model(ddb_item) + + # Fetch and attach guardrails for this model + guardrail_items = fetch_guardrails_for_model(self._guardrails_table, model_id) + attach_guardrails_to_model(model, guardrail_items) + + return UpdateModelResponse(model=model) diff --git a/lambda/models/handler/utils.py b/lambda/models/handler/utils.py index db2b1519d..f45df0b4e 100644 --- a/lambda/models/handler/utils.py +++ b/lambda/models/handler/utils.py @@ -14,9 +14,9 @@ """Common utility functions across all API handlers.""" -from typing import Any, Dict +from typing import Any, Dict, List -from ..domain_objects import LISAModel +from ..domain_objects import GuardrailConfig, LISAModel def to_lisa_model(model_dict: Dict[str, Any]) -> LISAModel: @@ -26,3 +26,55 @@ def to_lisa_model(model_dict: Dict[str, Any]) -> LISAModel: model_dict["model_config"]["modelUrl"] = model_dict["model_url"] lisa_model: LISAModel = LISAModel.model_validate(model_dict["model_config"]) return lisa_model + + +def create_guardrail_config(item: Dict[str, Any]) -> GuardrailConfig: + """Create a GuardrailConfig object from a DynamoDB guardrail item.""" + return GuardrailConfig(**item) + + +def attach_guardrails_to_model(model: LISAModel, guardrail_items: List[Dict[str, Any]]) -> None: + """Build guardrails config from DDB items and attach to model.""" + if not guardrail_items: + return + + model.guardrailsConfig = { + f"guardrail-{item['guardrailName']}": create_guardrail_config(item) for item in guardrail_items + } + + +def fetch_guardrails_for_model(guardrails_table, model_id: str) -> List[Dict[str, Any]]: + """Query guardrails table for a specific model ID.""" + guardrails_response = guardrails_table.query( + IndexName="ModelIdIndex", + KeyConditionExpression="modelId = :modelId", + ExpressionAttributeValues={":modelId": model_id}, + ) + return guardrails_response.get("Items", []) + + +def fetch_all_guardrails(guardrails_table) -> List[Dict[str, Any]]: + """Scan all guardrails from the table with pagination.""" + all_guardrails = [] + guardrails_response = guardrails_table.scan() + all_guardrails.extend(guardrails_response.get("Items", [])) + pagination_key = guardrails_response.get("LastEvaluatedKey", None) + + while pagination_key: + guardrails_response = guardrails_table.scan(ExclusiveStartKey=pagination_key) + all_guardrails.extend(guardrails_response.get("Items", [])) + pagination_key = guardrails_response.get("LastEvaluatedKey", None) + + return all_guardrails + + +def group_guardrails_by_model(guardrail_items: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: + """Group guardrail items by modelId.""" + guardrails_by_model: Dict[str, List[Dict[str, Any]]] = {} + for item in guardrail_items: + model_id = item["modelId"] + if model_id not in guardrails_by_model: + guardrails_by_model[model_id] = [] + guardrails_by_model[model_id].append(item) + + return guardrails_by_model diff --git a/lambda/models/lambda_functions.py b/lambda/models/lambda_functions.py index 034a89dd0..595b03655 100644 --- a/lambda/models/lambda_functions.py +++ b/lambda/models/lambda_functions.py @@ -57,6 +57,7 @@ dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config) model_table = dynamodb.Table(os.environ["MODEL_TABLE_NAME"]) +guardrails_table = dynamodb.Table(os.environ["GUARDRAILS_TABLE_NAME"]) stepfunctions = boto3.client("stepfunctions", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -92,6 +93,7 @@ async def create_model(create_request: CreateModelRequest) -> CreateModelRespons autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, model_table_resource=model_table, + guardrails_table_resource=guardrails_table, ) return create_handler(create_request=create_request) @@ -104,6 +106,7 @@ async def list_models(request: Request) -> ListModelsResponse: autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, model_table_resource=model_table, + guardrails_table_resource=guardrails_table, ) user_groups = [] @@ -130,6 +133,7 @@ async def get_model( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, model_table_resource=model_table, + guardrails_table_resource=guardrails_table, ) user_groups = [] @@ -157,6 +161,7 @@ async def update_model( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, model_table_resource=model_table, + guardrails_table_resource=guardrails_table, ) return update_handler(model_id=model_id, update_request=update_request) @@ -170,6 +175,7 @@ async def delete_model( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, model_table_resource=model_table, + guardrails_table_resource=guardrails_table, ) return delete_handler(model_id=model_id) diff --git a/lambda/models/state_machine/create_model.py b/lambda/models/state_machine/create_model.py index b757c0238..365af01e7 100644 --- a/lambda/models/state_machine/create_model.py +++ b/lambda/models/state_machine/create_model.py @@ -24,7 +24,7 @@ import boto3 from botocore.config import Config from models.clients.litellm_client import LiteLLMClient -from models.domain_objects import CreateModelRequest, InferenceContainer, ModelStatus +from models.domain_objects import CreateModelRequest, GuardrailsTableEntry, InferenceContainer, ModelStatus from models.exception import ( MaxPollsExceededException, StackFailedToCreateException, @@ -46,6 +46,7 @@ ec2Client = boto3.client("ec2", region_name=os.environ["AWS_REGION"], config=retry_config) ddbResource = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) model_table = ddbResource.Table(os.environ["MODEL_TABLE_NAME"]) +guardrails_table = ddbResource.Table(os.environ["GUARDRAILS_TABLE_NAME"]) cfnClient = boto3.client("cloudformation", region_name=os.environ["AWS_REGION"], config=retry_config) iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -96,6 +97,11 @@ def handle_set_model_to_creating(event: Dict[str, Any], context: Any) -> Dict[st ) ) + # Create a copy of event without guardrailsConfig + # Guardrails are stored separately in the guardrails table + model_config_data = deepcopy(event) + model_config_data.pop("guardrailsConfig", None) + model_table.update_item( Key={"model_id": request.modelId}, UpdateExpression=( @@ -104,7 +110,7 @@ def handle_set_model_to_creating(event: Dict[str, Any], context: Any) -> Dict[st ), ExpressionAttributeValues={ ":model_status": ModelStatus.CREATING, - ":model_config": event, + ":model_config": model_config_data, ":model_description": request.modelDescription, ":lm": int(datetime.now(UTC).timestamp()), }, @@ -137,7 +143,7 @@ def handle_start_copy_docker_image(event: Dict[str, Any], context: Any) -> Dict[ # Remove registry URL if present to get just the repository name if "/" in repository_name: - repository_name = repository_name.split("/")[-1] + repository_name = repository_name.split("/", 1)[1] # Verify image exists in ECR ecrClient.describe_images(repositoryName=repository_name, imageIds=[{"imageTag": image_tag}]) @@ -406,6 +412,108 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str return output_dict +def handle_add_guardrails_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Add guardrails to LiteLLM and store them in DynamoDB.""" + logger.info(f"Adding guardrails to LiteLLM for model: {event.get('modelId')}") + output_dict = deepcopy(event) + + # Check if guardrails config exists + if not event.get("guardrailsConfig") or not event["guardrailsConfig"]: + logger.info("No guardrails configuration found, skipping guardrail creation") + output_dict["guardrail_ids"] = [] + return output_dict + + guardrail_ids = [] + created_guardrails = [] + + try: + # Process each guardrail in the configuration + for guardrail_key, guardrail_config in event["guardrailsConfig"].items(): + logger.info(f"Processing guardrail: {guardrail_key}") + + model_id = event["modelId"] + + # Transform guardrail config to LiteLLM format + litellm_guardrail_config = { + "guardrail": { + "guardrail_name": f'{guardrail_config["guardrailName"]}-{model_id}', + "litellm_params": { + "guardrail": "bedrock", + "mode": str(guardrail_config.get("mode", "pre_call")), + "guardrailIdentifier": guardrail_config["guardrailIdentifier"], + "guardrailVersion": guardrail_config.get("guardrailVersion", "DRAFT"), + "default_on": False, + }, + "guardrail_info": {"description": guardrail_config.get("description", "")}, + } + } + + # Create guardrail in LiteLLM + logger.info(f"Creating guardrail in LiteLLM: {guardrail_config['guardrailName']}") + litellm_response = litellm_client.create_guardrail(litellm_guardrail_config) + + # Extract LiteLLM guardrail ID from response + litellm_guardrail_id = None + if "guardrail_id" in litellm_response: + litellm_guardrail_id = litellm_response["guardrail_id"] + else: + logger.error(f"Unexpected LiteLLM guardrail response structure: {litellm_response}") + raise KeyError(f"Could not find guardrail ID in LiteLLM response: {litellm_response}") + + # Create guardrail entry for DynamoDB + guardrail_entry = GuardrailsTableEntry( + guardrailId=litellm_guardrail_id, + modelId=model_id, + guardrailName=guardrail_config["guardrailName"], + guardrailIdentifier=guardrail_config["guardrailIdentifier"], + guardrailVersion=guardrail_config.get("guardrailVersion", "DRAFT"), + mode=guardrail_config.get("mode", "pre_call"), + description=guardrail_config.get("description"), + allowedGroups=guardrail_config.get("allowedGroups", []), + ) + + # Store in DynamoDB + logger.info(f"Storing guardrail in DynamoDB: {litellm_guardrail_id}") + guardrails_table.put_item(Item=guardrail_entry.model_dump()) + + guardrail_ids.append(litellm_guardrail_id) + created_guardrails.append( + { + "guardrail_id": litellm_guardrail_id, + "guardrail_name": guardrail_config["guardrailName"], + } + ) + + logger.info( + f"Successfully created guardrail: {guardrail_config['guardrailName']} with ID: {litellm_guardrail_id}" + ) + + except Exception as e: + logger.error(f"Error creating guardrails: {str(e)}") + + # Clean up any created guardrails on failure + for created_guardrail in created_guardrails: + try: + logger.info(f"Cleaning up guardrail: {created_guardrail['guardrail_id']}") + # Delete from DynamoDB + guardrails_table.delete_item( + Key={"guardrail_id": created_guardrail["guardrail_id"], "model_id": event["modelId"]} + ) + # Delete from LiteLLM + litellm_client.delete_guardrail(created_guardrail["litellm_guardrail_id"]) + except Exception as cleanup_error: + logger.error(f"Error during guardrail cleanup: {str(cleanup_error)}") + + # Re-raise the original exception + raise e + + output_dict["guardrail_ids"] = guardrail_ids + output_dict["created_guardrails"] = created_guardrails + + logger.info(f"Successfully created {len(guardrail_ids)} guardrails for model: {event['modelId']}") + return output_dict + + def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]: """ Handle failures from state machine. @@ -419,11 +527,32 @@ def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]: to Failed. Cleaning up the CloudFormation stack, if it still exists, will happen in the DeleteModel API. """ logger.error(f"Handling state machine failure: {event}") - error_dict = json.loads( # error from SFN is json payload on top of json payload we add to the exception - json.loads(event["Cause"])["errorMessage"] - ) - error_reason = error_dict["error"] - original_event = error_dict["event"] + + try: + # Parse the error from Step Functions + cause_data = json.loads(event["Cause"]) + error_message = cause_data["errorMessage"] + + # Try to parse the error message as JSON (for our custom exceptions) + try: + error_dict = json.loads(error_message) + if isinstance(error_dict, dict) and "error" in error_dict: + error_reason = error_dict["error"] + original_event = error_dict.get("event", event) + else: + # If it's not our expected format, use the raw error message + error_reason = str(error_dict) if error_dict else "Unknown error" + original_event = event + except (json.JSONDecodeError, TypeError): + # If error_message is not JSON, use it directly + error_reason = error_message + original_event = event + + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.error(f"Error parsing failure event: {str(e)}") + error_reason = f"Failed to parse error details: {str(e)}" + original_event = event + logger.error(f"Failure reason: {error_reason}, Model ID: {original_event.get('modelId', 'unknown')}") # terminate EC2 instance if we have one recorded diff --git a/lambda/models/state_machine/delete_model.py b/lambda/models/state_machine/delete_model.py index e5f4ac9cc..e9707364b 100644 --- a/lambda/models/state_machine/delete_model.py +++ b/lambda/models/state_machine/delete_model.py @@ -34,6 +34,7 @@ cloudformation = boto3.client("cloudformation", region_name=os.environ["AWS_REGION"], config=retry_config) dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) ddb_table = dynamodb.Table(os.environ["MODEL_TABLE_NAME"]) +guardrails_table = dynamodb.Table(os.environ["GUARDRAILS_TABLE_NAME"]) iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config) secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -87,6 +88,68 @@ def handle_delete_from_litellm(event: Dict[str, Any], context: Any) -> Dict[str, return event +def handle_delete_guardrails(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Delete all guardrails associated with the model from both LiteLLM and DynamoDB.""" + logger.info(f"Deleting guardrails for model: {event.get('modelId')}") + output_dict = deepcopy(event) + + model_id = event["modelId"] + deleted_guardrails = [] + + try: + # Get all guardrails for this model from DynamoDB + response = guardrails_table.query( + IndexName="ModelIdIndex", + KeyConditionExpression="modelId = :modelId", + ExpressionAttributeValues={":modelId": model_id}, + ) + + guardrails_to_delete = response.get("Items", []) + + if not guardrails_to_delete: + logger.info(f"No guardrails found for model: {model_id}") + output_dict["deleted_guardrails"] = [] + return output_dict + + logger.info(f"Found {len(guardrails_to_delete)} guardrails to delete for model: {model_id}") + + # Delete each guardrail from both LiteLLM and DynamoDB + for guardrail in guardrails_to_delete: + guardrail_id = guardrail["guardrailId"] + guardrail_name = guardrail["guardrailName"] + + try: + logger.info(f"Deleting guardrail from LiteLLM: {guardrail_name} (ID: {guardrail_id})") + # Delete from LiteLLM + litellm_client.delete_guardrail(guardrail_id) + + logger.info(f"Deleting guardrail from DynamoDB: {guardrail_name} (ID: {guardrail_id})") + # Delete from DynamoDB + guardrails_table.delete_item(Key={"guardrailId": guardrail_id, "modelId": model_id}) + + deleted_guardrails.append( + {"guardrail_id": guardrail_id, "guardrail_name": guardrail_name, "action": "deleted"} + ) + + logger.info(f"Successfully deleted guardrail: {guardrail_name}") + + except Exception as delete_error: + logger.error(f"Error deleting individual guardrail {guardrail_name}: {str(delete_error)}") + # Continue with other guardrails even if one fails + # We don't want to stop the entire model deletion process because of a guardrail deletion failure + continue + + except Exception as e: + logger.error(f"Error during guardrail deletion process for model {model_id}: {str(e)}") + # Don't raise the exception - we want to continue with model deletion even if guardrail cleanup fails + # Log the error but proceed with the deletion workflow + + output_dict["deleted_guardrails"] = deleted_guardrails + logger.info(f"Completed guardrail deletion for model: {model_id}. Deleted {len(deleted_guardrails)} guardrails.") + + return output_dict + + def handle_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: """Initialize stack deletion.""" stack_arn = event[CFN_STACK_ARN] diff --git a/lambda/models/state_machine/update_model.py b/lambda/models/state_machine/update_model.py index 7479cc49a..3978fcd39 100644 --- a/lambda/models/state_machine/update_model.py +++ b/lambda/models/state_machine/update_model.py @@ -23,11 +23,12 @@ import boto3 from models.clients.litellm_client import LiteLLMClient -from models.domain_objects import ModelStatus +from models.domain_objects import GuardrailsTableEntry, ModelStatus from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config ddbResource = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) model_table = ddbResource.Table(os.environ["MODEL_TABLE_NAME"]) +guardrails_table = ddbResource.Table(os.environ["GUARDRAILS_TABLE_NAME"]) autoscaling_client = boto3.client("autoscaling", region_name=os.environ["AWS_REGION"], config=retry_config) ecs_client = boto3.client("ecs", region_name=os.environ["AWS_REGION"], config=retry_config) cfn_client = boto3.client("cloudformation", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -297,6 +298,9 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: has_metadata_update = has_metadata_update or is_autoscaling_update if has_metadata_update: + # Remove guardrailsConfig from model_config before storing + # Guardrails are stored separately in the guardrails table + model_config.pop("guardrailsConfig", None) ddb_update_expression += ", model_config = :mc" ddb_update_values[":mc"] = model_config @@ -320,10 +324,14 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: and model_status == ModelStatus.IN_SERVICE ) + # Determine if guardrails update is needed + needs_guardrails_update = event["update_payload"].get("guardrailsConfig") is not None + # We only need to poll for activation so that we know when to add the model back to LiteLLM output_dict["has_capacity_update"] = is_enable output_dict["is_disable"] = is_disable output_dict["needs_ecs_update"] = needs_ecs_update + output_dict["needs_guardrails_update"] = needs_guardrails_update output_dict["initial_model_status"] = model_status # needed for simple metadata updates output_dict["current_model_status"] = ddb_update_values[":ms"] # for state machine debugging / visibility @@ -433,6 +441,268 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict +def handle_update_guardrails(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Update guardrails for a model in LiteLLM and DynamoDB. + + This handler will: + 1. Process guardrails configuration updates from the event + 2. Update existing guardrails in LiteLLM + 3. Update guardrail entries in DynamoDB + 4. Handle creation of new guardrails and deletion of removed ones + """ + logger.info(f"Updating guardrails for model: {event.get('model_id')}") + output_dict = deepcopy(event) + + model_id = event["model_id"] + guardrails_config = event["update_payload"].get("guardrailsConfig") + + # Check if guardrails config exists + if not guardrails_config: + logger.info("No guardrails configuration found, skipping guardrail updates") + output_dict["guardrail_update_ids"] = [] + return output_dict + + updated_guardrails = [] + created_guardrails = [] + deleted_guardrails = [] + + try: + # Get existing guardrails for this model from DynamoDB + existing_guardrails = {} + response = guardrails_table.query( + IndexName="ModelIdIndex", + KeyConditionExpression="modelId = :modelId", + ExpressionAttributeValues={":modelId": model_id}, + ) + + for item in response.get("Items", []): + existing_guardrails[item["guardrailName"]] = item + + # Process each guardrail in the new configuration + processed_guardrail_names = set() + + for guardrail_config in guardrails_config.values(): + guardrail_name = guardrail_config["guardrailName"] + + logger.info(f"Processing guardrail update: {guardrail_name}") + + # Check if this guardrail is marked for deletion using deletion flag + if guardrail_config.get("markedForDeletion", False): + logger.info(f"Found guardrail marked for deletion: {guardrail_name}") + + # Add to processed names to prevent double deletion later + processed_guardrail_names.add(guardrail_name) + + # Find the existing guardrail to delete by name + guardrail_to_delete = existing_guardrails.get(guardrail_name) + + if guardrail_to_delete: + try: + logger.info( + f"Deleting guardrail: {guardrail_to_delete['guardrailName']} " + f"(ID: {guardrail_to_delete['guardrailId']})" + ) + + # Delete from LiteLLM + litellm_client.delete_guardrail(guardrail_to_delete["guardrailId"]) + + # Delete from DynamoDB + guardrails_table.delete_item( + Key={"guardrailId": guardrail_to_delete["guardrailId"], "modelId": model_id} + ) + + deleted_guardrails.append( + { + "guardrail_id": guardrail_to_delete["guardrailId"], + "guardrail_name": guardrail_to_delete["guardrailName"], + "action": "deleted", + } + ) + + logger.info(f"Successfully deleted guardrail: {guardrail_to_delete['guardrailName']}") + + except Exception as delete_error: + logger.error(f"Error deleting guardrail marked for deletion: {str(delete_error)}") + # Continue with other operations even if one deletion fails + else: + logger.warning(f"No matching guardrail found for deletion: {guardrail_name}") + + # Skip normal processing for deletion markers + continue + + processed_guardrail_names.add(guardrail_name) + + # Check if this is an existing guardrail or a new one + if guardrail_name in existing_guardrails: + # Update existing guardrail + existing_guardrail = existing_guardrails[guardrail_name] + litellm_guardrail_id = existing_guardrail["guardrailId"] + + # Transform guardrail config to LiteLLM format for update + litellm_guardrail_config = { + "guardrail": { + "guardrail_name": f'{guardrail_config["guardrailName"]}-{model_id}', + "litellm_params": { + "guardrail": "bedrock", + "mode": str(guardrail_config.get("mode", "pre_call")), + "guardrailIdentifier": guardrail_config["guardrailIdentifier"], + "guardrailVersion": guardrail_config.get("guardrailVersion", "DRAFT"), + "default_on": False, + }, + "guardrail_info": {"description": guardrail_config.get("description", "")}, + } + } + + # Update guardrail in LiteLLM + logger.info(f"Updating guardrail in LiteLLM: {guardrail_name}") + litellm_client.update_guardrail(litellm_guardrail_id, litellm_guardrail_config) + + # Update guardrail entry in DynamoDB + update_expression = ( + "SET guardrailIdentifier = :gi, guardrailVersion = :gv, #mode = :m, " + "description = :d, allowedGroups = :ag, lastModifiedDate = :lm" + ) + guardrails_table.update_item( + Key={"guardrailId": existing_guardrail["guardrailId"], "modelId": model_id}, + UpdateExpression=update_expression, + ExpressionAttributeNames={"#mode": "mode"}, # mode is a reserved keyword in DynamoDB + ExpressionAttributeValues={ + ":gi": guardrail_config["guardrailIdentifier"], + ":gv": guardrail_config.get("guardrailVersion", "DRAFT"), + ":m": str(guardrail_config.get("mode", "pre_call")), + ":d": guardrail_config.get("description"), + ":ag": guardrail_config.get("allowedGroups", []), + ":lm": int(datetime.now(UTC).timestamp() * 1000), + }, + ) + + updated_guardrails.append( + { + "guardrail_id": existing_guardrail["guardrailId"], + "guardrail_name": guardrail_name, + "action": "updated", + } + ) + + logger.info(f"Successfully updated guardrail: {guardrail_name}") + + else: + + # Transform guardrail config to LiteLLM format + litellm_guardrail_config = { + "guardrail": { + "guardrail_name": f'{guardrail_config["guardrailName"]}-{model_id}', + "litellm_params": { + "guardrail": "bedrock", + "mode": str(guardrail_config.get("mode", "pre_call")), + "guardrailIdentifier": guardrail_config["guardrailIdentifier"], + "guardrailVersion": guardrail_config.get("guardrailVersion", "DRAFT"), + "default_on": False, + }, + "guardrail_info": {"description": guardrail_config.get("description", "")}, + } + } + + # Create guardrail in LiteLLM + logger.info(f"Creating new guardrail in LiteLLM: {guardrail_name}") + litellm_response = litellm_client.create_guardrail(litellm_guardrail_config) + + # Extract LiteLLM guardrail ID from response + litellm_guardrail_id = None + if "guardrail_id" in litellm_response: + litellm_guardrail_id = litellm_response["guardrail_id"] + else: + logger.error(f"Unexpected LiteLLM guardrail response structure: {litellm_response}") + raise KeyError(f"Could not find guardrail ID in LiteLLM response: {litellm_response}") + + # Create guardrail entry for DynamoDB + guardrail_entry = GuardrailsTableEntry( + guardrailId=litellm_guardrail_id, + modelId=model_id, + guardrailName=guardrail_config["guardrailName"], + guardrailIdentifier=guardrail_config["guardrailIdentifier"], + guardrailVersion=guardrail_config.get("guardrailVersion", "DRAFT"), + mode=str(guardrail_config.get("mode", "pre_call")), + description=guardrail_config.get("description"), + allowedGroups=guardrail_config.get("allowedGroups", []), + ) + + # Store in DynamoDB + logger.info(f"Storing new guardrail in DynamoDB: {litellm_guardrail_id}") + guardrails_table.put_item(Item=guardrail_entry.model_dump()) + + created_guardrails.append( + {"guardrail_id": litellm_guardrail_id, "guardrail_name": guardrail_name, "action": "created"} + ) + + logger.info(f"Successfully created new guardrail: {guardrail_name}") + + # Delete guardrails that are no longer in the configuration + for guardrail_name, existing_guardrail in existing_guardrails.items(): + if guardrail_name not in processed_guardrail_names: + logger.info(f"Deleting removed guardrail: {guardrail_name}") + + try: + # Delete from LiteLLM + litellm_client.delete_guardrail(existing_guardrail["guardrailId"]) + + # Delete from DynamoDB + guardrails_table.delete_item( + Key={"guardrailId": existing_guardrail["guardrailId"], "modelId": model_id} + ) + + deleted_guardrails.append( + { + "guardrail_id": existing_guardrail["guardrailId"], + "guardrail_name": guardrail_name, + "action": "deleted", + } + ) + + logger.info(f"Successfully deleted guardrail: {guardrail_name}") + + except Exception as delete_error: + logger.error(f"Error deleting guardrail {guardrail_name}: {str(delete_error)}") + # Continue with other operations even if one deletion fails + + # Combine all operations for output + all_guardrail_operations = updated_guardrails + created_guardrails + deleted_guardrails + + except Exception as e: + logger.error(f"Error updating guardrails: {str(e)}") + + # Clean up any newly created guardrails on failure + for created_guardrail in created_guardrails: + try: + logger.info(f"Cleaning up created guardrail: {created_guardrail['guardrail_id']}") + # Delete from DynamoDB + guardrails_table.delete_item( + Key={"guardrailId": created_guardrail["guardrail_id"], "modelId": model_id} + ) + # Delete from LiteLLM + litellm_client.delete_guardrail(created_guardrail["guardrail_id"]) + except Exception as cleanup_error: + logger.error(f"Error during guardrail cleanup: {str(cleanup_error)}") + + # Re-raise the original exception + raise e + + output_dict["guardrail_updates"] = all_guardrail_operations + output_dict["guardrail_update_summary"] = { + "updated": len(updated_guardrails), + "created": len(created_guardrails), + "deleted": len(deleted_guardrails), + } + + logger.info( + f"Successfully processed guardrail updates for model: {model_id}. " + f"Updated: {len(updated_guardrails)}, Created: {len(created_guardrails)}, " + f"Deleted: {len(deleted_guardrails)}" + ) + return output_dict + + def get_ecs_resources_from_stack(stack_name: str) -> tuple[str, str, str]: """Extract ECS service name, cluster name, and current task definition ARN from CloudFormation.""" try: diff --git a/lib/api-base/ecsCluster.ts b/lib/api-base/ecsCluster.ts index 225002489..9366c5741 100644 --- a/lib/api-base/ecsCluster.ts +++ b/lib/api-base/ecsCluster.ts @@ -45,7 +45,7 @@ import { ListenerCondition, SslPolicy, } from 'aws-cdk-lib/aws-elasticloadbalancingv2'; -import { Effect, IRole, ManagedPolicy, PolicyStatement, Role } from 'aws-cdk-lib/aws-iam'; +import { IRole, ManagedPolicy, Role } from 'aws-cdk-lib/aws-iam'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { Construct } from 'constructs'; @@ -71,6 +71,7 @@ type ECSClusterProps = { ecsConfig: ECSConfig; securityGroup: ISecurityGroup; vpc: Vpc; + environment: Record; } & BaseProps; /** @@ -90,7 +91,27 @@ export class ECSCluster extends Construct { /** Map of all services by identifier */ public readonly services: Partial> = {}; + /** Application Load Balancer */ + public readonly loadBalancer: ApplicationLoadBalancer; + + /** Application Listener */ + public readonly listener: any; + + /** ECS Cluster */ + public readonly cluster: Cluster; + private readonly targetGroups: Partial> = {}; + private readonly config: Config; + private readonly ecsConfig: ECSConfig; + private readonly vpc: Vpc; + private readonly securityGroup: ISecurityGroup; + private readonly logGroup: LogGroup; + private readonly volumes: Volume[]; + private readonly mountPoints: MountPoint[]; + private readonly baseEnvironment: Record; + private readonly autoScalingGroup: AutoScalingGroup; + private readonly asgCapacityProvider: AsgCapacityProvider; + private readonly identifier: string; /** * Creates a task definition with its associated container and IAM role (base method). @@ -108,7 +129,6 @@ export class ECSCluster extends Construct { taskDefinitionName: string, config: Config, taskDefinition: TaskDefinition, - baseEnvironment: Record, ecsConfig: ECSConfig, volumes: Volume[], mountPoints: MountPoint[], @@ -124,24 +144,9 @@ export class ECSCluster extends Construct { ...(executionRole && { executionRole }), }); - // Grant CloudWatch logs permissions to both task role and execution role + // Grant CloudWatch logs write permissions to task role and execution role logGroup.grantWrite(taskRole); - if (executionRole) { - logGroup.grantWrite(executionRole); - } else { - // If no custom execution role, ensure the default execution role has CloudWatch permissions - // This is critical for log stream creation during container startup - ec2TaskDefinition.addToExecutionRolePolicy(new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'logs:CreateLogGroup', - 'logs:CreateLogStream', - 'logs:PutLogEvents', - 'logs:DescribeLogStreams' - ], - resources: [logGroup.logGroupArn, `${logGroup.logGroupArn}:*`] - })); - } + logGroup.grantWrite(ec2TaskDefinition.obtainExecutionRole()); // Add container to task definition const containerHealthCheckConfig = taskDefinition.containerConfig.healthCheckConfig; @@ -165,7 +170,7 @@ export class ECSCluster extends Construct { const container = ec2TaskDefinition.addContainer(createCdkId([taskDefinitionName, 'Container']), { containerName: createCdkId([config.deploymentName, taskDefinitionName], 32, 2), image, - environment: {...baseEnvironment, ...taskDefinition.environment}, + environment: {...this.baseEnvironment, ...taskDefinition.environment}, logging: LogDriver.awsLogs({ logGroup: logGroup, streamPrefix: taskDefinitionName @@ -191,21 +196,8 @@ export class ECSCluster extends Construct { */ constructor (scope: Construct, id: string, props: ECSClusterProps) { super(scope, id); - const { config, identifier, vpc, securityGroup, ecsConfig } = props; - - // Retrieve execution role if it has been overridden - const executionRole = config.roles ? Role.fromRoleArn( - this, - createCdkId([identifier, 'ER']), - StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/roles/${identifier}EX`), - ) : undefined; - - // Create ECS task definition - const taskRole = Role.fromRoleArn( - this, - createCdkId([identifier, 'TR']), - StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/roles/${identifier}`), - ); + const { config, identifier, vpc, securityGroup, ecsConfig, environment } = props; + this.identifier = identifier; // Create ECS cluster const cluster = new Cluster(this, createCdkId([config.deploymentName, config.deploymentStage, 'Cl']), { @@ -294,7 +286,7 @@ export class ECSCluster extends Construct { const baseEnvironment: { [key: string]: string; - } = {}; + } = {...environment}; const volumes: Volume[] = []; const mountPoints: MountPoint[] = []; @@ -401,6 +393,11 @@ export class ECSCluster extends Construct { createCdkId([identifier, 'ApplicationListener']), listenerProps, ); + + // Expose load balancer, listener, and cluster for shared use + this.loadBalancer = loadBalancer; + this.listener = listener; + this.cluster = cluster; const protocol = listenerProps.port === 443 ? 'https' : 'http'; const domain = @@ -413,79 +410,105 @@ export class ECSCluster extends Construct { .concat('*') .join(','); - Object.entries(ecsConfig.tasks).forEach(([name, definition]) => { - const taskResult = this.createTaskDefinition( - name, - config, - definition, - baseEnvironment, - ecsConfig, - volumes, - mountPoints, - logGroup, - taskRole, - executionRole - ); - const { taskDefinition, container } = taskResult; - - // Create ECS service for primary task - const serviceProps: Ec2ServiceProps = { - cluster: cluster, - serviceName: createCdkId([name], 32, 2), - taskDefinition: taskDefinition, - circuitBreaker: !config.region.includes('iso') ? { rollback: true } : undefined, - capacityProviderStrategies: [ - { capacityProvider: asgCapacityProvider.capacityProviderName, weight: 1 } - ] - }; + // Store configuration for later use by addTask method + this.config = config; + this.ecsConfig = ecsConfig; + this.vpc = vpc; + this.securityGroup = securityGroup; + this.logGroup = logGroup; + this.volumes = volumes; + this.mountPoints = mountPoints; + this.baseEnvironment = baseEnvironment; + this.autoScalingGroup = autoScalingGroup; + this.asgCapacityProvider = asgCapacityProvider; + } - const service = new Ec2Service(this, createCdkId([config.deploymentName, name, 'Ec2Svc']), serviceProps); - const scalableTaskCount = service.autoScaleTaskCount({ - minCapacity: 1, - // 10 is just a magic number we don't expect to hit because we don't have better data on this - maxCapacity: 10 - }); - service.node.addDependency(autoScalingGroup); - - // since our containers are using ephemeral ports, the load balancer must be allowed to access them - service.connections.allowFrom(loadBalancer, Port.allTcp()); - - // Create target groups for both services - const loadBalancerHealthCheckConfig = ecsConfig.loadBalancerConfig.healthCheckConfig; - - const targetGroup = listener.addTargets(createCdkId([identifier, name, 'TgtGrp']), { - targetGroupName: createCdkId([config.deploymentName, identifier, name], 32, 2).toLowerCase(), - healthCheck: { - path: loadBalancerHealthCheckConfig.path, - interval: Duration.seconds(loadBalancerHealthCheckConfig.interval), - timeout: Duration.seconds(loadBalancerHealthCheckConfig.timeout), - healthyThresholdCount: loadBalancerHealthCheckConfig.healthyThresholdCount, - unhealthyThresholdCount: loadBalancerHealthCheckConfig.unhealthyThresholdCount, - }, - port: 80, - targets: [service], - priority: definition.applicationTarget?.priority, - conditions: definition.applicationTarget?.conditions?.map(({ type, values }) => { + /** + * Add a task to the ECS cluster with its own target group and service. + * + * @param taskName - The name of the task (e.g., ECSTasks.REST, ECSTasks.MCPWORKBENCH) + * @param taskDefinition - The task definition configuration. Environment variables within task definition will be merged with + * cluster environment variables. + * @param identifier - The identifier for naming resources + * @returns Object containing the created service and target group + */ + public addTask ( + taskName: ECSTasks, + taskDefinition: TaskDefinition, + ): { service: Ec2Service; targetGroup?: ApplicationTargetGroup } { + // Retrieve task role and execution role for the task + const taskRole = Role.fromRoleArn( + this, + createCdkId([taskName, 'TR']), + StringParameter.valueForStringParameter(this, `${this.config.deploymentPrefix}/roles/${taskName}`), + ); + const executionRole = Role.fromRoleArn( + this, + createCdkId([taskName, 'ER']), + StringParameter.valueForStringParameter(this, `${this.config.deploymentPrefix}/roles/${taskName}EX`), + ); + + const taskResult = this.createTaskDefinition( + taskName, + this.config, + taskDefinition, + this.ecsConfig, + this.volumes, + this.mountPoints, + this.logGroup, + taskRole, + executionRole + ); + const { taskDefinition: ec2TaskDefinition, container } = taskResult; + + // Store references + this.containers[taskName] = container; + this.taskRoles[taskName] = taskRole; + + // Create ECS service + const serviceProps: Ec2ServiceProps = { + cluster: this.cluster, + serviceName: createCdkId([taskName], 32, 2), + taskDefinition: ec2TaskDefinition, + circuitBreaker: !this.config.region.includes('iso') ? { rollback: true } : undefined, + capacityProviderStrategies: [ + { capacityProvider: this.asgCapacityProvider.capacityProviderName, weight: 1 } + ] + }; + + const service = new Ec2Service(this, createCdkId([this.config.deploymentName, taskName, 'Ec2Svc']), serviceProps); + service.node.addDependency(this.autoScalingGroup); + + // Store service reference + this.services[taskName] = service; + + // Allow load balancer to access the service + service.connections.allowFrom(this.loadBalancer, Port.allTcp()); + + const loadBalancerHealthCheckConfig = this.ecsConfig.loadBalancerConfig.healthCheckConfig; + + const targetGroup = this.listener.addTargets(createCdkId([this.identifier, taskName, 'TgtGrp']), { + targetGroupName: createCdkId([this.config.deploymentName, this.identifier, taskName], 32, 2).toLowerCase(), + healthCheck: { + path: loadBalancerHealthCheckConfig.path, + interval: Duration.seconds(loadBalancerHealthCheckConfig.interval), + timeout: Duration.seconds(loadBalancerHealthCheckConfig.timeout), + healthyThresholdCount: loadBalancerHealthCheckConfig.healthyThresholdCount, + unhealthyThresholdCount: loadBalancerHealthCheckConfig.unhealthyThresholdCount, + }, + port: 80, + targets: [service], + ...(taskDefinition.applicationTarget?.priority && { + priority: taskDefinition.applicationTarget.priority, + conditions: taskDefinition.applicationTarget.conditions?.map(({ type, values }) => { switch (type) { case 'pathPatterns': return ListenerCondition.pathPatterns(values); } }) - }); - - scalableTaskCount.scaleOnRequestCount(createCdkId([identifier, 'ScalingPolicy']), { - requestsPerTarget: ecsConfig.autoScalingConfig.metricConfig.targetValue, - targetGroup, - scaleInCooldown: Duration.seconds(ecsConfig.autoScalingConfig.metricConfig.duration), - scaleOutCooldown: Duration.seconds(ecsConfig.autoScalingConfig.metricConfig.duration) - }); - - // Store in maps for future reference - const ecsTasksKey = name as keyof typeof ECSTasks; - this.containers[ecsTasksKey] = container; - this.taskRoles[ecsTasksKey] = taskRole; - this.services[ecsTasksKey] = service; - this.targetGroups[ecsTasksKey] = targetGroup; + }) }); + + return { service, targetGroup }; } } diff --git a/lib/api-base/fastApiContainer.ts b/lib/api-base/fastApiContainer.ts index 3f2a533df..0ffca9d5a 100644 --- a/lib/api-base/fastApiContainer.ts +++ b/lib/api-base/fastApiContainer.ts @@ -14,29 +14,20 @@ limitations under the License. */ -import { CfnOutput, Duration } from 'aws-cdk-lib'; +import { CfnOutput } from 'aws-cdk-lib'; import { ITable } from 'aws-cdk-lib/aws-dynamodb'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; -import { AmiHardwareType, ContainerDefinition } from 'aws-cdk-lib/aws-ecs'; -import { IRole } from 'aws-cdk-lib/aws-iam'; +import { AmiHardwareType } from 'aws-cdk-lib/aws-ecs'; import { Construct } from 'constructs'; import { dump as yamlDump } from 'js-yaml'; import { ECSCluster, ECSTasks } from './ecsCluster'; import { BaseProps, Ec2Metadata, ECSConfig, EcsSourceType } from '../schema'; import { Vpc } from '../networking/vpc'; -import { MCP_WORKBENCH_PATH, REST_API_PATH } from '../util'; +import { REST_API_PATH } from '../util'; import * as child_process from 'child_process'; import * as path from 'path'; -import { letIfDefined } from '../util/common-functions'; -import { Bucket } from 'aws-cdk-lib/aws-s3'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; -import * as events from 'aws-cdk-lib/aws-events'; -import * as targets from 'aws-cdk-lib/aws-events-targets'; -import * as lambda from 'aws-cdk-lib/aws-lambda'; -import { Effect, PolicyDocument, PolicyStatement, Role, ServicePrincipal } from 'aws-cdk-lib/aws-iam'; -import { LAMBDA_PATH } from '../util'; -import { getDefaultRuntime } from './utils'; // This is the amount of memory to buffer (or subtract off) from the total instance memory, if we don't include this, // the container can have a hard time finding available RAM resources to start and the tasks will fail deployment @@ -44,6 +35,7 @@ const INSTANCE_MEMORY_RESERVATION = 1024; const SERVE_CONTAINER_MEMORY_RESERVATION = 1024 * 2; const WORKBENCH_CONTAINER_MEMORY_RESERVATION = 1024; + /** * Properties for FastApiContainer Construct. * @@ -56,17 +48,15 @@ type FastApiContainerProps = { securityGroup: ISecurityGroup; tokenTable: ITable | undefined; vpc: Vpc; + managementKeyName: string; } & BaseProps; /** * FastApiContainer Construct. */ export class FastApiContainer extends Construct { - /** Map of all container definitions by identifier */ - public readonly containers: ContainerDefinition[] = []; - - /** Map of all task roles by identifier */ - public readonly taskRoles: Partial> = {}; + /** ECS Cluster **/ + public readonly apiCluster: ECSCluster; /** FastAPI URL **/ public readonly endpoint: string; @@ -79,31 +69,57 @@ export class FastApiContainer extends Construct { constructor (scope: Construct, id: string, props: FastApiContainerProps) { super(scope, id); - const { config, securityGroup, tokenTable, vpc } = props; + const { config, securityGroup, tokenTable, vpc, managementKeyName} = props; + + const instanceType = 'm5.large'; + const buildArgs: Record | undefined = { BASE_IMAGE: config.baseImage, PYPI_INDEX_URL: config.pypiConfig.indexUrl, PYPI_TRUSTED_HOST: config.pypiConfig.trustedHost, LITELLM_CONFIG: yamlDump(config.litellmConfig), }; - const baseEnvironment: Record = { + + // Add build config overrides if provided + if (config.restApiConfig.buildConfig?.NODEENV_CACHE_DIR) { + buildArgs.NODEENV_CACHE_DIR = config.restApiConfig.buildConfig.NODEENV_CACHE_DIR; + } + + // Add MCP Workbench build config overrides if provided + if (config.mcpWorkbenchBuildConfig) { + Object.entries(config.mcpWorkbenchBuildConfig).forEach(([key, value]) => { + if (value) { + buildArgs[key] = value; + } + }); + } + + // Environment variables for all containers + const environment: Record = { LOG_LEVEL: config.logLevel, AWS_REGION: config.region, AWS_REGION_NAME: config.region, // for supporting SageMaker endpoints in LiteLLM - THREADS: Ec2Metadata.get('m5.large').vCpus.toString(), - LITELLM_KEY: config.litellmConfig.db_key, - OPENAI_API_KEY: config.litellmConfig.db_key, - TIKTOKEN_CACHE_DIR: '/app/TIKTOKEN_CACHE', + THREADS: Ec2Metadata.get(instanceType).vCpus.toString(), USE_AUTH: 'true', AUTHORITY: config.authConfig!.authority, CLIENT_ID: config.authConfig!.clientId, ADMIN_GROUP: config.authConfig!.adminGroup, USER_GROUP: config.authConfig!.userGroup, JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty, + MANAGEMENT_KEY_NAME: managementKeyName }; if (tokenTable) { - baseEnvironment.TOKEN_TABLE_NAME = tokenTable.tableName; + environment.TOKEN_TABLE_NAME = tokenTable.tableName; + } + + // Requires mount point /etc/pki from host + if (config.region.includes('iso')) { + environment.SSL_CERT_DIR = '/etc/pki/tls/certs'; + environment.SSL_CERT_FILE = config.certificateAuthorityBundle; + environment.REQUESTS_CA_BUNDLE = config.certificateAuthorityBundle; + environment.AWS_CA_BUNDLE = config.certificateAuthorityBundle; + environment.CURL_CA_BUNDLE = config.certificateAuthorityBundle; } // Pre-generate the tiktoken cache to ensure it does not attempt to fetch data from the internet at runtime. @@ -125,7 +141,7 @@ export class FastApiContainer extends Construct { path: REST_API_PATH, type: EcsSourceType.ASSET }; - const instanceType = 'm5.large'; + const healthCheckConfig = { command: ['CMD-SHELL', 'exit 0'], interval: 10, @@ -149,46 +165,9 @@ export class FastApiContainer extends Construct { } }, buildArgs, - tasks: { - [ECSTasks.REST]: { - environment: baseEnvironment, - containerConfig: { - image: restApiImage, - healthCheckConfig, - environment: {}, - sharedMemorySize: 0 - }, - // set a softlimit of what we expect to use - containerMemoryReservationMiB: SERVE_CONTAINER_MEMORY_RESERVATION - }, - [ECSTasks.MCPWORKBENCH]: { - environment: {...baseEnvironment, - RCLONE_CONFIG_S3_REGION: config.region, - MCPWORKBENCH_BUCKET: [config.deploymentName, config.deploymentStage, 'MCPWorkbench', config.accountNumber].join('-').toLowerCase(), - }, - containerConfig: { - image: { - baseImage: config.baseImage, - path: MCP_WORKBENCH_PATH, - type: EcsSourceType.ASSET - }, - healthCheckConfig, - environment: {}, - sharedMemorySize: 0, - privileged: true - }, - applicationTarget: { - port: 8000, - priority: 80, - conditions: [ - { type: 'pathPatterns', values: ['/v2/mcp/*'] } - ] - }, - containerMemoryReservationMiB: WORKBENCH_CONTAINER_MEMORY_RESERVATION, - } - }, + tasks: {}, // reserve at least enough memory for each task and a buffer for the instance to use - containerMemoryBuffer: Ec2Metadata.get(instanceType).memory - (INSTANCE_MEMORY_RESERVATION + SERVE_CONTAINER_MEMORY_RESERVATION + WORKBENCH_CONTAINER_MEMORY_RESERVATION), + containerMemoryBuffer: Ec2Metadata.get(instanceType).memory - (INSTANCE_MEMORY_RESERVATION + SERVE_CONTAINER_MEMORY_RESERVATION + (config.deployMcpWorkbench ? WORKBENCH_CONTAINER_MEMORY_RESERVATION : 0)), instanceType, internetFacing: config.restApiConfig.internetFacing, loadBalancerConfig: { @@ -209,99 +188,38 @@ export class FastApiContainer extends Construct { ecsConfig, config, securityGroup, - vpc - }); - - const workbenchService = apiCluster.services.MCPWORKBENCH; - - // Create Lambda function to handle S3 events and trigger MCP Workbench service redeployment - const s3EventHandlerRole = new Role(this, 'S3EventHandlerRole', { - assumedBy: new ServicePrincipal('lambda.amazonaws.com'), - inlinePolicies: { - 'S3EventHandlerPolicy': new PolicyDocument({ - statements: [ - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'logs:CreateLogGroup', - 'logs:CreateLogStream', - 'logs:PutLogEvents' - ], - resources: [`arn:${config.partition}:logs:*:*:*`] - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'ecs:UpdateService', - 'ecs:DescribeServices', - 'ecs:DescribeClusters' - ], - resources: [ - `arn:${config.partition}:ecs:${config.region}:*:cluster/${workbenchService?.cluster?.clusterName}*`, - `arn:${config.partition}:ecs:${config.region}:*:service/${workbenchService?.cluster?.clusterName}*/${workbenchService?.serviceName}*` - ] - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'ssm:GetParameter' - ], - resources: [ - `arn:${config.partition}:ssm:${config.region}:*:parameter${config.deploymentPrefix}/deploymentName` - ] - }) - ] - }) - } + vpc, + environment }); - const s3EventHandlerLambda = new lambda.Function(this, 'S3EventHandlerLambda', { - runtime: getDefaultRuntime(), - handler: 'mcp_workbench.s3_event_handler.handler', - code: lambda.Code.fromAsset(config.lambdaPath ?? LAMBDA_PATH), - timeout: Duration.minutes(2), - role: s3EventHandlerRole, + // Add the REST API task to the cluster (default target, no priority/conditions) + apiCluster.addTask(ECSTasks.REST, { environment: { - DEPLOYMENT_PREFIX: config.deploymentPrefix!, - API_NAME: props.apiName, - ECS_CLUSTER_NAME: workbenchService!.cluster?.clusterName, - MCPWORKBENCH_SERVICE_NAME: workbenchService!.serviceName - } - }); - - // Create EventBridge rule to trigger Lambda when S3 objects are created/deleted - const rescanMcpWorkbenchRule = new events.Rule(this, 'RescanMCPWorkbenchRule', { - eventPattern: { - source: ['aws.s3', 'debug'], - detailType: [ - 'Object Created', - 'Object Deleted' - ], - detail: { - bucket: { - name: [[config.deploymentName, config.deploymentStage, 'MCPWorkbench', config.accountNumber].join('-').toLowerCase()] - } - } + LITELLM_KEY: config.litellmConfig.db_key, + OPENAI_API_KEY: config.litellmConfig.db_key, + TIKTOKEN_CACHE_DIR: '/app/TIKTOKEN_CACHE', }, + containerConfig: { + image: restApiImage, + healthCheckConfig, + environment: {}, + sharedMemorySize: 0 + }, + containerMemoryReservationMiB: SERVE_CONTAINER_MEMORY_RESERVATION, + applicationTarget: { + port: 8080 + } }); - rescanMcpWorkbenchRule.addTarget(new targets.LambdaFunction(s3EventHandlerLambda, { - retryAttempts: 2, - maxEventAge: Duration.minutes(5) - })); - if (tokenTable) { - Object.entries(apiCluster.taskRoles).forEach(([, role]) => { - tokenTable.grantReadData(role); - }); + // Grant token table access to REST API task role only + const restTaskRole = apiCluster.taskRoles[ECSTasks.REST]; + if (restTaskRole) { + tokenTable.grantReadData(restTaskRole); + } } - letIfDefined(apiCluster.taskRoles.MCPWORKBENCH, (taskRole) => { - const bucketName = [config.deploymentName, config.deploymentStage, 'MCPWorkbench', config.accountNumber].join('-').toLowerCase(); - const workbenchBucket = Bucket.fromBucketName(scope, 'MCPWorkbenchBucket', bucketName); - workbenchBucket.grantRead(taskRole); - }); - + this.apiCluster = apiCluster; this.endpoint = apiCluster.endpointUrl; new StringParameter(scope, 'FastApiEndpoint', { @@ -309,10 +227,6 @@ export class FastApiContainer extends Construct { stringValue: this.endpoint }); - // Update - this.containers = Object.values(apiCluster.containers); - this.taskRoles = apiCluster.taskRoles; - // CFN output new CfnOutput(this, `${props.apiName}Url`, { value: apiCluster.endpointUrl, diff --git a/lib/api-base/utils.ts b/lib/api-base/utils.ts index 0d28305b5..3c5f40d59 100644 --- a/lib/api-base/utils.ts +++ b/lib/api-base/utils.ts @@ -85,10 +85,11 @@ export function registerAPIEndpoint ( authorizer?: IAuthorizer, role?: IRole, ): IFunction { - const functionId = `${ + // Validate the function id + const functionId = ( funcDef.id || - [cdk.Stack.of(scope).stackName, funcDef.resource, funcDef.name, funcDef.disambiguator].filter(Boolean).join('-') - }`; + [cdk.Stack.of(scope).stackName, funcDef.resource, funcDef.name, funcDef.disambiguator].filter(Boolean).join('-') + ).replace(/[^a-zA-Z0-9-_]/g, '-').slice(0, 63); const functionResource = getOrCreateResource(scope, api.root, funcDef.path.split('/')); let handler; diff --git a/lib/chat/api/configuration.ts b/lib/chat/api/configuration.ts index 7531bf708..26824f674 100644 --- a/lib/chat/api/configuration.ts +++ b/lib/chat/api/configuration.ts @@ -46,7 +46,7 @@ type ConfigurationApiProps = { rootResourceId: string; securityGroups: ISecurityGroup[]; vpc: Vpc; - mcpApi: McpApi; + mcpApi?: McpApi; } & BaseProps; /** @@ -88,11 +88,8 @@ export class ConfigurationApi extends Construct { removalPolicy: config.removalPolicy, }); - const mcpServersTable = dynamodb.Table.fromTableName(this, 'McpServersTable', mcpApi.mcpServersTableNameParameter.stringValue); const lambdaRole: IRole = createLambdaRole(this, config.deploymentName, 'ConfigurationApi', this.configTable.tableArn, config.roles?.LambdaConfigurationApiExecutionRole); - mcpServersTable.grantReadWriteData(lambdaRole); - // Populate the App Config table with default config const date = new Date(); new AwsCustomResource(this, 'lisa-init-ddb-config', { @@ -103,36 +100,42 @@ export class ConfigurationApi extends Construct { parameters: { TableName: this.configTable.tableName, Item: { - 'versionId': {'N': '0'}, - 'changedBy': {'S': 'System'}, - 'configScope': {'S': 'global'}, - 'changeReason': {'S': 'Initial deployment default config'}, - 'createdAt': {'S': Math.round(date.getTime() / 1000).toString()}, - 'configuration': {'M': { - 'enabledComponents': {'M': { - 'deleteSessionHistory': {'BOOL': 'True'}, - 'viewMetaData': {'BOOL': 'True'}, - 'editKwargs': {'BOOL': 'True'}, - 'editPromptTemplate': {'BOOL': 'True'}, - 'editChatHistoryBuffer': {'BOOL': 'True'}, - 'editNumOfRagDocument': {'BOOL': 'True'}, - 'uploadRagDocs': {'BOOL': 'True'}, - 'uploadContextDocs': {'BOOL': 'True'}, - 'documentSummarization': {'BOOL': 'True'}, - 'showRagLibrary': {'BOOL': 'True'}, - 'showMcpWorkbench': {'BOOL': 'False'}, - 'showPromptTemplateLibrary': {'BOOL': 'True'}, - 'mcpConnections': {'BOOL': 'True'}, - 'modelLibrary': {'BOOL': 'True'}, - 'encryptSession': {'BOOL': 'False'}, - }}, - 'systemBanner': {'M': { - 'isEnabled': {'BOOL': 'False'}, - 'text': {'S': ''}, - 'textColor': {'S': ''}, - 'backgroundColor': {'S': ''} - }} - }} + 'versionId': { 'N': '0' }, + 'changedBy': { 'S': 'System' }, + 'configScope': { 'S': 'global' }, + 'changeReason': { 'S': 'Initial deployment default config' }, + 'createdAt': { 'S': Math.round(date.getTime() / 1000).toString() }, + 'configuration': { + 'M': { + 'enabledComponents': { + 'M': { + 'deleteSessionHistory': { 'BOOL': 'True' }, + 'viewMetaData': { 'BOOL': 'True' }, + 'editKwargs': { 'BOOL': 'True' }, + 'editPromptTemplate': { 'BOOL': 'True' }, + 'editChatHistoryBuffer': { 'BOOL': 'True' }, + 'editNumOfRagDocument': { 'BOOL': 'True' }, + 'uploadRagDocs': { 'BOOL': 'True' }, + 'uploadContextDocs': { 'BOOL': 'True' }, + 'documentSummarization': { 'BOOL': 'True' }, + 'showRagLibrary': { 'BOOL': 'True' }, + 'showMcpWorkbench': { 'BOOL': 'False' }, + 'showPromptTemplateLibrary': { 'BOOL': 'True' }, + 'mcpConnections': { 'BOOL': 'True' }, + 'modelLibrary': { 'BOOL': 'True' }, + 'encryptSession': { 'BOOL': 'False' }, + } + }, + 'systemBanner': { + 'M': { + 'isEnabled': { 'BOOL': 'False' }, + 'text': { 'S': '' }, + 'textColor': { 'S': '' }, + 'backgroundColor': { 'S': '' } + } + } + } + } }, }, }, @@ -146,13 +149,15 @@ export class ConfigurationApi extends Construct { const fastApiEndpoint = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/serve/endpoint`); - const environment = { + let environment = { CONFIG_TABLE_NAME: this.configTable.tableName, - FASTAPI_ENDPOINT: fastApiEndpoint, - // add MCP_SERVERS_TABLE_NAME so we can update it if the configuration changes - MCP_SERVERS_TABLE_NAME: mcpServersTable.tableName + FASTAPI_ENDPOINT: fastApiEndpoint }; + if (mcpApi) { + this.createMcpApiTable(mcpApi, lambdaRole, environment); + } + // Create API Lambda functions const apis: PythonLambdaFunction[] = [ { @@ -196,4 +201,12 @@ export class ConfigurationApi extends Construct { } }); } + + private createMcpApiTable (mcpApi: McpApi, lambdaRole: IRole, environment: Record) { + const mcpServersTable = dynamodb.Table.fromTableName(this, 'McpServersTable', mcpApi.mcpServersTableNameParameter.stringValue); + mcpServersTable.grantReadWriteData(lambdaRole); + + // add MCP_SERVERS_TABLE_NAME so we can update it if the configuration changes + environment.MCP_SERVERS_TABLE_NAME = mcpServersTable.tableName; + } } diff --git a/lib/chat/chatConstruct.ts b/lib/chat/chatConstruct.ts index 470f98f7e..c3b5355c3 100644 --- a/lib/chat/chatConstruct.ts +++ b/lib/chat/chatConstruct.ts @@ -51,14 +51,15 @@ export class LisaChatApplicationConstruct extends Construct { const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; - const mcpApi = new McpApi(scope, 'McpApi', { - authorizer, - config, - restApiId, - rootResourceId, - securityGroups, - vpc, - }); + const mcpApi = config.deployMcpWorkbench ? + new McpApi(scope, 'McpApi', { + authorizer, + config, + restApiId, + rootResourceId, + securityGroups, + vpc, + }) : undefined; // Create Configuration API first to get the configuration table const configurationApi = new ConfigurationApi(scope, 'ConfigurationApi', { @@ -68,7 +69,7 @@ export class LisaChatApplicationConstruct extends Construct { rootResourceId, securityGroups, vpc, - mcpApi + ...(config.deployMcpWorkbench ? { mcpApi } : {}) }); // Add REST API Lambdas to APIGW diff --git a/lib/core/apiBaseConstruct.ts b/lib/core/apiBaseConstruct.ts index 7cb27dcc4..8467309c7 100644 --- a/lib/core/apiBaseConstruct.ts +++ b/lib/core/apiBaseConstruct.ts @@ -23,7 +23,6 @@ import { BaseProps } from '../schema'; import { Vpc } from '../networking/vpc'; import { Role } from 'aws-cdk-lib/aws-iam'; import { ITable } from 'aws-cdk-lib/aws-dynamodb'; -import McpWorkbenchConstruct from '../serve/mcpWorkbenchConstruct'; export type LisaApiBaseProps = { vpc: Vpc; @@ -80,13 +79,6 @@ export class LisaApiBaseConstruct extends Construct { binaryMediaTypes: ['font/*', 'image/*'], }); - new McpWorkbenchConstruct(this, id + 'McpWorkbench', { - ...props, - authorizer: this.authorizer!, - restApiId: restApi.restApiId, - rootResourceId: restApi.restApiRootResourceId, - securityGroups: [props.vpc.securityGroups.ecsModelAlbSg], - }); this.restApi = restApi; this.restApiId = restApi.restApiId; diff --git a/lib/core/apiDeploymentConstruct.ts b/lib/core/apiDeploymentConstruct.ts index 288bd486a..a02f9b11b 100644 --- a/lib/core/apiDeploymentConstruct.ts +++ b/lib/core/apiDeploymentConstruct.ts @@ -49,7 +49,7 @@ export class LisaApiDeploymentConstruct extends Construct { const api_url = `https://${restApiId}.execute-api.${Aws.REGION}.${Aws.URL_SUFFIX}/${config.deploymentStage}`; new StringParameter(scope, 'LisaApiDeploymentStringParameter', { - parameterName: `${config.deploymentPrefix}/${config.deploymentName}/${config.appName}/LisaApiUrl`, + parameterName: `${config.deploymentPrefix}/LisaApiUrl`, stringValue: api_url, description: 'API Gateway URL for LISA', }); diff --git a/lib/core/iam/ecs.json b/lib/core/iam/ecs.json index 38b59a068..17710ad00 100644 --- a/lib/core/iam/ecs.json +++ b/lib/core/iam/ecs.json @@ -92,17 +92,6 @@ ], "Resource": "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:*" }, - { - "Effect": "Allow", - "Action": [ - "logs:DescribeLogGroups", - "logs:CreateLogStream", - "logs:FilterLogEvents", - "logs:PutLogEvents", - "logs:DescribeLogStreams" - ], - "Resource": "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:*" - }, { "Effect": "Allow", "Action": [ @@ -196,6 +185,28 @@ "Effect": "Allow", "Action": "secretsmanager:GetSecretValue", "Resource": "*" + }, + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket" + ], + "Resource": [ + "arn:${AWS::Partition}:s3:::*-mcpworkbench-*", + "arn:${AWS::Partition}:s3:::*-mcpworkbench-*/*" + ] + }, + { + "Effect": "Allow", + "Action": [ + "dynamodb:GetItem", + "dynamodb:Query", + "dynamodb:Scan" + ], + "Resource": "arn:${AWS::Partition}:dynamodb:${AWS::Region}:${AWS::AccountId}:table/*-LISAApiTokenTable" } ] } diff --git a/lib/core/iam/roles.ts b/lib/core/iam/roles.ts index 8118a4424..34788c7d9 100644 --- a/lib/core/iam/roles.ts +++ b/lib/core/iam/roles.ts @@ -28,7 +28,9 @@ export enum Roles { ECS_MODEL_DEPLOYER_ROLE = 'ECSModelDeployerRole', ECS_MODEL_TASK_ROLE = 'ECSModelTaskRole', ECS_REST_API_EX_ROLE = 'ECSRestApiExRole', + ECS_MCPWORKBENCH_API_EX_ROLE = 'ECSMcpWorkbenchApiExRole', ECS_REST_API_ROLE = 'ECSRestApiRole', + ECS_MCPWORKBENCH_API_ROLE = 'ECSMcpWorkbenchApiRole', LAMBDA_CONFIGURATION_API_EXECUTION_ROLE = 'LambdaConfigurationApiExecutionRole', LAMBDA_EXECUTION_ROLE = 'LambdaExecutionRole', MODEL_API_ROLE = 'ModelApiRole', @@ -53,7 +55,9 @@ export const RoleNames: Record = { [Roles.ECS_MODEL_DEPLOYER_ROLE]: 'ECSModelDeployerRole', [Roles.ECS_MODEL_TASK_ROLE]: 'ECSModelTaskRole', [Roles.ECS_REST_API_EX_ROLE]: 'ECSRestApiExRole', + [Roles.ECS_MCPWORKBENCH_API_EX_ROLE]: 'ECSMcpWorkbenchApiExRole', [Roles.ECS_REST_API_ROLE]: 'ECSRestApiRole', + [Roles.ECS_MCPWORKBENCH_API_ROLE]: 'ECSMcpWorkbenchApiRole', [Roles.LAMBDA_CONFIGURATION_API_EXECUTION_ROLE]: 'LambdaConfigurationApiExecutionRole', [Roles.LAMBDA_EXECUTION_ROLE]: 'LambdaExecutionRole', [Roles.MODEL_API_ROLE]: 'ModelApiRole', diff --git a/lib/docs/.vitepress/config.mts b/lib/docs/.vitepress/config.mts index 59be67204..b51a4b1b9 100644 --- a/lib/docs/.vitepress/config.mts +++ b/lib/docs/.vitepress/config.mts @@ -55,6 +55,7 @@ const navLinks = [ { text: 'Model Compatibility', link: '/config/model-compatibility' }, { text: 'Model Management API', link: '/config/model-management-api' }, { text: 'Model Management UI', link: '/config/model-management-ui' }, + { text: 'Guardrails', link: '/config/guardrails' }, { text: 'Usage & Features', link: '/config/usage' }, { text: 'RAG Vector Stores', link: '/config/vector-stores' }, { text: 'Langfuse Tracing', link: '/config/langfuse-tracing'}, diff --git a/lib/docs/admin/deploy.md b/lib/docs/admin/deploy.md index cdd605d82..888d15fa0 100644 --- a/lib/docs/admin/deploy.md +++ b/lib/docs/admin/deploy.md @@ -283,7 +283,7 @@ Update your `config-custom.yaml` to point to ADC-accessible repositories: ```yaml # Configure pip to use ADC-accessible PyPI mirror -pipConfig: +pypiConfig: indexUrl: https://your-adc-pypi-mirror.com/simple trustedHost: your-adc-pypi-mirror.com @@ -293,9 +293,45 @@ npmConfig: # Use ADC-accessible base images for LISA-Serve and Batch Ingestion baseImage: /python:3.11 + +# Configure offline build dependencies for REST API (nodeenv for prisma-client-py) +restApiConfig: + buildConfig: + NODEENV_CACHE_DIR: "./nodeenv-cache" # Path relative to lib/serve/rest-api/ + +# Configure offline build dependencies for MCP Workbench (S6 Overlay and rclone) +mcpWorkbenchBuildConfig: + S6_OVERLAY_NOARCH_SOURCE: "./s6-overlay-noarch.tar.xz" # Path relative to lib/serve/mcp-workbench/ + S6_OVERLAY_ARCH_SOURCE: "./s6-overlay-x86_64.tar.xz" # Path relative to lib/serve/mcp-workbench/ + RCLONE_SOURCE: "./rclone-linux-amd64.zip" # Path relative to lib/serve/mcp-workbench/ ``` You'll also want any model hosting base containers available, e.g. vllm/vllm-openai:latest and ghcr.io/huggingface/text-embeddings-inference:latest +#### Preparing Offline Build Dependencies + +For environments without internet access during Docker builds, you can pre-cache required dependencies: + +**REST API nodeenv cache** (required by prisma-client-py): +```bash +# Create the cache directory in the REST API build context +python -m nodeenv lib/serve/rest-api/nodeenv-cache +``` + +**MCP Workbench dependencies** (S6 Overlay and rclone): +```bash +# Download S6 Overlay files +cd lib/serve/mcp-workbench/ +wget https://github.com/just-containers/s6-overlay/releases/download/v3.1.6.2/s6-overlay-noarch.tar.xz +wget https://github.com/just-containers/s6-overlay/releases/download/v3.1.6.2/s6-overlay-x86_64.tar.xz + +# Download rclone +wget https://github.com/rclone/rclone/releases/download/v1.71.0/rclone-v1.71.0-linux-amd64.zip + +cd ../../.. +``` + +These cached dependencies will be used during the Docker build process instead of downloading from the internet. + To utilize the prebuilt hosting model containers with self-hosted models, select `type: ecr` in the Model Deployment > Container Configs. ### Deployment Steps diff --git a/lib/docs/config/guardrails.md b/lib/docs/config/guardrails.md new file mode 100644 index 000000000..74246bb06 --- /dev/null +++ b/lib/docs/config/guardrails.md @@ -0,0 +1,326 @@ +# Guardrails + +## Overview + +Guardrails in LISA provide a powerful way to ensure safe and compliant model outputs through integration with AWS Bedrock Guardrails via LiteLLM. Guardrails can validate, filter, and control both input prompts and model responses, helping you enforce content policies, protect sensitive information, and maintain compliance standards. + +## Key Features + +- **AWS Bedrock Integration**: Leverages AWS Bedrock Guardrails for robust content filtering +- **Flexible Application Modes**: Apply guardrails `pre-call`, `during-call`, or `post-call` +- **Group-Based Access Control**: Target specific user groups with different guardrail policies +- **Per-Model Configuration**: Each model can have multiple guardrails with different settings +- **Automatic Application**: Guardrails are automatically applied based on user group membership + +## Architecture + +### Data Storage + +Guardrails are stored in a dedicated DynamoDB table with the following structure: + +- **Partition Key**: `guardrailId` (LiteLLM-generated unique identifier) +- **Sort Key**: `modelId` (the model this guardrail is associated with) +- **Global Secondary Index**: `ModelIdIndex` (allows querying all guardrails for a specific model) + +### Guardrail Application Flow + +1. **User Request**: User sends a request to invoke a model +2. **Group Extraction**: System extracts user's group memberships from JWT token +3. **Guardrail Lookup**: System queries DynamoDB for guardrails associated with the model +4. **Group Matching**: System determines which guardrails apply based on: + - Public guardrails (no `allowedGroups` specified) apply to all users + - Private guardrails apply only if user belongs to at least one `allowedGroups` +5. **Guardrail Injection**: Applicable guardrail are added to the LiteLLM request +6. **Validation**: AWS Bedrock Guardrails validate the request/response +7. **Response Handling**: If a guardrail is triggered, the admin configured guardrail response is returned + +## Prerequisites: Creating Guardrails in AWS Bedrock Console + +Before you can attach guardrails to models in LISA, you must first create the guardrails in the AWS Bedrock Console. + +### Steps to Create a Guardrail in AWS Bedrock Console + +1. **Navigate to AWS Bedrock Console** + - Open the AWS Console + - Navigate to Amazon Bedrock service + - Select "Guardrails" from the left navigation menu + +2. **Create a New Guardrail** + - Click "Create guardrail" + - Provide a name for your guardrail + - Define blocked messaging responses + - Configure guardrail policies: + - **Content filters**: Filter harmful content categories (hate, insults, sexual, violence, etc.) + - **Denied topics**: Define topics to block + - **Word filters**: Block or redact specific words/phrases + - **Sensitive information filters**: Redact PII (email, phone, SSN, etc.) + - **Contextual grounding**: Prevent hallucinations and ensure relevance + +3. **Configure Guardrail Settings** + - Configure version (`DRAFT` or numbered version) + +4. **Test Your Guardrail** + - Use the AWS Console's test functionality + - Verify the guardrail behaves as expected + +5. **Note the Guardrail Details** + - Copy the **Guardrail ID** (e.g., `abc123xyz`) + - Note the **Guardrail Version** (e.g., `DRAFT`, `1`, `2`) + - Alternatively, copy the full **Guardrail ARN** + +## Guardrail Configuration + +### Configuration Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `guardrailName` | string | Yes | Friendly name for the guardrail in LISA | +| `guardrailIdentifier` | string | Yes | AWS Bedrock Guardrail ARN or ID | +| `guardrailVersion` | string | No | Version to use (default: "DRAFT") | +| `mode` | string | No | When to apply: "pre_call", "during_call", or "post_call" (default: "pre_call") | +| `description` | string | No | Human-readable description of the guardrail's purpose | +| `allowedGroups` | array | No | List of user groups this guardrail applies to (empty = public) | + +### Guardrail Modes + +- `pre_call`: Validates input prompts before sending to the model +- `during_call`: Validates during model processing (streaming scenarios) +- `post_call`: Validates model responses after generation + +## Managing Guardrails via LISA Models API + +### Creating a Model with Guardrails + +Guardrails are attached to models as part of the model creation process. Include the `guardrailsConfig` field in your model creation request to apply a guardrail to a model: + +```bash +POST /{deploymentStage}/models + +{ + "modelId": "my-model-id", + "modelName": "MyModel", + "streaming": true, + "features": [], + "modelType": "textgen", + "guardrailsConfig": { + "guardrail-1": { + "guardrailName": "ContentFilter", + "guardrailIdentifier": "abc123xyz", + "guardrailVersion": "1", + "mode": "pre_call", + "description": "Filters harmful content from user inputs", + "allowedGroups": ["team-a", "team-b"] + }, + "guardrail-2": { + "guardrailName": "PIIProtection", + "guardrailIdentifier": "arn:aws:bedrock:us-east-1:123456789012:guardrail/xyz789", + "guardrailVersion": "DRAFT", + "mode": "post_call", + "description": "Redacts PII from model responses", + "allowedGroups": [] + } + } +} +``` + +**Notes:** +- The guardrail key (e.g., "guardrail-1") is an internal identifier +- `guardrailIdentifier` must match an existing AWS Bedrock Guardrail +- Empty `allowedGroups` means the guardrail applies to all users +- Multiple guardrails can be configured per model +- Guardrails are automatically registered in LiteLLM during model creation + +### Listing Guardrails + +#### List All Guardrails via LiteLLM API + +To view all guardrails registered in LiteLLM: + +```bash +GET /v2/serve/v2/guardrails/list +``` + +This endpoint queries LiteLLM directly and returns all registered guardrails across all models. + +### Updating Guardrails + +Update guardrails by sending a PUT request to update the model: + +```bash +PUT /{deploymentStage}/models/{modelId} + +{ + "guardrailsConfig": { + "guardrail-1": { + "guardrailName": "ContentFilter", + "guardrailIdentifier": "abc123xyz", + "guardrailVersion": "2", + "mode": "pre_call", + "description": "Updated content filter with stricter rules", + "allowedGroups": ["team-a", "team-b", "team-c"] + } + } +} +``` + +**Update Operations:** + +1. **Modify Existing Guardrail**: Include the guardrail with updated values +2. **Add New Guardrail**: Include a new guardrail key with complete configuration +3. **Remove Guardrail**: Set `markedForDeletion: true` on the guardrail + +### Deleting Guardrails + +Guardrails are deleted when their associated model is deleted or when they are marked for deletion: + +```bash +PUT /{deploymentStage}/models/{modelId} + +{ + "guardrailsConfig": { + "guardrail-1": { + "guardrailName": "ContentFilter", + "guardrailIdentifier": "abc123xyz", + "markedForDeletion": true + } + } +} +``` + +This operation: +1. Removes the guardrail from LiteLLM +2. Deletes all associated guardrails from DynamoDB +3. Removes guardrail configurations from LiteLLM + +**Note**: Deleting a model does NOT delete the underlying AWS Bedrock Guardrail. It only removes the association between the guardrail and the LISA model and removes the guardrail from LiteLLM. + +## Managing Guardrails via UI + +### Creating Guardrails During Model Creation + +1. **Create the Guardrail in AWS Bedrock Console first** (see Prerequisites section) +2. Navigate to **Model Management** and select **Create Model** +3. Fill in the base model configuration +4. Navigate to **Guardrails Configuration** +5. Click **Add Guardrail** +6. Configure the guardrail: + - **Guardrail Name**: Enter a friendly name + - **Guardrail Identifier**: Enter the AWS Bedrock Guardrail ARN or ID (from AWS Console) + - **Guardrail Version**: Specify version (default: DRAFT) + - **Mode**: Select when to apply (Pre Call, During Call, or Post Call) + - **Description** (optional): Describe the guardrail's purpose + - **Allowed Groups** (optional): Add group names that should have this guardrail +7. Click **Add** to add groups, or press Enter after typing a group name +8. Repeat steps 6-8 to add multiple guardrails +9. Finish remaining configuration steps +10. Click **Create Model** to finalize + +### Viewing Guardrails + +1. Navigate to **Model Management** +2. Select a model card +3. Select **Actions** and then **Update** +4. Guardrails will be displayed in the model details under **Guardrails Configuration** + +### Updating / Removing Guardrails + +1. Navigate to **Model Management** +2. Select the model card of the model you want to update +3. Select **Actions** and then **Update** +4. Navigate to **Guardrails Configuration** +5. Modify existing guardrails +6. Add new guardrails using the **Add Guardrail** button +7. Remove guardrails by clicking the **X** button +8. Navigate to final page and click **Update Model** to save changes + +**Note**: When editing a model, clicking the X button marks guardrails for deletion rather than removing them immediately. They will be deleted when you save the model. + +## Best Practices + +### 1. Design Guardrails in AWS Bedrock First + +- Test guardrails thoroughly in AWS Bedrock Console before attaching to models +- Create separate guardrails for different purposes (content filtering, PII, compliance) +- Use descriptive names to identify guardrail purposes + +### 2. Use Group-Based Access Appropriately + +- Start with public guardrails for baseline protection +- Add team-specific guardrails for specialized requirements +- Document which groups require which guardrails +- Regularly audit group memberships + +### 3. Monitor and Iterate + +- Adjust guardrail sensitivity based on false positives/negatives +- Update guardrails in AWS Bedrock as policies and requirements evolve +- Update model configurations in LISA to use new guardrail versions + +## Troubleshooting + +### Guardrails Not Being Applied + +**Symptom**: Requests are not being filtered as expected + +**Possible Causes**: +1. Guardrail doesn't exist in AWS Bedrock +2. Guardrail identifier is incorrect +3. Guardrail is not attached to model +4. Guardrail version attached to model is incorrect +4. User is not a member of the required groups +5. AWS Bedrock Guardrail is not accessible from LISA VPC + +**Resolution**: +1. Verify guardrail exists in AWS Bedrock Console +2. Check guardrail identifier configured in LISA matches AWS Bedrock Console +3. Verify user group memberships +4. Check guardrail configuration in model details +5. Check CloudWatch logs for errors +6. Check REST ECS Container for errors + +### Guardrail Updates Not Taking Effect + +**Symptom**: Updated guardrail configuration not being applied + +**Possible Causes**: +1. Model update did not complete successfully +2. Guardrail changes made in AWS Bedrock but version not updated in LISA +3. Cache issues with model configuration + +**Resolution**: +1. Check model status (should be "In Service") +2. Verify `guardrailVersion` in LISA matches the version in AWS Bedrock +3. Check state machine execution logs +4. Verify guardrail configuration + +### Invalid Guardrail Identifier Error + +**Symptom**: Error during model creation or update mentioning invalid guardrail + +**Possible Causes**: +1. Guardrail doesn't exist in AWS Bedrock +2. Incorrect guardrail ID or ARN +3. Guardrail in different AWS region + +**Resolution**: +1. Verify guardrail exists in AWS Bedrock Console in the correct region +2. Copy guardrail identifier directly from AWS Console + +### High Latency with Guardrails + +**Symptom**: Requests take significantly longer with guardrails enabled + +**Possible Causes**: +1. Too many guardrails configured +2. Complex guardrail rules in AWS Bedrock + +**Resolution**: +1. Reduce number of guardrails where possible +2. Optimize guardrail rules in AWS Bedrock Console +3. Consider using only critical guardrails for performance-sensitive applications + +## Additional Resources + +- [AWS Bedrock Guardrails Documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html) +- [LiteLLM Guardrails API](https://litellm-api.up.railway.app/#/Guardrails) +- LISA Model Management API Documentation diff --git a/lib/docs/config/role-overrides.md b/lib/docs/config/role-overrides.md index 6eafa8c6f..e3293542e 100644 --- a/lib/docs/config/role-overrides.md +++ b/lib/docs/config/role-overrides.md @@ -123,6 +123,28 @@ The example provided is an export from a deployed LISA instance based on Least P } } }, + "ECSMcpWorkbenchApiExRole": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "ecr:GetAuthorizationToken", + "ecr:BatchCheckLayerAvailability", + "ecr:GetDownloadUrlForLayer", + "ecr:BatchGetImage", + "logs:CreateLogStream", + "logs:PutLogEvents" + ], + "Resource": "*" + } + ], + "Version": "2012-10-17" + } + } + }, "ECSRestApiExRoleDefaultPolicy": { "Type": "AWS::IAM::Policy", "Properties": { @@ -736,6 +758,30 @@ The example provided is an export from a deployed LISA instance based on Least P "RoleName": "app-REST-Role" } }, + "ECSMcpWorkbenchApiRole": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Statement": [ + { + "Action": "sts:AssumeRole", + "Effect": "Allow", + "Principal": { + "Service": "ecs-tasks.amazonaws.com" + } + } + ], + "Version": "2012-10-17" + }, + "Description": "Allow MCP Workbench API task access to AWS resources", + "ManagedPolicyArns": [ + { + "Ref": "appECSPolicy361D8A62" + } + ], + "RoleName": "app-REST-Role" + } + }, "DocsDeployerRole": { "Type": "AWS::IAM::Role", "Properties": { diff --git a/lib/docs/index.ts b/lib/docs/index.ts index ae9dabd50..09822c3c3 100644 --- a/lib/docs/index.ts +++ b/lib/docs/index.ts @@ -17,7 +17,7 @@ import { Construct } from 'constructs'; import { LisaDocsConstruct, LisaDocsProps } from './docConstruct'; import { Stack } from 'aws-cdk-lib'; -export * from './docConstruct'; +export { LisaDocsConstruct, LisaDocsProps }; /** * Lisa Docs Stack diff --git a/lib/iam/iamConstruct.ts b/lib/iam/iamConstruct.ts index 82a7707c7..00eac6060 100644 --- a/lib/iam/iamConstruct.ts +++ b/lib/iam/iamConstruct.ts @@ -78,6 +78,10 @@ export class LisaServeIAMSConstruct extends Construct { id: 'REST', type: ECSTaskType.API, }, + { + id: 'MCPWORKBENCH', + type: ECSTaskType.API, + }, ]; ecsRoles.forEach((role) => { @@ -95,17 +99,18 @@ export class LisaServeIAMSConstruct extends Construct { description: `Role ARN for LISA ${role.type} ${role.id} ECS Task`, }); - if (config.roles) { - const executionRoleOverride = getRoleId(`ECS_${role.id}_${role.type}_EX_ROLE`.toUpperCase()); + const executionRoleId = createCdkId([role.id, 'ExRole']); + const executionRoleName = createCdkId([config.deploymentName, role.id, 'ExRole']); + const executionRole = config.roles ? ( // @ts-expect-error - dynamic key lookup of object - const executionRole = Role.fromRoleName(scope, createCdkId([role.id, 'ExRole']), config.roles[executionRoleOverride]); - - new StringParameter(scope, createCdkId([config.deploymentName, role.id, 'EX', 'SP']), { - parameterName: `${config.deploymentPrefix}/roles/${role.id}EX`, - stringValue: executionRole.roleArn, - description: `Role ARN for LISA ${role.type} ${role.id} ECS Execution`, - }); - } + Role.fromRoleName(scope, executionRoleId, config.roles[getRoleId(`ECS_${role.id}_${role.type}_EX_ROLE`.toUpperCase())]) + ) : this.createEcsExecutionRole(role, executionRoleId, executionRoleName); + + new StringParameter(scope, createCdkId([config.deploymentName, role.id, 'EX', 'SP']), { + parameterName: `${config.deploymentPrefix}/roles/${role.id}EX`, + stringValue: executionRole.roleArn, + description: `Role ARN for LISA ${role.type} ${role.id} ECS Execution`, + }); }); } @@ -175,4 +180,15 @@ export class LisaServeIAMSConstruct extends Construct { managedPolicies: [taskPolicy], }); } + + private createEcsExecutionRole (role: ECSRole, roleId: string, roleName: string): IRole { + return new Role(this.scope, roleId, { + assumedBy: new ServicePrincipal('ecs-tasks.amazonaws.com'), + roleName, + description: `Allow ${role.id} ${role.type} execution role to pull images and write logs`, + managedPolicies: [ + ManagedPolicy.fromAwsManagedPolicyName('service-role/AmazonECSTaskExecutionRolePolicy') + ], + }); + } } diff --git a/lib/models/guardrails-table.ts b/lib/models/guardrails-table.ts new file mode 100644 index 000000000..d5f5a3793 --- /dev/null +++ b/lib/models/guardrails-table.ts @@ -0,0 +1,66 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { AttributeType, BillingMode, Table, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; +import { Construct } from 'constructs'; + +/** + * Properties for GuardrailsTable Construct. + */ +export type GuardrailsTableProps = { + deploymentPrefix: string; + removalPolicy: any; +}; + +/** + * DynamoDB table for storing Bedrock Guardrails configurations per model + */ +export class GuardrailsTable extends Construct { + public readonly table: Table; + + constructor (scope: Construct, id: string, props: GuardrailsTableProps) { + super(scope, id); + + const { removalPolicy } = props; + + // Create the guardrails table with composite key structure + this.table = new Table(this, 'GuardrailsTable', { + partitionKey: { + name: 'guardrailId', + type: AttributeType.STRING + }, + sortKey: { + name: 'modelId', + type: AttributeType.STRING + }, + billingMode: BillingMode.PAY_PER_REQUEST, + encryption: TableEncryption.AWS_MANAGED, + removalPolicy: removalPolicy, + }); + + this.table.addGlobalSecondaryIndex({ + indexName: 'ModelIdIndex', + partitionKey: { + name: 'modelId', + type: AttributeType.STRING + }, + sortKey: { + name: 'guardrailId', + type: AttributeType.STRING + }, + }); + } +} diff --git a/lib/models/model-api.ts b/lib/models/model-api.ts index 0ec602a99..1689e0445 100644 --- a/lib/models/model-api.ts +++ b/lib/models/model-api.ts @@ -42,7 +42,7 @@ import { Vpc } from '../networking/vpc'; import { ECSModelDeployer } from './ecs-model-deployer'; import { DockerImageBuilder } from './docker-image-builder'; import { DeleteModelStateMachine } from './state-machine/delete-model'; -import { AttributeType, BillingMode, Table, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; +import { AttributeType, BillingMode, ITable, Table, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; import { CreateModelStateMachine } from './state-machine/create-model'; import { UpdateModelStateMachine } from './state-machine/update-model'; import { Secret } from 'aws-cdk-lib/aws-secretsmanager'; @@ -60,6 +60,7 @@ import { LAMBDA_PATH } from '../util'; */ type ModelsApiProps = BaseProps & { authorizer?: IAuthorizer; + guardrailsTable: ITable; lisaServeEndpointUrlPs?: StringParameter; restApiId: string; rootResourceId: string; @@ -74,7 +75,7 @@ export class ModelsApi extends Construct { constructor (scope: Construct, id: string, props: ModelsApiProps) { super(scope, id); - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + const { authorizer, config, guardrailsTable, restApiId, rootResourceId, securityGroups, vpc } = props; const lisaServeEndpointUrlPs = props.lisaServeEndpointUrlPs ?? StringParameter.fromStringParameterName( scope, @@ -150,12 +151,13 @@ export class ModelsApi extends Construct { const stateMachinesLambdaRole = config.roles ? Role.fromRoleName(this, Roles.MODEL_SFN_LAMBDA_ROLE, config.roles.ModelsSfnLambdaRole) : - this.createStateMachineLambdaRole(modelTable.tableArn, dockerImageBuilder.dockerImageBuilderFn.functionArn, + this.createStateMachineLambdaRole(modelTable.tableArn, guardrailsTable.tableArn, dockerImageBuilder.dockerImageBuilderFn.functionArn, ecsModelDeployer.ecsModelDeployerFn.functionArn, lisaServeEndpointUrlPs.parameterArn, managementKeyName, config); const createModelStateMachine = new CreateModelStateMachine(this, 'CreateModelWorkflow', { config: config, modelTable: modelTable, + guardrailsTable: guardrailsTable, lambdaLayers: lambdaLayers, role: stateMachinesLambdaRole, vpc: vpc, @@ -171,6 +173,7 @@ export class ModelsApi extends Construct { const deleteModelStateMachine = new DeleteModelStateMachine(this, 'DeleteModelWorkflow', { config: config, modelTable: modelTable, + guardrailsTable: guardrailsTable, lambdaLayers: lambdaLayers, role: stateMachinesLambdaRole, vpc: vpc, @@ -183,6 +186,7 @@ export class ModelsApi extends Construct { const updateModelStateMachine = new UpdateModelStateMachine(this, 'UpdateModelWorkflow', { config: config, modelTable: modelTable, + guardrailsTable: guardrailsTable, lambdaLayers: lambdaLayers, role: stateMachinesLambdaRole, vpc: vpc, @@ -200,6 +204,7 @@ export class ModelsApi extends Construct { DELETE_SFN_ARN: deleteModelStateMachine.stateMachineArn, UPDATE_SFN_ARN: updateModelStateMachine.stateMachineArn, MODEL_TABLE_NAME: modelTable.tableName, + GUARDRAILS_TABLE_NAME: guardrailsTable.tableName, }; const lambdaRole: IRole = createLambdaRole(this, config.deploymentName, 'ModelApi', modelTable.tableArn, config.roles?.ModelApiRole); @@ -322,6 +327,21 @@ export class ModelsApi extends Construct { `${modelTable.tableArn}/*` ], }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'dynamodb:GetItem', + 'dynamodb:PutItem', + 'dynamodb:UpdateItem', + 'dynamodb:DeleteItem', + 'dynamodb:Query', + 'dynamodb:Scan', + ], + resources: [ + guardrailsTable.tableArn, + `${guardrailsTable.tableArn}/*` + ], + }), new PolicyStatement({ effect: Effect.ALLOW, actions: [ @@ -376,7 +396,7 @@ export class ModelsApi extends Construct { * @param managementKeyName - Name of the management key secret * @returns The created role */ - createStateMachineLambdaRole (modelTableArn: string, dockerImageBuilderFnArn: string, ecsModelDeployerFnArn: string, lisaServeEndpointUrlParamArn: string, managementKeyName: string, config: any): IRole { + createStateMachineLambdaRole (modelTableArn: string, guardrailTableArn: string ,dockerImageBuilderFnArn: string, ecsModelDeployerFnArn: string, lisaServeEndpointUrlParamArn: string, managementKeyName: string, config: any): IRole { return new Role(this, Roles.MODEL_SFN_LAMBDA_ROLE, { assumedBy: new ServicePrincipal('lambda.amazonaws.com'), managedPolicies: [ @@ -393,10 +413,13 @@ export class ModelsApi extends Construct { 'dynamodb:PutItem', 'dynamodb:UpdateItem', 'dynamodb:Scan', + 'dynamodb:Query' ], resources: [ modelTableArn, `${modelTableArn}/*`, + guardrailTableArn, + `${guardrailTableArn}/*`, ] }), new PolicyStatement({ @@ -498,6 +521,7 @@ export class ModelsApi extends Construct { actions: [ 'bedrock:InvokeModel', 'bedrock:InvokeModelWithResponseStream', + 'bedrock:ApplyGuardrail' ], resources: ['*'], // Bedrock model ARNs are dynamic and region-specific }), diff --git a/lib/models/modelsApiConstruct.ts b/lib/models/modelsApiConstruct.ts index 07ca79bbf..9d0448054 100644 --- a/lib/models/modelsApiConstruct.ts +++ b/lib/models/modelsApiConstruct.ts @@ -18,6 +18,7 @@ import { Stack, StackProps } from 'aws-cdk-lib'; import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { ITable } from 'aws-cdk-lib/aws-dynamodb'; import { Construct } from 'constructs'; import { Vpc } from '../networking/vpc'; @@ -28,6 +29,7 @@ import { StringParameter } from 'aws-cdk-lib/aws-ssm'; export type LisaModelsApiProps = BaseProps & StackProps & { authorizer?: IAuthorizer; + guardrailsTable: ITable; lisaServeEndpointUrlPs?: StringParameter; restApiId: string; rootResourceId: string; @@ -47,12 +49,13 @@ export class LisaModelsApiConstruct extends Construct { constructor (scope: Stack, id: string, props: LisaModelsApiProps) { super(scope, id); - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + const { authorizer, config, guardrailsTable, restApiId, rootResourceId, securityGroups, vpc } = props; // Add REST API Lambdas to APIGW new ModelsApi(scope, 'ModelsApi', { authorizer, config, + guardrailsTable, restApiId, rootResourceId, securityGroups, diff --git a/lib/models/state-machine/create-model.ts b/lib/models/state-machine/create-model.ts index 57c7a3d4c..a71675586 100644 --- a/lib/models/state-machine/create-model.ts +++ b/lib/models/state-machine/create-model.ts @@ -40,6 +40,7 @@ import { LAMBDA_PATH } from '../../util'; type CreateModelStateMachineProps = BaseProps & { modelTable: ITable, + guardrailsTable: ITable, lambdaLayers: ILayerVersion[]; dockerImageBuilderFnArn: string; ecsModelDeployerFnArn: string; @@ -70,6 +71,7 @@ export class CreateModelStateMachine extends Construct { ECS_MODEL_DEPLOYER_FN_ARN: ecsModelDeployerFnArn, LISA_API_URL_PS_NAME: restApiContainerEndpointPs.parameterName, MODEL_TABLE_NAME: modelTable.tableName, + GUARDRAILS_TABLE_NAME: props.guardrailsTable.tableName, REST_API_VERSION: 'v2', MANAGEMENT_KEY_NAME: managementKeyName, RESTAPI_SSL_CERT_ARN: config.restApiConfig?.sslCertIamArn ?? '', @@ -211,12 +213,32 @@ export class CreateModelStateMachine extends Construct { outputPath: OUTPUT_PATH, }); + const addGuardrailsToLitellm = new LambdaInvoke(this, 'AddGuardrailsToLitellm', { + lambdaFunction: new Function(this, 'AddGuardrailsToLitellmFunc', { + runtime: getDefaultRuntime(), + handler: 'models.state_machine.create_model.handle_add_guardrails_to_litellm', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + const successState = new Succeed(this, 'CreateSuccess'); const failState = new Fail(this, 'CreateFailed'); // Check if image is pre-existing ECR image const checkImageTypeChoice = new Choice(this, 'CheckImageTypeChoice'); + // Check if guardrails need to be added + const checkGuardrailsChoice = new Choice(this, 'CheckGuardrailsChoice'); + // State Machine definition setModelToCreating.next(createModelInfraChoice); createModelInfraChoice @@ -259,9 +281,18 @@ export class CreateModelStateMachine extends Construct { .otherwise(addModelToLitellm); waitBeforePollingCreateStack.next(pollCreateStack); + // Check for guardrails and add them if present + addModelToLitellm.next(checkGuardrailsChoice); + checkGuardrailsChoice + .when(Condition.isPresent('$.guardrailsConfig'), addGuardrailsToLitellm) + .otherwise(successState); + // terminal states handleFailureState.next(failState); - addModelToLitellm.next(successState); + addGuardrailsToLitellm.next(successState); + addGuardrailsToLitellm.addCatch(handleFailureState, { // fail if guardrail creation fails + errors: ['States.TaskFailed'], + }); const stateMachine = new StateMachine(this, 'CreateModelSM', { definitionBody: DefinitionBody.fromChainable(setModelToCreating), diff --git a/lib/models/state-machine/delete-model.ts b/lib/models/state-machine/delete-model.ts index 03c4b9bb1..3f773b8f0 100644 --- a/lib/models/state-machine/delete-model.ts +++ b/lib/models/state-machine/delete-model.ts @@ -38,6 +38,7 @@ import { LAMBDA_PATH } from '../../util'; type DeleteModelStateMachineProps = BaseProps & { modelTable: ITable, + guardrailsTable: ITable, lambdaLayers: ILayerVersion[], vpc: Vpc, securityGroups: ISecurityGroup[]; @@ -57,10 +58,11 @@ export class DeleteModelStateMachine extends Construct { constructor (scope: Construct, id: string, props: DeleteModelStateMachineProps) { super(scope, id); - const { config, modelTable, lambdaLayers, role, vpc, securityGroups, restApiContainerEndpointPs, managementKeyName, executionRole } = props; + const { config, modelTable, guardrailsTable, lambdaLayers, role, vpc, securityGroups, restApiContainerEndpointPs, managementKeyName, executionRole } = props; const environment = { // Environment variables to set in all Lambda functions MODEL_TABLE_NAME: modelTable.tableName, + GUARDRAILS_TABLE_NAME: guardrailsTable.tableName, LISA_API_URL_PS_NAME: restApiContainerEndpointPs.parameterName, REST_API_VERSION: 'v2', MANAGEMENT_KEY_NAME: managementKeyName, @@ -154,6 +156,23 @@ export class DeleteModelStateMachine extends Construct { outputPath: OUTPUT_PATH, }); + const deleteGuardrails = new LambdaInvoke(this, 'DeleteGuardrails', { + lambdaFunction: new Function(this, 'DeleteGuardrailsFunc', { + runtime: getDefaultRuntime(), + handler: 'models.state_machine.delete_model.handle_delete_guardrails', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + const successState = new Succeed(this, 'DeleteSuccess'); const deleteStackChoice = new Choice(this, 'DeleteStackChoice'); @@ -164,7 +183,8 @@ export class DeleteModelStateMachine extends Construct { // State Machine definition setModelToDeleting.next(deleteFromLitellm); - deleteFromLitellm.next(deleteStackChoice); + deleteFromLitellm.next(deleteGuardrails); + deleteGuardrails.next(deleteStackChoice); deleteStackChoice .when(Condition.isNotNull('$.cloudformation_stack_arn'), deleteStack) diff --git a/lib/models/state-machine/update-model.ts b/lib/models/state-machine/update-model.ts index 1c2b3c0dc..79c36eca6 100644 --- a/lib/models/state-machine/update-model.ts +++ b/lib/models/state-machine/update-model.ts @@ -39,6 +39,7 @@ import { LAMBDA_PATH } from '../../util'; type UpdateModelStateMachineProps = BaseProps & { modelTable: ITable, + guardrailsTable: ITable, lambdaLayers: ILayerVersion[], vpc: Vpc, securityGroups: ISecurityGroup[]; @@ -72,6 +73,7 @@ export class UpdateModelStateMachine extends Construct { const environment = { // Environment variables to set in all Lambda functions MODEL_TABLE_NAME: modelTable.tableName, + GUARDRAILS_TABLE_NAME: props.guardrailsTable.tableName, LISA_API_URL_PS_NAME: restApiContainerEndpointPs.parameterName, REST_API_VERSION: 'v2', MANAGEMENT_KEY_NAME: managementKeyName, @@ -147,6 +149,23 @@ export class UpdateModelStateMachine extends Construct { outputPath: OUTPUT_PATH, }); + const handleUpdateGuardrails = new LambdaInvoke(this, 'HandleUpdateGuardrails', { + lambdaFunction: new Function(this, 'HandleUpdateGuardrailsFunc', { + runtime: getDefaultRuntime(), + handler: 'models.state_machine.update_model.handle_update_guardrails', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + const handleFinishUpdate = new LambdaInvoke(this, 'HandleFinishUpdate', { lambdaFunction: new Function(this, 'HandleFinishUpdateFunc', { runtime: getDefaultRuntime(), @@ -169,6 +188,7 @@ export class UpdateModelStateMachine extends Construct { // choice states const hasEcsUpdateChoice = new Choice(this, 'HasEcsUpdateChoice'); + const hasGuardrailsUpdateChoice = new Choice(this, 'HasGuardrailsUpdateChoice'); const hasCapacityUpdateChoice = new Choice(this, 'HasCapacityUpdateChoice'); const pollAsgChoice = new Choice(this, 'PollAsgChoice'); const pollEcsDeploymentChoice = new Choice(this, 'PollEcsDeploymentChoice'); @@ -190,15 +210,22 @@ export class UpdateModelStateMachine extends Construct { // ECS update flow hasEcsUpdateChoice .when(Condition.booleanEquals('$.needs_ecs_update', true), handleEcsUpdate) - .otherwise(hasCapacityUpdateChoice); + .otherwise(hasGuardrailsUpdateChoice); handleEcsUpdate.next(handlePollEcsDeployment); handlePollEcsDeployment.next(pollEcsDeploymentChoice); pollEcsDeploymentChoice .when(Condition.booleanEquals('$.should_continue_ecs_polling', true), waitBeforePollEcsDeployment) - .otherwise(hasCapacityUpdateChoice); + .otherwise(hasGuardrailsUpdateChoice); waitBeforePollEcsDeployment.next(handlePollEcsDeployment); + // Guardrails update flow + hasGuardrailsUpdateChoice + .when(Condition.booleanEquals('$.needs_guardrails_update', true), handleUpdateGuardrails) + .otherwise(hasCapacityUpdateChoice); + + handleUpdateGuardrails.next(hasCapacityUpdateChoice); + // Existing capacity update flow hasCapacityUpdateChoice .when(Condition.booleanEquals('$.has_capacity_update', true), handlePollCapacity) diff --git a/lib/networking/vpc/index.ts b/lib/networking/vpc/index.ts index 2341e6cf3..97e40b52d 100644 --- a/lib/networking/vpc/index.ts +++ b/lib/networking/vpc/index.ts @@ -67,6 +67,7 @@ export class Vpc extends Construct { // Imports VPC for use by application if supplied, else creates a VPC. vpc = ec2Vpc.fromLookup(this, 'imported-vpc', { vpcId: config.vpcId, + returnVpnGateways: false, }); // Checks if SubnetIds are provided in the config, if so we import them for use. diff --git a/lib/schema/configSchema.ts b/lib/schema/configSchema.ts index 29804e78e..7734d74f6 100644 --- a/lib/schema/configSchema.ts +++ b/lib/schema/configSchema.ts @@ -461,7 +461,8 @@ export const ContainerConfigSchema = z.object({ .describe('Environment variables for the container.'), sharedMemorySize: z.number().min(0).default(0).describe('The value for the size of the /dev/shm volume.'), healthCheckConfig: ContainerHealthCheckConfigSchema.default({}), - privileged: z.boolean().optional() + privileged: z.boolean().optional(), + memoryReservation: z.number().min(0).optional().describe('Memory reservation in MiB for the container.') }).describe('Configuration for the container.'); export type ContainerConfig = z.infer; @@ -709,6 +710,9 @@ const FastApiContainerConfigSchema = z.object({ domainName: z.string().nullish().default(null), sslCertIamArn: z.string().nullish().default(null).describe('ARN of the self-signed cert to be used throughout the system'), imageConfig: ImageAssetSchema.optional().describe('Override image configuration for ECS FastAPI Containers'), + buildConfig: z.object({ + NODEENV_CACHE_DIR: z.string().optional().describe('Override with a path relative to the build directory for a pre-cached nodeenv directory. Defaults to NODEENV_CACHE. For offline environments, populate using: python -m nodeenv PATH') + }).default({}), rdsConfig: RdsInstanceConfig .default({ dbName: 'postgres', @@ -780,6 +784,8 @@ const RoleConfig = z.object({ ECSRestApiRole: z.string().max(64), ECSRestApiExRole: z.string().max(64), LambdaExecutionRole: z.string().max(64), + ECSMcpWorkbenchApiRole: z.string().max(64), + ECSMcpWorkbenchApiExRole: z.string().max(64), LambdaConfigurationApiExecutionRole: z.string().max(64), ModelApiRole: z.string().max(64), ModelsSfnLambdaRole: z.string().max(64), @@ -813,6 +819,12 @@ export const RawConfigObject = z.object({ partition: z.string().default('aws').describe('AWS partition for deployment.'), domain: z.string().default('amazonaws.com').describe('AWS domain for deployment'), restApiConfig: FastApiContainerConfigSchema.describe('Image override for Rest API'), + mcpWorkbenchConfig: ImageAssetSchema.optional().describe('Image override for MCP Workbench'), + mcpWorkbenchBuildConfig: z.object({ + S6_OVERLAY_NOARCH_SOURCE: z.string().optional().describe('Override the URL with a path relative to the build directory for the architecture independent S6 overlay tar.xz.'), + S6_OVERLAY_ARCH_SOURCE: z.string().optional().describe('Override the URL with a path relative to the build directory for the architecture specific S6 overlay tar.xz.'), + RCLONE_SOURCE: z.string().optional().describe('Override the URL with a path relative to the build directory for an rclone .zip') + }).default({}), batchIngestionConfig: ImageAssetSchema.optional().describe('Image override for Batch Ingestion'), vpcId: z.string().optional().describe('VPC ID for the application. (e.g. vpc-0123456789abcdef)'), subnets: z.array(z.object({ @@ -843,6 +855,7 @@ export const RawConfigObject = z.object({ deployDocs: z.boolean().default(true).describe('Whether to deploy docs stacks.'), deployUi: z.boolean().default(true).describe('Whether to deploy UI stacks.'), deployMetrics: z.boolean().default(true).describe('Whether to deploy Metrics stack.'), + deployMcpWorkbench: z.boolean().default(true).describe('Whether to deploy MCP Workbench stack.'), logLevel: z.union([z.literal('DEBUG'), z.literal('INFO'), z.literal('WARNING'), z.literal('ERROR')]) .default('DEBUG') .describe('Log level for application.'), @@ -898,7 +911,7 @@ export const RawConfigObject = z.object({ .describe('Aspect CDK injector for permissions. Ref: https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.aws_iam.PermissionsBoundary.html'), stackSynthesizer: z.nativeEnum(stackSynthesizerType).optional().describe('Set the stack synthesize type. Ref: https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.StackSynthesizer.html'), bootstrapQualifier: z.string().optional().describe('CDK bootstrap qualifier to use for stack synthesis. Defaults to CDK default if not specified.'), - bootstrapRolePrefix: z.string().optional().describe('Prefix for CDK bootstrap role names. Useful when roles have custom prefixes like LLNL_User_Roles_. Leave empty for standard role names.'), + bootstrapRolePrefix: z.string().optional().describe('Prefix for CDK bootstrap role names. Useful when roles have custom prefixes like My_User_Roles_. Leave empty for standard role names.'), litellmConfig: LiteLLMConfig, convertInlinePoliciesToManaged: z.boolean().optional().default(false).describe('Convert inline policies to managed policies'), iamRdsAuth: z.boolean().optional().default(false).describe('Enable IAM authentication for RDS'), diff --git a/lib/serve/ecs-model/vllm/src/entrypoint.sh b/lib/serve/ecs-model/vllm/src/entrypoint.sh index 4a5492aa8..472a4d6da 100644 --- a/lib/serve/ecs-model/vllm/src/entrypoint.sh +++ b/lib/serve/ecs-model/vllm/src/entrypoint.sh @@ -27,6 +27,23 @@ if [[ -n "${MAX_TOTAL_TOKENS}" ]]; then ADDITIONAL_ARGS+=" --max-model-len ${MAX_TOTAL_TOKENS}" fi +# Add vLLM specific arguments from environment variables +if [[ -n "${VLLM_TENSOR_PARALLEL_SIZE}" ]]; then + ADDITIONAL_ARGS+=" --tensor-parallel-size ${VLLM_TENSOR_PARALLEL_SIZE}" +fi + +if [[ -n "${VLLM_ASYNC_SCHEDULING}" ]] && [[ "${VLLM_ASYNC_SCHEDULING}" == "true" ]]; then + ADDITIONAL_ARGS+=" --async-scheduling" +fi + +if [[ -n "${VLLM_MAX_PARALLEL_LOADING_WORKERS}" ]]; then + ADDITIONAL_ARGS+=" --max-parallel-loading-workers ${VLLM_MAX_PARALLEL_LOADING_WORKERS}" +fi + +if [[ -n "${VLLM_USE_TQDM_ON_LOAD}" ]] && [[ "${VLLM_USE_TQDM_ON_LOAD}" == "true" ]]; then + ADDITIONAL_ARGS+=" --use-tqdm-on-load" +fi + # Start the webserver echo "Starting vLLM" python3 -m vllm.entrypoints.openai.api_server \ diff --git a/lib/serve/index.ts b/lib/serve/index.ts index 4124f5a9e..28edb3d10 100644 --- a/lib/serve/index.ts +++ b/lib/serve/index.ts @@ -20,8 +20,8 @@ import { Construct } from 'constructs'; import { FastApiContainer } from '../api-base/fastApiContainer'; import { LisaServeApplicationConstruct, LisaServeApplicationProps } from './serveApplicationConstruct'; import { ITable } from 'aws-cdk-lib/aws-dynamodb'; - export * from './serveApplicationConstruct'; +export * from './mcpWorkbenchConstruct'; /** * LisaServe Application stack. @@ -32,6 +32,8 @@ export class LisaServeApplicationStack extends Stack { public readonly modelsPs: StringParameter; public readonly endpointUrl: StringParameter; public readonly tokenTable?: ITable; + public readonly guardrailsTableNamePs: StringParameter; + public readonly guardrailsTable: ITable; /** * @param {Construct} scope - The parent or owner of the construct. @@ -47,5 +49,7 @@ export class LisaServeApplicationStack extends Stack { this.modelsPs = app.modelsPs; this.restApi = app.restApi; this.tokenTable = app.tokenTable; + this.guardrailsTableNamePs = app.guardrailsTableNamePs; + this.guardrailsTable = app.guardrailsTable; } } diff --git a/lib/serve/mcp-workbench/Dockerfile b/lib/serve/mcp-workbench/Dockerfile index cc72b2705..8879beb75 100644 --- a/lib/serve/mcp-workbench/Dockerfile +++ b/lib/serve/mcp-workbench/Dockerfile @@ -1,4 +1,5 @@ -FROM python:3.13.7-slim +ARG BASE_IMAGE=python:3.13.7-slim +FROM ${BASE_IMAGE} ARG RCLONE_VERSION=v1.71.0 ARG RCLONE_ARCH=amd64 diff --git a/lib/serve/mcpWorkbenchConstruct.ts b/lib/serve/mcpWorkbenchConstruct.ts index 5322e6247..3204f7e55 100644 --- a/lib/serve/mcpWorkbenchConstruct.ts +++ b/lib/serve/mcpWorkbenchConstruct.ts @@ -18,31 +18,36 @@ import { IAuthorizer, IRestApi, RestApi } from 'aws-cdk-lib/aws-apigateway'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { Construct } from 'constructs'; import { Vpc } from '../networking/vpc'; -import { BaseProps, Config } from '../schema'; +import { BaseProps, Config, EcsSourceType } from '../schema'; import * as s3 from 'aws-cdk-lib/aws-s3'; -import { RemovalPolicy, StackProps } from 'aws-cdk-lib'; +import { Duration, RemovalPolicy, StackProps } from 'aws-cdk-lib'; import { createCdkId } from '../core/utils'; import * as ssm from 'aws-cdk-lib/aws-ssm'; import { getDefaultRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../api-base/utils'; import * as iam from 'aws-cdk-lib/aws-iam'; -import { LAMBDA_PATH } from '../util'; +import { LAMBDA_PATH, MCP_WORKBENCH_PATH } from '../util'; import * as lambda from 'aws-cdk-lib/aws-lambda'; +import * as events from 'aws-cdk-lib/aws-events'; +import * as targets from 'aws-cdk-lib/aws-events-targets'; +import { ECSCluster, ECSTasks } from '../api-base/ecsCluster'; +import { Ec2Service } from 'aws-cdk-lib/aws-ecs'; export type McpWorkbenchConstructProps = { - authorizer: IAuthorizer; restApiId: string; rootResourceId: string; securityGroups: ISecurityGroup[]; vpc: Vpc; + apiCluster: ECSCluster; + authorizer?: IAuthorizer; } & BaseProps & StackProps; -export default class McpWorkbenchConstruct extends Construct { +export class McpWorkbenchConstruct extends Construct { + public readonly workbenchBucket: s3.Bucket; + constructor (scope: Construct, id: string, props: McpWorkbenchConstructProps) { super(scope, id); - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; - - const workbenchBucket = this.createWorkbenchBucket(scope, config); + const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc, apiCluster } = props; // Get common layer based on arn from SSM due to issues with cross stack references const commonLambdaLayer = lambda.LayerVersion.fromLayerVersionArn( @@ -62,12 +67,14 @@ export default class McpWorkbenchConstruct extends Construct { rootResourceId: rootResourceId, }); - const lambdaLayers = [commonLambdaLayer, fastapiLambdaLayer]; + const lambdaLayers = [commonLambdaLayer, fastapiLambdaLayer]; - this.createWorkbenchApi(restApi, rootResourceId, config, vpc, securityGroups, authorizer, workbenchBucket, lambdaLayers); + const workbenchBucket = this.createWorkbenchBucket(scope, config); + this.createWorkbenchApi(restApi, config, vpc, securityGroups, workbenchBucket, lambdaLayers, authorizer); + this.createWorkbenchService(apiCluster, config); } - private createWorkbenchApi (restApi: IRestApi, rootResourceId: string, config: Config, vpc: Vpc, securityGroups: ISecurityGroup[], authorizer: IAuthorizer, workbenchBucket: s3.Bucket, lambdaLayers: lambda.ILayerVersion[]) { + private createWorkbenchApi (restApi: IRestApi, config: Config, vpc: Vpc, securityGroups: ISecurityGroup[], workbenchBucket: s3.Bucket, lambdaLayers: lambda.ILayerVersion[], authorizer?: IAuthorizer) { const env = { ADMIN_GROUP: config.authConfig?.adminGroup || '', @@ -110,6 +117,13 @@ export default class McpWorkbenchConstruct extends Construct { method: 'DELETE', environment: env, path: 'mcp-workbench/{toolId}' + }, { + name: 'validate_syntax', + resource: 'mcp_workbench', + description: 'Validate Python code syntax', + method: 'POST', + environment: env, + path: 'mcp-workbench/validate-syntax' }]; // Create IAM role for Lambda @@ -148,7 +162,11 @@ export default class McpWorkbenchConstruct extends Construct { authorizer, lambdaRole, ); - if (f.method === 'POST' || f.method === 'PUT') { + + // Grant S3 permissions based on function type + if (['validate_syntax'].includes(f.name)) { + // No S3 permissions needed for syntax validation + } else if (f.method === 'POST' || f.method === 'PUT') { workbenchBucket.grantWrite(lambdaFunction); } else if (f.method === 'GET') { workbenchBucket.grantRead(lambdaFunction); @@ -173,4 +191,122 @@ export default class McpWorkbenchConstruct extends Construct { eventBridgeEnabled: true }); } + + private createWorkbenchService (apiCluster: ECSCluster, config: Config) { + + const mcpWorkbenchImage = config.mcpWorkbenchConfig || { + baseImage: config.baseImage, + path: MCP_WORKBENCH_PATH, + type: EcsSourceType.ASSET + }; + + const mcpWorkbenchTaskDefinition = { + environment: { + RCLONE_CONFIG_S3_REGION: config.region, + MCPWORKBENCH_BUCKET: [config.deploymentName, config.deploymentStage, 'MCPWorkbench', config.accountNumber].join('-').toLowerCase(), + }, + containerConfig: { + image: mcpWorkbenchImage, + healthCheckConfig: { + command: ['CMD-SHELL', 'exit 0'], + interval: 10, + startPeriod: 30, + timeout: 5, + retries: 3 + }, + environment: {}, + sharedMemorySize: 0, + privileged: true + }, + containerMemoryReservationMiB: 1024, + applicationTarget: { + port: 8000, + priority: 80, + conditions: [{ + type: 'pathPatterns' as const, + values: ['/v2/mcp/*'] + }] + } + }; + + const { service } = apiCluster.addTask(ECSTasks.MCPWORKBENCH, mcpWorkbenchTaskDefinition); + + this.createS3EventHandler(config, service); + } + + private createS3EventHandler (config: any, workbenchService: Ec2Service) { + const s3EventHandlerRole = new iam.Role(this, 'S3EventHandlerRole', { + assumedBy: new iam.ServicePrincipal('lambda.amazonaws.com'), + inlinePolicies: { + 'S3EventHandlerPolicy': new iam.PolicyDocument({ + statements: [ + new iam.PolicyStatement({ + effect: iam.Effect.ALLOW, + actions: [ + 'logs:CreateLogGroup', + 'logs:CreateLogStream', + 'logs:PutLogEvents' + ], + resources: [`arn:${config.partition}:logs:*:*:*`] + }), + new iam.PolicyStatement({ + effect: iam.Effect.ALLOW, + actions: [ + 'ecs:UpdateService', + 'ecs:DescribeServices', + 'ecs:DescribeClusters' + ], + resources: [ + `arn:${config.partition}:ecs:${config.region}:*:cluster/${workbenchService.cluster.clusterName}*`, + `arn:${config.partition}:ecs:${config.region}:*:service/${workbenchService.cluster.clusterName}*/${workbenchService.serviceName}*` + ] + }), + new iam.PolicyStatement({ + effect: iam.Effect.ALLOW, + actions: [ + 'ssm:GetParameter' + ], + resources: [ + `arn:${config.partition}:ssm:${config.region}:*:parameter${config.deploymentPrefix}/deploymentName` + ] + }) + ] + }) + } + }); + + const s3EventHandlerLambda = new lambda.Function(this, 'S3EventHandlerLambda', { + runtime: getDefaultRuntime(), + handler: 'mcp_workbench.s3_event_handler.handler', + code: lambda.Code.fromAsset(config.lambdaPath ?? LAMBDA_PATH), + timeout: Duration.minutes(2), + role: s3EventHandlerRole, + environment: { + DEPLOYMENT_PREFIX: config.deploymentPrefix!, + API_NAME: 'MCPWorkbench', + ECS_CLUSTER_NAME: workbenchService.cluster.clusterName, + MCPWORKBENCH_SERVICE_NAME: workbenchService.serviceName + } + }); + + const rescanMcpWorkbenchRule = new events.Rule(this, 'RescanMCPWorkbenchRule', { + eventPattern: { + source: ['aws.s3', 'debug'], + detailType: [ + 'Object Created', + 'Object Deleted' + ], + detail: { + bucket: { + name: [[config.deploymentName, config.deploymentStage, 'MCPWorkbench', config.accountNumber].join('-').toLowerCase()] + } + } + }, + }); + + rescanMcpWorkbenchRule.addTarget(new targets.LambdaFunction(s3EventHandlerLambda, { + retryAttempts: 2, + maxEventAge: Duration.minutes(5) + })); + } } diff --git a/lib/serve/mcpWorkbenchStack.ts b/lib/serve/mcpWorkbenchStack.ts new file mode 100644 index 000000000..d95a677ae --- /dev/null +++ b/lib/serve/mcpWorkbenchStack.ts @@ -0,0 +1,51 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Stack, StackProps } from 'aws-cdk-lib'; +import { Construct } from 'constructs'; +import { BaseProps } from '../schema'; +import { McpWorkbenchConstruct } from './mcpWorkbenchConstruct'; +import { Vpc } from '../networking/vpc'; +import { ECSCluster } from '../api-base/ecsCluster'; +import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; + +export type McpWorkbenchStackProps = { + vpc: Vpc; + restApiId: string; + rootResourceId: string; + apiCluster: ECSCluster; + authorizer?: IAuthorizer; +} & BaseProps & StackProps; + +export class McpWorkbenchStack extends Stack { + constructor (scope: Construct, id: string, props: McpWorkbenchStackProps) { + super(scope, id, props); + + const { vpc, restApiId, rootResourceId, authorizer, apiCluster } = props; + + new McpWorkbenchConstruct(this, 'McpWorkbench', { + ...props, + restApiId, + rootResourceId, + securityGroups: [vpc.securityGroups.ecsModelAlbSg], + vpc: vpc, + apiCluster, + authorizer + }); + } + + +} diff --git a/lib/serve/rest-api/Dockerfile b/lib/serve/rest-api/Dockerfile index 325aefff3..9681ce9bd 100644 --- a/lib/serve/rest-api/Dockerfile +++ b/lib/serve/rest-api/Dockerfile @@ -1,6 +1,9 @@ ARG BASE_IMAGE=python:3.11 FROM ${BASE_IMAGE} +ARG NODEENV_CACHE_DIR=NODEENV_CACHE +ENV NODEENV_CACHE_DIR=$NODEENV_CACHE_DIR + # Install build dependencies for madoka package RUN apt-get update && apt-get install -y \ gcc \ @@ -25,6 +28,22 @@ WORKDIR /app COPY src/requirements.txt . RUN pip install --no-cache-dir --upgrade -r requirements.txt +# Copy nodeenv cache directory (always exists, may be empty or populated) +COPY ${NODEENV_CACHE_DIR} /tmp/nodeenv-cache/ + +# Pre-cache nodeenv for prisma-client-py +# If the copied directory has content, use it (for offline environments) +# Otherwise, download it during build (requires internet) +RUN mkdir -p /root/.cache/prisma-python && \ + if [ -d "/tmp/nodeenv-cache" ] && [ -n "$(ls /tmp/nodeenv-cache 2>/dev/null)" ]; then \ + echo "Using pre-cached nodeenv from host" && \ + cp -r /tmp/nodeenv-cache /root/.cache/prisma-python/nodeenv && \ + rm -rf /tmp/nodeenv-cache; \ + else \ + echo "Downloading nodeenv (requires internet)" && \ + python -m nodeenv /root/.cache/prisma-python/nodeenv; \ + fi + # Copy the source code into the container COPY src/ ./src diff --git a/lib/serve/rest-api/NODEENV_CACHE/.gitkeep b/lib/serve/rest-api/NODEENV_CACHE/.gitkeep new file mode 100644 index 000000000..b098f0c35 --- /dev/null +++ b/lib/serve/rest-api/NODEENV_CACHE/.gitkeep @@ -0,0 +1,2 @@ +# Placeholder to ensure NODEENV_CACHE directory exists in build context +# For offline builds, populate this directory using: python -m nodeenv NODEENV_CACHE diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index 1341b3469..70f28f488 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -14,6 +14,7 @@ """Model invocation routes.""" +import json import logging import os from collections.abc import Iterator @@ -25,7 +26,15 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from starlette.status import HTTP_401_UNAUTHORIZED -from ....auth import Authorizer +from ....auth import Authorizer, extract_user_groups_from_jwt +from ....utils.guardrails import ( + create_guardrail_json_response, + create_guardrail_streaming_response, + extract_guardrail_response, + get_applicable_guardrails, + get_model_guardrails, + is_guardrail_violation, +) # Local LiteLLM installation URL. By default, LiteLLM runs on port 4000. Change the port here if the # port was changed as part of the LiteLLM startup in entrypoint.sh @@ -81,6 +90,94 @@ router = APIRouter() +async def apply_guardrails_to_request(params: dict, model_id: str, jwt_data: dict) -> None: + """ + Apply guardrails to a chat completion request. + + This function modifies the params dict in-place, adding applicable guardrails + based on the user's group membership and the model's guardrail configuration. + + Args: + params: The request parameters dict to modify + model_id: The model ID to get guardrails for + jwt_data: JWT data containing user information + + Raises: + No exceptions are raised - errors are logged and the request continues + """ + try: + # Get guardrails for this model + guardrails = await get_model_guardrails(model_id) + + if not guardrails: + return + + # Extract user groups from JWT + user_groups = extract_user_groups_from_jwt(jwt_data) + + # Determine which guardrails apply to this user + applicable_guardrail_names = get_applicable_guardrails(user_groups, guardrails, model_id) + + # Add guardrails to request if any apply + if applicable_guardrail_names: + params["guardrails"] = applicable_guardrail_names + logger.info(f"Applying guardrails to model {model_id}: {applicable_guardrail_names}") + + except Exception as e: + logger.error(f"Error applying guardrails for model {model_id}: {e}") + # Continue with request even if guardrails fail to apply + + +def handle_guardrail_violation_response( + response: requests.Response, model_id: str, params: dict, is_streaming: bool +) -> Response | None: + """ + Handle guardrail violation errors in LiteLLM responses. + + Checks if a 400 error response contains a guardrail violation and converts it + into an appropriate format (streaming or non-streaming). + + Args: + response: The HTTP response from LiteLLM + model_id: The model ID from the request + params: The original request parameters + is_streaming: Whether this is a streaming request + + Returns: + Response object if a guardrail violation was handled, None otherwise + """ + if response.status_code != 400: + return None + + try: + error_response = response.json() + error_msg = error_response.get("error", {}).get("message", "") + + if not is_guardrail_violation(error_msg): + return None + + logger.info("Guardrail policy violated") + + guardrail_response = extract_guardrail_response(error_msg) + if not guardrail_response: + return None + + created = int(error_response.get("created", 0) if is_streaming else params.get("created", 0)) + + if is_streaming: + # Return as streaming response + return StreamingResponse( + create_guardrail_streaming_response(guardrail_response, model_id, created), status_code=200 + ) + else: + # Return as a normal completion response + return create_guardrail_json_response(guardrail_response, model_id, created) + + except Exception as e: + logger.error(f"Error handling guardrail violation: {e}") + return None + + def generate_response(iterator: Iterator[Union[str, bytes]]) -> Iterator[str]: """For streaming responses, generate strings instead of bytes objects so that clients recognize the LLM output.""" for line in iterator: @@ -90,6 +187,66 @@ def generate_response(iterator: Iterator[Union[str, bytes]]) -> Iterator[str]: yield f"{line}\n\n" +def generate_response_with_guardrail_handling(iterator: Iterator[Union[str, bytes]], model: str) -> Iterator[str]: + """ + Generate streaming responses with guardrail violation error handling. + + This wrapper checks each chunk in the stream for guardrail violations and converts + them into properly formatted streaming responses. + """ + for line in iterator: + if isinstance(line, bytes): + line = line.decode() + + if not line: + continue + + # Check if this line contains an error (SSE format: "data: {...}") + if line.startswith("data: "): + try: + # Extract JSON from SSE data line + data_content = line[6:].strip() # Remove "data: " prefix + + # Skip [DONE] marker + if data_content == "[DONE]": + yield f"{line}\n\n" + continue + + # Try to parse as JSON to check for errors + chunk_data = json.loads(data_content) + + # Check if this is an error chunk + if "error" in chunk_data: + error_msg = chunk_data.get("error", {}).get("message", "") + + if is_guardrail_violation(error_msg): + logger.info("Guardrail policy violated in streaming response") + + guardrail_response = extract_guardrail_response(error_msg) + if guardrail_response: + # Stream the guardrail response + created = int(chunk_data.get("created", 0)) + for chunk in create_guardrail_streaming_response(guardrail_response, model, created): + yield chunk + return # Stop streaming after guardrail response + else: + # Could not extract guardrail response, pass through the error + yield f"{line}\n\n" + else: + # Different error, pass it through + yield f"{line}\n\n" + else: + # Normal chunk, pass it through + yield f"{line}\n\n" + + except json.JSONDecodeError: + # Not valid JSON or not in expected format, pass through as-is + yield f"{line}\n\n" + else: + # Not in SSE format, pass through as-is + yield f"{line}\n\n" + + @router.api_route("/{api_path:path}", methods=["GET", "POST", "OPTIONS", "PUT", "PATCH", "DELETE", "HEAD"]) async def litellm_passthrough(request: Request, api_path: str) -> Response: """ @@ -104,6 +261,7 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: authorizer = Authorizer() require_admin = api_path not in OPENAI_ROUTES + jwt_data = await authorizer.authenticate_request(request) if not await authorizer.can_access(request, require_admin): raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough") @@ -114,16 +272,46 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: headers["Authorization"] = f"Bearer {LITELLM_KEY}" http_method = request.method - if http_method == "GET": + if http_method == "GET" or http_method == "DELETE": response = requests.request(method=http_method, url=litellm_path, headers=headers) return JSONResponse(response.json(), status_code=response.status_code) - # not a GET request, so expect a JSON payload as part of the request + # not a GET or DELETE request, so expect a JSON payload as part of the request params = await request.json() + + # Apply guardrails for chat/completions requests + if api_path in ["chat/completions", "v1/chat/completions"]: + model_id = params.get("model") + if model_id: + await apply_guardrails_to_request(params, model_id, jwt_data) + if params.get("stream", False): # if a streaming request response = requests.request(method=http_method, url=litellm_path, json=params, headers=headers, stream=True) - return StreamingResponse(generate_response(response.iter_lines()), status_code=response.status_code) + + # Check for guardrail violations + model_id = params.get("model", "") + guardrail_response = handle_guardrail_violation_response(response, model_id, params, is_streaming=True) + if guardrail_response: + return guardrail_response + + # Normal streaming (no error or non-guardrail error) + # Use guardrail-aware generator for chat/completions endpoints + if api_path in ["chat/completions", "v1/chat/completions"]: + model_id = params.get("model", "") + return StreamingResponse( + generate_response_with_guardrail_handling(response.iter_lines(), model_id), + status_code=response.status_code, + ) + else: + return StreamingResponse(generate_response(response.iter_lines()), status_code=response.status_code) else: # not a streaming request response = requests.request(method=http_method, url=litellm_path, json=params, headers=headers) + + # Check for guardrail violations + model_id = params.get("model", "") + guardrail_response = handle_guardrail_violation_response(response, model_id, params, is_streaming=False) + if guardrail_response: + return guardrail_response + if response.status_code != 200: logger.error(f"LiteLLM error response: {response.text}") return JSONResponse(response.json(), status_code=response.status_code) diff --git a/lib/serve/rest-api/src/auth.py b/lib/serve/rest-api/src/auth.py index 07f3bc07f..2f36552d1 100644 --- a/lib/serve/rest-api/src/auth.py +++ b/lib/serve/rest-api/src/auth.py @@ -132,6 +132,50 @@ def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: return group in current_node +def extract_user_groups_from_jwt(jwt_data: Optional[Dict[str, Any]]) -> list[str]: + """ + Extract user groups from JWT data using the JWT_GROUPS_PROP environment variable. + + This follows the same property path traversal logic as is_user_in_group() function. + + Parameters + ---------- + jwt_data : Optional[Dict[str, Any]] + JWT data from authentication. None if user authenticated via API token. + + Returns + ------- + list[str] + List of groups the user belongs to. Empty list if no JWT data or groups not found. + """ + if jwt_data is None: + # API token users have no JWT, treat as having no group restrictions + return [] + + jwt_groups_property = os.environ.get("JWT_GROUPS_PROP", "") + if not jwt_groups_property: + logger.warning("JWT_GROUPS_PROP environment variable not set") + return [] + + # Traverse the property path to find groups + props = jwt_groups_property.split(".") + current_node = jwt_data + + for prop in props: + if isinstance(current_node, dict) and prop in current_node: + current_node = current_node[prop] + else: + logger.debug(f"Groups property path '{jwt_groups_property}' not found in JWT data") + return [] + + # current_node should now be the groups list + if isinstance(current_node, list): + return current_node + else: + logger.warning(f"Expected list of groups but got {type(current_node)}") + return [] + + def get_authorization_token(headers: Dict[str, str], header_name: str = AuthHeaders.AUTHORIZATION) -> str: """Get Bearer token from Authorization headers if it exists.""" if header_name in headers: diff --git a/lib/serve/rest-api/src/utils/guardrails.py b/lib/serve/rest-api/src/utils/guardrails.py new file mode 100644 index 000000000..1d6eb042a --- /dev/null +++ b/lib/serve/rest-api/src/utils/guardrails.py @@ -0,0 +1,239 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for managing and applying LiteLLM guardrails.""" + +import json +import os +import re +from collections.abc import Iterator +from typing import Any, Dict, List, Optional + +import boto3 +from fastapi.responses import JSONResponse +from loguru import logger + + +async def get_model_guardrails(model_id: str) -> List[Dict[str, Any]]: + """ + Query the guardrails DynamoDB table for guardrails associated with a model. + + Parameters + ---------- + model_id : str + The model ID to query guardrails for. + + Returns + ------- + List[Dict[str, Any]] + List of guardrail configurations for the model. Returns empty list if no guardrails found. + """ + try: + dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"]) + guardrails_table = dynamodb.Table(os.environ["GUARDRAILS_TABLE_NAME"]) + + # Query using the ModelIdIndex GSI + response = guardrails_table.query( + IndexName="ModelIdIndex", + KeyConditionExpression="modelId = :modelId", + ExpressionAttributeValues={":modelId": model_id}, + ) + + guardrails = response.get("Items", []) + logger.debug(f"Found {len(guardrails)} guardrails for model {model_id}") + return guardrails + + except Exception as e: + logger.error(f"Error fetching guardrails for model {model_id}: {e}") + return [] + + +def get_applicable_guardrails(user_groups: List[str], guardrails: List[Dict[str, Any]], model_id: str) -> List[str]: + """ + Determine which guardrails apply to a user based on group membership. + + A guardrail applies if: + - It has no allowed_groups (public guardrail, applies to everyone) + - The user is a member of at least one of the guardrail's allowed_groups + + Parameters + ---------- + user_groups : List[str] + List of groups the user belongs to. + + guardrails : List[Dict[str, Any]] + List of guardrail configurations from DynamoDB. + + model_id : str + The model ID being invoked. Used to construct the full LiteLLM guardrail name. + + Returns + ------- + List[str] + List of LiteLLM guardrail names (format: {guardrail_name}-{model_id}) that should be applied to the request. + """ + applicable_guardrails = [] + + for guardrail in guardrails: + # Skip guardrails marked for deletion + if guardrail.get("markedForDeletion", False): + continue + + allowed_groups = guardrail.get("allowedGroups", []) + guardrail_name = guardrail.get("guardrailName") + + if not guardrail_name: + logger.warning(f"Guardrail missing guardrailName field: {guardrail}") + continue + + # Construct the full LiteLLM guardrail name (matches format used in create_model.py) + litellm_guardrail_name = f"{guardrail_name}-{model_id}" + + # If no groups specified, guardrail is public (applies to everyone) + if not allowed_groups: + applicable_guardrails.append(litellm_guardrail_name) + logger.debug(f"Applying public guardrail: {litellm_guardrail_name}") + continue + + # Check if user has any matching group + if any(group in allowed_groups for group in user_groups): + applicable_guardrails.append(litellm_guardrail_name) + logger.debug(f"Applying guardrail {litellm_guardrail_name} based on group membership") + + return applicable_guardrails + + +def is_guardrail_violation(error_msg: str) -> bool: + """ + Check if an error message indicates a guardrail policy violation. + + Parameters + ---------- + error_msg : str + The error message to check. + + Returns + ------- + bool + True if the error message indicates a guardrail violation, False otherwise. + """ + return "Violated guardrail policy" in error_msg + + +def extract_guardrail_response(error_msg: str) -> Optional[str]: + """ + Extract the bedrock_guardrail_response from an error message. + + Parameters + ---------- + error_msg : str + The error message containing the guardrail response. + + Returns + ------- + Optional[str] + The extracted guardrail response text, or None if not found. + """ + match = re.search(r"'bedrock_guardrail_response':\s*'([^']*)'", error_msg) + return match.group(1) if match else None + + +def create_guardrail_streaming_response(guardrail_response: str, model_id: str, created: int = 0) -> Iterator[str]: + """ + Generate streaming response chunks for a guardrail violation. + + Parameters + ---------- + guardrail_response : str + The guardrail response text to stream. + model_id : str + The model ID associated with the request. + created : int, optional + The creation timestamp, by default 0. + + Yields + ------ + str + Properly formatted SSE chunks for the guardrail response. + """ + # First chunk with content + response_chunk = { + "id": "guardrail-response", + "object": "chat.completion.chunk", + "created": created, + "model": model_id, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": guardrail_response}, + "finish_reason": None, + } + ], + "lisa_guardrail_triggered": True, + } + yield f"data: {json.dumps(response_chunk)}\n\n" + + # Final chunk with finish_reason + final_chunk = { + "id": "guardrail-response", + "object": "chat.completion.chunk", + "created": created, + "model": model_id, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + "lisa_guardrail_triggered": True, + } + yield f"data: {json.dumps(final_chunk)}\n\n" + yield "data: [DONE]\n\n" + + +def create_guardrail_json_response(guardrail_response: str, model_id: str, created: int = 0) -> JSONResponse: + """ + Create a JSON response for a guardrail violation. + + Parameters + ---------- + guardrail_response : str + The guardrail response text. + model_id : str + The model ID associated with the request. + created : int, optional + The creation timestamp, by default 0. + + Returns + ------- + JSONResponse + A properly formatted JSON response for the guardrail violation. + """ + response_data = { + "id": "guardrail-response", + "object": "chat.completion", + "created": created, + "model": model_id, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": guardrail_response}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + "lisa_guardrail_triggered": True, + } + return JSONResponse(response_data, status_code=200) diff --git a/lib/serve/serveApplicationConstruct.ts b/lib/serve/serveApplicationConstruct.ts index 26cfa9fbf..df5be76bd 100644 --- a/lib/serve/serveApplicationConstruct.ts +++ b/lib/serve/serveApplicationConstruct.ts @@ -20,6 +20,7 @@ import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { Construct } from 'constructs'; import { Code, Function, IFunction, ILayerVersion, LayerVersion } from 'aws-cdk-lib/aws-lambda'; import { FastApiContainer } from '../api-base/fastApiContainer'; +import { ECSCluster } from '../api-base/ecsCluster'; import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; import { BaseProps, Config } from '../schema'; @@ -40,8 +41,8 @@ import { AwsCustomResource, PhysicalResourceId } from 'aws-cdk-lib/custom-resour import { getDefaultRuntime } from '../api-base/utils'; import { ISecurityGroup, Port } from 'aws-cdk-lib/aws-ec2'; import { ECSTasks } from '../api-base/ecsCluster'; -import { letIfDefined } from '../util/common-functions'; import { EventBus } from 'aws-cdk-lib/aws-events'; +import { GuardrailsTable } from '../models/guardrails-table'; export type LisaServeApplicationProps = { vpc: Vpc; @@ -57,6 +58,10 @@ export class LisaServeApplicationConstruct extends Construct { public readonly modelsPs: StringParameter; public readonly endpointUrl: StringParameter; public readonly tokenTable?: ITable; + public readonly ecsCluster: ECSCluster; + public readonly managementKeySecretName: string; + public readonly guardrailsTableNamePs: StringParameter; + public readonly guardrailsTable: ITable; /** * @param {Stack} scope - The parent or owner of the construct. @@ -83,6 +88,22 @@ export class LisaServeApplicationConstruct extends Construct { } this.tokenTable = tokenTable; + const { managementKeySecretName } = this.createManagementKeySecret(scope, config, vpc, securityGroups); + this.managementKeySecretName = managementKeySecretName; + + // Create guardrails table in serve stack to avoid circular dependency + const guardrailsTableConstruct = new GuardrailsTable(scope, 'GuardrailsTable', { + deploymentPrefix: config.deploymentPrefix || '', + removalPolicy: config.removalPolicy, + }); + this.guardrailsTable = guardrailsTableConstruct.table; + + // Create SSM parameter for guardrails table name + this.guardrailsTableNamePs = new StringParameter(scope, 'GuardrailsTableNameParameter', { + parameterName: `${config.deploymentPrefix}/guardrailsTableName`, + stringValue: this.guardrailsTable.tableName, + }); + // Create REST API const restApi = new FastApiContainer(scope, 'RestApi', { apiName: 'REST', @@ -91,78 +112,7 @@ export class LisaServeApplicationConstruct extends Construct { securityGroup: vpc.securityGroups.restApiAlbSg, tokenTable: tokenTable, vpc: vpc, - }); - - // Create EventBus for management key rotation events - const managementEventBus = new EventBus(scope, createCdkId([scope.node.id, 'managementEventBus']), { - eventBusName: `${config.deploymentName}-lisa-management-events`, - }); - - // Use a stable name for the management key secret - const managementKeySecret = new Secret(scope, createCdkId([scope.node.id, 'managementKeySecret']), { - secretName: `${config.deploymentName}-lisa-management-key`, // Use stable name based on deployment - description: 'LISA management key secret', - generateSecretString: { - excludePunctuation: true, - passwordLength: 16 - }, - removalPolicy: config.removalPolicy - }); - - // Add rotation policy for the management key secret - const rotationLambda = new Function(scope, createCdkId([scope.node.id, 'managementKeyRotationLambda']), { - runtime: getDefaultRuntime(), - handler: 'management_key.handler', - code: Code.fromAsset(config.lambdaPath || LAMBDA_PATH), - timeout: Duration.minutes(5), - environment: { - EVENT_BUS_NAME: managementEventBus.eventBusName, - }, - role: new Role(scope, createCdkId([scope.node.id, 'managementKeyRotationRole']), { - assumedBy: new ServicePrincipal('lambda.amazonaws.com'), - managedPolicies: [ - ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaVPCAccessExecutionRole'), - ], - inlinePolicies: { - 'SecretsManagerRotation': new PolicyDocument({ - statements: [ - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'secretsmanager:DescribeSecret', - 'secretsmanager:GetSecretValue', - 'secretsmanager:PutSecretValue', - 'secretsmanager:UpdateSecretVersionStage' - ], - resources: [managementKeySecret.secretArn] - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'events:PutEvents' - ], - resources: [managementEventBus.eventBusArn] - }) - ] - }) - } - }), - securityGroups: securityGroups, - vpc: vpc.vpc, - }); - - // Configure automatic rotation every 30 days - managementKeySecret.addRotationSchedule('RotationSchedule', { - automaticallyAfter: Duration.days(30), - rotationLambda: rotationLambda - }); - - const managementKeySecretNameStringParameter = new StringParameter(scope, createCdkId(['ManagementKeySecretName']), { - parameterName: `${config.deploymentPrefix}/managementKeySecretName`, - stringValue: managementKeySecret.secretName, - }); - restApi.containers.forEach((container) => { - container.addEnvironment('MANAGEMENT_KEY_NAME', managementKeySecretNameStringParameter.stringValue); + managementKeyName: managementKeySecretName }); // LiteLLM requires a PostgreSQL database to support multiple-instance scaling with dynamic model management. @@ -241,14 +191,27 @@ export class LisaServeApplicationConstruct extends Construct { ...(config.iamRdsAuth ? {} : { passwordSecretId: litellmDbPasswordSecret.secretName }) })); - Object.values(restApi.taskRoles).forEach((taskRole) => { - litellmDbConnectionInfoPs.grantRead(taskRole); - }); - // update the rdsConfig with the endpoint address config.restApiConfig.rdsConfig.dbHost = litellmDb.dbInstanceEndpointAddress; - letIfDefined(restApi.taskRoles[ECSTasks.REST], (serveRole) => { + // Create Parameter Store entry with RestAPI URI + this.endpointUrl = new StringParameter(scope, createCdkId(['LisaServeRestApiUri', 'StringParameter']), { + parameterName: `${config.deploymentPrefix}/lisaServeRestApiUri`, + stringValue: restApi.endpoint, + description: 'URI for LISA Serve API', + }); + + // Create Parameter Store entry with registeredModels + this.modelsPs = new StringParameter(scope, createCdkId(['RegisteredModels', 'StringParameter']), { + parameterName: `${config.deploymentPrefix}/registeredModels`, + stringValue: JSON.stringify([]), + description: 'Serialized JSON of registered models data', + }); + + const serveRole = restApi.apiCluster.taskRoles[ECSTasks.REST]; + if (serveRole) { + // Grant access to REST API task role only + litellmDbConnectionInfoPs.grantRead(serveRole); if (config.iamRdsAuth) { litellmDb.grantConnect(serveRole, serveRole.roleName); @@ -277,53 +240,27 @@ export class LisaServeApplicationConstruct extends Construct { litellmDb.grantConnect(serveRole); litellmDbPasswordSecret.grantRead(serveRole); } - }); - - restApi.containers.forEach((container) => { - container.addEnvironment('LITELLM_DB_INFO_PS_NAME', litellmDbConnectionInfoPs.parameterName); - }); - - if (config.region.includes('iso')) { - const ca_bundle = config.certificateAuthorityBundle ?? ''; - - restApi.containers.forEach((container) => { - container.addEnvironment('SSL_CERT_DIR', '/etc/pki/tls/certs'); - container.addEnvironment('SSL_CERT_FILE', ca_bundle); - container.addEnvironment('REQUESTS_CA_BUNDLE', ca_bundle); - container.addEnvironment('CURL_CA_BUNDLE', ca_bundle); - container.addEnvironment('AWS_CA_BUNDLE', ca_bundle); - }); + this.modelsPs.grantRead(serveRole); } - // Create Parameter Store entry with RestAPI URI - this.endpointUrl = new StringParameter(scope, createCdkId(['LisaServeRestApiUri', 'StringParameter']), { - parameterName: `${config.deploymentPrefix}/lisaServeRestApiUri`, - stringValue: restApi.endpoint, - description: 'URI for LISA Serve API', - }); - - // Create Parameter Store entry with registeredModels - this.modelsPs = new StringParameter(scope, createCdkId(['RegisteredModels', 'StringParameter']), { - parameterName: `${config.deploymentPrefix}/registeredModels`, - stringValue: JSON.stringify([]), - description: 'Serialized JSON of registered models data', - }); - - letIfDefined(restApi.taskRoles[ECSTasks.REST], (serveRole) => { - this.modelsPs.grantRead(serveRole); - }); + // Use the guardrails table name from the construct we just created + const guardrailsTableName = this.guardrailsTable.tableName; // Add parameter as container environment variable for both RestAPI and RagAPI - restApi.containers.forEach((container) => { + const container = restApi.apiCluster.containers[ECSTasks.REST]; + if (container) { + container.addEnvironment('LITELLM_DB_INFO_PS_NAME', litellmDbConnectionInfoPs.parameterName); container.addEnvironment('REGISTERED_MODELS_PS_NAME', this.modelsPs.parameterName); container.addEnvironment('LITELLM_DB_INFO_PS_NAME', litellmDbConnectionInfoPs.parameterName); - }); + container.addEnvironment('GUARDRAILS_TABLE_NAME', guardrailsTableName); + } restApi.node.addDependency(this.modelsPs); restApi.node.addDependency(litellmDbConnectionInfoPs); restApi.node.addDependency(this.endpointUrl); // Update this.restApi = restApi; + this.ecsCluster = restApi.apiCluster; // Grant permissions after restApi is fully constructed // Additional permissions for REST API Role @@ -334,6 +271,7 @@ export class LisaServeApplicationConstruct extends Construct { actions: [ 'bedrock:InvokeModel', 'bedrock:InvokeModelWithResponseStream', + 'bedrock:ApplyGuardrail', ], resources: [ '*' @@ -352,14 +290,36 @@ export class LisaServeApplicationConstruct extends Construct { ] }); + // Grant DynamoDB permissions for guardrails table + const guardrails_permissions = new Policy(scope, 'GuardrailsTablePerms', { + statements: [ + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'dynamodb:Query', + 'dynamodb:GetItem', + ], + resources: [ + `arn:${config.partition}:dynamodb:${config.region}:${config.accountNumber}:table/${guardrailsTableName}/*`, + ], + }), + ] + }); + // Grant SSM parameter read access and attach invocation permissions - const restRole = restApi.taskRoles[ECSTasks.REST]; + const restRole = restApi.apiCluster.taskRoles[ECSTasks.REST]; if (restRole) { this.modelsPs.grantRead(restRole); litellmDbConnectionInfoPs.grantRead(restRole); restRole.attachInlinePolicy(invocation_permissions); + restRole.attachInlinePolicy(guardrails_permissions); + if (serveRole) { + this.modelsPs.grantRead(serveRole); + litellmDbConnectionInfoPs.grantRead(serveRole); + serveRole.attachInlinePolicy(invocation_permissions); + } } - } + }; getIAMAuthLambda (scope: Stack, config: Config, secret: ISecret, user: string, vpc: Vpc, securityGroups: ISecurityGroup[]): IFunction { // Create the IAM role for updating the database to allow IAM authentication @@ -412,4 +372,74 @@ export class LisaServeApplicationConstruct extends Construct { StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/layerVersion/common`), ); } + + private createManagementKeySecret (scope: Stack, config: Config, vpc: Vpc, securityGroups: ISecurityGroup[]): { managementKeySecretName: string } { + const managementKeySecretName = `${config.deploymentName}-lisa-management-key`; + + const managementEventBus = new EventBus(scope, createCdkId([scope.node.id, 'managementEventBus']), { + eventBusName: `${config.deploymentName}-lisa-management-events`, + }); + + const managementKeySecret = new Secret(scope, createCdkId([scope.node.id, 'managementKeySecret']), { + secretName: managementKeySecretName, + description: 'LISA management key secret', + generateSecretString: { + excludePunctuation: true, + passwordLength: 16 + }, + removalPolicy: config.removalPolicy + }); + + const rotationLambda = new Function(scope, createCdkId([scope.node.id, 'managementKeyRotationLambda']), { + runtime: getDefaultRuntime(), + handler: 'management_key.handler', + code: Code.fromAsset(config.lambdaPath || LAMBDA_PATH), + timeout: Duration.minutes(5), + environment: { + EVENT_BUS_NAME: managementEventBus.eventBusName, + }, + role: new Role(scope, createCdkId([scope.node.id, 'managementKeyRotationRole']), { + assumedBy: new ServicePrincipal('lambda.amazonaws.com'), + managedPolicies: [ + ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaVPCAccessExecutionRole'), + ], + inlinePolicies: { + 'SecretsManagerRotation': new PolicyDocument({ + statements: [ + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'secretsmanager:DescribeSecret', + 'secretsmanager:GetSecretValue', + 'secretsmanager:PutSecretValue', + 'secretsmanager:UpdateSecretVersionStage' + ], + resources: [managementKeySecret.secretArn] + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['events:PutEvents'], + resources: [managementEventBus.eventBusArn] + }) + ] + }) + } + }), + securityGroups: securityGroups, + vpc: vpc.vpc, + }); + + managementKeySecret.addRotationSchedule('RotationSchedule', { + automaticallyAfter: Duration.days(30), + rotationLambda: rotationLambda + }); + + new StringParameter(scope, createCdkId(['ManagementKeySecretName']), { + parameterName: `${config.deploymentPrefix}/managementKeySecretName`, + stringValue: managementKeySecret.secretName, + }); + + return { managementKeySecretName }; + } + } diff --git a/lib/stages.ts b/lib/stages.ts index 739810fc0..061681c9f 100644 --- a/lib/stages.ts +++ b/lib/stages.ts @@ -29,6 +29,7 @@ import { Tags, } from 'aws-cdk-lib'; import * as lambda from 'aws-cdk-lib/aws-lambda'; +import * as events from 'aws-cdk-lib/aws-events'; import { Construct } from 'constructs'; import { AwsSolutionsChecks, NagSuppressions, NIST80053R5Checks } from 'cdk-nag'; @@ -43,6 +44,7 @@ import { LisaNetworkingStack } from './networking'; import { LisaRagStack } from './rag'; import { BaseProps, stackSynthesizerType } from './schema'; import { LisaServeApplicationStack } from './serve'; +import { McpWorkbenchStack } from './serve/mcpWorkbenchStack'; import { UserInterfaceStack } from './user-interface'; import { LisaDocsStack } from './docs'; import { LisaMetricsStack } from './metrics'; @@ -165,6 +167,27 @@ class RemoveEventSourceMappingTagsAspect implements IAspect { } } +/** + * Removes Tags property from all AWS::Events::Rule resources in a CDK application. + * This is required for AWS GovCloud regions which don't support Tags on Rule resources. + */ +class RemoveEventRuleTagsAspect implements IAspect { + /** + * Checks if the given node is an instance of CfnResource and specifically an AWS::Events::Rule resource. + * If true, it removes the Tags property to prevent deployment failures in AWS GovCloud regions. + * + * @param {Construct} node - The CDK construct being visited. + */ + public visit (node: Construct): void { + // Check if the node is a CloudFormation resource of type AWS::Events::Rule + if (node instanceof events.CfnRule) { + // Remove Tags property for AWS GovCloud compatibility + node.addPropertyDeletionOverride('Tags'); + } + } +} + + export type CommonStackProps = { synthesizer?: IStackSynthesizer; } & BaseProps; @@ -260,6 +283,7 @@ export class LisaServeApplicationStage extends Stage { ...baseStackProps, authorizer: apiBaseStack.authorizer, description: `LISA-models: ${config.deploymentName}-${config.deploymentStage}`, + guardrailsTable: serveStack.guardrailsTable, lisaServeEndpointUrlPs: config.restApiConfig.internetFacing ? serveStack.endpointUrl : undefined, restApiId: apiBaseStack.restApiId, rootResourceId: apiBaseStack.rootResourceId, @@ -271,6 +295,24 @@ export class LisaServeApplicationStage extends Stage { apiDeploymentStack.addDependency(modelsApiDeploymentStack); this.stacks.push(modelsApiDeploymentStack); + if (config.deployMcpWorkbench) { + const mcpWorkbenchStack = new McpWorkbenchStack(this, 'LisaMcpWorkbench', { + ...baseStackProps, + stackName: createCdkId([config.deploymentName, config.appName, 'mcp-workbench', config.deploymentStage]), + description: `LISA-mcp-workbench: ${config.deploymentName}-${config.deploymentStage}`, + vpc: networkingStack.vpc, + restApiId: apiBaseStack.restApiId, + rootResourceId: apiBaseStack.rootResourceId, + apiCluster: serveStack.restApi.apiCluster, + authorizer: apiBaseStack.authorizer, + }); + mcpWorkbenchStack.addDependency(coreStack); + mcpWorkbenchStack.addDependency(apiBaseStack); + mcpWorkbenchStack.addDependency(serveStack); + apiDeploymentStack.addDependency(mcpWorkbenchStack); + this.stacks.push(mcpWorkbenchStack); + } + if (config.deployRag) { const ragStack = new LisaRagStack(this, 'LisaRAG', { ...baseStackProps, @@ -414,6 +456,7 @@ export class LisaServeApplicationStage extends Stage { // AWS GovCloud regions don't support Tags on EventSourceMapping resources if (config.region.includes('gov')) { Aspects.of(this).add(new RemoveEventSourceMappingTagsAspect()); + Aspects.of(this).add(new RemoveEventRuleTagsAspect()); } } } diff --git a/lib/user-interface/react/package.json b/lib/user-interface/react/package.json index 2b8121ffc..90f8a8715 100644 --- a/lib/user-interface/react/package.json +++ b/lib/user-interface/react/package.json @@ -1,7 +1,7 @@ { "name": "lisa-web", "private": true, - "version": "5.3.2", + "version": "5.4.0", "type": "module", "scripts": { "dev": "vite", @@ -36,6 +36,7 @@ "langchain": "^0.3.15", "lodash": "^4.17.21", "luxon": "^3.5.0", + "dompurify": "^3.2.5", "mermaid": "^11.10.1", "react": "^18.3.1", "react-ace": "^14.0.1", diff --git a/lib/user-interface/react/src/App.tsx b/lib/user-interface/react/src/App.tsx index 1edd43a66..cb4e27b81 100644 --- a/lib/user-interface/react/src/App.tsx +++ b/lib/user-interface/react/src/App.tsx @@ -26,7 +26,7 @@ import Chatbot from './pages/Chatbot'; import Topbar from './components/Topbar'; import SystemBanner from './components/system-banner/system-banner'; import { useAppSelector } from './config/store'; -import { selectCurrentUserIsAdmin } from './shared/reducers/user.reducer'; +import { selectCurrentUserIsAdmin, selectCurrentUserIsUser } from './shared/reducers/user.reducer'; import ModelManagement from './pages/ModelManagement'; import ModelLibrary from './pages/ModelLibrary'; import NotificationBanner from './shared/notification/notification'; @@ -43,6 +43,8 @@ import { ConfigurationContext } from './shared/configuration.provider'; import McpServers from '@/pages/Mcp'; import ModelComparisonPage from './pages/ModelComparison'; import McpWorkbench from './pages/McpWorkbench'; +import ColorSchemeContext from './shared/color-scheme.provider'; +import { applyMode, Mode } from '@cloudscape-design/global-styles'; export type RouteProps = { @@ -51,16 +53,23 @@ export type RouteProps = { configs?: IConfiguration }; -const PrivateRoute = ({ children, showConfig, configs }: RouteProps) => { +const PrivateRoute = ({ children }: RouteProps) => { const auth = useAuth(); - if (auth.isAuthenticated) { - if (showConfig && configs?.configuration.enabledComponents[showConfig] === false) { - return ; - } + const isUserAdmin = useAppSelector(selectCurrentUserIsAdmin); + const isUser = useAppSelector(selectCurrentUserIsUser); + + if (auth.isAuthenticated && (isUserAdmin || isUser)) { return children; } else if (auth.isLoading) { return ; + } else if (auth.isAuthenticated && !isUserAdmin && !isUser) { + return ( +
+

Access Denied

+

You do not have permission to access this application. Please contact your administrator.

+
+ ); } else { return ; } @@ -88,6 +97,23 @@ function App () { }); const config = fullConfig?.[0]; + const [colorScheme, setColorScheme] = useState(() => { + // Check to see if Media-Queries are supported + if (window.matchMedia) { + // Check if the dark-mode Media-Query matches + if (window.matchMedia('(prefers-color-scheme: dark)').matches) { + // Dark + return Mode.Dark; + } + } + + return Mode.Light; + }); + + useEffect(() => { + applyMode(colorScheme); + }, [colorScheme]); + useEffect(() => { if (nav) { setShowNavigation(true); @@ -97,60 +123,61 @@ function App () { }, [nav]); return ( - - {config?.configuration.systemBanner.isEnabled && } -
- -
- - } - toolsHide={true} - notifications={} - stickyNotifications={true} - navigation={nav} - navigationWidth={450} - content={ - - - - - } - /> - - - - } - /> - - - - } - /> - {config?.configuration?.enabledComponents?.modelLibrary && - - - } - />} - {config?.configuration?.enabledComponents?.showRagLibrary && + + + {config?.configuration.systemBanner.isEnabled && } +
+ +
+ + } + toolsHide={true} + notifications={} + stickyNotifications={true} + navigation={nav} + navigationWidth={450} + content={ + + + + + } + /> + + + + } + /> + + + + } + /> + {config?.configuration?.enabledComponents?.modelLibrary && + + + } + />} + {config?.configuration?.enabledComponents?.showRagLibrary && <> } - {config?.configuration?.enabledComponents?.showPromptTemplateLibrary && - - - } - />} - - - - } - /> - {config?.configuration?.enabledComponents?.mcpConnections && - - - } - />} - {config?.configuration?.enabledComponents?.showMcpWorkbench && + {config?.configuration?.enabledComponents?.showPromptTemplateLibrary && + + + } + />} + + + } + /> + {config?.configuration?.enabledComponents?.mcpConnections && + + + } + />} + {config?.configuration?.enabledComponents?.showMcpWorkbench && + + } + /> + } + {config?.configuration?.enabledComponents?.enableModelComparisonUtility && + + + } /> - } - {config?.configuration?.enabledComponents?.enableModelComparisonUtility && - - } - /> - } - - - Loading configuration... - - : - - } /> - - } - /> - {confirmationModal && } - {config?.configuration.systemBanner.isEnabled && } -
+ + + Loading configuration... + + : + + } /> +
+ } + /> + {confirmationModal && } + {config?.configuration.systemBanner.isEnabled && } +
+ ); } diff --git a/lib/user-interface/react/src/components/Topbar.tsx b/lib/user-interface/react/src/components/Topbar.tsx index 90fb48633..2e246379e 100644 --- a/lib/user-interface/react/src/components/Topbar.tsx +++ b/lib/user-interface/react/src/components/Topbar.tsx @@ -14,16 +14,17 @@ limitations under the License. */ -import { ReactElement, useEffect, useState } from 'react'; +import { ReactElement, useContext } from 'react'; import { useAuth } from 'react-oidc-context'; import { useHref, useNavigate } from 'react-router-dom'; -import { applyDensity, applyMode, Density, Mode } from '@cloudscape-design/global-styles'; +import { applyDensity, Density, Mode } from '@cloudscape-design/global-styles'; import TopNavigation, { TopNavigationProps } from '@cloudscape-design/components/top-navigation'; import { getBaseURI } from './utils'; import { purgeStore, useAppSelector } from '../config/store'; import { selectCurrentUserIsAdmin, selectCurrentUsername } from '../shared/reducers/user.reducer'; import { IConfiguration } from '../shared/model/configuration.model'; import { ButtonDropdownProps } from '@cloudscape-design/components'; +import ColorSchemeContext from '@/shared/color-scheme.provider'; applyDensity(Density.Comfortable); @@ -36,31 +37,7 @@ function Topbar ({ configs }: TopbarProps): ReactElement { const auth = useAuth(); const isUserAdmin = useAppSelector(selectCurrentUserIsAdmin); const userName = useAppSelector(selectCurrentUsername); - const [isDarkMode, setIsDarkMode] = useState(window.matchMedia('(prefers-color-scheme: dark)').matches); - - useEffect(() => { - if (isDarkMode) { - applyMode(Mode.Dark); - } else { - applyMode(Mode.Light); - } - }, [isDarkMode]); - - useEffect(() => { - // Check to see if Media-Queries are supported - if (window.matchMedia) { - // Check if the dark-mode Media-Query matches - if (window.matchMedia('(prefers-color-scheme: dark)').matches) { - // Dark - applyMode(Mode.Dark); - } else { - // Light - applyMode(Mode.Light); - } - } else { - // Default (when Media-Queries are not supported) - } - }, []); + const {colorScheme, setColorScheme} = useContext(ColorSchemeContext); const libraryItems = [ ...(configs?.configuration.enabledComponents?.modelLibrary ? [{ @@ -183,7 +160,7 @@ function Topbar ({ configs }: TopbarProps): ReactElement { await auth.signoutSilent(); break; case 'color-mode': - setIsDarkMode(!isDarkMode); + setColorScheme(colorScheme === Mode.Light ? Mode.Dark : Mode.Light); break; default: break; @@ -192,7 +169,7 @@ function Topbar ({ configs }: TopbarProps): ReactElement { iconName: 'user-profile', items: [ { id: 'version-info', text: `LISA v${window.gitInfo?.revisionTag}`, disabled: true }, - { id: 'color-mode', text: isDarkMode ? 'Light mode' : 'Dark mode', iconSvg: ( + { id: 'color-mode', text: colorScheme === Mode.Light ? 'Dark mode' : 'Light mode', iconSvg: ( !msg.guardrailTriggered); + // Always concatenate filtered session history with new messages + const messagesToProcess = filteredHistory.concat(params.message); let messages = messagesToProcess.map((msg) => { const baseMessage: any = { @@ -199,6 +201,8 @@ export const useChatGeneration = ({ const resp: string[] = []; const toolCallsAccumulator: { [index: number]: any } = {}; + let guardrailTriggered = false; + for await (const chunk of stream) { // Check if stop was requested if (stopRequested.current) { @@ -208,6 +212,13 @@ export const useChatGeneration = ({ const content = chunk.content as string; + // Check if this chunk indicates a guardrail was triggered + const isGuardrailTriggered = (chunk as any).id === 'guardrail-response'; + + if (isGuardrailTriggered) { + guardrailTriggered = true; + } + // Get tool calls from LangChain streaming chunks let tool_calls: any[] = []; @@ -355,17 +366,25 @@ export const useChatGeneration = ({ setSession((prev) => { const lastMessage = prev.history[prev.history.length - 1]; if (lastMessage?.type === MessageTypes.AI) { + let updatedHistory = [...prev.history.slice(0, -1), + new LisaChatMessage({ + ...lastMessage, + usage: { + ...lastMessage.usage, + responseTime: parseFloat(responseTime.toFixed(2)) + }, + guardrailTriggered: guardrailTriggered + }) + ]; + + // If guardrail was triggered, also mark the user message + if (guardrailTriggered) { + updatedHistory = markLastUserMessageAsGuardrailTriggered(updatedHistory); + } + return { ...prev, - history: [...prev.history.slice(0, -1), - new LisaChatMessage({ - ...lastMessage, - usage: { - ...lastMessage.usage, - responseTime: parseFloat(responseTime.toFixed(2)) - } - }) - ], + history: updatedHistory, }; } return prev; @@ -385,23 +404,40 @@ export const useChatGeneration = ({ const content = response.content as string; const usage = response.response_metadata.tokenUsage; + // Check if guardrail was triggered + const isGuardrailTriggered = (response as any)?.id === 'guardrail-response'; + // Calculate response time const responseTime = (performance.now() - startTime) / 1000; await memory.saveContext({ input: params.input }, { output: content }); - setSession((prev) => ({ - ...prev, - history: [...prev.history, new LisaChatMessage({ - type: 'ai', - content, - metadata, - toolCalls: [...(response.tool_calls ?? [])], - usage: { - ...usage, - responseTime: parseFloat(responseTime.toFixed(2)) - } - })], - })); + + // Create the AI message + const aiMessage = new LisaChatMessage({ + type: 'ai', + content, + metadata, + toolCalls: [...(response.tool_calls ?? [])], + usage: { + ...usage, + responseTime: parseFloat(responseTime.toFixed(2)) + }, + guardrailTriggered: isGuardrailTriggered + }); + + setSession((prev) => { + let updatedHistory = [...prev.history, aiMessage]; + + // If guardrail was triggered, also mark the user message + if (isGuardrailTriggered) { + updatedHistory = markLastUserMessageAsGuardrailTriggered(updatedHistory); + } + + return { + ...prev, + history: updatedHistory, + }; + }); } } } catch (error) { diff --git a/lib/user-interface/react/src/components/mcp-workbench/McpWorkbenchManagementComponent.tsx b/lib/user-interface/react/src/components/mcp-workbench/McpWorkbenchManagementComponent.tsx index bdbef9ff4..1f4cbf49e 100644 --- a/lib/user-interface/react/src/components/mcp-workbench/McpWorkbenchManagementComponent.tsx +++ b/lib/user-interface/react/src/components/mcp-workbench/McpWorkbenchManagementComponent.tsx @@ -14,16 +14,14 @@ limitations under the License. */ -import { Button, CodeEditor, Container, Grid, SpaceBetween, List, Header, Box, Input, FormField, TextFilter, Pagination } from '@cloudscape-design/components'; +import { Button, Container, Grid, SpaceBetween, List, Header, Box, Input, FormField, TextFilter, Pagination, Link, TextContent, Spinner } from '@cloudscape-design/components'; +import AceEditor from 'react-ace'; +import {Editor} from 'ace-builds'; +import { CancellableEventHandler } from '@cloudscape-design/components/internal/events'; import 'react'; -import 'ace-builds'; -import ace from 'ace-builds'; -import 'ace-builds/src-noconflict/mode-python'; -import 'ace-builds/src-noconflict/theme-tomorrow'; -import 'ace-builds/src-noconflict/ext-language_tools'; -import { ReactElement, useEffect, useState } from 'react'; +import { ReactElement, useCallback, useContext, useEffect, useState } from 'react'; import { useAppDispatch } from '@/config/store'; -import { useNotificationService } from '@/shared/util/hooks'; +import { useDebounce, useNotificationService } from '@/shared/util/hooks'; import * as z from 'zod'; import { setConfirmationModal } from '@/shared/reducers/modal.reducer'; import { @@ -32,11 +30,15 @@ import { useCreateMcpToolMutation, useUpdateMcpToolMutation, useDeleteMcpToolMutation, - mcpToolsApi + mcpToolsApi, + useValidateMcpToolMutation, } from '@/shared/reducers/mcp-tools.reducer'; import { IMcpTool } from '@/shared/model/mcp-tools.model'; import { setBreadcrumbs } from '@/shared/reducers/breadcrumbs.reducer'; import { useValidationReducer } from '@/shared/validation'; +import { ModifyMethod } from '@/shared/validation/modify-method'; +import ColorSchemeContext from '@/shared/color-scheme.provider'; +import { Mode } from '@cloudscape-design/global-styles'; const DEFAULT_CONTENT = 'from mcpworkbench.core.annotations import mcp_tool\nfrom mcpworkbench.core.base_tool import BaseTool\nfrom typing import Annotated\n\n\n# =============================================================================\n# METHOD 1: FUNCTION-BASED APPROACH WITH @mcp_tool DECORATOR\n# =============================================================================\n# This is a simpler approach for straightforward tools that don\'t need\n# complex initialization or state management.\n\n@mcp_tool(\n name="simple_calculator",\n description="A simple calculator using the decorator approach"\n)\nasync def simple_calculator(\n operator: Annotated[str, "The arithmetic operation: add, subtract, multiply, or divide"],\n left_operand: Annotated[float, "The first number in the operation"],\n right_operand: Annotated[float, "The second number in the operation"]\n) -> dict:\n \'\'\'\n Perform basic arithmetic operations using the decorator approach.\n \n The @mcp_tool decorator automatically:\n 1. Registers the function as an MCP tool\n 2. Extracts parameter information from type annotations\n 3. Uses the Annotated descriptions for parameter documentation\n 4. Handles the MCP protocol communication\n \n This approach is ideal for:\n - Simple, stateless operations\n - Quick prototyping\n - Tools that don\'t need complex initialization\n \'\'\'\n \n if operator == "add":\n result = left_operand + right_operand\n elif operator == "subtract":\n result = left_operand - right_operand\n elif operator == "multiply":\n result = left_operand * right_operand\n elif operator == "divide":\n if right_operand == 0:\n raise ValueError("Cannot divide by zero")\n result = left_operand / right_operand\n else:\n raise ValueError(f"Unknown operator: {operator}")\n \n return {\n "operator": operator,\n "left_operand": left_operand,\n "right_operand": right_operand,\n "result": result\n }\n\n\n# =============================================================================\n# METHOD 2: CLASS-BASED APPROACH\n# =============================================================================\n# This is the more structured approach, ideal for complex tools that need\n# initialization, state management, or multiple related operations.\n\nclass CalculatorTool(BaseTool):\n """\n A simple calculator tool that performs basic arithmetic operations.\n \n This class demonstrates the class-based approach to creating MCP tools:\n 1. Inherit from BaseTool\n 2. Initialize with name and description in __init__\n 3. Implement execute() method that returns the callable function\n 4. Define the actual tool function with proper type annotations\n """\n \n def __init__(self):\n """\n Initialize the tool with metadata.\n \n The BaseTool constructor requires:\n - name: A unique identifier for the tool\n - description: A clear description of what the tool does\n """\n super().__init__(\n name="calculator",\n description="Performs basic arithmetic operations (add, subtract, multiply, divide)"\n )\n\n async def execute(self):\n """\n Return the callable function that implements the tool\'s functionality.\n \n This method is called by the MCP framework to get the actual function\n that will be executed when the tool is invoked.\n """\n return self.calculate\n \n async def calculate(\n self,\n operator: Annotated[str, "add, subtract, multiply, or divide"],\n left_operand: Annotated[float, "The first number"],\n right_operand: Annotated[float, "The second number"]\n ):\n """\n Execute the calculator operation.\n \n Parameter Type Annotations with Context:\n =======================================\n Notice the use of Annotated[type, "description"] for each parameter.\n This is OPTIONAL but highly recommended because it provides:\n \n 1. Type information for the MCP framework\n 2. Human-readable descriptions that help AI models understand\n what each parameter is for\n 3. Better error messages and validation\n \n The Annotated type comes from typing module and follows this pattern:\n Annotated[actual_type, "description_string"]\n \n Examples:\n - Annotated[str, "The operation to perform"]\n - Annotated[int, "A positive integer between 1 and 100"]\n - Annotated[list[str], "A list of file paths to process"]\n """ \n if operator == "add":\n result = left_operand + right_operand\n elif operator == "subtract":\n result = left_operand - right_operand\n elif operator == "multiply":\n result = left_operand * right_operand\n elif operator == "divide":\n if right_operand == 0:\n raise ValueError("Cannot divide by zero")\n result = left_operand / right_operand\n else:\n raise ValueError(f"Unknown operator: {operator}")\n \n return {\n "operator": operator,\n "left_operand": left_operand,\n "right_operand": right_operand,\n "result": result\n }'; @@ -47,29 +49,35 @@ export function McpWorkbenchManagementComponent (): ReactElement { // API hooks const { data: tools = [], isFetching: isLoadingTools, refetch } = useListMcpToolsQuery(); const [selectedToolId, setSelectedToolId] = useState(null); - const { data: selectedToolData, isFetching: isLoadingTool, } = useGetMcpToolQuery(selectedToolId!, { + const { data: selectedToolData, isUninitialized } = useGetMcpToolQuery(selectedToolId!, { + skip: selectedToolId === null, refetchOnMountOrArgChange: true, refetchOnFocus: true }); + const [loadingAce, setLoadingAce] = useState(true); + const [isDirty, setIsDirty] = useState(false); const [createToolMutation, { isLoading: isCreating }] = useCreateMcpToolMutation(); const [updateToolMutation, { isLoading: isUpdating }] = useUpdateMcpToolMutation(); const [deleteToolMutation] = useDeleteMcpToolMutation(); - - const [isDirty, setIsDirty] = useState(false); + const {colorScheme} = useContext(ColorSchemeContext); + const [statusText, setStatusText] = useState(''); const schema = z.object({ id: z.string().regex(/^[a-z0-9_.]+?(\.py)?$/).trim().min(3, 'String cannot be empty.'), contents: z.string().trim().min(1, 'String cannot be empty.'), }); - const { errors, touchFields, setFields, isValid, state } = useValidationReducer(schema, { - form: { id: `my_new_tool_${Date.now()}`, contents: DEFAULT_CONTENT} as IMcpTool, + const { errors, touchFields, setFields, isValid, state, setState } = useValidationReducer(schema, { + form: { } as Partial, formSubmitting: false, touched: {}, validateAll: false }); + const [validateMcpToolMutation, {isLoading: isLoadingValidateMcpTool, data: validMcpToolResponse} ] = useValidateMcpToolMutation(); + const [editor, setEditor] = useState(); + // Filtering and pagination state const [filterText, setFilterText] = useState(''); const [currentPageIndex, setCurrentPageIndex] = useState(1); @@ -87,6 +95,83 @@ export function McpWorkbenchManagementComponent (): ReactElement { currentPageIndex * pageSize ); + const [ waitingForValidation, setWaitingForValidation ] = useState(false); + + // Dont validate immediately, wait until this hasn't been called for 300ms + const debouncedValidation = useDebounce(useCallback((contents: string) => { + validateMcpToolMutation(contents).then((response) => { + setWaitingForValidation(false); + setStatusText(undefined); + + // Handle validation response + if ('data' in response) { + // Successful validation response + const validationData = response.data; + const annotations = []; + + // Add syntax error annotations + if (validationData.syntax_errors && validationData.syntax_errors.length > 0) { + validationData.syntax_errors.forEach((error) => { + annotations.push({ + row: Math.max(0, error.line - 1), // Ace editor is 0-indexed + column: error.column, + type: 'error', + text: `${error.type}: ${error.message}` + }); + }); + } + + if (validationData.missing_required_imports.length > 0) { + validationData.missing_required_imports.forEach((error) => { + annotations.push({ + row: 0, + column: 0, + type: 'error', + text: error + }); + }); + } + + // Apply annotations to editor + if (editor) { + editor.getSession().setAnnotations(annotations); + } + + } else if ('error' in response) { + // Error response from validation API + console.error('Validation API error:', response.error); + + // Clear any existing annotations + if (editor) { + editor.getSession().setAnnotations([]); + } + + // Show error notification + const errorMessage = 'data' in response.error && response.error.data?.message + ? response.error.data.message + : 'Unknown validation error'; + notificationService.generateNotification( + `Validation error: ${errorMessage}`, + 'error' + ); + } + }).catch((error) => { + // Handle promise rejection + console.error('Validation request failed:', error); + + // Clear any existing annotations + if (editor) { + editor.getSession().setAnnotations([]); + } + + // Show error notification + notificationService.generateNotification( + `Validation failed: ${error.message || 'Unknown error'}`, + 'error' + ); + }); + }, [validateMcpToolMutation, editor, notificationService]), 300); + // remove top breadcrumbs dispatch(setBreadcrumbs([])); @@ -95,83 +180,109 @@ export function McpWorkbenchManagementComponent (): ReactElement { setCurrentPageIndex(1); }, [filterText]); + useEffect(() => { + async function loadAce () { + await import('ace-builds'); + + // Import language modes you need + await import('ace-builds/src-noconflict/mode-python'); + + // Import themes + await import('ace-builds/src-noconflict/theme-cloud_editor'); + await import('ace-builds/src-noconflict/theme-cloud_editor_dark'); + + setLoadingAce(false); + } + + loadAce(); + }, []); + // Update editor content when a tool is selected useEffect(() => { - if (selectedToolId !== null && selectedToolData && isDirty === false) { + if (!isUninitialized && selectedToolData?.id) { setFields({ id: selectedToolData.id, contents: selectedToolData.contents, size: selectedToolData.size, - updated_at : selectedToolData.updated_at + updated_at: selectedToolData.updated_at }); - setIsDirty(true); + setIsDirty(false); + setStatusText(undefined); } - }, [selectedToolId, selectedToolData, setFields, isDirty]); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [isUninitialized, selectedToolData]); // Handle tool selection const handleToolSelect = (tool: IMcpTool) => { if (isDirty) { dispatch( setConfirmationModal({ - action: 'Switch Tool?', - resourceName: 'Unsaved change', + title: 'Switch Tool?', + action: 'Switch Tool', onConfirm: () => { setSelectedToolId(tool.id); - setIsDirty(false); + setStatusText('Loading MCP tool.'); }, - description: 'You have unsaved changes. Switching tools will lose these changes.' + description: 'You have unsaved changes. Switching tools will discard these changes.' }) ); } else { setSelectedToolId(tool.id); + setStatusText('Loading MCP tool.'); } }; - // Handle editor content change - const handleEditorChange = (value: string) => { - setFields({ - contents: value - }); - setIsDirty(true); - touchFields(['contents']); - }; - // Handle creating new tool - const handleCreateNew = () => { + const handleCreateNew = (event: CancellableEventHandler) => { + event.preventDefault(); + const newTool = { - id: ['my_new_tool', Date.now()].join('-'), + id: '', contents: DEFAULT_CONTENT, }; if (isDirty && selectedToolId) { dispatch( setConfirmationModal({ + title: 'Create New Tool?', action: 'Create New Tool', - resourceName: '', onConfirm: () => { setSelectedToolId(null); + touchFields(['id'], ModifyMethod.Unset); setFields(newTool); setIsDirty(false); }, - description: 'You have unsaved changes. Creating a new tool will lose these changes.' + description: 'You have unsaved changes. Creating a new tool will discard these changes.' }) ); } else { setSelectedToolId(null); + touchFields(['id'], ModifyMethod.Unset); setFields(newTool); - setIsDirty(true); + setIsDirty(false); } }; // Handle create tool const handleCreateTool = async () => { + const result = schema.safeParse(state.form); + if (!result.success) { + setState({ + ...state, + validateAll: true + }); + + return; + } + try { - await createToolMutation({ + const result = await createToolMutation({ id: state.form.id, contents: state.form.contents }).unwrap(); notificationService.generateNotification(`Successfully created tool: ${state.form.id}`, 'success'); + setSelectedToolId(result.id); setIsDirty(false); dispatch(mcpToolsApi.util.invalidateTags(['mcpTools'])); refetch(); @@ -218,7 +329,7 @@ export function McpWorkbenchManagementComponent (): ReactElement { id: '', contents: '', size: undefined, - updated_at : undefined + updated_at: undefined }); setIsDirty(false); } @@ -234,6 +345,14 @@ export function McpWorkbenchManagementComponent (): ReactElement { ); }; + const disabled = !isDirty || !isValid || (isLoadingValidateMcpTool || waitingForValidation) || !validMcpToolResponse?.is_valid; + const disabledReason = [ + {predicate: !isDirty, message: 'Tool has not been modified.'}, + {predicate: !isValid, message: 'Ensure all fields are valid.'}, + {predicate: isLoadingValidateMcpTool || waitingForValidation, message: 'Validating tool.'}, + {predicate: !validMcpToolResponse?.is_valid, message: 'Please correct all errors.'} + ].find((reason) => reason.predicate)?.message; + return ( */}
+ { + setFields({ + contents + }); + debouncedValidation(contents); + setStatusText('Validating MCP tool.'); + setIsDirty(true); + touchFields(['contents']); + + if (!waitingForValidation) { + setWaitingForValidation(true); + } + }} + onLoad={(editor) => { + setEditor(editor); + }} + width='100%' + /> +
+ + + + + { statusText ? +

{statusText}

+
: null} + + + {selectedToolId === null ? ( + + ) : ( + + )} +
+
+ :
+ +

Select an existing tool or Create Tool

+ {statusText ? +

{statusText}

+
: null} +
+ +
} ); diff --git a/lib/user-interface/react/src/components/model-management/create-model/AutoScalingConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/AutoScalingConfig.tsx index b5327ec01..ab34da769 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/AutoScalingConfig.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/AutoScalingConfig.tsx @@ -32,84 +32,86 @@ export function AutoScalingConfig (props: AutoScalingConfigProps) : ReactElement Auto Scaling Capacity}> - + + + props.touchFields(['autoScalingConfig.blockDeviceVolumeSize'])} disabled={props.isEdit} onChange={({ detail }) => { props.setFields({ 'autoScalingConfig.blockDeviceVolumeSize': Number(detail.value) }); }}/> GBs - - + + props.touchFields(['autoScalingConfig.minCapacity'])} onChange={({ detail }) => { props.setFields({ 'autoScalingConfig.minCapacity': Number(detail.value) }); }}/> instances - - + + props.touchFields(['autoScalingConfig.maxCapacity'])} onChange={({ detail }) => { props.setFields({ 'autoScalingConfig.maxCapacity': Number(detail.value) }); }}/> instances - - + + props.touchFields(['autoScalingConfig.desiredCapacity'])} onChange={({ detail }) => { props.setFields({ 'autoScalingConfig.desiredCapacity': detail.value.trim().length > 0 ? Number(detail.value) : undefined }); }}/> instances - - + + props.touchFields(['autoScalingConfig.cooldown'])} onChange={({ detail }) => { props.setFields({ 'autoScalingConfig.cooldown': Number(detail.value) }); }}/> seconds - - + + props.touchFields(['autoScalingConfig.defaultInstanceWarmup'])} onChange={({ detail }) => { props.setFields({ 'autoScalingConfig.defaultInstanceWarmup': Number(detail.value) }); }}/> seconds - + Metric Config}> - props.touchFields(['autoScalingConfig.metricConfig.albMetricName'])} disabled={props.isEdit} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.metricConfig.albMetricName': detail.value }); - }}/> + props.touchFields(['autoScalingConfig.metricConfig.albMetricName'])} disabled={props.isEdit} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.metricConfig.albMetricName': detail.value }); + }}/> - props.touchFields(['autoScalingConfig.metricConfig.targetValue'])} disabled={props.isEdit} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.metricConfig.targetValue': Number(detail.value) }); - }}/> + props.touchFields(['autoScalingConfig.metricConfig.targetValue'])} disabled={props.isEdit} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.metricConfig.targetValue': Number(detail.value) }); + }}/> - - props.touchFields(['autoScalingConfig.metricConfig.duration'])} disabled={props.isEdit} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.metricConfig.duration': Number(detail.value) }); - }}/> - seconds - + + props.touchFields(['autoScalingConfig.metricConfig.duration'])} disabled={props.isEdit} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.metricConfig.duration': Number(detail.value) }); + }}/> + seconds + - - props.touchFields(['autoScalingConfig.metricConfig.estimatedInstanceWarmup'])} disabled={props.isEdit} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.metricConfig.estimatedInstanceWarmup': Number(detail.value) }); - }}/> - seconds - + + props.touchFields(['autoScalingConfig.metricConfig.estimatedInstanceWarmup'])} disabled={props.isEdit} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.metricConfig.estimatedInstanceWarmup': Number(detail.value) }); + }}/> + seconds + diff --git a/lib/user-interface/react/src/components/model-management/create-model/BaseModelConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/BaseModelConfig.tsx index 14e11839e..d0319892f 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/BaseModelConfig.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/BaseModelConfig.tsx @@ -38,123 +38,125 @@ export function BaseModelConfig (props: FormProps & BaseModelConf return ( - props.touchFields(['modelId'])} onChange={({ detail }) => { - props.setFields({ 'modelId': detail.value }); - }} disabled={props.isEdit} placeholder='mistral-vllm'/> + props.touchFields(['modelId'])} onChange={({ detail }) => { + props.setFields({ 'modelId': detail.value }); + }} disabled={props.isEdit} placeholder='mistral-vllm'/> - props.touchFields(['modelName'])} onChange={({ detail }) => { - props.setFields({ 'modelName': detail.value }); - }} disabled={props.isEdit} placeholder='mistralai/Mistral-7B-Instruct-v0.2'/> + constraintText='The full model name is the repository path, or the third party model provider path. The path format typically will be: {ProviderPath}/{ProviderModelName}. Users do not see this value in the chat assistant user interface.'> + + props.touchFields(['modelName'])} onChange={({ detail }) => { + props.setFields({ 'modelName': detail.value }); + }} disabled={props.isEdit} placeholder='mistralai/Mistral-7B-Instruct-v0.2'/> + Model Description (Optional)} errorText={props.formErrors?.modelDescription}> - Model Description (optional)} errorText={props.formErrors?.modelDescription}> - props.touchFields(['modelDescription'])} onChange={({ detail }) => { - props.setFields({ 'modelDescription': detail.value }); - }} placeholder='Brief description of the model and its capabilities'/> + props.touchFields(['modelDescription'])} onChange={({ detail }) => { + props.setFields({ 'modelDescription': detail.value }); + }} placeholder='Brief description of the model and its capabilities'/> + {!props.item.lisaHostedModel && <>API Key (Optional)} errorText={props.formErrors?.apiKey}> - {!props.item.lisaHostedModel && API Key (optional)} errorText={props.formErrors?.apiKey}> - props.touchFields(['apiKey'])} onChange={({ detail }) => { - props.setFields({ 'apiKey': detail.value }); - }} disabled={props.isEdit}/> - } - Model URL (optional)} errorText={props.formErrors?.modelUrl}> - props.touchFields(['modelUrl'])} onChange={({ detail }) => { - props.setFields({ 'modelUrl': detail.value }); - }} disabled={props.isEdit}/> + props.touchFields(['apiKey'])} onChange={({ detail }) => { + props.setFields({ 'apiKey': detail.value }); + }} disabled={props.isEdit}/>} + Model URL (Optional)} errorText={props.formErrors?.modelUrl}> + props.touchFields(['modelUrl'])} onChange={({ detail }) => { + props.setFields({ 'modelUrl': detail.value }); + }} disabled={props.isEdit}/> - { + const fields = { + 'modelType': detail.selectedOption.value, + }; - // turn off streaming for embedded models - if (fields.modelType === ModelType.embedding || fields.modelType === ModelType.imagegen) { - fields['streaming'] = false; - } + // turn off streaming for embedded models + if (fields.modelType === ModelType.embedding || fields.modelType === ModelType.imagegen) { + fields['streaming'] = false; + } - // turn off summarization and image input for embedded and imagegen models - if ((fields.modelType === ModelType.embedding || fields.modelType === ModelType.imagegen)) { - fields['features'] = props.item.features.filter((feature) => feature.name !== ModelFeatures.SUMMARIZATION && feature.name !== ModelFeatures.IMAGE_INPUT && feature.name !== ModelFeatures.TOOL_CALLS); - } + // turn off summarization and image input for embedded and imagegen models + if ((fields.modelType === ModelType.embedding || fields.modelType === ModelType.imagegen)) { + fields['features'] = props.item.features.filter((feature) => feature.name !== ModelFeatures.SUMMARIZATION && feature.name !== ModelFeatures.IMAGE_INPUT && feature.name !== ModelFeatures.TOOL_CALLS); + } - props.setFields(fields); - }} - onBlur={() => props.touchFields(['modelType'])} - options={[ - { label: 'TEXTGEN', value: ModelType.textgen }, - { label: 'IMAGEGEN', value: ModelType.imagegen }, - { label: 'EMBEDDING', value: ModelType.embedding }, - ]} - disabled={props.isEdit} - /> - + props.setFields(fields); + }} + onBlur={() => props.touchFields(['modelType'])} + options={[ + { label: 'TEXTGEN', value: ModelType.textgen }, + { label: 'IMAGEGEN', value: ModelType.imagegen }, + { label: 'EMBEDDING', value: ModelType.embedding }, + ]} + disabled={props.isEdit} + /> {props.item.lisaHostedModel && ( <> - ({value: instance}))} + selectedOption={{value: props.item.instanceType}} + loadingText='Loading instances' + disabled={props.isEdit} + onBlur={() => props.touchFields(['instanceType'])} + onChange={({ detail }) => { + props.setFields({ 'instanceType': detail.selectedOption.value }); + }} + filteringType='auto' + statusType={ isLoadingInstances ? 'loading' : 'finished'} + virtualScroll + /> - props.touchFields(['inferenceContainer'])} + onChange={({ detail }) => + props.setFields({ + 'inferenceContainer': detail.selectedOption.value, + }) + } + options={[ + { label: 'TGI', value: InferenceContainer.TGI }, + { label: 'TEI', value: InferenceContainer.TEI }, + { label: 'VLLM', value: InferenceContainer.VLLM }, + ]} + disabled={props.isEdit} + /> )} - + + + props.setFields({'streaming': detail.checked}) @@ -163,8 +165,10 @@ export function BaseModelConfig (props: FormProps & BaseModelConf disabled={isEmbeddingModel || isImageModel} checked={props.item.streaming} /> - - + + + + { if (detail.checked && props.item.features.find((feature) => feature.name === ModelFeatures.TOOL_CALLS) === undefined) { @@ -177,8 +181,10 @@ export function BaseModelConfig (props: FormProps & BaseModelConf onBlur={() => props.touchFields(['features'])} checked={props.item.features.find((feature) => feature.name === ModelFeatures.TOOL_CALLS) !== undefined} /> - - + + + + { if (detail.checked && props.item.features.find((feature) => feature.name === ModelFeatures.IMAGE_INPUT) === undefined) { @@ -191,9 +197,11 @@ export function BaseModelConfig (props: FormProps & BaseModelConf onBlur={() => props.touchFields(['features'])} checked={props.item.features.find((feature) => feature.name === ModelFeatures.IMAGE_INPUT) !== undefined} /> - - feature.name === ModelFeatures.SUMMARIZATION) !== undefined ? 'Ensure model context is large enough to support these requests.' : ''}> + + + feature.name === ModelFeatures.SUMMARIZATION) !== undefined ? 'Ensure model context is large enough to support these requests.' : ''}> + { if (detail.checked && props.item.features.find((feature) => feature.name === ModelFeatures.SUMMARIZATION) === undefined) { @@ -206,19 +214,19 @@ export function BaseModelConfig (props: FormProps & BaseModelConf onBlur={() => props.touchFields(['features'])} checked={props.item.features.find((feature) => feature.name === ModelFeatures.SUMMARIZATION) !== undefined} /> - + - - feature.name === ModelFeatures.SUMMARIZATION) !== undefined ? props.item.features.filter((feature) => feature.name === 'summarization')[0].overview : ''} inputMode='text' onBlur={() => props.touchFields(['features'])} onChange={({ detail }) => { - props.setFields({ 'features': [...props.item.features.filter((feature) => feature.name !== ModelFeatures.SUMMARIZATION), {name: ModelFeatures.SUMMARIZATION, overview: detail.value}] }); - }} disabled={!props.item.features.find((feature) => feature.name === ModelFeatures.SUMMARIZATION)} placeholder='Optional overview of Summarization for Model'/> + Summarization Capabilities (Optional)} errorText={props.formErrors?.summarizationCapabilities}> + feature.name === ModelFeatures.SUMMARIZATION) !== undefined ? props.item.features.filter((feature) => feature.name === 'summarization')[0].overview : ''} inputMode='text' onBlur={() => props.touchFields(['features'])} onChange={({ detail }) => { + props.setFields({ 'features': [...props.item.features.filter((feature) => feature.name !== ModelFeatures.SUMMARIZATION), {name: ModelFeatures.SUMMARIZATION, overview: detail.value}] }); + }} disabled={!props.item.features.find((feature) => feature.name === ModelFeatures.SUMMARIZATION)} placeholder='Overview of Summarization for Model'/> props.setFields({ 'allowedGroups': values })} - description='Restrict model access to specific groups. Leave empty to allow access to all users.' + constraintText='Restrict model access to specific groups. Leave empty to allow access to all users.' /> ); diff --git a/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx index 3b20905a7..e8c86449c 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx @@ -38,43 +38,43 @@ export function ContainerConfig (props: ContainerConfigProps) : ReactElement { > - - props.touchFields(['containerConfig.sharedMemorySize'])} - onChange={({ detail }) => { - props.setFields({ 'containerConfig.sharedMemorySize': Number(detail.value) }); - }} - /> - MiB - - + props.touchFields(['containerConfig.image.baseImage'])} + value={props.item.sharedMemorySize.toString()} + type='number' + inputMode='numeric' + onBlur={() => props.touchFields(['containerConfig.sharedMemorySize'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.image.baseImage': detail.value }); + props.setFields({ 'containerConfig.sharedMemorySize': Number(detail.value) }); }} /> + MiB + + + props.touchFields(['containerConfig.image.baseImage'])} + onChange={({ detail }) => { + props.setFields({ 'containerConfig.image.baseImage': detail.value }); + }} + /> - props.touchFields(['containerConfig.image.type'])} + onChange={({ detail }) => { + props.setFields({ 'containerConfig.image.type': detail.selectedOption.value }); + }} + options={[ + { label: 'asset', value: EcsSourceType.ASSET, description: 'Base container image used to build model hosting image, e.g. \'vllm/vllm-openai\'' }, + { label: 'ecr', value: EcsSourceType.ECR, description: 'Prebuilt ECR image url used when deploying to ECS' }, + ]} + /> - - {props.item.healthCheckConfig.command.map((item, index) => - - props.touchFields(['containerConfig.healthCheckConfig.command'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.healthCheckConfig.command' : props.item.healthCheckConfig.command.map((item, i) => i === index ? detail.value : item) }); - }}/> - - - )} - - + + {props.item.healthCheckConfig.command.map((item, index) => + + props.touchFields(['containerConfig.healthCheckConfig.command'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.healthCheckConfig.command' : props.item.healthCheckConfig.command.map((item, i) => i === index ? detail.value : item) }); + }}/> + + + )} + + - - props.touchFields(['containerConfig.healthCheckConfig.interval'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.healthCheckConfig.interval': Number(detail.value) }); - }}/> - seconds - + + props.touchFields(['containerConfig.healthCheckConfig.interval'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.healthCheckConfig.interval': Number(detail.value) }); + }}/> + seconds + - - props.touchFields(['containerConfig.healthCheckConfig.startPeriod'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.healthCheckConfig.startPeriod': Number(detail.value) }); - }}/> - seconds - + + props.touchFields(['containerConfig.healthCheckConfig.startPeriod'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.healthCheckConfig.startPeriod': Number(detail.value) }); + }}/> + seconds + - - props.touchFields(['containerConfig.healthCheckConfig.timeout'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.healthCheckConfig.timeout': Number(detail.value) }); - }}/> - seconds - - - props.touchFields(['containerConfig.healthCheckConfig.retries'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.healthCheckConfig.retries': Number(detail.value) }); + + props.touchFields(['containerConfig.healthCheckConfig.timeout'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.healthCheckConfig.timeout': Number(detail.value) }); }}/> + seconds + + + props.touchFields(['containerConfig.healthCheckConfig.retries'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.healthCheckConfig.retries': Number(detail.value) }); + }}/> value !== undefined ); } + + // Handle guardrailsConfig if present + if (updateFields.guardrailsConfig !== undefined) { + // Send the complete current guardrails config from state.form, not just the diff + // This ensures all guardrails are preserved unless explicitly marked for deletion + updateRequest.guardrailsConfig = state.form.guardrailsConfig; + } + updateModelMutation(updateRequest); } } @@ -353,6 +363,15 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { onEdit: state.form.lisaHostedModel, forExternalModel: false }, + { + title: 'Guardrails Configuration', + description: 'Configure guardrails for your model (optional).', + content: ( + + ), + onEdit: true, + forExternalModel: true + }, { title: `Review and ${props.isEdit ? 'Update' : 'Create'}`, description: `Review configuration ${props.isEdit ? 'changes' : ''} prior to submitting.`, diff --git a/lib/user-interface/react/src/components/model-management/create-model/GuardrailsConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/GuardrailsConfig.tsx new file mode 100644 index 000000000..74a2bd964 --- /dev/null +++ b/lib/user-interface/react/src/components/model-management/create-model/GuardrailsConfig.tsx @@ -0,0 +1,283 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import React, { ReactElement, useState } from 'react'; +import { FormProps } from '../../../shared/form/form-props'; +import { IGuardrailsConfig, GuardrailMode } from '../../../shared/model/model-management.model'; +import { + Button, + Container, + FormField, + Grid, + Header, + Icon, + Input, + Select, + SpaceBetween, + Textarea, + TokenGroup +} from '@cloudscape-design/components'; + +type GuardrailsConfigProps = FormProps & { + isEdit: boolean; +}; + +export function GuardrailsConfig (props: GuardrailsConfigProps): ReactElement { + const guardrails = props.item || {}; + const guardrailEntries = Object.entries(guardrails); + const [groupInputValues, setGroupInputValues] = useState>({}); + + const addGuardrail = () => { + const newKey = `guardrail-${Date.now()}`; + const newGuardrails = { + ...guardrails, + [newKey]: { + guardrailName: '', + guardrailIdentifier: '', + guardrailVersion: 'DRAFT', + mode: GuardrailMode.PRE_CALL, + description: '', + allowedGroups: [], + } + }; + props.setFields({ 'guardrailsConfig': newGuardrails }); + }; + + const removeGuardrail = (key: string) => { + if (props.isEdit) { + // Mark for deletion instead of removing + const updatedGuardrails = { + ...guardrails, + [key]: { + ...guardrails[key], + markedForDeletion: true + } + }; + props.setFields({ 'guardrailsConfig': updatedGuardrails }); + } else { + // Remove completely for new models + const remainingGuardrails = Object.fromEntries( + Object.entries(guardrails).filter(([k]) => k !== key) + ); + + props.setFields({ 'guardrailsConfig': remainingGuardrails }); + } + props.touchFields(['guardrailsConfig']); + }; + + const updateGuardrail = (key: string, field: string, value: any) => { + const updatedGuardrails = { + ...guardrails, + [key]: { + ...guardrails[key], + [field]: value + } + }; + props.setFields({ 'guardrailsConfig': updatedGuardrails }); + }; + + const modeOptions = [ + { label: 'Pre Call', value: GuardrailMode.PRE_CALL, description: 'Execute guardrail before LLM call' }, + { label: 'During Call', value: GuardrailMode.DURING_CALL, description: 'Execute guardrail during LLM call' }, + { label: 'Post Call', value: GuardrailMode.POST_CALL, description: 'Execute guardrail after LLM call' } + ]; + + return ( + + + Add Guardrail + + } + > + Guardrails Configuration + + } + > + {guardrailEntries.length === 0 ? ( +
+ No guardrails configured. Click "Add Guardrail" to create one. +
+ ) : ( + + {guardrailEntries.map(([key, guardrail]) => { + // Skip rendering guardrails marked for deletion + if (guardrail.markedForDeletion) { + return null; + } + + return ( + +
+ {guardrail.guardrailName || 'New Guardrail'} +
+
+ +
+ + } + > + + + updateGuardrail(key, 'guardrailName', detail.value)} + onBlur={() => props.touchFields([`guardrailsConfig.guardrails.${key}.guardrailName`])} + placeholder='Enter guardrail name' + /> + + updateGuardrail(key, 'guardrailIdentifier', detail.value)} + onBlur={() => props.touchFields([`guardrailsConfig.guardrails.${key}.guardrailIdentifier`])} + placeholder='Enter guardrail identifier (ARN or ID)' + /> + + + updateGuardrail(key, 'guardrailVersion', detail.value)} + onBlur={() => props.touchFields([`guardrailsConfig.guardrails.${key}.guardrailVersion`])} + placeholder='Enter version (e.g., DRAFT, 1, 2)' + /> + + +