diff --git a/.github/workflows/code.ai-review.yml b/.github/workflows/code.ai-review.yml index 36a73d6de..cd9609472 100644 --- a/.github/workflows/code.ai-review.yml +++ b/.github/workflows/code.ai-review.yml @@ -143,7 +143,7 @@ jobs: }); - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v5 + uses: aws-actions/configure-aws-credentials@v6 with: aws-region: ${{ vars.AWS_REGION }} role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT }}:role/${{ vars.ROLE_NAME_TO_ASSUME }} diff --git a/.github/workflows/code.deploy.demo.yml b/.github/workflows/code.deploy.demo.yml index 7a324bbe0..ea4d508a9 100644 --- a/.github/workflows/code.deploy.demo.yml +++ b/.github/workflows/code.deploy.demo.yml @@ -20,9 +20,9 @@ jobs: environment: demo runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@cf459bd40262a8603163308e488de922e4eb5a95 # v4 + uses: aws-actions/configure-aws-credentials@6e631f05b2a5f53c9f1e27150d5e8af2f907b03b # v4 with: aws-region: ${{ vars.AWS_REGION }} role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT }}:role/${{ vars.ROLE_NAME_TO_ASSUME }} @@ -33,11 +33,11 @@ jobs: run: | echo "${{vars.CONFIG_YAML}}" > config-custom.yaml - name: Set up Python 3.13 - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v5 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v5 with: python-version: "3.13" - name: Use Node.js 24.x - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: 24.x - name: Install CDK dependencies diff --git a/.github/workflows/code.deploy.dev.yml b/.github/workflows/code.deploy.dev.yml index 96afd248e..be2b5a029 100644 --- a/.github/workflows/code.deploy.dev.yml +++ b/.github/workflows/code.deploy.dev.yml @@ -20,9 +20,9 @@ jobs: environment: dev runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@cf459bd40262a8603163308e488de922e4eb5a95 # v4 + uses: aws-actions/configure-aws-credentials@6e631f05b2a5f53c9f1e27150d5e8af2f907b03b # v4 with: aws-region: ${{ vars.AWS_REGION }} role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT }}:role/${{ vars.ROLE_NAME_TO_ASSUME }} @@ -33,11 +33,11 @@ jobs: run: | echo "${{vars.CONFIG_YAML}}" > config-custom.yaml - name: Set up Python 3.13 - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v5 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v5 with: python-version: "3.13" - name: Use Node.js 24.x - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: 24.x - name: Install CDK dependencies diff --git a/.github/workflows/code.draft-release-and-tag.yml b/.github/workflows/code.draft-release-and-tag.yml index 13a3d8eda..13132968c 100644 --- a/.github/workflows/code.draft-release-and-tag.yml +++ b/.github/workflows/code.draft-release-and-tag.yml @@ -16,7 +16,7 @@ jobs: if: (startsWith(github.event.pull_request.head.ref, 'release/' ) || startsWith(github.event.pull_request.head.ref, 'hotfix/')) && github.event.pull_request.merged == true && github.event.pull_request.base.ref == 'main' steps: - name: Checkout Source Tag - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 with: ref: main - name: Get Version diff --git a/.github/workflows/code.end-to-end-test.nightly.yml b/.github/workflows/code.end-to-end-test.nightly.yml index 1922b1622..8f7e6f889 100644 --- a/.github/workflows/code.end-to-end-test.nightly.yml +++ b/.github/workflows/code.end-to-end-test.nightly.yml @@ -28,11 +28,11 @@ jobs: runs-on: ubuntu-latest needs: notify_e2e_start steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 with: ref: develop - name: Setup Node.js - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: '24' cache: 'npm' @@ -47,7 +47,7 @@ jobs: run: npx cypress run --config-file cypress/cypress.e2e.config.ts - name: Archive Cypress videos & screenshots if: failure() || always() - uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v4 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v4 with: name: cypress-e2e-artifacts path: | diff --git a/.github/workflows/code.hotfix.branch.yml b/.github/workflows/code.hotfix.branch.yml index f38b97033..78a6eac0a 100644 --- a/.github/workflows/code.hotfix.branch.yml +++ b/.github/workflows/code.hotfix.branch.yml @@ -21,7 +21,7 @@ jobs: pull-requests: write # Required for creating PRs steps: - name: Checkout Source Tag - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 with: ref: refs/tags/${{ github.event.inputs.source_tag }} - name: Create Hotfix Branch and Update Version diff --git a/.github/workflows/code.merge.main-to-develop.yml b/.github/workflows/code.merge.main-to-develop.yml index d690ecf41..5022d19c8 100644 --- a/.github/workflows/code.merge.main-to-develop.yml +++ b/.github/workflows/code.merge.main-to-develop.yml @@ -14,7 +14,7 @@ jobs: contents: write # Required for merging branches steps: - name: Checkout main - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 with: ref: main ssh-key: ${{ secrets.DEPLOYMENT_SSH_KEY }} diff --git a/.github/workflows/code.publish.yml b/.github/workflows/code.publish.yml index 0c45b6afe..6d47235e6 100644 --- a/.github/workflows/code.publish.yml +++ b/.github/workflows/code.publish.yml @@ -25,55 +25,72 @@ jobs: contents: write # Required for uploading release assets id-token: write # Required for npm trusted publishing (OIDC) steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + # Free up disk space (~30GB+) by removing preinstalled software we don't need + - name: Free disk space + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo rm -rf /opt/hostedtoolcache/go + sudo rm -rf /opt/hostedtoolcache/Ruby + sudo rm -rf /usr/local/share/powershell + sudo rm -rf /usr/local/share/chromium + sudo rm -rf /usr/local/lib/heroku + sudo rm -rf /usr/share/swift + sudo docker image prune --all --force + df -h + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 # Setup .npmrc file to publish to NpmJs Packages - - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: '24.x' registry-url: 'https://registry.npmjs.org' # Setup Python for build scripts - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.13' + # Install npm dependencies and publish package. Auth is established with NpmJs Trusted publishing. # To update, modify package at https://www.npmjs.com/package/awslabs-lisa/access # More info: https://docs.npmjs.com/trusted-publishers - run: npm ci + # Set version from input when running in test mode + - name: Set test version + if: github.event_name == 'workflow_dispatch' + run: npm version "${{ inputs.version }}" --no-git-tag-version --allow-same-version - name: Publish NPM Package - if: github.event_name == 'release' || !inputs.test_mode + if: "!(github.event_name == 'workflow_dispatch' && inputs.test_mode == true)" run: npm publish - name: Publish NPM Package (Dry Run) - if: github.event_name == 'workflow_dispatch' && inputs.test_mode + if: github.event_name == 'workflow_dispatch' && inputs.test_mode == true run: npm publish --dry-run - - # Build binary assets (lambda layers and container images) - - name: Build Lambda Layers and Container Images - run: | - # Create build directory for lambda layers - mkdir -p build - - # Build assets (runs build-lambdas and build-images --export) - ./bin/build-assets + # Install Python dependencies needed by build scripts + - name: Install Python build dependencies + run: pip install tiktoken==0.12.0 + # Build and export container images (separate from npm publish to avoid OOM) + - name: Build Container Images + run: ./bin/build-images --export env: PYPI_URL: https://pypi.org/simple LISA_VERSION: ${{ github.event_name == 'release' && github.event.release.tag_name || inputs.version }} # Upload binary assets to GitHub Release - name: Upload Release Assets - if: github.event_name == 'release' || !inputs.test_mode + if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && inputs.test_mode == false) uses: softprops/action-gh-release@v2 with: + tag_name: ${{ github.event_name == 'release' && github.event.release.tag_name || inputs.version }} files: | - dist/layers/*.zip dist/images/*.tar env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # In test mode, just list what would be uploaded - name: List Build Artifacts (Test Mode) - if: github.event_name == 'workflow_dispatch' && inputs.test_mode + if: github.event_name == 'workflow_dispatch' && inputs.test_mode == true run: | echo "=== Lambda Layers (dist/layers/*.zip) ===" ls -lh dist/layers/*.zip 2>/dev/null || echo "No zip files found" diff --git a/.github/workflows/code.release.branch.yml b/.github/workflows/code.release.branch.yml index ff786c1ad..211013afb 100644 --- a/.github/workflows/code.release.branch.yml +++ b/.github/workflows/code.release.branch.yml @@ -19,13 +19,13 @@ jobs: pull-requests: write # Required for creating PRs steps: - name: Checkout Develop Branch - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 with: ref: develop fetch-depth: 0 # Fetch full history for proper branch comparison ssh-key: ${{ secrets.DEPLOYMENT_SSH_KEY }} - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@cf459bd40262a8603163308e488de922e4eb5a95 # v4 + uses: aws-actions/configure-aws-credentials@6e631f05b2a5f53c9f1e27150d5e8af2f907b03b # v4 with: aws-region: ${{ vars.AWS_REGION }} role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT }}:role/${{ vars.ROLE_NAME_TO_ASSUME }} diff --git a/.github/workflows/code.smoke-test.yml b/.github/workflows/code.smoke-test.yml index 48f660e3c..8ea1df819 100644 --- a/.github/workflows/code.smoke-test.yml +++ b/.github/workflows/code.smoke-test.yml @@ -14,10 +14,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 - name: Setup Node.js - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: '24' cache: 'npm' @@ -42,7 +42,7 @@ jobs: - name: Archive Cypress videos & screenshots if: failure() || always() - uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v4 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v4 with: name: cypress-smoke-artifacts path: | diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 650065f25..d723a16a2 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 - name: Initialize CodeQL uses: github/codeql-action/init@b36bf259c813715f76eafece573914b94412cd13 # v3 @@ -40,7 +40,7 @@ jobs: - name: Set up Python 3.13 if: matrix.language == 'python' - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v5 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v5 with: python-version: "3.13" @@ -58,7 +58,7 @@ jobs: - name: Use Node.js 24.x if: matrix.language == 'javascript' - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: 24.x diff --git a/.github/workflows/docs.deploy.github-pages.yml b/.github/workflows/docs.deploy.github-pages.yml index 6eaa9155c..4d77b794c 100644 --- a/.github/workflows/docs.deploy.github-pages.yml +++ b/.github/workflows/docs.deploy.github-pages.yml @@ -19,11 +19,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 with: fetch-depth: 0 - name: Setup Node - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: 24 cache: npm diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index 2057caf39..388d1917d 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -41,9 +41,9 @@ jobs: permissions: contents: read steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 - name: Use Node.js 24.x - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: 24.x - name: Install dependencies @@ -59,9 +59,9 @@ jobs: permissions: contents: read steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 - name: Set up Python 3.13 - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v5 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v5 with: python-version: "3.13" - name: Install dependencies @@ -87,13 +87,13 @@ jobs: permissions: contents: read steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 - name: Set up Python 3.13 - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v5 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v5 with: python-version: '3.13' - name: Use Node.js 24.x - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v4 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v4 with: node-version: 24.x - name: Install CDK dependencies diff --git a/.gitignore b/.gitignore index 343778277..d1e3a8ea0 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ lib/rag/ingestion/ingestion-image/build .DS_Store *.iml *.code-workspace +.hf_token_cache # AI Tools .cursor diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d44717974..33a78ef94 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: hooks: - id: codespell entry: codespell - args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*poetry.lock*,*coverage*,*models/*,*htmlcov*,*TIKTOKEN_CACHE/*,*test/cdk/stacks/__baselines__/*', "-L=xdescribe,assertIn,afterAll"] + args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*/public/*,*poetry.lock*,*coverage*,*models/*,*htmlcov*,*TIKTOKEN_CACHE/*,*test/cdk/stacks/__baselines__/*', "-L=xdescribe,assertIn,afterAll"] pass_filenames: false - repo: https://github.com/pycqa/isort diff --git a/CHANGELOG.md b/CHANGELOG.md index e05c64353..76b9ae1ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,39 @@ +# v6.3.0 + +## UI Updates +- Added RAG citation document preview side panel in Chat UI +- Exposed the document preview panel in the document library for viewing documents +- Added "Dismiss all" button for notification stacks +- Fixed "Loading Configuration..." text styling to match LISA UI using Cloudscape components +- Added last updated date/time to session displays + +## Other Key Changes +- Updated VLLM image to latest AWS deep-learning base with GPU settings for ECS, memory reservation, and tensor parallelization from GPU count +- Dockerfiles for embedding (instructor, tei), text generation (tgi), and VLLM now run OS package upgrades during build +- Removed deprecated LISA Serve V1 endpoints and supporting infrastructure +- Updated dependencies across the codebase + +## Bug Fixes +- Fixed RAG pipeline collection ID resolution (find_by_id_or_name fallback) and EventBus update mismatches on deployment +- Resolved max_tokens handling for non-Anthropic models on Anthropic routes +- Improved RAG PDF parsing quality (excessive whitespace and invisible Unicode characters) +- Addressed consistency of UI validation warnings for field format and required fields +- Added missing required role for batch ingestion +- Added cache clearing at login to prevent cache corruption issues + +## Documentation +- Added Claude Code setup guide for LISA Serve integration +- Updated deployment guide + +## Acknowledgements +* @bedanley +* @Ernest-Gray +* @estohlmann +* @gingerknight +* @jmharold + +**Full Changelog**: https://github.com/awslabs/LISA/compare/v6.2.1..v6.3.0 + # v6.2.1 ## Bug Fixes diff --git a/VERSION b/VERSION index 024b066c0..798e38995 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -6.2.1 +6.3.0 diff --git a/bin/build-assets b/bin/build-assets index 3dcc739fb..8f211f7c2 100755 --- a/bin/build-assets +++ b/bin/build-assets @@ -10,6 +10,17 @@ export PYPI_URL=${PYPI_URL:-"https://pypi.org/simple"} export OUTPUT_DIR=$LAYER_DIR export IMAGE_DIR +# Parse arguments +INCLUDE_IMAGES=false +for arg in "$@"; do + case $arg in + --include-images) + INCLUDE_IMAGES=true + shift + ;; + esac +done + echo "Building all assets..." # Build Lambda layers (Python and Node.js) @@ -24,8 +35,10 @@ mv ./build/Lambda.zip "$LAYER_DIR/" rm -rf ./build cd "$ROOT" -# Build and export container images -echo "Building Image exports..." -./bin/build-images --export +# Build and export container images (only when explicitly requested) +if [[ "$INCLUDE_IMAGES" == "true" ]]; then + echo "Building Image exports..." + ./bin/build-images --export +fi echo "All assets built successfully!" diff --git a/bin/build-images b/bin/build-images index 3b7f68707..394cd1d98 100755 --- a/bin/build-images +++ b/bin/build-images @@ -155,7 +155,7 @@ build_all_images() { # lisa-vllm build_image "Dockerfile" "lisa-vllm" "latest" "./lib/serve/ecs-model/vllm" \ "NODE_ENV=production" \ - "BASE_IMAGE=public.ecr.aws/deep-learning-containers/vllm:0.13-gpu-py312" \ + "BASE_IMAGE=public.ecr.aws/deep-learning-containers/vllm:0.15-gpu-py312-ec2" \ "MOUNTS3_DEB_URL=https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb" echo "All images built successfully!" diff --git a/ecs_model_deployer/src/lib/ecsCluster.ts b/ecs_model_deployer/src/lib/ecsCluster.ts index a8451bd2e..4ec393c94 100644 --- a/ecs_model_deployer/src/lib/ecsCluster.ts +++ b/ecs_model_deployer/src/lib/ecsCluster.ts @@ -27,7 +27,6 @@ import { Ec2Service, Ec2ServiceProps, Ec2TaskDefinition, - EcsOptimizedImage, HealthCheck, Host, LinuxParameters, @@ -108,7 +107,7 @@ export class ECSCluster extends Construct { const autoScalingGroup = cluster.addCapacity(createCdkId([identifier, 'ASG']), { vpcSubnets: subnetSelection, instanceType: new InstanceType(ecsConfig.instanceType), - machineImage: EcsOptimizedImage.amazonLinux2023(ecsConfig.amiHardwareType), + machineImage: CodeFactory.getEcsMachineImage(config.region, ecsConfig.amiHardwareType, ecsConfig.amiId), minCapacity: ecsConfig.autoScalingConfig.minCapacity, maxCapacity: ecsConfig.autoScalingConfig.maxCapacity, groupMetrics: [GroupMetrics.all()], @@ -149,7 +148,7 @@ export class ECSCluster extends Construct { // EC2 user data to mount ephemeral NVMe drive const MOUNT_PATH = config.nvmeHostMountPath ?? '/nvme'; const NVME_PATH = Ec2Metadata.get(ecsConfig.instanceType).nvmePath; - /* eslint-disable no-useless-escape */ + const rawUserData = `#!/bin/bash set -e # Check if NVMe is already formatted @@ -165,12 +164,28 @@ export class ECSCluster extends Construct { echo ${NVME_PATH} ${MOUNT_PATH} xfs defaults,nofail 0 2 >> /etc/fstab fi - # Update Docker root location and restart Docker service + # Configure Docker: set data-root on NVMe and ensure nvidia runtime is registered mkdir -p ${MOUNT_PATH}/docker - echo '{\"data-root\": \"${MOUNT_PATH}/docker\"}' | tee /etc/docker/daemon.json + cat > /etc/docker/daemon.json <<'DOCKEREOF' +{ + "data-root": "${MOUNT_PATH}/docker", + "runtimes": { + "nvidia": { + "path": "nvidia-container-runtime", + "runtimeArgs": [] + } + }, + "default-runtime": "nvidia" +} +DOCKEREOF + # Substitute the actual mount path + sed -i "s|\${MOUNT_PATH}|${MOUNT_PATH}|g" /etc/docker/daemon.json systemctl restart docker + + # Enable GPU support in ECS agent + echo "ECS_ENABLE_GPU_SUPPORT=true" >> /etc/ecs/ecs.config `; - /* eslint-enable no-useless-escape */ + autoScalingGroup.addUserData(rawUserData); // Create mount point for container @@ -186,6 +201,11 @@ export class ECSCluster extends Construct { mountPoints.push(nvmeMountPoint); } + // Enable GPU support in ECS agent for GPU instances without NVMe (NVMe path sets it in user data) + if (ecsConfig.amiHardwareType === AmiHardwareType.GPU && !Ec2Metadata.get(ecsConfig.instanceType).nvmePath) { + autoScalingGroup.addUserData('echo "ECS_ENABLE_GPU_SUPPORT=true" >> /etc/ecs/ecs.config'); + } + // Add permissions to use SSM in dev environment for EC2 debugging purposes only if (config.deploymentStage === 'dev') { autoScalingGroup.role.addManagedPolicy(ManagedPolicy.fromAwsManagedPolicyName('AmazonSSMFullAccess')); @@ -245,7 +265,11 @@ export class ECSCluster extends Construct { container = ec2TaskDefinition.addContainer(createCdkId([identifier, 'Container']), { containerName: createCdkId([config.deploymentName, identifier], 32, 2), image, - environment, + environment: { + ...environment, + // Required for NVIDIA container runtime when not using NVIDIA base images + ...(ecsConfig.amiHardwareType === AmiHardwareType.GPU && { NVIDIA_DRIVER_CAPABILITIES: 'utility,compute' }), + }, logging: LogDriver.awsLogs({ streamPrefix: identifier }), gpuCount: Ec2Metadata.get(ecsConfig.instanceType).gpuCount, memoryReservationMiB: taskDefinition.containerConfig.memoryReservation ?? diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index e19145bfd..82cbc47c8 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -401,6 +401,9 @@ class ContainerConfig(BaseModel): sharedMemorySize: PositiveInt healthCheckConfig: ContainerHealthCheckConfig environment: dict[str, str] | None = {} + memoryReservation: int | None = Field( + default=None, ge=0, description="Memory reservation in MiB for the container." + ) @field_validator("environment") @classmethod diff --git a/lambda/repository/collection_repo.py b/lambda/repository/collection_repo.py index 94ecd6374..4f2664f83 100644 --- a/lambda/repository/collection_repo.py +++ b/lambda/repository/collection_repo.py @@ -368,6 +368,16 @@ def count_by_repository(self, repository_id: str, status: CollectionStatus | Non logger.error(f"Failed to count collections for repository {repository_id}: {e}") raise CollectionRepositoryError(f"Failed to count collections: {str(e)}") + def find_by_id_or_name(self, collection_id: str, repository_id: str) -> RagCollectionConfig | None: + """Find a collection by UUID primary key, falling back to name lookup. + + Handles the case where a pipeline stores a user-entered name as collectionId. + """ + collection = self.find_by_id(collection_id, repository_id) + if collection: + return collection + return self.find_by_name(repository_id, collection_id) + def find_by_name(self, repository_id: str, collection_name: str) -> RagCollectionConfig | None: """ Find a collection by repository ID and name. diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index d9b04c2fb..2dad4e757 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -121,6 +121,54 @@ def list_status(event: dict, context: dict) -> dict[str, Any]: return cast(dict, vs_repo.get_repository_status()) +def enrich_metadata_with_document_id( + docs: list[dict[str, Any]], repository_id: str, search_collection_id: str +) -> list[dict[str, Any]]: + """Enrich document metadata with document_id by looking up in RAG document table. + + Works for both new documents (that have document_id in vector store) and + existing documents (that don't have it yet), ensuring backward compatibility. + + Args: + docs: Documents from vector store similarity search + repository_id: Repository ID for scoping the lookup + search_collection_id: The actual collection ID used for the search (not from metadata) + + Returns: + Documents with enriched metadata including document_id + """ + for doc in docs: + metadata = doc.get("metadata", {}) + + # Look up document_id from RAG document table using source path + source = metadata.get("source") + + if source and search_collection_id: + try: + # Query RAG document table by source using the ACTUAL collection_id from search + # Not the metadata's collectionId which may be "default" in vector store + rag_doc = doc_repo.find_one_by_source( + repository_id=repository_id, collection_id=search_collection_id, source=source + ) + + if rag_doc: + metadata["document_id"] = rag_doc.document_id + logger.info(f"Enriched metadata with document_id: {rag_doc.document_id} for source: {source}") + else: + logger.warning( + f"No RAG document found for source: {source} " + f"in repository: {repository_id}, collection: {search_collection_id}" + ) + + except Exception as e: + logger.error(f"Failed to enrich metadata for source {source}: {e}") + # Continue without document_id - frontend will handle gracefully + else: + logger.debug(f"Missing source ({source}) or collection_id ({search_collection_id}), skipping enrichment") + + return docs + + @api_wrapper def similarity_search(event: dict, context: dict) -> dict[str, Any]: """Return documents matching the query. @@ -198,6 +246,10 @@ def similarity_search(event: dict, context: dict) -> dict[str, Any]: bedrock_agent_client=bedrock_client, ) + # Enrich metadata with documentId for documents that don't have it + # Pass the actual search_collection_id (not the metadata's collectionId which may be "default") + docs = enrich_metadata_with_document_id(docs, repository_id, search_collection_id) # type: ignore[arg-type] + doc_content = [ { "Document": { diff --git a/lambda/repository/pipeline_delete_documents.py b/lambda/repository/pipeline_delete_documents.py index e3682baf4..ff4ec3b32 100644 --- a/lambda/repository/pipeline_delete_documents.py +++ b/lambda/repository/pipeline_delete_documents.py @@ -14,11 +14,10 @@ import logging import os -from typing import Any import boto3 from boto3.dynamodb.conditions import Key -from models.domain_objects import CollectionStatus, IngestionJob, IngestionStatus, IngestionType, JobActionType +from models.domain_objects import CollectionStatus, IngestionJob, IngestionStatus, JobActionType from repository.collection_repo import CollectionRepository from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService @@ -373,100 +372,3 @@ def pipeline_delete_documents(job: IngestionJob) -> None: error_msg = f"Failed to process batch deletion: {str(e)}" logger.error(error_msg, exc_info=True) raise Exception(error_msg) - - -def handle_pipeline_delete_event(event: dict[str, Any], context: Any) -> None: - """Handle pipeline document deletion for S3 ObjectRemoved events.""" - # Extract and validate inputs - logger.debug(f"Received event: {event}") - - detail = event.get("detail", {}) - bucket = detail.get("bucket", None) - key = detail.get("key", None) - repository_id = detail.get("repositoryId", None) - collection_id = detail.get("collectionId", None) - pipeline_config = detail.get("pipelineConfig", None) - s3_path = f"s3://{bucket}/{key}" - - if not repository_id: - logger.warning("No repository_id in event, skipping deletion") - return - - # Get repository to determine type and configuration - repository = vs_repo.find_repository_by_id(repository_id) - if not repository: - logger.warning(f"Repository {repository_id} not found, skipping deletion") - return - - # For Bedrock KB repositories, use data source ID as collection ID - if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): - if not collection_id: - # Fallback: try to get from bedrock config (legacy support) - bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) - - # Try new structure with dataSources array - data_sources = bedrock_config.get("dataSources", []) - if data_sources: - first_data_source = data_sources[0] - collection_id = ( - first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id - ) - else: - # Try legacy single data source ID - collection_id = bedrock_config.get("bedrockKnowledgeDatasourceId") - - if not collection_id: - logger.error(f"Bedrock KB repository {repository_id} missing data source ID") - return - - logger.info( - f"Processing Bedrock KB document deletion {s3_path} for repository {repository_id}, " - f"collection {collection_id}" - ) - else: - if not pipeline_config or not isinstance(pipeline_config, dict): - logger.warning("No pipeline_config in event, skipping deletion") - return - - embedding_model = pipeline_config.get("embeddingModel", None) - if embedding_model is None: - logger.warning("No embedding_model in pipeline_config, skipping deletion") - return - - collection_id = embedding_model - logger.info(f"Deleting object {s3_path} for repository {repository_id}/{embedding_model}") - - # Find documents by source path (idempotent - handles missing documents gracefully) - documents = rag_document_repository.find_by_source( - repository_id=repository_id, - collection_id=collection_id, - document_source=s3_path, - join_docs=False, # Don't need subdocs for deletion - ) - - if not documents: - logger.info(f"Document {s3_path} not found in tracking system, already deleted or never tracked") - return # Idempotent - success even if document doesn't exist - - # Delete each found document - for rag_document in documents: - logger.info(f"Deleting tracked document {rag_document.document_id} from {s3_path}") - - # Find or create ingestion job for deletion - ingestion_job = ingestion_job_repository.find_by_document(rag_document.document_id) - if ingestion_job is None: - ingestion_job = IngestionJob( - repository_id=repository_id, - collection_id=collection_id, - embedding_model=collection_id, # Use collection_id as embedding_model - chunk_strategy=None, - s3_path=rag_document.source, - username=rag_document.username, - ingestion_type=IngestionType.AUTO, - status=IngestionStatus.DELETE_PENDING, - ) - ingestion_job_repository.save(ingestion_job) - - # Submit deletion job - ingestion_service.create_delete_job(ingestion_job) - logger.info(f"Submitted deletion job for document {s3_path} in repository {repository_id}") diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 7b6aa883b..1a214aa6a 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Lambda function for pipeline document ingestion.""" +"""Batch container processing functions for pipeline document ingestion.""" import logging import os -from datetime import timedelta -from typing import Any import boto3 from models.domain_objects import ( - ChunkingStrategy, - FixedChunkingStrategy, IngestionJob, IngestionStatus, IngestionType, @@ -39,12 +35,11 @@ from repository.s3_metadata_manager import S3MetadataManager from repository.services.repository_service_factory import RepositoryServiceFactory from repository.vector_store_repo import VectorStoreRepository -from utilities.auth import get_username from utilities.bedrock_kb import get_datasource_bucket_for_collection, ingest_document_to_kb, S3DocumentDiscoveryService from utilities.common_functions import retry_config from utilities.file_processing import generate_chunks from utilities.repository_types import RepositoryType -from utilities.time import now, utc_now +from utilities.time import now dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) ingestion_job_table = dynamodb.Table(os.environ["LISA_INGESTION_JOB_TABLE_NAME"]) @@ -58,7 +53,6 @@ session = boto3.Session() s3 = boto3.client("s3", region_name=os.environ["AWS_REGION"], config=retry_config) bedrock_agent = boto3.client("bedrock-agent", region_name=os.environ["AWS_REGION"], config=retry_config) -ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) def pipeline_ingest(job: IngestionJob) -> None: @@ -425,277 +419,13 @@ def remove_document_from_vectorstore(doc: RagDocument) -> None: vector_store.delete(doc.subdocs) # type: ignore[union-attr] -def handle_pipeline_ingest_event(event: dict[str, Any], context: Any) -> None: - """Handle pipeline document ingestion.""" - # Extract and validate inputs - logger.debug(f"Received event: {event}") - - detail = event.get("detail", {}) - bucket = detail.get("bucket", None) - username = get_username(event) - key = detail.get("key", None) - - # Safety check: filter out metadata files (should be filtered by EventBridge) - if key and key.endswith(".metadata.json"): - logger.warning(f"Metadata file event reached Lambda (should be filtered by EventBridge): {key}") - return - repository_id = detail.get("repositoryId", None) - pipeline_config = detail.get("pipelineConfig", None) - collection_id = detail.get("collectionId", None) - s3_path = f"s3://{bucket}/{key}" - - # Get repository to determine type and configuration - repository = vs_repo.find_repository_by_id(repository_id) - - # For Bedrock KB repositories, use data source ID as collection ID - if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): - if not collection_id: - # Fallback: try to get from bedrock config (legacy support) - bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) - - # Try new structure with dataSources array - data_sources = bedrock_config.get("dataSources", []) - if data_sources: - first_data_source = data_sources[0] - collection_id_val: str | None = ( - first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id - ) - if not collection_id_val: - logger.error(f"Bedrock KB repository {repository_id} has invalid data source") - return - collection_id = collection_id_val - else: - # Try legacy single data source ID - collection_id_val = bedrock_config.get("bedrockKnowledgeDatasourceId") - if not collection_id_val: - logger.error(f"Bedrock KB repository {repository_id} missing data source ID") - return - collection_id = collection_id_val - - if not collection_id: - logger.error(f"Bedrock KB repository {repository_id} missing data source ID") - return - - embedding_model = repository.get("embeddingModelId") - chunk_strategy = NoneChunkingStrategy() # KB manages chunking - - # Set username to "system" for auto-ingestion from KB bucket - username = "system" - ingestion_type = IngestionType.AUTO - - logger.info( - f"Processing Bedrock KB document {s3_path} for repository {repository_id}, " f"collection {collection_id}" - ) - else: - # Non-Bedrock KB path (existing logic) - embedding_model = pipeline_config.get("embeddingModel", None) - - if collection_id: - collection = collection_service.get_collection( - collection_id=collection_id, repository_id=repository_id, is_admin=True, username="", user_groups=[] - ) - - if collection.embeddingModel is not None: - embedding_model = collection.embeddingModel - else: - collection_id = embedding_model - - chunk_strategy = extract_chunk_strategy(pipeline_config) - ingestion_type = IngestionType.MANUAL - - logger.info(f"Ingesting object {s3_path} for repository {repository_id}/{embedding_model}") - - # Get collection for metadata merging - collection_dict = None - if collection_id and collection_id != embedding_model: - try: - collection_obj = collection_service.get_collection( - collection_id=collection_id, repository_id=repository_id, is_admin=True, username="", user_groups=[] - ) - collection_dict = collection_obj.model_dump() if collection_obj else None - except Exception as e: - logger.warning(f"Could not fetch collection for metadata merging: {e}") - - # Merge metadata from repository, collection, and pipeline sources - merged_metadata = MetadataGenerator.merge_metadata( - repository=repository, - collection=collection_dict, - document_metadata=None, - for_bedrock_kb=False, # Keep tags as array for ingestion jobs - ) - - # Create ingestion job and save it to dynamodb - job = IngestionJob( - repository_id=repository_id, - collection_id=collection_id, - embedding_model=embedding_model, - chunk_strategy=chunk_strategy, - s3_path=s3_path, - username=username, - ingestion_type=ingestion_type, - metadata=merged_metadata, - ) - ingestion_job_repository.save(job) - ingestion_service.submit_create_job(job) - - logger.info(f"Submitted ingestion job for document {s3_path} in repository {repository_id}") - - -def handle_pipline_ingest_schedule(event: dict[str, Any], context: Any) -> None: - """ - Lists all objects in the specified S3 bucket and prefix that were modified in the last 24 hours. - - Args: - event: Event data containing bucket and prefix information - context: Lambda context - - Returns: - Dictionary containing array of files with their bucket and key - """ - # Log the full event for debugging - logger.debug(f"Received event: {event}") - - detail = event.get("detail", {}) - bucket = detail.get("bucket", None) - username = get_username(event) - prefix = detail.get("prefix", None) - repository_id = detail.get("repositoryId", None) - pipeline_config = detail.get("pipelineConfig", None) - embedding_model = pipeline_config.get("embeddingModel", None) - - # hard code fixed length chunking until more strategies are implemented - chunk_strategy = extract_chunk_strategy(pipeline_config) - - # Get repository for metadata merging - repository = vs_repo.find_repository_by_id(repository_id) - - try: - # Add debug logging - logger.info(f"Processing request for bucket: {bucket}, prefix: {prefix}") - - # Calculate timestamp for 24 hours ago - twenty_four_hours_ago = utc_now() - timedelta(hours=24) - - # List to store matching objects - modified_keys = [] - - # Use paginator to handle case where there are more than 1000 objects - paginator = s3.get_paginator("list_objects_v2") - - # Add debug logging for S3 list operation - logger.info(f"Listing objects in {bucket}{prefix} modified after {twenty_four_hours_ago}") - - # Iterate through all objects in the bucket/prefix - try: - for page in paginator.paginate(Bucket=bucket, Prefix=prefix): - if "Contents" not in page: - logger.info(f"No contents found in page for {bucket}{prefix}") - continue - - # Check each object's last modified time - for obj in page["Contents"]: - last_modified = obj["LastModified"] - if last_modified >= twenty_four_hours_ago: - logger.info(f"Found modified file: {obj['Key']} (Last Modified: {last_modified})") - modified_keys.append(obj["Key"]) - else: - logger.debug( - f"Skipping file {obj['Key']} - Last modified {last_modified} before cutoff " - f"{twenty_four_hours_ago}" - ) - except Exception as e: - logger.error(f"Error during S3 list operation: {str(e)}", exc_info=True) - raise - - # Merge metadata from repository and pipeline sources (no collection for scheduled jobs) - merged_metadata = MetadataGenerator.merge_metadata( - repository=repository, - collection=None, - document_metadata=None, - for_bedrock_kb=False, # Keep tags as array for ingestion jobs - ) - - # create an IngestionJob for every object created/modified - for key in modified_keys: - job = IngestionJob( - repository_id=repository_id, - collection_id=embedding_model, - chunk_strategy=chunk_strategy, - s3_path=f"s3://{bucket}/{key}", - username=username, - ingestion_type=IngestionType.AUTO, - metadata=merged_metadata, - ) - ingestion_job_repository.save(job) - ingestion_service.submit_create_job(job) - - logger.info(f"Found {len(modified_keys)} modified files in {bucket}{prefix}") - except Exception as e: - logger.error(f"Error listing objects: {str(e)}", exc_info=True) - raise e - - def batch_texts(texts: list[str], metadatas: list[dict], batch_size: int = 256) -> list[tuple[list[str], list[dict]]]: - """ - Split texts and metadata into batches of specified size. - - Args: - texts: List of text strings to batch - metadatas: List of metadata dictionaries - batch_size: Maximum size of each batch (default 256 to match embedding server limit) - Returns: - List of tuples containing (texts_batch, metadatas_batch) - """ batches = [] for i in range(0, len(texts), batch_size): - text_batch = texts[i : i + batch_size] - metadata_batch = metadatas[i : i + batch_size] - batches.append((text_batch, metadata_batch)) + batches.append((texts[i : i + batch_size], metadatas[i : i + batch_size])) return batches -def extract_chunk_strategy(pipeline_config: dict) -> ChunkingStrategy: - """ - Extract and validate chunking strategy from pipeline configuration. - - Supports both new chunkingStrategy object format and legacy flat fields for backward compatibility. - Uses Pydantic model validation to ensure data integrity. - - Args: - pipeline_config: Pipeline configuration dictionary - - Returns: - ChunkingStrategy object (validated Pydantic model) - - Raises: - ValueError: If chunking strategy type is unsupported or validation fails - """ - # Check for new chunkingStrategy object format first - if "chunkingStrategy" in pipeline_config and pipeline_config["chunkingStrategy"]: - chunking_strategy = pipeline_config["chunkingStrategy"] - chunk_type = chunking_strategy.get("type", "fixed") - - if chunk_type == "fixed": - # Use Pydantic model validation for type safety and validation - result: FixedChunkingStrategy = FixedChunkingStrategy.model_validate(chunking_strategy) - return result - else: - # Future: Handle other chunking strategy types (semantic, recursive, etc.) - raise ValueError(f"Unsupported chunking strategy type: {chunk_type}") - - # Fall back to legacy flat fields for backward compatibility - elif "chunkSize" in pipeline_config and "chunkOverlap" in pipeline_config: - chunk_size = int(pipeline_config["chunkSize"]) - chunk_overlap = int(pipeline_config["chunkOverlap"]) - # Use Pydantic model for validation - return FixedChunkingStrategy(size=chunk_size, overlap=chunk_overlap) - - # Default values if neither format is present - else: - logger.warning("No chunking strategy found in pipeline config, using defaults") - return FixedChunkingStrategy(size=512, overlap=51) - - def prepare_chunks(docs: list, repository_id: str, collection_id: str) -> tuple[list[str], list[dict]]: """Prepare texts and metadata from document chunks.""" texts = [] diff --git a/lambda/repository/pipeline_ingest_handlers.py b/lambda/repository/pipeline_ingest_handlers.py new file mode 100644 index 000000000..8327a1e98 --- /dev/null +++ b/lambda/repository/pipeline_ingest_handlers.py @@ -0,0 +1,344 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda event handlers for pipeline document ingestion and deletion.""" + +import logging +import os +from datetime import timedelta +from typing import Any, cast + +import boto3 +from models.domain_objects import ( + FixedChunkingStrategy, + IngestionJob, + IngestionStatus, + IngestionType, + NoneChunkingStrategy, +) +from repository.collection_service import CollectionService +from repository.ingestion_job_repo import IngestionJobRepository +from repository.ingestion_service import DocumentIngestionService +from repository.metadata_generator import MetadataGenerator +from repository.rag_document_repo import RagDocumentRepository +from repository.vector_store_repo import VectorStoreRepository +from utilities.auth import get_username +from utilities.common_functions import retry_config +from utilities.repository_types import RepositoryType +from utilities.time import utc_now + +dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) +ingestion_job_table = dynamodb.Table(os.environ["LISA_INGESTION_JOB_TABLE_NAME"]) +ingestion_service = DocumentIngestionService() +ingestion_job_repository = IngestionJobRepository() +vs_repo = VectorStoreRepository() +rag_document_repository = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"]) +collection_service = CollectionService(vector_store_repo=vs_repo, document_repo=rag_document_repository) + +logger = logging.getLogger(__name__) +session = boto3.Session() +s3 = boto3.client("s3", region_name=os.environ["AWS_REGION"], config=retry_config) +bedrock_agent = boto3.client("bedrock-agent", region_name=os.environ["AWS_REGION"], config=retry_config) +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + + +def extract_chunk_strategy(pipeline_config: dict) -> FixedChunkingStrategy | NoneChunkingStrategy: + if "chunkingStrategy" in pipeline_config and pipeline_config["chunkingStrategy"]: + chunking_strategy = pipeline_config["chunkingStrategy"] + chunk_type = chunking_strategy.get("type", "fixed") + if chunk_type == "fixed": + return cast(FixedChunkingStrategy, FixedChunkingStrategy.model_validate(chunking_strategy)) + else: + raise ValueError(f"Unsupported chunking strategy type: {chunk_type}") + elif "chunkSize" in pipeline_config and "chunkOverlap" in pipeline_config: + return FixedChunkingStrategy( + size=int(pipeline_config["chunkSize"]), overlap=int(pipeline_config["chunkOverlap"]) + ) + else: + logger.warning("No chunking strategy found in pipeline config, using defaults") + return FixedChunkingStrategy(size=512, overlap=51) + + +def handle_pipeline_ingest_event(event: dict[str, Any], context: Any) -> None: + """Handle pipeline document ingestion.""" + logger.debug(f"Received event: {event}") + + detail = event.get("detail", {}) + bucket = detail.get("bucket", None) + username = get_username(event) + key = detail.get("key", None) + + if key and key.endswith(".metadata.json"): + logger.warning(f"Metadata file event reached Lambda (should be filtered by EventBridge): {key}") + return + repository_id = detail.get("repositoryId", None) + pipeline_config = detail.get("pipelineConfig", None) + collection_id = detail.get("collectionId", None) + s3_path = f"s3://{bucket}/{key}" + + repository = vs_repo.find_repository_by_id(repository_id) + + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): + if not collection_id: + bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + data_sources = bedrock_config.get("dataSources", []) + if data_sources: + first_data_source = data_sources[0] + collection_id_val: str | None = ( + first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id + ) + if not collection_id_val: + logger.error(f"Bedrock KB repository {repository_id} has invalid data source") + return + collection_id = collection_id_val + else: + collection_id_val = bedrock_config.get("bedrockKnowledgeDatasourceId") + if not collection_id_val: + logger.error(f"Bedrock KB repository {repository_id} missing data source ID") + return + collection_id = collection_id_val + + if not collection_id: + logger.error(f"Bedrock KB repository {repository_id} missing data source ID") + return + + embedding_model = repository.get("embeddingModelId") + chunk_strategy = NoneChunkingStrategy() + username = "system" + ingestion_type = IngestionType.AUTO + + logger.info( + f"Processing Bedrock KB document {s3_path} for repository {repository_id}, collection {collection_id}" + ) + else: + if not collection_id: + collection_id = pipeline_config.get("collectionId") + + embedding_model = pipeline_config.get("embeddingModel", None) + + if collection_id: + resolved = collection_service.collection_repo.find_by_id_or_name(collection_id, repository_id) + if resolved is None: + raise ValueError(f"Collection '{collection_id}' not found in repository '{repository_id}'") + collection_id = resolved.collectionId + if resolved.embeddingModel is not None: + embedding_model = resolved.embeddingModel + else: + collection_id = embedding_model + + chunk_strategy = extract_chunk_strategy(pipeline_config) + ingestion_type = IngestionType.MANUAL + + logger.info(f"Ingesting object {s3_path} for repository {repository_id}/{embedding_model}") + + collection_dict = None + if collection_id and collection_id != embedding_model: + try: + collection_obj = collection_service.get_collection( + collection_id=collection_id, repository_id=repository_id, is_admin=True, username="", user_groups=[] + ) + collection_dict = collection_obj.model_dump() if collection_obj else None + except Exception as e: + logger.warning(f"Could not fetch collection for metadata merging: {e}") + + merged_metadata = MetadataGenerator.merge_metadata( + repository=repository, + collection=collection_dict, + document_metadata=None, + for_bedrock_kb=False, + ) + + job = IngestionJob( + repository_id=repository_id, + collection_id=collection_id, + embedding_model=embedding_model, + chunk_strategy=chunk_strategy, + s3_path=s3_path, + username=username, + ingestion_type=ingestion_type, + metadata=merged_metadata, + ) + ingestion_job_repository.save(job) + ingestion_service.submit_create_job(job) + + logger.info(f"Submitted ingestion job for document {s3_path} in repository {repository_id}") + + +def handle_pipline_ingest_schedule(event: dict[str, Any], context: Any) -> None: + """Lists objects modified in the last 24 hours and submits ingestion jobs.""" + logger.debug(f"Received event: {event}") + + detail = event.get("detail", {}) + bucket = detail.get("bucket", None) + username = get_username(event) + prefix = detail.get("prefix", None) + repository_id = detail.get("repositoryId", None) + pipeline_config = detail.get("pipelineConfig", None) + embedding_model = pipeline_config.get("embeddingModel", None) + + chunk_strategy = extract_chunk_strategy(pipeline_config) + + repository = vs_repo.find_repository_by_id(repository_id) + + try: + logger.info(f"Processing request for bucket: {bucket}, prefix: {prefix}") + + twenty_four_hours_ago = utc_now() - timedelta(hours=24) + modified_keys = [] + paginator = s3.get_paginator("list_objects_v2") + + logger.info(f"Listing objects in {bucket}{prefix} modified after {twenty_four_hours_ago}") + + try: + for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + if "Contents" not in page: + logger.info(f"No contents found in page for {bucket}{prefix}") + continue + for obj in page["Contents"]: + last_modified = obj["LastModified"] + if last_modified >= twenty_four_hours_ago: + logger.info(f"Found modified file: {obj['Key']} (Last Modified: {last_modified})") + modified_keys.append(obj["Key"]) + else: + logger.debug( + f"Skipping file {obj['Key']} - Last modified {last_modified}" + f" before cutoff {twenty_four_hours_ago}" + ) + except Exception as e: + logger.error(f"Error during S3 list operation: {str(e)}", exc_info=True) + raise + + merged_metadata = MetadataGenerator.merge_metadata( + repository=repository, + collection=None, + document_metadata=None, + for_bedrock_kb=False, + ) + + for key in modified_keys: + job = IngestionJob( + repository_id=repository_id, + collection_id=embedding_model, + chunk_strategy=chunk_strategy, + s3_path=f"s3://{bucket}/{key}", + username=username, + ingestion_type=IngestionType.AUTO, + metadata=merged_metadata, + ) + ingestion_job_repository.save(job) + ingestion_service.submit_create_job(job) + + logger.info(f"Found {len(modified_keys)} modified files in {bucket}{prefix}") + except Exception as e: + logger.error(f"Error listing objects: {str(e)}", exc_info=True) + raise e + + +def handle_pipeline_delete_event(event: dict[str, Any], context: Any) -> None: + """Handle pipeline document deletion for S3 ObjectRemoved events.""" + logger.debug(f"Received event: {event}") + + detail = event.get("detail", {}) + bucket = detail.get("bucket", None) + key = detail.get("key", None) + repository_id = detail.get("repositoryId", None) + collection_id = detail.get("collectionId", None) + pipeline_config = detail.get("pipelineConfig", None) + s3_path = f"s3://{bucket}/{key}" + + if not repository_id: + logger.warning("No repository_id in event, skipping deletion") + return + + repository = vs_repo.find_repository_by_id(repository_id) + if not repository: + logger.warning(f"Repository {repository_id} not found, skipping deletion") + return + + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): + if not collection_id: + bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + data_sources = bedrock_config.get("dataSources", []) + if data_sources: + first_data_source = data_sources[0] + collection_id = ( + first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id + ) + else: + collection_id = bedrock_config.get("bedrockKnowledgeDatasourceId") + + if not collection_id: + logger.error(f"Bedrock KB repository {repository_id} missing data source ID") + return + + logger.info( + f"Processing Bedrock KB document deletion {s3_path}" + f" for repository {repository_id}, collection {collection_id}" + ) + else: + if not pipeline_config or not isinstance(pipeline_config, dict): + logger.warning("No pipeline_config in event, skipping deletion") + return + + if not collection_id: + collection_id = pipeline_config.get("collectionId") + + if collection_id: + resolved = collection_service.collection_repo.find_by_id_or_name(collection_id, repository_id) + if resolved is None: + logger.warning( + f"Collection '{collection_id}' not found in repository '{repository_id}', skipping deletion" + ) + return + collection_id = resolved.collectionId + else: + # Legacy fallback: pipelines without collectionId used embeddingModel as collection_id + embedding_model = pipeline_config.get("embeddingModel") + if embedding_model is None: + logger.warning("No collectionId or embeddingModel in pipeline_config, skipping deletion") + return + collection_id = embedding_model + + logger.info(f"Deleting object {s3_path} for repository {repository_id}/{collection_id}") + + documents = rag_document_repository.find_by_source( + repository_id=repository_id, + collection_id=collection_id, + document_source=s3_path, + join_docs=False, + ) + + if not documents: + logger.info(f"Document {s3_path} not found in tracking system, already deleted or never tracked") + return + + for rag_document in documents: + logger.info(f"Deleting tracked document {rag_document.document_id} from {s3_path}") + + ingestion_job = ingestion_job_repository.find_by_document(rag_document.document_id) + if ingestion_job is None: + ingestion_job = IngestionJob( + repository_id=repository_id, + collection_id=collection_id, + embedding_model=collection_id, + chunk_strategy=None, + s3_path=rag_document.source, + username=rag_document.username, + ingestion_type=IngestionType.AUTO, + status=IngestionStatus.DELETE_PENDING, + ) + ingestion_job_repository.save(ingestion_job) + + ingestion_service.create_delete_job(ingestion_job) + logger.info(f"Submitted deletion job for document {s3_path} in repository {repository_id}") diff --git a/lambda/repository/rag_document_repo.py b/lambda/repository/rag_document_repo.py index d688da3bf..f5b418d45 100644 --- a/lambda/repository/rag_document_repo.py +++ b/lambda/repository/rag_document_repo.py @@ -199,6 +199,31 @@ def find_by_source( yield from self._yield_documents(response["Items"], join_docs=join_docs) + def find_one_by_source(self, repository_id: str, collection_id: str, source: str) -> RagDocument | None: + """Find a single document by source path. + + Args: + repository_id: Repository identifier + collection_id: Collection identifier + source: S3 source path (e.g., s3://bucket/key/file.docx) + + Returns: + First matching RagDocument if found, None otherwise + """ + try: + # Use existing find_by_source generator + docs_generator = self.find_by_source( + repository_id=repository_id, collection_id=collection_id, document_source=source, join_docs=False + ) + + # Get first document from generator + first_doc = next(docs_generator, None) + return first_doc + + except ClientError as e: + logging.error(f"Error finding document by source: {e}") + return None + def _yield_documents(self, items: list[dict], join_docs: bool) -> Generator[RagDocument]: for item in items: document = RagDocument(**item) @@ -467,7 +492,7 @@ def delete_s3_docs(self, repository_id: str, docs: list[RagDocument]) -> list[st try: logging.info(f"Removing S3 doc: {source}") self.delete_s3_object(uri=source) - except Exception as e: + except ClientError as e: logging.error(f"Failed to delete S3 object {source}: {e}") # Continue with other deletions diff --git a/lambda/utilities/file_processing.py b/lambda/utilities/file_processing.py index 4e92d02e1..046be2563 100644 --- a/lambda/utilities/file_processing.py +++ b/lambda/utilities/file_processing.py @@ -15,6 +15,8 @@ """Helper functions to parse documents for ingestion into RAG vector store.""" import logging import os +import re +import unicodedata from io import BytesIO from urllib.parse import urlparse @@ -96,8 +98,10 @@ def _extract_pdf_content(s3_object: dict) -> str: except PdfReadError as e: logger.error(f"Error reading PDF file: {e}") raise - - return "".join(page.extract_text() or "" for page in pdf_reader.pages) + raw = " ".join(page.extract_text() or "" for page in pdf_reader.pages) + raw = unicodedata.normalize("NFKC", raw) + raw = re.sub(r"[\xad\u200b\u200c\u200d\ufeff]", "", raw) + return re.sub(r"\s+", " ", raw).strip() def _extract_docx_content(s3_object: dict) -> str: diff --git a/lib/api-base/ecsCluster.ts b/lib/api-base/ecsCluster.ts index 4b70e6e00..2f367e05a 100644 --- a/lib/api-base/ecsCluster.ts +++ b/lib/api-base/ecsCluster.ts @@ -30,7 +30,6 @@ import { Ec2Service, Ec2ServiceProps, Ec2TaskDefinition, - EcsOptimizedImage, HealthCheck, Host, LinuxParameters, @@ -221,7 +220,7 @@ export class ECSCluster extends Construct { vpc: vpc.vpc, vpcSubnets: vpc.subnetSelection, instanceType: new InstanceType(ecsConfig.instanceType), - machineImage: EcsOptimizedImage.amazonLinux2023(ecsConfig.amiHardwareType), + machineImage: CodeFactory.getEcsMachineImage(config.region, ecsConfig.amiHardwareType, ecsConfig.amiId), minCapacity: ecsConfig.autoScalingConfig.minCapacity, maxCapacity: ecsConfig.autoScalingConfig.maxCapacity, groupMetrics: [GroupMetrics.all()], diff --git a/lib/api-base/fastApiContainer.ts b/lib/api-base/fastApiContainer.ts index 57d036d4c..74d1161fa 100644 --- a/lib/api-base/fastApiContainer.ts +++ b/lib/api-base/fastApiContainer.ts @@ -24,7 +24,7 @@ import { dump as yamlDump } from 'js-yaml'; import { ECSCluster, ECSTasks } from './ecsCluster'; import { BaseProps, Ec2Metadata, ECSConfig, EcsSourceType } from '../schema'; import { Vpc } from '../networking/vpc'; -import { REST_API_PATH } from '../util'; +import { REST_API_PATH, ROOT_PATH } from '../util'; import * as child_process from 'child_process'; import * as path from 'path'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; @@ -128,7 +128,8 @@ export class FastApiContainer extends Construct { // Skip tiktoken cache generation in test environment if (process.env.NODE_ENV !== 'test') { try { - child_process.execSync(`python3 scripts/cache-tiktoken-for-offline.py ${cache_dir}`, { stdio: 'inherit' }); + const scriptPath = path.join(ROOT_PATH, 'scripts', 'cache-tiktoken-for-offline.py'); + child_process.execSync(`python3 ${scriptPath} ${cache_dir}`, { stdio: 'inherit' }); } catch (error) { console.warn('Failed to generate tiktoken cache:', error); // Continue execution even if cache generation fails diff --git a/lib/core/layers/authorizer/requirements.txt b/lib/core/layers/authorizer/requirements.txt index 9c76662e4..9f044c1d9 100644 --- a/lib/core/layers/authorizer/requirements.txt +++ b/lib/core/layers/authorizer/requirements.txt @@ -1,5 +1,5 @@ # urllib3<2 // Provided by Lambda # cachetools==6.2.2 // provided by Common Layer # requests==2.32.5 // provided by Common Layer -cryptography==46.0.3 +cryptography==46.0.5 PyJWT==2.10.1 diff --git a/lib/core/layers/fastapi/requirements.txt b/lib/core/layers/fastapi/requirements.txt index c5b5b85e2..4d75a2fac 100644 --- a/lib/core/layers/fastapi/requirements.txt +++ b/lib/core/layers/fastapi/requirements.txt @@ -3,5 +3,5 @@ fastapi==0.124.2 mangum==0.19.0 pydantic==2.12.5 -cryptography==46.0.3 -starlette==0.46.2 +cryptography==46.0.5 +starlette==0.49.1 diff --git a/lib/docs/.vitepress/config.mts b/lib/docs/.vitepress/config.mts index 5c58bcc99..1db733da1 100644 --- a/lib/docs/.vitepress/config.mts +++ b/lib/docs/.vitepress/config.mts @@ -80,6 +80,7 @@ const navLinks = [ { text: 'MCP Connections: Third-party servers', link: '/config/mcp' }, { text: 'MCP Workbench: Experimentation', link: '/config/mcp-workbench' }, { text: 'Usage Analytics', link: '/config/cloudwatch' }, + { text: 'Claude Code Setup for LISA Serve', link: '/config/claude-code-setup' }, ], }, { diff --git a/lib/docs/config/claude-code-setup.md b/lib/docs/config/claude-code-setup.md new file mode 100644 index 000000000..c850c68c4 --- /dev/null +++ b/lib/docs/config/claude-code-setup.md @@ -0,0 +1,61 @@ +# Claude Code Setup for LISA Serve + +This guide explains how to configure Claude Code with LISA Serve + +### References +- [Claude Code Documentation](https://code.claude.com/docs) +- [Claude Cote LLM Gateway Configuration](https://code.claude.com/docs/en/llm-gateway) + +### Prerequisites +- LISA instance deployed and accessible +- LISA serve endpoint URL +- LISA API key (See API Key Management) +- A model deployed via LISA Serve + +### Setup Steps + +1. **Configure Claude Code Environment Variables**: + ```bash + # Set the base URL to your LISA endpoint + # Find it on cloudformation in the LISA-lisa-serve- stack in the outputs tab + export ANTHROPIC_BASE_URL=https://your-lisa-serve-endpoint.com # Typically ends in '/v2/serve' + + # You can generate an API key in the API Key Management Page on LISA's UI + export ANTHROPIC_AUTH_TOKEN=your-lisa-api-key + + # Specify your models, they must match your LISA model ids + export ANTHROPIC_MODEL=your-lisa-model-id + export ANTHROPIC_SMALL_FAST_MODEL=your-lisa-fast-model-id + + export ANTHROPIC_DEFAULT_SONNET_MODEL = your-lisa-model-id + export ANTHROPIC_DEFAULT_OPUS_MODEL = your-lisa-model-id + export ANTHROPIC_DEFAULT_HAIKU_MODEL = your-lisa-model-id + export CLAUDE_CODE_SUBAGENT_MODEL = your-lisa-model-id + + export CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS = 1 # Disable experimental beta options + + export CLAUDE_CODE_MAX_OUTPUT_TOKENS = 4192 # Adjust according to your model's requirements + export MAX_THINKING_TOKENS = 8192 # set to 0 to disable thinking + ``` + + +## Verification + +After configuration, verify your setup: + +```bash +# Start Claude Code with a simple prompt +claude "hello world" +``` + +### Testing in the VSCode extension +You have two options to test if the configuration is working +1. Open new claude code tab (This should reload the environment variables depending on your configuration) +2. Reload the vscode window + +## Troubleshooting + +### LISA Endpoint Issues +- Verify endpoint is accessible: `curl https://your-lisa-endpoint.com/health` +- Check API key is valid +- Confirm model names match LISA configuration diff --git a/lib/docs/config/model-compatibility.md b/lib/docs/config/model-compatibility.md index 8f7ea9b4c..44d0987dc 100644 --- a/lib/docs/config/model-compatibility.md +++ b/lib/docs/config/model-compatibility.md @@ -28,3 +28,4 @@ See the [deployment](/admin/deploy) section for details on how to set up the vLL how the HuggingFace containers will serve safetensor weights downloaded from the HuggingFace website, vLLM will do the same, and our configuration will allow you to serve these artifacts automatically. vLLM does not have many supported models for embeddings, but as they become available, LISA will support them as long as the vLLM container version is updated in the config.yaml file and as long as the model's safetensors can be found in S3. +- Please see the [vLLM Environment Variables Documentation](./vllm_variables.md) before getting started with vLLM models diff --git a/lib/docs/config/vllm_variables.md b/lib/docs/config/vllm_variables.md index 1e409386d..a7fc016d2 100644 --- a/lib/docs/config/vllm_variables.md +++ b/lib/docs/config/vllm_variables.md @@ -2,6 +2,7 @@ LISA Serve supports configuring vLLM model serving through environment variables. These variables allow you to control performance, memory usage, parallelization, and advanced features when deploying models with vLLM. - **NOTE:** Standard vLLM environment variables are supported and passed directly into the VLLM container. [See vLLM's documentation](https://docs.vllm.ai/en/latest/configuration/env_vars/) +- Review your ECS instance type's specifications to determine if the model you want LISA Serve to host has the proper VRAM/RAM capacity. Instances that have multiple GPUs may require the VLLM_TENSOR_PARALLEL_SIZE environment variable set to utilize all GPUs. ## Core Performance & Memory @@ -27,6 +28,7 @@ LISA Serve supports configuring vLLM model serving through environment variables | `VLLM_MAX_NUM_SEQS` | Maximum concurrent sequences | `256` | `128`, `512` | | `VLLM_ENABLE_PREFIX_CACHING` | Enable prefix caching for repeated prompts | `false` | `true` | | `VLLM_ENABLE_CHUNKED_PREFILL` | Enable chunked prefill | `false` | `true` | +| `VLLM_ASYNC_SCHEDULING` | Adds --async-scheduling for higher performance if hardware supported | `false` | `true` | ## Parallel Processing diff --git a/lib/rag/ingestion/ingestion-image/requirements.txt b/lib/rag/ingestion/ingestion-image/requirements.txt index 0fd25a55d..ccdea896d 100644 --- a/lib/rag/ingestion/ingestion-image/requirements.txt +++ b/lib/rag/ingestion/ingestion-image/requirements.txt @@ -8,14 +8,14 @@ numpy>=2.1.0 # Standardized to boto3==1.40.76 for compatibility across all components boto3==1.40.76 -aiohttp==3.13.2 +aiohttp==3.13.3 click==8.3.1 -cryptography==46.0.3 +cryptography==46.0.5 fastapi_utils==0.8.0 fastapi==0.124.2 gunicorn==23.0.0 langchain-community==0.4.1 -langchain-core==1.2.7 +langchain-core==1.2.14 langchain-text-splitters==1.1.0 loguru==0.7.3 mangum==0.19.0 @@ -25,13 +25,12 @@ prisma==0.15.0 psycopg2-binary==2.9.11 pydantic==2.12.5 PyJWT==2.10.1 -pynacl==1.6.1 -pypdf==6.4.1 +pynacl==1.6.2 +pypdf==6.6.2 lxml==6.0.2 python-docx==1.2.0 requests-aws4auth==1.3.1 requests==2.32.5 -text-generation==0.7.0 # ASGI Server - Version constrained by litellm[proxy]==1.81.3 in rest-api # Standardized to 0.38.0 for compatibility across all components uvicorn==0.38.0 diff --git a/lib/rag/ingestion/ingestion-job-construct.ts b/lib/rag/ingestion/ingestion-job-construct.ts index 4a546c1fa..f7e2940a7 100644 --- a/lib/rag/ingestion/ingestion-job-construct.ts +++ b/lib/rag/ingestion/ingestion-job-construct.ts @@ -170,6 +170,21 @@ export class IngestionJobConstruct extends Construct { 'BUILD_DIR': buildDirName }); + // Create execution role for ECS tasks to pull images from ECR and write logs + const executionRole = new iam.Role(this, 'BatchJobExecutionRole', { + roleName: `${config.deploymentName}-${config.deploymentStage}-batch-exec-role-${hash}`, + assumedBy: new iam.ServicePrincipal('ecs-tasks.amazonaws.com'), + description: 'Execution role for ECS Batch ingestion tasks', + }); + + // Add ECR permissions for pulling container images + executionRole.addManagedPolicy( + iam.ManagedPolicy.fromAwsManagedPolicyName('service-role/AmazonECSTaskExecutionRolePolicy') + ); + + // Grant CloudWatch Logs permissions + logGroup.grantWrite(executionRole); + // AWS Batch job definition specifying container configuration const jobDefinition = new batch.EcsJobDefinition(this, 'IngestionJobDefinition', { jobDefinitionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-job-${hash}`, @@ -180,6 +195,7 @@ export class IngestionJobConstruct extends Construct { cpu: 2, command: ['-m', 'repository.pipeline_ingestion', 'Ref::ACTION', 'Ref::DOCUMENT_ID'], jobRole: lambdaRole, + executionRole: executionRole, logging: new ecs.AwsLogDriver({ streamPrefix: 'batch-job', logGroup: logGroup @@ -200,9 +216,9 @@ export class IngestionJobConstruct extends Construct { // Lambda function for handling scheduled document ingestion - using container image const handlePipelineIngestScheduleLambda = new lambda.Function(this, 'handlePipelineIngestSchedule', { - functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-ingest-schedule-${hash}`, + functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-ingest-schedule`, runtime: getPythonRuntime(), - handler: 'repository.pipeline_ingest_documents.handle_pipline_ingest_schedule', + handler: 'repository.pipeline_ingest_handlers.handle_pipline_ingest_schedule', code: lambda.Code.fromAsset('./lambda'), timeout: Duration.seconds(60), memorySize: 256, @@ -228,9 +244,9 @@ export class IngestionJobConstruct extends Construct { // Lambda function for handling S3 event-based document ingestion - using container image const handlePipelineIngestEvent = new lambda.Function(this, 'handlePipelineIngestEvent', { - functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-ingest-event-${hash}`, + functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-ingest-event`, runtime: getPythonRuntime(), - handler: 'repository.pipeline_ingest_documents.handle_pipeline_ingest_event', + handler: 'repository.pipeline_ingest_handlers.handle_pipeline_ingest_event', code: lambda.Code.fromAsset('./lambda'), timeout: Duration.seconds(60), memorySize: 256, @@ -256,9 +272,9 @@ export class IngestionJobConstruct extends Construct { // Lambda function for handling document deletion events - using container image const handlePipelineDeleteEvent = new lambda.Function(this, 'handlePipelineDeleteEvent', { - functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-delete-event-${hash}`, + functionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-delete-event`, runtime: getPythonRuntime(), - handler: 'repository.pipeline_ingest_documents.handle_pipeline_delete_event', + handler: 'repository.pipeline_ingest_handlers.handle_pipeline_delete_event', code: lambda.Code.fromAsset('./lambda'), timeout: Duration.seconds(60), memorySize: 256, diff --git a/lib/rag/layer/requirements.txt b/lib/rag/layer/requirements.txt index 8906929e6..58d5d6142 100644 --- a/lib/rag/layer/requirements.txt +++ b/lib/rag/layer/requirements.txt @@ -12,7 +12,7 @@ # Heavy document processing happens in container: lib/rag/ingestion/ingestion-image/ # Core langchain package for vector store base classes -langchain-core==1.2.7 +langchain-core==1.2.14 # langchain-community for OpenSearchVectorSearch and PGVector classes # This package has many optional dependencies - we only need the core vectorstore functionality @@ -29,5 +29,4 @@ numpy==2.1.0 # psycopg2-binary provided by Common Layer # boto3/botocore provided by Lambda runtime - -urllib3==2.6.1 +urllib3==2.6.3 diff --git a/lib/schema/configSchema.ts b/lib/schema/configSchema.ts index b7d2c1113..0b25e1e37 100644 --- a/lib/schema/configSchema.ts +++ b/lib/schema/configSchema.ts @@ -560,6 +560,8 @@ export type TaskDefinition = z.infer; */ export const EcsBaseConfigSchema = z.object({ amiHardwareType: z.enum(AmiHardwareType).describe('Name of the model.'), + amiId: z.string().optional() + .describe('Optional AMI ID for a custom ECS machine image (e.g. ami-0123456789abcdef0). If not provided, the default ECS-optimized AMI will be used (AL2 for ADC/iso regions, AL2023 otherwise).'), autoScalingConfig: AutoScalingConfigSchema.describe('Configuration for auto scaling settings.'), buildArgs: z.record(z.string(), z.string()).optional() .describe('Optional build args to be applied when creating the task container if containerConfig.image.type is ASSET'), diff --git a/lib/serve/ecs-model/embedding/instructor/Dockerfile b/lib/serve/ecs-model/embedding/instructor/Dockerfile index bbde17aa5..ff18457c3 100644 --- a/lib/serve/ecs-model/embedding/instructor/Dockerfile +++ b/lib/serve/ecs-model/embedding/instructor/Dockerfile @@ -1,6 +1,8 @@ ARG BASE_IMAGE=public.ecr.aws/docker/library/python:3.13-slim FROM ${BASE_IMAGE} +RUN apt-get update -y && apt-get upgrade -y && rm -rf /var/lib/apt/lists/* + # Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) RUN mkdir -p /etc/ssh && \ echo "" >> /etc/ssh/ssh_config && \ diff --git a/lib/serve/ecs-model/embedding/tei/Dockerfile b/lib/serve/ecs-model/embedding/tei/Dockerfile index 1409cba4c..295e2f88e 100644 --- a/lib/serve/ecs-model/embedding/tei/Dockerfile +++ b/lib/serve/ecs-model/embedding/tei/Dockerfile @@ -20,13 +20,14 @@ RUN mkdir -p /etc/ssh && \ ##### DOWNLOAD MOUNTPOINTS S3 ARG MOUNTS3_DEB_URL ARG MOUNTS3_DEB_SHA256 -RUN apt update -y && apt install -y wget rsync && \ +RUN apt-get update -y && apt-get upgrade -y && apt-get install -y wget rsync && \ wget ${MOUNTS3_DEB_URL} -O mount-s3.deb && \ if [ -n "${MOUNTS3_DEB_SHA256}" ]; then \ echo "${MOUNTS3_DEB_SHA256} mount-s3.deb" | sha256sum -c; \ fi && \ - apt install -y ./mount-s3.deb && \ - rm mount-s3.deb + apt-get install -y ./mount-s3.deb && \ + rm mount-s3.deb && \ + rm -rf /var/lib/apt/lists/* COPY src/entrypoint.sh ./entrypoint.sh RUN chmod +x entrypoint.sh diff --git a/lib/serve/ecs-model/textgen/tgi/Dockerfile b/lib/serve/ecs-model/textgen/tgi/Dockerfile index a6ca7a891..4370882b4 100644 --- a/lib/serve/ecs-model/textgen/tgi/Dockerfile +++ b/lib/serve/ecs-model/textgen/tgi/Dockerfile @@ -19,10 +19,14 @@ RUN mkdir -p /etc/ssh && \ ##### DOWNLOAD MOUNTPOINTS S3 ARG MOUNTS3_DEB_URL -RUN apt update -y && apt install -y wget rsync && \ - wget ${MOUNTS3_DEB_URL} && \ - apt install -y ./mount-s3.deb && \ - rm mount-s3.deb +ARG MOUNTS3_DEB_SHA256 +RUN apt-get update -y && apt-get upgrade -y && apt-get install -y wget rsync && \ + wget ${MOUNTS3_DEB_URL} -O mount-s3.deb && \ + if [ -n "${MOUNTS3_DEB_SHA256}" ]; then \ + echo "${MOUNTS3_DEB_SHA256} mount-s3.deb" | sha256sum -c; \ + fi && \ + apt-get install -y ./mount-s3.deb && \ + rm mount-s3.deb && rm -rf /var/lib/apt/lists/* COPY src/entrypoint.sh ./entrypoint.sh RUN chmod +x entrypoint.sh diff --git a/lib/serve/ecs-model/vllm/Dockerfile b/lib/serve/ecs-model/vllm/Dockerfile index ad3d91d06..cab16e0e5 100644 --- a/lib/serve/ecs-model/vllm/Dockerfile +++ b/lib/serve/ecs-model/vllm/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=public.ecr.aws/deep-learning-containers/vllm:0.13-gpu-py312 +ARG BASE_IMAGE=public.ecr.aws/deep-learning-containers/vllm:0.15-gpu-py312-ec2 FROM ${BASE_IMAGE} # Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) @@ -21,11 +21,17 @@ RUN mkdir -p /etc/ssh && \ ARG MOUNTS3_DEB_URL ARG MOUNTS3_DEB_SHA256 RUN if command -v apt-get >/dev/null 2>&1; then \ - apt update -y && apt install -y wget rsync && \ - wget ${MOUNTS3_DEB_URL} && apt install -y ./mount-s3.deb && \ + apt-get update -y && apt-get upgrade -y && apt-get install -y wget rsync && \ + wget ${MOUNTS3_DEB_URL} -O mount-s3.deb && \ + if [ -n "${MOUNTS3_DEB_SHA256}" ]; then \ + echo "${MOUNTS3_DEB_SHA256} mount-s3.deb" | sha256sum -c; \ + fi && \ + apt-get install -y ./mount-s3.deb && \ rm mount-s3.deb && rm -rf /var/lib/apt/lists/*; \ elif command -v yum >/dev/null 2>&1; then \ - yum install -y wget rsync && wget ${MOUNTS3_DEB_URL} && \ + MOUNTS3_RPM_URL=$(echo ${MOUNTS3_DEB_URL} | sed 's/\.deb$/.rpm/') && \ + yum update -y && yum install -y wget rsync && \ + wget ${MOUNTS3_RPM_URL} -O mount-s3.rpm && \ yum install -y ./mount-s3.rpm && yum clean all && rm mount-s3.rpm; \ elif command -v apk >/dev/null 2>&1; then \ apk add --no-cache wget rsync && wget ${MOUNTS3_DEB_URL} && \ diff --git a/lib/serve/ecs-model/vllm/src/entrypoint.sh b/lib/serve/ecs-model/vllm/src/entrypoint.sh index b17f1ae27..4203917ac 100644 --- a/lib/serve/ecs-model/vllm/src/entrypoint.sh +++ b/lib/serve/ecs-model/vllm/src/entrypoint.sh @@ -43,9 +43,10 @@ declare -a vars=("S3_BUCKET_MODELS" "LOCAL_MODEL_PATH" "MODEL_NAME" "S3_MOUNT_PO # VLLM_BLOCK_SIZE - Memory block size (8/16/32) # VLLM_SEED - Random seed for reproducibility # VLLM_FLOAT32_MATMUL_PRECISION - Float32 matmul precision (ieee/tf32) +# VLLM_ASYNC_SCHEDULING - Adds --async-scheduling for higher performance # # ATTENTION & BACKENDS: -# VLLM_ATTENTION_BACKEND - Attention backend (FLASH_ATTN/XFORMERS/ROCM_FLASH/TORCH_SDPA/FLASHINFER/etc) +# VLLM_ATTENTION_BACKEND - Attention backend, read natively by vLLM (FLASH_ATTN/FLASHINFER/XFORMERS) # VLLM_ENABLE_PREFIX_CACHING - Enable prefix caching (true/false, default: true) # VLLM_ENABLE_CHUNKED_PREFILL - Enable chunked prefill (true/false, default: true) # VLLM_MAX_CHUNKED_PREFILL_TOKENS - Max tokens per prefill chunk @@ -150,14 +151,17 @@ TOTAL_MEM_GB=$((TOTAL_MEM_KB / 1024 / 1024)) echo "Total system memory: ${TOTAL_MEM_GB}GB" # Check GPU availability +GPU_MEM_GB=0 if command -v nvidia-smi &> /dev/null; then - GPU_INFO=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | head -1) - GPU_MEM_MB=${GPU_INFO} - GPU_MEM_GB=$((GPU_MEM_MB / 1024)) - echo "GPU memory available: ${GPU_MEM_GB}GB" + GPU_INFO=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '[:space:]') + if [[ "${GPU_INFO}" =~ ^[0-9]+$ ]]; then + GPU_MEM_GB=$((GPU_INFO / 1024)) + echo "GPU memory available: ${GPU_MEM_GB}GB" + else + echo "Warning: nvidia-smi returned unexpected output: '${GPU_INFO}', assuming no GPU" + fi else echo "No GPU detected or nvidia-smi not available" - GPU_MEM_GB=0 fi # Memory warnings and recommendations @@ -172,11 +176,12 @@ fi # Validate tensor parallel configuration if [[ -n "${VLLM_TENSOR_PARALLEL_SIZE}" ]] && [[ ${VLLM_TENSOR_PARALLEL_SIZE} -gt 1 ]]; then - if [[ ${GPU_MEM_GB} -eq 0 ]]; then - echo "Error: Tensor parallelism requires GPU but no GPU detected" - exit 1 + GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l || echo 0) + if [[ ${GPU_COUNT} -eq 0 ]]; then + echo "Warning: Tensor parallelism requested (${VLLM_TENSOR_PARALLEL_SIZE}) but no GPUs detected - proceeding anyway" + else + echo "Using tensor parallelism with ${VLLM_TENSOR_PARALLEL_SIZE} GPUs (${GPU_COUNT} detected)" fi - echo "Using tensor parallelism with ${VLLM_TENSOR_PARALLEL_SIZE} GPUs" fi # Start the webserver @@ -233,6 +238,9 @@ if [[ -n "${VLLM_TENSOR_PARALLEL_SIZE}" ]]; then echo " --tensor-parallel-size ${VLLM_TENSOR_PARALLEL_SIZE}" fi +# Attention backend override - read natively by vLLM as env var, no CLI arg needed +# Valid values: FLASH_ATTN, FLASHINFER, XFORMERS + # Quantization method if [[ -n "${VLLM_QUANTIZATION}" ]]; then ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --quantization ${VLLM_QUANTIZATION}" @@ -259,9 +267,23 @@ if [[ -n "${VLLM_TOOL_CALL_PARSER}" ]]; then echo " --tool-call-parser ${VLLM_TOOL_CALL_PARSER}" fi +if [[ "${VLLM_ASYNC_SCHEDULING}" == "true" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --async-scheduling" + echo " --async-scheduling" +fi + +# Max parallel loading workers (avoids RAM OOM with tensor parallelism + large models) +if [[ -n "${VLLM_MAX_PARALLEL_LOADING_WORKERS}" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --max-parallel-loading-workers ${VLLM_MAX_PARALLEL_LOADING_WORKERS}" + echo " --max-parallel-loading-workers ${VLLM_MAX_PARALLEL_LOADING_WORKERS}" +fi + echo "Starting vLLM with args: ${ADDITIONAL_ARGS}" -echo "vLLM environment variables:" -env | grep -E "^(VLLM_|MAX_TOTAL_TOKENS)=" || echo "No vLLM environment variables set" +# Print all VLLM_ environment variables at startup +echo "=== VLLM Environment Variables ===" +env | grep -E "^VLLM_" || echo "No VLLM_ environment variables set" +echo "===================================" + python3 -m vllm.entrypoints.openai.api_server \ --model ${LOCAL_MODEL_PATH} \ diff --git a/lib/serve/mcp-workbench/pyproject.toml b/lib/serve/mcp-workbench/pyproject.toml index de687c4f6..73ad83aa5 100644 --- a/lib/serve/mcp-workbench/pyproject.toml +++ b/lib/serve/mcp-workbench/pyproject.toml @@ -15,9 +15,9 @@ dependencies = [ "click==8.3.1", "starlette>=0.40.0,<0.51.0", "uvicorn>=0.31.1,<0.32.0", - "aiohttp==3.13.2", + "aiohttp==3.13.3", "boto3==1.40.76", - "cryptography==46.0.3", + "cryptography==46.0.5", "gunicorn>=23.0.0,<24.0.0", "pydantic>=2.5.0,<3.0.0", "PyJWT>=2.10.1,<3.0.0", diff --git a/lib/serve/rest-api/.coveragerc b/lib/serve/rest-api/.coveragerc index 5d1b54764..10210bfe9 100644 --- a/lib/serve/rest-api/.coveragerc +++ b/lib/serve/rest-api/.coveragerc @@ -6,13 +6,8 @@ omit = */src/*/*/*/__init__.py */src/*/*/*/*/__init__.py # Exclude FastAPI endpoint wrappers (thin wrappers around handlers) - */src/api/endpoints/v1/*.py */src/api/endpoints/v2/*.py # Exclude main application file (requires full integration test) */src/main.py - # Exclude model adapters (require actual model endpoints) - */src/lisa_serve/ecs/textgen/*.py - */src/lisa_serve/ecs/embedding/*.py - */src/lisa_serve/base/*.py # Exclude routes (thin wrappers) */src/api/routes.py diff --git a/lib/serve/rest-api/src/api/endpoints/v1/__init__.py b/lib/serve/rest-api/src/api/endpoints/v1/__init__.py deleted file mode 100644 index 509c4c9e2..000000000 --- a/lib/serve/rest-api/src/api/endpoints/v1/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Public imports.""" diff --git a/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py b/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py deleted file mode 100644 index 1d7932b76..000000000 --- a/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Embedding routes.""" - -import logging - -from fastapi import APIRouter -from fastapi.responses import JSONResponse -from starlette.status import HTTP_200_OK - -from ....handlers.embeddings import handle_embeddings -from ....utils.resources import EmbeddingsRequest, RestApiResource - -logger = logging.getLogger(__name__) - -router = APIRouter() - - -@router.post(f"/{RestApiResource.EMBEDDINGS}") -async def embeddings(request: EmbeddingsRequest) -> JSONResponse: - """Text embeddings.""" - response = await handle_embeddings(request.dict()) - - return JSONResponse(content=response, status_code=HTTP_200_OK) diff --git a/lib/serve/rest-api/src/api/endpoints/v1/generation.py b/lib/serve/rest-api/src/api/endpoints/v1/generation.py deleted file mode 100644 index 85ed986cd..000000000 --- a/lib/serve/rest-api/src/api/endpoints/v1/generation.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Generation routes.""" - -import logging - -from fastapi import APIRouter -from fastapi.responses import JSONResponse, StreamingResponse -from starlette.status import HTTP_200_OK - -from ....handlers.generation import handle_generate, handle_generate_stream, handle_openai_generate_stream -from ....utils.resources import ( - GenerateRequest, - GenerateStreamRequest, - OpenAIChatCompletionsRequest, - OpenAICompletionsRequest, - RestApiResource, -) - -logger = logging.getLogger(__name__) - -router = APIRouter() - - -@router.post(f"/{RestApiResource.GENERATE}") -async def generate(request: GenerateRequest) -> JSONResponse: - """Text generation.""" - response = await handle_generate(request.dict()) - - return JSONResponse(content=response, status_code=HTTP_200_OK) - - -@router.post(f"/{RestApiResource.GENERATE_STREAM}") -async def generate_stream(request: GenerateStreamRequest) -> StreamingResponse: - """Text generation with streaming.""" - return StreamingResponse( - handle_generate_stream(request.dict()), - media_type="text/event-stream", - ) - - -@router.post(f"/{RestApiResource.OPENAI_CHAT_COMPLETIONS}") -async def openai_chat_completion_generate_stream(request: OpenAIChatCompletionsRequest) -> StreamingResponse: - """Text generation with streaming.""" - return StreamingResponse( - handle_openai_generate_stream(request.dict()), - media_type="text/event-stream", - ) - - -@router.post(f"/{RestApiResource.OPENAI_COMPLETIONS}") -async def openai_completion_generate_stream(request: OpenAICompletionsRequest) -> StreamingResponse: - """Text generation with streaming.""" - return StreamingResponse( - handle_openai_generate_stream(request.dict(), is_text_completion=True), - media_type="text/event-stream", - ) diff --git a/lib/serve/rest-api/src/api/endpoints/v1/models.py b/lib/serve/rest-api/src/api/endpoints/v1/models.py deleted file mode 100644 index 3fdef4ed7..000000000 --- a/lib/serve/rest-api/src/api/endpoints/v1/models.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model information routes.""" - -import logging - -from fastapi import APIRouter, Query -from fastapi.responses import JSONResponse -from starlette.status import HTTP_200_OK - -from ....handlers.models import ( - handle_describe_model, - handle_describe_models, - handle_list_models, - handle_openai_list_models, -) -from ....utils.resources import ModelType, RestApiResource - -logger = logging.getLogger(__name__) - -router = APIRouter() - - -@router.get(f"/{RestApiResource.DESCRIBE_MODEL}") -async def describe_model( - provider: str = Query( - None, - description="Model provider name.", - alias="provider", - ), - model_name: str = Query( - None, - description="Name of model.", - alias="modelName", - ), -) -> JSONResponse: - """Describe model by provider and model name.""" - response = await handle_describe_model(provider, model_name) - - return JSONResponse(content=response, status_code=HTTP_200_OK) - - -@router.get(f"/{RestApiResource.DESCRIBE_MODELS}") -async def describe_models( - model_types: list[ModelType] | None = Query( - None, - description="The types of models to list. If not provided, all types will be listed.", - alias="modelTypes", - ), -) -> JSONResponse: - """Describe models by model type.""" - if model_types is None: - model_types = list(ModelType) - - response = await handle_describe_models(model_types) - - return JSONResponse(content=response, status_code=HTTP_200_OK) - - -@router.get(f"/{RestApiResource.LIST_MODELS}") -async def list_models( - model_types: list[ModelType] | None = Query( - None, - description="The types of models to list. If not provided, all types will be listed.", - alias="modelTypes", - ), -) -> JSONResponse: - """List models by model type.""" - if model_types is None: - model_types = list(ModelType) - - response = await handle_list_models(model_types) - - return JSONResponse(content=response, status_code=HTTP_200_OK) - - -@router.get(f"/{RestApiResource.OPENAI_LIST_MODELS}") -async def openai_list_models() -> JSONResponse: - """List models for OpenAI Compatibility. Only returns TEXTGEN models.""" - response = await handle_openai_list_models() - return JSONResponse(content=response, status_code=HTTP_200_OK) diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index ba3b581bb..58e206034 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -17,8 +17,10 @@ import json import logging import os +import time import uuid from collections.abc import Iterator +from typing import Any, cast import boto3 from auth import extract_user_groups_from_jwt @@ -54,6 +56,14 @@ router = APIRouter() +# Model info cache with TTL (Time To Live) +# Cache structure: {model_id: {"data": model_info, "timestamp": cache_time}} +_model_info_cache: dict[str, dict[str, Any]] = {} +# Cache TTL in seconds (default: 5 minutes) +# Stale entries are automatically refreshed when accessed after TTL expires +# This ensures deleted/recreated models get fresh data within 5 minutes +MODEL_INFO_CACHE_TTL = int(os.environ.get("MODEL_INFO_CACHE_TTL", "300")) + def _generate_presigned_video_url(key: str, content_type: str = "video/mp4") -> str: """Generate a presigned URL for video content stored in S3.""" @@ -159,6 +169,85 @@ def handle_guardrail_violation_response( return None +def invalidate_model_cache(model_id: str | None = None) -> None: + """ + Manually invalidate model info cache. + + Note: This function is available for manual/programmatic cache clearing but is not + automatically triggered. The cache relies on TTL expiration for normal operation. + + Args: + model_id: Specific model to invalidate. If None, clears entire cache. + """ + if model_id is None: + _model_info_cache.clear() + logger.info("Cleared entire model info cache") + elif model_id in _model_info_cache: + del _model_info_cache[model_id] + logger.info(f"Invalidated cache for model {model_id}") + + +def get_model_info(model_id: str, use_cache: bool = True) -> dict | None: + """ + Get model information from LiteLLM for a given model ID. + + Uses a TTL-based cache to reduce API calls while ensuring deleted/recreated + models are eventually refreshed. + + Args: + model_id: User-defined model ID (model_name in LiteLLM) + use_cache: Whether to use cached data (default True). Set False to force refresh. + + Returns: + Model info dict with litellm_params, or None if not found + """ + current_time = time.time() + + # Check cache first if enabled + if use_cache and model_id in _model_info_cache: + cache_entry = _model_info_cache[model_id] + cache_age = current_time - cache_entry["timestamp"] + + # Return cached data if still fresh + if cache_age < MODEL_INFO_CACHE_TTL: + logger.debug(f"Cache hit for model {model_id} (age: {cache_age:.1f}s)") + model_info = cast(dict[Any, Any], cache_entry["data"]) + return model_info + else: + logger.debug(f"Cache expired for model {model_id} (age: {cache_age:.1f}s)") + + # Cache miss or expired - fetch from LiteLLM + try: + headers = {"Authorization": f"Bearer {LITELLM_KEY}"} + response = requests_request( + method="GET", + url=f"{LITELLM_URL}/model/info", + headers=headers, + timeout=2, + ) + + if response.status_code == HTTP_200_OK: + all_models = response.json().get("data", []) + # Filter to find the specific model by model_name + for model in all_models: + if model.get("model_name") == model_id: + model_info = cast(dict[Any, Any], model) + # Update cache + _model_info_cache[model_id] = {"data": model_info, "timestamp": current_time} + logger.debug(f"Cached model info for {model_id}") + return model_info + + # Model not found - remove from cache if present + if model_id in _model_info_cache: + logger.info(f"Model {model_id} no longer exists, removing from cache") + del _model_info_cache[model_id] + + except Exception as e: + logger.error(f"Failed to get model info for {model_id}: {e}") + + return None + + def generate_response(iterator: Iterator[str | bytes]) -> Iterator[str]: """For streaming responses, generate strings instead of bytes objects so that clients recognize the LLM output.""" for line in iterator: @@ -380,25 +469,34 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: logger.error(f"Invalid JSON in request body: {e}") raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Invalid JSON in request body") + # Get model info from LiteLLM to determine the actual model provider path + model_id = params.get("model") + model_name = None # The actual provider/model path (e.g., "bedrock/us.anthropic.claude...") + if model_id: + model_info = get_model_info(model_id) + if model_info: + model_name = model_info.get("litellm_params", {}).get("model") + logger.debug(f"model_id: {model_id}, model_name: {model_name}") + # Apply guardrails BEFORE sending to LiteLLM for chat/completions requests # This adds guardrail configuration to the request so LiteLLM enforces them is_chat_completion = is_chat_route(api_path) if is_chat_completion: - model_id = params.get("model") if model_id and jwt_data: await apply_guardrails_to_request(params, model_id, jwt_data) # Validate and cap max_tokens if needed for Claude Code requests if is_anthropic_route(api_path): - model_id = params.get("model") - - # Check for anthropic specific headers + # Check for anthropic specific headers and reset the max token parameter to None + # so LiteLLM handles the max_token value. Only if it's not an Anthropic model if model_id and "anthropic-beta" in headers and "anthropic-version" in headers: - # reset max token parameter to null so LiteLLM handles the max_token value - if "max_tokens" in params: - params["max_tokens"] = None - if "max_completion_tokens" in params: - params["max_completion_tokens"] = None + + # Only nullify max_tokens if the model is NOT an Anthropic model + if model_name and ".anthropic" not in model_name: + if "max_tokens" in params: + params["max_tokens"] = None + if "max_completion_tokens" in params: + params["max_completion_tokens"] = None is_streaming = params.get("stream", False) if is_streaming: diff --git a/lib/serve/rest-api/src/handlers/embeddings.py b/lib/serve/rest-api/src/handlers/embeddings.py deleted file mode 100644 index f6fcdde43..000000000 --- a/lib/serve/rest-api/src/handlers/embeddings.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Embedding route handlers.""" -import logging -from typing import Any - -from utils.request_utils import RegistryProtocol, validate_and_prepare_llm_request -from utils.resources import RestApiResource - -logger = logging.getLogger(__name__) - - -async def handle_embeddings(request_data: dict[str, Any], registry: RegistryProtocol | None = None) -> dict[str, Any]: - """Handle for embeddings endpoint. - - Parameters - ---------- - request_data : dict[str, Any] - Request data - registry : RegistryProtocol | None - Optional registry for dependency injection (testing) - - Returns - ------- - dict[str, Any] - Embeddings response - """ - model, model_kwargs, text = await validate_and_prepare_llm_request( - request_data, RestApiResource.EMBEDDINGS, registry - ) - response = await model.embed_query(text=text, model_kwargs=model_kwargs) - - return response.dict() # type: ignore diff --git a/lib/serve/rest-api/src/handlers/generation.py b/lib/serve/rest-api/src/handlers/generation.py deleted file mode 100644 index a796c98b0..000000000 --- a/lib/serve/rest-api/src/handlers/generation.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Generation route handlers - refactored for testability.""" -import json -import logging -from collections.abc import AsyncGenerator -from typing import Any - -from services.text_processing import ( - map_openai_params_to_lisa, - parse_model_provider_from_string, - render_context_from_messages, -) -from utils.request_utils import ( - handle_stream_exceptions, - RegistryProtocol, - validate_and_prepare_llm_request, -) -from utils.resources import RestApiResource - -logger = logging.getLogger(__name__) - - -async def handle_generate(request_data: dict[str, Any], registry: RegistryProtocol | None = None) -> dict[str, Any]: - """Handle for generate endpoint. - - Parameters - ---------- - request_data : dict[str, Any] - Request data - registry : RegistryProtocol | None - Optional registry for dependency injection (testing) - - Returns - ------- - dict[str, Any] - Generation response - """ - model, model_kwargs, text = await validate_and_prepare_llm_request(request_data, RestApiResource.GENERATE, registry) - try: - response = await model.generate(text=text, model_kwargs=model_kwargs) - return response.dict() # type: ignore - except Exception as e: - logger.error(f"Model generation failed: {e}") - raise - - -@handle_stream_exceptions -async def handle_generate_stream( - request_data: dict[str, Any], registry: RegistryProtocol | None = None -) -> AsyncGenerator[str]: - """Handle for generate_stream endpoint. - - Parameters - ---------- - request_data : dict[str, Any] - Request data - registry : RegistryProtocol | None - Optional registry for dependency injection (testing) - - Yields - ------ - str - Streaming response chunks - """ - model, model_kwargs, text = await validate_and_prepare_llm_request( - request_data, RestApiResource.GENERATE_STREAM, registry - ) - async for response in model.generate_stream(text=text, model_kwargs=model_kwargs): - yield f"data:{json.dumps(response.dict(exclude_none=True))}\n\n" - - -@handle_stream_exceptions -async def handle_openai_generate_stream( - request_data: dict[str, Any], is_text_completion: bool = False, registry: RegistryProtocol | None = None -) -> AsyncGenerator[str]: - """Handle for openai_generate_stream endpoint. - - Parameters - ---------- - request_data : dict[str, Any] - Request data - is_text_completion : bool - Whether this is a text completion request - registry : RegistryProtocol | None - Optional registry for dependency injection (testing) - - Yields - ------ - str - Streaming response chunks - """ - # Map OpenAI parameters to LISA parameters - mapped_kwargs = map_openai_params_to_lisa(request_data) - - # Extract text based on completion type - if is_text_completion: - text = request_data["prompt"] # text is already a string - else: - text = render_context_from_messages(request_data["messages"]) # convert list to string - - # Parse model and provider - model_name, provider = parse_model_provider_from_string(request_data["model"]) - - # Build LISA request - lisa_request_data = { - "modelName": model_name, - "provider": provider, - "text": text, - "streaming": request_data.get("stream", False), - "modelKwargs": mapped_kwargs, - } - - model, model_kwargs, text = await validate_and_prepare_llm_request( - lisa_request_data, RestApiResource.GENERATE_STREAM, registry - ) - - async for response in model.openai_generate_stream( - text=text, model_kwargs=model_kwargs, is_text_completion=is_text_completion - ): - yield f"data:{json.dumps(response.dict(exclude_none=True))}\n\n" - - if is_text_completion: - yield "data: [DONE]\n\n" - - -# Keep backward compatibility - these are now just aliases to the service functions -render_context = render_context_from_messages -parse_model_provider_names = parse_model_provider_from_string diff --git a/lib/serve/rest-api/src/handlers/models.py b/lib/serve/rest-api/src/handlers/models.py deleted file mode 100644 index bc8fcafb5..000000000 --- a/lib/serve/rest-api/src/handlers/models.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model route handlers - refactored for testability.""" - -import logging -from typing import Any, DefaultDict - -from fastapi import HTTPException -from services.model_service import ModelService -from starlette.status import HTTP_404_NOT_FOUND -from utils.cache_manager import get_registered_models_cache -from utils.resources import ModelType - -logger = logging.getLogger(__name__) - - -def _get_model_service() -> ModelService: - """Factory function to create ModelService with current cache. - - This allows for dependency injection in tests. - """ - return ModelService(get_registered_models_cache()) - - -async def handle_list_models( - model_types: list[ModelType], model_service: ModelService | None = None -) -> dict[ModelType, dict[str, list[str]]]: - """Handle for list_models endpoint. - - Parameters - ---------- - model_types : List[ModelType] - Model types to list - model_service : ModelService | None - Optional model service for dependency injection (testing) - - Returns - ------- - Dict[ModelType, Dict[str, List[str]]] - List of model names by model type and model provider - """ - service = model_service or _get_model_service() - return service.list_models(model_types) - - -async def handle_openai_list_models(model_service: ModelService | None = None) -> dict[str, Any]: - """Handle for list_models endpoint. - - Parameters - ---------- - model_service : ModelService | None - Optional model service for dependency injection (testing) - - Returns - ------- - Dict[str, Any] - OpenAI-compatible response object to list Models - """ - service = model_service or _get_model_service() - return service.list_models_openai_format() - - -async def handle_describe_model( - provider: str, model_name: str, model_service: ModelService | None = None -) -> dict[str, Any]: - """Handle for describe_model endpoint. - - Parameters - ---------- - provider : str - Model provider name - model_name : str - Model name - model_service : ModelService | None - Optional model service for dependency injection (testing) - - Returns - ------- - Dict[str, Any] - Model metadata - - Raises - ------ - HTTPException - If model metadata not found - """ - service = model_service or _get_model_service() - metadata = service.get_model_metadata(provider, model_name) - - if not metadata: - error_message = f"Metadata for provider {provider} and model {model_name} not found." - logger.error(error_message, extra={"event": "handle_describe_model", "status": "ERROR"}) - raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=error_message) - - return metadata - - -async def handle_describe_models( - model_types: list[ModelType], model_service: ModelService | None = None -) -> DefaultDict[str, DefaultDict[str, dict[str, Any]]]: - """Handle for describe_models endpoint. - - Parameters - ---------- - model_types : List[ModelType] - Model types to list - model_service : ModelService | None - Optional model service for dependency injection (testing) - - Returns - ------- - DefaultDict[str, DefaultDict[str, Dict[str, Any]]] - Model metadata by model type, model provider, and model name - """ - service = model_service or _get_model_service() - return service.describe_models(model_types) diff --git a/lib/serve/rest-api/src/lisa_serve/__init__.py b/lib/serve/rest-api/src/lisa_serve/__init__.py index 852a6cad8..63d616dcf 100644 --- a/lib/serve/rest-api/src/lisa_serve/__init__.py +++ b/lib/serve/rest-api/src/lisa_serve/__init__.py @@ -18,8 +18,6 @@ from loguru import logger -from .ecs import * # noqa: F403,F401 - # Configure custom logger logger.remove() logger_level = os.environ.get("LOG_LEVEL", "INFO") diff --git a/lib/serve/rest-api/src/lisa_serve/base/__init__.py b/lib/serve/rest-api/src/lisa_serve/base/__init__.py deleted file mode 100644 index 034d2b609..000000000 --- a/lib/serve/rest-api/src/lisa_serve/base/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Public imports.""" -# flake8: noqa -from .base import * diff --git a/lib/serve/rest-api/src/lisa_serve/base/base.py b/lib/serve/rest-api/src/lisa_serve/base/base.py deleted file mode 100644 index 964b8de99..000000000 --- a/lib/serve/rest-api/src/lisa_serve/base/base.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base model adapters and responses.""" -import re -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator -from typing import Any - -from pydantic import BaseModel, Field - -############# -# RESPONSES # -############# - - -class EmbedQueryResponse(BaseModel): - """Response for embed_query method.""" - - embeddings: list[list[float]] = Field(..., description="Batch of text embeddings.") - - -class GenerateResponse(BaseModel): - """Response for generate method.""" - - generatedText: str = Field(..., description="Generated text.") - generatedTokens: int | None = Field(..., description="Number of generated tokens.") - finishReason: str | None = Field(None, description="Reason for finishing text generation.") - - -class Token(BaseModel): - """Token for generate_stream method.""" - - text: str = Field(..., description="Token text.") - special: bool | None = Field(None, description="Whether token is a special token.") - - -class GenerateStreamResponse(BaseModel): - """Response for generate_stream method.""" - - token: Token - generatedTokens: int | None = Field(..., description="Number of generated tokens.") - finishReason: str | None = Field(None, description="Reason for finishing text generation.") - - -class OpenAIChatCompletionsDelta(BaseModel): - """Token content from Chat Completions endpoint.""" - - content: str = Field(..., description="The contents of the chunk message.") - role: str = Field(..., description="The role of the author of this message.") - - -class OpenAIChatCompletionsChoice(BaseModel): - """Text choice object from Chat Completions endpoint.""" - - delta: OpenAIChatCompletionsDelta = Field( - ..., description="A chat completion delta generated by streamed model responses." - ) - finish_reason: str | None = Field(..., description="The reason the model stopped generating tokens.") - index: int = Field(..., description="The index of the choice in the list of choices.") - - -class OpenAIChatCompletionsResponse(BaseModel): - """Response from Chat Completions endpoint.""" - - id: str = Field(..., description="A unique identifier for the chat completion. Each chunk has the same ID.") - choices: list[OpenAIChatCompletionsChoice] = Field( - ..., description="A list of chat completion choices. Can be more than one if n is greater than 1." - ) - created: int = Field( - ..., - description=" ".join( - [ - "The Unix timestamp (in seconds) of when the chat completion was created.", - "Each chunk has the same timestamp.", - ] - ), - ) - model: str = Field(..., description="The model to generate the completion.") - system_fingerprint: str = Field( - ..., description="This fingerprint represents the backend configuration that the model runs with." - ) - object: str = Field("chat.completion.chunk", description="The object type, which is always chat.completion.chunk.") - - -class OpenAICompletionsChoice(BaseModel): - """Text choice object from Completions endpoint.""" - - text: str = Field(..., description="A chat completion delta generated by streamed model responses.") - finish_reason: str | None = Field(..., description="The reason the model stopped generating tokens.") - index: int = Field(..., description="The index of the choice in the list of choices.") - - -class OpenAICompletionsResponse(BaseModel): - """Response from Completions endpoint.""" - - id: str = Field(..., description="A unique identifier for the chat completion. Each chunk has the same ID.") - choices: list[OpenAICompletionsChoice] = Field( - ..., description="A list of chat completion choices. Can be more than one if n is greater than 1." - ) - created: int = Field( - ..., - description=" ".join( - [ - "The Unix timestamp (in seconds) of when the chat completion was created.", - "Each chunk has the same timestamp.", - ] - ), - ) - model: str = Field(..., description="The model to generate the completion.") - system_fingerprint: str = Field( - ..., description="This fingerprint represents the backend configuration that the model runs with." - ) - object: str = Field("text_completion", description="The object type, which is always chat.completion.chunk.") - - -############ -# ADAPTERS # -############ - - -class EmbeddingModelAdapter(ABC): - """Abstract base class for embedding model adapters. - - Parameters - ---------- - model_name : str - Model name. - - endpoint_url : str, default=None - Endpoint URL. - """ - - def __init__(self, *, model_name: str, endpoint_url: str | None = None) -> None: - self.model_name = model_name - self.endpoint_url = endpoint_url - - @abstractmethod - def embed_query(self, *, text: str, model_kwargs: dict[str, Any]) -> EmbedQueryResponse: - """Embed query. - - Parameters - ---------- - text : str - Input text to embed. - - model_kwargs : Dict[str, Any] - Arguments to embedding model. - - Returns - ------- - EmbedQueryResponse - Embedding model response. - """ - pass - - -class TextGenModelAdapter(ABC): - """Abstract base class for text generation model adapters. - - Parameters - ---------- - model_name : str - Model name. - - endpoint_url : str, default=None - Endpoint URL. - """ - - def __init__(self, *, model_name: str, endpoint_url: str | None = None) -> None: - self.model_name = model_name - self.endpoint_url = endpoint_url - - @abstractmethod - def generate(self, *, text: str, model_kwargs: dict[str, Any]) -> GenerateResponse: - """Text generation. - - Parameters - ---------- - text : str - Prompt input text. - - model_kwargs : Dict[str, Any] - Arguments to text generation model. - - Returns - ------- - GenerateResponse - Text generation model response. - """ - pass - - -class StreamTextGenModelAdapter(ABC): - """Abstract base class for text generation model adapters with streaming option.""" - - @abstractmethod - def generate_stream( - self, - *, - text: str, - model_kwargs: dict[str, Any], - ) -> AsyncGenerator[GenerateStreamResponse]: - """Text generation with token streaming. - - Parameters - ---------- - text : str - Prompt input text. - - model_kwargs : Dict[str, Any] - Arguments to text generation model. - - Returns - ------- - AsyncGenerator[GenerateStreamResponse, None] - Text generation model response with streaming. - """ - pass - - -def escape_curly_brackets(s: str) -> str: - """Escapes curly brackets in the given string for downstream use with `str.format()`. - - Parameters - ---------- - s : str - String to be escaped. - - Returns - ------- - str - Escaped string. - """ - return re.sub(r"({|})", r"\1\1", s) diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/__init__.py b/lib/serve/rest-api/src/lisa_serve/ecs/__init__.py deleted file mode 100644 index 4b87aa82b..000000000 --- a/lib/serve/rest-api/src/lisa_serve/ecs/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Public imports.""" -# flake8: noqa -from .embedding import * -from .textgen import * diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/__init__.py b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/__init__.py deleted file mode 100644 index 328123eac..000000000 --- a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Public imports.""" -# flake8: noqa -from .instructor import * -from .tei import * diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py deleted file mode 100644 index 2bd532b19..000000000 --- a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model adapter and kwargs validator for ECS embedding instructor model endpoints.""" -from typing import Any - -from aiohttp import ClientSession -from loguru import logger -from pydantic import BaseModel - -from ...base import EmbeddingModelAdapter, EmbedQueryResponse, escape_curly_brackets -from ...registry import registry - - -class EcsEmbeddingInstructorValidator(BaseModel): - """Model kwargs validator for ECS embedding instructor model endpoints. - - Parameters - ---------- - instruction : str, default="Represent the document:" - Instructor for customized embeddings. - """ - - instruction: str = "Represent the document:" - - -class EcsEmbeddingInstructorAdapter(EmbeddingModelAdapter): - """Model adapter for ECS embedding instructor model endpoints. - - Parameters - ---------- - model_name : str - Model name. - - endpoint_url : str - Endpoint URL. - """ - - def __init__(self, *, model_name: str, endpoint_url: str) -> None: - super().__init__(model_name=model_name, endpoint_url=endpoint_url) - - # PyTorch DLC has the endpoint at path /predictions/model - self.endpoint_url = f"{self.endpoint_url.rstrip('/')}/predictions/model" # type: ignore - - async def embed_query(self, *, text: str, model_kwargs: dict[str, Any]) -> EmbedQueryResponse: # type: ignore - """Embed data. - - Parameters - ---------- - text : str - Input text to embed. - - model_kwargs : Dict[str, Any] - Arguments and configurations specific to the model. - - Returns - ------- - EmbedQueryResponse - Embedding model response. - """ - # Unpack instruction - instruction = model_kwargs["instruction"] - payload = { - "instruction": instruction, - "text": text, - } - - try: - async with ClientSession() as session: - async with session.post(self.endpoint_url, json=payload) as server_response: - server_response.raise_for_status() - server_response_json = await server_response.json() - - response = EmbedQueryResponse(embeddings=server_response_json) - - logger.debug( - f"Received: {escape_curly_brackets(response.json())}", - extra={"event": f"{self.__class__.__name__}:embed_query"}, - ) - return response - except Exception as e: - logger.error(f"Embedding request failed: {e}") - raise - - -# Register the model -registry.register( - provider="ecs.embedding.instructor", - adapter=EcsEmbeddingInstructorAdapter, - validator=EcsEmbeddingInstructorValidator, -) diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/tei.py b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/tei.py deleted file mode 100644 index df2e2f21f..000000000 --- a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/tei.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model adapter and kwargs validator for ECS embedding instructor model endpoints.""" -from typing import Any - -from aiohttp import ClientSession -from loguru import logger -from pydantic import BaseModel - -from ...base import EmbeddingModelAdapter, EmbedQueryResponse, escape_curly_brackets -from ...registry import registry - - -class EcsEmbeddingTeiValidator(BaseModel): - """Model kwargs validator for ECS TEI model endpoints. - - Parameters - ---------- - normalize : bool, default=True - Normalizes embeddings when enabled. - truncate : bool, default=True - Truncates inputs to model context length when enabled. - """ - - normalize: bool = True - truncate: bool = True - - -class EcsEmbeddingTeiAdapter(EmbeddingModelAdapter): - """Model adapter for ECS TEI model endpoints. - - Parameters - ---------- - model_name : str - Model name. - - endpoint_url : str - Endpoint URL. - """ - - def __init__(self, *, model_name: str, endpoint_url: str) -> None: - super().__init__(model_name=model_name, endpoint_url=endpoint_url) - - self.endpoint_url = endpoint_url.rstrip("/") - - async def embed_query(self, *, text: str | list[str], model_kwargs: dict[str, Any]) -> EmbedQueryResponse: # type: ignore # noqa: E501 - """Embed data. - - Parameters - ---------- - text : Union[str, list[str]] - Input text(s) to embed. - - model_kwargs : Dict[str, Any] - Arguments and configurations specific to the model. - - Returns - ------- - EmbedQueryResponse - Embedding model response. - """ - # Unpack instruction - payload = {"inputs": text, **model_kwargs} - - async with ClientSession() as session: - async with session.post( - self.endpoint_url, json=payload, headers={"Content-Type": "application/json"} - ) as server_response: - server_response.raise_for_status() - server_response_json = await server_response.json() - - response = EmbedQueryResponse(embeddings=server_response_json) - - logger.debug( - f"Received: {escape_curly_brackets(response.json())}", - extra={"event": f"{self.__class__.__name__}:embed_query"}, - ) - return response - - -# Register the model -registry.register( - provider="ecs.embedding.tei", - adapter=EcsEmbeddingTeiAdapter, - validator=EcsEmbeddingTeiValidator, -) diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/__init__.py b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/__init__.py deleted file mode 100644 index 5f36930dc..000000000 --- a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Public imports.""" -# flake8: noqa -from .tgi import * diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py deleted file mode 100644 index 29a4dcde6..000000000 --- a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model adapter and kwargs validator for ECS text generation TGI model endpoints.""" -import time -import uuid -from collections.abc import AsyncGenerator -from typing import Any - -from loguru import logger -from pydantic import BaseModel, confloat, Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt -from text_generation import AsyncClient - -from ...base import ( - escape_curly_brackets, - GenerateResponse, - GenerateStreamResponse, - OpenAIChatCompletionsChoice, - OpenAIChatCompletionsDelta, - OpenAIChatCompletionsResponse, - OpenAICompletionsChoice, - OpenAICompletionsResponse, - StreamTextGenModelAdapter, - TextGenModelAdapter, - Token, -) -from ...registry import registry - - -class EcsTextGenTgiValidator(BaseModel): - """Model kwargs validator for ECS text generation TGI model endpoints. - - Parameters - ---------- - max_new_tokens : int, default=50 - Maximum number of generated tokens. - - top_k : int, default=None - The number of highest probability vocabulary tokens to keep for top-k-filtering. - - top_p : float, default=None - If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` - or higher are kept for generation. - - typical_p : float, default=None - Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) - for more information. - - temperature : float, default=None - Value used to divide the logits distribution. - - repetition_penalty : float, default=None - Penalty to add for text repetition. - - return_full_text : bool, default=False - Whether to prepend the prompt to the generated text. - - truncate : int, default=None - Whether to truncate input tokens to given size. - - stop_sequences : List[str], default=[] - Stop generating tokens if a member of `stop` is generated. - - seed : int, default=None - Random sampling seed. - - do_sample : bool, default=False - Activate logits sampling. - - watermark : bool, default=False - Watermark output response. - """ - - max_new_tokens: NonNegativeInt = 50 - top_k: NonNegativeInt | None = None - top_p: confloat(gt=0.0, lt=1.0) | None = None # type: ignore - typical_p: confloat(gt=0.0, lt=1.0) | None = None # type: ignore - temperature: NonNegativeFloat | None = None - repetition_penalty: PositiveFloat | None = None - return_full_text: bool = False - truncate: PositiveInt | None = None - stop_sequences: list[str] = Field(default_factory=list) - seed: PositiveInt | None = None - do_sample: bool = False - watermark: bool = False - - -class EcsTextGenTgiAdapter(TextGenModelAdapter, StreamTextGenModelAdapter): - """Model adapter for ECS text generation TGI model endpoints. - - Parameters - ---------- - model_name : str - Model name. - - endpoint_url : str - Endpoint URL. - """ - - def __init__(self, *, model_name: str, endpoint_url: str) -> None: - super().__init__(model_name=model_name, endpoint_url=endpoint_url) - - # Define client - self.client = AsyncClient(endpoint_url, timeout=60) - - async def generate(self, *, text: str, model_kwargs: dict[str, Any]) -> GenerateResponse: # type: ignore - """Text generation. - - Parameters - ---------- - text : str - Prompt input text. - - model_kwargs : Dict[str, Any] - Arguments to text generation model. - - Returns - ------- - GenerateResponse - Text generation model response. - """ - request = {"prompt": text, **model_kwargs} - resp = await self.client.generate(**request) - response = GenerateResponse( - generatedText=resp.generated_text, - generatedTokens=resp.details.generated_tokens, - finishReason=resp.details.finish_reason, - ) - logger.debug( - f"Response: {escape_curly_brackets(response.json())}", - extra={"event": f"{self.__class__.__name__}:generate"}, - ) - return response - - async def generate_stream( - self, *, text: str, model_kwargs: dict[str, Any] - ) -> AsyncGenerator[GenerateStreamResponse]: - """Text generation with token streaming. - - Parameters - ---------- - text : str - Prompt input text. - - model_kwargs : Dict[str, Any] - Arguments to text generation model. - - Returns - ------- - AsyncGenerator[GenerateStreamResponse, None] - Text generation model response with streaming. - """ - request = {"prompt": text, **model_kwargs} - async for resp in self.client.generate_stream(**request): - response = GenerateStreamResponse( - token=Token(text=resp.token.text, special=resp.token.special), - generatedTokens=resp.details.generated_tokens if resp.details else None, - finishReason=resp.details.finish_reason if resp.details else None, - ) - logger.debug( - f"Response: {escape_curly_brackets(response.json())}", - extra={"event": f"{self.__class__.__name__}:generate_stream"}, - ) - yield response - - async def openai_generate_stream( - self, *, text: str, model_kwargs: dict[str, Any], is_text_completion: bool - ) -> AsyncGenerator[GenerateStreamResponse]: - """Text generation with token streaming, conforming to the OpenAI API specification. - - Parameters - ---------- - text : str - Prompt input text. - - model_kwargs : Dict[str, Any] - Arguments to text generation model. - - is_text_completion : bool - Tells if this is a request from the /completions API (True) or if it is from the - /chat/completions API (False) - - Returns - ------- - AsyncGenerator[GenerateStreamResponse, None] - Text generation model response with streaming. - """ - # Generate static values once before streaming - resp_id = str(uuid.uuid4()) - fingerprint = str(uuid.uuid4()) - created = int(time.time()) - - request = {"prompt": text, **model_kwargs} - if is_text_completion: - response_class = OpenAICompletionsResponse - else: - response_class = OpenAIChatCompletionsResponse - async for resp in self.client.generate_stream(**request): - response = response_class( - id=resp_id, - created=created, - model=self.model_name, - object="text_completion" if is_text_completion else "chat.completion.chunk", - system_fingerprint=fingerprint, - choices=[ - ( - OpenAICompletionsChoice( - index=0, - finish_reason=resp.details.finish_reason if resp.details else None, - text=resp.token.text, - ) - if is_text_completion - else OpenAIChatCompletionsChoice( - index=0, - finish_reason=resp.details.finish_reason if resp.details else None, - delta=OpenAIChatCompletionsDelta(content=resp.token.text, role="assistant"), - ) - ) - ], - ) - logger.debug( - f"Response: {escape_curly_brackets(response.json())}", - extra={"event": f"{self.__class__.__name__}:generate_stream"}, - ) - yield response - - -# Register the model -registry.register( - provider="ecs.textgen.tgi", - adapter=EcsTextGenTgiAdapter, - validator=EcsTextGenTgiValidator, -) diff --git a/lib/serve/rest-api/src/lisa_serve/registry/__init__.py b/lib/serve/rest-api/src/lisa_serve/registry/__init__.py deleted file mode 100644 index 0ef0b751f..000000000 --- a/lib/serve/rest-api/src/lisa_serve/registry/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Public imports.""" -from .index import ModelRegistry - -registry = ModelRegistry() diff --git a/lib/serve/rest-api/src/lisa_serve/registry/index.py b/lib/serve/rest-api/src/lisa_serve/registry/index.py deleted file mode 100644 index eb2dfc992..000000000 --- a/lib/serve/rest-api/src/lisa_serve/registry/index.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model registry.""" -from typing import Any - - -class ModelRegistry: - """Registry for model providers.""" - - def __init__(self) -> None: - self.registry: dict[str, Any] = {} - - def register(self, *, provider: str, adapter: Any, validator: Any) -> None: - """Register the adapter and validator for the model provider. - - Parameters - ---------- - provider : str - Model provider name. - - adapter : Any - Model adapter. - - validator : Any - Model kwargs validator. - """ - self.registry[provider] = {"adapter": adapter, "validator": validator} - - def get_assets(self, provider: str) -> dict[str, Any]: - """Get model registry entry.""" - try: - model_assets = self.registry[provider] - except KeyError: - raise KeyError( - f"Model provider '{provider}' not found in registry. Available providers: " - f"{', '.join(self.registry)}" - ) - return model_assets # type: ignore diff --git a/lib/serve/rest-api/src/main.py b/lib/serve/rest-api/src/main.py index 34a9cb745..a8640bfaf 100644 --- a/lib/serve/rest-api/src/main.py +++ b/lib/serve/rest-api/src/main.py @@ -13,16 +13,13 @@ # limitations under the License. """REST API.""" -import json import os import sys from contextlib import asynccontextmanager -import boto3 from api.routes import router from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from lisa_serve.registry import registry from loguru import logger from middleware import ( auth_middleware, @@ -31,8 +28,6 @@ security_middleware, validate_input_middleware, ) -from services.model_registration import ModelRegistrationService -from utils.cache_manager import set_registered_models_cache logger.remove() logger_level = os.environ.get("LOG_LEVEL", "INFO") @@ -59,33 +54,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): # type: ignore - """REST API start and update task.""" - event = "start_and_update_task" - task_logger = logger.bind(event=event) - task_logger.debug("Start task", status="START") - - # Create model registration service - registration_service = ModelRegistrationService(registry) - - try: - verify_path = os.getenv("SSL_CERT_FILE") or None - # Use synchronous boto3 client - this runs once at startup so async isn't needed - # This avoids aiobotocore dependency which has version conflicts with litellm's boto3 - ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], verify=verify_path) - response = ssm_client.get_parameter(Name=os.environ["REGISTERED_MODELS_PS_NAME"]) - - registered_models = json.loads(response["Parameter"]["Value"]) - - # Register all models using the service - new_models = registration_service.register_models(registered_models) - - # Update the global cache - set_registered_models_cache(new_models) - except Exception: - task_logger.exception("An unknown error occurred", status="ERROR") - + """REST API lifespan.""" yield - task_logger.debug("Finished API Lifespan task", status="FINISH") app = FastAPI(lifespan=lifespan) diff --git a/lib/serve/rest-api/src/middleware/auth_middleware.py b/lib/serve/rest-api/src/middleware/auth_middleware.py index a30ab8c8f..2119cde96 100644 --- a/lib/serve/rest-api/src/middleware/auth_middleware.py +++ b/lib/serve/rest-api/src/middleware/auth_middleware.py @@ -119,7 +119,6 @@ async def auth_middleware(request: Request, call_next: Callable[[Request], Respo return await call_next(request) except HTTPException as e: - # For OpenAI/Anthropic routes, provide more specific error messages if is_openai_route: logger.warning(f"Authentication failed for OpenAI/Anthropic route {path}: {e.detail}") return JSONResponse( @@ -132,7 +131,11 @@ async def auth_middleware(request: Request, call_next: Callable[[Request], Respo } }, ) - raise + logger.warning(f"Authentication failed for {path}: {e.detail}") + return JSONResponse( + status_code=e.status_code, + content={"error": "Unauthorized", "message": e.detail}, + ) except Exception as e: logger.error(f"Authentication error: {e}") if is_openai_route: diff --git a/lib/serve/rest-api/src/requirements.txt b/lib/serve/rest-api/src/requirements.txt index 29b54bfc8..25e1b5124 100644 --- a/lib/serve/rest-api/src/requirements.txt +++ b/lib/serve/rest-api/src/requirements.txt @@ -6,13 +6,12 @@ boto3==1.40.76 opentelemetry-api>=1.20.0 opentelemetry-sdk>=1.20.0 -aiohttp==3.13.2 +aiohttp==3.13.3 backoff==2.2.1 cachetools==6.2.2 click==8.3.1 -cryptography==46.0.3 +cryptography==46.0.5 fastapi>=0.120.1 -fastapi_utils==0.8.0 gunicorn>=23.0.0,<24.0.0 # LiteLLM - Upgraded to 1.81.3 for RDS IAM token refresh fix (PR #18795) @@ -22,9 +21,7 @@ litellm[proxy]==1.81.3 loguru==0.7.3 pydantic>=2.5.0,<3.0.0 PyJWT>=2.10.1,<3.0.0 -text-generation==0.7.0 prisma==0.15.0 -pynacl>=1.5.0,<2.0.0 starlette>=0.40.0,<0.51.0 # ASGI Server - Version constrained by litellm[proxy]==1.81.3 diff --git a/lib/serve/rest-api/src/services/model_registration.py b/lib/serve/rest-api/src/services/model_registration.py deleted file mode 100644 index 3a66650c3..000000000 --- a/lib/serve/rest-api/src/services/model_registration.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model registration service.""" -from typing import Any, Protocol - -from utils.resources import ModelType, RestApiResource - - -class RegistryProtocol(Protocol): - """Protocol for model registry.""" - - def get_assets(self, provider: str) -> dict[str, Any]: - """Get model assets for a provider.""" - ... - - -class ModelRegistrationService: - """Service for registering models from configuration.""" - - # Supported inference containers - SUPPORTED_CONTAINERS = ["tgi", "tei", "instructor"] - - def __init__(self, registry: RegistryProtocol): - """Initialize the service. - - Parameters - ---------- - registry : RegistryProtocol - The model registry to use for getting validators - """ - self.registry = registry - - def create_empty_cache(self) -> dict[str, dict[str, Any]]: - """Create an empty model cache structure. - - Returns - ------- - dict[str, dict[str, Any]] - Empty cache with all required keys - """ - return { - ModelType.EMBEDDING: {}, - ModelType.TEXTGEN: {}, - RestApiResource.EMBEDDINGS: {}, - RestApiResource.GENERATE: {}, - RestApiResource.GENERATE_STREAM: {}, - "metadata": {}, - "endpointUrls": {}, - } - - def is_supported_container(self, inference_container: str) -> bool: - """Check if inference container is supported. - - Parameters - ---------- - inference_container : str - The inference container name - - Returns - ------- - bool - True if supported, False otherwise - """ - return inference_container in self.SUPPORTED_CONTAINERS - - def register_model(self, model: dict[str, Any], cache: dict[str, dict[str, Any]]) -> None: - """Register a single model into the cache. - - Parameters - ---------- - model : dict[str, Any] - Model configuration with keys: provider, modelName, modelType, endpointUrl, streaming - cache : dict[str, dict[str, Any]] - The cache to update - """ - provider = model["provider"] - model_name = model["modelName"] - model_type = model["modelType"] - - # provider format is `modelHosting.modelType.inferenceContainer` - # example: "ecs.textgen.tgi" - parts = provider.split(".") - if len(parts) != 3: - return # Invalid provider format - - inference_container = parts[2] - - # Skip unsupported containers - if not self.is_supported_container(inference_container): - return - - # Get default model kwargs from validator - validator = self.registry.get_assets(provider)["validator"] - model_kwargs = validator().dict() - - # Build model key - model_key = f"{provider}.{model_name}" - - # Store endpoint URL - cache["endpointUrls"][model_key] = model["endpointUrl"] - - # Store metadata - cache["metadata"][model_key] = { - "provider": provider, - "modelName": model_name, - "modelType": model_type, - "modelKwargs": model_kwargs, - } - if "streaming" in model: - cache["metadata"][model_key]["streaming"] = model["streaming"] - - # Register by model type and resource - if model_type == ModelType.EMBEDDING: - cache[RestApiResource.EMBEDDINGS].setdefault(provider, []).append(model_name) - cache[ModelType.EMBEDDING].setdefault(provider, []).append(model_name) - elif model_type == ModelType.TEXTGEN: - cache[RestApiResource.GENERATE].setdefault(provider, []).append(model_name) - cache[ModelType.TEXTGEN].setdefault(provider, []).append(model_name) - if model.get("streaming", False): - cache[RestApiResource.GENERATE_STREAM].setdefault(provider, []).append(model_name) - - def register_models(self, models: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: - """Register multiple models. - - Parameters - ---------- - models : list[dict[str, Any]] - List of model configurations - - Returns - ------- - dict[str, dict[str, Any]] - The populated cache - """ - cache = self.create_empty_cache() - - for model in models: - try: - self.register_model(model, cache) - except Exception: # nosec B112 - # Skip models that fail to register - this is intentional - # to allow partial registration when some models are misconfigured - continue - - return cache diff --git a/lib/serve/rest-api/src/services/model_service.py b/lib/serve/rest-api/src/services/model_service.py deleted file mode 100644 index daef3f9ba..000000000 --- a/lib/serve/rest-api/src/services/model_service.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Service for model operations - follows Single Responsibility Principle.""" - -import time -from collections import defaultdict -from typing import Any, DefaultDict - -from utils.resources import ModelType - - -class ModelService: - """Service class for model-related operations. - - This class encapsulates all model listing and description logic, - making it easy to test without external dependencies. - """ - - def __init__(self, models_cache: dict[str, Any]): - """Initialize with models cache. - - Parameters - ---------- - models_cache : dict - The registered models cache - """ - self.models_cache = models_cache - - def list_models(self, model_types: list[ModelType]) -> dict[ModelType, dict[str, list[str]]]: - """List models by type. - - Parameters - ---------- - model_types : List[ModelType] - Model types to list - - Returns - ------- - Dict[ModelType, Dict[str, List[str]]] - List of model names by model type and provider - """ - return {model_type: self.models_cache.get(model_type, {}) for model_type in model_types} - - def list_models_openai_format(self) -> dict[str, Any]: - """List models in OpenAI-compatible format. - - Returns - ------- - Dict[str, Any] - OpenAI-compatible response with text generation models - """ - textgen_models = self.models_cache.get(ModelType.TEXTGEN, {}) - - model_payload: list[dict[str, Any]] = [] - for provider, models in textgen_models.items(): - model_payload.extend( - {"id": f"{model} ({provider})", "object": "model", "created": int(time.time()), "owned_by": "LISA"} - for model in models - ) - - return {"data": model_payload, "object": "list"} - - def get_model_metadata(self, provider: str, model_name: str) -> dict[str, Any] | None: - """Get metadata for a specific model. - - Parameters - ---------- - provider : str - Model provider name - model_name : str - Model name - - Returns - ------- - Dict[str, Any] | None - Model metadata or None if not found - """ - model_key = f"{provider}.{model_name}" - metadata_cache = self.models_cache.get("metadata", {}) - result = metadata_cache.get(model_key) - return result if result is not None else None - - def describe_models(self, model_types: list[ModelType]) -> DefaultDict[str, DefaultDict[str, dict[str, Any]]]: - """Get detailed metadata for models by type. - - Parameters - ---------- - model_types : List[ModelType] - Model types to describe - - Returns - ------- - DefaultDict[str, DefaultDict[str, Dict[str, Any]]] - Model metadata by type, provider, and name - """ - registered_models = self.list_models(model_types) - metadata_cache = self.models_cache.get("metadata", {}) - response: DefaultDict[str, DefaultDict[str, dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) - - for model_type, providers in registered_models.items(): - response[model_type] = {} # type: ignore - providers = providers or {} - for provider, model_names in providers.items(): - response[model_type][provider] = [ - metadata_cache[f"{provider}.{model_name}"] - for model_name in model_names - if f"{provider}.{model_name}" in metadata_cache - ] # type: ignore - - return response diff --git a/lib/serve/rest-api/src/utils/cache_manager.py b/lib/serve/rest-api/src/utils/cache_manager.py deleted file mode 100644 index 9f3b87660..000000000 --- a/lib/serve/rest-api/src/utils/cache_manager.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model Cache Utilities.""" -import threading -from typing import Any - -from .resources import ModelType, RestApiResource - -# Cache structure containing different types of information related to registered models. -# - ModelType keys (EMBEDDING, TEXTGEN) are used for quick lookup of models by type. -# - RestApiResource keys (EMBEDDINGS, GENERATE, GENERATE_STREAM) contain models by endpoint. -# - 'metadata' contains detailed information about each model. -# - 'endpointUrls' contains the URLs for model instantiation. -REGISTERED_MODELS_CACHE: dict[str, dict[str, Any]] = { - ModelType.EMBEDDING: {}, - ModelType.TEXTGEN: {}, - RestApiResource.EMBEDDINGS: {}, - RestApiResource.GENERATE: {}, - RestApiResource.GENERATE_STREAM: {}, - "metadata": {}, - "endpointUrls": {}, -} -MODEL_ASSETS_CACHE: dict[str, tuple[Any, Any]] = {} - -# Thread locks for cache operations -_REGISTERED_MODELS_LOCK = threading.RLock() -_MODEL_ASSETS_LOCK = threading.RLock() - - -def get_registered_models_cache() -> dict[str, dict[str, Any]]: - """Get the cache containing the registered models.""" - with _REGISTERED_MODELS_LOCK: - return REGISTERED_MODELS_CACHE.copy() - - -def get_model_assets(model_key: str) -> tuple[Any, Any] | None: - """Get the cache belonging to the model assets.""" - with _MODEL_ASSETS_LOCK: - return MODEL_ASSETS_CACHE.get(model_key) - - -def cache_model_assets(key: str, model_assets: tuple[Any, Any]) -> None: - """Cache the specified model assets for the specified key.""" - with _MODEL_ASSETS_LOCK: - MODEL_ASSETS_CACHE[key] = model_assets - - -def set_registered_models_cache(models: dict[str, dict[str, Any]]) -> None: - """Set the registered model cache to the specified models value.""" - with _REGISTERED_MODELS_LOCK: - global REGISTERED_MODELS_CACHE - REGISTERED_MODELS_CACHE = models diff --git a/lib/serve/rest-api/src/utils/request_utils.py b/lib/serve/rest-api/src/utils/request_utils.py index b4aa86094..46f4f045f 100644 --- a/lib/serve/rest-api/src/utils/request_utils.py +++ b/lib/serve/rest-api/src/utils/request_utils.py @@ -18,33 +18,9 @@ import sys import traceback from collections.abc import AsyncGenerator, Callable -from typing import Any, Protocol +from typing import Any from loguru import logger -from utils.cache_manager import cache_model_assets, get_model_assets, get_registered_models_cache -from utils.resources import RestApiResource - - -class RegistryProtocol(Protocol): - """Protocol for model registry - allows dependency injection.""" - - def get_assets(self, provider: str) -> dict[str, Any]: - """Get model assets for a provider.""" - ... - - -def _get_default_registry() -> RegistryProtocol: - """Lazy import of registry to avoid import-time dependencies. - - This function is only called at runtime, not at import time, - allowing tests to mock the registry without importing lisa_serve. - """ - # Import here to avoid circular dependencies and allow test mocking - # This is intentionally not at module level - from lisa_serve.registry import registry # noqa: PLC0415 - - return registry - logger.remove() logger_level = os.environ.get("LOG_LEVEL", "INFO") @@ -69,141 +45,6 @@ def _get_default_registry() -> RegistryProtocol: ) -async def validate_model(request_data: dict[str, Any], resource: RestApiResource) -> None: - """Validate that the selected model is registered and supported for the specified resource. - - Parameters - ---------- - request_data : Dict[str, Any] - Request data. - - resource : RestApiResource - REST API resource. - - Raises - ------ - Exception - If the selected model is not registered or not supported for the specified resource. - - Returns - ------- - None - None. - - """ - event = "validate_model" - - provider = request_data["provider"] - model_name = request_data["modelName"] - - registered_models_cache = get_registered_models_cache() - supported_models = registered_models_cache[resource][provider] - if model_name not in supported_models: - # Sanitize inputs for logging to prevent log injection - safe_model_name = str(model_name).replace("\n", "").replace("\r", "") - safe_resource = str(resource).replace("\n", "").replace("\r", "") - safe_supported = str(supported_models).replace("\n", "").replace("\r", "") - - message = ( - f"Provider does not support model {safe_model_name} for endpoint " - f"/{safe_resource}, expected one of: {safe_supported}" - ) - logger.error(message, extra={"event": event, "status": "ERROR"}) - raise ValueError(message) - - -async def get_model_and_validator( - request_data: dict[str, Any], registry: RegistryProtocol | None = None -) -> tuple[Any, Any]: - """Get model and model kwargs validator. - - Parameters - ---------- - request_data : Dict[str, Any] - Request data. - registry : RegistryProtocol | None - Optional registry for dependency injection (testing). - - Returns - ------- - Tuple - The model and model kwargs validator. - """ - provider = request_data["provider"] - model_name = request_data["modelName"] - model_key = f"{provider}.{model_name}" - - # Try to get model and validator from the cache - model_assets = get_model_assets(model_key) - if not model_assets: - # If not cached, retrieve model assets from registry - if registry is None: - registry = _get_default_registry() - - registry_assets = registry.get_assets(provider) - adapter = registry_assets["adapter"] - validator = registry_assets["validator"] - - # Retrieve model endpoint URL - registered_models_cache = get_registered_models_cache() - try: - endpoint_url = registered_models_cache["endpointUrls"][model_key] - except KeyError: - raise KeyError(f"Model endpoint URL not found for {model_key}") - - # Instantiate the model - model = adapter(model_name=model_name, endpoint_url=endpoint_url) - - # Store model and validator in the cache - model_assets = (model, validator) - cache_model_assets(model_key, model_assets) - - return model_assets - - -async def validate_and_prepare_llm_request( - request_data: dict[str, Any], resource: RestApiResource, registry: RegistryProtocol | None = None -) -> tuple[Any, Any, str]: - """Validate and prepare data for LLM (Language Model) requests. - - Parameters - ---------- - request_data : Dict[str, Any] - Request data. - - resource : RestApiResource - REST API resource. - - registry : RegistryProtocol | None - Optional registry for dependency injection (testing). - - Returns - ------- - Tuple - The model, prepared model kwargs, and text for processing. - """ - event = "validate_and_prepare_llm_request" - task_logger = logger.bind(event=event) - task_logger.debug("Start task", status="START") - - # Validate the requested model is registered - await validate_model(request_data, resource) - - # Instantiate the model and get the model kwargs validator - model, validator = await get_model_and_validator(request_data, registry) - - # Verify model kwargs - model_kwargs = validator(**request_data["modelKwargs"]) - - task_logger.debug("Finish task", status="FINISH") - - text = request_data.get("text") - if text is None: - raise ValueError("Missing required field: text") - - return model, model_kwargs.dict(), text - - def handle_stream_exceptions( func: Callable[..., AsyncGenerator[str]], ) -> Callable[..., AsyncGenerator[str]]: diff --git a/lib/serve/rest-api/src/utils/resources.py b/lib/serve/rest-api/src/utils/resources.py deleted file mode 100644 index 17934b00a..000000000 --- a/lib/serve/rest-api/src/utils/resources.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""REST API resources.""" -from enum import Enum -from typing import Any - -from pydantic import BaseModel, Field - - -class RestApiResource(str, Enum): - """REST API resource.""" - - # Model info - LIST_MODELS = "listModels" - DESCRIBE_MODEL = "describeModel" - DESCRIBE_MODELS = "describeModels" - - # Run models - EMBEDDINGS = "embeddings" - GENERATE = "generate" - GENERATE_STREAM = "generateStream" - - # OpenAI API Compatibility - OPENAI_LIST_MODELS = "openai/models" - OPENAI_COMPLETIONS = "openai/completions" - OPENAI_CHAT_COMPLETIONS = "openai/chat/completions" - - -class ModelType(str, Enum): - """Valid model types.""" - - EMBEDDING = "embedding" - TEXTGEN = "textgen" - VIDEOGEN = "videogen" - - -class _BaseModelRequest(BaseModel): - """Base model resource.""" - - provider: str = Field(..., description="The backend provider for the model.") - modelName: str = Field(..., description="The model name.") - text: str | list[str] = Field(..., description="The input text(s) to be processed by the model.") - modelKwargs: dict[str, Any] = Field(default={}, description="Arguments to the model.") - - -class EmbeddingsRequest(_BaseModelRequest): - """Create text embeddings.""" - - -class GenerateRequest(_BaseModelRequest): - """Run text generation.""" - - -class GenerateStreamRequest(_BaseModelRequest): - """Run text generation with streaming.""" - - -class OpenAIChatCompletionsRequest(BaseModel): - """Run text generation for Chat Completions for OpenAI API. - - Additional documentation at https://platform.openai.com/docs/api-reference/chat/create - """ - - messages: list[dict[str, str]] = Field(..., description="A list of messages comprising the conversation so far.") - model: str = Field(..., description="ID of the model to use.") - frequency_penalty: float | None = Field(None, description="Penalty to add for text repetition.") - logit_bias: dict[Any, Any] | None = Field( - None, description="Modify the likelihood of specified tokens appearing in the completion." - ) - logprobs: bool | None = Field( - False, - description=" ".join( - [ - "Whether to return log probabilities of the output tokens or not. If true, returns", - "the log probabilities of each output token returned in the content of message.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - top_logprobs: int | None = Field( - None, - description=" ".join( - [ - "An integer between 0 and 20 specifying the number of most likely tokens to return", - "at each token position, each with an associated log probability. logprobs must be", - "set to true if this parameter is used.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - max_tokens: int | None = Field(50, description="Maximum number of generated tokens.") - n: int | None = Field( - 1, - description=" ".join( - [ - "How many completions to generate for each prompt.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - presence_penalty: float | None = Field( - 0, - description=" ".join( - [ - "Number increasing the model's likelihood to talk about new topics.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - seed: int | None = Field(None, description="Random sampling seed.") - stop: list[str] | None = Field( - default_factory=list, description="Stop generating tokens if a member of `stop` is generated." - ) - stream: bool | None = Field( - False, - description=" ".join( - [ - "Whether to stream back partial progress. If set, tokens will be sent as data-only", - "server-sent events as they become available, with the stream terminated by a", - "data: [DONE] message.", - ] - ), - ) - top_p: float | None = Field( - None, - description=" ".join( - [ - "If set to < 1, only the smallest set of most probable tokens with probabilities", - "that add up to `top_p` or higher are kept for generation.", - ] - ), - ) - temperature: float | None = Field(None, description="Value used to divide the logits distribution.") - - -class OpenAICompletionsRequest(BaseModel): - """Run text generation for Completions for OpenAI API. - - Additional documentation at https://platform.openai.com/docs/api-reference/completions - """ - - model: str = Field(..., description="ID of the model to use.") - prompt: Any = Field( - ..., - description=" ".join( - [ - "The prompt(s) to generate completions for, encoded as a string, array of strings,", - "array of tokens, or array of token arrays.", - ] - ), - ) - best_of: int | None = Field( - 1, - description=" ".join( - [ - 'Generates best_of completions server-side and returns the "best"', - "(the one with the highest log probability per token). Results cannot be streamed.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - echo: bool | None = Field(False, description="Whether to prepend the prompt to the generated text.") - frequency_penalty: float | None = Field(None, description="Penalty to add for text repetition.") - logit_bias: dict[Any, Any] | None = Field( - None, - description=" ".join( - [ - "Modify the likelihood of specified tokens appearing in the completion.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - logprobs: int | None = Field( - None, - description=" ".join( - [ - "Include the log probabilities on the logprobs most likely output tokens,", - "as well the chosen tokens. For example, if logprobs is 5, the API will", - "return a list of the 5 most likely tokens. The API will always return the", - "logprob of the sampled token, so there may be up to logprobs+1", - "elements in the response.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - max_tokens: int | None = Field( - 50, description="The maximum number of tokens that can be generated in the completion." - ) - n: int | None = Field( - 1, - description=" ".join( - [ - "How many completions to generate for each prompt.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - presence_penalty: float | None = Field( - 0, - description=" ".join( - [ - "Number increasing the model's likelihood to talk about new topics.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - seed: int | None = Field(None, description="Random sampling seed.") - stop: Any | None = Field( - default_factory=list, description="Stop generating tokens if a member of `stop` is generated." - ) - stream: bool | None = Field( - False, - description=" ".join( - [ - "Whether to stream back partial progress.", - "If set, tokens will be sent as data-only server-sent events as they become available,", - "with the stream terminated by a data: [DONE] message.", - ] - ), - ) - suffix: str | None = Field( - None, - description=" ".join( - [ - "The suffix that comes after a completion of inserted text.", - "This parameter is only supported for gpt-3.5-turbo-instruct.", - "This option is ignored for TGI/Hugging Face models.", - ] - ), - ) - temperature: float | None = Field(1.0, description="Value used to divide the logits distribution.") - top_p: float | None = Field( - None, - description=" ".join( - [ - "If set to < 1, only the smallest set of most probable tokens with", - "probabilities that add up to `top_p` or higher are kept for generation.", - ] - ), - ) diff --git a/lib/serve/serveApplicationConstruct.ts b/lib/serve/serveApplicationConstruct.ts index 452b3b04d..fc4778cf3 100644 --- a/lib/serve/serveApplicationConstruct.ts +++ b/lib/serve/serveApplicationConstruct.ts @@ -258,6 +258,7 @@ export class LisaServeApplicationConstruct extends Construct { // This runs when switching to IAM auth or updating the configuration // Pass parameters via payload since the Lambda is shared // Use Stack.of(scope).toJsonString() to properly resolve CDK tokens in the payload + // Include timestamp to force re-run on every deployment const lambdaInvokeParams = { service: 'Lambda', action: 'invoke', @@ -271,6 +272,7 @@ export class LisaServeApplicationConstruct extends Construct { dbName: config.restApiConfig.rdsConfig.dbName, dbUser: config.restApiConfig.rdsConfig.username, iamName: serveRole.roleName, + timestamp: new Date().toISOString(), // Force re-run on every deployment }) }, }; @@ -314,6 +316,7 @@ export class LisaServeApplicationConstruct extends Construct { container.addEnvironment('LITELLM_DB_INFO_PS_NAME', litellmDbConnectionInfoPs.parameterName); container.addEnvironment('GUARDRAILS_TABLE_NAME', guardrailsTableName); container.addEnvironment('GENERATED_IMAGES_S3_BUCKET_NAME', imagesBucketName); + container.addEnvironment('MODEL_INFO_CACHE_TTL', '300'); // Add metrics queue URL if provided if (props.metricsQueueUrl) { // Get the queue URL from SSM parameter diff --git a/lib/user-interface/react/package.json b/lib/user-interface/react/package.json index b70fa8a42..7b08ec285 100644 --- a/lib/user-interface/react/package.json +++ b/lib/user-interface/react/package.json @@ -1,7 +1,7 @@ { "name": "lisa-web", "private": true, - "version": "6.2.1", + "version": "6.3.0", "type": "module", "scripts": { "dev": "vite", @@ -43,6 +43,7 @@ "luxon": "^3.7.2", "mermaid": "^11.12.2", "oidc-client-ts": "^3.1.0", + "pdfjs-dist": "^5.4.624", "react": "^19.2.1", "react-ace": "^14.0.1", "react-dom": "^19.2.1", diff --git a/lib/user-interface/react/src/App.tsx b/lib/user-interface/react/src/App.tsx index b5788c92a..70785a0cf 100644 --- a/lib/user-interface/react/src/App.tsx +++ b/lib/user-interface/react/src/App.tsx @@ -17,7 +17,7 @@ import 'regenerator-runtime/runtime'; import { ReactElement, useEffect, useState } from 'react'; import { Navigate, Route, Routes } from 'react-router-dom'; -import { AppLayout } from '@cloudscape-design/components'; +import { AppLayout, Box } from '@cloudscape-design/components'; import Spinner from '@cloudscape-design/components/spinner'; import { useAuth } from './auth/useAuth'; @@ -287,7 +287,7 @@ function App () { configLoading ?
- Loading configuration... + Loading configuration...
: diff --git a/lib/user-interface/react/src/components/Topbar.test.tsx b/lib/user-interface/react/src/components/Topbar.test.tsx index c4bbd45ac..945f93783 100644 --- a/lib/user-interface/react/src/components/Topbar.test.tsx +++ b/lib/user-interface/react/src/components/Topbar.test.tsx @@ -31,7 +31,6 @@ vi.mock('../auth/useAuth'); // Mock store functions vi.mock('@/config/store', () => ({ - purgeStore: vi.fn(), useAppDispatch: vi.fn(() => vi.fn()), useAppSelector: vi.fn((selector) => { const selectorStr = selector.toString(); @@ -93,19 +92,13 @@ describe('Topbar', () => { it('calls signoutRedirect when sign out is clicked', async () => { const user = userEvent.setup(); - const { purgeStore } = await import('@/config/store'); renderTopbar(); - // Click the user profile dropdown button (the button with user icon) const userButton = screen.getByRole('button', { expanded: false }); await user.click(userButton); - - // Click the sign out option await user.click(screen.getByText('Sign out')); - // Verify that purgeStore and signoutRedirect were called - expect(purgeStore).toHaveBeenCalledOnce(); expect(mockAuth.signoutRedirect).toHaveBeenCalledOnce(); }); @@ -132,4 +125,5 @@ describe('Topbar', () => { redirect_uri: window.location.toString(), }); }); + }); diff --git a/lib/user-interface/react/src/components/Topbar.tsx b/lib/user-interface/react/src/components/Topbar.tsx index f438ff439..1a288bfa4 100644 --- a/lib/user-interface/react/src/components/Topbar.tsx +++ b/lib/user-interface/react/src/components/Topbar.tsx @@ -19,7 +19,7 @@ import { useAuth } from '../auth/useAuth'; import { useHref, useNavigate } from 'react-router-dom'; import { applyDensity, Density, Mode } from '@cloudscape-design/global-styles'; import TopNavigation, { TopNavigationProps } from '@cloudscape-design/components/top-navigation'; -import { purgeStore, useAppDispatch, useAppSelector } from '@/config/store'; +import { useAppDispatch, useAppSelector } from '@/config/store'; import { selectCurrentUserIsAdmin, selectCurrentUserIsApiUser, selectCurrentUsername } from '../shared/reducers/user.reducer'; import { IConfiguration } from '@/shared/model/configuration.model'; import { ButtonDropdownProps } from '@cloudscape-design/components'; @@ -223,7 +223,6 @@ function Topbar ({ configs }: TopbarProps): ReactElement { auth.signinRedirect({ redirect_uri: window.location.toString() }); break; case 'signout': - await purgeStore(); await auth.removeUser(); await auth.signoutRedirect({ extraQueryParams: { diff --git a/lib/user-interface/react/src/components/app-configured.tsx b/lib/user-interface/react/src/components/app-configured.tsx index 28e0d63e6..0362e8310 100644 --- a/lib/user-interface/react/src/components/app-configured.tsx +++ b/lib/user-interface/react/src/components/app-configured.tsx @@ -24,8 +24,31 @@ import Spinner from '@cloudscape-design/components/spinner'; import { OidcConfig } from '../config/oidc.config'; import { User, UserProfile } from 'oidc-client-ts'; -import { purgeStore, useAppDispatch } from '../config/store'; +import { useAppDispatch } from '../config/store'; import { updateUserState } from '../shared/reducers/user.reducer'; +import { useAuth } from '../auth/useAuth'; + +function UserStateSync () { + const dispatch = useAppDispatch(); + const auth = useAuth(); + + useEffect(() => { + if (auth.user) { + const userGroups = getGroups(auth.user.profile); + dispatch(updateUserState({ + name: auth.user.profile.name, + preferred_username: auth.user.profile.preferred_username, + email: auth.user.profile.email, + groups: userGroups, + isAdmin: userGroups ? isAdmin(userGroups) : false, + isUser: window.env.USER_GROUP ? userGroups && isUser(userGroups) : true, + isApiUser: window.env.API_GROUP ? userGroups && isApiUser(userGroups) : false, + })); + } + }, [auth.user, dispatch]); + + return null; +} function OAuthCallback () { useEffect(() => { @@ -118,8 +141,6 @@ function AppConfigured () { window.history.replaceState({}, document.title, `${window.location.pathname}${window.location.hash}`); setOidcUser(user); } else { - // User not authorized - purge store and remove user from OIDC storage - await purgeStore(); // Clear OIDC session storage to force re-authentication const oidcStorageKey = `oidc.user:${window.env.AUTHORITY}:${window.env.CLIENT_ID}`; sessionStorage.removeItem(oidcStorageKey); @@ -127,6 +148,7 @@ function AppConfigured () { } }} > + } /> diff --git a/lib/user-interface/react/src/components/chatbot/Chat.tsx b/lib/user-interface/react/src/components/chatbot/Chat.tsx index b593dc5df..ce4af84a7 100644 --- a/lib/user-interface/react/src/components/chatbot/Chat.tsx +++ b/lib/user-interface/react/src/components/chatbot/Chat.tsx @@ -45,6 +45,7 @@ import { useAttachImageToSessionMutation, useGetSessionHealthQuery, useLazyGetSessionByIdQuery, + useListSessionsQuery, useUpdateSessionMutation, } from '@/shared/reducers/session.reducer'; import { useAppDispatch, useAppSelector } from '@/config/store'; @@ -83,6 +84,9 @@ import { setConfirmationModal } from '@/shared/reducers/modal.reducer'; import ConfirmationModal from '@/shared/modal/confirmation-modal'; import { selectCurrentUsername } from '@/shared/reducers/user.reducer'; import { conditionalDeps } from '../utils'; +import { formatDate } from '@/shared/util/formats'; +import DocumentSidePanel from './components/DocumentSidePanel'; +import { useDocumentSidePanel } from '@/shared/hooks/useDocumentSidePanel'; export default function Chat ({ sessionId }) { const dispatch = useAppDispatch(); @@ -125,6 +129,7 @@ export default function Chat ({ sessionId }) { const [userPrompt, setUserPrompt] = useState(''); const [fileContext, setFileContext] = useState(''); const [fileContextName, setFileContextName] = useState(''); + const [fileContextFiles, setFileContextFiles] = useState>([]); const [dirtySession, setDirtySession] = useState(false); const [isConnected, setIsConnected] = useState(false); const [useRag, setUseRag] = useState(false); @@ -139,6 +144,17 @@ export default function Chat ({ sessionId }) { const [updatingAutoApprovalForTool, setUpdatingAutoApprovalForTool] = useState(null); const [showMarkdownPreview, setShowMarkdownPreview] = useState(false); + // Document side panel management + const { showDocSidePanel, selectedDocumentForPanel, handleOpenDocument, handleCloseDocPanel } = useDocumentSidePanel(); + + // Close document side panel when session changes + useEffect(() => { + if (showDocSidePanel) { + handleCloseDocPanel(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [sessionId]); + // Get color scheme context for markdown preview const { colorScheme } = useContext(ColorSchemeContext); const isDarkMode = colorScheme === Mode.Dark; @@ -248,6 +264,13 @@ export default function Chat ({ sessionId }) { setRagConfig } = useSession(sessionId, getSessionById); + // Get sessions list lastUpdated timestamp + const { data: sessions } = useListSessionsQuery(null, { refetchOnMountOrArgChange: 5 }); + const currentSessionSummary = useMemo(() => + sessions?.find((s) => s.sessionId === session.sessionId), + [sessions, session.sessionId] + ); + const { modelsOptions, handleModelChange } = useModels( allModels, chatConfiguration, @@ -775,12 +798,14 @@ export default function Chat ({ sessionId }) { isVideoGenerationMode, fileContext, fileContextName, + fileContextFiles, config, useRag, showMarkdownPreview, setUserPrompt, setFileContext, setFileContextName, + setFileContextFiles, handleAction, handleKeyPress, handleButtonClick, @@ -797,6 +822,7 @@ export default function Chat ({ sessionId }) { isVideoGenerationMode, fileContext, fileContextName, + fileContextFiles, config, useRag, showMarkdownPreview, @@ -853,9 +879,10 @@ export default function Chat ({ sessionId }) { fileContext={fileContext} setFileContext={setFileContext} setFileContextName={setFileContextName} + setFileContextFiles={setFileContextFiles} selectedModel={selectedModel} // eslint-disable-next-line react-hooks/exhaustive-deps - />), conditionalDeps([modals.contextUpload], [modals.contextUpload], [modals.contextUpload, openModal, closeModal, fileContext, setFileContext, setFileContextName, selectedModel]))} + />), conditionalDeps([modals.contextUpload], [modals.contextUpload], [modals.contextUpload, openModal, closeModal, fileContext, setFileContext, setFileContextName, setFileContextFiles, selectedModel]))} {useMemo(() => ( )} -
- - - {loadingSession && ( - - - - Loading session... - Please wait while we load your conversation history - - - )} - - {useMemo(() => { - if (loadingSession) return null; - - return session.history.map((message, idx) => ( + {/* Chat messages area */} +
+ + + {loadingSession && ( + + + + Loading session... + Please wait while we load your conversation history + + + )} + + {useMemo(() => { + if (loadingSession) return null; + + return session.history.map((message, idx) => ()); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [session.history, chatConfiguration, loadingSession])} + + {!loadingSession && (isRunning || callingToolName) && !isStreaming && !isImageGenerationMode && !isVideoGenerationMode && )); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [session.history, chatConfiguration, loadingSession])} - - {!loadingSession && (isRunning || callingToolName) && !isStreaming && !isImageGenerationMode && !isVideoGenerationMode && } - {!loadingSession && session.history.length === 0 && sessionId === undefined && ( - - )} -
- + onOpenDocument={handleOpenDocument} + />} + {!loadingSession && session.history.length === 0 && sessionId === undefined && ( + + )} +
+ +
+ + {/* Document side panel */} + {showDocSidePanel && ( + + )}
+
e.preventDefault()}> @@ -1041,7 +1090,7 @@ export default function Chat ({ sessionId }) { )} - + {enabledServers && enabledServers.length > 0 && selectedModel?.features?.filter((feature) => feature.name === ModelFeatures.TOOL_CALLS)?.length && true ? ( {enabledServers.length} MCP Servers - {openAiTools?.length || 0} tools @@ -1051,6 +1100,13 @@ export default function Chat ({ sessionId }) { : ( This model does not have Tool Calling enabled )} + + {!loadingSession && session.history.length > 0 && (currentSessionSummary?.lastUpdated) && ( + + Last updated: {formatDate(currentSessionSummary?.lastUpdated)} + + )} + {isConnected ? 'Connected' : 'Disconnected'} diff --git a/lib/user-interface/react/src/components/chatbot/components/ChatPromptInput.tsx b/lib/user-interface/react/src/components/chatbot/components/ChatPromptInput.tsx index 07c81799a..77eca1687 100644 --- a/lib/user-interface/react/src/components/chatbot/components/ChatPromptInput.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/ChatPromptInput.tsx @@ -29,13 +29,14 @@ type ChatPromptInputProps = { isImageGenerationMode: boolean; isVideoGenerationMode: boolean; fileContext: string; - fileContextName: string; + fileContextFiles: Array<{name: string, content: string}>; config: IConfiguration; useRag: boolean; showMarkdownPreview: boolean; setUserPrompt: (value: string) => void; setFileContext: (value: string) => void; setFileContextName: (value: string) => void; + setFileContextFiles: React.Dispatch>>; handleAction: () => void; handleKeyPress: (event: any) => void; handleButtonClick: (event: { detail: { id: string } }) => void; @@ -61,18 +62,37 @@ export const ChatPromptInput: React.FC = ({ isImageGenerationMode, isVideoGenerationMode, fileContext, - fileContextName, + fileContextFiles, config, useRag, showMarkdownPreview, setUserPrompt, setFileContext, setFileContextName, + setFileContextFiles, handleAction, handleKeyPress, handleButtonClick, getButtonItems, }) => { + // Handler for removing individual files + const handleRemoveFile = (fileNameToRemove: string) => { + const remainingFiles = fileContextFiles.filter((f) => f.name !== fileNameToRemove); + + if (remainingFiles.length === 0) { + // No files left, clear everything + setFileContext(''); + setFileContextName(''); + setFileContextFiles([]); + } else { + // Update with remaining files + const combinedContext = remainingFiles.map((f) => f.content).join('\n\n'); + const fileNames = remainingFiles.map((f) => f.name).join(', '); + setFileContext(`File context:\n${combinedContext}`); + setFileContextName(fileNames); + setFileContextFiles(remainingFiles); + } + }; return ( = ({ } secondaryContent={ - fileContext && ( + fileContext && fileContextFiles.length > 0 && ( { - setFileContext(''); - setFileContextName(''); + items={fileContextFiles.map((file) => ({ + file: new File([file.content], file.name) + }))} + onDismiss={(event) => { + // The event.detail contains the fileIndex + const dismissedIndex = (event.detail as any).fileIndex; + if (dismissedIndex !== undefined && fileContextFiles[dismissedIndex]) { + handleRemoveFile(fileContextFiles[dismissedIndex].name); + } }} alignment='horizontal' showFileSize={false} showFileLastModified={false} showFileThumbnail={false} i18nStrings={{ - removeFileAriaLabel: () => 'Remove file', + removeFileAriaLabel: (fileIndex) => `Remove file ${fileContextFiles[fileIndex]?.name || fileIndex + 1}`, limitShowFewer: 'Show fewer files', limitShowMore: 'Show more files', errorIconAriaLabel: 'Error', diff --git a/lib/user-interface/react/src/components/chatbot/components/DocumentSidePanel.tsx b/lib/user-interface/react/src/components/chatbot/components/DocumentSidePanel.tsx new file mode 100644 index 000000000..7d6635531 --- /dev/null +++ b/lib/user-interface/react/src/components/chatbot/components/DocumentSidePanel.tsx @@ -0,0 +1,255 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import React, { useEffect, useState } from 'react'; +import { + Box, + Button, + Header, + SpaceBetween, + Spinner, + StatusIndicator, + Container, +} from '@cloudscape-design/components'; +import { useDispatch } from 'react-redux'; +import { getFileType, normalizeDocumentName } from '@/components/utils'; +import { useLazyDownloadRagDocumentQuery } from '@/shared/reducers/rag.reducer'; +import { useNotificationService } from '@/shared/util/hooks'; + +export type DocumentSidePanelProps = { + visible: boolean; + onClose: () => void; + document: { + documentId: string; + repositoryId: string; + name: string; + source: string; + } | null; +}; + +export function DocumentSidePanel ({ visible, onClose, document }: DocumentSidePanelProps) { + const dispatch = useDispatch(); + const notificationService = useNotificationService(dispatch); + const [downloadUrl, { isLoading: isLoadingUrl }] = useLazyDownloadRagDocumentQuery(); + const [documentUrl, setDocumentUrl] = useState(null); + const [textContent, setTextContent] = useState(''); + const [error, setError] = useState(null); + + const fileType = document ? getFileType(document.name) : 'txt'; + + // Load document URL when document changes + useEffect(() => { + if (!visible || !document) { + // Revoke object URL to prevent memory leaks + if (documentUrl) { + URL.revokeObjectURL(documentUrl); + } + setDocumentUrl(null); + setTextContent(''); + setError(null); + return; + } + + loadDocument(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [visible, document?.documentId]); + + // Cleanup object URL on unmount + useEffect(() => { + return () => { + if (documentUrl) { + URL.revokeObjectURL(documentUrl); + } + }; + }, [documentUrl]); + + const loadDocument = async () => { + if (!document) return; + + try { + setError(null); + + const urlResponse = await downloadUrl({ + documentId: document.documentId, + repositoryId: document.repositoryId, + }).unwrap(); + + if (fileType === 'pdf') { + // Fetch PDF as blob and create object URL with correct MIME type + // This ensures browser displays it inline instead of downloading + const response = await fetch(urlResponse); + if (!response.ok) { + throw new Error(`Failed to fetch document: ${response.status} ${response.statusText}`); + } + const blob = await response.blob(); + + // Create a new blob with the correct MIME type + const pdfBlob = new Blob([blob], { type: 'application/pdf' }); + const objectUrl = URL.createObjectURL(pdfBlob); + setDocumentUrl(objectUrl); + } else if (fileType === 'txt') { + // For text files, fetch and display content + const response = await fetch(urlResponse); + if (!response.ok) { + throw new Error(`Failed to fetch document: ${response.status} ${response.statusText}`); + } + const text = await response.text(); + setTextContent(text); + } + } catch (err) { + const errorMessage = err instanceof Error ? err.message : 'Unknown error'; + notificationService.generateNotification( + `Failed to load document: ${errorMessage}`, + 'error' + ); + setError('Failed to load document. Please try again.'); + } + }; + + const handleDownload = async () => { + if (!document) return; + + try { + const url = await downloadUrl({ + documentId: document.documentId, + repositoryId: document.repositoryId, + }).unwrap(); + + window.open(url, '_blank', 'noopener, noreferrer'); + } catch (err) { + const errorMessage = err instanceof Error ? err.message : 'Unknown error'; + notificationService.generateNotification( + `Failed to download document: ${errorMessage}`, + 'error' + ); + } + }; + + if (!visible) { + return null; + } + + return ( +
+ + +