diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 42f7bdfae..a82c92632 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -3,22 +3,21 @@ pwd sudo apt-get update -y -sudo apt-get install -y jq yq +sudo apt-get install -y jq -make createPythonEnvironment +python3 -m venv .venv . .venv/bin/activate echo "source .venv/bin/activate" >> ~/.bashrc echo "source .venv/bin/activate" >> ~/.zshrc -echo "alias deploylisa='make clean && npm ci && make deploy HEADLESS=true'" >> ~/.bashrc -echo "alias deploylisa='make clean && npm ci && make deploy HEADLESS=true'" >> ~/.zshrc +echo "alias deploylisa='npm run clean && npm ci && HEADLESS=true npm run deploy'" >> ~/.bashrc +echo "alias deploylisa='npm run clean && npm ci && HEADLESS=true npm run deploy'" >> ~/.zshrc python -m pip install --upgrade pip -pip3 install yq huggingface_hub s5cmd -make installPythonRequirements +pip3 install huggingface_hub s5cmd +npm run install:python -make createTypeScriptEnvironment -make installTypeScriptRequirements +npm install git config --unset-all core.hooksPath pre-commit install diff --git a/.github/workflows/code.deploy.demo.yml b/.github/workflows/code.deploy.demo.yml index dcc30ce58..0bf6abad1 100644 --- a/.github/workflows/code.deploy.demo.yml +++ b/.github/workflows/code.deploy.demo.yml @@ -40,12 +40,18 @@ jobs: uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v4 with: node-version: 24.x - - name: Install CDK dependencies + - name: Install dependencies run: | npm ci + pip install -r requirements-dev.txt + pip install -e ./lisa-sdk + pip install -e lib/serve/mcp-workbench - name: Deploy LISA run: | - make deploy HEADLESS=true + npm run deploy + env: + HEADLESS: "true" + SKIP_INSTALL: "true" SendSlackNotification: name: Send Slack Notification needs: [ DeployLISA ] diff --git a/.github/workflows/code.deploy.dev.yml b/.github/workflows/code.deploy.dev.yml index 3f660a849..07758a9e4 100644 --- a/.github/workflows/code.deploy.dev.yml +++ b/.github/workflows/code.deploy.dev.yml @@ -40,12 +40,18 @@ jobs: uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v4 with: node-version: 24.x - - name: Install CDK dependencies + - name: Install dependencies run: | npm ci + pip install -r requirements-dev.txt + pip install -e ./lisa-sdk + pip install -e lib/serve/mcp-workbench - name: Deploy LISA run: | - make deploy HEADLESS=true + npm run deploy + env: + HEADLESS: "true" + SKIP_INSTALL: "true" SendSlackNotification: name: Send Slack Notification needs: [ DeployLISA ] diff --git a/.github/workflows/code.e2e-full-test.weekly.yml b/.github/workflows/code.e2e-full-test.weekly.yml new file mode 100644 index 000000000..587ba20b5 --- /dev/null +++ b/.github/workflows/code.e2e-full-test.weekly.yml @@ -0,0 +1,82 @@ +name: Weekly Full E2E Tests + +on: + schedule: + - cron: '0 2 * * 0' # Every Sunday at 02:00 UTC + workflow_dispatch: + inputs: + ref: + description: 'Branch or tag to test against' + required: false + default: 'develop' + type: string + workflow_call: + inputs: + ref: + description: 'Branch or tag to test against' + required: false + default: 'develop' + type: string + +permissions: + contents: read + +env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_URL }} + +jobs: + notify_full_e2e_start: + name: Starting Full E2E Tests + runs-on: ubuntu-latest + steps: + - name: Send "Full E2E Tests Starting" to Slack + uses: rtCamp/action-slack-notify@e31e87e03dd19038e411e38ae27cbad084a90661 # v2 + env: + SLACK_TITLE: 'Full E2E Tests Starting' + MSG_MINIMAL: true + SLACK_MESSAGE: 'Full E2E test suite has started on ref `${{ inputs.ref || github.ref_name }}`...' + + full-e2e: + name: Run Full E2E Tests + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 + with: + ref: ${{ inputs.ref || 'develop' }} + - name: Setup Node.js + uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v4 + with: + node-version: '24' + cache: 'npm' + - name: Install base dependencies + run: npm ci + - name: Run Cypress Full E2E Suite + env: + ADMIN_USER_NAME: ${{ secrets.ADMIN_USER_NAME }} + ADMIN_PASSWORD: ${{ secrets.ADMIN_PASSWORD }} + USER_NAME: ${{ secrets.USER_NAME }} + USER_PASSWORD: ${{ secrets.USER_PASSWORD }} + run: npx cypress run --config-file cypress/cypress.e2e.config.ts + - name: Archive Cypress videos & screenshots + if: always() + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v4 + with: + name: cypress-full-e2e-artifacts + path: | + cypress/videos/e2e + cypress/screenshots/e2e + + notify_full_e2e_end: + name: Full E2E Tests Finished + runs-on: ubuntu-latest + needs: full-e2e + if: always() + steps: + - name: Notify Full E2E results to Slack + uses: rtCamp/action-slack-notify@e31e87e03dd19038e411e38ae27cbad084a90661 # v2 + env: + SLACK_COLOR: ${{ needs.full-e2e.result == 'success' && 'good' || 'danger' }} + SLACK_TITLE: 'Full E2E Tests Finished' + MSG_MINIMAL: false + SLACK_MESSAGE: ${{ needs.full-e2e.result == 'success' && format('Full E2E test suite passed on ref `{0}`.', inputs.ref || github.ref_name) || format(' Full E2E test suite {0} on ref `{1}`.', needs.full-e2e.result, inputs.ref || github.ref_name) }} diff --git a/.github/workflows/code.end-to-end-test.nightly.yml b/.github/workflows/code.end-to-end-test.nightly.yml index 236a0f0c5..9c946f93c 100644 --- a/.github/workflows/code.end-to-end-test.nightly.yml +++ b/.github/workflows/code.end-to-end-test.nightly.yml @@ -1,4 +1,4 @@ -name: Nightly E2E Tests +name: Nightly E2E Health Check on: schedule: @@ -19,18 +19,18 @@ jobs: - name: Send β€œE2E Tests Starting” to Slack uses: rtCamp/action-slack-notify@e31e87e03dd19038e411e38ae27cbad084a90661 # v2 env: - SLACK_TITLE: 'E2E Tests Starting' + SLACK_TITLE: 'Nightly E2E Health Check Starting' MSG_MINIMAL: true - SLACK_MESSAGE: 'E2E tests have started…' + SLACK_MESSAGE: 'Nightly E2E health check (quick specs) has started...' e2e: name: πŸƒβ€β™€οΈ Run E2E Tests runs-on: ubuntu-latest - needs: notify_e2e_start + timeout-minutes: 15 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 with: - ref: develop + ref: develop - name: Setup Node.js uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v4 with: @@ -44,9 +44,14 @@ jobs: ADMIN_PASSWORD: ${{ secrets.ADMIN_PASSWORD }} USER_NAME: ${{ secrets.USER_NAME }} USER_PASSWORD: ${{ secrets.USER_PASSWORD }} - run: npx cypress run --config-file cypress/cypress.e2e.config.ts + # Quick specs only β€” excludes 000-cleanup and bedrock-model-workflow (long-running infra tests). + # Update this list when adding new quick E2E specs. + run: >- + npx cypress run + --config-file cypress/cypress.e2e.config.ts + --spec "cypress/src/e2e/specs/admin.e2e.spec.ts,cypress/src/e2e/specs/user.e2e.spec.ts,cypress/src/e2e/specs/chat.e2e.spec.ts,cypress/src/e2e/specs/bedrock-quick.e2e.spec.ts" - name: Archive Cypress videos & screenshots - if: failure() || always() + if: always() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v4 with: name: cypress-e2e-artifacts @@ -64,8 +69,6 @@ jobs: uses: rtCamp/action-slack-notify@e31e87e03dd19038e411e38ae27cbad084a90661 # v2 env: SLACK_COLOR: ${{ needs.e2e.result == 'success' && 'good' || 'danger' }} - SLACK_TITLE: 'E2E Tests Finished' + SLACK_TITLE: 'Nightly E2E Health Check Finished' MSG_MINIMAL: false - SLACK_MESSAGE_ON_SUCCESS: 'βœ… E2E tests passed on branch `${{ github.ref_name }}`.' - SLACK_MESSAGE_ON_FAILURE: ' ❌ E2E tests failed on branch `${{ github.ref_name }}`.' - SLACK_MESSAGE: 'E2E tests completed with status `${{ job.status }}`.' + SLACK_MESSAGE: ${{ needs.e2e.result == 'success' && format('Nightly E2E health check passed on branch `{0}`.', github.ref_name) || format(' Nightly E2E health check {0} on branch `{1}`.', needs.e2e.result, github.ref_name) }} diff --git a/.github/workflows/code.release.branch.yml b/.github/workflows/code.release.branch.yml index 211013afb..5bd174cec 100644 --- a/.github/workflows/code.release.branch.yml +++ b/.github/workflows/code.release.branch.yml @@ -96,3 +96,11 @@ jobs: env: GH_TOKEN: ${{ github.token }} GITHUB_TOKEN: ${{ secrets.LEAD_ACCESS_TOKEN }} + + run_full_e2e: + name: Run Full E2E on Release Branch + needs: MakeNewReleaseBranch + uses: ./.github/workflows/code.e2e-full-test.weekly.yml + with: + ref: release/${{ github.event.inputs.release_tag }} + secrets: inherit # pragma: allowlist secret diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index 1ed8551e6..46e4802a4 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -64,22 +64,14 @@ jobs: uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v5 with: python-version: "3.13" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # Try hash-verified install first, fall back to regular - if [ -f "requirements-dev-hashes.txt" ]; then - pip install --require-hashes -r requirements-dev-hashes.txt - else - pip install -r requirements-dev.txt - fi - pip install -e ./lisa-sdk + - name: Install Python dependencies + run: npm run install:python - name: Run tests env: ACCOUNT_NUMBER: '012345678901' REGION: us-east-1 run: | - make test-coverage + npm run test:python:coverage pre-commit: name: Run All Pre-Commit needs: [send_starting_slack_notification] diff --git a/.gitignore b/.gitignore index d1e3a8ea0..5281f6d80 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,8 @@ __pycache__/ *.key *.pem TIKTOKEN_CACHE +# Ignore only top-level docs directory, not lib/docs +/docs/ # CDK asset staging directory .cdk.staging diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 33a78ef94..15d9b1297 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,9 +7,9 @@ repos: - id: verify-config name: Verify config file description: Verify config file to check if certain parameters are empty - entry: scripts/verify-config.sh + entry: node scripts/verify-config.mjs verbose: true - language: script + language: system files: config-base.yaml - repo: https://github.com/PyCQA/bandit @@ -90,7 +90,7 @@ repos: args: - --max-line-length=120 - --extend-immutable-calls=Query,fastapi.Depends,fastapi.params.Depends - - --ignore=B008,B042,E203,W503 # Ignore error for function calls in argument defaults and exception __init__ args + - --ignore=B008,B042,E203,E704,W503 # Ignore error for function calls in argument defaults and exception __init__ args exclude: ^(__init__.py$|.*\/__init__.py$|test/cdk/stacks/__baselines__/) diff --git a/CHANGELOG.md b/CHANGELOG.md index e4909afa1..a25aa4e01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,43 @@ +# v6.5.0 + +## Key Features + +### Self-Service RAG Administration + +A new RAG Admin role gives designated users full control over RAG repository operations, document ingestion, collection management, and pipeline configuration without granting full system administrator privileges. This reduces the operational bottleneck where every RAG change required a system administrator. Self-service RAG is especially useful in multi-tenant environments. + +### Operations Metrics Dashboard + +New dashboard reports track metrics across models and clusters, including inference latency, token usage, and batch ingestion job status. For example, customers can use the new input/output token reports to derive costs across users, groups, and models. This is useful for multi-tenant environments with a variety of end-user orgs. Also, model containers publish Prometheus metrics for vLLM, TEI, and TGI, and batch ingestion jobs report totals and failures for RAG document ingestion. + +### Integrating Externally Deployed Models + +Administrators can register deployed models that are not LISA-managed by providing a URL that can be accessed from the LiteLLM ECS cluster. These models appear and behave like other models in the platform. + +### AWS Session Credentials + +LISA now lets you attach AWS credentials to a chat session. While that session is active, MCP tools can use those credentials to call AWS APIs, so tool-based workflows can reach AWS resources in the same context as the conversation instead of requiring separate per-tool setup. + +An example of a tool using this can be seen: lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py + +## Other Key Changes + +- Updated OpenSearch for new RAG collections to the latest supported version and indexing engine (existing collections continue to work as before) +- Introduced optional audit logging for input/output from requests to LISA with opt-in and filtering +- Implemented a deployment Lambda to ensure configured models are present in LiteLLM +- Split Cypress E2E workflows into nightly health checks and weekly full suite runs, with API-based resource cleanup between runs +- Updated LiteLLM to version 1.82.4 + +## Acknowledgements +* @bedanley +* @drduhe +* @Ernest-Gray +* @estohlmann +* @gingerknight +* @jmharold + +**Full Changelog**: https://github.com/awslabs/LISA/compare/v6.4.0..v6.5.0 + # v6.4.0 ## Key Features diff --git a/Makefile b/Makefile deleted file mode 100644 index e3ba78f73..000000000 --- a/Makefile +++ /dev/null @@ -1,482 +0,0 @@ -.PHONY: \ - bootstrap createPythonEnvironment installPythonRequirements \ - createTypeScriptEnvironment installTypeScriptRequirements \ - deploy destroy \ - clean cleanTypeScript cleanPython cleanCfn cleanMisc \ - help dockerCheck dockerLogin listStacks modelCheck buildNpmModules \ - test test-coverage test-lambda test-mcp-workbench test-sdk test-rest-api test-sdk-integ test-integ test-rag-integ test-metadata-integ \ - lock-poetry validate-deps - -################################################################################# -# GLOBALS # -################################################################################# - -PROJECT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) -HEADLESS = false -DOCKER_CMD ?= $(or $(CDK_DOCKER),docker) - -# Function to read config with fallback to base config and default value -# Usage: VAR := $(call get_config,property,default_value) -define get_config -$(shell test -f $(PROJECT_DIR)/config-custom.yaml && yq -r $(1) $(PROJECT_DIR)/config-custom.yaml 2>/dev/null | grep -v '^null$$' || \ - (test -f $(PROJECT_DIR)/config-base.yaml && yq -r $(1) $(PROJECT_DIR)/config-base.yaml 2>/dev/null | grep -v '^null$$') || \ - echo "$(2)") -endef - -# PROFILE (optional argument) -ifeq (${PROFILE},) -PROFILE := $(call get_config,.profile,) -ifeq ($(PROFILE),) -$(warning profile is not set in command line using PROFILE variable or config files, attempting deployment without this variable) -endif -endif - -# DEPLOYMENT_NAME -ifeq (${DEPLOYMENT_NAME},) -DEPLOYMENT_NAME := $(call get_config,.deploymentName,prod) -endif - -# ACCOUNT_NUMBER -ifeq (${ACCOUNT_NUMBER},) -ACCOUNT_NUMBER := $(call get_config,.accountNumber,) -endif - -ifeq (${ACCOUNT_NUMBER},) -$(error accountNumber must be set in command line using ACCOUNT_NUMBER variable or config files) -endif - -# REGION -ifeq (${REGION},) -REGION := $(call get_config,.region,) -endif - -ifeq (${REGION},) -$(error region must be set in command line using REGION variable or config files) -endif - -# PARTITION -ifeq (${PARTITION},) -PARTITION := $(call get_config,.partition,aws) -endif - -# DOMAIN - used for the docker login -ifeq (${DOMAIN},) -ifeq ($(findstring isob,${REGION}),isob) -DOMAIN := sc2s.sgov.gov -else ifeq ($(findstring iso,${REGION}),iso) -DOMAIN := c2s.ic.gov -else -DOMAIN := amazonaws.com -endif -endif - -# Arguments defined through config files - -# APP_NAME -APP_NAME := $(call get_config,.appName,lisa) - -# DEPLOYMENT_STAGE -DEPLOYMENT_STAGE := $(call get_config,.deploymentStage,prod) - -# ACCOUNT_NUMBERS_ECR - AWS account numbers that need to be logged into with Docker CLI to use ECR -ACCOUNT_NUMBERS_ECR := $(shell test -f $(PROJECT_DIR)/config-custom.yaml && yq '.accountNumbersEcr[]' $(PROJECT_DIR)/config-custom.yaml 2>/dev/null || echo "") - -# Append deployed account number to array for dockerLogin rule -ACCOUNT_NUMBERS_ECR := $(ACCOUNT_NUMBERS_ECR) $(ACCOUNT_NUMBER) - -# STACK -ifeq ($(STACK),) - STACK := $(DEPLOYMENT_STAGE)/* -endif - -ifneq ($(findstring $(DEPLOYMENT_STAGE),$(STACK)),$(DEPLOYMENT_STAGE)) - override STACK := $(DEPLOYMENT_STAGE)/$(STACK) -endif - -# MODEL_IDS - IDs of models to deploy -MODEL_IDS := $(shell test -f $(PROJECT_DIR)/config-custom.yaml && yq '.ecsModels[].modelName' $(PROJECT_DIR)/config-custom.yaml 2>/dev/null || echo "") - -# MODEL_BUCKET - S3 bucket containing model artifacts -MODEL_BUCKET := $(call get_config,.s3BucketModels,) - -# BASE_URL - Base URL for web UI assets based on domain name and deployment stage -DOMAIN_NAME := $(call get_config,.apiGatewayConfig.domainName,) -ifeq ($(DOMAIN_NAME),) -BASE_URL := /$(DEPLOYMENT_STAGE)/ -else -BASE_URL := / -endif - -################################################################################# -# COMMANDS # -################################################################################# - -## Bootstrap AWS Account with CDK bootstrap -bootstrap: - @printf "Bootstrapping: $(ACCOUNT_NUMBER) | $(REGION) | $(PARTITION)\n" - -ifdef PROFILE - @npx cdk bootstrap \ - --profile $(PROFILE) \ - aws://$(ACCOUNT_NUMBER)/$(REGION) \ - --partition $(PARTITION) \ - --cloudformation-execution-policies arn:$(PARTITION):iam::aws:policy/AdministratorAccess -else - @npx cdk bootstrap \ - aws://$(ACCOUNT_NUMBER)/$(REGION) \ - --partition $(PARTITION) \ - --cloudformation-execution-policies arn:$(PARTITION):iam::aws:policy/AdministratorAccess -endif - - -## Set up Python interpreter environment to match LISA deployed version -createPythonEnvironment: - python3.13 -m venv .venv - @printf ">>> New virtual environment created. To activate run: 'source .venv/bin/activate'" - - -## Install Python dependencies for development -installPythonRequirements: - CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install pip --upgrade - CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install --prefer-binary -r requirements-dev.txt - CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install -e lisa-sdk - CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install -e lib/serve/mcp-workbench - -## Set up TypeScript interpreter environment -createTypeScriptEnvironment: - npm init - - -## Install TypeScript dependencies for development -installTypeScriptRequirements: - npm install - -install: installPythonRequirements installTypeScriptRequirements - -## Make sure Docker is running -dockerCheck: - @cmd_output=$$($(DOCKER_CMD) ps); \ - if [ $$? != 0 ]; then \ - echo "Process $(DOCKER_CMD) is not running. Exiting..."; \ - exit 1; \ - fi; \ - - -## Check if models are uploaded -modelCheck: - @access_token=""; \ - for MODEL_ID in $(MODEL_IDS); do \ - $(PROJECT_DIR)/scripts/check-for-models.sh -m $$MODEL_ID -s $(MODEL_BUCKET); \ - if [ $$? != 0 ]; then \ - localModelDir="./models"; \ - if [ ! -d "$$localModelDir" ]; then \ - mkdir "$$localModelDir"; \ - fi; \ - echo; \ - echo "Preparing to download, convert, and upload safetensors for model: $$MODEL_ID"; \ - echo "Local directory: '$$localModelDir' will be used to store downloaded and converted model weights"; \ - echo "Note: sudo privileges required to remove model dir due to docker mount using root"; \ - echo "Would you like to continue? [y/N] "; \ - read confirm_download; \ - if [ $${confirm_download:-'N'} = 'y' ]; then \ - mkdir -p $$localModelDir; \ - if [ -z "$$access_token" ]; then \ - if [ -n "$$HUGGINGFACE_TOKEN" ]; then \ - access_token="$$HUGGINGFACE_TOKEN"; \ - elif [ -f ".hf_token_cache" ]; then \ - access_token=$$(cat .hf_token_cache); \ - else \ - echo "What is your huggingface access token? "; \ - read access_token; \ - echo "$$access_token" > .hf_token_cache; \ - fi; \ - fi; \ - echo "Converting and uploading safetensors for model: $$MODEL_ID"; \ - tgiImage=$$(yq -r '[.ecsModels[] | select(.inferenceContainer == "tgi") | .baseImage] | first' $(PROJECT_DIR)/config-custom.yaml); \ - if [ "$$tgiImage" = "null" ] || [ -z "$$tgiImage" ]; then \ - tgiImage="ghcr.io/huggingface/text-generation-inference:latest"; \ - fi; \ - echo "Using TGI image: $$tgiImage"; \ - $(PROJECT_DIR)/scripts/convert-and-upload-model.sh -m $$MODEL_ID -s $(MODEL_BUCKET) -a $$access_token -t $$tgiImage -d $$localModelDir; \ - fi; \ - fi; \ - done - -## Run all clean commands -clean: cleanTypeScript cleanPython cleanCfn cleanMisc - - -## Delete all compiled Python files and related artifacts -cleanPython: - @find . -type f -name "*.py[co]" -delete - @find . -type d -name "__pycache__" -exec rm -rf {} + - @find . -type d -name ".pytest_cache" -exec rm -rf {} + - @find . -type d -name "*.egg-info" -exec rm -rf {} + - @find . -type d -name "dist" -exec rm -rf {} + - @find . -type d -name ".mypy_cache" -exec rm -rf {} + - @find . -type d -name ".tox" -exec rm -rf {} + - - -## Delete TypeScript artifacts and related folders -cleanTypeScript: - @find . -type f -name "*.js.map" -delete - @find . -type d -name "dist" -exec rm -rf {} + - @find . -type d -name "build" -exec rm -rf {} + - @find . -type d -name ".tscache" -exec rm -rf {} + - @find . -type d -name ".jest_cache" -exec rm -rf {} + - @find . -type d -name "node_modules" -exec rm -rf {} + - @find . -type d -name "cdk.out" -exec rm -rf {} + - @find . -type d -name "coverage" -exec rm -rf {} + - - -## Delete CloudFormation outputs -cleanCfn: - @find . -type d -name "cdk.out" -exec rm -rf {} + - - -## Delete all misc files -cleanMisc: - @find . -type f -name "*.DS_Store" -delete - @rm -f .hf_token_cache - - -## Login Docker CLI to Amazon Elastic Container Registry -dockerLogin: dockerCheck -ifdef PROFILE - @$(foreach ACCOUNT,$(ACCOUNT_NUMBERS_ECR), \ - aws ecr get-login-password --region ${REGION} --profile ${PROFILE} | $(DOCKER_CMD) login --username AWS --password-stdin ${ACCOUNT}.dkr.ecr.${REGION}.${DOMAIN} >/dev/null 2>&1; \ - ) -else - @$(foreach ACCOUNT,$(ACCOUNT_NUMBERS_ECR), \ - aws ecr get-login-password --region ${REGION} | $(DOCKER_CMD) login --username AWS --password-stdin ${ACCOUNT}.dkr.ecr.${REGION}.${DOMAIN} >/dev/null 2>&1; \ - ) -endif - - -listStacks: - @npx cdk list - -buildNpmModules: - BASE_URL=$(BASE_URL) npm run build - -buildArchive: - BUILD_ASSETS=true npm run build - -define print_config - @printf "\n \ - DEPLOYING $(STACK) STACK APP INFRASTRUCTURE \n \ - -----------------------------------\n \ - Account Number $(ACCOUNT_NUMBER)\n \ - Region $(REGION)\n \ - Partition $(PARTITION)\n \ - Domain $(DOMAIN)\n \ - App Name $(APP_NAME)\n \ - Deployment Stage $(DEPLOYMENT_STAGE)\n \ - Deployment Name $(DEPLOYMENT_NAME)" - @if [ -n "$(PROFILE)" ]; then \ - printf "\n Deployment Profile $(PROFILE)"; \ - fi - @printf "\n-----------------------------------\n" -endef - -## Deploy all infrastructure -deploy: install dockerCheck dockerLogin cleanMisc modelCheck buildNpmModules - $(call print_config) -ifeq ($(HEADLESS),true) - npx cdk deploy ${STACK} $(if $(PROFILE),--profile ${PROFILE}) --require-approval never -c ${ENV}='$(shell echo '${${ENV}}')'; -else - @printf "Is the configuration correct? [y/N] "\ - && read confirm_config &&\ - if [ $${confirm_config:-'N'} = 'y' ]; then \ - npx cdk deploy ${STACK} $(if $(PROFILE),--profile ${PROFILE}) -c ${ENV}='$(shell echo '${${ENV}}')'; \ - fi; -endif - - -## Tear down all infrastructure -destroy: cleanMisc - $(call print_config) -ifeq ($(HEADLESS),true) - npx cdk destroy ${STACK} --force $(if $(PROFILE),--profile ${PROFILE}); -else - @printf "Is the configuration correct? [y/N] "\ - && read confirm_config &&\ - if [ $${confirm_config:-'N'} = 'y' ]; then \ - npx cdk destroy ${STACK} --force $(if $(PROFILE),--profile ${PROFILE}); \ - fi; -endif - - - -################################################################################# -# SELF DOCUMENTING COMMANDS # -################################################################################# - -.DEFAULT_GOAL := help - -# Inspired by -# sed script explained: -# /^##/: -# * save line in hold space -# * purge line -# * Loop: -# * append newline + line to hold space -# * go to next line -# * if line starts with doc comment, strip comment character off and loop -# * remove target prerequisites -# * append hold space (+ newline) to line -# * replace newline plus comments by `---` -# * print line -# Separate expressions are necessary because labels cannot be delimited by -# semicolon; see - -help: - @echo "$$(tput bold)Available rules:$$(tput sgr0)" - @echo - @sed -n -e "/^## / { \ - h; \ - s/.*//; \ - :doc" \ - -e "H; \ - n; \ - s/^## //; \ - t doc" \ - -e "s/:.*//; \ - G; \ - s/\\n## /---/; \ - s/\\n/ /g; \ - p; \ - }" ${MAKEFILE_LIST} \ - | LC_ALL='C' sort --ignore-case \ - | awk -F '---' \ - -v ncol=$$(tput cols) \ - -v indent=35 \ - -v col_on="$$(tput setaf 6)" \ - -v col_off="$$(tput sgr0)" \ - '{ \ - printf "%s%*s%s ", col_on, -indent, $$1, col_off; \ - n = split($$2, words, " "); \ - line_length = ncol - indent; \ - for (i = 1; i <= n; i++) { \ - line_length -= length(words[i]) + 1; \ - if (line_length <= 0) { \ - line_length = ncol - indent - length(words[i]) - 1; \ - printf "\n%*s ", -indent, " "; \ - } \ - printf "%s ", words[i]; \ - } \ - printf "\n"; \ - }' \ - | more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars') - -## Run all Python unit tests (non-integration) with coverage report -test-coverage: - @echo "Running lambda tests with coverage..." - @pytest test/lambda --verbose \ - --cov lambda \ - --cov-report term-missing \ - --cov-report html:build/coverage \ - --cov-report xml:build/coverage/coverage.xml \ - --cov-fail-under 83 - @echo "" - @echo "Running MCP Workbench tests with coverage..." - @pytest test/mcp-workbench --verbose \ - --cov lib/serve/mcp-workbench/src \ - --cov-report term-missing \ - --cov-report html:build/coverage-mcp \ - --cov-report xml:build/coverage-mcp/coverage.xml \ - --cov-append \ - --cov-fail-under 83 - @echo "" - @echo "Running SDK tests with coverage..." - @pytest test/sdk --verbose \ - --cov lisa-sdk/lisapy \ - --cov-report term-missing \ - --cov-report html:build/coverage-sdk \ - --cov-report xml:build/coverage-sdk/coverage.xml \ - --cov-append \ - --cov-fail-under 80 - @echo "" - @echo "Running REST API tests with coverage..." - @pytest test/rest-api --verbose \ - --cov lib/serve/rest-api/src \ - --cov-config lib/serve/rest-api/.coveragerc \ - --cov-report term-missing \ - --cov-report html:build/coverage-rest-api \ - --cov-report xml:build/coverage-rest-api/coverage.xml \ - --cov-append \ - --cov-fail-under 80 - - -## Run all Python unit tests (non-integration) without coverage -test: - @echo "Running lambda tests..." - @pytest test/lambda --verbose - @echo "" - @echo "Running MCP Workbench tests..." - @pytest test/mcp-workbench --verbose - @echo "" - @echo "Running SDK tests..." - @pytest test/sdk --verbose - @echo "" - @echo "Running REST API tests..." - @pytest test/rest-api --verbose - -## Run lambda tests only -test-lambda: - pytest test/lambda --verbose - -## Run MCP Workbench tests only -test-mcp-workbench: - pytest test/mcp-workbench --verbose - -## Run LISA SDK unit tests only -test-sdk: - pytest test/sdk --verbose - -## Run REST API unit tests only -test-rest-api: - pytest test/rest-api --verbose - -## Run LISA SDK integration tests (requires deployed LISA environment) -test-sdk-integ: - @echo "Running LISA SDK integration tests..." - @echo "Note: These tests require a deployed LISA environment with:" - @echo " - --api or --url argument for API endpoint" - @echo " - --region, --deployment, --profile arguments" - @echo " - AWS credentials configured" - @echo "" - @echo "Example: pytest test/integration/sdk --api https://your-api.com --region us-west-2" - @echo "" - pytest test/integration/sdk --verbose - -## Run integration tests (Python-based) -test-integ: - pytest test/python --verbose - -## Run RAG integration tests (requires deployed LISA environment) -test-rag-integ: - @echo "Running RAG integration tests..." - @echo "Note: These tests require a deployed LISA environment with:" - @echo " - LISA_API_URL environment variable set" - @echo " - LISA_DEPLOYMENT_NAME environment variable set" - @echo " - AWS credentials configured" - @echo "" - pytest test/integration --verbose - -## Run repository metadata preservation integration tests -test-metadata-integ: - pytest test/integration/test_repository_update_metadata_preservation.py --verbose - -## Regenerate all Poetry lock files -lock-poetry: - @echo "Regenerating Poetry lock files..." - @cd lisa-sdk && poetry lock && echo "βœ“ lisa-sdk/poetry.lock updated" - -## Validate all requirements files can be installed -validate-deps: - @echo "Validating requirements files..." - @for req in $$(find . -name "requirements*.txt" -not -path "./node_modules/*" -not -path "./.venv/*"); do \ - echo "Checking $$req..."; \ - pip-compile --dry-run --quiet $$req 2>&1 | grep -i "error\|conflict" && echo "βœ— $$req has conflicts" || echo "βœ“ $$req is valid"; \ - done diff --git a/README.md b/README.md index 365d2dd40..05a08a521 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,10 @@ # LLM Inference Solution for Amazon Dedicated Cloud (LISA) + [![Full Documentation](https://img.shields.io/badge/Full%20Documentation-blue?style=for-the-badge&logo=Vite&logoColor=white)](https://awslabs.github.io/LISA/) [![Contact Us](https://img.shields.io/badge/Contact%20Us-green?style=for-the-badge&logo=maildotru&logoColor=white)](mailto:lisa-product-team@amazon.com) + ## What is LISA? + Our large language model (LLM) inference solution for the Amazon Dedicated Cloud (ADC), LISA, is open source infrastructure-as-code. Customers deploy it directly into an Amazon Web Services (AWS) account in any region. LISA is scalable and ready to support production use cases. LISA accelerates GenAI adoption by offering built-in configurability with Amazon Bedrock models, Knowledge Bases, and Guardrails. Also by offering advanced capabilities like an optional enterprise-ready chat user interface (UI) with configurable features, authentication, resource access control, centralized model orchestration via LiteLLM, model self-hosting via Amazon ECS, retrieval augmented generation (RAG), APIs, and broad model context protocol (MCP) support and features. LISA is also compatible with OpenAI’s API specification making it easily configurable with supporting solutions. For example, the Continue plugin for VSCode and JetBrains integrated development environments (IDE). @@ -9,6 +12,7 @@ LISA accelerates GenAI adoption by offering built-in configurability with Amazon LISA's roadmap is customer-driven, with new capabilities launching monthly. Reach out to the product team to ask questions, provide feedback, and send feature requests via the "Contact Us" button above. ## Key Features + * **Open Source**: No subscription or licensing fees. LISA costs are based on service usage. * **Ongoing Releases**: The product roadmap is customer-driven with releases typically every 2-4 weeks. LISA is backed by a software development team that builds production grade solutions to accelerate customers' GenAI adoption. * **Model Flexibility**: Bring your own models for self-hosting, or quickly configure LISA with 100+ models supported by third-party model providers, including Amazon Bedrock and Jumpstart. @@ -18,30 +22,40 @@ LISA's roadmap is customer-driven, with new capabilities launching monthly. Reac * **FedRAMP**: Leverages FedRAMP High compliant services. ## Major Components + LISA’s four major components include Serve, a Chat UI, RAG, and MCP. LISA Serve and LISA MCP are standalone, foundational core solutions with APIs for customers not leveraging LISA’s Chat UI. Both LISA’s Chat UI and RAG are optional components, but must be used with Serve. Read more in the Architecture Overview section of LISA's documentation site linked above. ## Deployment Prerequisites + ### Pre-Deployment Steps + * Set up or have access to an AWS account. * Ensure that your AWS account has the appropriate permissions. Resource creation during the AWS CDK deployment expects Administrator or Administrator-like permissions, to include resource creation and mutation permissions. Installation will not succeed if this profile does not have permissions to create and edit arbitrary resources for the system. This level of permissions is not required for the runtime of LISA. This is only necessary for deployment and subsequent updates. * If using the chat UI, have your Identity Provider (IdP) information available, and access. * If using an existing VPC, have its information available. * Familiarity with AWS Cloud Development Kit (CDK) and infrastructure-as-code principles is a plus. * AWS CDK and Model Management both leverage AWS Systems Manager Agent (SSM) parameter store. Confirm that SSM is approved for use by your organization before beginning. If you're new to CDK, review the [AWS CDK Documentation](https://docs.aws.amazon.com/cdk/v2/guide/home.html) and consult with your AWS support team. + ### Software + * AWS CLI installed and configured * Python 3.13 * Node.js 24 * Docker installed and running * Sufficient disk space for model downloads and conversions + ## Getting Started + For detailed instructions on setting up, configuring, and deploying LISA, please refer to our separate documentation on installation and usage. -- [Deployment Guide](lib/docs/admin/getting-started.md) -- [Configuration](lib/docs/config/configuration.md) + +* [Deployment Guide](lib/docs/admin/getting-started.md) +* [Configuration](lib/docs/config/configuration.md) + ## License + Although this repository is released under the Apache 2.0 license, when configured to use PGVector as a RAG store it uses the third party `psycopg2-binary` library. The `psycopg2-binary` project's licensing includes diff --git a/VERSION b/VERSION index 19b860c18..f22d756da 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -6.4.0 +6.5.0 diff --git a/bin/build-images b/bin/build-images old mode 100755 new mode 100644 index 394cd1d98..da7d16be8 --- a/bin/build-images +++ b/bin/build-images @@ -97,7 +97,8 @@ ecr_login() { # Function to check if a config parameter is enabled (defaults to true if not present) should_build_image() { local param="$1" - local value=$(yq ".${param}" "$ROOT/custom-config.yaml" 2>/dev/null) + local value + value=$(cd "$ROOT" && node scripts/config.mjs --get ".${param}" 2>/dev/null || true) [[ "$value" != "false" ]] } diff --git a/cdk.json b/cdk.json index 2bd48137f..c52859489 100644 --- a/cdk.json +++ b/cdk.json @@ -1,5 +1,5 @@ { - "app": "npm run deploy", + "app": "tsx ./bin/lisa.ts", "requireApproval": "never", "watch": { "include": [ diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..e3ccedcda --- /dev/null +++ b/conftest.py @@ -0,0 +1,64 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Root conftest - registers pytest options needed by integration tests. + +pytest_addoption() must be in a root-level conftest because pytest parses +command-line options before loading subdirectory conftests. Options used by +test/integration/sdk/ (--api, --url, etc.) must be registered here. +""" + +from pytest import Parser + + +def pytest_addoption(parser: Parser) -> None: + """Register CLI options for integration tests (e.g. test/integration/sdk/).""" + parser.addoption( + "--url", + action="store", + default=None, + help="REST url used for testing. If not provided, read from config-custom.yaml or fetch from AWS.", + ) + parser.addoption( + "--api", + action="store", + default=None, + help="API Gateway url used for testing. If not provided, read from config-custom.yaml or fetch from AWS.", + ) + parser.addoption("--verify", action="store", default="false", help="Verify https request") + parser.addoption( + "--region", + action="store", + default=None, + help="AWS region. Defaults to config-custom.yaml or us-west-2.", + ) + parser.addoption( + "--stage", + action="store", + default=None, + help="Deployment stage. Defaults to config-custom.yaml or dev.", + ) + parser.addoption( + "--deployment", + action="store", + default=None, + help="Deployment name. Defaults to config-custom.yaml or app.", + ) + parser.addoption( + "--profile", + action="store", + default=None, + help="AWS profile. Defaults to config-custom.yaml or default.", + ) + parser.addoption("--auth_token", action="store", default=None, help="Auth token for API tests") diff --git a/cypress/cypress.smoke.config.ts b/cypress/cypress.smoke.config.ts index 463a7cd9d..325b90160 100644 --- a/cypress/cypress.smoke.config.ts +++ b/cypress/cypress.smoke.config.ts @@ -27,6 +27,8 @@ export default defineConfig({ screenshotsFolder: `${PROJECT_ROOT}/screenshots/smoke`, trashAssetsBeforeRuns: true, // wipe out old videos/screenshots + defaultCommandTimeout: 10000, // 10 seconds β€” CI runners need more headroom than the 4s default + e2e: { specPattern: `${PROJECT_ROOT}/src/smoke/specs/**/*.smoke.spec.ts`, supportFile: `${PROJECT_ROOT}/src/smoke/support/index.ts`, diff --git a/cypress/package.json b/cypress/package.json index 90f3f975b..480a73d49 100644 --- a/cypress/package.json +++ b/cypress/package.json @@ -4,7 +4,7 @@ "version": "1.0.0", "devDependencies": { "@types/node": "^25.3.3", - "cypress": "^15.7.1", + "cypress": "^15.12.0", "dotenv": "^17.2.3", "lint-staged": "^16.2.7", "lodash": "^4.17.21" diff --git a/cypress/src/e2e/specs/000-cleanup.e2e.spec.ts b/cypress/src/e2e/specs/000-cleanup.e2e.spec.ts new file mode 100644 index 000000000..09f41d0f4 --- /dev/null +++ b/cypress/src/e2e/specs/000-cleanup.e2e.spec.ts @@ -0,0 +1,187 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/// + +/** + * Pre-test cleanup spec. Runs before other E2E specs to ensure a clean + * environment by deleting all e2e-* prefixed resources and polling until + * async deletions (models, repositories) are fully complete. + * + * Runs first via alphabetical filename ordering (000- prefix). + */ + +import { makeAuthenticatedRequest } from '../../support/collectionHelpers'; + +const E2E_PREFIX = 'e2e-'; +const E2E_PROMPT_PREFIX = 'E2E '; +const POLL_INTERVAL = 5000; +const DELETION_TIMEOUT = 120000; + +describe('E2E Environment Cleanup', () => { + before(() => { + Cypress.session.clearAllSavedSessions(); + }); + + beforeEach(() => { + cy.loginAs('admin'); + }); + + it('Delete all E2E sessions', () => { + makeAuthenticatedRequest('DELETE', '/session').then((response) => { + if (response.status >= 200 && response.status < 300) { + cy.log('Deleted all sessions'); + } else { + cy.log(`Session deletion returned status: ${response.status}`); + } + }); + }); + + it('Delete all E2E repositories and wait for removal', () => { + makeAuthenticatedRequest('GET', '/repository').then((response) => { + if (response.status !== 200) { + cy.log(`Failed to list repositories: ${response.status}`); + return; + } + + const repositories = response.body ?? []; + const e2eRepos = repositories.filter((r: any) => + typeof r.repositoryId === 'string' && r.repositoryId.startsWith(E2E_PREFIX) + ); + + if (e2eRepos.length === 0) { + cy.log('No E2E repositories to clean up'); + return; + } + + cy.log(`Deleting ${e2eRepos.length} E2E repository(ies)`); + + const repoIds = e2eRepos.map((r: any) => r.repositoryId); + + e2eRepos.forEach((repo: any) => { + makeAuthenticatedRequest('DELETE', `/repository/${repo.repositoryId}`).then((delResp) => { + cy.log(`DELETE /repository/${repo.repositoryId} β†’ ${delResp.status}`); + }); + }); + + // Poll until all e2e repos are fully removed + pollUntilGone('repositories', '/repository', repoIds, (body) => { + const repos = body ?? []; + return repos.filter((r: any) => repoIds.includes(r.repositoryId)); + }); + }); + }); + + it('Delete all E2E prompt templates', () => { + makeAuthenticatedRequest('GET', '/prompt-templates').then((response) => { + if (response.status !== 200) { + cy.log(`Failed to list prompt templates: ${response.status}`); + return; + } + + const templates = response.body?.templates ?? []; + const e2eTemplates = templates.filter((t: any) => + typeof t.title === 'string' && t.title.startsWith(E2E_PROMPT_PREFIX) + ); + + if (e2eTemplates.length === 0) { + cy.log('No E2E prompt templates to clean up'); + return; + } + + cy.log(`Deleting ${e2eTemplates.length} E2E prompt template(s)`); + + e2eTemplates.forEach((template: any) => { + const templateId = template.promptTemplateId || template.id; + if (templateId) { + makeAuthenticatedRequest('DELETE', `/prompt-templates/${templateId}`).then((delResp) => { + cy.log(`DELETE prompt template "${template.title}" β†’ ${delResp.status}`); + }); + } + }); + }); + }); + + it('Delete all E2E models and wait for removal', () => { + makeAuthenticatedRequest('GET', '/models').then((response) => { + if (response.status !== 200) { + cy.log(`Failed to list models: ${response.status}`); + return; + } + + const models = response.body?.models ?? []; + const e2eModels = models.filter((m: any) => + typeof m.modelId === 'string' && m.modelId.startsWith(E2E_PREFIX) + ); + + if (e2eModels.length === 0) { + cy.log('No E2E models to clean up'); + return; + } + + cy.log(`Deleting ${e2eModels.length} E2E model(s)`); + + const modelIds = e2eModels.map((m: any) => m.modelId); + + e2eModels.forEach((model: any) => { + makeAuthenticatedRequest('DELETE', `/models/${model.modelId}`).then((delResp) => { + cy.log(`DELETE /models/${model.modelId} β†’ ${delResp.status}`); + }); + }); + + // Poll until all e2e models are fully removed + pollUntilGone('models', '/models', modelIds, (body) => { + const models = body?.models ?? []; + return models.filter((m: any) => modelIds.includes(m.modelId)); + }); + }); + }); +}); + +/** + * Poll an API endpoint until none of the target resource IDs remain. + * Handles async deletion (state machines, CloudFormation teardown). + */ +function pollUntilGone ( + resourceType: string, + endpoint: string, + targetIds: string[], + extractRemaining: (body: any) => any[], +) { + cy.log(`Waiting for ${targetIds.length} ${resourceType} to be fully removed...`); + const startTime = Date.now(); + + function check (): void { + makeAuthenticatedRequest('GET', endpoint).then((response) => { + const remaining = response.status === 200 ? extractRemaining(response.body) : []; + + if (remaining.length === 0) { + cy.log(`All E2E ${resourceType} fully removed`); + return; + } + + const elapsed = Date.now() - startTime; + if (elapsed < DELETION_TIMEOUT) { + cy.log(`${remaining.length} ${resourceType} still deleting, polling...`); + cy.wait(POLL_INTERVAL).then(() => check()); + } else { + cy.log(`WARNING: ${remaining.length} ${resourceType} still present after ${DELETION_TIMEOUT}ms`); + } + }); + } + + check(); +} diff --git a/cypress/src/e2e/specs/bedrock-model-workflow.e2e.spec.ts b/cypress/src/e2e/specs/bedrock-model-workflow.e2e.spec.ts index 6656525f8..8d0ddaaf2 100644 --- a/cypress/src/e2e/specs/bedrock-model-workflow.e2e.spec.ts +++ b/cypress/src/e2e/specs/bedrock-model-workflow.e2e.spec.ts @@ -17,15 +17,22 @@ /// /** - * E2E test for Bedrock model creation and chat workflow. - * Creates a Bedrock model, then uses it in chat. + * Full E2E test for Bedrock model creation and chat workflow. + * Creates a Bedrock model, repository, collections, documents, and prompt templates. + * Used by the weekly and release CI workflows. + * + * Cleanup strategy: + * - 000-cleanup.e2e.spec.ts runs first (alphabetical ordering) to sweep orphaned resources + * and poll until async deletions complete + * - skipCleanup: false: inline UI-based cleanup runs after tests + * - after(): best-effort API sweep catches anything inline cleanup missed */ import { runBedrockModelWorkflowTests } from '../../shared/specs/bedrock-model-workflow.shared.spec'; +import { sweepAllE2eResources } from '../../support/cleanupHelpers'; describe('Bedrock Model Workflow (E2E)', () => { before(() => { - // Clear Cypress session cache to allow fresh login Cypress.session.clearAllSavedSessions(); }); @@ -33,5 +40,11 @@ describe('Bedrock Model Workflow (E2E)', () => { cy.loginAs('admin'); }); - runBedrockModelWorkflowTests({skipCleanup: true}); + after(() => { + // Best-effort sweep to catch anything inline cleanup missed or if tests failed + cy.loginAs('admin'); + sweepAllE2eResources(); + }); + + runBedrockModelWorkflowTests({skipCleanup: false}); }); diff --git a/cypress/src/e2e/specs/bedrock-quick.e2e.spec.ts b/cypress/src/e2e/specs/bedrock-quick.e2e.spec.ts new file mode 100644 index 000000000..f302d1594 --- /dev/null +++ b/cypress/src/e2e/specs/bedrock-quick.e2e.spec.ts @@ -0,0 +1,38 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/// + +/** + * Quick E2E test for Bedrock model creation, prompt templates, and chat. + * No infrastructure provisioning (repositories, collections, documents). + * Suitable for nightly health check runs. + */ + +import { runBedrockQuickTests } from '../../shared/specs/bedrock-model-workflow.shared.spec'; + +describe('Bedrock Quick Workflow (E2E)', () => { + before(() => { + // Clear Cypress session cache to allow fresh login + Cypress.session.clearAllSavedSessions(); + }); + + beforeEach(() => { + cy.loginAs('admin'); + }); + + runBedrockQuickTests({skipCleanup: true}); +}); diff --git a/cypress/src/e2e/support/commands.ts b/cypress/src/e2e/support/commands.ts index 20201a9d5..9de6039a8 100644 --- a/cypress/src/e2e/support/commands.ts +++ b/cypress/src/e2e/support/commands.ts @@ -32,6 +32,7 @@ const APIS = [ { pattern: '**/mcp-workbench*', alias: 'getMcpWorkbench' }, { pattern: '**/prompt-templates*', alias: 'getPromptTemplates' }, { pattern: '**/user-preferences*', alias: 'getUserPreferences' }, + { pattern: '**/project*', alias: 'getProjects' }, ]; /** @@ -58,7 +59,7 @@ function waitForCriticalApis () { */ function waitForAppReady () { // Wait for "Loading configuration..." to disappear - cy.contains('Loading configuration...', { timeout: 15000 }).should('not.exist'); + cy.contains('Loading configuration...', { timeout: 30000 }).should('not.exist'); // Wait for any loading spinners to complete cy.get('body').then(($body) => { @@ -80,7 +81,7 @@ function waitForAppReady () { Cypress.Commands.add('loginAs', (role = 'user') => { const log = Cypress.log({ displayName: 'Cognito Login', - message: [`πŸ” Authenticating | ${role}`], + message: [`Authenticating | ${role}`], autoEnd: false, }); @@ -126,20 +127,47 @@ Cypress.Commands.add('loginAs', (role = 'user') => { cy.get('@usernameInput').clear({ force: true }); cy.get('@usernameInput').type(username, { force: true }); - // Fill password - cy.get('input[name="password"]') - .filter(':visible') - .type(password, { force: true, log: false }); - - // Submit - cy.get('input[type="submit"], input[aria-label="submit"], button[type="submit"]') - .filter(':visible') - .first() - .click({ force: true }); + // Handle both single-page and two-step login flows + // Check if password field is already visible (single-page flow) + cy.get('body').then(($body) => { + const passwordVisible = $body.find('input[name="password"]:visible').length > 0; + + if (!passwordVisible) { + // Two-step flow: click Next/Continue button to proceed to password + cy.get('input[type="submit"], button[type="submit"], button:contains("Next"), button:contains("Continue"), input[value="Next"], input[name="signInSubmitButton"]') + .filter(':visible') + .first() + .click({ force: true }); + + // Wait for password field to appear + cy.get('input[name="password"]', { timeout: 10000 }) + .should('be.visible'); + } + + // Fill password + cy.get('input[name="password"]') + .filter(':visible') + .type(password, { force: true, log: false }); + + // Submit the form + cy.get('input[type="submit"], button[type="submit"], input[name="signInSubmitButton"]') + .filter(':visible') + .first() + .click({ force: true }); + }); }); - // Wait for redirect back to app and allow configuration to load - cy.wait(2000); + // Wait for redirect back to app and OIDC token to be stored + // The app needs time to process the auth callback and store tokens + cy.url({ timeout: 30000 }).should('not.include', 'amazoncognito.com'); + + // Wait for OIDC token to appear in sessionStorage + cy.window({ timeout: 15000 }).should((win) => { + const hasOidcToken = Object.keys(win.sessionStorage).some((key) => + key.startsWith('oidc.user:') + ); + expect(hasOidcToken, 'OIDC token should be stored after login').to.equal(true); + }); }); }); }, @@ -149,10 +177,9 @@ Cypress.Commands.add('loginAs', (role = 'user') => { // The key format is: oidc.user:: // We check for any key starting with 'oidc.user:' since we don't have the exact values here cy.window().then((win) => { - const hasOidcToken = Object.keys(win.sessionStorage).some((key) => - key.startsWith('oidc.user:') - ); - expect(hasOidcToken).to.equal(true); + const sessionKeys = Object.keys(win.sessionStorage); + const oidcKey = sessionKeys.find((key) => key.startsWith('oidc.user:')); + expect(oidcKey, 'OIDC token should exist in sessionStorage').to.not.equal(undefined); }); }, cacheAcrossSpecs: false, diff --git a/cypress/src/shared/specs/admin.shared.spec.ts b/cypress/src/shared/specs/admin.shared.spec.ts index d5c041a05..f7dd58632 100644 --- a/cypress/src/shared/specs/admin.shared.spec.ts +++ b/cypress/src/shared/specs/admin.shared.spec.ts @@ -23,6 +23,7 @@ import { navigateAndVerifyAdminPage, + navigateToAdminPage, expandAdminMenu, collapseAdminMenu, } from '../../support/adminHelpers'; @@ -41,12 +42,11 @@ export function runAdminTests (options: { }); it('Admin can access Configuration page', () => { - navigateAndVerifyAdminPage( - 'Configuration', - '/configuration', - 'Configuration', - 'custom' - ); + navigateToAdminPage('Configuration'); + cy.url().should('include', '/configuration'); + + // Check for the submit button which is always visible + cy.get('[data-testid="configuration-submit"]').should('be.visible'); }); it('Model Management page loads and shows model cards', () => { diff --git a/cypress/src/shared/specs/bedrock-model-workflow.shared.spec.ts b/cypress/src/shared/specs/bedrock-model-workflow.shared.spec.ts index 5a5176309..2fa6ac57e 100644 --- a/cypress/src/shared/specs/bedrock-model-workflow.shared.spec.ts +++ b/cypress/src/shared/specs/bedrock-model-workflow.shared.spec.ts @@ -55,17 +55,12 @@ import { completePromptTemplateWizard, waitForPromptTemplateCreationSuccess, verifyPromptTemplateInList, - deletePromptTemplateIfExists, selectPromptTemplateInChat, - promptTemplateExists, PromptTemplateType, } from '../../support/promptTemplateHelpers'; import { - CollectionConfig, - navigateToRagManagement, waitForRepositoryReady, getAutoCreatedCollectionInfo, - renameCollection, uploadDocument, waitForDocumentIngested, selectRagRepositoryInChat, @@ -93,37 +88,22 @@ const DEFAULT_TEST_MODEL: BedrockModelConfig = { export type BedrockWorkflowTestOptions = { modelConfig?: BedrockModelConfig; repositoryConfig?: RepositoryConfig; - collectionConfig?: CollectionConfig; promptTemplateConfig?: PromptTemplateConfig; skipChat?: boolean; skipCleanup?: boolean; testDocumentPath?: string; }; -export function runBedrockModelWorkflowTests (options: BedrockWorkflowTestOptions = {}) { +/** + * Quick tests: model wizard, prompt templates, and chat with persona/directive. + * No infrastructure provisioning or long waits. Suitable for nightly runs. + */ +export function runBedrockQuickTests (options: BedrockWorkflowTestOptions = {}) { const dateString = getTodayDateString(); const testModel = options.modelConfig || DEFAULT_TEST_MODEL; - const testRepository: RepositoryConfig = options.repositoryConfig || { - repositoryId: `e2e-repo-${dateString}`, - knowledgeBaseName: 'test-bedrock-kb', - dataSourceIndex: 0, - }; - const testCollection: CollectionConfig = options.collectionConfig || { - collectionId: `e2e-collection-${dateString}`, - collectionName: `E2E Test Collection ${dateString}`, - repositoryId: testRepository.repositoryId, - }; - const testDocumentPath = options.testDocumentPath || 'test-document.txt'; - // Track test state for dependencies const testState = { modelCreated: false, - repositoryCreated: false, - repositoryReady: false, - collectionRenamed: false, - collectionId: '', // Store the actual collection ID - documentUploaded: false, - documentIngested: false, personaTemplateCreated: false, directiveTemplateCreated: false, }; @@ -201,6 +181,104 @@ Respond with only one phrase per message, chosen randomly. Treat every input as verifyModelInList(testModel.modelId); }); + it('Admin creates a persona prompt template (or uses existing)', () => { + navigateToPromptTemplates(); + + // Wait for prompt templates API to load and check if template already exists + cy.wait('@getPromptTemplates', { timeout: 30000 }).then((interception) => { + const templates = (interception.response?.body as { templates?: any[] })?.templates ?? []; + const templateExists = templates.some((template: any) => template.title === testPromptTemplatePersona.title); + + if (templateExists) { + cy.log(`Prompt template "${testPromptTemplatePersona.title}" already exists, skipping creation`); + testState.personaTemplateCreated = true; + } else { + openCreatePromptTemplateWizard(); + fillPromptTemplateConfig(testPromptTemplatePersona); + completePromptTemplateWizard(); + waitForPromptTemplateCreationSuccess(testPromptTemplatePersona.title); + testState.personaTemplateCreated = true; + } + }); + }); + + it('Persona prompt template appears in Prompt Templates list', function () { + if (!testState.personaTemplateCreated) { + this.skip(); + } + + navigateToPromptTemplates(); + cy.wait('@getPromptTemplates', { timeout: 30000 }); + verifyPromptTemplateInList(testPromptTemplatePersona.title); + }); + + it('Admin creates a directive prompt template (or uses existing)', () => { + navigateToPromptTemplates(); + + // Wait for prompt templates API to load and check if template already exists + cy.wait('@getPromptTemplates', { timeout: 30000 }).then((interception) => { + const templates = (interception.response?.body as { templates?: any[] })?.templates ?? []; + const templateExists = templates.some((template: any) => template.title === testPromptTemplateDirective.title); + + if (templateExists) { + cy.log(`Prompt template "${testPromptTemplateDirective.title}" already exists, skipping creation`); + testState.directiveTemplateCreated = true; + } else { + openCreatePromptTemplateWizard(); + fillPromptTemplateConfig(testPromptTemplateDirective); + completePromptTemplateWizard(); + waitForPromptTemplateCreationSuccess(testPromptTemplateDirective.title); + testState.directiveTemplateCreated = true; + } + }); + }); + + it('Directive prompt template appears in Prompt Templates list', function () { + if (!testState.directiveTemplateCreated) { + this.skip(); + } + + navigateToPromptTemplates(); + cy.wait('@getPromptTemplates', { timeout: 30000 }); + verifyPromptTemplateInList(testPromptTemplateDirective.title); + }); + + it('Send chat message with persona and directive', () => { + navigateAndVerifyChatPage(); + selectModelInChat(testModel.modelId); + + // Apply the Magic 8 Ball persona (system prompt) + selectPromptTemplateInChat(testPromptTemplatePersona.title, PromptTemplateType.Persona); + selectPromptTemplateInChat(testPromptTemplateDirective.title, PromptTemplateType.Directive); + sendMessageWithButton(); + verifyChatResponseReceived(); + }); +} + +/** + * Infrastructure tests: repository creation, collection management, document ingestion, and RAG chat. + * These involve long waits (up to 5 min each) for provisioning. Suitable for weekly/release runs. + */ +export function runBedrockInfraTests (options: BedrockWorkflowTestOptions = {}) { + const dateString = getTodayDateString(); + const testModel = options.modelConfig || DEFAULT_TEST_MODEL; + const testRepository: RepositoryConfig = options.repositoryConfig || { + repositoryId: `e2e-repo-${dateString}`, + knowledgeBaseName: 'test-bedrock-kb', + dataSourceIndex: 0, + }; + const testDocumentPath = options.testDocumentPath || 'test-document.txt'; + + const testState = { + repositoryCreated: false, + repositoryReady: false, + collectionReady: false, + collectionId: '', + collectionName: '', + documentUploaded: false, + documentIngested: false, + }; + it('Admin creates a Bedrock Knowledgebase repository (or uses existing)', () => { navigateToRepositoryManagement(); @@ -215,12 +293,16 @@ Respond with only one phrase per message, chosen randomly. Treat every input as } else { openCreateRepositoryWizard(); fillRepositoryConfig(testRepository); - selectKnowledgeBase(testRepository.knowledgeBaseName); - selectDataSource(testRepository.dataSourceIndex); - skipToCreateRepository(); - completeRepositoryWizard(); - waitForRepositoryCreationSuccess(testRepository.repositoryId); - testState.repositoryCreated = true; + + selectKnowledgeBase(testRepository.knowledgeBaseName).then((kbSelected) => { + expect(kbSelected, `Knowledge Base "${testRepository.knowledgeBaseName}" should be available`).to.equal(true); + + selectDataSource(testRepository.dataSourceIndex); + skipToCreateRepository(); + completeRepositoryWizard(); + waitForRepositoryCreationSuccess(testRepository.repositoryId); + testState.repositoryCreated = true; + }); } }); }); @@ -240,38 +322,36 @@ Respond with only one phrase per message, chosen randomly. Treat every input as } navigateToRepositoryManagement(); - waitForRepositoryReady(testRepository.repositoryId, 300000); + waitForRepositoryReady(testRepository.repositoryId, 1200000); testState.repositoryReady = true; }); - it('Rename auto-created collection to known name', function () { + it('Get auto-created default collection info', function () { if (!testState.repositoryReady) { this.skip(); } - navigateToRagManagement(); - - // Get the auto-created collection info (name and ID) and rename it + // Fetch the default collection's name and ID via API getAutoCreatedCollectionInfo(testRepository.repositoryId).then((collectionInfo) => { - cy.log(`Auto-created collection: ${collectionInfo.name} (ID: ${collectionInfo.id})`); - testState.collectionId = collectionInfo.id; // Store the collection ID - renameCollection(collectionInfo.name, testCollection.collectionName); - testState.collectionRenamed = true; + cy.log(`Default collection: ${collectionInfo.name} (ID: ${collectionInfo.id})`); + testState.collectionId = collectionInfo.id; + testState.collectionName = collectionInfo.name; + testState.collectionReady = true; }); }); it('Upload test document to collection via chat page', function () { - if (!testState.collectionRenamed) { + if (!testState.collectionReady) { this.skip(); } // Navigate to chat page navigateAndVerifyChatPage(); - // Select model, repository, and collection + // Select model, repository, and collection (use actual default collection name) selectModelInChat(testModel.modelId); selectRagRepositoryInChat(testRepository.repositoryId); - selectCollectionInChat(testCollection.collectionName); + selectCollectionInChat(testState.collectionName); // Upload the document uploadDocument(testDocumentPath); @@ -286,136 +366,15 @@ Respond with only one phrase per message, chosen randomly. Treat every input as testState.documentIngested = true; }); - it('Admin creates a persona prompt template', () => { - navigateToPromptTemplates(); - - promptTemplateExists(testPromptTemplatePersona.title).then((exists) => { - if (exists) { - cy.log(`Prompt template ${testPromptTemplatePersona.title} already exists, skipping creation`); - return; - } - - openCreatePromptTemplateWizard(); - fillPromptTemplateConfig(testPromptTemplatePersona); - completePromptTemplateWizard(); - waitForPromptTemplateCreationSuccess(testPromptTemplatePersona.title); - }); - }); - - it('Rename auto-created collection to known name', function () { - if (!testState.repositoryReady) { - this.skip(); - } - - navigateToRagManagement(); - - // Get the auto-created collection info (name and ID) and rename it - getAutoCreatedCollectionInfo(testRepository.repositoryId).then((collectionInfo) => { - cy.log(`Auto-created collection: ${collectionInfo.name} (ID: ${collectionInfo.id})`); - testState.collectionId = collectionInfo.id; // Store the collection ID - renameCollection(collectionInfo.name, testCollection.collectionName); - testState.collectionRenamed = true; - }); - }); - - it('Upload test document to collection via chat page', function () { - if (!testState.collectionRenamed) { - this.skip(); - } - - // Navigate to chat page - navigateAndVerifyChatPage(); - - // Select model, repository, and collection - selectModelInChat(testModel.modelId); - selectRagRepositoryInChat(testRepository.repositoryId); - selectCollectionInChat(testCollection.collectionName); - - // Upload the document - uploadDocument(testDocumentPath); - testState.documentUploaded = true; - }); - - it('Wait for document to be ingested', function () { - if (!testState.documentUploaded) { - this.skip(); - } - - waitForDocumentIngested(testRepository.repositoryId, testState.collectionId, testDocumentPath, 300000); - testState.documentIngested = true; - }); - - it('Admin creates a persona prompt template (or uses existing)', () => { - navigateToPromptTemplates(); - - // Wait for prompt templates API to load and check if template already exists - cy.wait('@getPromptTemplates', { timeout: 30000 }).then((interception) => { - const templates = interception.response?.body || []; - const templateExists = templates.some((template: any) => template.title === testPromptTemplatePersona.title); - - if (templateExists) { - cy.log(`Prompt template "${testPromptTemplatePersona.title}" already exists, skipping creation`); - testState.personaTemplateCreated = true; - } else { - openCreatePromptTemplateWizard(); - fillPromptTemplateConfig(testPromptTemplatePersona); - completePromptTemplateWizard(); - waitForPromptTemplateCreationSuccess(testPromptTemplatePersona.title); - testState.personaTemplateCreated = true; - } - }); - }); - - it('Persona prompt template appears in Prompt Templates list', function () { - if (!testState.personaTemplateCreated) { - this.skip(); - } - - navigateToPromptTemplates(); - verifyPromptTemplateInList(testPromptTemplatePersona.title); - }); - - it('Admin creates a directive prompt template (or uses existing)', () => { - navigateToPromptTemplates(); - - promptTemplateExists(testPromptTemplateDirective.title).then((exists) => { - if (exists) { - cy.log(`Prompt template ${testPromptTemplateDirective.title} already exists, skipping creation`); - return; - } - - openCreatePromptTemplateWizard(); - fillPromptTemplateConfig(testPromptTemplateDirective); - completePromptTemplateWizard(); - waitForPromptTemplateCreationSuccess(testPromptTemplateDirective.title); - }); - }); - - it('Directive prompt template appears in Prompt Templates list', function () { - if (!testState.directiveTemplateCreated) { + it('Send chat message with rag response', function () { + if (!testState.documentIngested) { this.skip(); } - navigateToPromptTemplates(); - verifyPromptTemplateInList(testPromptTemplateDirective.title); - }); - - it('Send chat message with persona and directive', () => { - navigateAndVerifyChatPage(); - selectModelInChat(testModel.modelId); - - // Apply the Magic 8 Ball persona (system prompt) - selectPromptTemplateInChat(testPromptTemplatePersona.title, PromptTemplateType.Persona); - selectPromptTemplateInChat(testPromptTemplateDirective.title, PromptTemplateType.Directive); - sendMessageWithButton(); - verifyChatResponseReceived(); - }); - - it('Send chat message with rag response', () => { navigateAndVerifyChatPage(); selectModelInChat(testModel.modelId); selectRagRepositoryInChat(testRepository.repositoryId); - selectCollectionInChat(testCollection.collectionName); + selectCollectionInChat(testState.collectionName); insertChatPrompt('Who is Whiskers?'); sendMessageWithButton(); verifyChatResponseReceived(); @@ -433,18 +392,6 @@ Respond with only one phrase per message, chosen randomly. Treat every input as deleteRepositoryIfExists(testRepository.repositoryId); }); - it('Cleanup: delete persona prompt template', () => { - navigateToPromptTemplates(); - cy.wait(2000); - deletePromptTemplateIfExists(testPromptTemplatePersona.title); - }); - - it('Cleanup: delete directive prompt template', () => { - navigateToPromptTemplates(); - cy.wait(2000); - deletePromptTemplateIfExists(testPromptTemplateDirective.title); - }); - it('Cleanup: delete test model', () => { navigateToAdminPage('Model Management'); cy.wait(2000); @@ -452,3 +399,12 @@ Respond with only one phrase per message, chosen randomly. Treat every input as }); } } + +/** + * Full workflow: runs both quick tests and infrastructure tests. + * Backward-compatible wrapper used by the full E2E spec (weekly/release). + */ +export function runBedrockModelWorkflowTests (options: BedrockWorkflowTestOptions = {}) { + runBedrockQuickTests(options); + runBedrockInfraTests(options); +} diff --git a/cypress/src/shared/specs/project.shared.spec.ts b/cypress/src/shared/specs/project.shared.spec.ts new file mode 100644 index 000000000..58d8b3171 --- /dev/null +++ b/cypress/src/shared/specs/project.shared.spec.ts @@ -0,0 +1,423 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/// + +/** + * Shared test suite for Projects Organization feature. + * Can be used by both smoke tests (with fixtures) and e2e tests (with real data). + */ + +import { + navigateToChatPage, + switchToProjectsView, + switchToHistoryView, + verifyCurrentView, + verifyViewPersistence, + createProject, + renameProject, + deleteProjectOnly, + deleteProjectWithSessions, + verifyProjectExists, + verifyProjectNotExists, +} from '../../support/projectHelpers'; + +export function runProjectsTests (options: { + verifyFixtureData?: boolean; +} = {}) { + const { verifyFixtureData = false } = options; + + describe('View Toggle & Navigation', () => { + it('should have configuration loaded with projectOrganization enabled', () => { + // First verify the intercept captured the configuration call + cy.wait('@getConfiguration', { timeout: 10000 }).its('response.body').then((body) => { + expect(body).to.be.an('array'); + expect(body[0]).to.have.nested.property('configuration.enabledComponents.projectOrganization', true); + }); + }); + + it('should display segmented control with History and Projects options', () => { + navigateToChatPage(); + + // Wait for projects to load + cy.wait('@getProjects', { timeout: 30000 }); + + // Verify both view options are visible + cy.get('[data-testid="project-history-toggle"]').should('be.visible'); + cy.get('[data-testid="history"]', { timeout: 5000 }).should('be.visible'); + cy.get('[data-testid="projects"]').should('be.visible'); + }); + + it('should switch between History and Projects views', () => { + navigateToChatPage(); + cy.wait('@getProjects', { timeout: 30000 }); + + // Default should be History + verifyCurrentView('history'); + + // Switch to Projects + switchToProjectsView(); + verifyCurrentView('projects'); + + // Switch back to History + switchToHistoryView(); + verifyCurrentView('history'); + }); + + it('should persist view selection to localStorage', () => { + navigateToChatPage(); + cy.wait('@getProjects', { timeout: 30000 }); + + // Switch to Projects view + switchToProjectsView(); + verifyViewPersistence('projects'); + + // Switch to History view + switchToHistoryView(); + verifyViewPersistence('history'); + }); + + it('should restore view selection after page refresh', () => { + navigateToChatPage(); + cy.wait('@getProjects', { timeout: 30000 }); + + // Switch to Projects view + switchToProjectsView(); + verifyCurrentView('projects'); + + // Refresh page + cy.reload(); + cy.wait('@getProjects', { timeout: 30000 }); + + // Should still be on Projects view + verifyCurrentView('projects'); + }); + }); + + describe('Create Projects', () => { + beforeEach(() => { + navigateToChatPage(); + cy.wait('@getProjects', { timeout: 30000 }); + switchToProjectsView(); + }); + + it('should open New Project modal from New dropdown', () => { + // Click New button + cy.get('[data-testid="new-session-dropdown"]').click(); + + // Click New Project menu item + cy.get('[data-testid="new-project"]').should('be.visible').click(); + + // Verify modal appears with correct header + cy.contains('h2', 'New Project', { timeout: 5000 }).should('be.visible'); + + // Cancel to close modal + cy.get('[data-testid="create-project-cancel"]').click(); + }); + + it('should create a new project successfully', () => { + const projectName = `Test Project ${Date.now()}`; + + createProject(projectName); + + // Verify project appears in list + verifyProjectExists(projectName); + }); + + it('should validate empty project name', () => { + // Open New Project + cy.get('[data-testid="new-session-dropdown"]').click(); + cy.get('[data-testid="new-project"]').click(); + + // Wait for modal + cy.contains('h2', 'New Project', { timeout: 5000 }).should('be.visible'); + + // Try to confirm with empty name - Create button should be disabled + cy.get('[data-testid="create-project-input"]').should('be.visible').clear(); + cy.get('button').filter(':visible').contains('Create').closest('button').should('be.disabled'); + + // Cancel to close modal + cy.get('[data-testid="create-project-cancel"]').click(); + }); + + if (verifyFixtureData) { + it('should display fixture projects', () => { + verifyProjectExists('Research'); + verifyProjectExists('Product Dev'); + }); + } + }); + + describe('Rename Projects', () => { + beforeEach(() => { + navigateToChatPage(); + cy.wait('@getProjects', { timeout: 30000 }); + switchToProjectsView(); + }); + + it('should open Rename modal with current name pre-filled', () => { + const projectName = verifyFixtureData ? 'Research' : 'Test Project'; + + // Skip if no projects exist + cy.get('body').then(($body) => { + if (!$body.text().includes(projectName)) { + cy.log('Skipping: No projects found'); + return; + } + + // Open project actions menu + cy.get(`[aria-label="Project actions for ${projectName}"]`) + .first() + .should('be.visible') + .click(); + + // Click Rename + cy.get('[data-testid="rename"]').click(); + + // Verify modal with current name + cy.contains('h2', 'Rename Project', { timeout: 5000 }).should('be.visible'); + cy.get('[data-testid="rename-project-input"] input').should('have.value', projectName); + + // Cancel + cy.get('[data-testid="rename-project-cancel"]').click(); + }); + }); + + it('should rename a project successfully', () => { + const originalName = `Original ${Date.now()}`; + const newName = `Renamed ${Date.now()}`; + + // Create a project first + createProject(originalName); + verifyProjectExists(originalName); + + // Rename it + renameProject(originalName, newName); + + // Verify new name appears and old name is gone + verifyProjectExists(newName); + verifyProjectNotExists(originalName); + }); + }); + + describe('Delete Projects', () => { + beforeEach(() => { + navigateToChatPage(); + cy.wait('@getProjects', { timeout: 30000 }); + switchToProjectsView(); + }); + + it('should show delete modal with two options', () => { + const projectName = `Delete Test ${Date.now()}`; + + // Create a project + createProject(projectName); + + // Open delete modal + cy.get(`[aria-label="Project actions for ${projectName}"]`) + .first() + .should('be.visible') + .click(); + cy.get('[data-testid="delete"]').first().click(); + + // Verify modal shows both delete options as buttons + cy.contains('h2', 'Delete Project', { timeout: 5000 }).should('be.visible'); + cy.contains('button', 'Delete project only').should('be.visible'); + cy.contains('button', 'Delete project and sessions').should('be.visible'); + + // Cancel + cy.get('[data-testid="delete-project-cancel"]').click(); + }); + + it('should delete project only (keep sessions)', () => { + // Use existing Product Dev project - first assign a session to it + const projectName = 'Product Dev'; + const sessionName = 'How do I get started'; // Partial match for truncated display + + verifyProjectExists(projectName); + + // Assign a session to this project first + switchToHistoryView(); + cy.contains('Last 3 Months').click(); // Expand section + // Find session row and click its actions button (3 dots) + cy.contains('[data-testid="session-item"]', sessionName, { timeout: 5000 }) + .first() + .should('be.visible') + .find('[aria-label="Control instance"]') + .first() + .click(); + // Projects are listed directly under "Add to Project" category - click the project name + cy.contains('[role="menuitem"]', projectName).click(); + cy.wait('@assignSession'); + + // Switch back to Projects view and delete project only + switchToProjectsView(); + deleteProjectOnly(projectName); + + // Verify project is gone + verifyProjectNotExists(projectName); + + // Verify session still exists in History view (no longer has project badge) + switchToHistoryView(); + cy.contains('Last 3 Months').click(); // Expand section again + cy.contains('[data-testid="session-item"]', sessionName, { timeout: 5000 }).first().should('be.visible'); + }); + + it('should delete project with all sessions', () => { + // Use existing Research project from fixtures which has sessions assigned + const projectName = 'Research'; + const sessionName = 'Technical Discussion'; + + verifyProjectExists(projectName); + + // Verify session exists before delete + switchToHistoryView(); + cy.contains('Last 3 Months').click(); // Expand section + cy.contains(sessionName, { timeout: 5000 }).should('exist'); + + // Switch back and delete project with all sessions + switchToProjectsView(); + deleteProjectWithSessions(projectName); + + // Verify project is gone + verifyProjectNotExists(projectName); + + // Reload page to force refetch of sessions with updated mock data + cy.reload(); + cy.wait('@getSessions'); + cy.wait('@getProjects'); + + // Verify session is gone from History view + switchToHistoryView(); + cy.contains('Last 3 Months').click(); // Expand section + cy.contains(sessionName, { timeout: 5000 }).should('not.exist'); + }); + }); + + describe('Assign Sessions to Projects', () => { + beforeEach(() => { + navigateToChatPage(); + cy.wait('@getSessions', { timeout: 30000 }); + cy.wait('@getProjects', { timeout: 30000 }); + }); + + it('should show "Add to Project" in session context menu', () => { + switchToHistoryView(); + cy.contains('Last 3 Months').click(); // Expand section + + // Use a session not assigned to a project + const sessionName = 'How do I get started'; + + // Find session and click its actions button + cy.contains('[data-testid="session-item"]', sessionName, { timeout: 5000 }) + .first() + .should('be.visible') + .find('[aria-label="Control instance"]') + .first() + .click(); + + // Verify "Add to Project" menu category exists (nested dropdown item) + cy.contains('Add to Project').should('be.visible'); + + // Click elsewhere to close menu + cy.get('body').click(); + }); + + if (verifyFixtureData) { + it('should assign session to project from History view', () => { + switchToHistoryView(); + cy.contains('Last 3 Months').click(); // Expand section + + // Use a session not already assigned to a project + const sessionName = 'How do I get started'; + const projectName = 'Research'; + + // Find session and click its actions button + cy.contains('[data-testid="session-item"]', sessionName, { timeout: 5000 }) + .first() + .should('be.visible') + .find('[aria-label="Control instance"]') + .first() + .click(); + + // Click the project name in the dropdown menu (not the badge) + cy.get('[role="menu"]').contains(projectName).click(); + cy.wait('@assignSession'); + + // Switch to Projects view and verify session appears there + switchToProjectsView(); + cy.contains(projectName).should('be.visible'); + }); + } + + it('should display session in both History and Projects views', () => { + if (!verifyFixtureData) { + cy.log('Skipping: Requires fixture data'); + return; + } + + const sessionName = 'Technical Discussion'; + const projectName = 'Research'; + + // Verify in History view + switchToHistoryView(); + cy.contains('Last 3 Months').click(); // Expand section + cy.contains('[data-testid="session-item"]', sessionName, { timeout: 5000 }).should('be.visible'); + + // Verify in Projects view + switchToProjectsView(); + cy.contains(projectName).should('be.visible'); + }); + }); + + describe('Remove Sessions from Projects', () => { + beforeEach(() => { + navigateToChatPage(); + cy.wait('@getSessions', { timeout: 30000 }); + cy.wait('@getProjects', { timeout: 30000 }); + }); + + if (verifyFixtureData) { + it('should remove session from project in History view', () => { + switchToHistoryView(); + cy.contains('Last 3 Months').click(); // Expand section + + // Use a session that is assigned to a project + const sessionName = 'Technical Discussion'; + + // Verify session has project badge before removal + cy.contains('[data-testid="session-item"]', sessionName, { timeout: 5000 }) + .first() + .should('be.visible') + .should('contain', 'Research'); + + // Find session and click its actions button + cy.contains('[data-testid="session-item"]', sessionName, { timeout: 5000 }) + .first() + .find('[aria-label="Control instance"]') + .first() + .click(); + + // Click Remove from Project + cy.get('[role="menu"]').contains('Remove from Project').click(); + + // Verify the unassign API was called with correct body + cy.wait('@assignSession').its('request.body').should('deep.equal', { unassign: true }); + }); + } + }); + +} diff --git a/cypress/src/shared/specs/rag-admin.shared.spec.ts b/cypress/src/shared/specs/rag-admin.shared.spec.ts new file mode 100644 index 000000000..bb1f9bd8a --- /dev/null +++ b/cypress/src/shared/specs/rag-admin.shared.spec.ts @@ -0,0 +1,95 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/// + +/** + * Shared test suite for RAG Admin role features. + * RAG Admins see the Administration dropdown with only RAG Management. + * They cannot access admin-only pages but can access chat. + * + * Can be used by both smoke tests (with fixtures) and e2e tests (with real data). + */ + +import { expandRagAdminMenu } from '../../support/adminHelpers'; +import { waitForContentToLoad, verifyCloudscapeTableHasData } from '../../support/dataHelpers'; +import { navigateAndVerifyChatPage } from '../../support/chatHelpers'; + +const ADMIN_MENU_SELECTOR = '[role="menu"][aria-label="Administration"]'; +const MENU_ITEM_SELECTOR = '[role="menuitem"]'; + +export function runRagAdminTests (options: { + expectMinItems?: boolean; + verifyFixtureData?: boolean; +} = {}) { + const { expectMinItems = false, verifyFixtureData = false } = options; + + it('RAG Admin sees Administration button with only RAG Management', () => { + expandRagAdminMenu(); + }); + + it('RAG Admin can navigate to RAG Management page', () => { + const minItems = expectMinItems ? 3 : 0; + + // Use expandRagAdminMenu to wait for stable header and open the correct menu + expandRagAdminMenu(); + + // Click RAG Management from the open menu + cy.get(ADMIN_MENU_SELECTOR, { timeout: 10000 }) + .filter(':visible') + .contains(MENU_ITEM_SELECTOR, 'RAG Management') + .click(); + + cy.url().should('include', '/repository-management'); + cy.wait('@getRepositories', { timeout: 10000 }); + waitForContentToLoad(); + + if (minItems > 0) { + verifyCloudscapeTableHasData(minItems); + } + + if (verifyFixtureData) { + cy.contains('Technical Documentation').should('be.visible'); + cy.contains('Product Knowledge Base').should('be.visible'); + cy.contains('Training Materials').should('be.visible'); + } + }); + + it('RAG Admin cannot access admin-only pages', () => { + const adminOnlyPaths = [ + '#/configuration', + '#/model-management', + '#/api-token-management', + '#/mcp-management', + '#/mcp-workbench', + ]; + + adminOnlyPaths.forEach((path) => { + cy.visit(path, { failOnStatusCode: false, timeout: 10000 }); + const stripped = path.replace('#/', ''); + + cy.url({ timeout: 10000 }).should('satisfy', (url: string) => { + return !url.includes(stripped) || + url.includes('access-denied') || + url.includes('unauthorized'); + }, `Expected rag-admin to be redirected from ${path}`); + }); + }); + + it('RAG Admin can access chat', () => { + navigateAndVerifyChatPage(); + }); +} diff --git a/cypress/src/shared/specs/user.shared.spec.ts b/cypress/src/shared/specs/user.shared.spec.ts index adbfc890e..13b27c2dd 100644 --- a/cypress/src/shared/specs/user.shared.spec.ts +++ b/cypress/src/shared/specs/user.shared.spec.ts @@ -25,7 +25,10 @@ import { checkNoAdminButton } from '../../support/adminHelpers'; -export function runUserTests () { +// eslint-disable-next-line @typescript-eslint/no-unused-vars +export function runUserTests (options: { + verifyFixtureData?: boolean; +} = {}) { it('Non-admin does not see the Administration button', () => { // Wait for configuration to load before checking UI // cy.wait('@getConfiguration', { timeout: 30000 }); diff --git a/cypress/src/smoke/fixtures/configuration.json b/cypress/src/smoke/fixtures/configuration.json index b2980009e..219dcf80a 100644 --- a/cypress/src/smoke/fixtures/configuration.json +++ b/cypress/src/smoke/fixtures/configuration.json @@ -17,13 +17,17 @@ "editNumOfRagDocument": true, "editChatHistoryBuffer": true, "uploadRagDocs": true, + "ragSelectionAvailable": true, "uploadContextDocs": true, "documentSummarization": true, "showRagLibrary": true, "showPromptTemplateLibrary": true, "mcpConnections": true, - "showMcpWorkbench": true - } + "awsSessions": false, + "showMcpWorkbench": true, + "projectOrganization": true + }, + "maxProjectsPerUser": 50 }, "changedBy": "System", "configScope": "global", diff --git a/cypress/src/smoke/fixtures/env.json b/cypress/src/smoke/fixtures/env.json index 4f1115b9b..ad6dd8827 100644 --- a/cypress/src/smoke/fixtures/env.json +++ b/cypress/src/smoke/fixtures/env.json @@ -11,5 +11,6 @@ "HOSTED_MCP_ENABLED": true, "API_BASE_URL": "/dev/", "USE_CUSTOM_BRANDING": false, - "CUSTOM_DISPLAY_NAME": "LISA" + "CUSTOM_DISPLAY_NAME": "LISA", + "RAG_ADMIN_GROUP": "rag-admin" } diff --git a/cypress/src/smoke/fixtures/project.json b/cypress/src/smoke/fixtures/project.json new file mode 100644 index 000000000..6181123d5 --- /dev/null +++ b/cypress/src/smoke/fixtures/project.json @@ -0,0 +1,14 @@ +[ + { + "projectId": "proj-001", + "name": "Research", + "_createDaysAgo": 45, + "_updatedDaysAgo": 30 + }, + { + "projectId": "proj-002", + "name": "Product Dev", + "_createDaysAgo": 50, + "_updatedDaysAgo": 35 + } +] diff --git a/cypress/src/smoke/fixtures/repository.json b/cypress/src/smoke/fixtures/repository.json index a5868956c..3ceff179c 100644 --- a/cypress/src/smoke/fixtures/repository.json +++ b/cypress/src/smoke/fixtures/repository.json @@ -5,7 +5,7 @@ "type": "pgvector", "embeddingModelId": "titan-embed", "status": "UPDATE_COMPLETE", - "allowedGroups": ["admin"], + "allowedGroups": ["admin", "rag-admin"], "metadata": { "tags": [] }, @@ -38,7 +38,7 @@ "type": "opensearch", "embeddingModelId": "e5-embed", "status": "UPDATE_COMPLETE", - "allowedGroups": ["admin"], + "allowedGroups": ["admin", "rag-admin"], "metadata": { "tags": ["open-rag"] }, diff --git a/cypress/src/smoke/fixtures/session.json b/cypress/src/smoke/fixtures/session.json index 97d001066..677b14f2a 100644 --- a/cypress/src/smoke/fixtures/session.json +++ b/cypress/src/smoke/fixtures/session.json @@ -3,27 +3,29 @@ "sessionId": "f56fc284-629c-4ba7-ab3d-56f4a21c13ee", "name": "Technical Discussion", "firstHumanMessage": "What is the difference between REST and GraphQL?", - "startTime": "2026-01-02T08:30:00.000000+00:00", - "createTime": "2026-01-02T08:30:00.000000+00:00", - "lastUpdated": "2026-01-02T09:15:00.000000+00:00", + "_startDaysAgo": 46, + "_updatedDaysAgo": 45, + "_expectedBucket": "Last 3 Months", + "projectId": "proj-001", "isEncrypted": false }, { "sessionId": "a1b2c3d4-e5f6-7890-abcd-ef1234567890", "name": "Product Questions", "firstHumanMessage": "Tell me about the product features", - "startTime": "2026-01-01T14:20:00.000000+00:00", - "createTime": "2026-01-01T14:20:00.000000+00:00", - "lastUpdated": "2026-01-01T15:45:00.000000+00:00", + "_startDaysAgo": 61, + "_updatedDaysAgo": 60, + "_expectedBucket": "Last 3 Months", + "projectId": "proj-001", "isEncrypted": false }, { "sessionId": "12345678-90ab-cdef-1234-567890abcdef", "name": null, "firstHumanMessage": "How do I get started with the platform?", - "startTime": "2025-12-28T10:00:00.000000+00:00", - "createTime": "2025-12-28T10:00:00.000000+00:00", - "lastUpdated": "2025-12-28T11:30:00.000000+00:00", + "_startDaysAgo": 51, + "_updatedDaysAgo": 50, + "_expectedBucket": "Last 3 Months", "isEncrypted": false } ] diff --git a/cypress/src/smoke/specs/project.smoke.spec.ts b/cypress/src/smoke/specs/project.smoke.spec.ts new file mode 100644 index 000000000..7524b8b76 --- /dev/null +++ b/cypress/src/smoke/specs/project.smoke.spec.ts @@ -0,0 +1,38 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/// + +/** + * Smoke test suite for Projects Organization feature. + * Uses shared test suite with fixture-based data validation. + */ + +import { runProjectsTests } from '../../shared/specs/project.shared.spec'; + +describe('Projects Organization (Smoke)', () => { + beforeEach(() => { + cy.loginAs('user'); + }); + + after(() => { + cy.clearAllSessionStorage(); + }); + + runProjectsTests({ + verifyFixtureData: true, + }); +}); diff --git a/cypress/src/smoke/specs/rag-admin.smoke.spec.ts b/cypress/src/smoke/specs/rag-admin.smoke.spec.ts new file mode 100644 index 000000000..0998546a0 --- /dev/null +++ b/cypress/src/smoke/specs/rag-admin.smoke.spec.ts @@ -0,0 +1,39 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/// + +/** + * Smoke test suite for RAG Admin role features. + * Uses shared test suite with fixture data verification enabled. + */ + +import { runRagAdminTests } from '../../shared/specs/rag-admin.shared.spec'; + +describe('RAG Admin Navigation (Smoke)', () => { + beforeEach(() => { + cy.loginAs('rag-admin'); + }); + + after(() => { + cy.clearAllSessionStorage(); + }); + + runRagAdminTests({ + expectMinItems: true, + verifyFixtureData: true, + }); +}); diff --git a/cypress/src/smoke/support/commands.ts b/cypress/src/smoke/support/commands.ts index 05fe3dc5e..c62c9ee78 100644 --- a/cypress/src/smoke/support/commands.ts +++ b/cypress/src/smoke/support/commands.ts @@ -18,14 +18,12 @@ import { randomUUID, randomString, toBase64Url } from './utils'; -// API endpoints with their aliases (matching E2E pattern) const API_STUBS = [ { endpoint: 'models', alias: 'getModels' }, { endpoint: 'prompt-templates', alias: 'getPromptTemplates' }, { endpoint: 'repository', alias: 'getRepositories' }, { endpoint: 'configuration', alias: 'getConfiguration' }, { endpoint: 'health', alias: 'getHealth' }, - { endpoint: 'session', alias: 'getSessions' }, { endpoint: 'api-tokens', alias: 'getApiTokens' }, { endpoint: 'mcp', alias: 'getMcp' }, { endpoint: 'mcp-server', alias: 'getMcpServers' }, @@ -33,6 +31,160 @@ const API_STUBS = [ { endpoint: 'collections', alias: 'getCollections' }, ]; +// Stateful mock data for projects +let mockProjects: Array<{ + projectId: string; + name: string; + createTime: string; + lastUpdated: string; +}> = []; + +// Stateful mock data for sessions +let mockSessions: Array<{ + sessionId: string; + name: string | null; + firstHumanMessage: string; + startTime: string; + createTime: string; + lastUpdated: string; + projectId?: string; + isEncrypted: boolean; +}> = []; + +/** + * Setup stateful project stubs that track mutations. + */ +function setupProjectStubs (apiBase: string) { + // Initialize from fixture with dynamic dates computed from _*DaysAgo metadata + cy.fixture('project.json').then((fixtureProjects) => { + mockProjects = fixtureProjects.map(applyDateOffsets) as typeof mockProjects; + }); + + // GET projects - returns current state + cy.intercept('GET', `**${apiBase}/project*`, (req) => { + req.reply({ body: mockProjects }); + }).as('getProjects'); + + // POST - create project and add to state + cy.intercept('POST', `**${apiBase}/project`, (req) => { + const newProject = { + projectId: randomUUID(), + name: req.body?.name || 'New Project', + createTime: new Date().toISOString(), + lastUpdated: new Date().toISOString(), + }; + mockProjects.push(newProject); + req.reply({ statusCode: 201, body: newProject }); + }).as('createProject'); + + // PUT - update project name in state + cy.intercept('PUT', new RegExp(`${apiBase}/project/[^/]+$`), (req) => { + const projectId = req.url.split('/').pop()?.split('?')[0]; + const idx = mockProjects.findIndex((p) => p.projectId === projectId); + if (idx >= 0 && req.body?.name) { + mockProjects[idx].name = req.body.name; + mockProjects[idx].lastUpdated = new Date().toISOString(); + } + req.reply({ statusCode: 200, body: mockProjects[idx] || { message: 'Updated' } }); + }).as('updateProject'); + + // DELETE - remove project from state, optionally delete sessions + cy.intercept('DELETE', `**${apiBase}/project/*`, (req) => { + const url = new URL(req.url); + const pathParts = url.pathname.split('/'); + const projectId = pathParts[pathParts.length - 1]; + const deleteSessions = req.body?.deleteSessions === true; + + // Remove project + mockProjects = mockProjects.filter((p) => p.projectId !== projectId); + + // If deleteSessions is true, remove sessions with this projectId + if (deleteSessions) { + mockSessions = mockSessions.filter((s) => s.projectId !== projectId); + } else { + // Just unassign sessions from this project + mockSessions = mockSessions.map((s) => + s.projectId === projectId ? { ...s, projectId: undefined } : s + ); + } + + req.reply({ statusCode: 200, body: { message: 'Deleted' } }); + }).as('deleteProject'); + + // PUT session assignment + cy.intercept('PUT', `**${apiBase}/project/*/session/*`, (req) => { + const url = new URL(req.url); + const pathParts = url.pathname.split('/'); + const sessionId = pathParts[pathParts.length - 1]; + const projectId = pathParts[pathParts.length - 3]; + const unassign = req.body?.unassign === true; + + const idx = mockSessions.findIndex((s) => s.sessionId === sessionId); + if (idx >= 0) { + mockSessions[idx].projectId = unassign ? undefined : projectId; + } + + req.reply({ statusCode: 200, body: { message: 'Session assignment updated' } }); + }).as('assignSession'); +} + +/** + * Compute a date relative to now, offset by the given number of days. + */ +function daysAgo (days: number): string { + const date = new Date(); + date.setDate(date.getDate() - days); + return date.toISOString(); +} + +/** + * Transforms a fixture entry by converting underscore-prefixed day-offset + * metadata fields (_startDaysAgo, _updatedDaysAgo, _createDaysAgo) into + * real ISO date strings, then strips the metadata fields. + * + * This keeps fixture JSON files as the single source of truth for both + * API shape and timing intent. See Sessions.tsx for bucket boundaries: + * Last Day (<=1), Last 7 Days (<=7), Last Month (<=30), + * Last 3 Months (<=90), Older (>90). + */ +function applyDateOffsets (fixture: Record): Record { + const result = { ...fixture }; + + if (typeof result._startDaysAgo === 'number') { + result.startTime = daysAgo(result._startDaysAgo as number); + result.createTime = daysAgo(result._startDaysAgo as number); + } + if (typeof result._createDaysAgo === 'number') { + result.createTime = daysAgo(result._createDaysAgo as number); + } + if (typeof result._updatedDaysAgo === 'number') { + result.lastUpdated = daysAgo(result._updatedDaysAgo as number); + } + + // Strip metadata fields before using as mock API response + delete result._startDaysAgo; + delete result._createDaysAgo; + delete result._updatedDaysAgo; + delete result._expectedBucket; + + return result; +} + +/** + * Setup stateful session stubs that track mutations. + */ +function setupSessionStubs (apiBase: string) { + // Initialize from fixture with dynamic dates computed from _*DaysAgo metadata + cy.fixture('session.json').then((fixtureSessions) => { + mockSessions = fixtureSessions.map(applyDateOffsets) as typeof mockSessions; + }); + + // GET sessions - returns current state + cy.intercept('GET', `**${apiBase}/session*`, (req) => { + req.reply({ body: mockSessions }); + }).as('getSessions'); +} + /** * Setup API stubs for smoke tests. */ @@ -45,18 +197,23 @@ function setupApiStubs (env: Record) { headers: { 'Content-Type': 'application/javascript' }, }).as('stubEnv'); - // Stub all API endpoints with consistent aliases + // Stub all static API endpoints API_STUBS.forEach(({ endpoint, alias }) => { cy.intercept('GET', `**${apiBase}/${endpoint}*`, { fixture: `${endpoint}.json` }).as(alias); }); + + // Setup stateful project stubs + setupProjectStubs(apiBase); + + // Setup stateful session stubs + setupSessionStubs(apiBase); } /** * Build a mock OIDC user object. */ -function buildOidcUser (role: 'admin' | 'user', env: Record) { - const isAdmin = role === 'admin'; - const groups = isAdmin ? ['admin'] : ['user']; +function buildOidcUser (role: 'admin' | 'user' | 'rag-admin', env: Record) { + const groups = role === 'admin' ? ['admin'] : role === 'rag-admin' ? ['rag-admin'] : ['user']; const now = Math.floor(Date.now() / 1000); const jwtPayload = { @@ -93,7 +250,7 @@ function buildOidcUser (role: 'admin' | 'user', env: Record) { /** * Setup OIDC stubs for the login flow. */ -function setupOidcStubs (role: 'admin' | 'user', env: Record) { +function setupOidcStubs (role: 'admin' | 'user' | 'rag-admin', env: Record) { const oidcUser = buildOidcUser(role, env); // Stub OIDC discovery @@ -131,12 +288,12 @@ function setupOidcStubs (role: 'admin' | 'user', env: Record) { */ function waitForAppReady () { // Wait for "Loading configuration..." to disappear - cy.contains('Loading configuration...', { timeout: 15000 }).should('not.exist'); + cy.contains('Loading configuration...', { timeout: 30000 }).should('not.exist'); // Wait for spinners to disappear cy.get('body').then(($body) => { if ($body.find('[class*="awsui_spinner"]').length > 0) { - cy.get('[class*="awsui_spinner"]', { timeout: 10000 }).should('not.exist'); + cy.get('[class*="awsui_spinner"]', { timeout: 15000 }).should('not.exist'); } }); } @@ -144,7 +301,7 @@ function waitForAppReady () { /** * Custom command to log in a user via stubbed OIDC flow. */ -Cypress.Commands.add('loginAs', (role = 'user') => { +Cypress.Commands.add('loginAs', (role: 'admin' | 'user' | 'rag-admin' = 'user') => { cy.fixture('env.json').then((env) => { // Setup all stubs setupApiStubs(env); @@ -153,11 +310,13 @@ Cypress.Commands.add('loginAs', (role = 'user') => { // Visit the app cy.visit('/'); - // Click sign in to trigger OIDC flow - cy.contains('Sign in').click(); + cy.get('button', { timeout: 30000 }) + .contains('Sign in') + .should('be.visible') + .click({ force: true }); // Wait for the redirect and login to complete - cy.contains('Sign in', { timeout: 10000 }).should('not.exist'); + cy.get('button').contains('Sign in', { timeout: 20000 }).should('not.exist'); // Wait for app to be ready waitForAppReady(); diff --git a/cypress/src/smoke/support/index.ts b/cypress/src/smoke/support/index.ts index c2cc568ce..ac0f58c3e 100644 --- a/cypress/src/smoke/support/index.ts +++ b/cypress/src/smoke/support/index.ts @@ -28,7 +28,7 @@ declare global { * @param role - The role to simulate ('admin' or 'user') * @example cy.session('admin', () => cy.loginAs('admin')) */ - loginAs(role?: 'admin' | 'user'): Chainable; + loginAs(role?: 'admin' | 'user' | 'rag-admin'): Chainable; /** * Custom command to setup API stubs for a given role. @@ -36,7 +36,7 @@ declare global { * @param role - The role to simulate ('admin' or 'user') * @example cy.setupStubs('admin') */ - setupStubs(role?: 'admin' | 'user'): Chainable; + setupStubs(role?: 'admin' | 'user' | 'rag-admin'): Chainable; } } } diff --git a/cypress/src/support/adminHelpers.ts b/cypress/src/support/adminHelpers.ts index 93f76c187..e47bc1e2f 100644 --- a/cypress/src/support/adminHelpers.ts +++ b/cypress/src/support/adminHelpers.ts @@ -93,6 +93,40 @@ export function collapseAdminMenu () { cy.get(ADMIN_MENU_SELECTOR).should('not.be.visible'); } +/** + * Expand the admin menu for a RAG Admin user and verify only RAG Management is present. + * Admin-only items (Configuration, Model Management, etc.) should not appear. + */ +export function expandRagAdminMenu () { + getLibraryButton().should('be.visible'); + getAdminButton().should('be.visible'); + + getAdminButton() + .click() + .should('have.attr', 'aria-expanded', 'true'); + + // Cloudscape may render multiple menu elements (collapsed/expanded views). + // Filter to visible only to avoid asserting on hidden duplicates. + const ADMIN_ONLY_ITEMS = [ + 'Configuration', + 'Model Management', + 'API Token Management', + 'MCP Management', + 'MCP Workbench', + ]; + + cy.get(ADMIN_MENU_SELECTOR, { timeout: 10000 }) + .filter(':visible') + .should('have.length', 1) + .within(() => { + cy.get(MENU_ITEM_SELECTOR).filter(':visible').should('have.length', 1); + cy.contains(MENU_ITEM_SELECTOR, 'RAG Management').should('be.visible'); + ADMIN_ONLY_ITEMS.forEach((item) => { + cy.contains(MENU_ITEM_SELECTOR, item).should('not.exist'); + }); + }); +} + export function checkNoAdminButton () { // Use the specific selector for the Administration button cy.get('header button[aria-label="Administration"]').should('not.exist'); @@ -120,6 +154,8 @@ export function navigateToAdminPage (menuItemName: string) { export function verifyAdminPageLoaded (urlFragment: string, pageTitle?: string) { cy.url().should('include', urlFragment); + waitForContentToLoad(); + if (pageTitle) { cy.get('h1, h2, [data-testid="page-title"]') .should('be.visible') @@ -147,7 +183,6 @@ export function navigateAndVerifyAdminPage ( ) { navigateToAdminPage(menuItemName); verifyAdminPageLoaded(urlFragment, pageTitle); - waitForContentToLoad(); switch (contentType) { case 'table': diff --git a/cypress/src/support/chatHelpers.ts b/cypress/src/support/chatHelpers.ts index 2a71854c6..8b791bf40 100644 --- a/cypress/src/support/chatHelpers.ts +++ b/cypress/src/support/chatHelpers.ts @@ -21,27 +21,24 @@ // Chat page selectors export const CHAT_SELECTORS = { - MODEL_INPUT: 'input[placeholder*="model" i], input[aria-label*="model" i]', - RAG_REPO_INPUT: 'input#rag-repository-autosuggest, input[placeholder*="RAG Repository" i]', - COLLECTION_INPUT: 'input#collection-autosuggest, input[placeholder*="collection" i]', - MESSAGE_INPUT: 'textarea[placeholder*="message" i]', + MODEL_INPUT: '[data-testid="model-selection-autosuggest"] input, input[placeholder*="model" i], input[aria-label*="model" i]', + RAG_REPO_INPUT: '[data-testid="rag-repository-autosuggest"] input, input#rag-repository-autosuggest, input[placeholder*="RAG Repository" i]', + COLLECTION_INPUT: '[data-testid="rag-collection-autosuggest"] input, input#collection-autosuggest, input[placeholder*="collection" i]', + MESSAGE_INPUT: '[data-testid="chat-prompt-textarea"] textarea', DROPDOWN_OPTION: '[role="option"], [role="menuitem"]', }; /** - * Navigate to the AI Assistant (chat) page by clicking the menu item + * Navigate to the AI Assistant (chat) page */ export function navigateToChatPage () { - // For e2e tests, login should already direct to chat page - // For smoke tests, we may need to click the menu item // Check if we're already on the chat page cy.url().then((url) => { - if (!url.includes('/ai-assistant')) { - cy.get('a[aria-label="AI Assistant"]') - .eq(2) - .should('exist') - .and('be.visible') - .click(); + if (!url.includes('ai-assistant')) { + // Use client-side navigation to preserve auth state + cy.window().then((win) => { + win.location.hash = '#/ai-assistant'; + }); } }); } @@ -50,12 +47,12 @@ export function navigateToChatPage () { * Verify that the chat page has loaded correctly */ export function verifyChatPageLoaded () { - cy.url().should('include', '/ai-assistant'); + cy.url({ timeout: 10000 }).should('include', 'ai-assistant'); // Wait for the prompt input textarea to be visible // Use attribute selectors that are stable across builds - cy.get('textarea[placeholder*="message" i]') - .first() + // Allow extra time for lazy-loaded Chat route to render + cy.get(CHAT_SELECTORS.MESSAGE_INPUT, { timeout: 15000 }) .should('exist') .and('be.visible'); } @@ -69,8 +66,11 @@ export function waitForInitialDataLoad () { cy.get('[data-testid="loading"], .awsui-spinner, .loading', { timeout: 5000 }) .should('not.exist'); - // Give the page more time to stabilize after auth and initial API calls - cy.wait(3000); + // Wait for the model selection input to be ready (indicates models API has loaded) + cy.get(CHAT_SELECTORS.MODEL_INPUT, { timeout: 15000 }) + .first() + .should('be.visible') + .and('not.be.disabled'); } /** diff --git a/cypress/src/support/cleanupHelpers.ts b/cypress/src/support/cleanupHelpers.ts new file mode 100644 index 000000000..338be16fc --- /dev/null +++ b/cypress/src/support/cleanupHelpers.ts @@ -0,0 +1,156 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/** + * cleanupHelpers.ts + * API-based sweep cleanup for E2E test resources. + * Finds and deletes ALL resources matching the e2e- prefix, + * regardless of which test run created them. + */ + +import { makeAuthenticatedRequest } from './collectionHelpers'; + +const E2E_PREFIX = 'e2e-'; +const E2E_PROMPT_PREFIX = 'E2E '; + +/** + * Delete all models whose modelId starts with the E2E prefix. + */ +export function sweepE2eModels () { + cy.log('Sweeping E2E models...'); + makeAuthenticatedRequest('GET', '/models').then((response) => { + if (response.status !== 200) { + cy.log(`Failed to list models: ${response.status}`); + return; + } + + const models = response.body?.models ?? []; + const e2eModels = models.filter((m: any) => + typeof m.modelId === 'string' && m.modelId.startsWith(E2E_PREFIX) + ); + + cy.log(`Found ${e2eModels.length} E2E model(s) to clean up`); + + e2eModels.forEach((model: any) => { + cy.log(`Deleting model: ${model.modelId}`); + makeAuthenticatedRequest('DELETE', `/models/${model.modelId}`).then((delResp) => { + if (delResp.status >= 200 && delResp.status < 300) { + cy.log(`Deleted model ${model.modelId}`); + } else { + cy.log(`Failed to delete model ${model.modelId}: ${delResp.status}`); + } + }); + }); + }); +} + +/** + * Delete all repositories whose repositoryId starts with the E2E prefix. + * Repository deletion cascades to collections and documents. + */ +export function sweepE2eRepositories () { + cy.log('Sweeping E2E repositories...'); + makeAuthenticatedRequest('GET', '/repository').then((response) => { + if (response.status !== 200) { + cy.log(`Failed to list repositories: ${response.status}`); + return; + } + + const repositories = response.body ?? []; + const e2eRepos = repositories.filter((r: any) => + typeof r.repositoryId === 'string' && r.repositoryId.startsWith(E2E_PREFIX) + ); + + cy.log(`Found ${e2eRepos.length} E2E repository(ies) to clean up`); + + e2eRepos.forEach((repo: any) => { + cy.log(`Deleting repository: ${repo.repositoryId}`); + makeAuthenticatedRequest('DELETE', `/repository/${repo.repositoryId}`).then((delResp) => { + if (delResp.status >= 200 && delResp.status < 300) { + cy.log(`Deleted repository ${repo.repositoryId}`); + } else { + cy.log(`Failed to delete repository ${repo.repositoryId}: ${delResp.status}`); + } + }); + }); + }); +} + +/** + * Delete all prompt templates whose title starts with the E2E prefix. + */ +export function sweepE2ePromptTemplates () { + cy.log('Sweeping E2E prompt templates...'); + makeAuthenticatedRequest('GET', '/prompt-templates').then((response) => { + if (response.status !== 200) { + cy.log(`Failed to list prompt templates: ${response.status}`); + return; + } + + const templates = response.body?.templates ?? []; + const e2eTemplates = templates.filter((t: any) => + typeof t.title === 'string' && t.title.startsWith(E2E_PROMPT_PREFIX) + ); + + cy.log(`Found ${e2eTemplates.length} E2E prompt template(s) to clean up`); + + e2eTemplates.forEach((template: any) => { + const templateId = template.promptTemplateId || template.id; + if (!templateId) { + cy.log(`Skipping template "${template.title}" - no ID found`); + return; + } + cy.log(`Deleting prompt template: ${template.title} (${templateId})`); + makeAuthenticatedRequest('DELETE', `/prompt-templates/${templateId}`).then((delResp) => { + if (delResp.status >= 200 && delResp.status < 300) { + cy.log(`Deleted prompt template ${templateId}`); + } else { + cy.log(`Failed to delete prompt template ${templateId}: ${delResp.status}`); + } + }); + }); + }); +} + +/** + * Delete all sessions for the current user. + */ +export function sweepE2eSessions () { + cy.log('Sweeping E2E sessions...'); + makeAuthenticatedRequest('DELETE', '/session').then((response) => { + if (response.status >= 200 && response.status < 300) { + cy.log('Deleted all sessions'); + } else { + cy.log(`Failed to delete sessions: ${response.status}`); + } + }); +} + +/** + * Sweep all E2E test resources. Intended to run in before/after hooks + * to ensure a clean environment regardless of prior test state. + * + * Deletion order matters: sessions first, then repositories (cascades to + * collections/documents), then prompt templates, then models. + */ +export function sweepAllE2eResources () { + cy.log('=== Starting E2E resource sweep ==='); + sweepE2eSessions(); + sweepE2eRepositories(); + sweepE2ePromptTemplates(); + sweepE2eModels(); + cy.log('=== E2E resource sweep complete ==='); +} diff --git a/cypress/src/support/collectionHelpers.ts b/cypress/src/support/collectionHelpers.ts index e118b00eb..df964ad16 100644 --- a/cypress/src/support/collectionHelpers.ts +++ b/cypress/src/support/collectionHelpers.ts @@ -19,6 +19,8 @@ * Reusable helpers for RAG collection management and document operations. */ +import { navigateToAdminPage } from './adminHelpers'; + export type CollectionConfig = { collectionId: string; collectionName: string; @@ -29,25 +31,23 @@ export type CollectionConfig = { * Navigate to the RAG Management page */ export function navigateToRagManagement () { - cy.visit('/#/repository-management'); - cy.url().should('include', '/repository-management'); - cy.wait(1000); + navigateToAdminPage('RAG Management'); + cy.url({ timeout: 30000 }).should('include', '/repository-management'); } /** * Get the API base URL from the application's environment */ -function getApiBaseUrl (): Cypress.Chainable { - return cy.window().then((win: any) => { - const apiBaseUrl = win.env?.API_BASE_URL || ''; - return apiBaseUrl.replace(/\/+$/, ''); // Remove trailing slashes - }); +export function getApiBaseUrl (): Cypress.Chainable { + // Get base URL from Cypress config and ensure it doesn't have trailing slash + const baseUrl = Cypress.config('baseUrl') as string; + return cy.wrap(baseUrl.replace(/\/+$/, '')); } /** * Get the authentication token from session storage */ -function getAuthToken (): Cypress.Chainable { +export function getAuthToken (): Cypress.Chainable { return cy.window().then((win) => { // Find the OIDC token in sessionStorage const oidcKey = Object.keys(win.sessionStorage).find((key) => key.startsWith('oidc.user:')); @@ -65,7 +65,7 @@ function getAuthToken (): Cypress.Chainable { * @param path - API path (e.g., '/repository', '/collections') * @param options - Additional request options (body, headers, etc.) */ -function makeAuthenticatedRequest ( +export function makeAuthenticatedRequest ( method: string, path: string, options: Partial = {} @@ -87,10 +87,10 @@ function makeAuthenticatedRequest ( } /** - * Wait for repository to be fully created (up to 5 minutes) + * Wait for repository to be fully created (up to 20 minutes) * Checks repository status until it's CREATE_COMPLETE or UPDATE_COMPLETE */ -export function waitForRepositoryReady (repositoryId: string, timeoutMs: number = 300000) { +export function waitForRepositoryReady (repositoryId: string, timeoutMs: number = 1200000) { cy.log(`Waiting for repository ${repositoryId} to be ready...`); const startTime = Date.now(); @@ -213,8 +213,9 @@ export function uploadDocument (filePath: string) { .find('input[type="file"]') .selectFile(`src/e2e/fixtures/${filePath}`, { force: true }); - // Wait a moment for file to be attached - cy.wait(1000); + // Wait for file to be attached (file token appears in UI) + cy.get('[data-testid="rag-upload-file-input"]') + .should('contain.text', filePath.split('/').pop()); // Click the Upload button to submit cy.contains('button', 'Upload') @@ -275,12 +276,13 @@ export function selectRagRepositoryInChat (repositoryId: string) { cy.log(`Selecting RAG repository: ${repositoryId}`); // Click the RAG repository input - cy.get('input#rag-repository-autosuggest, input[placeholder*="RAG Repository" i]') + cy.get('[data-testid="rag-repository-autosuggest"] input, input#rag-repository-autosuggest, input[placeholder*="RAG Repository" i]') .should('be.visible') .click({ force: true }); // Wait for dropdown to appear and select the repository - cy.get('[role="option"]') + cy.get('[role="option"]', { timeout: 10000 }) + .should('be.visible') .contains(repositoryId) .should('be.visible') .click(); @@ -294,7 +296,7 @@ export function selectCollectionInChat (collectionName: string) { cy.log(`Selecting collection: ${collectionName}`); // Click the collection input - cy.get('input#collection-autosuggest, input[placeholder*="collection" i]') + cy.get('[data-testid="rag-collection-autosuggest"] input, input#collection-autosuggest, input[placeholder*="collection" i]') .should('be.visible') .click({ force: true }); @@ -330,7 +332,7 @@ export function sendMessageAndVerifyRagResponse (message: string) { cy.intercept('POST', '**/chat/completions').as('chatCompletion'); // Type the message - cy.get('textarea[placeholder*="message" i]') + cy.get('[data-testid="chat-prompt-textarea"] textarea, textarea[placeholder*="message" i]') .should('be.visible') .clear() .type(message); @@ -426,7 +428,8 @@ export function deleteCollectionIfExists (collectionName: string) { .should('be.visible') .click(); - cy.wait(2000); + // Wait for modal to close after deletion + cy.get('[data-testid="confirmation-modal-delete-btn"]', { timeout: 10000 }).should('not.exist'); } }); } diff --git a/cypress/src/support/modelFormHelpers.ts b/cypress/src/support/modelFormHelpers.ts index 6f6768df3..25f6ccde6 100644 --- a/cypress/src/support/modelFormHelpers.ts +++ b/cypress/src/support/modelFormHelpers.ts @@ -48,8 +48,8 @@ export function openCreateModelWizard () { * Fill in the base model configuration for a third-party (Bedrock) model */ export function fillBedrockModelConfig (config: BedrockModelConfig) { - cy.get('input[placeholder="mistral-vllm"]').clear().type(config.modelId); - cy.get('input[placeholder*="mistralai/Mistral"]').clear().type(config.modelName); + cy.get('[data-testid="model-id-input"] input, input[placeholder="mistral-vllm"]').clear().type(config.modelId); + cy.get('[data-testid="model-name-input"] input, input[placeholder*="mistralai/Mistral"]').clear().type(config.modelName); if (config.modelDescription) { cy.get('input[placeholder*="Brief description"]').clear().type(config.modelDescription); @@ -92,10 +92,44 @@ export function waitForModelCreationSuccess (modelId: string) { } /** - * Verify model appears in the model management list + * Verify model appears in the model management list. + * After creation, the model may not appear in the initial GET /models response + * because the API is eventually consistent. Retries with page reload if needed. */ -export function verifyModelInList (modelId: string) { - cy.contains(modelId, { timeout: 10000 }).should('be.visible'); +export function verifyModelInList (modelId: string, maxRetries: number = 3) { + function checkWithRetry (attempt: number): void { + // Ensure we're on the Model Management page before waiting for API + cy.url().then((url) => { + if (!url.includes('model-management')) { + cy.window().then((win) => { + win.location.hash = '#/model-management'; + }); + cy.url({ timeout: 10000 }).should('include', 'model-management'); + } + }); + + // Now wait for the models API to load on the Model Management page + cy.wait('@getModels', { timeout: 30000 }); + + cy.get('body').then(($body) => { + if ($body.text().includes(modelId)) { + cy.contains(modelId).should('be.visible'); + } else if (attempt < maxRetries) { + cy.log(`Model ${modelId} not found (attempt ${attempt}/${maxRetries}), refreshing...`); + cy.wait(5000); + // Navigate back to Model Management and retry + cy.window().then((win) => { + win.location.hash = '#/model-management'; + }); + cy.url({ timeout: 10000 }).should('include', 'model-management'); + checkWithRetry(attempt + 1); + } else { + // Final attempt - let it fail with a clear assertion + cy.contains(modelId, { timeout: 10000 }).should('be.visible'); + } + }); + } + checkWithRetry(1); } /** @@ -110,6 +144,9 @@ export function deleteModelIfExists (modelId: string) { .find('input[type="radio"]') .click({ force: true }); + // Set up intercept before triggering delete + cy.intercept('DELETE', '**/models/*').as('deleteModel'); + // Click the Actions dropdown cy.get('[data-testid="model-actions-dropdown"]').click(); @@ -121,7 +158,9 @@ export function deleteModelIfExists (modelId: string) { .should('be.visible') .click(); - cy.wait(2000); + // Wait for delete API to complete and modal to close + cy.wait('@deleteModel', { timeout: 10000 }); + cy.get('[data-testid="confirmation-modal-delete-btn"]').should('not.exist'); } }); } @@ -130,16 +169,30 @@ export function deleteModelIfExists (modelId: string) { * Select a model in the chat interface */ export function selectModelInChat (modelId: string) { - cy.get('input[placeholder*="model" i], input[aria-label*="model" i]', { timeout: 45000 }) + // Click to open the dropdown and wait for options to load + cy.get('[data-testid="model-selection-autosuggest"] input, input[placeholder*="model" i], input[aria-label*="model" i]', { timeout: 45000 }) .first() .should('not.be.disabled') - .click({ force: true }) + .click({ force: true }); + + // Wait for dropdown options to appear + cy.get('[role="option"], [role="menuitem"]', { timeout: 15000 }) + .should('be.visible'); + + // Type to filter, then select the matching option + cy.get('[data-testid="model-selection-autosuggest"] input, input[placeholder*="model" i], input[aria-label*="model" i]') + .first() + .clear() .type(modelId); cy.get('[role="option"], [role="menuitem"]') .contains(modelId) .should('be.visible') .click(); + + // Verify the model was actually selected β€” send button becomes enabled + cy.get('button[aria-label="Send message"]', { timeout: 30000 }) + .should('not.be.disabled'); } /** @@ -150,7 +203,7 @@ export function sendChatMessage (message: string) { // Intercept the chat completions API call cy.intercept('POST', '**/v2/serve/chat/completions').as('chatInference'); - cy.get('textarea[placeholder*="message" i]') + cy.get('[data-testid="chat-prompt-textarea"] textarea, textarea[placeholder*="message" i]') .should('not.be.disabled') .type(message); @@ -181,16 +234,26 @@ export function verifyChatResponse (userMessage: string) { * Delete all chat sessions for the current user */ export function deleteAllSessions () { - // Click the Delete All Sessions button - cy.get('button[aria-label="Delete All Sessions"]') - .should('be.visible') - .click(); + cy.get('body').then(($body) => { + if ($body.find('button[aria-label="Delete All Sessions"]').length === 0) { + cy.log('No sessions to delete β€” Delete All Sessions button not found'); + return; + } - // Wait for confirmation modal and click Delete button - cy.get('[data-testid="confirmation-modal-delete-btn"]', { timeout: 5000 }) - .should('be.visible') - .click(); + // Set up intercept before triggering delete + cy.intercept('DELETE', '**/session*').as('deleteSessions'); - // Wait for deletion to complete - cy.wait(2000); + cy.get('button[aria-label="Delete All Sessions"]') + .should('be.visible') + .click(); + + // Wait for confirmation modal and click Delete button + cy.get('[data-testid="confirmation-modal-delete-btn"]', { timeout: 5000 }) + .should('be.visible') + .click(); + + // Wait for delete API to complete and modal to close + cy.wait('@deleteSessions', { timeout: 10000 }); + cy.get('[data-testid="confirmation-modal-delete-btn"]').should('not.exist'); + }); } diff --git a/cypress/src/support/projectHelpers.ts b/cypress/src/support/projectHelpers.ts new file mode 100644 index 000000000..fb02ffa3e --- /dev/null +++ b/cypress/src/support/projectHelpers.ts @@ -0,0 +1,332 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/** + * projectHelpers.ts + * Reusable helpers for Project Organization feature interactions. + */ + +// Project feature selectors +export const PROJECT_SELECTORS = { + // Segmented control for view toggle + VIEW_TOGGLE: '[data-testid="project-history-toggle"]', + HISTORY_VIEW_BUTTON: '[data-testid="history"]', + PROJECTS_VIEW_BUTTON: '[data-testid="projects"]', + + // Buttons and dropdowns + NEW_SESSION_DROPDOWN: '[data-testid="new-session-dropdown"]', + NEW_PROJECT_MENU_ITEM: '[data-testid="new-project"]', + PROJECT_ACTIONS_BUTTON: (projectName: string) => `[aria-label="Project actions for ${projectName}"]`, + RENAME_MENU_ITEM: '[data-testid="rename"]', + DELETE_MENU_ITEM: '[data-testid="delete"]', + ADD_TO_PROJECT_MENU_ITEM: '[data-testid="add-to-project"]', + REMOVE_FROM_PROJECT_MENU_ITEM: '[data-testid="remove-from-project"]', + + // Modals + MODAL_DIALOG: '[role="dialog"]', + MODAL_HEADER: '[class*="awsui_header"] h2', + PROJECT_NAME_INPUT: 'input[placeholder*="Enter project name"]', + DELETE_PROJECT_ONLY_RADIO: 'input[value="project-only"]', + DELETE_PROJECT_WITH_SESSIONS_RADIO: 'input[value="with-sessions"]', + + // Lists and items + PROJECT_LIST_ITEM: (projectName: string) => `[data-testid="project-${projectName}"]`, + SESSION_ITEM: '[data-testid="session-item"]', + SESSION_ITEM_ACTIVE: '[data-testid="session-item-active"]', + PROJECT_BADGE: '[class*="awsui_badge"]', + SESSION_ACTIONS_BUTTON: '[aria-label="Session actions"]', + + // Empty states + EMPTY_STATE: '[class*="awsui_empty"]', + EMPTY_STATE_TEXT: '[class*="awsui_empty"] [class*="awsui_content"]', +}; + +/** + * Navigate to the chat page to access session history and projects. + * Uses client-side hash navigation to preserve auth state and avoid + * React re-render race conditions that occur with link clicks. + */ +export function navigateToChatPage () { + cy.url().then((url) => { + if (!url.includes('/ai-assistant')) { + cy.window().then((win) => { + win.location.hash = '#/ai-assistant'; + }); + } + }); + + cy.url({ timeout: 10000 }).should('include', '/ai-assistant'); + + // Wait for any loading spinners to disappear + cy.get('body').then(($body) => { + if ($body.find('[class*="awsui_spinner"]').length > 0) { + cy.get('[class*="awsui_spinner"]', { timeout: 10000 }).should('not.exist'); + } + }); +} + +/** + * Switch to the Projects view using the segmented control + */ +export function switchToProjectsView () { + cy.get('[data-testid="project-history-toggle"]', { timeout: 15000 }) + .should('be.visible') + .contains('button', 'Projects') + .click(); + + // Wait for view toggle to update + cy.get('[data-testid="project-history-toggle"]') + .contains('button', 'Projects') + .should('have.attr', 'aria-pressed', 'true'); +} + +/** + * Switch to the History view using the segmented control + */ +export function switchToHistoryView () { + cy.get('[data-testid="project-history-toggle"]', { timeout: 15000 }) + .should('be.visible') + .contains('button', 'History') + .click(); + + // Wait for view toggle to update + cy.get('[data-testid="project-history-toggle"]') + .contains('button', 'History') + .should('have.attr', 'aria-pressed', 'true'); +} + +/** + * Verify the current active view + * @param view - Expected view: 'history' or 'projects' + */ +export function verifyCurrentView (view: 'history' | 'projects') { + const selector = view === 'history' + ? PROJECT_SELECTORS.HISTORY_VIEW_BUTTON + : PROJECT_SELECTORS.PROJECTS_VIEW_BUTTON; + + cy.get(selector).should('have.attr', 'aria-pressed', 'true'); +} + +/** + * Verify view selection is persisted in localStorage + * @param view - Expected view: 'history' or 'projects' + */ +export function verifyViewPersistence (view: 'history' | 'projects') { + cy.window().then((win) => { + const storedView = win.localStorage.getItem('lisa-history-view'); + expect(storedView).to.equal(view); + }); +} + +/** + * Create a new project + * @param projectName - Name of the project to create + */ +export function createProject (projectName: string) { + // Open New dropdown using data-testid + cy.get(PROJECT_SELECTORS.NEW_SESSION_DROPDOWN).click(); + + // Click New Project menu item + cy.get(PROJECT_SELECTORS.NEW_PROJECT_MENU_ITEM).click(); + + // Wait for modal header to appear + cy.contains('h2', 'New Project', { timeout: 5000 }).should('be.visible'); + + // Enter project name + cy.get(PROJECT_SELECTORS.PROJECT_NAME_INPUT) + .should('be.visible') + .clear() + .type(projectName); + + // Confirm + cy.get('button').filter(':visible').contains('Create').click(); + + // Wait for create API call to complete + cy.wait('@createProject'); + + // Wait for modal to close - check that the visible modal header is gone + cy.contains('h2', 'New Project').should('not.be.visible'); +} + +/** + * Rename an existing project + * @param currentName - Current project name + * @param newName - New project name + */ +export function renameProject (currentName: string, newName: string) { + // Open project actions menu + cy.get(PROJECT_SELECTORS.PROJECT_ACTIONS_BUTTON(currentName)) + .first() + .should('be.visible') + .click(); + + // Click Rename + cy.get(PROJECT_SELECTORS.RENAME_MENU_ITEM).click(); + + // Wait for modal header to appear + cy.contains('h2', 'Rename Project', { timeout: 5000 }).should('be.visible'); + + // Verify current name is pre-filled and enter new name + cy.get('[data-testid="rename-project-input"] input') + .should('have.value', currentName) + .clear() + .type(newName); + + // Confirm - button text is "Rename" + cy.get('button').filter(':visible').contains('Rename').click(); + + // Wait for update API call to complete + cy.wait('@updateProject'); + + // Wait for modal to close + cy.contains('h2', 'Rename Project').should('not.be.visible'); +} + +/** + * Delete a project without deleting its sessions + * @param projectName - Name of the project to delete + */ +export function deleteProjectOnly (projectName: string) { + // Open project actions menu + cy.get(PROJECT_SELECTORS.PROJECT_ACTIONS_BUTTON(projectName)) + .first() + .should('be.visible') + .click(); + + // Click Delete + cy.get(PROJECT_SELECTORS.DELETE_MENU_ITEM).first().click(); + + // Wait for modal header to appear + cy.contains('h2', 'Delete Project', { timeout: 5000 }).should('be.visible'); + + // Click "Delete project only" button + cy.get('button').filter(':visible').contains('Delete project only').click(); + + // Wait for delete API call to complete + cy.wait('@deleteProject'); + + // Wait for modal to close + cy.contains('h2', 'Delete Project').should('not.be.visible'); +} + +/** + * Delete a project and all its sessions + * @param projectName - Name of the project to delete + */ +export function deleteProjectWithSessions (projectName: string) { + // Open project actions menu + cy.get(PROJECT_SELECTORS.PROJECT_ACTIONS_BUTTON(projectName)) + .first() + .should('be.visible') + .click(); + + // Click Delete + cy.get(PROJECT_SELECTORS.DELETE_MENU_ITEM).first().click(); + + // Wait for modal header to appear + cy.contains('h2', 'Delete Project', { timeout: 5000 }).should('be.visible'); + + // Click "Delete project and sessions" button + cy.get('button').filter(':visible').contains('Delete project and sessions').click(); + + // Wait for delete API call to complete + cy.wait('@deleteProject'); + + // Wait for modal to close + cy.contains('h2', 'Delete Project').should('not.be.visible'); +} + + +/** + * Verify that a project exists in the Projects view + * @param projectName - Name of the project to verify + */ +export function verifyProjectExists (projectName: string) { + // Project names are truncated to 15 chars in ExpandableSection header + const displayName = projectName.length > 15 ? `${projectName.slice(0, 15)}...` : projectName; + cy.contains(displayName, { timeout: 10000 }).should('be.visible'); +} + +/** + * Verify that a project does not exist + * @param projectName - Name of the project that should not exist + */ +export function verifyProjectNotExists (projectName: string) { + // Project names are truncated to 15 chars in ExpandableSection header + const displayName = projectName.length > 15 ? `${projectName.slice(0, 15)}...` : projectName; + cy.contains(displayName, { timeout: 10000 }).should('not.exist'); +} + + +/** + * Enable the Projects feature via configuration (admin only) + */ +export function enableProjectsFeature () { + // Navigate to Configuration page + cy.get('a[aria-label="Configuration"]') + .should('be.visible') + .click(); + + cy.url().should('include', '/configuration'); + + // Find and enable projectOrganization toggle + cy.contains('Project Organization') + .parent() + .within(() => { + cy.get('input[type="checkbox"]').then(($checkbox) => { + if (!$checkbox.is(':checked')) { + cy.wrap($checkbox).click({ force: true }); + } + }); + }); + + // Save configuration + cy.contains('button', 'Save').click(); + + // Wait for save to complete + cy.contains('Configuration saved successfully', { timeout: 10000 }) + .should('be.visible'); +} + +/** + * Disable the Projects feature via configuration (admin only) + */ +export function disableProjectsFeature () { + // Navigate to Configuration page + cy.get('a[aria-label="Configuration"]') + .should('be.visible') + .click(); + + cy.url().should('include', '/configuration'); + + // Find and disable projectOrganization toggle + cy.contains('Project Organization') + .parent() + .within(() => { + cy.get('input[type="checkbox"]').then(($checkbox) => { + if ($checkbox.is(':checked')) { + cy.wrap($checkbox).click({ force: true }); + } + }); + }); + + // Save configuration + cy.contains('button', 'Save').click(); + + // Wait for save to complete + cy.contains('Configuration saved successfully', { timeout: 10000 }) + .should('be.visible'); +} diff --git a/cypress/src/support/promptTemplateHelpers.ts b/cypress/src/support/promptTemplateHelpers.ts index ec4a580c5..a275806d3 100644 --- a/cypress/src/support/promptTemplateHelpers.ts +++ b/cypress/src/support/promptTemplateHelpers.ts @@ -242,7 +242,13 @@ export function selectPromptTemplateInChat (templateTitle: string, templateType: .and('not.be.disabled') .click(); - // Wait for modal to close and UI to stabilize + // Wait for modal to close cy.get(modalSelector).should('not.be.visible'); - cy.wait(500); + + // Verify template was applied based on type + if (!isPersona) { + // Directive content goes to the textarea + cy.get('[data-testid="chat-prompt-textarea"] textarea, textarea[placeholder*="message" i]') + .should('not.have.value', ''); + } } diff --git a/cypress/src/support/repositoryHelpers.ts b/cypress/src/support/repositoryHelpers.ts index dba1143b9..88e3fd984 100644 --- a/cypress/src/support/repositoryHelpers.ts +++ b/cypress/src/support/repositoryHelpers.ts @@ -19,6 +19,8 @@ * Reusable helpers for repository creation and management interactions. */ +import { navigateToAdminPage } from './adminHelpers'; + export type RepositoryConfig = { repositoryId: string; knowledgeBaseName: string; @@ -39,9 +41,8 @@ export function repositoryExists (repositoryId: string): Cypress.Chainable { // Set up intercept for data sources API before selecting KB cy.intercept('GET', '**/bedrock-kb/*/data-sources').as('getDataSources'); - // Wait for the select to be visible (API already loaded in fillRepositoryConfig) - cy.get('[data-testid="knowledge-base-select"]').should('be.visible'); + // Wait for the select to be visible + cy.get('[data-testid="knowledge-base-select"]', { timeout: 10000 }).should('be.visible'); - // Click the Knowledge Base dropdown button - cy.get('[data-testid="knowledge-base-select"]') - .find('button') - .click(); + // Check if the select is disabled (no KBs available) or has the empty placeholder + return cy.get('[data-testid="knowledge-base-select"]').then(($select) => { + const button = $select.find('button'); + const selectText = $select.text(); - // Select the knowledge base by name - cy.get('[role="option"]') - .contains(knowledgeBaseName) - .should('be.visible') - .click(); + // Check for empty state indicators + if (button.is(':disabled') || + selectText.includes('No available Knowledge Bases') || + selectText.includes('Choose a Knowledge Base')) { + + // Try clicking to see if dropdown has options + cy.get('[data-testid="knowledge-base-select"]') + .find('button') + .click({ force: true }); + + // Check if any options exist + return cy.get('body').then(($body) => { + const hasOptions = $body.find('[role="listbox"] [role="option"]').length > 0; + + if (!hasOptions) { + cy.log('No Knowledge Bases available - skipping KB selection'); + // Close dropdown if open by clicking elsewhere + cy.get('body').click(0, 0); + return cy.wrap(false); + } - // Wait for data sources to load after selecting KB - cy.wait('@getDataSources', { timeout: 30000 }); + // Options exist, try to find the specific KB + const kbOption = $body.find(`[role="option"]:contains("${knowledgeBaseName}")`); + if (kbOption.length > 0) { + cy.get('[role="option"]') + .contains(knowledgeBaseName) + .click(); + cy.wait('@getDataSources', { timeout: 30000 }); + return cy.wrap(true); + } else { + cy.log(`Knowledge Base "${knowledgeBaseName}" not found - selecting first available`); + cy.get('[role="option"]').first().click(); + cy.wait('@getDataSources', { timeout: 30000 }); + return cy.wrap(true); + } + }); + } + + // Select is enabled, proceed normally + cy.get('[data-testid="knowledge-base-select"]') + .find('button') + .click(); + + // Wait for dropdown to open + cy.get('[role="listbox"]', { timeout: 10000 }).should('exist'); + + // Select the knowledge base by name or first available + return cy.get('body').then(($body) => { + const kbOption = $body.find(`[role="option"]:contains("${knowledgeBaseName}")`); + if (kbOption.length > 0) { + cy.get('[role="option"]') + .contains(knowledgeBaseName) + .click(); + } else if ($body.find('[role="option"]').length > 0) { + cy.log(`Knowledge Base "${knowledgeBaseName}" not found - selecting first available`); + cy.get('[role="option"]').first().click(); + } else { + cy.log('No Knowledge Bases available'); + cy.get('body').click(0, 0); // Close dropdown + return cy.wrap(false); + } + cy.wait('@getDataSources', { timeout: 30000 }); + return cy.wrap(true); + }); + }); } /** @@ -191,7 +253,8 @@ export function deleteRepositoryIfExists (repositoryId: string) { .should('be.visible') .click(); - cy.wait(2000); + // Wait for modal to close after deletion + cy.get('[data-testid="confirmation-modal-delete-btn"]', { timeout: 30000 }).should('not.exist'); } }); } diff --git a/ecs_model_deployer/src/lib/ecs-model.ts b/ecs_model_deployer/src/lib/ecs-model.ts index 0d4b98330..0fefc4e42 100644 --- a/ecs_model_deployer/src/lib/ecs-model.ts +++ b/ecs_model_deployer/src/lib/ecs-model.ts @@ -22,6 +22,7 @@ import { Construct } from 'constructs'; import { ECSCluster } from './ecsCluster'; import { getModelIdentifier } from './utils'; import { APP_MANAGEMENT_KEY, Ec2Metadata, EcsClusterConfig, EcsSourceType, PartialConfig } from '../../../lib/schema'; +import { createCdkId } from '../../../lib/core/utils'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; // Default memory buffer if not specified in config (2GB) @@ -99,19 +100,29 @@ export class EcsModel extends Construct { * represent the environment variables for Docker at runtime. */ private getEnvironmentVariables (config: PartialConfig, modelConfig: EcsClusterConfig): { [key: string]: string } { + const identifier = getModelIdentifier(modelConfig); const environment: { [key: string]: string } = { LOCAL_MODEL_PATH: `${config.nvmeContainerMountPath ?? '/nvme'}/model`, S3_BUCKET_MODELS: config.s3BucketModels ?? '', MODEL_NAME: modelConfig.modelName, LOCAL_CODE_PATH: modelConfig.localModelCode, // Only needed when s5cmd is used, but just keep for now AWS_REGION: config.region ?? '', // needed for s5cmd - MANAGEMENT_KEY_NAME: StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`) + MANAGEMENT_KEY_NAME: StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`), + // Used by metrics_publisher.py for CloudWatch dimensions + CLUSTER_NAME: createCdkId([config.deploymentName, identifier], 32, 2), + SERVICE_NAME: createCdkId([config.deploymentName, identifier], 32, 2), }; if (modelConfig.modelType === 'embedding') { environment.SAGEMAKER_BASE_DIR = config.nvmeContainerMountPath ?? '/nvme'; } + // Set SERVED_MODEL_NAME for TEI so it accepts the model name sent by LiteLLM + // in OpenAI-compatible requests, avoiding "model not found" warnings. + if (modelConfig.inferenceContainer === 'tei') { + environment.SERVED_MODEL_NAME = modelConfig.modelName; + } + if (config.mountS3DebUrl) { environment.S3_MOUNT_POINT = 's3-models-mount'; // More threads than files during S3 mount point copy to NVMe is fine; by default use half threads diff --git a/ecs_model_deployer/src/lib/ecsCluster.ts b/ecs_model_deployer/src/lib/ecsCluster.ts index 507256bc9..3bbf01b19 100644 --- a/ecs_model_deployer/src/lib/ecsCluster.ts +++ b/ecs_model_deployer/src/lib/ecsCluster.ts @@ -17,7 +17,7 @@ // ECS Cluster Construct. import { CfnOutput, Duration, RemovalPolicy } from 'aws-cdk-lib'; import { BlockDeviceVolume, GroupMetrics, Monitoring } from 'aws-cdk-lib/aws-autoscaling'; -import { Metric, Stats } from 'aws-cdk-lib/aws-cloudwatch'; +import { Alarm, ComparisonOperator, Metric, Stats, TreatMissingData } from 'aws-cdk-lib/aws-cloudwatch'; import { InstanceType, ISecurityGroup, IVpc, SubnetSelection } from 'aws-cdk-lib/aws-ec2'; import { Alias } from 'aws-cdk-lib/aws-kms'; import { @@ -92,7 +92,7 @@ export class ECSCluster extends Construct { const cluster = new Cluster(this, createCdkId([identifier, 'Cl']), { clusterName: createCdkId([config.deploymentName, identifier], 32, 2), vpc: vpc, - containerInsightsV2: !config.region?.includes('iso') ? ContainerInsights.ENABLED : ContainerInsights.DISABLED, + containerInsightsV2: ContainerInsights.ENHANCED, }); // SNS encryption key for ECS lifecycle hooks (AppSec Finding #5) @@ -368,6 +368,80 @@ DOCKEREOF estimatedInstanceWarmup: Duration.seconds(ecsConfig.autoScalingConfig.metricConfig.duration), }); + // Model ALB alarms β€” created only when the health dashboard is enabled. + // These use concrete ALB/TargetGroup dimensions (available here at deploy + // time) so the alarms actually receive datapoints. The dashboard uses + // SEARCH expressions for dynamic discovery; alarms cannot use SEARCH. + if (config.deployHealthDashboard) { + const alarmPrefix = `${config.deploymentName}-${config.deploymentStage}-LISA-${identifier}`; + const albDims = { LoadBalancer: loadBalancer.loadBalancerFullName }; + const tgDims = { TargetGroup: targetGroup.targetGroupFullName, LoadBalancer: loadBalancer.loadBalancerFullName }; + + new Alarm(this, createCdkId([identifier, 'UnhealthyHostsAlarm']), { + alarmName: `${alarmPrefix}-UnhealthyHosts`, + alarmDescription: `Model ${identifier}: one or more containers are failing ALB health checks.`, + metric: new Metric({ + namespace: 'AWS/ApplicationELB', + metricName: 'UnHealthyHostCount', + dimensionsMap: tgDims, + statistic: 'Maximum', + period: Duration.minutes(5), + }), + threshold: 0, + comparisonOperator: ComparisonOperator.GREATER_THAN_THRESHOLD, + evaluationPeriods: 2, + treatMissingData: TreatMissingData.NOT_BREACHING, + }); + + new Alarm(this, createCdkId([identifier, 'Target5xxAlarm']), { + alarmName: `${alarmPrefix}-Target5xxErrors`, + alarmDescription: `Model ${identifier}: sustained HTTP 5xx errors from model container.`, + metric: new Metric({ + namespace: 'AWS/ApplicationELB', + metricName: 'HTTPCode_Target_5XX_Count', + dimensionsMap: tgDims, + statistic: 'Sum', + period: Duration.minutes(5), + }), + threshold: 10, + comparisonOperator: ComparisonOperator.GREATER_THAN_THRESHOLD, + evaluationPeriods: 2, + treatMissingData: TreatMissingData.NOT_BREACHING, + }); + + new Alarm(this, createCdkId([identifier, 'ConnectionErrorAlarm']), { + alarmName: `${alarmPrefix}-TargetConnectionErrors`, + alarmDescription: `Model ${identifier}: ALB cannot connect to container (crash/OOM).`, + metric: new Metric({ + namespace: 'AWS/ApplicationELB', + metricName: 'TargetConnectionErrorCount', + dimensionsMap: tgDims, + statistic: 'Sum', + period: Duration.minutes(5), + }), + threshold: 5, + comparisonOperator: ComparisonOperator.GREATER_THAN_THRESHOLD, + evaluationPeriods: 2, + treatMissingData: TreatMissingData.NOT_BREACHING, + }); + + new Alarm(this, createCdkId([identifier, 'HighLatencyAlarm']), { + alarmName: `${alarmPrefix}-HighP99Latency`, + alarmDescription: `Model ${identifier}: p99 response time exceeds 120s.`, + metric: new Metric({ + namespace: 'AWS/ApplicationELB', + metricName: 'TargetResponseTime', + dimensionsMap: albDims, + statistic: 'p99', + period: Duration.minutes(5), + }), + threshold: 120, + comparisonOperator: ComparisonOperator.GREATER_THAN_THRESHOLD, + evaluationPeriods: 3, + treatMissingData: TreatMissingData.NOT_BREACHING, + }); + } + const domain = loadBalancer.loadBalancerDnsName; endpointUrl = `${protocol}://${domain}`; diff --git a/example_config.yaml b/example_config.yaml index 27fadde20..85b5e23cb 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -4,6 +4,7 @@ authConfig: authority: clientId: adminGroup: + ragAdminGroup: jwtGroupsProperty: s3BucketModels: hf-models-gaiic ragRepositories: [] @@ -28,6 +29,11 @@ ragRepositories: [] # domainName: # restApiConfig: # sslCertIamArn: ARN of the self-signed cert to be used throughout the system +# domainName: Custom hostname for the LISA Serve ALB (optional) +# MCP Workbench uses its own ALB; set a workbench hostname the same way (optional). If omitted while restApiConfig.domainName is set, a derived host is used (e.g. lisa-serve.example β†’ lisa-mcp-workbench.example). +# mcpWorkbenchRestApiConfig: +# domainName: +# sslCertIamArn: # optional; falls back to restApiConfig.sslCertIamArn, then mcpWorkbenchEcsConfig.sslCertIamArn # Some customers will want to download required libs prior to deployment, provide a path to the zipped resources # lambdaLayerAssets: # authorizerLayerPath: /path/to/authorizer_layer.zip diff --git a/flake.nix b/flake.nix index 3edebefce..83af693ad 100644 --- a/flake.nix +++ b/flake.nix @@ -30,14 +30,12 @@ # Core development tools needed for LISA packages = with pkgs; [ awscli2 # AWS command-line interface for deployment and management - gnumake jq # JSON processor for parsing AWS responses and configuration python313Full # Python runtime for LISA backend services nodejs # Node.js runtime for CDK infrastructure and frontend tooling nodePackages.aws-cdk # AWS CDK CLI, the command line tool for CDK apps python313Packages.pre-commit-hooks # Git hook framework for code quality checks python313Packages.uv # Fast Python package installer and virtual environment manager - yq # YAML processor for configuration management ]; # Script that runs when entering the development shell diff --git a/lambda/authorizer/lambda_functions.py b/lambda/authorizer/lambda_functions.py index 729f49650..47c28e256 100644 --- a/lambda/authorizer/lambda_functions.py +++ b/lambda/authorizer/lambda_functions.py @@ -21,11 +21,16 @@ from typing import Any import boto3 -import create_env_variables # noqa: F401 +import create_env_variables # noqa: F401 # type: ignore[reportMissingImports] import jwt import requests from botocore.exceptions import ClientError from cachetools import cached, TTLCache +from utilities.audit_logging_utils import ( + get_matched_audit_prefix, + get_method_and_path_from_method_arn, + log_audit_event, +) from utilities.auth_provider import get_authorization_provider from utilities.common_functions import authorization_wrapper, get_id_token, get_property_path, retry_config from utilities.time import now_seconds @@ -43,11 +48,32 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """Handle authorization for REST API.""" logger.info("REST API authorization handler started") + # Compute request path/action for conditional audit logs. + http_method, request_path = get_method_and_path_from_method_arn(event.get("methodArn", "")) + audit_area = get_matched_audit_prefix(request_path) + + def _log_audit(decision: str, username: str, auth_type: str) -> None: + """Emit AUDIT_API_GATEWAY_REQUEST when `audit_area` is enabled.""" + if not audit_area: + return + log_audit_event( + logger, + "AUDIT_API_GATEWAY_REQUEST", + { + "area": audit_area, + "action": f"{http_method} {request_path}", + "decision": decision, + "user": {"username": username, "auth_type": auth_type}, + }, + ) + id_token = get_id_token(event) if not id_token: logger.warning("Missing id_token in request. Denying access.") logger.info(f"REST API authorization handler completed with 'Deny' for resource {event['methodArn']}") + if audit_area: + _log_audit(decision="Deny", username="unknown", auth_type="unknown") return generate_policy(effect="Deny", resource=event["methodArn"]) client_id = os.environ.get("CLIENT_ID", "") @@ -64,6 +90,8 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=username) allow_policy["context"] = {"username": username, "groups": groups, "authType": "management"} logger.debug(f"Generated policy: {allow_policy}") + if audit_area: + _log_audit(decision="Allow", username=username, auth_type="management") return allow_policy if os.environ.get("TOKEN_TABLE_NAME", None): @@ -74,6 +102,8 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=username) allow_policy["context"] = {"username": username, "groups": groups, "authType": "api_token"} logger.debug(f"Generated policy: {allow_policy}") + if audit_area: + _log_audit(decision="Allow", username=username, auth_type="api_token") return allow_policy if jwt_data := id_token_is_valid(id_token=id_token, client_id=client_id, authority=authority): @@ -83,10 +113,13 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: # Use auth provider for access checks (consistent with auth.py) auth_provider = get_authorization_provider() is_admin_user = auth_provider.check_admin_access(username, user_groups) + is_rag_admin_user = auth_provider.check_rag_admin_access(username, user_groups) has_app_access = auth_provider.check_app_access(username, user_groups) - if not is_admin_user and not has_app_access: + if not is_admin_user and not is_rag_admin_user and not has_app_access: logger.info(f"User {username} denied access - no valid authorization found") + if audit_area: + _log_audit(decision="Deny", username=username, auth_type="jwt") return deny_policy groups = json.dumps(user_groups) @@ -95,9 +128,13 @@ def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: logger.debug(f"Generated policy: {allow_policy}") logger.info(f"REST API authorization handler completed with 'Allow' for resource {event['methodArn']}") + if audit_area: + _log_audit(decision="Allow", username=username, auth_type="jwt") return allow_policy logger.info(f"REST API authorization handler completed with 'Deny' for resource {event['methodArn']}") + if audit_area: + _log_audit(decision="Deny", username="unknown", auth_type="unknown") return deny_policy diff --git a/lambda/configuration/lambda_functions.py b/lambda/configuration/lambda_functions.py index a70c0d7dd..fcb461934 100644 --- a/lambda/configuration/lambda_functions.py +++ b/lambda/configuration/lambda_functions.py @@ -100,13 +100,14 @@ def check_show_mcp_workbench(body: dict[str, Any], old_configuration: dict[str, from mcp_server.lambda_functions import table as mcp_servers_table # noqa: PLC0415 if new_show_mcp_value: + mcp_base = os.getenv("MCP_WORKBENCH_ENDPOINT") or os.getenv("FASTAPI_ENDPOINT") mcp_server_model = McpServerModel( id=MCPWORKBENCH_UUID, owner="lisa:public", name="MCP Workbench", description="MCP Workbench Tools", customHeaders={"Authorization": "Bearer {LISA_BEARER_TOKEN}"}, - url=f"{os.getenv('FASTAPI_ENDPOINT')}/v2/mcp/", + url=f"{mcp_base}/v2/mcp/", status=McpServerStatus.ACTIVE, ) diff --git a/lambda/dockerimagebuilder/__init__.py b/lambda/dockerimagebuilder/__init__.py index 80b371e64..284972e71 100644 --- a/lambda/dockerimagebuilder/__init__.py +++ b/lambda/dockerimagebuilder/__init__.py @@ -63,7 +63,7 @@ # Setup build environment mkdir /home/ec2-user/docker_resources aws --region ${AWS_REGION} s3 sync s3://{{BUCKET_NAME}} /home/ec2-user/docker_resources -cd /home/ec2-user/docker_resources/{{LAYER_TO_ADD}} +cd /home/ec2-user/docker_resources while [ 1 ]; do shutdown -c; @@ -72,9 +72,10 @@ function buildTagPush() { echo "Starting Docker build for {{IMAGE_ID}}" | tee -a /var/log/docker-build.log - sed -iE 's/^FROM.*/FROM {{BASE_IMAGE}}/' Dockerfile + sed -iE 's/^FROM.*/FROM {{BASE_IMAGE}}/' {{LAYER_TO_ADD}}/Dockerfile docker build -t {{IMAGE_ID}} --build-arg BASE_IMAGE={{BASE_IMAGE}} \\ - --build-arg MOUNTS3_DEB_URL={{MOUNTS3_DEB_URL}} . 2>&1 | tee -a /var/log/docker-build.log && \\ + --build-arg MOUNTS3_DEB_URL={{MOUNTS3_DEB_URL}} \\ + -f {{LAYER_TO_ADD}}/Dockerfile . 2>&1 | tee -a /var/log/docker-build.log && \\ docker tag {{IMAGE_ID}} {{ECR_URI}}:{{IMAGE_ID}} 2>&1 | tee -a /var/log/docker-build.log && \\ aws --region ${AWS_REGION} ecr get-login-password | \\ docker login --username AWS --password-stdin {{ECR_URI}} 2>&1 | tee -a /var/log/docker-build.log && \\ diff --git a/lambda/mcp_workbench/syntax_validator.py b/lambda/mcp_workbench/syntax_validator.py index 7ee227844..4f455a803 100644 --- a/lambda/mcp_workbench/syntax_validator.py +++ b/lambda/mcp_workbench/syntax_validator.py @@ -14,6 +14,9 @@ """Python syntax validation module for MCP Workbench.""" import ast +import importlib +import importlib.abc +import importlib.machinery import importlib.util import logging import os @@ -39,6 +42,51 @@ def __post_init__(self) -> None: self.missing_required_imports = [] +class _StubLoader(importlib.abc.Loader): + """Loader that creates empty stub modules for ``mcpworkbench.*``.""" + + def create_module(self, spec: importlib.machinery.ModuleSpec) -> ModuleType: + mod = ModuleType(spec.name) + mod.__path__ = [] + mod.__package__ = spec.name + mod.__spec__ = spec + return mod + + def exec_module(self, module: ModuleType) -> None: + pass + + +class _McpWorkbenchStubFinder(importlib.abc.MetaPathFinder): + """Auto-stub any ``mcpworkbench.*`` import that hasn't already been mocked. + + During Lambda-based validation we only have explicit mocks for + ``mcpworkbench.core.*``. Tools may import from other subpackages + (e.g. ``mcpworkbench.aws.*``) that don't exist in the Lambda + environment. This finder intercepts those imports and returns + lightweight stub modules so validation can proceed without + ImportErrors. + """ + + _PREFIX = "mcpworkbench." + _loader = _StubLoader() + + def find_spec( + self, + fullname: str, + path: Any = None, + target: Any = None, + ) -> importlib.machinery.ModuleSpec | None: + if fullname == "mcpworkbench" or fullname.startswith(self._PREFIX): + if fullname not in sys.modules: + spec = importlib.machinery.ModuleSpec( + fullname, + self._loader, + is_package=True, + ) + return spec + return None + + class PythonSyntaxValidator: """Validates Python code syntax and imports without execution.""" @@ -197,11 +245,17 @@ def _setup_mcp_environment(self, module: Any) -> None: # Create mock module hierarchy in sys.modules # This allows user code to do: from mcpworkbench.core.base_tool import BaseTool + # __path__ must be set so Python treats these as packages that can have submodules. if "mcpworkbench" not in sys.modules: - sys.modules["mcpworkbench"] = ModuleType("mcpworkbench") + mcpworkbench_mod = ModuleType("mcpworkbench") + mcpworkbench_mod.__path__ = [] + mcpworkbench_mod.__package__ = "mcpworkbench" + sys.modules["mcpworkbench"] = mcpworkbench_mod if "mcpworkbench.core" not in sys.modules: core_module = ModuleType("mcpworkbench.core") + core_module.__path__ = [] + core_module.__package__ = "mcpworkbench.core" sys.modules["mcpworkbench.core"] = core_module sys.modules["mcpworkbench"].core = core_module # type: ignore[attr-defined] @@ -219,6 +273,13 @@ def _setup_mcp_environment(self, module: Any) -> None: logger.info("MCP mock modules successfully injected into sys.modules") logger.info(f"Modules now in sys.modules: {[k for k in sys.modules.keys() if 'mcpworkbench' in k]}") + + # Install a catch-all finder so that imports of other mcpworkbench + # subpackages (e.g. mcpworkbench.aws.*) return stubs instead of + # raising ImportError during validation. + if not any(isinstance(f, _McpWorkbenchStubFinder) for f in sys.meta_path): + sys.meta_path.append(_McpWorkbenchStubFinder()) + logger.info("Installed _McpWorkbenchStubFinder for remaining mcpworkbench.* imports") else: logger.info("Real MCP Workbench package is already available in sys.modules") diff --git a/lambda/metrics/batch_job_metric.py b/lambda/metrics/batch_job_metric.py new file mode 100644 index 000000000..de2cd7792 --- /dev/null +++ b/lambda/metrics/batch_job_metric.py @@ -0,0 +1,84 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda handler for publishing CloudWatch metrics on Batch job state changes. + +Captures SUBMITTED, RUNNING, SUCCEEDED, and FAILED state transitions from +EventBridge and publishes corresponding metrics to the LISA/BatchIngestion +namespace. This provides queue-level visibility regardless of how the +ingestion job was triggered (S3 event, scheduled, or manual upload). +""" + +import json +import logging +import os + +import boto3 + +logger = logging.getLogger(__name__) + +cloudwatch = boto3.client("cloudwatch") + +# Map Batch job states to CloudWatch metric names +STATE_METRIC_MAP = { + "SUBMITTED": "JobsSubmitted", + "RUNNING": "JobsStarted", + "SUCCEEDED": "JobsSucceeded", + "FAILED": "JobsFailed", +} + + +def handler(event: dict, context: dict) -> None: + """Publish a CloudWatch metric when an AWS Batch ingestion job changes state. + + Triggered by an EventBridge rule that captures Batch Job State Change + events for the ingestion job queue. + + Parameters + ---------- + event : dict + EventBridge event with Batch job state change details. + context : dict + Lambda execution context. + """ + namespace = os.environ["METRICS_NAMESPACE"] + deployment = os.environ["DEPLOYMENT_NAME"] + stage = os.environ["DEPLOYMENT_STAGE"] + + detail = event.get("detail", {}) + job_queue = detail.get("jobQueue", "unknown") + job_name = detail.get("jobName", "unknown") + status = detail.get("status", "UNKNOWN") + + metric_name = STATE_METRIC_MAP.get(status) + if not metric_name: + logger.warning(json.dumps({"message": "Unhandled job status", "status": status})) + return + + cloudwatch.put_metric_data( + Namespace=namespace, + MetricData=[ + { + "MetricName": metric_name, + "Dimensions": [ + {"Name": "DeploymentName", "Value": deployment}, + {"Name": "DeploymentStage", "Value": stage}, + {"Name": "JobQueue", "Value": os.environ.get("JOB_QUEUE_LABEL", job_queue.split("/")[-1])}, + ], + "Value": 1, + "Unit": "Count", + }, + ], + ) + logger.info(json.dumps({"status": status, "metric": metric_name, "jobName": job_name, "jobQueue": job_queue})) diff --git a/lambda/metrics/lambda_functions.py b/lambda/metrics/lambda_functions.py index 5d21e1bb4..e3d40f8e1 100644 --- a/lambda/metrics/lambda_functions.py +++ b/lambda/metrics/lambda_functions.py @@ -21,6 +21,7 @@ import boto3 import create_env_variables # noqa: F401 from botocore.exceptions import ClientError +from metrics.models import MetricsEvent from utilities.common_functions import api_wrapper, retry_config from utilities.time import iso_string, utc_now @@ -204,31 +205,32 @@ def process_metrics_sqs_event(event: dict, context: dict) -> None: for record in event.get("Records", []): try: - # Parse the message body - message = json.loads(record["body"]) - - user_id = message.get("userId") - session_id = message.get("sessionId") - user_groups = message.get("userGroups", []) - messages = message.get("messages", []) - - logger.info(f"Processing metrics for user: {user_id}, session: {session_id}.") - - if not user_id: - logger.error("SQS message missing required 'userId' field") + # Parse and validate the message body through the MetricsEvent model. + # This catches malformed messages early and gives clear field access. + raw = json.loads(record["body"]) + try: + msg = MetricsEvent.model_validate(raw) + except Exception as validation_err: + logger.error(f"SQS message failed MetricsEvent validation: {validation_err}") continue - if not session_id: - logger.error("SQS message missing required 'sessionId' field") - continue + logger.info(f"Processing metrics for user: {msg.userId}, session: {msg.sessionId}, model: {msg.modelId}.") + + # Calculate prompt/RAG/MCP metrics for a given session + session_metrics = calculate_session_metrics(msg.messages) - # Calculate metrics for a given session - session_metrics = calculate_session_metrics(messages) + # Attach token data to session_metrics when present + if msg.promptTokens is not None: + session_metrics["promptTokens"] = msg.promptTokens + if msg.completionTokens is not None: + session_metrics["completionTokens"] = msg.completionTokens + if msg.modelId is not None: + session_metrics["modelId"] = msg.modelId logger.info(f"Calculated session metrics: {session_metrics}") # Update usage metrics for given session - update_user_metrics_by_session(user_id, session_id, session_metrics, user_groups) + update_user_metrics_by_session(msg.userId, msg.sessionId, session_metrics, msg.userGroups, msg.eventType) except Exception as e: logger.error(f"Error processing SQS message: {str(e)}") @@ -358,6 +360,9 @@ def publish_metric_deltas( delta_mcp_calls: int, delta_mcp_usage: dict[str, int], user_groups: list[str], + delta_prompt_tokens: int = 0, + delta_completion_tokens: int = 0, + model_id: str | None = None, ) -> None: """Publish only metric deltas to CloudWatch to prevent double counting. @@ -373,8 +378,14 @@ def publish_metric_deltas( Change in MCP tool calls delta_mcp_usage : Dict[str, int] Changes in individual MCP tool usage - user_groups : List[str], optional + user_groups : List[str] The groups that the user belongs to + delta_prompt_tokens : int + Change in prompt (input) token count + delta_completion_tokens : int + Change in completion (output) token count + model_id : str | None + Model ID used for the request (used as a dimension for token metrics) """ try: timestamp = utc_now() @@ -394,6 +405,16 @@ def publish_metric_deltas( }, ] ) + if model_id: + metric_data.append( + { + "MetricName": "ModelPromptCount", + "Dimensions": [{"Name": "ModelId", "Value": model_id}], + "Value": delta_prompts, + "Unit": "Count", + "Timestamp": timestamp, + } + ) if delta_rag != 0: metric_data.extend( @@ -441,6 +462,69 @@ def publish_metric_deltas( } ) + # Token metrics β€” aggregate, per-user, per-model + if delta_prompt_tokens != 0: + metric_data.extend( + [ + # Aggregate totals (no dimension) + { + "MetricName": "TotalPromptTokens", + "Value": delta_prompt_tokens, + "Unit": "Count", + "Timestamp": timestamp, + }, + # Per-user + { + "MetricName": "UserPromptTokens", + "Dimensions": [{"Name": "UserId", "Value": user_id}], + "Value": delta_prompt_tokens, + "Unit": "Count", + "Timestamp": timestamp, + }, + ] + ) + if model_id: + metric_data.append( + { + "MetricName": "ModelPromptTokens", + "Dimensions": [{"Name": "ModelId", "Value": model_id}], + "Value": delta_prompt_tokens, + "Unit": "Count", + "Timestamp": timestamp, + } + ) + + if delta_completion_tokens != 0: + metric_data.extend( + [ + # Aggregate totals (no dimension) + { + "MetricName": "TotalCompletionTokens", + "Value": delta_completion_tokens, + "Unit": "Count", + "Timestamp": timestamp, + }, + # Per-user + { + "MetricName": "UserCompletionTokens", + "Dimensions": [{"Name": "UserId", "Value": user_id}], + "Value": delta_completion_tokens, + "Unit": "Count", + "Timestamp": timestamp, + }, + ] + ) + if model_id: + metric_data.append( + { + "MetricName": "ModelCompletionTokens", + "Dimensions": [{"Name": "ModelId", "Value": model_id}], + "Value": delta_completion_tokens, + "Unit": "Count", + "Timestamp": timestamp, + } + ) + # Group-level metrics if user_groups: for group in user_groups: @@ -477,8 +561,33 @@ def publish_metric_deltas( } ) + if delta_prompt_tokens != 0: + metric_data.append( + { + "MetricName": "GroupPromptTokens", + "Dimensions": [{"Name": "GroupName", "Value": group}], + "Value": delta_prompt_tokens, + "Unit": "Count", + "Timestamp": timestamp, + } + ) + + if delta_completion_tokens != 0: + metric_data.append( + { + "MetricName": "GroupCompletionTokens", + "Dimensions": [{"Name": "GroupName", "Value": group}], + "Value": delta_completion_tokens, + "Unit": "Count", + "Timestamp": timestamp, + } + ) + if metric_data: - cloudwatch.put_metric_data(Namespace="LISA/UsageMetrics", MetricData=metric_data) + # CloudWatch PutMetricData accepts max 1000 metrics per call; batch if needed + batch_size = 1000 + for i in range(0, len(metric_data), batch_size): + cloudwatch.put_metric_data(Namespace="LISA/UsageMetrics", MetricData=metric_data[i : i + batch_size]) logger.info(f"Published {len(metric_data)} metric deltas for user {user_id}") except Exception as e: @@ -486,7 +595,11 @@ def publish_metric_deltas( def update_user_metrics_by_session( - user_id: str, session_id: str, session_metrics: dict[str, Any], user_groups: list[str] + user_id: str, + session_id: str, + session_metrics: dict[str, Any], + user_groups: list[str], + event_type: str = "full", ) -> None: """Update usage metrics for a given user based on session-level metrics. @@ -500,6 +613,11 @@ def update_user_metrics_by_session( Calculated metrics for this session user_groups : List[str] The groups that the user is apart of + event_type : str + "full" β€” API token user or session-lambda event; owns all metrics. + "token_only" β€” JWT/UI passthrough event; only carries token counts, session + lambda already counted the prompts. Do not write a sessionMetrics + entry β€” that would create synthetic sessions and pollute aggregation. """ table_name = os.environ.get("USAGE_METRICS_TABLE_NAME") @@ -532,11 +650,79 @@ def update_user_metrics_by_session( if delta != 0: delta_mcp_usage[tool_name] = delta + # Calculate token counts for this event + new_prompt_tokens = session_metrics.get("promptTokens", 0) or 0 + new_completion_tokens = session_metrics.get("completionTokens", 0) or 0 + old_prompt_tokens = existing_session_metrics.get("promptTokens", 0) or 0 + old_completion_tokens = existing_session_metrics.get("completionTokens", 0) or 0 + delta_prompt_tokens = new_prompt_tokens - old_prompt_tokens + delta_completion_tokens = new_completion_tokens - old_completion_tokens + model_id = session_metrics.get("modelId") + + # event_type=="token_only": JWT/UI passthrough events that carry only token counts. + # The session lambda already counted the prompts for these requests, so we publish + # CloudWatch token metrics but must NOT write a sessionMetrics entry β€” that would + # create synthetic sessions and pollute prompt aggregation. + is_token_only_event = event_type == "token_only" + # Publish only deltas to CloudWatch (prevents double counting) - if delta_prompts != 0 or delta_rag != 0 or delta_mcp_calls != 0 or delta_mcp_usage: - publish_metric_deltas(user_id, delta_prompts, delta_rag, delta_mcp_calls, delta_mcp_usage, user_groups) + has_changes = ( + delta_prompts != 0 + or delta_rag != 0 + or delta_mcp_calls != 0 + or delta_mcp_usage + or delta_prompt_tokens != 0 + or delta_completion_tokens != 0 + ) + if has_changes: + publish_metric_deltas( + user_id, + delta_prompts, + delta_rag, + delta_mcp_calls, + delta_mcp_usage, + user_groups, + delta_prompt_tokens=delta_prompt_tokens, + delta_completion_tokens=delta_completion_tokens, + model_id=model_id, + ) - # Update DynamoDB with session-based metrics + if is_token_only_event: + # Only update the aggregate token totals in DynamoDB β€” no sessionMetrics entry. + if not user_exists: + item: dict[str, Any] = { + "userId": user_id, + "totalPrompts": 0, + "ragUsageCount": 0, + "mcpToolCallsCount": 0, + "mcpToolUsage": {}, + "totalPromptTokens": new_prompt_tokens, + "totalCompletionTokens": new_completion_tokens, + "sessionMetrics": {}, + "firstSeen": iso_string(), + "lastSeen": iso_string(), + "userGroups": set(user_groups) if user_groups else None, + } + usage_metrics_table.put_item(Item=item) + else: + existing_prompt_tokens = int(existing_item.get("totalPromptTokens", 0) or 0) + existing_completion_tokens = int(existing_item.get("totalCompletionTokens", 0) or 0) + usage_metrics_table.update_item( + Key={"userId": user_id}, + UpdateExpression=( + "SET lastSeen = :now, " + "totalPromptTokens = :total_prompt_tokens, " + "totalCompletionTokens = :total_completion_tokens" + ), + ExpressionAttributeValues={ + ":now": iso_string(), + ":total_prompt_tokens": existing_prompt_tokens + delta_prompt_tokens, + ":total_completion_tokens": existing_completion_tokens + delta_completion_tokens, + }, + ) + return + + # Full event (API user or session-lambda event with prompts) β€” update everything. if not user_exists: # Create new user with session metrics item = { @@ -545,6 +731,8 @@ def update_user_metrics_by_session( "ragUsageCount": session_metrics["ragUsage"], "mcpToolCallsCount": session_metrics["mcpToolCallsCount"], "mcpToolUsage": session_metrics["mcpToolUsage"], + "totalPromptTokens": new_prompt_tokens, + "totalCompletionTokens": new_completion_tokens, "sessionMetrics": {session_id: session_metrics}, "firstSeen": iso_string(), "lastSeen": iso_string(), @@ -567,18 +755,31 @@ def update_user_metrics_by_session( for tool_name, count in sm.get("mcpToolUsage", {}).items(): aggregate_mcp_usage[tool_name] = aggregate_mcp_usage.get(tool_name, 0) + count + # Also add current event tokens to the aggregated session total + # (session lambda events don't carry tokens; API events do) + existing_prompt_tokens = int(existing_item.get("totalPromptTokens", 0) or 0) + existing_completion_tokens = int(existing_item.get("totalCompletionTokens", 0) or 0) + updated_prompt_tokens = existing_prompt_tokens + delta_prompt_tokens + updated_completion_tokens = existing_completion_tokens + delta_completion_tokens + # Update the user record usage_metrics_table.update_item( Key={"userId": user_id}, - UpdateExpression="SET lastSeen = :now, totalPrompts = :total_prompts, ragUsageCount = :total_rag, " - "mcpToolCallsCount = :total_mcp, mcpToolUsage = :mcp_usage, " - "sessionMetrics = :session_metrics, userGroups = :groups", + UpdateExpression=( + "SET lastSeen = :now, totalPrompts = :total_prompts, ragUsageCount = :total_rag, " + "mcpToolCallsCount = :total_mcp, mcpToolUsage = :mcp_usage, " + "totalPromptTokens = :total_prompt_tokens, " + "totalCompletionTokens = :total_completion_tokens, " + "sessionMetrics = :session_metrics, userGroups = :groups" + ), ExpressionAttributeValues={ ":now": iso_string(), ":total_prompts": total_prompts, ":total_rag": total_rag, ":total_mcp": total_mcp_calls, ":mcp_usage": aggregate_mcp_usage, + ":total_prompt_tokens": updated_prompt_tokens, + ":total_completion_tokens": updated_completion_tokens, ":session_metrics": all_session_metrics, ":groups": set(user_groups) if user_groups else (existing_item.get("userGroups") or set()), }, diff --git a/lambda/metrics/models.py b/lambda/metrics/models.py index 6aa6eb756..474e047c1 100644 --- a/lambda/metrics/models.py +++ b/lambda/metrics/models.py @@ -20,10 +20,21 @@ class MetricsEvent(BaseModel): - """Event model for usage metrics published to SQS.""" + """Event model for usage metrics published to SQS. + + event_type : str + "full" β€” API token user or session-lambda event; owns all metrics. + "token_only" β€” JWT/UI passthrough event; only carries token counts, session + lambda already counted the prompts. Do not write a sessionMetrics + entry β€” that would create synthetic sessions and pollute aggregation. + """ userId: str sessionId: str messages: list[dict[str, Any]] userGroups: list[str] timestamp: str + eventType: str = "full" + modelId: str | None = None + promptTokens: int | None = None + completionTokens: int | None = None diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index 71626f783..557d5c16c 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -74,6 +74,14 @@ class ModelType(StrEnum): EMBEDDING = auto() +class ModelHostingType(StrEnum): + """Defines where a model is hosted.""" + + THIRD_PARTY = auto() + LISA_HOSTED = auto() + INTERNAL_HOSTED = auto() + + class GuardrailMode(StrEnum): """Defines supported guardrail execution modes.""" @@ -466,6 +474,7 @@ class LISAModel(BaseModel): allowedGroups: list[str] | None = None guardrailsConfig: GuardrailsConfig | None = None contextWindow: int | None = None + hostingType: ModelHostingType | None = ModelHostingType.THIRD_PARTY class ApiResponseBase(BaseModel): @@ -492,6 +501,7 @@ class CreateModelRequest(BaseModel): allowedGroups: list[str] | None = None apiKey: str | None = None guardrailsConfig: GuardrailsConfig | None = None + hostingType: ModelHostingType | None = ModelHostingType.THIRD_PARTY @model_validator(mode="after") def validate_create_model_request(self) -> Self: @@ -513,6 +523,13 @@ def validate_create_model_request(self) -> Self: "autoScalingConfig, containerConfig, inferenceContainer, instanceType, and loadBalancerConfig" ) + if self.hostingType == ModelHostingType.INTERNAL_HOSTED and not self.modelUrl: + raise ValueError("modelUrl is required for INTERNAL_HOSTED models.") + if self.hostingType == ModelHostingType.INTERNAL_HOSTED and self.modelUrl: + parsed_url = urllib.parse.urlparse(self.modelUrl) + if not parsed_url.hostname or not parsed_url.hostname.lower().endswith(".elb.amazonaws.com"): + raise ValueError("modelUrl for INTERNAL_HOSTED models must target an AWS load balancer hostname.") + return self diff --git a/lambda/models/litellm_model_sync.py b/lambda/models/litellm_model_sync.py new file mode 100644 index 000000000..8dfe37685 --- /dev/null +++ b/lambda/models/litellm_model_sync.py @@ -0,0 +1,296 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda handler for syncing all models from DynamoDB to LiteLLM. + +This Lambda is triggered when the LiteLLM PostgreSQL database is created or updated, +ensuring all models in the Models DynamoDB table are registered in LiteLLM. + +Note: This module intentionally does NOT import from models.state_machine.create_model +to avoid requiring GUARDRAILS_TABLE_NAME at module load time. +""" + +import json +import logging +import os +from typing import Any + +import boto3 +from models.clients.litellm_client import LiteLLMClient +from models.domain_objects import ModelStatus, ModelType +from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config +from utilities.time import now + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +ddb_resource = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) +iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config) +secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) + + +def get_litellm_client() -> LiteLLMClient: + """Create a LiteLLM client with proper authentication.""" + return LiteLLMClient( + base_uri=get_rest_api_container_endpoint(), + verify=get_cert_path(iam_client), + headers={ + "Authorization": secrets_manager.get_secret_value( + SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSCURRENT" + )["SecretString"], + "Content-Type": "application/json", + }, + ) + + +def build_litellm_params(model_item: dict[str, Any]) -> dict[str, Any]: + """Build LiteLLM params from a DynamoDB model item.""" + model_config = model_item.get("model_config", {}) + model_name = model_config.get("modelName", "") + model_url = model_item.get("model_url", "") + model_type = model_config.get("modelType", "").upper() + inference_container = model_config.get("inferenceContainer", "").lower() + + # Check if this is a video generation model + is_video_model = model_type == ModelType.VIDEOGEN.upper() + + # For video generation models, use empty litellm_settings to avoid drop_params error + litellm_params: dict[str, Any] = {} if is_video_model else {"drop_params": True} + + # Determine if this is a LISA-managed model (has infrastructure) + is_lisa_managed = bool(model_url and model_config.get("autoScalingConfig")) + + if is_lisa_managed: + # Determine the correct LiteLLM provider prefix based on the inference container type + if inference_container == "vllm": + provider_prefix = "hosted_vllm" + else: + provider_prefix = "openai" + # Remove duplicate openai prefixing if present + if model_name.startswith("openai/"): + model_name = model_name[len("openai/") :] + + litellm_params["model"] = f"{provider_prefix}/{model_name}" + litellm_params["api_base"] = model_url if model_url.endswith("/v1") else f"{model_url}/v1" + else: + litellm_params["model"] = model_name + + return litellm_params + + +def sync_model_to_litellm( + litellm_client: LiteLLMClient, model_table: Any, model_item: dict[str, Any], existing_model_names: set[str] +) -> dict[str, Any]: + """Sync a single model to LiteLLM. + + Args: + litellm_client: The LiteLLM client + model_table: The DynamoDB model table + model_item: The model item from DynamoDB + existing_model_names: Set of model names that already exist in LiteLLM + + Returns: + Result dictionary with model_id and status + """ + model_id = model_item.get("model_id", "") + + try: + # Check if model already exists in LiteLLM by name + if model_id in existing_model_names: + logger.info(f"Model {model_id} already exists in LiteLLM, skipping") + return {"model_id": model_id, "status": "skipped", "reason": "already_exists_in_litellm"} + + # Build litellm_params for this model + litellm_params = build_litellm_params(model_item) + + # Add the model to LiteLLM + logger.info(f"Adding model {model_id} to LiteLLM with params: {litellm_params}") + litellm_response = litellm_client.add_model( + model_name=model_id, + litellm_params=litellm_params, + ) + + # Extract the LiteLLM ID from response + if "model_info" in litellm_response and "id" in litellm_response["model_info"]: + litellm_id = litellm_response["model_info"]["id"] + elif "id" in litellm_response: + litellm_id = litellm_response["id"] + elif "model_id" in litellm_response: + litellm_id = litellm_response["model_id"] + else: + logger.warning(f"Could not extract LiteLLM ID from response for model {model_id}: {litellm_response}") + litellm_id = None + + # Update DynamoDB with the litellm_id + if litellm_id: + model_table.update_item( + Key={"model_id": model_id}, + UpdateExpression="SET litellm_id = :lid, last_modified_date = :lm", + ExpressionAttributeValues={ + ":lid": litellm_id, + ":lm": now(), + }, + ) + + logger.info(f"Successfully added model {model_id} to LiteLLM with ID {litellm_id}") + return {"model_id": model_id, "status": "synced", "litellm_id": litellm_id} + + except Exception as e: + logger.error(f"Failed to sync model {model_id} to LiteLLM: {e}", exc_info=True) + return {"model_id": model_id, "status": "failed", "error": str(e)} + + +PHYSICAL_RESOURCE_ID = "LiteLLMModelSync" + + +def _run_sync(force: bool = False) -> dict[str, Any]: + """Run the model sync logic. + + Args: + force: If True, re-sync all IN_SERVICE models regardless of existing litellm_id. + + Returns: + Dictionary with sync summary. + """ + model_table_name = os.environ.get("MODEL_TABLE_NAME") + if not model_table_name: + raise ValueError("MODEL_TABLE_NAME environment variable is not set") + + model_table = ddb_resource.Table(model_table_name) + + # Scan for all models in DynamoDB + logger.info(f"Scanning Models table: {model_table_name}") + models = [] + scan_kwargs: dict[str, Any] = {} + + while True: + response = model_table.scan(**scan_kwargs) + models.extend(response.get("Items", [])) + + if "LastEvaluatedKey" not in response: + break + scan_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"] + + logger.info(f"Found {len(models)} models in DynamoDB") + + # Filter for models that should be synced (IN_SERVICE status) + # In force mode, re-sync all IN_SERVICE models regardless of existing litellm_id + eligible_models = [] + already_synced = 0 + for m in models: + if m.get("model_status") == ModelStatus.IN_SERVICE: + if force or not m.get("litellm_id"): + eligible_models.append(m) + else: + already_synced += 1 + logger.info(f"Model {m.get('model_id')} already has litellm_id, skipping") + + logger.info(f"Found {len(eligible_models)} models needing sync, {already_synced} already synced") + + if not eligible_models: + logger.info("No eligible models to sync") + return { + "message": "No eligible models to sync", + "total_models": len(models), + "eligible_models": 0, + "already_synced": already_synced, + "synced": 0, + "skipped": 0, + "failed": 0, + } + + # Get existing models from LiteLLM to double-check against duplicates + try: + litellm_client = get_litellm_client() + existing_litellm_models = litellm_client.list_models() + existing_model_names: set[str] = {m.get("model_name", "") for m in existing_litellm_models} + logger.info(f"Found {len(existing_model_names)} existing models in LiteLLM") + except Exception as e: + logger.warning(f"Could not list existing LiteLLM models, proceeding anyway: {e}") + litellm_client = get_litellm_client() # Create client anyway for syncing + existing_model_names = set() + + # Sync each model + results = [] + for model_item in eligible_models: + result = sync_model_to_litellm(litellm_client, model_table, model_item, existing_model_names) + results.append(result) + + # Summarize results + synced = sum(1 for r in results if r["status"] == "synced") + skipped = sum(1 for r in results if r["status"] == "skipped") + failed = sum(1 for r in results if r["status"] == "failed") + + logger.info(f"Sync complete. Synced: {synced}, Skipped: {skipped}, Failed: {failed}") + + return { + "message": "Model sync completed", + "total_models": len(models), + "eligible_models": len(eligible_models), + "already_synced": already_synced, + "synced": synced, + "skipped": skipped, + "failed": failed, + "details": results, + } + + +def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: + """CloudFormation CustomResource handler to sync models from DynamoDB to LiteLLM. + + On Create/Update: Scans the Models DynamoDB table for IN_SERVICE models and + registers any missing ones in LiteLLM. + On Delete: No-op (returns SUCCESS β€” nothing to clean up). + + Supports a 'force' flag via ResourceProperties to re-sync all models + regardless of existing litellm_id. + + Args: + event: CloudFormation CustomResource event + context: Lambda context + + Returns: + CustomResource response dict with PhysicalResourceId, Status, and Data. + """ + request_type = event.get("RequestType", "") + logger.info(f"LiteLLM model sync invoked: RequestType={request_type}") + + # Delete is a no-op β€” nothing to clean up. + # IMPORTANT: Return the *incoming* PhysicalResourceId on Delete so the CDK + # framework doesn't reject the response for changing the physical ID. + if request_type == "Delete": + logger.info("RequestType=Delete: no-op, returning SUCCESS") + physical_id = event.get("PhysicalResourceId", PHYSICAL_RESOURCE_ID) + return {"Status": "SUCCESS", "PhysicalResourceId": physical_id} + + # Create and Update both run the sync + try: + # Check for force flag in ResourceProperties + resource_props = event.get("ResourceProperties", {}) or {} + force = bool(resource_props.get("force", False)) + logger.info(f"Starting LiteLLM model sync. Event: {json.dumps(event)}, force={force}") + + data = _run_sync(force=force) + return { + "Status": "SUCCESS", + "PhysicalResourceId": PHYSICAL_RESOURCE_ID, + "Data": data, + } + except Exception as e: + logger.error(f"LiteLLM model sync failed: {e}", exc_info=True) + return { + "Status": "FAILED", + "PhysicalResourceId": PHYSICAL_RESOURCE_ID, + "Reason": str(e), + } diff --git a/lambda/models/state_machine/create_model.py b/lambda/models/state_machine/create_model.py index 5e7fc22cb..9dc11729d 100644 --- a/lambda/models/state_machine/create_model.py +++ b/lambda/models/state_machine/create_model.py @@ -17,6 +17,7 @@ import json import logging import os +import time from copy import deepcopy from datetime import datetime from typing import Any @@ -25,7 +26,14 @@ import boto3 from botocore.config import Config from models.clients.litellm_client import LiteLLMClient -from models.domain_objects import CreateModelRequest, GuardrailsTableEntry, InferenceContainer, ModelStatus, ModelType +from models.domain_objects import ( + CreateModelRequest, + GuardrailsTableEntry, + InferenceContainer, + ModelHostingType, + ModelStatus, + ModelType, +) from models.exception import ( MaxPollsExceededException, StackFailedToCreateException, @@ -615,7 +623,21 @@ def handle_add_model_to_litellm(event: dict[str, Any], context: Any) -> dict[str litellm_params["model"] = f"{provider_prefix}/{model_name}" litellm_params["api_base"] = f"{event['modelUrl']}/v1" # model's OpenAI-compliant route else: - litellm_params["model"] = event["modelName"] + model_name = event["modelName"] + if str(event.get("hostingType", "")).upper() == ModelHostingType.INTERNAL_HOSTED.value.upper(): + # Internal hosted models are registered as OpenAI-compatible providers routed through api_base. + # Normalize common user-entered prefixes so LiteLLM doesn't route via hosted_vllm or external providers. + stripped = True + while stripped: + stripped = False + for prefix in ("openai/", "hosted_vllm/"): + if model_name.startswith(prefix): + model_name = model_name[len(prefix) :] + stripped = True + litellm_params["model"] = f"openai/{model_name}" + litellm_params["api_base"] = str(event["modelUrl"]).rstrip("/") + else: + litellm_params["model"] = model_name litellm_response = litellm_client.add_model( model_name=event["modelId"], @@ -769,14 +791,49 @@ def handle_add_guardrails_to_litellm(event: dict[str, Any], context: Any) -> dic return output_dict -def _fetch_context_window_from_litellm(litellm_id: str) -> Any | None: - """Fetch max_input_tokens from LiteLLM for non-LISA-managed (Bedrock/third-party) models.""" - try: - model_info = litellm_client.get_model(litellm_id) - return int(model_info.get("model_info", {}).get("max_input_tokens")) - except Exception as e: - logger.warning(f"Could not fetch context window from LiteLLM for {litellm_id}: {e}") - return None +def _fetch_context_window_from_litellm( + litellm_id: str, + max_attempts: int = 5, + base_delay: float = 2.0, + backoff_factor: float = 2.0, +) -> Any | None: + """Fetch max_input_tokens from LiteLLM for non-LISA-managed (Bedrock/third-party) models. + + Retries with exponential backoff to handle cases where LiteLLM is queried too + quickly after a model is registered and returns a transient error. + + Args: + litellm_id: The LiteLLM model ID to query. + max_attempts: Maximum number of attempts before giving up (default: 5). + base_delay: Initial delay in seconds between retries (default: 2.0). + backoff_factor: Multiplier applied to delay after each failed attempt (default: 2.0). + + Returns: + The context window size as an int, or None if it could not be determined. + """ + last_exception: Exception | None = None + delay = base_delay + + for attempt in range(1, max_attempts + 1): + try: + model_info = litellm_client.get_model(litellm_id) + return int(model_info.get("model_info", {}).get("max_input_tokens")) + except Exception as e: + last_exception = e + if attempt < max_attempts: + logger.warning( + f"Attempt {attempt}/{max_attempts} failed to fetch context window from LiteLLM " + f"for {litellm_id}: {e}. Retrying in {delay:.1f}s..." + ) + time.sleep(delay) + delay *= backoff_factor + else: + logger.warning( + f"All {max_attempts} attempts exhausted fetching context window from LiteLLM " + f"for {litellm_id}: {last_exception}" + ) + + return None def _fetch_context_window_from_s3(model_name: Any, model_type: str) -> int | None: diff --git a/lambda/models/state_machine/delete_model.py b/lambda/models/state_machine/delete_model.py index 721f24df5..44dac1e05 100644 --- a/lambda/models/state_machine/delete_model.py +++ b/lambda/models/state_machine/delete_model.py @@ -22,6 +22,7 @@ import boto3 from models.clients.litellm_client import LiteLLMClient +from models.state_machine.failure_utils import extract_model_failure_details from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config from utilities.time import now @@ -184,3 +185,27 @@ def handle_delete_from_ddb(event: dict[str, Any], context: Any) -> dict[str, Any model_key = {"model_id": event["modelId"]} ddb_table.delete_item(Key=model_key) return event + + +def handle_failure(event: dict[str, Any], context: Any) -> dict[str, Any]: + """Set model status to Failed for unrecoverable delete workflow errors.""" + logger.error(f"Handling delete-model state machine failure: {event}") + + model_id, error_reason = extract_model_failure_details( + event=event, + default_reason="Delete model state machine failed.", + ) + if not model_id: + logger.error("Unable to determine model id from delete failure event; skipping DDB status update.") + return event + + ddb_table.update_item( + Key={"model_id": model_id}, + UpdateExpression="SET model_status = :ms, last_modified_date = :lmd, failure_reason = :fr", + ExpressionAttributeValues={ + ":ms": ModelStatus.FAILED, + ":lmd": now(), + ":fr": error_reason[:1000], + }, + ) + return event diff --git a/lambda/models/state_machine/failure_utils.py b/lambda/models/state_machine/failure_utils.py new file mode 100644 index 000000000..6c6b606bc --- /dev/null +++ b/lambda/models/state_machine/failure_utils.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for state machine failure-event parsing.""" + +import json +from typing import Any + + +def extract_model_failure_details(event: dict[str, Any], default_reason: str) -> tuple[str | None, str]: + """Extract model id and failure reason from Step Functions catch payloads.""" + raw_error = event.get("error") + catch_error: dict[str, Any] = raw_error if isinstance(raw_error, dict) else {} + cause_payload = event.get("Cause") or catch_error.get("Cause") + + cause_data: dict[str, Any] | None = None + if isinstance(cause_payload, str): + try: + parsed = json.loads(cause_payload) + if isinstance(parsed, dict): + cause_data = parsed + except Exception: + cause_data = None + + model_id = event.get("model_id") or event.get("modelId") + if not model_id and isinstance(cause_data, dict): + model_id = cause_data.get("model_id") or cause_data.get("modelId") + if not model_id: + cause_input = cause_data.get("input") + if isinstance(cause_input, dict): + model_id = cause_input.get("model_id") or cause_input.get("modelId") + + error_reason = default_reason + if isinstance(cause_data, dict): + error_reason = str(cause_data.get("errorMessage", error_reason)) + elif cause_payload is not None: + error_reason = str(cause_payload) + elif "error" in event: + error_reason = str(event.get("error")) + + return model_id, error_reason diff --git a/lambda/models/state_machine/update_model.py b/lambda/models/state_machine/update_model.py index 47122558e..fd30ca281 100644 --- a/lambda/models/state_machine/update_model.py +++ b/lambda/models/state_machine/update_model.py @@ -24,6 +24,7 @@ import boto3 from models.clients.litellm_client import LiteLLMClient from models.domain_objects import GuardrailsTableEntry, ModelStatus, ModelType +from models.state_machine.failure_utils import extract_model_failure_details from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config from utilities.time import now @@ -410,12 +411,13 @@ def handle_finish_update(event: dict[str, Any], context: Any) -> dict[str, Any]: # - vLLM: Use hosted_vllm/ to pass through full model name (e.g., "openai/gpt-oss-20b") # - TGI/TEI: Use openai/ prefix (LiteLLM strips it before sending to backend) model_name = ddb_item["model_config"]["modelName"] - inference_container = ddb_item["model_config"].get("inferenceContainer", "").lower() + inference_container = ddb_item["model_config"].get("inferenceContainer", "") - if inference_container == "vllm": - # vLLM serves models with full HF repo name (e.g., "openai/gpt-oss-20b") - # hosted_vllm/ prefix ensures LiteLLM passes through the complete name - provider_prefix = "hosted_vllm" + if inference_container: + if inference_container.lower() == "vllm": + # vLLM serves models with full HF repo name (e.g., "openai/gpt-oss-20b") + # hosted_vllm/ prefix ensures LiteLLM passes through the complete name + provider_prefix = "hosted_vllm" else: # TGI and TEI use OpenAI-compatible APIs with model name stripping provider_prefix = "openai" @@ -1085,3 +1087,27 @@ def handle_poll_ecs_deployment(event: dict[str, Any], context: Any) -> dict[str, output_dict["should_continue_ecs_polling"] = False return output_dict + + +def handle_failure(event: dict[str, Any], context: Any) -> dict[str, Any]: + """Set model status to Failed for any unrecoverable update workflow error.""" + logger.error(f"Handling update-model state machine failure: {event}") + + model_id, error_reason = extract_model_failure_details( + event=event, + default_reason="Update model state machine failed.", + ) + if not model_id: + logger.error("Unable to determine model id from update failure event; skipping DDB status update.") + return event + + model_table.update_item( + Key={"model_id": model_id}, + UpdateExpression="SET model_status = :ms, last_modified_date = :lm, failure_reason = :fr", + ExpressionAttributeValues={ + ":ms": ModelStatus.FAILED, + ":lm": now(), + ":fr": error_reason[:1000], + }, + ) + return event diff --git a/lambda/repository/collection_service.py b/lambda/repository/collection_service.py index f205aedb4..e4f001988 100644 --- a/lambda/repository/collection_service.py +++ b/lambda/repository/collection_service.py @@ -450,6 +450,7 @@ def list_all_user_collections( username: str, user_groups: list[str], is_admin: bool, + is_rag_admin: bool = False, page_size: int = 20, pagination_token: dict[str, Any] | None = None, filter_text: str | None = None, @@ -499,15 +500,30 @@ def list_all_user_collections( logger.info(f"Estimated total collections: {estimated_total}") # Select and execute pagination strategy + effective_admin = is_admin or is_rag_admin if estimated_total > 1000: logger.info("Using scalable pagination strategy for large dataset") collections, next_token = self._paginate_large_collections( - repositories, username, user_groups, is_admin, page_size, pagination_token, filter_text, sort_params + repositories, + username, + user_groups, + effective_admin, + page_size, + pagination_token, + filter_text, + sort_params, ) else: logger.info("Using simple pagination strategy") collections, next_token = self._paginate_collections( - repositories, username, user_groups, is_admin, page_size, pagination_token, filter_text, sort_params + repositories, + username, + user_groups, + effective_admin, + page_size, + pagination_token, + filter_text, + sort_params, ) logger.info(f"Returning {len(collections)} collections") diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index 1f045081b..e422271a4 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -49,7 +49,16 @@ from repository.s3_metadata_manager import S3MetadataManager from repository.services import RepositoryServiceFactory from repository.vector_store_repo import VectorStoreRepository -from utilities.auth import admin_only, get_groups, get_user_context, get_username, is_admin, user_has_group_access +from utilities.auth import ( + admin_only, + get_groups, + get_user_context, + get_username, + is_admin, + is_rag_admin, + rag_admin_or_admin, + user_has_group_access, +) from utilities.bedrock_kb import create_s3_scan_job from utilities.bedrock_kb_discovery import ( build_pipeline_configs_from_kb_config, @@ -208,6 +217,7 @@ def similarity_search(event: dict, context: dict) -> dict[str, Any]: # Get user context for collection access username, is_admin, groups = get_user_context(event) + effective_admin = is_admin or is_rag_admin(event) is_default = collection_id is not None and collection_id == repository.get("embeddingModelId") # Determine embedding model @@ -217,7 +227,7 @@ def similarity_search(event: dict, context: dict) -> dict[str, Any]: collection_id=collection_id if not is_default else None, # type: ignore[arg-type] username=username, user_groups=groups, - is_admin=is_admin, + is_admin=effective_admin, ) if collection_id else query_string_params.get("modelName") # type: ignore[union-attr] @@ -267,7 +277,13 @@ def similarity_search(event: dict, context: dict) -> dict[str, Any]: def get_repository(event: dict[str, Any], repository_id: str) -> dict[str, Any]: - """Ensures a user has access to the repository or else raises an HTTPException.""" + """Ensures a user has access to the repository or else raises an HTTPException. + + Note: RAG admins are intentionally NOT given blanket repository access here. + They must have group membership via allowedGroups. This is the security boundary + that scopes RAG admin operations to their group-accessible repositories. + The @rag_admin_or_admin decorator gates role access; this function gates repo access. + """ repo: dict[str, Any] = vs_repo.find_repository_by_id(repository_id) # Admins have access to all repositories @@ -463,7 +479,7 @@ def create_default_collection(event: dict, context: dict) -> dict[str, Any]: @api_wrapper -@admin_only +@rag_admin_or_admin def create_collection(event: dict, context: dict) -> dict[str, Any]: """ Create a new collection within a vector store. @@ -561,6 +577,7 @@ def get_collection(event: dict, context: dict) -> dict[str, Any]: # Get user context username, is_admin, groups = get_user_context(event) + effective_admin = is_admin or is_rag_admin(event) # Ensure repository exists and user has access repo = get_repository(event, repository_id=repository_id) @@ -583,7 +600,7 @@ def get_collection(event: dict, context: dict) -> dict[str, Any]: collection_id=collection_id, username=username, user_groups=groups, - is_admin=is_admin, + is_admin=effective_admin, ) if collection is None: @@ -595,7 +612,7 @@ def get_collection(event: dict, context: dict) -> dict[str, Any]: @api_wrapper -@admin_only +@rag_admin_or_admin def update_collection(event: dict, context: dict) -> dict[str, Any]: """ Update a collection within a vector store. @@ -628,6 +645,7 @@ def update_collection(event: dict, context: dict) -> dict[str, Any]: # Get user context username, is_admin, groups = get_user_context(event) + effective_admin = is_admin or is_rag_admin(event) # Ensure repository exists and user has access _ = get_repository(event, repository_id=repository_id) @@ -649,7 +667,7 @@ def update_collection(event: dict, context: dict) -> dict[str, Any]: collection_data=request, username=username, user_groups=groups, - is_admin=is_admin, + is_admin=effective_admin, ) result: dict[str, Any] = updated_collection.model_dump(mode="json") @@ -657,7 +675,7 @@ def update_collection(event: dict, context: dict) -> dict[str, Any]: @api_wrapper -@admin_only +@rag_admin_or_admin def delete_collection(event: dict, context: dict) -> dict[str, Any]: """ Delete a collection (regular or default) within a vector store. @@ -695,6 +713,7 @@ def delete_collection(event: dict, context: dict) -> dict[str, Any]: # Get user context username, is_admin, groups = get_user_context(event) + effective_admin = is_admin or is_rag_admin(event) # Ensure repository exists and user has access repo = get_repository(event, repository_id=repository_id) @@ -707,7 +726,7 @@ def delete_collection(event: dict, context: dict) -> dict[str, Any]: embedding_name=embedding_name if is_default_collection else None, # None for regular collections username=username, user_groups=groups, - is_admin=is_admin, + is_admin=effective_admin, ) return result @@ -751,6 +770,7 @@ def list_collections(event: dict, context: dict) -> dict[str, Any]: # Get user context username, is_admin, groups = get_user_context(event) + effective_admin = is_admin or is_rag_admin(event) # Ensure repository exists and user has access _ = get_repository(event, repository_id=repository_id) @@ -781,7 +801,7 @@ def list_collections(event: dict, context: dict) -> dict[str, Any]: repository_id=repository_id, username=username, user_groups=groups, - is_admin=is_admin, + is_admin=effective_admin, page_size=page_size, last_evaluated_key=last_evaluated_key, ) @@ -853,6 +873,10 @@ def list_user_collections(event: dict, context: dict) -> dict[str, Any]: HTTPException: If authentication fails """ # Get user context + # RAG admins pass is_rag_admin=True so they get scoped-admin collection access + # within repos they have group access to (bypasses collection-level allowedGroups). + # is_admin remains the real flag so _get_accessible_repositories still filters + # repos by group membership β€” RAG admins do NOT see all repos. username, is_admin, groups = get_user_context(event) logger.info(f"list_user_collections called by user={username}, is_admin={is_admin}") @@ -883,6 +907,7 @@ def list_user_collections(event: dict, context: dict) -> dict[str, Any]: username=username, user_groups=groups, is_admin=is_admin, + is_rag_admin=is_rag_admin(event), page_size=page_size, pagination_token=pagination_token, filter_text=filter_text, @@ -919,9 +944,9 @@ def list_user_collections(event: dict, context: dict) -> dict[str, Any]: def _ensure_document_ownership(event: dict[str, Any], docs: list[RagDocument]) -> None: - """Verify ownership of documents""" + """Verify ownership of documents. Admins and RAG admins can delete any document.""" username = get_username(event) - if is_admin(event) is False: + if not is_admin(event) and not is_rag_admin(event): for doc in docs: if not (doc.username == username): raise ValueError(f"Document {doc.document_id} is not owned by {username}") @@ -1073,6 +1098,7 @@ def ingest_documents(event: dict, context: dict) -> dict: handle_deprecated_chunking_strategy(request, query_params) username, is_admin, groups = get_user_context(event) + effective_admin = is_admin or is_rag_admin(event) repository = get_repository(event, repository_id=repository_id) # Get collection if specified @@ -1083,7 +1109,7 @@ def ingest_documents(event: dict, context: dict) -> dict: repository_id=repository_id, username=username, user_groups=groups, - is_admin=is_admin, + is_admin=effective_admin, ).model_dump() # For Bedrock KB repositories, upload metadata files BEFORE documents @@ -1322,12 +1348,13 @@ def list_jobs(event: dict[str, Any], context: dict) -> dict[str, Any]: # Get user context username, is_admin_user, _ = get_user_context(event) + effective_admin = is_admin_user or is_rag_admin(event) # Fetch jobs from repository jobs, returned_last_evaluated_key = ingestion_job_repository.list_jobs_by_repository( repository_id=params.repository_id, username=username, - is_admin=is_admin_user, + is_admin=effective_admin, time_limit_hours=params.time_limit_hours, page_size=params.page_size, last_evaluated_key=params.last_evaluated_key, @@ -1523,10 +1550,13 @@ def _validate_immutable_pipeline_fields(current_pipelines: list, new_pipelines: @api_wrapper -@admin_only +@rag_admin_or_admin def update_repository(event: dict, context: dict) -> dict[str, Any]: """ - Update a vector store configuration. This function is only accessible by administrators. + Update a vector store configuration. Accessible by administrators and RAG admins (with scoped access). + + Admins can update all fields. RAG admins with group access can only update pipeline-related fields. + RAG admins cannot change allowedGroups or other repository-level settings. If the pipeline configuration has changed, this will trigger an infrastructure deployment using the state machine, similar to repository creation. @@ -1553,13 +1583,23 @@ def update_repository(event: dict, context: dict) -> dict[str, Any]: # Parse request body try: - body = json.loads(event.get("body", {})) + body = json.loads(event.get("body", "{}")) request = UpdateVectorStoreRequest(**body) except json.JSONDecodeError as e: raise ValidationError(f"Invalid JSON in request body: {e}") except Exception as e: raise ValidationError(f"Invalid request: {e}") + # RAG admins: verify group access and restrict to pipeline-only updates + if not is_admin(event) and is_rag_admin(event): + # Verify group access to this repo + _ = get_repository(event, repository_id=repository_id) + # RAG admins can only update pipelines and bedrockKnowledgeBaseConfig + allowed_fields = {"pipelines", "bedrockKnowledgeBaseConfig"} + disallowed = set(body.keys()) - allowed_fields + if disallowed: + raise ForbiddenException(f"RAG admins cannot update the following fields: {', '.join(sorted(disallowed))}") + # Get current repository configuration to check for pipeline changes current_repo = vs_repo.find_repository_by_id(repository_id, raw_config=True) current_config = current_repo.get("config", {}) @@ -1568,6 +1608,12 @@ def update_repository(event: dict, context: dict) -> dict[str, Any]: # Build updates dictionary (only include fields that were provided) updates = request.model_dump(exclude_none=True, mode="json") + # Defense-in-depth: RAG admins can only update pipeline-related fields. + # This filters the serialized model output in case defaults were populated. + if not is_admin(event) and is_rag_admin(event): + allowed_fields = {"pipelines", "bedrockKnowledgeBaseConfig"} + updates = {k: v for k, v in updates.items() if k in allowed_fields} + # Convert bedrockKnowledgeBaseConfig to pipelines for Bedrock KB repositories repository_type = current_config.get("type") if ( diff --git a/lambda/repository/pipeline_ingest_handlers.py b/lambda/repository/pipeline_ingest_handlers.py index 33788854f..4fa6ae259 100644 --- a/lambda/repository/pipeline_ingest_handlers.py +++ b/lambda/repository/pipeline_ingest_handlers.py @@ -80,7 +80,7 @@ def handle_pipeline_ingest_event(event: dict[str, Any], context: Any) -> None: key = detail.get("key", None) if key and key.endswith(".metadata.json"): - logger.warning(f"Metadata file event reached Lambda (should be filtered by EventBridge): {key}") + logger.warning(f"Ignoring Metadata file: {key}") return repository_id = detail.get("repositoryId", None) pipeline_config = detail.get("pipelineConfig", None) diff --git a/lambda/repository/services/opensearch_repository_service.py b/lambda/repository/services/opensearch_repository_service.py index 80fbcd271..4e9dabc1d 100644 --- a/lambda/repository/services/opensearch_repository_service.py +++ b/lambda/repository/services/opensearch_repository_service.py @@ -194,4 +194,5 @@ def _get_vector_store_client(self, collection_id: str, embeddings: Embeddings) - use_ssl=True, verify_certs=True, connection_class=RequestsHttpConnection, + engine="faiss", ) diff --git a/lambda/session/lambda_functions.py b/lambda/session/lambda_functions.py index c2c36cad5..a658649c9 100644 --- a/lambda/session/lambda_functions.py +++ b/lambda/session/lambda_functions.py @@ -249,6 +249,30 @@ def _map_session( ) +def _strip_context_from_display_text(text: str) -> str: + cleaned = text.strip() + file_context_prefix = "File context:" + rag_context_prefix = "Context from document search:" + context_prefixes = (file_context_prefix, rag_context_prefix) + + if not any(cleaned.startswith(prefix) for prefix in context_prefixes): + return cleaned + + if cleaned.startswith(file_context_prefix): + return "" + + # Older sessions may have merged context + prompt into one text blob. + # Keep only the final user prompt for session list display. + parts = [part.strip() for part in cleaned.split("\n\n") if part.strip()] + if parts: + tail = parts[-1] + if not any(tail.startswith(prefix) for prefix in context_prefixes): + return tail + + lines = [line.strip() for line in cleaned.splitlines() if line.strip()] + return lines[-1] if lines else "" + + def _find_first_human_message(session: dict, user_id: str | None = None) -> str: # Check if session is encrypted if session.get("is_encrypted", False): @@ -274,13 +298,17 @@ def _find_first_human_message(session: dict, user_id: str | None = None) -> str: if msg.get("type") == "human": content = msg.get("content") if isinstance(content, str): - return content + cleaned = _strip_context_from_display_text(content) + if cleaned: + return cleaned elif isinstance(content, list): for item in content: if isinstance(item, dict): text: str = item.get("text", "") - if text and not text.startswith("File context:"): - return text + if text: + cleaned = _strip_context_from_display_text(text) + if cleaned: + return cleaned else: logger.warning(f"Unhandled human message content in session {session.get('sessionId', 'unknown')}") return "" @@ -614,12 +642,19 @@ def put_session(event: dict, context: dict) -> SuccessResponse | dict: # Only publish metrics for non-API-token users (JWT/UI users) if auth_type != "api_token" and "USAGE_METRICS_QUEUE_NAME" in os.environ: + # Extract modelId from the session configuration if available + model_id = None + if configuration and configuration.selectedModel: + model_id = configuration.selectedModel.modelId + metrics_event = MetricsEvent( userId=user_id, sessionId=session_id, messages=session_data.history, userGroups=groups, timestamp=session_data.lastUpdated, + eventType="full", + modelId=model_id, ) sqs_client.send_message( QueueUrl=os.environ["USAGE_METRICS_QUEUE_NAME"], diff --git a/lambda/utilities/audit_logging_utils.py b/lambda/utilities/audit_logging_utils.py new file mode 100644 index 000000000..9046cb730 --- /dev/null +++ b/lambda/utilities/audit_logging_utils.py @@ -0,0 +1,265 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Shared helpers for API Gateway audit logging. + +Strict opt-in behavior: +- When disabled, callers must not emit audit logs. +- When enabledPaths are provided, audit only applies when the request path + matches one of the configured prefixes (prefix match with path-boundary). +- When auditAll is true, audit applies to all paths. +""" + +from __future__ import annotations + +import json +import logging +import os +from functools import lru_cache +from typing import Any + +_DEFAULT_MAX_BODY_BYTES = 20_000 + +# Keys we always redact anywhere in the JSON structure. +_SENSITIVE_KEYS = { + "password", + "token", + "secret", + "apikey", + "api_key", + "accesskey", + "privatekey", +} + +# Since `_SENSITIVE_KEYS` is already stored in lowercase, no additional `.lower()` pass is needed. +_SENSITIVE_KEYS_LOWER = _SENSITIVE_KEYS + + +def _env_bool(name: str) -> bool: + value = os.getenv(name, "").strip().lower() + return value in ("1", "true", "yes", "y", "on") + + +def _env_int(name: str, default: int) -> int: + raw = os.getenv(name, "").strip() + if not raw: + return default + try: + return int(raw) + except ValueError: + return default + + +def audit_enabled() -> bool: + return _env_bool("LISA_AUDIT_ENABLED") + + +def audit_all() -> bool: + # auditAll is only meaningful when auditing is enabled. + return audit_enabled() and _env_bool("LISA_AUDIT_AUDIT_ALL") + + +def audit_include_json_body() -> bool: + """ + When false (default), callers must not emit AUDIT_API_GATEWAY_REQUEST_BODY. + CDK sets this only when audit logging is enabled and includeJsonBody is true. + """ + return _env_bool("LISA_AUDIT_INCLUDE_JSON_BODY") + + +def log_audit_event(logger: logging.Logger, event_type: str, payload: dict[str, Any]) -> None: + """ + Emit audit data so it appears in CloudWatch log streams. + + Lambda's configured formatter (see ``setup_root_logging``) only prints the log + ``message`` string, not ``logging`` ``extra`` fields. This helper appends a + compact JSON object after the event type for search and Logs Insights (e.g. + parse the substring after the first space as JSON). + """ + record: dict[str, Any] = {"event_type": event_type, **payload} + try: + serialized = json.dumps(record, default=str, separators=(",", ":"), ensure_ascii=False) + except (TypeError, ValueError): + serialized = json.dumps( + {"event_type": event_type, "serialization_error": True}, + default=str, + separators=(",", ":"), + ) + # Use %-style args so JSON or event_type cannot break formatting if they contain "%". + logger.info("%s %s", event_type, serialized) + + +@lru_cache(maxsize=None) +def _parse_enabled_path_prefixes(raw: str) -> list[str]: + if not raw: + return [] + prefixes = [p.strip() for p in raw.split(",") if p.strip()] + normalized: list[str] = [] + for p in prefixes: + if not p.startswith("/"): + p = f"/{p}" + p = p.rstrip("/") + if p: + normalized.append(p) + return normalized + + +def enabled_path_prefixes() -> list[str]: + raw = os.getenv("LISA_AUDIT_ENABLED_PATH_PREFIXES", "") + return _parse_enabled_path_prefixes(raw) + + +def normalize_request_path(path: str) -> str: + if not path: + return "/" + normalized = path.strip() + if not normalized.startswith("/"): + normalized = f"/{normalized}" + normalized = normalized.rstrip("/") + return normalized if normalized else "/" + + +def strip_first_path_segment(path: str) -> str: + # Converts "/prod/session/123" -> "/session/123" + p = normalize_request_path(path) + parts = [part for part in p.split("/") if part] + if len(parts) <= 1: + return p + return "/" + "/".join(parts[1:]) + + +def _path_starts_with_prefix(path: str, prefix: str) -> bool: + if not prefix: + return False + if prefix == "/": + return True + + if not path.startswith(prefix): + return False + if len(path) == len(prefix): + return True + return path[len(prefix)] == "/" + + +def get_matched_audit_prefix(path: str) -> str | None: + """ + Return the matched prefix (e.g. "/session") when auditing should apply. + + Returns: + - "ALL" when auditAll is enabled + - None when strict opt-in does not match + """ + if not audit_enabled(): + return None + if audit_all(): + return "ALL" + + prefixes = enabled_path_prefixes() + if not prefixes: + return None + + p1 = normalize_request_path(path) + p2 = strip_first_path_segment(p1) + + for prefix in prefixes: + if _path_starts_with_prefix(p1, prefix) or _path_starts_with_prefix(p2, prefix): + return prefix + return None + + +def should_audit_path(path: str) -> bool: + return get_matched_audit_prefix(path) is not None + + +def get_method_and_path_from_method_arn(method_arn: str) -> tuple[str, str]: + """ + Parse execute-api methodArn into (http_method, request_path). + + Example: + arn:aws:execute-api:us-east-1:123:abc123/prod/GET/repository/foo + """ + if not method_arn: + return ("unknown", "/") + try: + # arn:...:apiId/stage/VERB/path... + parts = method_arn.split("/") + # parts[-1] includes part(s) of the resource path; join everything after method. + if len(parts) < 4: + return ("unknown", "/") + http_method = parts[2] if len(parts) > 2 else "unknown" + resource_path = "/".join(parts[3:]) if len(parts) > 3 else "" + return (http_method or "unknown", normalize_request_path(resource_path)) + except Exception: + return ("unknown", "/") + + +def sanitize_json_for_audit(value: Any) -> Any: + """ + Recursively redact sensitive keys from JSON values. + + This is intentionally permissive: it redacts by key name anywhere in the structure. + """ + if isinstance(value, dict): + sanitized: dict[str, Any] = {} + for k, v in value.items(): + key = str(k) + if key.lower() in _SENSITIVE_KEYS_LOWER: + sanitized[key] = "" + else: + sanitized[key] = sanitize_json_for_audit(v) + return sanitized + if isinstance(value, list): + return [sanitize_json_for_audit(v) for v in value] + return value + + +def sanitize_json_body_for_audit(body: Any) -> str: + """ + Convert body into a sanitized JSON string suitable for audit logging. + + Returns placeholder strings for non-JSON or oversized bodies. + """ + max_bytes = _env_int("LISA_AUDIT_MAX_BODY_BYTES", _DEFAULT_MAX_BODY_BYTES) + if body is None: + return "" + if isinstance(body, bytes): + try: + raw = body.decode("utf-8") + except UnicodeDecodeError: + return f"" + elif isinstance(body, str): + raw = body + else: + # Unexpected body type (defensive; should still avoid logging sensitive objects). + return "" + + if not raw: + return "" + + raw_bytes_len = len(raw.encode("utf-8")) + if raw_bytes_len > max_bytes: + return "" + + try: + parsed = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return "" + + sanitized = sanitize_json_for_audit(parsed) + try: + return json.dumps(sanitized) + except TypeError: + # Extremely defensive fallback: ensure logs never explode. + return json.dumps(str(sanitized)) diff --git a/lambda/utilities/auth.py b/lambda/utilities/auth.py index b5a8e0670..a6e32182d 100644 --- a/lambda/utilities/auth.py +++ b/lambda/utilities/auth.py @@ -60,11 +60,31 @@ def is_admin(event: dict) -> bool: return result +def is_rag_admin(event: dict) -> bool: + """Get RAG admin status from event using the configured authorization provider.""" + username = get_username(event) + groups = get_groups(event) + auth_provider = get_authorization_provider() + result = auth_provider.check_rag_admin_access(username, groups) + return result + + def get_user_context(event: dict[str, Any]) -> tuple[str, bool, list[str]]: """Extract user context from event.""" return get_username(event), is_admin(event), get_groups(event) +def get_authorizer(event: Any) -> dict[str, Any]: + """Return the API Gateway Lambda authorizer context dict. + + This is a small shared helper so other parts of the codebase don't need to + re-implement the same defensive extraction logic. + """ + if not isinstance(event, dict): + return {} + return event.get("requestContext", {}).get("authorizer", {}) or {} + + def user_has_group_access(user_groups: list[str], allowed_groups: list[str]) -> bool: """ Check if user has access based on group membership. @@ -96,6 +116,18 @@ def wrapper(event: dict[str, Any], context: dict[str, Any], *args: Any, **kwargs return wrapper +def rag_admin_or_admin(func: Callable) -> Callable: + """Decorator that allows access for users with admin or RAG admin privileges.""" + + @wraps(func) + def wrapper(event: dict[str, Any], context: dict[str, Any], *args: Any, **kwargs: Any) -> Any: + if not is_admin(event) and not is_rag_admin(event): + raise ForbiddenException("User does not have permission to access this resource") + return func(event, context, *args, **kwargs) + + return wrapper + + def get_management_key() -> str: secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) secret_name = secret_name_param["Parameter"]["Value"] diff --git a/lambda/utilities/auth_provider.py b/lambda/utilities/auth_provider.py index 0f28c645d..308e314f3 100644 --- a/lambda/utilities/auth_provider.py +++ b/lambda/utilities/auth_provider.py @@ -45,6 +45,24 @@ def check_admin_access(self, username: str, groups: list[str] | None = None) -> """ pass + @abstractmethod + def check_rag_admin_access(self, username: str, groups: list[str] | None = None) -> bool: + """Check if a user has RAG admin access. + + Parameters + ---------- + username : str + The username to check RAG admin access for + groups : list[str] | None + Optional list of groups the user belongs to (used by group-based providers) + + Returns + ------- + bool + True if user has RAG admin access, False otherwise + """ + pass + @abstractmethod def check_app_access(self, username: str, groups: list[str] | None = None) -> bool: """Check if a user has general application access. @@ -70,7 +88,9 @@ class OIDCAuthorizationProvider(AuthorizationProvider): Uses JWT group claims to determine admin and app access. """ - def __init__(self, admin_group: str | None = None, user_group: str | None = None): + def __init__( + self, admin_group: str | None = None, user_group: str | None = None, rag_admin_group: str | None = None + ): """Initialize the OIDC authorization provider. Parameters @@ -79,9 +99,12 @@ def __init__(self, admin_group: str | None = None, user_group: str | None = None The admin group name. If not provided, uses ADMIN_GROUP env var at check time. user_group : str | None The user group name. If not provided, uses USER_GROUP env var at check time. + rag_admin_group : str | None + The RAG admin group name. If not provided, uses RAG_ADMIN_GROUP env var at check time. """ self._admin_group = admin_group self._user_group = user_group + self._rag_admin_group = rag_admin_group @property def admin_group(self) -> str: @@ -93,6 +116,11 @@ def user_group(self) -> str: """Get user group, reading from env if not explicitly set.""" return self._user_group if self._user_group is not None else os.environ.get("USER_GROUP", "") + @property + def rag_admin_group(self) -> str: + """Get RAG admin group, reading from env if not explicitly set.""" + return self._rag_admin_group if self._rag_admin_group is not None else os.environ.get("RAG_ADMIN_GROUP", "") + def check_admin_access(self, username: str, groups: list[str] | None = None) -> bool: """Check if user has admin access based on group membership. @@ -116,6 +144,19 @@ def check_admin_access(self, username: str, groups: list[str] | None = None) -> logger.info(f"User groups: {groups} and admin: {self.admin_group}") return is_admin + def check_rag_admin_access(self, username: str, groups: list[str] | None = None) -> bool: + """Check if user has RAG admin access based on group membership.""" + if not self.rag_admin_group: + return False + + if not groups: + logger.debug(f"No groups provided for user {username}") + return False + + is_rag_admin = self.rag_admin_group in groups + logger.info(f"User groups: {groups} and rag_admin: {self.rag_admin_group}") + return is_rag_admin + def check_app_access(self, username: str, groups: list[str] | None = None) -> bool: """Check if user has app access based on group membership. diff --git a/lambda/utilities/fastapi_middleware/request_logging_middleware.py b/lambda/utilities/fastapi_middleware/request_logging_middleware.py index 105cfe76d..9e96f009b 100644 --- a/lambda/utilities/fastapi_middleware/request_logging_middleware.py +++ b/lambda/utilities/fastapi_middleware/request_logging_middleware.py @@ -16,10 +16,18 @@ import json import logging +import os import time from typing import Any from starlette.middleware.base import BaseHTTPMiddleware, Request, RequestResponseEndpoint, Response +from utilities.audit_logging_utils import ( + audit_include_json_body, + get_matched_audit_prefix, + log_audit_event, + sanitize_json_body_for_audit, +) +from utilities.auth import get_authorizer from utilities.header_sanitizer import sanitize_headers logger = logging.getLogger(__name__) @@ -62,6 +70,37 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - # Build sanitized request data for logging log_data = self._build_log_data(request, event) + # Optional audit payload logging (strict opt-in by API Gateway path prefix). + # For Mangum apps, the FastAPI route path may not include the API Gateway base path, + # so we allow the deployer to provide it via LISA_AUDIT_API_GATEWAY_BASE_PATH. + base_path = os.getenv("LISA_AUDIT_API_GATEWAY_BASE_PATH", "") + full_path = request.url.path + if base_path: + full_path = f"{base_path.rstrip('/')}{request.url.path}" + + audit_prefix = get_matched_audit_prefix(full_path) + if audit_prefix and audit_include_json_body(): + authorizer = get_authorizer(event) + body_bytes = await request.body() + if not body_bytes: + # No payload to audit. + body_bytes = b"" + sanitized_body = sanitize_json_body_for_audit(body_bytes) + if sanitized_body: + log_audit_event( + logger, + "AUDIT_API_GATEWAY_REQUEST_BODY", + { + "area": audit_prefix, + "action": f"{request.method} {full_path}", + "user": { + "username": authorizer.get("username", "unknown"), + "auth_type": authorizer.get("authType", "unknown"), + }, + "body": sanitized_body, + }, + ) + # Log the incoming request logger.info( f"Request: {request.method} {request.url.path}", @@ -100,8 +139,8 @@ def _build_log_data(self, request: Request, event: dict[str, Any]) -> dict[str, Dictionary with sanitized request data for logging """ # Extract request context - request_context = event.get("requestContext", {}) - authorizer = request_context.get("authorizer", {}) + request_context = event.get("requestContext", {}) if isinstance(event, dict) else {} + authorizer = get_authorizer(event) identity = request_context.get("identity", {}) # Sanitize headers (redact auth, replace user-controlled headers) diff --git a/lambda/utilities/lambda_decorators.py b/lambda/utilities/lambda_decorators.py index 73f041de6..ad79c4dfd 100644 --- a/lambda/utilities/lambda_decorators.py +++ b/lambda/utilities/lambda_decorators.py @@ -20,6 +20,12 @@ from contextvars import ContextVar from typing import Any, overload +from utilities.audit_logging_utils import ( + audit_include_json_body, + get_matched_audit_prefix, + log_audit_event, + sanitize_json_body_for_audit, +) from utilities.event_parser import sanitize_event_for_logging from utilities.input_validation import DEFAULT_MAX_REQUEST_SIZE, validate_input from utilities.response_builder import generate_exception_response, generate_html_response @@ -106,6 +112,35 @@ def wrapper(event: dict[Any, Any], context: Any) -> dict[Any, Any]: sanitized_event = sanitize_event_for_logging(event) logger.info(f"Lambda {lambda_func_name}({code_func_name}) invoked with {sanitized_event}") + # Optional audit payload logging (strict opt-in by path prefix + includeJsonBody). + audit_prefix = get_matched_audit_prefix(event.get("path", "") or "") + if audit_prefix and audit_include_json_body(): + body = event.get("body") + if not body: + # No payload to audit. + body = None + if body is not None: + http_method = event.get("httpMethod", "unknown") + path = event.get("path", "") or "/" + + # API Gateway can pass non-JSON bodies; sanitize_json_body_for_audit handles placeholders. + sanitized_body = sanitize_json_body_for_audit(body) + authorizer = event.get("requestContext", {}).get("authorizer", {}) or {} + + log_audit_event( + logger, + "AUDIT_API_GATEWAY_REQUEST_BODY", + { + "area": audit_prefix, + "action": f"{http_method} {path}", + "user": { + "username": authorizer.get("username", "unknown"), + "auth_type": authorizer.get("authType", "unknown"), + }, + "body": sanitized_body, + }, + ) + try: result = f(event, context) return generate_html_response(200, result) diff --git a/lambda/utilities/response_builder.py b/lambda/utilities/response_builder.py index c6462bb24..dffd918e4 100644 --- a/lambda/utilities/response_builder.py +++ b/lambda/utilities/response_builder.py @@ -115,7 +115,7 @@ def generate_html_response(status_code: int, response_body: Any) -> dict[str, st "Content-Type": "application/json", "Cache-Control": "no-store, no-cache", "Pragma": "no-cache", - "Strict-Transport-Security": "max-age:47304000; includeSubDomains", + "Strict-Transport-Security": "max-age=47304000; includeSubDomains", "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", }, diff --git a/lib/api-base/auditEnv.ts b/lib/api-base/auditEnv.ts new file mode 100644 index 000000000..e0556fe36 --- /dev/null +++ b/lib/api-base/auditEnv.ts @@ -0,0 +1,48 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Config } from '../schema'; + +export const LISA_AUDIT_ENABLED = 'LISA_AUDIT_ENABLED'; +export const LISA_AUDIT_AUDIT_ALL = 'LISA_AUDIT_AUDIT_ALL'; +export const LISA_AUDIT_ENABLED_PATH_PREFIXES = 'LISA_AUDIT_ENABLED_PATH_PREFIXES'; +export const LISA_AUDIT_MAX_BODY_BYTES = 'LISA_AUDIT_MAX_BODY_BYTES'; +export const LISA_AUDIT_INCLUDE_JSON_BODY = 'LISA_AUDIT_INCLUDE_JSON_BODY'; +export const LISA_AUDIT_API_GATEWAY_BASE_PATH = 'LISA_AUDIT_API_GATEWAY_BASE_PATH'; + +function normalizePrefix (prefix: string): string { + const trimmed = prefix.trim(); + if (!trimmed) return ''; + const withLeading = trimmed.startsWith('/') ? trimmed : `/${trimmed}`; + return withLeading.replace(/\/+$/, ''); +} + +export function getAuditLoggingEnv (config: Config): Record { + const audit = config.auditLoggingConfig; + const enabled = audit?.enabled ?? false; + const all = enabled && (audit?.auditAll ?? false); + const enabledPaths = (audit?.enabledPaths ?? []).map(normalizePrefix).filter(Boolean); + const maxBytes = audit?.maxRequestBodyBytes ?? 20000; + const includeJsonBody = enabled && (audit?.includeJsonBody ?? false); + + return { + [LISA_AUDIT_ENABLED]: String(enabled), + [LISA_AUDIT_AUDIT_ALL]: String(all), + [LISA_AUDIT_ENABLED_PATH_PREFIXES]: enabledPaths.join(','), + [LISA_AUDIT_MAX_BODY_BYTES]: String(maxBytes), + [LISA_AUDIT_INCLUDE_JSON_BODY]: String(includeJsonBody), + }; +} diff --git a/lib/api-base/authorizer.ts b/lib/api-base/authorizer.ts index 4f522c33c..2e1380958 100644 --- a/lib/api-base/authorizer.ts +++ b/lib/api-base/authorizer.ts @@ -29,6 +29,7 @@ import { Vpc } from '../networking/vpc'; import { getPythonRuntime } from './utils'; import { ITable } from 'aws-cdk-lib/aws-dynamodb'; import { LAMBDA_PATH } from '../util'; +import { getAuditLoggingEnv } from './auditEnv'; /** * Properties for RestApiGateway Construct. @@ -92,9 +93,11 @@ export class CustomAuthorizer extends Construct { AUTHORITY: config.authConfig!.authority, ADMIN_GROUP: config.authConfig!.adminGroup, USER_GROUP: config.authConfig!.userGroup, + RAG_ADMIN_GROUP: config.authConfig!.ragAdminGroup, JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty, MANAGEMENT_KEY_NAME: managementKeySecretName, - ...(tokenTable ? { TOKEN_TABLE_NAME: tokenTable?.tableName } : {}) + ...(tokenTable ? { TOKEN_TABLE_NAME: tokenTable?.tableName } : {}), + ...getAuditLoggingEnv(config), }, role: role, vpc: vpc.vpc, diff --git a/lib/api-base/ecsCluster.ts b/lib/api-base/ecsCluster.ts index e6b7dc35e..d11c1c9be 100644 --- a/lib/api-base/ecsCluster.ts +++ b/lib/api-base/ecsCluster.ts @@ -205,7 +205,7 @@ export class ECSCluster extends Construct { const cluster = new Cluster(this, createCdkId([config.deploymentName, config.deploymentStage, 'Cl']), { clusterName: createCdkId([config.deploymentName, config.deploymentStage, identifier], 32, 2), vpc: vpc.vpc, - containerInsightsV2: !config.region.includes('iso') ? ContainerInsights.ENABLED : ContainerInsights.DISABLED, + containerInsightsV2: ContainerInsights.ENHANCED, }); const asgSecurityGroup = new SecurityGroup(this, 'RestAsgSecurityGroup', { diff --git a/lib/api-base/fastApiContainer.ts b/lib/api-base/fastApiContainer.ts index 20e911a8f..cad18ed58 100644 --- a/lib/api-base/fastApiContainer.ts +++ b/lib/api-base/fastApiContainer.ts @@ -109,6 +109,7 @@ export class FastApiContainer extends Construct { CLIENT_ID: config.authConfig!.clientId, ADMIN_GROUP: config.authConfig!.adminGroup, USER_GROUP: config.authConfig!.userGroup, + RAG_ADMIN_GROUP: config.authConfig!.ragAdminGroup, JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty, MANAGEMENT_KEY_NAME: managementKeyName }; diff --git a/lib/api-base/utils.ts b/lib/api-base/utils.ts index 94974c9fd..5cf1e66b1 100644 --- a/lib/api-base/utils.ts +++ b/lib/api-base/utils.ts @@ -125,7 +125,7 @@ export function registerAPIEndpoint ( vpc: vpc.vpc, securityGroups, vpcSubnets: vpc.subnetSelection, - logRetention: RetentionDays.ONE_MONTH, + logRetention: RetentionDays.ONE_MONTH }); } diff --git a/lib/api-tokens/api-tokens.ts b/lib/api-tokens/api-tokens.ts index 823f0a45c..a83c55e21 100644 --- a/lib/api-tokens/api-tokens.ts +++ b/lib/api-tokens/api-tokens.ts @@ -28,6 +28,7 @@ import { Vpc } from '../networking/vpc'; import { createLambdaRole } from '../core/utils'; import { LAMBDA_PATH } from '../util'; import { Table } from 'aws-cdk-lib/aws-dynamodb'; +import { getAuditLoggingEnv, LISA_AUDIT_API_GATEWAY_BASE_PATH } from '../api-base/auditEnv'; /** * Properties for ApiTokensApi Construct. @@ -95,6 +96,8 @@ export class ApiTokensApi extends Construct { TOKEN_TABLE_NAME: tokenTable.tableName, ADMIN_GROUP: config.authConfig?.adminGroup || '', API_GROUP: config.authConfig?.apiGroup || '', + ...getAuditLoggingEnv(config), + [LISA_AUDIT_API_GATEWAY_BASE_PATH]: '/api-tokens', }; // Create Lambda role with DynamoDB permissions diff --git a/lib/chat/api/chat-assistant-stacks-api.ts b/lib/chat/api/chat-assistant-stacks-api.ts index 857276cb1..5ea3b3525 100644 --- a/lib/chat/api/chat-assistant-stacks-api.ts +++ b/lib/chat/api/chat-assistant-stacks-api.ts @@ -21,6 +21,7 @@ import { AttributeType, BillingMode, Table } from 'aws-cdk-lib/aws-dynamodb'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; import { createLambdaRole } from '../../core/utils'; +import { getAuditLoggingEnv } from '../../api-base/auditEnv'; import { IRole } from 'aws-cdk-lib/aws-iam'; import { IAuthorizer, RestApi } from 'aws-cdk-lib/aws-apigateway'; import { Vpc } from '../../networking/vpc'; @@ -75,6 +76,7 @@ export class ChatAssistantStacksApi extends Construct { const environment = { CHAT_ASSISTANT_STACKS_TABLE_NAME: this.stacksTable.tableName, ADMIN_GROUP: config.authConfig?.adminGroup || '', + ...getAuditLoggingEnv(config), }; const apis: PythonLambdaFunction[] = [ diff --git a/lib/chat/api/configuration.ts b/lib/chat/api/configuration.ts index 1e8fcec50..d8d78f6df 100644 --- a/lib/chat/api/configuration.ts +++ b/lib/chat/api/configuration.ts @@ -24,6 +24,7 @@ import { Construct } from 'constructs'; import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; import { BaseProps } from '../../schema'; import { createLambdaRole } from '../../core/utils'; +import { getAuditLoggingEnv } from '../../api-base/auditEnv'; import { Vpc } from '../../networking/vpc'; import { AwsCustomResource, PhysicalResourceId } from 'aws-cdk-lib/custom-resources'; import { IRole } from 'aws-cdk-lib/aws-iam'; @@ -118,12 +119,14 @@ export class ConfigurationApi extends Construct { 'editChatHistoryBuffer': { 'BOOL': 'True' }, 'editNumOfRagDocument': { 'BOOL': 'True' }, 'uploadRagDocs': { 'BOOL': config.deployRag ? 'True' : 'False' }, + 'ragSelectionAvailable': { 'BOOL': config.deployRag ? 'True' : 'False' }, 'uploadContextDocs': { 'BOOL': 'True' }, 'documentSummarization': { 'BOOL': 'True' }, 'showRagLibrary': { 'BOOL': config.deployRag ? 'True' : 'False' }, 'showMcpWorkbench': { 'BOOL': config.deployMcpWorkbench ? 'True' : 'False' }, 'showPromptTemplateLibrary': { 'BOOL': 'True' }, 'mcpConnections': { 'BOOL': config.deployMcp ? 'True' : 'False' }, + 'awsSessions': { 'BOOL': 'False' }, 'modelLibrary': { 'BOOL': 'True' }, 'encryptSession': { 'BOOL': 'False' }, 'chatAssistantStacks': { 'BOOL': 'False' }, @@ -154,12 +157,20 @@ export class ConfigurationApi extends Construct { const fastApiEndpoint = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/serve/endpoint`); - let environment = { + let environment: Record = { CONFIG_TABLE_NAME: this.configTable.tableName, FASTAPI_ENDPOINT: fastApiEndpoint, ADMIN_GROUP: config.authConfig?.adminGroup || '', + ...getAuditLoggingEnv(config), }; + if (config.deployMcpWorkbench) { + environment.MCP_WORKBENCH_ENDPOINT = StringParameter.valueForStringParameter( + this, + `${config.deploymentPrefix}/mcpWorkbench/endpoint`, + ); + } + if (mcpApi) { this.createMcpApiTable(mcpApi, lambdaRole, environment); } diff --git a/lib/chat/api/mcp.ts b/lib/chat/api/mcp.ts index 04845cda5..3a1b8ad35 100644 --- a/lib/chat/api/mcp.ts +++ b/lib/chat/api/mcp.ts @@ -25,6 +25,7 @@ import { Construct } from 'constructs'; import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; import { BaseProps } from '../../schema'; import { createLambdaRole } from '../../core/utils'; +import { getAuditLoggingEnv } from '../../api-base/auditEnv'; import { Vpc } from '../../networking/vpc'; import { LAMBDA_PATH } from '../../util'; import { RemovalPolicy } from 'aws-cdk-lib'; @@ -116,6 +117,7 @@ export class McpApi extends Construct { ADMIN_GROUP: config.authConfig?.adminGroup || '', MCP_SERVERS_TABLE_NAME: mcpServersTable.tableName, MCP_SERVERS_BY_OWNER_INDEX_NAME: byOwnerIndex, + ...getAuditLoggingEnv(config), }; // Create API Lambda functions diff --git a/lib/chat/api/projects.ts b/lib/chat/api/projects.ts index 396e6e2de..416ca7446 100644 --- a/lib/chat/api/projects.ts +++ b/lib/chat/api/projects.ts @@ -26,6 +26,7 @@ import { Construct } from 'constructs'; import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; import { BaseProps } from '../../schema'; import { createLambdaRole } from '../../core/utils'; +import { getAuditLoggingEnv } from '../../api-base/auditEnv'; import { Vpc } from '../../networking/vpc'; import { LAMBDA_PATH } from '../../util'; @@ -80,6 +81,7 @@ export class ProjectsApi extends Construct { SESSIONS_TABLE_NAME: sessionTable.tableName, SESSIONS_BY_USER_ID_INDEX_NAME: 'byUserId', CONFIG_TABLE_NAME: configTable.tableName, + ...getAuditLoggingEnv(config), }; const lambdaRole: IRole = createLambdaRole( diff --git a/lib/chat/api/prompt-template-api.ts b/lib/chat/api/prompt-template-api.ts index 6f8296b0b..c23b8ddd4 100644 --- a/lib/chat/api/prompt-template-api.ts +++ b/lib/chat/api/prompt-template-api.ts @@ -21,6 +21,7 @@ import { AttributeType, BillingMode, ProjectionType, Table } from 'aws-cdk-lib/a import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; import { createLambdaRole } from '../../core/utils'; +import { getAuditLoggingEnv } from '../../api-base/auditEnv'; import { IRole } from 'aws-cdk-lib/aws-iam'; import { IAuthorizer, RestApi } from 'aws-cdk-lib/aws-apigateway'; import { Vpc } from '../../networking/vpc'; @@ -107,6 +108,7 @@ export class PromptTemplateApi extends Construct { ADMIN_GROUP: config.authConfig?.adminGroup || '', PROMPT_TEMPLATES_TABLE_NAME: promptTemplatesTable.tableName, PROMPT_TEMPLATES_BY_LATEST_INDEX_NAME: byOwnerIndexName, + ...getAuditLoggingEnv(config), }; const apis: PythonLambdaFunction[] = [ diff --git a/lib/chat/api/session.ts b/lib/chat/api/session.ts index d5bc137a2..d8261c206 100644 --- a/lib/chat/api/session.ts +++ b/lib/chat/api/session.ts @@ -26,6 +26,7 @@ import { Construct } from 'constructs'; import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; import { BaseProps, RemovalPolicy } from '../../schema'; import { createLambdaRole } from '../../core/utils'; +import { getAuditLoggingEnv } from '../../api-base/auditEnv'; import { Vpc } from '../../networking/vpc'; import { LAMBDA_PATH } from '../../util'; @@ -142,6 +143,7 @@ export class SessionApi extends Construct { CONFIG_TABLE_NAME: configTable.tableName, SESSION_ENCRYPTION_KEY_ARN: sessionEncryptionKey.keyArn, ...(projectsTableName ? { PROJECTS_TABLE_NAME: projectsTableName } : {}), + ...getAuditLoggingEnv(config), }; const lambdaRole: IRole = createLambdaRole( diff --git a/lib/chat/api/user-preferences.ts b/lib/chat/api/user-preferences.ts index 2091d63f3..8fcc21e24 100644 --- a/lib/chat/api/user-preferences.ts +++ b/lib/chat/api/user-preferences.ts @@ -25,6 +25,7 @@ import { Construct } from 'constructs'; import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; import { BaseProps } from '../../schema'; import { createLambdaRole } from '../../core/utils'; +import { getAuditLoggingEnv } from '../../api-base/auditEnv'; import { Vpc } from '../../networking/vpc'; import { LAMBDA_PATH } from '../../util'; import { RemovalPolicy } from 'aws-cdk-lib'; @@ -88,7 +89,8 @@ export class UserPreferencesApi extends Construct { }); const env = { - USER_PREFERENCES_TABLE_NAME: userPreferencesTable.tableName + USER_PREFERENCES_TABLE_NAME: userPreferencesTable.tableName, + ...getAuditLoggingEnv(config), }; // Create API Lambda functions diff --git a/lib/core/apiBaseConstruct.ts b/lib/core/apiBaseConstruct.ts index 725ff54ca..9b99dc1c5 100644 --- a/lib/core/apiBaseConstruct.ts +++ b/lib/core/apiBaseConstruct.ts @@ -41,11 +41,12 @@ import { LAMBDA_PATH } from '../util'; import { getPythonRuntime } from '../api-base/utils'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { EventBus } from 'aws-cdk-lib/aws-events'; -import { BlockPublicAccess, Bucket, BucketEncryption, HttpMethods } from 'aws-cdk-lib/aws-s3'; +import { BlockPublicAccess, Bucket, BucketEncryption, HttpMethods, IBucket } from 'aws-cdk-lib/aws-s3'; export type LisaApiBaseProps = { vpc: Vpc; securityGroups: ISecurityGroup[]; + bucketAccessLogsBucket: IBucket; } & BaseProps & StackProps; @@ -66,12 +67,7 @@ export class LisaApiBaseConstruct extends Construct { constructor (scope: Stack, id: string, props: LisaApiBaseProps) { super(scope, id); - const { config, vpc, securityGroups } = props; - - // Get bucket access logs bucket - const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', - StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) - ); + const { bucketAccessLogsBucket, config, vpc, securityGroups } = props; // Create Images S3 bucket for generated images and videos // This is created in API Base stack so it's available to both Chat and Serve stacks diff --git a/lib/core/coreConstruct.ts b/lib/core/coreConstruct.ts index 6f1559996..f713dff72 100644 --- a/lib/core/coreConstruct.ts +++ b/lib/core/coreConstruct.ts @@ -34,6 +34,7 @@ export type CoreStackProps = BaseProps & StackProps; * Creates Lambda layers */ export class CoreConstruct extends Construct { + public readonly loggingBucket: Bucket; /** * @param {Construct} scope - The parent or owner of the construct. * @param {string} id - The unique identifier for the construct within its scope. @@ -42,7 +43,7 @@ export class CoreConstruct extends Construct { super(scope, id); const { config } = props; - const loggingBucket = new Bucket(scope, 'BucketAccessLogsBucket', { + this.loggingBucket = new Bucket(scope, 'BucketAccessLogsBucket', { removalPolicy: config.removalPolicy, autoDeleteObjects: config.removalPolicy === RemovalPolicy.DESTROY, bucketName: ([config.deploymentName, config.accountNumber, config.deploymentStage, 'bucket', 'access', 'logs'].join('-')).toLowerCase(), @@ -54,7 +55,7 @@ export class CoreConstruct extends Construct { new StringParameter(scope, 'LISABucketAccessLogsBucket', { parameterName: `${config.deploymentPrefix}/bucket/bucket-access-logs`, - stringValue: loggingBucket.bucketArn, + stringValue: this.loggingBucket.bucketArn, description: 'A bucket for access logs from other buckets to be written to.', }); diff --git a/lib/core/index.ts b/lib/core/index.ts index 28f4813fc..93beed351 100644 --- a/lib/core/index.ts +++ b/lib/core/index.ts @@ -17,6 +17,7 @@ import { Construct } from 'constructs'; import { Stack } from 'aws-cdk-lib'; +import { IBucket } from 'aws-cdk-lib/aws-s3'; import { CoreConstruct, CoreStackProps } from './coreConstruct'; export * from './coreConstruct'; @@ -27,6 +28,7 @@ export * from './apiDeploymentConstruct'; * Creates Lambda layers */ export class CoreStack extends Stack { + public readonly loggingBucket: IBucket; /** * @param {Construct} scope - The parent or owner of the construct. * @param {string} id - The unique identifier for the construct within its scope. @@ -34,6 +36,8 @@ export class CoreStack extends Stack { constructor (scope: Construct, id: string, props: CoreStackProps) { super(scope, id, props); - (new CoreConstruct(this, id + 'Resources', props)).node.addMetadata('aws:cdk:path', this.node.path); + const core = new CoreConstruct(this, id + 'Resources', props); + core.node.addMetadata('aws:cdk:path', this.node.path); + this.loggingBucket = core.loggingBucket; } } diff --git a/lib/core/layers/common/requirements.txt b/lib/core/layers/common/requirements.txt index 211039264..00a0e2d29 100644 --- a/lib/core/layers/common/requirements.txt +++ b/lib/core/layers/common/requirements.txt @@ -4,3 +4,4 @@ psycopg2-binary==2.9.11 cachetools==7.0.1 requests==2.32.5 +pydantic>=2.5.0,<3.0.0 diff --git a/lib/core/layers/index.ts b/lib/core/layers/index.ts index b47fc284d..5635b10e3 100644 --- a/lib/core/layers/index.ts +++ b/lib/core/layers/index.ts @@ -175,7 +175,7 @@ export class NodeLayer extends Construct { // Install dependencies console.log(`Building Node.js layer: ${id} at ${layerPath}`); - execSync('npm install --omit=dev --production', { + execSync('npm install --omit=dev', { cwd: nodejsDir, stdio: 'inherit', }); diff --git a/lib/docs/.vitepress/config.mts b/lib/docs/.vitepress/config.mts index b47af5ed8..573178edb 100644 --- a/lib/docs/.vitepress/config.mts +++ b/lib/docs/.vitepress/config.mts @@ -92,7 +92,10 @@ const navLinks = [ { text: 'LISA Chat UI', link: '/user/chat' }, { text: 'Document Library Management', link: '/user/document-library' }, { text: 'Model Library', link: '/user/model-library' }, - { text: 'Breaking Changes', link: '/user/breaking-changes' }, + { text: 'Prompt Template Library', link: '/user/prompt-template-library' }, + { text: 'Session History', link: '/config/session' }, + { text: 'User Preferences', link: '/config/user-preferences' }, + { text: 'Breaking Changes', link: '/config/breaking-changes' }, { text: 'Change Log', link: 'https://github.com/awslabs/LISA/releases' }, ], }, @@ -103,19 +106,18 @@ const navLinks = [ { text: 'Chat Assistant Stacks', link: '/config/chat-assistant-stacks#api-reference' }, { text: 'Collection Management (Repository)', link: '/config/collection-management-api#endpoints' }, { text: 'Bedrock Guardrails', link: '/config/guardrails#managing-guardrails-via-lisa-models-api' }, - { text: 'Hosted MCP Servers (/mcp)', link: '/config/hosted-mcp#api-operations' }, + { text: 'Hosted MCP Servers', link: '/config/hosted-mcp#api-operations' }, { text: 'Metrics', link: '/admin/api-overview#metrics-api-gateway-endpoints' }, - { text: 'Model Management (/models)', link: '/config/model-management-api#listing-models-admin-api' }, - { text: 'Project Organization (/project)', link: '/config/projects#api-reference' }, + { text: 'Model Management', link: '/config/model-management-api#listing-models-admin-api' }, + { text: 'Project Organization', link: '/config/projects#api-reference' }, { text: 'RAG Repository', link: '/config/repositories#configuration-examples' }, - // TODO: Add API documentation for the following APIs - // { text: 'MCP Workbench', link: '/config/mcp-workbench#programmatic-api-access' }, - // { text: 'Bedrock Knowledge Base (/bedrock-kb)', link: '/config/TODO-bedrock-kb#api-reference' }, - // { text: 'MCP Server Connections (/mcp-server)', link: '/config/TODO-mcp-server#api-reference' }, - // { text: 'MCP Workbench tool management REST API (/mcp-workbench)', link: '/config/TODO-mcp-workbench#api-reference' }, - // { text: 'Prompt Templates (/prompt-templates)', link: '/config/TODO-prompt-templates#api-reference' }, - // { text: 'Session (/session)', link: '/config/TODO-session#api-reference' }, - // { text: 'User Preferences (/user-preferences)', link: '/config/TODO-user-preferences#api-reference' }, + { text: 'MCP Workbench', link: '/config/mcp-workbench#programmatic-api-access' }, + { text: 'Bedrock Knowledge Base', link: '/config/repositories#bedrock-knowledge-base-api-reference' }, + { text: 'MCP Server Connections', link: '/config/mcp#api-reference' }, + { text: 'MCP Workbench', link: '/config/mcp-workbench#api-reference' }, + { text: 'Prompt Templates', link: '/config/prompt-templates#api-reference' }, + { text: 'Session', link: '/config/session#api-reference' }, + { text: 'User Preferences', link: '/config/user-preferences#api-reference' }, ], }, ]; @@ -131,6 +133,11 @@ export default defineConfig({ markdown: { config(md) { md.use(tabsMarkdownPlugin) + const defaultRender = md.render.bind(md); + md.render = (src, env) => { + src = src.replace(/Array<([^>]+)>/g, 'Array<$1>'); + return defaultRender(src, env); + }; }, }, // https://vitepress.dev/reference/default-theme-config diff --git a/lib/docs/admin/deploy.md b/lib/docs/admin/deploy.md index dd74939c6..5b9c1dac6 100644 --- a/lib/docs/admin/deploy.md +++ b/lib/docs/admin/deploy.md @@ -1,4 +1,5 @@ # Deployment + ## Prerequisites * Set up or have access to an AWS account. @@ -21,6 +22,22 @@ > To minimize version conflicts and ensure a consistent deployment environment, we recommend executing the following steps on a dedicated EC2 instance. However, LISA can be deployed from any machine that meets the prerequisites listed above. ## Deployment Steps + + +LISA uses npm scripts for build and deployment. Key commands: + +| Task | Command | +|------|---------| +| Install Python & TypeScript deps | `npm run install:python` then `npm install` | +| Stage model weights | `npm run model:check` | +| Bootstrap CDK | `npm run bootstrap` | +| Deploy (full pipeline) | `npm run deploy` | +| Build archive (ADC pre-build) | `npm run build:archive` | +| List CDK stacks | `npm run cdk:list` | + +The `npm run deploy` script runs the full pipeline: install dependencies, Docker checks, ECR login, model verification, build, and CDK deploy. Use `STACK= npm run deploy` to deploy specific stacks. + + ### Step 1: Clone the Repository Ensure you're working with the latest stable release of LISA: @@ -30,11 +47,14 @@ git clone -b main --single-branch cd lisa ``` -### Step 2a: Create/Configure `config-custom.yaml`: +### Step 2a: Create/Configure `config-custom.yaml` + Run the command below to copy the example configuration into `config-custom.yaml`. This will create the file if it doesn't exist already. + ```bash cp example_config.yaml config-custom.yaml ``` + Review the `config-custom.yaml` settings. Some settings will be configured later in this guide. ### Step 2b: Set Up Environment Variables @@ -50,8 +70,11 @@ export CDK_DOCKER=finch # Optional, only required if not using docker as contain ### Step 3: Set Up Python and TypeScript Environments -Install system dependencies and set up both Python and TypeScript environments: + - ***NOTE** The code block below has two tabs for Debian & EL/AL2* +Install system dependencies and set up both Python and TypeScript environments using the project's npm scripts: + +* ***NOTE** The code block below has two tabs for Debian & EL/AL2* :::tabs == Debian @@ -61,19 +84,16 @@ Install system dependencies and set up both Python and TypeScript environments: sudo apt-get update sudo apt-get install -y jq -# Install Python packages +# Install Python packages (for model staging) pip3 install --user --upgrade pip pip3 install yq huggingface_hub s5cmd -# Set up Python environment -make createPythonEnvironment && source .venv/bin/activate - -# Install Python Requirements -make installPythonRequirements +# Create and activate Python virtual environment +python3 -m venv .venv && source .venv/bin/activate -# Set up TypeScript environment -make createTypeScriptEnvironment -make installTypeScriptRequirements +# Install Python and TypeScript dependencies via npm scripts +npm run install:python +npm install ``` == EL / AL2 @@ -85,46 +105,83 @@ sudo yum update -y && yum install -y git jq yq # Install runtimes (use mise for installation - https://mise.jdx.dev/installing-mise.html) mise use --global python@3.13 node@24 -# Install Python packages +# Install Python packages (for model staging) pip3 install --user --upgrade pip pip3 install yq huggingface_hub s5cmd -# Set up Python environment -make createPythonEnvironment && source .venv/bin/activate +# Create and activate Python virtual environment +python3 -m venv .venv && source .venv/bin/activate + +# Install Python and TypeScript dependencies via npm scripts +npm run install:python +npm install +``` + -# Install Python Requirements -make installPythonRequirements +== MacOS -# Set up TypeScript environment -make createTypeScriptEnvironment -make installTypeScriptRequirements +```bash +# 0) Install Homebrew if not installed +/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" + +# 1) Install core tools +brew update +brew install git jq yq s5cmd + +# 2) Install and activate mise for zsh +curl https://mise.run | sh +echo 'eval "$(~.local/bin/mise activate zsh)"' >> ~/.zshrc +source ~/.zshrc + +# 3) Install runtimes +mise use --global python@3.13 node@24 + +# 4) Verify you are using mise versions +which python +which node +python --version +node --version + + +# 5) Create and activate Python virtual environment +python3 -m venv .venv +source .venv/bin/activate + +# 6) Upgrade pip and install model-staging tools +python -m pip install --upgrade pip +python -m pip install huggingface_hub yq + +# 7) Install Python and TypeScript dependencies via npm scripts +npm run install:python +npm install ``` + ::: ### Step 4: Configure LISA Edit the `config-custom.yaml` file to customize your LISA deployment. Key configurations include: -- AWS account and region settings -- Authentication settings -- Model bucket name +* AWS account and region settings +* Authentication settings +* Model bucket name ### Step 5: Configure Identity Provider In the `config-custom.yaml` file, configure the `authConfig` block for authentication. LISA supports OpenID Connect (OIDC) providers such as AWS Cognito or Keycloak. Required fields include: -- `authority`: URL of your identity provider -- `clientId`: Client ID for your application -- `adminGroup`: Group name for users with model management permissions -- `userGroup`: Group name for regular LISA users -- `jwtGroupsProperty`: Path to the groups field in the JWT token -- `additionalScopes` (optional): Extra scopes for group membership information +* `authority`: URL of your identity provider +* `clientId`: Client ID for your application +* `adminGroup`: Group name for users with model management permissions +* `userGroup`: Group name for regular LISA users +* `jwtGroupsProperty`: Path to the groups field in the JWT token +* `additionalScopes` (optional): Extra scopes for group membership information IDP Configuration examples using AWS Cognito and Keycloak can be found: [IDP Configuration Examples](/admin/idp-config) - ### Step 6: Configure LiteLLM + We utilize LiteLLM under the hood to allow LISA to respond to the [OpenAI specification](https://platform.openai.com/docs/api-reference). For LiteLLM configuration, a key must be set up so that the system may communicate with a database for tracking all the models that are added or removed using the [Model Management API](/config/model-management-api). The key must start with `sk-` and then can be any @@ -132,74 +189,66 @@ arbitrary string. We recommend generating a new UUID and then using that as the key. Configuration example is below. - ```yaml litellmConfig: db_key: sk-00000000-0000-0000-0000-000000000000 # needed for db operations, create your own key # pragma: allowlist-secret ``` -### Step 7: Set Up SSL Certificates (Development Only) - -LISA requires SSL certificates for secure communication. Choose the appropriate method based on your deployment environment. - -#### AWS Certificate Manager - -Use AWS Certificate Manager to create and manage certificates: - -1. **Create a Certificate in AWS Certificate Manager**: - - Navigate to the [AWS Certificate Manager Console](https://console.aws.amazon.com/acm) - - Request a public certificate - - For internal AWS deployments, use the domain pattern: `.people.aws.dev` - - Follow the DNS validation process to verify domain ownership - - Note: You may need access to specific AWS bindles or Route 53 hosted zones - -2. **Configure Custom Domains** in your `config-custom.yaml`: - -```yaml -restApiConfig: - sslCertIamArn: arn:aws:acm:::certificate/ - domainName: serve..people.aws.dev - -apiGatewayConfig: - domainName: chat..people.aws.dev -``` - -- For `sslCertIamArn` copy the arn from your ssl certificate from the AWS Certificate Manager. Otherwise you can manually fill it in. -- For `domainName` replace `` with your chosen subdomain. - -1. **Set Up Route 53 and Custom Domains**: - -After configuring your certificate and custom domains in `config-custom.yaml`, you need to set up DNS routing: - -**Create Route 53 Hosted Zone**: - - Navigate to Route 53 in the AWS Console - - Create a hosted zone for your domain (if it does not already exists) - - Note the hosted zone ID and name servers - -**Configure API Gateway Custom Domain** (after LISA deployment): - - Navigate to API Gateway β†’ Custom domain names - - Create a custom domain for your chat endpoint: `chat..people.aws.dev` - - Associate it with your API Gateway stage - -**Create DNS Records**: - - In Route 53, create an A record for `chat..people.aws.dev`: - - Type: A record (Alias) - - Alias target: Your API Gateway custom domain - - Create a CNAME record for `serve..people.aws.dev`: - - Type: CNAME - - Value: Your LisaServe REST API Application Load Balancer DNS name (found in EC2 β†’ Load Balancers) - -**For Internal AWS Deployments**: - - Register your DNS name using Supernova at https://supernova.amazon.dev/ - - Follow the guide at https://w.amazon.com/bin/view/SuperNova/PreOnboardingSteps/ - - Use the pattern: `{username}.people.aws.dev` - - Associate with the appropriate AWS bindle for access control +> [!IMPORTANT] +> To include prompt/response content in LiteLLM logs (published by the `LISA Serve` ECS task to CloudWatch via `litellm.log`), enable LiteLLM logging callbacks and message logging in `config-custom.yaml`. +> +> 1. Add the following to `litellmConfig`: +> ```yaml +> litellmConfig: +> litellm_settings: +> callbacks: ["otel"] +> turn_off_message_logging: false +> environment_variables: +> OTEL_EXPORTER: console +> callback_settings: +> otel: +> message_logging: true +> ``` +> +> 2. Ensure you are aware of the privacy/compliance implications: this causes request/response content to be logged. +> +> LiteLLM Proxy logging reference: https://docs.litellm.ai/docs/proxy/logging -**Redeploy LISA** - - Redeploy LISA for the changes to take effect - - After completing these steps and redeploying LISA, your application will be accessible via custom domains with valid SSL certificates, eliminating the need to accept self-signed certificates in your browser. +> [!IMPORTANT] +> API Gateway audit logging (strict opt-in): +> LISA can emit audit logs for API Gateway requests (who initiated the request, what action was taken, and a sanitized JSON body) only when enabled via `auditLoggingConfig` in `config-custom.yaml`. +> +> Example (opt-in to specific API prefixes): +> ```yaml +> auditLoggingConfig: +> enabled: true +> auditAll: false +> enabledPaths: ["/api-tokens", "/models", "/repository", "/session", "/configuration", "/prompt-templates", "/project", "/user-preferences", "/mcp", "/mcp-server", "/mcp-workbench", "/metrics", "/chat-assistant-stacks", "/bedrock-kb"] +> ``` +> +> Example (`auditAll`): +> ```yaml +> auditLoggingConfig: +> enabled: true +> auditAll: true +> ``` +> +> Optional JSON body audit (default **off**): set `includeJsonBody: true` to emit `AUDIT_API_GATEWAY_REQUEST_BODY` for opted-in paths. When `includeJsonBody` is false or omitted, request bodies are never logged, even when path auditing is enabled. +> +> When audit logging is enabled for a given API prefix, two kinds of events may appear. **They are not in the same CloudWatch log group:** +> +> | Event | What it contains | Where it is logged | +> | ----- | ---------------- | ------------------ | +> | `AUDIT_API_GATEWAY_REQUEST` | Allow/Deny, user identity, HTTP method + path (from the authorizer) | **API Gateway Lambda authorizer** log group (e.g. `…-lambda-authorizer`) | +> | `AUDIT_API_GATEWAY_REQUEST_BODY` | Sanitized JSON body (and user context from the proxy event) | **The Lambda (or service) that handles the route** β€” e.g. `put_session` for `PUT /session/{id}`, or the FastAPI/Mangum app log stream for APIs served that way β€” only if `includeJsonBody: true` | +> +> API Gateway does **not** send the HTTP body to the REST authorizer, so body audit must run in the integration that receives `event["body"]`. +> +> Each audit line is logged as **`EVENT_TYPE` followed by a compact JSON object** (same fields as before), so the full payload appears in the log message and can be parsed in CloudWatch Logs Insights (e.g. split on the first space and `parse` the JSON). +> +> Privacy note: enabling JSON body audit logging may include sensitive user data; ensure your organization’s compliance requirements are met. -### Step 8a: Customize Model Deployment (If Using LISA Serve) +### Step 7a: Customize Model Deployment (If Using LISA Serve) In the `ecsModels` section of `config-custom.yaml`, allow our deployment process to pull the model weights for you. @@ -215,11 +264,11 @@ ecsModels: baseImage: vllm/vllm-openai:latest ``` -### Step 8b: Stage Model Weights +### Step 7b: Stage Model Weights LISA requires model weights to be staged in the S3 bucket specified in your `config-custom.yaml` file, assuming the S3 bucket follows this structure: -``` +```text s3:/// s3://// s3://// @@ -229,7 +278,7 @@ s3:/// **Example:** -``` +```text s3:///mistralai/Mistral-7B-Instruct-v0.2 s3:///mistralai/Mistral-7B-Instruct-v0.2/ s3:///mistralai/Mistral-7B-Instruct-v0.2/ @@ -239,7 +288,7 @@ s3:///mistralai/Mistral-7B-Instruct-v0.2/ To automatically download and stage the model weights defined by the `ecsModels` parameter in your `config-custom.yaml`, use the following command: ```bash -make modelCheck +npm run model:check ``` This command verifies if the model's weights are already present in your S3 bucket. If not, it downloads the weights, converts them to the required format, and uploads them to your S3 bucket. Ensure adequate disk space is available for this process. @@ -249,13 +298,13 @@ This command verifies if the model's weights are already present in your S3 buck > Previously, before models could be managed through the [API](/config/model-management-api) or via the Model Management > section of the [Chatbot](/user/chat), this parameter also > dictated which models were deployed. - > **NOTE** -> For air-gapped systems, before running `make modelCheck` you should manually download model artifacts and place them in a `models` directory at the project root, using the structure: `models/`. + + +> For air-gapped systems, before running `npm run model:check` you should manually download model artifacts and place them in a `models` directory at the project root, using the structure: `models/`. > **NOTE** > This process is primarily designed and tested for HuggingFace models. For other model formats, you will need to manually create and upload safetensors. - > **NOTE** > Please valdiate that all files successfully downloaded locally AND were uploaded to the S3 Bucket. Ensure all large files such as .safetensor files exist. @@ -264,8 +313,9 @@ This command verifies if the model's weights are already present in your S3 buck If you haven't bootstrapped your AWS account for CDK: ```bash -make bootstrap +npm run bootstrap ``` + ## ADC Region Deployment Tips Amazon Dedicated Cloud (ADC) regions are isolated AWS environments designed for government customers' most sensitive workloads. These regions have restricted internet access and limited external dependencies, requiring special deployment considerations for LISA. @@ -283,16 +333,22 @@ This approach builds all necessary components in a commercial region with full i 1. Set up LISA in a commercial AWS region with internet access 2. Build all components: + ```bash - make buildArchive + npm run build:archive + ./bin/build-assets --include-images ``` + This generates: - - Lambda function zip files in `./dist/layers/*.zip` - - Docker images exported as `./dist/images/*.tar` files + + +* Lambda function zip files in `./dist/layers/*.zip` (from `build:archive`) + * Docker images exported as `./dist/images/*.tar` files (from `build-assets --include-images`) #### Step 2: Transfer to ADC Region 1. Upload Docker images to ECR in your ADC region: + ```bash # Load and tag images docker load -i lisa-rest-api.tar @@ -302,6 +358,7 @@ This approach builds all necessary components in a commercial region with full i aws ecr get-login-password --region | docker login --username AWS --password-stdin .dkr.ecr..amazonaws.com docker push .dkr.ecr..amazonaws.com/lisa-rest-api:latest ``` + You'll want to repeat this for lisa-batch-ingestion, as well as any of the LISA base model hosting containers (lisa-vllm, lisa-tgi, lisa-tei) 2. Transfer built artifacts to ADC environment @@ -341,16 +398,16 @@ restApiConfig: code: .dkr.ecr..amazonaws.com/lisa-rest-api:latest ``` - - ### Approach 2: In-Region Building This approach configures LISA to build components using repositories accessible from within the ADC region. #### Prerequisites -- ADC-accessible package repositories (PyPI mirror, npm registry, container registry) -- ADC-accessible container registries -- Network connectivity to required build dependencies + +* ADC-accessible package repositories (PyPI mirror, npm registry, container registry) + +* ADC-accessible container registries +* Network connectivity to required build dependencies #### Configuration @@ -380,6 +437,7 @@ mcpWorkbenchBuildConfig: S6_OVERLAY_ARCH_SOURCE: "./s6-overlay-x86_64.tar.xz" # Path relative to lib/serve/mcp-workbench/ RCLONE_SOURCE: "./rclone-linux-amd64.zip" # Path relative to lib/serve/mcp-workbench/ ``` + You'll also want any model hosting base containers available, e.g. vllm/vllm-openai:latest and ghcr.io/huggingface/text-embeddings-inference:latest #### Preparing Offline Build Dependencies @@ -409,11 +467,13 @@ cp -r ~/.cache/prisma* lib/serve/rest-api/PRISMA_CACHE/ ``` **Important Notes:** -- The cache is platform-specific. Generate it on a system matching your Docker base image (e.g., for `public.ecr.aws/docker/library/python:3.13-slim` which is Debian-based, so you may want to use a Debian-based system) -- The `prisma version` command downloads binaries for your current platform -- Both `prisma/` and `prisma-python/` directories are required for offline operation + +* The cache is platform-specific. Generate it on a system matching your Docker base image (e.g., for `public.ecr.aws/docker/library/python:3.13-slim` which is Debian-based, so you may want to use a Debian-based system) +* The `prisma version` command downloads binaries for your current platform +* Both `prisma/` and `prisma-python/` directories are required for offline operation **MCP Workbench dependencies** (S6 Overlay and rclone): + ```bash # Download S6 Overlay files cd lib/serve/mcp-workbench/ @@ -435,23 +495,27 @@ To utilize the prebuilt hosting model containers with self-hosted models, select Once your configuration is complete: 1. Bootstrap CDK (if not already done): + ```bash - make bootstrap + npm run bootstrap ``` 2. Deploy LISA: + ```bash - make deploy + npm run deploy ``` 3. Deploy specific stacks if needed: + ```bash - make deploy STACK=LisaServe + STACK=LisaServe npm run deploy ``` 4. List available stacks: + ```bash - make listStacks + npm run cdk:list ``` ### Testing Your Deployment @@ -464,8 +528,8 @@ pytest lisa-sdk/tests --url --verify ![LISA Cognito Setup Example](../assets/LISA_Cognito_Example.png) @@ -83,6 +84,7 @@ authConfig: clientId: your-client-id adminGroup: AdminGroup userGroup: UserGroup + ragAdminGroup: RagAdminGroup # optional: grants RAG Admin role to this Cognito group jwtGroupsProperty: cognito:groups ``` @@ -95,6 +97,7 @@ authConfig: **Cause**: Incorrect OpenID Connect scopes configuration. **Solution**: + - Verify that your App Client has the correct OpenID Connect scopes enabled: - `email` - `openid` @@ -109,25 +112,43 @@ authConfig: **Cause**: Using "Traditional Web App" instead of "Single Page Application" (SPA) when creating the App Client. **Solution**: + - Recreate your App Client and select **"Single Page Application" (SPA)** as the app type - SPA clients do not require a client secret for token exchange, which is correct for browser-based applications **Testing Tip**: Use Chrome or Firefox Developer Tools: + - Open Developer Tools (F12) - Navigate to the "Application" tab (Chrome) or "Storage" tab (Firefox) - Find and clear Cookies related to your Cognito domain - This allows you to retry the login process with a fresh authentication flow +#### "Something went wrong" / "An error was encountered with the requested page" (After Login) + +**Symptom**: After entering credentials on the Cognito login page, you are shown a generic error: "Something went wrong" or "An error was encountered with the requested page." + +**Cause**: The `redirect_uri` sent to Cognito does not exactly match the allowed callback URLs. LISA uses `origin + pathname` (no hash fragment) per the OAuth 2.0 spec. If your Cognito App Client's allowed URLs omit the path or use a different format, Cognito rejects the redirect. + +**Solution**: + +- In your App Client's "Allowed callback URLs", add both: + - `https:///` (e.g. `https://xxx.execute-api.us-east-1.amazonaws.com/dev`) + - `https:////` (with trailing slash) +- For custom domains, add `https:///` and `https://` +- Ensure no typos, correct protocol (https), and exact path (including trailing slash variants) + #### "Contact Your Administrator" Error on Login Page **Symptom**: The Cognito hosted UI displays an error message asking you to contact your administrator. **Possible Causes**: + - Incorrect callback URLs in the App Client configuration - Mismatch between the URL that Cognito is redirecting to and the allowed callback URLs - The callback URL must exactly match (including trailing slashes) **Solution**: + - Verify that your App Client's "Allowed callback URLs" include: - Your API Gateway dev stage URL: `https:///dev` - The same URL with trailing slash: `https:///dev/` @@ -141,12 +162,12 @@ like in the Cognito clients. Instead, it will be a string configured by your Key will be able to provide you with a client name or create a client for you to use for this application. Once you have this string, use that as the `clientId` within the `authConfig` block. - -``` +```yaml authConfig: authority: https://your-keycloak-server.com clientId: your-client-name adminGroup: AdminGroup userGroup: UserGroup + ragAdminGroup: RagAdminGroup # optional: grants RAG Admin role to this Keycloak role jwtGroupsProperty: realm_access.roles ``` diff --git a/lib/docs/config/collection-management-api.md b/lib/docs/config/collection-management-api.md index 817812848..63a3bc482 100644 --- a/lib/docs/config/collection-management-api.md +++ b/lib/docs/config/collection-management-api.md @@ -714,7 +714,7 @@ fetch(hardDeleteUrl, { **Important Notes:** -1. **Admin Access Required**: Only users with admin access to the collection can delete it +1. **Admin or RAG Admin Access Required**: Only Admins or RAG Admins with group access to the repository can delete collections 2. **Default Collection Protection**: The default collection (based on embedding model ID) cannot be deleted 3. **Document Cleanup**: All documents in the collection will be removed from S3, DynamoDB, and the vector store 4. **Irreversible Operation**: Hard delete is permanent and cannot be undone @@ -1214,10 +1214,11 @@ Collections inherit configuration from their parent vector store: - **Admin**: Delete collection, modify access control ### Access Rules -1. Admin users have full access to all collections -2. Non-admin users must have group membership intersection with collection's allowed groups -3. Private collections are only accessible to creator and admins -4. Vector stores with `allowUserCollections: false` prevent non-admin collection creation +1. Admin users have full access to all collections across all repositories +2. RAG Admin users can create, update, and delete collections on repositories they have group access to; they cannot modify `allowedGroups` or repository-level settings +3. Non-admin users must have group membership intersection with collection's allowed groups +4. Private collections are only accessible to creator and admins +5. Vector stores with `allowUserCollections: false` prevent non-admin collection creation ## Best Practices diff --git a/lib/docs/config/configuration-ui.md b/lib/docs/config/configuration-ui.md index f5e0ce3b8..d8cd35502 100644 --- a/lib/docs/config/configuration-ui.md +++ b/lib/docs/config/configuration-ui.md @@ -7,12 +7,14 @@ The Configuration UI is an Administrator-only page accessible via `Administratio The Chat Features section contains toggles that control which capabilities are available to users in the Chat UI. Features are organized into the following groups: ### RAG + | Toggle | Description | |--------|-------------| | Document upload from Chat | Allows users to upload documents directly from the chat interface for RAG queries. See [RAG Repository](/config/repositories) for collection setup. | | Edit number of referenced documents | Lets users adjust how many RAG documents are referenced during inference. | ### Library + | Toggle | Description | |--------|-------------| | Model Library | Exposes the [Model Library](/user/model-library) page where users can browse available models. | @@ -20,12 +22,14 @@ The Chat Features section contains toggles that control which capabilities are a | Prompt Template Library | Exposes the Prompt Template Library for creating and managing reusable prompt templates. | ### In-Context + | Toggle | Description | |--------|-------------| | Document upload to context | Allows users to upload documents directly into the conversation context. | | Document Summarization | Enables document summarization capabilities within chat sessions. | ### Advanced + | Toggle | Description | |--------|-------------| | Edit model arguments | Allows users to modify model inference parameters (temperature, top_p, etc.). | @@ -38,18 +42,21 @@ The Chat Features section contains toggles that control which capabilities are a | Chat Assistant Stacks | Enables the [Chat Assistant Stacks](/config/chat-assistant-stacks) feature for pre-configured assistant workflows. | ### MCP + | Toggle | Description | |--------|-------------| | MCP Server Connections | Enables users to configure [MCP server connections](/config/mcp). See also [LISA MCP: Self-host servers](/config/hosted-mcp). | | MCP Workbench | Provides an experimentation workbench for MCP tools. See [MCP Workbench](/config/mcp-workbench). Requires MCP Server Connections to be enabled first. | +| MCP AWS Sessions | Enables the [AWS Sessions](/config/mcp#aws-sessions) feature, allowing users to connect AWS credentials per chat session for use by MCP tools. Requires MCP Server Connections to be enabled first. | ### API Tokens + | Toggle | Description | |--------|-------------| | User managed API tokens | Allows users to create and manage their own API tokens for programmatic access to LISA Serve. See [API Token Management](/config/api-tokens). | > [!NOTE] -> Some toggles have dependencies. For example, MCP Workbench requires MCP Server Connections to be enabled. Disabling a prerequisite toggle will automatically disable its dependents. +> Some toggles have dependencies. For example, MCP Workbench and AWS Sessions require MCP Server Connections to be enabled. Disabling a prerequisite toggle will automatically disable its dependents. ## System Banner diff --git a/lib/docs/config/configuration.md b/lib/docs/config/configuration.md index d861d243e..4119d1bce 100644 --- a/lib/docs/config/configuration.md +++ b/lib/docs/config/configuration.md @@ -14,6 +14,7 @@ authConfig: clientId: adminGroup: userGroup: + ragAdminGroup: # optional: IDP group for RAG Admin role jwtGroupsProperty: ``` diff --git a/lib/docs/config/custom-branding.md b/lib/docs/config/custom-branding.md index 2a3293f2d..fea01af90 100644 --- a/lib/docs/config/custom-branding.md +++ b/lib/docs/config/custom-branding.md @@ -32,7 +32,7 @@ customDisplayName: "YourProductName" When `useCustomBranding: true` is set, LISA looks for your custom assets in the following location: -``` +```text lib/user-interface/react/public/branding/custom/ ``` @@ -48,7 +48,7 @@ Create a `custom` directory and provide these three files: ### Directory Structure -``` +```text lib/user-interface/react/public/branding/ β”œβ”€β”€ base/ # Default LISA branding (don't modify) β”‚ β”œβ”€β”€ favicon.ico @@ -63,16 +63,19 @@ lib/user-interface/react/public/branding/ ### Asset Guidelines **Favicon (`favicon.ico`)** + - Standard browser icon format - Appears in browser tabs and bookmarks - Should be simple and recognizable at small sizes **Logo (`logo.svg`)** + - Vector format for optimal rendering at any size - Used in the top navigation bar - Recommended: Display size: ~120-200px wide **Login Image (`login.png`)** + - Displayed on the authentication page ## Display Name Customization @@ -92,6 +95,7 @@ customDisplayName: "YourProductName" ``` With this configuration: + - The page title changes from "AWS LISA AI Chat Assistant" to "YourProductName AI Chat Assistant" - All references to "LISA" in the UI become "YourProductName" - Your custom logo, favicon, and login image are used @@ -105,20 +109,25 @@ Beyond assets and names, you can customize the visual theme by creating a custom LISA contains two theme files: **Base Theme (Default):** -``` + +```text lib/user-interface/react/src/theme.ts ``` + This file contains a minimal theme with an empty token configuration and should not be modified directly. This theme serves as a fallback if no custom theme is defined and will load the Cloudscape default theming. **Custom Theme (Optional):** -``` + +```text lib/user-interface/react/src/theme-custom.ts ``` + Create this file to define your custom theme. This file is gitignored, allowing you to maintain organization-specific branding without committing it to version control. When `useCustomBranding: true` is configured, LISA will automatically: + 1. Look for `theme-custom.ts` first 2. Fall back to `theme.ts` if the custom file doesn't exist 3. Use Cloudscape's default theme if neither contains customizations @@ -128,10 +137,12 @@ When `useCustomBranding: true` is configured, LISA will automatically: The [Cloudscape theming system](https://cloudscape.design/foundation/visual-foundation/theming/) allows you to customize various visual aspects of the default Cloudscape theme such as: **Typography** + - Font families - Font sizes and weights **Colors** + - Background colors (layout, containers, inputs) - Text colors (body, headings, links) - Button colors (primary, secondary, hover states) @@ -140,11 +151,13 @@ The [Cloudscape theming system](https://cloudscape.design/foundation/visual-foun - Selection/highlight colors **Layout** + - Border radius for buttons and containers - Spacing and padding - Component sizing **Context-Specific Styling** + - Top navigation appearance - Dropdown menus - Flashbar notifications @@ -155,6 +168,7 @@ The [Cloudscape theming system](https://cloudscape.design/foundation/visual-foun 1. **Create Custom Theme File** Copy the example custom theme to create your own: + ```bash cp lib/user-interface/react/src/theme-custom.ts.example \ lib/user-interface/react/src/theme-custom.ts @@ -163,6 +177,7 @@ The [Cloudscape theming system](https://cloudscape.design/foundation/visual-foun 2. **Edit Theme Variables** Open `theme-custom.ts` and customize the theme variables at the top of the file: + ```typescript // THEME VARIABLES - Edit these to customize the entire theme @@ -179,6 +194,7 @@ The [Cloudscape theming system](https://cloudscape.design/foundation/visual-foun 3. **Configure Branding** Enable custom branding in `config-custom.yaml`: + ```yaml useCustomBranding: true customDisplayName: "YourProductName" @@ -189,12 +205,14 @@ The [Cloudscape theming system](https://cloudscape.design/foundation/visual-foun For local development testing: a. Update `lib/user-interface/react/public/env.js`: + ```js "USE_CUSTOM_BRANDING": true, "CUSTOM_DISPLAY_NAME": "YourProductName" ``` b. Start the development server: + ```bash npm run dev ``` @@ -204,8 +222,9 @@ The [Cloudscape theming system](https://cloudscape.design/foundation/visual-foun 5. **Deploy** Deploy your changes: + ```bash - make deploy + npm run deploy ``` > [!NOTE] @@ -237,6 +256,7 @@ if (window.env?.USE_CUSTOM_BRANDING) { ``` **How it works:** + 1. When `USE_CUSTOM_BRANDING` is true, LISA scans for theme files 2. If `theme-custom.ts` exists, it loads that file 3. Otherwise, it falls back to `theme.ts` (the base theme) @@ -268,6 +288,7 @@ export function getDisplayName(): string { ``` These utilities ensure: + - Assets are loaded from the correct directory - The correct display name is used throughout the application - Fallback to default LISA branding if custom assets are missing @@ -277,6 +298,7 @@ These utilities ensure: ### Complete Custom Branding Setup 1. **Update Configuration** + ```yaml # config-custom.yaml useCustomBranding: true @@ -284,11 +306,13 @@ These utilities ensure: ``` 2. **Create Custom Assets Directory** + ```bash mkdir -p lib/user-interface/react/public/branding/custom ``` 3. **Add Your Assets** + ```bash # Copy your branded assets cp /path/to/your/favicon.ico lib/user-interface/react/public/branding/custom/ @@ -297,6 +321,7 @@ These utilities ensure: ``` 4. **Customize Theme (Optional)** + ```bash # Create and edit theme-custom.ts with your color scheme cp lib/user-interface/react/src/theme-custom.ts.example \ @@ -307,8 +332,9 @@ These utilities ensure: ``` 5. **Deploy** + ```bash - make deploy + npm run deploy ``` ### Verification @@ -336,6 +362,7 @@ After deployment, verify your branding: **Issue**: Custom assets don't appear after deployment **Solutions**: + - Verify files exist in `lib/user-interface/react/public/branding/custom/` - Check file names match exactly: `favicon.ico`, `logo.svg`, `login.png` - Ensure `useCustomBranding: true` in config @@ -347,6 +374,7 @@ After deployment, verify your branding: **Issue**: "LISA" still appears instead of custom name **Solutions**: + - Verify `customDisplayName` is set in `config-custom.yaml` - Ensure config changes were deployed - Check `{LISA_URL}/{STAGE}/env.js` path for `CUSTOM_DISPLAY_NAME` and `USE_CUSTOM_BRANDING` @@ -357,6 +385,7 @@ After deployment, verify your branding: **Issue**: Custom theme colors don't appear **Solutions**: + - Verify `useCustomBranding: true` (theme is only applied when branding is enabled) - Ensure `theme-custom.ts` exists in `lib/user-interface/react/src/` - Verify theme variables are properly defined in `theme-custom.ts` @@ -368,6 +397,7 @@ After deployment, verify your branding: **Issue**: Changes to theme-custom.ts not appearing **Solutions**: + - Restart the development server (`npm run dev`) - Clear browser cache - Check for TypeScript errors in the theme file @@ -378,6 +408,7 @@ After deployment, verify your branding: **Issue**: Some assets are custom, others are default **Solutions**: + - Ensure all three asset files are present in the `custom` directory - Check file permissions are readable - Verify no typos in file names (case-sensitive on Linux) @@ -388,6 +419,7 @@ After deployment, verify your branding: **Issue**: Some components are not showing the color they were configured with in `theme-custom.ts` **Solutions**: + - Restart the development server - Clear browser cache - Change the value of the component (e.g. `#0054E3` -> `#0054E2`). Reuse of the same values can occasionally be problematic in the Cloudscape theming system. @@ -399,6 +431,7 @@ After deployment, verify your branding: Here's a complete example showing all aspects of custom branding: ### config-custom.yaml + ```yaml accountNumber: 123456789012 region: us-east-1 @@ -418,7 +451,8 @@ authConfig: ``` ### Assets Prepared -``` + +```text lib/user-interface/react/public/branding/custom/ β”œβ”€β”€ favicon.ico # Acme company icon β”œβ”€β”€ logo.svg # Acme company logo @@ -426,6 +460,7 @@ lib/user-interface/react/public/branding/custom/ ``` ### Custom Theme + ```typescript // lib/user-interface/react/src/theme-custom.ts (excerpt) const FONT_FAMILY = 'Roboto, Arial, sans-serif'; @@ -436,7 +471,9 @@ const LIGHT_TOPNAV_BACKGROUND = '#0A3D62'; ``` ### Result + After deployment, users see: + - Browser tab: "Acme AI Chat Assistant" with Acme favicon - Top navigation: Acme logo and "Acme" branding - Login page: Acme welcome image diff --git a/lib/docs/config/mcp-workbench.md b/lib/docs/config/mcp-workbench.md index 2a075f74c..5e3066515 100644 --- a/lib/docs/config/mcp-workbench.md +++ b/lib/docs/config/mcp-workbench.md @@ -24,6 +24,18 @@ The integrated browser-based editor allows administrators to write Python code a ## Configuration +### Deployment infrastructure + +The MCP Workbench **HTTP server** (streamable MCP and AWS session routes) always runs on **its own** ECS cluster and Application Load Balancer, separate from the LISA Serve REST API. The container still serves `/v2/mcp/*` and `/api/aws/*` on that load balancer’s default listener. + +The hosted MCP base URL is stored in SSM at `…/mcpWorkbench/endpoint` (and used by configuration Lambdas). It must target the **MCP Workbench** ALB, not the Serve API ALB. When you set `restApiConfig.domainName`, LISA derives a separate workbench hostname by default (for example `lisa-serve.` becomes `lisa-mcp-workbench.`, and `serve.` becomes `mcp-workbench.`) unless you override it with `mcpWorkbenchEcsConfig.domainName`. Create a DNS record for that hostname pointing at the **MCP Workbench** load balancer in EC2. + +Optional `mcpWorkbenchEcsConfig` in your deployment configuration lets you tune instance type, ASG minimum and maximum capacity, root volume size, and scaling cooldown for the workbench cluster. + +**CORS:** The browser calls the workbench from the **UI origin** (custom domain, ALB URL, or local dev), which changes with deployment and app configuration. By default, `mcpWorkbenchCorsOrigins` is `*` so the workbench container allows any origin (`CORS_ORIGINS`). Set `mcpWorkbenchCorsOrigins` in your deployment config to a comma-separated list if you need to restrict origins. The workbench hostname may still differ from the Serve API hostname; verify OIDC flows for your setup. + +**CDK:** The workbench stack is deployed in the same account and VPC as the rest of LISA. In the current stage layout it is created when `deployMcpWorkbench` is enabled (alongside the Serve stack when `deployServe` is enabled). + ### Step 1: Enable the MCP Workbench Menu 1. **Access Admin Configuration** @@ -62,20 +74,83 @@ Once the MCP Workbench connection is activated, all custom enabled tools become ### Programmatic API Access -LISA automatically hosts an MCP Server containing all MCP Workbench tools. The server is accessible through the following endpoints: +LISA automatically hosts an MCP Server containing all MCP Workbench tools. The server is accessible on the **MCP Workbench** load balancer (see SSM `…/mcpWorkbench/endpoint`), for example: -**AWS Load Balancer URL:** -``` +**AWS Load Balancer URL (example):** + +```text https://abc-rest-..elb.amazonaws.com/v2/mcp/ ``` -**Custom Domain URL (if configured):** -``` +**Custom Domain URL (if configured on that load balancer):** + +```text https:///v2/mcp/ ``` > **Authentication Required:** API access requires [Programmatic API Tokens](./api-tokens.md) for authentication. +## API Reference + +The MCP Workbench includes a REST API for managing tool source files and syntax validation in addition to hosted MCP runtime access. + +Base path: `/mcp-workbench` + +### List Tools + +- Method: `GET` +- Path: `/mcp-workbench` +- Description: Lists MCP Workbench tools available to the caller. + +### Create Tool + +- Method: `POST` +- Path: `/mcp-workbench` +- Description: Creates a new MCP Workbench tool. + +### Get Tool + +- Method: `GET` +- Path: `/mcp-workbench/{toolId}` +- Description: Retrieves a single MCP Workbench tool. + +Path parameters: + +- `toolId` (string, required): Tool identifier + +### Update Tool + +- Method: `PUT` +- Path: `/mcp-workbench/{toolId}` +- Description: Updates an existing MCP Workbench tool. + +Path parameters: + +- `toolId` (string, required): Tool identifier + +### Delete Tool + +- Method: `DELETE` +- Path: `/mcp-workbench/{toolId}` +- Description: Deletes an MCP Workbench tool. + +Path parameters: + +- `toolId` (string, required): Tool identifier + +### Validate Python Syntax + +- Method: `POST` +- Path: `/mcp-workbench/validate-syntax` +- Description: Validates Python code syntax before creating or updating tools. + +Example: + +```bash +curl -X GET "https:////mcp-workbench" \ + -H "Authorization: Bearer " +``` + ## Development Guidelines ### Creating Your First Tool @@ -149,6 +224,19 @@ Both approaches will make your tool available in the chat interface once deploye ## Advanced Usage +### AWS Sessions + +When the [AWS Sessions](/config/mcp#aws-sessions) feature is enabled, MCP Workbench tools can leverage per-session AWS credentials that users connect in the chat UI. Tools receive the caller's identity (user and session) from the request context and use it to look up stored credentials. + +To create a tool that uses AWS credentials: + +1. Import `get_caller_identity` from `mcpworkbench.aws.identity` and `get_aws_session_for_user` from the shared session service. +2. Call `get_caller_identity()` to obtain the current user and session IDs from the request headers. +3. Call `get_aws_session_for_user(user_id, session_id)` to retrieve the `AwsSessionRecord` (or handle `AwsSessionMissingError` if the user has not connected credentials). +4. Use the record's `aws_access_key_id`, `aws_secret_access_key`, `aws_session_token`, and `aws_region` to construct boto3 clients. + +See `lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py` for a complete example. Without tools that leverage these credentials, the AWS Sessions feature has no effect. + ### Adding Python Dependencies Operators can modify `lib/serve/mcp-workbench/requirements.txt` to add additional Python libraries that will be available in the MCP Workbench environment. After modifying the requirements file, you'll need to perform a CDK deployment for those additional libraries to become available to your custom tools. diff --git a/lib/docs/config/mcp.md b/lib/docs/config/mcp.md index 0845bf147..7b811053c 100644 --- a/lib/docs/config/mcp.md +++ b/lib/docs/config/mcp.md @@ -9,8 +9,9 @@ tools and perform the necessary steps to complete the task. **Activate MCP Feature & LLM** -1. Administrators must first activate the **MCP Server Connections** feature on LISA’s **Configuration** page. When active, the MCP Connections page will appear under the **Libraries** menu for all LISA users. To activate, toggle **Allow MCP Server Connections** to be active. Click **Save Changes** and then **Update.** -2. An Administrator must add an LLM to support LISA’s MCP tools capability. During model creation in Model Management, toggle **Tool Calls** to be active. +1. Administrators must first activate the **MCP Server Connections** feature on LISA’s **Configuration** page. When active, the MCP Connections page will appear under the **Libraries** menu for all LISA users. To activate, toggle **MCP Server Connections** to be active. Click **Save Changes** and then **Update.** +2. Optionally, enable **AWS Sessions** to allow users to connect AWS credentials per chat session for use by MCP tools that support it. See [AWS Sessions](#aws-sessions) below. +3. An Administrator must add an LLM to support LISA’s MCP tools capability. During model creation in Model Management, toggle **Tool Calls** to be active. **Add MCP Connections** @@ -68,4 +69,85 @@ When a user activates Autopilot Mode, that user will not be prompted to confirm Admins can edit and delete any MCP Server Connection. Non-admins can edit or delete MCP Server Connections that they created. +## API Reference + +The MCP Server Connections API manages MCP endpoints that users and administrators can enable in chat sessions. + +Base path: `/mcp-server` + +### List MCP Server Connections + +- Method: `GET` +- Path: `/mcp-server` +- Description: Lists MCP server connections available to the caller. + +### Create MCP Server Connection + +- Method: `POST` +- Path: `/mcp-server` +- Description: Creates a new MCP server connection. + +### Get MCP Server Connection + +- Method: `GET` +- Path: `/mcp-server/{serverId}` +- Description: Retrieves a specific MCP server connection. + +Path parameters: + +- `serverId` (string, required): MCP server identifier + +### Update MCP Server Connection + +- Method: `PUT` +- Path: `/mcp-server/{serverId}` +- Description: Updates an existing MCP server connection. + +Path parameters: + +- `serverId` (string, required): MCP server identifier + +### Delete MCP Server Connection + +- Method: `DELETE` +- Path: `/mcp-server/{serverId}` +- Description: Deletes an MCP server connection. + +Path parameters: + +- `serverId` (string, required): MCP server identifier + +Example: + +```bash +curl -X GET "https:////mcp-server" \ + -H "Authorization: Bearer " +``` + +## AWS Sessions + +When **AWS Sessions** is enabled (Administration β†’ Configuration β†’ MCP section), users can connect their AWS credentials to individual chat sessions. This allows MCP tools to perform AWS operations on behalf of the user using their own credentials. + +### How it works + +- **Per-session scope**: Credentials are stored per user and per chat session. Each session has its own isolated AWS connection. +- **In-memory storage**: Keys are validated and converted to short-lived session credentials stored securely in memory. Long-term credentials are never persisted. +- **MCP tool requirement**: The feature only has effect when an MCP server exposes tools that leverage these credentials. For example, the MCP Workbench sample S3 tools use connected credentials to list buckets or perform other S3 operations. Without such tools, connecting credentials has no effect. +- **Session lifecycle**: Credentials are discarded when the chat session ends. + +> **Caution:** Credentials with broad permissions can create, modify, or delete resources in your AWS account. Use IAM credentials with the minimum permissions required for the tools you intend to use. + +### Using AWS Sessions + +1. Ensure **AWS Sessions** is enabled by your Administrator (Configuration β†’ MCP β†’ AWS Sessions). +2. Start or open a chat session. +3. Open the session configuration panel (gear icon) and locate the **AWS Credentials** section. +4. Enter your Access Key ID, Secret Access Key, optional Session Token, and Region. +5. Click **Connect**. Your credentials are validated and converted to short-lived session credentials. +6. Use MCP tools that support AWS credentials (e.g., S3 list buckets) within that chat session. The LLM will invoke these tools using your connected credentials. + +> **TIP:** +> +> AWS Sessions requires MCP Server Connections to be enabled. Administrators can disable AWS Sessions independently if they do not want users connecting AWS credentials in chat. + ![MCP Example](../assets/mcp_toolchain.gif) diff --git a/lib/docs/config/model-compatibility.md b/lib/docs/config/model-compatibility.md index 44d0987dc..f13c722f7 100644 --- a/lib/docs/config/model-compatibility.md +++ b/lib/docs/config/model-compatibility.md @@ -28,4 +28,5 @@ See the [deployment](/admin/deploy) section for details on how to set up the vLL how the HuggingFace containers will serve safetensor weights downloaded from the HuggingFace website, vLLM will do the same, and our configuration will allow you to serve these artifacts automatically. vLLM does not have many supported models for embeddings, but as they become available, LISA will support them as long as the vLLM container version is updated in the config.yaml file and as long as the model's safetensors can be found in S3. + - Please see the [vLLM Environment Variables Documentation](./vllm_variables.md) before getting started with vLLM models diff --git a/lib/docs/config/model-management-api.md b/lib/docs/config/model-management-api.md index 8c9c5f314..26e75142a 100644 --- a/lib/docs/config/model-management-api.md +++ b/lib/docs/config/model-management-api.md @@ -95,7 +95,7 @@ curl -s -H "Authorization: Bearer " -X GET https:// ## Creating a Model (Admin API) -LISA provides the `/models` endpoint for creating both ECS and LiteLLM-hosted models. Depending on the request payload, infrastructure will be created or bypassed (e.g., for LiteLLM-only models). +LISA provides the `/models` endpoint for creating LISA-hosted ECS models and externally hosted models managed through LiteLLM. Externally hosted models include both third-party providers and customer internal hosted endpoints. This API accepts the same model definition parameters that were accepted in the V2 model definitions within the config.yaml file with one notable difference: the `containerConfig.image.path` field is now omitted because it corresponded with the `inferenceContainer` selection. As a convenience, this path is no longer required. @@ -170,6 +170,19 @@ POST https:///models } ``` +### Creating a Customer Internal Hosted Model: + +```json +{ + "modelId": "internal-mistral7b", + "modelName": "openai/mistral-7b-instruct", + "modelType": "textgen", + "streaming": true, + "hostingType": "INTERNAL_HOSTED", + "modelUrl": "http://internal-lisa-mistral7binstruct03-665568061.us-east-1.elb.amazonaws.com/v1" +} +``` + ### Explanation of Key Fields for Creation Payload: - `modelId`: The unique identifier for the model. This is any name you would like it to be. @@ -182,6 +195,8 @@ POST https:///models - LiteLLM-only, SageMaker: If you want to use a SageMaker Endpoint named `my-sm-endpoint`, then the `modelName` value should be `sagemaker/my-sm-endpoint`. - `modelType`: The type of model, such as text generation (textgen). - `streaming`: Whether the model supports streaming inference. +- `hostingType`: Optional hosting selector. Use `INTERNAL_HOSTED` for customer internal load balancer endpoints. +- `modelUrl`: Required for `INTERNAL_HOSTED` and used as LiteLLM `api_base` for inference routing. - `instanceType`: The type of EC2 instance to be used (only applicable for ECS models). - `containerConfig`: Details about the Docker container, memory allocation, and environment variables. - `autoScalingConfig`: Configuration related to ECS autoscaling. diff --git a/lib/docs/config/model-management-ui.md b/lib/docs/config/model-management-ui.md index bb760ff82..ba5888834 100644 --- a/lib/docs/config/model-management-ui.md +++ b/lib/docs/config/model-management-ui.md @@ -2,7 +2,12 @@ ## Configuring Models -LISA's Model Management UI allows Administrators to configure models for use with LISA. LISA supports third party models that are hosted externally to LISA that are compatible with LiteLLM. LISA also supports self-hosting models within Amazon ECS. LISA's Model Management wizard walks Administrators through configuration steps. +LISA's Model Management UI allows Administrators to configure models for use with LISA. LISA supports: +- third-party models hosted externally to LISA that are compatible with LiteLLM, +- customer internal hosted models exposed by an internal AWS load balancer URL, and +- self-hosted models running on LISA-managed Amazon ECS infrastructure. + +LISA's Model Management wizard walks Administrators through configuration steps. ## Scaling Models diff --git a/lib/docs/config/prompt-templates.md b/lib/docs/config/prompt-templates.md new file mode 100644 index 000000000..03aa8c095 --- /dev/null +++ b/lib/docs/config/prompt-templates.md @@ -0,0 +1,84 @@ +# Prompt Templates API + +LISA includes prompt template APIs to help teams standardize common prompts and reuse them across chat workflows. + +## Overview + +Prompt Templates in LISA are reusable prompt artifacts that can be created by users (or administrators), edited over time, and selected in chat workflows. They are primarily used to standardize how teams prompt models and to reduce repeated prompt authoring. + +LISA supports two common prompt styles: + +- **Directive prompts**: Instruction-focused templates that define what the model should do (for example, summarize, extract entities, classify, or generate structured output). +- **Persona prompts**: Role-focused templates that define how the model should respond (for example, tone, audience, communication style, and level of detail). + +These styles can be used independently or combined. A common pattern is to use a persona prompt to establish communication style, then a directive prompt to enforce task-specific behavior and output format. + +### Visibility and Access Model + +Prompt templates can be scoped to different audiences in LISA: + +- **Private**: Visible only to the creator; useful for personal workflows and experimentation. +- **Shared to IDP groups**: Available to specific identity-provider groups; useful for team- or role-specific prompt libraries. +- **Global**: Available to all users; useful for organization-wide standards, approved templates, and common operational workflows. + +This model lets organizations balance flexibility and governance: individuals can iterate quickly with private templates, teams can collaborate through group-scoped templates, and administrators can publish vetted global templates for broad reuse. + +### Suggested Usage + +- Use **directive prompts** for repeatable tasks that require consistent output structure. +- Use **persona prompts** for consistency in voice and audience fit. +- Use **group-shared templates** for domain teams (for example, operations, engineering, or compliance). +- Use **global templates** for officially approved prompts that should be broadly discoverable. + +## API Reference + +Base path: `/prompt-templates` + +### List Prompt Templates + +- Method: `GET` +- Path: `/prompt-templates` +- Description: Lists prompt templates available to the caller. + +### Create Prompt Template + +- Method: `POST` +- Path: `/prompt-templates` +- Description: Creates a new prompt template. + +### Get Prompt Template + +- Method: `GET` +- Path: `/prompt-templates/{promptTemplateId}` +- Description: Returns a specific prompt template. + +Path parameters: + +- `promptTemplateId` (string, required): Prompt template identifier + +### Update Prompt Template + +- Method: `PUT` +- Path: `/prompt-templates/{promptTemplateId}` +- Description: Updates a specific prompt template. + +Path parameters: + +- `promptTemplateId` (string, required): Prompt template identifier + +### Delete Prompt Template + +- Method: `DELETE` +- Path: `/prompt-templates/{promptTemplateId}` +- Description: Deletes a specific prompt template. + +Path parameters: + +- `promptTemplateId` (string, required): Prompt template identifier + +Example: + +```bash +curl -X GET "https:////prompt-templates" \ + -H "Authorization: Bearer " +``` diff --git a/lib/docs/config/repositories.md b/lib/docs/config/repositories.md index a4320e675..fd8c2b832 100644 --- a/lib/docs/config/repositories.md +++ b/lib/docs/config/repositories.md @@ -33,7 +33,7 @@ The repository-collection model provides a two-tier organizational structure ana Customers have two methods to load files into repositories configured with LISA: 1. **Manual Upload**: Load files via the chat assistant user interface (UI), or API -2. **Automated Pipeline**: (Admins-only) Configure LISA's ingestion pipelines for automated document processing +2. **Automated Pipeline**: (Admins and RAG Admins) Configure LISA's ingestion pipelines for automated document processing. Admins can configure pipelines on any repository; RAG Admins can configure pipelines on repositories they have group access to. This role is especially useful in multi-tenant environments. ## Configuration @@ -46,6 +46,7 @@ Files loaded via the chat assistant UI are limited by size, and are processed th LISA's automated document ingestion pipeline supports larger files and broader file types. Supported file types include: PDF, docx, and plain text files (.txt, .json, .yaml, xml, etc). The individual file size limit is 50 MB. LISA's pipelines offer chunking support for fixed size chunking or no chunking. For customers using Amazon Bedrock Knowledge Bases, LISA supports all chunking strategies offered by the service. LISA's automated ingestion pipelines provide customers with a flexible, scalable solution for loading documents into configured repositories and collections. Customers can set up multiple ingestion pipelines for a repository. For each pipeline they define: + - The target repository and collection - Embedding model (inherited from repository if not defined) - Chunking strategy (can be customized per pipeline) @@ -127,6 +128,7 @@ Collection access is controlled through user groups: - **Repository-level Groups**: Collections inherit allowed groups from their parent repository by default - **Collection-level Groups**: Collections can override with their own group restrictions for finer control - **Admin Access**: Administrators have full access to all collections across all repositories +- **RAG Admin Access**: RAG Admins can create, update, and delete collections on repositories they have group access to. They cannot modify repository-level settings or `allowedGroups`. This role is especially useful in multi-tenant environments. - **User Collection Creation**: Repositories can be configured to allow or restrict user-created collections via the `allowUserCollections` flag ## Configuration Examples @@ -137,7 +139,7 @@ RAG repositories and collections are configurable through the chat assistant web Repositories are created by administrators and define the underlying vector store implementation, embedding model, and default access controls. -#### Request Example: +#### Request Example ```bash curl -s -H 'Authorization: Bearer ' -XPOST -d @repository.json https:///repository @@ -174,7 +176,7 @@ curl -s -H 'Authorization: Bearer ' -XPOST -d @repository.json https } ``` -#### Response Fields: +#### Response Fields - `status`: "success" if the state machine was started successfully - `executionArn`: The state machine ARN used to deploy the repository @@ -183,7 +185,7 @@ curl -s -H 'Authorization: Bearer ' -XPOST -d @repository.json https Collections can be created by users with appropriate permissions within an existing repository. -#### Request Example: +#### Collection Request Example ```bash curl -s -H 'Authorization: Bearer ' -XPOST -d @collection.json https:///repository/my-rag-repository/collection @@ -216,7 +218,7 @@ curl -s -H 'Authorization: Bearer ' -XPOST -d @collection.json https } ``` -#### Response Fields: +#### Collection Response Fields - `collectionId`: Unique identifier for the created collection (UUID) - `repositoryId`: Parent repository identifier @@ -230,14 +232,14 @@ curl -s -H 'Authorization: Bearer ' -XPOST -d @collection.json https Retrieve all collections accessible to the current user within a repository. -#### Request Example: +#### Listing Request Example ```bash curl -s -H 'Authorization: Bearer ' \ 'https:///repository/my-rag-repository/collections?page=1&pageSize=20&sortBy=name&sortOrder=asc' ``` -#### Query Parameters: +#### Query Parameters - `page`: Page number (default: 1) - `pageSize`: Items per page (default: 20, max: 100) @@ -247,17 +249,23 @@ curl -s -H 'Authorization: Bearer ' \ ## UI Components -### RAG Repository Management (Admin) +### RAG Repository Management (Admin and RAG Admin) -Administrators access repository management through the Admin Configurations page. This interface provides: +Administrators and RAG Admins access repository management through the Administration menu. The capabilities available depend on the user's role: +**Administrators** have full access, including: - Create, update, and delete repositories - Configure vector store implementation (OpenSearch, PGVector, Bedrock Knowledge Base) - Set default embedding models and chunking strategies -- Define repository-level access controls +- Define repository-level access controls (`allowedGroups`) - Configure metadata tags - Enable or disable user-created collections +**RAG Admins** have scoped access on repositories they belong to via group membership: +- Create, update, and delete collections +- Update ingestion pipelines +- Cannot create or delete repositories, or modify `allowedGroups` + ### RAG Collection Library The Collection Library is accessible from the Document Library page and provides: diff --git a/lib/docs/config/session.md b/lib/docs/config/session.md new file mode 100644 index 000000000..9465c2556 --- /dev/null +++ b/lib/docs/config/session.md @@ -0,0 +1,88 @@ +# Session API + +LISA uses session APIs to persist and manage chat session state, including metadata updates and media attachment workflows. + +## Overview + +Session endpoints power core chat lifecycle behavior in LISA: + +- Listing a user's existing sessions +- Creating or updating a session +- Renaming sessions for better organization +- Attaching generated or uploaded images to session history +- Deleting one or all sessions for the user + +These APIs are used by the chat UI and can also be used programmatically. + +## API Reference + +Base path: `/session` + +### List Sessions + +- Method: `GET` +- Path: `/session` +- Description: Lists sessions available to the caller. + +### Delete All Caller Sessions + +- Method: `DELETE` +- Path: `/session` +- Description: Deletes all sessions for the caller. + +### Get Session + +- Method: `GET` +- Path: `/session/{sessionId}` +- Description: Returns a specific session by ID. + +Path parameters: + +- `sessionId` (string, required): Session identifier + +### Create or Update Session + +- Method: `PUT` +- Path: `/session/{sessionId}` +- Description: Creates or updates a specific session. + +Path parameters: + +- `sessionId` (string, required): Session identifier + +### Delete Session + +- Method: `DELETE` +- Path: `/session/{sessionId}` +- Description: Deletes a specific session. + +Path parameters: + +- `sessionId` (string, required): Session identifier + +### Rename Session + +- Method: `PUT` +- Path: `/session/{sessionId}/name` +- Description: Updates a session display name. + +Path parameters: + +- `sessionId` (string, required): Session identifier + +### Attach Image to Session + +- Method: `PUT` +- Path: `/session/{sessionId}/attachImage` +- Description: Attaches image metadata/content to a session. + +Path parameters: + +- `sessionId` (string, required): Session identifier + +Example: + +```bash +curl -X GET "https:////session" \ + -H "Authorization: Bearer " +``` diff --git a/lib/docs/config/user-preferences.md b/lib/docs/config/user-preferences.md new file mode 100644 index 000000000..0c5c06a07 --- /dev/null +++ b/lib/docs/config/user-preferences.md @@ -0,0 +1,40 @@ +# User Preferences API + +LISA persists user-specific behavior and UI preferences through a dedicated user preferences API. + +## Overview + +User Preferences are used to retain per-user settings across sessions, including preferences that affect chat and MCP behavior. This API provides: + +- Retrieval of current caller preferences +- Creation or update of caller preferences + +These endpoints are user-scoped and designed for personalized experience management. + +## API Reference + +Base path: `/user-preferences` + +### Get User Preferences + +- Method: `GET` +- Path: `/user-preferences` +- Description: Returns preferences for the calling user. + +### Create or Update User Preferences + +- Method: `PUT` +- Path: `/user-preferences` +- Description: Creates or updates preferences for the calling user. + +Example: + +```bash +curl -X PUT "https:////user-preferences" \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "theme": "dark", + "showMcpTools": true + }' +``` diff --git a/lib/docs/docConstruct.ts b/lib/docs/docConstruct.ts index 971d26141..49a973e13 100644 --- a/lib/docs/docConstruct.ts +++ b/lib/docs/docConstruct.ts @@ -23,12 +23,12 @@ import { BaseProps } from '../schema'; import { Roles } from '../core/iam/roles'; import { DOCS_DIST_PATH } from '../util'; -import { StringParameter } from 'aws-cdk-lib/aws-ssm'; - /** * Properties for DocsStack Construct. */ -export type LisaDocsProps = BaseProps & StackProps; +export type LisaDocsProps = BaseProps & StackProps & { + bucketAccessLogsBucket: IBucket; +}; /** * User Interface Construct. @@ -44,11 +44,7 @@ export class LisaDocsConstruct extends Construct { super(scope, id); this.scope = scope; - const { config } = props; - - const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', - StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) - ); + const { bucketAccessLogsBucket, config } = props; // Create Docs S3 bucket const docsBucket = new Bucket(scope, 'DocsBucket', { diff --git a/lib/docs/user/breaking-changes.md b/lib/docs/user/breaking-changes.md index d160bf854..2077586e8 100644 --- a/lib/docs/user/breaking-changes.md +++ b/lib/docs/user/breaking-changes.md @@ -61,7 +61,7 @@ upgrade: internally, rendering the ecsModels list obsolete. We recommend backing up your model settings to facilitate their redeployment through the new Model Management API with minimal downtime. 1. Networking Changes and Full Teardown: Core networking changes require a complete teardown of the existing LISA - installation using the make destroy command before upgrading. Cross-stack dependencies have been modified, + installation using the `npm run destroy` command before upgrading. Cross-stack dependencies have been modified, necessitating this full teardown to ensure proper application of the v3 infrastructure changes. Additionally, users may need to manually delete some resources, such as ECR repositories or S3 buckets, if they were populated before CloudFormation began deleting the stack. This operation is destructive and irreversible, so it is crucial to back up diff --git a/lib/mcp/mcp-server-api.ts b/lib/mcp/mcp-server-api.ts index c6cd2ab28..21a0caada 100644 --- a/lib/mcp/mcp-server-api.ts +++ b/lib/mcp/mcp-server-api.ts @@ -25,17 +25,19 @@ import { Construct } from 'constructs'; import { getPythonRuntime, registerAPIEndpoint } from '../api-base/utils'; import { APP_MANAGEMENT_KEY, BaseProps } from '../schema'; import { createCdkId, createLambdaRole } from '../core/utils'; +import { getAuditLoggingEnv } from '../api-base/auditEnv'; import { Vpc } from '../networking/vpc'; import { LAMBDA_PATH } from '../util'; import { McpServerDeployer } from './mcp-server-deployer'; import { CreateMcpServerStateMachine } from './state-machine/create-mcp-server'; import { DeleteMcpServerStateMachine } from './state-machine/delete-mcp-server'; import { UpdateMcpServerStateMachine } from './state-machine/update-mcp-server'; -import { BlockPublicAccess, Bucket, BucketEncryption, HttpMethods } from 'aws-cdk-lib/aws-s3'; +import { BlockPublicAccess, Bucket, BucketEncryption, HttpMethods, IBucket } from 'aws-cdk-lib/aws-s3'; import { RemovalPolicy } from 'aws-cdk-lib'; type McpServerApiProps = { authorizer: IAuthorizer; + bucketAccessLogsBucket: IBucket; restApiId: string; rootResourceId: string; securityGroups: ISecurityGroup[]; @@ -54,7 +56,7 @@ export class McpServerApi extends Construct { constructor (scope: Construct, id: string, props: McpServerApiProps) { super(scope, id); - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + const { authorizer, bucketAccessLogsBucket, config, restApiId, rootResourceId, securityGroups, vpc } = props; // Get common layer based on arn from SSM due to issues with cross stack references const commonLambdaLayer = LayerVersion.fromLayerVersionArn( @@ -85,10 +87,6 @@ export class McpServerApi extends Construct { deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY, }); - const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', - StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) - ); - const bucket = new Bucket(scope, createCdkId(['LISA', 'MCP-Hosting', config.deploymentName, config.deploymentStage]), { removalPolicy: config.removalPolicy, autoDeleteObjects: config.removalPolicy === RemovalPolicy.DESTROY, @@ -120,12 +118,22 @@ export class McpServerApi extends Construct { if (!mcpResource) { mcpResource = restApi.root.addResource('mcp'); } + // Match hosted MCP server routes: browser MCP clients send MCP + LISA session headers (see ecsMcpServer.allowedCorsHeaders) + const mcpBrowserCorsAllowHeaders = Array.from(new Set([ + ...Cors.DEFAULT_HEADERS, + 'Accept', + 'Mcp-Session-Id', + 'X-Session-Id', + 'Last-Event-Id', + 'mcp-protocol-version', + 'X-Amz-User-Agent', + ])); // Add CORS preflight support for the /mcp resource // This ensures OPTIONS method is available even if the resource already existed // addCorsPreflight is idempotent - it won't create duplicate OPTIONS methods mcpResource.addCorsPreflight({ allowOrigins: Cors.ALL_ORIGINS, - allowHeaders: Cors.DEFAULT_HEADERS, + allowHeaders: mcpBrowserCorsAllowHeaders, }); const mcpResourceId = mcpResource.resourceId; @@ -197,6 +205,7 @@ export class McpServerApi extends Construct { DELETE_MCP_SERVER_SFN_ARN: deleteMcpServerStateMachine.stateMachineArn, UPDATE_MCP_SERVER_SFN_ARN: updateMcpServerStateMachine.stateMachineArn, ADMIN_GROUP: config.authConfig?.adminGroup || '', + ...getAuditLoggingEnv(config), }; const lambdaRole = createLambdaRole(this, config.deploymentName, 'McpServerDynamicApi', mcpServersTable.tableArn, config.roles?.LambdaExecutionRole); diff --git a/lib/mcp/mcpApiConstruct.ts b/lib/mcp/mcpApiConstruct.ts index 2bcd5d408..9805a986a 100644 --- a/lib/mcp/mcpApiConstruct.ts +++ b/lib/mcp/mcpApiConstruct.ts @@ -17,6 +17,7 @@ import { Stack, StackProps } from 'aws-cdk-lib'; import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { IBucket } from 'aws-cdk-lib/aws-s3'; import { Construct } from 'constructs'; import { Vpc } from '../networking/vpc'; @@ -30,6 +31,7 @@ export type LisaMcpApiProps = BaseProps & rootResourceId: string; securityGroups: ISecurityGroup[]; vpc: Vpc; + bucketAccessLogsBucket: IBucket; }; /** @@ -44,11 +46,12 @@ export class LisaMcpApiConstruct extends Construct { constructor (scope: Stack, id: string, props: LisaMcpApiProps) { super(scope, id); - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + const { authorizer, bucketAccessLogsBucket, config, restApiId, rootResourceId, securityGroups, vpc } = props; // Add MCP Server API dynamic hosting new McpServerApi(scope, 'McpServerApi', { authorizer, + bucketAccessLogsBucket, config, restApiId, rootResourceId, diff --git a/lib/metrics/index.ts b/lib/metrics/index.ts index b28b059d0..a37235e36 100644 --- a/lib/metrics/index.ts +++ b/lib/metrics/index.ts @@ -22,6 +22,7 @@ import { Construct } from 'constructs'; import { BaseProps } from '../schema'; import { Vpc } from '../networking/vpc'; import { MetricsConstruct } from './metricsConstruct'; +import { ModelHealthDashboard } from './modelHealthDashboard'; /** * Properties for LisaMetricsStack. @@ -48,5 +49,9 @@ export class LisaMetricsStack extends Stack { new MetricsConstruct(this, id, props).node.addMetadata('aws:cdk:path', this.node.path); + if (props.config.deployHealthDashboard) { + new ModelHealthDashboard(this, 'ModelHealth', { config: props.config }); + } + } } diff --git a/lib/metrics/metricsConstruct.ts b/lib/metrics/metricsConstruct.ts index 57b1f7158..536b69b64 100644 --- a/lib/metrics/metricsConstruct.ts +++ b/lib/metrics/metricsConstruct.ts @@ -33,6 +33,7 @@ import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../ import { BaseProps } from '../schema'; import { createLambdaRole } from '../core/utils'; import { Vpc } from '../networking/vpc'; +import { getAuditLoggingEnv } from '../api-base/auditEnv'; import { LAMBDA_PATH } from '../util'; import { Duration, RemovalPolicy } from 'aws-cdk-lib'; @@ -210,6 +211,43 @@ export class MetricsConstruct extends Construct { width: 8, height: 6, }), + // Total Prompts by Model Widget + new cloudwatch.GraphWidget({ + title: 'Total Prompts by Model', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/UsageMetrics,ModelId} MetricName="ModelPromptCount"\', \'Sum\', 3600)', + label: '', + period: Duration.hours(1), + }), + ], + view: cloudwatch.GraphWidgetView.BAR, + width: 8, + height: 6, + }), + // Total Prompt Tokens + Total Completion Tokens (aggregate, stacked) + new cloudwatch.GraphWidget({ + title: 'Total Tokens Over Time (Aggregate)', + left: [ + new cloudwatch.Metric({ + namespace: 'LISA/UsageMetrics', + metricName: 'TotalPromptTokens', + label: 'Input Tokens', + statistic: 'Sum', + period: Duration.hours(1), + }), + new cloudwatch.Metric({ + namespace: 'LISA/UsageMetrics', + metricName: 'TotalCompletionTokens', + label: 'Output Tokens', + statistic: 'Sum', + period: Duration.hours(1), + }), + ], + stacked: true, + width: 8, + height: 6, + }), // User Metrics section new cloudwatch.TextWidget({ @@ -304,10 +342,105 @@ export class MetricsConstruct extends Construct { width: 8, height: 6, }), + + // ── Token Usage Metrics section ────────────────────────────────── + new cloudwatch.TextWidget({ + markdown: '## **Token Usage Metrics**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // Input tokens by model + new cloudwatch.GraphWidget({ + title: 'Input Tokens by Model', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/UsageMetrics,ModelId} MetricName="ModelPromptTokens"\', \'Sum\', 3600)', + label: '', + period: Duration.hours(1), + }), + ], + width: 8, + height: 6, + }), + + // Output tokens by model + new cloudwatch.GraphWidget({ + title: 'Output Tokens by Model', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/UsageMetrics,ModelId} MetricName="ModelCompletionTokens"\', \'Sum\', 3600)', + label: '', + period: Duration.hours(1), + }), + ], + width: 8, + height: 6, + }), + + // Input tokens by user + new cloudwatch.GraphWidget({ + title: 'Input Tokens by User', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/UsageMetrics,UserId} MetricName="UserPromptTokens"\', \'Sum\', 3600)', + label: '', + period: Duration.hours(1), + }), + ], + width: 8, + height: 6, + }), + + // Output tokens by user + new cloudwatch.GraphWidget({ + title: 'Output Tokens by User', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/UsageMetrics,UserId} MetricName="UserCompletionTokens"\', \'Sum\', 3600)', + label: '', + period: Duration.hours(1), + }), + ], + width: 8, + height: 6, + }), + + // Input tokens by group + new cloudwatch.GraphWidget({ + title: 'Input Tokens by Group', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/UsageMetrics,GroupName} MetricName="GroupPromptTokens"\', \'Sum\', 3600)', + label: '', + period: Duration.hours(1), + }), + ], + width: 8, + height: 6, + }), + + // Output tokens by group + new cloudwatch.GraphWidget({ + title: 'Output Tokens by Group', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/UsageMetrics,GroupName} MetricName="GroupCompletionTokens"\', \'Sum\', 3600)', + label: '', + period: Duration.hours(1), + }), + ], + width: 8, + height: 6, + }), ); + // ECS Model Health is in a separate dashboard β€” see modelHealthDashboard.ts + const env = { - USAGE_METRICS_TABLE_NAME: usageMetricsTable.tableName + USAGE_METRICS_TABLE_NAME: usageMetricsTable.tableName, + ...getAuditLoggingEnv(config), }; // Create metrics API endpoints diff --git a/lib/metrics/modelHealthDashboard.ts b/lib/metrics/modelHealthDashboard.ts new file mode 100644 index 000000000..dbec598d5 --- /dev/null +++ b/lib/metrics/modelHealthDashboard.ts @@ -0,0 +1,925 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import * as cloudwatch from 'aws-cdk-lib/aws-cloudwatch'; +import { Construct } from 'constructs'; +import { Duration } from 'aws-cdk-lib'; +import { BaseProps } from '../schema'; + +/** + * CloudWatch dashboard for ECS model hosting operational health. + * + * Uses Container Insights v2 metrics (ECS/ContainerInsights namespace) + * and ALB metrics. SEARCH expressions auto-discover all model clusters + * so new model deployments appear without dashboard changes. + */ +export class ModelHealthDashboard extends Construct { + + constructor (scope: Construct, id: string, props: BaseProps) { + super(scope, id); + + const { config } = props; + // Deployment prefix used in SEARCH expressions to scope to this deployment's clusters. + // Cluster names are built via createCdkId and always start with deploymentName + // (e.g. "prod-gptoss20b"). CloudWatch SEARCH tokenizes on hyphens, so + // "prod-gptoss20b" becomes tokens ["prod", "gptoss20b"]. Using a partial match + // (no double quotes) like ClusterName=${dp} matches any ClusterName containing + // the deployment name token. Double-quoted values do exact match only β€” no wildcards. + const dp = config.deploymentName; + + const dashboard = new cloudwatch.Dashboard(this, 'ModelHealthDashboard', { + dashboardName: `${dp}-${config.deploymentStage}-LISA-Model-Health`, + start: '-P7D', + }); + + // ===================================================================== + // Task & Container Health + // ===================================================================== + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '# **LISA Self-Hosted Model Health Dashboard**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + new cloudwatch.TextWidget({ + markdown: '## **Task & Container Health**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // Running vs Desired Task Count per cluster/service + // Use Maximum so counts display as whole numbers instead of fractional averages. + new cloudwatch.GraphWidget({ + title: 'Running vs Desired Tasks (by Cluster)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="RunningTaskCount" ClusterName=${dp}', 'Maximum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="DesiredTaskCount" ClusterName=${dp}', 'Maximum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Pending tasks β€” waiting for placement (capacity issues) + // Use Maximum instead of Average so the count shows as whole numbers + // (Average over 5 min produces tiny fractions like 0.03). + new cloudwatch.GraphWidget({ + title: 'Pending Tasks (by Cluster)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="PendingTaskCount" ClusterName=${dp}', 'Maximum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Task set count β€” tracks deployment rollouts and circuit breaker activity + new cloudwatch.GraphWidget({ + title: 'Task Sets (Deployment Rollouts)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="TaskSetCount" ClusterName=${dp}', 'Maximum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Service deployment count β€” spikes indicate restarts or circuit breaker trips + new cloudwatch.GraphWidget({ + title: 'Deployment Count (by Service)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="DeploymentCount" ClusterName=${dp}', 'Maximum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + ); + + // ===================================================================== + // ALB Target Health + // ===================================================================== + // ALB metrics are published with specific dimension combos. Target-level + // metrics (HealthyHostCount, HTTP codes, etc.) use {TargetGroup, LoadBalancer}. + // Connection-level metrics (ActiveConnectionCount, etc.) use {LoadBalancer} only. + // The deployment name token (e.g. "prod") scopes results to this deployment's + // ALBs and target groups. + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **ALB Target Health**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // Healthy host count per target group + new cloudwatch.GraphWidget({ + title: 'Healthy Host Count (by Target Group)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,TargetGroup,LoadBalancer} MetricName="HealthyHostCount" ${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Unhealthy host count per target group + new cloudwatch.GraphWidget({ + title: 'Unhealthy Host Count (by Target Group)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,TargetGroup,LoadBalancer} MetricName="UnHealthyHostCount" ${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + ); + + // ===================================================================== + // Error Rates + // ===================================================================== + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **Error Rates**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // Target 5xx β€” failed model invocations (500s from the container) + new cloudwatch.GraphWidget({ + title: 'Target 5xx Errors (Failed Invocations)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,TargetGroup,LoadBalancer} MetricName="HTTPCode_Target_5XX_Count" ${dp}', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 8, + height: 6, + }), + + // Target 4xx β€” client errors / bad requests to models + new cloudwatch.GraphWidget({ + title: 'Target 4xx Errors (by Model)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,TargetGroup,LoadBalancer} MetricName="HTTPCode_Target_4XX_Count" ${dp}', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 8, + height: 6, + }), + + // ELB 5xx β€” load balancer level errors (no healthy targets, timeouts) + new cloudwatch.GraphWidget({ + title: 'ELB 5xx Errors (by Load Balancer)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,LoadBalancer} MetricName="HTTPCode_ELB_5XX_Count" ${dp}', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 8, + height: 6, + }), + ); + + // ===================================================================== + // Latency & Throughput + // ===================================================================== + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **Latency & Throughput**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // Target response time p50/p99 per model + // ALB publishes TargetResponseTime in seconds; multiply by 1000 for milliseconds. + // Exclude the REST API target group (contains "RestA") β€” it's the API router, not a model. + // SEARCH auto-labels include the full ALB/TG ARN path which is hard to read; + // unfortunately CloudWatch SEARCH doesn't support label customization. + // For clean per-model latency, see the Inference Engine Metrics section below + // (E2E Request Latency, TTFT, Inter-Token Latency) which use the ModelName dimension. + new cloudwatch.GraphWidget({ + title: 'Target Response Time p50 (by Model)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,TargetGroup,LoadBalancer} MetricName="TargetResponseTime" ${dp} NOT RestA NOT rest NOT MCP', 'p50', 300) * 1000`, + label: '', + period: Duration.minutes(5), + }), + ], + leftYAxis: { label: 'ms' }, + width: 12, + height: 6, + }), + + new cloudwatch.GraphWidget({ + title: 'Target Response Time p99 (by Model)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,TargetGroup,LoadBalancer} MetricName="TargetResponseTime" ${dp} NOT RestA NOT rest NOT MCP', 'p99', 300) * 1000`, + label: '', + period: Duration.minutes(5), + }), + ], + leftYAxis: { label: 'ms' }, + width: 12, + height: 6, + }), + + // Request count per model (throughput / load) β€” excludes REST API target group + new cloudwatch.GraphWidget({ + title: 'Request Count (by Model)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,TargetGroup,LoadBalancer} MetricName="RequestCount" ${dp} NOT RestA NOT rest NOT MCP', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Active connection count β€” concurrent load per ALB + new cloudwatch.GraphWidget({ + title: 'Active Connections (by Load Balancer)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,LoadBalancer} MetricName="ActiveConnectionCount" ${dp}', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // New connection count β€” rate of new connections + new cloudwatch.GraphWidget({ + title: 'New Connections (by Load Balancer)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/ApplicationELB,LoadBalancer} MetricName="NewConnectionCount" ${dp}', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + ); + + // ===================================================================== + // Resource Utilization + // ===================================================================== + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **Resource Utilization**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // CPU utilization per cluster/service + new cloudwatch.GraphWidget({ + title: 'CPU Utilized (by Cluster)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="CpuUtilized" ClusterName=${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 8, + height: 6, + }), + + // Memory utilization per cluster/service + new cloudwatch.GraphWidget({ + title: 'Memory Utilized (by Cluster)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="MemoryUtilized" ClusterName=${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 8, + height: 6, + }), + + // GPU Cache Usage (vLLM) β€” from custom metrics publisher + // The raw metric is a 0–1 decimal; multiply by 100 for display as a percentage. + new cloudwatch.GraphWidget({ + title: 'GPU Cache Usage % (vLLM)', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/InferenceMetrics,ModelName} MetricName="GpuCacheUsagePercent"\', \'Average\', 300) * 100', + label: '', + period: Duration.minutes(5), + }), + ], + leftYAxis: { min: 0, max: 100, label: '%' }, + width: 8, + height: 6, + }), + + // CPU reserved vs utilized β€” shows headroom + new cloudwatch.GraphWidget({ + title: 'CPU Reserved (by Cluster)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="CpuReserved" ClusterName=${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 8, + height: 6, + }), + + // Memory reserved vs utilized β€” shows headroom + new cloudwatch.GraphWidget({ + title: 'Memory Reserved (by Cluster)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="MemoryReserved" ClusterName=${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 8, + height: 6, + }), + + // Inference Requests Running/Waiting (vLLM) β€” from custom metrics publisher + new cloudwatch.GraphWidget({ + title: 'Requests Running / Waiting (vLLM)', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/InferenceMetrics,ModelName} MetricName="RequestsRunning"\', \'Average\', 300)', + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{LISA/InferenceMetrics,ModelName} MetricName="RequestsWaiting"\', \'Average\', 300)', + label: '', + period: Duration.minutes(5), + }), + ], + width: 8, + height: 6, + }), + ); + + // ===================================================================== + // Network & Storage + // ===================================================================== + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **Network & Storage**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // Network throughput RX/TX + new cloudwatch.GraphWidget({ + title: 'Network RX / TX Bytes (by Cluster)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="NetworkRxBytes" ClusterName=${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="NetworkTxBytes" ClusterName=${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Storage I/O β€” read and write bytes + new cloudwatch.GraphWidget({ + title: 'Storage Read / Write Bytes (by Cluster)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="StorageReadBytes" ClusterName=${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{ECS/ContainerInsights,ClusterName,ServiceName} MetricName="StorageWriteBytes" ClusterName=${dp}', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + ); + + // ===================================================================== + // Inference Engine Metrics (from metrics_publisher.py) + // ===================================================================== + // These metrics are scraped from the Prometheus /metrics endpoint of each + // inference engine (vLLM, TGI, TEI) and published to the LISA/InferenceMetrics + // CloudWatch namespace by a background script running in each container. + const metricsNs = 'LISA/InferenceMetrics'; + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **Inference Engine Metrics**\nScraped from Prometheus `/metrics` endpoints via `metrics_publisher.py`', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // vLLM: Token throughput β€” derived from cumulative token counters. + // AvgPrompt/GenerationThroughputToksPerSec gauges were removed in newer vLLM versions, + // so we use DIFF on the cumulative totals divided by the period (300s) to get toks/sec. + new cloudwatch.GraphWidget({ + title: 'Token Throughput (vLLM)', + left: [ + new cloudwatch.MathExpression({ + expression: `DIFF(SEARCH('{${metricsNs},ModelName} MetricName="PromptTokensTotal"', 'Maximum', 300)) / 300`, + label: 'Prompt toks/s', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `DIFF(SEARCH('{${metricsNs},ModelName} MetricName="GenerationTokensTotal"', 'Maximum', 300)) / 300`, + label: 'Generation toks/s', + period: Duration.minutes(5), + }), + ], + leftYAxis: { label: 'toks/s' }, + rightYAxis: { label: 'toks/s' }, + width: 12, + height: 6, + }), + + // vLLM: E2E request latency and TTFT + new cloudwatch.GraphWidget({ + title: 'E2E Request Latency / TTFT (vLLM)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="E2ERequestLatencySeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="TimeToFirstTokenSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // vLLM: Inter-token latency (TPOT) β€” key SLO metric for streaming + new cloudwatch.GraphWidget({ + title: 'Inter-Token Latency / TPOT (vLLM)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="InterTokenLatencySeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // vLLM: Queue time β€” how long requests wait before processing + new cloudwatch.GraphWidget({ + title: 'Request Queue Time (vLLM)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="RequestQueueTimeSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // vLLM: Prefill and decode time breakdown + new cloudwatch.GraphWidget({ + title: 'Prefill / Decode Time (vLLM)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="RequestPrefillTimeSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="RequestDecodeTimeSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // vLLM: Completed requests and prefix cache effectiveness + new cloudwatch.GraphWidget({ + title: 'Completed Requests (vLLM)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="RequestSuccessTotal"', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // TGI/TEI: Queue size and batch size + new cloudwatch.GraphWidget({ + title: 'Queue Size (TGI / TEI)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="QueueSize"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // TGI/TEI: Batch current size + new cloudwatch.GraphWidget({ + title: 'Batch Current Size (TGI / TEI)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="BatchCurrentSize"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // TGI: Request success / failure counts + new cloudwatch.GraphWidget({ + title: 'TGI Request Success / Failure', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="RequestSuccess"', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="RequestFailure"', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // TGI: Latency breakdown β€” queue, inference, per-token + new cloudwatch.GraphWidget({ + title: 'TGI Latency Breakdown (Queue / Inference / Per-Token)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="QueueDurationSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="InferenceDurationSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="MeanTimePerTokenSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // TGI: Input / output token sizes per request + new cloudwatch.GraphWidget({ + title: 'TGI Avg Input / Generated Tokens per Request', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="InputLengthPerRequest"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="GeneratedTokensPerRequest"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // TEI: Request duration breakdown β€” tokenization, queue, inference + new cloudwatch.GraphWidget({ + title: 'TEI Request Duration Breakdown', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="RequestDurationSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // TEI: Tokenization / Queue / Inference time breakdown + new cloudwatch.GraphWidget({ + title: 'TEI Tokenization / Queue / Inference Time', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="TokenizationDurationSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="QueueDurationSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="InferenceDurationSeconds"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Metrics publisher heartbeat β€” confirms which models are reporting + new cloudwatch.GraphWidget({ + title: 'Metrics Publisher Heartbeat (by Model)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{${metricsNs},ModelName} MetricName="MetricsPublisherHeartbeat"', 'Average', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + ); + + // ===================================================================== + // Batch Ingestion Metrics + // ===================================================================== + // AWS Batch does not publish job-level metrics to CloudWatch natively. + // All job state transitions (SUBMITTED, RUNNING, SUCCEEDED, FAILED) are + // captured via EventBridge β†’ Lambda β†’ custom CloudWatch metrics in the + // LISA/BatchIngestion namespace. This provides queue-level visibility + // regardless of how the job was triggered (S3 event, scheduled, or + // manual upload through the chat UI). + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **Batch Ingestion**\nJob queue metrics from EventBridge state change events (covers all ingestion triggers)', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // Jobs submitted β€” total ingestion jobs entering the queue from any source + new cloudwatch.GraphWidget({ + title: 'Jobs Submitted (All Sources)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{LISA/BatchIngestion,DeploymentName,DeploymentStage,JobQueue} MetricName="JobsSubmitted" DeploymentName="${dp}"', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Jobs succeeded vs failed β€” completion outcomes + new cloudwatch.GraphWidget({ + title: 'Jobs Succeeded vs Failed', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{LISA/BatchIngestion,DeploymentName,DeploymentStage,JobQueue} MetricName="JobsSucceeded" DeploymentName="${dp}"', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + right: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{LISA/BatchIngestion,DeploymentName,DeploymentStage,JobQueue} MetricName="JobsFailed" DeploymentName="${dp}"', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Jobs started β€” tracks jobs that entered RUNNING state + new cloudwatch.GraphWidget({ + title: 'Jobs Started (Running)', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{LISA/BatchIngestion,DeploymentName,DeploymentStage,JobQueue} MetricName="JobsStarted" DeploymentName="${dp}"', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // Ingestion Lambda errors β€” failures in the submission Lambdas themselves + // (kept as a secondary signal for Lambda-level issues) + new cloudwatch.GraphWidget({ + title: 'Ingestion Lambda Errors', + left: [ + new cloudwatch.MathExpression({ + expression: `SEARCH('{AWS/Lambda,FunctionName} MetricName="Errors" ${dp}-${config.deploymentStage}-ingestion', 'Sum', 300)`, + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + ); + + // ===================================================================== + // Auto Scaling + // ===================================================================== + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **Auto Scaling**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + + // ASG group size β€” instances backing the ECS clusters + new cloudwatch.GraphWidget({ + title: 'ASG Instance Count (by Group)', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{AWS/AutoScaling,AutoScalingGroupName} MetricName="GroupInServiceInstances"\', \'Average\', 300)', + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + + // ASG desired vs in-service + new cloudwatch.GraphWidget({ + title: 'ASG Desired Capacity (by Group)', + left: [ + new cloudwatch.MathExpression({ + expression: 'SEARCH(\'{AWS/AutoScaling,AutoScalingGroupName} MetricName="GroupDesiredCapacity"\', \'Average\', 300)', + label: '', + period: Duration.minutes(5), + }), + ], + width: 12, + height: 6, + }), + ); + + // ===================================================================== + // Alarms + // ===================================================================== + // NOTE: ALB alarms (unhealthy hosts, 5xx errors, connection errors, + // latency, rejected connections) were removed because: + // 1. Model ALB dimensions (TargetGroup/LoadBalancer) are dynamic and + // unknown at deploy time β€” dimensionless metrics return no data. + // 2. CloudWatch does not support SEARCH expressions in Metric Alarms. + // ALB health is monitored via the SEARCH-based dashboard widgets above. + const alarmPrefix = `${dp}-${config.deploymentStage}-LISA`; + + // Batch ingestion job failures β€” from custom metric published by + // EventBridge β†’ Lambda when Batch jobs enter FAILED state. + const batchJobFailuresAlarm = new cloudwatch.Alarm(this, 'BatchJobFailuresAlarm', { + alarmName: `${alarmPrefix}-BatchJobFailures`, + alarmDescription: 'One or more batch ingestion jobs have failed. Check AWS Batch console and CloudWatch Logs for the failed job details.', + metric: new cloudwatch.Metric({ + namespace: 'LISA/BatchIngestion', + metricName: 'JobsFailed', + dimensionsMap: { + DeploymentName: dp, + DeploymentStage: config.deploymentStage, + }, + statistic: 'Sum', + period: Duration.minutes(5), + }), + threshold: 0, + comparisonOperator: cloudwatch.ComparisonOperator.GREATER_THAN_THRESHOLD, + evaluationPeriods: 1, + treatMissingData: cloudwatch.TreatMissingData.NOT_BREACHING, + }); + + // Add alarm status widgets to the dashboard + dashboard.addWidgets( + new cloudwatch.TextWidget({ + markdown: '## **Alarm Status**', + width: 24, + height: 1, + background: cloudwatch.TextWidgetBackground.TRANSPARENT, + }), + new cloudwatch.AlarmStatusWidget({ + title: 'Model Health Alarms', + alarms: [ + batchJobFailuresAlarm, + ], + width: 24, + height: 4, + }), + ); + } +} diff --git a/lib/models/docker-image-builder.ts b/lib/models/docker-image-builder.ts index 6d484d244..62bfc8212 100644 --- a/lib/models/docker-image-builder.ts +++ b/lib/models/docker-image-builder.ts @@ -26,7 +26,7 @@ import { } from 'aws-cdk-lib/aws-iam'; import { Code, Function } from 'aws-cdk-lib/aws-lambda'; import { Duration, RemovalPolicy, Stack } from 'aws-cdk-lib'; -import { BlockPublicAccess, Bucket, BucketEncryption } from 'aws-cdk-lib/aws-s3'; +import { BlockPublicAccess, Bucket, BucketEncryption, IBucket } from 'aws-cdk-lib/aws-s3'; import { BucketDeployment, Source } from 'aws-cdk-lib/aws-s3-deployment'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { createCdkId } from '../core/utils'; @@ -35,9 +35,9 @@ import { Vpc } from '../networking/vpc'; import { Roles } from '../core/iam/roles'; import { getPythonRuntime } from '../api-base/utils'; import { ECS_MODEL_PATH, LAMBDA_PATH } from '../util'; -import { StringParameter } from 'aws-cdk-lib/aws-ssm'; export type DockerImageBuilderProps = BaseProps & { + bucketAccessLogsBucket: IBucket; ecrUri: string; mountS3DebUrl: string; securityGroups: ISecurityGroup[]; @@ -52,11 +52,7 @@ export class DockerImageBuilder extends Construct { const stackName = Stack.of(scope).stackName; - const { config } = props; - - const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', - StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) - ); + const { bucketAccessLogsBucket, config } = props; const ec2DockerBucket = new Bucket(this, createCdkId([stackName, 'docker-image-builder-ec2-bucket']), { enforceSSL: true, diff --git a/lib/models/litellm-sync.ts b/lib/models/litellm-sync.ts new file mode 100644 index 000000000..d9af7b5d8 --- /dev/null +++ b/lib/models/litellm-sync.ts @@ -0,0 +1,124 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +import { CustomResource, Duration } from 'aws-cdk-lib'; +import { + Effect, + ManagedPolicy, + PolicyStatement, + Role, + ServicePrincipal, +} from 'aws-cdk-lib/aws-iam'; +import { Code, Function, ILayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { ITable } from 'aws-cdk-lib/aws-dynamodb'; +import { StringParameter } from 'aws-cdk-lib/aws-ssm'; +import { Provider } from 'aws-cdk-lib/custom-resources'; +import { Construct } from 'constructs'; + +import { getPythonRuntime } from '../api-base/utils'; +import { APP_MANAGEMENT_KEY, BaseProps } from '../schema'; +import { Vpc } from '../networking/vpc'; +import { LAMBDA_PATH } from '../util'; + +export type LiteLLMSyncConstructProps = { + modelTable: ITable; + lambdaLayers: ILayerVersion[]; + vpc: Vpc; + securityGroups: ISecurityGroup[]; +} & BaseProps; + +/** + * Construct that creates a Lambda custom resource to sync models from DynamoDB to LiteLLM. + * This is triggered on every deployment to ensure all models in the Models DynamoDB table + * are registered in LiteLLM after the database is created or updated. + */ +export class LiteLLMSyncConstruct extends Construct { + constructor (scope: Construct, id: string, props: LiteLLMSyncConstructProps) { + super(scope, id); + + const { config, modelTable, lambdaLayers, vpc, securityGroups } = props; + const lambdaPath = config.lambdaPath || LAMBDA_PATH; + + const managementKeyName = StringParameter.valueForStringParameter( + this, + `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}` + ); + + const litellmSyncRole = new Role(this, 'LiteLLMModelSyncRole', { + assumedBy: new ServicePrincipal('lambda.amazonaws.com'), + managedPolicies: [ + ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaVPCAccessExecutionRole'), + ], + }); + + // Grant permissions to read/update the specific model table + litellmSyncRole.addToPrincipalPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['dynamodb:Scan', 'dynamodb:GetItem', 'dynamodb:UpdateItem'], + resources: [modelTable.tableArn], + })); + + // Grant access to SSM parameters + litellmSyncRole.addToPrincipalPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['ssm:GetParameter'], + resources: [`arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/*`], + })); + + // Grant access to management key secret (scoped to the specific secret name) + litellmSyncRole.addToPrincipalPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['secretsmanager:GetSecretValue'], + resources: [`arn:${config.partition}:secretsmanager:${config.region}:${config.accountNumber}:secret:${managementKeyName}*`], + })); + + // Grant IAM access for SSL cert validation + litellmSyncRole.addToPrincipalPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['iam:GetServerCertificate'], + resources: ['*'], + })); + + const litellmModelSyncLambda = new Function(this, 'LiteLLMModelSync', { + runtime: getPythonRuntime(), + handler: 'models.litellm_model_sync.handler', + code: Code.fromAsset(lambdaPath), + layers: lambdaLayers, + environment: { + MODEL_TABLE_NAME: modelTable.tableName, + MANAGEMENT_KEY_NAME: managementKeyName, + LISA_API_URL_PS_NAME: `${config.deploymentPrefix}/lisaServeRestApiUri`, + REST_API_VERSION: 'v2', + RESTAPI_SSL_CERT_ARN: config.restApiConfig?.sslCertIamArn ?? '', + }, + role: litellmSyncRole, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + timeout: Duration.minutes(10), + description: 'Sync all models from DynamoDB to LiteLLM when the LiteLLM database is created or updated', + }); + + const syncProvider = new Provider(this, 'LiteLLMModelSyncProvider', { + onEventHandler: litellmModelSyncLambda, + }); + + new CustomResource(this, 'LiteLLMModelSyncResource', { + serviceToken: syncProvider.serviceToken, + properties: { timestamp: new Date().toISOString() }, // Force re-run on every deployment + }); + } +} diff --git a/lib/models/model-api.ts b/lib/models/model-api.ts index 4d0b2b2e3..dfb2a5003 100644 --- a/lib/models/model-api.ts +++ b/lib/models/model-api.ts @@ -38,6 +38,7 @@ import { Provider } from 'aws-cdk-lib/custom-resources'; import { Construct } from 'constructs'; import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../api-base/utils'; +import { getAuditLoggingEnv, LISA_AUDIT_API_GATEWAY_BASE_PATH } from '../api-base/auditEnv'; import { APP_MANAGEMENT_KEY, BaseProps } from '../schema'; import { Vpc } from '../networking/vpc'; @@ -47,10 +48,12 @@ import { DeleteModelStateMachine } from './state-machine/delete-model'; import { AttributeType, BillingMode, ITable, Table, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; import { CreateModelStateMachine } from './state-machine/create-model'; import { UpdateModelStateMachine } from './state-machine/update-model'; +import { IBucket } from 'aws-cdk-lib/aws-s3'; import { Secret } from 'aws-cdk-lib/aws-secretsmanager'; import { createCdkId, createLambdaRole } from '../core/utils'; import { Roles } from '../core/iam/roles'; import { LAMBDA_PATH } from '../util'; +import { LiteLLMSyncConstruct } from './litellm-sync'; /** * Properties for ModelsApi Construct. @@ -62,6 +65,7 @@ import { LAMBDA_PATH } from '../util'; */ type ModelsApiProps = BaseProps & { authorizer?: IAuthorizer; + bucketAccessLogsBucket: IBucket; guardrailsTable?: ITable; lisaServeEndpointUrlPs?: StringParameter; restApiId: string; @@ -77,7 +81,7 @@ export class ModelsApi extends Construct { constructor (scope: Construct, id: string, props: ModelsApiProps) { super(scope, id); - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + const { authorizer, bucketAccessLogsBucket, config, restApiId, rootResourceId, securityGroups, vpc } = props; // Use guardrailsTable passed from serve stack, or fall back to SSM parameter lookup for backward compatibility const guardrailsTable = props.guardrailsTable ?? (() => { @@ -150,6 +154,7 @@ export class ModelsApi extends Construct { }); const dockerImageBuilder = new DockerImageBuilder(this, 'docker-image-builder', { + bucketAccessLogsBucket, ecrUri: ecsModelBuildRepo.repositoryUri, mountS3DebUrl: config.mountS3DebUrl!, config: config, @@ -288,6 +293,8 @@ export class ModelsApi extends Construct { ADMIN_GROUP: config.authConfig?.adminGroup || '', MODELS_BUCKET_NAME: config.s3BucketModels, MANAGEMENT_KEY_NAME: managementKeyName, + ...getAuditLoggingEnv(config), + [LISA_AUDIT_API_GATEWAY_BASE_PATH]: '/models', // SSM parameter names for RAG tables (optional - only exist if RAG is deployed) ...(config.deployRag && { LISA_RAG_VECTOR_STORE_TABLE_PS_NAME: `${config.deploymentPrefix}/ragVectorStoreTableName`, @@ -560,6 +567,15 @@ export class ModelsApi extends Construct { properties: {}, }); + // Sync models from DynamoDB to LiteLLM on every deployment + new LiteLLMSyncConstruct(this, 'LiteLLMSync', { + config, + modelTable, + lambdaLayers, + vpc, + securityGroups, + }); + } /** diff --git a/lib/models/modelsApiConstruct.ts b/lib/models/modelsApiConstruct.ts index f1a6f8196..43dbdbc04 100644 --- a/lib/models/modelsApiConstruct.ts +++ b/lib/models/modelsApiConstruct.ts @@ -19,6 +19,7 @@ import { Stack, StackProps } from 'aws-cdk-lib'; import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; import { ITable } from 'aws-cdk-lib/aws-dynamodb'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { IBucket } from 'aws-cdk-lib/aws-s3'; import { Construct } from 'constructs'; import { Vpc } from '../networking/vpc'; @@ -35,6 +36,7 @@ export type LisaModelsApiProps = BaseProps & rootResourceId: string; securityGroups: ISecurityGroup[]; vpc: Vpc; + bucketAccessLogsBucket: IBucket; }; /** @@ -49,11 +51,12 @@ export class LisaModelsApiConstruct extends Construct { constructor (scope: Stack, id: string, props: LisaModelsApiProps) { super(scope, id); - const { authorizer, config, guardrailsTable, lisaServeEndpointUrlPs, restApiId, rootResourceId, securityGroups, vpc } = props; + const { authorizer, bucketAccessLogsBucket, config, guardrailsTable, lisaServeEndpointUrlPs, restApiId, rootResourceId, securityGroups, vpc } = props; // Add REST API Lambdas to APIGW new ModelsApi(scope, 'ModelsApi', { authorizer, + bucketAccessLogsBucket, config, guardrailsTable, lisaServeEndpointUrlPs, diff --git a/lib/models/state-machine/create-model.ts b/lib/models/state-machine/create-model.ts index bed75bec2..40323f722 100644 --- a/lib/models/state-machine/create-model.ts +++ b/lib/models/state-machine/create-model.ts @@ -303,6 +303,9 @@ export class CreateModelStateMachine extends Construct { // State Machine definition setModelToCreating.next(createModelInfraChoice); + setModelToCreating.addCatch(handleFailureState, { + errors: ['States.ALL'], + }); createModelInfraChoice .when(Condition.booleanEquals('$.create_infra', true), startCopyDockerImage) .otherwise(addModelToLitellm); @@ -310,7 +313,7 @@ export class CreateModelStateMachine extends Construct { // Check if we need to poll for docker image or skip directly to stack creation startCopyDockerImage.next(checkImageTypeChoice); startCopyDockerImage.addCatch(handleFailureState, { // fail if ECR image verification fails - errors: ['States.TaskFailed'], + errors: ['States.ALL'], }); checkImageTypeChoice .when(Condition.stringEquals('$.image_info.image_status', 'prebuilt'), startCreateStack) @@ -319,7 +322,7 @@ export class CreateModelStateMachine extends Construct { // poll ECR image copy status loop pollDockerImageAvailable.next(pollDockerImageChoice); pollDockerImageAvailable.addCatch(handleFailureState, { // fail if exception thrown from code - errors: ['MaxPollsExceededException'], + errors: ['States.ALL'], }); pollDockerImageChoice .when(Condition.booleanEquals('$.continue_polling_docker', true), waitBeforePollingDockerImage) @@ -329,14 +332,11 @@ export class CreateModelStateMachine extends Construct { // poll CloudFormation stack status loop startCreateStack.next(pollCreateStack); startCreateStack.addCatch(handleFailureState, { // fail if CDK failed to create model stack - errors: ['StackFailedToCreateException'] + errors: ['States.ALL'] }); pollCreateStack.next(pollCreateStackChoice); pollCreateStack.addCatch(handleFailureState, { // fail if model failed or failed to create in time - errors: [ - 'MaxPollsExceededException', - 'UnexpectedCloudFormationStateException', - ], + errors: ['States.ALL'], }); pollCreateStackChoice .when(Condition.booleanEquals('$.continue_polling_stack', true), waitBeforePollingCreateStack) @@ -345,6 +345,9 @@ export class CreateModelStateMachine extends Construct { // Poll for model instances to be healthy before proceeding pollModelReady.next(pollModelReadyChoice); + pollModelReady.addCatch(handleFailureState, { + errors: ['States.ALL'], + }); pollModelReadyChoice .when(Condition.booleanEquals('$.continue_polling_capacity', true), waitBeforePollingModelReady) .otherwise(createSchedule); @@ -352,10 +355,19 @@ export class CreateModelStateMachine extends Construct { // Create schedule after model is ready createSchedule.next(addModelToLitellm); + createSchedule.addCatch(handleFailureState, { + errors: ['States.ALL'], + }); // Enrich context window after model is added to LiteLLM (non-blocking) addModelToLitellm.next(enrichContextWindow); + addModelToLitellm.addCatch(handleFailureState, { + errors: ['States.ALL'], + }); enrichContextWindow.next(checkGuardrailsChoice); + enrichContextWindow.addCatch(handleFailureState, { + errors: ['States.ALL'], + }); // Check for guardrails and add them if present checkGuardrailsChoice @@ -366,7 +378,7 @@ export class CreateModelStateMachine extends Construct { handleFailureState.next(failState); addGuardrailsToLitellm.next(successState); addGuardrailsToLitellm.addCatch(handleFailureState, { // fail if guardrail creation fails - errors: ['States.TaskFailed'], + errors: ['States.ALL'], }); const stateMachine = new StateMachine(this, 'CreateModelSM', { diff --git a/lib/models/state-machine/delete-model.ts b/lib/models/state-machine/delete-model.ts index 3f07d25f3..075a57417 100644 --- a/lib/models/state-machine/delete-model.ts +++ b/lib/models/state-machine/delete-model.ts @@ -21,6 +21,7 @@ import { Choice, Condition, DefinitionBody, + Fail, StateMachine, Succeed, Wait, @@ -173,7 +174,25 @@ export class DeleteModelStateMachine extends Construct { outputPath: OUTPUT_PATH, }); + const handleFailure = new LambdaInvoke(this, 'HandleFailure', { + lambdaFunction: new Function(this, 'HandleFailureFunc', { + runtime: getPythonRuntime(), + handler: 'models.state_machine.delete_model.handle_failure', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + const successState = new Succeed(this, 'DeleteSuccess'); + const failState = new Fail(this, 'DeleteFailed'); const deleteStackChoice = new Choice(this, 'DeleteStackChoice'); const pollDeleteStackChoice = new Choice(this, 'PollDeleteStackChoice'); @@ -183,15 +202,35 @@ export class DeleteModelStateMachine extends Construct { // State Machine definition setModelToDeleting.next(deleteFromLitellm); + setModelToDeleting.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); deleteFromLitellm.next(deleteGuardrails); + deleteFromLitellm.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); deleteGuardrails.next(deleteStackChoice); + deleteGuardrails.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); deleteStackChoice .when(Condition.isNotNull('$.cloudformation_stack_arn'), deleteStack) .otherwise(deleteFromDdb); deleteStack.next(monitorDeleteStack); + deleteStack.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); monitorDeleteStack.next(pollDeleteStackChoice); + monitorDeleteStack.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); waitBeforePollingStackStatus.next(monitorDeleteStack); @@ -201,6 +240,11 @@ export class DeleteModelStateMachine extends Construct { deleteFromDdb.next(successState); + deleteFromDdb.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); + handleFailure.next(failState); const stateMachine = new StateMachine(this, 'DeleteModelSM', { definitionBody: DefinitionBody.fromChainable(setModelToDeleting), diff --git a/lib/models/state-machine/update-model.ts b/lib/models/state-machine/update-model.ts index 7b39ccaf3..8a3b04380 100644 --- a/lib/models/state-machine/update-model.ts +++ b/lib/models/state-machine/update-model.ts @@ -28,6 +28,7 @@ import { Choice, Condition, DefinitionBody, + Fail, StateMachine, Succeed, Wait, @@ -184,8 +185,26 @@ export class UpdateModelStateMachine extends Construct { outputPath: OUTPUT_PATH, }); + const handleFailure = new LambdaInvoke(this, 'HandleFailure', { + lambdaFunction: new Function(this, 'HandleFailureFunc', { + runtime: getPythonRuntime(), + handler: 'models.state_machine.update_model.handle_failure', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + // terminal states const successState = new Succeed(this, 'UpdateSuccess'); + const failState = new Fail(this, 'UpdateFailed'); // choice states const hasEcsUpdateChoice = new Choice(this, 'HasEcsUpdateChoice'); @@ -207,6 +226,10 @@ export class UpdateModelStateMachine extends Construct { // State Machine definition handleJobIntake.next(hasEcsUpdateChoice); + handleJobIntake.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); // ECS update flow hasEcsUpdateChoice @@ -214,7 +237,15 @@ export class UpdateModelStateMachine extends Construct { .otherwise(hasGuardrailsUpdateChoice); handleEcsUpdate.next(handlePollEcsDeployment); + handleEcsUpdate.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); handlePollEcsDeployment.next(pollEcsDeploymentChoice); + handlePollEcsDeployment.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); pollEcsDeploymentChoice .when(Condition.booleanEquals('$.should_continue_ecs_polling', true), waitBeforePollEcsDeployment) .otherwise(hasGuardrailsUpdateChoice); @@ -226,6 +257,10 @@ export class UpdateModelStateMachine extends Construct { .otherwise(hasCapacityUpdateChoice); handleUpdateGuardrails.next(hasCapacityUpdateChoice); + handleUpdateGuardrails.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); // Existing capacity update flow hasCapacityUpdateChoice @@ -233,6 +268,10 @@ export class UpdateModelStateMachine extends Construct { .otherwise(handleFinishUpdate); handlePollCapacity.next(pollAsgChoice); + handlePollCapacity.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); pollAsgChoice.when(Condition.booleanEquals('$.should_continue_capacity_polling', true), waitBeforePollAsg) .otherwise(waitBeforeModelAvailable); waitBeforePollAsg.next(handlePollCapacity); @@ -240,6 +279,11 @@ export class UpdateModelStateMachine extends Construct { waitBeforeModelAvailable.next(handleFinishUpdate); handleFinishUpdate.next(successState); + handleFinishUpdate.addCatch(handleFailure, { + errors: ['States.ALL'], + resultPath: '$.error', + }); + handleFailure.next(failState); const stateMachine = new StateMachine(this, 'UpdateModelSM', { definitionBody: DefinitionBody.fromChainable(handleJobIntake), diff --git a/lib/rag/api/repository.ts b/lib/rag/api/repository.ts index 3664ad757..444edb606 100644 --- a/lib/rag/api/repository.ts +++ b/lib/rag/api/repository.ts @@ -85,7 +85,7 @@ export class RepositoryApi extends Construct { method: 'GET', environment: { ...baseEnvironment, - }, + } }, { name: 'list_status', diff --git a/lib/rag/ingestion/ingestion-job-construct.ts b/lib/rag/ingestion/ingestion-job-construct.ts index 67b06958e..ea5f899e8 100644 --- a/lib/rag/ingestion/ingestion-job-construct.ts +++ b/lib/rag/ingestion/ingestion-job-construct.ts @@ -28,6 +28,8 @@ import * as iam from 'aws-cdk-lib/aws-iam'; import * as batch from 'aws-cdk-lib/aws-batch'; import * as ecs from 'aws-cdk-lib/aws-ecs'; import * as dynamodb from 'aws-cdk-lib/aws-dynamodb'; +import * as events from 'aws-cdk-lib/aws-events'; +import * as targets from 'aws-cdk-lib/aws-events-targets'; import * as lambda from 'aws-cdk-lib/aws-lambda'; import { getPythonRuntime } from '../../api-base/utils'; import { Vpc } from '../../networking/vpc'; @@ -107,12 +109,12 @@ export class IngestionJobConstruct extends Construct { // AWS Batch Fargate compute environment for running ingestion jobs const maxvCpus = this.getMaxCpus(vpc); const computeEnv = new batch.FargateComputeEnvironment(this, 'IngestionJobFargateEnv', { + computeEnvironmentName: `${config.deploymentName}-${config.deploymentStage}-ingestion-job-compute`, vpc: vpc.vpc, vpcSubnets: vpc.subnetSelection, maxvCpus: maxvCpus, }); - // AWS Batch job queue that uses the Fargate compute environment const jobQueue = new batch.JobQueue(this, 'IngestionJobQueue', { computeEnvironments: [ { @@ -296,5 +298,45 @@ export class IngestionJobConstruct extends Construct { principal: new iam.ServicePrincipal('events.amazonaws.com'), action: 'lambda:InvokeFunction' }); + + // EventBridge rule to capture Batch job state changes and publish custom CloudWatch metrics. + // AWS Batch does not publish job-level metrics to CloudWatch natively, so we use + // EventBridge job state change events as the source of truth. This captures all + // ingestion jobs regardless of trigger (S3 event, scheduled, or manual upload). + const batchJobMetricLambda = new lambda.Function(this, 'BatchJobMetricPublisher', { + functionName: `${config.deploymentName}-${config.deploymentStage}-batch-job-metric`, + runtime: lambda.Runtime.PYTHON_3_13, + handler: 'batch_job_metric.handler', + code: lambda.Code.fromAsset(path.join(__dirname, '../../../lambda/metrics')), + environment: { + METRICS_NAMESPACE: 'LISA/BatchIngestion', + DEPLOYMENT_NAME: config.deploymentName, + DEPLOYMENT_STAGE: config.deploymentStage, + JOB_QUEUE_LABEL: `${config.deploymentName}-${config.deploymentStage}-ingestion-job`, + }, + timeout: Duration.seconds(30), + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: [vpc.securityGroups.lambdaSg], + }); + + batchJobMetricLambda.addToRolePolicy(new iam.PolicyStatement({ + actions: ['cloudwatch:PutMetricData'], + resources: ['*'], + })); + + new events.Rule(this, 'BatchJobStateChangeRule', { + ruleName: `${config.deploymentName}-${config.deploymentStage}-batch-job-state-change`, + description: 'Captures AWS Batch job state changes for ingestion pipeline and publishes CloudWatch metrics', + eventPattern: { + source: ['aws.batch'], + detailType: ['Batch Job State Change'], + detail: { + status: ['SUBMITTED', 'RUNNING', 'SUCCEEDED', 'FAILED'], + jobQueue: [jobQueue.jobQueueArn], + }, + }, + targets: [new targets.LambdaFunction(batchJobMetricLambda)], + }); } } diff --git a/lib/rag/ragConstruct.ts b/lib/rag/ragConstruct.ts index 309a0884d..ff107e18b 100644 --- a/lib/rag/ragConstruct.ts +++ b/lib/rag/ragConstruct.ts @@ -17,7 +17,7 @@ import { CfnOutput, Duration, RemovalPolicy, Stack, StackProps } from 'aws-cdk-l import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; import { ISecurityGroup, Port } from 'aws-cdk-lib/aws-ec2'; import { ILayerVersion, LayerVersion } from 'aws-cdk-lib/aws-lambda'; -import { BlockPublicAccess, Bucket, BucketEncryption, HttpMethods } from 'aws-cdk-lib/aws-s3'; +import { BlockPublicAccess, Bucket, BucketEncryption, HttpMethods, IBucket } from 'aws-cdk-lib/aws-s3'; import * as dynamodb from 'aws-cdk-lib/aws-dynamodb'; import { AttributeType, BillingMode, StreamViewType, Table, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; @@ -29,6 +29,7 @@ import { Layer } from '../core/layers'; import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; import { APP_MANAGEMENT_KEY, BaseProps, Config } from '../schema'; +import { getAuditLoggingEnv } from '../api-base/auditEnv'; import { SecurityGroupEnum } from '../core/iam/SecurityGroups'; import { SecurityGroupFactory } from '../networking/vpc/security-group-factory'; import { Roles } from '../core/iam/roles'; @@ -48,6 +49,7 @@ import { AwsCustomResource, PhysicalResourceId } from 'aws-cdk-lib/custom-resour export type LisaRagProps = { authorizer: IAuthorizer; + bucketAccessLogsBucket: IBucket; endpointUrl?: StringParameter; modelsPs?: StringParameter; restApiId: string; @@ -70,7 +72,7 @@ export class LisaRagConstruct extends Construct { constructor (scope: Stack, id: string, props: LisaRagProps) { super(scope, id); this.scope = scope; - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + const { authorizer, bucketAccessLogsBucket, config, restApiId, rootResourceId, securityGroups, vpc } = props; const endpointUrl = props.endpointUrl ?? StringParameter.fromStringParameterName( scope, @@ -84,10 +86,6 @@ export class LisaRagConstruct extends Construct { `${config.deploymentPrefix}/registeredModels`, ); - const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', - StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) - ); - const bucket = new Bucket(scope, createCdkId(['LISA', 'RAG', config.deploymentName, config.deploymentStage]), { removalPolicy: config.removalPolicy, autoDeleteObjects: config.removalPolicy === RemovalPolicy.DESTROY, @@ -211,6 +209,7 @@ export class LisaRagConstruct extends Construct { const baseEnvironment: Record = { ADMIN_GROUP: config.authConfig!.adminGroup, + RAG_ADMIN_GROUP: config.authConfig!.ragAdminGroup, BUCKET_NAME: bucket.bucketName, CHUNK_OVERLAP: config.ragFileProcessingConfig!.chunkOverlap.toString(), CHUNK_SIZE: config.ragFileProcessingConfig!.chunkSize.toString(), @@ -226,6 +225,7 @@ export class LisaRagConstruct extends Construct { REGISTERED_REPOSITORIES_PS: `${config.deploymentPrefix}/registeredRepositories`, REST_API_VERSION: 'v2', TIKTOKEN_CACHE_DIR: '/tmp', + ...getAuditLoggingEnv(config), }; // Add REST API SSL Cert ARN if it exists to be used to verify SSL calls to REST API diff --git a/lib/rag/state_machine/pipeline-state-machine.ts b/lib/rag/state_machine/pipeline-state-machine.ts index 4a7310b8f..45fb7dabc 100644 --- a/lib/rag/state_machine/pipeline-state-machine.ts +++ b/lib/rag/state_machine/pipeline-state-machine.ts @@ -280,7 +280,7 @@ export class PipelineStateMachine extends Construct { // Create the state machine this.stateMachine = new sfn.StateMachine(this, 'PipelineStateMachine', { stateMachineName: `${config.deploymentName}-${config.deploymentStage}-pipeline-state-machine`, - definition, + definitionBody: sfn.DefinitionBody.fromChainable(definition), role: stateMachineRole, timeout: Duration.minutes(30), tracingEnabled: true diff --git a/lib/rag/vector-store/state_machine/create-store.ts b/lib/rag/vector-store/state_machine/create-store.ts index 8fc0d9fb0..334f8dbd0 100644 --- a/lib/rag/vector-store/state_machine/create-store.ts +++ b/lib/rag/vector-store/state_machine/create-store.ts @@ -194,7 +194,7 @@ export class CreateStoreStateMachine extends Construct { // Create a new state machine using the definition and roles specified this.stateMachine = new sfn.StateMachine(this, 'CreateStoreStateMachine', { - definition, + definitionBody: sfn.DefinitionBody.fromChainable(definition), role, stateMachineType: sfn.StateMachineType.STANDARD, removalPolicy: config.removalPolicy diff --git a/lib/rag/vector-store/state_machine/delete-store.ts b/lib/rag/vector-store/state_machine/delete-store.ts index 9a21fb3bf..dc5dd3dfe 100644 --- a/lib/rag/vector-store/state_machine/delete-store.ts +++ b/lib/rag/vector-store/state_machine/delete-store.ts @@ -263,7 +263,7 @@ export class DeleteStoreStateMachine extends Construct { // Create a new state machine using the definition and roles specified this.stateMachine = new sfn.StateMachine(this, 'DeleteStoreStateMachine', { - definition, + definitionBody: sfn.DefinitionBody.fromChainable(definition), role, stateMachineType: sfn.StateMachineType.STANDARD, removalPolicy: config.removalPolicy diff --git a/lib/schema/configSchema.ts b/lib/schema/configSchema.ts index c7de9541f..47f965977 100644 --- a/lib/schema/configSchema.ts +++ b/lib/schema/configSchema.ts @@ -717,6 +717,7 @@ const AuthConfigSchema = z.object({ adminGroup: z.string().default('').describe('Name of the admin group.'), userGroup: z.string().default('').describe('Name of the user group.'), apiGroup: z.string().default('').describe('Name of the API group for API token access.'), + ragAdminGroup: z.string().default('').describe('Name of the RAG admin group for RAG management access.'), jwtGroupsProperty: z.string().default('').describe('Name of the JWT groups property.'), additionalScopes: z.array(z.string()).default([]).describe('Additional JWT scopes to request.'), }).describe('Configuration schema for authorization.'); @@ -747,6 +748,25 @@ const FastApiContainerConfigSchema = z.object({ ), }).describe('Configuration schema for REST API.'); +/** Custom domain / TLS for the MCP Workbench ALB only (separate from Serve’s `restApiConfig`). */ +const McpWorkbenchRestApiConfigSchema = z + .object({ + domainName: z + .string() + .nullish() + .default(null) + .describe( + 'Hostname for the MCP Workbench ALB (HTTPS listener and SSM …/mcpWorkbench/endpoint). Configure here for the same YAML shape as `restApiConfig.domainName` for LISA Serve.', + ), + sslCertIamArn: z + .string() + .nullish() + .default(null) + .describe( + 'ACM certificate ARN for the MCP Workbench ALB. Same role as `restApiConfig.sslCertIamArn` for Serve; if omitted, falls back to `mcpWorkbenchEcsConfig.sslCertIamArn` then `restApiConfig.sslCertIamArn`.', + ), + }) + .describe('Optional load balancer domain and TLS for MCP Workbench (parallel to `restApiConfig` for LISA Serve).'); const RagFileProcessingConfigSchema = z.object({ chunkSize: z.number().min(100).max(10000), @@ -776,6 +796,29 @@ const ApiGatewayConfigSchema = z .optional() .describe('Configuration schema for API Gateway Endpoint'); +const AuditLoggingConfigSchema = z + .object({ + enabled: z.boolean().default(false).describe('Whether to enable audit logging for opted-in API Gateway paths.'), + auditAll: z.boolean().default(false).describe('If true, enable audit logging for all API Gateway paths.'), + enabledPaths: z + .array(z.string().min(1)) + .default([]) + .describe('API Gateway path prefixes (e.g. "/session") to include in audit logging. Prefix match is used.'), + maxRequestBodyBytes: z + .number() + .int() + .min(1) + .default(20000) + .describe('Maximum request body bytes to include in audit logs (oversized bodies are replaced with a placeholder).'), + includeJsonBody: z + .boolean() + .default(false) + .describe( + 'When true, emit AUDIT_API_GATEWAY_REQUEST_BODY for opted-in paths. When false (default), request bodies are never included in audit logs.' + ), + }) + .describe('Audit logging configuration for API Gateway request auditing.'); + const LiteLLMConfig = z.object({ db_key: z.string().refine( (key) => key.startsWith('sk-'), // key needed for model management actions @@ -786,6 +829,9 @@ const LiteLLMConfig = z.object({ litellm_settings: z.any().optional(), router_settings: z.any().optional(), environment_variables: z.any().optional(), + // LiteLLM callback-specific settings (e.g., OpenTelemetry message logging toggles). + // This must be allowed here so Zod doesn't strip it at deploy time. + callback_settings: z.any().optional(), }) .describe('Core LiteLLM configuration - see https://litellm.vercel.app/docs/proxy/configs#all-settings for more details about each field.'); @@ -837,6 +883,9 @@ export const RawConfigObject = z.object({ partition: z.string().default('aws').describe('AWS partition for deployment.'), domain: z.string().default('amazonaws.com').describe('AWS domain for deployment'), restApiConfig: FastApiContainerConfigSchema.describe('Image override for Rest API'), + mcpWorkbenchRestApiConfig: McpWorkbenchRestApiConfigSchema.optional().describe( + 'Custom domain and certificate for the MCP Workbench ALB. Same usage as `restApiConfig.domainName` / `sslCertIamArn` for LISA Serve.', + ), mcpWorkbenchConfig: ImageAssetSchema.optional().describe('Image override for MCP Workbench'), mcpWorkbenchBuildConfig: z.object({ S6_OVERLAY_NOARCH_SOURCE: z.string().optional().describe('Override the URL with a path relative to the build directory for the architecture independent S6 overlay tar.xz.'), @@ -876,9 +925,43 @@ export const RawConfigObject = z.object({ useCustomBranding: z.boolean().optional().describe('Whether to use custom branding assets in the UI.'), customDisplayName: z.string().optional().describe('Custom display name to replace "LISA" branding in titles and descriptions. Requires "useCustomBranding" to be enabled.'), deployMetrics: z.boolean().default(true).describe('Whether to deploy Metrics stack.'), + deployHealthDashboard: z.boolean().default(true).describe('Whether to deploy the ECS Model Health CloudWatch dashboard for monitoring model container health, errors, latency, and resource utilization.'), deployMcp: z.boolean().default(true).describe('Whether to deploy LISA MCP stack.'), deployServe: z.boolean().default(true).describe('Whether to deploy LISA Serve stack.'), deployMcpWorkbench: z.boolean().default(true).describe('Whether to deploy MCP Workbench stack.'), + mcpWorkbenchEcsConfig: z + .object({ + instanceType: z.enum(VALID_INSTANCE_KEYS).optional().describe('EC2 instance type for the MCP Workbench ECS cluster.'), + blockDeviceVolumeSize: z.number().min(30).optional().describe('Root volume size (GiB) for cluster instances.'), + minCapacity: z.number().min(1).optional().describe('Minimum ASG capacity for the MCP Workbench cluster.'), + maxCapacity: z.number().min(1).optional().describe('Maximum ASG capacity for the MCP Workbench cluster.'), + cooldown: z.number().min(1).optional().describe('Cooldown (seconds) between scaling activities.'), + domainName: z + .string() + .nullish() + .describe( + 'Optional hostname for the MCP Workbench ALB (same effect as `mcpWorkbenchRestApiConfig.domainName`; use that block for parity with `restApiConfig`). ' + + 'If omitted and restApiConfig.domainName is set, a default is derived (e.g. first label `lisa-serve` β†’ `lisa-mcp-workbench`, or `serve` β†’ `mcp-workbench`) so the workbench does not reuse the Serve API hostname. ' + + 'Otherwise the published endpoint uses this ALB’s DNS name. You must create DNS for the chosen or derived name pointing at the MCP Workbench ALB.', + ), + sslCertIamArn: z + .string() + .nullish() + .describe( + 'Optional ACM certificate ARN for the MCP Workbench ALB HTTPS listener (same effect as `mcpWorkbenchRestApiConfig.sslCertIamArn`). If omitted, inherits restApiConfig.sslCertIamArn when set; ' + + 'otherwise the workbench ALB uses HTTP on port 80 (browser MCP from an https UI will fail). Set explicitly when using a dedicated workbench hostname.', + ), + }) + .optional() + .describe( + 'Optional sizing and load-balancer settings for the MCP Workbench ECS cluster. The workbench HTTP server always runs on its own ECS cluster and ALB (separate from the Serve REST API).', + ), + mcpWorkbenchCorsOrigins: z + .string() + .default('*') + .describe( + 'Comma-separated CORS allowed origins for the MCP Workbench HTTP server container (CORS_ORIGINS). Use * to allow any browser origin (typical when the UI is served from varying hosts or ports). More restrictive deployments can list explicit origins.', + ), logLevel: z.union([z.literal('DEBUG'), z.literal('INFO'), z.literal('WARNING'), z.literal('ERROR')]) .default('DEBUG') .describe('Log level for application.'), @@ -925,6 +1008,7 @@ export const RawConfigObject = z.object({ }) .optional() .describe('Configuration for local Lambda layer code'), + auditLoggingConfig: AuditLoggingConfigSchema.optional(), permissionsBoundaryAspect: z .object({ permissionsBoundaryPolicyName: z.string(), diff --git a/lib/serve/ecs-model/embedding/instructor/Dockerfile b/lib/serve/ecs-model/embedding/instructor/Dockerfile index ff18457c3..98f7332c9 100644 --- a/lib/serve/ecs-model/embedding/instructor/Dockerfile +++ b/lib/serve/ecs-model/embedding/instructor/Dockerfile @@ -50,8 +50,8 @@ RUN /opt/conda/bin/conda install s5cmd && \ ARG LOCAL_CODE_PATH WORKDIR ${LOCAL_CODE_PATH} -COPY src/inference.py src/requirements.txt ${LOCAL_CODE_PATH}/ -COPY src/entrypoint.sh entrypoint.sh +COPY embedding/instructor/src/inference.py embedding/instructor/src/requirements.txt ${LOCAL_CODE_PATH}/ +COPY embedding/instructor/src/entrypoint.sh entrypoint.sh RUN chmod +x entrypoint.sh ENTRYPOINT ["./entrypoint.sh"] diff --git a/lib/serve/ecs-model/embedding/tei/Dockerfile b/lib/serve/ecs-model/embedding/tei/Dockerfile index 295e2f88e..6630b9e65 100644 --- a/lib/serve/ecs-model/embedding/tei/Dockerfile +++ b/lib/serve/ecs-model/embedding/tei/Dockerfile @@ -17,19 +17,24 @@ RUN mkdir -p /etc/ssh && \ echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ fi -##### DOWNLOAD MOUNTPOINTS S3 +##### Download S3 mountpoints and boto3 ARG MOUNTS3_DEB_URL ARG MOUNTS3_DEB_SHA256 -RUN apt-get update -y && apt-get upgrade -y && apt-get install -y wget rsync && \ +RUN apt-get update -y && apt-get upgrade -y && \ + apt-get install -y wget rsync python3 python3-pip && \ wget ${MOUNTS3_DEB_URL} -O mount-s3.deb && \ if [ -n "${MOUNTS3_DEB_SHA256}" ]; then \ echo "${MOUNTS3_DEB_SHA256} mount-s3.deb" | sha256sum -c; \ fi && \ apt-get install -y ./mount-s3.deb && \ + pip3 install --no-cache-dir --break-system-packages boto3 && \ rm mount-s3.deb && \ rm -rf /var/lib/apt/lists/* -COPY src/entrypoint.sh ./entrypoint.sh +# Metrics publisher for CloudWatch (scrapes Prometheus /metrics endpoint) +COPY metrics_publisher.py /opt/metrics_publisher.py + +COPY embedding/tei/src/entrypoint.sh ./entrypoint.sh RUN chmod +x entrypoint.sh ENTRYPOINT ["./entrypoint.sh"] diff --git a/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh b/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh index 0e6e7f9d6..19c94f29a 100644 --- a/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh +++ b/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh @@ -186,4 +186,17 @@ echo "Starting TEI with args: ${ADDITIONAL_ARGS}" echo "TEI environment variables:" env | grep -E "^(MAX_CONCURRENT_REQUESTS|MAX_BATCH_TOKENS|MAX_BATCH_REQUESTS|MAX_CLIENT_BATCH_SIZE|TOKENIZATION_WORKERS|REVISION|DTYPE|POOLING|DEFAULT_PROMPT|DENSE_PATH|SERVED_MODEL_NAME|AUTO_TRUNCATE|PAYLOAD_LIMIT|HF_TOKEN|API_KEY|OTLP_ENDPOINT|PROMETHEUS_PORT|CORS_ALLOW_ORIGIN)=" || echo "No TEI environment variables set" +# Start metrics publisher in background (publishes Prometheus metrics to CloudWatch) +# TEI serves Prometheus metrics on the main HTTP server at /metrics (port 8080). +# The --prometheus-port flag controls a separate dedicated endpoint (default 9000) +# which may not be available in all TEI builds, so always scrape from the main port. +if [ -f /opt/metrics_publisher.py ]; then + export METRICS_ENDPOINT="http://localhost:8080/metrics" + export INFERENCE_ENGINE="tei" + echo "Starting metrics publisher daemon (endpoint: ${METRICS_ENDPOINT})..." + python3 /opt/metrics_publisher.py & + METRICS_PID=$! + echo "Metrics publisher started (PID: ${METRICS_PID})" +fi + text-embeddings-router --model-id $LOCAL_MODEL_PATH --port 8080 --json-output ${ADDITIONAL_ARGS} diff --git a/lib/serve/ecs-model/metrics_publisher.py b/lib/serve/ecs-model/metrics_publisher.py new file mode 100644 index 000000000..12bfbaedb --- /dev/null +++ b/lib/serve/ecs-model/metrics_publisher.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +LISA Inference Metrics Publisher + +Background daemon that scrapes Prometheus metrics from inference engine +endpoints (vLLM, TGI, TEI) and publishes them to CloudWatch. + +Environment variables: + METRICS_PUBLISH_INTERVAL - Seconds between scrape/publish cycles (default: 60) + METRICS_ENDPOINT - Prometheus metrics URL (default: http://localhost:8080/metrics) + INFERENCE_ENGINE - Explicit engine type override: vllm, tgi, or tei (default: auto-detect) + CLUSTER_NAME - ECS cluster name (CloudWatch dimension) + SERVICE_NAME - ECS service name (CloudWatch dimension) + MODEL_NAME - Model identifier (CloudWatch dimension) + AWS_REGION - AWS region for CloudWatch API calls + METRICS_NAMESPACE - CloudWatch namespace (default: LISA/InferenceMetrics) +""" + +import json +import logging +import os +import re +import sys +import time +from urllib.error import URLError +from urllib.request import urlopen + +import boto3 +from botocore.config import Config as BotoConfig + +log_level = logging.DEBUG if os.environ.get("DEBUG", "").lower() in ("true", "1", "yes") else logging.INFO +logging.basicConfig( + level=log_level, + format="[metrics_publisher] %(asctime)s %(levelname)s %(message)s", + stream=sys.stdout, +) +log = logging.getLogger("metrics_publisher") + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +PUBLISH_INTERVAL = int(os.environ.get("METRICS_PUBLISH_INTERVAL", "60")) +METRICS_ENDPOINT = os.environ.get("METRICS_ENDPOINT", "http://localhost:8080/metrics") +CLUSTER_NAME = os.environ.get("CLUSTER_NAME", "") +SERVICE_NAME = os.environ.get("SERVICE_NAME", "") +MODEL_NAME = os.environ.get("MODEL_NAME", "") +NAMESPACE = os.environ.get("METRICS_NAMESPACE", "LISA/InferenceMetrics") + +# Engine type is required β€” set by each container's entrypoint.sh +INFERENCE_ENGINE = os.environ.get("INFERENCE_ENGINE", "").lower().strip() +if INFERENCE_ENGINE not in ("vllm", "tgi", "tei"): + log.error( + "INFERENCE_ENGINE environment variable must be set to one of: vllm, tgi, tei (got: %r)", + INFERENCE_ENGINE or "", + ) + sys.exit(1) + +# Metrics we care about, keyed by engine type. +# Each entry maps a Prometheus metric name to a CloudWatch metric name. +VLLM_METRICS = { + "vllm:gpu_cache_usage_perc": "GpuCacheUsagePercent", + "vllm:kv_cache_usage_perc": "KvCacheUsagePercent", # distinct metric for KV cache usage + "vllm:num_requests_running": "RequestsRunning", + "vllm:num_requests_waiting": "RequestsWaiting", + "vllm:num_requests_swapped": "RequestsSwapped", + "vllm:avg_prompt_throughput_toks_per_s": "AvgPromptThroughputToksPerSec", + "vllm:avg_generation_throughput_toks_per_s": "AvgGenerationThroughputToksPerSec", + "vllm:prompt_tokens_total": "PromptTokensTotal", + "vllm:generation_tokens_total": "GenerationTokensTotal", + "vllm:request_success_total": "RequestSuccessTotal", + "vllm:prefix_cache_queries": "PrefixCacheQueries", + "vllm:prefix_cache_hits": "PrefixCacheHits", +} + +# Histogram metrics β€” we extract the _sum and _count to compute averages +VLLM_HISTOGRAM_METRICS = { + "vllm:e2e_request_latency_seconds": "E2ERequestLatencySeconds", + "vllm:time_to_first_token_seconds": "TimeToFirstTokenSeconds", + "vllm:inter_token_latency_seconds": "InterTokenLatencySeconds", + "vllm:request_queue_time_seconds": "RequestQueueTimeSeconds", + "vllm:request_prefill_time_seconds": "RequestPrefillTimeSeconds", + "vllm:request_decode_time_seconds": "RequestDecodeTimeSeconds", +} + +TGI_METRICS = { + "tgi_queue_size": "QueueSize", + "tgi_batch_current_size": "BatchCurrentSize", + "tgi_batch_current_max_tokens": "BatchCurrentMaxTokens", + "tgi_request_count": "RequestCount", + "tgi_request_success": "RequestSuccess", + "tgi_request_failure": "RequestFailure", +} + +TGI_HISTOGRAM_METRICS = { + "tgi_request_duration": "RequestDurationSeconds", + "tgi_request_queue_duration": "QueueDurationSeconds", + "tgi_request_inference_duration": "InferenceDurationSeconds", + "tgi_request_mean_time_per_token_duration": "MeanTimePerTokenSeconds", + "tgi_request_generated_tokens": "GeneratedTokensPerRequest", + "tgi_request_input_length": "InputLengthPerRequest", + "tgi_batch_inference_duration": "BatchInferenceDurationSeconds", +} + +TEI_METRICS = { + "te_queue_size": "QueueSize", + "te_batch_current_size": "BatchCurrentSize", +} + +TEI_HISTOGRAM_METRICS = { + "te_request_duration": "RequestDurationSeconds", + "te_request_tokenization_duration": "TokenizationDurationSeconds", + "te_request_queue_duration": "QueueDurationSeconds", + "te_request_inference_duration": "InferenceDurationSeconds", +} + +# --------------------------------------------------------------------------- +# Prometheus text format parser (minimal, no external deps) +# --------------------------------------------------------------------------- +PROM_LINE_RE = re.compile(r"^(?P[a-zA-Z_:][a-zA-Z0-9_:]*)" r"(?:\{[^}]*\})?\s+" r"(?P[^\s]+)") + + +def parse_prometheus(text: str) -> dict[str, float]: + """Parse Prometheus exposition format into {metric_name: value}.""" + metrics: dict[str, float] = {} + for line in text.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + m = PROM_LINE_RE.match(line) + if m: + try: + val = float(m.group("value")) + name = m.group("name") + # Accumulate (some metrics appear multiple times with different labels) + # For gauges we want the latest; for counters we sum across labels. + # Since we pick specific metrics, simple last-write-wins is fine for gauges, + # and for _total counters we sum. + if name.endswith("_total") or name.endswith("_count") or name.endswith("_sum"): + metrics[name] = metrics.get(name, 0.0) + val + else: + metrics[name] = val + except ValueError: + continue + return metrics + + +def build_metric_data( + metrics: dict[str, float], + engine: str, + dimensions: list[dict], +) -> list[dict]: + """Build CloudWatch MetricData entries from scraped Prometheus metrics.""" + data: list[dict] = [] + + if engine == "vllm": + gauge_map = VLLM_METRICS + hist_map = VLLM_HISTOGRAM_METRICS + elif engine == "tgi": + gauge_map = TGI_METRICS + hist_map = TGI_HISTOGRAM_METRICS + elif engine == "tei": + gauge_map = TEI_METRICS + hist_map = TEI_HISTOGRAM_METRICS + else: + return data + + # Gauge / counter metrics + for prom_name, cw_name in gauge_map.items(): + val = metrics.get(prom_name) + if val is not None: + data.append( + { + "MetricName": cw_name, + "Dimensions": dimensions, + "Value": val, + "Unit": "None", + } + ) + + # Histogram metrics β€” publish average from _sum/_count + for prom_name, cw_name in hist_map.items(): + total = metrics.get(f"{prom_name}_sum") + count = metrics.get(f"{prom_name}_count") + if total is not None and count is not None and count > 0: + # Determine unit: token/length metrics are counts, everything else is seconds + unit = "None" if cw_name.endswith("PerRequest") else "Seconds" + data.append( + { + "MetricName": cw_name, + "Dimensions": dimensions, + "Value": total / count, + "Unit": unit, + } + ) + + # Always publish engine type as a tag via a simple metric + data.append( + { + "MetricName": "MetricsPublisherHeartbeat", + "Dimensions": dimensions, + "Value": 1.0, + "Unit": "None", + } + ) + + return data + + +def publish_loop() -> None: + """Main loop: scrape β†’ parse β†’ publish, repeat.""" + dimensions = [] + if CLUSTER_NAME: + dimensions.append({"Name": "ClusterName", "Value": CLUSTER_NAME}) + if SERVICE_NAME: + dimensions.append({"Name": "ServiceName", "Value": SERVICE_NAME}) + if MODEL_NAME: + dimensions.append({"Name": "ModelName", "Value": MODEL_NAME}) + + if not dimensions: + log.warning("No dimensions configured (CLUSTER_NAME, SERVICE_NAME, MODEL_NAME). Metrics will be dimensionless.") + + boto_config = BotoConfig(retries={"max_attempts": 2, "mode": "standard"}) + region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + cw = boto3.client("cloudwatch", config=boto_config, region_name=region) + + consecutive_failures = 0 + max_failures_before_backoff = 5 + + log.info( + "Starting metrics publisher: endpoint=%s interval=%ds namespace=%s engine=%s dimensions=%s", + METRICS_ENDPOINT, + PUBLISH_INTERVAL, + NAMESPACE, + INFERENCE_ENGINE, + json.dumps(dimensions), + ) + + # Wait for the inference server to start + log.info("Waiting for inference server at %s ...", METRICS_ENDPOINT) + while True: + try: + urlopen(METRICS_ENDPOINT, timeout=5) # nosec B310 + log.info("Inference server is up.") + break + except (URLError, OSError): + time.sleep(10) + + while True: + try: + resp = urlopen(METRICS_ENDPOINT, timeout=10) # nosec B310 + text = resp.read().decode("utf-8", errors="replace") + metrics = parse_prometheus(text) + + metric_data = build_metric_data(metrics, INFERENCE_ENGINE, dimensions) + + if metric_data: + # CloudWatch accepts max 20 metrics per call; batch in chunks of 20 + for i in range(0, len(metric_data), 20): + cw.put_metric_data(Namespace=NAMESPACE, MetricData=metric_data[i : i + 20]) + log.debug("Published %d metrics to %s", len(metric_data), NAMESPACE) + + consecutive_failures = 0 + + except (URLError, OSError) as e: + consecutive_failures += 1 + log.warning("Failed to scrape metrics (attempt %d): %s", consecutive_failures, e) + except Exception as e: + consecutive_failures += 1 + log.error("Error in publish cycle (attempt %d): %s", consecutive_failures, e, exc_info=True) + + # Back off if we keep failing + sleep_time = PUBLISH_INTERVAL + if consecutive_failures > max_failures_before_backoff: + sleep_time = min(PUBLISH_INTERVAL * 4, 300) + + time.sleep(sleep_time) + + +if __name__ == "__main__": + try: + publish_loop() + except KeyboardInterrupt: + log.info("Shutting down metrics publisher.") + except Exception as e: + # Never crash the container β€” just log and exit quietly + log.error("Fatal error in metrics publisher: %s", e, exc_info=True) diff --git a/lib/serve/ecs-model/textgen/tgi/Dockerfile b/lib/serve/ecs-model/textgen/tgi/Dockerfile index 4370882b4..27f991d2d 100644 --- a/lib/serve/ecs-model/textgen/tgi/Dockerfile +++ b/lib/serve/ecs-model/textgen/tgi/Dockerfile @@ -17,18 +17,23 @@ RUN mkdir -p /etc/ssh && \ echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ fi -##### DOWNLOAD MOUNTPOINTS S3 +##### Download S3 mountpoints and boto3 ARG MOUNTS3_DEB_URL ARG MOUNTS3_DEB_SHA256 -RUN apt-get update -y && apt-get upgrade -y && apt-get install -y wget rsync && \ +RUN apt-get update -y && apt-get upgrade -y && \ + apt-get install -y wget rsync python3 python3-pip && \ wget ${MOUNTS3_DEB_URL} -O mount-s3.deb && \ if [ -n "${MOUNTS3_DEB_SHA256}" ]; then \ echo "${MOUNTS3_DEB_SHA256} mount-s3.deb" | sha256sum -c; \ fi && \ apt-get install -y ./mount-s3.deb && \ + pip3 install --no-cache-dir --break-system-packages boto3 && \ rm mount-s3.deb && rm -rf /var/lib/apt/lists/* -COPY src/entrypoint.sh ./entrypoint.sh +# Metrics publisher for CloudWatch (scrapes Prometheus /metrics endpoint) +COPY metrics_publisher.py /opt/metrics_publisher.py + +COPY textgen/tgi/src/entrypoint.sh ./entrypoint.sh RUN chmod +x entrypoint.sh ENTRYPOINT ["./entrypoint.sh"] diff --git a/lib/serve/ecs-model/textgen/tgi/src/entrypoint.sh b/lib/serve/ecs-model/textgen/tgi/src/entrypoint.sh index a0eeb3ded..2878ee660 100644 --- a/lib/serve/ecs-model/textgen/tgi/src/entrypoint.sh +++ b/lib/serve/ecs-model/textgen/tgi/src/entrypoint.sh @@ -78,4 +78,16 @@ echo "Starting TGI with args: ${startArgs[*]}" echo "TGI environment variables:" env | grep -E "^(MAX_CONCURRENT_REQUESTS|MAX_INPUT_LENGTH|MAX_TOTAL_TOKENS|MAX_BATCH_PREFILL_TOKENS|MAX_BATCH_TOTAL_TOKENS|WAITING_SERVED_RATIO|QUANTIZE|DTYPE|TRUST_REMOTE_CODE|REVISION|NUM_SHARD|CUDA_VISIBLE_DEVICES|CUDA_MEMORY_FRACTION|ATTENTION|SPECULATE|ROPE_SCALING|ROPE_FACTOR|JSON_OUTPUT|LOG_LEVEL|OTLP_ENDPOINT|TOKENIZER_CONFIG_PATH|DISABLE_CUSTOM_KERNELS)=" || echo "No TGI environment variables set" +# Start metrics publisher in background (publishes Prometheus metrics to CloudWatch) +# TGI serves Prometheus metrics on the main HTTP server at /metrics (port 8080). +if [ -f /opt/metrics_publisher.py ]; then + PROM_PORT="${PROMETHEUS_PORT:-8080}" + export METRICS_ENDPOINT="http://localhost:${PROM_PORT}/metrics" + export INFERENCE_ENGINE="tgi" + echo "Starting metrics publisher daemon (endpoint: ${METRICS_ENDPOINT})..." + python3 /opt/metrics_publisher.py & + METRICS_PID=$! + echo "Metrics publisher started (PID: ${METRICS_PID})" +fi + text-generation-launcher "${startArgs[@]}" diff --git a/lib/serve/ecs-model/vllm/Dockerfile b/lib/serve/ecs-model/vllm/Dockerfile index bcb9f43d9..e4605c11d 100644 --- a/lib/serve/ecs-model/vllm/Dockerfile +++ b/lib/serve/ecs-model/vllm/Dockerfile @@ -51,7 +51,10 @@ import tiktoken; \ [tiktoken.get_encoding(enc) for enc in tiktoken.list_encoding_names()]" && \ chmod -R 755 ${TIKTOKEN_CACHE_DIR} -COPY src/entrypoint.sh ./entrypoint.sh +# Metrics publisher for CloudWatch (scrapes Prometheus /metrics endpoint) +COPY metrics_publisher.py /opt/metrics_publisher.py + +COPY vllm/src/entrypoint.sh ./entrypoint.sh RUN chmod +x entrypoint.sh ENTRYPOINT ["./entrypoint.sh"] diff --git a/lib/serve/ecs-model/vllm/src/entrypoint.sh b/lib/serve/ecs-model/vllm/src/entrypoint.sh index 4203917ac..8ecfd0a25 100644 --- a/lib/serve/ecs-model/vllm/src/entrypoint.sh +++ b/lib/serve/ecs-model/vllm/src/entrypoint.sh @@ -284,6 +284,14 @@ echo "=== VLLM Environment Variables ===" env | grep -E "^VLLM_" || echo "No VLLM_ environment variables set" echo "===================================" +# Start metrics publisher in background (publishes Prometheus metrics to CloudWatch) +if [ -f /opt/metrics_publisher.py ]; then + export INFERENCE_ENGINE="vllm" + echo "Starting metrics publisher daemon..." + python3 /opt/metrics_publisher.py & + METRICS_PID=$! + echo "Metrics publisher started (PID: ${METRICS_PID})" +fi python3 -m vllm.entrypoints.openai.api_server \ --model ${LOCAL_MODEL_PATH} \ diff --git a/lib/serve/mcp-workbench/README.md b/lib/serve/mcp-workbench/README.md index 6b73d18b4..a381bf54d 100644 --- a/lib/serve/mcp-workbench/README.md +++ b/lib/serve/mcp-workbench/README.md @@ -149,7 +149,7 @@ cors_origins: ["*"] # Advanced CORS settings (optional - will use defaults if not specified) cors_settings: - allow_methods: ["GET", "POST", "OPTIONS"] + allow_methods: ["*"] allow_headers: ["*"] allow_credentials: false expose_headers: [] @@ -260,6 +260,16 @@ This project is designed to work with the existing LISA MCP infrastructure: 4. MCP Workbench reads tools from the mounted location 5. External processes can trigger rescans via HTTP GET requests +### AWS Session Management + +MCP Workbench supports **AWS Sessions**, allowing users to connect their AWS credentials per chat session. When enabled by an Administrator, users can connect credentials in the chat UI; those credentials are validated, converted to short-lived session credentials, and stored in memory per (user, session). MCP tools can retrieve them via `get_caller_identity()` and `get_aws_session_for_user()` to perform AWS operations on behalf of the user. + +- **REST API**: `POST /api/aws/connect`, `GET /api/aws/status`, `DELETE /api/aws/connect` +- **Identity**: Extracted from `Authorization` (JWT) and `X-Session-Id` headers +- **Tool integration**: See `src/examples/sample_tools/aws_operator_tools.py` for a generic boto3-based `aws_api_call` tool using connected credentials + +The feature requires the **AWS Sessions** toggle to be enabled in Administration β†’ Configuration β†’ MCP. Without MCP tools that leverage the credentials, connecting them has no effect. + ## Development ### Project Structure diff --git a/lib/serve/mcp-workbench/pyproject.toml b/lib/serve/mcp-workbench/pyproject.toml index 435870623..5a311d16a 100644 --- a/lib/serve/mcp-workbench/pyproject.toml +++ b/lib/serve/mcp-workbench/pyproject.toml @@ -9,7 +9,8 @@ description = "A dynamic host for python files used as MCP tools" requires-python = ">=3.13" authors = [{name = "Dustin Sweigart", email = "dustinps@amazon.com"}] dependencies = [ - "fastmcp>=2.0.0", + "fastmcp>=2.10.0,<3.0.0", + "mcp>=1.26.0,<2.0.0", "pydantic>=2.0.0", "pyyaml>=6.0.2", "click==8.3.1", diff --git a/lib/serve/mcp-workbench/s6-overlay/services.d/mcpworkbench/run b/lib/serve/mcp-workbench/s6-overlay/services.d/mcpworkbench/run index 308f62260..5971a3e93 100755 --- a/lib/serve/mcp-workbench/s6-overlay/services.d/mcpworkbench/run +++ b/lib/serve/mcp-workbench/s6-overlay/services.d/mcpworkbench/run @@ -9,30 +9,28 @@ EXIT_ROUTE="${EXIT_ROUTE:-/exit}" CORS_ORIGINS="${CORS_ORIGINS:-*}" LOG_LEVEL="${LOG_LEVEL:-info}" -# Build command arguments -ARGS="--tools-dir ${TOOLS_DIR} --host ${HOST} --port ${PORT}" +# Build command arguments (array preserves spaced values as single argv elements) +ARGS=(--tools-dir "$TOOLS_DIR" --host "$HOST" --port "$PORT") # Add optional routes if set if [ -n "${RESCAN_ROUTE}" ]; then - ARGS="${ARGS} --rescan-route ${RESCAN_ROUTE}" + ARGS+=(--rescan-route "$RESCAN_ROUTE") fi if [ -n "${EXIT_ROUTE}" ]; then - ARGS="${ARGS} --exit-route ${EXIT_ROUTE}" + ARGS+=(--exit-route "$EXIT_ROUTE") fi -# Add CORS origins -if [ -n "${EXIT_ROUTE}" ]; then - ARGS="${ARGS} --cors-origins \"${CORS_ORIGINS}\"" -fi +# CORS: allow browser calls from the UI origin (varies by deployment); default * in shell and config +ARGS+=(--cors-origins "$CORS_ORIGINS") # Add verbosity based on log level case "${LOG_LEVEL}" in debug) - ARGS="${ARGS} --debug" + ARGS+=(--debug) ;; verbose) - ARGS="${ARGS} --verbose" + ARGS+=(--verbose) ;; esac @@ -41,7 +39,7 @@ echo "[mcpworkbench] Starting MCP Workbench server..." echo "[mcpworkbench] Tools directory: ${TOOLS_DIR}" echo "[mcpworkbench] Server: ${HOST}:${PORT}" echo "[mcpworkbench] MCP route: ${MCP_ROUTE}" -echo "[mcpworkbench] Arguments: ${ARGS}" +echo "[mcpworkbench] Arguments: ${ARGS[*]}" # Create tools directory if it doesn't exist mkdir -p "${TOOLS_DIR}" @@ -49,4 +47,4 @@ mkdir -p "${TOOLS_DIR}" s6-svwait -U /run/service/s3mount # Start the MCP workbench server -exec s6-setuidgid root mcpworkbench ${ARGS} +exec s6-setuidgid root mcpworkbench "${ARGS[@]}" diff --git a/lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py b/lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py new file mode 100644 index 000000000..4158d4753 --- /dev/null +++ b/lib/serve/mcp-workbench/src/examples/sample_tools/aws_operator_tools.py @@ -0,0 +1,153 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic AWS API access via boto3 using the MCP workbench AWS session. + +This sample exposes one tool that can call any boto3 client method (service + +operation + parameters). That matches IAM permissions of the connected +credentials. For production, consider restricting allowed services or operations. +""" + +from __future__ import annotations + +import re +from collections.abc import Mapping +from datetime import date, datetime +from decimal import Decimal +from typing import Any + +import boto3 +from botocore.response import StreamingBody +from mcpworkbench.aws import shared_session_service as _session_service +from mcpworkbench.aws.identity import CallerIdentityError, get_caller_identity +from mcpworkbench.aws.session_models import AwsSessionRecord +from mcpworkbench.aws.session_service import AwsSessionMissingError +from mcpworkbench.core.annotations import mcp_tool + +_SERVICE_RE = re.compile(r"^[a-z][a-z0-9-]*$") +_OPERATION_RE = re.compile(r"^[a-z][a-z0-9_]*$") +_STREAMING_BODY_READ_LIMIT = 65_536 + + +def _session_record() -> AwsSessionRecord: + try: + identity = get_caller_identity() + except CallerIdentityError as exc: + raise RuntimeError( + "Could not determine caller identity from the request. " + "Ensure the MCP connection sends Authorization and X-Session-Id headers." + ) from exc + + try: + return _session_service.get_aws_session_for_user(identity.user_id, identity.session_id) + except AwsSessionMissingError as exc: + raise RuntimeError("AWS session not connected or expired.") from exc + + +def _build_client(record: AwsSessionRecord, service_name: str, region_name: str | None) -> Any: + return boto3.client( + service_name, + aws_access_key_id=record.aws_access_key_id, + aws_secret_access_key=record.aws_secret_access_key, + aws_session_token=record.aws_session_token, + region_name=region_name or record.aws_region, + ) + + +def _to_serializable(obj: Any) -> Any: + if obj is None or isinstance(obj, (bool, int, float, str)): + return obj + if isinstance(obj, Decimal): + return float(obj) + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if isinstance(obj, dict): + return {k: _to_serializable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_to_serializable(v) for v in obj] + if isinstance(obj, bytes): + return obj.decode("utf-8", errors="replace") + if isinstance(obj, StreamingBody): + try: + chunk = obj.read(_STREAMING_BODY_READ_LIMIT) + truncated = len(chunk) >= _STREAMING_BODY_READ_LIMIT + try: + text = chunk.decode("utf-8") + except UnicodeDecodeError: + text = chunk.hex() + truncated = True + return { + "_streaming_body": True, + "content_preview": text, + "truncated": truncated, + "note": "S3 and similar APIs return a stream; only a prefix is returned here.", + } + finally: + try: + obj.close() + except Exception: + # Best-effort cleanup; ignore close errors to avoid changing caller behavior. + pass # nosec B110 + return str(obj) + + +@mcp_tool( + name="aws_api_call", + description=( + "Call any AWS API exposed as a boto3 client method using the connected AWS session. " + "Arguments: service (e.g. s3, ec2, dynamodb), operation (snake_case method name such as " + "list_buckets or describe_instances), optional parameters object for boto3 keyword " + "arguments, optional region to override the session default. " + "Respects the caller's IAM permissions; destructive or broad calls are possibleβ€”use " + "with care. Paginator workflows use multiple calls or the AWS CLI from your environment." + ), +) +def aws_api_call( + service: str, + operation: str, + parameters: dict[str, Any] | None = None, + region: str | None = None, +) -> dict[str, Any]: + if not _SERVICE_RE.match(service): + raise ValueError(f"Invalid service name {service!r}; expected a boto3 service id (letters, digits, hyphen).") + if not _OPERATION_RE.match(operation): + raise ValueError(f"Invalid operation {operation!r}; expected a snake_case boto3 client method name.") + + record = _session_record() + client = _build_client(record, service, region) + method = getattr(client, operation, None) + if method is None or not callable(method): + raise ValueError( + f"No such client method {operation!r} on service {service!r}. " + "Use boto3's snake_case names (see AWS service API docs / boto3 reference)." + ) + + if parameters is None: + params = {} + elif not isinstance(parameters, Mapping): + raise ValueError( + f"parameters must be a JSON object (mapping of string keys to boto3 keyword arguments), " + f"not {type(parameters).__name__}." + ) + else: + params = dict(parameters) + + try: + response = method(**params) + except TypeError as exc: + raise ValueError( + f"Bad parameters for {service}.{operation}: {exc}. Check required arguments in the AWS API / boto3 docs." + ) from exc + + return {"response": _to_serializable(response)} diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/__init__.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/__init__.py new file mode 100644 index 000000000..98eb262f8 --- /dev/null +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/__init__.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +AWS session management package for MCP Workbench. + +This package contains helper types and utilities for managing short-lived +AWS session credentials on a per-(user, session) basis. +""" + +from .identity import CallerIdentity as CallerIdentity +from .identity import CallerIdentityError as CallerIdentityError +from .identity import get_caller_identity as get_caller_identity +from .session_service import AwsSessionService +from .session_store import InMemoryAwsSessionStore +from .sts_client import AwsStsClient + +# Shared singletons β€” both the HTTP routes and MCP tools must use the same +# instances so credentials connected via /api/aws/connect are visible to tools. +shared_session_store = InMemoryAwsSessionStore(safety_margin_seconds=60) +shared_session_service = AwsSessionService(store=shared_session_store) +shared_sts_client = AwsStsClient() diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/aws_routes.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/aws_routes.py new file mode 100644 index 000000000..81a6ce904 --- /dev/null +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/aws_routes.py @@ -0,0 +1,159 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from datetime import timezone +from typing import Any + +from fastapi import APIRouter, HTTPException, Request, Response, status + +from . import shared_session_store as _session_store +from . import shared_sts_client as _sts_client +from .identity import decode_jwt_payload +from .session_models import AwsSessionRecord +from .sts_client import InvalidAwsCredentialsError + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +def _get_identity_from_request(request: Request) -> tuple[str, str]: + """ + Extract (user_id, session_id) from the authenticated request. + + user_id is derived from the JWT ``sub`` claim in the Authorization + header (already verified by OIDCHTTPBearer middleware). + session_id comes from the ``X-Session-Id`` header sent by the frontend. + """ + # request.headers is case-insensitive; avoid converting to a plain dict + hdrs = request.headers + + # --- user_id: prefer explicit header, fall back to JWT sub claim --- + user_id: str | None = hdrs.get("x-user-id") + if not user_id: + auth_header = hdrs.get("authorization", "") + token = auth_header.removeprefix("Bearer").strip() if auth_header else "" + if token: + claims = decode_jwt_payload(token) + user_id = claims.get("sub") + logger.debug("Extracted user_id=%s from JWT sub claim", user_id) + + # --- session_id from header --- + session_id = hdrs.get("x-session-id") + + if not user_id or not session_id: + missing = [] + if not user_id: + missing.append("user_id (no JWT sub claim or X-User-Id header)") + if not session_id: + missing.append("session_id (no X-Session-Id header)") + detail = f"Missing: {'; '.join(missing)}" + logger.warning("Identity extraction failed: %s", detail) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) + + return user_id, session_id + + +@router.post("/connect", status_code=status.HTTP_200_OK) +async def connect_aws(request: Request) -> dict[str, Any]: + """ + Accept AWS static credentials, validate them, and create a short-lived STS session. + + Request body: + - accessKeyId: str + - secretAccessKey: str + - sessionToken?: str + - region: str + """ + user_id, session_id = _get_identity_from_request(request) + + try: + body = await request.json() + except Exception: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Request body must be valid JSON.", + ) + + access_key_id = body.get("accessKeyId") + secret_access_key = body.get("secretAccessKey") + session_token = body.get("sessionToken") + region = body.get("region") + + if not access_key_id or not secret_access_key or not region: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="accessKeyId, secretAccessKey, and region are required.", + ) + + try: + account_id, arn = _sts_client.validate_static_credentials( + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + region=region, + ) + except InvalidAwsCredentialsError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"code": "InvalidCredentials", "message": str(exc)}, + ) from exc + + # For permanent (IAM user) credentials, duration_seconds controls the + # GetSessionToken TTL. For temporary credentials the param is ignored + # and the session record uses the STS maximum (12 h) since we cannot + # determine the real expiration of caller-provided temp creds. + record: AwsSessionRecord = _sts_client.create_session_credentials( + user_id=user_id, + session_id=session_id, + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + region=region, + duration_seconds=3600, + ) + + _session_store.set_session(record) + + return { + "accountId": account_id, + "arn": arn, + "expiresAt": record.expires_at.astimezone(timezone.utc).isoformat().replace("+00:00", "Z"), + } + + +@router.get("/status", status_code=status.HTTP_200_OK) +async def aws_status(request: Request) -> dict[str, Any]: + """Return current AWS connection status for the user/session.""" + user_id, session_id = _get_identity_from_request(request) + record = _session_store.get_session(user_id, session_id) + + if not record: + return {"connected": False} + + return { + "connected": True, + "expiresAt": record.expires_at.astimezone(timezone.utc).isoformat().replace("+00:00", "Z"), + } + + +@router.delete("/connect", status_code=status.HTTP_204_NO_CONTENT) +async def disconnect_aws(request: Request) -> Response: + """Explicitly clear AWS session credentials for the user/session.""" + user_id, session_id = _get_identity_from_request(request) + _session_store.delete_session(user_id, session_id) + return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/identity.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/identity.py new file mode 100644 index 000000000..e8ba5b81d --- /dev/null +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/identity.py @@ -0,0 +1,192 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for extracting caller identity inside MCP tool functions. + +Tool functions call :func:`get_caller_identity` to obtain the current +:class:`CallerIdentity`. On first access within a request, the function +reads HTTP headers from the underlying MCP request context and caches the +result in a ``ContextVar``. + +No FastMCP middleware is required β€” identity is resolved lazily on demand. +""" + +from __future__ import annotations + +import base64 +import contextvars +import json +import logging +from dataclasses import dataclass +from typing import Any, cast + +logger = logging.getLogger(__name__) + +_current_identity: contextvars.ContextVar[CallerIdentity | None] = contextvars.ContextVar( + "current_caller_identity", default=None +) + + +@dataclass(frozen=True) +class CallerIdentity: + user_id: str + session_id: str + + +class CallerIdentityError(Exception): + """Raised when caller identity cannot be determined from the HTTP request.""" + + +def decode_jwt_payload(token: str) -> dict: + """Extract claims from a JWT payload via base64 decode (no signature check). + + The OIDCHTTPBearer middleware already verified the signature, so this + is purely for reading claims. + """ + parts = token.split(".") + if len(parts) < 2: + return {} + payload = parts[1] + payload += "=" * ((4 - len(payload) % 4) % 4) + try: + return cast(dict[str, Any], json.loads(base64.urlsafe_b64decode(payload))) + except Exception: + return {} + + +def _extract_identity_from_headers(headers: dict[str, str]) -> CallerIdentity | None: + """Try to build a :class:`CallerIdentity` from raw HTTP headers. + + Returns ``None`` when either ``user_id`` or ``session_id`` cannot be + determined. + """ + user_id: str | None = headers.get("x-user-id") + if not user_id: + auth_header = headers.get("authorization", "") + token = auth_header.removeprefix("Bearer").strip() if auth_header else "" + if token: + claims = decode_jwt_payload(token) + user_id = claims.get("sub") + logger.debug("Extracted user_id=%s from JWT sub claim", user_id) + + session_id = headers.get("x-session-id") + + if user_id and session_id: + return CallerIdentity(user_id=user_id, session_id=session_id) + return None + + +def _get_headers_from_request_ctx() -> dict[str, str]: + """Read HTTP headers directly from the MCP low-level request context. + + Falls back to FastMCP's ``get_http_headers()`` if the direct approach + fails. Returns an empty dict if neither method succeeds. + """ + # Approach 1: read directly from the MCP request_ctx ContextVar + try: + from mcp.server.lowlevel.server import request_ctx # noqa: PLC0415 + + ctx = request_ctx.get() + request = ctx.request + if request is not None: + headers = cast( + dict[str, str], + {name.lower(): value for name, value in request.headers.items()}, + ) + logger.debug( + "identity: read %d headers from request_ctx (keys: %s)", + len(headers), + sorted(headers.keys()), + ) + return headers + logger.warning("identity: request_ctx.request is None") + except LookupError: + logger.warning("identity: request_ctx ContextVar not set") + except Exception: + logger.warning("identity: failed reading request_ctx", exc_info=True) + + # Approach 2: use FastMCP's helper (catches RuntimeError internally) + try: + from fastmcp.server.dependencies import get_http_headers # noqa: PLC0415 + + headers = cast(dict[str, str], get_http_headers(include_all=True)) + logger.debug( + "identity: fastmcp get_http_headers returned %d headers (keys: %s)", + len(headers), + sorted(headers.keys()), + ) + return headers + except Exception: + logger.warning("identity: fastmcp get_http_headers failed", exc_info=True) + + return {} + + +def _populate_identity_from_http() -> CallerIdentity | None: + """Read HTTP headers from the current MCP request and set the ContextVar. + + Must be called inside an MCP tool-call context. + + Returns the identity if successfully extracted, ``None`` otherwise. + """ + headers = _get_headers_from_request_ctx() + if not headers: + logger.warning("identity: no headers available β€” cannot extract identity") + return None + + identity = _extract_identity_from_headers(headers) + if identity: + _current_identity.set(identity) + logger.debug( + "identity: resolved user_id=%s session_id=%s", + identity.user_id, + identity.session_id, + ) + else: + has_auth = "authorization" in headers + has_session = "x-session-id" in headers + logger.warning( + "identity: extraction failed β€” authorization present=%s, " "x-session-id present=%s, header keys=%s", + has_auth, + has_session, + sorted(headers.keys()), + ) + return identity + + +def get_caller_identity() -> CallerIdentity: + """Return the caller identity for the current MCP tool invocation. + + On first call within a request, lazily reads HTTP headers from the + MCP request context and caches the result. Subsequent calls in the + same context return the cached value. + + Raises :class:`CallerIdentityError` when identity cannot be determined + (required headers absent or not in an MCP request context). + """ + identity = _current_identity.get() + if identity is not None: + return identity + + try: + identity = _populate_identity_from_http() + except Exception as exc: + raise CallerIdentityError("Could not read HTTP headers β€” not in an MCP request context.") from exc + + if identity is None: + raise CallerIdentityError( + "Cannot determine caller identity. " + "Ensure the MCP connection sends Authorization and X-Session-Id headers." + ) + return identity diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_models.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_models.py new file mode 100644 index 000000000..c3c4e4b31 --- /dev/null +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_models.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone + + +@dataclass +class AwsSessionRecord: + """ + In-memory representation of a short-lived AWS session for a user/session. + + The fields mirror the design in LISA_Auth.md, with expires_at stored as + an aware UTC datetime. + """ + + user_id: str + session_id: str + aws_access_key_id: str + aws_secret_access_key: str + aws_session_token: str + aws_region: str + expires_at: datetime + + def is_expired(self, *, safety_margin_seconds: int = 0) -> bool: + """Return True if the record should be treated as expired.""" + now = datetime.now(timezone.utc) + effective_expiry = self.expires_at + if safety_margin_seconds > 0: + effective_expiry = effective_expiry - timedelta(seconds=safety_margin_seconds) + return now >= effective_expiry diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_service.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_service.py new file mode 100644 index 000000000..3c5243337 --- /dev/null +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_service.py @@ -0,0 +1,43 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +from .session_models import AwsSessionRecord +from .session_store import InMemoryAwsSessionStore + + +class AwsSessionMissingError(Exception): + """Raised when no AWS session is stored for the given user/session.""" + + +class AwsSessionExpiredError(Exception): + """Raised when an AWS session exists but is expired.""" + + +@dataclass +class AwsSessionService: + """High-level helper for retrieving AWS sessions for MCP tools.""" + + store: InMemoryAwsSessionStore + + def get_aws_session_for_user(self, user_id: str, session_id: str) -> AwsSessionRecord: + record = self.store.get_session(user_id, session_id) + if record is None: + # We intentionally don't distinguish missing vs expired here since + # InMemoryAwsSessionStore cleans up expired records on access. + raise AwsSessionMissingError("AWS session not connected or expired.") + return record diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_store.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_store.py new file mode 100644 index 000000000..b926ecb63 --- /dev/null +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/session_store.py @@ -0,0 +1,60 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field + +from .session_models import AwsSessionRecord + + +@dataclass +class InMemoryAwsSessionStore: + """ + Simple in-process implementation of an AWS session store. + + This is suitable for a single MCP Workbench process. For multi-instance + deployments, a distributed store such as Redis should be used instead. + """ + + safety_margin_seconds: int = 0 + + _sessions: dict[tuple[str, str], AwsSessionRecord] = field(default_factory=dict, init=False) + + def set_session(self, record: AwsSessionRecord) -> None: + """Create or update the session for the given user/session.""" + key = (record.user_id, record.session_id) + self._sessions[key] = record + + def get_session(self, user_id: str, session_id: str) -> AwsSessionRecord | None: + """ + Retrieve the session for a given user/session, or None if missing/expired. + """ + key = (user_id, session_id) + record = self._sessions.get(key) + if record is None: + return None + + # Treat sessions as expired if past expiration or too close to expiry + if record.is_expired(safety_margin_seconds=self.safety_margin_seconds): + # Clean up expired record + self._sessions.pop(key, None) + return None + + return record + + def delete_session(self, user_id: str, session_id: str) -> None: + """Delete the session for the given user/session, if it exists.""" + key = (user_id, session_id) + self._sessions.pop(key, None) diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/aws/sts_client.py b/lib/serve/mcp-workbench/src/mcpworkbench/aws/sts_client.py new file mode 100644 index 000000000..abc69fee8 --- /dev/null +++ b/lib/serve/mcp-workbench/src/mcpworkbench/aws/sts_client.py @@ -0,0 +1,152 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any + +import boto3 + +from .session_models import AwsSessionRecord + + +class InvalidAwsCredentialsError(Exception): + """Raised when provided AWS credentials are invalid or STS rejects them.""" + + +@dataclass +class AwsStsClient: + """ + Thin wrapper around boto3 STS client for validating credentials and + creating short-lived session credentials. + """ + + def _create_sts_client( + self, + access_key_id: str, + secret_access_key: str, + session_token: str | None, + region: str, + ) -> Any: + kwargs: dict[str, Any] = { + "aws_access_key_id": access_key_id, + "aws_secret_access_key": secret_access_key, + "region_name": region, + # Use the regional STS endpoint so traffic stays within the VPC + # when an STS VPC endpoint is configured (the global endpoint + # sts.amazonaws.com is not reachable from private subnets). + "endpoint_url": f"https://sts.{region}.amazonaws.com", + } + if session_token: + kwargs["aws_session_token"] = session_token + return boto3.client("sts", **kwargs) + + def validate_static_credentials( + self, + access_key_id: str, + secret_access_key: str, + session_token: str | None, + region: str, + ) -> tuple[str, str]: + """ + Validate credentials via GetCallerIdentity. + + Returns (account_id, arn) on success, raises InvalidAwsCredentialsError on failure. + """ + sts = self._create_sts_client(access_key_id, secret_access_key, session_token, region) + try: + identity = sts.get_caller_identity() + except Exception as exc: # noqa: BLE001 + raise InvalidAwsCredentialsError(f"STS GetCallerIdentity failed: {type(exc).__name__}: {exc}") from exc + + account_id = str(identity.get("Account")) + arn = str(identity.get("Arn")) + return account_id, arn + + # AWS STS temporary credentials can last at most 12 hours. + MAX_TEMP_CREDENTIAL_TTL_SECONDS = 43200 + + def create_session_credentials( + self, + user_id: str, + session_id: str, + access_key_id: str, + secret_access_key: str, + session_token: str | None, + region: str, + duration_seconds: int = 3600, + safety_margin_seconds: int = 60, + ) -> AwsSessionRecord: + """ + Produce an AwsSessionRecord from the provided credentials. + + * **Long-term (IAM user) credentials** (no session_token): calls + ``GetSessionToken`` to mint short-lived temporary credentials + lasting ``duration_seconds``. + * **Temporary credentials** (session_token present): stores them + directly -- AWS forbids calling ``GetSessionToken`` with + temporary credentials. There is no STS API to query when + pre-existing temporary credentials expire, so we assume the + maximum STS lifetime (12 h). The credentials will naturally + fail at call time once they truly expire. + + The returned record's ``expires_at`` is adjusted by + ``safety_margin_seconds``. + """ + now = datetime.now(timezone.utc) + + if session_token: + # We cannot determine the real expiration of caller-provided + # temporary credentials, so assume the STS maximum (12 h). + # The credentials will fail with an auth error at call time + # once they truly expire, prompting the user to reconnect. + assumed_ttl = self.MAX_TEMP_CREDENTIAL_TTL_SECONDS + expires_at = now + timedelta(seconds=assumed_ttl - safety_margin_seconds) + return AwsSessionRecord( + user_id=user_id, + session_id=session_id, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token, + aws_region=region, + expires_at=expires_at, + ) + + # Long-term IAM user credentials -- mint a session via STS + sts = self._create_sts_client(access_key_id, secret_access_key, None, region) + try: + response = sts.get_session_token(DurationSeconds=duration_seconds) + except Exception as exc: # noqa: BLE001 + raise InvalidAwsCredentialsError(f"STS GetSessionToken failed: {type(exc).__name__}: {exc}") from exc + + creds: dict[str, Any] = response["Credentials"] + raw_expiration: datetime = creds["Expiration"] + if raw_expiration.tzinfo is None: + raw_expiration = raw_expiration.replace(tzinfo=timezone.utc) + expires_at = raw_expiration - timedelta(seconds=safety_margin_seconds) + + if expires_at <= now: + expires_at = now + timedelta(seconds=1) + + return AwsSessionRecord( + user_id=user_id, + session_id=session_id, + aws_access_key_id=creds["AccessKeyId"], + aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], + aws_region=region, + expires_at=expires_at, + ) diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/config/models.py b/lib/serve/mcp-workbench/src/mcpworkbench/config/models.py index 47280856b..69afef403 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/config/models.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/config/models.py @@ -22,9 +22,15 @@ class CORSConfig(BaseModel): """CORS configuration settings.""" allow_origins: list[str] = Field(default=["*"], description="Allowed origins for CORS") - allow_methods: list[str] = Field(default=["GET", "POST", "OPTIONS"], description="Allowed HTTP methods") + allow_methods: list[str] = Field( + default=["*"], + description=( + "Allowed HTTP methods for CORS preflight; use * (Starlette expands to " + "all standard methods) for MCP streamable HTTP clients." + ), + ) allow_headers: list[str] = Field(default=["*"], description="Allowed headers") - allow_credentials: bool = Field(default=True, description="Allow credentials in CORS requests") + allow_credentials: bool = Field(default=False, description="Allow credentials in CORS requests") expose_headers: list[str] = Field(default=[], description="Headers to expose to the browser") max_age: int = Field(default=600, description="Maximum age for CORS preflight cache") diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py index 5a980bd18..8c7e2545d 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py @@ -199,16 +199,25 @@ class ApiTokenAuthorizer: """ def __init__(self) -> None: + table_name = os.environ.get(TOKEN_TABLE_NAME) + if not table_name: + logger.info("TOKEN_TABLE_NAME is unset; programmatic API token auth is disabled (OIDC still works).") + self._token_table = None + return ddb_resource = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"]) - self._token_table = ddb_resource.Table(os.environ[TOKEN_TABLE_NAME]) + self._token_table = ddb_resource.Table(table_name) def _get_token_info(self, token: str) -> Any: """Return DDB entry for token if it exists.""" + if self._token_table is None: + return None ddb_response = self._token_table.get_item(Key={"token": token}, ReturnConsumedCapacity="NONE") return ddb_response.get("Item", None) def is_valid_api_token(self, headers: dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" + if self._token_table is None: + return False for header_name in API_KEY_HEADER_NAMES: token = get_authorization_token(headers, header_name) if token: diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py index bfb22ad2a..ec19743b5 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py @@ -23,17 +23,19 @@ from fastmcp import FastMCP from starlette.applications import Starlette -from starlette.middleware.cors import CORSMiddleware +from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount, Route from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR +from ..aws.aws_routes import router as aws_router from ..config.models import ServerConfig from ..core.base_tool import BaseTool from ..core.tool_discovery import ToolDiscovery, ToolInfo, ToolType from ..core.tool_registry import ToolRegistry -from .auth import OIDCHTTPBearer +from .auth import is_idp_used, OIDCHTTPBearer +from .middleware import CORSMiddleware, wrap_asgi_with_cors_headers logger = logging.getLogger(__name__) @@ -146,14 +148,14 @@ async def health_check(request: Request) -> JSONResponse: return JSONResponse({"status": "healthy", "service": "mcpworkbench"}) logger.info(f"CORS Allowed Origins: {self.config.cors_settings.allow_origins}") - mcp_app.add_middleware( - CORSMiddleware, - allow_origins=self.config.cors_settings.allow_origins, - allow_methods=self.config.cors_settings.allow_methods, - allow_headers=self.config.cors_settings.allow_headers, - ) - - mcp_app.add_middleware(OIDCHTTPBearer) + # Auth only on mounted apps; CORS is applied at the root Starlette app so OPTIONS preflight + # is handled before routing (avoids FastMCP 500 on OPTIONS and missing ACAO on errors). + if is_idp_used(): + mcp_app.add_middleware(OIDCHTTPBearer) + else: + logger.info( + "USE_AUTH is false or unset: OIDC/API-token auth middleware is disabled (same as Serve REST API)." + ) # Add MCP mount routes = [ @@ -161,9 +163,22 @@ async def health_check(request: Request) -> JSONResponse: Mount("/v2/mcp", mcp_app), ] + # Mount AWS session management routes under /api/aws + from fastapi import FastAPI # noqa: PLC0415 + + aws_app = FastAPI() + if is_idp_used(): + aws_app.add_middleware(OIDCHTTPBearer) + aws_app.include_router(aws_router) + routes.append(Mount("/api/aws", aws_app)) + self._add_management_routes(mcp_app) - return Starlette(routes=routes, lifespan=mcp_app.lifespan) + return Starlette( + routes=routes, + middleware=[Middleware(CORSMiddleware, cors_config=self.config.cors_settings)], + lifespan=mcp_app.lifespan, + ) async def _register_discovered_tools(self, tools: list[ToolInfo]) -> None: """Register discovered tools with FastMCP.""" @@ -249,6 +264,8 @@ async def start(self) -> None: # Create Starlette app with both MCP and HTTP routes starlette_app = self._create_starlette_app() + # Outer ASGI wrapper so 500s from ServerErrorMiddleware still get CORS headers (browser can read body) + asgi_app = wrap_asgi_with_cors_headers(starlette_app, self.config.cors_settings) # Start server with Starlette app logger.info(f"Starting MCP Workbench server on {self.config.server_host}:{self.config.server_port}") @@ -264,7 +281,7 @@ async def start(self) -> None: import uvicorn # noqa: PLC0415 config = uvicorn.Config( - starlette_app, + asgi_app, host=self.config.server_host, port=self.config.server_port, log_level="info", diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py index 798e9c119..0be13a801 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py @@ -20,11 +20,13 @@ from datetime import datetime from typing import Any +from starlette.datastructures import MutableHeaders from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware as StarletteCORSMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR +from starlette.types import ASGIApp, Message, Receive, Scope, Send from ..config.models import CORSConfig from ..core.tool_discovery import ToolDiscovery @@ -48,6 +50,79 @@ def __init__(self, app: Any, cors_config: CORSConfig) -> None: ) +def _parse_request_origin(scope: Scope) -> str | None: + if scope["type"] != "http": + return None + for key, value in scope.get("headers") or []: + if key.lower() == b"origin" and isinstance(value, bytes): + return value.decode("latin-1") + return None + + +def _access_control_allow_origin_value(cors_config: CORSConfig, request_origin: str | None) -> str | None: + origins = cors_config.allow_origins + empty_origin_wildcard = "" in origins + + if cors_config.allow_credentials: + # "" in allow_origins means "reflect any request Origin" (cannot use * with credentials). + if empty_origin_wildcard: + return request_origin if request_origin else None + if request_origin and request_origin in origins: + return request_origin + fallback = origins[0] if origins else "*" + return None if fallback == "" else fallback + if "*" in origins: + return "*" + if request_origin and request_origin in origins: + return request_origin + return origins[0] if origins else "*" + + +def _merge_vary_origin(headers: MutableHeaders) -> None: + existing = headers.get("vary") + if existing: + parts = [p.strip() for p in existing.split(",") if p.strip()] + if "Origin" not in parts: + headers["vary"] = f"{existing}, Origin" + else: + headers["vary"] = "Origin" + + +def wrap_asgi_with_cors_headers(app: ASGIApp, cors_config: CORSConfig) -> ASGIApp: + """Outer ASGI wrapper: ensure CORS headers on every HTTP response when missing. + + Starlette's outer ``ServerErrorMiddleware`` can emit error responses that bypass inner + ``CORSMiddleware``'s ``send`` wrapper, so browsers see 500 without + ``Access-Control-Allow-Origin`` and block the response body. + """ + + async def asgi(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await app(scope, receive, send) + return + + origin = _parse_request_origin(scope) + + async def send_wrapper(message: Message) -> None: + if message["type"] == "http.response.start": + headers = MutableHeaders(scope=message) + if "access-control-allow-origin" not in headers: + acao = _access_control_allow_origin_value(cors_config, origin) + if acao is not None: + headers["access-control-allow-origin"] = acao + if origin and acao == origin: + _merge_vary_origin(headers) + if cors_config.allow_headers and "*" in cors_config.allow_headers: + if "access-control-allow-headers" not in headers: + headers["access-control-allow-headers"] = "*" + + await send(message) + + await app(scope, receive, send_wrapper) + + return asgi + + class ExitRouteMiddleware(BaseHTTPMiddleware): """Middleware to handle application exit requests.""" diff --git a/lib/serve/mcpWorkbenchConstruct.ts b/lib/serve/mcpWorkbenchConstruct.ts index 55d34ef1d..56eb0c61e 100644 --- a/lib/serve/mcpWorkbenchConstruct.ts +++ b/lib/serve/mcpWorkbenchConstruct.ts @@ -18,7 +18,8 @@ import { IAuthorizer, IRestApi, RestApi } from 'aws-cdk-lib/aws-apigateway'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { Construct } from 'constructs'; import { Vpc } from '../networking/vpc'; -import { BaseProps, Config, EcsSourceType } from '../schema'; +import { AmiHardwareType } from '../schema/cdk'; +import { APP_MANAGEMENT_KEY, BaseProps, Config, ECSConfig, Ec2Metadata, EcsSourceType } from '../schema'; import * as s3 from 'aws-cdk-lib/aws-s3'; import { Duration, RemovalPolicy, StackProps } from 'aws-cdk-lib'; import { createCdkId } from '../core/utils'; @@ -27,19 +28,21 @@ import { getPythonRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../ import * as iam from 'aws-cdk-lib/aws-iam'; import { LAMBDA_PATH, MCP_WORKBENCH_PATH } from '../util'; import { WORKBENCH_CONTAINER_MEMORY_RESERVATION, WORKBENCH_CONTAINER_MEMORY_LIMIT } from '../api-base/fastApiContainer'; +import { defaultMcpWorkbenchHostnameFromServeApiDomain } from './mcpWorkbenchDomain'; import * as lambda from 'aws-cdk-lib/aws-lambda'; import * as events from 'aws-cdk-lib/aws-events'; import * as targets from 'aws-cdk-lib/aws-events-targets'; import { ECSCluster, ECSTasks } from '../api-base/ecsCluster'; import { Ec2Service } from 'aws-cdk-lib/aws-ecs'; +import * as dynamodb from 'aws-cdk-lib/aws-dynamodb'; import { BlockPublicAccess, BucketEncryption } from 'aws-cdk-lib/aws-s3'; export type McpWorkbenchConstructProps = { + bucketAccessLogsBucket: s3.IBucket; restApiId: string; rootResourceId: string; securityGroups: ISecurityGroup[]; vpc: Vpc; - apiCluster: ECSCluster; authorizer?: IAuthorizer; } & BaseProps & StackProps; @@ -49,7 +52,7 @@ export class McpWorkbenchConstruct extends Construct { constructor (scope: Construct, id: string, props: McpWorkbenchConstructProps) { super(scope, id); - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc, apiCluster } = props; + const { authorizer, bucketAccessLogsBucket, config, restApiId, rootResourceId, securityGroups, vpc } = props; // Get common layer based on arn from SSM due to issues with cross stack references const commonLambdaLayer = lambda.LayerVersion.fromLayerVersionArn( @@ -71,14 +74,139 @@ export class McpWorkbenchConstruct extends Construct { const lambdaLayers = [commonLambdaLayer, fastapiLambdaLayer]; - const workbenchBucket = this.createWorkbenchBucket(scope, config); + const workbenchBucket = this.createWorkbenchBucket(scope, config, bucketAccessLogsBucket); this.createWorkbenchApi(restApi, config, vpc, securityGroups, workbenchBucket, lambdaLayers, authorizer); if (config.deployMcpWorkbench) { - this.createWorkbenchService(apiCluster, config, vpc); + this.createWorkbenchService(config, vpc); } } + private buildMcpWorkbenchBuildArgs (config: Config): Record { + const buildArgs: Record = { + BASE_IMAGE: config.baseImage, + PYPI_INDEX_URL: config.pypiConfig.indexUrl, + PYPI_TRUSTED_HOST: config.pypiConfig.trustedHost, + }; + if (config.mcpWorkbenchBuildConfig) { + Object.entries(config.mcpWorkbenchBuildConfig).forEach(([key, value]) => { + if (value) { + buildArgs[key] = value; + } + }); + } + return buildArgs; + } + + private buildWorkbenchEcsConfig (config: Config): ECSConfig { + const o = config.mcpWorkbenchEcsConfig ?? {}; + const instanceType = o.instanceType ?? 'm5.xlarge'; + // Workbench uses its own ALB; never reuse restApiConfig.domainName (that name resolves to the Serve ALB). + // mcpWorkbenchRestApiConfig mirrors restApiConfig for YAML parity; mcpWorkbenchEcsConfig.domainName remains supported. + const workbenchDomainName = + config.mcpWorkbenchRestApiConfig?.domainName ?? + o.domainName ?? + defaultMcpWorkbenchHostnameFromServeApiDomain(config.restApiConfig.domainName ?? undefined) ?? + null; + const workbenchSslCertArn = + config.mcpWorkbenchRestApiConfig?.sslCertIamArn ?? + o.sslCertIamArn ?? + config.restApiConfig.sslCertIamArn ?? + null; + return { + amiHardwareType: AmiHardwareType.STANDARD, + autoScalingConfig: { + blockDeviceVolumeSize: o.blockDeviceVolumeSize ?? 50, + minCapacity: o.minCapacity ?? 1, + maxCapacity: o.maxCapacity ?? 5, + cooldown: o.cooldown ?? 60, + defaultInstanceWarmup: 60, + metricConfig: { + albMetricName: 'RequestCountPerTarget', + targetValue: 1000, + duration: 60, + estimatedInstanceWarmup: 30, + }, + }, + buildArgs: this.buildMcpWorkbenchBuildArgs(config), + tasks: {}, + containerMemoryBuffer: 0, + instanceType, + internetFacing: config.restApiConfig.internetFacing, + loadBalancerConfig: { + healthCheckConfig: { + path: '/health', + interval: 60, + timeout: 30, + healthyThresholdCount: 2, + unhealthyThresholdCount: 3, + }, + domainName: workbenchDomainName, + sslCertIamArn: workbenchSslCertArn, + }, + }; + } + + private buildWorkbenchClusterEnvironment (config: Config, instanceType: string, managementKeyName: string | undefined): Record { + const environment: Record = { + LOG_LEVEL: config.logLevel, + AWS_REGION: config.region, + AWS_REGION_NAME: config.region, + THREADS: Ec2Metadata.get(instanceType).vCpus.toString(), + }; + if (config.authConfig) { + environment.USE_AUTH = 'true'; + environment.AUTHORITY = config.authConfig.authority; + environment.CLIENT_ID = config.authConfig.clientId; + environment.ADMIN_GROUP = config.authConfig.adminGroup; + environment.USER_GROUP = config.authConfig.userGroup; + environment.JWT_GROUPS_PROP = config.authConfig.jwtGroupsProperty; + environment.MANAGEMENT_KEY_NAME = managementKeyName!; + } else { + environment.USE_AUTH = 'false'; + } + if (config.region.includes('iso')) { + environment.SSL_CERT_DIR = '/etc/pki/tls/certs'; + environment.SSL_CERT_FILE = config.certificateAuthorityBundle; + environment.REQUESTS_CA_BUNDLE = config.certificateAuthorityBundle; + environment.AWS_CA_BUNDLE = config.certificateAuthorityBundle; + environment.CURL_CA_BUNDLE = config.certificateAuthorityBundle; + } + return environment; + } + + private getMcpWorkbenchTaskDefinition (config: Config) { + const mcpWorkbenchImage = config.mcpWorkbenchConfig || { + baseImage: config.baseImage, + path: MCP_WORKBENCH_PATH, + type: EcsSourceType.ASSET, + }; + + return { + environment: { + RCLONE_CONFIG_S3_REGION: config.region, + MCPWORKBENCH_BUCKET: [config.deploymentName, config.deploymentStage, 'MCPWorkbench', config.accountNumber].join('-').toLowerCase(), + CORS_ORIGINS: config.mcpWorkbenchCorsOrigins, + }, + containerConfig: { + image: mcpWorkbenchImage, + healthCheckConfig: { + command: ['CMD-SHELL', 'exit 0'], + interval: 10, + startPeriod: 30, + timeout: 5, + retries: 3, + }, + environment: {}, + sharedMemorySize: 0, + privileged: true, + }, + containerMemoryReservationMiB: WORKBENCH_CONTAINER_MEMORY_RESERVATION, + memoryLimitMiB: WORKBENCH_CONTAINER_MEMORY_LIMIT, + applicationTarget: { port: 8000 }, + }; + } + private createWorkbenchApi (restApi: IRestApi, config: Config, vpc: Vpc, securityGroups: ISecurityGroup[], workbenchBucket: s3.Bucket, lambdaLayers: lambda.ILayerVersion[], authorizer?: IAuthorizer) { const env = { @@ -181,11 +309,7 @@ export class McpWorkbenchConstruct extends Construct { }); } - private createWorkbenchBucket (scope: Construct, config: Config): s3.Bucket { - const bucketAccessLogsBucket = s3.Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', - ssm.StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`), - ); - + private createWorkbenchBucket (scope: Construct, config: Config, bucketAccessLogsBucket: s3.IBucket): s3.Bucket { return new s3.Bucket(scope, createCdkId(['LISA', 'MCPWorkbench', config.deploymentName, config.deploymentStage]), { bucketName: [config.deploymentName, config.deploymentStage, 'MCPWorkbench', config.accountNumber].join('-').toLowerCase(), removalPolicy: config.removalPolicy, @@ -199,47 +323,52 @@ export class McpWorkbenchConstruct extends Construct { }); } - private createWorkbenchService (apiCluster: ECSCluster, config: Config, vpc: Vpc) { + private createWorkbenchService (config: Config, vpc: Vpc) { + const ecsConfig = this.buildWorkbenchEcsConfig(config); + const managementKeyName = config.authConfig + ? ssm.StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`) + : undefined; + const environment = this.buildWorkbenchClusterEnvironment(config, ecsConfig.instanceType, managementKeyName); + // Same token table as Serve REST API (SSM from Api Base); required for ApiTokenAuthorizer in auth middleware + environment.TOKEN_TABLE_NAME = ssm.StringParameter.valueForStringParameter( + this, + `${config.deploymentPrefix}/tokenTableName`, + ); - const mcpWorkbenchImage = config.mcpWorkbenchConfig || { - baseImage: config.baseImage, - path: MCP_WORKBENCH_PATH, - type: EcsSourceType.ASSET - }; + const workbenchCluster = new ECSCluster(this, 'McpWorkbenchDedicatedEcs', { + identifier: 'McpWorkbenchDedicated', + ecsConfig, + config, + securityGroup: vpc.securityGroups.restApiAlbSg, + vpc, + environment, + }); - const mcpWorkbenchTaskDefinition = { - environment: { - RCLONE_CONFIG_S3_REGION: config.region, - MCPWORKBENCH_BUCKET: [config.deploymentName, config.deploymentStage, 'MCPWorkbench', config.accountNumber].join('-').toLowerCase(), - }, - containerConfig: { - image: mcpWorkbenchImage, - healthCheckConfig: { - command: ['CMD-SHELL', 'exit 0'], - interval: 10, - startPeriod: 30, - timeout: 5, - retries: 3 - }, - environment: {}, - sharedMemorySize: 0, - privileged: true - }, - containerMemoryReservationMiB: WORKBENCH_CONTAINER_MEMORY_RESERVATION, - memoryLimitMiB: WORKBENCH_CONTAINER_MEMORY_LIMIT, - applicationTarget: { - port: 8000, - priority: 80, - conditions: [{ - type: 'pathPatterns' as const, - values: ['/v2/mcp/*'] - }] - } - }; + const mcpWorkbenchTaskDefinition = this.getMcpWorkbenchTaskDefinition(config); + const { service } = workbenchCluster.addTask(ECSTasks.MCPWORKBENCH, mcpWorkbenchTaskDefinition); - const { service } = apiCluster.addTask(ECSTasks.MCPWORKBENCH, mcpWorkbenchTaskDefinition); + const tokenTableNameParameter = ssm.StringParameter.fromStringParameterName( + this, + createCdkId(['McpWorkbench', 'TokenTableNameParameter']), + `${config.deploymentPrefix}/tokenTableName`, + ); + const tokenTable = dynamodb.Table.fromTableName( + this, + createCdkId(['McpWorkbench', 'TokenTable']), + tokenTableNameParameter.stringValue, + ); + const mcpWorkbenchTaskRole = workbenchCluster.taskRoles[ECSTasks.MCPWORKBENCH]; + if (mcpWorkbenchTaskRole) { + tokenTable.grantReadData(mcpWorkbenchTaskRole); + } this.createS3EventHandler(config, service, vpc); + + new ssm.StringParameter(this, 'McpWorkbenchHostedEndpoint', { + parameterName: `${config.deploymentPrefix}/mcpWorkbench/endpoint`, + stringValue: workbenchCluster.endpointUrl, + description: 'Base URL for hosted MCP Workbench HTTP server (MCP path /v2/mcp/)', + }); } private createS3EventHandler (config: any, workbenchService: Ec2Service, vpc: Vpc) { diff --git a/lib/serve/mcpWorkbenchDomain.ts b/lib/serve/mcpWorkbenchDomain.ts new file mode 100644 index 000000000..81f52dab8 --- /dev/null +++ b/lib/serve/mcpWorkbenchDomain.ts @@ -0,0 +1,51 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +/** + * When `restApiConfig.domainName` is set, MCP Workbench must not reuse that hostname: it runs on a + * separate ALB, and DNS for the Serve API name targets the Serve load balancer only. + * + * If `mcpWorkbenchRestApiConfig.domainName` and `mcpWorkbenchEcsConfig.domainName` are omitted, derive a conventional workbench hostname so HTTPS + * (shared ACM cert / wildcard) and SSM `…/mcpWorkbench/endpoint` stay consistent: + * - First label ends with `-serve` β†’ replace that suffix with `-mcp-workbench` (e.g. `lisa-serve.example` β†’ `lisa-mcp-workbench.example`). + * - First label is exactly `serve` β†’ use `mcp-workbench` (e.g. `serve.alias.example` β†’ `mcp-workbench.alias.example`). + * + * Otherwise returns null so the workbench ALB DNS name is used (operators should set `mcpWorkbenchRestApiConfig.domainName` or `mcpWorkbenchEcsConfig.domainName` if they need TLS on a custom name). + */ +export function defaultMcpWorkbenchHostnameFromServeApiDomain (restApiDomain: string | null | undefined): string | null { + const trimmed = restApiDomain?.trim(); + if (!trimmed) { + return null; + } + const parts = trimmed.split('.'); + const first = parts[0]; + if (!first) { + return null; + } + + let nextFirst: string | null = null; + if (first.endsWith('-serve')) { + nextFirst = `${first.slice(0, -'-serve'.length)}-mcp-workbench`; + } else if (first === 'serve') { + nextFirst = 'mcp-workbench'; + } + + if (!nextFirst) { + return null; + } + parts[0] = nextFirst; + return parts.join('.'); +} diff --git a/lib/serve/mcpWorkbenchStack.ts b/lib/serve/mcpWorkbenchStack.ts index 4dccb042f..ecea7a6c6 100644 --- a/lib/serve/mcpWorkbenchStack.ts +++ b/lib/serve/mcpWorkbenchStack.ts @@ -19,14 +19,14 @@ import { Construct } from 'constructs'; import { BaseProps } from '../schema'; import { McpWorkbenchConstruct } from './mcpWorkbenchConstruct'; import { Vpc } from '../networking/vpc'; -import { ECSCluster } from '../api-base/ecsCluster'; import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; +import { IBucket } from 'aws-cdk-lib/aws-s3'; export type McpWorkbenchStackProps = { + bucketAccessLogsBucket: IBucket; vpc: Vpc; restApiId: string; rootResourceId: string; - apiCluster: ECSCluster; authorizer?: IAuthorizer; } & BaseProps & StackProps; @@ -34,15 +34,15 @@ export class McpWorkbenchStack extends Stack { constructor (scope: Construct, id: string, props: McpWorkbenchStackProps) { super(scope, id, props); - const { vpc, restApiId, rootResourceId, authorizer, apiCluster } = props; + const { vpc, restApiId, rootResourceId, authorizer, bucketAccessLogsBucket } = props; new McpWorkbenchConstruct(this, 'McpWorkbench', { ...props, + bucketAccessLogsBucket, restApiId, rootResourceId, securityGroups: [vpc.securityGroups.ecsModelAlbSg], vpc: vpc, - apiCluster, authorizer }); } diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index 58e206034..6251a67d7 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -36,7 +36,8 @@ get_model_guardrails, is_guardrail_violation, ) -from utils.metrics import publish_metrics_event +from utils.metrics import extract_token_usage, publish_metrics_event +from utils.request_utils import get_lisa_end_user_id from utils.route_utils import is_anthropic_route, is_chat_route, is_lisa_public_route, is_openai_route # Local LiteLLM installation URL. By default, LiteLLM runs on port 4000. Change the port here if the @@ -257,63 +258,96 @@ def generate_response(iterator: Iterator[str | bytes]) -> Iterator[str]: yield f"{line}\n\n" -def generate_response_with_guardrail_handling(iterator: Iterator[str | bytes], model: str) -> Iterator[str]: +def generate_response_with_guardrail_handling( + iterator: Iterator[str | bytes], + model: str, + request: Request, + params: dict, +) -> Iterator[str]: """ - Generate streaming responses with guardrail violation error handling. + Generate streaming responses with guardrail violation error handling and token usage capture. - This wrapper checks each chunk in the stream for guardrail violations and converts - them into properly formatted streaming responses. - """ - for line in iterator: - if isinstance(line, bytes): - line = line.decode() - - if not line: - continue + In addition to guardrail handling, this generator watches for the SSE usage chunk + (the chunk containing ``"usage": {...}``) emitted at the end of a streaming response. + When found, it extracts prompt/completion/total token counts and publishes a unified + metrics event to SQS after the stream completes. - # Check if this line contains an error (SSE format: "data: {...}") - if line.startswith("data: "): - try: - # Extract JSON from SSE data line - data_content = line[6:].strip() # Remove "data: " prefix + All chunks are forwarded to the client unchanged β€” token capture is a side-effect only. - # Skip [DONE] marker - if data_content == "[DONE]": - yield f"{line}\n\n" - continue + Args: + iterator: The line iterator from the streaming LiteLLM response + model: The model ID used for the request + request: The original FastAPI request (passed to publish_metrics_event) + params: The original request parameters (passed to publish_metrics_event) + """ + captured_prompt_tokens: int | None = None + captured_completion_tokens: int | None = None + guardrail_triggered = False - # Try to parse as JSON to check for errors - chunk_data = json.loads(data_content) + try: + for line in iterator: + if isinstance(line, bytes): + line = line.decode() - # Check if this is an error chunk - if "error" in chunk_data: - error_msg = chunk_data.get("error", {}).get("message", "") + if not line: + continue - if is_guardrail_violation(error_msg): - logger.info("Guardrail policy violated in streaming response") + if line.startswith("data: "): + try: + data_content = line[6:].strip() # Remove "data: " prefix - guardrail_response = extract_guardrail_response(error_msg) - if guardrail_response: - # Stream the guardrail response - created = int(chunk_data.get("created", 0)) - yield from create_guardrail_streaming_response(guardrail_response, model, created) - return # Stop streaming after guardrail response + if data_content == "[DONE]": + yield f"{line}\n\n" + continue + + chunk_data = json.loads(data_content) + + # Capture token usage from the usage chunk (present near end of stream) + if "usage" in chunk_data and chunk_data["usage"]: + pt, ct = extract_token_usage(chunk_data) + if pt is not None: + captured_prompt_tokens = pt + if ct is not None: + captured_completion_tokens = ct + + # Check if this is an error chunk + if "error" in chunk_data: + error_msg = chunk_data.get("error", {}).get("message", "") + + if is_guardrail_violation(error_msg): + logger.info("Guardrail policy violated in streaming response") + + guardrail_response = extract_guardrail_response(error_msg) + if guardrail_response: + guardrail_triggered = True + created = int(chunk_data.get("created", 0)) + yield from create_guardrail_streaming_response(guardrail_response, model, created) + return # Stop streaming β€” finally block publishes metrics + else: + yield f"{line}\n\n" else: - # Could not extract guardrail response, pass through the error yield f"{line}\n\n" else: - # Different error, pass it through yield f"{line}\n\n" - else: - # Normal chunk, pass it through - yield f"{line}\n\n" - except json.JSONDecodeError: - # Not valid JSON or not in expected format, pass through as-is + except json.JSONDecodeError: + yield f"{line}\n\n" + else: yield f"{line}\n\n" - else: - # Not in SSE format, pass through as-is - yield f"{line}\n\n" + finally: + # Always publish metrics after the stream ends, regardless of how the generator exits + # (normal exhaustion, guardrail early-return, or unexpected exception). + # When a guardrail fires, the model never completes its output so there are no token + # counts β€” pass None explicitly to skip token metrics for that case. + status_code = getattr(request.state, "upstream_status_code", HTTP_200_OK) + + publish_metrics_event( + request, + params, + status_code, + prompt_tokens=None if guardrail_triggered else captured_prompt_tokens, + completion_tokens=None if guardrail_triggered else captured_completion_tokens, + ) @router.api_route("/{api_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) @@ -419,6 +453,15 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: is_video_endpoint = "video" in api_path.lower() is_image_endpoint = "image" in api_path.lower() + # LiteLLM uses a well-known header to attribute requests to an end-user. + # We set it from the human-readable username derived from JWT/session. + lisa_username = get_lisa_end_user_id( + jwt_data=jwt_data, + state_username=getattr(request.state, "username", None), + ) + if lisa_username: + headers["x-litellm-end-user-id"] = lisa_username + # Handle multipart/form-data requests (video generation with image references, image edits) if is_multipart and (is_video_endpoint or is_image_endpoint): try: @@ -443,6 +486,9 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: # Create new headers without Content-Type (requests library will set it with correct boundary) forward_headers = {"Authorization": f"Bearer {LITELLM_KEY}"} + # Preserve end-user attribution header for multipart requests if it was set above. + if "x-litellm-end-user-id" in headers: + forward_headers["x-litellm-end-user-id"] = headers["x-litellm-end-user-id"] # Forward multipart request to LiteLLM response = requests_request( @@ -469,6 +515,11 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: logger.error(f"Invalid JSON in request body: {e}") raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Invalid JSON in request body") + # If the caller didn't already set OpenAI's "user" identifier, populate it + # with the human-readable LISA username so LiteLLM can surface it in logs. + if isinstance(params, dict) and lisa_username and (not params.get("user")): + params["user"] = lisa_username + # Get model info from LiteLLM to determine the actual model provider path model_id = params.get("model") model_name = None # The actual provider/model path (e.g., "bedrock/us.anthropic.claude...") @@ -508,15 +559,13 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: if guardrail_response: return guardrail_response - # Publish metrics for streaming chat completions (API users) + # Use token-capturing, guardrail-aware generator for chat/completions. + # The generator publishes the unified metrics event (including token counts) + # after the stream ends, so no separate publish_metrics_event call is needed here. if is_chat_completion and response.status_code == HTTP_200_OK: - publish_metrics_event(request, params, response.status_code) - - # Use guardrail-aware generator for chat/completions to catch violations in the stream - if is_chat_completion: model_id = params.get("model", "") return StreamingResponse( - generate_response_with_guardrail_handling(response.iter_lines(), model_id), + generate_response_with_guardrail_handling(response.iter_lines(), model_id, request, params), status_code=response.status_code, ) else: @@ -537,8 +586,11 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: if response.status_code != HTTP_200_OK: logger.error(f"LiteLLM error response: {response.text}") - # Publish metrics for chat completions (API users) + # Pass the parsed response body so tokens can be extracted from the usage field. + response_body = response.json() if response.status_code == HTTP_200_OK else None + + # Publish metrics for non-streaming chat completions (API users). if is_chat_completion: - publish_metrics_event(request, params, response.status_code) + publish_metrics_event(request, params, response.status_code, response_body=response_body) - return JSONResponse(response.json(), status_code=response.status_code) + return JSONResponse(response_body, status_code=response.status_code) diff --git a/lib/serve/rest-api/src/auth_provider.py b/lib/serve/rest-api/src/auth_provider.py index 311a18f91..f69b22b7b 100644 --- a/lib/serve/rest-api/src/auth_provider.py +++ b/lib/serve/rest-api/src/auth_provider.py @@ -122,7 +122,11 @@ class OIDCAuthorizationProvider(AuthorizationProvider): Uses JWT group claims to determine admin and app access. """ - def __init__(self, admin_group: str | None = None, user_group: str | None = None): + def __init__( + self, + admin_group: str | None = None, + user_group: str | None = None, + ): """Initialize the OIDC authorization provider. Parameters diff --git a/lib/serve/rest-api/src/entrypoint.sh b/lib/serve/rest-api/src/entrypoint.sh index b96de0a7c..82908f727 100644 --- a/lib/serve/rest-api/src/entrypoint.sh +++ b/lib/serve/rest-api/src/entrypoint.sh @@ -41,6 +41,26 @@ echo "--------------------------------" head -20 litellm_config.yaml echo "--------------------------------" +# If LiteLLM OTEL message_logging is enabled, OTEL console output may still omit +# request/response payloads unless LiteLLM is run in a more verbose mode. +# LiteLLM uses --detailed_debug for this. +LITELLM_DETAILED_DEBUG_ARGS="" +ENABLE_LITELLM_MESSAGE_LOGGING=$(python -c 'import yaml +d=yaml.safe_load(open("litellm_config.yaml")) or {} +cs=d.get("callback_settings") or {} +otel=(cs.get("otel") if isinstance(cs, dict) else {}) or {} +print("true" if otel.get("message_logging") else "false") +') +if [ "$ENABLE_LITELLM_MESSAGE_LOGGING" = "true" ]; then + echo " - LiteLLM OTEL message_logging enabled; adding --detailed_debug" + LITELLM_DETAILED_DEBUG_ARGS="--detailed_debug" + # Ensure LiteLLM logs at least INFO so message/response payloads can appear in OTEL console output. + # (Default is WARNING when DEBUG is not set.) + if [ -z "${LITELLM_LOG_LEVEL:-}" ]; then + export LITELLM_LOG_LEVEL="INFO" + fi +fi + # Configure logging behavior based on DEBUG environment variable # Set DEBUG=true in ECS task definition to enable debug logging for all services if [ "${DEBUG}" = "true" ]; then @@ -109,7 +129,7 @@ fi # Use --num_workers to increase parallelism for embedding requests LITELLM_WORKERS=${LITELLM_WORKERS:-4} echo " - LiteLLM workers: $LITELLM_WORKERS" -litellm -c litellm_config.yaml --use_prisma_db_push --num_workers "$LITELLM_WORKERS" > litellm.log 2>&1 & +litellm -c litellm_config.yaml --use_prisma_db_push --num_workers "$LITELLM_WORKERS" $LITELLM_DETAILED_DEBUG_ARGS > litellm.log 2>&1 & LITELLM_PID=$! echo " - LiteLLM PID: $LITELLM_PID" diff --git a/lib/serve/rest-api/src/main.py b/lib/serve/rest-api/src/main.py index a8640bfaf..89d2151b1 100644 --- a/lib/serve/rest-api/src/main.py +++ b/lib/serve/rest-api/src/main.py @@ -28,6 +28,7 @@ security_middleware, validate_input_middleware, ) +from starlette.types import ASGIApp, Receive, Scope, Send logger.remove() logger_level = os.environ.get("LOG_LEVEL", "INFO") @@ -114,3 +115,38 @@ async def security_check(request, call_next): # type: ignore - Request size is within limits (model proxy endpoints are exempt) """ return await security_middleware(request, call_next) + + +def _parse_asgi_spec_version(spec_version: str) -> tuple[int, ...]: + """Parse ASGI spec_version like '2.4' into a tuple for comparison.""" + try: + return tuple(int(p) for p in spec_version.split(".") if p != "") + except ValueError: + return (2, 0) + + +class EnsureAsgiHttpSpec24Middleware: + """Normalize HTTP scope's ASGI spec_version to >= 2.4. + + Starlette 0.49+ ``StreamingResponse`` runs ``listen_for_disconnect`` in parallel with + the body iterator only when ``scope['asgi']['spec_version']`` is below 2.4. + In some deployments this can race with ``BaseHTTPMiddleware`` and raise: + + RuntimeError: Unexpected message received: http.request + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope.get("type") == "http": + asgi = scope.setdefault("asgi", {}) + raw = str(asgi.get("spec_version", "2.0")) + if _parse_asgi_spec_version(raw) < (2, 4): + asgi["spec_version"] = "2.4" + await self.app(scope, receive, send) + + +# Wrap the fully-built FastAPI app (Gunicorn imports ``app`` from this module). +_built_asgi_app: ASGIApp = app +app = EnsureAsgiHttpSpec24Middleware(_built_asgi_app) diff --git a/lib/serve/rest-api/src/middleware/auth_middleware.py b/lib/serve/rest-api/src/middleware/auth_middleware.py index 2119cde96..4fe7b38c7 100644 --- a/lib/serve/rest-api/src/middleware/auth_middleware.py +++ b/lib/serve/rest-api/src/middleware/auth_middleware.py @@ -100,7 +100,10 @@ async def auth_middleware(request: Request, call_next: Callable[[Request], Respo request.state.is_admin = authorizer.auth_provider.check_admin_access_jwt( jwt_data, authorizer.jwt_groups_property ) - request.state.username = jwt_data.get("sub", jwt_data.get("username", "unknown")) + # Resolve username: prefer cognito:username / username over the opaque UUID sub + request.state.username = ( + jwt_data.get("cognito:username") or jwt_data.get("username") or jwt_data.get("sub", "unknown") + ) request.state.groups = _extract_groups_from_jwt(jwt_data, authorizer.jwt_groups_property) elif hasattr(request.state, "api_token_info"): # API token auth diff --git a/lib/serve/rest-api/src/requirements.txt b/lib/serve/rest-api/src/requirements.txt index 3df1ea57e..a6eb56f9e 100644 --- a/lib/serve/rest-api/src/requirements.txt +++ b/lib/serve/rest-api/src/requirements.txt @@ -1,29 +1,28 @@ # AWS SDK Dependencies -# boto3 version pinned by litellm[proxy]==1.81.3 for RDS IAM token refresh -boto3==1.40.76 +# Keep boto3 aligned with LiteLLM / RDS IAM token refresh expectations (loose range, not hard-pinned). +boto3>=1.40.76,<1.41.0 # OpenTelemetry - Optional for LiteLLM Weave integration, silences import warnings opentelemetry-api>=1.20.0 opentelemetry-sdk>=1.20.0 -aiohttp==3.13.3 -backoff==2.2.1 -cachetools==7.0.2 -click==8.3.1 -cryptography==46.0.5 +aiohttp>=3.13.3,<4.0.0 +backoff>=2.2.1,<3.0.0 +cachetools>=7.0.2,<8.0.0 +click>=8.3.0,<9.0.0 +cryptography>=46.0.5,<47.0.0 fastapi>=0.120.1 gunicorn>=23.0.0,<24.0.0 -# LiteLLM - Upgraded to 1.81.3 for RDS IAM token refresh fix (PR #18795) -# Fixes: "All connection attempts failed" errors every 15 minutes with IAM auth -litellm[proxy]==1.81.3 +# LiteLLM β€” pinned by policy (orchestration / proxy compatibility). +# Fixes: "All connection attempts failed" errors every 15 minutes with IAM auth (older LiteLLM). +litellm[proxy]==1.82.4 -loguru==0.7.3 +loguru>=0.7.3,<0.8.0 pydantic>=2.5.0,<3.0.0 PyJWT>=2.10.1,<3.0.0 -prisma==0.15.0 +prisma>=0.15.0,<0.16.0 starlette>=0.40.0,<0.51.0 -# ASGI Server - Version constrained by litellm[proxy]==1.81.3 -# litellm requires uvicorn>=0.31.1,<0.32.0 -uvicorn>=0.31.1,<0.32.0 +# ASGI Server β€” litellm[proxy] extra requires uvicorn>=0.32.1,<1.0.0 +uvicorn>=0.32.1,<1.0.0 diff --git a/lib/serve/rest-api/src/utils/metrics.py b/lib/serve/rest-api/src/utils/metrics.py index 417a5e95f..15de724f7 100644 --- a/lib/serve/rest-api/src/utils/metrics.py +++ b/lib/serve/rest-api/src/utils/metrics.py @@ -14,15 +14,15 @@ """Metrics utilities for publishing usage data.""" -import json import logging import os import uuid from datetime import datetime import boto3 -from auth import get_user_context +from auth import get_user_context, is_api_user from fastapi import Request +from utils.metrics_models import MetricsEvent logger = logging.getLogger(__name__) @@ -92,14 +92,50 @@ def extract_messages_for_metrics(params: dict) -> list[dict]: return formatted_messages -def publish_metrics_event(request: Request, params: dict, response_status: int) -> None: +def extract_token_usage(response_body: dict | None) -> tuple[int | None, int | None]: """ - Publish metrics event to SQS queue for API users + Extract token usage from a LLM response body (non-streaming or SSE chunk). + + The usage structure is identical in both cases β€” LiteLLM normalises it: + {"usage": {"prompt_tokens": N, "completion_tokens": N, ...}, ...} + + Args: + response_body: The parsed JSON response or SSE chunk from LiteLLM + + Returns: + Tuple of (prompt_tokens, completion_tokens), each int or None. + """ + if not response_body: + return None, None + + usage = response_body.get("usage") + if not usage: + return None, None + + return usage.get("prompt_tokens"), usage.get("completion_tokens") + + +def publish_metrics_event( + request: Request, + params: dict, + response_status: int, + response_body: dict | None = None, + prompt_tokens: int | None = None, + completion_tokens: int | None = None, +) -> None: + """ + Publish metrics event to SQS queue for API users. + + Includes both message-level metrics (for prompt/RAG/MCP counting) and + token-level metrics (prompt_tokens, completion_tokens) if available. Args: request: The FastAPI request object - params: The request parameters + params: The request parameters (contains messages and model) response_status: HTTP response status code + response_body: Optional parsed response JSON (used to extract tokens for non-streaming) + prompt_tokens: Optional prompt token count (provided directly for streaming) + completion_tokens: Optional completion token count (provided directly for streaming) """ # Only publish metrics for successful completions if response_status != 200: @@ -113,24 +149,55 @@ def publish_metrics_event(request: Request, params: dict, response_status: int) try: username, groups = get_user_context(request) - messages = extract_messages_for_metrics(params) - - # Generate a synthetic session ID for API users - session_id = f"api-{int(datetime.now().timestamp())}-{uuid.uuid4().hex[:8]}" - - # Create metrics event in the same format as session lambda - metrics_event = { - "userId": username, - "sessionId": session_id, - "messages": messages, - "userGroups": groups, - "timestamp": datetime.now().isoformat(), - } - - # Publish to SQS - sqs_client.send_message(QueueUrl=queue_url, MessageBody=json.dumps(metrics_event)) - - logger.info(f"Published metrics event for API user: {username}") + model_id = params.get("model") + + # If token counts were not passed directly, try to extract from response_body + if prompt_tokens is None and response_body is not None: + prompt_tokens, completion_tokens = extract_token_usage(response_body) + + is_jwt_user = not is_api_user(request) + + if is_jwt_user: + # JWT/UI user: the session lambda already publishes prompt/RAG/MCP metrics with the + # real sessionId. The passthrough only supplies the token counts that the session + # lambda cannot see (they come from the LLM response, not the session history). + # Skip entirely if there are no token counts to add. + if prompt_tokens is None and completion_tokens is None: + logger.debug("No token data for JWT user, skipping passthrough metrics publish") + return + messages = [] # Prevent double-counting prompts β€” session lambda owns this + session_id = f"ui-tokens-{uuid.uuid4().hex}" + event_type = "token_only" + else: + # API token user: publish full messages + tokens. + # The session lambda does not run for API users, so the passthrough owns all metrics. + messages = extract_messages_for_metrics(params) + session_id = f"api-{uuid.uuid4().hex}" + event_type = "full" + + # Build and validate the event through the Pydantic model before publishing + metrics_event = MetricsEvent( + userId=username, + sessionId=session_id, + messages=messages, + userGroups=groups, + timestamp=datetime.now().isoformat(), + eventType=event_type, + modelId=model_id, + promptTokens=prompt_tokens, + completionTokens=completion_tokens, + ) + + # Publish to SQS β€” exclude None fields to keep the message lean + sqs_client.send_message( + QueueUrl=queue_url, + MessageBody=metrics_event.model_dump_json(exclude_none=True), + ) + + logger.info( + f"Published metrics event for user: {username} " + f"tokens: prompt={prompt_tokens} completion={completion_tokens}" + ) except Exception as e: # Don't fail the request if metrics publishing fails diff --git a/lib/serve/rest-api/src/utils/metrics_models.py b/lib/serve/rest-api/src/utils/metrics_models.py new file mode 100644 index 000000000..90b92c22a --- /dev/null +++ b/lib/serve/rest-api/src/utils/metrics_models.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic models for metrics events published to SQS. + +This module is intentionally kept in sync with lambda/metrics/models.py. +The two files live in separate deployment contexts (FastAPI container vs Lambda) +and cannot share code directly, so any schema changes must be applied to both. +""" + +from typing import Any + +from pydantic import BaseModel + + +class MetricsEvent(BaseModel): + """Event model for usage metrics published to SQS. + + event_type : str + "full" β€” API token user or session-lambda event; owns all metrics. + "token_only" β€” JWT/UI passthrough event; only carries token counts, session + lambda already counted the prompts. Do not write a sessionMetrics + entry β€” that would create synthetic sessions and pollute aggregation. + """ + + userId: str + sessionId: str + messages: list[dict[str, Any]] + userGroups: list[str] + timestamp: str + eventType: str = "full" + modelId: str | None = None + promptTokens: int | None = None + completionTokens: int | None = None diff --git a/lib/serve/rest-api/src/utils/request_utils.py b/lib/serve/rest-api/src/utils/request_utils.py index 46f4f045f..2da9b1c65 100644 --- a/lib/serve/rest-api/src/utils/request_utils.py +++ b/lib/serve/rest-api/src/utils/request_utils.py @@ -91,3 +91,43 @@ async def wrapper(*args: Any, **kwargs: Any) -> AsyncGenerator[str]: yield f"data:{json.dumps({'event': 'error', 'data': error_message})}\n\n" return wrapper + + +def get_lisa_end_user_id( + jwt_data: dict[str, Any] | None, + state_username: str | None, +) -> str | None: + """ + Derive a human-readable end-user id for logs/spend attribution. + + LiteLLM uses the provided end-user identifier for spend/budget/logging. + We prefer the same claims used by the authorizer/session to make the + logs match what admins see in the UI/session DB. + + Precedence (highest to lowest): + 1. jwt_data["cognito:username"] (if present) + 2. jwt_data["username"] (if present and no cognito:username) + 3. jwt_data["sub"] + 4. fallback to state_username (request.state.username) + """ + candidate: str | None = None + if isinstance(jwt_data, dict): + username = jwt_data.get("username") + candidate = username if isinstance(username, str) and username else None + + cognito_username = jwt_data.get("cognito:username") + if isinstance(cognito_username, str) and cognito_username: + candidate = cognito_username + + if not candidate: + sub = jwt_data.get("sub") + if isinstance(sub, str) and sub: + candidate = sub + + if candidate: + return candidate + + if isinstance(state_username, str) and state_username: + return state_username + + return None diff --git a/lib/serve/serveApplicationConstruct.ts b/lib/serve/serveApplicationConstruct.ts index fc4778cf3..7d7f068e8 100644 --- a/lib/serve/serveApplicationConstruct.ts +++ b/lib/serve/serveApplicationConstruct.ts @@ -14,6 +14,7 @@ limitations under the License. */ import { Duration, RemovalPolicy, Stack, StackProps } from 'aws-cdk-lib'; +import * as cloudwatch from 'aws-cdk-lib/aws-cloudwatch'; import { ITable, Table } from 'aws-cdk-lib/aws-dynamodb'; import { Credentials, DatabaseInstance, DatabaseInstanceEngine } from 'aws-cdk-lib/aws-rds'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; @@ -27,6 +28,7 @@ import { Effect, Policy, PolicyStatement, + Role, } from 'aws-cdk-lib/aws-iam'; import { HostedRotation } from 'aws-cdk-lib/aws-secretsmanager'; import { SecurityGroupEnum } from '../core/iam/SecurityGroups'; @@ -36,7 +38,6 @@ import { AwsCustomResource, AwsCustomResourcePolicy, PhysicalResourceId } from ' import { ISecurityGroup, Port } from 'aws-cdk-lib/aws-ec2'; import { ECSTasks } from '../api-base/ecsCluster'; import { GuardrailsTable } from '../models/guardrails-table'; -import { Role } from 'aws-cdk-lib/aws-iam'; export type LisaServeApplicationProps = { vpc: Vpc; @@ -428,6 +429,69 @@ export class LisaServeApplicationConstruct extends Construct { serveRole.attachInlinePolicy(invocation_permissions); } } - }; + + // ===================================================================== + // REST API ALB Alarms + // ===================================================================== + // These alarms use the REST API ALB's concrete dimensions (known at deploy + // time). Model ALB alarms are not created here because model ALBs are + // dynamic and CloudWatch does not support SEARCH in Metric Alarms. + // Model ALB health is monitored via SEARCH-based dashboard widgets in + // the ModelHealthDashboard. + if (config.deployHealthDashboard) { + const restAlb = restApi.apiCluster.loadBalancer; + const restAlbFullName = restAlb.loadBalancerFullName; + const alarmPrefix = `${config.deploymentName}-${config.deploymentStage}-LISA`; + + new cloudwatch.Alarm(scope, 'RestApi-ELB5xxAlarm', { + alarmName: `${alarmPrefix}-RestApi-ELB5xxErrors`, + alarmDescription: 'REST API ALB is returning 5xx errors, typically meaning no healthy targets are available.', + metric: new cloudwatch.Metric({ + namespace: 'AWS/ApplicationELB', + metricName: 'HTTPCode_ELB_5XX_Count', + dimensionsMap: { LoadBalancer: restAlbFullName }, + statistic: 'Sum', + period: Duration.minutes(5), + }), + threshold: 5, + comparisonOperator: cloudwatch.ComparisonOperator.GREATER_THAN_THRESHOLD, + evaluationPeriods: 2, + treatMissingData: cloudwatch.TreatMissingData.NOT_BREACHING, + }); + + new cloudwatch.Alarm(scope, 'RestApi-HighLatencyAlarm', { + alarmName: `${alarmPrefix}-RestApi-HighP99Latency`, + alarmDescription: 'REST API p99 response time exceeds 120 seconds. The API may be overloaded.', + metric: new cloudwatch.Metric({ + namespace: 'AWS/ApplicationELB', + metricName: 'TargetResponseTime', + dimensionsMap: { LoadBalancer: restAlbFullName }, + statistic: 'p99', + period: Duration.minutes(5), + }), + threshold: 120, + comparisonOperator: cloudwatch.ComparisonOperator.GREATER_THAN_THRESHOLD, + evaluationPeriods: 3, + treatMissingData: cloudwatch.TreatMissingData.NOT_BREACHING, + }); + + new cloudwatch.Alarm(scope, 'RestApi-RejectedConnectionsAlarm', { + alarmName: `${alarmPrefix}-RestApi-RejectedConnections`, + alarmDescription: 'REST API ALB is rejecting connections, indicating the API is at maximum capacity.', + metric: new cloudwatch.Metric({ + namespace: 'AWS/ApplicationELB', + metricName: 'RejectedConnectionCount', + dimensionsMap: { LoadBalancer: restAlbFullName }, + statistic: 'Sum', + period: Duration.minutes(5), + }), + threshold: 0, + comparisonOperator: cloudwatch.ComparisonOperator.GREATER_THAN_THRESHOLD, + evaluationPeriods: 2, + treatMissingData: cloudwatch.TreatMissingData.NOT_BREACHING, + }); + } + + } } diff --git a/lib/stages.ts b/lib/stages.ts index 8b217f19c..f1236ba73 100644 --- a/lib/stages.ts +++ b/lib/stages.ts @@ -270,6 +270,7 @@ export class LisaServeApplicationStage extends Stage { ...baseStackProps, stackName: createCdkId([config.deploymentName, config.appName, 'API']), description: `LISA-API: ${config.deploymentName}-${config.deploymentStage}`, + bucketAccessLogsBucket: coreStack.loggingBucket, vpc: networkingStack.vpc, securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], }); @@ -306,6 +307,7 @@ export class LisaServeApplicationStage extends Stage { const mcpApiStack = new LisaMcpApiStack(this, 'LisaMcpApi', { ...baseStackProps, authorizer: apiBaseStack.authorizer!, + bucketAccessLogsBucket: coreStack.loggingBucket, description: `LISA-mcp: ${config.deploymentName}-${config.deploymentStage}`, restApiId: apiBaseStack.restApiId, rootResourceId: apiBaseStack.rootResourceId, @@ -337,6 +339,7 @@ export class LisaServeApplicationStage extends Stage { } if (config.deployServe) { + let mcpWorkbenchStackInstance: McpWorkbenchStack | undefined; const serveStack = new LisaServeApplicationStack(this, 'LisaServe', { ...baseStackProps, description: `LISA-serve: ${config.deploymentName}-${config.deploymentStage}`, @@ -359,6 +362,7 @@ export class LisaServeApplicationStage extends Stage { const modelsApiDeploymentStack = new LisaModelsApiStack(this, 'LisaModelsApiDeployment', { ...baseStackProps, authorizer: apiBaseStack.authorizer, + bucketAccessLogsBucket: coreStack.loggingBucket, description: `LISA-models: ${config.deploymentName}-${config.deploymentStage}`, lisaServeEndpointUrlPs: config.restApiConfig.internetFacing ? serveStack.endpointUrl : undefined, guardrailsTable: serveStack.guardrailsTable, @@ -378,27 +382,28 @@ export class LisaServeApplicationStage extends Stage { this.stacks.push(modelsApiDeploymentStack); if (config.deployMcpWorkbench) { - const mcpWorkbenchStack = new McpWorkbenchStack(this, 'LisaMcpWorkbench', { + mcpWorkbenchStackInstance = new McpWorkbenchStack(this, 'LisaMcpWorkbench', { ...baseStackProps, + bucketAccessLogsBucket: coreStack.loggingBucket, stackName: createCdkId([config.deploymentName, config.appName, 'mcp-workbench', config.deploymentStage]), description: `LISA-mcp-workbench: ${config.deploymentName}-${config.deploymentStage}`, vpc: networkingStack.vpc, restApiId: apiBaseStack.restApiId, rootResourceId: apiBaseStack.rootResourceId, - apiCluster: serveStack.restApi.apiCluster, authorizer: apiBaseStack.authorizer, }); - mcpWorkbenchStack.addDependency(coreStack); - mcpWorkbenchStack.addDependency(apiBaseStack); - mcpWorkbenchStack.addDependency(serveStack); - apiDeploymentStack.addDependency(mcpWorkbenchStack); - this.stacks.push(mcpWorkbenchStack); + mcpWorkbenchStackInstance.addDependency(coreStack); + mcpWorkbenchStackInstance.addDependency(apiBaseStack); + apiDeploymentStack.addDependency(mcpWorkbenchStackInstance); + this.stacks.push(mcpWorkbenchStackInstance); + serveStack.addDependency(mcpWorkbenchStackInstance); } if (config.deployRag) { const ragStack = new LisaRagStack(this, 'LisaRAG', { ...baseStackProps, authorizer: apiBaseStack.authorizer!, + bucketAccessLogsBucket: coreStack.loggingBucket, description: `LISA-rag: ${config.deploymentName}-${config.deploymentStage}`, restApiId: apiBaseStack.restApiId, rootResourceId: apiBaseStack.rootResourceId, @@ -438,6 +443,10 @@ export class LisaServeApplicationStage extends Stage { chatStack.addDependency(modelsApiDeploymentStack); // ChatStack reads: serve/endpoint from ServeStack chatStack.addDependency(serveStack); + // ChatStack reads: mcpWorkbench/endpoint when MCP Workbench is deployed + if (mcpWorkbenchStackInstance) { + chatStack.addDependency(mcpWorkbenchStackInstance); + } // ChatStack reads: queue-name/usage-metrics from MetricsStack (if deployMetrics) if (metricsStack) { chatStack.addDependency(metricsStack); @@ -449,6 +458,7 @@ export class LisaServeApplicationStage extends Stage { const uiStack = new UserInterfaceStack(this, 'LisaUserInterface', { ...baseStackProps, architecture: ARCHITECTURE, + bucketAccessLogsBucket: coreStack.loggingBucket, stackName: createCdkId([config.deploymentName, config.appName, 'ui', config.deploymentStage]), description: `LISA-user-interface: ${config.deploymentName}-${config.deploymentStage}`, restApiId: apiBaseStack.restApiId, @@ -460,6 +470,10 @@ export class LisaServeApplicationStage extends Stage { // UIStack reads: lisaServeRestApiUri from ServeStack uiStack.addDependency(serveStack); uiStack.addDependency(apiBaseStack); + // UIStack reads: mcpWorkbench/endpoint when MCP Workbench is deployed (AWS session + MCP browser calls) + if (mcpWorkbenchStackInstance) { + uiStack.addDependency(mcpWorkbenchStackInstance); + } apiDeploymentStack.addDependency(uiStack); this.stacks.push(uiStack); } @@ -469,7 +483,8 @@ export class LisaServeApplicationStage extends Stage { if (config.deployDocs) { const docsStack = new LisaDocsStack(this, 'LisaDocs', { - ...baseStackProps + ...baseStackProps, + bucketAccessLogsBucket: coreStack.loggingBucket, }); // DocsStack reads: bucket/bucket-access-logs from CoreStack docsStack.addDependency(coreStack); diff --git a/lib/user-interface/react/index.html b/lib/user-interface/react/index.html index 0dbad35ba..3a149fd90 100644 --- a/lib/user-interface/react/index.html +++ b/lib/user-interface/react/index.html @@ -5,31 +5,9 @@ AWS LISA AI Chat Assistant - - - - +
- diff --git a/lib/user-interface/react/package.json b/lib/user-interface/react/package.json index 47fc88615..80bd31dc7 100644 --- a/lib/user-interface/react/package.json +++ b/lib/user-interface/react/package.json @@ -1,7 +1,7 @@ { "name": "lisa-web", "private": true, - "version": "6.4.0", + "version": "6.5.0", "type": "module", "scripts": { "postinstall": "patch-package", @@ -65,12 +65,14 @@ "tinyglobby": "^0.2.15", "typescript": "~5.9.3", "unraw": "^3.0.0", + "@modelcontextprotocol/sdk": "~1.27.1", "use-mcp": "^0.0.21", "vitepress": "^1.6.4", "web-namespaces": "^2.0.1", "zod": "^4.1.13" }, "devDependencies": { + "@rolldown/pluginutils": "^1.0.0-beta.47", "@tailwindcss/vite": "^4.1.18", "@testing-library/jest-dom": "^6.9.1", "@testing-library/react": "^16.3.0", @@ -81,12 +83,9 @@ "@types/react": "^19.2.9", "@types/react-dom": "^19.2.3", "@types/redux-mock-store": "^1.5.0", - "@types/redux-persist": "^4.3.1", - "@types/uuid": "^11.0.0", "@typescript-eslint/eslint-plugin": "^8.49.0", "@typescript-eslint/parser": "^8.49.0", "@vitejs/plugin-react-swc": "^4.2.2", - "@vitest/coverage-istanbul": "^4.0.15", "@vitest/coverage-v8": "^4.0.15", "@vitest/ui": "^4.0.15", "eslint": "^10.0.2", diff --git a/lib/user-interface/react/src/App.test.tsx b/lib/user-interface/react/src/App.test.tsx new file mode 100644 index 000000000..d59c4b2b5 --- /dev/null +++ b/lib/user-interface/react/src/App.test.tsx @@ -0,0 +1,190 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { render, screen } from '@testing-library/react'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { MemoryRouter } from 'react-router-dom'; +import { Provider } from 'react-redux'; +import { configureStore } from '@reduxjs/toolkit'; + +import App from './App'; +import { + selectCurrentUserIsAdmin, + selectCurrentUserIsUser, + selectCurrentUserIsRagAdmin, + selectCurrentUserIsApiUser, + selectCurrentUsername, +} from './shared/reducers/user.reducer'; + +// Mock auth +vi.mock('./auth/useAuth'); + +// Mock store - useAppSelector matches by selector function reference +vi.mock('./config/store', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + useAppDispatch: vi.fn(() => vi.fn()), + useAppSelector: vi.fn(), + }; +}); + +// Mock lazy-loaded pages to avoid Suspense complexity +vi.mock('./pages/Home', () => ({ default: () =>
Home
})); +vi.mock('./pages/Chatbot', () => ({ default: () =>
Chatbot
})); +vi.mock('./pages/RepositoryManagement', () => ({ default: () =>
Repository Management
})); +vi.mock('./pages/ModelManagement', () => ({ default: () =>
Model Management
})); +vi.mock('./pages/Configuration', () => ({ default: () =>
Configuration
})); +vi.mock('./pages/ApiTokenManagement', () => ({ default: () =>
API Token Management
})); + +// Mock configuration query +vi.mock('./shared/reducers/configuration.reducer', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + useGetConfigurationQuery: vi.fn(() => ({ data: undefined, isLoading: false })), + }; +}); + +// Mock notification hook +vi.mock('./shared/hooks/useAnnouncementNotifier', () => ({ + useAnnouncementNotifier: vi.fn(), +})); + +// Mock Topbar to simplify rendering +vi.mock('./components/Topbar', () => ({ default: () =>
Topbar
})); + +// Mock system banner +vi.mock('./components/system-banner/system-banner', () => ({ default: () => null })); + +// Mock notification banner +vi.mock('./shared/notification/notification', () => ({ default: () => null })); + +// Mock confirmation modal +vi.mock('./shared/modal/confirmation-modal', () => ({ default: () => null })); + +// Mock breadcrumbs +vi.mock('./shared/breadcrumb/breadcrumbs', () => ({ Breadcrumbs: () => null })); +vi.mock('./shared/breadcrumb/breadcrumbs-change-listener', () => ({ default: () => null })); + +// Helper to create selector mock for different role combinations +type RoleMockConfig = { + isAdmin?: boolean; + isUser?: boolean; + isRagAdmin?: boolean; + isApiUser?: boolean; +}; + +const createSelectorMock = (roles: RoleMockConfig) => { + return (selector: any) => { + if (selector === selectCurrentUserIsAdmin) return roles.isAdmin ?? false; + if (selector === selectCurrentUserIsRagAdmin) return roles.isRagAdmin ?? false; + if (selector === selectCurrentUserIsUser) return roles.isUser ?? false; + if (selector === selectCurrentUserIsApiUser) return roles.isApiUser ?? false; + if (selector === selectCurrentUsername) return 'Test User'; + // Inline selectors (e.g., confirmationModal) β€” return safe defaults + return null; + }; +}; + +const mockStore = configureStore({ + reducer: { + user: () => ({ info: undefined }), + modal: () => ({ confirmationModal: null }), + }, +}); + +const renderApp = (route: string) => { + return render( + + + + + + ); +}; + +describe('Route Guards', () => { + beforeEach(async () => { + vi.clearAllMocks(); + (window as any).env = { + ...window.env, + RAG_ENABLED: true, + HOSTED_MCP_ENABLED: false, + }; + }); + + describe('RagAdminRoute (/repository-management)', () => { + it('renders children when user isRagAdmin', async () => { + const { useAuth } = await import('./auth/useAuth'); + (useAuth as any).mockReturnValue({ isAuthenticated: true, isLoading: false }); + + const { useAppSelector } = await import('./config/store'); + (useAppSelector as any).mockImplementation(createSelectorMock({ isRagAdmin: true })); + + renderApp('/repository-management'); + expect(await screen.findByTestId('repo-management-page')).toBeInTheDocument(); + }); + + it('renders children when user isAdmin', async () => { + const { useAuth } = await import('./auth/useAuth'); + (useAuth as any).mockReturnValue({ isAuthenticated: true, isLoading: false }); + + const { useAppSelector } = await import('./config/store'); + (useAppSelector as any).mockImplementation(createSelectorMock({ isAdmin: true })); + + renderApp('/repository-management'); + expect(await screen.findByTestId('repo-management-page')).toBeInTheDocument(); + }); + + it('redirects when user is regular user', async () => { + const { useAuth } = await import('./auth/useAuth'); + (useAuth as any).mockReturnValue({ isAuthenticated: true, isLoading: false }); + + const { useAppSelector } = await import('./config/store'); + (useAppSelector as any).mockImplementation(createSelectorMock({ isUser: true })); + + renderApp('/repository-management'); + expect(screen.queryByTestId('repo-management-page')).not.toBeInTheDocument(); + }); + }); + + describe('PrivateRoute (/ai-assistant)', () => { + it('renders children when user isRagAdmin', async () => { + const { useAuth } = await import('./auth/useAuth'); + (useAuth as any).mockReturnValue({ isAuthenticated: true, isLoading: false }); + + const { useAppSelector } = await import('./config/store'); + (useAppSelector as any).mockImplementation(createSelectorMock({ isRagAdmin: true })); + + renderApp('/ai-assistant'); + expect(await screen.findByTestId('chatbot-page')).toBeInTheDocument(); + }); + }); + + describe('AdminRoute (/model-management)', () => { + it('blocks rag-admin from admin-only routes', async () => { + const { useAuth } = await import('./auth/useAuth'); + (useAuth as any).mockReturnValue({ isAuthenticated: true, isLoading: false }); + + const { useAppSelector } = await import('./config/store'); + (useAppSelector as any).mockImplementation(createSelectorMock({ isRagAdmin: true })); + + renderApp('/model-management'); + expect(screen.queryByTestId('model-management-page')).not.toBeInTheDocument(); + }); + }); +}); diff --git a/lib/user-interface/react/src/App.tsx b/lib/user-interface/react/src/App.tsx index ac860dfbd..95c3e8daa 100644 --- a/lib/user-interface/react/src/App.tsx +++ b/lib/user-interface/react/src/App.tsx @@ -15,43 +15,43 @@ */ import 'regenerator-runtime/runtime'; -import { ReactElement, useEffect, useState } from 'react'; +import { lazy, ReactElement, Suspense, useEffect, useState } from 'react'; import { Navigate, Route, Routes } from 'react-router-dom'; import { AppLayout, Box } from '@cloudscape-design/components'; import Spinner from '@cloudscape-design/components/spinner'; import { useAuth } from './auth/useAuth'; -import Home from './pages/Home'; -import Chatbot from './pages/Chatbot'; import Topbar from './components/Topbar'; import SystemBanner from './components/system-banner/system-banner'; import { useAppSelector } from './config/store'; -import { selectCurrentUserIsAdmin, selectCurrentUserIsUser, selectCurrentUserIsApiUser } from './shared/reducers/user.reducer'; -import ModelManagement from './pages/ModelManagement'; -import McpManagement from './pages/McpManagement'; -import ModelLibrary from './pages/ModelLibrary'; -import RepositoryManagement from './pages/RepositoryManagement'; -import ApiTokenManagement from './pages/ApiTokenManagement'; -import UserApiToken from './pages/UserApiToken'; +import { selectCurrentUserIsAdmin, selectCurrentUserIsUser, selectCurrentUserIsApiUser, selectCurrentUserIsRagAdmin } from './shared/reducers/user.reducer'; import NotificationBanner from './shared/notification/notification'; import ConfirmationModal, { ConfirmationModalProps } from './shared/modal/confirmation-modal'; -import Configuration from './pages/Configuration'; -import ChatAssistantStacks from './pages/ChatAssistantStacks'; import { useGetConfigurationQuery } from './shared/reducers/configuration.reducer'; import { IConfiguration } from './shared/model/configuration.model'; -import DocumentLibrary from './pages/DocumentLibrary'; -import CollectionLibrary from './pages/CollectionLibrary'; import { Breadcrumbs } from './shared/breadcrumb/breadcrumbs'; import BreadcrumbsDefaultChangeListener from './shared/breadcrumb/breadcrumbs-change-listener'; -import PromptTemplatesLibrary from './pages/PromptTemplatesLibrary'; import { ConfigurationContext } from './shared/configuration.provider'; -import McpServers from '@/pages/Mcp'; -import ModelComparisonPage from './pages/ModelComparison'; -import McpWorkbench from './pages/McpWorkbench'; import ColorSchemeContext from './shared/color-scheme.provider'; import { applyMode, Mode } from '@cloudscape-design/global-styles'; import { useAnnouncementNotifier } from './shared/hooks/useAnnouncementNotifier'; +const Home = lazy(() => import('./pages/Home')); +const Chatbot = lazy(() => import('./pages/Chatbot')); +const ModelManagement = lazy(() => import('./pages/ModelManagement')); +const McpManagement = lazy(() => import('./pages/McpManagement')); +const ModelLibrary = lazy(() => import('./pages/ModelLibrary')); +const RepositoryManagement = lazy(() => import('./pages/RepositoryManagement')); +const ApiTokenManagement = lazy(() => import('./pages/ApiTokenManagement')); +const UserApiToken = lazy(() => import('./pages/UserApiToken')); +const Configuration = lazy(() => import('./pages/Configuration')); +const DocumentLibrary = lazy(() => import('./pages/DocumentLibrary')); +const CollectionLibrary = lazy(() => import('./pages/CollectionLibrary')); +const PromptTemplatesLibrary = lazy(() => import('./pages/PromptTemplatesLibrary')); +const McpServers = lazy(() => import('@/pages/Mcp')); +const ModelComparisonPage = lazy(() => import('./pages/ModelComparison')); +const McpWorkbench = lazy(() => import('./pages/McpWorkbench')); +const ChatAssistantStacks = lazy(() => import('./pages/ChatAssistantStacks')); export type RouteProps = { children: ReactElement[] | ReactElement; @@ -64,12 +64,13 @@ const PrivateRoute = ({ children }: RouteProps) => { const auth = useAuth(); const isUserAdmin = useAppSelector(selectCurrentUserIsAdmin); const isUser = useAppSelector(selectCurrentUserIsUser); + const isRagAdmin = useAppSelector(selectCurrentUserIsRagAdmin); - if (auth.isAuthenticated && (isUserAdmin || isUser)) { + if (auth.isAuthenticated && (isUserAdmin || isUser || isRagAdmin)) { return children; } else if (auth.isLoading) { return ; - } else if (auth.isAuthenticated && !isUserAdmin && !isUser) { + } else if (auth.isAuthenticated && !isUserAdmin && !isUser && !isRagAdmin) { return (

Access Denied

@@ -93,6 +94,19 @@ const AdminRoute = ({ children }: RouteProps) => { } }; +const RagAdminRoute = ({ children }: RouteProps) => { + const auth = useAuth(); + const isUserAdmin = useAppSelector(selectCurrentUserIsAdmin); + const isRagAdmin = useAppSelector(selectCurrentUserIsRagAdmin); + if (auth.isAuthenticated && (isUserAdmin || isRagAdmin)) { + return children; + } else if (auth.isLoading) { + return ; + } else { + return ; + } +}; + const ApiUserRoute = ({ children }: RouteProps) => { const auth = useAuth(); const isUserAdmin = useAppSelector(selectCurrentUserIsAdmin); @@ -106,6 +120,13 @@ const ApiUserRoute = ({ children }: RouteProps) => { } }; +const RouteLoadingFallback = () => ( +
+ + Loading page... +
+); + function App () { const [nav, setNav] = useState(null); const confirmationModal: ConfirmationModalProps = useAppSelector((state) => state.modal.confirmationModal); @@ -158,153 +179,155 @@ function App () { navigation={nav} navigationWidth={300} content={ - - - - - } - /> - - - - } - /> - - - - } - /> - {window.env.HOSTED_MCP_ENABLED && - - - } - />} - {window.env.RAG_ENABLED && - - - } - />} - - - - } - /> - }> + + + + + } + /> + + + + } + /> + - + - ) : ( - - ) - } - /> - {config?.configuration?.enabledComponents?.enableUserApiTokens && - - - } - />} - {config?.configuration?.enabledComponents?.modelLibrary && - - - } - />} - {config?.configuration?.enabledComponents?.showRagLibrary && - <> - - - - } - /> - - - - } - /> - } - {config?.configuration?.enabledComponents?.showPromptTemplateLibrary && - - - } - />} - - - - } - /> - {config?.configuration?.enabledComponents?.chatAssistantStacks && - - - } - />} - {config?.configuration?.enabledComponents?.mcpConnections && - - - } - />} - {config?.configuration?.enabledComponents?.enableModelComparisonUtility && - - + } + /> + {window.env.HOSTED_MCP_ENABLED && + + + } + />} + {window.env.RAG_ENABLED && + + + } + />} + + + + } + /> + + + + ) : ( + + ) + } + /> + {config?.configuration?.enabledComponents?.enableUserApiTokens && + + + } + />} + {config?.configuration?.enabledComponents?.modelLibrary && + + + } + />} + {config?.configuration?.enabledComponents?.showRagLibrary && + <> + + + + } + /> + + + + } + /> + } + {config?.configuration?.enabledComponents?.showPromptTemplateLibrary && + + + } + />} + + + + } + /> + {config?.configuration?.enabledComponents?.mcpConnections && + + + } + />} + {config?.configuration?.enabledComponents?.enableModelComparisonUtility && + + + } + /> } - /> - } - - - Loading configuration... -
- : - - } /> - + {config?.configuration?.enabledComponents?.chatAssistantStacks && + + + } + />} + + + Loading configuration... + + : + + } /> + + } /> {confirmationModal && } diff --git a/lib/user-interface/react/src/components/Topbar.test.tsx b/lib/user-interface/react/src/components/Topbar.test.tsx index 945f93783..2e285fb9b 100644 --- a/lib/user-interface/react/src/components/Topbar.test.tsx +++ b/lib/user-interface/react/src/components/Topbar.test.tsx @@ -25,20 +25,20 @@ import { configureStore } from '@reduxjs/toolkit'; import Topbar from './Topbar'; import ColorSchemeContext from '@/shared/color-scheme.provider'; import { Mode } from '@cloudscape-design/global-styles'; +import { + selectCurrentUserIsAdmin, + selectCurrentUserIsRagAdmin, + selectCurrentUserIsApiUser, + selectCurrentUsername, +} from '@/shared/reducers/user.reducer'; // Mock the auth abstraction vi.mock('../auth/useAuth'); -// Mock store functions +// Mock store functions - use selector reference matching vi.mock('@/config/store', () => ({ useAppDispatch: vi.fn(() => vi.fn()), - useAppSelector: vi.fn((selector) => { - const selectorStr = selector.toString(); - if (selectorStr.includes('selectCurrentUserIsAdmin')) return false; - if (selectorStr.includes('selectCurrentUserIsApiUser')) return false; - if (selectorStr.includes('selectCurrentUsername')) return 'Test User'; - return null; - }), + useAppSelector: vi.fn(), })); const mockAuth = { @@ -79,10 +79,20 @@ const renderTopbar = (props = {}) => { }; describe('Topbar', () => { - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks(); (useAuth as any).mockReturnValue(mockAuth); + // Set default selector mock (regular user, no admin roles) + const storeModule = await import('@/config/store'); + (storeModule.useAppSelector as any).mockImplementation((selector: any) => { + if (selector === selectCurrentUserIsAdmin) return false; + if (selector === selectCurrentUserIsRagAdmin) return false; + if (selector === selectCurrentUserIsApiUser) return false; + if (selector === selectCurrentUsername) return 'Test User'; + return null; + }); + // Mock window.env (window as any).env = { CLIENT_ID: 'test-client-id', @@ -102,6 +112,91 @@ describe('Topbar', () => { expect(mockAuth.signoutRedirect).toHaveBeenCalledOnce(); }); + it('shows Administration with only RAG Management for rag-admin user', async () => { + const storeModule = await import('@/config/store'); + (storeModule.useAppSelector as any).mockImplementation((selector: any) => { + if (selector === selectCurrentUserIsAdmin) return false; + if (selector === selectCurrentUserIsRagAdmin) return true; + if (selector === selectCurrentUserIsApiUser) return false; + if (selector === selectCurrentUsername) return 'RAG Admin User'; + return null; + }); + (window as any).env = { + ...window.env, + RAG_ENABLED: true, + }; + + const user = userEvent.setup(); + renderTopbar(); + + // Should see Administration dropdown (Cloudscape renders duplicate text in collapsed/expanded views) + const adminDropdowns = screen.getAllByText('Administration'); + expect(adminDropdowns.length).toBeGreaterThan(0); + + // Click to open dropdown + await user.click(adminDropdowns[0]); + + // Should see RAG Management + expect(screen.getByText('RAG Management')).toBeInTheDocument(); + + // Should NOT see admin-only items + expect(screen.queryByText('Configuration')).not.toBeInTheDocument(); + expect(screen.queryByText('Model Management')).not.toBeInTheDocument(); + expect(screen.queryByText('API Token Management')).not.toBeInTheDocument(); + }); + + it('shows all admin items for admin user', async () => { + const storeModule = await import('@/config/store'); + (storeModule.useAppSelector as any).mockImplementation((selector: any) => { + if (selector === selectCurrentUserIsAdmin) return true; + if (selector === selectCurrentUserIsRagAdmin) return false; + if (selector === selectCurrentUserIsApiUser) return false; + if (selector === selectCurrentUsername) return 'Admin User'; + return null; + }); + (window as any).env = { + ...window.env, + RAG_ENABLED: true, + }; + + const user = userEvent.setup(); + renderTopbar(); + + // Cloudscape TopNavigation renders duplicate text in collapsed/expanded views + const adminDropdowns = screen.getAllByText('Administration'); + expect(adminDropdowns.length).toBeGreaterThan(0); + await user.click(adminDropdowns[0]); + + expect(screen.getByText('Configuration')).toBeInTheDocument(); + expect(screen.getByText('Model Management')).toBeInTheDocument(); + expect(screen.getByText('RAG Management')).toBeInTheDocument(); + expect(screen.getByText('API Token Management')).toBeInTheDocument(); + }); + + it('hides Administration for rag-admin when RAG_ENABLED is false', async () => { + const storeModule = await import('@/config/store'); + (storeModule.useAppSelector as any).mockImplementation((selector: any) => { + if (selector === selectCurrentUserIsAdmin) return false; + if (selector === selectCurrentUserIsRagAdmin) return true; + if (selector === selectCurrentUserIsApiUser) return false; + if (selector === selectCurrentUsername) return 'RAG Admin User'; + return null; + }); + (window as any).env = { + ...window.env, + RAG_ENABLED: false, + }; + + renderTopbar(); + expect(screen.queryByText('Administration')).not.toBeInTheDocument(); + }); + + it('hides Administration for regular user', () => { + // Default mock already returns isAdmin=false, isRagAdmin=false + renderTopbar(); + expect(screen.queryByText('Administration')).not.toBeInTheDocument(); + }); + it('calls signinRedirect when sign in is clicked for unauthenticated user', async () => { const user = userEvent.setup(); @@ -120,9 +215,10 @@ describe('Topbar', () => { // Click the sign in option await user.click(screen.getByText('Sign in')); - // Verify that signinRedirect was called with correct redirect_uri + // Verify that signinRedirect was called with correct redirect_uri (no hash, per OAuth spec) + const { getRedirectUri } = await import('@/config/oidc.config'); expect(mockAuth.signinRedirect).toHaveBeenCalledWith({ - redirect_uri: window.location.toString(), + redirect_uri: getRedirectUri(), }); }); diff --git a/lib/user-interface/react/src/components/Topbar.tsx b/lib/user-interface/react/src/components/Topbar.tsx index 45c8b6e8c..4ff959d8c 100644 --- a/lib/user-interface/react/src/components/Topbar.tsx +++ b/lib/user-interface/react/src/components/Topbar.tsx @@ -20,11 +20,11 @@ import { useHref, useNavigate } from 'react-router-dom'; import { applyDensity, Density, Mode } from '@cloudscape-design/global-styles'; import TopNavigation, { TopNavigationProps } from '@cloudscape-design/components/top-navigation'; import { useAppDispatch, useAppSelector } from '@/config/store'; -import { selectCurrentUserIsAdmin, selectCurrentUserIsApiUser, selectCurrentUsername } from '../shared/reducers/user.reducer'; +import { selectCurrentUserIsAdmin, selectCurrentUserIsApiUser, selectCurrentUserIsRagAdmin, selectCurrentUsername } from '../shared/reducers/user.reducer'; import { IConfiguration } from '@/shared/model/configuration.model'; import { ButtonDropdownProps } from '@cloudscape-design/components'; import ColorSchemeContext from '@/shared/color-scheme.provider'; -import { OidcConfig } from '@/config/oidc.config'; +import { OidcConfig, getRedirectUri } from '@/config/oidc.config'; import { getBrandingAssetPath } from '../shared/util/branding'; import { getDisplayName } from '@/shared/util/branding'; import { useDeleteAllSessionsForUserMutation } from '@/shared/reducers/session.reducer'; @@ -43,6 +43,7 @@ function Topbar ({ configs }: TopbarProps): ReactElement { const dispatch = useAppDispatch(); const notificationService = useNotificationService(dispatch); const isUserAdmin = useAppSelector(selectCurrentUserIsAdmin); + const isUserRagAdmin = useAppSelector(selectCurrentUserIsRagAdmin); const isApiUser = useAppSelector(selectCurrentUserIsApiUser); const userName = useAppSelector(selectCurrentUsername); const { colorScheme, setColorScheme } = useContext(ColorSchemeContext); @@ -103,6 +104,8 @@ function Topbar ({ configs }: TopbarProps): ReactElement { } as ButtonDropdownProps.Item] : []) ].sort((a,b) => a.text.localeCompare(b.text)); + const showAdminDropdown = isUserAdmin || (isUserRagAdmin && window.env.RAG_ENABLED); + return ( { return window.env.API_GROUP ? userGroups.includes(window.env.API_GROUP) : false; }; +const isRagAdmin = (userGroups: any): boolean => { + return window.env.RAG_ADMIN_GROUP ? userGroups.includes(window.env.RAG_ADMIN_GROUP) : false; +}; + function AppConfigured () { const dispatch = useAppDispatch(); const [oidcUser, setOidcUser] = useState(); @@ -114,6 +119,7 @@ function AppConfigured () { isAdmin: userGroups ? isAdmin(userGroups) : false, isUser: window.env.USER_GROUP ? userGroups && isUser(userGroups) : true, isApiUser: window.env.API_GROUP ? userGroups && isApiUser(userGroups) : false, + isRagAdmin: userGroups ? isRagAdmin(userGroups) : false, }), ); } @@ -137,7 +143,9 @@ function AppConfigured () { { - if ((window.env.USER_GROUP && user && isUser(getGroups(user.profile))) || !window.env.USER_GROUP) { + const userGroups = user ? getGroups(user.profile) : undefined; + const hasAccess = userGroups && (isUser(userGroups) || isRagAdmin(userGroups) || isAdmin(userGroups)); + if ((window.env.USER_GROUP && user && hasAccess) || !window.env.USER_GROUP) { window.history.replaceState({}, document.title, `${window.location.pathname}${window.location.hash}`); setOidcUser(user); } else { diff --git a/lib/user-interface/react/src/components/chatbot/Chat.tsx b/lib/user-interface/react/src/components/chatbot/Chat.tsx index 114bb7b79..1960f4882 100644 --- a/lib/user-interface/react/src/components/chatbot/Chat.tsx +++ b/lib/user-interface/react/src/components/chatbot/Chat.tsx @@ -94,6 +94,7 @@ export default function Chat ({ sessionId, initialStack }) { const dispatch = useAppDispatch(); const navigate = useNavigate(); const config: IConfiguration = useContext(ConfigurationContext); + const ragSelectionAvailable = config?.configuration?.enabledComponents?.ragSelectionAvailable ?? true; const notificationService = useNotificationService(dispatch); const modelSelectRef = useRef(null); const bottomRef = useRef(null); @@ -244,7 +245,7 @@ export default function Chat ({ sessionId, initialStack }) { const pendingToolChainExecution = useRef<(() => Promise) | null>(null); // Use the custom hook to manage multiple MCP connections - const { tools: mcpTools, callTool, McpConnections, toolToServerMap } = useMultipleMcp(enabledServers, userPreferences?.preferences?.mcp); + const { tools: mcpTools, callTool, McpConnections, toolToServerMap } = useMultipleMcp(enabledServers, userPreferences?.preferences?.mcp, session?.sessionId); const [updatePreferences, {isSuccess: isUpdatingPreferencesSuccess, isError: isUpdatingPreferencesError, isLoading: isUpdatingPreferences}] = useUpdateUserPreferencesMutation(); // Load markdown preview preference from user preferences @@ -322,11 +323,6 @@ export default function Chat ({ sessionId, initialStack }) { return inList?.status === ModelStatus.Stopped; }, [selectedModel, modelsForDropdown]); - const hasStoppedModelsInDropdown = useMemo(() => - (modelsForDropdown || []).some((m) => m.status === ModelStatus.Stopped), - [modelsForDropdown] - ); - // Set default model if none is selected, default model is configured, and user hasn't interacted (only InService models) const availableModelsForDefault = useMemo(() => (modelsForDropdown || []).filter((m) => m.status === ModelStatus.InService), @@ -894,7 +890,7 @@ export default function Chat ({ sessionId, initialStack }) { const getButtonItemsWithAssistantMode = useCallback((...args: Parameters) => { const [config, useRag, isImageGen, isVideoGen, isConnected, isModelDel, showMd] = args; return getButtonItems(config, useRag, isImageGen, isVideoGen, isConnected, isModelDel, showMd, !!effectiveStack, !!selectedModel, loadingSession); - }, [config, effectiveStack, selectedModel, loadingSession]); + }, [effectiveStack, selectedModel, loadingSession]); const promptInputProps = useMemo(() => ({ userPrompt, @@ -1192,6 +1188,7 @@ export default function Chat ({ sessionId, initialStack }) { > 0} statusType={isFetchingModels ? 'loading' : 'finished'} loadingText='Loading models (might take few seconds)...' @@ -1205,11 +1202,6 @@ export default function Chat ({ sessionId, initialStack }) { ref={modelSelectRef} controlId='model-selection-autosuggest' /> - {hasStoppedModelsInDropdown && ( - - Some models in the list are stopped and cannot be selected. - - )} {window.env.RAG_ENABLED && !isImageGenerationMode && !isVideoGenerationMode && ( @@ -1218,6 +1210,7 @@ export default function Chat ({ sessionId, initialStack }) { setUseRag={setUseRag} setRagConfig={setRagConfig} ragConfig={ragConfig} + selectionAvailable={ragSelectionAvailable} allowedRepositoryIds={effectiveStack ? (effectiveStack.repositoryIds ?? []) : undefined} allowedCollectionIds={effectiveStack ? (effectiveStack.collectionIds ?? []) : undefined} /> diff --git a/lib/user-interface/react/src/components/chatbot/components/ChatPromptInput.tsx b/lib/user-interface/react/src/components/chatbot/components/ChatPromptInput.tsx index 77eca1687..607200ad0 100644 --- a/lib/user-interface/react/src/components/chatbot/components/ChatPromptInput.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/ChatPromptInput.tsx @@ -95,6 +95,7 @@ export const ChatPromptInput: React.FC = ({ }; return ( { if (item.type === 'text' && typeof item.text === 'string') { - if (item.text.startsWith('File context:')) return null; + if ( + item.text.startsWith('File context:') || + item.text.startsWith('Context from document search:') + ) { + return null; + } const displayableText = getDisplayableMessage(item.text, message.type === MessageTypes.AI ? ragCitationsString : undefined); @@ -255,7 +260,7 @@ export const Message = React.memo(({ message, isRunning, showMetadata, isStreami return ( (message.type === MessageTypes.HUMAN || message.type === MessageTypes.AI || message.type === MessageTypes.TOOL) && -
+
{(isRunning && !callingToolName && !message?.metadata?.videoGeneration) && ( -
} - filteringType='auto' - value={selectedRepositoryOption} - enteredTextLabel={(text) => `Use: "${text}"`} - onChange={handleRepositoryChange} - options={filteredRepositories?.map((repository) => ({ - value: repository.repositoryId, - label: repository?.repositoryName?.length ? repository?.repositoryName : repository.repositoryId - })) || []} - controlId='rag-repository-autosuggest' - /> - No collections available.
} - filteringType='auto' - value={selectedCollectionOption} - enteredTextLabel={(text) => `Use: "${text}"`} - onChange={handleCollectionChange} - options={collectionOptions} - controlId='rag-collection-autosuggest' - /> - + {selectionAvailable === false ? null : ( + + No repositories available.} + filteringType='auto' + value={selectedRepositoryOption} + enteredTextLabel={(text) => `Use: "${text}"`} + onChange={handleRepositoryChange} + options={filteredRepositories?.map((repository) => ({ + value: repository.repositoryId, + label: repository?.repositoryName?.length ? repository?.repositoryName : repository.repositoryId + })) || []} + controlId='rag-repository-autosuggest' + /> + No collections available.} + filteringType='auto' + value={selectedCollectionOption} + enteredTextLabel={(text) => `Use: "${text}"`} + onChange={handleCollectionChange} + options={collectionOptions} + controlId='rag-collection-autosuggest' + /> + + )} ); } diff --git a/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx b/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx index 66497e30b..88f9d2476 100644 --- a/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx @@ -32,6 +32,7 @@ import { IModel, ModelType } from '@/shared/model/model-management.model'; import { IConfiguration } from '@/shared/model/configuration.model'; import { LisaChatSession } from '@/components/types'; import { ModelFeatures } from '@/components/types'; +import AwsCredentialsPanel from '@/components/settings/AwsCredentialsPanel'; export type SessionConfigurationProps = { title?: string; @@ -114,76 +115,91 @@ export const SessionConfiguration = ({ size='large' > - - updateSessionConfiguration('streaming', detail.checked)} - checked={chatConfiguration.sessionConfiguration.streaming} - disabled={!selectedModel?.streaming || isRunning} - > - Stream Responses - - updateSessionConfiguration('markdownDisplay', detail.checked)} - checked={chatConfiguration.sessionConfiguration.markdownDisplay} - > - Display Responses as Markdown - - {systemConfig && systemConfig.configuration.enabledComponents.viewMetaData && + {(() => { + const items = [ updateSessionConfiguration('showMetadata', detail.checked)} - checked={chatConfiguration.sessionConfiguration.showMetadata} - disabled={isRunning} + key='streaming' + onChange={({ detail }) => updateSessionConfiguration('streaming', detail.checked)} + checked={chatConfiguration.sessionConfiguration.streaming} + disabled={!selectedModel?.streaming || isRunning} > - Show Message Metadata - } - {systemConfig && systemConfig.configuration.enabledComponents.editChatHistoryBuffer && !isImageModel && !isVideoModel && !modelOnly && - - updateSessionConfiguration('ragTopK', parseInt(detail.selectedOption.value))} - options={oneThroughTenOptions} - /> - } - {selectedModel?.features?.find((feature) => feature.name === ModelFeatures.REASONING) && - - updateSessionConfiguration('chatHistoryBufferSize', parseInt(detail.selectedOption.value))} + options={oneThroughTenOptions} + /> + + ] : []), + ...(systemConfig && systemConfig.configuration.enabledComponents.editNumOfRagDocument && !isImageModel && !isVideoModel && !modelOnly ? [ + + updateSessionConfiguration('modelArgs', {...chatConfiguration.sessionConfiguration.modelArgs, reasoning_effort: detail.selectedOption.value })} + options={reasoningEffortOptions} + /> + , + updateSessionConfiguration('showReasoningContent', detail.checked)} + checked={chatConfiguration.sessionConfiguration.showReasoningContent} + disabled={isRunning} + > + Show Reasoning Content + + ] : []), + ]; + + return ( + ({ colspan: 6 }))}> + {items} + + ); + })()} {systemConfig && systemConfig.configuration.enabledComponents.editKwargs && !isImageModel && !isVideoModel && } } + {visible && session && systemConfig?.configuration?.enabledComponents?.mcpConnections && (systemConfig?.configuration?.enabledComponents?.awsSessions ?? false) && ( + + )} {isImageModel && ( { renderWithProviders(); - const newSessionButton = screen.getByRole('button', { name: /new/i }); - await user.click(newSessionButton); + const actionsContainer = screen.getByTestId('sessions-actions'); + const [dropdownTrigger] = within(actionsContainer).getAllByRole('button'); + await user.click(dropdownTrigger); - const newChatItem = await screen.findByText('New Chat'); + const newChatItem = await screen.findByRole('menuitem', { name: /new chat/i }); await user.click(newChatItem); expect(mockNewSession).toHaveBeenCalledOnce(); diff --git a/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx b/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx index 25f4e8c4f..2838aefe9 100644 --- a/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx @@ -277,6 +277,7 @@ export function Sessions ({ newSession }) { {projectsEnabled ? ( { setHistoryView(detail.selectedId); @@ -313,8 +314,10 @@ export function Sessions ({ newSession }) { Found {filteredSessions.length} session{filteredSessions.length !== 1 ? 's' : ''} )} -
+
- New - + /> @@ -398,6 +400,7 @@ export function Sessions ({ newSession }) { key={item.sessionId} padding='xxs' className={item.sessionId === currentSessionId ? styles.sessionItemActive : styles.sessionItem} + data-testid={item.sessionId === currentSessionId ? 'session-item-active' : 'session-item'} > diff --git a/lib/user-interface/react/src/components/chatbot/hooks/chat.hooks.tsx b/lib/user-interface/react/src/components/chatbot/hooks/chat.hooks.tsx index 79fff6b72..8676f04e6 100644 --- a/lib/user-interface/react/src/components/chatbot/hooks/chat.hooks.tsx +++ b/lib/user-interface/react/src/components/chatbot/hooks/chat.hooks.tsx @@ -85,6 +85,15 @@ const processReasoningContent = ( return { cleanedContent: parsed.cleanedContent, reasoningContent }; }; +/** +* Checks whether caught exceptions are due to a guardrail + * being triggered so that they can be handled gracefully. + */ +const isGuardrailError = (error: any): boolean => { + const msg = error?.error?.message || error?.message || ''; + return typeof msg === 'string' && msg.toLowerCase().includes('violated guardrail policy'); +}; + /** * Parses accumulated tool call data into final tool call objects. */ @@ -902,11 +911,30 @@ export const useChatGeneration = ({ await memory.saveContext({ input: params.input }, { output: finalCleanedContent }); setIsStreaming(false); } catch (exception) { - setSession((prev) => ({ - ...prev, - history: prev.history.slice(0, -1), - })); - throw exception; + if (isGuardrailError(exception)) { + // Handle gracefully β€” same as the in-stream guardrail path + setSession((prev) => { + const lastMessage = prev.history[prev.history.length - 1]; + if (lastMessage?.type === MessageTypes.AI) { + let updatedHistory = [...prev.history.slice(0, -1), + new LisaChatMessage({ + ...lastMessage, + guardrailTriggered: true, + }) + ]; + updatedHistory = markLastUserMessageAsGuardrailTriggered(updatedHistory); + return { ...prev, history: updatedHistory }; + } + return prev; + }); + // Do NOT rethrow β€” fall through to finally block + } else { + setSession((prev) => ({ + ...prev, + history: prev.history.slice(0, -1), + })); + throw exception; + } } } else { const response = await llmClient.invoke(messages, { tools: modelSupportsTools ? openAiTools : undefined }); diff --git a/lib/user-interface/react/src/components/chatbot/hooks/mcp.hooks.tsx b/lib/user-interface/react/src/components/chatbot/hooks/mcp.hooks.tsx index 4a8ca8220..769134e42 100644 --- a/lib/user-interface/react/src/components/chatbot/hooks/mcp.hooks.tsx +++ b/lib/user-interface/react/src/components/chatbot/hooks/mcp.hooks.tsx @@ -20,11 +20,21 @@ import { McpServer } from '@/shared/reducers/mcp-server.reducer'; import { McpPreferences } from '@/shared/reducers/user-preferences.reducer'; // Individual MCP Connection Component -export const McpConnection = ({ server, onToolsChange, onConnectionChange }: { +export const McpConnection = ({ server, onToolsChange, onConnectionChange, sessionId }: { server: McpServer, onToolsChange: (tools: any[], clientName: string) => void, - onConnectionChange: (connection: any, clientName: string) => void + onConnectionChange: (connection: any, clientName: string) => void, + sessionId?: string, }) => { + const customHeaders = server.customHeaders; + const mergedHeaders = useMemo(() => { + const base: Record = { ...(customHeaders ?? {}) }; + if (sessionId) { + base['X-Session-Id'] = sessionId; + } + return Object.keys(base).length > 0 ? base : undefined; + }, [customHeaders, sessionId]); + const connection = useMcp({ url: server?.url ?? ' ', clientName: server?.name, @@ -32,7 +42,7 @@ export const McpConnection = ({ server, onToolsChange, onConnectionChange }: { autoRetry: true, debug: false, clientConfig: server?.clientConfig ?? undefined, - customHeaders: server?.customHeaders ?? undefined, + customHeaders: mergedHeaders, callbackUrl: `${window.location.origin}${window.env.API_BASE_URL.includes('.') ? '/' : window.env.API_BASE_URL}oauth/callback`, }); @@ -61,7 +71,7 @@ export const McpConnection = ({ server, onToolsChange, onConnectionChange }: { }; // Custom hook to manage multiple MCP connections dynamically -export const useMultipleMcp = (servers: McpServer[], mcpPreferences: McpPreferences) => { +export const useMultipleMcp = (servers: McpServer[], mcpPreferences: McpPreferences, sessionId?: string) => { const [allTools, setAllTools] = useState([]); const [serverToolsMap, setServerToolsMap] = useState>(new Map()); const [connectionsMap, setConnectionsMap] = useState>(new Map()); @@ -131,10 +141,11 @@ export const useMultipleMcp = (servers: McpServer[], mcpPreferences: McpPreferen callTool, McpConnections: servers?.map((server) => ( )), toolToServerMap diff --git a/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx b/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx index aa6fbc570..a781a315a 100644 --- a/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx +++ b/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx @@ -20,6 +20,7 @@ import { SetFieldsFunction } from '../../shared/validation'; const ragOptions = { uploadRagDocs: 'Document upload from Chat', + ragSelectionAvailable: 'RAG repository & collection selection', editNumOfRagDocument: 'Edit number of referenced documents', }; @@ -48,7 +49,8 @@ const advancedOptions = { const mcpOptions = { mcpConnections: 'MCP Server Connections', - showMcpWorkbench: 'MCP Workbench' + showMcpWorkbench: 'MCP Workbench', + awsSessions: 'MCP AWS Sessions' }; const apiTokenOptions = { @@ -75,7 +77,8 @@ const dependencies: DependencyMap<{ apiTokenOptions: typeof apiTokenOptions; }> = { showMcpWorkbench: { prerequisites: ['mcpConnections'] }, - mcpConnections: { dependents: ['showMcpWorkbench'] } + awsSessions: { prerequisites: ['mcpConnections'] }, + mcpConnections: { dependents: ['showMcpWorkbench', 'awsSessions'] } }; const configurableOperations = [{ @@ -176,7 +179,7 @@ export function ActivatedUserComponents (props: ActivatedComponentConfigurationP }} checked={isChecked} disabled={isDisabled} - data-cy={`Toggle-${item}`} + data-testid={`Toggle-${item}`} > {operation.items[item]} diff --git a/lib/user-interface/react/src/components/configuration/ConfigurationComponent.tsx b/lib/user-interface/react/src/components/configuration/ConfigurationComponent.tsx index 242815e1f..cee8300ef 100644 --- a/lib/user-interface/react/src/components/configuration/ConfigurationComponent.tsx +++ b/lib/user-interface/react/src/components/configuration/ConfigurationComponent.tsx @@ -183,7 +183,7 @@ export function ConfigurationComponent (): ReactElement { } }} loading={isUpdating} - data-cy='configuration-submit' + data-testid='configuration-submit' disabled={isUpdating || _.isEmpty(changesDiff)} > Save Changes diff --git a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.test.tsx b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.test.tsx index 75d3950b4..b8f0ae642 100644 --- a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.test.tsx +++ b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.test.tsx @@ -16,7 +16,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { screen, waitFor } from '@testing-library/react'; -import { DocumentLibraryComponent, getMatchesCountText } from './DocumentLibraryComponent'; +import { DocumentLibraryComponent, canDeleteAll, getMatchesCountText } from './DocumentLibraryComponent'; import { renderWithProviders } from '../../test/helpers/render'; import { MemoryRouter } from 'react-router-dom'; import { createMockDocument } from '../../test/factories/document.factory'; @@ -34,6 +34,7 @@ describe('DocumentLibraryComponent', () => { // Mock Redux selectors vi.spyOn(store, 'useAppSelector').mockImplementation((selector: any) => { if (selector.toString().includes('selectCurrentUsername')) return 'test-user'; + if (selector.toString().includes('selectCurrentUserIsRagAdmin')) return false; if (selector.toString().includes('selectCurrentUserIsAdmin')) return false; return null; }); @@ -335,4 +336,41 @@ describe('DocumentLibraryComponent', () => { expect(getMatchesCountText(0)).toBe('0 matches'); }); }); + + describe('canDeleteAll', () => { + const docUploadedByUser = createMockDocument({ username: 'test-user' }); + const docUploadedByOther = createMockDocument({ username: 'test-user-other', document_id: 'doc-999' }); + + it('returns false when no items are selected', () => { + expect(canDeleteAll([], 'test-user', false, false)).toBe(false); + }); + + it('allows a regular user to delete their own documents', () => { + expect(canDeleteAll([docUploadedByUser], 'test-user', false, false)).toBe(true); + }); + + it('blocks a regular user from deleting documents they do not own', () => { + expect(canDeleteAll([docUploadedByOther], 'test-user', false, false)).toBe(false); + }); + + it('allows an admin to delete documents uploaded by another user', () => { + expect(canDeleteAll([docUploadedByOther], 'test-admin', true, false)).toBe(true); + }); + + it('allows a rag admin to delete documents uploaded by another user', () => { + expect(canDeleteAll([docUploadedByOther], 'test-rag', false, true)).toBe(true); + }); + + it('blocks a regular user from a mixed selection containing a doc they do not own', () => { + expect(canDeleteAll([docUploadedByUser, docUploadedByOther], 'test-user', false, false)).toBe(false); + }); + + it('allows an admin to delete a mixed selection', () => { + expect(canDeleteAll([docUploadedByUser, docUploadedByOther], 'test-admin', true, false)).toBe(true); + }); + + it('allows a rag admin to delete a mixed selection', () => { + expect(canDeleteAll([docUploadedByUser, docUploadedByOther], 'test-rag', false, true)).toBe(true); + }); + }); }); diff --git a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx index ffff07d69..dab3e8335 100644 --- a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx +++ b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx @@ -37,7 +37,7 @@ import { DEFAULT_PREFERENCES, PAGE_SIZE_OPTIONS, TABLE_DEFINITION, TABLE_PREFERE import { useCollection } from '@cloudscape-design/collection-hooks'; import Box from '@cloudscape-design/components/box'; import { useAppDispatch, useAppSelector } from '../../config/store'; -import { selectCurrentUserIsAdmin, selectCurrentUsername } from '../../shared/reducers/user.reducer'; +import { selectCurrentUserIsAdmin, selectCurrentUserIsRagAdmin, selectCurrentUsername } from '../../shared/reducers/user.reducer'; import { RagDocument } from '../types'; import { setConfirmationModal } from '../../shared/reducers/modal.reducer'; import { useLocalStorage } from '../../shared/hooks/use-local-storage'; @@ -55,8 +55,8 @@ export function getMatchesCountText (count) { return count === 1 ? '1 match' : `${count} matches`; } -function canDeleteAll (selectedItems: ReadonlyArray, username: string, isAdmin: boolean) { - return selectedItems.length > 0 && (isAdmin || selectedItems.every((doc) => doc.username === username)); +export function canDeleteAll (selectedItems: ReadonlyArray, username: string, isAdmin: boolean, isRagAdmin: boolean) { + return selectedItems.length > 0 && (isAdmin || isRagAdmin || selectedItems.every((doc) => doc.username === username)); } function disabledDeleteReason (selectedItems: ReadonlyArray) { @@ -84,6 +84,7 @@ export function DocumentLibraryComponent ({ repositoryId, collectionId }: Docume const currentUser = useAppSelector(selectCurrentUsername); const isAdmin = useAppSelector(selectCurrentUserIsAdmin); + const isRagAdmin = useAppSelector(selectCurrentUserIsRagAdmin); const [preferences, setPreferences] = useLocalStorage('DocumentRagPreferences', DEFAULT_PREFERENCES); const dispatch = useAppDispatch(); @@ -140,7 +141,7 @@ export function DocumentLibraryComponent ({ repositoryId, collectionId }: Docume { id: 'rm', text: 'Delete', - disabled: !canDeleteAll(collectionProps.selectedItems, currentUser, isAdmin), + disabled: !canDeleteAll(collectionProps.selectedItems, currentUser, isAdmin, isRagAdmin), disabledReason: disabledDeleteReason(collectionProps.selectedItems), }, { diff --git a/lib/user-interface/react/src/components/model-management/create-model/BaseModelConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/BaseModelConfig.tsx index 82430f833..3cdb2be2d 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/BaseModelConfig.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/BaseModelConfig.tsx @@ -20,7 +20,7 @@ import FormField from '@cloudscape-design/components/form-field'; import Input from '@cloudscape-design/components/input'; import Toggle from '@cloudscape-design/components/toggle'; import Select from '@cloudscape-design/components/select'; -import { IModelRequest, InferenceContainer, ModelType } from '../../../shared/model/model-management.model'; +import { IModelRequest, InferenceContainer, ModelHostingType, ModelType } from '../../../shared/model/model-management.model'; import { Grid, SpaceBetween } from '@cloudscape-design/components'; import { useGetInstancesQuery } from '../../../shared/reducers/model-management.reducer'; import { ModelFeatures } from '@/components/types'; @@ -49,41 +49,49 @@ export function BaseModelConfig (props: FormProps & BaseModelConf props.touchFields(['modelId'])} onChange={({ detail }) => { + props.touchFields(['modelId'])} onChange={({ detail }) => { props.setFields({ 'modelId': detail.value }); }} disabled={props.isEdit} placeholder='mistral-vllm'/> - props.touchFields(['modelName'])} onChange={({ detail }) => { + props.touchFields(['modelName'])} onChange={({ detail }) => { props.setFields({ 'modelName': detail.value }); }} disabled={props.isEdit} placeholder='mistralai/Mistral-7B-Instruct-v0.2'/> @@ -96,7 +104,7 @@ export function BaseModelConfig (props: FormProps & BaseModelConf props.setFields({ 'modelDescription': detail.value }); }} placeholder='Brief description of the model and its capabilities'/> - {!props.item.lisaHostedModel && API Key - Optional} description='API authentication key for accessing third-party model provider services.' errorText={props.formErrors?.apiKey} @@ -106,8 +114,12 @@ export function BaseModelConfig (props: FormProps & BaseModelConf }} disabled={props.isEdit} placeholder='sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'/> } Model URL - Optional} - description='Custom endpoint URL for the model API (e.g., for self-hosted or third-party services).' + label={props.item.hostingType === ModelHostingType.INTERNAL_HOSTED ? 'Model URL' : Model URL - Optional} + description={ + props.item.hostingType === ModelHostingType.INTERNAL_HOSTED ? + 'Required internal AWS load balancer endpoint for this model (for example, http://internal-xyz.elb.amazonaws.com/v1).' : + 'Custom endpoint URL for the model API (e.g., for self-hosted or third-party services).' + } errorText={props.formErrors?.modelUrl} > props.touchFields(['modelUrl'])} onChange={({ detail }) => { @@ -159,7 +171,7 @@ export function BaseModelConfig (props: FormProps & BaseModelConf disabled={props.isEdit} /> - {props.item.lisaHostedModel && ( + {(props.item.hostingType === ModelHostingType.LISA_HOSTED || props.item.lisaHostedModel) && ( <> { if (props.isEdit) { @@ -123,8 +125,8 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { } })() }) : null), - loadBalancerConfig: (state.form.lisaHostedModel ? state.form.loadBalancerConfig : null), - autoScalingConfig: (state.form.lisaHostedModel ? state.form.autoScalingConfig : null), + loadBalancerConfig: (isLisaHosted ? state.form.loadBalancerConfig : null), + autoScalingConfig: (isLisaHosted ? state.form.autoScalingConfig : null), inferenceContainer: state.form.inferenceContainer ?? null, instanceType: state.form.instanceType ? state.form.instanceType : null, modelUrl: state.form.modelUrl ? state.form.modelUrl : null @@ -148,7 +150,13 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { const changesDiff = useMemo(() => { return props.isEdit ? getJsonDifference({ ...props.selectedItems[0], - lisaHostedModel: Boolean(props.selectedItems[0].containerConfig || props.selectedItems[0].autoScalingConfig || props.selectedItems[0].loadBalancerConfig) + lisaHostedModel: Boolean(props.selectedItems[0].containerConfig || props.selectedItems[0].autoScalingConfig || props.selectedItems[0].loadBalancerConfig), + hostingType: ( + props.selectedItems[0].hostingType || + (props.selectedItems[0].containerConfig || props.selectedItems[0].autoScalingConfig || props.selectedItems[0].loadBalancerConfig + ? ModelHostingType.LISA_HOSTED + : ModelHostingType.THIRD_PARTY) + ) }, toSubmit) : getJsonDifference({}, toSubmit); // eslint-disable-next-line react-hooks/exhaustive-deps @@ -347,7 +355,13 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { ...parsedValue.containerConfig, environment: props.selectedItems[0].containerConfig?.environment ? Object.entries(props.selectedItems[0].containerConfig?.environment).map(([key, value]) => ({ key, value: String(value) })) : [], }, - lisaHostedModel: Boolean(props.selectedItems[0].containerConfig || props.selectedItems[0].autoScalingConfig || props.selectedItems[0].loadBalancerConfig) + lisaHostedModel: Boolean(props.selectedItems[0].containerConfig || props.selectedItems[0].autoScalingConfig || props.selectedItems[0].loadBalancerConfig), + hostingType: ( + props.selectedItems[0].hostingType || + (props.selectedItems[0].containerConfig || props.selectedItems[0].autoScalingConfig || props.selectedItems[0].loadBalancerConfig + ? ModelHostingType.LISA_HOSTED + : ModelHostingType.THIRD_PARTY) + ) } }); } else { @@ -356,6 +370,8 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { ...state, form: { ...state.form, + hostingType: state.form.hostingType || ModelHostingType.THIRD_PARTY, + lisaHostedModel: false } }); } @@ -427,7 +443,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { ), isOptional: true, - onEdit: state.form.lisaHostedModel, + onEdit: isLisaHosted, forExternalModel: false }, { @@ -435,7 +451,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { content: ( ), - onEdit: state.form.lisaHostedModel, + onEdit: isLisaHosted, forExternalModel: false }, { @@ -444,7 +460,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { ), isOptional: true, - onEdit: state.form.lisaHostedModel, + onEdit: isLisaHosted, forExternalModel: false }, { @@ -479,7 +495,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { ]; const steps = allSteps.filter((step) => { - return state.form.lisaHostedModel || step.forExternalModel; + return isLisaHosted || step.forExternalModel; }); return ( diff --git a/lib/user-interface/react/src/components/repository-management/RepositoryActions.test.tsx b/lib/user-interface/react/src/components/repository-management/RepositoryActions.test.tsx new file mode 100644 index 000000000..da8258094 --- /dev/null +++ b/lib/user-interface/react/src/components/repository-management/RepositoryActions.test.tsx @@ -0,0 +1,133 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { screen, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { RepositoryActions } from './RepositoryActions'; +import { renderWithProviders } from '../../test/helpers/render'; +import { createMockRepositories } from '../../test/factories/repository.factory'; + +const mockRepositories = createMockRepositories(3); + +vi.mock('../../shared/reducers/rag.reducer', async () => { + const actual: any = await vi.importActual('../../shared/reducers/rag.reducer'); + return { + ...actual, + useListRagRepositoriesQuery: vi.fn(() => ({ + data: mockRepositories, + isFetching: false, + isLoading: false, + })), + useUpdateRagRepositoryMutation: vi.fn(() => [vi.fn(), { isSuccess: false, isError: false, error: null, isLoading: false }]), + useDeleteRagRepositoryMutation: vi.fn(() => [vi.fn(), { isSuccess: false, isError: false, error: null, isLoading: false }]), + ragApi: { + ...actual.ragApi, + util: { + invalidateTags: vi.fn(), + }, + }, + }; +}); + +const defaultProps = { + selectedItems: [] as any[], + setSelectedItems: vi.fn(), + setNewRepositoryModalVisible: vi.fn(), + setEdit: vi.fn(), +}; + +const adminState = { + user: { info: { isAdmin: true, isRagAdmin: false, isUser: true, isApiUser: false } }, +}; + +const ragAdminState = { + user: { info: { isAdmin: false, isRagAdmin: true, isUser: false, isApiUser: false } }, +}; + +describe('RepositoryActions', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + describe('admin user', () => { + it('shows Create Repository button', async () => { + renderWithProviders(, { preloadedState: adminState }); + + await waitFor(() => { + expect(screen.getByText('Create Repository')).toBeInTheDocument(); + }); + }); + + it('shows Delete in actions dropdown', async () => { + const user = userEvent.setup(); + const propsWithSelection = { + ...defaultProps, + selectedItems: [mockRepositories[0]], + }; + renderWithProviders(, { preloadedState: adminState }); + + const actionsButton = screen.getByText('Actions'); + await user.click(actionsButton); + + await waitFor(() => { + expect(screen.getByText('Delete')).toBeInTheDocument(); + }); + }); + }); + + describe('RAG admin user', () => { + it('does not show Create Repository button', async () => { + renderWithProviders(, { preloadedState: ragAdminState }); + + await waitFor(() => { + expect(screen.queryByText('Create Repository')).not.toBeInTheDocument(); + }); + }); + + it('does not show Delete in actions dropdown', async () => { + const user = userEvent.setup(); + const propsWithSelection = { + ...defaultProps, + selectedItems: [mockRepositories[0]], + }; + renderWithProviders(, { preloadedState: ragAdminState }); + + const actionsButton = screen.getByText('Actions'); + await user.click(actionsButton); + + await waitFor(() => { + expect(screen.queryByText('Delete')).not.toBeInTheDocument(); + }); + }); + + it('shows Edit in actions dropdown', async () => { + const user = userEvent.setup(); + const propsWithSelection = { + ...defaultProps, + selectedItems: [mockRepositories[0]], + }; + renderWithProviders(, { preloadedState: ragAdminState }); + + const actionsButton = screen.getByText('Actions'); + await user.click(actionsButton); + + await waitFor(() => { + expect(screen.getByText('Edit')).toBeInTheDocument(); + }); + }); + }); +}); diff --git a/lib/user-interface/react/src/components/repository-management/RepositoryActions.tsx b/lib/user-interface/react/src/components/repository-management/RepositoryActions.tsx index e714dcc9e..b4ffe8350 100644 --- a/lib/user-interface/react/src/components/repository-management/RepositoryActions.tsx +++ b/lib/user-interface/react/src/components/repository-management/RepositoryActions.tsx @@ -23,7 +23,8 @@ import { Checkbox, SpaceBetween, } from '@cloudscape-design/components'; -import { useAppDispatch } from '@/config/store'; +import { useAppDispatch, useAppSelector } from '@/config/store'; +import { selectCurrentUserIsAdmin } from '@/shared/reducers/user.reducer'; import { useNotificationService } from '@/shared/util/hooks'; import { INotificationService } from '@/shared/notification/notification.service'; import { Action, ThunkDispatch } from '@reduxjs/toolkit'; @@ -46,6 +47,7 @@ export type RepositoryActionProps = { function RepositoryActions (props: RepositoryActionProps): ReactElement { const dispatch = useAppDispatch(); + const isAdmin = useAppSelector(selectCurrentUserIsAdmin); const notificationService = useNotificationService(dispatch); const { setEdit, setNewRepositoryModalVisible, setSelectedItems } = props; const { isFetching } = useListRagRepositoriesQuery(undefined, { @@ -64,13 +66,15 @@ function RepositoryActions (props: RepositoryActionProps): ReactElement { onClick={handleRefresh} ariaLabel='Refresh repository table' /> - {RepositoryActionButton(dispatch, notificationService, props)} - + {RepositoryActionButton(dispatch, notificationService, props, isAdmin)} + {isAdmin && ( + + )} ); } @@ -79,7 +83,7 @@ type RagRepository = RagRepositoryConfig & { legacy?: boolean }; -function RepositoryActionButton (dispatch: ThunkDispatch, notificationService: INotificationService, props: RepositoryActionProps): ReactElement { +function RepositoryActionButton (dispatch: ThunkDispatch, notificationService: INotificationService, props: RepositoryActionProps, isAdmin: boolean): ReactElement { const { setEdit, selectedItems, setSelectedItems, setNewRepositoryModalVisible } = props; const [disabledModal, setDisabledModel] = useState(false); const [showModal, setShowModal] = useState(false); @@ -158,11 +162,11 @@ function RepositoryActionButton (dispatch: ThunkDispatch, noti disabled: selectedItems.length !== 1 || selectedRepo?.legacy, disabledReason: selectedItems.length !== 1 ? '' : selectedRepo?.legacy ? 'Legacy repositories created through YAML cannot be edited.' : undefined }, - { + ...(isAdmin ? [{ id: 'rm', text: 'Delete', disabled: selectedItems.length !== 1, - }]; + }] : [])]; return ( { }); }); - it('should have Create Repository button', async () => { - renderWithProviders(); + it('should have Create Repository button for admin users', async () => { + renderWithProviders(, { + preloadedState: { user: { info: { isAdmin: true, isRagAdmin: false, isUser: true, isApiUser: false } } } + }); await waitFor(() => { expect(screen.getByText('Create Repository')).toBeInTheDocument(); diff --git a/lib/user-interface/react/src/components/repository-management/createRepository/CreateRepositoryModal.test.tsx b/lib/user-interface/react/src/components/repository-management/createRepository/CreateRepositoryModal.test.tsx index 25c860c9d..341f05a72 100644 --- a/lib/user-interface/react/src/components/repository-management/createRepository/CreateRepositoryModal.test.tsx +++ b/lib/user-interface/react/src/components/repository-management/createRepository/CreateRepositoryModal.test.tsx @@ -29,6 +29,39 @@ vi.mock('@/shared/util/hooks', () => ({ }), })); +const adminState = { + user: { info: { isAdmin: true, isRagAdmin: false, isUser: true, isApiUser: false } }, +}; + +const ragAdminState = { + user: { info: { isAdmin: false, isRagAdmin: true, isUser: false, isApiUser: false } }, +}; + +const existingRepo: RagRepositoryConfig = { + repositoryId: 'test-repo', + repositoryName: 'Test Repository', + type: RagRepositoryType.OPENSEARCH, + embeddingModelId: 'amazon.titan-embed-text-v1', + allowedGroups: ['admin'], + opensearchConfig: { + dataNodes: 2, + dataNodeInstanceType: 't3.small.search', + masterNodes: 0, + masterNodeInstanceType: 't3.small.search', + volumeSize: 10, + }, + pipelines: [ + { + autoRemove: true, + trigger: 'event' as const, + s3Bucket: 'test-bucket', + s3Prefix: 'documents/', + chunkSize: 512, + chunkOverlap: 51, + }, + ], +}; + describe('CreateRepositoryModal', () => { let mockUpdateMutation: ReturnType; let mockCreateMutation: ReturnType; @@ -71,31 +104,6 @@ describe('CreateRepositoryModal', () => { }); it('renders update modal with existing repository data', async () => { - const existingRepo: RagRepositoryConfig = { - repositoryId: 'test-repo', - repositoryName: 'Test Repository', - type: RagRepositoryType.OPENSEARCH, - embeddingModelId: 'amazon.titan-embed-text-v1', - allowedGroups: ['admin'], - opensearchConfig: { - dataNodes: 2, - dataNodeInstanceType: 't3.small.search', - masterNodes: 0, - masterNodeInstanceType: 't3.small.search', - volumeSize: 10, - }, - pipelines: [ - { - autoRemove: true, - trigger: 'event' as const, - s3Bucket: 'test-bucket', - s3Prefix: 'documents/', - chunkSize: 512, - chunkOverlap: 51, - }, - ], - }; - renderWithProviders( { setVisible={vi.fn()} selectedItems={[existingRepo]} setSelectedItems={vi.fn()} - /> + />, + { preloadedState: adminState } ); // Wait for the modal to render with update title @@ -121,31 +130,6 @@ describe('CreateRepositoryModal', () => { }); it('includes pipelines in updates when pipeline configuration changes', async () => { - const existingRepo: RagRepositoryConfig = { - repositoryId: 'test-repo', - repositoryName: 'Test Repository', - type: RagRepositoryType.OPENSEARCH, - embeddingModelId: 'amazon.titan-embed-text-v1', - allowedGroups: ['admin'], - opensearchConfig: { - dataNodes: 2, - dataNodeInstanceType: 't3.small.search', - masterNodes: 0, - masterNodeInstanceType: 't3.small.search', - volumeSize: 10, - }, - pipelines: [ - { - autoRemove: true, - trigger: 'event' as const, - s3Bucket: 'test-bucket', - s3Prefix: 'documents/', - chunkSize: 512, - chunkOverlap: 51, - }, - ], - }; - renderWithProviders( { setVisible={vi.fn()} selectedItems={[existingRepo]} setSelectedItems={vi.fn()} - /> + />, + { preloadedState: adminState } ); await waitFor(() => { @@ -190,4 +175,47 @@ describe('CreateRepositoryModal', () => { // Verify create mutation is available (not update) expect(mockCreateMutation).toBeDefined(); }); + + it('admin edit shows all wizard steps', async () => { + renderWithProviders( + , + { preloadedState: adminState } + ); + + await waitFor(() => { + expect(screen.getAllByText('Repository Configuration').length).toBeGreaterThan(0); + expect(screen.getAllByText('Pipeline Configuration').length).toBeGreaterThan(0); + expect(screen.getAllByText('Metadata & Tags').length).toBeGreaterThan(0); + expect(screen.getAllByText('Review and Update').length).toBeGreaterThan(0); + }); + }); + + it('RAG admin edit shows only Pipeline and Review steps', async () => { + renderWithProviders( + , + { preloadedState: ragAdminState } + ); + + await waitFor(() => { + expect(screen.getAllByText('Pipeline Configuration').length).toBeGreaterThan(0); + expect(screen.getAllByText('Review and Update').length).toBeGreaterThan(0); + }); + + expect(screen.queryByText('Repository Configuration')).not.toBeInTheDocument(); + expect(screen.queryByText('Metadata & Tags')).not.toBeInTheDocument(); + }); }); diff --git a/lib/user-interface/react/src/components/repository-management/createRepository/CreateRepositoryModal.tsx b/lib/user-interface/react/src/components/repository-management/createRepository/CreateRepositoryModal.tsx index 3ae2d68f0..e0b0e79a7 100644 --- a/lib/user-interface/react/src/components/repository-management/createRepository/CreateRepositoryModal.tsx +++ b/lib/user-interface/react/src/components/repository-management/createRepository/CreateRepositoryModal.tsx @@ -17,7 +17,8 @@ import { Modal, Wizard } from '@cloudscape-design/components'; import { ReactElement, useEffect, useMemo } from 'react'; import { scrollToInvalid, useValidationReducer } from '../../../shared/validation'; -import { useAppDispatch } from '../../../config/store'; +import { useAppDispatch, useAppSelector } from '../../../config/store'; +import { selectCurrentUserIsAdmin, selectCurrentUserIsRagAdmin } from '../../../shared/reducers/user.reducer'; import { useNotificationService } from '../../../shared/util/hooks'; import { setConfirmationModal } from '../../../shared/reducers/modal.reducer'; import { useCreateRagRepositoryMutation, useUpdateRagRepositoryMutation } from '../../../shared/reducers/rag.reducer'; @@ -73,6 +74,8 @@ export function CreateRepositoryModal (props: CreateRepositoryModalProps): React metadata: { tags: [] } }) as RagRepositoryConfig; const dispatch = useAppDispatch(); + const isAdmin = useAppSelector(selectCurrentUserIsAdmin); + const isRagAdmin = useAppSelector(selectCurrentUserIsRagAdmin); const notificationService = useNotificationService(dispatch); const { @@ -225,6 +228,7 @@ export function CreateRepositoryModal (props: CreateRepositoryModalProps): React isEdit={isEdit} /> ), onEdit: true, + onRagAdminEdit: false, }, { title: 'Pipeline Configuration', @@ -241,6 +245,7 @@ export function CreateRepositoryModal (props: CreateRepositoryModalProps): React ), isOptional: true, onEdit: true, + onRagAdminEdit: true, }, { title: 'Metadata & Tags', @@ -255,6 +260,7 @@ export function CreateRepositoryModal (props: CreateRepositoryModalProps): React ), isOptional: true, onEdit: true, + onRagAdminEdit: false, }, { title: `Review and ${isEdit ? 'Update' : 'Create'}`, @@ -264,8 +270,13 @@ export function CreateRepositoryModal (props: CreateRepositoryModalProps): React info={isEdit ? 'Any changes will cause a redeployment of the vector store, which may result in data loss of previously store RAG documents.' : undefined} /> ), onEdit: state.form, + onRagAdminEdit: true, }, - ].filter((step) => isEdit ? step.onEdit : true); + ].filter((step) => { + if (isEdit && !isAdmin && isRagAdmin) return step.onRagAdminEdit; + if (isEdit) return step.onEdit; + return true; + }); function resetState () { setState({ diff --git a/lib/user-interface/react/src/components/settings/AwsCredentialsPanel.tsx b/lib/user-interface/react/src/components/settings/AwsCredentialsPanel.tsx new file mode 100644 index 000000000..4d47167f2 --- /dev/null +++ b/lib/user-interface/react/src/components/settings/AwsCredentialsPanel.tsx @@ -0,0 +1,282 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import React, { useCallback, useEffect, useMemo, useState } from 'react'; +import { + Box, + Button, + Container, + Form, + FormField, + Header, + Input, + SpaceBetween, + StatusIndicator, + TextContent +} from '@cloudscape-design/components'; +import { lisaAxios } from '@/shared/reducers/reducer.utils'; +import { MCP_WORKBENCH_URI } from '@/components/utils'; + +type AwsStatusResponse = { + connected: boolean; + expiresAt?: string; +}; + +type ConnectResponse = { + accountId: string; + arn: string; + expiresAt: string; +}; + +type AwsCredentialsPanelProps = { + /** Optional hook for parent components to react when connection state changes */ + onStatusChange?: (status: AwsStatusResponse) => void; + /** Optional session identifier to scope credentials per-session */ + sessionId?: string; + /** Optional header title; defaults to "AWS Credentials" */ + title?: string; +}; + +const AwsCredentialsPanel: React.FC = ({ onStatusChange, sessionId, title = 'AWS Credentials' }) => { + const [accessKeyId, setAccessKeyId] = useState(''); + const [secretAccessKey, setSecretAccessKey] = useState(''); + const [sessionToken, setSessionToken] = useState(''); + const [region, setRegion] = useState('us-east-1'); + + const [status, setStatus] = useState(null); + const [accountId, setAccountId] = useState(null); + const [arn, setArn] = useState(null); + + const [isLoadingStatus, setIsLoadingStatus] = useState(false); + const [isSubmitting, setIsSubmitting] = useState(false); + const [isDisconnecting, setIsDisconnecting] = useState(false); + const [error, setError] = useState(null); + + const expiresInMinutes = useMemo(() => { + if (!status?.connected || !status.expiresAt) return null; + try { + const expires = new Date(status.expiresAt).getTime(); + const now = Date.now(); + const diffMs = expires - now; + if (diffMs <= 0) return 0; + return Math.round(diffMs / 60000); + } catch { + return null; + } + }, [status]); + + const loadStatus = useCallback(async () => { + try { + setIsLoadingStatus(true); + setError(null); + const { data } = await lisaAxios.get(`${MCP_WORKBENCH_URI}/api/aws/status`, { + headers: sessionId ? { 'X-Session-Id': sessionId } : undefined, + }); + setStatus(data); + if (onStatusChange) onStatusChange(data); + } catch (e: any) { + setError(e.message ?? 'Failed to load AWS status'); + } finally { + setIsLoadingStatus(false); + } + }, [sessionId, onStatusChange]); + + useEffect(() => { + setStatus(null); + setAccountId(null); + setArn(null); + setError(null); + void loadStatus(); + }, [sessionId, loadStatus]); + + const handleConnect = async () => { + setError(null); + setIsSubmitting(true); + try { + const body = { + accessKeyId: accessKeyId.trim(), + secretAccessKey: secretAccessKey.trim(), + sessionToken: sessionToken.trim() || undefined, + region: region.trim() + }; + const { data } = await lisaAxios.post(`${MCP_WORKBENCH_URI}/api/aws/connect`, body, { + headers: sessionId ? { 'X-Session-Id': sessionId } : undefined, + }); + setAccountId(data.accountId); + setArn(data.arn); + const newStatus: AwsStatusResponse = { connected: true, expiresAt: data.expiresAt }; + setStatus(newStatus); + if (onStatusChange) onStatusChange(newStatus); + } catch (e: any) { + setError(e.message ?? 'Failed to connect AWS credentials'); + } finally { + setIsSubmitting(false); + } + }; + + const handleDisconnect = async () => { + setError(null); + setIsDisconnecting(true); + try { + await lisaAxios.delete(`${MCP_WORKBENCH_URI}/api/aws/connect`, { + headers: sessionId ? { 'X-Session-Id': sessionId } : undefined, + }); + const newStatus: AwsStatusResponse = { connected: false }; + setStatus(newStatus); + setAccountId(null); + setArn(null); + if (onStatusChange) onStatusChange(newStatus); + } catch (e: any) { + setError(e.message ?? 'Failed to disconnect AWS credentials'); + } finally { + setIsDisconnecting(false); + } + }; + + const isConnected = status?.connected; + const isExpired = isConnected && expiresInMinutes !== null && expiresInMinutes <= 0; + + return ( +
{title}} + actions={ + + {isConnected && ( + + )} + + + } + > + + +

+ Connect your AWS credentials to this chat session. Your keys are validated and converted to + short-lived session credentials stored securely in memory. To use them, your MCP server must + expose tools that leverage these credentials (for example, S3 list buckets or other AWS operations). + Without such tools, connecting credentials has no effect. +

+

+ Caution: Credentials with broad permissions can create, modify, or delete resources + in your AWS account. Use IAM credentials with the minimum permissions required for the tools you + intend to use. +

+
+ Connection status}> + + + {isConnected && !isExpired && expiresInMinutes != null + ? `Connected (expires in ${expiresInMinutes} minutes)` + : isConnected && isExpired + ? 'Connected (expired)' + : 'Not connected'} + + {accountId && arn && ( + + + Account ID: {accountId} + + + ARN: {arn} + + + )} + + + Credentials are discarded when your session ends. + + + {error && ( + + + {error} + + + )} + + + + Enter AWS credentials}> + + + setAccessKeyId(detail.value)} + type='text' + autoComplete='off' + /> + + + setSecretAccessKey(detail.value)} + type='password' + autoComplete='off' + /> + + + setSessionToken(detail.value)} + type='password' + autoComplete='off' + /> + + + setRegion(detail.value)} + type='text' + autoComplete='off' + /> + + + +
+
+ ); +}; + +export default AwsCredentialsPanel; diff --git a/lib/user-interface/react/src/components/utils.ts b/lib/user-interface/react/src/components/utils.ts index 58b14b767..374e6eb54 100644 --- a/lib/user-interface/react/src/components/utils.ts +++ b/lib/user-interface/react/src/components/utils.ts @@ -25,6 +25,11 @@ const stripTrailingSlash = (str) => { export const RESTAPI_URI = stripTrailingSlash(window.env.RESTAPI_URI); export const RESTAPI_VERSION = window.env.RESTAPI_VERSION; +/** Base URL for MCP Workbench HTTP (MCP stream + /api/aws). From SSM …/mcpWorkbench/endpoint (workbench ALB; distinct from Serve API when custom domains are used). */ +export const MCP_WORKBENCH_URI = window.env.MCP_WORKBENCH_URI + ? stripTrailingSlash(window.env.MCP_WORKBENCH_URI) + : RESTAPI_URI; + /** * Gets base URI for API Gateway. This can either be the APIGW execution URL directly or a * custom domain. @@ -74,7 +79,15 @@ export const getSessionDisplay = (session: LisaChatSession, maxLength?: number) export const getDisplayableMessage = (content: MessageContent, ragCitations?: string) => { if (Array.isArray(content)) { - return content.find((item) => item.type === 'text' && !item.text.startsWith('File context:'))?.text + (ragCitations ?? '') || ''; + return ( + content.find( + (item) => + item.type === 'text' && + !item.text.startsWith('File context:') && + !item.text.startsWith('Context from document search:') + )?.text + (ragCitations ?? '') + || '' + ); } return content + (ragCitations ?? ''); }; diff --git a/lib/user-interface/react/src/config/oidc.config.ts b/lib/user-interface/react/src/config/oidc.config.ts index 8e807c2b6..eb61e7326 100644 --- a/lib/user-interface/react/src/config/oidc.config.ts +++ b/lib/user-interface/react/src/config/oidc.config.ts @@ -16,6 +16,10 @@ import { AuthProviderProps } from 'react-oidc-context'; +/** OAuth redirect_uri must not include the hash fragment (RFC 6749). Use origin + pathname. */ +export const getRedirectUri = (): string => + `${window.location.origin}${window.location.pathname}`; + interface LisaOidcConfig { authority: string; client_id: string; @@ -28,8 +32,8 @@ interface LisaOidcConfig { export const OidcConfig: AuthProviderProps & LisaOidcConfig = { authority: window.env.AUTHORITY, client_id: window.env.CLIENT_ID, - redirect_uri: window.location.toString(), - post_logout_redirect_uri: window.location.toString(), + redirect_uri: getRedirectUri(), + post_logout_redirect_uri: getRedirectUri(), scope: 'openid profile email' + (window.env.CUSTOM_SCOPES ? ' ' + window.env.CUSTOM_SCOPES.join(' ') : ''), response_type: 'code', }; diff --git a/lib/user-interface/react/src/main.tsx b/lib/user-interface/react/src/main.tsx index dfc33d499..b422a0a36 100644 --- a/lib/user-interface/react/src/main.tsx +++ b/lib/user-interface/react/src/main.tsx @@ -18,32 +18,11 @@ import React from 'react'; import ReactDOM from 'react-dom/client'; import { Provider } from 'react-redux'; import './index.css'; -import AppConfigured from './components/app-configured'; import '@cloudscape-design/global-styles/index.css'; -import getStore from './config/store'; import { applyTheme } from '@cloudscape-design/components/theming'; import { Theme } from '@cloudscape-design/components/theming'; -// Conditionally apply custom theme if branding is enabled -if (window.env?.USE_CUSTOM_BRANDING) { - try { - // Vite will only include files that actually exist - const themeModules = import.meta.glob('./theme*.ts'); - - // Try custom first, fall back to base - const themeModule = themeModules['./theme-custom.ts'] - ? await themeModules['./theme-custom.ts']() - : await themeModules['./theme.ts'](); - - const { brandTheme } = themeModule as { brandTheme: Theme }; - applyTheme({ theme: brandTheme }); - console.log('Theme loaded:', themeModules['./theme-custom.ts'] ? 'custom' : 'base'); - } catch { - console.warn('No theme file found, using Cloudscape default theme'); - } -} - declare global { // eslint-disable-next-line @typescript-eslint/consistent-type-definitions interface Window { @@ -53,9 +32,11 @@ declare global { ADMIN_GROUP?: string; USER_GROUP?: string; API_GROUP?: string; + RAG_ADMIN_GROUP?: string; JWT_GROUPS_PROP?: string; CUSTOM_SCOPES: string[]; RESTAPI_URI: string; + MCP_WORKBENCH_URI?: string; RESTAPI_VERSION: string; RAG_ENABLED: boolean; HOSTED_MCP_ENABLED: boolean; @@ -70,6 +51,64 @@ declare global { } } +const baseUrl = import.meta.env.BASE_URL || '/'; +const normalizedBase = baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`; + +const loadRuntimeScript = async (scriptName: string): Promise => { + await new Promise((resolve, reject) => { + const script = document.createElement('script'); + script.src = `${normalizedBase}${scriptName}`; + script.async = false; + script.onload = () => resolve(); + script.onerror = () => reject(new Error(`Failed to load ${scriptName}`)); + document.head.appendChild(script); + }); +}; + +await loadRuntimeScript('env.js'); +try { + await loadRuntimeScript('git-info.js'); +} catch { + // git-info.js is generated at build time; not present in dev/CI + // App runs fine without it β€” window.gitInfo remains undefined +} + +const favicon = document.getElementById('favicon') as HTMLLinkElement | null; +if (favicon) { + const brandingDir = window.env?.USE_CUSTOM_BRANDING ? 'custom' : 'base'; + favicon.href = `${normalizedBase}branding/${brandingDir}/favicon.ico`; +} + +const pageTitle = document.getElementById('page-title'); +if (pageTitle) { + const displayName = window.env?.CUSTOM_DISPLAY_NAME || 'LISA'; + pageTitle.textContent = `${displayName} AI Chat Assistant`; +} + +// Conditionally apply custom theme if branding is enabled +if (window.env?.USE_CUSTOM_BRANDING) { + try { + // Vite will only include files that actually exist + const themeModules = import.meta.glob('./theme*.ts'); + + // Try custom first, fall back to base + const themeModule = themeModules['./theme-custom.ts'] + ? await themeModules['./theme-custom.ts']() + : await themeModules['./theme.ts'](); + + const { brandTheme } = themeModule as { brandTheme: Theme }; + applyTheme({ theme: brandTheme }); + console.log('Theme loaded:', themeModules['./theme-custom.ts'] ? 'custom' : 'base'); + } catch { + console.warn('No theme file found, using Cloudscape default theme'); + } +} + +const [{ default: AppConfigured }, { default: getStore }] = await Promise.all([ + import('./components/app-configured'), + import('./config/store'), +]); + const store = getStore(); ReactDOM.createRoot(document.getElementById('root')!).render( diff --git a/lib/user-interface/react/src/pages/Home.tsx b/lib/user-interface/react/src/pages/Home.tsx index d8825709c..56ae89fe3 100644 --- a/lib/user-interface/react/src/pages/Home.tsx +++ b/lib/user-interface/react/src/pages/Home.tsx @@ -17,6 +17,7 @@ import { useEffect, useState } from 'react'; import { useNavigate } from 'react-router-dom'; import { useAuth } from '../auth/useAuth'; +import { getRedirectUri } from '../config/oidc.config'; import { Alert, Box, Button, Modal } from '@cloudscape-design/components'; import { getBrandingAssetPath } from '../shared/util/branding'; @@ -47,7 +48,7 @@ export function Home ({ setNav }) {