diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 08bedb572..a94289b76 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -13,8 +13,8 @@ echo "source .venv/bin/activate" >> ~/.zshrc echo "alias deploylisa='make clean && npm ci && make deploy HEADLESS=true'" >> ~/.bashrc echo "alias deploylisa='make clean && npm ci && make deploy HEADLESS=true'" >> ~/.zshrc -pip install --upgrade pip -pip3 install yq huggingface_hub s5cmd +python -m pip install --upgrade pip +pip3 install yq==3.4.3 huggingface_hub==0.26.3 s5cmd==2.2.2 make installPythonRequirements make createTypeScriptEnvironment diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..7a3263faf --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,240 @@ +version: 2 +updates: + # Enable version updates for npm + - package-ecosystem: "npm" + directory: "/" + schedule: + interval: "weekly" + day: "monday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + rebase-strategy: "auto" + + # Enable version updates for pip - root directory + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + day: "tuesday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + + # Enable version updates for pip - lisa-sdk + - package-ecosystem: "pip" + directory: "/lisa-sdk" + schedule: + interval: "weekly" + day: "tuesday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + + # Enable version updates for pip - authorizer layer + - package-ecosystem: "pip" + directory: "/lib/core/layers/authorizer" + schedule: + interval: "weekly" + day: "tuesday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + + # Enable version updates for pip - common layer + - package-ecosystem: "pip" + directory: "/lib/core/layers/common" + schedule: + interval: "weekly" + day: "tuesday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + + # Enable version updates for pip - fastapi layer + - package-ecosystem: "pip" + directory: "/lib/core/layers/fastapi" + schedule: + interval: "weekly" + day: "tuesday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + + # Enable version updates for pip - rest-api src + - package-ecosystem: "pip" + directory: "/lib/serve/rest-api/src" + schedule: + interval: "weekly" + day: "tuesday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + + # Enable version updates for pip - rag layer + - package-ecosystem: "pip" + directory: "/lib/rag/layer" + schedule: + interval: "weekly" + day: "tuesday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + + # Enable security updates for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + day: "wednesday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 5 + + # Enable updates for Docker - RAG ingestion + - package-ecosystem: "docker" + directory: "/lib/rag/ingestion/ingestion-image" + schedule: + interval: "weekly" + day: "thursday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 3 + + # Enable updates for Docker - vector store + - package-ecosystem: "docker" + directory: "/lib/rag/vector-store/state_machine" + schedule: + interval: "weekly" + day: "thursday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 3 + + # Enable updates for Docker - REST API + - package-ecosystem: "docker" + directory: "/lib/serve/rest-api" + schedule: + interval: "weekly" + day: "thursday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 3 + + # Enable updates for Docker - MCP workbench + - package-ecosystem: "docker" + directory: "/lib/serve/mcp-workbench" + schedule: + interval: "weekly" + day: "thursday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 3 + + # Enable updates for Docker - VLLM + - package-ecosystem: "docker" + directory: "/lib/serve/ecs-model/vllm" + schedule: + interval: "weekly" + day: "thursday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 3 + + # Enable updates for Docker - TEI + - package-ecosystem: "docker" + directory: "/lib/serve/ecs-model/embedding/tei" + schedule: + interval: "weekly" + day: "thursday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 3 + + # Enable updates for Docker - instructor + - package-ecosystem: "docker" + directory: "/lib/serve/ecs-model/embedding/instructor" + schedule: + interval: "weekly" + day: "thursday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 3 + + # Enable updates for Docker - TGI + - package-ecosystem: "docker" + directory: "/lib/serve/ecs-model/textgen/tgi" + schedule: + interval: "weekly" + day: "thursday" + time: "09:00" + reviewers: + - "awslabs/lisa-maintainers" + commit-message: + prefix: "chore" + include: "scope" + open-pull-requests-limit: 3 diff --git a/.github/scripts/generate-pr-description.sh b/.github/scripts/generate-pr-description.sh new file mode 100755 index 000000000..df36f9df2 --- /dev/null +++ b/.github/scripts/generate-pr-description.sh @@ -0,0 +1,261 @@ +#!/bin/bash + +# GitHub Actions script to generate AI-powered PR descriptions using Amazon Bedrock +# This script analyzes commit history and generates formatted PR descriptions matching LISA's changelog format + +set -e # Exit on any error + +RELEASE_TAG="$1" + +if [ -z "$RELEASE_TAG" ]; then + echo "Error: Release tag is required as first argument" + exit 1 +fi + +echo "πŸ” Fetching commit history and PR details for release branch (comparing against main)..." + +# Function to check if GitHub CLI is properly authenticated +check_gh_auth() { + if ! command -v gh >/dev/null 2>&1; then + return 1 + fi + + # Check authentication status with timeout + if timeout 10s gh auth status >/dev/null 2>&1; then + return 0 + else + return 1 + fi +} + +# Function to get PR info for a commit +get_commit_pr_info() { + local commit_hash="$1" + local commit_subject="$2" + local commit_author="$3" + + # Try to find associated PRs using GitHub CLI (only if authenticated) + if [[ "$GH_AUTHENTICATED" == "true" ]]; then + # Use timeout to prevent hanging and suppress errors + local pr_info=$(timeout 15s gh pr list --search "$commit_hash" --state merged --json number,title,body --limit 1 2>/dev/null || echo "[]") + + if [[ "$pr_info" != "[]" ]] && [[ -n "$pr_info" ]] && [[ "$pr_info" != *"error"* ]]; then + local pr_number=$(echo "$pr_info" | jq -r '.[0].number // empty' 2>/dev/null) + local pr_title=$(echo "$pr_info" | jq -r '.[0].title // empty' 2>/dev/null) + local pr_body=$(echo "$pr_info" | jq -r '.[0].body // empty' 2>/dev/null) + + if [[ -n "$pr_number" && -n "$pr_title" && "$pr_number" != "null" ]]; then + echo "- $commit_subject ($commit_author)" + echo " PR #$pr_number: $pr_title" + + if [[ -n "$pr_body" && "$pr_body" != "null" ]]; then + # Truncate very long PR descriptions and clean up formatting + local cleaned_body=$(echo "$pr_body" | head -c 500 | tr '\n' ' ' | tr -s ' ') + echo " $cleaned_body" + fi + return 0 + fi + fi + fi + + # Fallback to commit message only + echo "- $commit_subject ($commit_author)" + return 1 +} + +# Check GitHub CLI authentication status +if check_gh_auth; then + echo "βœ… GitHub CLI authenticated - will attempt PR lookups" + GH_AUTHENTICATED="true" +else + echo "⚠️ GitHub CLI not available or not authenticated - using commit messages only" + echo " To enable PR lookups, run: gh auth login" + GH_AUTHENTICATED="false" +fi + +# Get commits with PR information +COMMITS="" +commit_count=0 +pr_found_count=0 +echo "πŸ“‘ Looking up commit information..." + +# Determine the correct main branch reference +MAIN_REF="" +if git rev-parse --verify main >/dev/null 2>&1; then + MAIN_REF="main" +elif git rev-parse --verify origin/main >/dev/null 2>&1; then + MAIN_REF="origin/main" +elif git rev-parse --verify refs/remotes/origin/main >/dev/null 2>&1; then + MAIN_REF="refs/remotes/origin/main" +else + echo "❌ Cannot find main branch reference. Available branches:" + git branch -a 2>/dev/null || echo "No branches found" + echo "Using develop branch as fallback for commit analysis" + MAIN_REF="HEAD~50" # Fallback to last 50 commits +fi + +echo "πŸ” Using main branch reference: $MAIN_REF" + +while IFS='|' read -r hash subject author; do + if [[ -n "$hash" ]]; then + # Get commit info and handle return code properly with set -e + commit_info=$(get_commit_pr_info "$hash" "$subject" "$author") || commit_return_code=$? + + # Check if PR info was found (return code 0 means PR found) + if [[ -z "$commit_return_code" ]]; then + pr_found_count=$((pr_found_count + 1)) + fi + + if [[ -n "$COMMITS" ]]; then + COMMITS="$COMMITS"$'\n'"$commit_info" + else + COMMITS="$commit_info" + fi + commit_count=$((commit_count + 1)) + + # Show progress for long-running operations + if [[ $((commit_count % 3)) -eq 0 ]]; then + echo " ... processed $commit_count commits" + fi + + # Reset the return code variable for next iteration + unset commit_return_code + fi +done < <(git log $MAIN_REF..HEAD --pretty=format:"%H|%s|%an" --no-merges 2>/dev/null || echo "") + +echo "βœ… Processed $commit_count commits" +if [[ "$GH_AUTHENTICATED" == "true" ]]; then + echo "πŸ“Š Found PR information for $pr_found_count commits" +fi + +# Get unique contributors from commits for acknowledgements (use email to extract GitHub username) +CONTRIBUTORS=$(git log $MAIN_REF..HEAD --pretty=format:"%ae" --no-merges 2>/dev/null | sort -u | sed 's/@.*$//' | sed 's/^/* @/' | tr '\n' '\n' || echo "") + +# Get the current version from VERSION file to use as previous version in changelog +if [ -f "VERSION" ]; then + PREVIOUS_VERSION="v$(cat VERSION)" +else + # Fallback to git tag if VERSION file doesn't exist + PREVIOUS_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "unknown") +fi + +# If no commits, use a default message +if [ -z "$COMMITS" ]; then + COMMITS="- Version update to $RELEASE_TAG" + CONTRIBUTORS="* @github_actions_lisa" +fi + +echo "πŸ“ Found commits:" +echo -e "$COMMITS" +echo "" +echo "πŸ‘₯ Contributors: $CONTRIBUTORS" +echo "🏷️ Previous version: $PREVIOUS_VERSION" +echo "" + +# Create a prompt for Bedrock to generate PR description matching LISA's changelog format +PROMPT="Please create a comprehensive pull request description for the LISA software release that covers ALL commits and PRs provided. Use this format structure: + +# $RELEASE_TAG + +## Key Features + +### [Feature Name - create as many sections as needed] +[Description of the feature and its capabilities] +**[Subcategory if applicable]:** +- **[Component]**: [Description of enhancement] +- **[Component]**: [Description of enhancement] + +### [Another Feature Name - repeat for each major feature/PR] +[Description and details] + +### [Continue for ALL features found in the commits] +[Ensure every significant commit/PR is represented] + +## Key Changes +- **[Category]**: [Description of change] +- **[Category]**: [Description of change] +- **[Category]**: [Description of change] +[Include ALL significant changes, not just a few examples] + +## Acknowledgements +$CONTRIBUTORS + +**Full Changelog**: https://github.com/awslabs/LISA/compare/$PREVIOUS_VERSION..$RELEASE_TAG + +--- + +IMPORTANT: You MUST analyze and include ALL of the following commits/PRs in your response: +$COMMITS + +Requirements: +1. Create a separate Key Features section for EVERY major feature, enhancement, or significant PR listed above +2. Do NOT limit yourself to just 2-3 features - cover ALL significant changes +3. Group related smaller commits together into logical feature sections +4. Use descriptive, professional language for each feature section +5. Ensure every PR mentioned above gets appropriate coverage in the description +6. List ALL significant changes in the Key Changes section +7. If there are many commits, prioritize PRs with detailed descriptions first, then group commits by theme + +Generate a comprehensive description that covers ALL the provided commits and PRs now:" + +# Call Bedrock to generate description +echo "πŸ€– Generating PR description with Bedrock Claude 3 Haiku..." + +# Use jq to properly construct the JSON payload to avoid escaping issues +BEDROCK_PAYLOAD=$(jq -n \ + --arg prompt "$PROMPT" \ + '{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 3000, + "messages": [ + { + "role": "user", + "content": $prompt + } + ] + }') + +RESPONSE=$(aws bedrock-runtime invoke-model \ + --model-id "anthropic.claude-3-haiku-20240307-v1:0" \ + --body "$BEDROCK_PAYLOAD" \ + --cli-binary-format raw-in-base64-out \ + /tmp/bedrock_response.json) + +# Extract the generated description from the response +DESCRIPTION=$(jq -r '.content[0].text' /tmp/bedrock_response.json) + +# Fallback description if Bedrock fails - use LISA changelog format +if [ -z "$DESCRIPTION" ] || [ "$DESCRIPTION" = "null" ]; then + echo "⚠️ Bedrock response failed, using fallback description" + DESCRIPTION="# $RELEASE_TAG + +## Key Features + +### System Updates +This release includes version updates and system improvements to enhance LISA's stability and performance. + +## Key Changes +- **Version Management**: Updated version numbers across all package files +- **Release Process**: Automated release branch creation and versioning +- **System Maintenance**: General system updates and improvements + +## Acknowledgements +$CONTRIBUTORS + +**Full Changelog**: https://github.com/awslabs/LISA/compare/$PREVIOUS_VERSION..$RELEASE_TAG" +else + echo "βœ… Successfully generated PR description with Bedrock" +fi + +echo "" +echo "πŸ“‹ Generated PR description:" +echo "----------------------------------------" +echo "$DESCRIPTION" +echo "----------------------------------------" + +# Save description to GitHub Actions output +{ + echo 'DESCRIPTION<> $GITHUB_OUTPUT diff --git a/.github/workflows/code.deploy.demo.yml b/.github/workflows/code.deploy.demo.yml index e200c1030..0170b1dee 100644 --- a/.github/workflows/code.deploy.demo.yml +++ b/.github/workflows/code.deploy.demo.yml @@ -11,7 +11,7 @@ jobs: CheckPendingWorkflow: runs-on: ubuntu-latest steps: - - uses: ahmadnassri/action-workflow-queue@v1 + - uses: ahmadnassri/action-workflow-queue@542658b3a8270cac81ae15d401b0d974732808ac # v1 with: delay: 300000 timeout: 7200000 @@ -20,9 +20,9 @@ jobs: environment: demo runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@3d21ddcb5087c3d29b7e19fe293e3455fabe32af # v4 with: aws-region: ${{ vars.AWS_REGION }} role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT }}:role/${{ vars.ROLE_NAME_TO_ASSUME }} @@ -33,11 +33,11 @@ jobs: run: | echo "${{vars.CONFIG_YAML}}" > config-custom.yaml - name: Set up Python 3.11 - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: "3.11" - name: Use Node.js 20.x - uses: actions/setup-node@v4 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 with: node-version: 20.x - name: Install CDK dependencies @@ -53,7 +53,7 @@ jobs: if: always() steps: - name: Send Notification that Demo Deploy Finished - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 env: SLACK_WEBHOOK: ${{ secrets.INTERNAL_DEV_SLACK_WEBHOOK_URL }} SLACK_COLOR: ${{ contains(join(needs.*.result, ' '), 'failure') && 'failure' || 'success' }} diff --git a/.github/workflows/code.deploy.dev.yml b/.github/workflows/code.deploy.dev.yml index be2717afc..417fe49be 100644 --- a/.github/workflows/code.deploy.dev.yml +++ b/.github/workflows/code.deploy.dev.yml @@ -11,7 +11,7 @@ jobs: CheckPendingWorkflow: runs-on: ubuntu-latest steps: - - uses: ahmadnassri/action-workflow-queue@v1 + - uses: ahmadnassri/action-workflow-queue@542658b3a8270cac81ae15d401b0d974732808ac # v1 with: delay: 300000 timeout: 7200000 @@ -20,9 +20,9 @@ jobs: environment: dev runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@3d21ddcb5087c3d29b7e19fe293e3455fabe32af # v4 with: aws-region: ${{ vars.AWS_REGION }} role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT }}:role/${{ vars.ROLE_NAME_TO_ASSUME }} @@ -33,11 +33,11 @@ jobs: run: | echo "${{vars.CONFIG_YAML}}" > config-custom.yaml - name: Set up Python 3.11 - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: "3.11" - name: Use Node.js 20.x - uses: actions/setup-node@v4 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 with: node-version: 20.x - name: Install CDK dependencies @@ -53,7 +53,7 @@ jobs: if: always() steps: - name: Send Notification that Dev Deploy Finished - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_URL }} SLACK_COLOR: ${{ contains(join(needs.*.result, ' '), 'failure') && 'failure' || 'success' }} diff --git a/.github/workflows/code.draft-release-and-tag.yml b/.github/workflows/code.draft-release-and-tag.yml index 85d53593b..b1933f23f 100644 --- a/.github/workflows/code.draft-release-and-tag.yml +++ b/.github/workflows/code.draft-release-and-tag.yml @@ -5,24 +5,28 @@ on: types: [closed] permissions: - id-token: write - contents: write + contents: read # Default read-only jobs: draft_release: runs-on: ubuntu-latest + permissions: + contents: write # Required for creating releases + id-token: write # Required for AWS authentication if: (startsWith(github.event.pull_request.head.ref, 'release/' ) || startsWith(github.event.pull_request.head.ref, 'hotfix/')) && github.event.pull_request.merged == true && github.event.pull_request.base.ref == 'main' steps: - name: Checkout Source Tag - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 with: ref: main - name: Get Version id: get-version run: | - version=$(echo ${{github.event.pull_request.head.ref}} | cut -d/ -f2) + version=$(echo "$GITHUB_HEAD_REF" | cut -d/ -f2) echo "version=$version" >> $GITHUB_OUTPUT echo "VERSION = $version" + env: + GITHUB_HEAD_REF: ${{ github.event.pull_request.head.ref }} - name: Create Release run: | gh release create ${{ steps.get-version.outputs.version }} --generate-notes -d -t "${{ steps.get-version.outputs.version }}" --target main @@ -33,15 +37,19 @@ jobs: needs: [draft_release] runs-on: ubuntu-latest if: always() + permissions: + contents: read steps: - name: Get Version id: get-version run: | - version=$(echo ${{github.event.pull_request.head.ref}} | cut -d/ -f2) + version=$(echo "$GITHUB_HEAD_REF" | cut -d/ -f2) echo "version=$version" >> $GITHUB_OUTPUT echo "VERSION = $version" + env: + GITHUB_HEAD_REF: ${{ github.event.pull_request.head.ref }} - name: Send Notification that Draft Release is Ready - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 if: (startsWith(github.event.pull_request.head.ref, 'release/' ) || startsWith(github.event.pull_request.head.ref, 'hotfix/')) && github.event.pull_request.merged == true && github.event.pull_request.base.ref == 'main' with: status: success() diff --git a/.github/workflows/code.end-to-end-test.nightly.yml b/.github/workflows/code.end-to-end-test.nightly.yml index b55035000..c96fd7471 100644 --- a/.github/workflows/code.end-to-end-test.nightly.yml +++ b/.github/workflows/code.end-to-end-test.nightly.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Send β€œE2E Tests Starting” to Slack - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 env: SLACK_TITLE: 'E2E Tests Starting' MSG_MINIMAL: true @@ -28,9 +28,9 @@ jobs: runs-on: ubuntu-latest needs: notify_e2e_start steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Setup Node.js - uses: actions/setup-node@v3 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 with: node-version: '18' cache: 'npm' @@ -42,7 +42,7 @@ jobs: run: npx cypress run --config-file cypress/cypress.e2e.config.ts - name: Archive Cypress videos & screenshots if: failure() || always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 with: name: cypress-e2e-artifacts path: | @@ -56,7 +56,7 @@ jobs: if: always() steps: - name: Notify E2E results to Slack - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 env: SLACK_COLOR: ${{ needs.e2e.result == 'success' && 'good' || 'danger' }} SLACK_TITLE: 'E2E Tests Finished' diff --git a/.github/workflows/code.hotfix.branch.yml b/.github/workflows/code.hotfix.branch.yml index 6f8bede9f..d28a36aaf 100644 --- a/.github/workflows/code.hotfix.branch.yml +++ b/.github/workflows/code.hotfix.branch.yml @@ -10,16 +10,18 @@ on: required: true permissions: - id-token: write - contents: write - pull-requests: write + contents: read # Default read-only jobs: MakeNewBranch: runs-on: ubuntu-latest + permissions: + contents: write # Required for branch creation + id-token: write # Required for AWS authentication + pull-requests: write # Required for creating PRs steps: - name: Checkout Source Tag - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 with: ref: refs/tags/${{ github.event.inputs.source_tag }} - name: Create Hotfix Branch and Update Version diff --git a/.github/workflows/code.merge.main-to-develop.yml b/.github/workflows/code.merge.main-to-develop.yml index 0a18b2596..06ad5d096 100644 --- a/.github/workflows/code.merge.main-to-develop.yml +++ b/.github/workflows/code.merge.main-to-develop.yml @@ -5,14 +5,16 @@ on: types: [released] permissions: - contents: write + contents: read # Default read-only jobs: conduct_merge: runs-on: ubuntu-latest + permissions: + contents: write # Required for merging branches steps: - name: Checkout main - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 with: ref: main ssh-key: ${{ secrets.DEPLOYMENT_SSH_KEY }} @@ -32,7 +34,7 @@ jobs: if: always() steps: - name: Send Notification that Develop is up to date - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 env: SLACK_WEBHOOK: ${{ secrets.INTERNAL_DEV_SLACK_WEBHOOK_URL }} SLACK_COLOR: ${{ contains(join(needs.*.result, ' '), 'failure') && 'failure' || 'success' }} diff --git a/.github/workflows/code.publish.yml b/.github/workflows/code.publish.yml index f90cf80ea..bf707c271 100644 --- a/.github/workflows/code.publish.yml +++ b/.github/workflows/code.publish.yml @@ -4,19 +4,18 @@ on: types: [released] permissions: - contents: read - packages: write + contents: read # Default read-only jobs: PublishLISA: runs-on: ubuntu-latest permissions: contents: read - packages: write + packages: write # Required for npm package publishing steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 # Setup .npmrc file to publish to GitHub Packages - - uses: actions/setup-node@v4 + - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 with: node-version: '20.x' registry-url: 'https://npm.pkg.github.com' @@ -30,9 +29,11 @@ jobs: needs: [ PublishLISA ] runs-on: ubuntu-latest if: always() + permissions: + contents: read steps: - name: Send Notification that package has published - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 env: SLACK_WEBHOOK: ${{ secrets.INTERNAL_DEV_SLACK_WEBHOOK_URL }} SLACK_COLOR: ${{ contains(join(needs.*.result, ' '), 'failure') && 'failure' || 'success' }} diff --git a/.github/workflows/code.release.branch.yml b/.github/workflows/code.release.branch.yml index cbe3900cd..5696a2b82 100644 --- a/.github/workflows/code.release.branch.yml +++ b/.github/workflows/code.release.branch.yml @@ -7,16 +7,19 @@ on: required: true permissions: - id-token: write - contents: write - pull-requests: write + contents: read # Default read-only jobs: MakeNewReleaseBranch: runs-on: ubuntu-latest + environment: dev + permissions: + contents: write # Required for branch creation + id-token: write # Required for AWS authentication + pull-requests: write # Required for creating PRs steps: - name: Checkout Develop Branch - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 with: ref: develop - name: Create Release Branch and Update Version @@ -29,12 +32,37 @@ jobs: echo "$( jq --arg version ${RELEASE_TAG:1} '.version = $version' package.json )" > package.json sed -E -i -e "s/version = \"[0-9\.].+\"/version = \"${RELEASE_TAG:1}\"/g" lisa-sdk/pyproject.toml echo ${RELEASE_TAG:1} > VERSION + # update package-lock.json + npm ci git commit -a -m "Updating version for release ${{ github.event.inputs.release_tag }}" git push origin release/${{ github.event.inputs.release_tag }} env: GITHUB_TOKEN: ${{ secrets.LEAD_ACCESS_TOKEN }} + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@3d21ddcb5087c3d29b7e19fe293e3455fabe32af # v4 + with: + aws-region: ${{ vars.AWS_REGION }} + role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT }}:role/${{ vars.ROLE_NAME_TO_ASSUME }} + role-session-name: GitHub_to_AWS_via_FederatedOIDC + role-duration-seconds: 14400 + - name: Generate PR Description with Bedrock + id: generate_description + run: | + # Switch back to develop branch for accurate commit comparison + git checkout develop + chmod +x .github/scripts/generate-pr-description.sh + .github/scripts/generate-pr-description.sh "${{ github.event.inputs.release_tag }}" + env: + GITHUB_TOKEN: ${{ secrets.LEAD_ACCESS_TOKEN }} + - name: Draft Pull Request run: | - gh pr create -d --title "Release ${{github.event.inputs.release_tag}} into Main" --body "Release ${{github.event.inputs.release_tag}} PR into Main" --base main --head release/${{ github.event.inputs.release_tag }} + # Switch to release branch to ensure proper context for PR creation + git checkout release/${{ github.event.inputs.release_tag }} + gh pr create -d \ + --title "Release ${{github.event.inputs.release_tag}} into Main" \ + --body "${{ steps.generate_description.outputs.DESCRIPTION }}" \ + --base main \ + --head release/${{ github.event.inputs.release_tag }} env: GH_TOKEN: ${{ github.token }} diff --git a/.github/workflows/code.smoke-test.yml b/.github/workflows/code.smoke-test.yml index 1041cab23..bbf706b9d 100644 --- a/.github/workflows/code.smoke-test.yml +++ b/.github/workflows/code.smoke-test.yml @@ -14,10 +14,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Setup Node.js - uses: actions/setup-node@v3 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 with: node-version: '18' cache: 'npm' @@ -42,7 +42,7 @@ jobs: - name: Archive Cypress videos & screenshots if: failure() || always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 with: name: cypress-smoke-artifacts path: | diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..d9f00cb60 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,73 @@ +name: "CodeQL Security Analysis" + +on: + # Temporarily disabled to avoid conflicts with default CodeQL setup + # push: + # branches: [ "main", "develop" ] + # pull_request: + # branches: [ "main", "develop" ] + # schedule: + # - cron: '0 8 * * 1' # Weekly on Monday at 8 AM UTC + workflow_dispatch: + +permissions: + contents: read + +jobs: + analyze: + name: Analyze Code + runs-on: ubuntu-latest + permissions: + contents: read + security-events: write + actions: read + + strategy: + fail-fast: false + matrix: + language: [ 'javascript', 'python' ] + + steps: + - name: Checkout repository + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 + + - name: Initialize CodeQL + uses: github/codeql-action/init@b36bf259c813715f76eafece573914b94412cd13 # v3 + with: + languages: ${{ matrix.language }} + config: | + name: "Custom CodeQL Analysis" + + - name: Set up Python 3.11 + if: matrix.language == 'python' + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 + with: + python-version: "3.11" + + - name: Install Python dependencies + if: matrix.language == 'python' + run: | + python -m pip install --upgrade pip + # Try hash-verified install first, fall back to regular + if [ -f "requirements-dev-hashes.txt" ]; then + pip install --require-hashes -r requirements-dev-hashes.txt + else + pip install -r requirements-dev.txt + fi + pip install -e ./lisa-sdk + + - name: Use Node.js 20.x + if: matrix.language == 'javascript' + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 + with: + node-version: 20.x + + - name: Install Node.js dependencies + if: matrix.language == 'javascript' + run: | + npm ci + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@b36bf259c813715f76eafece573914b94412cd13 # v3 + with: + category: "/language:${{matrix.language}}" diff --git a/.github/workflows/docs.deploy.github-pages.yml b/.github/workflows/docs.deploy.github-pages.yml index ba12f72e7..faa07b402 100644 --- a/.github/workflows/docs.deploy.github-pages.yml +++ b/.github/workflows/docs.deploy.github-pages.yml @@ -19,16 +19,16 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 with: fetch-depth: 0 - name: Setup Node - uses: actions/setup-node@v4 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 with: node-version: 20 cache: npm - name: Setup Pages - uses: actions/configure-pages@v4 + uses: actions/configure-pages@1f0c5cde4bc74cd7e1254d0cb4de8d49e9068c7d # v4 - name: Install root dependencies run: | npm ci @@ -39,7 +39,7 @@ jobs: CI: "" DOCS_BASE_PATH: '/LISA/' - name: Upload artifact - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@56afc609e74202658d3ffba0e8f6dda462b719fa # v3 with: path: ./lib/docs/dist deploy: @@ -52,4 +52,4 @@ jobs: steps: - name: Deploy to GitHub Pages id: deployment - uses: actions/deploy-pages@v4 + uses: actions/deploy-pages@d6db90164ac5ed86f2b6aed7e0febac5b3c0c03e # v4 diff --git a/.github/workflows/issues.alert.yml b/.github/workflows/issues.alert.yml index f7995928e..e7d361f85 100644 --- a/.github/workflows/issues.alert.yml +++ b/.github/workflows/issues.alert.yml @@ -1,5 +1,7 @@ name: Alert on Issue Creation -permissions: {} +permissions: + contents: read + issues: read on: issues: types: [opened, reopened] @@ -8,9 +10,12 @@ jobs: send_slack_notification: name: Send Issue Alert Slack Notification runs-on: ubuntu-latest + permissions: + contents: read + issues: read steps: - name: Send slack notification for issue created - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 if: github.event.action == 'opened' env: SLACK_WEBHOOK: ${{ secrets.INTERNAL_DEV_SLACK_WEBHOOK_URL }} @@ -20,7 +25,7 @@ jobs: MSG_MINIMAL: 'true' SLACK_MESSAGE: ' Issue <${{ github.event.issue.html_url }}|${{ github.event.issue.title }}> created by ${{ github.event.sender.login }}' - name: Send slack notification for issue reopened - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 if: github.event.action == 'reopened' env: SLACK_WEBHOOK: ${{ secrets.INTERNAL_DEV_SLACK_WEBHOOK_URL }} diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index 4a3e7d2af..cd9e03c82 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -13,10 +13,12 @@ jobs: send_starting_slack_notification: name: Send Starting Slack Notification runs-on: ubuntu-latest + permissions: + contents: read steps: - name: Send Internal PR Created Notification if: github.event_name == 'pull_request' && github.event.action == 'opened' - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_URL }} SLACK_TITLE: 'PR Created: ${{ github.event.pull_request.title }} by ${{ github.event.pull_request.user.login }}' @@ -25,7 +27,7 @@ jobs: SLACK_MESSAGE: 'PR Created ${{ github.event.pull_request.html_url }}' - name: Send Mission Solution PR Created Notification if: github.event_name == 'pull_request' && github.event.action == 'opened' - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 env: SLACK_WEBHOOK: ${{ secrets.MISSION_SOLUTION_PR_WEBHOOK }} SLACK_TITLE: '${{github.event.repository.name}} PR Created: ${{ github.event.pull_request.title }} by ${{ github.event.pull_request.user.login }}' @@ -36,10 +38,12 @@ jobs: name: CDK Tests needs: [send_starting_slack_notification] runs-on: ubuntu-latest + permissions: + contents: read steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Use Node.js 20.x - uses: actions/setup-node@v3 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 with: node-version: 20.x - name: Install dependencies @@ -52,16 +56,23 @@ jobs: name: Backend Tests runs-on: ubuntu-latest needs: [send_starting_slack_notification] + permissions: + contents: read steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Set up Python 3.11 - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements-dev.txt + # Try hash-verified install first, fall back to regular + if [ -f "requirements-dev-hashes.txt" ]; then + pip install --require-hashes -r requirements-dev-hashes.txt + else + pip install -r requirements-dev.txt + fi pip install -e ./lisa-sdk - name: Run tests run: | @@ -70,29 +81,33 @@ jobs: name: Run All Pre-Commit needs: [send_starting_slack_notification] runs-on: ubuntu-latest + permissions: + contents: read steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Set up Python 3.11 - uses: actions/setup-python@v3 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5 with: python-version: '3.11' - name: Use Node.js 20.x - uses: actions/setup-node@v3 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 with: node-version: 20.x - name: Install CDK dependencies working-directory: ./ run: | npm ci - - uses: pre-commit/action@v3.0.1 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 send_final_slack_notification: name: Send Final Slack Notification needs: [cdk-build, backend-build, pre-commit] runs-on: ubuntu-latest if: always() + permissions: + contents: read steps: - name: Send GitHub Action trigger data to Slack workflow - uses: rtCamp/action-slack-notify@v2 + uses: rtCamp/action-slack-notify@cdf0a2130cbcdfd82ba5fcac8e076370bf381b36 # v2 if: github.event_name != 'pull_request' env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_URL }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c1b26b4a..30385e8a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,7 @@ repos: hooks: - id: codespell entry: codespell - args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*poetry.lock*,*coverage*,*models/*,*htmlcov*,*TIKTOKEN_CACHE/*', "-L=xdescribe"] + args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*poetry.lock*,*coverage*,*models/*,*htmlcov*,*TIKTOKEN_CACHE/*', "-L=xdescribe,assertIn"] pass_filenames: false - repo: https://github.com/pycqa/isort diff --git a/Makefile b/Makefile index a12fb4dfa..32cb27f8f 100644 --- a/Makefile +++ b/Makefile @@ -143,18 +143,17 @@ else endif -## Set up Python interpreter environment +## Set up Python interpreter environment to match LISA deployed version createPythonEnvironment: - python3 -m venv .venv + python3.11 -m venv .venv @printf ">>> New virtual environment created. To activate run: 'source .venv/bin/activate'" ## Install Python dependencies for development installPythonRequirements: - pip3 install pip --upgrade - pip3 install -r requirements-dev.txt - pip3 install -e lisa-sdk - + CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install pip --upgrade + CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install --prefer-binary -r requirements-dev.txt + CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install -e lisa-sdk ## Set up TypeScript interpreter environment createTypeScriptEnvironment: @@ -206,7 +205,10 @@ modelCheck: fi; \ echo "Converting and uploading safetensors for model: $$MODEL_ID"; \ tgiImage=$$(yq -r '[.ecsModels[] | select(.inferenceContainer == "tgi") | .baseImage] | first' $(PROJECT_DIR)/config-custom.yaml); \ - echo $$tgiImage; \ + if [ "$$tgiImage" = "null" ] || [ -z "$$tgiImage" ]; then \ + tgiImage="ghcr.io/huggingface/text-generation-inference:latest"; \ + fi; \ + echo "Using TGI image: $$tgiImage"; \ $(PROJECT_DIR)/scripts/convert-and-upload-model.sh -m $$MODEL_ID -s $(MODEL_BUCKET) -a $$access_token -t $$tgiImage -d $$localModelDir; \ fi; \ fi; \ @@ -247,6 +249,7 @@ cleanCfn: ## Delete all misc files cleanMisc: @find . -type f -name "*.DS_Store" -delete + @find . -type d -name "TIKTOKEN_CACHE" -exec rm -rf {} + @rm -f .hf_token_cache @@ -269,6 +272,9 @@ listStacks: buildNpmModules: npm run build +buildArchive: + BUILD_ASSETS=true npm run build + define print_config @printf "\n \ DEPLOYING $(STACK) STACK APP INFRASTRUCTURE \n \ @@ -383,4 +389,4 @@ test-coverage: --cov-report term-missing \ --cov-report html:build/coverage \ --cov-report xml:build/coverage/coverage.xml \ - --cov-fail-under 83 + --cov-fail-under 85 diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..e5acff157 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,87 @@ +# πŸ”’ Security Policy + +## πŸ“‹ Supported Versions + +We actively maintain security updates for the following versions: + +| Version | Supported | +| ------- | ------------------ | +| Latest | βœ… Fully supported | +| < Latest| ❌ Security updates on best-effort basis | + +## 🚨 Reporting Security Vulnerabilities + +**Please DO NOT report security vulnerabilities through public GitHub issues.** + +Instead, please report security issues by: + +1. **Email**: Send details to the project maintainers via GitHub +2. **Private Issue**: Use GitHub's security advisory feature if available +3. **Direct Contact**: Contact repository administrators directly + +### πŸ“ What to Include + +Please include the following information: +- Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) +- Full paths of source file(s) related to the manifestation of the issue +- The location of the affected source code (tag/branch/commit or direct URL) +- Any special configuration required to reproduce the issue +- Step-by-step instructions to reproduce the issue +- Proof-of-concept or exploit code (if possible) +- Impact of the issue, including how an attacker might exploit it + +## πŸ›‘οΈ Security Measures in Place + +### **Static Analysis** +- **CodeQL**: Automated security scanning on all pull requests +- **Dependency Scanning**: Regular vulnerability detection +- **License Compliance**: Automated license validation + +### **Dependency Management** +- **Dependabot**: Automated security updates +- **Pin Dependencies**: Critical dependencies pinned by hash +- **Vulnerability Monitoring**: Continuous monitoring of known CVEs + +### **CI/CD Security** +- **Least Privilege**: GitHub Actions use minimal required permissions +- **Supply Chain Protection**: All third-party actions pinned by commit hash +- **Secure Workflows**: No dangerous workflow patterns + +### **Infrastructure Security** +- **Container Security**: Base images pinned to specific digests +- **AWS IAM**: Least privilege access controls +- **Encryption**: TLS 1.2+ for all communications + +## ⚑ Response Timeline + +- **Critical vulnerabilities**: 24-48 hours +- **High severity**: 7 days +- **Medium severity**: 30 days +- **Low severity**: Next release cycle + +## πŸ” Security Testing + +This project undergoes regular security assessments: + +- **OpenSSF Scorecard**: Monthly comprehensive security analysis +- **Dependency Scanning**: Weekly automated checks +- **Static Analysis**: On every pull request +- **Security Reviews**: Quarterly manual assessments + +## πŸ›οΈ Governance + +Security decisions follow these principles: + +1. **Defense in Depth**: Multiple layers of security controls +2. **Zero Trust**: Verify all access and communications +3. **Least Privilege**: Minimum required access permissions +4. **Continuous Monitoring**: Real-time threat detection +5. **Incident Response**: Documented response procedures + +## πŸ“š Additional Resources + +- [OpenSSF Scorecard](https://github.com/ossf/scorecard) - Security health metrics +- [NIST Cybersecurity Framework](https://www.nist.gov/cyberframework) - Security guidelines +- [OWASP Top 10](https://owasp.org/www-project-top-ten/) - Common vulnerabilities + +For questions about this security policy, please contact the project maintainers. diff --git a/VERSION b/VERSION index 03f488b07..c7cb1311a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -5.3.0 +5.3.1 diff --git a/bin/build-assets b/bin/build-assets new file mode 100755 index 000000000..4a214ae4a --- /dev/null +++ b/bin/build-assets @@ -0,0 +1,7 @@ +#!/bin/bash +set -e + +ROOT=$(pwd) + +./bin/build-lambdas +./bin/build-images --export diff --git a/bin/build-images b/bin/build-images new file mode 100755 index 000000000..ae9377f21 --- /dev/null +++ b/bin/build-images @@ -0,0 +1,142 @@ +#!/bin/bash + +set -e + +ROOT=$(pwd) +OUTPUT_DIR=$ROOT/dist/images +DOCKER_CMD=$(command -v finch >/dev/null 2>&1 && echo "finch" || echo "docker") + +# Parse command line arguments +UPLOAD=false +EXPORT=false +for arg in "$@"; do + case $arg in + --upload) + UPLOAD=true + shift + ;; + --export) + EXPORT=true + mkdir -p $OUTPUT_DIR + shift + ;; + esac +done + +# Default LISA_VERSION if not set +LISA_VERSION=${LISA_VERSION:-$(cat ./VERSION 2>/dev/null || echo "latest")} + +# ECR configuration +ACCOUNT=${AWS_ACCOUNT:-""} +REGION=${AWS_REGION:-"us-east-1"} +DOMAIN=${AWS_DOMAIN:-"amazonaws.com"} +ECR_BASE_URL=$ACCOUNT.dkr.ecr.$REGION.$DOMAIN + +# Function to build a single image +build_image() { + local dockerfile_path="$1" + local repository_name="$2" + local image_tag="$3" + local build_context_path="$4" + shift 4 + local build_args=("$@") + + echo "Building image: $repository_name:$image_tag" + echo "Context: $build_context_path" + + # Construct docker build command + local docker_cmd="$DOCKER_CMD build" + + # Add build args + for arg in "${build_args[@]}"; do + docker_cmd="$docker_cmd --build-arg $arg" + done + + # Add dockerfile, tag, and context + docker_cmd="$docker_cmd -f $build_context_path/$dockerfile_path -t $repository_name:$image_tag $build_context_path" + + echo "Executing: $docker_cmd" + eval "$docker_cmd" + echo "Successfully built $repository_name:$image_tag" + + # Upload to ECR if --upload flag is set + if [[ "$UPLOAD" == "true" && -n "$ACCOUNT" ]]; then + local ecr_tag="$ECR_BASE_URL/$repository_name:$image_tag" + echo "Tagging for ECR: $ecr_tag" + $DOCKER_CMD tag "$repository_name:$image_tag" "$ecr_tag" + echo "Pushing to ECR: $ecr_tag" + $DOCKER_CMD push "$ecr_tag" + echo "Successfully pushed $ecr_tag" + fi + + # Export image if --export flag is set + if [[ "$EXPORT" == "true" ]]; then + local export_file="$OUTPUT_DIR/${repository_name}_${image_tag}.tar" + echo "Exporting image to: $export_file" + $DOCKER_CMD save "$repository_name:$image_tag" -o "$export_file" + echo "Successfully exported $export_file" + fi + echo "" +} + +# Function to login to ECR +ecr_login() { + if [[ "$UPLOAD" == "true" && -n "$ACCOUNT" ]]; then + echo "Logging into ECR..." + aws ecr get-login-password --region $REGION | $DOCKER_CMD login --username AWS --password-stdin $ACCOUNT.dkr.ecr.$REGION.$DOMAIN + echo "ECR login successful" + echo "" + fi +} + +# Main function to build all images +build_all_images() { + echo "Starting Docker image builds..." + echo "LISA_VERSION: $LISA_VERSION" + if [[ "$UPLOAD" == "true" && -n "$ACCOUNT" ]]; then + echo "ECR_BASE_URL: $ECR_BASE_URL" + echo "Upload: Enabled" + else + echo "Upload: Disabled" + fi + echo "" + + ecr_login + + # lisa-rest-api + python3 scripts/cache-tiktoken-for-offline.py ./lib/serve/rest-api/TIKTOKEN_CACHE + build_image "Dockerfile" "lisa-rest-api" "$LISA_VERSION" "./lib/serve/rest-api" \ + "NODE_ENV=production" \ + "LITELLM_CONFIG=\"db_key: sk-a8814208-0388-480c-9fc7-fea59607ca38\"" \ + "BASE_IMAGE=python:3.11" + + # lisa-batch-ingestion + RAG_DIR="./lib/rag/ingestion/ingestion-image" + BUILD_DIR="${RAG_DIR}/build" + mkdir -p "$BUILD_DIR" + rsync -av --exclude='__pycache__' ./lambda/ "$BUILD_DIR/" + build_image "Dockerfile" "lisa-batch-ingestion" "$LISA_VERSION" "$RAG_DIR" "NODE_ENV=production" + + # lisa-tei + build_image "Dockerfile" "lisa-tei" "latest" "./lib/serve/ecs-model/embedding/tei" \ + "NODE_ENV=production" \ + "BASE_IMAGE=ghcr.io/huggingface/text-embeddings-inference:latest" \ + "MOUNTS3_DEB_URL=https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb" + + # lisa-tgi + build_image "Dockerfile" "lisa-tgi" "latest" "./lib/serve/ecs-model/textgen/tgi" \ + "NODE_ENV=production" \ + "BASE_IMAGE=ghcr.io/huggingface/text-generation-inference:latest" \ + "MOUNTS3_DEB_URL=https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb" + + # lisa-vllm + build_image "Dockerfile" "lisa-vllm" "latest" "./lib/serve/ecs-model/vllm" \ + "NODE_ENV=production" \ + "BASE_IMAGE=vllm/vllm-openai:latest" \ + "MOUNTS3_DEB_URL=https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb" + + echo "All images built successfully!" +} + +# Run the build +build_all_images diff --git a/bin/build-lambdas b/bin/build-lambdas new file mode 100755 index 000000000..aa156a939 --- /dev/null +++ b/bin/build-lambdas @@ -0,0 +1,46 @@ +#!/bin/bash +set -e + +ROOT=$(pwd) +OUTPUT_DIR=$ROOT/dist/layers +mkdir -p $OUTPUT_DIR + +PYPI_URL=${PYPI_URL:-https://pypi.org/simple/} +source .venv/bin/activate + +build_layer() { + local package_name=$1 + local source_path=$2 + local pre_build_cmd=$3 + echo "Building Lambda Layer $package_name from $source_path..." + + if [ -n "$pre_build_cmd" ]; then + eval "$pre_build_cmd" + fi + + cd $source_path + $ROOT/bin/package-lambda-layer --src . --output "$package_name.zip" --pypi $PYPI_URL --layer + mv ./build/"$package_name.zip" $OUTPUT_DIR/ + rm -rf ./build + cd $ROOT +} + +build_lambda() { + local package_name=$1 + local source_path=$2 + echo "Building Lambda $package_name from $source_path..." + cd "$source_path" + $ROOT/bin/package-lambda-layer --src . --output "$package_name.zip" --pypi $PYPI_URL + mv ./build/"$package_name.zip" $OUTPUT_DIR/ + rm -rf ./build + cd $ROOT +} + +echo "Building Python Lambda Layers..." +build_layer "AimlAdcLisaCommonLayer" "./lib/core/layers/common" +build_layer "AimlAdcLisaAuthLayer" "./lib/core/layers/authorizer" +build_layer "AimlAdcLisaFastApiLayer" "./lib/core/layers/fastapi" +build_layer "AimlAdcLisaRag" "./lib/rag/layer" "python3 scripts/cache-tiktoken-for-offline.py ./lib/rag/layer/TIKTOKEN_CACHE" + +echo "Building Python Lambdas..." +build_lambda "AimlAdcLisaLambda" "./lambda" diff --git a/bin/copy-deps.sh b/bin/copy-deps.sh deleted file mode 100755 index c972a7247..000000000 --- a/bin/copy-deps.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash - -function install_python_deps() { - local input_path=$1 - local output_path=$2 - local package=$3 - - echo "Installing Python dependencies for $package" - mkdir -p "${output_path}" - if ! pip install -r ${input_path}/requirements.txt --target $output_path --platform manylinux2014_x86_64 --only-binary=:all: --no-deps --no-cache-dir; then - echo "Failed to install Python dependencies for ${package}" - exit 1 - fi - - echo "${package} dependencies installed successfully" - rsync -a "${input_path}/" "${output_path}" - - echo "Optimizing ${package}" - find $output_path -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null - find $output_path -type d -name "*.dist-info" -exec rm -rf {} + 2>/dev/null - find $output_path -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null - find $output_path -type f -name "*.pyc" -delete - find $output_path -type f -name "*.pyo" -delete - find $output_path -type f -name "*.so" -exec strip {} + 2>/dev/null -} - -function setup_python_dist(){ - cd dist - - # Define the layers - PYTHON_VERSION="3.11" - DIST="." - OUTPUT_DIR="python/lib/${PYTHON_VERSION}/site-packages" - - # Create a virtual environment for isolation - python -m venv .venv - source .venv/bin/activate - - # # Install dependencies for each lambda layer - layers=("authorizer" "common" "fastapi") - layers_path="../lib/core/layers" - layers_output="${DIST}/lambdaLayer" - for layer in "${layers[@]}"; do - # ./package-lambda-layer --src="${layers_path}/${layer}" --build="./dist/$layer" --output="${layers_output}/${layer}" - layer_path="${layers_path}/${layer}" - layer_output="${layers_output}/${layer}/${OUTPUT_DIR}" - install_python_deps $layer_path $layer_output $layer - done - - # Install rag layer - rag_path="../lib/rag/layer" - rag_output="${DIST}/rag/${OUTPUT_DIR}" - rag_package="rag" - install_python_deps $rag_path $rag_output $rag_package - - # Install lisa-sdk dependencies - sdk_path="../lisa-sdk" - sdk_output="${DIST}/lisa-sdk/${OUTPUT_DIR}" - sdk_package="lisa-sdk" - install_python_deps $sdk_path $sdk_output $sdk_package - - # Deactivate virtual environment - deactivate - rm -rf .venv - echo "All Python dependencies installed successfully" - cd - -} - -function copy_dist() { - mkdir -p dist/ecs_model_deployer && rsync -av ecs_model_deployer/dist/ dist/ecs_model_deployer/ && cp ecs_model_deployer/Dockerfile dist/ecs_model_deployer/ - mkdir -p dist/vector_store_deployer && rsync -av vector_store_deployer/dist/ dist/vector_store_deployer/ && cp vector_store_deployer/Dockerfile dist/vector_store_deployer/ - mkdir -p dist/lisa-web && rsync -av lib/user-interface/react/dist/ dist/lisa-web - mkdir -p dist/docs && rsync -av lib/docs/dist/ dist/docs - cp VERSION dist/ -} - -mkdir -p dist -# setup_python_dist -copy_dist diff --git a/bin/package-lambda-layer b/bin/package-lambda-layer index 18591b8ca..16cc05c1a 100755 --- a/bin/package-lambda-layer +++ b/bin/package-lambda-layer @@ -4,76 +4,141 @@ set -e SRC=src OUTPUT=Lambda.zip EXCLUDE_PACKAGES="" -SRC_ROOT=$PWD -BUILD_DIR=$SRC_ROOT/build +BUILD_DIR=$PWD/build +IS_LAYER=0 +TMP_DIR=$BUILD_DIR/tmp/ +PYPI_URL= # Parse named parameters while [ $# -gt 0 ]; do - case "$1" in - --src=*) - SRC="${1#*=}" - ;; - --output=*) - OUTPUT="${1#*=}" - ;; - --build=*) - BUILD_DIR="${1#*=}" - ;; - --exclude=*) - EXCLUDE_PACKAGES="${1#*=}" - ;; - *) - echo "Unknown parameter: $1" - echo "Usage: $0 --src= --output= --exclude=" - exit 1 - ;; - esac + if [[ $1 == *"="* ]]; then + # Handle --param=value style + param="${1%%=*}" + value="${1#*=}" + + case "$param" in + --src) + SRC="$value" + ;; + --output) + OUTPUT="$value" + ;; + --build) + BUILD_DIR="$value" + ;; + --exclude) + EXCLUDE_PACKAGES="$value" + ;; + --pypi) + PYPI_URL="$value" + ;; + --layer) + IS_LAYER=1 + ;; + *) + echo "Unknown parameter: $param" + echo "Usage: $0 --src --output --exclude --layer" + exit 1 + ;; + esac + else + # Handle --param value style + case "$1" in + --src) + shift + SRC="$1" + ;; + --output) + shift + OUTPUT="$1" + ;; + --build) + shift + BUILD_DIR="$1" + TMP_DIR=$BUILD_DIR/tmp/python/ + ;; + --exclude) + shift + EXCLUDE_PACKAGES="$1" + ;; + --pypi) + shift + PYPI_URL="$1" + ;; + --layer) + IS_LAYER=1 + ;; + *) + echo "Unknown parameter: $1" + echo "Usage: $0 --src --output --exclude " + exit 1 + ;; + esac + fi shift done +echo "Starting" +if [ $IS_LAYER -eq 1 ]; then + TMP_DIR=$BUILD_DIR/tmp/python/ +fi + +if [ -z "$PYPI_URL" ]; then + echo "Must supply PYPI_URL via --pypi" + exit 1 +fi + +# Extract IP from PYPI_URL for trusted host +TRUSTED_HOST=$(echo $PYPI_URL | sed 's|http://||' | sed 's|/.*||') + +# Print parameters for debugging +echo "Source directory: $SRC" +echo "Output file: $OUTPUT" +echo "Build directory: $BUILD_DIR" +echo "Temp directory: $TMP_DIR" + + install_requirements() { - echo "installing requirements" - rm -rf "$BUILD_DIR" - mkdir -p "$BUILD_DIR/python" - python3 -m pip install "$SRC_ROOT" --target "${BUILD_DIR}/python" + echo "Installing requirements" + rm -rf "$TMP_DIR" + mkdir -p "$TMP_DIR" + if [ -f "$SRC/requirements.txt" ]; then + echo "Installing requirements from $SRC/requirements.txt" + echo "Using python version $(python3 --version)" + python3 -m pip install -r "$SRC/requirements.txt" --force-reinstall --no-cache-dir --target "$TMP_DIR" --index-url $PYPI_URL --trusted-host $TRUSTED_HOST + else + echo "No requirements.txt found in $SRC" + fi } build_package() { - echo "building package" + echo "Building package" if [ -d "$SRC" ]; then - cp -r "$SRC"/* "${BUILD_DIR}/python/" - fi -} - -copy_configuration() { - echo "copying configuration" - if [ -d "configuration/Packaging" ]; then - cp -a configuration/Packaging "$BUILD_DIR" + rsync -av --exclude='build' --exclude='.hatch' --exclude='.venv' "$SRC/" "$TMP_DIR/" fi } package_artifacts() { - echo "packaging" + echo "Packaging" if [ -n "$EXCLUDE_PACKAGES" ]; then echo "Removing excluded packages: $EXCLUDE_PACKAGES" for pkg in ${EXCLUDE_PACKAGES//,/ }; do echo "Removing $pkg" - rm -rf ${BUILD_DIR}/python/${pkg} - rm -rf ${BUILD_DIR}/python/${pkg}-* + rm -rf ${TMP_DIR}/${pkg} + rm -rf ${TMP_DIR}/${pkg}-* # Also remove egg-info directories - find "${BUILD_DIR}/python" -type d -name "${pkg}*egg-info" -exec rm -rf {} + + find "$TMP_DIR" -type d -name "${pkg}*egg-info" -exec rm -rf {} + done fi # AWS Lambda recommends to exclude __pycache__: https://docs.aws.amazon.com/lambda/latest/dg/python-package.html#python-package-pycache - find "${BUILD_DIR}/python" -depth -name __pycache__ -exec rm -rf {} \; - cd "${BUILD_DIR}" - zip "${BUILD_DIR}/${OUTPUT}" ./python -r - rm -rf "${BUILD_DIR}/python" + find "${TMP_DIR}" -depth -name __pycache__ -exec rm -rf {} \; + cd "${BUILD_DIR}/tmp/" + zip -r "${BUILD_DIR}/${OUTPUT}" . + rm -rf "${BUILD_DIR}/tmp" } install_requirements build_package -copy_configuration package_artifacts diff --git a/cypress/src/smoke/fixtures/models.json b/cypress/src/smoke/fixtures/models.json index af61f2e27..6a10fd5ab 100644 --- a/cypress/src/smoke/fixtures/models.json +++ b/cypress/src/smoke/fixtures/models.json @@ -2,7 +2,7 @@ "models": [ { "autoScalingConfig": { - "blockDeviceVolumeSize": 30, + "blockDeviceVolumeSize": 50, "minCapacity": 1, "maxCapacity": 1, "cooldown": 420, diff --git a/ecs_model_deployer/src/lib/ecsCluster.ts b/ecs_model_deployer/src/lib/ecsCluster.ts index 45b95cd95..280676e6a 100644 --- a/ecs_model_deployer/src/lib/ecsCluster.ts +++ b/ecs_model_deployer/src/lib/ecsCluster.ts @@ -124,9 +124,9 @@ export class ECSCluster extends Construct { // we want to set these based on the task created but currently the ECSCluster for model // will only create one task, so grab these values during creation so we can set the properties // on this class - let container; - let taskRole; - let endpointUrl; + let container: ContainerDefinition | undefined; + let taskRole: IRole | undefined; + let endpointUrl: string | undefined; Object.entries(ecsConfig.tasks).forEach(([, taskDefinition]) => { const environment = taskDefinition.environment; @@ -201,7 +201,7 @@ export class ECSCluster extends Construct { } const roleId = identifier; - const taskRole = taskRoleName ? + taskRole = taskRoleName ? Role.fromRoleName(this, createCdkId([config.deploymentName, roleId]), taskRoleName) : this.createTaskRole(config.deploymentName ?? '', config.deploymentPrefix, roleId); @@ -231,7 +231,7 @@ export class ECSCluster extends Construct { : undefined; const image = CodeFactory.createImage(taskDefinition.containerConfig.image, this, identifier, ecsConfig.buildArgs); - const container = ec2TaskDefinition.addContainer(createCdkId([identifier, 'Container']), { + container = ec2TaskDefinition.addContainer(createCdkId([identifier, 'Container']), { containerName: createCdkId([config.deploymentName, identifier], 32, 2), image, environment, @@ -322,15 +322,20 @@ export class ECSCluster extends Construct { const domain = loadBalancer.loadBalancerDnsName; endpointUrl = `${protocol}://${domain}`; + }); - new CfnOutput(this, 'modelEndpointurl', { - key: 'modelEndpointUrl', - value: this.endpointUrl, - }); + // Validate endpointUrl is set before creating output + if (!endpointUrl) { + throw new Error('Failed to create endpoint URL - no tasks configured'); + } + + new CfnOutput(this, 'modelEndpointurl', { + key: 'modelEndpointUrl', + value: endpointUrl, }); // Update - this.endpointUrl = endpointUrl!; + this.endpointUrl = endpointUrl; this.container = container!; this.taskRole = taskRole!; } diff --git a/lambda/authorizer/lambda_functions.py b/lambda/authorizer/lambda_functions.py index ed6716e89..db1aca7b9 100644 --- a/lambda/authorizer/lambda_functions.py +++ b/lambda/authorizer/lambda_functions.py @@ -18,7 +18,6 @@ import os import ssl from datetime import datetime -from functools import cache from typing import Any, Dict import boto3 @@ -26,6 +25,7 @@ import jwt import requests from botocore.exceptions import ClientError +from cachetools import cached, TTLCache from utilities.common_functions import authorization_wrapper, get_id_token, get_property_path, retry_config logger = logging.getLogger(__name__) @@ -203,7 +203,7 @@ def find_jwt_username(jwt_data: dict[str, str]) -> str: return username -@cache +@cached(cache=TTLCache(maxsize=1, ttl=300)) def get_management_tokens() -> list[str]: """Return secret management tokens if they exist.""" secret_tokens: list[str] = [] diff --git a/lambda/configuration/lambda_functions.py b/lambda/configuration/lambda_functions.py index 71fe29dfd..4d96a507b 100644 --- a/lambda/configuration/lambda_functions.py +++ b/lambda/configuration/lambda_functions.py @@ -21,7 +21,6 @@ from typing import Any, Dict import boto3 -import create_env_variables # noqa: F401 from botocore.exceptions import ClientError from mcp_server.models import McpServerModel from mcp_workbench.lambda_functions import MCPWORKBENCH_UUID diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index bf0ccd79b..98e4f51d1 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -99,7 +99,7 @@ class LoadBalancerConfig(BaseModel): class AutoScalingConfig(BaseModel): """Specifies auto-scaling parameters for model deployment.""" - blockDeviceVolumeSize: Optional[NonNegativeInt] = 30 + blockDeviceVolumeSize: Optional[NonNegativeInt] = 50 minCapacity: NonNegativeInt maxCapacity: NonNegativeInt cooldown: PositiveInt diff --git a/lambda/models/model_api_key_cleanup.py b/lambda/models/model_api_key_cleanup.py new file mode 100644 index 000000000..84a750670 --- /dev/null +++ b/lambda/models/model_api_key_cleanup.py @@ -0,0 +1,321 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model API Key Cleanup Lambda + +This Lambda function removes the api_key field from existing Bedrock models +that were created with the old LiteLLM version that required api_key = "ignored". # pragma: allowlist secret +This fixes "Invalid API Key format" errors for Bedrock models that don't need API keys. + +Only models with modelName prefixed with "bedrock/" are processed. + +The cleanup runs automatically during CDK deployment via a CustomResource. +""" + +import json +import os +import sys +from typing import Any, Dict, List + +import boto3 +import psycopg2 +from utilities.common_functions import retry_config + +# Add the lambda directory to the Python path +sys.path.append("/opt/python") +sys.path.append("/var/task") + + +def get_all_dynamodb_models() -> List[Dict[str, str]]: + """Get all models from DynamoDB table with their IDs and names.""" + try: + dynamodb = boto3.client("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) + ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + + # Get model table name from SSM parameter + deployment_prefix = os.environ.get("DEPLOYMENT_PREFIX") + if not deployment_prefix: + raise ValueError("DEPLOYMENT_PREFIX environment variable not set") + model_table_param = f"{deployment_prefix}/modelTableName" + + try: + table_name_response = ssm_client.get_parameter(Name=model_table_param) + table_name = table_name_response["Parameter"]["Value"] + + if not table_name: + raise ValueError("Empty table name returned from SSM") + + except Exception as e: + print(f"Could not get model table name from SSM: {e}") + return [] + + # Scan the entire DynamoDB table to get all models + response = dynamodb.scan(TableName=table_name) + + models = [] + for item in response.get("Items", []): + # Extract model_id and modelName from the item with proper validation + model_id = item.get("model_id", {}).get("S", "") + model_config = item.get("model_config", {}) + model_name = "" + + if "M" in model_config and "modelName" in model_config["M"]: + model_name = model_config["M"]["modelName"].get("S", "") + + # Only include models with both ID and name + if model_id and model_name: + models.append({"model_id": model_id, "model_name": model_name}) + + print(f"Found {len(models)} models in DynamoDB") + return models + + except Exception as e: + print(f"Error scanning DynamoDB table: {e}") + return [] + + +def get_database_connection(): + """Get database connection using connection info from SSM.""" + ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + + # Get database connection info from SSM using environment variable + deployment_prefix = os.environ.get("DEPLOYMENT_PREFIX") + if not deployment_prefix: + raise ValueError("DEPLOYMENT_PREFIX environment variable not set") + db_param_name = f"{deployment_prefix}/LiteLLMDbConnectionInfo" + + try: + db_param_response = ssm_client.get_parameter(Name=db_param_name, WithDecryption=True) + db_params = json.loads(db_param_response["Parameter"]["Value"]) + except Exception as e: + raise ValueError(f"Failed to get database connection info from SSM: {e}") + + # Get database credentials from Secrets Manager + try: + secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) + secret_response = secrets_client.get_secret_value(SecretId=db_params["passwordSecretId"]) + secret = json.loads(secret_response["SecretString"]) + except Exception as e: + raise ValueError(f"Failed to get database credentials from Secrets Manager: {e}") + + # Validate required parameters + required_params = ["dbHost", "dbPort", "dbName", "username"] + for param in required_params: + if param not in db_params: + raise ValueError(f"Missing required database parameter: {param}") + + if "password" not in secret: + raise ValueError("Missing password in secret") + + # Create connection with proper error handling + try: + conn = psycopg2.connect( + host=db_params["dbHost"], + port=db_params["dbPort"], + database=db_params["dbName"], + user=db_params["username"], + password=secret["password"], + ) + return conn + except Exception as e: + raise ValueError(f"Failed to connect to database: {e}") + + +def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Lambda handler for Bedrock model API key cleanup. + + Only processes models with modelName prefixed with "bedrock/". + + Args: + event: CloudFormation CustomResource event + context: Lambda context + + Returns: + CloudFormation CustomResource response + """ + print("Starting Bedrock model API key cleanup...") + + # Validate environment variables + required_env_vars = ["AWS_REGION", "DEPLOYMENT_PREFIX"] + for env_var in required_env_vars: + if not os.environ.get(env_var): + error_msg = f"Missing required environment variable: {env_var}" + print(error_msg) + return {"Status": "FAILED", "PhysicalResourceId": "bedrock-auth-cleanup", "Reason": error_msg} + + conn = None + cursor = None + + try: + # Get database connection + conn = get_database_connection() + cursor = conn.cursor() + + # First, let's see what tables exist in the database + cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") + tables = cursor.fetchall() + print(f"Available tables in database: {[table[0] for table in tables]}") + + # Try to find the correct LiteLLM model table name + litellm_table = None + for table in tables: + table_name = table[0] + if "proxymodel" in table_name.lower() or table_name == "LiteLLM_ProxyModelTable": + litellm_table = table_name + print(f"Found LiteLLM model table: {table_name}") + break + + if not litellm_table: + print("No LiteLLM model table found in database. Database might not be initialized yet.") + print("Bedrock model cleanup completed! 0 Bedrock models updated (no LiteLLM tables found)") + # Return success response for CloudFormation CustomResource + return {"Status": "SUCCESS", "PhysicalResourceId": "bedrock-auth-cleanup", "Data": {"ModelsUpdated": "0"}} + + # Query all models from the LiteLLM database using the found table (use quotes for case-sensitive names) + cursor.execute(f'SELECT * FROM "{litellm_table}" LIMIT 1') + columns = [desc[0] for desc in cursor.description] + print(f"Table {litellm_table} columns: {columns}") + + # Try to find the correct column names + model_id_col = next((col for col in columns if "id" in col.lower()), None) + model_name_col = next((col for col in columns if "name" in col.lower()), None) + litellm_params_col = next((col for col in columns if "param" in col.lower()), None) + + if not all([model_id_col, model_name_col, litellm_params_col]): + print(f"Could not find required columns in table {litellm_table}") + print(f" Available columns: {columns}") + print("Bedrock model cleanup completed! 0 Bedrock models updated (LiteLLM table structure unknown)") + # Return success response for CloudFormation CustomResource + return {"Status": "SUCCESS", "PhysicalResourceId": "bedrock-auth-cleanup", "Data": {"ModelsUpdated": "0"}} + + # Query all models from the LiteLLM database + cursor.execute(f'SELECT "{model_id_col}", "{model_name_col}", "{litellm_params_col}" FROM "{litellm_table}"') + models = cursor.fetchall() + + print(f"Found {len(models)} total models in LiteLLM database") + + # Get all models from DynamoDB and check if they exist in LiteLLM + dynamodb_models = get_all_dynamodb_models() + bedrock_models_processed = 0 + + for dynamodb_model in dynamodb_models: + dynamodb_model_id = dynamodb_model["model_id"] + dynamodb_model_name = dynamodb_model["model_name"] + + # Check if this is a Bedrock model + if not dynamodb_model_name.startswith("bedrock/"): + continue + + print(f"Processing Bedrock model: {dynamodb_model_name}") + + # Find the corresponding LiteLLM model by matching the model_name (alias) + # DynamoDB model_id is actually the alias, LiteLLM model_name is the alias + matching_litellm_model = None + for model_id, model_name, litellm_params_data in models: + # Check if this LiteLLM model_name matches our DynamoDB model_id (which is the alias) + if model_name == dynamodb_model_id: + try: + # Handle both dict and JSON string formats + if isinstance(litellm_params_data, dict): + litellm_params = litellm_params_data + elif isinstance(litellm_params_data, str): + litellm_params = json.loads(litellm_params_data) if litellm_params_data else {} + else: + litellm_params = {} + except json.JSONDecodeError: + continue + + matching_litellm_model = { + "model_id": model_id, + "model_name": model_name, + "litellm_params": litellm_params, + } + break + + if not matching_litellm_model: + print(f"No matching LiteLLM model found for: {dynamodb_model_name}") + continue + + # Check if this model has an api_key to remove + if "api_key" in matching_litellm_model["litellm_params"]: + print(f"Removing api_key from: {matching_litellm_model['model_name']}") + + try: + # Remove api_key from litellm_params + clean_params = matching_litellm_model["litellm_params"].copy() + if "api_key" in clean_params: # pragma: allowlist secret + del clean_params["api_key"] + + # Update the model in the database + clean_params_json = json.dumps(clean_params) + cursor.execute( + f'UPDATE "{litellm_table}" SET "{litellm_params_col}" = %s WHERE "{model_id_col}" = %s', + (clean_params_json, matching_litellm_model["model_id"]), + ) + print(f"Successfully cleaned model: {matching_litellm_model['model_name']}") + bedrock_models_processed += 1 + + except Exception as e: + print(f"Error cleaning model {matching_litellm_model['model_name']}: {e}") + conn.rollback() + else: + print(f"Model {matching_litellm_model['model_name']} already clean") + + # Commit the changes + conn.commit() + print(f"Cleanup completed! {bedrock_models_processed} Bedrock models processed") + + # Return success response for CloudFormation CustomResource + return { + "Status": "SUCCESS", + "PhysicalResourceId": "bedrock-auth-cleanup", + "Data": {"ModelsUpdated": str(bedrock_models_processed)}, + } + + except ValueError as e: + # Handle configuration/validation errors + print(f"Configuration error: {e}") + return {"Status": "FAILED", "PhysicalResourceId": "bedrock-auth-cleanup", "Reason": str(e)} + + except Exception as e: + # Handle unexpected errors + print(f"Cleanup failed: {e}") + import traceback + + print(f"Traceback: {traceback.format_exc()}") + + # Rollback any pending database changes + if conn: + try: + conn.rollback() + except Exception as rollback_error: + print(f"Failed to rollback database changes: {rollback_error}") + + return {"Status": "FAILED", "PhysicalResourceId": "bedrock-auth-cleanup", "Reason": str(e)} + + finally: + # Ensure proper cleanup of database resources + if cursor: + try: + cursor.close() + except Exception as e: + print(f"Error closing cursor: {e}") + + if conn: + try: + conn.close() + except Exception as e: + print(f"Error closing database connection: {e}") diff --git a/lambda/models/state_machine/create_model.py b/lambda/models/state_machine/create_model.py index 769acbaa7..b757c0238 100644 --- a/lambda/models/state_machine/create_model.py +++ b/lambda/models/state_machine/create_model.py @@ -30,7 +30,12 @@ StackFailedToCreateException, UnexpectedCloudFormationStateException, ) -from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config +from utilities.common_functions import ( + get_account_and_partition, + get_cert_path, + get_rest_api_container_endpoint, + retry_config, +) logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -236,15 +241,10 @@ def camelize_object(o): # type: ignore[no-untyped-def] # Handle ECR images differently - use the existing ECR image instead of the built one if event["image_info"].get("image_type") == "ecr": # For pre-existing ECR images, construct the ARN using the image repository - account_id = os.environ.get("AWS_ACCOUNT_ID", "") - if not account_id: - # Try to get account ID from the existing ECR repository ARN - ecr_repo_arn = os.environ.get("ECR_REPOSITORY_ARN", "") - if ecr_repo_arn: - account_id = ecr_repo_arn.split(":")[4] + account_id, partition = get_account_and_partition() repository_arn = ( - f"arn:aws:ecr:{os.environ['AWS_REGION']}:{account_id}:repository/{event['image_info']['image_uri']}" + f"arn:{partition}:ecr:{os.environ['AWS_REGION']}:{account_id}:repository/{event['image_info']['image_uri']}" ) prepared_event["containerConfig"]["image"] = { "repositoryArn": repository_arn, @@ -269,11 +269,17 @@ def camelize_object(o): # type: ignore[no-untyped-def] stack_name = payload.get("stackName", None) if not stack_name: + # Log the full payload for debugging + logger.error(f"ECS Model Deployer response: {payload}") + error_message = payload.get("errorMessage", "Unknown error") + error_type = payload.get("errorType", "Unknown error type") + raise StackFailedToCreateException( json.dumps( { - "error": "Failed to create Model CloudFormation Stack. Please validate model parameters are valid.", + "error": f"Failed to create Model CloudFormation Stack. {error_type}: {error_message}", "event": event, + "deployer_response": payload, } ) ) @@ -351,9 +357,9 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str # Fallback to default if JSON parsing fails litellm_params = {} - litellm_params["api_key"] = event.get( - "apiKey", "ignored" - ) # pragma: allowlist-secret not a real key, but needed for LiteLLM to be happy + # Only set api_key if it's present in the event + if "apiKey" in event: + litellm_params["api_key"] = event["apiKey"] # pragma: allowlist-secret litellm_params["drop_params"] = True # drop unrecognized param instead of failing the request on it if is_lisa_managed: @@ -368,7 +374,18 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str litellm_params=litellm_params, ) - litellm_id = litellm_response["model_info"]["id"] + # Handle different LiteLLM API response structures + if "model_info" in litellm_response and "id" in litellm_response["model_info"]: + litellm_id = litellm_response["model_info"]["id"] + elif "id" in litellm_response: + litellm_id = litellm_response["id"] + elif "model_id" in litellm_response: + litellm_id = litellm_response["model_id"] + else: + # Log the actual response structure for debugging + logger.error(f"Unexpected LiteLLM response structure: {litellm_response}") + raise KeyError(f"Could not find model ID in LiteLLM response: {litellm_response}") + output_dict["litellm_id"] = litellm_id model_table.update_item( diff --git a/lambda/models/state_machine/update_model.py b/lambda/models/state_machine/update_model.py index 8965304ab..7479cc49a 100644 --- a/lambda/models/state_machine/update_model.py +++ b/lambda/models/state_machine/update_model.py @@ -391,7 +391,6 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: litellm_params["model"] = f"openai/{ddb_item['model_config']['modelName']}" litellm_params["api_base"] = model_url - litellm_params["api_key"] = "ignored" # pragma: allowlist-secret not a real key, but needed for LiteLLM to be happy ddb_update_expression = "SET model_status = :ms, last_modified_date = :lm" ddb_update_values: Dict[str, Any] = { diff --git a/lambda/repository/embeddings.py b/lambda/repository/embeddings.py index 9f4863baf..c36186be7 100644 --- a/lambda/repository/embeddings.py +++ b/lambda/repository/embeddings.py @@ -14,13 +14,14 @@ import logging import os -from typing import Any, List +from typing import List import boto3 import requests -from lisapy.langchain import LisaOpenAIEmbeddings -from utilities.common_functions import get_cert_path, retry_config -from utilities.validation import ValidationError +from pydantic import BaseModel, field_validator +from utilities.auth import get_management_key +from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config +from utilities.validation import validate_model_name, ValidationError logger = logging.getLogger(__name__) ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -30,39 +31,47 @@ lisa_api_endpoint = "" -class PipelineEmbeddings: +class RagEmbeddings(BaseModel): """ - Handles document embeddings for pipeline processing using management credentials. - - This class provides methods to embed both single queries and batches of documents - using the LISA API with management-level authentication. + Handles document embeddings through LiteLLM using management credentials. """ model_name: str - - def __init__(self, model_name: str) -> None: + token: str + lisa_api_endpoint: str + base_url: str + cert_path: str | bool + + @field_validator("model_name") + @classmethod + def validate_model_name(cls, v: str) -> str: + validate_model_name(v) + return v + + def __init__(self, model_name: str, id_token: str | None = None, **data) -> None: + # Prepare initialization data + init_data = {"model_name": model_name, **data} try: - self.model_name = model_name - # Get the management key secret name from SSM Parameter Store - secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) - secret_name = secret_name_param["Parameter"]["Value"] - - # Get the management token from Secrets Manager using the secret name - secret_response = secrets_client.get_secret_value(SecretId=secret_name) - self.token = secret_response["SecretString"] - - # Get the API endpoint from SSM - lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) - self.base_url = f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve" - - # Get certificate path for SSL verification - self.cert_path = get_cert_path(iam_client) - + # Use management token if id_token is not provided + if id_token is None: + logger.info("Using management key for ingestion") + init_data["token"] = get_management_key() + else: + init_data["token"] = id_token + + init_data["lisa_api_endpoint"] = get_rest_api_container_endpoint() + init_data["base_url"] = get_rest_api_container_endpoint() + init_data["cert_path"] = get_cert_path(iam_client) + + super().__init__(**init_data) logger.info("Successfully initialized pipeline embeddings") except Exception: logger.error("Failed to initialize pipeline embeddings", exc_info=True) raise + class Config: + arbitrary_types_allowed = True + def embed_documents(self, texts: List[str]) -> List[List[float]]: """ Generate embeddings for a list of documents. @@ -88,14 +97,13 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: response = requests.post( url, json=request_data, - headers={"Authorization": self.token, "Content-Type": "application/json"}, + headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"}, verify=self.cert_path, # Use proper SSL verification timeout=300, # 5 minute timeout ) if response.status_code != 200: logger.error(f"Embedding request failed with status {response.status_code}") - logger.error(f"Response content: {response.text}") raise Exception(f"Embedding request failed with status {response.status_code}") result = response.json() @@ -148,45 +156,3 @@ def embed_query(self, text: str) -> List[float]: logger.info("Embedding single query text") return self.embed_documents([text])[0] - - -def get_embeddings_pipeline(model_name: str) -> Any: - """ - Get embeddings for pipeline requests using management token. - - Args: - model_name: Name of the embedding model to use - - Raises: - ValidationError: If model name is invalid - Exception: If API request fails - """ - logger.info("Starting pipeline embeddings request") - - return PipelineEmbeddings(model_name=model_name) - - -def get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings: - """ - Initialize and return an embeddings client for the specified model. - - Args: - model_name: Name of the embedding model to use - id_token: Authentication token for API access - - Returns: - LisaOpenAIEmbeddings: Configured embeddings client - """ - global lisa_api_endpoint - - if not lisa_api_endpoint: - lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) - lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"] - - base_url = f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve" - cert_path = get_cert_path(iam_client) - - embedding = LisaOpenAIEmbeddings( - lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=cert_path - ) - return embedding diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index 9422b81a9..a20dc0f43 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -23,7 +23,7 @@ from boto3.dynamodb.types import TypeSerializer from botocore.config import Config from models.domain_objects import FixedChunkingStrategy, IngestionJob, IngestionStatus, RagDocument -from repository.embeddings import get_embeddings +from repository.embeddings import RagEmbeddings from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService from repository.rag_document_repo import RagDocumentRepository @@ -40,7 +40,6 @@ region_name = os.environ["AWS_REGION"] session = boto3.Session() ssm_client = boto3.client("ssm", region_name, config=retry_config) -secrets_client = boto3.client("secretsmanager", region_name, config=retry_config) iam_client = boto3.client("iam", region_name, config=retry_config) step_functions_client = boto3.client("stepfunctions", region_name, config=retry_config) ddb_client = boto3.client("dynamodb", region_name, config=retry_config) @@ -109,6 +108,7 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: - queryStringParameters.query: Search query text - queryStringParameters.repositoryType: Type of repository - queryStringParameters.topK (optional): Number of results to return (default: 3) + - queryStringParameters.score (optional): Include similarity scores (default: false) context (dict): The Lambda context object Returns: @@ -122,6 +122,7 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: model_name = query_string_params["modelName"] query = query_string_params["query"] top_k = query_string_params.get("topK", 3) + include_score = query_string_params.get("score", "false").lower() == "true" repository_id = event["pathParameters"]["repositoryId"] repository = vs_repo.find_repository_by_id(repository_id) @@ -139,7 +140,7 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: repository_id=repository_id, ) else: - embeddings = get_embeddings(model_name=model_name, id_token=id_token) + embeddings = RagEmbeddings(model_name=model_name, id_token=id_token) vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) # empty vector stores do not have an initialize index. Return empty docs @@ -148,11 +149,11 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: ): logger.info(f"Index {model_name} does not exist. Returning empty docs.") else: - results = vs.similarity_search( - query, - k=top_k, + docs = ( + _similarity_search_with_score(vs, query, top_k, repository) + if include_score + else _similarity_search(vs, query, top_k) ) - docs = [{"page_content": r.page_content, "metadata": r.metadata} for r in results] doc_content = [ { "Document": { @@ -536,6 +537,48 @@ def delete(event: dict, context: dict) -> Any: return {"status": "success", "executionArn": response["executionArn"]} +@api_wrapper +@admin_only +def delete_index(event: dict, context: dict) -> None: + """ + Clear the vector store for the specified repository and model. + + Args: + event (dict): The Lambda event object containing path parameters + context (dict): The Lambda context object + """ + path_params = event.get("pathParameters", {}) or {} + repository_id = path_params.get("repositoryId", None) + if not repository_id: + raise ValidationError("repositoryId is required") + model_name = path_params.get("modelName", None) + if not model_name: + raise ValidationError("modelName is required") + + repository = vs_repo.find_repository_by_id(repository_id=repository_id) + id_token = get_id_token(event) + embeddings = RagEmbeddings(model_name=model_name, id_token=id_token) + vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) + + try: + if RepositoryType.is_type(repository, RepositoryType.OPENSEARCH): + if vs.client.indices.exists(index=model_name): + vs.client.indices.delete(index=model_name) + logger.info(f"Deleted OpenSearch index: {model_name}") + else: + logger.info(f"OpenSearch index {model_name} does not exist") + elif RepositoryType.is_type(repository, RepositoryType.PGVECTOR): + # For PGVector, delete all documents in the collection + vs.delete_collection() + logger.info(f"Deleted PGVector collection: {model_name}") + else: + logger.error(f"Unsupported repository type: {repository.get('type')}") + return {"status": "error", "message": "Repository is not supported"} + except Exception as e: + logger.error(f"Failed to clear vector store: {e}") + return {"status": "error", "message": str(e)} + + def _remove_legacy(repository_id: str) -> None: registered_repositories = ssm_client.get_parameter(Name=os.environ["REGISTERED_REPOSITORIES_PS"]) registered_repositories = json.loads(registered_repositories["Parameter"]["Value"]) @@ -549,3 +592,59 @@ def _remove_legacy(repository_id: str) -> None: Type="String", Overwrite=True, ) + + +def _similarity_search(vs, query: str, top_k: int) -> list[dict[str, Any]]: + """Perform similarity search without scores. + + Args: + vs: Vector store instance + query: Search query string + top_k: Number of top results to return + + Returns: + List of documents with page_content and metadata + """ + results = vs.similarity_search_with_score( + query, + k=top_k, + ) + + return [{"page_content": doc.page_content, "metadata": doc.metadata} for doc, score in results] + + +def _similarity_search_with_score(vs, query: str, top_k: int, repository: dict) -> list[dict[str, Any]]: + """Perform similarity search with normalized scores. + + Args: + vs: Vector store instance + query: Search query string + top_k: Number of top results to return + repository: Repository configuration dict + + Returns: + List of documents with page_content, metadata, and similarity_score + """ + results = vs.similarity_search_with_score( + query, + k=top_k, + ) + docs = [] + for i, (doc, score) in enumerate(results): + similarity_score = RepositoryType.get_type(repository=repository).calculate_similarity_score(score) + logger.info( + f"Result {i + 1}: Raw Score={score:.4f}, Similarity={similarity_score:.4f}, " + + f"Content: {doc.page_content[:200]}..." + ) + logger.info(f"Result {i + 1} metadata: {doc.metadata}") + docs.append( + { + "page_content": doc.page_content, + "metadata": {**doc.metadata, "similarity_score": similarity_score}, + } + ) + + if results and max(score for _, score in results) < 0.3: + logger.warning(f"All similarity < 0.3 for query '{query}' - possible embedding model mismatch") + + return docs diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 83598ef9a..2b82080e2 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -21,7 +21,7 @@ import boto3 from models.domain_objects import FixedChunkingStrategy, IngestionJob, IngestionStatus, IngestionType, RagDocument -from repository.embeddings import get_embeddings_pipeline +from repository.embeddings import RagEmbeddings from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService from repository.rag_document_repo import RagDocumentRepository @@ -48,10 +48,12 @@ def pipeline_ingest(job: IngestionJob) -> None: + texts = [] + metadatas = [] + all_ids = [] try: # chunk and save chunks in vector store repository = vs_repo.find_repository_by_id(job.repository_id) - all_ids = [] if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): ingest_document_to_kb( s3_client=s3, @@ -104,15 +106,15 @@ def pipeline_ingest(job: IngestionJob) -> None: logging.info(f"Successfully ingested document {job.s3_path} ({len(all_ids)} chunks) into {job.collection_id}") except Exception as e: ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_FAILED) - error_msg = f"Failed to process document: {str(e)}" logger.error(error_msg, exc_info=True) + logger.error(f"Job: {job.model_dump_json(indent=2)}") raise Exception(error_msg) def remove_document_from_vectorstore(doc: RagDocument) -> None: # Delete from the Vector Store - embeddings = get_embeddings_pipeline(model_name=doc.collection_id) + embeddings = RagEmbeddings(model_name=doc.collection_id) vector_store = get_vector_store_client( doc.repository_id, index=doc.collection_id, @@ -280,7 +282,7 @@ def store_chunks_in_vectorstore( texts: List[str], metadatas: List[Dict], repository_id: str, embedding_model: str ) -> List[str]: """Store document chunks in vector store.""" - embeddings = get_embeddings_pipeline(model_name=embedding_model) + embeddings = RagEmbeddings(model_name=embedding_model) vs = get_vector_store_client( repository_id, index=embedding_model, diff --git a/lambda/session/lambda_functions.py b/lambda/session/lambda_functions.py index 68fd4648e..0971e00f1 100644 --- a/lambda/session/lambda_functions.py +++ b/lambda/session/lambda_functions.py @@ -21,14 +21,16 @@ from concurrent.futures import ThreadPoolExecutor from datetime import datetime from decimal import Decimal -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import boto3 import create_env_variables # noqa: F401 from botocore.exceptions import ClientError +from cachetools import cached, TTLCache from utilities.auth import get_username from utilities.common_functions import api_wrapper, get_groups, get_session_id, retry_config from utilities.encoders import convert_decimal +from utilities.session_encryption import decrypt_session_fields, migrate_session_to_encrypted, SessionEncryptionError logger = logging.getLogger(__name__) @@ -42,8 +44,55 @@ # Get model table for real-time feature validation model_table = dynamodb.Table(os.environ.get("MODEL_TABLE_NAME")) +# Get configuration table for system settings +config_table = dynamodb.Table(os.environ["CONFIG_TABLE_NAME"]) + executor = ThreadPoolExecutor(max_workers=10) +# Cache for configuration values to avoid repeated database queries +cache = TTLCache(maxsize=1, ttl=300) # 5 minutes + + +@cached(cache=cache) +def _is_session_encryption_enabled() -> bool: + """Check if session encryption is enabled via global configuration. + + Returns + ------- + bool + True if session encryption is enabled, False otherwise. + Defaults to False if configuration is not found or accessible. + """ + + try: + logger.debug("Querying global configuration for session encryption setting") + # Query the global configuration entry + response = config_table.query( + KeyConditionExpression="configScope = :scope", + ExpressionAttributeValues={":scope": "global"}, + ScanIndexForward=False, + Limit=1, + ) + + items = response.get("Items", []) + if items: + config_item = items[0] + configuration = config_item.get("configuration", {}) + enabled_components = configuration.get("enabledComponents", {}) + encrypt_session = enabled_components.get("encryptSession", False) # Default to False + logger.info(f"Retrieved session encryption setting from global config: {encrypt_session}") + return encrypt_session + else: + logger.warning("No global configuration found, defaulting session encryption to disabled") + return False + + except ClientError as error: + logger.error(f"Failed to query global configuration: {error}, defaulting to encryption disabled") + return False + except Exception as e: + logger.error(f"Error checking session encryption configuration: {e}, defaulting to disabled") + return False + def _get_current_model_config(model_id: str) -> Any: """Get the current model configuration from the model table. @@ -199,20 +248,41 @@ def _generate_presigned_image_url(key: str) -> str: return url -def _map_session(session: dict) -> Dict[str, Any]: +def _map_session(session: dict, user_id: Optional[str] = None) -> Dict[str, Any]: return { "sessionId": session.get("sessionId", None), "name": session.get("name", None), - "firstHumanMessage": _find_first_human_message(session), + "firstHumanMessage": _find_first_human_message(session, user_id), "startTime": session.get("startTime", None), "createTime": session.get("createTime", None), "lastUpdated": session.get( "lastUpdated", session.get("startTime", None) ), # Fallback to startTime for backward compatibility + "isEncrypted": session.get("is_encrypted", False), } -def _find_first_human_message(session: dict) -> str: +def _find_first_human_message(session: dict, user_id: Optional[str] = None) -> str: + # Check if session is encrypted + if session.get("is_encrypted", False): + # For encrypted sessions, decrypt to get the first message + try: + if user_id: + logging.info( + f"Decrypting encrypted session {session.get('sessionId', 'unknown')} " + f"to find first message for user {user_id}" + ) + decrypted_session = decrypt_session_fields(session, user_id, session.get("sessionId", "")) + # Use the decrypted session for finding the first message + session = decrypted_session + else: + # If no user_id provided, return placeholder + return "[Encrypted Session - User ID required]" + except SessionEncryptionError as e: + logging.error(f"Failed to decrypt session {session.get('sessionId', 'unknown')} to find first message: {e}") + return "[Encrypted Session - Decryption failed]" + + # For unencrypted sessions (or successfully decrypted sessions), proceed as before for msg in session.get("history", []): if msg.get("type") == "human": content = msg.get("content") @@ -225,7 +295,7 @@ def _find_first_human_message(session: dict) -> str: if text and not text.startswith("File context:"): return text else: - logger.warning(f"Unhandled human message content in session {session['sessionId']}") + logger.warning(f"Unhandled human message content in session {session.get('sessionId', 'unknown')}") return "" @@ -237,7 +307,7 @@ def list_sessions(event: dict, context: dict) -> List[Dict[str, Any]]: logger.info(f"Listing sessions for user {user_id}") sessions = _get_all_user_sessions(user_id) - return list(executor.map(_map_session, sessions)) + return list(executor.map(lambda session: _map_session(session, user_id), sessions)) def _process_image(task: Tuple[dict, str]) -> None: @@ -261,6 +331,18 @@ def get_session(event: dict, context: dict) -> dict: response = table.get_item(Key={"sessionId": session_id, "userId": user_id}) resp = response.get("Item", {}) + if not resp: + return {"statusCode": 404, "body": json.dumps({"error": "Session not found"})} + + # Check if session data is encrypted and decrypt if necessary + try: + if resp.get("is_encrypted", False): + logging.info(f"Decrypting encrypted session {session_id} for user {user_id}") + resp = decrypt_session_fields(resp, user_id, session_id) + except SessionEncryptionError as e: + logging.error(f"Failed to decrypt session {session_id}: {e}") + return {"statusCode": 500, "body": json.dumps({"error": "Failed to decrypt session data"})} + # Update configuration with current model settings before returning if resp and resp.get("configuration"): configuration = resp.get("configuration", {}) @@ -409,7 +491,86 @@ def put_session(event: dict, context: dict) -> dict: updated_temp_config = _update_session_with_current_model_config(temp_config) configuration["selectedModel"] = updated_temp_config.get("selectedModel", configuration["selectedModel"]) - # Publish event to SQS queue for metrics processing + # Check if encryption is enabled via configuration table + encryption_enabled = _is_session_encryption_enabled() + + # Prepare session data for storage + session_data = { + "history": messages, + "name": body.get("name", None), + "configuration": configuration, + "startTime": datetime.now().isoformat(), + "createTime": datetime.now().isoformat(), + "lastUpdated": datetime.now().isoformat(), + } + + # Encrypt sensitive data if encryption is enabled + if encryption_enabled: + try: + logging.info(f"Encrypting session {session_id} for user {user_id}") + encrypted_session = migrate_session_to_encrypted(session_data, user_id, session_id) + + # Update DynamoDB with encrypted data + table.update_item( + Key={"sessionId": session_id, "userId": user_id}, + UpdateExpression="SET #encrypted_history = :encrypted_history, #name = :name, " + + "#encrypted_configuration = :encrypted_configuration, #startTime = :startTime, " + + "#createTime = if_not_exists(#createTime, :createTime), #lastUpdated = :lastUpdated, " + + "#encryption_version = :encryption_version, #is_encrypted = :is_encrypted", + ExpressionAttributeNames={ + "#encrypted_history": "encrypted_history", + "#name": "name", + "#encrypted_configuration": "encrypted_configuration", + "#startTime": "startTime", + "#createTime": "createTime", + "#lastUpdated": "lastUpdated", + "#encryption_version": "encryption_version", + "#is_encrypted": "is_encrypted", + }, + ExpressionAttributeValues={ + ":encrypted_history": encrypted_session["encrypted_history"], + ":name": encrypted_session["name"], + ":encrypted_configuration": encrypted_session["encrypted_configuration"], + ":startTime": encrypted_session["startTime"], + ":createTime": encrypted_session["createTime"], + ":lastUpdated": encrypted_session["lastUpdated"], + ":encryption_version": encrypted_session["encryption_version"], + ":is_encrypted": encrypted_session["is_encrypted"], + }, + ReturnValues="UPDATED_NEW", + ) + except SessionEncryptionError as e: + logging.error(f"Failed to encrypt session {session_id}: {e}") + return {"statusCode": 500, "body": json.dumps({"error": "Failed to encrypt session data"})} + else: + # Store unencrypted data (legacy mode) + table.update_item( + Key={"sessionId": session_id, "userId": user_id}, + UpdateExpression="SET #history = :history, #name = :name, #configuration = :configuration, " + + "#startTime = :startTime, #createTime = if_not_exists(#createTime, :createTime), " + + "#lastUpdated = :lastUpdated, #is_encrypted = :is_encrypted", + ExpressionAttributeNames={ + "#history": "history", + "#name": "name", + "#configuration": "configuration", + "#startTime": "startTime", + "#createTime": "createTime", + "#lastUpdated": "lastUpdated", + "#is_encrypted": "is_encrypted", + }, + ExpressionAttributeValues={ + ":history": messages, + ":name": body.get("name", None), + ":configuration": configuration, + ":startTime": datetime.now().isoformat(), + ":createTime": datetime.now().isoformat(), + ":lastUpdated": datetime.now().isoformat(), + ":is_encrypted": False, + }, + ReturnValues="UPDATED_NEW", + ) + + # Publish event to SQS queue for metrics processing (use unencrypted data for metrics) try: if "USAGE_METRICS_QUEUE_NAME" in os.environ: # Create a copy of the event to send to SQS @@ -430,29 +591,6 @@ def put_session(event: dict, context: dict) -> dict: except Exception as e: logger.error(f"Failed to publish to metrics queue: {e}") - table.update_item( - Key={"sessionId": session_id, "userId": user_id}, - UpdateExpression="SET #history = :history, #name = :name, #configuration = :configuration, " - + "#startTime = :startTime, #createTime = if_not_exists(#createTime, :createTime), " - + "#lastUpdated = :lastUpdated", - ExpressionAttributeNames={ - "#history": "history", - "#name": "name", - "#configuration": "configuration", - "#startTime": "startTime", - "#createTime": "createTime", - "#lastUpdated": "lastUpdated", - }, - ExpressionAttributeValues={ - ":history": messages, - ":name": body.get("name", None), - ":configuration": configuration, - ":startTime": datetime.now().isoformat(), - ":createTime": datetime.now().isoformat(), - ":lastUpdated": datetime.now().isoformat(), - }, - ReturnValues="UPDATED_NEW", - ) return {"statusCode": 200, "body": json.dumps({"message": "Session updated successfully"})} except ValueError as e: return {"statusCode": 400, "body": json.dumps({"error": str(e)})} diff --git a/lambda/utilities/auth.py b/lambda/utilities/auth.py index c6f5f5bd6..d98466eb0 100644 --- a/lambda/utilities/auth.py +++ b/lambda/utilities/auth.py @@ -16,11 +16,23 @@ from functools import wraps from typing import Any, Callable, Dict +import boto3 +from botocore.config import Config from utilities.common_functions import get_groups from utilities.exceptions import HTTPException logger = logging.getLogger(__name__) +retry_config = Config( + retries={ + "max_attempts": 3, + "mode": "standard", + }, +) + +secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + def get_username(event: dict) -> str: """Get the username from the event.""" @@ -46,3 +58,10 @@ def wrapper(event: Dict[str, Any], context: Dict[str, Any], *args: Any, **kwargs return func(event, context, *args, **kwargs) return wrapper + + +def get_management_key() -> str: + secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) + secret_name = secret_name_param["Parameter"]["Value"] + secret_response = secrets_client.get_secret_value(SecretId=secret_name) + return secret_response["SecretString"] diff --git a/lambda/utilities/common_functions.py b/lambda/utilities/common_functions.py index def669737..c7e9079dc 100644 --- a/lambda/utilities/common_functions.py +++ b/lambda/utilities/common_functions.py @@ -519,3 +519,22 @@ def get_bearer_token(event, with_prefix: bool = True): # Return the token after "Bearer " return auth_header.split(" ", 1)[1].strip() + + +def get_account_and_partition() -> tuple[str, str]: + """Get AWS account ID and partition from environment or ECR repository ARN. + + Returns: + tuple[str, str]: (account_id, partition) + """ + account_id = os.environ.get("AWS_ACCOUNT_ID", "") + partition = os.environ.get("AWS_PARTITION", "aws") + + if not account_id: + ecr_repo_arn = os.environ.get("ECR_REPOSITORY_ARN", "") + if ecr_repo_arn: + arn_parts = ecr_repo_arn.split(":") + partition = arn_parts[1] + account_id = arn_parts[4] + + return account_id, partition diff --git a/lambda/utilities/constants.py b/lambda/utilities/constants.py index 0dd0b48f4..a263d03ff 100644 --- a/lambda/utilities/constants.py +++ b/lambda/utilities/constants.py @@ -16,3 +16,4 @@ PDF_FILE = "pdf" TEXT_FILE = "txt" DOCX_FILE = "docx" +RICH_TEXT_FILE = "rtf" diff --git a/lambda/utilities/file_processing.py b/lambda/utilities/file_processing.py index f2f9c6857..db04f85ce 100644 --- a/lambda/utilities/file_processing.py +++ b/lambda/utilities/file_processing.py @@ -27,7 +27,7 @@ from models.domain_objects import ChunkingStrategyType, IngestionJob from pypdf import PdfReader from pypdf.errors import PdfReadError -from utilities.constants import DOCX_FILE, PDF_FILE, TEXT_FILE +from utilities.constants import DOCX_FILE, PDF_FILE, RICH_TEXT_FILE, TEXT_FILE from utilities.exceptions import RagUploadException logger = logging.getLogger(__name__) @@ -47,12 +47,13 @@ def _extract_text_by_content_type(content_type: str, s3_object: dict) -> str: extraction_functions = { PDF_FILE: _extract_pdf_content, DOCX_FILE: _extract_docx_content, - TEXT_FILE: lambda obj: obj["Body"].read(), + TEXT_FILE: _extract_text_content, + RICH_TEXT_FILE: _extract_text_content, } extraction_function = extraction_functions.get(content_type) if extraction_function: - return str(extraction_function(s3_object)) + return extraction_function(s3_object) else: logger.error(f"File has unsupported content type: {content_type}") raise RagUploadException("Unsupported file type") @@ -126,6 +127,18 @@ def _extract_docx_content(s3_object: dict) -> str: return output +def _extract_text_content(s3_object: dict) -> str: + """ + Extracts text content from an S3 object. Decode as + utf-8 to properly read special characters + + Parameters + ---------- + s3_object (dict): an S3 object containing a text file body. + """ + return s3_object["Body"].read().decode("utf-8", errors="replace") + + def generate_chunks(ingestion_job: IngestionJob) -> list[Document]: """Generate chunks from an ingestion job. diff --git a/lambda/utilities/repository_types.py b/lambda/utilities/repository_types.py index 2e9b054aa..0ec62cdd5 100644 --- a/lambda/utilities/repository_types.py +++ b/lambda/utilities/repository_types.py @@ -21,6 +21,19 @@ class RepositoryType(str, Enum): OPENSEARCH = "opensearch" BEDROCK_KB = "bedrock_knowledge_base" + @classmethod + def get_type(cls, repository: Dict[str, Any]) -> "RepositoryType": + return RepositoryType(repository.get("type")) + @classmethod def is_type(cls, repository: Dict[str, Any], repo_type: "RepositoryType") -> bool: - return repository.get("type") == repo_type.value + return repository.get("type") == repo_type + + def calculate_similarity_score(self, score: float) -> float: + # Convert cosine distance to similarity for PGVector + # PGVector returns cosine distance (0-2 range, lower = more similar) + # Convert to similarity (0-1 range, higher = more similar) + if self == RepositoryType.PGVECTOR: + return max(0.0, 1.0 - (score / 2.0)) + else: + return score diff --git a/lambda/utilities/session_encryption.py b/lambda/utilities/session_encryption.py new file mode 100644 index 000000000..3b57ef8b7 --- /dev/null +++ b/lambda/utilities/session_encryption.py @@ -0,0 +1,305 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for encrypting and decrypting session data.""" + +import base64 +import json +import logging +import os +from decimal import Decimal +from typing import Any, Dict, Optional + +import boto3 +from botocore.exceptions import ClientError +from cryptography.fernet import Fernet + +logger = logging.getLogger(__name__) + +# Initialize KMS client +kms_client = boto3.client("kms", region_name=os.environ.get("AWS_REGION", "us-east-1")) + + +class TypePreservingJSONEncoder(json.JSONEncoder): + """Custom JSON encoder that preserves numeric types.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, Decimal): + return float(obj) + return super().default(obj) + + +def _serialize_with_type_preservation(data: Any) -> str: + """Serialize data to JSON while preserving numeric types.""" + return json.dumps(data, cls=TypePreservingJSONEncoder) + + +def _deserialize_with_type_preservation(json_str: str) -> Any: + """Deserialize JSON string while preserving numeric types.""" + return json.loads(json_str, parse_float=float, parse_int=int) + + +class SessionEncryptionError(Exception): + """Custom exception for session encryption errors.""" + + pass + + +def _get_kms_key_arn() -> str: + """Get the KMS key ARN from environment variables.""" + key_arn = os.environ.get("SESSION_ENCRYPTION_KEY_ARN") + if not key_arn: + raise SessionEncryptionError("SESSION_ENCRYPTION_KEY_ARN environment variable not set") + return key_arn + + +def _generate_data_key(key_arn: str, encryption_context: Optional[Dict[str, str]] = None) -> tuple[bytes, bytes]: + """ + Generate a data key from KMS. + + Args: + key_arn: KMS key ARN + encryption_context: Optional encryption context + + Returns: + Tuple of (plaintext_data_key, encrypted_data_key) + """ + try: + response = kms_client.generate_data_key( + KeyId=key_arn, KeySpec="AES_256", EncryptionContext=encryption_context or {} + ) + return response["Plaintext"], response["CiphertextBlob"] + except ClientError as e: + logger.error(f"Failed to generate data key: {e}") + raise SessionEncryptionError(f"Failed to generate data key: {e}") + + +def _decrypt_data_key(encrypted_data_key: bytes, encryption_context: Optional[Dict[str, str]] = None) -> bytes: + """ + Decrypt a data key using KMS. + + Args: + encrypted_data_key: Encrypted data key + encryption_context: Optional encryption context + + Returns: + Plaintext data key + """ + try: + response = kms_client.decrypt(CiphertextBlob=encrypted_data_key, EncryptionContext=encryption_context or {}) + return response["Plaintext"] # type: ignore + except ClientError as e: + logger.error(f"Failed to decrypt data key: {e}") + raise SessionEncryptionError(f"Failed to decrypt data key: {e}") + + +def _create_encryption_context(user_id: str, session_id: str) -> Dict[str, str]: + """ + Create encryption context for KMS operations. + + Args: + user_id: User ID + session_id: Session ID + + Returns: + Encryption context dictionary + """ + return {"userId": user_id, "sessionId": session_id, "purpose": "session-encryption"} + + +def encrypt_session_data(data: Any, user_id: str, session_id: str) -> str: + """ + Encrypt session data using KMS envelope encryption. + + Args: + data: Data to encrypt (will be JSON serialized) + user_id: User ID for encryption context + session_id: Session ID for encryption context + + Returns: + Base64 encoded string containing encrypted data key and encrypted data + """ + try: + # Create encryption context + encryption_context = _create_encryption_context(user_id, session_id) + + # Get KMS key ARN + key_arn = _get_kms_key_arn() + + # Generate data key + plaintext_key, encrypted_key = _generate_data_key(key_arn, encryption_context) + + # Serialize data to JSON while preserving numeric types + json_data = _serialize_with_type_preservation(data) + + # Encrypt data using Fernet (AES 128 in CBC mode with PKCS7 padding) + fernet = Fernet(base64.urlsafe_b64encode(plaintext_key[:32])) + encrypted_data = fernet.encrypt(json_data.encode("utf-8")) + + # Combine encrypted key and encrypted data + combined = { + "encrypted_key": base64.b64encode(encrypted_key).decode("utf-8"), + "encrypted_data": base64.b64encode(encrypted_data).decode("utf-8"), + "encryption_version": "1.0", + } + + return base64.b64encode(json.dumps(combined).encode("utf-8")).decode("utf-8") + + except Exception as e: + logger.error(f"Failed to encrypt session data: {e}") + raise SessionEncryptionError(f"Failed to encrypt session data: {e}") + + +def decrypt_session_data(encrypted_data: str, user_id: str, session_id: str) -> Any: + """ + Decrypt session data using KMS envelope encryption. + + Args: + encrypted_data: Base64 encoded encrypted data + user_id: User ID for encryption context + session_id: Session ID for encryption context + + Returns: + Decrypted and deserialized data + """ + try: + # Create encryption context + encryption_context = _create_encryption_context(user_id, session_id) + + # Decode the combined data + combined_json = base64.b64decode(encrypted_data).decode("utf-8") + combined = json.loads(combined_json) + + # Extract encrypted key and data + encrypted_key = base64.b64decode(combined["encrypted_key"]) + encrypted_data_bytes = base64.b64decode(combined["encrypted_data"]) + + # Decrypt the data key + plaintext_key = _decrypt_data_key(encrypted_key, encryption_context) + + # Decrypt the data + fernet = Fernet(base64.urlsafe_b64encode(plaintext_key[:32])) + decrypted_json = fernet.decrypt(encrypted_data_bytes).decode("utf-8") + + # Deserialize and return while preserving numeric types + return _deserialize_with_type_preservation(decrypted_json) + + except Exception as e: + logger.error(f"Failed to decrypt session data: {e}") + raise SessionEncryptionError(f"Failed to decrypt session data: {e}") + + +def is_encrypted_data(data: str) -> bool: + """ + Check if a string appears to be encrypted session data. + + Args: + data: String to check + + Returns: + True if data appears to be encrypted + """ + try: + # Try to decode as base64 + decoded = base64.b64decode(data).decode("utf-8") + parsed = json.loads(decoded) + + # Check if it has the expected structure + return ( + isinstance(parsed, dict) + and "encrypted_key" in parsed + and "encrypted_data" in parsed + and "encryption_version" in parsed + ) + except Exception: + return False + + +def migrate_session_to_encrypted(session_data: Dict[str, Any], user_id: str, session_id: str) -> Dict[str, Any]: + """ + Migrate a session from unencrypted to encrypted format. + + Args: + session_data: Session data dictionary + user_id: User ID + session_id: Session ID + + Returns: + Updated session data with encrypted fields + """ + try: + # Fields to encrypt + fields_to_encrypt = ["history", "configuration"] + + # Create a copy of the session data + encrypted_session = session_data.copy() + + # Encrypt sensitive fields + for field in fields_to_encrypt: + if field in session_data and session_data[field] is not None: + encrypted_value = encrypt_session_data(session_data[field], user_id, session_id) + encrypted_session[f"encrypted_{field}"] = encrypted_value + # Remove the unencrypted field + del encrypted_session[field] + + # Add encryption metadata + encrypted_session["encryption_version"] = "1.0" + encrypted_session["is_encrypted"] = True + + return encrypted_session + + except Exception as e: + logger.error(f"Failed to migrate session to encrypted: {e}") + raise SessionEncryptionError(f"Failed to migrate session to encrypted: {e}") + + +def decrypt_session_fields(session_data: Dict[str, Any], user_id: str, session_id: str) -> Dict[str, Any]: + """ + Decrypt encrypted fields in session data. + + Args: + session_data: Session data dictionary + user_id: User ID + session_id: Session ID + + Returns: + Session data with decrypted fields + """ + try: + # Fields that might be encrypted + encrypted_fields = ["encrypted_history", "encrypted_configuration"] + decrypted_session = session_data.copy() + + # Decrypt encrypted fields + for encrypted_field in encrypted_fields: + if encrypted_field in session_data and session_data[encrypted_field] is not None: + # Get the original field name + original_field = encrypted_field.replace("encrypted_", "") + + # Decrypt the data + decrypted_data = decrypt_session_data(session_data[encrypted_field], user_id, session_id) + decrypted_session[original_field] = decrypted_data + + # Remove the encrypted field + del decrypted_session[encrypted_field] + + # Remove encryption metadata + decrypted_session.pop("encryption_version", None) + decrypted_session.pop("is_encrypted", None) + + return decrypted_session + + except Exception as e: + logger.error(f"Failed to decrypt session fields: {e}") + raise SessionEncryptionError(f"Failed to decrypt session fields: {e}") diff --git a/lambda/utilities/vector_store.py b/lambda/utilities/vector_store.py index 46370f400..5634f47a0 100644 --- a/lambda/utilities/vector_store.py +++ b/lambda/utilities/vector_store.py @@ -18,8 +18,7 @@ import os import boto3 -from langchain_community.vectorstores.opensearch_vector_search import OpenSearchVectorSearch -from langchain_community.vectorstores.pgvector import PGVector +from langchain_community.vectorstores import OpenSearchVectorSearch, PGVector from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from opensearchpy import RequestsHttpConnection @@ -61,7 +60,7 @@ def get_vector_store_client(repository_id: str, index: str, embeddings: Embeddin return OpenSearchVectorSearch( opensearch_url=opensearch_endpoint, - index_name=index.lower(), + index_name=index, embedding_function=embeddings, http_auth=auth, timeout=300, diff --git a/lib/api-base/fastApiContainer.ts b/lib/api-base/fastApiContainer.ts index 6bdc6276b..3f2a533df 100644 --- a/lib/api-base/fastApiContainer.ts +++ b/lib/api-base/fastApiContainer.ts @@ -92,20 +92,16 @@ export class FastApiContainer extends Construct { AWS_REGION_NAME: config.region, // for supporting SageMaker endpoints in LiteLLM THREADS: Ec2Metadata.get('m5.large').vCpus.toString(), LITELLM_KEY: config.litellmConfig.db_key, - TIKTOKEN_CACHE_DIR: '/app/TIKTOKEN_CACHE' + OPENAI_API_KEY: config.litellmConfig.db_key, + TIKTOKEN_CACHE_DIR: '/app/TIKTOKEN_CACHE', + USE_AUTH: 'true', + AUTHORITY: config.authConfig!.authority, + CLIENT_ID: config.authConfig!.clientId, + ADMIN_GROUP: config.authConfig!.adminGroup, + USER_GROUP: config.authConfig!.userGroup, + JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty, }; - if (config.restApiConfig.internetFacing) { - baseEnvironment.USE_AUTH = 'true'; - baseEnvironment.AUTHORITY = config.authConfig!.authority; - baseEnvironment.CLIENT_ID = config.authConfig!.clientId; - baseEnvironment.ADMIN_GROUP = config.authConfig!.adminGroup; - baseEnvironment.USER_GROUP = config.authConfig!.userGroup; - baseEnvironment.JWT_GROUPS_PROP = config.authConfig!.jwtGroupsProperty; - } else { - baseEnvironment.USE_AUTH = 'false'; - } - if (tokenTable) { baseEnvironment.TOKEN_TABLE_NAME = tokenTable.tableName; } @@ -140,7 +136,7 @@ export class FastApiContainer extends Construct { const ecsConfig: ECSConfig = { amiHardwareType: AmiHardwareType.STANDARD, autoScalingConfig: { - blockDeviceVolumeSize: 30, + blockDeviceVolumeSize: 50, minCapacity: 1, maxCapacity: 5, cooldown: 60, @@ -231,7 +227,7 @@ export class FastApiContainer extends Construct { 'logs:CreateLogStream', 'logs:PutLogEvents' ], - resources: ['arn:aws:logs:*:*:*'] + resources: [`arn:${config.partition}:logs:*:*:*`] }), new PolicyStatement({ effect: Effect.ALLOW, @@ -241,8 +237,8 @@ export class FastApiContainer extends Construct { 'ecs:DescribeClusters' ], resources: [ - `arn:aws:ecs:${config.region}:*:cluster/${workbenchService?.cluster?.clusterName}*`, - `arn:aws:ecs:${config.region}:*:service/${workbenchService?.cluster?.clusterName}*/${workbenchService?.serviceName}*` + `arn:${config.partition}:ecs:${config.region}:*:cluster/${workbenchService?.cluster?.clusterName}*`, + `arn:${config.partition}:ecs:${config.region}:*:service/${workbenchService?.cluster?.clusterName}*/${workbenchService?.serviceName}*` ] }), new PolicyStatement({ @@ -251,7 +247,7 @@ export class FastApiContainer extends Construct { 'ssm:GetParameter' ], resources: [ - `arn:aws:ssm:${config.region}:*:parameter${config.deploymentPrefix}/deploymentName` + `arn:${config.partition}:ssm:${config.region}:*:parameter${config.deploymentPrefix}/deploymentName` ] }) ] diff --git a/lib/chat/api/configuration.ts b/lib/chat/api/configuration.ts index 5b9d03f30..7531bf708 100644 --- a/lib/chat/api/configuration.ts +++ b/lib/chat/api/configuration.ts @@ -53,6 +53,8 @@ type ConfigurationApiProps = { * API which Maintains config state in DynamoDB */ export class ConfigurationApi extends Construct { + public readonly configTable: dynamodb.Table; + constructor (scope: Construct, id: string, props: ConfigurationApiProps) { super(scope, id); @@ -72,7 +74,7 @@ export class ConfigurationApi extends Construct { ); // Create DynamoDB table to handle config data - const configTable = new dynamodb.Table(this, 'ConfigurationTable', { + this.configTable = new dynamodb.Table(this, 'ConfigurationTable', { partitionKey: { name: 'configScope', type: dynamodb.AttributeType.STRING, @@ -87,8 +89,8 @@ export class ConfigurationApi extends Construct { }); const mcpServersTable = dynamodb.Table.fromTableName(this, 'McpServersTable', mcpApi.mcpServersTableNameParameter.stringValue); + const lambdaRole: IRole = createLambdaRole(this, config.deploymentName, 'ConfigurationApi', this.configTable.tableArn, config.roles?.LambdaConfigurationApiExecutionRole); - const lambdaRole: IRole = createLambdaRole(this, config.deploymentName, 'ConfigurationApi', configTable.tableArn, config.roles?.LambdaConfigurationApiExecutionRole); mcpServersTable.grantReadWriteData(lambdaRole); // Populate the App Config table with default config @@ -99,7 +101,7 @@ export class ConfigurationApi extends Construct { action: 'putItem', physicalResourceId: PhysicalResourceId.of('initConfigData'), parameters: { - TableName: configTable.tableName, + TableName: this.configTable.tableName, Item: { 'versionId': {'N': '0'}, 'changedBy': {'S': 'System'}, @@ -122,6 +124,7 @@ export class ConfigurationApi extends Construct { 'showPromptTemplateLibrary': {'BOOL': 'True'}, 'mcpConnections': {'BOOL': 'True'}, 'modelLibrary': {'BOOL': 'True'}, + 'encryptSession': {'BOOL': 'False'}, }}, 'systemBanner': {'M': { 'isEnabled': {'BOOL': 'False'}, @@ -144,7 +147,7 @@ export class ConfigurationApi extends Construct { const fastApiEndpoint = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/serve/endpoint`); const environment = { - CONFIG_TABLE_NAME: configTable.tableName, + CONFIG_TABLE_NAME: this.configTable.tableName, FASTAPI_ENDPOINT: fastApiEndpoint, // add MCP_SERVERS_TABLE_NAME so we can update it if the configuration changes MCP_SERVERS_TABLE_NAME: mcpServersTable.tableName @@ -185,11 +188,11 @@ export class ConfigurationApi extends Construct { lambdaRole, ); if (f.method === 'POST' || f.method === 'PUT') { - configTable.grantWriteData(lambdaFunction); + this.configTable.grantWriteData(lambdaFunction); } else if (f.method === 'GET') { - configTable.grantReadData(lambdaFunction); + this.configTable.grantReadData(lambdaFunction); } else if (f.method === 'DELETE') { - configTable.grantReadWriteData(lambdaFunction); + this.configTable.grantReadWriteData(lambdaFunction); } }); } diff --git a/lib/chat/api/session.ts b/lib/chat/api/session.ts index 04d9d6ba7..54f24cbb1 100644 --- a/lib/chat/api/session.ts +++ b/lib/chat/api/session.ts @@ -20,6 +20,7 @@ import { Effect, IRole, PolicyStatement } from 'aws-cdk-lib/aws-iam'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { LayerVersion } from 'aws-cdk-lib/aws-lambda'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; +import { Key } from 'aws-cdk-lib/aws-kms'; import { Construct } from 'constructs'; import { getDefaultRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; @@ -40,6 +41,7 @@ import { RemovalPolicy } from 'aws-cdk-lib'; * @property {IAuthorizer} authorizer - APIGW authorizer * @property {ISecurityGroup[]} securityGroups - Security groups for Lambdas * @property {Map }importedSubnets for application. + * @property {dynamodb.Table} configTable - Configuration DynamoDB table */ type SessionApiProps = { authorizer: IAuthorizer; @@ -47,6 +49,7 @@ type SessionApiProps = { rootResourceId: string; securityGroups: ISecurityGroup[]; vpc: Vpc; + configTable: dynamodb.Table; } & BaseProps; /** @@ -56,7 +59,7 @@ export class SessionApi extends Construct { constructor (scope: Construct, id: string, props: SessionApiProps) { super(scope, id); - const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc, configTable } = props; // Get common layer based on arn from SSM due to issues with cross stack references const commonLambdaLayer = LayerVersion.fromLayerVersionArn( @@ -65,6 +68,13 @@ export class SessionApi extends Construct { StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/layerVersion/common`), ); + // Get FastAPI layer for cryptography support + const fastapiLambdaLayer = LayerVersion.fromLayerVersionArn( + this, + 'session-fastapi-lambda-layer', + StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/layerVersion/fastapi`), + ); + // Create DynamoDB table to handle chat sessions const sessionTable = new dynamodb.Table(this, 'SessionsTable', { partitionKey: { @@ -92,6 +102,19 @@ export class SessionApi extends Construct { sortKey: { name: 'startTime', type: dynamodb.AttributeType.STRING }, }); + // Create KMS key for session data encryption + const sessionEncryptionKey = new Key(this, 'SessionEncryptionKey', { + description: 'KMS key for encrypting session data at rest', + enableKeyRotation: true, + removalPolicy: config.removalPolicy, + }); + + // Store KMS key ARN in SSM parameter for cross-stack access + new StringParameter(this, 'SessionEncryptionKeyArnParameter', { + parameterName: `${config.deploymentPrefix}/sessionEncryptionKeyArn`, + stringValue: sessionEncryptionKey.keyArn, + }); + const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) ); @@ -129,6 +152,8 @@ export class SessionApi extends Construct { SESSIONS_BY_USER_ID_INDEX_NAME: byUserIdIndex, GENERATED_IMAGES_S3_BUCKET_NAME: imagesBucket.bucketName, MODEL_TABLE_NAME: modelTableName, + CONFIG_TABLE_NAME: configTable.tableName, + SESSION_ENCRYPTION_KEY_ARN: sessionEncryptionKey.keyArn, }; const lambdaRole: IRole = createLambdaRole( @@ -148,6 +173,15 @@ export class SessionApi extends Construct { }) ); + // Add permissions to read from configuration table + lambdaRole.addToPrincipalPolicy( + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['dynamodb:GetItem', 'dynamodb:Query'], + resources: [configTable.tableArn] + }) + ); + // If metrics stack deployment is enabled if (config.deployMetrics) { // Get metrics queue name from SSM @@ -166,6 +200,19 @@ export class SessionApi extends Construct { Object.assign(env, { USAGE_METRICS_QUEUE_NAME: usageMetricsQueueName }); } + // Add KMS permissions for session encryption + lambdaRole.addToPrincipalPolicy( + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'kms:GenerateDataKey', + 'kms:Decrypt', + 'kms:DescribeKey' + ], + resources: [sessionEncryptionKey.keyArn] + }) + ); + // Create API Lambda functions const apis: PythonLambdaFunction[] = [ { @@ -231,7 +278,7 @@ export class SessionApi extends Construct { this, restApi, lambdaPath, - [commonLambdaLayer], + [commonLambdaLayer, fastapiLambdaLayer], f, getDefaultRuntime(), vpc, diff --git a/lib/chat/chatConstruct.ts b/lib/chat/chatConstruct.ts index 808a0883d..470f98f7e 100644 --- a/lib/chat/chatConstruct.ts +++ b/lib/chat/chatConstruct.ts @@ -50,8 +50,8 @@ export class LisaChatApplicationConstruct extends Construct { const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; - // Add REST API Lambdas to APIGW - new SessionApi(scope, 'SessionApi', { + + const mcpApi = new McpApi(scope, 'McpApi', { authorizer, config, restApiId, @@ -60,23 +60,26 @@ export class LisaChatApplicationConstruct extends Construct { vpc, }); - const mcpApi = new McpApi(scope, 'McpApi', { + // Create Configuration API first to get the configuration table + const configurationApi = new ConfigurationApi(scope, 'ConfigurationApi', { authorizer, config, restApiId, rootResourceId, securityGroups, vpc, + mcpApi }); - new ConfigurationApi(scope, 'ConfigurationApi', { + // Add REST API Lambdas to APIGW + new SessionApi(scope, 'SessionApi', { authorizer, config, restApiId, rootResourceId, securityGroups, vpc, - mcpApi + configTable: configurationApi.configTable, }); new PromptTemplateApi(scope, 'PromptTemplateApi', { diff --git a/lib/core/coreConstruct.ts b/lib/core/coreConstruct.ts index d5d8b12b4..3b80017e7 100644 --- a/lib/core/coreConstruct.ts +++ b/lib/core/coreConstruct.ts @@ -14,18 +14,14 @@ limitations under the License. */ import * as lambda from 'aws-cdk-lib/aws-lambda'; -import { ILayerVersion, LayerVersion } from 'aws-cdk-lib/aws-lambda'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { Construct } from 'constructs'; import { Layer } from './layers'; import { BaseProps } from '../schema'; -import { createCdkId } from './utils'; -import { PythonLayerVersion } from '@aws-cdk/aws-lambda-python-alpha'; -import { getDefaultRuntime } from '../api-base/utils'; import { RemovalPolicy, Stack, StackProps } from 'aws-cdk-lib'; -import { COMMON_LAYER_PATH, FASTAPI_LAYER_PATH, AUTHORIZER_LAYER_PATH, SDK_PATH } from '../util'; +import { COMMON_LAYER_PATH, FASTAPI_LAYER_PATH, AUTHORIZER_LAYER_PATH } from '../util'; import { Bucket } from 'aws-cdk-lib/aws-s3'; export const ARCHITECTURE = lambda.Architecture.X86_64; @@ -89,24 +85,6 @@ export class CoreConstruct extends Construct { assetPath: config.lambdaLayerAssets?.authorizerLayerPath, }); - // Build SDK Layer - let sdkLambdaLayer: ILayerVersion; - if (config.lambdaLayerAssets?.sdkLayerPath) { - sdkLambdaLayer = new LayerVersion(scope, 'SdkLayer', { - code: lambda.Code.fromAsset(config.lambdaLayerAssets?.sdkLayerPath), - compatibleRuntimes: [getDefaultRuntime()], - removalPolicy: config.removalPolicy, - description: 'LISA SDK common layer', - }); - } else { - sdkLambdaLayer = new PythonLayerVersion(scope, 'SdkLayer', { - entry: SDK_PATH, - compatibleRuntimes: [getDefaultRuntime()], - removalPolicy: config.removalPolicy, - description: 'LISA SDK common layer', - }); - } - new StringParameter(scope, 'LisaCommonLamdaLayerStringParameter', { parameterName: `${config.deploymentPrefix}/layerVersion/common`, stringValue: commonLambdaLayer.layer.layerVersionArn, @@ -124,11 +102,5 @@ export class CoreConstruct extends Construct { stringValue: authorizerLambdaLayer.layer.layerVersionArn, description: 'Layer Version ARN for LISA Authorizer Lambda Layer', }); - - new StringParameter(scope, createCdkId([config.deploymentName, config.deploymentStage, 'SdkLayer']), { - parameterName: `${config.deploymentPrefix}/layerVersion/lisa-sdk`, - stringValue: sdkLambdaLayer.layerVersionArn, - description: 'Layer Version ARN for LISA SDK Layer', - }); } } diff --git a/lib/core/layers/authorizer/requirements.txt b/lib/core/layers/authorizer/requirements.txt index fb0ff6550..c27b5c1e3 100644 --- a/lib/core/layers/authorizer/requirements.txt +++ b/lib/core/layers/authorizer/requirements.txt @@ -1,4 +1,5 @@ # urllib3<2 // Provided by Lambda -requests==2.32.4 +# cachetools==5.5.0 // provided by Common Layer +# requests==2.32.5 // provided by Common Layer cryptography==44.0.1 -PyJWT==2.9.0 +PyJWT==2.10.1 diff --git a/lib/core/layers/common/requirements.txt b/lib/core/layers/common/requirements.txt index 7488d54a0..6f0a41f2e 100644 --- a/lib/core/layers/common/requirements.txt +++ b/lib/core/layers/common/requirements.txt @@ -1,4 +1,6 @@ # boto3>=1.34.131 // Provided by Lambda # botocore>=1.34.131 // Provided by Lambda # urllib3<2 // Provided by Lambda -psycopg2-binary==2.9.9 +psycopg2-binary==2.9.10 +cachetools==5.5.0 +requests==2.32.5 diff --git a/lib/core/layers/fastapi/requirements.txt b/lib/core/layers/fastapi/requirements.txt index 11af70d39..56fd326e0 100644 --- a/lib/core/layers/fastapi/requirements.txt +++ b/lib/core/layers/fastapi/requirements.txt @@ -1,5 +1,6 @@ # boto3==1.34.131 // Provided by Lambda +# requests==2.32.5 // provided by Common Layer fastapi==0.111.0 mangum==0.17.0 pydantic==2.8.2 -requests==2.32.4 +cryptography==44.0.1 diff --git a/lib/docs/admin/deploy.md b/lib/docs/admin/deploy.md index 33765d25a..e7096ca54 100644 --- a/lib/docs/admin/deploy.md +++ b/lib/docs/admin/deploy.md @@ -195,77 +195,144 @@ make bootstrap ``` ## ADC Region Deployment Tips -If you are deploying LISA into an ADC region with limited access to dependencies, we recommend that you build LISA in a -commercial region first, and then bring it up into your ADC region to deploy. First, do the npm and pip installs on a -computer with access to the dependencies. Then bundle it up with the libraries included and move into the ADC region. -Some properties will need to be set in the deployment file pointing to the built artifacts. From there the deployment -process is the same. - -### Using pre-built resources - -A default configuration will build the necessary containers, lambda layers, and production optimized -web application at build time. In the event that you would like to use pre-built resources due to -network connectivity reasons or other concerns with the environment where you'll be deploying LISA -you can do so. - -- For ECS containers (Models, APIs, etc) you can modify the `containerConfig` block of - the corresponding entry in `config.yaml`. For container images you can provide a path to a directory - from which a docker container will be built (default), a path to a tarball, an ECR repository arn and - optional tag, or a public registry path. - - We provide immediate support for HuggingFace TGI and TEI containers and for vLLM containers. The `example_config.yaml` - file provides examples for TGI and TEI, and the only difference for using vLLM is to change the - `inferenceContainer`, `baseImage`, and `path` options, as indicated in the snippet below. All other options can - remain the same as the model definition examples we have for the TGI or TEI models. vLLM can also support embedding - models in this way, so all you need to do is refer to the embedding model artifacts and remove the `streaming` field - to deploy the embedding model. - - vLLM has support for the OpenAI Embeddings API, but model support for it is limited because the feature is new. Currently, - the only supported embedding model with vLLM is [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct), - but this list is expected to grow over time as vLLM updates. - ```yaml - ecsModels: - - modelName: your-model-name - inferenceContainer: tgi - baseImage: ghcr.io/huggingface/text-generation-inference:2.0.1 - ``` -- If you are deploying the LISA Chat User Interface you can optionally specify the path to the pre-built - website assets using the top level `webAppAssetsPath` parameter in `config.yaml`. Specifying this path - (typically `lib/user-interface/react/dist`) will avoid using a container to build and bundle the assets - at CDK build time. -- For the lambda layers you can specify the path to a local zip archive of the layer code by including - the optional `lambdaLayerAssets` block in `config.yaml` similar to the following: +Amazon Dedicated Cloud (ADC) regions are isolated AWS environments designed for government customers' most sensitive workloads. These regions have restricted internet access and limited external dependencies, requiring special deployment considerations for LISA. -``` +There are two deployment approaches for ADC regions: + +1. **Pre-built Resources (Recommended)**: Build all components in a commercial region, then transfer to ADC +2. **In-Region Building**: Configure LISA to use ADC-accessible repositories for building components + +### Approach 1: Pre-built Resources (Recommended) + +This approach builds all necessary components in a commercial region with full internet access, then transfers them to the ADC region. + +#### Step 1: Build Components in Commercial Region + +1. Set up LISA in a commercial AWS region with internet access +2. Build all components: + ```bash + make buildArchive + ``` + This generates: + - Lambda function zip files in `./dist/layers/*.zip` + - Docker images exported as `./dist/images/*.tar` files + +#### Step 2: Transfer to ADC Region + +1. Upload Docker images to ECR in your ADC region: + ```bash + # Load and tag images + docker load -i lisa-rest-api.tar + docker tag lisa-rest-api:latest .dkr.ecr..amazonaws.com/lisa-rest-api:latest + + # Push to ADC ECR + aws ecr get-login-password --region | docker login --username AWS --password-stdin .dkr.ecr..amazonaws.com + docker push .dkr.ecr..amazonaws.com/lisa-rest-api:latest + ``` + You'll want to repeat this for lisa-batch-ingestion, as well as any of the LISA base model hosting containers (lisa-vllm, lisa-tgi, lisa-tei) + +2. Transfer built artifacts to ADC environment + +#### Step 3: Configure LISA for Pre-built Resources + +Update your `config-custom.yaml` in the ADC region: + +```yaml +# Lambda layers from pre-built archives lambdaLayerAssets: - authorizerLayerPath: lib/core/layers/authorizer_layer.zip - commonLayerPath: lib/core/layers/common_layer.zip - fastapiLayerPath: /path/to/fastapi_layer.zip - sdkLayerPath: lib/rag/layers/sdk_layer.zip + authorizerLayerPath: './dist/layers/AimlAdcLisaAuthLayer.zip' + commonLayerPath: './dist/layers/AimlAdcLisaCommonLayer.zip' + fastapiLayerPath: './dist/layers/AimlAdcLisaFastApiLayer.zip' + ragLayerPath: './dist/layers/AimlAdcLisaRag.zip' + sdkLayerPath: './dist/layers/AimlAdcLisaSdk.zip' + +# Lambda functions +lambdaPath: './dist/layers/AimlAdcLisaLambda.zip' + +# Pre-built web assets +webAppAssetsPath: './dist/lisa-web' +documentsPath: './dist/docs' +ecsModelDeployerPath: './dist/ecs_model_deployer' +vectorStoreDeployerPath: './dist/vector_store_deployer' + +# Container images from ECR +batchIngestionConfig: + type: external + code: .dkr.ecr..amazonaws.com/lisa-batch-ingestion:latest + +restApiConfig: + imageConfig: + type: external + code: .dkr.ecr..amazonaws.com/lisa-rest-api:latest ``` -### Deploying in ADC region -Now that we have everything setup we are ready to deploy. -```bash -make deploy -``` +### Approach 2: In-Region Building -By default, all stacks will be deployed but a particular stack can be deployed by providing the `STACK` argument to the `deploy` target. +This approach configures LISA to build components using repositories accessible from within the ADC region. -```bash -make deploy STACK=LisaServe -``` +#### Prerequisites +- ADC-accessible package repositories (PyPI mirror, npm registry, container registry) +- ADC-accessible container registries +- Network connectivity to required build dependencies -Available stacks can be listed by running: +#### Configuration -```bash -make listStacks +Update your `config-custom.yaml` to point to ADC-accessible repositories: + +```yaml +# Configure pip to use ADC-accessible PyPI mirror +pipConfig: + indexUrl: https://your-adc-pypi-mirror.com/simple + trustedHost: your-adc-pypi-mirror.com + +# Configure npm to use ADC-accessible registry +npmConfig: + registry: https://your-adc-npm-registry.com + +# Use ADC-accessible base images for LISA-Serve and Batch Ingestion +baseImage: /python:3.11 ``` +You'll also want any model hosting base containers available, e.g. vllm/vllm-openai:latest and ghcr.io/huggingface/text-embeddings-inference:latest + +To utilize the prebuilt hosting model containers with self-hosted models, select `type: ecr` in the Model Deployment > Container Configs. -After the `deploy` command is run, you should see many docker build outputs and eventually a CDK progress bar. The deployment should take about 10-15 minutes and will produce a single cloud formation output for the websocket URL. +### Deployment Steps -You can test the deployment with the integration test: +Once your configuration is complete: + +1. Bootstrap CDK (if not already done): + ```bash + make bootstrap + ``` + +2. Deploy LISA: + ```bash + make deploy + ``` + +3. Deploy specific stacks if needed: + ```bash + make deploy STACK=LisaServe + ``` + +4. List available stacks: + ```bash + make listStacks + ``` + +### Testing Your Deployment + +After deployment completes (10-15 minutes), test with: ```bash -pytest lisa-sdk/tests --url --verify | false +pytest lisa-sdk/tests --url --verify ``` + +### Troubleshooting ADC Deployments + +- **Build failures**: Ensure all dependencies are accessible from ADC region +- **Container pull errors**: Verify ECR repositories exist and have correct permissions +- **Lambda deployment issues**: Check that lambda zip files are properly formatted and accessible +- **Network connectivity**: Confirm VPC configuration allows required outbound connections diff --git a/lib/models/model-api.ts b/lib/models/model-api.ts index a81fc00db..0ec602a99 100644 --- a/lib/models/model-api.ts +++ b/lib/models/model-api.ts @@ -29,8 +29,10 @@ import { Role, ServicePrincipal, } from 'aws-cdk-lib/aws-iam'; -import { LayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { Code, Function, LayerVersion } from 'aws-cdk-lib/aws-lambda'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; +import { CustomResource, Duration } from 'aws-cdk-lib'; +import { Provider } from 'aws-cdk-lib/custom-resources'; import { Construct } from 'constructs'; import { getDefaultRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../api-base/utils'; @@ -149,7 +151,7 @@ export class ModelsApi extends Construct { const stateMachinesLambdaRole = config.roles ? Role.fromRoleName(this, Roles.MODEL_SFN_LAMBDA_ROLE, config.roles.ModelsSfnLambdaRole) : this.createStateMachineLambdaRole(modelTable.tableArn, dockerImageBuilder.dockerImageBuilderFn.functionArn, - ecsModelDeployer.ecsModelDeployerFn.functionArn, lisaServeEndpointUrlPs.parameterArn, managementKeyName); + ecsModelDeployer.ecsModelDeployerFn.functionArn, lisaServeEndpointUrlPs.parameterArn, managementKeyName, config); const createModelStateMachine = new CreateModelStateMachine(this, 'CreateModelWorkflow', { config: config, @@ -330,6 +332,39 @@ export class ModelsApi extends Construct { ] }); lambdaFunction.role!.attachInlinePolicy(workflowPermissions); + + // Model API key cleanup - runs once per deployment version + const modelApiKeyCleanupLambda = new Function(this, 'ModelApiKeyCleanup', { + runtime: getDefaultRuntime(), + handler: 'models.model_api_key_cleanup.lambda_handler', + code: Code.fromAsset(lambdaPath), + layers: lambdaLayers, + environment: { + LISA_API_URL_PS_NAME: lisaServeEndpointUrlPs.parameterName, + MANAGEMENT_KEY_NAME: managementKeyName, + REST_API_VERSION: 'v2', + DEPLOYMENT_PREFIX: config.deploymentPrefix || '', + }, + role: stateMachinesLambdaRole, + vpc: vpc.vpc, + securityGroups: securityGroups, + timeout: Duration.minutes(5), + description: 'Remove api_key from existing Bedrock models to fix Invalid API Key format errors', + }); + + // Run cleanup automatically during deployment + const cleanupProvider = new Provider(this, 'ModelApiKeyCleanupProvider', { + onEventHandler: modelApiKeyCleanupLambda, + }); + + new CustomResource(this, 'ModelApiKeyCleanupResource', { + serviceToken: cleanupProvider.serviceToken, + properties: { + // Only runs once - increment this version number if you need to run cleanup again + CleanupVersion: '1', + }, + }); + } /** @@ -341,7 +376,7 @@ export class ModelsApi extends Construct { * @param managementKeyName - Name of the management key secret * @returns The created role */ - createStateMachineLambdaRole (modelTableArn: string, dockerImageBuilderFnArn: string, ecsModelDeployerFnArn: string, lisaServeEndpointUrlParamArn: string, managementKeyName: string): IRole { + createStateMachineLambdaRole (modelTableArn: string, dockerImageBuilderFnArn: string, ecsModelDeployerFnArn: string, lisaServeEndpointUrlParamArn: string, managementKeyName: string, config: any): IRole { return new Role(this, Roles.MODEL_SFN_LAMBDA_ROLE, { assumedBy: new ServicePrincipal('lambda.amazonaws.com'), managedPolicies: [ @@ -357,6 +392,7 @@ export class ModelsApi extends Construct { 'dynamodb:GetItem', 'dynamodb:PutItem', 'dynamodb:UpdateItem', + 'dynamodb:Scan', ], resources: [ modelTableArn, @@ -416,15 +452,6 @@ export class ModelsApi extends Construct { 'StringEquals': {'aws:ResourceTag/lisa_temporary_instance': 'true'} } }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'ssm:GetParameter' - ], - resources: [ - lisaServeEndpointUrlParamArn - ], - }), new PolicyStatement({ effect: Effect.ALLOW, actions: [ @@ -466,6 +493,36 @@ export class ModelsApi extends Construct { ], resources: ['*'], }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'bedrock:InvokeModel', + 'bedrock:InvokeModelWithResponseStream', + ], + resources: ['*'], // Bedrock model ARNs are dynamic and region-specific + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'ssm:GetParameter', + ], + resources: [ + lisaServeEndpointUrlParamArn, + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/lisaServeRestApiUri`, + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter/LISA-lisa-management-key`, + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/LiteLLMDbConnectionInfo`, + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/modelTableName`, + ], + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'secretsmanager:GetSecretValue', + ], + resources: [ + `arn:${config.partition}:secretsmanager:${config.region}:${config.accountNumber}:secret:*`, // LiteLLM DB password secret + ], + }), ] }), } diff --git a/lib/models/state-machine/create-model.ts b/lib/models/state-machine/create-model.ts index cb5de9ddf..57c7a3d4c 100644 --- a/lib/models/state-machine/create-model.ts +++ b/lib/models/state-machine/create-model.ts @@ -75,6 +75,7 @@ export class CreateModelStateMachine extends Construct { RESTAPI_SSL_CERT_ARN: config.restApiConfig?.sslCertIamArn ?? '', LITELLM_CONFIG_OBJ: JSON.stringify(config.litellmConfig), AWS_ACCOUNT_ID: config.accountNumber, + AWS_PARTITION: config.partition, }; const setModelToCreating = new LambdaInvoke(this, 'SetModelToCreating', { diff --git a/lib/rag/api/repository.ts b/lib/rag/api/repository.ts index 4a8e6a360..abbe6014a 100644 --- a/lib/rag/api/repository.ts +++ b/lib/rag/api/repository.ts @@ -125,6 +125,16 @@ export class RepositoryApi extends Construct { ...baseEnvironment, }, }, + { + name: 'delete_index', + resource: 'repository', + description: 'Delete an index within a repository', + path: 'repository/{repositoryId}/index/{modelName}', + method: 'DELETE', + environment: { + ...baseEnvironment, + }, + }, { name: 'similarity_search', resource: 'repository', diff --git a/lib/rag/ingestion/ingestion-image/Dockerfile b/lib/rag/ingestion/ingestion-image/Dockerfile index 22ae46154..dd2c07ccc 100644 --- a/lib/rag/ingestion/ingestion-image/Dockerfile +++ b/lib/rag/ingestion/ingestion-image/Dockerfile @@ -1,11 +1,13 @@ -FROM public.ecr.aws/lambda/python:3.11 +ARG BASE_IMAGE=public.ecr.aws/lambda/python:3.11 +FROM ${BASE_IMAGE} ARG BUILD_DIR=build WORKDIR /workdir COPY ./requirements.txt /workdir -RUN /var/lang/bin/pip install --no-cache-dir -r /workdir/requirements.txt -t . +RUN /var/lang/bin/pip install --no-cache-dir --upgrade pip && \ + /var/lang/bin/pip install --no-cache-dir -r /workdir/requirements.txt -t . COPY ./${BUILD_DIR} /workdir diff --git a/lib/rag/ingestion/ingestion-image/requirements.txt b/lib/rag/ingestion/ingestion-image/requirements.txt index 97cd2f20d..9b1625b4f 100644 --- a/lib/rag/ingestion/ingestion-image/requirements.txt +++ b/lib/rag/ingestion/ingestion-image/requirements.txt @@ -1,6 +1,8 @@ # boto3>=1.34.131 // Provided by Lambda # botocore>=1.34.131 // Provided by Lambda # urllib3<2 // Provided by Lambda +# Pin NumPy to avoid GCC 9.3+ requirement in Docker builds +numpy<2.0.0 aioboto3==12.3.0 aiobotocore==2.11.2 aiohttp==3.12.14 @@ -10,9 +12,9 @@ cryptography==44.0.1 fastapi_utils==0.7.0 fastapi==0.115.11 gunicorn==23.0.0 -langchain-community==0.3.9 -langchain-openai==0.2.11 -langchain==0.3.9 +langchain-community==0.3.27 +langchain-core==0.3.76 +langchain-text-splitters==0.3.11 loguru==0.7.2 mangum==0.17.0 opensearch-py==2.6.0 @@ -20,12 +22,12 @@ pgvector==0.2.5 prisma==0.13.1 psycopg2-binary==2.9.9 pydantic==2.8.2 -PyJWT==2.9.0 +PyJWT==2.10.1 pynacl==1.5.0 pypdf==6.0.0 lxml==5.1.0 python-docx==1.1.0 requests-aws4auth==1.2.3 -requests==2.32.4 +requests==2.32.5 text-generation==0.7.0 uvicorn==0.29.0 diff --git a/lib/rag/ingestion/ingestion-job-construct.ts b/lib/rag/ingestion/ingestion-job-construct.ts index 37625cdb4..9e52cc066 100644 --- a/lib/rag/ingestion/ingestion-job-construct.ts +++ b/lib/rag/ingestion/ingestion-job-construct.ts @@ -29,10 +29,10 @@ import * as batch from 'aws-cdk-lib/aws-batch'; import * as ecs from 'aws-cdk-lib/aws-ecs'; import * as dynamodb from 'aws-cdk-lib/aws-dynamodb'; import * as lambda from 'aws-cdk-lib/aws-lambda'; +import { getDefaultRuntime } from '../../api-base/utils'; import { Vpc } from '../../networking/vpc'; import path from 'path'; import { ILayerVersion } from 'aws-cdk-lib/aws-lambda'; -import { getDefaultRuntime } from '../../api-base/utils'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import * as fs from 'fs'; import * as crypto from 'crypto'; @@ -63,7 +63,7 @@ export class IngestionJobConstruct extends Construct { constructor (scope: Construct, id: string, props: IngestionJobConstructProps) { super(scope, id); - const { config, vpc, lambdaRole, layers, baseEnvironment } = props; + const { config, vpc, layers, lambdaRole, baseEnvironment } = props; const hash = crypto.randomBytes(6).toString('hex'); // DynamoDB table for tracking ingestion jobs @@ -132,10 +132,9 @@ export class IngestionJobConstruct extends Construct { // Skip actual copying during tests to avoid file not found errors if (process.env.NODE_ENV !== 'test') { fs.cpSync(path.join(__dirname, '../../../lambda'), buildDir, copyOptions); - fs.cpSync(path.join(__dirname, '../../../lisa-sdk/lisapy'), path.join(buildDir, 'lisapy'), copyOptions); } else { // For tests, we just ensure the directories exist but don't copy files - const directories = ['repository', 'prompt_templates', 'lisapy']; + const directories = ['repository', 'prompt_templates']; directories.forEach((dir) => { const dirPath = path.join(buildDir, dir); fs.mkdirSync(dirPath, { recursive: true }); @@ -190,7 +189,7 @@ export class IngestionJobConstruct extends Construct { resources: ['*'] })); - // Lambda function for handling scheduled document ingestion + // Lambda function for handling scheduled document ingestion - using container image const handlePipelineIngestScheduleLambda = new lambda.Function(this, 'handlePipelineIngestSchedule', { functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-ingest-schedule-${hash}`, runtime: getDefaultRuntime(), @@ -213,7 +212,7 @@ export class IngestionJobConstruct extends Construct { action: 'lambda:InvokeFunction' }); - // Lambda function for handling S3 event-based document ingestion + // Lambda function for handling S3 event-based document ingestion - using container image const handlePipelineIngestEvent = new lambda.Function(this, 'handlePipelineIngestEvent', { functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-ingest-event-${hash}`, runtime: getDefaultRuntime(), @@ -223,7 +222,7 @@ export class IngestionJobConstruct extends Construct { memorySize: 256, vpc: vpc!.vpc, environment: baseEnvironment, - layers: layers, + layers, role: lambdaRole }); const eventParameterName = `${config.deploymentPrefix}/ingestion/ingest/event`; @@ -236,17 +235,17 @@ export class IngestionJobConstruct extends Construct { action: 'lambda:InvokeFunction' }); - // Lambda function for handling document deletion events + // Lambda function for handling document deletion events - using container image const handlePipelineDeleteEvent = new lambda.Function(this, 'handlePipelineDeleteEvent', { functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-delete-event-${hash}`, runtime: getDefaultRuntime(), - handler: 'repository.pipeline_delete_documents.handle_pipeline_delete_event', + handler: 'repository.pipeline_ingest_documents.handle_pipeline_delete_event', code: lambda.Code.fromAsset('./lambda'), timeout: Duration.seconds(60), memorySize: 256, vpc: vpc!.vpc, environment: baseEnvironment, - layers: layers, + layers, role: lambdaRole }); const deleteParameterName = `${config.deploymentPrefix}/ingestion/delete/event`; diff --git a/lib/rag/layer/requirements.txt b/lib/rag/layer/requirements.txt index 36d98af3e..02dd6c16e 100644 --- a/lib/rag/layer/requirements.txt +++ b/lib/rag/layer/requirements.txt @@ -1,13 +1,22 @@ -# boto3>=1.34.131 // Provided by Lambda -# botocore>=1.34.131 // Provided by Lambda -# urllib3<2 // Provided by Lambda -langchain==0.3.9 -langchain-community==0.3.9 -langchain-openai==0.2.11 +# Container-based Lambda - no size constraints! +# All packages can be included since container images have 10GB limit + +# Core RAG packages +# psycopg2-binary==2.9.10 // provided by Common Layer +langchain-text-splitters==0.3.11 +langchain-community==0.3.27 +langchain-core==0.3.76 +# Required by langchain-community - Pin NumPy to avoid GCC 9.3+ +numpy<2.0.0 + +# Database and search connectors opensearch-py==2.6.0 pgvector==0.2.5 -psycopg2-binary==2.9.9 pypdf==6.0.0 -lxml==5.1.0 -python-docx==1.1.0 -requests-aws4auth==1.2.3 +lxml==5.3.0 +python-docx==1.1.2 +requests-aws4auth==1.3.1 +tiktoken==0.9.0 + +# Force urllib3 to newer version for security (resolves Poetry 1.5.1 conflict) +urllib3==2.5.0 diff --git a/lib/rag/ragConstruct.ts b/lib/rag/ragConstruct.ts index fa35491ff..509287b82 100644 --- a/lib/rag/ragConstruct.ts +++ b/lib/rag/ragConstruct.ts @@ -169,6 +169,7 @@ export class LisaRagConstruct extends Construct { REGISTERED_REPOSITORIES_PS_PREFIX: `${config.deploymentPrefix}/LisaServeRagConnectionInfo/`, REGISTERED_REPOSITORIES_PS: `${config.deploymentPrefix}/registeredRepositories`, REST_API_VERSION: 'v2', + TIKTOKEN_CACHE_DIR: '/tmp', }; // Add REST API SSL Cert ARN if it exists to be used to verify SSL calls to REST API @@ -196,12 +197,6 @@ export class LisaRagConstruct extends Construct { StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/layerVersion/common`), ); - const sdkLayer = LayerVersion.fromLayerVersionArn( - scope, - 'rag-sdk-lambda-layer', - StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/layerVersion/lisa-sdk`), - ); - // Pre-generate the tiktoken cache to ensure it does not attempt to fetch data from the internet at runtime. if (config.restApiConfig.imageConfig === undefined) { const cache_dir = path.join(RAG_LAYER_PATH, 'TIKTOKEN_CACHE'); @@ -233,7 +228,22 @@ export class LisaRagConstruct extends Construct { parameterName: `${config.deploymentPrefix}/layerVersion/rag`, stringValue: ragLambdaLayer.layer.layerVersionArn }); - const layers = [commonLambdaLayer, ragLambdaLayer.layer, sdkLayer]; + + const layers = [commonLambdaLayer, ragLambdaLayer.layer]; + + // Pre-generate the tiktoken cache to ensure it does not attempt to fetch data from the internet at runtime. + if (config.restApiConfig.imageConfig === undefined) { + const cache_dir = path.join(RAG_LAYER_PATH, 'TIKTOKEN_CACHE'); + // Skip tiktoken cache generation in test environment + if (process.env.NODE_ENV !== 'test') { + try { + child_process.execSync(`python3 scripts/cache-tiktoken-for-offline.py ${cache_dir}`, { stdio: 'inherit' }); + } catch (error) { + console.warn('Failed to generate tiktoken cache:', error); + // Continue execution even if cache generation fails + } + } + } // create a security group for opensearch const openSearchSg = SecurityGroupFactory.createSecurityGroup( @@ -323,7 +333,7 @@ export class LisaRagConstruct extends Construct { config, vpc, baseEnvironment, - { common: commonLambdaLayer, rag: ragLambdaLayer.layer, sdk: sdkLayer }, + { common: commonLambdaLayer, rag: ragLambdaLayer.layer }, lambdaRole, docMetaTable, subDocTable, @@ -354,7 +364,7 @@ export class LisaRagConstruct extends Construct { config: Config, vpc: Vpc, baseEnvironment: Record, - layers: { [key in 'common' | 'sdk' | 'rag']: ILayerVersion }, + layers: { [key in 'common' | 'rag']: ILayerVersion }, lambdaRole: IRole, docMetaTable: dynamodb.ITable, subDocTable: dynamodb.ITable, @@ -613,7 +623,7 @@ export class LisaRagConstruct extends Construct { rdsConfig: ragConfig.rdsConfig, repositoryId: ragConfig.repositoryId, type: ragConfig.type, - layers: [layers.common, layers.rag, layers.sdk], + layers: [layers.common, layers.rag], registeredRepositoriesParamName, ragDocumentTable: docMetaTable, ragSubDocumentTable: subDocTable, diff --git a/lib/rag/vector-store/state_machine/delete-store.ts b/lib/rag/vector-store/state_machine/delete-store.ts index cda13d22d..1c5cb23f3 100644 --- a/lib/rag/vector-store/state_machine/delete-store.ts +++ b/lib/rag/vector-store/state_machine/delete-store.ts @@ -33,8 +33,8 @@ import { LAMBDA_PATH } from '../../../util'; type DeleteStoreStateMachineProps = BaseProps & { ragVectorStoreTable: ITable, - lambdaLayers: ILayerVersion[]; vectorStoreDeployerFnArn: string; + lambdaLayers: ILayerVersion[]; vpc: Vpc, role?: iam.IRole, executionRole: iam.IRole; @@ -137,7 +137,8 @@ export class DeleteStoreStateMachine extends Construct { }).next(handleCleanupBedrockKnowledgeBase); const lambdaPath = config.lambdaPath || LAMBDA_PATH; - const cleanupDocsFunc = new Function(this, 'CleanupRepositoryDocsFunc', { + + const cleanupDocsFunc = new Function(this, 'CleanupRepositoryDocsFunc', { runtime: getDefaultRuntime(), handler: 'repository.state_machine.cleanup_repo_docs.lambda_handler', code: Code.fromAsset(lambdaPath), diff --git a/lib/schema/configSchema.ts b/lib/schema/configSchema.ts index c29fd4c76..59dd74789 100644 --- a/lib/schema/configSchema.ts +++ b/lib/schema/configSchema.ts @@ -490,7 +490,7 @@ export const MetricConfigSchema = z.object({ .describe('Metric configuration for ECS auto scaling.'); export const AutoScalingConfigSchema = z.object({ - blockDeviceVolumeSize: z.number().min(30).default(30), + blockDeviceVolumeSize: z.number().min(30).default(50), minCapacity: z.number().min(1).default(1).describe('Minimum capacity for auto scaling. Must be at least 1.'), maxCapacity: z.number().min(1).default(2).describe('Maximum capacity for auto scaling. Must be at least 1.'), defaultInstanceWarmup: z.number().default(180).describe('Default warm-up time in seconds until a newly launched instance can'), @@ -828,7 +828,7 @@ export const RawConfigObject = z.object({ privateEndpoints: z.boolean().default(false).describe('Whether to use privateEndpoints for REST API.'), s3BucketModels: z.string().describe('S3 bucket for models.'), mountS3DebUrl: z.string().describe('URL for S3-mounted Debian package.'), - imageBuilderVolumeSize: z.number().default(30).describe('EC2 volume size for image builder. Needs to be large enough for system plus inference container.'), + imageBuilderVolumeSize: z.number().default(50).describe('EC2 volume size for image builder. Needs to be large enough for system plus inference container.'), accountNumbersEcr: z .array(z.union([z.number(), z.string()])) .transform((arr) => arr.map(String)) @@ -883,7 +883,6 @@ export const RawConfigObject = z.object({ commonLayerPath: z.string().optional().describe('Lambda common layer code path'), fastapiLayerPath: z.string().optional().describe('Lambda API code path'), ragLayerPath: z.string().optional().describe('Lambda RAG layer code path'), - sdkLayerPath: z.string().optional().describe('Lambda SDK layer code path'), }) .optional() .describe('Configuration for local Lambda layer code'), diff --git a/lib/serve/ecs-model/embedding/instructor/Dockerfile b/lib/serve/ecs-model/embedding/instructor/Dockerfile index 177356c37..f84bf09ab 100644 --- a/lib/serve/ecs-model/embedding/instructor/Dockerfile +++ b/lib/serve/ecs-model/embedding/instructor/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE +ARG BASE_IMAGE=python:3.11 FROM ${BASE_IMAGE} #### POINT TO NEW PYPI CONFIG @@ -26,6 +26,7 @@ RUN if [ "$CONDA_URL" != "" ]; then \ RUN /opt/conda/bin/conda install s5cmd && \ /opt/conda/bin/conda clean -ya && \ + pip install --no-cache-dir --upgrade pip && \ pip install -U --no-cache --upgrade torchserve torch-model-archiver torch-workflow-archiver ARG LOCAL_CODE_PATH diff --git a/lib/serve/ecs-model/embedding/instructor/src/requirements.txt b/lib/serve/ecs-model/embedding/instructor/src/requirements.txt index 533606cf1..33e2ac50b 100644 --- a/lib/serve/ecs-model/embedding/instructor/src/requirements.txt +++ b/lib/serve/ecs-model/embedding/instructor/src/requirements.txt @@ -1,3 +1,3 @@ -InstructorEmbedding -sentence-transformers -transformers +InstructorEmbedding==1.0.1 +sentence-transformers==3.3.1 +transformers==4.56.0 diff --git a/lib/serve/ecs-model/embedding/tei/Dockerfile b/lib/serve/ecs-model/embedding/tei/Dockerfile index 1e25374b0..b45ddfe99 100644 --- a/lib/serve/ecs-model/embedding/tei/Dockerfile +++ b/lib/serve/ecs-model/embedding/tei/Dockerfile @@ -1,10 +1,14 @@ -ARG BASE_IMAGE +ARG BASE_IMAGE=ghcr.io/huggingface/text-embeddings-inference:1.2.3 FROM ${BASE_IMAGE} ##### DOWNLOAD MOUNTPOINTS S3 ARG MOUNTS3_DEB_URL +ARG MOUNTS3_DEB_SHA256 RUN apt update -y && apt install -y wget rsync && \ - wget ${MOUNTS3_DEB_URL} && \ + wget ${MOUNTS3_DEB_URL} -O mount-s3.deb && \ + if [ -n "${MOUNTS3_DEB_SHA256}" ]; then \ + echo "${MOUNTS3_DEB_SHA256} mount-s3.deb" | sha256sum -c; \ + fi && \ apt install -y ./mount-s3.deb && \ rm mount-s3.deb diff --git a/lib/serve/ecs-model/textgen/tgi/Dockerfile b/lib/serve/ecs-model/textgen/tgi/Dockerfile index 1e25374b0..8549e3a46 100644 --- a/lib/serve/ecs-model/textgen/tgi/Dockerfile +++ b/lib/serve/ecs-model/textgen/tgi/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE +ARG BASE_IMAGE=ghcr.io/huggingface/text-generation-inference:2.0.1 FROM ${BASE_IMAGE} ##### DOWNLOAD MOUNTPOINTS S3 diff --git a/lib/serve/ecs-model/vllm/Dockerfile b/lib/serve/ecs-model/vllm/Dockerfile index 1e25374b0..a35b36396 100644 --- a/lib/serve/ecs-model/vllm/Dockerfile +++ b/lib/serve/ecs-model/vllm/Dockerfile @@ -1,8 +1,9 @@ -ARG BASE_IMAGE +ARG BASE_IMAGE=python:3.11 FROM ${BASE_IMAGE} ##### DOWNLOAD MOUNTPOINTS S3 ARG MOUNTS3_DEB_URL +ARG MOUNTS3_DEB_SHA256 RUN apt update -y && apt install -y wget rsync && \ wget ${MOUNTS3_DEB_URL} && \ apt install -y ./mount-s3.deb && \ diff --git a/lib/serve/mcp-workbench/Dockerfile b/lib/serve/mcp-workbench/Dockerfile index 7ddc3a087..4be972214 100644 --- a/lib/serve/mcp-workbench/Dockerfile +++ b/lib/serve/mcp-workbench/Dockerfile @@ -36,11 +36,12 @@ RUN unzip -d /tmp /tmp/$(basename $RCLONE_SOURCE) && \ # Copy and install the MCP workbench package COPY . /workspace/mcpworkbench-src/ -RUN pip install -e /workspace/mcpworkbench-src/ +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -e /workspace/mcpworkbench-src/ # Install additional requirements COPY requirements.txt /workspace/ -RUN pip install -r /workspace/requirements.txt +RUN pip install --no-cache-dir -r /workspace/requirements.txt # Copy s6-overlay service definitions COPY s6-overlay/services.d /etc/services.d diff --git a/lib/serve/mcp-workbench/pyproject.toml b/lib/serve/mcp-workbench/pyproject.toml index 5dfffe655..558830536 100644 --- a/lib/serve/mcp-workbench/pyproject.toml +++ b/lib/serve/mcp-workbench/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "gunicorn==23.0.0", "pydantic==2.8.2", "PyJWT==2.9.0", - "requests==2.32.4", + "requests==2.32.5", "fastapi==0.115.11", "fastapi_utils==0.7.0", "loguru==0.7.2" diff --git a/lib/serve/rest-api/Dockerfile b/lib/serve/rest-api/Dockerfile index ceb768db6..325aefff3 100644 --- a/lib/serve/rest-api/Dockerfile +++ b/lib/serve/rest-api/Dockerfile @@ -1,6 +1,13 @@ -ARG BASE_IMAGE +ARG BASE_IMAGE=python:3.11 FROM ${BASE_IMAGE} +# Install build dependencies for madoka package +RUN apt-get update && apt-get install -y \ + gcc \ + g++ \ + make \ + && rm -rf /var/lib/apt/lists/* + # Copy LiteLLM config directly out of the LISA config.yaml file ARG LITELLM_CONFIG @@ -23,8 +30,7 @@ COPY src/ ./src COPY TIKTOKEN_CACHE ./TIKTOKEN_CACHE -# Generate the prisma binary -RUN prisma generate +# LiteLLM will handle Prisma setup at runtime with --use_prisma_db_push flag # Copy LiteLLM config directly to container, it will be updated at runtime # with LISA-hosted models. This filename is expected in the entrypoint.sh file, so do not modify diff --git a/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py b/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py index 163e8303b..454fd1ced 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py @@ -27,7 +27,7 @@ router = APIRouter() -@router.post(f"/{RestApiResource.EMBEDDINGS.value}") +@router.post(f"/{RestApiResource.EMBEDDINGS}") async def embeddings(request: EmbeddingsRequest) -> JSONResponse: """Text embeddings.""" response = await handle_embeddings(request.dict()) diff --git a/lib/serve/rest-api/src/api/endpoints/v1/generation.py b/lib/serve/rest-api/src/api/endpoints/v1/generation.py index 6413035c8..d80ff6f14 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/generation.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/generation.py @@ -33,7 +33,7 @@ router = APIRouter() -@router.post(f"/{RestApiResource.GENERATE.value}") +@router.post(f"/{RestApiResource.GENERATE}") async def generate(request: GenerateRequest) -> JSONResponse: """Text generation.""" response = await handle_generate(request.dict()) @@ -41,7 +41,7 @@ async def generate(request: GenerateRequest) -> JSONResponse: return JSONResponse(content=response, status_code=200) -@router.post(f"/{RestApiResource.GENERATE_STREAM.value}") +@router.post(f"/{RestApiResource.GENERATE_STREAM}") async def generate_stream(request: GenerateStreamRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( @@ -50,7 +50,7 @@ async def generate_stream(request: GenerateStreamRequest) -> StreamingResponse: ) -@router.post(f"/{RestApiResource.OPENAI_CHAT_COMPLETIONS.value}") +@router.post(f"/{RestApiResource.OPENAI_CHAT_COMPLETIONS}") async def openai_chat_completion_generate_stream(request: OpenAIChatCompletionsRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( @@ -59,7 +59,7 @@ async def openai_chat_completion_generate_stream(request: OpenAIChatCompletionsR ) -@router.post(f"/{RestApiResource.OPENAI_COMPLETIONS.value}") +@router.post(f"/{RestApiResource.OPENAI_COMPLETIONS}") async def openai_completion_generate_stream(request: OpenAICompletionsRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( diff --git a/lib/serve/rest-api/src/api/endpoints/v1/models.py b/lib/serve/rest-api/src/api/endpoints/v1/models.py index e3d374552..3bcb353c0 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/models.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/models.py @@ -33,7 +33,7 @@ router = APIRouter() -@router.get(f"/{RestApiResource.DESCRIBE_MODEL.value}") +@router.get(f"/{RestApiResource.DESCRIBE_MODEL}") async def describe_model( provider: str = Query( None, @@ -52,7 +52,7 @@ async def describe_model( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.DESCRIBE_MODELS.value}") +@router.get(f"/{RestApiResource.DESCRIBE_MODELS}") async def describe_models( model_types: Optional[List[ModelType]] = Query( None, @@ -69,7 +69,7 @@ async def describe_models( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.LIST_MODELS.value}") +@router.get(f"/{RestApiResource.LIST_MODELS}") async def list_models( model_types: Optional[List[ModelType]] = Query( None, @@ -86,7 +86,7 @@ async def list_models( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.OPENAI_LIST_MODELS.value}") +@router.get(f"/{RestApiResource.OPENAI_LIST_MODELS}") async def openai_list_models() -> JSONResponse: """List models for OpenAI Compatibility. Only returns TEXTGEN models.""" response = await handle_openai_list_models() diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index 880eb2b60..1341b3469 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -25,7 +25,7 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from starlette.status import HTTP_401_UNAUTHORIZED -from ....auth import get_authorization_token, get_jwks_client, id_token_is_valid, is_idp_used, is_user_in_group +from ....auth import Authorizer # Local LiteLLM installation URL. By default, LiteLLM runs on port 4000. Change the port here if the # port was changed as part of the LiteLLM startup in entrypoint.sh @@ -102,40 +102,10 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: litellm_path = f"{LITELLM_URL}/{api_path}" headers = dict(request.headers.items()) - if not is_valid_management_token(headers): - # If not handling an OpenAI request, we will also check if the user is an Admin user before allowing the - # request, otherwise, we will block it. This prevents non-admins from invoking model management APIs - # directly. If LISA Serve is deployed without an IdP configuration, we cannot determine who is an admin - # user, so all API routes will default to being openly accessible. - if is_idp_used(): - client_id = os.environ.get("CLIENT_ID", "") - authority = os.environ.get("AUTHORITY", "") - admin_group = os.environ.get("ADMIN_GROUP", "") - user_group = os.environ.get("USER_GROUP", "") - jwt_groups_property = os.environ.get("JWT_GROUPS_PROP", "") - - id_token = get_authorization_token(headers=headers, header_name="Authorization") - jwks_client = get_jwks_client() - if jwt_data := id_token_is_valid( - id_token=id_token, authority=authority, client_id=client_id, jwks_client=jwks_client - ): - if user_group != "" and not is_user_in_group( - jwt_data=jwt_data, group=user_group, jwt_groups_property=jwt_groups_property - ): - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough" - ) - if api_path not in OPENAI_ROUTES: - if not is_user_in_group( - jwt_data=jwt_data, group=admin_group, jwt_groups_property=jwt_groups_property - ): - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough" - ) - else: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough" - ) + authorizer = Authorizer() + require_admin = api_path not in OPENAI_ROUTES + if not await authorizer.can_access(request, require_admin): + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough") # At this point in the request, we have already validated auth with IdP or persistent token. By using LiteLLM for # model management, LiteLLM requires an admin key, and that forces all requests to require a key as well. To avoid @@ -154,32 +124,6 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: return StreamingResponse(generate_response(response.iter_lines()), status_code=response.status_code) else: # not a streaming request response = requests.request(method=http_method, url=litellm_path, json=params, headers=headers) + if response.status_code != 200: + logger.error(f"LiteLLM error response: {response.text}") return JSONResponse(response.json(), status_code=response.status_code) - - -def refresh_management_tokens() -> list[str]: - """Return secret management tokens if they exist.""" - secret_tokens = [] - - try: - secret_tokens.append( - secrets_manager.get_secret_value(SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSCURRENT")[ - "SecretString" - ] - ) - secret_tokens.append( - secrets_manager.get_secret_value( - SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSPREVIOUS" - )["SecretString"] - ) - except Exception: - logger.info(f"No previous secret version for {os.environ.get('MANAGEMENT_KEY_NAME')}") - - return secret_tokens - - -def is_valid_management_token(headers: dict[str, str]) -> bool: - """Return if API Token from request headers is valid if found.""" - secret_tokens = refresh_management_tokens() - token = get_authorization_token(headers=headers, header_name="Authorization").strip() - return token in secret_tokens diff --git a/lib/serve/rest-api/src/api/routes.py b/lib/serve/rest-api/src/api/routes.py index ca79631ae..08e052796 100644 --- a/lib/serve/rest-api/src/api/routes.py +++ b/lib/serve/rest-api/src/api/routes.py @@ -20,20 +20,20 @@ from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse -from ..auth import OIDCHTTPBearer +from ..auth import Authorizer from .endpoints.v2 import litellm_passthrough logger = logging.getLogger(__name__) router = APIRouter() -if os.getenv("USE_AUTH", "true").lower() == "false": - dependencies = [] - logger.info("Auth disabled") -else: - security = OIDCHTTPBearer() - dependencies = [Depends(security)] +dependencies = [] +if os.getenv("USE_AUTH", "true").lower() == "true": logger.info("Auth enabled") + security = Authorizer() + dependencies = [Depends(security)] +else: + logger.info("Auth disabled") router.include_router( litellm_passthrough.router, prefix="/v2/serve", tags=["litellm_passthrough"], dependencies=dependencies @@ -46,6 +46,17 @@ async def health_check() -> JSONResponse: This needs to match the path in the config.yaml file. """ - content = {"status": "OK"} - - return JSONResponse(content=content, status_code=200) + try: + # Basic health verification - check if required environment variables are set + required_vars = ["AWS_REGION", "LOG_LEVEL"] + missing_vars = [var for var in required_vars if not os.getenv(var)] + + if missing_vars: + content = {"status": "UNHEALTHY", "missing_env_vars": missing_vars} + return JSONResponse(content=content, status_code=503) + + content = {"status": "OK"} + return JSONResponse(content=content, status_code=200) + except Exception as e: + content = {"status": "UNHEALTHY", "error": str(e)} + return JSONResponse(content=content, status_code=503) diff --git a/lib/serve/rest-api/src/auth.py b/lib/serve/rest-api/src/auth.py index c83d86392..07f3bc07f 100644 --- a/lib/serve/rest-api/src/auth.py +++ b/lib/serve/rest-api/src/auth.py @@ -13,27 +13,28 @@ # limitations under the License. """Authentication for FastAPI app.""" +import asyncio import os import ssl import sys +import threading from datetime import datetime +from enum import Enum from pathlib import Path -from time import time from typing import Any, Dict, Optional import boto3 import jwt import requests +from cachetools import TTLCache +from cachetools.keys import hashkey from fastapi import HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from loguru import logger from starlette.status import HTTP_401_UNAUTHORIZED -# The following are field names, not passwords or tokens -API_KEY_HEADER_NAMES = [ - "Authorization", # OpenAI Bearer token format, collides with IdP, but that's okay for this use case - "Api-Key", # pragma: allowlist secret # Azure key format, can be used with Continue IDE plugin -] +from .utils.decorators import singleton + TOKEN_EXPIRATION_NAME = "tokenExpiration" # nosec B105 TOKEN_TABLE_NAME = "TOKEN_TABLE_NAME" # nosec B105 USE_AUTH = "USE_AUTH" @@ -51,9 +52,17 @@ ) -def is_idp_used() -> bool: - """Get if the identity provider is being used based on environment variable.""" - return os.environ.get(USE_AUTH, "false").lower() == "true" +# The following are field names, not passwords or tokens +class AuthHeaders(str, Enum): + """API key header names.""" + + AUTHORIZATION = "Authorization" # OpenAI Bearer token format, collides with IdP, but that's okay for this use case + API_KEY = "Api-Key" # pragma: allowlist secret # Azure key format, can be used with Continue IDE plugin + + @classmethod + def values(cls) -> list[str]: + """Return list of header values.""" + return list(cls) if not jwt.algorithms.has_crypto: @@ -106,13 +115,13 @@ def id_token_is_valid( }, ) return data - except jwt.exceptions.PyJWTError as e: + except (jwt.exceptions.PyJWTError, jwt.exceptions.DecodeError) as e: logger.exception(e) return None def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: str) -> bool: - """Check if the user is an admin.""" + """Check if the user is in group.""" props = jwt_groups_property.split(".") current_node = jwt_data for prop in props: @@ -123,7 +132,7 @@ def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: return group in current_node -def get_authorization_token(headers: Dict[str, str], header_name: str) -> str: +def get_authorization_token(headers: Dict[str, str], header_name: str = AuthHeaders.AUTHORIZATION) -> str: """Get Bearer token from Authorization headers if it exists.""" if header_name in headers: return headers.get(header_name, "").removeprefix("Bearer").strip() @@ -133,29 +142,37 @@ def get_authorization_token(headers: Dict[str, str], header_name: str) -> str: class OIDCHTTPBearer(HTTPBearer): """OIDC based bearer token authenticator.""" - def __init__(self, **kwargs: Dict[str, Any]): + def __init__(self, authority: Optional[str] = None, client_id: Optional[str] = None, **kwargs: Dict[str, Any]): super().__init__(**kwargs) - self._token_authorizer = ApiTokenAuthorizer() - self._management_token_authorizer = ManagementTokenAuthorizer() - - self._jwks_client = get_jwks_client() + self.authority = authority or os.environ.get("AUTHORITY", "") + self.client_id = client_id or os.environ.get("CLIENT_ID", "") + self.jwks_client = get_jwks_client() - async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]: - """Verify the provided bearer token or API Key. API Key will take precedence over the bearer token.""" - if self._token_authorizer.is_valid_api_token(request.headers): - return None # valid API token, not continuing with OIDC auth - elif self._management_token_authorizer.is_valid_api_token(request.headers): - logger.info("looks like a valid mgmt token") - return None # valid management token, not continuing with OIDC auth + async def id_token_is_valid(self, request: Request) -> Optional[Dict[str, Any]]: + """Check whether an ID token is valid and return decoded data.""" http_auth_creds = await super().__call__(request) - if not id_token_is_valid( - id_token=http_auth_creds.credentials, - authority=os.environ["AUTHORITY"], - client_id=os.environ["CLIENT_ID"], - jwks_client=self._jwks_client, - ): - raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated") - return http_auth_creds + id_token = http_auth_creds.credentials + try: + signing_key = self.jwks_client.get_signing_key_from_jwt(id_token) + data: Dict[str, Any] = jwt.decode( + id_token, + signing_key.key, + algorithms=["RS256"], + issuer=self.authority, + audience=self.client_id, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "verify_aud": True, + "verify_iss": True, + }, + ) + return data + except (jwt.exceptions.PyJWTError, jwt.exceptions.DecodeError) as e: + logger.exception(e) + return None class ApiTokenAuthorizer: @@ -175,12 +192,14 @@ def _get_token_info(self, token: str) -> Any: ddb_response = self._token_table.get_item(Key={"token": token}, ReturnConsumedCapacity="NONE") return ddb_response.get("Item", None) - def is_valid_api_token(self, headers: Dict[str, str]) -> bool: + async def is_valid_api_token(self, headers: Dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" - for header_name in API_KEY_HEADER_NAMES: - token = headers.get(header_name, "").removeprefix("Bearer").strip() + + for header_name in AuthHeaders.values(): + token = get_authorization_token(headers, header_name) + if token: - token_info = self._get_token_info(token) + token_info = await asyncio.to_thread(self._get_token_info, token) if token_info: token_expiration = int(token_info.get(TOKEN_EXPIRATION_NAME, datetime.max.timestamp())) current_time = int(datetime.now().timestamp()) @@ -193,33 +212,142 @@ class ManagementTokenAuthorizer: """Class for checking Management tokens against a SecretsManager secret.""" def __init__(self) -> None: - self._secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) - self._secret_tokens: list[str] = [] - self._last_run = 0 - - def _refreshTokens(self) -> None: - """Refresh secret management tokens.""" - current_time = int(time()) - if current_time - (self._last_run or 0) > 3600: - secret_tokens = [] + self._cache = TTLCache(maxsize=1, ttl=300) + self._cache_lock = threading.RLock() + self._local = threading.local() + + def _get_secrets_client(self): + """Get thread-local secrets manager client.""" + if not hasattr(self._local, "secrets_manager"): + self._local.secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) + return self._local.secrets_manager + + def get_management_tokens(self) -> list[str]: + """Return secret management tokens if they exist.""" + cache_key = hashkey() + + with self._cache_lock: + if cache_key in self._cache: + return self._cache[cache_key] + + logger.info("Updating management tokens cache") + secret_tokens = [] + secret_id = os.environ.get("MANAGEMENT_KEY_NAME") + secrets_manager = self._get_secrets_client() + + try: secret_tokens.append( - self._secrets_manager.get_secret_value( - SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSCURRENT" - )["SecretString"] + secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSCURRENT")["SecretString"] ) - try: - secret_tokens.append( - self._secrets_manager.get_secret_value( - SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSPREVIOUS" - )["SecretString"] - ) - except Exception: - logger.info(f"No previous secret version for {os.environ.get('MANAGEMENT_KEY_NAME')}") - self._secret_tokens = secret_tokens - self._last_run = current_time + secret_tokens.append( + secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSPREVIOUS")["SecretString"] + ) + except Exception: + logger.info(f"No previous secret version for {secret_id}") - def is_valid_api_token(self, headers: Dict[str, str]) -> bool: + with self._cache_lock: + self._cache[cache_key] = secret_tokens + + return secret_tokens + + async def is_valid_api_token(self, headers: Dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" - self._refreshTokens() - token = headers.get("Authorization", "").strip() - return token in self._secret_tokens + secret_tokens = await asyncio.to_thread(self.get_management_tokens) + token = get_authorization_token(headers) + return token in secret_tokens + + +@singleton +class Authorizer: + """Composite authenticator that tries multiple authentication methods in order.""" + + def __init__(self) -> None: + self.client_id = os.environ.get("CLIENT_ID", "") + self.authority = os.environ.get("AUTHORITY", "") + self.admin_group = os.environ.get("ADMIN_GROUP", "") + self.user_group = os.environ.get("USER_GROUP", "") + self.jwt_groups_property = os.environ.get("JWT_GROUPS_PROP", "") + + self.token_authorizer = ApiTokenAuthorizer() + self.management_token_authorizer = ManagementTokenAuthorizer() + self.oidc_authorizer = OIDCHTTPBearer(authority=self.authority, client_id=self.client_id) + + async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]: + jwt_data = await self.authenticate_request(request) + return jwt_data + + async def authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]: + """Authenticate request and return JWT data if valid, else None. Invalid requests throw an exception""" + + logger.trace(f"Authenticating request: {request.method} {request.url.path}") + # First try API tokens + logger.trace("Try API Auth Token...") + if await self.token_authorizer.is_valid_api_token(request.headers): + logger.trace("Valid API token") + return None + + # Then try management tokens + logger.trace("Try Management Auth Token...") + if await self.management_token_authorizer.is_valid_api_token(request.headers): + logger.trace("Valid Management token") + return None + + # Finally try OIDC Bearer tokens + logger.trace("Try OIDC Auth Token...") + jwt_data = await self.oidc_authorizer.id_token_is_valid(request) + if jwt_data: + logger.trace("Valid OIDC token") + return jwt_data + + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated") + + def _log_access_attempt( + self, request: Request, auth_method: str, user_id: str, endpoint: str, success: bool, reason: str = "" + ) -> None: + """Centralized logging for all authentication attempts.""" + status = "SUCCESS" if success else "FAILED" + log_msg = f"AUTH {status}: user={user_id} method={auth_method} endpoint={endpoint}" + if reason: + log_msg += f" reason={reason}" + + if success: + logger.info(log_msg) + else: + logger.warning(log_msg) + + async def can_access( + self, request: Request, require_admin: bool, jwt_data: Optional[Dict[str, Any]] = None + ) -> bool: + """Return whether the user is authorized to access the endpoint.""" + endpoint = f"{request.method} {request.url.path}" + + if jwt_data is None: + jwt_data = await self.authenticate_request(request) + + # Valid API_TOKEN will be treated as admin + if not jwt_data: + auth_method = "API_TOKEN" + user_id = "api_user" + has_access = True + reason = "Valid API/Management token" + else: + auth_method = "OIDC" + user_id = jwt_data.get("sub", jwt_data.get("username", "unknown")) + + # If user is admin, always allow access + if is_user_in_group(jwt_data, self.admin_group, self.jwt_groups_property): + has_access = True + reason = "Admin user" + # If admin is required but user is not admin, deny access + elif require_admin: + has_access = False + reason = "Admin required" + # For non-admin requests, check user group + else: + has_access = self.user_group == "" or is_user_in_group( + jwt_data=jwt_data, group=self.user_group, jwt_groups_property=self.jwt_groups_property + ) + reason = "Valid user group" if has_access else "Invalid user group" + + self._log_access_attempt(request, auth_method, user_id, endpoint, has_access, reason) + return has_access diff --git a/lib/serve/rest-api/src/entrypoint.sh b/lib/serve/rest-api/src/entrypoint.sh index 15fe25eda..2cc753bec 100644 --- a/lib/serve/rest-api/src/entrypoint.sh +++ b/lib/serve/rest-api/src/entrypoint.sh @@ -5,22 +5,65 @@ set -e HOST="0.0.0.0" PORT="8080" -# Prisma client is generated during build +echo "Starting LISA REST API Service" +echo "==================================" + +# Prisma client is now generated during build from LiteLLM's schema echo "Prisma client already generated during build" # Update LiteLLM config that was already copied from config.yaml with runtime-deployed models. # Depends on SSM Parameter for registered models. -echo "Configuring and starting LiteLLM" -# litellm_config.yaml is generated from the REST API Dockerfile from the LISA config.yaml. -# Do not modify the litellm_config.yaml name unless you change it in the Dockerfile and in the `litellm` command below. -python ./src/utils/generate_litellm_config.py -f litellm_config.yaml +echo "πŸ”§ Configuring LiteLLM..." +echo " - AWS Region: $AWS_REGION" +echo " - Models Parameter: $REGISTERED_MODELS_PS_NAME" +echo " - DB Info Parameter: $LITELLM_DB_INFO_PS_NAME" + +# Generate LiteLLM config with error handling +if ! python ./src/utils/generate_litellm_config.py -f litellm_config.yaml; then + echo "❌ Failed to generate LiteLLM configuration" + echo " This usually indicates issues with:" + echo " - SSM parameter access permissions" + echo " - Database connection parameters" + echo " - Model registration data" + exit 1 +fi + +echo "βœ… LiteLLM configuration generated successfully" + +# Verify config file exists and has content +if [ ! -f "litellm_config.yaml" ]; then + echo "❌ LiteLLM config file not found after generation" + exit 1 +fi + +echo "πŸ“„ LiteLLM config file contents:" +echo "--------------------------------" +head -20 litellm_config.yaml +echo "--------------------------------" + +# Start LiteLLM in the background with better error handling +echo "πŸš€ Starting LiteLLM server..." +echo " - Config file: litellm_config.yaml" +echo " - Port: 4000 (internal)" +echo " - Database: Prisma with auto-push enabled" + +# Start LiteLLM and capture its PID +litellm -c litellm_config.yaml --use_prisma_db_push > litellm.log 2>&1 & +LITELLM_PID=$! + +echo " - LiteLLM PID: $LITELLM_PID" +echo " - Log file: litellm.log" + +# LiteLLM is starting in the background, proceed with Gunicorn startup -# Start LiteLLM in the background, default port 4000, not exposed outside of container. -# If you need to change the port, you can specify the --port option, and then the port needs to be updated in -# src/api/endpoints/v2/litellm_passthrough.py for the LiteLLM URI -litellm -c litellm_config.yaml & +# Validate THREADS variable with default value +THREADS=${THREADS:-4} +echo "πŸš€ Starting Gunicorn with $THREADS workers..." -echo "Starting Gunicorn with $THREADS workers..." +# Start Gunicorn with Uvicorn workers +echo " - Host: $HOST" +echo " - Port: $PORT" +echo " - Workers: $THREADS" +echo " - Timeout: 600 seconds" -# Start Gunicorn with Uvicorn workers. exec gunicorn -k uvicorn.workers.UvicornWorker -t 600 -w "$THREADS" -b "$HOST:$PORT" "src.main:app" diff --git a/lib/serve/rest-api/src/handlers/generation.py b/lib/serve/rest-api/src/handlers/generation.py index bf35adb86..313b3781f 100644 --- a/lib/serve/rest-api/src/handlers/generation.py +++ b/lib/serve/rest-api/src/handlers/generation.py @@ -26,9 +26,12 @@ async def handle_generate(request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle for generate endpoint.""" model, model_kwargs, text = await validate_and_prepare_llm_request(request_data, RestApiResource.GENERATE) - response = await model.generate(text=text, model_kwargs=model_kwargs) - - return response.dict() # type: ignore + try: + response = await model.generate(text=text, model_kwargs=model_kwargs) + return response.dict() # type: ignore + except Exception as e: + logger.error(f"Model generation failed: {e}") + raise @handle_stream_exceptions diff --git a/lib/serve/rest-api/src/handlers/models.py b/lib/serve/rest-api/src/handlers/models.py index 7955c1f83..240dddc02 100644 --- a/lib/serve/rest-api/src/handlers/models.py +++ b/lib/serve/rest-api/src/handlers/models.py @@ -115,10 +115,10 @@ async def handle_describe_models(model_types: List[ModelType]) -> DefaultDict[st response: DefaultDict[str, DefaultDict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) for model_type, providers in registered_models.items(): - response[model_type.value] = {} # type: ignore + response[model_type] = {} # type: ignore providers = providers or {} for provider, model_names in providers.items(): - response[model_type.value][provider] = [ + response[model_type][provider] = [ registered_models_cache["metadata"][f"{provider}.{model_name}"] for model_name in model_names ] # type: ignore diff --git a/lib/serve/rest-api/src/lisa_serve/__init__.py b/lib/serve/rest-api/src/lisa_serve/__init__.py index 51c90fc4e..c25b6dd2e 100644 --- a/lib/serve/rest-api/src/lisa_serve/__init__.py +++ b/lib/serve/rest-api/src/lisa_serve/__init__.py @@ -30,7 +30,7 @@ { "sink": sys.stdout, "format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | " - "{request_id} | {message}", + "{extra[request_id]} | {message}", "level": logger_level.upper(), } ] diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py index 849d9fc7a..b4b110808 100644 --- a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py +++ b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py @@ -76,18 +76,22 @@ async def embed_query(self, *, text: str, model_kwargs: Dict[str, Any]) -> Embed "text": text, } - async with ClientSession() as session: - async with session.post(self.endpoint_url, json=payload) as server_response: - server_response.raise_for_status() - server_response_json = await server_response.json() - - response = EmbedQueryResponse(embeddings=server_response_json) - - logger.debug( - f"Received: {escape_curly_brackets(response.json())}", - extra={"event": f"{self.__class__.__name__}:embed_query"}, - ) - return response + try: + async with ClientSession() as session: + async with session.post(self.endpoint_url, json=payload) as server_response: + server_response.raise_for_status() + server_response_json = await server_response.json() + + response = EmbedQueryResponse(embeddings=server_response_json) + + logger.debug( + f"Received: {escape_curly_brackets(response.json())}", + extra={"event": f"{self.__class__.__name__}:embed_query"}, + ) + return response + except Exception as e: + logger.error(f"Embedding request failed: {e}") + raise # Register the model diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py index a1415be92..bcd92224e 100644 --- a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py +++ b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py @@ -195,10 +195,12 @@ async def openai_generate_stream( AsyncGenerator[GenerateStreamResponse, None] Text generation model response with streaming. """ - request = {"prompt": text, **model_kwargs} + # Generate static values once before streaming resp_id = str(uuid.uuid4()) fingerprint = str(uuid.uuid4()) created = int(time.time()) + + request = {"prompt": text, **model_kwargs} if is_text_completion: response_class = OpenAICompletionsResponse else: diff --git a/lib/serve/rest-api/src/lisa_serve/registry/index.py b/lib/serve/rest-api/src/lisa_serve/registry/index.py index 5b0a56f8f..77e23af74 100644 --- a/lib/serve/rest-api/src/lisa_serve/registry/index.py +++ b/lib/serve/rest-api/src/lisa_serve/registry/index.py @@ -45,6 +45,6 @@ def get_assets(self, provider: str) -> Dict[str, Any]: except KeyError: raise KeyError( f"Model provider '{provider}' not found in registry. Available providers: " - f"{', '.join(list(self.registry))}" + f"{', '.join(self.registry)}" ) return model_assets # type: ignore diff --git a/lib/serve/rest-api/src/requirements.txt b/lib/serve/rest-api/src/requirements.txt index 8230b4bab..fcbcaf2b0 100644 --- a/lib/serve/rest-api/src/requirements.txt +++ b/lib/serve/rest-api/src/requirements.txt @@ -1,17 +1,20 @@ -aioboto3>=12.0.0,<15.0.0 -aiobotocore>=2.11.0,<3.0.0 +aioboto3==13.4.0 +aiobotocore==2.18.0 aiohttp==3.12.14 -boto3>=1.34.0,<1.37.0 +backoff==2.2.1 +boto3==1.36.0 +cachetools==5.5.0 click==8.1.7 -cryptography>=43.0.1,<44.0.0 +cryptography==44.0.1 fastapi==0.115.11 fastapi_utils==0.7.0 gunicorn==23.0.0 -litellm[proxy]==1.72.4 +litellm[proxy]==1.77.4 loguru==0.7.2 pydantic==2.8.2 -PyJWT==2.9.0 +PyJWT==2.10.1 text-generation==0.7.0 prisma==0.13.1 pynacl==1.5.0 +starlette==0.46.2 uvicorn==0.29.0 diff --git a/lib/serve/rest-api/src/utils/cache_manager.py b/lib/serve/rest-api/src/utils/cache_manager.py index 44ff5d749..3c94bbace 100644 --- a/lib/serve/rest-api/src/utils/cache_manager.py +++ b/lib/serve/rest-api/src/utils/cache_manager.py @@ -13,6 +13,7 @@ # limitations under the License. """Model Cache Utilities.""" +import threading from typing import Any, Dict, Optional, Tuple from .resources import ModelType, RestApiResource @@ -33,24 +34,31 @@ } MODEL_ASSETS_CACHE: Dict[str, Tuple[Any, Any]] = {} +# Thread locks for cache operations +_REGISTERED_MODELS_LOCK = threading.RLock() +_MODEL_ASSETS_LOCK = threading.RLock() + def get_registered_models_cache() -> Dict[str, Dict[str, Any]]: """Get the cache containing the registered models.""" - return REGISTERED_MODELS_CACHE + with _REGISTERED_MODELS_LOCK: + return REGISTERED_MODELS_CACHE.copy() def get_model_assets(model_key: str) -> Optional[Tuple[Any, Any]]: """Get the cache belonging to the model assets.""" - return MODEL_ASSETS_CACHE.get(model_key, None) + with _MODEL_ASSETS_LOCK: + return MODEL_ASSETS_CACHE.get(model_key) def cache_model_assets(key: str, model_assets: Tuple[Any, Any]) -> None: """Cache the specified model assets for the specified key.""" - global MODEL_ASSETS_CACHE - MODEL_ASSETS_CACHE[key] = model_assets + with _MODEL_ASSETS_LOCK: + MODEL_ASSETS_CACHE[key] = model_assets def set_registered_models_cache(models: Dict[str, Dict[str, Any]]) -> None: """Set the registered model cache to the specified models value.""" - global REGISTERED_MODELS_CACHE - REGISTERED_MODELS_CACHE = models + with _REGISTERED_MODELS_LOCK: + global REGISTERED_MODELS_CACHE + REGISTERED_MODELS_CACHE = models diff --git a/lib/serve/rest-api/src/utils/decorators.py b/lib/serve/rest-api/src/utils/decorators.py new file mode 100644 index 000000000..a550aa44c --- /dev/null +++ b/lib/serve/rest-api/src/utils/decorators.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility decorators.""" +from typing import Any, Callable, cast, Dict, TypeVar + +T = TypeVar("T") + + +def singleton(cls: type[T]) -> Callable[..., T]: + """Singleton decorator.""" + instances: Dict[type, Any] = {} + + def get_instance(*args: Any, **kwargs: Any) -> T: + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return cast(T, instances[cls]) + + return get_instance diff --git a/lib/serve/rest-api/src/utils/generate_litellm_config.py b/lib/serve/rest-api/src/utils/generate_litellm_config.py index 0aa046aff..6217cf3d6 100644 --- a/lib/serve/rest-api/src/utils/generate_litellm_config.py +++ b/lib/serve/rest-api/src/utils/generate_litellm_config.py @@ -23,14 +23,13 @@ import yaml from rds_auth import generate_auth_token, get_lambda_role_name -ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"]) -secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) - @click.command() @click.option("-f", "--filepath", type=click.Path(exists=True, file_okay=True, dir_okay=False, writable=True)) def generate_config(filepath: str) -> None: """Read LiteLLM configuration and rewrite it with LISA-deployed model information.""" + ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"]) + with open(filepath, "r") as fp: config_contents = yaml.safe_load(fp) # Get and load registered models from ParameterStore @@ -43,8 +42,6 @@ def generate_config(filepath: str) -> None: "litellm_params": { "model": f"openai/{model['modelName']}", "api_base": model["endpointUrl"] + "/v1", # Local containers require the /v1 for OpenAI API routing. - # the following is an unused placeholder to avoid LiteLLM deployment failures - "api_key": "ignored", # pragma: allowlist secret }, } for model in registered_models @@ -92,6 +89,7 @@ def get_database_credentials(db_params: dict[str, str]) -> Tuple: """Get database password from Secrets Manager or using IAM auth.""" if "passwordSecretId" in db_params: + secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) secret_response = secrets_client.get_secret_value(SecretId=db_params["passwordSecretId"]) secret = json.loads(secret_response["SecretString"]) return (db_params["username"], secret["password"]) diff --git a/lib/serve/rest-api/src/utils/request_utils.py b/lib/serve/rest-api/src/utils/request_utils.py index a611d0fa1..cc4123e61 100644 --- a/lib/serve/rest-api/src/utils/request_utils.py +++ b/lib/serve/rest-api/src/utils/request_utils.py @@ -78,12 +78,17 @@ async def validate_model(request_data: Dict[str, Any], resource: RestApiResource registered_models_cache = get_registered_models_cache() supported_models = registered_models_cache[resource][provider] if model_name not in supported_models: + # Sanitize inputs for logging to prevent log injection + safe_model_name = str(model_name).replace("\n", "").replace("\r", "") + safe_resource = str(resource).replace("\n", "").replace("\r", "") + safe_supported = str(supported_models).replace("\n", "").replace("\r", "") + message = ( - f"Provider does not support model {model_name} for endpoint " - f"/{resource}, expected one of: {supported_models}" + f"Provider does not support model {safe_model_name} for endpoint " + f"/{safe_resource}, expected one of: {safe_supported}" ) logger.error(message, extra={"event": event, "status": "ERROR"}) - raise Exception(message) + raise ValueError(message) async def get_model_and_validator(request_data: Dict[str, Any]) -> Tuple[Any, Any]: @@ -113,7 +118,10 @@ async def get_model_and_validator(request_data: Dict[str, Any]) -> Tuple[Any, An # Retrieve model endpoint URL registered_models_cache = get_registered_models_cache() - endpoint_url = registered_models_cache["endpointUrls"][model_key] + try: + endpoint_url = registered_models_cache["endpointUrls"][model_key] + except KeyError: + raise KeyError(f"Model endpoint URL not found for {model_key}") # Instantiate the model model = adapter(model_name=model_name, endpoint_url=endpoint_url) @@ -158,7 +166,11 @@ async def validate_and_prepare_llm_request( task_logger.debug("Finish task", status="FINISH") - return model, model_kwargs.dict(), request_data["text"] + text = request_data.get("text") + if text is None: + raise ValueError("Missing required field: text") + + return model, model_kwargs.dict(), text def handle_stream_exceptions( diff --git a/lib/serve/rest-api/src/utils/resources.py b/lib/serve/rest-api/src/utils/resources.py index 888863922..c8929ae7a 100644 --- a/lib/serve/rest-api/src/utils/resources.py +++ b/lib/serve/rest-api/src/utils/resources.py @@ -170,7 +170,7 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - echo: Optional[int] = Field(False, description="Whether to prepend the prompt to the generated text.") + echo: Optional[bool] = Field(False, description="Whether to prepend the prompt to the generated text.") frequency_penalty: Optional[float] = Field(None, description="Penalty to add for text repetition.") logit_bias: Optional[Dict[Any, Any]] = Field( None, diff --git a/lib/serve/serveApplicationConstruct.ts b/lib/serve/serveApplicationConstruct.ts index a34c3edbf..26cfa9fbf 100644 --- a/lib/serve/serveApplicationConstruct.ts +++ b/lib/serve/serveApplicationConstruct.ts @@ -316,9 +316,16 @@ export class LisaServeApplicationConstruct extends Construct { // Add parameter as container environment variable for both RestAPI and RagAPI restApi.containers.forEach((container) => { container.addEnvironment('REGISTERED_MODELS_PS_NAME', this.modelsPs.parameterName); + container.addEnvironment('LITELLM_DB_INFO_PS_NAME', litellmDbConnectionInfoPs.parameterName); }); restApi.node.addDependency(this.modelsPs); + restApi.node.addDependency(litellmDbConnectionInfoPs); + restApi.node.addDependency(this.endpointUrl); + // Update + this.restApi = restApi; + + // Grant permissions after restApi is fully constructed // Additional permissions for REST API Role const invocation_permissions = new Policy(scope, 'ModelInvokePerms', { statements: [ @@ -344,13 +351,14 @@ export class LisaServeApplicationConstruct extends Construct { }), ] }); - letIfDefined(restApi.taskRoles[ECSTasks.REST], (serveRole) => { - this.modelsPs.grantRead(serveRole); - serveRole.attachInlinePolicy(invocation_permissions); - }); - // Update - this.restApi = restApi; + // Grant SSM parameter read access and attach invocation permissions + const restRole = restApi.taskRoles[ECSTasks.REST]; + if (restRole) { + this.modelsPs.grantRead(restRole); + litellmDbConnectionInfoPs.grantRead(restRole); + restRole.attachInlinePolicy(invocation_permissions); + } } getIAMAuthLambda (scope: Stack, config: Config, secret: ISecret, user: string, vpc: Vpc, securityGroups: ISecurityGroup[]): IFunction { diff --git a/lib/user-interface/react/package.json b/lib/user-interface/react/package.json index 5963765db..5a8414a0b 100644 --- a/lib/user-interface/react/package.json +++ b/lib/user-interface/react/package.json @@ -1,7 +1,7 @@ { "name": "lisa-web", "private": true, - "version": "5.3.0", + "version": "5.3.1", "type": "module", "scripts": { "dev": "vite", diff --git a/lib/user-interface/react/src/components/chatbot/Chat.tsx b/lib/user-interface/react/src/components/chatbot/Chat.tsx index 0c35973f1..523cb6d28 100644 --- a/lib/user-interface/react/src/components/chatbot/Chat.tsx +++ b/lib/user-interface/react/src/components/chatbot/Chat.tsx @@ -504,21 +504,26 @@ export default function Chat ({ sessionId }) { })); } + // Fetch RAG documents once if needed + let ragDocs = null; + if (useRag && !isImageGenerationMode) { + ragDocs = await fetchRelevantDocuments(userPrompt); + } + // Use extracted message builder utilities const messageContent = await buildMessageContent({ isImageGenerationMode, fileContext, useRag, userPrompt, - fetchRelevantDocuments, + ragDocs, }); const messageMetadata = await buildMessageMetadata({ isImageGenerationMode, useRag, - userPrompt, chatConfiguration, - fetchRelevantDocuments, + ragDocs, }); messages.push(new LisaChatMessage({ diff --git a/lib/user-interface/react/src/components/chatbot/utils/messageBuilder.utils.tsx b/lib/user-interface/react/src/components/chatbot/utils/messageBuilder.utils.tsx index a08966f15..de6b0da9a 100644 --- a/lib/user-interface/react/src/components/chatbot/utils/messageBuilder.utils.tsx +++ b/lib/user-interface/react/src/components/chatbot/utils/messageBuilder.utils.tsx @@ -21,7 +21,7 @@ export type MessageContentParams = { fileContext: string; useRag: boolean; userPrompt: string; - fetchRelevantDocuments?: (query: string) => Promise; + ragDocs?: any; }; export const buildMessageContent = async ({ @@ -29,7 +29,7 @@ export const buildMessageContent = async ({ fileContext, useRag, userPrompt, - fetchRelevantDocuments, + ragDocs, }: MessageContentParams) => { if (isImageGenerationMode) { return userPrompt; @@ -43,8 +43,7 @@ export const buildMessageContent = async ({ ]; } - if (useRag && fetchRelevantDocuments) { - const ragDocs = await fetchRelevantDocuments(userPrompt); + if (useRag && ragDocs) { return [ { type: 'text', text: 'File context: ' + formatDocumentsAsString(ragDocs.data?.docs) }, { type: 'text', text: userPrompt }, @@ -64,15 +63,13 @@ export const buildMessageContent = async ({ export const buildMessageMetadata = async ({ isImageGenerationMode, useRag, - userPrompt, chatConfiguration, - fetchRelevantDocuments, + ragDocs, }: { isImageGenerationMode: boolean; useRag: boolean; - userPrompt: string; chatConfiguration: any; - fetchRelevantDocuments?: (query: string) => Promise; + ragDocs?: any; }) => { const metadata: any = {}; @@ -81,8 +78,7 @@ export const buildMessageMetadata = async ({ metadata.imageGenerationSettings = chatConfiguration.sessionConfiguration.imageGenerationArgs; } - if (useRag && !isImageGenerationMode && fetchRelevantDocuments) { - const ragDocs = await fetchRelevantDocuments(userPrompt); + if (useRag && !isImageGenerationMode && ragDocs) { metadata.ragContext = formatDocumentsAsString(ragDocs.data?.docs, true); metadata.ragDocuments = formatDocumentTitlesAsString(ragDocs.data?.docs); } diff --git a/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx b/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx index 0f5127b34..e88a180e1 100644 --- a/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx +++ b/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx @@ -40,7 +40,8 @@ const advancedOptions = { viewMetaData: 'View chat meta-data', deleteSessionHistory: 'Delete Session History', editChatHistoryBuffer: 'Edit chat history buffer', - enableModelComparisonUtility: 'Enable Model Comparison Utility' + enableModelComparisonUtility: 'Enable Model Comparison Utility', + encryptSession: 'Enable Session Encryption', }; const mcpOptions = { diff --git a/lib/user-interface/react/src/components/model-management/ModelManagementActions.tsx b/lib/user-interface/react/src/components/model-management/ModelManagementActions.tsx index d01841811..d287c09cf 100644 --- a/lib/user-interface/react/src/components/model-management/ModelManagementActions.tsx +++ b/lib/user-interface/react/src/components/model-management/ModelManagementActions.tsx @@ -20,10 +20,7 @@ import { useAppDispatch, useAppSelector } from '@/config/store'; import { IModel, ModelStatus } from '@/shared/model/model-management.model'; import { useNotificationService } from '@/shared/util/hooks'; import { INotificationService } from '@/shared/notification/notification.service'; -import { - modelManagementApi, - useDeleteModelMutation, useUpdateModelMutation, -} from '@/shared/reducers/model-management.reducer'; +import { useDeleteModelMutation, useUpdateModelMutation} from '@/shared/reducers/model-management.reducer'; import { MutationTrigger } from '@reduxjs/toolkit/dist/query/react/buildHooks'; import { Action, ThunkDispatch } from '@reduxjs/toolkit'; import { setConfirmationModal } from '@/shared/reducers/modal.reducer'; @@ -38,6 +35,7 @@ export type ModelActionProps = { updateConfigMutation?: any; currentDefaultModel?: string; currentConfig?: any; + refetch?: () => void; }; function ModelActions (props: ModelActionProps): ReactElement { @@ -49,7 +47,7 @@ function ModelActions (props: ModelActionProps): ReactElement {