diff --git a/.github/workflows/code.ai-review.yml b/.github/workflows/code.ai-review.yml new file mode 100644 index 000000000..d8cd48857 --- /dev/null +++ b/.github/workflows/code.ai-review.yml @@ -0,0 +1,49 @@ +name: AI Code Review + +permissions: + id-token: write + contents: read + pull-requests: write + +on: + pull_request: + pull_request_review_comment: + types: [created] + +concurrency: + group: ${{ github.repository }}-${{ github.event.number || github.head_ref || + github.sha }}-${{ github.workflow }}-${{ github.event_name == + 'pull_request_review_comment' && 'pr_comment' || 'pr' }} + cancel-in-progress: ${{ github.event_name != 'pull_request_review_comment' }} + +jobs: + review: + environment: dev + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: ${{ vars.AWS_REGION }} + role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT }}:role/${{ vars.ROLE_NAME_TO_ASSUME }} + role-session-name: GitHub_to_AWS_via_FederatedOIDC + role-duration-seconds: 7200 + - name: PR Review + uses: tmokmss/bedrock-pr-reviewer@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + debug: false + summarize: true + summarize_release_notes: true + review_file_diff: | + - Do NOT provide general feedback, summaries, explanations of changes, statements of following existing patterns, or praises for making good additions. + - Focus solely on offering specific, objective insights based on the given context and refrain from making broad comments about potential impacts on the system or question intentions behind the changes. + - Comments should have actionable changes + - Disregard formatting, stylistic, or import issues, since this should be taken care of with linters and tests + - Ignore verification comments unless there is evidence that contradicts the statement. + - Ignore functional imports for test classes. Other classes should not import within functions + review_simple_changes: false + review_comment_lgtm: false + bedrock_light_model: ${{ vars.BEDROCK_LIGHT_MODEL }} + bedrock_heavy_model: ${{ vars.BEDROCK_HEAVY_MODEL }} diff --git a/.github/workflows/code.end-to-end-test.nightly.yml b/.github/workflows/code.end-to-end-test.nightly.yml index 3db8163a7..f7eced262 100644 --- a/.github/workflows/code.end-to-end-test.nightly.yml +++ b/.github/workflows/code.end-to-end-test.nightly.yml @@ -29,6 +29,8 @@ jobs: needs: notify_e2e_start steps: - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v4 + with: + ref: develop - name: Setup Node.js uses: actions/setup-node@a0853c24544627f65ddf259abe73b1d18a591444 # v4 with: diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index a2c57c880..85dd631a8 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -51,7 +51,7 @@ jobs: npm ci - name: Run tests run: | - npm run test + npm run test -ci backend-build: name: Backend Tests runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 5c733858b..93fd7d7de 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,8 @@ lib/rag/ingestion/ingestion-image/build *.code-workspace .cursor memory-bank/ +.kiro/ +.amazonq/ # Coverage Statistic Folders coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41ab94870..38efde3e4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,6 @@ default_language_version: node: system + python: python3.11 repos: - repo: local hooks: @@ -9,7 +10,7 @@ repos: entry: scripts/verify-config.sh verbose: true language: script - files: config.yaml + files: config-base.yaml - repo: https://github.com/PyCQA/bandit rev: '1.7.10' @@ -49,7 +50,7 @@ repos: pass_filenames: false - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 7.0.0 hooks: - id: isort name: isort (python) @@ -65,7 +66,8 @@ repos: - id: ruff args: - --exit-non-zero-on-fix - - --per-file-ignores=test/**/*.py:E402 + - --per-file-ignores=test/**/*.py:E402,test/**/*.py:PLC0415 + - --fix exclude: \.ipynb$ - repo: https://github.com/pycqa/flake8 @@ -77,11 +79,10 @@ repos: - flake8-bugbear - flake8-comprehensions - flake8-debugger - - flake8-string-format args: - --max-line-length=120 - --extend-immutable-calls=Query,fastapi.Depends,fastapi.params.Depends - - --ignore=B008,E203, W503 # Ignore error for function calls in argument defaults + - --ignore=B008,E203,W503 # Ignore error for function calls in argument defaults exclude: ^(__init__.py$|.*\/__init__.py$) diff --git a/CHANGELOG.md b/CHANGELOG.md index 08bc39d92..d9542a669 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,64 @@ +# v6.0.0 +Happy Thanksgiving! We are proud to announce the launch of our next major version, 6.0.0! This launch aligns with AWS re:invent in Las Vegas from Dec 1-5th. LISA 6.0.0 includes major enhancements to LISA's RAG capabilities. It also includes a new standalone solution, LISA MCP. + +We hope you enjoy this release as much as we enjoyed building it. Please reach out to our product team via the "Contact us" button in the readme. Our product roadmap is customer driven, and we want to hear your feedback, questions, and needs as we look to 2026. + + +## Breaking Changes +- **API Token Table Migration**: The API token table has been renamed and moved from the Serve stack to the API Base stack (`LisaServeTokenTable` → `LisaApiBaseTokenTable`). **Export all existing API keys before upgrading** and recreate them in the new table after deployment. This affects admin keys, service accounts, and any programmatic API access. +- **Management Key Secret Migration**: The LISA management key secret has been moved to the API Base stack with a new name format: `${deploymentName}-management-key` (removed `lisa-` prefix). **Update any scripts or integrations that reference the secret by name.** The secret value will be auto-generated during deployment; export from AWS Secrets Manager before upgrading if you need to preserve the existing value. Code using the SSM parameter `${deploymentPrefix}/appManagementKeySecretName` will continue to work without changes. +- **Existing Bedrock Knowledge Base Repositories** must be redeployed to support the new collections infrastructure. This is a simple update operation that creates the necessary infrastructure for automatic data source collection creation. Use the repository update API or UI to redeploy existing Bedrock Knowledge Base repositories. + +## Key Features +### LISA MCP +LISA MCP is a standalone infrastructure-as-code solution that allows administrators to deploy and host any Model Context Protocol (MCP) servers directly within LISA. This enterprise hosting platform provides turn-key infrastructure deployment, automatic scaling, and secure networking, allowing organizations to build and operate custom MCP tools without managing underlying infrastructure. +#### Enterprise Hosting Capabilities +- **Turn-key Deployment**: Deploy STDIO, HTTP, or SSE MCP servers through a single API call or intuitive UI workflow, eliminating the need for manual infrastructure configuration +- **Dynamic Container Management**: Bring your own pre-built container images or point to S3 artifacts that are automatically packaged into containers at deployment time +- **Automatic Scaling**: Configure auto-scaling policies with customizable min/max capacity, CPU, and memory settings to handle varying workloads efficiently +- **Secure VPC Networking**: All MCP servers run within your private VPC with Application and Network Load Balancers, ensuring traffic never leaves your secure network boundaries +- **API Gateway Integration**: Hosted MCP servers are automatically exposed through LISA's existing API Gateway, inheriting the same authentication, authorization, and security controls (API keys, IDP lockdown, JWT group enforcement) used across the platform +#### Administrative Control +- **MCP Management UI**: Complete lifecycle management through a dedicated admin interface where administrators create, update, start, stop, and delete hosted MCP servers +- **Group-Based Access Control**: Restrict server visibility and usage to specific identity provider groups or make them available organization-wide +- **Lifecycle Automation**: Step Functions orchestrate provisioning, health monitoring, failure handling, and connection publishing with full auditability through DynamoDB status records +- **Health Monitoring**: Built-in health checks at both the container and load balancer levels ensure reliable service availability +#### Integration & Extensibility +- **External Integration Support**: Hosted MCP servers can be accessed by external agents, copilots, RPA bots, or SaaS workloads using the same API Gateway endpoints and authentication mechanisms +- **mcp-proxy Support**: STDIO servers are automatically wrapped with `mcp-proxy` and exposed over HTTP, simplifying deployment of standard MCP server implementations +- **UI & API Parity**: Manage servers through either the MCP Management admin page or REST API endpoints (`/mcp`), providing flexibility for automation and programmatic workflows +### LISA RAG Collections +LISA's RAG capabilities just got a major upgrade! We've completely reimagined how you organize and manage RAG documents with the introduction of Collections. Collections transform how you structure your RAG content. Think of repositories as filing cabinets and collections as the organized drawers within—each with its own configuration. +#### Flexible Document Organization + +- **Custom Chunking Strategies**: Configure different chunking approaches per collection (fixed-size or no chunking). If using a Bedrock Knowledge Base all service chunking strategies are supported +- **Flexible Embedding Models**: Each collection can use its own embedding model, optimizing retrieval for specific document types +- **Access Control**: Set collection-level permissions with group-based access control, making it easy to share some collections organization-wide while keeping others restricted within the same repository +- **Rich Metadata Support**: Tag documents with custom metadata at the repository, collection, or document level for powerful filtering and organization +#### Intelligent Document Lifecycle Management + +- **Smart Deletion Workflows**: Delete collections asynchronously with optimized cleanup for each supported Repository +- **Document Preservation**: User-managed documents in Bedrock Knowledge Bases are automatically preserved during collection operations, ensuring you never lose important content +- **Enhanced UI Experience**: Browse, filter, and sort collections with visual status indicators, intuitive creation wizards, and document library integration with breadcrumb navigation +- **Admin-Controlled Operations**: Collection creation, updates, and deletion are restricted to administrators while regular users can continue to view and upload documents to collections they have permission to access +- **Backward Compatibility**: Existing repositories automatically get a virtual "Default" collection using the repository's embedding model with zero downtime and no database migrations required +#### Bedrock Knowledge Base Updates + +- **Automatic Collection Creation**: Each Bedrock Knowledge Base Data Source gets its own collection with LISA's management capabilities +- **Custom Metadata & Tagging**: Add LISA's metadata to your Bedrock Knowledge Base documents for enhanced organization and filtering +### Other Enhancements +- Updated the prompt area to auto-expand from 2 rows to 20 rows when typing a large prompt. +- Updates for easier prisma client generation +- Enhanced logging in LISA Rest ECS cluster to include LiteLLM logs + +## Acknowledgements +* @bedanley +* @dustins +* @estohlmann +* @jmharold + +**Full Changelog**: https://github.com/awslabs/LISA/compare/v5.4.0..v6.0.0 + # v5.4.0 ## Key Features @@ -32,8 +93,6 @@ Enhanced the user experience of the MCP Workbench with tool validation, error di * @estohlmann * @jmharold -**Full Changelog**: https://github.com/awslabs/LISA/compare/v5.3.2..v5.4.0 - # v5.3.2 ## Key Features diff --git a/Makefile b/Makefile index 16ce9c41f..e2706c5fd 100644 --- a/Makefile +++ b/Makefile @@ -372,4 +372,4 @@ test-coverage: --cov-report term-missing \ --cov-report html:build/coverage \ --cov-report xml:build/coverage/coverage.xml \ - --cov-fail-under 85 + --cov-fail-under 83 diff --git a/README.md b/README.md index c5b0b93cd..6bdea20b7 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,27 @@ # LLM Inference Solution for Amazon Dedicated Cloud (LISA) [![Full Documentation](https://img.shields.io/badge/Full%20Documentation-blue?style=for-the-badge&logo=Vite&logoColor=white)](https://awslabs.github.io/LISA/) +[![Contact Us](https://img.shields.io/badge/Contact%20Us-green?style=for-the-badge&logo=maildotru&logoColor=white)](mailto:lisa-product-team@amazon.com) ## What is LISA? -Our large language model (LLM) inference solution for the Amazon Dedicated Cloud (ADC), LISA, is an open source infrastructure-as-code solution. Customers deploy LISA directly into an Amazon Web Services (AWS) account. While specially designed for ADC regions that support government customers' most sensitive workloads, LISA is also compatible with commercial regions. LISA supports model self-hosting via Amazon Elastic Container Service (ECS). LISA's LiteLLM support also makes it compatible with 100+ models hosted by external model providers, including Amazon Bedrock. LISA further complements Amazon Bedrock by accelerating GenAI adoption. LISA's optional chat assistant user interface (UI) supports model management, model prompting, document summarization, chat session management, prompt libraries, retrieval augmented generation (RAG), automated document ingestion pipelines, and other advanced features. Customers can choose to integrate custom UIs directly with LISA, relying on LISA for centralized model orchestration, chat session management, and RAG. LISA is scalable and ready to support production use cases. The roadmap is customer-driven, with new capabilities launching monthly. +Our large language model (LLM) inference solution for the Amazon Dedicated Cloud (ADC), LISA, is open source infrastructure-as-code. Customers deploy it directly into an Amazon Web Services (AWS) account in any region. LISA is scalable and ready to support production use cases. + +LISA accelerates GenAI adoption by offering built-in configurability with Amazon Bedrock models, Knowledge Bases, and Guardrails. Also by offering advanced capabilities like an optional enterprise-ready chat user interface (UI) with configurable features, authentication, resource access control, centralized model orchestration via LiteLLM, model self-hosting via Amazon ECS, retrieval augmented generation (RAG), APIs, and broad model context protocol (MCP) support and features. LISA is also compatible with OpenAI’s API specification making it easily configurable with supporting solutions. For example, the Continue plugin for VSCode and JetBrains integrated development environments (IDE). + +LISA's roadmap is customer-driven, with new capabilities launching monthly. Reach out to the product team to ask questions, provide feedback, and send feature requests via the "Contact Us" button above. + ## Key Features -* **Open source**: No subscription or licensing fees. LISA costs are based on service usage. The roadmap is customer-driven with monthly releases. LISA is backed by a software development team. -* **Model Flexibility**: Bring your own models for self-hosting, or quickly configure LISA with 100+ models supported by third-party model providers, including Amazon Bedrock. +* **Open Source**: No subscription or licensing fees. LISA costs are based on service usage. +* **Ongoing Releases**: The product roadmap is customer-driven with releases typically every 2-4 weeks. LISA is backed by a software development team that builds production grade solutions to accelerate customers' GenAI adoption. +* **Model Flexibility**: Bring your own models for self-hosting, or quickly configure LISA with 100+ models supported by third-party model providers, including Amazon Bedrock and Jumpstart. * **Model Orchestration**: Centralize and standardize unique API calls to third-party model providers automatically with LISA via LiteLLM. LISA standardizes the unique API calls into the OpenAI format automatically. All that is required is an API key, model name, and API endpoint. * **Modular Components**: Accelerate GenAI adoption with secure, scalable software. LISA supports various use cases through configurable components: model serving and orchestration, chat user interface with advanced capabilities, authentication, retrieval augmented generation (RAG), Anthropic’s Model Context Protocol (MCP), and APIs. -* **CodeGen**: Supports OpenAI’s API specification, making LISA easily configurable with compatible solutions like the Continue plugin for VSCode and JetBrains integrated development environments (IDEs). This allows users to select from any LISA configured model to support LLM prompting directly in their IDE. +* **CodeGen**: LISA supports OpenAI’s API specification, making it easily configurable with compatible solutions like the Continue plugin for VSCode and JetBrains IDEs. * **FedRAMP**: Leverages FedRAMP High compliant services. + +## Major Components +LISA’s four major components include Serve, a Chat UI, RAG, and MCP. LISA Serve and LISA MCP are standalone, foundational core solutions with APIs for customers not leveraging LISA’s Chat UI. Both LISA’s Chat UI and RAG are optional components, but must be used with Serve. + +Read more in the Architecture Overview section of LISA's documentation site linked above. + ## Deployment Prerequisites ### Pre-Deployment Steps * Set up or have access to an AWS account. diff --git a/VERSION b/VERSION index 8a30e8f94..09b254e90 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -5.4.0 +6.0.0 diff --git a/cypress/package.json b/cypress/package.json index d5d4c7dd8..15c0ba574 100644 --- a/cypress/package.json +++ b/cypress/package.json @@ -17,7 +17,8 @@ "cypress:smoke:run": "cypress run --config-file cypress.smoke.config.ts", "clean": "rm -rf node_modules/", "lint:fix": "eslint --fix src/", - "format": "eslint --fix src/" + "format": "eslint --fix src/", + "test": "echo \"E2E tests run separately via cypress:e2e:run or cypress:smoke:run\"" }, "lint-staged": { "**/*.{js,jsx,ts,tsx}": [ diff --git a/cypress/src/smoke/fixtures/env.json b/cypress/src/smoke/fixtures/env.json index 32e949417..60a171023 100644 --- a/cypress/src/smoke/fixtures/env.json +++ b/cypress/src/smoke/fixtures/env.json @@ -7,5 +7,6 @@ "RESTAPI_URI": "", "RESTAPI_VERSION": "v2", "RAG_ENABLED": true, + "HOSTED_MCP_ENABLED": true, "API_BASE_URL": "/dev/" } diff --git a/cypress/src/support/adminHelpers.ts b/cypress/src/support/adminHelpers.ts index 817cc7e23..a70af5128 100644 --- a/cypress/src/support/adminHelpers.ts +++ b/cypress/src/support/adminHelpers.ts @@ -37,7 +37,7 @@ export function expandAdminMenu () { .should('be.visible'); cy.get('[role="menuitem"]') - .should('have.length', 2) + .should('have.length', 4) .then(($items) => { const labels = $items .map((_, el) => Cypress.$(el).text().trim()) @@ -45,6 +45,8 @@ export function expandAdminMenu () { expect(labels).to.deep.equal([ 'Configuration', 'Model Management', + 'RAG Management', + 'MCP Management', ]); }); } diff --git a/ecs_model_deployer/package.json b/ecs_model_deployer/package.json index f69091b27..b55b8e6e3 100644 --- a/ecs_model_deployer/package.json +++ b/ecs_model_deployer/package.json @@ -9,7 +9,7 @@ "pack:prod": "cd ./dist && npm i --omit dev", "copy-dist": "mkdir -p ../dist/ecs_model_deployer && cp -r ./dist/* ../dist/ecs_model_deployer/", "clean": "rm -rf ./dist/", - "test": "echo \"Error: no test specified\" && exit 1" + "test": "echo \"No tests for ECS model deployer package\"" }, "author": "", "license": "Apache-2.0", diff --git a/ecs_model_deployer/src/lib/ecs-model.ts b/ecs_model_deployer/src/lib/ecs-model.ts index 8e6b32b8f..f77511499 100644 --- a/ecs_model_deployer/src/lib/ecs-model.ts +++ b/ecs_model_deployer/src/lib/ecs-model.ts @@ -21,7 +21,7 @@ import { Construct } from 'constructs'; import { ECSCluster } from './ecsCluster'; import { getModelIdentifier } from './utils'; -import { Ec2Metadata, EcsClusterConfig, EcsSourceType, PartialConfig } from '../../../lib/schema'; +import { APP_MANAGEMENT_KEY, Ec2Metadata, EcsClusterConfig, EcsSourceType, PartialConfig } from '../../../lib/schema'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; // This is the amount of memory to buffer (or subtract off) from the total instance memory, if we don't include this, @@ -106,7 +106,7 @@ export class EcsModel extends Construct { MODEL_NAME: modelConfig.modelName, LOCAL_CODE_PATH: modelConfig.localModelCode, // Only needed when s5cmd is used, but just keep for now AWS_REGION: config.region ?? '', // needed for s5cmd - MANAGEMENT_KEY_NAME: StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/managementKeySecretName`) + MANAGEMENT_KEY_NAME: StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`) }; if (modelConfig.modelType === 'embedding') { diff --git a/lambda/configuration/lambda_functions.py b/lambda/configuration/lambda_functions.py index 5f0ea4c81..9b763b802 100644 --- a/lambda/configuration/lambda_functions.py +++ b/lambda/configuration/lambda_functions.py @@ -40,7 +40,7 @@ def get_configuration(event: dict, context: dict) -> Dict[str, Any]: return _get_configurations(config_scope) -def _get_configurations(config_scope: str) -> dict[str, Any]: +def _get_configurations(config_scope: str) -> list[dict[str, Any]]: response = {} try: response = table.query( @@ -55,7 +55,7 @@ def _get_configurations(config_scope: str) -> dict[str, Any]: else: logger.exception("Error fetching session") - return response.get("Items", {}) # type: ignore [no-any-return] + return response.get("Items", []) # type: ignore [no-any-return] @api_wrapper diff --git a/lambda/dockerimagebuilder/__init__.py b/lambda/dockerimagebuilder/__init__.py index f50fae2a5..d0d1d4186 100644 --- a/lambda/dockerimagebuilder/__init__.py +++ b/lambda/dockerimagebuilder/__init__.py @@ -56,7 +56,9 @@ systemctl start docker # Start CloudWatch agent with configuration -/opt/aws/amazon-cloudwatch-agent/bin/amazon-cloudwatch-agent-ctl -a fetch-config -m ec2 -c file:/opt/aws/amazon-cloudwatch-agent/etc/amazon-cloudwatch-agent.json -s +/opt/aws/amazon-cloudwatch-agent/bin/amazon-cloudwatch-agent-ctl \\ + -a fetch-config -m ec2 \\ + -c file:/opt/aws/amazon-cloudwatch-agent/etc/amazon-cloudwatch-agent.json -s # Setup build environment mkdir /home/ec2-user/docker_resources @@ -71,9 +73,11 @@ function buildTagPush() { echo "Starting Docker build for {{IMAGE_ID}}" | tee -a /var/log/docker-build.log sed -iE 's/^FROM.*/FROM {{BASE_IMAGE}}/' Dockerfile - docker build -t {{IMAGE_ID}} --build-arg BASE_IMAGE={{BASE_IMAGE}} --build-arg MOUNTS3_DEB_URL={{MOUNTS3_DEB_URL}} . 2>&1 | tee -a /var/log/docker-build.log && \ - docker tag {{IMAGE_ID}} {{ECR_URI}}:{{IMAGE_ID}} 2>&1 | tee -a /var/log/docker-build.log && \ - aws --region ${AWS_REGION} ecr get-login-password | docker login --username AWS --password-stdin {{ECR_URI}} 2>&1 | tee -a /var/log/docker-build.log && \ + docker build -t {{IMAGE_ID}} --build-arg BASE_IMAGE={{BASE_IMAGE}} \\ + --build-arg MOUNTS3_DEB_URL={{MOUNTS3_DEB_URL}} . 2>&1 | tee -a /var/log/docker-build.log && \\ + docker tag {{IMAGE_ID}} {{ECR_URI}}:{{IMAGE_ID}} 2>&1 | tee -a /var/log/docker-build.log && \\ + aws --region ${AWS_REGION} ecr get-login-password | \\ + docker login --username AWS --password-stdin {{ECR_URI}} 2>&1 | tee -a /var/log/docker-build.log && \\ docker push {{ECR_URI}}:{{IMAGE_ID}} 2>&1 | tee -a /var/log/docker-build.log echo "Build completed with exit code $?" | tee -a /var/log/docker-build.log return $? diff --git a/lambda/management_key.py b/lambda/management_key.py index afe6b4053..df650bb5e 100644 --- a/lambda/management_key.py +++ b/lambda/management_key.py @@ -17,6 +17,7 @@ import json import logging import os +import string from datetime import datetime from typing import Any, Dict @@ -132,7 +133,7 @@ def test_secret(secret_arn: str, token: str) -> None: raise ValueError("New secret is invalid - too short or empty") # Additional validation - ensure it doesn't contain punctuation (as per generation config) - if any(char in new_secret for char in "!@#$%^&*()_+-=[]{}|;:,.<>?"): # noqa: P103 + if any(char in new_secret for char in string.punctuation): raise ValueError("New secret contains punctuation when it shouldn't") logger.info(f"Secret test passed for version {token}") diff --git a/lambda/mcp_server/lambda_functions.py b/lambda/mcp_server/lambda_functions.py index b193487fa..30cf822d5 100644 --- a/lambda/mcp_server/lambda_functions.py +++ b/lambda/mcp_server/lambda_functions.py @@ -16,22 +16,36 @@ import json import logging import os +import re +import uuid from decimal import Decimal from functools import reduce from typing import Any, Dict, List, Optional import boto3 from boto3.dynamodb.conditions import Attr, Key -from utilities.auth import get_username, is_admin -from utilities.common_functions import api_wrapper, get_bearer_token, get_groups, get_item, retry_config +from utilities.auth import admin_only, get_user_context +from utilities.common_functions import api_wrapper, get_bearer_token, get_item, retry_config -from .models import McpServerModel, McpServerStatus +from .models import ( + HostedMcpServerModel, + HostedMcpServerStatus, + McpServerModel, + McpServerStatus, + UpdateHostedMcpServerRequest, +) logger = logging.getLogger(__name__) # Initialize the DynamoDB resource and the table using environment variables dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) table = dynamodb.Table(os.environ["MCP_SERVERS_TABLE_NAME"]) +stepfunctions = boto3.client("stepfunctions", region_name=os.environ["AWS_REGION"], config=retry_config) + + +def _normalize_server_name(name: str) -> str: + """Normalize server name to match CDK resource naming (alphanumeric only).""" + return re.sub(r"[^a-zA-Z0-9]", "", name) def replace_bearer_token_header(mcp_server: dict, replacement: str): @@ -141,7 +155,7 @@ def _get_mcp_servers( @api_wrapper def get(event: dict, context: dict) -> Any: """Retrieve a specific mcp server from DynamoDB.""" - user_id = get_username(event) + user_id, is_admin_user, groups = get_user_context(event) mcp_server_id = get_mcp_server_id(event) # Check if showPlaceholder query parameter is present @@ -157,8 +171,8 @@ def get(event: dict, context: dict) -> Any: # Check if the user is authorized to get the mcp server is_owner = item["owner"] == user_id or item["owner"] == "lisa:public" - groups = item.get("groups", []) - if is_owner or is_admin(event) or _is_member(get_groups(event), groups): + item_groups = item.get("groups", []) + if is_owner or is_admin_user or _is_member(groups, item_groups): # add extra attribute so the frontend doesn't have to determine this if is_owner: item["isOwner"] = True @@ -198,12 +212,11 @@ def _set_can_use( @api_wrapper def list(event: dict, context: dict) -> Dict[str, Any]: """List mcp servers for a user from DynamoDB.""" - user_id = get_username(event) + user_id, is_admin_user, groups = get_user_context(event) bearer_token = get_bearer_token(event) - groups = get_groups(event) - if is_admin(event): + if is_admin_user: logger.info(f"Listing all mcp servers for user {user_id} (is_admin)") return _set_can_use(_get_mcp_servers(replace_bearer_token=bearer_token), user_id, groups) @@ -217,7 +230,7 @@ def list(event: dict, context: dict) -> Dict[str, Any]: @api_wrapper def create(event: dict, context: dict) -> Any: """Create a new mcp server in DynamoDB.""" - user_id = get_username(event) + user_id, _, _ = get_user_context(event) body = json.loads(event["body"], parse_float=Decimal) body["owner"] = ( user_id if body.get("owner", None) != "lisa:public" else body["owner"] @@ -232,7 +245,7 @@ def create(event: dict, context: dict) -> Any: @api_wrapper def update(event: dict, context: dict) -> Any: """Update an existing mcp server in DynamoDB.""" - user_id = get_username(event) + user_id, is_admin_user, groups = get_user_context(event) mcp_server_id = get_mcp_server_id(event) body = json.loads(event["body"], parse_float=Decimal) body["owner"] = user_id if body.get("owner", None) != "lisa:public" else body["owner"] @@ -249,7 +262,7 @@ def update(event: dict, context: dict) -> Any: raise ValueError(f"MCP Server {mcp_server_model} not found.") # Check if the user is authorized to update the mcp server - if is_admin(event) or item["owner"] == user_id: + if is_admin_user or item["owner"] == user_id: # Check if switching to global if item["owner"] != mcp_server_model.owner: table.delete_item(Key={"id": mcp_server_id, "owner": item["owner"]}) @@ -264,7 +277,7 @@ def update(event: dict, context: dict) -> Any: @api_wrapper def delete(event: dict, context: dict) -> Dict[str, str]: """Logically delete a mcp server from DynamoDB.""" - user_id = get_username(event) + user_id, is_admin_user, _ = get_user_context(event) mcp_server_id = get_mcp_server_id(event) # Query for the mcp server @@ -275,7 +288,7 @@ def delete(event: dict, context: dict) -> Dict[str, str]: raise ValueError(f"MCP Server {mcp_server_id} not found.") # Check if the user is authorized to delete the mcp server - if is_admin(event) or item["owner"] == user_id: + if is_admin_user or item["owner"] == user_id: logger.info(f"Deleting mcp server {mcp_server_id} for user {user_id}") table.delete_item(Key={"id": mcp_server_id, "owner": item.get("owner")}) return {"status": "ok"} @@ -286,3 +299,240 @@ def delete(event: dict, context: dict) -> Dict[str, str]: def get_mcp_server_id(event: dict) -> str: """Extract the mcp server id from the event's path parameters.""" return str(event["pathParameters"]["serverId"]) + + +@api_wrapper +@admin_only +def create_hosted_mcp_server(event: dict, context: dict) -> Any: + """Trigger the state machine to create a LISA Hosted MCP server.""" + user_id, is_admin_user, groups = get_user_context(event) + body = json.loads(event["body"], parse_float=Decimal) + body["owner"] = user_id if body.get("owner", None) != "lisa:public" else body["owner"] + body["id"] = str(uuid.uuid4()) + + # Check if the user is authorized to create Hosted MCP server + if is_admin_user: + # Validate and parse the hosted server configuration + hosted_server_model = HostedMcpServerModel(**body) + + # Check if normalized name is unique + normalized_name = _normalize_server_name(hosted_server_model.name) + if not normalized_name: + raise ValueError("Server name must contain at least one alphanumeric character.") + + # Scan all items to check for duplicate normalized names + items = [] + scan_arguments = {} + while True: + response = table.scan(**scan_arguments) + items.extend(response.get("Items", [])) + + if "LastEvaluatedKey" in response: + scan_arguments["ExclusiveStartKey"] = response["LastEvaluatedKey"] + else: + break + + # Check if any existing server has the same normalized name + for item in items: + existing_name = item.get("name", "") + existing_normalized = _normalize_server_name(existing_name) + if existing_normalized == normalized_name and item.get("id") != body["id"]: + raise ValueError( + f"Server name '{hosted_server_model.name}' conflicts with existing server '{existing_name}'. " + f"Normalized names must be unique (alphanumeric characters only)." + ) + + # persist initial record + table.put_item(Item=hosted_server_model.model_dump(exclude_none=True)) + + # kick off state machine + sfn_arn = os.environ.get("CREATE_MCP_SERVER_SFN_ARN") + if not sfn_arn: + raise ValueError("CREATE_MCP_SERVER_SFN_ARN not configured") + stepfunctions.start_execution( + stateMachineArn=sfn_arn, + input=json.dumps(hosted_server_model.model_dump(exclude_none=True)), + ) + + result = hosted_server_model.model_dump(exclude_none=True) + result["status"] = HostedMcpServerStatus.CREATING + return result + raise ValueError(f"Not authorized to create hosted MCP server. User {user_id} is not an admin.") + + +@api_wrapper +@admin_only +def list_hosted_mcp_servers(event: dict, context: dict) -> Dict[str, Any]: + """List all hosted MCP servers from DynamoDB.""" + user_id, is_admin_user, groups = get_user_context(event) + + # Check if the user is authorized to list hosted MCP servers + if is_admin_user: + logger.info(f"Listing all hosted MCP servers for user {user_id} (is_admin)") + # Get all items from the table + items = [] + scan_arguments = {} + while True: + response = table.scan(**scan_arguments) + items.extend(response.get("Items", [])) + + if "LastEvaluatedKey" in response: + scan_arguments["ExclusiveStartKey"] = response["LastEvaluatedKey"] + else: + break + + return {"Items": items} + + raise ValueError(f"Not authorized to list hosted MCP servers. User {user_id} is not an admin.") + + +@api_wrapper +@admin_only +def get_hosted_mcp_server(event: dict, context: dict) -> Any: + """Retrieve a specific hosted MCP server from DynamoDB.""" + user_id, is_admin_user, groups = get_user_context(event) + mcp_server_id = get_mcp_server_id(event) + + # Query for the mcp server + response = table.query(KeyConditionExpression=Key("id").eq(mcp_server_id), Limit=1, ScanIndexForward=False) + item = get_item(response) + + if item is None: + raise ValueError(f"Hosted MCP Server {mcp_server_id} not found.") + + # Check if the user is authorized to get the hosted mcp server + if is_admin_user: + return item + + raise ValueError(f"Not authorized to get hosted MCP server {mcp_server_id}. User {user_id} is not an admin.") + + +@api_wrapper +@admin_only +def delete_hosted_mcp_server(event: dict, context: dict) -> Any: + """Trigger the state machine to delete a LISA Hosted MCP server.""" + user_id, is_admin_user, groups = get_user_context(event) + mcp_server_id = get_mcp_server_id(event) + + # Check if server exists + response = table.query(KeyConditionExpression=Key("id").eq(mcp_server_id), Limit=1, ScanIndexForward=False) + item = get_item(response) + + if item is None: + raise ValueError(f"Hosted MCP Server {mcp_server_id} not found.") + + # Validate server status - only allow deletion if in specific states + server_status = item.get("status", "") + allowed_statuses = [ + HostedMcpServerStatus.IN_SERVICE, + HostedMcpServerStatus.STOPPED, + HostedMcpServerStatus.FAILED, + ] + if server_status not in allowed_statuses: + raise ValueError( + f"Cannot delete server {mcp_server_id} with status '{server_status}'. " + f"Only servers with status '{HostedMcpServerStatus.IN_SERVICE}', " + f"'{HostedMcpServerStatus.STOPPED}', or '{HostedMcpServerStatus.FAILED}' can be deleted." + ) + + # Kick off state machine + sfn_arn = os.environ.get("DELETE_MCP_SERVER_SFN_ARN") + if not sfn_arn: + raise ValueError("DELETE_MCP_SERVER_SFN_ARN not configured") + + stepfunctions.start_execution( + stateMachineArn=sfn_arn, + input=json.dumps({"id": mcp_server_id}), + ) + + return {"message": f"Deletion initiated for hosted MCP server {mcp_server_id}"} + + +@api_wrapper +@admin_only +def update_hosted_mcp_server(event: dict, context: dict) -> Any: + """Trigger the state machine to update a LISA Hosted MCP server.""" + user_id, is_admin_user, groups = get_user_context(event) + mcp_server_id = get_mcp_server_id(event) + + # Check if server exists + response = table.query(KeyConditionExpression=Key("id").eq(mcp_server_id), Limit=1, ScanIndexForward=False) + item = get_item(response) + + if item is None: + raise ValueError(f"Hosted MCP Server {mcp_server_id} not found.") + + server_status = item.get("status", "") + + # Validate server is not actively mutating or failed before starting + if server_status not in (HostedMcpServerStatus.IN_SERVICE, HostedMcpServerStatus.STOPPED): + raise ValueError( + f"Server cannot be updated when it is not in the '{HostedMcpServerStatus.IN_SERVICE}' or " + f"'{HostedMcpServerStatus.STOPPED}' states" + ) + + # Parse and validate update request + body = json.loads(event["body"], parse_float=Decimal) + update_request = UpdateHostedMcpServerRequest(**body) + + # Validate enable/disable state transitions + if update_request.enabled is not None: + # Force capacity changes and enable/disable operations to happen in separate requests + if update_request.autoScalingConfig is not None: + raise ValueError("Start or Stop operations and AutoScaling changes must happen in separate requests.") + # Server cannot be enabled if it isn't already stopped + if update_request.enabled and server_status != HostedMcpServerStatus.STOPPED: + raise ValueError(f"Server cannot be enabled when it is not in the '{HostedMcpServerStatus.STOPPED}' state.") + # Server cannot be stopped if it isn't already in service + elif not update_request.enabled and server_status != HostedMcpServerStatus.IN_SERVICE: + raise ValueError( + f"Server cannot be stopped when it is not in the '{HostedMcpServerStatus.IN_SERVICE}' state." + ) + + # Validate auto-scaling config + if update_request.autoScalingConfig is not None: + stack_name = item.get("stack_name") + if not stack_name: + raise ValueError("Cannot update AutoScaling Config for server that does not have a CloudFormation stack.") + + asg_config = update_request.autoScalingConfig.model_dump(exclude_none=True) + current_asg_config = item.get("autoScalingConfig", {}) + + # Validate min <= max + min_capacity = asg_config.get("minCapacity", current_asg_config.get("minCapacity", 1)) + max_capacity = asg_config.get("maxCapacity", current_asg_config.get("maxCapacity", 1)) + + if min_capacity > max_capacity: + raise ValueError(f"Min capacity ({min_capacity}) cannot be greater than max capacity ({max_capacity}).") + + # Validate min and max are positive + if min_capacity < 1: + raise ValueError("Min capacity must be at least 1.") + if max_capacity < 1: + raise ValueError("Max capacity must be at least 1.") + + # Validate container config updates + if ( + update_request.environment is not None + or update_request.cpu is not None + or update_request.memoryLimitMiB is not None + or update_request.containerHealthCheckConfig is not None + ): + stack_name = item.get("stack_name") + if not stack_name: + raise ValueError("Cannot update container config for server that does not have a CloudFormation stack.") + + # Kick off state machine + sfn_arn = os.environ.get("UPDATE_MCP_SERVER_SFN_ARN") + if not sfn_arn: + raise ValueError("UPDATE_MCP_SERVER_SFN_ARN not configured") + + # Package server ID and request payload into single payload for step functions + state_machine_payload = {"server_id": mcp_server_id, "update_payload": update_request.model_dump()} + stepfunctions.start_execution( + stateMachineArn=sfn_arn, + input=json.dumps(state_machine_payload), + ) + + # Return current server config (status will be updated by state machine) + return item diff --git a/lambda/mcp_server/models.py b/lambda/mcp_server/models.py index 7db48941b..b880bd71c 100644 --- a/lambda/mcp_server/models.py +++ b/lambda/mcp_server/models.py @@ -14,18 +14,29 @@ import uuid from datetime import datetime -from enum import Enum -from typing import List, Optional +from enum import StrEnum +from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self +from utilities.validation import validate_any_fields_defined -class McpServerStatus(str, Enum): - """Enum representing the prompt template type.""" +class HostedMcpServerStatus(StrEnum): + """Defines possible MCP server deployment states.""" + + CREATING = "Creating" + IN_SERVICE = "InService" + STARTING = "Starting" + STOPPING = "Stopping" + STOPPED = "Stopped" + UPDATING = "Updating" + DELETING = "Deleting" + FAILED = "Failed" - def __str__(self) -> str: - """Represent the enum as a string.""" - return str(self.value) + +class McpServerStatus(StrEnum): + """Enum representing the prompt template type.""" ACTIVE = "active" INACTIVE = "inactive" @@ -62,7 +73,210 @@ class McpServerModel(BaseModel): clientConfig: Optional[dict] = Field(default_factory=lambda: None) # Status of the server set by admins - status: Optional[McpServerStatus] = Field(default=McpServerStatus.INACTIVE) + status: Optional[McpServerStatus] = Field(default=McpServerStatus.ACTIVE) # Groups of the MCP server groups: Optional[List[str]] = Field(default_factory=lambda: None) + + +class LoadBalancerHealthCheckConfig(BaseModel): + """Specifies health check parameters for load balancer configuration.""" + + path: str = Field(min_length=1) + interval: int = Field(gt=0) + timeout: int = Field(gt=0) + healthyThresholdCount: int = Field(gt=0) + unhealthyThresholdCount: int = Field(gt=0) + + +class LoadBalancerConfig(BaseModel): + """Defines load balancer settings.""" + + healthCheckConfig: LoadBalancerHealthCheckConfig + + +class ContainerHealthCheckConfig(BaseModel): + """Specifies container health check parameters.""" + + command: Union[str, List[str]] + interval: int = Field(gt=0) + startPeriod: int = Field(ge=0) + timeout: int = Field(gt=0) + retries: int = Field(gt=0) + + +class AutoScalingConfig(BaseModel): + """Auto-scaling configuration for hosted MCP servers.""" + + minCapacity: int + maxCapacity: int + targetValue: Optional[int] = Field(default=None) + metricName: Optional[str] = Field(default=None) + duration: Optional[int] = Field(default=None) + cooldown: Optional[int] = Field(default=None) + + +class AutoScalingConfigUpdate(BaseModel): + """Updatable auto-scaling configuration for hosted MCP servers (all fields optional).""" + + minCapacity: Optional[int] = Field(default=None) + maxCapacity: Optional[int] = Field(default=None) + targetValue: Optional[int] = Field(default=None) + metricName: Optional[str] = Field(default=None) + duration: Optional[int] = Field(default=None) + cooldown: Optional[int] = Field(default=None) + + +class HostedMcpServerModel(BaseModel): + """ + A Pydantic model representing a hosted MCP server configuration. + This model is used for creating MCP servers that are deployed on ECS Fargate. + """ + + # Unique identifier for the mcp server + id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4())) + + # Timestamp of when the mcp server was created + created: Optional[str] = Field(default_factory=lambda: datetime.now().isoformat()) + + # Owner of the MCP server + owner: str + + # Name of the MCP server + name: str + + # Description of the MCP server + description: Optional[str] = Field(default_factory=lambda: None) + + # Command to start the server + startCommand: str + + # Port number (optional, used for HTTP/SSE servers) + port: Optional[int] = Field(default=None) + + # Server type: 'stdio', 'http', or 'sse' + serverType: str + + # Container image (optional) + # If provided without s3Path: use as pre-built container image + # If provided with s3Path: use as base image for building from S3 artifacts + image: Optional[str] = Field(default=None) + + # S3 path to server artifacts (binaries, Python files, etc.) + # If provided with image: image is used as base image for building + # If provided without image: default base image is used + s3Path: Optional[str] = Field(default=None) + + # Auto-scaling configuration + autoScalingConfig: AutoScalingConfig + + # Load balancer configuration (optional, will use defaults if not provided) + loadBalancerConfig: Optional[LoadBalancerConfig] = Field(default=None) + + # Container health check configuration (optional, will use defaults if not provided) + containerHealthCheckConfig: Optional[ContainerHealthCheckConfig] = Field(default=None) + + # Environment variables for the container + environment: Optional[Dict[str, str]] = Field(default_factory=lambda: None) + + # IAM role ARN for task execution (optional, will be auto-created if not provided) + taskExecutionRoleArn: Optional[str] = Field(default=None) + + # IAM role ARN for running tasks (optional, will be auto-created if not provided) + taskRoleArn: Optional[str] = Field(default=None) + + # Fargate CPU units (defaults to 256 which equals 0.25 vCPU) + cpu: Optional[int] = Field(default=256) + + # Fargate memory limit in MiB (defaults to 512 MiB) + memoryLimitMiB: Optional[int] = Field(default=512) + + # Groups of the MCP server (for authorization) + groups: Optional[List[str]] = Field(default_factory=lambda: None) + + # Status of the server + status: Optional[HostedMcpServerStatus] = Field(default=HostedMcpServerStatus.CREATING) + + +class UpdateHostedMcpServerRequest(BaseModel): + """Specifies parameters for hosted MCP server update requests.""" + + enabled: Optional[bool] = None + autoScalingConfig: Optional[AutoScalingConfigUpdate] = None + environment: Optional[Dict[str, str]] = None + containerHealthCheckConfig: Optional[ContainerHealthCheckConfig] = None + loadBalancerConfig: Optional[LoadBalancerConfig] = None + cpu: Optional[int] = None + memoryLimitMiB: Optional[int] = None + description: Optional[str] = None + groups: Optional[List[str]] = None + + @model_validator(mode="after") + def validate_update_request(self) -> Self: + """Validates update request parameters.""" + fields = [ + self.enabled, + self.autoScalingConfig, + self.environment, + self.containerHealthCheckConfig, + self.loadBalancerConfig, + self.cpu, + self.memoryLimitMiB, + self.description, + self.groups, + ] + if not validate_any_fields_defined(fields): + raise ValueError( + "At least one field out of enabled, autoScalingConfig, environment, " + "containerHealthCheckConfig, loadBalancerConfig, cpu, memoryLimitMiB, " + "description, or groups must be defined in request payload." + ) + return self + + @field_validator("autoScalingConfig") + @classmethod + def validate_autoscaling_config(cls, config: Optional[AutoScalingConfig]) -> Optional[AutoScalingConfig]: + """Validates auto-scaling configuration.""" + if config is not None and not config: + raise ValueError("The autoScalingConfig must not be null if defined in request payload.") + return config + + @field_validator("containerHealthCheckConfig") + @classmethod + def validate_container_health_check_config( + cls, config: Optional[ContainerHealthCheckConfig] + ) -> Optional[ContainerHealthCheckConfig]: + """Validates container health check configuration.""" + if config is not None and not config: + raise ValueError("The containerHealthCheckConfig must not be null if defined in request payload.") + return config + + @field_validator("loadBalancerConfig") + @classmethod + def validate_load_balancer_config(cls, config: Optional[LoadBalancerConfig]) -> Optional[LoadBalancerConfig]: + """Validates load balancer configuration.""" + if config is not None and not config: + raise ValueError("The loadBalancerConfig must not be null if defined in request payload.") + return config + + @field_validator("cpu") + @classmethod + def validate_cpu(cls, cpu: Optional[int]) -> Optional[int]: + """Validates CPU units.""" + if cpu is not None: + # Fargate CPU must be in valid units: 256, 512, 1024, 2048, 4096 + valid_cpu_values = [256, 512, 1024, 2048, 4096] + if cpu not in valid_cpu_values: + raise ValueError(f"CPU must be one of {valid_cpu_values}") + return cpu + + @field_validator("memoryLimitMiB") + @classmethod + def validate_memory(cls, memory: Optional[int]) -> Optional[int]: + """Validates memory limit.""" + if memory is not None: + if memory < 512: + raise ValueError("Memory limit must be at least 512 MiB") + if memory > 30720: + raise ValueError("Memory limit must be at most 30720 MiB") + return memory diff --git a/lambda/mcp_server/state_machine/__init__.py b/lambda/mcp_server/state_machine/__init__.py new file mode 100644 index 000000000..4139ae4d0 --- /dev/null +++ b/lambda/mcp_server/state_machine/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lambda/mcp_server/state_machine/create_mcp_server.py b/lambda/mcp_server/state_machine/create_mcp_server.py new file mode 100644 index 000000000..6947f4b18 --- /dev/null +++ b/lambda/mcp_server/state_machine/create_mcp_server.py @@ -0,0 +1,327 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda handlers for CreateMcpServer state machine.""" + +import json +import logging +import os +import re +from copy import deepcopy +from datetime import datetime, UTC +from typing import Any, Dict, Optional + +import boto3 +from botocore.config import Config +from mcp_server.models import HostedMcpServerModel, HostedMcpServerStatus, McpServerStatus + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +lambdaConfig = Config(connect_timeout=60, read_timeout=600, retries={"max_attempts": 1}) +lambdaClient = boto3.client("lambda", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +cfnClient = boto3.client("cloudformation", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +ssmClient = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +ddbResource = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +mcp_servers_table = ddbResource.Table(os.environ["MCP_SERVERS_TABLE_NAME"]) + +MAX_POLLS = 60 + + +def handle_set_server_to_creating(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Set DDB entry to CREATING status.""" + logger.info(f"Setting MCP server to CREATING status: {event.get('id')}") + output_dict = deepcopy(event) + + server_id = event.get("id") + + if not server_id: + raise ValueError("Missing required field: id") + + mcp_servers_table.update_item( + Key={"id": server_id}, + UpdateExpression="SET #status = :status, last_modified = :lm", + ExpressionAttributeNames={"#status": "status"}, + ExpressionAttributeValues={ + ":status": HostedMcpServerStatus.CREATING, + ":lm": int(datetime.now(UTC).timestamp()), + }, + ) + + output_dict["server_status"] = HostedMcpServerStatus.CREATING + return output_dict + + +def handle_deploy_server(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Invoke MCP server deployer to create infrastructure.""" + logger.info(f"Deploying MCP server: {event.get('id')}") + output_dict = deepcopy(event) + + try: + # Validate and build server config using Pydantic model + # Exclude fields that the deployer doesn't need (owner, description, created, status) + server_config_model = HostedMcpServerModel.model_validate(event) + server_config = server_config_model.model_dump( + exclude_none=True, exclude={"owner", "description", "created", "status"} + ) + + logger.info(f"Sending server config to deployer: {json.dumps(server_config)}") + + # Invoke the MCP server deployer + response = lambdaClient.invoke( + FunctionName=os.environ["MCP_SERVER_DEPLOYER_FN_ARN"], + Payload=json.dumps({"mcpServerConfig": server_config}), + ) + + payload = response["Payload"].read() + payload = json.loads(payload) + stack_name = payload.get("stackName", None) + + if not stack_name: + logger.error(f"MCP Server Deployer response: {payload}") + raise ValueError(f"Failed to create MCP server stack: {payload}") + + response = cfnClient.describe_stacks(StackName=stack_name) + stack_arn = response["Stacks"][0]["StackId"] + + mcp_servers_table.update_item( + Key={"id": event.get("id")}, + UpdateExpression="SET #status = :status, stack_name = :stack_name, cloudformation_stack_arn = :stack_arn," + + " last_modified = :lm", + ExpressionAttributeNames={"#status": "status"}, + ExpressionAttributeValues={ + ":status": HostedMcpServerStatus.CREATING, + ":stack_name": stack_name, + ":stack_arn": stack_arn, + ":lm": int(datetime.now(UTC).timestamp()), + }, + ) + except Exception as e: + logger.error(f"Error deploying MCP server: {str(e)}") + raise Exception( + json.dumps( + { + "error": f"Error deploying MCP server: {str(e)}", + "event": event, + } + ) + ) + + output_dict["stack_name"] = stack_name + output_dict["stack_arn"] = stack_arn + output_dict["poll_count"] = 0 + output_dict["continue_polling"] = True + return output_dict + + +def handle_poll_deployment(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Poll CloudFormation stack status.""" + logger.info(f"Polling deployment status for stack: {event.get('stack_name')}") + output_dict = deepcopy(event) + + stack_name = event.get("stack_name") + stack_arn = event.get("stack_arn") + poll_count = event.get("poll_count", 0) + + if poll_count > MAX_POLLS: + raise Exception(f"Max polls exceeded for stack {stack_name}") + + try: + response = cfnClient.describe_stacks(StackName=stack_arn) + stack_status = response["Stacks"][0]["StackStatus"] + + logger.info(f"Stack {stack_name} status: {stack_status}") + + # Check if stack creation is complete + if stack_status in ["CREATE_COMPLETE", "UPDATE_COMPLETE"]: + output_dict["continue_polling"] = False + output_dict["stack_status"] = stack_status + elif stack_status.endswith("_FAILED") or stack_status.endswith("ROLLBACK_COMPLETE"): + raise Exception( + json.dumps( + { + "error": f"Stack {stack_name} failed with status: {stack_status}", + "event": event, + } + ) + ) + else: + # Still in progress + output_dict["poll_count"] = poll_count + 1 + output_dict["continue_polling"] = True + except Exception as e: + logger.error(f"Error polling stack status: {str(e)}") + raise Exception( + json.dumps( + { + "error": f"Error polling stack status: {str(e)}", + "event": event, + } + ) + ) + + return output_dict + + +def _get_mcp_connections_table_name(deployment_prefix: str) -> Optional[str]: + """Get MCP connections table name from SSM parameter if chat is deployed.""" + try: + response = ssmClient.get_parameter(Name=f"{deployment_prefix}/table/mcpServersTable") + return response["Parameter"]["Value"] + except ssmClient.exceptions.ParameterNotFound: + logger.info("MCP connections table SSM parameter not found, chat may not be deployed") + return None + except Exception as e: + logger.warning(f"Error getting MCP connections table name: {str(e)}") + return None + + +def _get_api_gateway_url(deployment_prefix: str) -> Optional[str]: + """Get API Gateway base URL from SSM parameter.""" + try: + response = ssmClient.get_parameter(Name=f"{deployment_prefix}/LisaApiUrl") + return response["Parameter"]["Value"] + except Exception as e: + logger.warning(f"Error getting API Gateway URL: {str(e)}") + return None + + +def _normalize_server_identifier(server_id: str) -> str: + """Normalize server identifier to match CDK resource naming (alphanumeric only).""" + return re.sub(r"[^a-zA-Z0-9]", "", server_id) + + +def handle_add_server_to_active(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Set server status to IN_SERVICE after successful deployment.""" + logger.info(f"Setting MCP server to IN_SERVICE: {event.get('id')}") + output_dict = deepcopy(event) + + server_id = event.get("id") + stack_name = event.get("stack_name") + name = event.get("name") + description = event.get("description") + idp_groups = event.get("groups", []) + owner = event.get("owner", "lisa:public") if idp_groups != [] else "lisa:public" + + # Update server status to IN_SERVICE + mcp_servers_table.update_item( + Key={"id": server_id}, + UpdateExpression="SET #status = :status, stack_name = :stack_name, last_modified = :lm", + ExpressionAttributeNames={"#status": "status"}, + ExpressionAttributeValues={ + ":status": HostedMcpServerStatus.IN_SERVICE, + ":stack_name": stack_name, + ":lm": int(datetime.now(UTC).timestamp()), + }, + ) + + # Create connection entry in MCP Connections table if chat is deployed + deployment_prefix = os.environ.get("DEPLOYMENT_PREFIX", "") + if deployment_prefix: + mcp_connections_table_name = _get_mcp_connections_table_name(deployment_prefix) + if mcp_connections_table_name: + try: + api_gateway_url = _get_api_gateway_url(deployment_prefix) + if api_gateway_url: + # Normalize server ID to match what CDK uses for resource naming + normalized_id = _normalize_server_identifier(name) + # Construct API Gateway URL for the hosted server + server_url = f"{api_gateway_url}/mcp/{normalized_id}/mcp" + + # Format groups with "group:" prefix if not already present + formatted_groups = [] + for group in idp_groups: + if group.startswith("group:"): + formatted_groups.append(group) + else: + formatted_groups.append(f"group:{group}") + + # Create connection entry + mcp_connections_table = ddbResource.Table(mcp_connections_table_name) + connection_entry = { + "id": server_id, + "owner": owner, + "url": server_url, + "name": name, + "created": datetime.now().isoformat(), + "customHeaders": {"Authorization": "Bearer {LISA_BEARER_TOKEN}"}, + "status": McpServerStatus.ACTIVE, + } + + if description: + connection_entry["description"] = description + + if formatted_groups: + connection_entry["groups"] = formatted_groups + + mcp_connections_table.put_item(Item=connection_entry) + logger.info(f"Created MCP connection entry for server {server_id} in connections table") + else: + logger.warning("Could not get API Gateway URL, skipping connection entry creation") + except Exception as e: + logger.error(f"Error creating MCP connection entry: {str(e)}") + # Don't fail the state machine if connection entry creation fails + else: + logger.info( + "MCP connections table not found, skipping connection entry creation (chat may not be deployed)" + ) + + output_dict["server_status"] = "InService" + return output_dict + + +def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Handle failure in the state machine.""" + logger.error(f"Handling MCP server creation failure: {event}") + + # Update server status to failed + try: + # Parse the error from Step Functions + cause_data = json.loads(event["Cause"]) + error_message = cause_data["errorMessage"] + + # Try to parse the error message as JSON (for our custom exceptions) + try: + error_dict = json.loads(error_message) + if isinstance(error_dict, dict) and "error" in error_dict: + error_reason = error_dict["error"] + original_event = error_dict.get("event", event) + else: + # If it's not our expected format, use the raw error message + error_reason = str(error_dict) if error_dict else "Unknown error" + original_event = event + except (json.JSONDecodeError, TypeError): + # If error_message is not JSON, use it directly + error_reason = error_message + original_event = event + + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.error(f"Error parsing failure event: {str(e)}") + error_reason = f"Failed to parse error details: {str(e)}" + original_event = event + + logger.error(f"Failure reason: {error_reason}, ServerId: {original_event.get('id', 'unknown')}") + + mcp_servers_table.update_item( + Key={"id": original_event.get("id", "unknown")}, + UpdateExpression="SET #status = :status, error_message = :error, last_modified = :lm", + ExpressionAttributeNames={"#status": "status"}, + ExpressionAttributeValues={ + ":status": HostedMcpServerStatus.FAILED, + ":error": event.get("error", "Unknown error"), + ":lm": int(datetime.now(UTC).timestamp()), + }, + ) + + return event diff --git a/lambda/mcp_server/state_machine/delete_mcp_server.py b/lambda/mcp_server/state_machine/delete_mcp_server.py new file mode 100644 index 000000000..b5985220e --- /dev/null +++ b/lambda/mcp_server/state_machine/delete_mcp_server.py @@ -0,0 +1,179 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda handlers for DeleteMcpServer state machine.""" + +import logging +import os +from copy import deepcopy +from datetime import datetime, UTC +from typing import Any, Dict, Optional +from uuid import uuid4 + +import boto3 +from boto3.dynamodb.conditions import Attr +from botocore.config import Config +from botocore.exceptions import ClientError +from mcp_server.models import HostedMcpServerStatus + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +lambdaConfig = Config(connect_timeout=60, read_timeout=600, retries={"max_attempts": 1}) +cfnClient = boto3.client("cloudformation", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +ddbResource = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +ssmClient = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +mcp_servers_table = ddbResource.Table(os.environ["MCP_SERVERS_TABLE_NAME"]) + +# DDB and Payload fields +STACK_NAME = "stack_name" +STACK_ARN = "cloudformation_stack_arn" + + +def _get_mcp_connections_table_name(deployment_prefix: str) -> Optional[str]: + """Get MCP connections table name from SSM parameter if chat is deployed.""" + try: + response = ssmClient.get_parameter(Name=f"{deployment_prefix}/table/mcpServersTable") + return response["Parameter"]["Value"] + except ssmClient.exceptions.ParameterNotFound: + logger.info("MCP connections table SSM parameter not found, chat may not be deployed") + return None + except Exception as e: + logger.warning(f"Error getting MCP connections table name: {str(e)}") + return None + + +def handle_set_server_to_deleting(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Start deletion workflow based on user-specified server input.""" + output_dict = deepcopy(event) + server_id = event["id"] + logger.info(f"Starting deletion workflow for MCP server: {server_id}") + server_key = {"id": server_id} + item = mcp_servers_table.get_item( + Key=server_key, + ConsistentRead=True, + ReturnConsumedCapacity="NONE", + ).get("Item", None) + if not item: + raise RuntimeError(f"Requested MCP server '{server_id}' was not found in DynamoDB table.") + stack_name = item.get(STACK_NAME, None) + stack_arn = item.get(STACK_ARN, None) + # Convert stack name to ARN if stack_name exists + if stack_name: + output_dict[STACK_NAME] = stack_name + output_dict[STACK_ARN] = stack_arn # Use stack name as ARN + else: + output_dict[STACK_ARN] = None + + mcp_servers_table.update_item( + Key=server_key, + UpdateExpression="SET last_modified = :lm, #status = :ms", + ExpressionAttributeNames={"#status": "status"}, + ExpressionAttributeValues={ + ":lm": int(datetime.now(UTC).timestamp()), + ":ms": HostedMcpServerStatus.DELETING, + }, + ) + return output_dict + + +def handle_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Initialize stack deletion.""" + output_dict = deepcopy(event) + stack_arn = event.get(STACK_ARN) + if not stack_arn: + raise ValueError("Stack arn not found in event") + + # Get the actual stack ARN before deleting + + logger.info(f"Deleting CloudFormation stack: {stack_arn}") + client_request_token = str(uuid4()) + cfnClient.delete_stack( + StackName=stack_arn, + ClientRequestToken=client_request_token, + ) + return output_dict + + +def handle_monitor_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Get stack status while it is being deleted and evaluate if state machine should continue polling.""" + output_dict = deepcopy(event) + # Prefer ARN if available, fall back to stack name + stack_identifier = event.get(STACK_ARN) or event.get(STACK_NAME) + if not stack_identifier: + raise ValueError("Stack ARN or name not found in event") + + try: + stack_metadata = cfnClient.describe_stacks(StackName=stack_identifier)["Stacks"][0] + stack_status = stack_metadata["StackStatus"] + continue_polling = True # stack not done yet, so continue monitoring + if stack_status == "DELETE_COMPLETE": + continue_polling = False # stack finished, allow state machine to stop polling + elif stack_status.endswith("COMPLETE") or stack_status.endswith("FAILED"): + # Didn't expect anything else, so raise error to fail state machine + raise RuntimeError(f"Stack entered unexpected terminal state '{stack_status}'.") + except ClientError as e: + # Check if the error is because the stack doesn't exist (ValidationError) + error_code = e.response.get("Error", {}).get("Code", "") + if error_code == "ValidationError": + # Stack doesn't exist - this means it was successfully deleted + # CloudFormation removes stacks completely after DELETE_COMPLETE, so ValidationError is expected + logger.info(f"Stack {stack_identifier} no longer exists (successfully deleted)") + continue_polling = False # Stack is gone, deletion is complete + else: + # Re-raise unexpected ClientErrors + logger.error(f"Error monitoring stack deletion: {str(e)}") + raise + except Exception as e: + # Re-raise unexpected errors + logger.error(f"Error monitoring stack deletion: {str(e)}") + raise + + output_dict["continue_polling"] = continue_polling + return output_dict + + +def handle_delete_from_ddb(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Delete item from DDB after successful deletion workflow and remove from connections table.""" + server_id = event["id"] + server_key = {"id": server_id} + + # Delete from MCP Connections table if chat is deployed + deployment_prefix = os.environ.get("DEPLOYMENT_PREFIX", "") + if deployment_prefix: + mcp_connections_table_name = _get_mcp_connections_table_name(deployment_prefix) + if mcp_connections_table_name: + try: + mcp_connections_table = ddbResource.Table(mcp_connections_table_name) + # The connections table uses (id, owner) as composite key + # We need to query/scan to find the entry with this server ID + # Since we don't know the owner, we'll scan for the id + response = mcp_connections_table.scan(FilterExpression=Attr("id").eq(server_id)) + + # Delete the matching item (there should only be one) + for item in response.get("Items", []): + mcp_connections_table.delete_item(Key={"id": item["id"], "owner": item["owner"]}) + logger.info( + f"Deleted MCP connection entry for server {server_id} (owner: {item['owner']}) " + + "from connections table" + ) + except Exception as e: + logger.warning(f"Error deleting from MCP connections table: {str(e)}") + # Continue with deletion from main table even if connections table deletion fails + + # Delete from main MCP servers table + mcp_servers_table.delete_item(Key=server_key) + logger.info(f"Deleted MCP server {server_id} from DynamoDB table") + + return event diff --git a/lambda/mcp_server/state_machine/update_mcp_server.py b/lambda/mcp_server/state_machine/update_mcp_server.py new file mode 100644 index 000000000..e058ac34b --- /dev/null +++ b/lambda/mcp_server/state_machine/update_mcp_server.py @@ -0,0 +1,1000 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda handlers for UpdateMcpServer state machine.""" + +import logging +import os +import re +from copy import deepcopy +from datetime import datetime, UTC +from typing import Any, Callable, Dict, List, Optional + +import boto3 +from boto3.dynamodb.conditions import Attr +from botocore.config import Config +from mcp_server.models import HostedMcpServerStatus, McpServerStatus + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +lambdaConfig = Config(connect_timeout=60, read_timeout=600, retries={"max_attempts": 1}) +ddbResource = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +mcp_servers_table = ddbResource.Table(os.environ["MCP_SERVERS_TABLE_NAME"]) +ecs_client = boto3.client("ecs", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +cfn_client = boto3.client("cloudformation", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=lambdaConfig) +application_autoscaling_client = boto3.client( + "application-autoscaling", region_name=os.environ["AWS_REGION"], config=lambdaConfig +) + +MAX_POLLS = 30 + + +def _get_mcp_connections_table_name(deployment_prefix: str) -> Optional[str]: + """Get MCP connections table name from SSM parameter if chat is deployed.""" + try: + response = ssm_client.get_parameter(Name=f"{deployment_prefix}/table/mcpServersTable") + return response["Parameter"]["Value"] + except ssm_client.exceptions.ParameterNotFound: + logger.info("MCP connections table SSM parameter not found, chat may not be deployed") + return None + except Exception as e: + logger.warning(f"Error getting MCP connections table name: {str(e)}") + return None + + +def _normalize_server_identifier(server_id: str) -> str: + """Normalize server identifier to match CDK resource naming (alphanumeric only).""" + return re.sub(r"[^a-zA-Z0-9]", "", server_id) + + +def _update_simple_field(server_config: Dict[str, Any], field_name: str, value: Any, server_id: str) -> None: + """Update a simple field in server_config.""" + logger.info(f"Setting {field_name} to '{value}' for server '{server_id}'") + server_config[field_name] = value + + +def _update_container_config( + server_config: Dict[str, Any], container_config: Dict[str, Any], server_id: str +) -> Dict[str, Any]: + """Handle container config update. + + Returns: + Dict containing any container config metadata needed for ECS updates + """ + logger.info(f"Updating container configuration for server '{server_id}'") + + container_metadata = {} + + if container_config.get("environment") is not None: + env_vars = container_config["environment"] + env_vars_to_delete = [] + + # Handle environment variable deletion markers + for key, value in env_vars.items(): + if value == "LISA_MARKED_FOR_DELETION": + env_vars_to_delete.append(key) + + for key in env_vars_to_delete: + del env_vars[key] + + server_config["environment"] = env_vars + + # Store deletion info for ECS update + if env_vars_to_delete: + container_metadata["env_vars_to_delete"] = env_vars_to_delete + logger.info(f"Deleted environment variables for server '{server_id}': {env_vars_to_delete}") + logger.info(f"Updated environment variables for server '{server_id}': {env_vars}") + + # Update CPU + if container_config.get("cpu") is not None: + server_config["cpu"] = int(container_config["cpu"]) + logger.info(f"Updated CPU for server '{server_id}': {container_config['cpu']}") + + # Update memory + if container_config.get("memoryLimitMiB") is not None: + server_config["memoryLimitMiB"] = int(container_config["memoryLimitMiB"]) + logger.info(f"Updated memory for server '{server_id}': {container_config['memoryLimitMiB']}") + + # Update container health check configuration + health_check_updates = {} + if container_config.get("containerHealthCheckConfig") is not None: + health_check_config = container_config["containerHealthCheckConfig"] + if health_check_config.get("command") is not None: + health_check_updates["command"] = health_check_config["command"] + if health_check_config.get("interval") is not None: + health_check_updates["interval"] = health_check_config["interval"] + if health_check_config.get("timeout") is not None: + health_check_updates["timeout"] = health_check_config["timeout"] + if health_check_config.get("startPeriod") is not None: + health_check_updates["startPeriod"] = health_check_config["startPeriod"] + if health_check_config.get("retries") is not None: + health_check_updates["retries"] = health_check_config["retries"] + + if health_check_updates: + server_config["containerHealthCheckConfig"] = health_check_updates + logger.info( + f"Updated container health check configuration for server '{server_id}': {health_check_updates}" + ) + + # Update load balancer health check configuration + if container_config.get("loadBalancerConfig") is not None: + server_config["loadBalancerConfig"] = container_config["loadBalancerConfig"] + logger.info(f"Updated load balancer configuration for server '{server_id}'") + + return container_metadata + + +def _get_metadata_update_handlers(server_config: Dict[str, Any], server_id: str) -> Dict[str, Callable[..., Any]]: + """Return a dictionary mapping field names to their update handlers.""" + return { + "description": lambda value: _update_simple_field(server_config, "description", value, server_id), + "groups": lambda value: _update_simple_field(server_config, "groups", value, server_id), + "environment": lambda value: _update_container_config(server_config, {"environment": value}, server_id), + "cpu": lambda value: _update_container_config(server_config, {"cpu": value}, server_id), + "memoryLimitMiB": lambda value: _update_container_config(server_config, {"memoryLimitMiB": value}, server_id), + "containerHealthCheckConfig": lambda value: _update_container_config( + server_config, {"containerHealthCheckConfig": value}, server_id + ), + "loadBalancerConfig": lambda value: _update_container_config( + server_config, {"loadBalancerConfig": value}, server_id + ), + } + + +def _process_metadata_updates( + server_config: Dict[str, Any], update_payload: Dict[str, Any], server_id: str +) -> tuple[bool, Dict[str, Any]]: + """ + Process metadata updates. + + Args: + server_config: The server configuration dictionary to update + update_payload: The payload containing updates + server_id: The server ID for logging purposes + + Returns: + tuple: (has_updates: bool, metadata: Dict containing update metadata) + """ + update_handlers = _get_metadata_update_handlers(server_config, server_id) + has_updates = False + update_metadata = {} + + # Handle container config fields specially + container_config_fields = [ + "environment", + "cpu", + "memoryLimitMiB", + "containerHealthCheckConfig", + "loadBalancerConfig", + ] + container_config_updates = {} + + for field_name in container_config_fields: + if field_name in update_payload and update_payload[field_name] is not None: + container_config_updates[field_name] = update_payload[field_name] + has_updates = True + + if container_config_updates: + container_metadata = _update_container_config(server_config, container_config_updates, server_id) + if container_metadata: + update_metadata["container"] = container_metadata + + # Handle simple fields + simple_fields = ["description", "groups"] + for field_name in simple_fields: + if field_name in update_payload and update_payload[field_name] is not None: + update_handlers[field_name](update_payload[field_name]) + has_updates = True + + return has_updates, update_metadata + + +def _update_mcp_connections_table_status(server_id: str, status: str) -> None: + """Update MCP Connections table status for a server.""" + deployment_prefix = os.environ.get("DEPLOYMENT_PREFIX", "") + if not deployment_prefix: + logger.info("No deployment prefix found, skipping MCP Connections table update") + return + + mcp_connections_table_name = _get_mcp_connections_table_name(deployment_prefix) + if not mcp_connections_table_name: + logger.info("MCP connections table not found, skipping status update") + return + + try: + mcp_connections_table = ddbResource.Table(mcp_connections_table_name) + # Scan for the connection entry with this server ID + response = mcp_connections_table.scan(FilterExpression=Attr("id").eq(server_id)) + + # Update the matching item(s) + for item in response.get("Items", []): + mcp_connections_table.update_item( + Key={"id": item["id"], "owner": item["owner"]}, + UpdateExpression="SET #status = :status", + ExpressionAttributeNames={"#status": "status"}, + ExpressionAttributeValues={":status": status}, + ) + logger.info(f"Updated MCP connection status for server {server_id} (owner: {item['owner']}) to {status}") + except Exception as e: + logger.warning(f"Error updating MCP Connections table status: {str(e)}") + # Don't fail the update if connection table update fails + + +def _update_mcp_connections_table_metadata( + server_id: str, description: Optional[str] = None, groups: Optional[List[str]] = None +) -> None: + """Update MCP Connections table metadata (description, groups) for a server.""" + deployment_prefix = os.environ.get("DEPLOYMENT_PREFIX", "") + if not deployment_prefix: + logger.info("No deployment prefix found, skipping MCP Connections metadata update") + return + + mcp_connections_table_name = _get_mcp_connections_table_name(deployment_prefix) + if not mcp_connections_table_name: + logger.info("MCP connections table not found, skipping metadata update") + return + + # Nothing to update + if description is None and groups is None: + return + + # Format groups with "group:" prefix if not already present + formatted_groups: Optional[List[str]] = None + if groups is not None: + formatted_groups = [] + for group in groups: + if group.startswith("group:"): + formatted_groups.append(group) + else: + formatted_groups.append(f"group:{group}") + + try: + mcp_connections_table = ddbResource.Table(mcp_connections_table_name) + # Scan for the connection entry with this server ID + response = mcp_connections_table.scan(FilterExpression=Attr("id").eq(server_id)) + + for item in response.get("Items", []): + update_expression_parts = [] + expr_attr_names: Dict[str, str] = {} + expr_attr_values: Dict[str, Any] = {} + + if description is not None: + update_expression_parts.append("#d = :desc") + expr_attr_names["#d"] = "description" + expr_attr_values[":desc"] = description + + if formatted_groups is not None: + update_expression_parts.append("#g = :groups") + expr_attr_names["#g"] = "groups" + expr_attr_values[":groups"] = formatted_groups + + if update_expression_parts: + update_expression = "SET " + ", ".join(update_expression_parts) + mcp_connections_table.update_item( + Key={"id": item["id"], "owner": item["owner"]}, + UpdateExpression=update_expression, + ExpressionAttributeNames=expr_attr_names if expr_attr_names else None, + ExpressionAttributeValues=expr_attr_values if expr_attr_values else None, + ) + logger.info( + f"Updated MCP connection metadata for server {server_id} (owner: {item['owner']}): " + f"{'description ' if description is not None else ''}" + f"{'groups ' if groups is not None else ''}".strip() + ) + except Exception as e: + logger.warning(f"Error updating MCP Connections table metadata: {str(e)}") + # Don't fail the update if connection table update fails + + +def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Handle initial UpdateMcpServer job submission. + + This handler will perform the following actions: + 1. Determine if any metadata (description, groups, environment, etc.) changes are required + 2. Determine if any AutoScaling changes are required + 3. Determine if enable/disable operation is required + 4. Commit changes to the database + """ + output_dict = deepcopy(event) + + server_id = event["server_id"] + logger.info(f"Processing UpdateMcpServer request for '{server_id}' with payload: {event}") + server_key = {"id": server_id} + ddb_item = mcp_servers_table.get_item( + Key=server_key, + ConsistentRead=True, + ).get("Item", None) + if not ddb_item: + raise RuntimeError(f"Requested server '{server_id}' was not found in DynamoDB table.") + + server_status = ddb_item.get("status") + stack_name = ddb_item.get("stack_name", None) + + if not stack_name: + raise RuntimeError("Cannot update server that does not have a CloudFormation stack.") + + output_dict["stack_name"] = stack_name + + # Two checks for enabling: check that value was not omitted, then check that it was actually True. + is_activation_request = event["update_payload"].get("enabled", None) is not None + is_enable = event["update_payload"].get("enabled", False) + is_disable = is_activation_request and not is_enable + + is_autoscaling_update = event["update_payload"].get("autoScalingConfig", None) is not None + + if is_activation_request and is_autoscaling_update: + raise RuntimeError( + "Cannot request AutoScaling updates at the same time as an enable or disable operation. " + "Please perform those as two separate actions." + ) + + # set up DDB update expression to accumulate info as more options are processed + ddb_update_expression = "SET #status = :ms, last_modified = :lm" + ddb_update_values = { + ":ms": HostedMcpServerStatus.UPDATING, + ":lm": int(datetime.now(UTC).timestamp()), + } + ExpressionAttributeNames = {"#status": "status"} + + # Process metadata updates (description, groups, environment, CPU, memory, health checks) + server_config = {} + # Copy existing server config fields that can be updated + for field in [ + "description", + "groups", + "environment", + "cpu", + "memoryLimitMiB", + "containerHealthCheckConfig", + "loadBalancerConfig", + "autoScalingConfig", + ]: + if field in ddb_item: + server_config[field] = ddb_item[field] + + # Ensure autoScalingConfig exists if we're going to update it + if is_autoscaling_update and "autoScalingConfig" not in server_config: + raise RuntimeError("Cannot update auto-scaling config for server that does not have auto-scaling configured.") + + has_metadata_update, update_metadata = _process_metadata_updates(server_config, event["update_payload"], server_id) + + if is_activation_request: + logger.info(f"Detected enable or disable activity for '{server_id}'") + if is_enable: + if server_status != HostedMcpServerStatus.STOPPED: + raise RuntimeError( + f"Server cannot be enabled when it is not in the '{HostedMcpServerStatus.STOPPED}' state." + ) + + # Set status to Starting instead of Updating to signify that it can't be accessed by a user yet + ddb_update_values[":ms"] = HostedMcpServerStatus.STARTING + + # Get current auto-scaling config to determine min capacity + if "autoScalingConfig" not in server_config: + raise RuntimeError("Cannot enable server that does not have auto-scaling configured.") + min_capacity = server_config["autoScalingConfig"].get("minCapacity", 1) + logger.info(f"Starting server '{server_id}' with min capacity of {min_capacity}.") + + # Store min capacity for later use in handle_finish_update + output_dict["min_capacity"] = min_capacity + + # Scale ECS service to min capacity now (will be polled in handle_poll_capacity) + try: + service_arn, cluster_arn, _ = get_ecs_resources_from_stack(stack_name) + ecs_client.update_service(cluster=cluster_arn, service=service_arn, desiredCount=int(min_capacity)) + logger.info(f"Scaled ECS service to {min_capacity} for server '{server_id}'") + except Exception as e: + logger.error(f"Error scaling ECS service to {min_capacity}: {str(e)}") + raise RuntimeError(f"Failed to scale ECS service to {min_capacity}: {str(e)}") + else: + # Only if we are deactivating a server, we update MCP Connections table + logger.info(f"Updating MCP Connections table for server '{server_id}' because of 'disable' activity.") + _update_mcp_connections_table_status(server_id, McpServerStatus.INACTIVE) + + # set status to Stopping instead of Updating to signify why it was removed + ddb_update_values[":ms"] = HostedMcpServerStatus.STOPPING + + # Scale ECS service to 0 (will be handled immediately) + output_dict["desired_capacity"] = 0 + + # Update ECS service immediately for disable + try: + # Get ECS resources from stack + service_arn, cluster_arn, _ = get_ecs_resources_from_stack(stack_name) + + # Also set Application Auto Scaling target min/max to 0 to prevent immediate scale-up by policies + try: + service_name = service_arn.split("/")[-1] + cluster_name = cluster_arn.split("/")[-1] + scalable_target_id = f"service/{cluster_name}/{service_name}" + application_autoscaling_client.register_scalable_target( + ServiceNamespace="ecs", + ResourceId=scalable_target_id, + ScalableDimension="ecs:service:DesiredCount", + MinCapacity=0, + MaxCapacity=0, + ) + logger.info( + "Updated scalable target to MinCapacity=0, MaxCapacity=0 for server " + + f"'{server_id}' ({scalable_target_id})" + ) + except Exception as e: + logger.warning(f"Could not update scalable target to 0 for server '{server_id}': {str(e)}") + + # Update service to 0 desired count + ecs_client.update_service(cluster=cluster_arn, service=service_arn, desiredCount=0) + logger.info(f"Scaled ECS service to 0 for server '{server_id}'") + except Exception as e: + logger.error(f"Error scaling ECS service to 0: {str(e)}") + raise RuntimeError(f"Failed to scale ECS service to 0: {str(e)}") + + if is_autoscaling_update: + asg_config = event["update_payload"]["autoScalingConfig"] + # Stage metadata updates regardless of immediate capacity changes or not + # Merge updates with existing config (autoScalingConfig already exists, validated above) + updated_min_capacity = None + updated_max_capacity = None + if minCapacity := asg_config.get("minCapacity"): + server_config["autoScalingConfig"]["minCapacity"] = int(minCapacity) + updated_min_capacity = int(minCapacity) + if maxCapacity := asg_config.get("maxCapacity"): + server_config["autoScalingConfig"]["maxCapacity"] = int(maxCapacity) + updated_max_capacity = int(maxCapacity) + if cooldown := asg_config.get("cooldown"): + server_config["autoScalingConfig"]["cooldown"] = int(cooldown) + if targetValue := asg_config.get("targetValue"): + server_config["autoScalingConfig"]["targetValue"] = int(targetValue) + if metricName := asg_config.get("metricName"): + server_config["autoScalingConfig"]["metricName"] = metricName + if duration := asg_config.get("duration"): + server_config["autoScalingConfig"]["duration"] = int(duration) + + # If server is running, apply update immediately via Application Auto Scaling + if server_status == HostedMcpServerStatus.IN_SERVICE: + try: + # Get ECS resources from stack + service_arn, cluster_arn, _ = get_ecs_resources_from_stack(stack_name) + + # Get service name from ARN + service_name = service_arn.split("/")[-1] + cluster_name = cluster_arn.split("/")[-1] + + # Update scalable target + scalable_target_id = f"service/{cluster_name}/{service_name}" + + update_params = { + "ServiceNamespace": "ecs", + "ResourceId": scalable_target_id, + "ScalableDimension": "ecs:service:DesiredCount", + } + + # Use updated values if provided, otherwise use current values from server_config + if updated_min_capacity is not None: + update_params["MinCapacity"] = updated_min_capacity + else: + update_params["MinCapacity"] = server_config["autoScalingConfig"].get("minCapacity", 1) + + if updated_max_capacity is not None: + update_params["MaxCapacity"] = updated_max_capacity + else: + update_params["MaxCapacity"] = server_config["autoScalingConfig"].get("maxCapacity", 1) + + application_autoscaling_client.register_scalable_target(**update_params) + logger.info(f"Updated auto-scaling configuration for server '{server_id}': {update_params}") + + # Note: Scaling policies would need to be updated separately if they changed + # For now, we only update min/max capacity + except Exception as e: + logger.error(f"Error updating auto-scaling configuration: {str(e)}") + raise RuntimeError(f"Failed to update auto-scaling configuration: {str(e)}") + + has_metadata_update = True + + if has_metadata_update: + # Update server config in DynamoDB - only include fields that were actually updated + update_payload = event["update_payload"] + + if "description" in update_payload and update_payload["description"] is not None: + ddb_update_expression += ", description = :desc" + ddb_update_values[":desc"] = server_config.get("description") + + if "groups" in update_payload and update_payload["groups"] is not None: + ddb_update_expression += ", groups = :groups" + ddb_update_values[":groups"] = server_config.get("groups", []) + + # Update container config fields if they were in the update payload + if "environment" in update_payload and update_payload["environment"] is not None: + ddb_update_expression += ", environment = :env" + ddb_update_values[":env"] = server_config.get("environment", {}) + if "cpu" in update_payload and update_payload["cpu"] is not None: + ddb_update_expression += ", cpu = :cpu" + ddb_update_values[":cpu"] = server_config.get("cpu") + if "memoryLimitMiB" in update_payload and update_payload["memoryLimitMiB"] is not None: + ddb_update_expression += ", memoryLimitMiB = :memory" + ddb_update_values[":memory"] = server_config.get("memoryLimitMiB") + if "containerHealthCheckConfig" in update_payload and update_payload["containerHealthCheckConfig"] is not None: + ddb_update_expression += ", containerHealthCheckConfig = :health" + ddb_update_values[":health"] = server_config.get("containerHealthCheckConfig") + if "loadBalancerConfig" in update_payload and update_payload["loadBalancerConfig"] is not None: + ddb_update_expression += ", loadBalancerConfig = :lb" + ddb_update_values[":lb"] = server_config.get("loadBalancerConfig") + if "autoScalingConfig" in update_payload and update_payload["autoScalingConfig"] is not None: + ddb_update_expression += ", autoScalingConfig = :asg" + ddb_update_values[":asg"] = server_config.get("autoScalingConfig") + + # Pass through container metadata for ECS updates + if update_metadata.get("container"): + output_dict["container_metadata"] = update_metadata["container"] + + logger.info(f"Server '{server_id}' update expression: {ddb_update_expression}") + logger.info(f"Server '{server_id}' update values: {list(ddb_update_values.keys())}") + + mcp_servers_table.update_item( + Key=server_key, + UpdateExpression=ddb_update_expression, + ExpressionAttributeValues=ddb_update_values, + ExpressionAttributeNames=ExpressionAttributeNames, + ) + + # If metadata changed, reflect updates in MCP Connections table + try: + update_payload = event["update_payload"] + should_update_description = "description" in update_payload and update_payload["description"] is not None + should_update_groups = "groups" in update_payload and update_payload["groups"] is not None + if should_update_description or should_update_groups: + _update_mcp_connections_table_metadata( + server_id=server_id, + description=server_config.get("description") if should_update_description else None, + groups=server_config.get("groups") if should_update_groups else None, + ) + except Exception as e: + # Don't fail the update flow for connection metadata update errors + logger.warning(f"Non-fatal error updating MCP Connections metadata: {str(e)}") + + # Determine if ECS update is needed (container config changes for running servers) + needs_ecs_update = ( + event["update_payload"].get("environment") is not None + or event["update_payload"].get("cpu") is not None + or event["update_payload"].get("memoryLimitMiB") is not None + or event["update_payload"].get("containerHealthCheckConfig") is not None + ) and server_status == HostedMcpServerStatus.IN_SERVICE + + # We only need to poll for activation so that we know when to update the MCP Connections table + # For Hosted MCP servers, also poll on disable to wait for tasks to deprovision fully + output_dict["has_capacity_update"] = is_enable or is_disable + output_dict["is_disable"] = is_disable + output_dict["needs_ecs_update"] = needs_ecs_update + output_dict["initial_server_status"] = server_status # needed for simple metadata updates + output_dict["current_server_status"] = ddb_update_values[":ms"] # for state machine debugging / visibility + + return output_dict + + +def get_ecs_resources_from_stack(stack_name: str) -> tuple[str, str, str]: + """Extract ECS service name, cluster name, and current task definition ARN from CloudFormation.""" + try: + resources = cfn_client.describe_stack_resources(StackName=stack_name)["StackResources"] + + service_arn = None + cluster_arn = None + + for resource in resources: + if resource["ResourceType"] == "AWS::ECS::Service": + service_arn = resource["PhysicalResourceId"] + elif resource["ResourceType"] == "AWS::ECS::Cluster": + cluster_arn = resource["PhysicalResourceId"] + + if not service_arn or not cluster_arn: + raise RuntimeError(f"Could not find ECS service or cluster in stack {stack_name}") + + # Get current task definition from service + service_info = ecs_client.describe_services(cluster=cluster_arn, services=[service_arn])["services"][0] + + current_task_def_arn = service_info["taskDefinition"] + + return service_arn, cluster_arn, current_task_def_arn + + except Exception as e: + logger.error(f"Error getting ECS resources from stack {stack_name}: {str(e)}") + raise RuntimeError(f"Failed to get ECS resources from CloudFormation stack: {str(e)}") + + +def create_updated_task_definition( + task_definition_arn: str, + updated_env_vars: Optional[Dict[str, str]] = None, + env_vars_to_delete: Optional[List[str]] = None, + updated_cpu: Optional[int] = None, + updated_memory: Optional[int] = None, + updated_health_check: Optional[Dict[str, Any]] = None, +) -> str: + """Create new task definition revision with updated configuration. + + Args: + task_definition_arn: ARN of the current task definition + updated_env_vars: Environment variables to add/update + env_vars_to_delete: List of environment variable names to delete + updated_cpu: Updated CPU units + updated_memory: Updated memory limit in MiB + updated_health_check: Updated container health check configuration + """ + try: + if env_vars_to_delete is None: + env_vars_to_delete = [] + if updated_env_vars is None: + updated_env_vars = {} + + # Get current task definition + task_def_response = ecs_client.describe_task_definition(taskDefinition=task_definition_arn) + task_def = task_def_response["taskDefinition"] + + # Create new task definition with updated configuration + new_task_def = { + "family": task_def["family"], + "volumes": task_def.get("volumes", []), + "containerDefinitions": [], + } + + # Add optional fields only if they have valid values + if task_def.get("taskRoleArn"): + new_task_def["taskRoleArn"] = task_def["taskRoleArn"] + if task_def.get("executionRoleArn"): + new_task_def["executionRoleArn"] = task_def["executionRoleArn"] + if task_def.get("networkMode"): + new_task_def["networkMode"] = task_def["networkMode"] + if task_def.get("requiresCompatibilities"): + new_task_def["requiresCompatibilities"] = task_def["requiresCompatibilities"] + + # Update CPU and memory if provided + if updated_cpu is not None: + new_task_def["cpu"] = str(updated_cpu) + elif task_def.get("cpu") is not None: + new_task_def["cpu"] = str(task_def["cpu"]) + + if updated_memory is not None: + new_task_def["memory"] = str(updated_memory) + elif task_def.get("memory") is not None: + new_task_def["memory"] = str(task_def["memory"]) + + # Update container definitions + for container in task_def["containerDefinitions"]: + new_container = container.copy() + + # Start with existing environment variables from the task definition + existing_env = {env["name"]: env["value"] for env in container.get("environment", [])} + logger.info(f"Existing environment variables: {list(existing_env.keys())}") + + # Apply updates/additions + existing_env.update(updated_env_vars) + logger.info(f"Environment variables after update: {list(existing_env.keys())}") + + # Remove deleted variables + for var_name in env_vars_to_delete: + if var_name in existing_env: + del existing_env[var_name] + logger.info(f"Deleted environment variable: {var_name}") + + logger.info(f"Final environment variables: {list(existing_env.keys())}") + + # Set the new environment variables + new_container["environment"] = [{"name": key, "value": value} for key, value in existing_env.items()] + + # Update health check configuration if provided + if updated_health_check: + current_health_check = new_container.get("healthCheck", {}) + + # Update individual health check fields + if updated_health_check.get("command") is not None: + current_health_check["command"] = updated_health_check["command"] + if updated_health_check.get("interval") is not None: + current_health_check["interval"] = int(updated_health_check["interval"]) + if updated_health_check.get("timeout") is not None: + current_health_check["timeout"] = int(updated_health_check["timeout"]) + if updated_health_check.get("startPeriod") is not None: + current_health_check["startPeriod"] = int(updated_health_check["startPeriod"]) + if updated_health_check.get("retries") is not None: + current_health_check["retries"] = int(updated_health_check["retries"]) + + new_container["healthCheck"] = current_health_check + logger.info(f"Updated container health check: {current_health_check}") + + new_task_def["containerDefinitions"].append(new_container) + + # Register new task definition + response = ecs_client.register_task_definition(**new_task_def) + new_task_def_arn = str(response["taskDefinition"]["taskDefinitionArn"]) + + logger.info(f"Created new task definition: {new_task_def_arn}") + return new_task_def_arn + + except Exception as e: + logger.error(f"Error creating updated task definition: {str(e)}") + raise RuntimeError(f"Failed to create updated task definition: {str(e)}") + + +def update_ecs_service(cluster_arn: str, service_arn: str, task_definition_arn: str) -> None: + """Update ECS service to use new task definition.""" + try: + ecs_client.update_service( + cluster=cluster_arn, + service=service_arn, + taskDefinition=task_definition_arn, + forceNewDeployment=True, + ) + logger.info(f"Updated ECS service {service_arn} to use task definition {task_definition_arn}") + + except Exception as e: + logger.error(f"Error updating ECS service: {str(e)}") + raise RuntimeError(f"Failed to update ECS service: {str(e)}") + + +def handle_ecs_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Update ECS task definition with new environment variables and update service. + + This handler will: + 1. Retrieve current task definition from ECS + 2. Create new task definition revision with updated configuration + 3. Update ECS service to use new task definition + 4. Set up for deployment monitoring + """ + output_dict = deepcopy(event) + server_id = event["server_id"] + + logger.info(f"Starting ECS update for server '{server_id}'") + + try: + # Get current server info from DDB (consistent read to ensure we see the latest env/cpu/memory) + ddb_item = mcp_servers_table.get_item(Key={"id": server_id}, ConsistentRead=True)["Item"] + stack_name = ddb_item.get("stack_name") + + if not stack_name: + raise RuntimeError(f"No CloudFormation stack found for server '{server_id}'") + + # Get ECS service and task definition from CloudFormation stack + service_arn, cluster_arn, task_definition_arn = get_ecs_resources_from_stack(stack_name) + + # Get updated environment variables from server config + updated_env_vars = ddb_item.get("environment", {}) + + # Get environment variables to delete from container metadata (if available) + env_vars_to_delete = [] + if container_metadata := event.get("container_metadata"): + env_vars_to_delete = container_metadata.get("env_vars_to_delete", []) + + logger.info(f"Environment variables to delete: {env_vars_to_delete}") + + # Get updated CPU and memory + updated_cpu = ddb_item.get("cpu") + updated_memory = ddb_item.get("memoryLimitMiB") + + # Get updated health check config + updated_health_check = ddb_item.get("containerHealthCheckConfig") + + # Create new task definition with updated configuration + new_task_def_arn = create_updated_task_definition( + task_definition_arn, updated_env_vars, env_vars_to_delete, updated_cpu, updated_memory, updated_health_check + ) + + # Update ECS service to use new task definition + update_ecs_service(cluster_arn, service_arn, new_task_def_arn) + + # Set up tracking for deployment monitoring + output_dict["new_task_definition_arn"] = new_task_def_arn + output_dict["ecs_service_arn"] = service_arn + output_dict["ecs_cluster_arn"] = cluster_arn + output_dict["remaining_ecs_polls"] = MAX_POLLS + + logger.info(f"Successfully initiated ECS update for server '{server_id}'") + + except Exception as e: + logger.error(f"ECS update failed for server '{server_id}': {str(e)}") + output_dict["ecs_update_error"] = str(e) + + return output_dict + + +def handle_poll_ecs_deployment(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Monitor ECS service deployment progress. + + This handler will: + 1. Check if ECS service deployment is complete + 2. Return boolean for continued polling if needed + 3. Handle deployment failures + """ + output_dict = deepcopy(event) + server_id = event["server_id"] + + # Check if there was an error in the ECS update step + if event.get("ecs_update_error"): + logger.error(f"ECS update error for server '{server_id}': {event['ecs_update_error']}") + output_dict["should_continue_ecs_polling"] = False + return output_dict + + cluster_name = event["ecs_cluster_arn"] + service_name = event["ecs_service_arn"] + new_task_def_arn = event["new_task_definition_arn"] + + try: + # Get service deployment status + services = ecs_client.describe_services(cluster=cluster_name, services=[service_name])["services"] + + if not services: + raise RuntimeError(f"ECS service {service_name} not found") + + service = services[0] + deployments = service["deployments"] + + # Check if deployment is stable + is_deployment_stable = True + primary_deployment = None + + # Look for our deployment + for deployment in deployments: + task_def = deployment["taskDefinition"] + # Handle both full ARN and family:revision format + if task_def == new_task_def_arn or ( + new_task_def_arn.endswith(task_def.split(":")[-1]) + and task_def.startswith(new_task_def_arn.split(":")[0]) + ): + primary_deployment = deployment + logger.info( + f"Found matching deployment: status={deployment['status']}, " + f"rolloutState={deployment.get('rolloutState', 'N/A')}" + ) + if deployment["status"] != "PRIMARY" or deployment.get("rolloutState") != "COMPLETED": + is_deployment_stable = False + logger.info( + f"Deployment not yet stable: status={deployment['status']}, " + f"rolloutState={deployment.get('rolloutState', 'N/A')}" + ) + else: + logger.info("Deployment is stable and completed") + break + + if not primary_deployment: + logger.warning(f"Could not find deployment for task definition {new_task_def_arn}") + logger.warning(f"Available task definitions: {[d['taskDefinition'] for d in deployments]}") + is_deployment_stable = False + + # Check polling limits + remaining_polls = event.get("remaining_ecs_polls", MAX_POLLS) - 1 + if remaining_polls <= 0 and not is_deployment_stable: + logger.error(f"ECS deployment polling timeout for server '{server_id}'") + output_dict["ecs_polling_error"] = ( + f"ECS deployment did not complete within expected time for server '{server_id}'" + ) + output_dict["should_continue_ecs_polling"] = False + return output_dict + + should_continue_polling = not is_deployment_stable and remaining_polls > 0 + + output_dict["should_continue_ecs_polling"] = should_continue_polling + output_dict["remaining_ecs_polls"] = remaining_polls + + if is_deployment_stable: + logger.info(f"ECS deployment completed successfully for server '{server_id}'") + else: + logger.info( + f"ECS deployment still in progress for server '{server_id}', remaining polls: {remaining_polls}" + ) + + except Exception as e: + logger.error(f"Error polling ECS deployment for server '{server_id}': {str(e)}") + output_dict["ecs_polling_error"] = f"Error polling ECS deployment: {str(e)}" + output_dict["should_continue_ecs_polling"] = False + + return output_dict + + +def handle_poll_capacity(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Poll ECS service to confirm if the capacity is done updating. + + This handler will: + 1. Get the ECS service's current status. If it is still updating, then exit with a + boolean to indicate for more polling + 2. If the service status has completed, validate that it has the desired number of + running tasks + 3. If both match, then discontinue polling + """ + output_dict = deepcopy(event) + server_id = event["server_id"] + stack_name = event["stack_name"] + logger.info(f"Polling capacity for server {server_id}, Stack: {stack_name}") + + try: + service_arn, cluster_arn, _ = get_ecs_resources_from_stack(stack_name) + service_info = ecs_client.describe_services(cluster=cluster_arn, services=[service_arn])["services"][0] + + desired_count = service_info["desiredCount"] + running_count = service_info["runningCount"] + + remaining_polls = event.get("remaining_capacity_polls", MAX_POLLS) - 1 + if remaining_polls <= 0: + output_dict["polling_error"] = ( + f"Server '{server_id}' did not start healthy tasks in expected amount of time." + ) + + should_continue_polling = desired_count != running_count and remaining_polls > 0 + + output_dict["should_continue_capacity_polling"] = should_continue_polling + output_dict["remaining_capacity_polls"] = remaining_polls + + logger.info( + f"Server '{server_id}' capacity: desired={desired_count}, running={running_count}, " + f"continue_polling={should_continue_polling}" + ) + + except Exception as e: + logger.error(f"Error polling capacity for server '{server_id}': {str(e)}") + output_dict["polling_error"] = f"Error polling capacity: {str(e)}" + output_dict["should_continue_capacity_polling"] = False + + return output_dict + + +def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Finalize update in DDB. + + 1. If the server was enabled from the Stopped state, update MCP Connections table to ACTIVE, + set status to InService in DDB + 2. If the server was disabled from the InService state, set status to Stopped + 3. Commit changes to DDB + """ + output_dict = deepcopy(event) + + server_id = event["server_id"] + server_key = {"id": server_id} + stack_name = event["stack_name"] + + ddb_update_expression = "SET #status = :ms, last_modified = :lm" + ddb_update_values: Dict[str, Any] = { + ":lm": int(datetime.now(UTC).timestamp()), + } + ExpressionAttributeNames = {"#status": "status"} + + if polling_error := event.get("polling_error", None): + logger.error(f"{polling_error} Setting ECS service back to 0 tasks.") + try: + service_arn, cluster_arn, _ = get_ecs_resources_from_stack(stack_name) + ecs_client.update_service(cluster=cluster_arn, service=service_arn, desiredCount=0) + except Exception as e: + logger.error(f"Error scaling service to 0: {str(e)}") + ddb_update_values[":ms"] = HostedMcpServerStatus.STOPPED + elif event["is_disable"]: + ddb_update_values[":ms"] = HostedMcpServerStatus.STOPPED + elif event["has_capacity_update"]: + ddb_update_values[":ms"] = HostedMcpServerStatus.IN_SERVICE + + # Update MCP Connections table to ACTIVE (service was already scaled in handle_job_intake) + _update_mcp_connections_table_status(server_id, McpServerStatus.ACTIVE) + logger.info(f"Updated MCP Connections table to ACTIVE for server '{server_id}'") + else: # No polling error, not disabled, and no capacity update means this was a metadata update, keep initial state + ddb_update_values[":ms"] = event["initial_server_status"] + + mcp_servers_table.update_item( + Key=server_key, + UpdateExpression=ddb_update_expression, + ExpressionAttributeValues=ddb_update_values, + ExpressionAttributeNames=ExpressionAttributeNames, + ) + + output_dict["current_server_status"] = ddb_update_values[":ms"] + + return output_dict diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index d2282d114..1ac628a30 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -14,12 +14,17 @@ """Defines domain objects for model endpoint interactions.""" +from __future__ import annotations + +import json import logging +import re import time +import urllib.parse import uuid from dataclasses import dataclass from datetime import datetime, timezone -from enum import Enum +from enum import auto, Enum, StrEnum from typing import Annotated, Any, Dict, Generator, List, Optional, TypeAlias, Union from uuid import uuid4 @@ -27,30 +32,27 @@ from pydantic.functional_validators import AfterValidator, field_validator, model_validator from typing_extensions import Self from utilities.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, MIN_PAGE_SIZE -from utilities.validators import validate_all_fields_defined, validate_any_fields_defined, validate_instance_type +from utilities.validation import ( + validate_all_fields_defined, + validate_any_fields_defined, + validate_instance_type, + ValidationError, +) logger = logging.getLogger(__name__) -class InferenceContainer(str, Enum): +class InferenceContainer(StrEnum): """Defines supported inference container types.""" - def __str__(self) -> str: - """Returns string representation of the enum value.""" - return str(self.value) - - TGI = "tgi" - TEI = "tei" - VLLM = "vllm" + TGI = auto() + TEI = auto() + VLLM = auto() -class ModelStatus(str, Enum): +class ModelStatus(StrEnum): """Defines possible model deployment states.""" - def __str__(self) -> str: - """Returns string representation of the enum value.""" - return str(self.value) - CREATING = "Creating" IN_SERVICE = "InService" STARTING = "Starting" @@ -61,28 +63,20 @@ def __str__(self) -> str: FAILED = "Failed" -class ModelType(str, Enum): +class ModelType(StrEnum): """Defines supported model categories.""" - def __str__(self) -> str: - """Returns string representation of the enum value.""" - return str(self.value) + TEXTGEN = auto() + IMAGEGEN = auto() + EMBEDDING = auto() - TEXTGEN = "textgen" - IMAGEGEN = "imagegen" - EMBEDDING = "embedding" - -class GuardrailMode(str, Enum): +class GuardrailMode(StrEnum): """Defines supported guardrail execution modes.""" - def __str__(self) -> str: - """Returns string representation of the enum value.""" - return str(self.value) - - PRE_CALL = "pre_call" - DURING_CALL = "during_call" - POST_CALL = "post_call" + PRE_CALL = auto() + DURING_CALL = auto() + POST_CALL = auto() class GuardrailConfig(BaseModel): @@ -425,20 +419,32 @@ class DeleteModelResponse(ApiResponseBase): pass -class IngestionType(str, Enum): - """Specifies whether ingestion was automatic or manual.""" +class IngestionType(StrEnum): + """Specifies how document was ingested into the system.""" - AUTO = "auto" - MANUAL = "manual" + AUTO = auto() # Automatic ingestion via pipeline (event-driven) + MANUAL = auto() # Manual ingestion via API (user-initiated) + EXISTING = auto() # Pre-existing document discovered in KB (user-managed) -RagDocumentDict: TypeAlias = Dict[str, Any] +class JobActionType(StrEnum): + """Defines deletion job types.""" + DOCUMENT_INGESTION = auto() + DOCUMENT_BATCH_INGESTION = auto() + DOCUMENT_DELETION = auto() + DOCUMENT_BATCH_DELETION = auto() + COLLECTION_DELETION = auto() -class ChunkingStrategyType(str, Enum): + +RagDocumentDict = Dict[str, Any] + + +class ChunkingStrategyType(StrEnum): """Defines supported document chunking strategies.""" - FIXED = "fixed" + FIXED = auto() + NONE = auto() class IngestionStatus(str, Enum): @@ -454,23 +460,54 @@ class IngestionStatus(str, Enum): DELETE_COMPLETED = "DELETE_COMPLETED" DELETE_FAILED = "DELETE_FAILED" + def is_terminal(self) -> bool: + """Check if status is terminal.""" + return self in [ + IngestionStatus.INGESTION_COMPLETED, + IngestionStatus.INGESTION_FAILED, + IngestionStatus.DELETE_COMPLETED, + IngestionStatus.DELETE_FAILED, + ] + + def is_success(self) -> bool: + """Check if status is success.""" + return self in [ + IngestionStatus.INGESTION_COMPLETED, + IngestionStatus.DELETE_COMPLETED, + ] + class FixedChunkingStrategy(BaseModel): """Defines parameters for fixed-size document chunking.""" type: ChunkingStrategyType = ChunkingStrategyType.FIXED - size: int - overlap: int + size: int = Field(ge=100, le=10000) + overlap: int = Field(ge=0) + + @model_validator(mode="after") + def validate_overlap(self) -> Self: + """Validates overlap is not more than half of chunk size.""" + if self.overlap > self.size / 2: + raise ValueError( + f"chunk overlap ({self.overlap}) must be less than or equal to " f"half of chunk size ({self.size / 2})" + ) + return self -ChunkingStrategy: TypeAlias = Union[FixedChunkingStrategy] +class NoneChunkingStrategy(BaseModel): + """Defines parameters for no-chunking strategy - documents ingested as-is.""" + + type: ChunkingStrategyType = ChunkingStrategyType.NONE + + +ChunkingStrategy = Union[FixedChunkingStrategy, NoneChunkingStrategy] class RagSubDocument(BaseModel): """Represents a sub-document entity for DynamoDB storage.""" document_id: str - subdocs: list[str] = Field(default_factory=lambda: []) + subdocs: List[str] = Field(default_factory=lambda: []) index: Optional[int] = Field(default=None) sk: Optional[str] = None @@ -484,7 +521,7 @@ class RagDocument(BaseModel): pk: Optional[str] = None document_id: str = Field(default_factory=lambda: str(uuid.uuid4())) - repository_id: str + repository_id: str = Field(min_length=3, max_length=20) collection_id: str document_name: str source: str @@ -541,16 +578,31 @@ class IngestionJob(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) s3_path: str - collection_id: str + collection_id: Optional[str] = Field( + default=None, description="Collection ID for full deletion, None for default collection deletion" + ) document_id: Optional[str] = Field(default=None) repository_id: str chunk_strategy: Optional[ChunkingStrategy] = Field(default=None) + embedding_model: Optional[str] = Field( + default=None, description="Embedding model name, used as index identifier for default collections" + ) username: Optional[str] = Field(default=None) + ingestion_type: IngestionType = Field( + default=IngestionType.MANUAL, description="How the document was ingested (MANUAL, AUTO, or EXISTING)" + ) status: IngestionStatus = IngestionStatus.INGESTION_PENDING created_date: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) error_message: Optional[str] = Field(default=None) document_name: Optional[str] = Field(default=None) auto: Optional[bool] = Field(default=None) + metadata: Optional[dict] = Field(default=None) + job_type: Optional[JobActionType] = Field(default=None, description="Type of deletion job") + collection_deletion: bool = Field(default=False, description="Indicates this is a collection deletion job") + s3_paths: Optional[List[str]] = Field(default=None, description="List of S3 paths for batch ingestion operations") + document_ids: Optional[List[str]] = Field( + default=None, description="List of document IDs from completed batch operations" + ) def __init__(self, **data: Any) -> None: super().__init__(**data) @@ -558,6 +610,27 @@ def __init__(self, **data: Any) -> None: self.document_name = self.s3_path.split("/")[-1] if self.s3_path else "" self.auto = self.username == "system" + @model_validator(mode="after") + def validate_collection_deletion_identifiers(self) -> Self: + """Validate that for collection deletion jobs, exactly one of collection_id or embedding_model is provided.""" + if self.collection_deletion: + has_collection_id = self.collection_id is not None + has_embedding_model = self.embedding_model is not None + + # XOR: exactly one must be true + if has_collection_id == has_embedding_model: + if not has_collection_id and not has_embedding_model: + raise ValueError( + "For collection deletion jobs, either collection_id or embedding_model must be provided" + ) + else: + raise ValueError( + "For collection deletion jobs, only one of collection_id or " + "embedding_model should be provided, not both" + ) + + return self + class PaginatedResponse(BaseModel): """Base class for paginated API responses.""" @@ -602,3 +675,549 @@ def parse_page_size( """Parse and validate page size with configurable limits.""" page_size = int(query_params.get("pageSize", str(default))) return max(MIN_PAGE_SIZE, min(page_size, max_size)) + + @staticmethod + def parse_last_evaluated_key(query_params: Dict[str, str], key_fields: List[str]) -> Optional[Dict[str, str]]: + """Parse last evaluated key from query parameters. + + Args: + query_params: Query string parameters dictionary + key_fields: List of field names to extract from query params (e.g., ['collectionId', 'status', 'createdAt']) + + Returns: + Dictionary with last evaluated key fields, or None if no key present + + Notes: + Query params should be formatted as lastEvaluatedKey{FieldName}, e.g.: + - lastEvaluatedKeyCollectionId + - lastEvaluatedKeyStatus + - lastEvaluatedKeyCreatedAt + """ + # Check if any lastEvaluatedKey fields are present + has_key = any(f"lastEvaluatedKey{field.capitalize()}" in query_params for field in key_fields) + + if not has_key: + return None + + last_evaluated_key = {} + for field in key_fields: + # Convert field name to camelCase for query param (e.g., collectionId -> CollectionId) + param_name = f"lastEvaluatedKey{field[0].upper()}{field[1:]}" + + if param_name in query_params: + last_evaluated_key[field] = urllib.parse.unquote(query_params[param_name]) + + return last_evaluated_key if last_evaluated_key else None + + @staticmethod + def parse_last_evaluated_key_v2(query_params: Dict[str, str]) -> Optional[Dict[str, Any]]: + """Parse v2 pagination token from query parameters. + + The v2 token format supports scalable pagination with per-repository cursors. + It is passed as a JSON string in the lastEvaluatedKey query parameter. + + Args: + query_params: Query string parameters dictionary + + Returns: + Dictionary with v2 pagination token structure, or None if not present + + Token Structure: + { + "version": "v2", + "repositoryCursors": { + "repo-id": { + "lastEvaluatedKey": {...}, + "exhausted": bool + } + }, + "globalOffset": int, + "filters": { + "filter": str, + "sortBy": str, + "sortOrder": str + } + } + """ + if "lastEvaluatedKey" not in query_params: + return None + + try: + token_str = urllib.parse.unquote(query_params["lastEvaluatedKey"]) + token = json.loads(token_str) + + # Validate it's a v2 token + if not isinstance(token, dict) or token.get("version") != "v2": + return None + + return token + except (json.JSONDecodeError, ValueError, TypeError): + return None + + +@dataclass +class FilterParams: + """Shared filtering parameter handling for collections.""" + + filter_text: Optional[str] = None + status_filter: Optional[CollectionStatus] = None + + @staticmethod + def from_query_params(query_params: Dict[str, str]) -> FilterParams: + """Parse filter parameters from query string parameters. + + Args: + query_params: Query string parameters dictionary + + Returns: + FilterParams object with parsed filter parameters + + Raises: + ValidationError: If status value is invalid + """ + filter_text = query_params.get("filter") + + status_filter = None + if "status" in query_params: + try: + status_filter = CollectionStatus(query_params["status"]) + except ValueError: + raise ValidationError(f"Invalid status value: {query_params['status']}") + + return FilterParams(filter_text=filter_text, status_filter=status_filter) + + +@dataclass +class SortParams: + """Shared sorting parameter handling for collections.""" + + sort_by: CollectionSortBy = None # Will be set to default in from_query_params + sort_order: SortOrder = None # Will be set to default in from_query_params + + @staticmethod + def from_query_params(query_params: Dict[str, str]) -> SortParams: + """Parse sort parameters from query string parameters. + + Args: + query_params: Query string parameters dictionary + + Returns: + SortParams object with parsed sort parameters + + Raises: + ValidationError: If sortBy or sortOrder values are invalid + """ + + sort_by = CollectionSortBy.CREATED_AT + if "sortBy" in query_params: + try: + sort_by = CollectionSortBy(query_params["sortBy"]) + except ValueError: + raise ValidationError(f"Invalid sortBy value: {query_params['sortBy']}") + + sort_order = SortOrder.DESC + if "sortOrder" in query_params: + try: + sort_order = SortOrder(query_params["sortOrder"]) + except ValueError: + raise ValidationError(f"Invalid sortOrder value: {query_params['sortOrder']}") + + return SortParams(sort_by=sort_by, sort_order=sort_order) + + +# ============================================================================ +# Collection Management Models +# ============================================================================ + + +class CollectionStatus(StrEnum): + """Defines possible states for a collection.""" + + ACTIVE = "ACTIVE" + ARCHIVED = "ARCHIVED" + DELETED = "DELETED" + DELETE_IN_PROGRESS = "DELETE_IN_PROGRESS" + DELETE_FAILED = "DELETE_FAILED" + + +class VectorStoreStatus(StrEnum): + """Defines possible states for a vector store deployment.""" + + CREATE_IN_PROGRESS = "CREATE_IN_PROGRESS" + CREATE_COMPLETE = "CREATE_COMPLETE" + CREATE_FAILED = "CREATE_FAILED" + UPDATE_IN_PROGRESS = "UPDATE_IN_PROGRESS" + UPDATE_COMPLETE = "UPDATE_COMPLETE" + UPDATE_COMPLETE_CLEANUP_IN_PROGRESS = "UPDATE_COMPLETE_CLEANUP_IN_PROGRESS" + UNKNOWN = "UNKNOWN" + + +class PipelineTrigger(StrEnum): + """Defines trigger types for collection pipelines.""" + + EVENT = auto() + SCHEDULE = auto() + + +class PipelineConfig(BaseModel): + """Defines pipeline configuration for automated document ingestion.""" + + autoRemove: bool = Field(default=True, description="Automatically remove documents after ingestion") + chunkOverlap: Optional[int] = Field( + default=None, ge=0, description="Chunk overlap for pipeline ingestion (deprecated, use chunkingStrategy)" + ) + chunkSize: Optional[int] = Field( + default=None, + ge=100, + le=10000, + description="Chunk size for pipeline ingestion (deprecated, use chunkingStrategy)", + ) + chunkingStrategy: Optional[ChunkingStrategy] = Field( + default=None, description="Chunking strategy for documents in this pipeline" + ) + collectionId: Optional[str] = Field( + default=None, description="Collection ID for this pipeline (for Bedrock KB, this is the data source ID)" + ) + s3Bucket: str = Field(min_length=1, description="S3 bucket for pipeline source") + s3Prefix: str = Field(description="S3 prefix for pipeline source") + trigger: PipelineTrigger = Field(description="Pipeline trigger type") + + @model_validator(mode="after") + def validate_chunking_config(self) -> Self: + """Validates that either chunkingStrategy or legacy chunk fields are provided.""" + has_legacy = self.chunkSize is not None and self.chunkOverlap is not None + has_new = self.chunkingStrategy is not None + + if not has_legacy and not has_new: + raise ValueError( + "Chunking configuration required: provide either 'chunkingStrategy' or both " + "'chunkSize' and 'chunkOverlap'" + ) + + # Validate that if one legacy field is provided, both must be provided + if (self.chunkSize is not None or self.chunkOverlap is not None) and not has_legacy: + raise ValueError("When using legacy chunking fields, both 'chunkSize' and 'chunkOverlap' must be provided") + + # If legacy fields provided but no chunkingStrategy, create one + if has_legacy and not has_new: + self.chunkingStrategy = FixedChunkingStrategy( + type=ChunkingStrategyType.FIXED, size=self.chunkSize, overlap=self.chunkOverlap + ) + + return self + + +class CollectionMetadata(BaseModel): + """Defines metadata for a collection.""" + + tags: List[str] = Field(default_factory=list, max_length=50, description="Metadata tags for the collection") + customFields: Dict[str, Any] = Field(default_factory=dict, description="Custom metadata fields") + + @field_validator("tags") + @classmethod + def validate_tags(cls, tags: List[str]) -> List[str]: + """Validates metadata tags.""" + tag_pattern = re.compile(r"^[a-zA-Z0-9_-]+$") + for tag in tags: + if len(tag) > 50: + raise ValueError("Each tag must be 50 characters or less") + if not tag_pattern.match(tag): + raise ValueError( + f"Tag '{tag}' contains invalid characters. " + "Tags must contain only alphanumeric characters, hyphens, and underscores" + ) + return tags + + @classmethod + def merge(cls, parent: Optional[CollectionMetadata], child: Optional[CollectionMetadata]) -> CollectionMetadata: + """Merges parent and child metadata. + + Args: + parent: Parent vector store metadata + child: Collection-specific metadata + + Returns: + Merged metadata with combined tags and merged custom fields + """ + if parent is None and child is None: + return cls() + if parent is None: + return child or cls() + if child is None: + return parent + + # Combine tags (deduplicate while preserving order) + merged_tags = list(dict.fromkeys(parent.tags + child.tags)) + + # Merge custom fields (child overrides parent) + merged_custom_fields = {**parent.customFields, **child.customFields} + + return cls(tags=merged_tags, customFields=merged_custom_fields) + + +class RagCollectionConfig(BaseModel): + """Represents a RAG collection configuration.""" + + collectionId: str = Field(default_factory=lambda: str(uuid4()), description="Unique collection identifier") + repositoryId: str = Field(min_length=1, description="Parent vector store ID") + name: Optional[str] = Field(default=None, max_length=100, description="User-friendly collection name") + description: Optional[str] = Field(default=None, description="Collection description") + chunkingStrategy: Optional[ChunkingStrategy] = Field(default=None, description="Chunking strategy for documents") + allowChunkingOverride: bool = Field( + default=True, description="Allow users to override chunking strategy during ingestion" + ) + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Collection-specific metadata (merged with parent)" + ) + allowedGroups: Optional[List[str]] = Field(default=None, description="User groups with access to collection") + embeddingModel: Optional[str] = Field(description="Embedding model ID (can be set at creation, immutable after)") + createdBy: str = Field(min_length=1, description="User ID of creator") + createdAt: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="Creation timestamp") + updatedAt: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="Last update timestamp") + status: CollectionStatus = Field(default=CollectionStatus.ACTIVE, description="Collection status") + pipelines: List[PipelineConfig] = Field(default_factory=list, description="Automated ingestion pipelines") + default: bool = Field(default=False, description="Indicates if this is a default collection for Bedrock KB") + dataSourceId: Optional[str] = Field( + default=None, description="Bedrock KB data source ID for filtering (Bedrock KB only)" + ) + + model_config = ConfigDict(use_enum_values=True, validate_default=True) + + @field_validator("name") + @classmethod + def validate_name(cls, name: Optional[str]) -> Optional[str]: + """Validates collection name.""" + if name is not None: + if len(name) > 100: + raise ValueError("Collection name must be 100 characters or less") + # Allow alphanumeric, spaces, hyphens, underscores + if not all(c.isalnum() or c in " -_" for c in name): + raise ValueError( + "Collection name must contain only alphanumeric characters, spaces, hyphens, and underscores" + ) + return name + + @field_validator("allowedGroups") + @classmethod + def validate_allowed_groups(cls, groups: Optional[List[str]]) -> Optional[List[str]]: + """Validates allowed groups.""" + if groups is not None and len(groups) == 0: + # Empty list should be treated as None (inherit from parent) + return None + return groups + + +class IngestDocumentRequest(BaseModel): + """Request model for ingesting documents.""" + + keys: List[str] = Field(description="S3 keys to ingest") + collectionId: Optional[str] = Field(default=None, description="Target collection ID") + embeddingModel: Optional[Dict[str, str]] = Field(default=None, description="Embedding model config") + chunkingStrategy: Optional[Dict[str, Any]] = Field(default=None, description="Chunking strategy override") + metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata") + + +class ListCollectionsResponse(PaginatedResponse): + """Response model for listing collections.""" + + collections: List[RagCollectionConfig] = Field(description="List of collections") + totalCount: Optional[int] = Field(default=None, description="Total number of collections") + currentPage: Optional[int] = Field(default=None, description="Current page number") + totalPages: Optional[int] = Field(default=None, description="Total number of pages") + + +class CollectionSortBy(StrEnum): + """Defines sort options for collection listing.""" + + NAME = "name" + CREATED_AT = "createdAt" + UPDATED_AT = "updatedAt" + + +class SortOrder(StrEnum): + """Defines sort order options.""" + + ASC = auto() + DESC = auto() + + +class RepositoryMetadata(BaseModel): + """Defines metadata for a repository/vector store.""" + + tags: List[str] = Field(default_factory=list, description="Tags for categorizing the repository") + customFields: Optional[Dict[str, Any]] = Field(default=None, description="Custom metadata fields") + + +class OpenSearchNewClusterConfig(BaseModel): + """Configuration for creating a new OpenSearch cluster.""" + + dataNodes: int = Field( + default=2, + ge=1, + description="The number of data nodes (instances) to use in the Amazon OpenSearch Service domain.", + ) + dataNodeInstanceType: str = Field(default="r7g.large.search", description="The instance type for your data nodes") + masterNodes: int = Field(default=0, ge=0, description="The number of instances to use for the master node") + masterNodeInstanceType: str = Field( + default="r7g.large.search", + description="The hardware configuration of the computer that hosts the dedicated master node", + ) + volumeSize: int = Field( + default=20, + ge=20, + description=( + "The size (in GiB) of the EBS volume for each data node. The minimum and maximum size of " + "an EBS volume depends on the EBS volume type and the instance type to which it is attached." + ), + ) + volumeType: str = Field( + default="gp3", description="The EBS volume type to use with the Amazon OpenSearch Service domain" + ) + multiAzWithStandby: bool = Field( + default=False, description="Indicates whether Multi-AZ with Standby deployment option is enabled." + ) + + +class OpenSearchExistingClusterConfig(BaseModel): + """Configuration for using an existing OpenSearch cluster.""" + + endpoint: str = Field(min_length=1, description="Existing OpenSearch Cluster endpoint") + + +# Union type for OpenSearch configurations +OpenSearchConfig = Union[OpenSearchNewClusterConfig, OpenSearchExistingClusterConfig] + + +class RdsInstanceConfig(BaseModel): + """Configuration schema for RDS Instances needed for LiteLLM scaling or PGVector RAG operations. + + The optional fields can be omitted to create a new database instance, otherwise fill in all fields + to use an existing database instance. + """ + + username: str = Field(default="postgres", description="The username used for database connection.") + passwordSecretId: Optional[str] = Field( + default=None, description="The SecretsManager Secret ID that stores the existing database password." + ) + dbHost: Optional[str] = Field(default=None, description="The database hostname for the existing database instance.") + dbName: str = Field(default="postgres", description="The name of the database for the database instance.") + dbPort: int = Field( + default=5432, + description="The port of the existing database instance or the port to be opened on the database instance.", + ) + + +class BedrockDataSource(BaseModel): + """Configuration for a single Bedrock Knowledge Base data source.""" + + id: str = Field(min_length=1, description="The ID of the Bedrock Knowledge Base data source") + name: str = Field(min_length=1, description="The name of the Bedrock Knowledge Base data source") + s3Uri: str = Field(min_length=1, description="The S3 URI of the data source (s3://bucket/prefix)") + + @field_validator("s3Uri") + @classmethod + def validate_s3_uri(cls, v: str) -> str: + """Validate S3 URI format.""" + if not v.startswith("s3://"): + raise ValueError("S3 URI must start with s3://") + return v + + +class BedrockKnowledgeBaseConfig(BaseModel): + """Configuration for Bedrock Knowledge Base with multiple data sources. + + Stores the KB ID and array of data sources. Backend converts to pipelines. + """ + + knowledgeBaseId: str = Field(min_length=1, description="The ID of the Bedrock Knowledge Base") + dataSources: List[BedrockDataSource] = Field( + min_length=1, description="Array of data sources in this Knowledge Base" + ) + + +class VectorStoreConfig(BaseModel): + """Represents a vector store/repository configuration.""" + + repositoryId: str = Field(description="Unique identifier for the repository") + repositoryName: Optional[str] = Field(default=None, description="User-friendly name for the repository") + description: Optional[str] = Field(default=None, description="Description of the repository") + embeddingModelId: Optional[str] = Field(default=None, description="Default embedding model ID") + type: str = Field(description="Type of vector store (opensearch, pgvector, bedrock_knowledge_base)") + allowedGroups: List[str] = Field(default_factory=list, description="User groups with access to this repository") + metadata: Optional[RepositoryMetadata] = Field(default=None, description="Repository metadata") + pipelines: Optional[List[PipelineConfig]] = Field(default=None, description="Automated ingestion pipelines") + # Type-specific configurations + opensearchConfig: Optional[Union[OpenSearchNewClusterConfig, OpenSearchExistingClusterConfig]] = Field( + default=None, description="OpenSearch configuration" + ) + rdsConfig: Optional[RdsInstanceConfig] = Field(default=None, description="RDS/PGVector configuration") + bedrockKnowledgeBaseConfig: Optional[BedrockKnowledgeBaseConfig] = Field( + default=None, description="Bedrock Knowledge Base configuration with data sources" + ) + # Status and timestamps + status: Optional[VectorStoreStatus] = Field(default=None, description="Repository Status") + createdAt: Optional[datetime] = Field(default=None, description="Creation timestamp") + updatedAt: Optional[datetime] = Field(default=None, description="Last update timestamp") + + +class UpdateVectorStoreRequest(BaseModel): + """Request model for updating a vector store.""" + + repositoryName: Optional[str] = Field(default=None, description="User-friendly name") + description: Optional[str] = Field(default=None, description="Description of the repository") + embeddingModelId: Optional[str] = Field(default=None, description="Default embedding model ID") + allowedGroups: Optional[List[str]] = Field(default=None, description="User groups with access") + metadata: Optional[RepositoryMetadata] = Field(default=None, description="Repository metadata") + pipelines: Optional[List[PipelineConfig]] = Field(default=None, description="Automated ingestion pipelines") + bedrockKnowledgeBaseConfig: Optional[BedrockKnowledgeBaseConfig] = Field( + default=None, description="Bedrock Knowledge Base configuration" + ) + + +class KnowledgeBaseMetadata(BaseModel): + """Metadata for a Bedrock Knowledge Base.""" + + knowledgeBaseId: str = Field(description="Knowledge Base ID") + name: str = Field(description="Knowledge Base name") + description: Optional[str] = Field(default="", description="Knowledge Base description") + status: str = Field(description="Knowledge Base status (ACTIVE, CREATING, DELETING, etc.)") + createdAt: Optional[datetime] = Field(default=None, description="Creation timestamp") + updatedAt: Optional[datetime] = Field(default=None, description="Last update timestamp") + + +class DataSourceMetadata(BaseModel): + """Metadata for a Bedrock Knowledge Base data source.""" + + dataSourceId: str = Field(description="Data Source ID") + name: str = Field(description="Data Source name") + description: Optional[str] = Field(default="", description="Data Source description") + status: str = Field(description="Data Source status (AVAILABLE, CREATING, DELETING, etc.)") + s3Bucket: str = Field(description="S3 bucket for the data source") + s3Prefix: str = Field(default="", description="S3 prefix for the data source") + createdAt: Optional[datetime] = Field(default=None, description="Creation timestamp") + updatedAt: Optional[datetime] = Field(default=None, description="Last update timestamp") + managed: Optional[bool] = Field(default=False, description="Whether this data source is managed by a collection") + collectionId: Optional[str] = Field(default=None, description="Collection ID if managed") + + @field_validator("s3Bucket") + @classmethod + def validate_s3_bucket(cls, v: str) -> str: + """Validate S3 bucket name format.""" + if not v: + raise ValueError("S3 bucket cannot be empty") + # Basic S3 bucket name validation + if not re.match(r"^[a-z0-9][a-z0-9.-]*[a-z0-9]$", v): + raise ValueError(f"Invalid S3 bucket name format: {v}") + return v + + +class DataSourceSelection(BaseModel): + """Represents a user's selection of a data source for collection creation. + + Frontend sends this to backend, which creates pipelines and collections. + """ + + dataSourceId: str = Field(description="Data Source ID") + dataSourceName: str = Field(description="Data Source name") + s3Bucket: str = Field(description="S3 bucket for the data source") + s3Prefix: str = Field(default="", description="S3 prefix for the data source") diff --git a/lambda/models/handler/get_model_handler.py b/lambda/models/handler/get_model_handler.py index c6b82d520..198f592aa 100644 --- a/lambda/models/handler/get_model_handler.py +++ b/lambda/models/handler/get_model_handler.py @@ -16,7 +16,7 @@ from typing import List, Optional -from utilities.common_functions import user_has_group_access +from utilities.auth import user_has_group_access from ..domain_objects import GetModelResponse from ..exception import ModelNotFoundError diff --git a/lambda/models/handler/list_models_handler.py b/lambda/models/handler/list_models_handler.py index c1f27ef90..fd2ae7541 100644 --- a/lambda/models/handler/list_models_handler.py +++ b/lambda/models/handler/list_models_handler.py @@ -16,7 +16,7 @@ from typing import List, Optional -from utilities.common_functions import user_has_group_access +from utilities.auth import user_has_group_access from ..domain_objects import ListModelsResponse from .base_handler import BaseApiHandler diff --git a/lambda/models/lambda_functions.py b/lambda/models/lambda_functions.py index 595b03655..1340d9e3a 100644 --- a/lambda/models/lambda_functions.py +++ b/lambda/models/lambda_functions.py @@ -24,8 +24,8 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from mangum import Mangum -from utilities.auth import is_admin -from utilities.common_functions import get_groups, retry_config +from utilities.auth import get_groups, is_admin +from utilities.common_functions import retry_config from utilities.fastapi_middleware.aws_api_gateway_middleware import AWSAPIGatewayMiddleware from .domain_objects import ( diff --git a/lambda/models/model_api_key_cleanup.py b/lambda/models/model_api_key_cleanup.py index 84a750670..f4b6afa14 100644 --- a/lambda/models/model_api_key_cleanup.py +++ b/lambda/models/model_api_key_cleanup.py @@ -27,6 +27,7 @@ import json import os import sys +import traceback from typing import Any, Dict, List import boto3 @@ -185,7 +186,13 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return {"Status": "SUCCESS", "PhysicalResourceId": "bedrock-auth-cleanup", "Data": {"ModelsUpdated": "0"}} # Query all models from the LiteLLM database using the found table (use quotes for case-sensitive names) - cursor.execute(f'SELECT * FROM "{litellm_table}" LIMIT 1') + + # Use psycopg2's identifier quoting to prevent SQL injection + cursor.execute( + psycopg2.sql.SQL("SELECT * FROM {} LIMIT 1").format( # noqa: S608, P103 + psycopg2.sql.Identifier(litellm_table) + ) + ) columns = [desc[0] for desc in cursor.description] print(f"Table {litellm_table} columns: {columns}") @@ -202,7 +209,14 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return {"Status": "SUCCESS", "PhysicalResourceId": "bedrock-auth-cleanup", "Data": {"ModelsUpdated": "0"}} # Query all models from the LiteLLM database - cursor.execute(f'SELECT "{model_id_col}", "{model_name_col}", "{litellm_params_col}" FROM "{litellm_table}"') + cursor.execute( + psycopg2.sql.SQL("SELECT {}, {}, {} FROM {}").format( # noqa: S608, P103 + psycopg2.sql.Identifier(model_id_col), + psycopg2.sql.Identifier(model_name_col), + psycopg2.sql.Identifier(litellm_params_col), + psycopg2.sql.Identifier(litellm_table), + ) + ) models = cursor.fetchall() print(f"Found {len(models)} total models in LiteLLM database") @@ -262,7 +276,11 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: # Update the model in the database clean_params_json = json.dumps(clean_params) cursor.execute( - f'UPDATE "{litellm_table}" SET "{litellm_params_col}" = %s WHERE "{model_id_col}" = %s', + psycopg2.sql.SQL("UPDATE {} SET {} = %s WHERE {} = %s").format( # noqa: S608, P103 + psycopg2.sql.Identifier(litellm_table), + psycopg2.sql.Identifier(litellm_params_col), + psycopg2.sql.Identifier(model_id_col), + ), (clean_params_json, matching_litellm_model["model_id"]), ) print(f"Successfully cleaned model: {matching_litellm_model['model_name']}") @@ -293,8 +311,6 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: except Exception as e: # Handle unexpected errors print(f"Cleanup failed: {e}") - import traceback - print(f"Traceback: {traceback.format_exc()}") # Rollback any pending database changes diff --git a/lambda/prompt_templates/lambda_functions.py b/lambda/prompt_templates/lambda_functions.py index 68fc278e9..1e3f34ba1 100644 --- a/lambda/prompt_templates/lambda_functions.py +++ b/lambda/prompt_templates/lambda_functions.py @@ -22,8 +22,8 @@ import boto3 from boto3.dynamodb.conditions import Attr, Key -from utilities.auth import get_username, is_admin -from utilities.common_functions import api_wrapper, get_groups, get_item, retry_config +from utilities.auth import get_user_context, get_username +from utilities.common_functions import api_wrapper, get_item, retry_config from .models import PromptTemplateModel @@ -85,7 +85,7 @@ def _get_prompt_templates( @api_wrapper def get(event: dict, context: dict) -> Any: """Retrieve a specific prompt template from DynamoDB.""" - user_id = get_username(event) + user_id, is_admin, groups = get_user_context(event) prompt_template_id = get_prompt_template_id(event) # Query for the latest prompt template revision @@ -97,7 +97,7 @@ def get(event: dict, context: dict) -> Any: # Check if the user is authorized to get the prompt template is_owner = item["owner"] == user_id - if is_owner or is_admin(event) or is_member(get_groups(event), item["groups"]): + if is_owner or is_admin or is_member(groups, item["groups"]): # add extra attribute so the frontend doesn't have to determine this if is_owner: item["isOwner"] = True @@ -117,15 +117,14 @@ def is_member(user_groups: List[str], prompt_groups: List[str]) -> bool: def list(event: dict, context: dict) -> Dict[str, Any]: """List prompt templates for a user from DynamoDB.""" query_params = event.get("queryStringParameters", {}) - user_id = get_username(event) + user_id, is_admin, groups = get_user_context(event) # Check whether to list public or private templates if query_params.get("public") == "true": - if is_admin(event): + if is_admin: logger.info(f"Listing all templates for user {user_id} (is_admin)") return _get_prompt_templates(latest=True) else: - groups = get_groups(event) logger.info(f"Listing public templates for user {user_id} with groups {groups}") return _get_prompt_templates(groups=groups, latest=True) else: @@ -149,7 +148,7 @@ def create(event: dict, context: dict) -> Any: @api_wrapper def update(event: dict, context: dict) -> Any: """Update an existing prompt template in DynamoDB.""" - user_id = get_username(event) + user_id, is_admin, _ = get_user_context(event) prompt_template_id = get_prompt_template_id(event) body = json.loads(event["body"], parse_float=Decimal) prompt_template_model = PromptTemplateModel(**body) @@ -165,7 +164,7 @@ def update(event: dict, context: dict) -> Any: raise ValueError(f"Prompt template {prompt_template_model} not found.") # Check if the user is authorized to update the prompt template - if is_admin(event) or item["owner"] == user_id: + if is_admin or item["owner"] == user_id: logger.info(f"Removing latest attribute from prompt_template with ID {prompt_template_id} for user {user_id}") # Remove latest attribute indicating no longer the latest version @@ -189,7 +188,7 @@ def update(event: dict, context: dict) -> Any: @api_wrapper def delete(event: dict, context: dict) -> Dict[str, str]: """Logically delete a prompt template from DynamoDB.""" - user_id = get_username(event) + user_id, is_admin, _ = get_user_context(event) prompt_template_id = get_prompt_template_id(event) # Query for the latest prompt template revision @@ -200,7 +199,7 @@ def delete(event: dict, context: dict) -> Dict[str, str]: raise ValueError(f"Prompt template {prompt_template_id} not found.") # Check if the user is authorized to delete the prompt template - if is_admin(event) or item["owner"] == user_id: + if is_admin or item["owner"] == user_id: logger.info(f"Removing latest attribute from prompt_template with ID {prompt_template_id} for user {user_id}") # Logical delete by removing the latest attribute diff --git a/lambda/prompt_templates/models.py b/lambda/prompt_templates/models.py index 71392d14a..c6226fe53 100644 --- a/lambda/prompt_templates/models.py +++ b/lambda/prompt_templates/models.py @@ -14,19 +14,15 @@ import uuid from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field -class PromptTemplateType(str, Enum): +class PromptTemplateType(StrEnum): """Enum representing the prompt template type.""" - def __str__(self) -> str: - """Represent the enum as a string.""" - return str(self.value) - PERSONA = "persona" DIRECTIVE = "directive" diff --git a/lambda/repository/collection_repo.py b/lambda/repository/collection_repo.py new file mode 100644 index 000000000..35632eb09 --- /dev/null +++ b/lambda/repository/collection_repo.py @@ -0,0 +1,403 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collection repository for DynamoDB operations.""" + +import logging +import os +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple + +import boto3 +from boto3.dynamodb.conditions import Key +from botocore.exceptions import ClientError +from models.domain_objects import CollectionSortBy, CollectionStatus, RagCollectionConfig, SortOrder +from utilities.common_functions import retry_config +from utilities.encoders import convert_decimal + +logger = logging.getLogger(__name__) + + +class CollectionRepositoryError(Exception): + """Exception raised for errors in collection repository operations.""" + + pass + + +class CollectionRepository: + """Collection repository for DynamoDB operations.""" + + def __init__(self, table_name: Optional[str] = None) -> None: + """ + Initialize the Collection Repository. + + Args: + table_name: Optional table name override for testing + """ + dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) + table_name = table_name or os.environ.get("LISA_RAG_COLLECTIONS_TABLE", "LisaRagCollectionsTable") + self.table = dynamodb.Table(table_name) + logger.info(f"Initialized CollectionRepository with table: {table_name}") + + def create(self, collection: RagCollectionConfig) -> RagCollectionConfig: + """ + Create a new collection in DynamoDB. + + Args: + collection: The collection configuration to create + + Returns: + The created collection configuration + + Raises: + CollectionRepositoryError: If creation fails + """ + try: + # Ensure timestamps are set + now = datetime.now(timezone.utc) + if not collection.createdAt: + collection.createdAt = now + if not collection.updatedAt: + collection.updatedAt = now + + # Convert to dict for DynamoDB + item = collection.model_dump() + + # Convert datetime objects to ISO strings + item["createdAt"] = collection.createdAt.isoformat() + item["updatedAt"] = collection.updatedAt.isoformat() + + # Put item with condition to prevent overwriting + self.table.put_item( + Item=item, + ConditionExpression="attribute_not_exists(collectionId)", + ) + + logger.info(f"Created collection: {collection.collectionId}") + return collection + + except ClientError as e: + if e.response["Error"]["Code"] == "ConditionalCheckFailedException": + raise CollectionRepositoryError(f"Collection with ID '{collection.collectionId}' already exists") + logger.error(f"Failed to create collection: {e}") + raise CollectionRepositoryError(f"Failed to create collection: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error creating collection: {e}") + raise CollectionRepositoryError(f"Unexpected error creating collection: {str(e)}") + + def find_by_id(self, collection_id: str, repository_id: str) -> Optional[RagCollectionConfig]: + """ + Find a collection by its ID and repository ID. + + Args: + collection_id: The collection ID + repository_id: The repository ID + + Returns: + The collection configuration if found, None otherwise + + Raises: + CollectionRepositoryError: If retrieval fails + """ + try: + response = self.table.get_item( + Key={ + "collectionId": collection_id, + "repositoryId": repository_id, + }, + ConsistentRead=True, + ) + + if "Item" not in response: + return None + + item = convert_decimal(response["Item"]) + return RagCollectionConfig(**item) + + except Exception as e: + logger.error(f"Failed to find collection {collection_id}: {e}") + raise CollectionRepositoryError(f"Failed to find collection: {str(e)}") + + def update( + self, + collection_id: str, + repository_id: str, + updates: Dict[str, Any], + expected_version: Optional[str] = None, + ) -> RagCollectionConfig: + """ + Update a collection with optimistic locking. + + Args: + collection_id: The collection ID + repository_id: The repository ID + updates: Dictionary of fields to update + expected_version: Expected updatedAt timestamp for optimistic locking + + Returns: + The updated collection configuration + + Raises: + CollectionRepositoryError: If update fails + """ + try: + # Build update expression + update_expr_parts = [] + expr_attr_names = {} + expr_attr_values = {} + + # Always update the updatedAt timestamp + updates["updatedAt"] = datetime.now(timezone.utc).isoformat() + + for key, value in updates.items(): + # Skip immutable fields + if key in ["collectionId", "repositoryId", "createdBy", "createdAt", "embeddingModel"]: + logger.warning(f"Skipping immutable field: {key}") + continue + + # Use attribute names to handle reserved words + attr_name = f"#{key}" + attr_value = f":{key}" + update_expr_parts.append(f"{attr_name} = {attr_value}") + expr_attr_names[attr_name] = key + expr_attr_values[attr_value] = value + + if not update_expr_parts: + raise CollectionRepositoryError("No valid fields to update") + + update_expression = "SET " + ", ".join(update_expr_parts) + + # Build condition expression for optimistic locking + condition_expression = "attribute_exists(collectionId)" + if expected_version: + condition_expression += " AND updatedAt = :expected_version" + expr_attr_values[":expected_version"] = expected_version + + # Perform update + response = self.table.update_item( + Key={ + "collectionId": collection_id, + "repositoryId": repository_id, + }, + UpdateExpression=update_expression, + ExpressionAttributeNames=expr_attr_names, + ExpressionAttributeValues=expr_attr_values, + ConditionExpression=condition_expression, + ReturnValues="ALL_NEW", + ) + + item = convert_decimal(response["Attributes"]) + logger.info(f"Updated collection: {collection_id}") + return RagCollectionConfig(**item) + + except ClientError as e: + if e.response["Error"]["Code"] == "ConditionalCheckFailedException": + if expected_version: + raise CollectionRepositoryError("Collection was modified by another process. Please retry.") + raise CollectionRepositoryError(f"Collection '{collection_id}' not found") + logger.error(f"Failed to update collection {collection_id}: {e}") + raise CollectionRepositoryError(f"Failed to update collection: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error updating collection {collection_id}: {e}") + raise CollectionRepositoryError(f"Unexpected error updating collection: {str(e)}") + + def delete(self, collection_id: str, repository_id: str) -> bool: + """ + Delete a collection from DynamoDB. + + Args: + collection_id: The collection ID + repository_id: The repository ID + + Returns: + True if deletion was successful + + Raises: + CollectionRepositoryError: If deletion fails + """ + try: + self.table.delete_item( + Key={ + "collectionId": collection_id, + "repositoryId": repository_id, + }, + ConditionExpression="attribute_exists(collectionId)", + ) + logger.info(f"Deleted collection: {collection_id}") + return True + + except ClientError as e: + if e.response["Error"]["Code"] == "ConditionalCheckFailedException": + raise CollectionRepositoryError(f"Collection '{collection_id}' not found") + logger.error(f"Failed to delete collection {collection_id}: {e}") + raise CollectionRepositoryError(f"Failed to delete collection: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error deleting collection {collection_id}: {e}") + raise CollectionRepositoryError(f"Unexpected error deleting collection: {str(e)}") + + def list_by_repository( + self, + repository_id: str, + page_size: int = 20, + last_evaluated_key: Optional[Dict[str, str]] = None, + filter_text: Optional[str] = None, + status_filter: Optional[CollectionStatus] = None, + sort_by: CollectionSortBy = CollectionSortBy.CREATED_AT, + sort_order: SortOrder = SortOrder.DESC, + ) -> Tuple[List[RagCollectionConfig], Optional[Dict[str, str]]]: + """ + List collections for a repository with pagination, filtering, and sorting. + + Args: + repository_id: The repository ID + page_size: Number of items per page (max 100) + last_evaluated_key: Pagination token from previous request + filter_text: Optional text to filter by name or description + status_filter: Optional status to filter by + sort_by: Field to sort by + sort_order: Sort order (asc/desc) + + Returns: + Tuple of (list of collections, last_evaluated_key for pagination) + + Raises: + CollectionRepositoryError: If listing fails + """ + try: + # Limit page size to 100 + page_size = min(page_size, 100) + + # Determine which index to use based on filters + if status_filter: + index_name = "StatusIndex" + key_condition = Key("repositoryId").eq(repository_id) & Key("status").eq(status_filter.value) + else: + index_name = "RepositoryIndex" + key_condition = Key("repositoryId").eq(repository_id) + + # Build query parameters + query_params = { + "IndexName": index_name, + "KeyConditionExpression": key_condition, + "Limit": page_size, + "ScanIndexForward": (sort_order == SortOrder.ASC), + } + + if last_evaluated_key: + query_params["ExclusiveStartKey"] = last_evaluated_key + + # Execute query + response = self.table.query(**query_params) + + # Convert items to collection objects + items = [convert_decimal(item) for item in response.get("Items", [])] + collections = [RagCollectionConfig(**item) for item in items] + + # Apply text filter if provided (post-query filtering) + if filter_text: + filter_text_lower = filter_text.lower() + collections = [ + c + for c in collections + if (c.name and filter_text_lower in c.name.lower()) + or (c.description and filter_text_lower in c.description.lower()) + ] + + # Sort collections if needed (for non-default sort fields) + if sort_by != CollectionSortBy.CREATED_AT: + reverse = sort_order == SortOrder.DESC + if sort_by == CollectionSortBy.NAME: + collections.sort(key=lambda c: c.name or "", reverse=reverse) + elif sort_by == CollectionSortBy.UPDATED_AT: + collections.sort(key=lambda c: c.updatedAt, reverse=reverse) + + # Get pagination token + next_key = response.get("LastEvaluatedKey") + + logger.info(f"Listed {len(collections)} collections for repository {repository_id}") + return collections, next_key + + except Exception as e: + logger.error(f"Failed to list collections for repository {repository_id}: {e}") + raise CollectionRepositoryError(f"Failed to list collections: {str(e)}") + + def count_by_repository(self, repository_id: str, status: Optional[CollectionStatus] = None) -> int: + """ + Count collections in a repository. + + Args: + repository_id: The repository ID + status: Optional status filter + + Returns: + Number of collections + + Raises: + CollectionRepositoryError: If count fails + """ + try: + if status: + index_name = "StatusIndex" + key_condition = Key("repositoryId").eq(repository_id) & Key("status").eq(status.value) + else: + index_name = "RepositoryIndex" + key_condition = Key("repositoryId").eq(repository_id) + + response = self.table.query( + IndexName=index_name, + KeyConditionExpression=key_condition, + Select="COUNT", + ) + + count = response.get("Count", 0) + logger.info(f"Counted {count} collections for repository {repository_id}") + return count + + except Exception as e: + logger.error(f"Failed to count collections for repository {repository_id}: {e}") + raise CollectionRepositoryError(f"Failed to count collections: {str(e)}") + + def find_by_name(self, repository_id: str, collection_name: str) -> Optional[RagCollectionConfig]: + """ + Find a collection by repository ID and name. + + Args: + repository_id: The repository ID + name: The collection name + + Returns: + The collection if found, None otherwise + + Raises: + CollectionRepositoryError: If search fails + """ + try: + response = self.table.query( + IndexName="RepositoryIndex", + KeyConditionExpression=Key("repositoryId").eq(repository_id), + FilterExpression="#name = :name", + ExpressionAttributeNames={"#name": "name"}, + ExpressionAttributeValues={":name": collection_name}, + Limit=1, + ) + + items = response.get("Items", []) + if items: + return RagCollectionConfig(**convert_decimal(items[0])) + + return None + + except Exception as e: + logger.error(f"Failed to find collection by name '{collection_name}': {e}") + raise CollectionRepositoryError(f"Failed to find collection by name: {str(e)}") diff --git a/lambda/repository/collection_service.py b/lambda/repository/collection_service.py new file mode 100644 index 000000000..4750fee6c --- /dev/null +++ b/lambda/repository/collection_service.py @@ -0,0 +1,999 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Collection service for business logic.""" + +import heapq +import logging +import os +from typing import Any, Dict, List, Optional, Tuple + +import boto3 +from models.domain_objects import ( + CollectionSortBy, + CollectionStatus, + IngestionJob, + IngestionStatus, + IngestionType, + JobActionType, + RagCollectionConfig, + SortOrder, + SortParams, +) +from repository.collection_repo import CollectionRepository +from repository.ingestion_job_repo import IngestionJobRepository +from repository.ingestion_service import DocumentIngestionService +from repository.rag_document_repo import RagDocumentRepository +from repository.services import RepositoryServiceFactory +from repository.vector_store_repo import VectorStoreRepository +from utilities.validation import ValidationError + +logger = logging.getLogger(__name__) + +# Initialize AWS clients +sfn_client = boto3.client("stepfunctions") +ssm_client = boto3.client("ssm") + + +class CollectionService: + """Service for collection operations.""" + + def __init__( + self, + collection_repo: Optional[CollectionRepository] = None, + vector_store_repo: Optional[VectorStoreRepository] = None, + document_repo: Optional[RagDocumentRepository] = None, + ): + self.collection_repo = collection_repo or CollectionRepository() + self.vector_store_repo = vector_store_repo or VectorStoreRepository() + self.document_repo = document_repo or RagDocumentRepository( + os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"] + ) + + def has_access( + self, + collection: RagCollectionConfig, + username: str, + user_groups: List[str], + is_admin: bool, + require_write: bool = False, + ) -> bool: + """Check if user has access to a collection.""" + if is_admin: + return True + if collection.createdBy == username: + return True + + if require_write: + return False + + # Public collection (empty allowedGroups means accessible to all) + allowed_groups = collection.allowedGroups or [] + if not allowed_groups: + return True + + # Check if user has at least one matching group + if bool(set(user_groups) & set(allowed_groups)): + return True + + return False + + def create_collection( + self, + collection: RagCollectionConfig, + username: str, + ) -> RagCollectionConfig: + """Create a new collection with name uniqueness validation. + + Args: + collection: Collection configuration to create + username: Username creating the collection + + Returns: + Created collection + + Raises: + ValidationError: If collection name already exists in repository + """ + + # Check if collection name already exists in this repository + existing = self.collection_repo.find_by_name(collection.repositoryId, collection.name) + if existing: + raise ValidationError( + f"Collection with name '{collection.name}' already exists in repository '{collection.repositoryId}'" + ) + + # Set fields + collection.createdBy = username + + return self.collection_repo.create(collection) + + def get_collection( + self, + repository_id: str, + collection_id: str, + username: str, + user_groups: List[str], + is_admin: bool, + ) -> RagCollectionConfig: + """Get a collection with access control. + + Args: + repository_id: Repository ID + collection_id: Collection ID + username: Username for access control + user_groups: User groups for access control + is_admin: Whether user is admin + """ + collection = self.collection_repo.find_by_id(collection_id, repository_id) + if not collection: + raise ValidationError(f"Collection {collection_id} not found") + if not self.has_access(collection, username, user_groups, is_admin): + raise ValidationError(f"Permission denied for collection {collection_id}") + return collection + + def list_collections( + self, + repository_id: str, + username: str, + user_groups: List[str], + is_admin: bool, + page_size: int = 20, + last_evaluated_key: Optional[Dict[str, str]] = None, + ) -> Tuple[List[RagCollectionConfig], Optional[Dict[str, str]]]: + """List collections with access control. + + For Bedrock KB repositories, default collections are persisted to the database + and will be included in the query results automatically. + + For other repository types, a virtual default collection is generated if needed. + """ + collections, key = self.collection_repo.list_by_repository( + repository_id, page_size=page_size, last_evaluated_key=last_evaluated_key + ) + filtered = [c for c in collections if self.has_access(c, username, user_groups, is_admin)] + + # On first page, check if virtual default collection needs to be added + # (only for non-Bedrock KB repositories) + if not last_evaluated_key: + repository = self.vector_store_repo.find_repository_by_id(repository_id) + + # Only create virtual collection if repository supports it + if repository: + service = RepositoryServiceFactory.create_service(repository) + if service.should_create_default_collection(): + default_collection = service.create_default_collection() + if default_collection: + # Check if a collection with the default embedding model ID already exists + existing_ids = {c.collectionId for c in filtered} + if default_collection.collectionId not in existing_ids: + filtered.append(default_collection) + + return filtered, key + + def update_collection( + self, + collection_id: str, + repository_id: str, + collection_data: Any, + username: str, + user_groups: List[str], + is_admin: bool, + ) -> RagCollectionConfig: + """Update a collection with access control and name uniqueness validation. + + Args: + collection_id: Collection ID to update + repository_id: Repository ID + request: RagCollectionConfig with fields to update + username: Username for access control + user_groups: User groups for access control + is_admin: Whether user is admin + + Returns: + Updated collection + + Raises: + ValidationError: If name already exists or access denied + """ + collection = self.collection_repo.find_by_id(collection_id, repository_id) + if not collection: + raise ValidationError(f"Collection {collection_id} not found") + if not self.has_access(collection, username, user_groups, is_admin, require_write=True): + raise ValidationError(f"Permission denied to update collection {collection_id}") + + # Define updatable fields + updatable_fields = [ + "name", + "description", + "chunkingStrategy", + "allowedGroups", + "metadata", + "allowChunkingOverride", + "pipelines", + ] + + # Build updates dictionary from request + updates = { + field: getattr(collection_data, field) + for field in updatable_fields + if hasattr(collection_data, field) and getattr(collection_data, field) is not None + } + + # For default collections, prevent changing immutable fields + if collection.default: + # Prevent changing default status or data source ID + if hasattr(collection_data, "default") and collection_data.default != collection.default: + raise ValidationError("Cannot change default status of a default collection") + if hasattr(collection_data, "dataSourceId") and collection_data.dataSourceId != collection.dataSourceId: + raise ValidationError("Cannot change data source ID of a default collection") + + # Remove these fields from updates if present + updates.pop("default", None) + updates.pop("dataSourceId", None) + + # Special validation for name changes + if "name" in updates and updates["name"] != collection.name: + existing = self.collection_repo.find_by_name(repository_id, updates["name"]) + if existing and existing.collectionId != collection_id: + raise ValidationError( + f"Collection with name '{updates['name']}' already exists in repository '{repository_id}'" + ) + + # Update collection + updated = self.collection_repo.update(collection_id, repository_id, updates) + return updated + + def delete_collection( + self, + repository_id: str, + collection_id: Optional[str], + embedding_name: Optional[str], + username: str, + user_groups: List[str], + is_admin: bool, + ) -> Dict[str, Any]: + """Delete a collection with access control. + + Args: + repository_id: Repository ID + collection_id: Collection ID (None for default collections) + embedding_name: Embedding model name (None for regular collections) + username: Username for access control + user_groups: User groups for access control + is_admin: Whether user is admin + + Returns: + Dictionary with deletion type and job ID + """ + # Validate that at least one identifier is provided + if not collection_id and not embedding_name: + raise ValidationError("Either collection_id or embedding_name must be provided") + + # Determine deletion type + is_default_collection = embedding_name is not None + deletion_type = "partial" if is_default_collection else "full" + + logger.info( + f"Starting {deletion_type} deletion for repository {repository_id}, " + f"collection_id={collection_id}, embedding_name={embedding_name}" + ) + + # For regular collections, verify access and update status + if not is_default_collection: + collection = self.collection_repo.find_by_id(collection_id, repository_id) + if not collection: + raise ValidationError(f"Collection {collection_id} not found") + if not self.has_access(collection, username, user_groups, is_admin, require_write=True): + raise ValidationError(f"Permission denied to delete collection {collection_id}") + + # Update collection status to DELETE_IN_PROGRESS + self.collection_repo.update(collection_id, repository_id, {"status": CollectionStatus.DELETE_IN_PROGRESS}) + + embedding_model = None # Don't set embedding_model for regular collections + else: + # For default collections, use embedding_name directly + embedding_model = embedding_name + + # Get document counts by ingestion type for summary + try: + all_docs, _, _ = self.document_repo.list_all( + repository_id=repository_id, + collection_id=collection_id if not is_default_collection else embedding_name, + limit=10000, # Get all documents for accurate count + ) + + lisa_managed_count = sum( + 1 for d in all_docs if d.ingestion_type in [IngestionType.MANUAL, IngestionType.AUTO] + ) + user_managed_count = sum(1 for d in all_docs if d.ingestion_type == IngestionType.EXISTING) + except Exception as e: + logger.warning(f"Failed to get document counts: {e}") + lisa_managed_count = None + user_managed_count = None + + # Create deletion job + try: + ingestion_job_repo = IngestionJobRepository() + ingestion_service = DocumentIngestionService() + + deletion_job = IngestionJob( + repository_id=repository_id, + collection_id=collection_id if not is_default_collection else None, + s3_path="", # Not applicable for collection deletion + embedding_model=embedding_model, # Only set for default collections + username=username, + status=IngestionStatus.DELETE_PENDING, + job_type=JobActionType.COLLECTION_DELETION, + collection_deletion=True, + ) + + # Save and submit the deletion job + ingestion_job_repo.save(deletion_job) + ingestion_service.create_delete_job(deletion_job) + + logger.info(f"Submitted {deletion_type} deletion job {deletion_job.id} " f"for repository {repository_id}") + + response = { + "jobId": deletion_job.id, + "deletionType": deletion_type, + "status": deletion_job.status, + } + + # Add summary if counts available + if lisa_managed_count is not None and user_managed_count is not None: + response["summary"] = { + "lisaManagedDocuments": lisa_managed_count, + "userManagedDocuments": user_managed_count, + "action": ( + "LISA-managed documents (MANUAL/AUTO) will be deleted, " + "user-managed documents (EXISTING) preserved" + ), + } + + return response + + except Exception as e: + logger.error(f"Failed to submit deletion job: {e}", exc_info=True) + + # Update collection status to DELETE_FAILED (only for regular collections) + if not is_default_collection: + self.collection_repo.update(collection_id, repository_id, {"status": CollectionStatus.DELETE_FAILED}) + + raise + + def get_collection_by_name( + self, + repository_id: str, + collection_name: str, + username: str, + user_groups: List[str], + is_admin: bool, + ) -> RagCollectionConfig: + """Get a collection by name with access control.""" + collection = self.collection_repo.find_by_name(repository_id, collection_name) + if not collection: + raise ValidationError(f"Collection '{collection_name}' not found") + if not self.has_access(collection, username, user_groups, is_admin): + raise ValidationError(f"Permission denied for collection '{collection_name}'") + return collection + + def count_collections(self, repository_id: str) -> int: + """Count total collections in a repository. + + Args: + repository_id: Repository ID + + Returns: + Total count of collections + """ + count = self.collection_repo.count_by_repository(repository_id) + return int(count) if count is not None else 0 + + def get_collection_model( + self, + repository_id: str, + collection_id: str, + username: str, + user_groups: List[str], + is_admin: bool, + ) -> Optional[str]: + """Get embedding model from collection or repository default. + + Args: + repository_id: Repository ID + collection_id: Collection ID + username: Username for access control + user_groups: User groups for access control + is_admin: Whether user is admin + + Returns: + Embedding model name from collection or repository default + """ + if collection_id: + try: + collection = self.collection_repo.find_by_id(collection_id, repository_id) + if collection and collection.embeddingModel: + return collection.embeddingModel + except ValidationError: + # Collection not found, fall back to repository default + pass + + # Fall back to repository default embedding model + repository = self.vector_store_repo.find_repository_by_id(repository_id) + embedding_model_id = repository.get("embeddingModelId") + return str(embedding_model_id) if embedding_model_id is not None else None + + def list_all_user_collections( + self, + username: str, + user_groups: List[str], + is_admin: bool, + page_size: int = 20, + pagination_token: Optional[Dict[str, Any]] = None, + filter_text: Optional[str] = None, + sort_params: Optional[SortParams] = None, + ) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + """ + List all collections user has access to across all repositories. + + This method orchestrates the complete workflow: + 1. Get accessible repositories + 2. Estimate collection count + 3. Select pagination strategy + 4. Execute query and return results + + Args: + username: Username for access control + user_groups: User groups for access control + is_admin: Whether user is admin + page_size: Number of items per page + pagination_token: Pagination token from previous request + filter_text: Optional text filter for name/description + sort_params: Optional SortParams object for sorting (defaults to createdAt desc) + + Returns: + Tuple of (list of enriched collections, pagination token) + """ + # Use default sort params if not provided + if sort_params is None: + sort_params = SortParams(sort_by=CollectionSortBy.CREATED_AT, sort_order=SortOrder.DESC) + + logger.info( + f"Listing all user collections for user={username}, is_admin={is_admin}, " + f"page_size={page_size}, filter={filter_text}, sort_by={sort_params.sort_by.value}" + ) + + # Get repositories user can access + repositories = self._get_accessible_repositories(username, user_groups, is_admin) + logger.debug(f"User has access to {len(repositories)} repositories") + + if not repositories: + logger.info("User has no accessible repositories, returning empty list") + return [], None + + # Estimate total collections + estimated_total = self._estimate_total_collections(repositories) + logger.info(f"Estimated total collections: {estimated_total}") + + # Select and execute pagination strategy + if estimated_total > 1000: + logger.info("Using scalable pagination strategy for large dataset") + collections, next_token = self._paginate_large_collections( + repositories, username, user_groups, is_admin, page_size, pagination_token, filter_text, sort_params + ) + else: + logger.info("Using simple pagination strategy") + collections, next_token = self._paginate_collections( + repositories, username, user_groups, is_admin, page_size, pagination_token, filter_text, sort_params + ) + + logger.info(f"Returning {len(collections)} collections") + return collections, next_token + + def _get_accessible_repositories( + self, username: str, user_groups: List[str], is_admin: bool + ) -> List[Dict[str, Any]]: + """ + Get all repositories user has access to. + + Args: + username: Username for access control + user_groups: User groups for access control + is_admin: Whether user is admin + + Returns: + List of repository configurations user can access + """ + all_repos = self.vector_store_repo.get_registered_repositories() + + if is_admin: + logger.debug(f"Admin user has access to all {len(all_repos)} repositories") + return all_repos + + accessible = [repo for repo in all_repos if self._has_repository_access(user_groups, repo)] + logger.debug(f"User has access to {len(accessible)} of {len(all_repos)} repositories") + return accessible + + def _has_repository_access(self, user_groups: List[str], repository: Dict[str, Any]) -> bool: + """ + Check if user has access to repository based on groups. + + Args: + user_groups: User groups for access control + repository: Repository configuration + + Returns: + True if user has access, False otherwise + """ + allowed_groups = repository.get("allowedGroups", []) + + # Public repository (no group restrictions) + if not allowed_groups: + return True + + # Check if user has at least one matching group + has_access = bool(set(user_groups) & set(allowed_groups)) + logger.debug( + f"Repository {repository.get('repositoryId')} access check: " + f"user_groups={user_groups}, allowed_groups={allowed_groups}, has_access={has_access}" + ) + return has_access + + def _enrich_with_repository_metadata( + self, collections: List[RagCollectionConfig], repositories: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Enrich collections with repository metadata. + + Args: + collections: List of collection configurations + repositories: List of repository configurations + + Returns: + List of enriched collection dictionaries with repositoryName + """ + # Create repository lookup map + repo_map = {repo["repositoryId"]: repo for repo in repositories} + + enriched = [] + for collection in collections: + collection_dict = collection.model_dump(mode="json") + + # Add repository name + repo_id = collection.repositoryId + repository = repo_map.get(repo_id) + if repository: + collection_dict["repositoryName"] = repository.get("repositoryName", repo_id) + else: + # Fallback if repository not in map + logger.warning(f"Repository {repo_id} not found in accessible repositories") + collection_dict["repositoryName"] = repo_id + + enriched.append(collection_dict) + + return enriched + + def _estimate_total_collections(self, repositories: List[Dict[str, Any]]) -> int: + """ + Estimate total number of collections across repositories. + + Args: + repositories: List of repository configurations + + Returns: + Estimated total collection count + """ + total = 0 + for repo in repositories: + try: + count = self.collection_repo.count_by_repository(repo["repositoryId"]) + total += count + except Exception as e: + logger.warning(f"Failed to count collections for repository {repo['repositoryId']}: {e}") + # Continue with other repositories + + return total + + def _paginate_collections( + self, + repositories: List[Dict[str, Any]], + username: str, + user_groups: List[str], + is_admin: bool, + page_size: int, + pagination_token: Optional[Dict[str, Any]], + filter_text: Optional[str], + sort_params: SortParams, + ) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + """ + Simple pagination strategy for small-to-medium deployments. + + Aggregates all collections in memory, applies filtering and sorting, + then returns requested page. + + Args: + repositories: List of accessible repositories + username: Username for access control + user_groups: User groups for access control + is_admin: Whether user is admin + page_size: Number of items per page + pagination_token: Pagination token from previous request + filter_text: Optional text filter + sort_params: SortParams object for sorting + + Returns: + Tuple of (list of enriched collections, next pagination token) + """ + # Parse pagination token + offset = 0 + if pagination_token and pagination_token.get("version") == "v1": + offset = pagination_token.get("offset", 0) + # Verify filter consistency + token_filter = pagination_token.get("filters", {}) + if ( + token_filter.get("filter") != filter_text + or token_filter.get("sortBy") != sort_params.sort_by + or token_filter.get("sortOrder") != sort_params.sort_order + ): + logger.warning("Pagination token filters don't match request, resetting to offset 0") + offset = 0 + + # Aggregate all collections from accessible repositories + all_collections: List[RagCollectionConfig] = [] + + for repo in repositories: + repo_id = repo["repositoryId"] + try: + # Query collections for this repository (fetch up to 100 per repo) + collections, _ = self.collection_repo.list_by_repository( + repository_id=repo_id, + page_size=100, + last_evaluated_key=None, + ) + + # Filter by collection-level permissions + accessible = [c for c in collections if self.has_access(c, username, user_groups, is_admin)] + + # Check if default collection needs to be added + service = RepositoryServiceFactory.create_service(repo) + default_collection = ( + service.create_default_collection() if service.should_create_default_collection() else None + ) + if default_collection: + # Check if a collection with the default embedding model ID already exists + existing_ids = {c.collectionId for c in accessible} + if default_collection.collectionId not in existing_ids: + accessible.append(default_collection) + + all_collections.extend(accessible) + logger.debug(f"Repository {repo_id}: {len(accessible)} accessible collections") + + except Exception as e: + logger.error(f"Failed to query collections for repository {repo_id}: {e}") + # Continue with other repositories + + # Apply text filtering + if filter_text: + all_collections = [c for c in all_collections if self._matches_filter(c, filter_text)] + logger.debug(f"After filtering: {len(all_collections)} collections") + + # Apply sorting + all_collections = self._sort_collections(all_collections, sort_params) + + # Apply pagination + start_idx = offset + end_idx = start_idx + page_size + page_collections = all_collections[start_idx:end_idx] + + # Enrich with repository metadata + enriched = self._enrich_with_repository_metadata(page_collections, repositories) + + # Build next token if more pages exist + next_token = None + if end_idx < len(all_collections): + next_token = { + "version": "v1", + "offset": end_idx, + "filters": { + "filter": filter_text, + "sortBy": sort_params.sort_by.value, + "sortOrder": sort_params.sort_order.value, + }, + } + + return enriched, next_token + + def _matches_filter(self, collection: RagCollectionConfig, filter_text: str) -> bool: + """ + Check if collection matches text filter. + + Args: + collection: Collection to check + filter_text: Text to search for (case-insensitive) + + Returns: + True if collection name or description contains filter text + """ + filter_lower = filter_text.lower() + + # Check name + if collection.name and filter_lower in collection.name.lower(): + return True + + # Check description + if collection.description and filter_lower in collection.description.lower(): + return True + + return False + + def _sort_collections( + self, collections: List[RagCollectionConfig], sort_params: SortParams + ) -> List[RagCollectionConfig]: + """ + Sort collections by specified field and order. + + Args: + collections: List of collections to sort + sort_params: SortParams object containing sort field and order + + Returns: + Sorted list of collections + """ + reverse = sort_params.sort_order == SortOrder.DESC + + if sort_params.sort_by == CollectionSortBy.NAME: + return sorted(collections, key=lambda c: c.name or "", reverse=reverse) + elif sort_params.sort_by == CollectionSortBy.UPDATED_AT: + return sorted(collections, key=lambda c: c.updatedAt, reverse=reverse) + else: # Default to createdAt + return sorted(collections, key=lambda c: c.createdAt, reverse=reverse) + + def _paginate_large_collections( + self, + repositories: List[Dict[str, Any]], + username: str, + user_groups: List[str], + is_admin: bool, + page_size: int, + pagination_token: Optional[Dict[str, Any]], + filter_text: Optional[str], + sort_params: SortParams, + ) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + """ + Scalable pagination strategy for large deployments. + + Uses incremental merge with per-repository cursors to handle + 1000+ collections efficiently without loading all into memory. + + Args: + repositories: List of accessible repositories + username: Username for access control + user_groups: User groups for access control + is_admin: Whether user is admin + page_size: Number of items per page + pagination_token: Pagination token from previous request + filter_text: Optional text filter + sort_params: SortParams object for sorting + + Returns: + Tuple of (list of enriched collections, next pagination token) + """ + # Initialize or restore repository cursors + if pagination_token and pagination_token.get("version") == "v2": + cursors = pagination_token.get("repositoryCursors", {}) + global_offset = pagination_token.get("globalOffset", 0) + # Convert lists back to sets + seen_ids_raw = pagination_token.get("seenCollectionIds", {}) + seen_collection_ids = {repo_id: set(id_list) for repo_id, id_list in seen_ids_raw.items()} + + # Verify filter consistency + token_filters = pagination_token.get("filters", {}) + if ( + token_filters.get("filter") != filter_text + or token_filters.get("sortBy") != sort_params.sort_by.value + or token_filters.get("sortOrder") != sort_params.sort_order.value + ): + logger.warning("Pagination token filters don't match, resetting cursors") + cursors = {} + global_offset = 0 + seen_collection_ids = {} + else: + cursors = {} + global_offset = 0 + seen_collection_ids = {} + + # Initialize cursors for new repositories + for repo in repositories: + repo_id = repo["repositoryId"] + if repo_id not in cursors: + cursors[repo_id] = {"lastEvaluatedKey": None, "exhausted": False} + if repo_id not in seen_collection_ids: + seen_collection_ids[repo_id] = set() + + # Fetch batches from each non-exhausted repository + batches = [] + for repo in repositories: + repo_id = repo["repositoryId"] + cursor = cursors[repo_id] + + if cursor["exhausted"]: + continue + + try: + # Query collections for this repository + collections, next_key = self.collection_repo.list_by_repository( + repository_id=repo_id, + page_size=page_size, # Fetch page_size per repo + last_evaluated_key=cursor["lastEvaluatedKey"], + ) + + # Filter by collection-level permissions + accessible = [c for c in collections if self.has_access(c, username, user_groups, is_admin)] + + # Track seen collection IDs for this repository + for c in accessible: + seen_collection_ids[repo_id].add(c.collectionId) + + # On first fetch for this repository, check if default collection needs to be added + if not cursor["lastEvaluatedKey"]: + service = RepositoryServiceFactory.create_service(repo) + if service.should_create_default_collection(): + default_collection = service.create_default_collection() + if default_collection: + # Check if we've seen a collection with the default embedding model ID + if default_collection.collectionId not in seen_collection_ids[repo_id]: + accessible.append(default_collection) + seen_collection_ids[repo_id].add(default_collection.collectionId) + + # Apply text filtering + if filter_text: + accessible = [c for c in accessible if self._matches_filter(c, filter_text)] + + batches.append({"repositoryId": repo_id, "collections": accessible, "nextKey": next_key}) + + # Update cursor + cursors[repo_id]["lastEvaluatedKey"] = next_key + cursors[repo_id]["exhausted"] = next_key is None + + logger.debug( + f"Repository {repo_id}: fetched {len(accessible)} collections, " + f"exhausted={cursors[repo_id]['exhausted']}" + ) + + except Exception as e: + logger.error(f"Failed to query collections for repository {repo_id}: {e}") + cursors[repo_id]["exhausted"] = True + + # Merge batches using heap for efficient sorting + merged = self._merge_sorted_batches(batches, sort_params.sort_by.value, sort_params.sort_order.value) + + # Extract requested page + start_idx = global_offset + end_idx = start_idx + page_size + page_collections = merged[start_idx:end_idx] + + # Enrich with repository metadata + enriched = self._enrich_with_repository_metadata(page_collections, repositories) + + # Determine if more pages exist + has_more = (end_idx < len(merged)) or any(not c["exhausted"] for c in cursors.values()) + + # Build next token + next_token = None + if has_more: + # If we consumed all merged results, reset offset for next fetch + new_offset = end_idx if end_idx < len(merged) else 0 + + # Convert sets to lists for JSON serialization + serializable_seen_ids = {repo_id: list(id_set) for repo_id, id_set in seen_collection_ids.items()} + + next_token = { + "version": "v2", + "repositoryCursors": cursors, + "globalOffset": new_offset, + "seenCollectionIds": serializable_seen_ids, + "filters": { + "filter": filter_text, + "sortBy": sort_params.sort_by.value, + "sortOrder": sort_params.sort_order.value, + }, + } + + return enriched, next_token + + def _merge_sorted_batches( + self, batches: List[Dict[str, Any]], sort_by: str, sort_order: str + ) -> List[RagCollectionConfig]: + """ + Merge pre-sorted batches from multiple repositories using min-heap. + + Time Complexity: O(N log K) where N = total collections, K = number of repositories + Space Complexity: O(N) for merged result + + Args: + batches: List of batch dictionaries with collections from each repository + sort_by: Field to sort by + sort_order: Sort order (asc/desc) + + Returns: + Merged and sorted list of collections + """ + if not batches: + return [] + + # Create heap with first item from each batch + heap: List[Tuple[Any, str, int, Dict[str, Any]]] = [] + + for batch in batches: + if batch["collections"]: + collection = batch["collections"][0] + sort_key = self._get_sort_key(collection, sort_by) + + # For descending order, negate numeric keys or reverse string comparison + if sort_order.lower() == "desc": + if isinstance(sort_key, str): + # For strings, we'll reverse the final list instead + pass + else: + # For datetime/numeric, negate for heap + sort_key = -sort_key.timestamp() if hasattr(sort_key, "timestamp") else -sort_key + + heapq.heappush(heap, (sort_key, batch["repositoryId"], 0, batch)) + + merged = [] + while heap: + _, repo_id, idx, batch = heapq.heappop(heap) + merged.append(batch["collections"][idx]) + + # Add next item from same batch + next_idx = idx + 1 + if next_idx < len(batch["collections"]): + next_collection = batch["collections"][next_idx] + next_sort_key = self._get_sort_key(next_collection, sort_by) + + if sort_order.lower() == "desc": + if isinstance(next_sort_key, str): + pass + else: + next_sort_key = ( + -next_sort_key.timestamp() if hasattr(next_sort_key, "timestamp") else -next_sort_key + ) + + heapq.heappush(heap, (next_sort_key, repo_id, next_idx, batch)) + + # For descending string sorts, reverse the final list + if sort_order.lower() == "desc" and sort_by == "name": + merged.reverse() + + return merged + + def _get_sort_key(self, collection: RagCollectionConfig, sort_by: str) -> Any: + """ + Extract sort key from collection. + + Args: + collection: Collection to extract key from + sort_by: Field to sort by + + Returns: + Sort key value + """ + if sort_by == "name": + return collection.name or "" + elif sort_by == "updatedAt": + return collection.updatedAt + else: # Default to createdAt + return collection.createdAt diff --git a/lambda/repository/embeddings.py b/lambda/repository/embeddings.py index c36186be7..b691783bf 100644 --- a/lambda/repository/embeddings.py +++ b/lambda/repository/embeddings.py @@ -89,11 +89,10 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: if not texts: raise ValidationError("No texts provided for embedding") - logger.info(f"Embedding {len(texts)} documents") + logger.info(f"Embedding {len(texts)} documents using {self.model_name}") try: url = f"{self.base_url}/embeddings" request_data = {"input": texts, "model": self.model_name} - response = requests.post( url, json=request_data, diff --git a/lambda/repository/ingestion_job_repo.py b/lambda/repository/ingestion_job_repo.py index 4868da717..a1727463d 100644 --- a/lambda/repository/ingestion_job_repo.py +++ b/lambda/repository/ingestion_job_repo.py @@ -14,6 +14,7 @@ """Repository for ingestion job DynamoDB operations.""" +import logging import os from datetime import datetime, timedelta, timezone from typing import Dict, Optional @@ -22,8 +23,13 @@ from models.domain_objects import IngestionJob, IngestionStatus from utilities.common_functions import retry_config -dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) -ingestion_job_table = dynamodb.Table(os.environ["LISA_INGESTION_JOB_TABLE_NAME"]) +logger = logging.getLogger(__name__) + + +def _get_ingestion_job_table(): + """Lazy initialization of DynamoDB table.""" + dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) + return dynamodb.Table(os.environ["LISA_INGESTION_JOB_TABLE_NAME"]) class IngestionJobListResponse: @@ -46,14 +52,33 @@ def __init__(self, message: str): class IngestionJobRepository: def __init__(self): - self.ddb_client = boto3.client("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) - self.table_name = os.environ["LISA_INGESTION_JOB_TABLE_NAME"] + self._ddb_client = None + self._table_name = None + self._batch_client = None + + @property + def ddb_client(self): + if self._ddb_client is None: + self._ddb_client = boto3.client("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) + return self._ddb_client + + @property + def table_name(self): + if self._table_name is None: + self._table_name = os.environ["LISA_INGESTION_JOB_TABLE_NAME"] + return self._table_name + + @property + def batch_client(self): + if self._batch_client is None: + self._batch_client = boto3.client("batch", region_name=os.environ["AWS_REGION"], config=retry_config) + return self._batch_client def save(self, job: IngestionJob) -> None: - ingestion_job_table.put_item(Item=job.model_dump(exclude_none=True)) + _get_ingestion_job_table().put_item(Item=job.model_dump(exclude_none=True)) def find_by_id(self, id: str) -> IngestionJob: - response = ingestion_job_table.get_item(Key={"id": id}) + response = _get_ingestion_job_table().get_item(Key={"id": id}) if not response.get("Item"): raise Exception(f"Ingestion job with id {id} not found") @@ -61,7 +86,7 @@ def find_by_id(self, id: str) -> IngestionJob: return IngestionJob(**response.get("Item")) def find_by_path(self, s3_path: str) -> list[IngestionJob]: - response = ingestion_job_table.query( + response = _get_ingestion_job_table().query( IndexName="s3Path", KeyConditionExpression="s3Path = :path", ExpressionAttributeValues={":path": s3_path} ) @@ -69,7 +94,7 @@ def find_by_path(self, s3_path: str) -> list[IngestionJob]: return [IngestionJob(**item) for item in items] def find_by_document(self, document_id: str) -> Optional[IngestionJob]: - response = ingestion_job_table.query( + response = _get_ingestion_job_table().query( IndexName="documentId", KeyConditionExpression="document_id = :document_id", ExpressionAttributeValues={":document_id": document_id}, @@ -84,7 +109,7 @@ def find_by_document(self, document_id: str) -> Optional[IngestionJob]: def update_status(self, job: IngestionJob, status: IngestionStatus) -> IngestionJob: job.status = status - ingestion_job_table.update_item( + _get_ingestion_job_table().update_item( Key={ "id": job.id, }, @@ -118,10 +143,6 @@ def list_jobs_by_repository( Returns: Tuple of (list of job dictionaries, last_evaluated_key for next page) """ - import logging - - logger = logging.getLogger(__name__) - time_threshold = datetime.now(timezone.utc) - timedelta(hours=time_limit_hours) time_threshold_str = time_threshold.isoformat() @@ -148,7 +169,7 @@ def list_jobs_by_repository( if last_evaluated_key: query_params["ExclusiveStartKey"] = last_evaluated_key - response = ingestion_job_table.query(**query_params) + response = _get_ingestion_job_table().query(**query_params) logger.info(f"GSI query returned {len(response.get('Items', []))} items") @@ -165,3 +186,72 @@ def list_jobs_by_repository( last_evaluated_key_response = response.get("LastEvaluatedKey") return jobs, last_evaluated_key_response + + def get_batch_job_status(self, job_id: str) -> Optional[str]: + """Get the status of a batch job by job ID. + + Args: + job_id: AWS Batch job ID + + Returns: + Job status (SUBMITTED, PENDING, RUNNABLE, STARTING, RUNNING, SUCCEEDED, FAILED) or None + """ + response = self.batch_client.describe_jobs(jobs=[job_id]) + if response.get("jobs"): + return response["jobs"][0].get("status") + return None + + def find_batch_job_for_document(self, document_id: str, job_queue: str) -> Optional[Dict]: + """Find the batch job associated with a document ingestion. + + Args: + document_id: Document ID + job_queue: Batch job queue name + + Returns: + Dict with jobId and status, or None if not found + """ + job_name_prefix = f"document-ingest-{document_id}" + + for status in ["RUNNING", "SUCCEEDED", "FAILED", "PENDING", "RUNNABLE", "STARTING"]: + try: + response = self.batch_client.list_jobs(jobQueue=job_queue, jobStatus=status) + for job in response.get("jobSummaryList", []): + if job["jobName"].startswith(job_name_prefix): + return {"jobId": job["jobId"], "status": status, "jobName": job["jobName"]} + except Exception as e: # nosec B112 + logger.debug(f"Error listing jobs with status {status}: {e}") + + return None + + def find_pending_collection_deletions(self, repository_id: str) -> list[IngestionJob]: + """Find all pending collection deletion jobs for a repository. + + Args: + repository_id: Repository ID + + Returns: + List of pending collection deletion jobs + """ + try: + response = _get_ingestion_job_table().query( + IndexName="repository_id-created_date-index", + KeyConditionExpression="repository_id = :repo_id", + FilterExpression=( + "attribute_exists(collection_deletion) AND collection_deletion = :true AND " + "(#status = :pending OR #status = :in_progress)" + ), + ExpressionAttributeNames={"#status": "status"}, + ExpressionAttributeValues={ + ":repo_id": repository_id, + ":true": True, + ":pending": IngestionStatus.DELETE_PENDING, + ":in_progress": IngestionStatus.DELETE_IN_PROGRESS, + }, + ) + + items = response.get("Items", []) + return [IngestionJob(**item) for item in items] + except Exception as e: + logger.error(f"Error finding pending collection deletions for repository {repository_id}: {e}") + return [] diff --git a/lambda/repository/ingestion_service.py b/lambda/repository/ingestion_service.py index a405c16c0..66251893a 100644 --- a/lambda/repository/ingestion_service.py +++ b/lambda/repository/ingestion_service.py @@ -13,9 +13,10 @@ # limitations under the License. import logging import os +from typing import Any, Dict, Optional import boto3 -from models.domain_objects import Enum, IngestionJob +from models.domain_objects import Enum, FixedChunkingStrategy, IngestDocumentRequest, IngestionJob, IngestionType logger = logging.getLogger(__name__) @@ -37,8 +38,65 @@ def _submit_job(self, job: IngestionJob, action: IngestionAction) -> None: ) logger.info(f"Submitted {action} job for document {job.id}: {response['jobId']}") - def create_ingest_job(self, job: IngestionJob) -> None: + def submit_create_job(self, job: IngestionJob) -> None: self._submit_job(job, IngestionAction("ingest")) def create_delete_job(self, job: IngestionJob) -> None: self._submit_job(job, IngestionAction("delete")) + + def create_ingestion_job( + self, + repository: dict, + collection: Optional[dict], + request: IngestDocumentRequest, + query_params: dict, + s3_path: str, + username: str, + metadata: Optional[Dict[str, Any]] = None, + ingestion_type: IngestionType = IngestionType.MANUAL, + ) -> IngestionJob: + + # Determine collection_id + collection_id = ( + request.collectionId + or (request.embeddingModel.get("modelName") if request.embeddingModel else None) + or repository.get("embeddingModelId") + ) + + # Determine chunking strategy + chunk_strategy = None + if collection and request.chunkingStrategy and collection.get("allowChunkingOverride"): + try: + chunk_strategy = ( + FixedChunkingStrategy(**request.chunkingStrategy) + if request.chunkingStrategy.get("type", "").upper() == "FIXED" + else collection.get("chunkingStrategy") + ) + except Exception: + chunk_strategy = collection.get("chunkingStrategy") + elif collection: + chunk_strategy = collection.get("chunkingStrategy") + if not chunk_strategy: + chunk_strategy = FixedChunkingStrategy( + size=str(query_params.get("chunkSize", 1000)), + overlap=str(query_params.get("chunkOverlap", 200)), + ) + + # Get embedding model + embedding_model = collection.get("embeddingModel") if collection else repository.get("embeddingModelId") + source = "collection" if collection else "repository" + logger.info(f"Using embedding model for ingestion: {embedding_model} (from {source})") + + job = IngestionJob( + repository_id=repository.get("repositoryId"), + collection_id=collection_id, + chunk_strategy=chunk_strategy, + embedding_model=embedding_model, + s3_path=s3_path, + username=username, + metadata=metadata, + ingestion_type=ingestion_type, + ) + logger.info(f"Created ingestion job with embedding_model: {embedding_model}") + + return job diff --git a/lambda/repository/job_status.py b/lambda/repository/job_status.py index 9d81c92e0..36cbd00c2 100644 --- a/lambda/repository/job_status.py +++ b/lambda/repository/job_status.py @@ -12,15 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Pydantic models for job status responses.""" +"""Job status helper functions.""" -from pydantic import BaseModel +from models.domain_objects import IngestionStatus -class JobStatus(BaseModel): - """Job status details returned by list_jobs_by_repository.""" +def is_terminal_status(status: IngestionStatus) -> bool: + """Check if status is terminal.""" + return status in [ + IngestionStatus.INGESTION_COMPLETED, + IngestionStatus.INGESTION_FAILED, + IngestionStatus.DELETE_COMPLETED, + IngestionStatus.DELETE_FAILED, + ] - status: str - document: str - auto: bool - created_date: str + +def is_success_status(status: IngestionStatus) -> bool: + """Check if status is success.""" + return status in [ + IngestionStatus.INGESTION_COMPLETED, + IngestionStatus.DELETE_COMPLETED, + ] diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index e825799ac..6cf4f0ada 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -13,37 +13,54 @@ # limitations under the License. """Lambda functions for RAG repository API.""" + import json import logging import os import urllib.parse +from types import SimpleNamespace from typing import Any, cast, Dict, List, Optional import boto3 from boto3.dynamodb.types import TypeSerializer from botocore.config import Config from models.domain_objects import ( - FixedChunkingStrategy, + FilterParams, + IngestDocumentRequest, IngestionJob, IngestionStatus, + IngestionType, ListJobsResponse, PaginationParams, PaginationResult, + RagCollectionConfig, RagDocument, + SortParams, + UpdateVectorStoreRequest, + VectorStoreConfig, + VectorStoreStatus, ) +from repository.collection_service import CollectionService from repository.config.params import ListJobsParams -from repository.embeddings import RagEmbeddings from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService +from repository.metadata_generator import MetadataGenerator from repository.rag_document_repo import RagDocumentRepository +from repository.s3_metadata_manager import S3MetadataManager +from repository.services import RepositoryServiceFactory from repository.vector_store_repo import VectorStoreRepository -from utilities.auth import admin_only, get_user_context, get_username, is_admin -from utilities.bedrock_kb import retrieve_documents -from utilities.common_functions import api_wrapper, get_groups, get_id_token, retry_config, user_has_group_access +from utilities.auth import admin_only, get_groups, get_user_context, get_username, is_admin, user_has_group_access +from utilities.bedrock_kb import create_s3_scan_job +from utilities.bedrock_kb_discovery import ( + build_pipeline_configs_from_kb_config, + get_available_data_sources, + list_knowledge_bases, +) +from utilities.bedrock_kb_validation import validate_bedrock_kb_exists +from utilities.common_functions import api_wrapper, retry_config from utilities.exceptions import HTTPException from utilities.repository_types import RepositoryType from utilities.validation import ValidationError -from utilities.vector_store import get_vector_store_client logger = logging.getLogger(__name__) region_name = os.environ["AWS_REGION"] @@ -68,6 +85,7 @@ vs_repo = VectorStoreRepository() ingestion_service = DocumentIngestionService() ingestion_job_repository = IngestionJobRepository() +collection_service = CollectionService(vector_store_repo=vs_repo, document_repo=doc_repo) @api_wrapper @@ -82,13 +100,12 @@ def list_all(event: dict, context: dict) -> List[Dict[str, Any]]: Returns: List of repository configurations user can access """ - user_groups = get_groups(event) + _, is_admin, groups = get_user_context(event) registered_repositories = vs_repo.get_registered_repositories() - admin_override = is_admin(event) return [ repo for repo in registered_repositories - if admin_override or user_has_group_access(user_groups, repo.get("allowedGroups", [])) + if is_admin or user_has_group_access(groups, repo.get("allowedGroups", [])) ] @@ -113,7 +130,10 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: Args: event (dict): The Lambda event object containing: - - queryStringParameters.modelName: Name of the embedding model + - queryStringParameters.modelName (optional): Name of the embedding model + (not needed if collectionId provided) + - queryStringParameters.collectionName (optional): Collection ID to search within. Will override + any modelName. - queryStringParameters.query: Search query text - queryStringParameters.repositoryType: Type of repository - queryStringParameters.topK (optional): Number of results to return (default: 3) @@ -127,41 +147,57 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: Raises: ValidationError: If required parameters are missing or invalid """ - query_string_params = event["queryStringParameters"] - model_name = query_string_params["modelName"] - query = query_string_params["query"] - top_k = query_string_params.get("topK", 3) + query_string_params = event.get("queryStringParameters") + path_params = event.get("pathParameters") + query = query_string_params.get("query") + top_k = int(query_string_params.get("topK", 3)) include_score = query_string_params.get("score", "false").lower() == "true" - repository_id = event["pathParameters"]["repositoryId"] + repository_id = path_params.get("repositoryId") + collection_id = query_string_params.get("collectionId") repository = get_repository(event, repository_id=repository_id) - id_token = get_id_token(event) + # Get user context for collection access + username, is_admin, groups = get_user_context(event) - docs: List[Dict[str, Any]] = [] - if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): - docs = retrieve_documents( - bedrock_runtime_client=bedrock_client, - repository=repository, - query=query, - top_k=int(top_k), + is_default = collection_id is not None and collection_id == repository.get("embeddingModelId") + # Determine embedding model + model_name = ( + collection_service.get_collection_model( repository_id=repository_id, + collection_id=collection_id if not is_default else None, + username=username, + user_groups=groups, + is_admin=is_admin, ) - else: - embeddings = RagEmbeddings(model_name=model_name, id_token=id_token) - vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) + if collection_id + else query_string_params.get("modelName") + ) + + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): + # No collectionId will query the entire Knowledge base. Reserve for Admins. + if collection_id is None and not is_admin: + raise ValidationError("collectionId is required when searching Bedrock Knowledge Bases") + elif not model_name: + raise ValidationError("modelName is required when collectionId is not provided") + + # Use repository service for similarity search + service = RepositoryServiceFactory.create_service(repository) + + # Use collection_id as vector store index if provided, otherwise use model_name + search_collection_id = collection_id or model_name + logger.info(f"Searching in collection: {search_collection_id} with embedding model: {model_name}") + + # Delegate to service for retrieval - service handles repository-specific logic + docs = service.retrieve_documents( + query=query, + collection_id=search_collection_id, + top_k=top_k, + model_name=model_name, + include_score=include_score, + bedrock_agent_client=bedrock_client, + ) - # empty vector stores do not have an initialize index. Return empty docs - if RepositoryType.is_type(repository, RepositoryType.OPENSEARCH) and not vs.client.indices.exists( - index=model_name - ): - logger.info(f"Index {model_name} does not exist. Returning empty docs.") - else: - docs = ( - _similarity_search_with_score(vs, query, top_k, repository) - if include_score - else _similarity_search(vs, query, top_k) - ) doc_content = [ { "Document": { @@ -177,23 +213,598 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: return doc_return -def get_repository(event: dict[str, Any], repository_id: str) -> None: - repo = vs_repo.find_repository_by_id(repository_id) - """Ensures a user has access to the repository or else raises an HTTPException""" - if is_admin(event) is False: - user_groups = json.loads(event["requestContext"]["authorizer"]["groups"]) or [] - if not user_has_group_access(user_groups, repo.get("allowedGroups", [])): - raise HTTPException(status_code=403, message="User does not have permission to access this repository") +def get_repository(event: dict[str, Any], repository_id: str) -> dict[str, Any]: + """Ensures a user has access to the repository or else raises an HTTPException.""" + repo: dict[str, Any] = vs_repo.find_repository_by_id(repository_id) + + # Admins have access to all repositories + if is_admin(event): + return repo + + # Non-admins must have matching group access + user_groups = get_groups(event) + if not user_has_group_access(user_groups, repo.get("allowedGroups", [])): + raise HTTPException(status_code=403, message="User does not have permission to access this repository") + return repo -def _ensure_document_ownership(event: dict[str, Any], docs: list[dict[str, Any]]) -> None: +def create_bedrock_collection(event: dict, context: dict) -> Dict[str, Any]: + """ + Create collections for a Bedrock Knowledge Base repository based on pipeline configurations. + This is called by the state machine during repository creation. + + Each pipeline configuration represents a data source and should have a corresponding collection. + + Args: + event (dict): The Lambda event object containing: + - ragConfig: Repository configuration with repositoryId and pipelines + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: A dictionary containing: + - collections: List of created collection configurations + - count: Number of collections created + + Raises: + ValidationError: If validation fails + HTTPException: If repository not found + """ + try: + # Extract repository config from state machine input + rag_config = event.get("ragConfig", {}) + repository_id = rag_config.get("repositoryId") + + if not repository_id: + raise ValidationError("repositoryId is required in ragConfig") + + logger.info(f"Creating collection(s) for Bedrock KB repository: {repository_id}") + + # Get repository configuration + repository = vs_repo.find_repository_by_id(repository_id=repository_id) + + # Get pipeline configurations - each pipeline should have a collectionId + pipelines = repository.get("pipelines", []) + + if not pipelines: + raise ValidationError(f"No pipelines found in Bedrock KB repository {repository_id}") + + logger.info(f"Found {len(pipelines)} pipeline(s) to create collections for") + + # Use repository service to create collections + service = RepositoryServiceFactory.create_service(repository) + + # Create a collection for each pipeline (skip if already exists) + created_collections = [] + skipped_collections = [] + for pipeline in pipelines: + collection_id = pipeline.get("collectionId") + if not collection_id: + logger.warning(f"Pipeline missing collectionId, skipping: {pipeline}") + continue + + collection_name = pipeline.get("collectionName", collection_id) + s3_bucket = pipeline.get("s3Bucket", "") + s3_prefix = pipeline.get("s3Prefix", "") + s3_uri = f"s3://{s3_bucket}/{s3_prefix}" if s3_bucket else "" + + # Check if collection already exists + try: + existing_collection = collection_service.get_collection( + repository_id=repository_id, + collection_id=collection_id, + username="system", + user_groups=[], + is_admin=True, + ) + if existing_collection: + logger.info(f"Collection {collection_id} already exists, skipping creation") + skipped_collections.append(existing_collection.model_dump(mode="json")) + continue + except (HTTPException, ValidationError): + # Collection doesn't exist, proceed with creation + pass + + logger.info( + f"Creating collection for pipeline with collectionId={collection_id}, " + f"collectionName={collection_name}, s3Uri={s3_uri}" + ) + + # Create collection using service helper + collection = service._create_collection_for_data_source( + data_source_id=collection_id, s3_uri=s3_uri, is_default=False, collection_name=collection_name + ) + + # Save the collection + collection_service.create_collection(collection=collection, username="system") + logger.info(f"Successfully saved collection: {collection.collectionId}") + created_collections.append(collection.model_dump(mode="json")) + + # Create S3 scan job to ingest existing documents + if s3_bucket: + logger.info(f"Creating S3 scan job for bucket {s3_bucket} with prefix '{s3_prefix}'") + job_id = create_s3_scan_job( + ingestion_job_repository=ingestion_job_repository, + ingestion_service=ingestion_service, + repository_id=repository_id, + collection_id=collection_id, + embedding_model=collection.embeddingModel, + s3_bucket=s3_bucket, + s3_prefix=s3_prefix, + ) + logger.info(f"Created S3 scan job {job_id} for collection {collection_id}") + + if not created_collections and not skipped_collections: + raise ValidationError(f"Failed to create any collections for repository {repository_id}") + + # Return all created and skipped collections + all_collections = created_collections + skipped_collections + result: dict[str, Any] = { + "collections": all_collections, + "count": len(all_collections), + "created": len(created_collections), + "skipped": len(skipped_collections), + } + logger.info(f"Collection summary: {len(created_collections)} created, {len(skipped_collections)} skipped") + logger.info(f"Successfully created {len(created_collections)} collection(s) for repository {repository_id}") + return result + + except Exception as e: + logger.error(f"Error creating Bedrock collection(s): {str(e)}") + raise + + +@api_wrapper +@admin_only +def create_collection(event: dict, context: dict) -> Dict[str, Any]: + """ + Create a new collection within a vector store. + + Args: + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The parent repository ID + - body: JSON with collection configuration (RagCollectionConfig) + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: A dictionary containing the created collection configuration + + Raises: + ValidationError: If validation fails or user lacks permission + HTTPException: If repository not found or access denied + """ + # Extract path parameters + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + + if not repository_id: + raise ValidationError("repositoryId is required") + + # Get user context + username, _, _ = get_user_context(event) + + # Ensure repository exists and user has access + repository = get_repository(event, repository_id=repository_id) + + # Block user-created collections for Bedrock Knowledge Base repositories + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): + raise ValidationError( + "Bedrock Knowledge Base repositories do not support user created collections. " + "Update the repository to add a new datasource collection." + ) + + # Parse request body + try: + body = json.loads(event.get("body", {})) + # Add required fields + body["repositoryId"] = repository_id + body["createdBy"] = username + collection = RagCollectionConfig(**body) + except json.JSONDecodeError as e: + raise ValidationError(f"Invalid JSON in request body: {e}") + except Exception as e: + raise ValidationError(f"Invalid request: {e}") + + # Create collection via service + created_collection = collection_service.create_collection( + collection=collection, + username=username, + ) + + # Return collection configuration + result: dict[str, Any] = created_collection.model_dump(mode="json") + return result + + +@api_wrapper +def get_collection(event: dict, context: dict) -> Dict[str, Any]: + """ + Get a collection by ID within a vector store. + + Args: + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The parent repository ID + - pathParameters.collectionId: The collection ID + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: A dictionary containing the collection configuration + + Raises: + ValidationError: If collection not found or user lacks permission + HTTPException: If repository not found or access denied + """ + # Extract path parameters + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + collection_id = path_params.get("collectionId") + + if not repository_id: + raise ValidationError("repositoryId is required") + if not collection_id: + raise ValidationError("collectionId is required") + + # Get user context + username, is_admin, groups = get_user_context(event) + + # Ensure repository exists and user has access + repo = get_repository(event, repository_id=repository_id) + + if repo.get("embeddingModelId") == collection_id: + # Not a real collection - create virtual default collection + service = RepositoryServiceFactory.create_service(repo) + collection = service.create_default_collection() + else: + # Get collection via service (includes access control check) + collection = collection_service.get_collection( + repository_id=repository_id, + collection_id=collection_id, + username=username, + user_groups=groups, + is_admin=is_admin, + ) + + if collection is None: + raise HTTPException( + status_code=404, message=f"Collection '{collection_id}' not found in repository '{repository_id}'" + ) + + # Return collection configuration + result: dict[str, Any] = collection.model_dump(mode="json") + return result + + +@api_wrapper +@admin_only +def update_collection(event: dict, context: dict) -> Dict[str, Any]: + """ + Update a collection within a vector store. + + Args: + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The parent repository ID + - pathParameters.collectionId: The collection ID + - body: JSON with partial collection updates (RagCollectionConfig) + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: A dictionary containing: + - collection: The updated collection configuration + - warnings: List of warning messages (e.g., chunking strategy changes) + + Raises: + ValidationError: If validation fails or user lacks permission + HTTPException: If repository or collection not found or access denied + """ + # Extract path parameters + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + collection_id = path_params.get("collectionId") + + if not repository_id: + raise ValidationError("repositoryId is required") + if not collection_id: + raise ValidationError("collectionId is required") + + # Get user context + username, is_admin, groups = get_user_context(event) + + # Ensure repository exists and user has access + _ = get_repository(event, repository_id=repository_id) + + # Parse request body - accept partial updates as a dictionary + try: + body = json.loads(event.get("body", {})) + except json.JSONDecodeError as e: + raise ValidationError(f"Invalid JSON in request body: {e}") + + # Create a simple namespace object to hold the update fields + # The service layer expects an object with attributes, not a dict + request = SimpleNamespace(**body) + + # Update collection via service (includes access control check) + updated_collection = collection_service.update_collection( + collection_id=collection_id, + repository_id=repository_id, + collection_data=request, + username=username, + user_groups=groups, + is_admin=is_admin, + ) + + result: dict[str, Any] = updated_collection.model_dump(mode="json") + return result + + +@api_wrapper +@admin_only +def delete_collection(event: dict, context: dict) -> Dict[str, Any]: + """ + Delete a collection (regular or default) within a vector store. + + Path: /repository/{repositoryId}/collection/{collectionId} + + Args: + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The parent repository ID + - pathParameters.collectionId: The collection ID (optional for default collections) + - queryStringParameters.embeddingName: Embedding model name (for default collections) + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: Dictionary with deletion type and job ID + + Raises: + ValidationError: If validation fails or user lacks permission + HTTPException: If repository or collection not found or access denied + """ + # Extract parameters + path_params = event.get("pathParameters", {}) + query_params = event.get("queryStringParameters", {}) or {} + + repository_id = path_params.get("repositoryId") + collection_id = path_params.get("collectionId") # May be None for default collections + embedding_name = query_params.get("embeddingName") # For default collections + + if not repository_id: + raise ValidationError("repositoryId is required") + + # Validate that we have either collectionId or embeddingName + if not collection_id and not embedding_name: + raise ValidationError("Either collectionId or embeddingName must be provided") + + # Get user context + username, is_admin, groups = get_user_context(event) + + # Ensure repository exists and user has access + repo = get_repository(event, repository_id=repository_id) + + is_default_collection = repo.get("embeddingModelId") == collection_id + # Delete collection via service + result: Dict[str, Any] = collection_service.delete_collection( + repository_id=repository_id, + collection_id=collection_id, # None for default collections + embedding_name=embedding_name if is_default_collection else None, # None for regular collections + username=username, + user_groups=groups, + is_admin=is_admin, + ) + + return result + + +@api_wrapper +def list_collections(event: dict, context: dict) -> Dict[str, Any]: + """ + List collections in a repository with pagination, filtering, and sorting. + + Args: + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The parent repository ID + - queryStringParameters.page: Page number (optional, default: 1) + - queryStringParameters.pageSize: Items per page (optional, default: 20, max: 100) + - queryStringParameters.filter: Text filter for name/description (optional) + - queryStringParameters.status: Status filter (ACTIVE, ARCHIVED, DELETED) (optional) + - queryStringParameters.sortBy: Sort field (name, createdAt, updatedAt) (optional, default: createdAt) + - queryStringParameters.sortOrder: Sort order (asc, desc) (optional, default: desc) + - queryStringParameters.lastEvaluatedKey*: Pagination token fields (optional) + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: A dictionary containing: + - collections: List of collection configurations + - pagination: Pagination metadata (totalCount, currentPage, totalPages) + - lastEvaluatedKey: Pagination token for next page + - hasNextPage: Whether there are more pages + - hasPreviousPage: Whether there is a previous page + + Raises: + ValidationError: If validation fails or user lacks permission + HTTPException: If repository not found or access denied + """ + # Extract path parameters + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + + if not repository_id: + raise ValidationError("repositoryId is required") + + # Get user context + username, is_admin, groups = get_user_context(event) + + # Ensure repository exists and user has access + _ = get_repository(event, repository_id=repository_id) + + # Parse query parameters + query_params = event.get("queryStringParameters", {}) or {} + + # Parse pagination parameters using PaginationParams composition object + page_size = PaginationParams.parse_page_size(query_params, default=20, max_size=100) + + # Define key fields based on the potential DynamoDB indexes being used + # collectionId is always present, status and createdAt are optional depending on the index + key_fields = ["collectionId", "status", "createdAt"] + last_evaluated_key = PaginationParams.parse_last_evaluated_key(query_params, key_fields) + + # Parse filter parameters using FilterParams composition object + filter_params = FilterParams.from_query_params(query_params) + filter_text = filter_params.filter_text + status_filter = filter_params.status_filter + + # Parse sort parameters using SortParams composition object + # sort_params = SortParams.from_query_params(query_params) + # sort_by = sort_params.sort_by + # sort_order = sort_params.sort_order + + # List collections via service (includes access control filtering) + collections, next_key = collection_service.list_collections( + repository_id=repository_id, + username=username, + user_groups=groups, + is_admin=is_admin, + page_size=page_size, + last_evaluated_key=last_evaluated_key, + ) + + # Calculate pagination metadata + pagination_result = PaginationResult.from_keys( + original_key=last_evaluated_key, + returned_key=next_key, + ) + + # Get total count (optional - can be expensive for large datasets) + total_count = None + current_page = None + total_pages = None + + # Only calculate total count if no filters are applied (for performance) + if not filter_text and not status_filter: + try: + total_count = collection_service.count_collections(repository_id=repository_id) + + # Calculate page numbers if we have total count + if total_count is not None: + total_pages = (total_count + page_size - 1) // page_size + # Estimate current page based on whether we have a last_evaluated_key + current_page = 1 if not last_evaluated_key else None + except Exception as e: + logger.warning(f"Failed to get total count for repository {repository_id}: {e}") + + # Build response + response = { + "collections": [c.model_dump(mode="json") for c in collections if c is not None], + "pagination": { + "totalCount": total_count, + "currentPage": current_page, + "totalPages": total_pages, + }, + "lastEvaluatedKey": next_key, + "hasNextPage": pagination_result.has_next_page, + "hasPreviousPage": pagination_result.has_previous_page, + } + + return response + + +@api_wrapper +def list_user_collections(event: dict, context: dict) -> Dict[str, Any]: + """ + List all collections user has access to across all repositories. + + Args: + event (dict): The Lambda event object containing: + - queryStringParameters.pageSize: Items per page (optional, default: 20, max: 100) + - queryStringParameters.filter: Text filter for name/description (optional) + - queryStringParameters.sortBy: Sort field (name, createdAt, updatedAt) (optional, default: createdAt) + - queryStringParameters.sortOrder: Sort order (asc, desc) (optional, default: desc) + - queryStringParameters.lastEvaluatedKey: Pagination token (optional, JSON string) + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: A dictionary containing: + - collections: List of collection configurations with repositoryName + - pagination: Pagination metadata + - lastEvaluatedKey: Pagination token for next page + - hasNextPage: Whether there are more pages + - hasPreviousPage: Whether there is a previous page + + Raises: + ValidationError: If validation fails + HTTPException: If authentication fails + """ + # Get user context + username, is_admin, groups = get_user_context(event) + logger.info(f"list_user_collections called by user={username}, is_admin={is_admin}") + + # Parse query parameters + query_params = event.get("queryStringParameters", {}) or {} + + # Parse pagination parameters + page_size = PaginationParams.parse_page_size(query_params, default=20, max_size=100) + + # Parse pagination token + pagination_token = None + if "lastEvaluatedKey" in query_params: + try: + pagination_token = json.loads(query_params["lastEvaluatedKey"]) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Failed to parse pagination token: {e}") + # Continue without token (start from beginning) + + # Parse filter parameters + filter_params = FilterParams.from_query_params(query_params) + filter_text = filter_params.filter_text + + # Parse sort parameters + sort_params = SortParams.from_query_params(query_params) + + # List collections via service + collections, next_token = collection_service.list_all_user_collections( + username=username, + user_groups=groups, + is_admin=is_admin, + page_size=page_size, + pagination_token=pagination_token, + filter_text=filter_text, + sort_params=sort_params, + ) + + # Calculate pagination metadata + has_next_page = next_token is not None + has_previous_page = pagination_token is not None + + # Encode next token as JSON string if present + encoded_next_token = None + if next_token: + try: + encoded_next_token = json.dumps(next_token) + except Exception as e: + logger.error(f"Failed to encode pagination token: {e}") + + # Build response + response = { + "collections": collections, + "pagination": { + "totalCount": None, # Not calculated for cross-repository queries + "currentPage": None, + "totalPages": None, + }, + "lastEvaluatedKey": encoded_next_token, + "hasNextPage": has_next_page, + "hasPreviousPage": has_previous_page, + } + + logger.info(f"Returning {len(collections)} collections, hasNextPage={has_next_page}") + return response + + +def _ensure_document_ownership(event: dict[str, Any], docs: list[RagDocument]) -> None: """Verify ownership of documents""" username = get_username(event) if is_admin(event) is False: for doc in docs: - if not (doc.get("username") == username): - raise ValueError(f"Document {doc.get('document_id')} is not owned by {username}") + if not (doc.username == username): + raise ValueError(f"Document {doc.document_id} is not owned by {username}") @api_wrapper @@ -204,7 +815,7 @@ def delete_documents(event: dict, context: dict) -> Dict[str, Any]: Args: event (dict): The Lambda event object containing: - pathParameters.repositoryId: The repository id of VectorStore - - queryStringParameters.collectionId: The collection identifier + - queryStringParameters.collectionId: The collection ID - queryStringParameters.repositoryType: Type of repository of VectorStore - queryStringParameters.documentIds (optional): Array of document IDs to purge - queryStringParameters.documentName (optional): Name of document to purge @@ -222,11 +833,14 @@ def delete_documents(event: dict, context: dict) -> Dict[str, Any]: repository_id = path_params.get("repositoryId") query_string_params = event.get("queryStringParameters", {}) or {} collection_id = query_string_params.get("collectionId", None) + body = json.loads(event.get("body", "")) document_ids = body.get("documentIds", None) if not document_ids: raise ValidationError("No 'documentIds' parameter supplied") + if not repository_id: + raise ValidationError("repositoryId is required") # Ensure repo access _ = get_repository(event, repository_id=repository_id) @@ -244,6 +858,7 @@ def delete_documents(event: dict, context: dict) -> Dict[str, Any]: # delete s3 files if doc_repo.delete_s3_docs(repository_id, rag_documents) + jobs = [] for rag_document in rag_documents: logger.info(f"Deleting document {rag_document.model_dump()}") @@ -254,6 +869,7 @@ def delete_documents(event: dict, context: dict) -> Dict[str, Any]: document_id=rag_document.document_id, repository_id=rag_document.repository_id, collection_id=rag_document.collection_id, + embedding_model=rag_document.collection_id, # Not needed for deletion chunk_strategy=None, s3_path=rag_document.source, username=rag_document.username, @@ -264,61 +880,175 @@ def delete_documents(event: dict, context: dict) -> Dict[str, Any]: ingestion_service.create_delete_job(ingestion_job) logger.info(f"Deleting document {rag_document.source} for repository {rag_document.repository_id}") - return {"documentIds": document_ids} + jobs.append( + { + "jobId": ingestion_job.id, + "documentId": ingestion_job.document_id, + "status": ingestion_job.status, + "s3Path": ingestion_job.s3_path, + } + ) + return {"jobs": jobs} -@api_wrapper -def ingest_documents(event: dict, context: dict) -> dict: - """Ingest documents into the RAG repository. + +def handle_deprecated_chunking_strategy(request: IngestDocumentRequest, query_params: dict) -> None: + """Handle deprecated chunkSize and chunkOverlap query parameters. + + This function provides backward compatibility by migrating legacy query parameters + to the new chunkingStrategy format. It logs deprecation warnings to encourage + migration to the new API format. Args: - event (dict): The Lambda event object containing: - - body.embeddingModel.modelName: Document collection id - - body.keys: List of s3 keys to ingest - - pathParameters.repositoryId: Repository id (VectorStore) - - queryStringParameters.repositoryType: Repository type (VectorStore) - - queryStringParameters.chunkSize (optional): Size of text chunks - - queryStringParameters.chunkOverlap (optional): Overlap between chunks - context (dict): The Lambda context object + request: The IngestDocumentRequest object to potentially modify + query_params: Query string parameters from the HTTP request - Returns: - dict: A dictionary containing: - - ids (list): List of generated document IDs - - count (int): Total number of documents ingested + Side Effects: + - Logs deprecation warning if legacy parameters are detected + - Modifies request.chunkingStrategy if legacy parameters are present + and chunkingStrategy is not already set - Raises: - ValidationError: If required parameters are missing or invalid + Deprecated Parameters: + - chunkSize: Size of each chunk (use chunkingStrategy.size instead) + - chunkOverlap: Overlap between chunks (use chunkingStrategy.overlap instead) """ + if "chunkSize" in query_params or "chunkOverlap" in query_params: + logger.warning( + "DEPRECATION WARNING: Query parameters 'chunkSize' and 'chunkOverlap' are deprecated. " + "Please use the 'chunkingStrategy' object in the request body instead. " + "Legacy parameters will be removed in a future version." + ) + + # Migrate legacy parameters to new format if chunkingStrategy not provided + if not request.chunkingStrategy: + chunk_size = int(query_params.get("chunkSize", 512)) + chunk_overlap = int(query_params.get("chunkOverlap", 51)) + + # Create chunkingStrategy from legacy parameters + request.chunkingStrategy = {"type": "fixed", "size": chunk_size, "overlap": chunk_overlap} + logger.info( + f"Migrated legacy parameters to chunkingStrategy: " f"size={chunk_size}, overlap={chunk_overlap}" + ) + if "collectionId" in query_params: + request.collectionId = query_params.get("collectionId") + + +@api_wrapper +def ingest_documents(event: dict, context: dict) -> dict: + """Ingest documents into the RAG repository.""" body = json.loads(event["body"]) - embedding_model = body["embeddingModel"] - model_name = embedding_model["modelName"] + request = IngestDocumentRequest(**body) + repository_id = event.get("pathParameters", {}).get("repositoryId") + query_params = event.get("queryStringParameters", {}) or {} bucket = os.environ["BUCKET_NAME"] - path_params = event.get("pathParameters", {}) - repository_id = path_params.get("repositoryId") - - query_string_params = event["queryStringParameters"] - chunk_size = int(query_string_params["chunkSize"]) if "chunkSize" in query_string_params else None - chunk_overlap = int(query_string_params["chunkOverlap"]) if "chunkOverlap" in query_string_params else None - logger.info(f"using repository {repository_id}") + # Handle deprecated chunking parameters + handle_deprecated_chunking_strategy(request, query_params) - username = get_username(event) - _ = get_repository(event, repository_id=repository_id) + username, is_admin, groups = get_user_context(event) + repository = get_repository(event, repository_id=repository_id) - ingestion_document_ids = [] - for key in body["keys"]: - job = IngestionJob( + # Get collection if specified + collection: Optional[dict[str, Any]] = None + if request.collectionId: + collection = collection_service.get_collection( + collection_id=request.collectionId, repository_id=repository_id, - collection_id=model_name, - chunk_strategy=FixedChunkingStrategy(size=str(chunk_size), overlap=str(chunk_overlap)), + username=username, + user_groups=groups, + is_admin=is_admin, + ).model_dump() + + # For Bedrock KB repositories, upload metadata files BEFORE documents + is_bedrock_kb = RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB) + if is_bedrock_kb: + + metadata_generator = MetadataGenerator() + s3_metadata_manager = S3MetadataManager() + + # Get collection object for metadata generation + collection_obj = None + if request.collectionId: + try: + collection_obj = collection_service.get_collection( + collection_id=request.collectionId, + repository_id=repository_id, + username=username, + user_groups=groups, + is_admin=is_admin, + ) + except Exception as e: + logger.warning(f"Could not fetch collection for metadata: {e}") + + # Upload metadata files first + for key in request.keys: + try: + # Generate metadata content + metadata_content = metadata_generator.generate_metadata_json( + repository=repository, collection=collection_obj, document_metadata=None + ) + + # Upload metadata file + s3_metadata_manager.upload_metadata_file( + s3_client=s3, bucket=bucket, document_key=key, metadata_content=metadata_content + ) + logger.info(f"Uploaded metadata file for {key}") + except Exception as e: + logger.error(f"Failed to upload metadata file for {key}: {e}") + # Continue with document upload even if metadata fails + + # Create jobs + jobs = [] + for key in request.keys: + job = ingestion_service.create_ingestion_job( + repository=repository, + collection=collection, + request=request, + query_params=query_params, s3_path=f"s3://{bucket}/{key}", username=username, + metadata=None, + ingestion_type=IngestionType.MANUAL, ) ingestion_job_repository.save(job) - ingestion_service.create_ingest_job(job) - ingestion_document_ids.append(job.id) + ingestion_service.submit_create_job(job) + jobs.append({"jobId": job.id, "documentId": job.document_id, "status": job.status, "s3Path": job.s3_path}) + + collection_id = job.collection_id + collection_name: Optional[str] = None + if collection: + collection_name = collection.get("name") + if not collection_name: + collection_name = collection_id + result: dict[str, Any] = {"jobs": jobs, "collectionId": collection_id, "collectionName": collection_name} + return result + + +@api_wrapper +def get_document(event: dict, context: dict) -> Dict[str, Any]: + """Get a document by ID. + + Args: + event (dict): The Lambda event object containing: + path_params: + repositoryId - the repository + documentId - the document - return {"ingestionJobIds": ingestion_document_ids} + Returns: + dict: The document object + """ + path_params = event.get("pathParameters", {}) or {} + repository_id = path_params.get("repositoryId") + document_id = path_params.get("documentId") + if repository_id is None or document_id is None: + raise ValidationError("Must set the repositoryId and documentId") + if not isinstance(repository_id, str): + raise ValidationError("repositoryId must be a string") + _ = get_repository(event, repository_id=repository_id) + doc = doc_repo.find_by_id(document_id=document_id) + + result: dict[str, Any] = doc.model_dump() + return result @api_wrapper @@ -340,6 +1070,8 @@ def download_document(event: dict, context: dict) -> str: repository_id = path_params.get("repositoryId") document_id = path_params.get("documentId") + if not repository_id: + raise ValidationError("repositoryId is required") _ = get_repository(event, repository_id=repository_id) doc = doc_repo.find_by_id(document_id=document_id) @@ -404,7 +1136,7 @@ def list_docs(event: dict, context: dict) -> dict[str, Any]: Args: event (dict): The Lambda event object containing query parameters - pathParameters.repositoryId: The repository id to list documents for - - queryStringParameters.collectionId: The collection id to list documents for + - queryStringParameters.collectionName: The collection name to list documents for context (dict): The Lambda context object Returns: @@ -419,8 +1151,14 @@ def list_docs(event: dict, context: dict) -> dict[str, Any]: query_string_params = event.get("queryStringParameters", {}) or {} collection_id = query_string_params.get("collectionId") + last_evaluated: Optional[dict[str, Optional[str]]] = None + if not repository_id: + raise ValidationError("repositoryId is required") + # Validate repository access + _ = get_repository(event, repository_id=repository_id) + if "lastEvaluatedKeyPk" in query_string_params: last_evaluated = { "pk": ( @@ -472,11 +1210,13 @@ def list_jobs(event: Dict[str, Any], context: dict) -> Dict[str, Any]: # Extract and validate parameters params = ListJobsParams.from_event(event) + if not params.repository_id: + raise ValidationError("repositoryId is required") # Validate repository access _ = get_repository(event, repository_id=params.repository_id) # Get user context - username, is_admin_user = get_user_context(event) + username, is_admin_user, _ = get_user_context(event) # Fetch jobs from repository jobs, returned_last_evaluated_key = ingestion_job_repository.list_jobs_by_repository( @@ -500,7 +1240,8 @@ def list_jobs(event: Dict[str, Any], context: dict) -> Dict[str, Any]: hasPreviousPage=pagination.has_previous_page, ) - return response.model_dump() + result: dict[str, Any] = response.model_dump() + return result @api_wrapper @@ -509,9 +1250,12 @@ def create(event: dict, context: dict) -> Any: """ Create a new process execution using AWS Step Functions. This function is only accessible by administrators. + For Bedrock Knowledge Base repositories, automatically adds a default pipeline configuration + if none is provided, using the datasource S3 bucket for event-driven ingestion. + Args: event (dict): The Lambda event object containing: - - body: A JSON string with the process creation details. + - body: A JSON string with the process creation details containing VectorStoreConfig. context (dict): The Lambda context object. Returns: @@ -521,13 +1265,43 @@ def create(event: dict, context: dict) -> Any: Raises: ValueError: If the user is not an administrator. + ValidationError: If the request body is invalid. """ # Fetch the Step Function ARN from SSM Parameter Store parameter_name = os.environ["LISA_RAG_CREATE_STATE_MACHINE_ARN_PARAMETER"] state_machine_arn = ssm_client.get_parameter(Name=parameter_name) - # Deserialize the event body and prepare input for Step Functions - input_data = json.loads(event["body"]) + # Deserialize the event body and parse as VectorStoreConfig + try: + body = json.loads(event["body"]) + vector_store_config = VectorStoreConfig(**body) + except json.JSONDecodeError as e: + raise ValidationError(f"Invalid JSON in request body: {e}") + except Exception as e: + raise ValidationError(f"Invalid VectorStoreConfig: {e}") + + # Auto-convert Bedrock KB config to pipelines + if vector_store_config.type == RepositoryType.BEDROCK_KB: + if not vector_store_config.bedrockKnowledgeBaseConfig: + raise ValidationError("Bedrock Knowledge Base configuration is required") + + if ( + not vector_store_config.bedrockKnowledgeBaseConfig.dataSources + or len(vector_store_config.bedrockKnowledgeBaseConfig.dataSources) == 0 + ): + raise ValidationError( + "Bedrock Knowledge Base repositories require at least one data source. " + "Please select at least one data source." + ) + # Convert bedrockKnowledgeBaseConfig to pipelines + vector_store_config.pipelines = build_pipeline_configs_from_kb_config( + vector_store_config.bedrockKnowledgeBaseConfig + ) + + # Convert to dictionary for Step Functions input + rag_config = vector_store_config.model_dump(mode="json", exclude_none=True) + input_data = {"ragConfig": rag_config} + serializer = TypeSerializer() # Start Step Function execution @@ -536,7 +1310,7 @@ def create(event: dict, context: dict) -> Any: input=json.dumps( { "body": input_data, - "config": {key: serializer.serialize(value) for key, value in input_data["ragConfig"].items()}, + "config": {key: serializer.serialize(value) for key, value in rag_config.items()}, } ), ) @@ -545,12 +1319,171 @@ def create(event: dict, context: dict) -> Any: return {"status": "success", "executionArn": response["executionArn"]} +@api_wrapper +def get_repository_by_id(event: dict, context: dict) -> Dict[str, Any]: + """ + Get a vector store configuration by ID. + + Args: + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The repository ID to retrieve + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: The repository configuration with default values for new fields + + Raises: + ValidationError: If repositoryId is missing + HTTPException: If repository not found or access denied + """ + # Extract path parameters + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + + if not repository_id: + raise ValidationError("repositoryId is required") + + # Get repository and check access + repository: dict[str, Any] = get_repository(event, repository_id) + + return repository + + +@api_wrapper +@admin_only +def update_repository(event: dict, context: dict) -> Dict[str, Any]: + """ + Update a vector store configuration. This function is only accessible by administrators. + + If the pipeline configuration has changed, this will trigger an infrastructure deployment + using the state machine, similar to repository creation. + + Args: + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The repository ID to update + - body: JSON with fields to update (UpdateVectorStoreRequest) + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: The updated repository configuration with executionArn if deployment triggered + + Raises: + ValidationError: If validation fails + HTTPException: If repository not found + """ + # Extract path parameters + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + + if not repository_id: + raise ValidationError("repositoryId is required") + + # Parse request body + try: + body = json.loads(event.get("body", {})) + request = UpdateVectorStoreRequest(**body) + except json.JSONDecodeError as e: + raise ValidationError(f"Invalid JSON in request body: {e}") + except Exception as e: + raise ValidationError(f"Invalid request: {e}") + + # Get current repository configuration to check for pipeline changes + current_repo = vs_repo.find_repository_by_id(repository_id, raw_config=True) + current_config = current_repo.get("config", {}) + current_pipelines = current_config.get("pipelines") + + # Build updates dictionary (only include fields that were provided) + updates = request.model_dump(exclude_none=True, mode="json") + + # Convert bedrockKnowledgeBaseConfig to pipelines for Bedrock KB repositories + repository_type = current_config.get("type") + if ( + repository_type == RepositoryType.BEDROCK_KB + and hasattr(request, "bedrockKnowledgeBaseConfig") + and request.bedrockKnowledgeBaseConfig is not None + ): + # Validate at least one data source + if ( + not request.bedrockKnowledgeBaseConfig.dataSources + or len(request.bedrockKnowledgeBaseConfig.dataSources) == 0 + ): + raise ValidationError( + "Bedrock Knowledge Base repositories require at least one collection. " + "Please select at least one data source." + ) + # Convert bedrockKnowledgeBaseConfig to pipelines + updates["pipelines"] = build_pipeline_configs_from_kb_config(request.bedrockKnowledgeBaseConfig) + logger.info(f"Converted {len(request.bedrockKnowledgeBaseConfig.dataSources)} data sources to pipeline configs") + + # Check if pipeline configuration has changed + # Use the converted pipelines from updates if available, otherwise use request.pipelines + new_pipelines = updates.get("pipelines") if "pipelines" in updates else request.pipelines + require_deployment = False + + if new_pipelines is not None: + # For Bedrock KB repositories, only check if data source IDs (collectionIds) have changed + if repository_type == RepositoryType.BEDROCK_KB: + current_collection_ids = {p.get("collectionId") for p in (current_pipelines or []) if p.get("collectionId")} + new_collection_ids = {p.get("collectionId") for p in new_pipelines if p.get("collectionId")} + + if current_collection_ids != new_collection_ids: + added = new_collection_ids - current_collection_ids + removed = current_collection_ids - new_collection_ids + logger.info(f"Bedrock KB data sources changed: added={list(added)}, removed={list(removed)}") + require_deployment = True + else: + logger.info("Bedrock KB data sources unchanged, no deployment needed") + else: + # For other repository types, compare full pipeline configs + require_deployment = new_pipelines != current_pipelines + + # Set status based on deployment requirement + status = VectorStoreStatus.UPDATE_IN_PROGRESS if require_deployment else VectorStoreStatus.UPDATE_COMPLETE + + # Update repository + updated_config: dict[str, Any] = vs_repo.update(repository_id, updates, status=status) + + # Trigger infrastructure deployment if pipeline changed + if require_deployment: + logger.info(f"Pipeline configuration changed for repository {repository_id}, triggering deployment") + + # Fetch the Step Function ARN from SSM Parameter Store + parameter_name = os.environ["LISA_RAG_CREATE_STATE_MACHINE_ARN_PARAMETER"] + state_machine_arn = ssm_client.get_parameter(Name=parameter_name) + + # Prepare input data for state machine (similar to create) + serializer = TypeSerializer() + rag_config = updated_config.copy() + rag_config["repositoryId"] = repository_id + # Remove status field - it will be set by the state machine + rag_config.pop("status", None) + + input_data = {"ragConfig": rag_config} + + # Start Step Function execution + response = step_functions_client.start_execution( + stateMachineArn=state_machine_arn["Parameter"]["Value"], + input=json.dumps( + { + "body": input_data, + "config": {key: serializer.serialize(value) for key, value in rag_config.items()}, + } + ), + ) + + logger.info(f"Started state machine execution: {response['executionArn']}") + updated_config["executionArn"] = response["executionArn"] + + return updated_config + + @api_wrapper @admin_only def delete(event: dict, context: dict) -> Any: """ Delete a vector store process using AWS Step Functions. This function ensures that the user is an administrator or owns the vector store being deleted. + Also deletes all associated collections and their documents. Args: event (dict): The Lambda event object containing: @@ -572,6 +1505,36 @@ def delete(event: dict, context: dict) -> Any: raise ValidationError("repositoryId is required") repository = vs_repo.find_repository_by_id(repository_id=repository_id, raw_config=True) + + # Delete all collections associated with this repository + try: + logger.info(f"Deleting all collections for repository: {repository_id}") + collections, _ = collection_service.list_collections( + repository_id=repository_id, + username="admin", + user_groups=[], + is_admin=True, + page_size=1000, # Get all collections + ) + + for collection in collections: + try: + logger.info(f"Deleting collection: {collection.collectionId}") + collection_service.delete_collection( + collection_id=collection.collectionId, + repository_id=repository_id, + embedding_name=collection.embeddingModel if collection.default else None, + username="admin", + user_groups=[], + is_admin=True, + ) + except Exception as e: + logger.error(f"Error deleting collection {collection.collectionId}: {str(e)}") + # Continue with other collections even if one fails + except Exception as e: + logger.error(f"Error listing/deleting collections for repository {repository_id}: {str(e)}") + # Continue with repository deletion even if collection cleanup fails + if repository.get("legacy", False) is True: _remove_legacy(repository_id) vs_repo.delete(repository_id=repository_id) @@ -591,48 +1554,6 @@ def delete(event: dict, context: dict) -> Any: return {"status": "success", "executionArn": response["executionArn"]} -@api_wrapper -@admin_only -def delete_index(event: dict, context: dict) -> None: - """ - Clear the vector store for the specified repository and model. - - Args: - event (dict): The Lambda event object containing path parameters - context (dict): The Lambda context object - """ - path_params = event.get("pathParameters", {}) or {} - repository_id = path_params.get("repositoryId", None) - if not repository_id: - raise ValidationError("repositoryId is required") - model_name = path_params.get("modelName", None) - if not model_name: - raise ValidationError("modelName is required") - - repository = vs_repo.find_repository_by_id(repository_id=repository_id) - id_token = get_id_token(event) - embeddings = RagEmbeddings(model_name=model_name, id_token=id_token) - vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) - - try: - if RepositoryType.is_type(repository, RepositoryType.OPENSEARCH): - if vs.client.indices.exists(index=model_name): - vs.client.indices.delete(index=model_name) - logger.info(f"Deleted OpenSearch index: {model_name}") - else: - logger.info(f"OpenSearch index {model_name} does not exist") - elif RepositoryType.is_type(repository, RepositoryType.PGVECTOR): - # For PGVector, delete all documents in the collection - vs.delete_collection() - logger.info(f"Deleted PGVector collection: {model_name}") - else: - logger.error(f"Unsupported repository type: {repository.get('type')}") - return {"status": "error", "message": "Repository is not supported"} - except Exception as e: - logger.error(f"Failed to clear vector store: {e}") - return {"status": "error", "message": str(e)} - - def _remove_legacy(repository_id: str) -> None: registered_repositories = ssm_client.get_parameter(Name=os.environ["REGISTERED_REPOSITORIES_PS"]) registered_repositories = json.loads(registered_repositories["Parameter"]["Value"]) @@ -648,57 +1569,113 @@ def _remove_legacy(repository_id: str) -> None: ) -def _similarity_search(vs, query: str, top_k: int) -> list[dict[str, Any]]: - """Perform similarity search without scores. +@api_wrapper +def list_bedrock_knowledge_bases(event: dict, context: dict) -> Dict[str, Any]: + """ + List all ACTIVE Bedrock Knowledge Bases in the AWS account. + + Marks KBs as unavailable if they're already associated with a repository. Args: - vs: Vector store instance - query: Search query string - top_k: Number of top results to return + event: Lambda event + context: Lambda context Returns: - List of documents with page_content and metadata + Dictionary with: + - knowledgeBases: List of ACTIVE Knowledge Base metadata with availability status + - totalKnowledgeBases: Count of ACTIVE KBs + + Raises: + ValidationError: If discovery fails """ - results = vs.similarity_search_with_score( - query, - k=top_k, + logger.info("Listing all ACTIVE Knowledge Bases") + + # Create bedrock-agent client + bedrock_agent_client = boto3.client("bedrock-agent", region_name, config=retry_config) + + # Get all knowledge bases and filter to ACTIVE only + all_kbs = list_knowledge_bases(bedrock_agent_client) + active_kbs = [kb for kb in all_kbs if kb.status == "ACTIVE"] + + # Get all existing repositories to check which KBs are already in use + existing_repos = vs_repo.get_registered_repositories() + used_kb_ids = set() + + for repo in existing_repos: + config = repo.get("config", {}) + bedrock_config = config.get("bedrockKnowledgeBaseConfig") + if bedrock_config and isinstance(bedrock_config, dict): + kb_id = bedrock_config.get("knowledgeBaseId") + if kb_id: + used_kb_ids.add(kb_id) + + # Convert to dictionaries and mark KBs as available or unavailable + kb_list = [] + for kb in active_kbs: + kb_dict = kb.model_dump(mode="json") + kb_dict["available"] = kb.knowledgeBaseId not in used_kb_ids + if not kb_dict["available"]: + kb_dict["unavailableReason"] = "Already associated with another repository" + kb_list.append(kb_dict) + + logger.info( + f"Found {len(active_kbs)} ACTIVE Knowledge Bases out of {len(all_kbs)} total, " + f"{len(used_kb_ids)} already in use" ) - return [{"page_content": doc.page_content, "metadata": doc.metadata} for doc, score in results] + return {"knowledgeBases": kb_list, "totalKnowledgeBases": len(kb_list)} -def _similarity_search_with_score(vs, query: str, top_k: int, repository: dict) -> list[dict[str, Any]]: - """Perform similarity search with normalized scores. +@api_wrapper +def list_bedrock_data_sources(event: dict, context: dict) -> Dict[str, Any]: + """ + List data sources for a specific Bedrock Knowledge Base. Args: - vs: Vector store instance - query: Search query string - top_k: Number of top results to return - repository: Repository configuration dict + event: Lambda event containing: + - pathParameters.kbId: Knowledge Base ID + - queryStringParameters.repositoryId (optional): Repository ID to check managed data sources + - queryStringParameters.refresh (optional): Force refresh cache (default: false) + context: Lambda context Returns: - List of documents with page_content, metadata, and similarity_score + Dictionary with: + - knowledgeBase: KB metadata (id, name, description) + - availableDataSources: Data sources not yet managed + - managedDataSources: Data sources already managed by collections + - totalDataSources: Total count + + Raises: + ValidationError: If KB not found or discovery fails """ - results = vs.similarity_search_with_score( - query, - k=top_k, - ) - docs = [] - for i, (doc, score) in enumerate(results): - similarity_score = RepositoryType.get_type(repository=repository).calculate_similarity_score(score) - logger.info( - f"Result {i + 1}: Raw Score={score:.4f}, Similarity={similarity_score:.4f}, " - + f"Content: {doc.page_content[:200]}..." - ) - logger.info(f"Result {i + 1} metadata: {doc.metadata}") - docs.append( - { - "page_content": doc.page_content, - "metadata": {**doc.metadata, "similarity_score": similarity_score}, - } - ) + path_params = event.get("pathParameters", {}) + query_params = event.get("queryStringParameters") or {} + + kb_id = path_params.get("kbId") + if not kb_id: + raise ValidationError("kbId is required") + + repository_id = query_params.get("repositoryId") + + logger.info(f"Listing data sources for KB {kb_id}, repository={repository_id}") - if results and max(score for _, score in results) < 0.3: - logger.warning(f"All similarity < 0.3 for query '{query}' - possible embedding model mismatch") + # Create bedrock-agent client + bedrock_agent_client = boto3.client("bedrock-agent", region_name, config=retry_config) - return docs + # Validate KB exists and get metadata + kb_config = validate_bedrock_kb_exists(kb_id, bedrock_agent_client) + + # Get available and managed data sources + data_sources = get_available_data_sources( + kb_id=kb_id, + repository_id=repository_id, + bedrock_agent_client=bedrock_agent_client, + ) + + return { + "knowledgeBase": { + "id": kb_id, + "name": kb_config.get("name"), + }, + "dataSources": [ds.model_dump(mode="json") for ds in data_sources], + } diff --git a/lambda/repository/metadata_generator.py b/lambda/repository/metadata_generator.py new file mode 100644 index 000000000..4a6bc1a02 --- /dev/null +++ b/lambda/repository/metadata_generator.py @@ -0,0 +1,333 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metadata generator for Bedrock KB documents.""" + +import json +import logging +import re +import time +from typing import Any, Dict, Optional + +import boto3 +from models.domain_objects import CollectionMetadata, RagCollectionConfig +from utilities.validation import ValidationError + +logger = logging.getLogger(__name__) + +# Bedrock KB metadata limits +MAX_METADATA_SIZE_BYTES = 10240 # 10KB +MAX_METADATA_KEY_LENGTH = 100 +MAX_METADATA_VALUE_LENGTH = 1000 + +# Reserved Bedrock KB field names +RESERVED_FIELDS = { + "x-amz-bedrock-kb-source-uri", + "x-amz-bedrock-kb-data-source-id", + "x-amz-bedrock-kb-chunk-id", +} + + +class MetadataGenerator: + """Generator for Bedrock KB metadata files.""" + + def __init__(self, cloudwatch_client=None): + """Initialize metadata generator with caching. + + Args: + cloudwatch_client: Optional CloudWatch client for metrics (defaults to creating one) + """ + self._collection_cache: Dict[str, Dict[str, Any]] = {} + self._cache_ttl = 300 # 5 minutes + self.cloudwatch_client = cloudwatch_client or boto3.client("cloudwatch") + + def _emit_metric( + self, + metric_name: str, + value: float = 1.0, + repository_id: Optional[str] = None, + collection_id: Optional[str] = None, + ) -> None: + """Emit CloudWatch metric. + + Args: + metric_name: Name of the metric + value: Metric value + repository_id: Optional repository ID for dimensions + collection_id: Optional collection ID for dimensions + """ + try: + dimensions = [] + if repository_id: + dimensions.append({"Name": "RepositoryId", "Value": repository_id}) + if collection_id: + dimensions.append({"Name": "CollectionId", "Value": collection_id}) + + metric_data = { + "MetricName": metric_name, + "Value": value, + "Unit": "Count", + } + + if dimensions: + metric_data["Dimensions"] = dimensions + + self.cloudwatch_client.put_metric_data(Namespace="LISA/BedrockKB", MetricData=[metric_data]) + except Exception as e: + logger.warning(f"Failed to emit CloudWatch metric {metric_name}: {e}") + + def generate_metadata_json( + self, + repository: Dict[str, Any], + collection: Optional[RagCollectionConfig], + document_metadata: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Generate metadata.json content for Bedrock KB. + + Merges metadata from three sources with precedence: + 1. Repository metadata (lowest) + 2. Collection metadata (medium) + 3. Document metadata (highest) + + Args: + repository: Repository configuration dictionary + collection: Collection configuration (optional) + document_metadata: Document-specific metadata (optional) + + Returns: + Dictionary with metadataAttributes structure for Bedrock KB + + Raises: + ValidationError: If metadata validation fails + """ + # Start with empty metadata + merged_metadata: Dict[str, Any] = {} + + # Merge repository metadata + repo_metadata = repository.get("metadata") + if repo_metadata: + if isinstance(repo_metadata, dict): + # Add tags as individual fields + for tag in repo_metadata.get("tags", []): + merged_metadata[f"tag_{tag}"] = True + elif hasattr(repo_metadata, "tags"): + for tag in repo_metadata.tags: + merged_metadata[f"tag_{tag}"] = True + + # Merge collection metadata + if collection and collection.metadata: + coll_metadata = collection.metadata + if isinstance(coll_metadata, CollectionMetadata): + # Add tags as individual fields + for tag in coll_metadata.tags: + merged_metadata[f"tag_{tag}"] = True + elif isinstance(coll_metadata, dict): + for tag in coll_metadata.get("tags", []): + merged_metadata[f"tag_{tag}"] = True + + # Add collection identifiers + if collection: + merged_metadata["collectionId"] = collection.collectionId + merged_metadata["collectionName"] = collection.name or collection.collectionId + + # Add repository identifier + merged_metadata["repositoryId"] = repository.get("repositoryId", "") + + # Merge document-specific metadata (highest precedence) + if document_metadata: + merged_metadata.update(document_metadata) + + # Create comma-separated tags from all tag_ fields + tag_keys = [key[4:] for key in merged_metadata.keys() if key.startswith("tag_")] + if tag_keys: + merged_metadata["tags"] = ",".join(sorted(tag_keys)) + + # Validate merged metadata + self.validate_metadata(merged_metadata) + + # Return in Bedrock KB format + return {"metadataAttributes": merged_metadata} + + def validate_metadata( + self, + metadata: Dict[str, Any], + repository_id: Optional[str] = None, + collection_id: Optional[str] = None, + ) -> bool: + """Validate metadata against Bedrock KB requirements. + + Args: + metadata: Metadata dictionary to validate + repository_id: Optional repository ID for metrics + collection_id: Optional collection ID for metrics + + Returns: + True if valid + + Raises: + ValidationError: If validation fails + """ + try: + # Check total size + metadata_json = json.dumps(metadata) + metadata_size = len(metadata_json.encode("utf-8")) + if metadata_size > MAX_METADATA_SIZE_BYTES: + self._emit_metric("MetadataValidationFailed", 1.0, repository_id, collection_id) + raise ValidationError( + f"Metadata size ({metadata_size} bytes) exceeds limit ({MAX_METADATA_SIZE_BYTES} bytes)" + ) + + # Validate each key-value pair + for key, value in metadata.items(): + # Validate key + self._validate_metadata_key(key) + + # Validate value type and size + self._validate_metadata_value(key, value) + + return True + except ValidationError: + self._emit_metric("MetadataValidationFailed", 1.0, repository_id, collection_id) + raise + + def _validate_metadata_key(self, key: str) -> None: + """Validate metadata key. + + Args: + key: Metadata key to validate + + Raises: + ValidationError: If key is invalid + """ + # Check for reserved fields + if key in RESERVED_FIELDS: + raise ValidationError(f"Metadata key '{key}' is reserved by Bedrock KB") + + # Check length + if len(key) > MAX_METADATA_KEY_LENGTH: + raise ValidationError(f"Metadata key '{key}' exceeds maximum length of {MAX_METADATA_KEY_LENGTH}") + + # Check format (alphanumeric, underscore, hyphen only) + if not re.match(r"^[a-zA-Z0-9_-]+$", key): + raise ValidationError( + f"Metadata key '{key}' contains invalid characters. " + "Only alphanumeric, underscore, and hyphen are allowed" + ) + + def _validate_metadata_value(self, key: str, value: Any) -> None: + """Validate metadata value. + + Args: + key: Metadata key (for error messages) + value: Metadata value to validate + + Raises: + ValidationError: If value is invalid + """ + # Check type + if not isinstance(value, (str, int, float, bool, list)): + raise ValidationError( + f"Metadata value for key '{key}' has invalid type '{type(value).__name__}'. " + "Only string, number, boolean, and array are allowed" + ) + + # Check string length + if isinstance(value, str) and len(value) > MAX_METADATA_VALUE_LENGTH: + raise ValidationError( + f"Metadata value for key '{key}' exceeds maximum length of {MAX_METADATA_VALUE_LENGTH}" + ) + + # Check array elements + if isinstance(value, list): + for item in value: + if not isinstance(item, (str, int, float, bool)): + raise ValidationError( + f"Metadata array for key '{key}' contains invalid type '{type(item).__name__}'. " + "Array elements must be string, number, or boolean" + ) + if isinstance(item, str) and len(item) > MAX_METADATA_VALUE_LENGTH: + raise ValidationError( + f"Metadata array element for key '{key}' exceeds maximum length of {MAX_METADATA_VALUE_LENGTH}" + ) + + def get_metadata_s3_key(self, document_s3_key: str) -> str: + """Generate S3 key for metadata file. + + Args: + document_s3_key: S3 key of the document + + Returns: + S3 key for the metadata file (document_key + ".metadata.json") + """ + return f"{document_s3_key}.metadata.json" + + def get_collection_metadata_cached( + self, collection_id: str, repository_id: str, collection_repo + ) -> Optional[Dict[str, Any]]: + """Get collection metadata with caching. + + Args: + collection_id: Collection ID + repository_id: Repository ID + collection_repo: Collection repository instance + + Returns: + Collection metadata dictionary or None + """ + cache_key = f"{repository_id}#{collection_id}" + cached = self._collection_cache.get(cache_key) + + # Check cache + if cached and time.time() - cached["timestamp"] < self._cache_ttl: + logger.debug(f"Using cached metadata for collection {collection_id}") + return cached["metadata"] + + # Fetch from DynamoDB + try: + collection = collection_repo.find_by_id(collection_id, repository_id) + if collection and collection.metadata: + metadata = collection.metadata + if isinstance(metadata, CollectionMetadata): + # Flatten metadata for Bedrock KB compatibility + metadata_dict = {} + # Add tags as an array field + if metadata.tags: + metadata_dict["tags"] = metadata.tags + else: + metadata_dict = metadata + + # Cache result + self._collection_cache[cache_key] = {"metadata": metadata_dict, "timestamp": time.time()} + + logger.debug(f"Cached metadata for collection {collection_id}") + return metadata_dict + except Exception as e: + logger.warning(f"Failed to fetch collection metadata: {e}") + + return None + + def clear_cache(self, collection_id: Optional[str] = None, repository_id: Optional[str] = None) -> None: + """Clear metadata cache. + + Args: + collection_id: Specific collection to clear (optional) + repository_id: Specific repository to clear (optional) + """ + if collection_id and repository_id: + cache_key = f"{repository_id}#{collection_id}" + self._collection_cache.pop(cache_key, None) + logger.debug(f"Cleared cache for collection {collection_id}") + else: + self._collection_cache.clear() + logger.debug("Cleared all metadata cache") diff --git a/lambda/repository/pipeline_delete_documents.py b/lambda/repository/pipeline_delete_documents.py index 307305412..f5eb30da7 100644 --- a/lambda/repository/pipeline_delete_documents.py +++ b/lambda/repository/pipeline_delete_documents.py @@ -17,13 +17,16 @@ from typing import Any, Dict import boto3 -from models.domain_objects import IngestionJob, IngestionStatus, IngestionType +from boto3.dynamodb.conditions import Key +from models.domain_objects import CollectionStatus, IngestionJob, IngestionStatus, IngestionType, JobActionType +from repository.collection_repo import CollectionRepository from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService from repository.pipeline_ingest_documents import remove_document_from_vectorstore from repository.rag_document_repo import RagDocumentRepository +from repository.services.repository_service_factory import RepositoryServiceFactory from repository.vector_store_repo import VectorStoreRepository -from utilities.bedrock_kb import delete_document_from_kb +from utilities.bedrock_kb import bulk_delete_documents_from_kb, delete_document_from_kb from utilities.common_functions import retry_config from utilities.repository_types import RepositoryType @@ -36,9 +39,184 @@ s3 = boto3.client("s3", region_name=os.environ["AWS_REGION"], config=retry_config) bedrock_agent = boto3.client("bedrock-agent", region_name=os.environ["AWS_REGION"], config=retry_config) rag_document_repository = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"]) +collection_repo = CollectionRepository() + + +def drop_opensearch_index(repository_id: str, collection_id: str) -> None: + """Drop OpenSearch index using repository service. + + Args: + repository_id: Repository ID + collection_id: Collection ID + """ + try: + logger.info(f"Dropping OpenSearch index for collection {collection_id}") + + repository = vs_repo.find_repository_by_id(repository_id) + service = RepositoryServiceFactory.create_service(repository) + + # Delegate to service layer + service.delete_collection(collection_id, s3_client=s3) + + except Exception as e: + logger.error(f"Failed to drop OpenSearch index: {e}", exc_info=True) + # Don't raise - continue with document deletion even if index drop fails + + +def drop_pgvector_collection(repository_id: str, collection_id: str) -> None: + """Drop PGVector collection using repository service. + + Args: + repository_id: Repository ID + collection_id: Collection ID + """ + try: + logger.info(f"Dropping PGVector collection for {collection_id}") + + repository = vs_repo.find_repository_by_id(repository_id) + service = RepositoryServiceFactory.create_service(repository) + + # Delegate to service layer + service.delete_collection(collection_id, s3_client=s3) + + except Exception as e: + logger.error(f"Failed to drop PGVector collection: {e}", exc_info=True) + # Don't raise - continue with document deletion even if collection drop fails + + +def pipeline_delete_collection(job: IngestionJob) -> None: + """ + Delete all documents in a collection. + + Steps: + 1. Drop vector store index for collection (if supported) + 2. Delete all documents from DynamoDB (which also handles subdocuments) + 3. Update collection status to DELETED + + Note: Dropping the index removes all embeddings, so we don't need to + delete them individually from the vector store. + + Args: + job: Ingestion job with collection deletion details + """ + try: + logger.info(f"Deleting collection {job.collection_id} in repository {job.repository_id}") + + repository = vs_repo.find_repository_by_id(job.repository_id) + + # Drop index for faster cleanup (OpenSearch/PGVector) + # This removes all embeddings from the vector store + if RepositoryType.is_type(repository, RepositoryType.OPENSEARCH): + drop_opensearch_index(job.repository_id, job.collection_id) + elif RepositoryType.is_type(repository, RepositoryType.PGVECTOR): + drop_pgvector_collection(job.repository_id, job.collection_id) + elif RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): + # For Bedrock KB, use bulk delete for efficiency + # Only delete LISA-managed documents (MANUAL/AUTO), preserve EXISTING + logger.info("Bedrock KB collection - bulk deleting LISA-managed documents from knowledge base") + pk = f"{job.repository_id}#{job.collection_id}" + + dynamodb = boto3.resource("dynamodb") + doc_table = dynamodb.Table(os.environ["RAG_DOCUMENT_TABLE"]) + + response = doc_table.query(KeyConditionExpression=Key("pk").eq(pk)) + documents = response.get("Items", []) + + # Continue pagination if needed + while "LastEvaluatedKey" in response: + response = doc_table.query( + KeyConditionExpression=Key("pk").eq(pk), ExclusiveStartKey=response["LastEvaluatedKey"] + ) + documents.extend(response.get("Items", [])) + + logger.info(f"Found {len(documents)} total documents in collection") + + # Separate by ingestion type + lisa_managed = [doc for doc in documents if doc.get("ingestion_type") in ["manual", "auto"]] + user_managed = [doc for doc in documents if doc.get("ingestion_type") == "existing"] + + logger.info( + f"Collection {job.collection_id}: " + f"lisa_managed={len(lisa_managed)}, user_managed={len(user_managed)}" + ) + + # Extract S3 paths for LISA-managed documents only + s3_paths = [doc.get("source", "") for doc in lisa_managed if doc.get("source")] + + if s3_paths: + try: + bulk_delete_documents_from_kb( + s3_client=s3, + bedrock_agent_client=bedrock_agent, + repository=repository, + s3_paths=s3_paths, + data_source_id=job.collection_id, + ) + logger.info( + f"Successfully bulk deleted {len(s3_paths)} LISA-managed documents from KB, " + f"preserved {len(user_managed)} user-managed documents" + ) + except Exception as e: + logger.error(f"Failed to bulk delete documents from Bedrock KB: {e}") + # Continue with DynamoDB deletion even if KB deletion fails + else: + logger.info("No LISA-managed documents to delete from KB") + + # Delete all documents and subdocuments from DynamoDB + # This method handles pagination and batch deletion + logger.info(f"Deleting all documents from DynamoDB for collection {job.collection_id}") + rag_document_repository.delete_all(job.repository_id, job.collection_id) + logger.info("Successfully deleted all documents from DynamoDB") + + # Delete collection DB entry + is_default_collection = job.embedding_model is not None + if not is_default_collection: + collection_repo.delete(job.collection_id, job.repository_id) + + # Update job status + ingestion_job_repository.update_status(job, IngestionStatus.DELETE_COMPLETED) + logger.info(f"Successfully deleted collection {job.collection_id}") + + except Exception as e: + ingestion_job_repository.update_status(job, IngestionStatus.DELETE_FAILED) + logger.error(f"Failed to delete collection: {str(e)}", exc_info=True) + + # Update collection status to DELETE_FAILED + try: + collection_repo.update(job.collection_id, job.repository_id, {"status": CollectionStatus.DELETE_FAILED}) + except Exception as update_error: + logger.error(f"Failed to update collection status: {update_error}") + + raise def pipeline_delete(job: IngestionJob) -> None: + """ + Route deletion job to appropriate handler based on job type. + + Args: + job: Ingestion job with deletion details + """ + # Check job type and route accordingly + if job.job_type == JobActionType.COLLECTION_DELETION: + logger.info(f"Routing to collection deletion for job {job.id}") + pipeline_delete_collection(job) + elif job.job_type == JobActionType.DOCUMENT_BATCH_DELETION: + logger.info(f"Routing to batch document deletion for job {job.id}") + pipeline_delete_documents(job) + else: + # Default to single document deletion + logger.info(f"Routing to document deletion for job {job.id}") + pipeline_delete_document(job) + + +def pipeline_delete_document(job: IngestionJob) -> None: + """ + Delete a single document. + + Args: + job: Ingestion job with document deletion details + """ try: logger.info(f"Deleting document {job.s3_path} for repository {job.repository_id}") @@ -75,9 +253,126 @@ def pipeline_delete(job: IngestionJob) -> None: raise Exception(error_msg) -def handle_pipeline_delete_event(event: Dict[str, Any], context: Any) -> None: - """Handle pipeline document ingestion.""" +def pipeline_delete_documents(job: IngestionJob) -> None: + """ + Delete multiple documents in batch (up to 100 at a time). + + Processes documents from document_ids field containing list of document IDs. + + Args: + job: Ingestion job with batch deletion details + """ + try: + logger.info(f"Starting batch deletion for job {job.id}") + + # Extract document list from document_ids field + if not job.document_ids: + raise ValueError("Batch deletion job missing 'document_ids' field") + + document_ids = job.document_ids + if not isinstance(document_ids, list): + raise ValueError("'document_ids' must be a list") + + if len(document_ids) > 100: + raise ValueError(f"Batch size {len(document_ids)} exceeds maximum of 100 documents") + logger.info(f"Processing {len(document_ids)} documents in batch deletion") + + # Update job status + ingestion_job_repository.update_status(job, IngestionStatus.DELETE_IN_PROGRESS) + + # Get repository for vector store operations + repository = vs_repo.find_repository_by_id(job.repository_id) + is_bedrock_kb = RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB) + + # Process each document + successful = 0 + failed = 0 + errors = [] + # For Bedrock KB, group S3 paths by data source (collection_id) + s3_paths_by_data_source = {} + + for document_id in document_ids: + try: + # Find associated RagDocument + rag_document = rag_document_repository.find_by_id(document_id, join_docs=True) + + if rag_document: + # For Bedrock KB, collect S3 paths for bulk deletion grouped by data source + if is_bedrock_kb: + data_source_id = rag_document.collection_id + if data_source_id not in s3_paths_by_data_source: + s3_paths_by_data_source[data_source_id] = [] + s3_paths_by_data_source[data_source_id].append(rag_document.source) + else: + # Remove from vector store immediately for non-Bedrock + remove_document_from_vectorstore(rag_document) + + # Remove from DDB + rag_document_repository.delete_by_id(rag_document.document_id) + successful += 1 + logger.info(f"Successfully deleted document {document_id}") + else: + # Document not found, count as successful (idempotent) + successful += 1 + logger.warning(f"Document {document_id} not found, skipping") + + except Exception as e: + failed += 1 + error_msg = f"Failed to delete document {document_id}: {str(e)}" + logger.error(error_msg, exc_info=True) + errors.append(error_msg) + + # For Bedrock KB, perform bulk deletion per data source + if is_bedrock_kb and s3_paths_by_data_source: + for data_source_id, s3_paths in s3_paths_by_data_source.items(): + try: + bulk_delete_documents_from_kb( + s3_client=s3, + bedrock_agent_client=bedrock_agent, + repository=repository, + s3_paths=s3_paths, + data_source_id=data_source_id, + ) + logger.info( + f"Successfully bulk deleted {len(s3_paths)} documents from Bedrock KB " + "data source {data_source_id}" + ) + except Exception as e: + logger.error( + "Failed to bulk delete from Bedrock KB data source %s: %s", + data_source_id, + e, + exc_info=True, + ) + # Documents already deleted from DynamoDB, continue with partial success + # This is acceptable because DynamoDB is source of truth + + # Update job with results in metadata + if not job.metadata: + job.metadata = {} + job.metadata["results"] = { + "successful": successful, + "failed": failed, + "errors": errors[:10], # Limit error messages + } + + if failed == 0: + ingestion_job_repository.update_status(job, IngestionStatus.DELETE_COMPLETED) + logger.info(f"Batch deletion completed: {successful} successful, {failed} failed") + else: + ingestion_job_repository.update_status(job, IngestionStatus.DELETE_FAILED) + logger.warning(f"Batch deletion completed with errors: {successful} successful, {failed} failed") + + except Exception as e: + ingestion_job_repository.update_status(job, IngestionStatus.DELETE_FAILED) + error_msg = f"Failed to process batch deletion: {str(e)}" + logger.error(error_msg, exc_info=True) + raise Exception(error_msg) + + +def handle_pipeline_delete_event(event: Dict[str, Any], context: Any) -> None: + """Handle pipeline document deletion for S3 ObjectRemoved events.""" # Extract and validate inputs logger.debug(f"Received event: {event}") @@ -85,36 +380,89 @@ def handle_pipeline_delete_event(event: Dict[str, Any], context: Any) -> None: bucket = detail.get("bucket", None) key = detail.get("key", None) repository_id = detail.get("repositoryId", None) + collection_id = detail.get("collectionId", None) pipeline_config = detail.get("pipelineConfig", None) - if not pipeline_config or not isinstance(pipeline_config, dict): - # If pipeline_config is missing or not a dict, skip + s3_path = f"s3://{bucket}/{key}" + + if not repository_id: + logger.warning("No repository_id in event, skipping deletion") return - embedding_model = pipeline_config.get("embeddingModel", None) - if embedding_model is None: - # If embedding_model is missing, skip + + # Get repository to determine type and configuration + repository = vs_repo.find_repository_by_id(repository_id) + if not repository: + logger.warning(f"Repository {repository_id} not found, skipping deletion") return - s3_key = f"s3://{bucket}/{key}" - logger.info(f"Deleting object {s3_key} for repository {repository_id}/{embedding_model}") + # For Bedrock KB repositories, use data source ID as collection ID + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): + if not collection_id: + # Fallback: try to get from bedrock config (legacy support) + bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + + # Try new structure with dataSources array + data_sources = bedrock_config.get("dataSources", []) + if data_sources: + first_data_source = data_sources[0] + collection_id = ( + first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id + ) + else: + # Try legacy single data source ID + collection_id = bedrock_config.get("bedrockKnowledgeDatasourceId") + + if not collection_id: + logger.error(f"Bedrock KB repository {repository_id} missing data source ID") + return + + logger.info( + f"Processing Bedrock KB document deletion {s3_path} for repository {repository_id}, " + f"collection {collection_id}" + ) + else: + if not pipeline_config or not isinstance(pipeline_config, dict): + logger.warning("No pipeline_config in event, skipping deletion") + return + + embedding_model = pipeline_config.get("embeddingModel", None) + if embedding_model is None: + logger.warning("No embedding_model in pipeline_config, skipping deletion") + return + + collection_id = embedding_model + logger.info(f"Deleting object {s3_path} for repository {repository_id}/{embedding_model}") + + # Find documents by source path (idempotent - handles missing documents gracefully) + documents = rag_document_repository.find_by_source( + repository_id=repository_id, + collection_id=collection_id, + document_source=s3_path, + join_docs=False, # Don't need subdocs for deletion + ) + + if not documents: + logger.info(f"Document {s3_path} not found in tracking system, already deleted or never tracked") + return # Idempotent - success even if document doesn't exist + + # Delete each found document + for rag_document in documents: + logger.info(f"Deleting tracked document {rag_document.document_id} from {s3_path}") - # Currently there could be RagDocuments without a corresponding IngestionJob, so lookup by RagDocument first - # and then find or create the corresponding IngestionJob. In the future it should be possible to lookup - # directly by IngestionJob - for rag_document in rag_document_repository.find_by_source( - repository_id=repository_id, collection_id=embedding_model, document_source=s3_key, join_docs=True - ): - logger.info(f"deleting doc {rag_document.model_dump()}") + # Find or create ingestion job for deletion ingestion_job = ingestion_job_repository.find_by_document(rag_document.document_id) if ingestion_job is None: ingestion_job = IngestionJob( repository_id=repository_id, - collection_id=embedding_model, + collection_id=collection_id, + embedding_model=collection_id, # Use collection_id as embedding_model chunk_strategy=None, s3_path=rag_document.source, username=rag_document.username, ingestion_type=IngestionType.AUTO, status=IngestionStatus.DELETE_PENDING, ) + ingestion_job_repository.save(ingestion_job) + # Submit deletion job ingestion_service.create_delete_job(ingestion_job) - logger.info(f"Deleting document {s3_key} for repository {ingestion_job.repository_id}") + logger.info(f"Submitted deletion job for document {s3_path} in repository {repository_id}") diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 2b82080e2..09f86e625 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -20,55 +20,200 @@ from typing import Any, Dict, List import boto3 -from models.domain_objects import FixedChunkingStrategy, IngestionJob, IngestionStatus, IngestionType, RagDocument +from models.domain_objects import ( + ChunkingStrategy, + FixedChunkingStrategy, + IngestionJob, + IngestionStatus, + IngestionType, + JobActionType, + NoneChunkingStrategy, + RagDocument, +) +from repository.collection_service import CollectionService from repository.embeddings import RagEmbeddings from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService +from repository.metadata_generator import MetadataGenerator from repository.rag_document_repo import RagDocumentRepository +from repository.s3_metadata_manager import S3MetadataManager +from repository.services.repository_service_factory import RepositoryServiceFactory from repository.vector_store_repo import VectorStoreRepository from utilities.auth import get_username -from utilities.bedrock_kb import ingest_document_to_kb +from utilities.bedrock_kb import get_datasource_bucket_for_collection, ingest_document_to_kb, S3DocumentDiscoveryService from utilities.common_functions import retry_config from utilities.file_processing import generate_chunks from utilities.repository_types import RepositoryType -from utilities.vector_store import get_vector_store_client dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) ingestion_job_table = dynamodb.Table(os.environ["LISA_INGESTION_JOB_TABLE_NAME"]) ingestion_service = DocumentIngestionService() ingestion_job_repository = IngestionJobRepository() vs_repo = VectorStoreRepository() +rag_document_repository = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"]) +collection_service = CollectionService(vector_store_repo=vs_repo, document_repo=rag_document_repository) logger = logging.getLogger(__name__) session = boto3.Session() s3 = boto3.client("s3", region_name=os.environ["AWS_REGION"], config=retry_config) bedrock_agent = boto3.client("bedrock-agent", region_name=os.environ["AWS_REGION"], config=retry_config) ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) -rag_document_repository = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"]) def pipeline_ingest(job: IngestionJob) -> None: - texts = [] - metadatas = [] - all_ids = [] + """ + Ingest a single document or batch of documents. + + Routes to appropriate handler based on job type. + """ + if job.job_type == JobActionType.DOCUMENT_BATCH_INGESTION: + pipeline_ingest_documents(job) + else: + pipeline_ingest_document(job) + + +def pipeline_ingest_document(job: IngestionJob) -> None: + """Ingest a single document.""" + texts: list[str] = [] + metadatas: list[dict] = [] + all_ids: list[str] = [] try: # chunk and save chunks in vector store repository = vs_repo.find_repository_by_id(job.repository_id) if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): - ingest_document_to_kb( - s3_client=s3, - bedrock_agent_client=bedrock_agent, - job=job, - repository=repository, + # Bedrock KB path: Copy document to KB bucket and track + # Get KB bucket for this collection (supports multiple config formats) + try: + kb_bucket = get_datasource_bucket_for_collection( + repository=repository, + collection_id=job.collection_id, + ) + except ValueError as e: + error_msg = str(e) + logger.error(error_msg) + job.status = IngestionStatus.INGESTION_FAILED + job.error_message = error_msg + ingestion_job_repository.save(job) + raise + + # Determine if document needs to be copied to KB bucket + source_bucket = job.s3_path.split("/")[2] + needs_copy = source_bucket != kb_bucket + + # Set canonical KB path + kb_s3_path = f"s3://{kb_bucket}/{os.path.basename(job.s3_path)}" + + if needs_copy: + # Document uploaded to LISA bucket, needs to be copied to KB bucket + logger.info( + f"Document {job.s3_path} uploaded to LISA bucket. " f"Copying to KB data source bucket {kb_bucket}" + ) + + # Check if document already exists (idempotent operation) + existing_docs = list( + rag_document_repository.find_by_source( + job.repository_id, job.collection_id, kb_s3_path, join_docs=False + ) ) - else: - documents = generate_chunks(job) - texts, metadatas = prepare_chunks(documents, job.repository_id) - all_ids = store_chunks_in_vectorstore(texts, metadatas, job.repository_id, job.collection_id) + + if existing_docs and not needs_copy: + # Document already tracked and in KB bucket, update upload_date and return + existing_doc = existing_docs[0] + existing_doc.upload_date = int(datetime.now(timezone.utc).timestamp() * 1000) + rag_document_repository.save(existing_doc) + + job.status = IngestionStatus.INGESTION_COMPLETED + job.document_id = existing_doc.document_id + ingestion_job_repository.save(job) + logger.info(f"Document {kb_s3_path} already tracked, updated upload_date") + return + + # Copy document to KB bucket if needed (user upload via LISA) + if needs_copy: + try: + # This will copy file to KB bucket, delete from source, and trigger KB ingestion + ingest_document_to_kb( + s3_client=s3, + bedrock_agent_client=bedrock_agent, + job=job, + repository=repository, + ) + logger.info(f"Copied document from {job.s3_path} to {kb_s3_path}") + except Exception as e: + logger.error(f"Failed to copy document to KB bucket: {e}") + raise + + rag_document = RagDocument( + repository_id=job.repository_id, + collection_id=job.collection_id, + document_name=os.path.basename(kb_s3_path), + source=kb_s3_path, # Use KB bucket path as canonical source + subdocs=[], # Empty - KB manages chunks internally + chunk_strategy=NoneChunkingStrategy(), # KB manages chunking + username=job.username, + ingestion_type=job.ingestion_type, + ) + rag_document_repository.save(rag_document) + + # Generate and upload metadata.json file for Bedrock KB + try: + metadata_generator = MetadataGenerator() + s3_metadata_manager = S3MetadataManager() + + # Get collection for metadata + collection = None + try: + collection = collection_service.get_collection( + collection_id=job.collection_id, + repository_id=job.repository_id, + username="system", + user_groups=[], + is_admin=True, + ) + except Exception as e: + logger.warning(f"Could not fetch collection for metadata: {e}") + + # Generate metadata content + metadata_content = metadata_generator.generate_metadata_json( + repository=repository, collection=collection, document_metadata=job.metadata + ) + + # Extract bucket and key from S3 path + bucket_name = kb_s3_path.split("/")[2] + document_key = "/".join(kb_s3_path.split("/")[3:]) + + # Upload metadata file + s3_metadata_manager.upload_metadata_file( + s3_client=s3, bucket=bucket_name, document_key=document_key, metadata_content=metadata_content + ) + logger.info(f"Created metadata file for {kb_s3_path}") + except Exception as e: + logger.error(f"Failed to create metadata file for {kb_s3_path}: {e}") + # Continue with ingestion even if metadata fails + + job.status = IngestionStatus.INGESTION_COMPLETED + job.document_id = rag_document.document_id + ingestion_job_repository.save(job) + logger.info( + f"Tracked document {kb_s3_path} for Bedrock KB repository {job.repository_id}. " + f"KB will handle ingestion automatically." + ) + return # Early return for Bedrock KB path + + # Non-Bedrock KB path + documents = generate_chunks(job) + texts, metadatas = prepare_chunks(documents, job.repository_id, job.collection_id) + all_ids = store_chunks_in_vectorstore( + texts=texts, + metadatas=metadatas, + repository_id=job.repository_id, + collection_id=job.collection_id, + embedding_model=job.embedding_model, + ) # remove old - for rag_document in rag_document_repository.find_by_source( - job.repository_id, job.collection_id, job.s3_path, join_docs=True + for rag_document in list( + rag_document_repository.find_by_source(job.repository_id, job.collection_id, job.s3_path, join_docs=True) ): prev_job = ingestion_job_repository.find_by_document(rag_document.document_id) @@ -112,12 +257,163 @@ def pipeline_ingest(job: IngestionJob) -> None: raise Exception(error_msg) +def pipeline_ingest_documents(job: IngestionJob) -> None: + """ + Ingest multiple documents in batch (up to 100 at a time). + + Processes documents from s3_paths field containing list of S3 paths. + If s3_paths is empty, triggers S3 bucket scan to discover existing documents. + """ + try: + logger.info(f"Starting batch ingestion for job {job.id}") + + # Check if this is an S3 discovery scan job (empty s3_paths) + if not job.s3_paths: + # Handle S3 bucket scanning for existing documents + _handle_s3_discovery_scan(job) + return + + # Normal batch ingestion path + # Extract document list from s3_paths field + + document_paths = job.s3_paths + if not isinstance(document_paths, list): + raise ValueError("'s3_paths' must be a list") + + if len(document_paths) > 100: + raise ValueError(f"Batch size {len(document_paths)} exceeds maximum of 100 documents") + + logger.info(f"Processing {len(document_paths)} documents in batch") + + # Update job status + ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_IN_PROGRESS) + + # Process each document and collect document IDs + successful = 0 + failed = 0 + errors = [] + document_ids = [] + + for s3_path in document_paths: + try: + # Create individual job for each document + doc_job = IngestionJob( + repository_id=job.repository_id, + collection_id=job.collection_id, + embedding_model=job.embedding_model, + chunk_strategy=job.chunk_strategy, + s3_path=s3_path, + username=job.username, + metadata=job.metadata, + ingestion_type=job.ingestion_type, + job_type=JobActionType.DOCUMENT_INGESTION, + ) + + # Process the document + pipeline_ingest_document(doc_job) + successful += 1 + document_ids.append(doc_job.document_id) + logger.info(f"Successfully ingested document {s3_path}") + + except Exception as e: + failed += 1 + error_msg = f"Failed to ingest {s3_path}: {str(e)}" + logger.error(error_msg, exc_info=True) + errors.append(error_msg) + + # Update job with document IDs + job.document_ids = document_ids + + if failed == 0: + ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_COMPLETED) + logger.info(f"Batch ingestion completed: {successful} successful, {failed} failed") + else: + ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_FAILED) + logger.warning(f"Batch ingestion completed with errors: {successful} successful, {failed} failed") + + except Exception as e: + ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_FAILED) + error_msg = f"Failed to process batch ingestion: {str(e)}" + logger.error(error_msg, exc_info=True) + raise Exception(error_msg) + + +def _handle_s3_discovery_scan(job: IngestionJob) -> None: + """ + Handle S3 bucket scanning for existing documents. + + Delegates to S3DocumentDiscoveryService for the actual work. + + Args: + job: Batch ingestion job with empty s3_paths (signals scan mode) + """ + try: + logger.info(f"Starting S3 discovery scan for job {job.id}") + + # Extract bucket and prefix from s3_path field + # Format: s3://bucket/prefix or s3://bucket/ + if not job.s3_path or not job.s3_path.startswith("s3://"): + raise ValueError("S3 scan job missing valid 's3_path' field") + + # Parse s3://bucket/prefix + path_parts = job.s3_path.replace("s3://", "").split("/", 1) + s3_bucket = path_parts[0] + s3_prefix = path_parts[1] if len(path_parts) > 1 else "" + + # Remove trailing slash from prefix if present + if s3_prefix.endswith("/"): + s3_prefix = s3_prefix[:-1] + + # Update job status + ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_IN_PROGRESS) + + # Initialize discovery service + metadata_generator = MetadataGenerator() + s3_metadata_manager = S3MetadataManager() + + discovery_service = S3DocumentDiscoveryService( + s3_client=s3, + bedrock_agent_client=bedrock_agent, + rag_document_repository=rag_document_repository, + metadata_generator=metadata_generator, + s3_metadata_manager=s3_metadata_manager, + collection_service=collection_service, + vector_store_repo=vs_repo, + ) + + # Perform discovery and ingestion + result = discovery_service.discover_and_ingest_documents( + repository_id=job.repository_id, + collection_id=job.collection_id, + s3_bucket=s3_bucket, + s3_prefix=s3_prefix, + ingestion_type=job.ingestion_type, + ) + + # Update job with results + job.document_ids = result.document_ids + + if result.failed == 0: + ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_COMPLETED) + else: + ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_FAILED) + + except Exception as e: + ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_FAILED) + error_msg = f"Failed to process S3 discovery scan: {str(e)}" + logger.error(error_msg, exc_info=True) + raise Exception(error_msg) + + def remove_document_from_vectorstore(doc: RagDocument) -> None: - # Delete from the Vector Store + """Delete document from vector store using repository service.""" + vs_repo = VectorStoreRepository() + repository = vs_repo.find_repository_by_id(doc.repository_id) + + service = RepositoryServiceFactory.create_service(repository) embeddings = RagEmbeddings(model_name=doc.collection_id) - vector_store = get_vector_store_client( - doc.repository_id, - index=doc.collection_id, + vector_store = service.get_vector_store_client( + collection_id=doc.collection_id, embeddings=embeddings, ) vector_store.delete(doc.subdocs) @@ -132,28 +428,85 @@ def handle_pipeline_ingest_event(event: Dict[str, Any], context: Any) -> None: bucket = detail.get("bucket", None) username = get_username(event) key = detail.get("key", None) + + # Safety check: filter out metadata files (should be filtered by EventBridge) + if key and key.endswith(".metadata.json"): + logger.warning(f"Metadata file event reached Lambda (should be filtered by EventBridge): {key}") + return repository_id = detail.get("repositoryId", None) pipeline_config = detail.get("pipelineConfig", None) - embedding_model = pipeline_config.get("embeddingModel", None) + collection_id = detail.get("collectionId", None) s3_path = f"s3://{bucket}/{key}" - logger.info(f"Ingesting object {s3_path} for repository {repository_id}/{embedding_model}") + # Get repository to determine type and configuration + repository = vs_repo.find_repository_by_id(repository_id) + + # For Bedrock KB repositories, use data source ID as collection ID + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): + if not collection_id: + # Fallback: try to get from bedrock config (legacy support) + bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + + # Try new structure with dataSources array + data_sources = bedrock_config.get("dataSources", []) + if data_sources: + first_data_source = data_sources[0] + if isinstance(first_data_source, dict): + collection_id = first_data_source.get("id") + else: + collection_id = getattr(first_data_source, "id", None) + else: + # Try legacy single data source ID + collection_id = bedrock_config.get("bedrockKnowledgeDatasourceId") + + if not collection_id: + logger.error(f"Bedrock KB repository {repository_id} missing data source ID") + return + + embedding_model = repository.get("embeddingModelId") + chunk_strategy = NoneChunkingStrategy() # KB manages chunking + + # Set username to "system" for auto-ingestion from KB bucket + username = "system" + ingestion_type = IngestionType.AUTO - chunk_strategy = extract_chunk_strategy(pipeline_config) + logger.info( + f"Processing Bedrock KB document {s3_path} for repository {repository_id}, " f"collection {collection_id}" + ) + else: + # Non-Bedrock KB path (existing logic) + embedding_model = pipeline_config.get("embeddingModel", None) + + if collection_id: + collection = collection_service.get_collection( + collection_id=collection_id, repository_id=repository_id, is_admin=True, username="", user_groups=[] + ) + + if collection.embeddingModel is not None: + embedding_model = collection.embeddingModel + else: + collection_id = embedding_model - # create ingestion job and save it to dynamodb + chunk_strategy = extract_chunk_strategy(pipeline_config) + ingestion_type = IngestionType.MANUAL + + logger.info(f"Ingesting object {s3_path} for repository {repository_id}/{embedding_model}") + + # Create ingestion job and save it to dynamodb job = IngestionJob( repository_id=repository_id, - collection_id=embedding_model, + collection_id=collection_id, + embedding_model=embedding_model, chunk_strategy=chunk_strategy, s3_path=s3_path, username=username, - ingestion_type=IngestionType.MANUAL, + ingestion_type=ingestion_type, + metadata=None, ) ingestion_job_repository.save(job) - ingestion_service.create_ingest_job(job) + ingestion_service.submit_create_job(job) - logger.info(f"Ingesting document {s3_path} for repository {repository_id}") + logger.info(f"Submitted ingestion job for document {s3_path} in repository {repository_id}") def handle_pipline_ingest_schedule(event: Dict[str, Any], context: Any) -> None: @@ -228,9 +581,10 @@ def handle_pipline_ingest_schedule(event: Dict[str, Any], context: Any) -> None: s3_path=f"s3://{bucket}/{key}", username=username, ingestion_type=IngestionType.AUTO, + metadata=None, ) ingestion_job_repository.save(job) - ingestion_service.create_ingest_job(job) + ingestion_service.submit_create_job(job) logger.info(f"Found {len(modified_keys)} modified files in {bucket}{prefix}") except Exception as e: @@ -257,15 +611,48 @@ def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) return batches -def extract_chunk_strategy(pipeline_config: Dict) -> FixedChunkingStrategy: - """Extract and validate configuration parameters.""" - chunk_size = int(pipeline_config["chunkSize"]) - chunk_overlap = int(pipeline_config["chunkOverlap"]) +def extract_chunk_strategy(pipeline_config: Dict) -> ChunkingStrategy: + """ + Extract and validate chunking strategy from pipeline configuration. + + Supports both new chunkingStrategy object format and legacy flat fields for backward compatibility. + Uses Pydantic model validation to ensure data integrity. + + Args: + pipeline_config: Pipeline configuration dictionary - return FixedChunkingStrategy(size=chunk_size, overlap=chunk_overlap) + Returns: + ChunkingStrategy object (validated Pydantic model) + Raises: + ValueError: If chunking strategy type is unsupported or validation fails + """ + # Check for new chunkingStrategy object format first + if "chunkingStrategy" in pipeline_config and pipeline_config["chunkingStrategy"]: + chunking_strategy = pipeline_config["chunkingStrategy"] + chunk_type = chunking_strategy.get("type", "fixed") + + if chunk_type == "fixed": + # Use Pydantic model validation for type safety and validation + return FixedChunkingStrategy.model_validate(chunking_strategy) + else: + # Future: Handle other chunking strategy types (semantic, recursive, etc.) + raise ValueError(f"Unsupported chunking strategy type: {chunk_type}") -def prepare_chunks(docs: List, repository_id: str) -> tuple[List[str], List[Dict]]: + # Fall back to legacy flat fields for backward compatibility + elif "chunkSize" in pipeline_config and "chunkOverlap" in pipeline_config: + chunk_size = int(pipeline_config["chunkSize"]) + chunk_overlap = int(pipeline_config["chunkOverlap"]) + # Use Pydantic model for validation + return FixedChunkingStrategy(size=chunk_size, overlap=chunk_overlap) + + # Default values if neither format is present + else: + logger.warning("No chunking strategy found in pipeline config, using defaults") + return FixedChunkingStrategy(size=512, overlap=51) + + +def prepare_chunks(docs: List, repository_id: str, collection_id: str) -> tuple[List[str], List[Dict]]: """Prepare texts and metadata from document chunks.""" texts = [] metadatas = [] @@ -273,19 +660,23 @@ def prepare_chunks(docs: List, repository_id: str) -> tuple[List[str], List[Dict for doc in docs: texts.append(doc.page_content) doc.metadata["repository_id"] = repository_id + doc.metadata["collection_id"] = collection_id metadatas.append(doc.metadata) return texts, metadatas def store_chunks_in_vectorstore( - texts: List[str], metadatas: List[Dict], repository_id: str, embedding_model: str + texts: List[str], metadatas: List[Dict], repository_id: str, collection_id: str, embedding_model: str ) -> List[str]: - """Store document chunks in vector store.""" + """Store document chunks in vector store using repository service.""" + vs_repo = VectorStoreRepository() + repository = vs_repo.find_repository_by_id(repository_id) + + service = RepositoryServiceFactory.create_service(repository) embeddings = RagEmbeddings(model_name=embedding_model) - vs = get_vector_store_client( - repository_id, - index=embedding_model, + vs = service.get_vector_store_client( + collection_id=collection_id, embeddings=embeddings, ) diff --git a/lambda/repository/rag_document_repo.py b/lambda/repository/rag_document_repo.py index 8dce81e07..8c3583179 100644 --- a/lambda/repository/rag_document_repo.py +++ b/lambda/repository/rag_document_repo.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import os +from concurrent.futures import as_completed, ThreadPoolExecutor from typing import Generator, Optional import boto3 @@ -325,33 +326,151 @@ def _get_subdoc_ids(self, entries: list[RagSubDocument]) -> list[str]: return [doc for entry in entries for doc in entry.subdocs] def delete_s3_object(self, uri: str) -> None: - """Delete an object from S3. + """Delete an object and its metadata file from S3. Args: - key: The key of the object to delete + uri: The S3 URI of the object to delete (s3://bucket/key) """ try: bucket, key = uri.replace("s3://", "").split("/", 1) + + # Delete document logging.info(f"Deleting S3 object: {bucket}/{key}") self.s3_client.delete_object(Bucket=bucket, Key=key) + + # Delete metadata file + metadata_key = f"{key}.metadata.json" + try: + logging.info(f"Deleting metadata file: {bucket}/{metadata_key}") + self.s3_client.delete_object(Bucket=bucket, Key=metadata_key) + except ClientError as e: + # Metadata file may not exist (idempotent) + if e.response["Error"]["Code"] != "NoSuchKey": + logging.warning(f"Failed to delete metadata file: {e}") + except ClientError as e: logging.error(f"Error deleting S3 object: {e.response['Error']['Message']}") raise + def delete_all(self, repository_id: str, collection_id: str) -> None: + """Delete all documents and subdocuments for a collection. + + Args: + repository_id: Repository ID + collection_id: Collection ID + """ + pk = RagDocument.createPartitionKey(repository_id, collection_id) + doc_ids = [] + + # Query and delete documents, collecting doc_ids in single pass + response = self.doc_table.query(KeyConditionExpression=Key("pk").eq(pk), ProjectionExpression="pk,document_id") + with self.doc_table.batch_writer() as batch: + for item in response["Items"]: + doc_ids.append(item["document_id"]) + batch.delete_item(Key={"pk": item["pk"], "document_id": item["document_id"]}) + + while "LastEvaluatedKey" in response: + response = self.doc_table.query( + KeyConditionExpression=Key("pk").eq(pk), + ProjectionExpression="pk,document_id", + ExclusiveStartKey=response["LastEvaluatedKey"], + ) + with self.doc_table.batch_writer() as batch: + for item in response["Items"]: + doc_ids.append(item["document_id"]) + batch.delete_item(Key={"pk": item["pk"], "document_id": item["document_id"]}) + + # Delete subdocuments in parallel + def delete_subdocs(doc_id: str) -> None: + response = self.subdoc_table.query( + KeyConditionExpression=Key("document_id").eq(doc_id), ProjectionExpression="document_id,sk" + ) + with self.subdoc_table.batch_writer() as batch: + for item in response["Items"]: + batch.delete_item(Key={"document_id": item["document_id"], "sk": item["sk"]}) + while "LastEvaluatedKey" in response: + response = self.subdoc_table.query( + KeyConditionExpression=Key("document_id").eq(doc_id), + ProjectionExpression="document_id,sk", + ExclusiveStartKey=response["LastEvaluatedKey"], + ) + with self.subdoc_table.batch_writer() as batch: + for item in response["Items"]: + batch.delete_item(Key={"document_id": item["document_id"], "sk": item["sk"]}) + + if doc_ids: + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(delete_subdocs, doc_id) for doc_id in doc_ids] + for future in as_completed(futures): + future.result() + def delete_s3_docs(self, repository_id: str, docs: list[RagDocument]) -> list[str]: - """Remove documents from S3""" + """Remove documents from S3 based on ingestion type. + + Only deletes S3 objects for MANUAL and AUTO ingestion types. + EXISTING documents are preserved in S3 (user-managed). + + Args: + repository_id: The repository ID + docs: List of RagDocument objects + + Returns: + List of S3 URIs that were removed + """ repo = self.vs_repo.find_repository_by_id(repository_id=repository_id) - pipelines = { - pipeline.get("embeddingModel"): pipeline.get("autoRemove", False) is True - for pipeline in repo.get("pipelines", []) - } - removed_source: list[str] = [ - doc.source - for doc in docs - if doc and (doc.ingestion_type != IngestionType.AUTO or pipelines.get(doc.collection_id)) - ] + + # Build mapping of collection IDs to autoRemove setting + collection_auto_remove = {} + for pipeline in repo.get("pipelines", []): + embedding_model = pipeline.get("embeddingModel") + auto_remove = pipeline.get("autoRemove", False) is True + if embedding_model: + collection_auto_remove[embedding_model] = auto_remove + + # Determine which documents should be removed from S3 + removed_source: list[str] = [] + preserved_count = 0 + + for doc in docs: + if not doc: + continue + + doc_source = doc.source + doc_ingestion_type = doc.ingestion_type + doc_collection_id = doc.collection_id + + if not doc_source: + continue + + # EXISTING documents: never remove from S3 (user-managed) + if doc_ingestion_type == IngestionType.EXISTING: + logging.info(f"Preserving user-managed document in S3: {doc_source}") + preserved_count += 1 + continue + + # MANUAL ingestion: always remove from S3 + if doc_ingestion_type == IngestionType.MANUAL: + removed_source.append(doc_source) + continue + + # AUTO ingestion: only remove if pipeline exists and has autoRemove enabled + if doc_ingestion_type == IngestionType.AUTO: + auto_remove = collection_auto_remove.get(doc_collection_id, False) + if auto_remove: + removed_source.append(doc_source) + else: + logging.info(f"Preserving AUTO document (autoRemove=False or no pipeline): {doc_source}") + preserved_count += 1 + + # Delete from S3 for source in removed_source: - logging.info(f"Removing S3 doc: {source}") - self.delete_s3_object(uri=source) + try: + logging.info(f"Removing S3 doc: {source}") + self.delete_s3_object(uri=source) + except Exception as e: + logging.error(f"Failed to delete S3 object {source}: {e}") + # Continue with other deletions + + logging.info(f"S3 deletion complete: deleted={len(removed_source)}, preserved={preserved_count}") return removed_source diff --git a/lambda/repository/s3_metadata_manager.py b/lambda/repository/s3_metadata_manager.py new file mode 100644 index 000000000..85b758b0a --- /dev/null +++ b/lambda/repository/s3_metadata_manager.py @@ -0,0 +1,255 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""S3 metadata file manager for Bedrock KB documents.""" + +import json +import logging +from typing import Any, Dict, List, Optional, Tuple + +import boto3 +from botocore.exceptions import ClientError + +logger = logging.getLogger(__name__) + +# Retry configuration +MAX_RETRIES = 3 +RETRY_DELAY = 1 # seconds + + +class S3MetadataManager: + """Manager for S3 metadata file operations.""" + + def __init__(self, cloudwatch_client=None): + """Initialize S3 metadata manager. + + Args: + cloudwatch_client: Optional CloudWatch client for metrics (defaults to creating one) + """ + self.cloudwatch_client = cloudwatch_client or boto3.client("cloudwatch") + + def _emit_metric( + self, + metric_name: str, + value: float = 1.0, + repository_id: Optional[str] = None, + collection_id: Optional[str] = None, + ) -> None: + """Emit CloudWatch metric. + + Args: + metric_name: Name of the metric + value: Metric value + repository_id: Optional repository ID for dimensions + collection_id: Optional collection ID for dimensions + """ + try: + dimensions = [] + if repository_id: + dimensions.append({"Name": "RepositoryId", "Value": repository_id}) + if collection_id: + dimensions.append({"Name": "CollectionId", "Value": collection_id}) + + metric_data = { + "MetricName": metric_name, + "Value": value, + "Unit": "Count", + } + + if dimensions: + metric_data["Dimensions"] = dimensions + + self.cloudwatch_client.put_metric_data(Namespace="LISA/BedrockKB", MetricData=[metric_data]) + except Exception as e: + logger.warning(f"Failed to emit CloudWatch metric {metric_name}: {e}") + + def upload_metadata_file( + self, + s3_client, + bucket: str, + document_key: str, + metadata_content: Dict[str, Any], + repository_id: Optional[str] = None, + collection_id: Optional[str] = None, + ) -> str: + """Upload metadata.json file to S3. + + Args: + s3_client: Boto3 S3 client + bucket: S3 bucket name + document_key: S3 key of the document + metadata_content: Metadata content dictionary + repository_id: Optional repository ID for metrics + collection_id: Optional collection ID for metrics + + Returns: + S3 key of the uploaded metadata file + + Raises: + ClientError: If S3 upload fails after retries + """ + metadata_key = f"{document_key}.metadata.json" + metadata_json = json.dumps(metadata_content, indent=2) + + logger.info( + f"Uploading metadata file: s3://{bucket}/{metadata_key}", + extra={ + "repository_id": repository_id, + "collection_id": collection_id, + "document_key": document_key, + "metadata_key": metadata_key, + }, + ) + + # Upload with retries + for attempt in range(MAX_RETRIES): + try: + s3_client.put_object( + Bucket=bucket, + Key=metadata_key, + Body=metadata_json.encode("utf-8"), + ContentType="application/json", + ) + logger.info(f"Successfully uploaded metadata file: {metadata_key}") + + # Emit success metric + self._emit_metric("MetadataFileCreated", 1.0, repository_id, collection_id) + + return metadata_key + + except ClientError as e: + error_code = e.response["Error"]["Code"] + + # Don't retry on permission errors + if error_code == "AccessDenied": + logger.error(f"Access denied uploading metadata file: {metadata_key}") + self._emit_metric("MetadataFileUploadFailed", 1.0, repository_id, collection_id) + raise + + # Retry on transient errors + if attempt < MAX_RETRIES - 1: + logger.warning(f"Retry {attempt + 1}/{MAX_RETRIES} for metadata upload: {metadata_key}") + continue + else: + logger.error(f"Failed to upload metadata file after {MAX_RETRIES} attempts: {metadata_key}") + self._emit_metric("MetadataFileUploadFailed", 1.0, repository_id, collection_id) + raise + + def delete_metadata_file( + self, + s3_client, + bucket: str, + document_key: str, + repository_id: Optional[str] = None, + collection_id: Optional[str] = None, + ) -> None: + """Delete metadata.json file from S3. + + Args: + s3_client: Boto3 S3 client + bucket: S3 bucket name + document_key: S3 key of the document + repository_id: Optional repository ID for metrics + collection_id: Optional collection ID for metrics + + Note: + This operation is idempotent - no error if file doesn't exist + """ + metadata_key = f"{document_key}.metadata.json" + + logger.info( + f"Deleting metadata file: s3://{bucket}/{metadata_key}", + extra={ + "repository_id": repository_id, + "collection_id": collection_id, + "document_key": document_key, + "metadata_key": metadata_key, + }, + ) + + try: + s3_client.delete_object(Bucket=bucket, Key=metadata_key) + logger.info(f"Successfully deleted metadata file: {metadata_key}") + + # Emit success metric + self._emit_metric("MetadataFileDeleted", 1.0, repository_id, collection_id) + + except ClientError as e: + error_code = e.response["Error"]["Code"] + + # Idempotent - file already deleted + if error_code == "NoSuchKey": + logger.info(f"Metadata file already deleted: {metadata_key}") + return + + # Log other errors but don't fail + logger.warning(f"Failed to delete metadata file: {metadata_key}, error: {e}") + + def batch_upload_metadata(self, s3_client, bucket: str, documents: List[Tuple[str, Dict[str, Any]]]) -> List[str]: + """Upload multiple metadata files in batch. + + Args: + s3_client: Boto3 S3 client + bucket: S3 bucket name + documents: List of (document_key, metadata_content) tuples + + Returns: + List of successfully uploaded metadata file S3 keys + """ + uploaded_keys = [] + failed_count = 0 + + logger.info(f"Batch uploading {len(documents)} metadata files") + + for document_key, metadata_content in documents: + try: + metadata_key = self.upload_metadata_file(s3_client, bucket, document_key, metadata_content) + uploaded_keys.append(metadata_key) + except Exception as e: + logger.error(f"Failed to upload metadata for {document_key}: {e}") + failed_count += 1 + # Continue with other uploads + + logger.info( + f"Batch upload complete: {len(uploaded_keys)} succeeded, {failed_count} failed out of {len(documents)}" + ) + + return uploaded_keys + + def batch_delete_metadata(self, s3_client, bucket: str, document_keys: List[str]) -> int: + """Delete multiple metadata files in batch. + + Args: + s3_client: Boto3 S3 client + bucket: S3 bucket name + document_keys: List of document S3 keys + + Returns: + Number of successfully deleted metadata files + """ + deleted_count = 0 + + logger.info(f"Batch deleting {len(document_keys)} metadata files") + + for document_key in document_keys: + try: + self.delete_metadata_file(s3_client, bucket, document_key) + deleted_count += 1 + except Exception as e: + logger.error(f"Failed to delete metadata for {document_key}: {e}") + # Continue with other deletions + + logger.info(f"Batch delete complete: {deleted_count} deleted out of {len(document_keys)}") + + return deleted_count diff --git a/lambda/repository/services/__init__.py b/lambda/repository/services/__init__.py new file mode 100644 index 000000000..dca8abc2f --- /dev/null +++ b/lambda/repository/services/__init__.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Repository service implementations for handling different vector store types.""" + +from .repository_service import RepositoryService +from .repository_service_factory import RepositoryServiceFactory + +__all__ = ["RepositoryService", "RepositoryServiceFactory"] diff --git a/lambda/repository/services/bedrock_kb_repository_service.py b/lambda/repository/services/bedrock_kb_repository_service.py new file mode 100644 index 000000000..450fc5798 --- /dev/null +++ b/lambda/repository/services/bedrock_kb_repository_service.py @@ -0,0 +1,451 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bedrock Knowledge Base repository service implementation.""" + +import logging +import os +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import boto3 +from boto3.dynamodb.conditions import Key +from models.domain_objects import ( + CollectionMetadata, + CollectionStatus, + IngestionJob, + IngestionType, + NoneChunkingStrategy, + RagCollectionConfig, + RagDocument, +) +from repository.rag_document_repo import RagDocumentRepository +from utilities.bedrock_kb import bulk_delete_documents_from_kb, delete_document_from_kb + +from .repository_service import RepositoryService + +logger = logging.getLogger(__name__) + + +class BedrockKBRepositoryService(RepositoryService): + """Service for Bedrock Knowledge Base repository operations. + + Bedrock KB manages its own ingestion, chunking, and embedding pipeline. + LISA only tracks documents and delegates actual operations to Bedrock. + """ + + def supports_custom_collections(self) -> bool: + """Bedrock KB only supports default collections (data sources).""" + return False + + def should_create_default_collection(self) -> bool: + """Bedrock KB does not need virtual default collections.""" + return False + + def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: + """For Bedrock KB, collection ID is the data source ID. + + Extracts the data source ID from the pipeline config's collectionId field, + which should match one of the data sources in bedrockKnowledgeBaseConfig. + """ + # The pipeline config should have a collectionId that matches a data source ID + collection_id = pipeline_config.get("collectionId") + + if collection_id: + return collection_id + + # Fallback: try to get from bedrock config (legacy support) + bedrock_config = self.repository.get("bedrockKnowledgeBaseConfig", {}) + + # Try new structure with dataSources array + data_sources = bedrock_config.get("dataSources", []) + if data_sources: + first_data_source = data_sources[0] + data_source_id = ( + first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id + ) + return data_source_id + + # Try legacy single data source ID + data_source_id = bedrock_config.get("bedrockKnowledgeDatasourceId") + if data_source_id: + return data_source_id + + raise ValueError(f"Bedrock KB repository {self.repository_id} missing data source ID") + + def ingest_document( + self, + job: IngestionJob, + texts: List[str], + metadatas: List[Dict[str, Any]], + ) -> RagDocument: + """Track document for Bedrock KB - KB handles actual ingestion. + + Bedrock KB automatically ingests documents from its S3 data source. + LISA only tracks the document metadata for querying and management. + """ + bedrock_config = self.repository.get("bedrockKnowledgeBaseConfig", {}) + kb_bucket = bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket") + + # Validate document is from KB bucket + kb_s3_path = self._validate_and_normalize_path(job.s3_path, kb_bucket) + + # Check if document already tracked (idempotent) + rag_document_repository = RagDocumentRepository( + os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"] + ) + + existing_docs = list( + rag_document_repository.find_by_source(job.repository_id, job.collection_id, kb_s3_path, join_docs=False) + ) + + if existing_docs: + # Update existing document timestamp + existing_doc = existing_docs[0] + existing_doc.upload_date = int(datetime.now(timezone.utc).timestamp() * 1000) + rag_document_repository.save(existing_doc) + logger.info(f"Document {kb_s3_path} already tracked, updated timestamp") + return existing_doc + + # Create new document tracking entry + rag_document = RagDocument( + repository_id=job.repository_id, + collection_id=job.collection_id, + document_name=os.path.basename(kb_s3_path), + source=kb_s3_path, + subdocs=[], # KB manages chunks internally + chunk_strategy=NoneChunkingStrategy(), + username=job.username, + ingestion_type=job.ingestion_type, + ) + rag_document_repository.save(rag_document) + + logger.info(f"Tracked document {kb_s3_path} for Bedrock KB. " f"KB will handle ingestion automatically.") + return rag_document + + def delete_document( + self, + document: RagDocument, + s3_client: Any, + bedrock_agent_client: Optional[Any] = None, + ) -> None: + """Delete document from Bedrock KB.""" + if not bedrock_agent_client: + raise ValueError("Bedrock agent client required for KB operations") + + # Create minimal job for deletion + job = IngestionJob( + repository_id=document.repository_id, + collection_id=document.collection_id, + s3_path=document.source, + username=document.username, + ingestion_type=document.ingestion_type, + ) + + delete_document_from_kb( + s3_client=s3_client, + bedrock_agent_client=bedrock_agent_client, + job=job, + repository=self.repository, + ) + + def delete_collection( + self, + collection_id: str, + s3_client: Any, + bedrock_agent_client: Optional[Any] = None, + ) -> None: + """Delete all LISA-managed documents from Bedrock KB collection. + + Only deletes documents with ingestion_type MANUAL or AUTO. + Preserves user-managed documents (ingestion_type EXISTING). + """ + if not bedrock_agent_client: + raise ValueError("Bedrock agent client required for KB operations") + + dynamodb = boto3.resource("dynamodb") + doc_table = dynamodb.Table(os.environ["RAG_DOCUMENT_TABLE"]) + + pk = f"{self.repository_id}#{collection_id}" + + # Query all documents in collection + response = doc_table.query(KeyConditionExpression=Key("pk").eq(pk)) + documents = response.get("Items", []) + + # Handle pagination + while "LastEvaluatedKey" in response: + response = doc_table.query( + KeyConditionExpression=Key("pk").eq(pk), ExclusiveStartKey=response["LastEvaluatedKey"] + ) + documents.extend(response.get("Items", [])) + + logger.info(f"Found {len(documents)} total documents in collection") + + # Separate by ingestion type + lisa_managed = [ + doc for doc in documents if doc.get("ingestion_type") in [IngestionType.MANUAL, IngestionType.AUTO] + ] + user_managed = [doc for doc in documents if doc.get("ingestion_type") == IngestionType.EXISTING] + + logger.info( + f"Collection {collection_id}: " f"lisa_managed={len(lisa_managed)}, user_managed={len(user_managed)}" + ) + + # Extract S3 paths for LISA-managed documents + s3_paths = [doc.get("source", "") for doc in lisa_managed if doc.get("source")] + + if s3_paths: + bulk_delete_documents_from_kb( + s3_client=s3_client, + bedrock_agent_client=bedrock_agent_client, + repository=self.repository, + s3_paths=s3_paths, + data_source_id=collection_id, + ) + logger.info( + f"Bulk deleted {len(s3_paths)} LISA-managed documents, " + f"preserved {len(user_managed)} user-managed documents" + ) + else: + logger.info("No LISA-managed documents to delete from KB") + + def retrieve_documents( + self, + query: str, + collection_id: str, + top_k: int, + model_name: str, + include_score: bool = False, + bedrock_agent_client: Optional[Any] = None, + ) -> List[Dict[str, Any]]: + """Retrieve documents from Bedrock KB using retrieve API. + + Args: + query: Search query + collection_id: Collection to search (data source ID) + top_k: Number of results to return + model_name: Embedding model name (not used for Bedrock KB) + include_score: Whether to include similarity scores in metadata + bedrock_agent_client: Bedrock agent client for KB operations + + Returns: + List of documents with page_content and metadata + """ + if not bedrock_agent_client: + raise ValueError("Bedrock agent client required for KB operations") + + bedrock_config = self.repository.get("bedrockKnowledgeBaseConfig", {}) + # Support both field names for backward compatibility + kb_id = bedrock_config.get("knowledgeBaseId", bedrock_config.get("bedrockKnowledgeBaseId")) + + if not kb_id: + raise ValueError( + f"Bedrock KB repository '{self.repository_id}' is missing required field " + f"'bedrockKnowledgeBaseId' or 'knowledgeBaseId' in bedrockKnowledgeBaseConfig. " + f"Please update the repository configuration with the actual AWS Bedrock Knowledge Base ID " + f"(e.g., 'KB123456' or a UUID format, not the LISA repository ID)." + ) + + # Use Bedrock retrieve API with data source filter + logger.info(f"Retrieving from KB: kb_id={kb_id}, data_source={collection_id}, query={query[:50]}...") + + # Build retrieve params with data source filter + retrieve_params = { + "knowledgeBaseId": kb_id, + "retrievalQuery": {"text": query}, + "retrievalConfiguration": { + "vectorSearchConfiguration": { + "numberOfResults": top_k, + } + }, + } + + # Add data source filter if collection_id is provided + # collection_id corresponds to the data source ID in Bedrock KB + if collection_id: + retrieve_params["retrievalConfiguration"]["vectorSearchConfiguration"]["filter"] = { + "equals": { + "key": "x-amz-bedrock-kb-data-source-id", + "value": collection_id, + } + } + logger.info(f"Filtering to data source: {collection_id}") + + try: + response = bedrock_agent_client.retrieve(**retrieve_params) + except Exception as e: + logger.error(f"Bedrock retrieve failed for KB {kb_id}: {str(e)}") + if "filter" in retrieve_params.get("retrievalConfiguration", {}).get("vectorSearchConfiguration", {}): + logger.error( + "Filter may not be supported. Ensure metadata field 'x-amz-bedrock-kb-data-source-id' " + "is configured in the Knowledge Base." + ) + raise + + # Transform Bedrock results to standard format + documents = [] + for result in response.get("retrievalResults", []): + metadata = result.get("metadata", {}).copy() + + # Add score to metadata if requested + if include_score: + metadata["similarity_score"] = result.get("score", 0.0) + + # Add location info to metadata + location = result.get("location", {}) + if location: + metadata["source"] = location.get("s3Location", {}).get("uri", "") + + documents.append( + { + "page_content": result.get("content", {}).get("text", ""), + "metadata": metadata, + } + ) + + return documents + + def validate_document_source(self, s3_path: str) -> str: + """Validate document is from KB data source bucket.""" + bedrock_config = self.repository.get("bedrockKnowledgeBaseConfig", {}) + kb_bucket = bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket") + + return self._validate_and_normalize_path(s3_path, kb_bucket) + + def get_vector_store_client(self, collection_id: str, embeddings: Any) -> Optional[Any]: + """Bedrock KB does not use external vector store clients.""" + return None + + def _create_collection_for_data_source( + self, data_source_id: str, s3_uri: str = "", is_default: bool = False, collection_name: Optional[str] = None + ) -> RagCollectionConfig: + """Create a collection configuration for a specific data source. + + Args: + data_source_id: The data source ID to use as collection ID + s3_uri: Optional S3 URI for the data source + is_default: Whether this is the default collection + collection_name: Optional collection name (defaults to data_source_id if not provided) + + Returns: + Collection configuration for the data source + """ + embedding_model = self.repository.get("embeddingModelId") + # Use provided collection_name or fall back to data_source_id + display_name = collection_name or f"{self.repository.get('name', self.repository_id)}-{data_source_id}" + + # Get KB name for description + kb_name = self.repository.get("repositoryName") or self.repository.get("name", "Knowledge Base") + + # Set tags and description based on whether this is default + if is_default: + tags = ["default", "bedrock-kb"] + description = f"Default collection for Bedrock Knowledge Base: {kb_name}" + else: + tags = ["bedrock-kb", "data-source"] + description = f"Auto-created collection for {kb_name}" + + collection = RagCollectionConfig( + collectionId=data_source_id, + repositoryId=self.repository_id, + name=display_name, + description=description, + embeddingModel=embedding_model, + chunkingStrategy=None, # KB controls chunking + allowedGroups=self.repository.get("allowedGroups", []), + createdBy=self.repository.get("createdBy", "system"), + status=CollectionStatus.ACTIVE, + metadata=CollectionMetadata(tags=tags, customFields={"s3Uri": s3_uri} if s3_uri else {}), + allowChunkingOverride=False, # KB controls chunking + pipelines=self.repository.get("pipelines", []), + default=is_default, + dataSourceId=data_source_id, + createdAt=datetime.now(timezone.utc), + updatedAt=datetime.now(timezone.utc), + ) + + return collection + + def create_default_collection(self, ingest_docs=False) -> Optional[RagCollectionConfig]: + """Create a default collection for Bedrock KB repository. + + For Bedrock KB, the collection ID is the data source ID. + If multiple data sources exist, returns the first one. + + Returns: + Default collection configuration for Bedrock KB + """ + try: + bedrock_config = self.repository.get("bedrockKnowledgeBaseConfig", {}) + + # Handle new structure with dataSources array + data_sources = bedrock_config.get("dataSources", []) + + # Also check for legacy single data source ID + legacy_data_source_id = bedrock_config.get("bedrockKnowledgeDatasourceId") + + if not data_sources and not legacy_data_source_id: + logger.warning(f"Bedrock KB repository {self.repository_id} missing data source ID") + return None + + # Use first data source from array, or legacy single ID + if data_sources: + first_data_source = data_sources[0] + data_source_id = ( + first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id + ) + s3_uri = ( + first_data_source.get("s3Uri", "") + if isinstance(first_data_source, dict) + else getattr(first_data_source, "s3Uri", "") + ) + else: + data_source_id = legacy_data_source_id + s3_uri = "" + + # Use helper method to create collection + default_collection = self._create_collection_for_data_source( + data_source_id=data_source_id, s3_uri=s3_uri, is_default=True + ) + + logger.info(f"Created virtual default collection for Bedrock KB repository {self.repository_id}") + + if ingest_docs: + # Ingest existing documents from S3 bucket if s3pipeline is configured + s3_bucket = bedrock_config.get("s3pipeline") + + if s3_bucket: + logger.info( + f"S3 pipeline configured with bucket {s3_bucket}. " + f"Document ingestion requires additional dependencies not available in this context." + ) + + return default_collection + + except Exception as e: + logger.error(f"Failed to create default collection for Bedrock KB repository {self.repository_id}: {e}") + return None + + def _validate_and_normalize_path(self, s3_path: str, expected_bucket: str) -> str: + """Validate S3 path is from expected bucket and normalize.""" + source_bucket = s3_path.split("/")[2] if s3_path.startswith("s3://") else None + + if source_bucket != expected_bucket: + logger.warning( + f"Document {s3_path} not from KB bucket {expected_bucket}. " f"Normalizing to KB bucket path." + ) + # Normalize to KB bucket path + return f"s3://{expected_bucket}/{os.path.basename(s3_path)}" + + return s3_path diff --git a/lambda/repository/services/opensearch_repository_service.py b/lambda/repository/services/opensearch_repository_service.py new file mode 100644 index 000000000..70545d52f --- /dev/null +++ b/lambda/repository/services/opensearch_repository_service.py @@ -0,0 +1,195 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenSearch repository service implementation.""" + +import json +import logging +import os +from typing import Any + +import boto3 +from langchain_community.vectorstores import OpenSearchVectorSearch +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore +from opensearchpy import RequestsHttpConnection +from repository.embeddings import RagEmbeddings +from requests_aws4auth import AWS4Auth +from utilities.common_functions import retry_config +from utilities.repository_types import RepositoryType + +from .vector_store_repository_service import VectorStoreRepositoryService + +logger = logging.getLogger(__name__) +session = boto3.Session() +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + + +class OpenSearchRepositoryService(VectorStoreRepositoryService): + """Service for OpenSearch repository operations. + + Inherits common vector store behavior from VectorStoreRepositoryService. + Only implements OpenSearch-specific index management. + """ + + def retrieve_documents( + self, + query: str, + collection_id: str, + top_k: int, + model_name: str, + include_score: bool = False, + bedrock_agent_client: Any = None, + ) -> list[dict[str, Any]]: + """Retrieve documents from OpenSearch with index existence check. + + Args: + query: Search query + collection_id: Collection to search + top_k: Number of results to return + model_name: Embedding model name to use for query embedding + include_score: Whether to include similarity scores in metadata + bedrock_agent_client: Not used for OpenSearch + + Returns: + List of documents with page_content and metadata + """ + # Create embeddings and vector store client once + embeddings = RagEmbeddings(model_name=model_name) + vector_store = self._get_vector_store_client( + collection_id=collection_id, + embeddings=embeddings, + ) + + # Check if index exists before searching + if hasattr(vector_store, "client") and hasattr(vector_store.client, "indices"): + if not vector_store.client.indices.exists(index=collection_id): + logger.info(f"Collection {collection_id} does not exist. Returning empty docs.") + return [] + + # Perform similarity search + results = vector_store.similarity_search_with_score(query, k=top_k) + + documents = [] + for i, (doc, score) in enumerate(results): + doc_dict = { + "page_content": doc.page_content, + "metadata": doc.metadata.copy() if doc.metadata else {}, + } + + if include_score: + # OpenSearch scores are already normalized (0-1 range) + normalized_score = self._normalize_similarity_score(score) + doc_dict["metadata"]["similarity_score"] = normalized_score + + logger.info( + f"Result {i + 1}: Raw Score={score:.4f}, Similarity={normalized_score:.4f}, " + f"Content: {doc.page_content[:200]}..." + ) + logger.info(f"Result {i + 1} metadata: {doc.metadata}") + + documents.append(doc_dict) + + # Warn if all scores are low (possible embedding model mismatch) + if include_score and results: + max_score = max(self._normalize_similarity_score(score) for _, score in results) + if max_score < 0.3: + logger.warning( + f"All similarity scores < 0.3 for query '{query}' - " "possible embedding model mismatch" + ) + + return documents + + def _drop_collection_index(self, collection_id: str) -> None: + """Drop OpenSearch index for collection.""" + try: + logger.info(f"Dropping OpenSearch index for collection {collection_id}") + + embeddings = RagEmbeddings(model_name=collection_id) + vector_store = self._get_vector_store_client( + collection_id=collection_id, + embeddings=embeddings, + ) + + # Drop the index if it exists + if hasattr(vector_store, "client") and hasattr(vector_store.client, "indices"): + index_name = f"{self.repository_id}_{collection_id}".lower() + if vector_store.client.indices.exists(index=index_name): + vector_store.client.indices.delete(index=index_name) + logger.info(f"Dropped OpenSearch index: {index_name}") + else: + logger.info(f"OpenSearch index {index_name} does not exist") + else: + logger.warning("Vector store client does not support index operations") + + except Exception as e: + logger.error(f"Failed to drop OpenSearch index: {e}", exc_info=True) + # Don't raise - continue with document deletion + + # OpenSearch uses default score normalization (0-1 range already) + + def _get_vector_store_client(self, collection_id: str, embeddings: Embeddings) -> VectorStore: + """Get OpenSearch vector store client. + + Args: + collection_id: Collection identifier + embeddings: Embeddings adapter + + Returns: + OpenSearchVectorSearch client instance + + Raises: + ValueError: If repository is not registered or not an OpenSearch repository + """ + prefix = os.environ.get("REGISTERED_REPOSITORIES_PS_PREFIX") + parameter_name = f"{prefix}{self.repository_id}" + + try: + connection_info = ssm_client.get_parameter(Name=parameter_name) + connection_info = json.loads(connection_info["Parameter"]["Value"]) + except ssm_client.exceptions.ParameterNotFound: + logger.error( + f"Repository '{self.repository_id}' not found in SSM Parameter Store. " + f"Parameter: {parameter_name}. " + f"Ensure the repository is registered before use." + ) + raise ValueError( + f"Repository '{self.repository_id}' is not registered. " + f"Please register the repository before performing operations." + ) + + if not RepositoryType.is_type(connection_info, RepositoryType.OPENSEARCH): + raise ValueError(f"Repository {self.repository_id} is not an OpenSearch repository") + + credentials = session.get_credentials() + auth = AWS4Auth( + credentials.access_key, + credentials.secret_key, + session.region_name, + "es", + session_token=credentials.token, + ) + + opensearch_endpoint = f"https://{connection_info.get('endpoint')}" + + return OpenSearchVectorSearch( + opensearch_url=opensearch_endpoint, + index_name=collection_id, + embedding_function=embeddings, + http_auth=auth, + timeout=300, + use_ssl=True, + verify_certs=True, + connection_class=RequestsHttpConnection, + ) diff --git a/lambda/repository/services/pgvector_repository_service.py b/lambda/repository/services/pgvector_repository_service.py new file mode 100644 index 000000000..f34515484 --- /dev/null +++ b/lambda/repository/services/pgvector_repository_service.py @@ -0,0 +1,136 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PGVector repository service implementation.""" + +import json +import logging +import os + +import boto3 +from langchain_community.vectorstores import PGVector +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore +from repository.embeddings import RagEmbeddings +from utilities.common_functions import get_lambda_role_name, retry_config +from utilities.rds_auth import generate_auth_token +from utilities.repository_types import RepositoryType + +from .vector_store_repository_service import VectorStoreRepositoryService + +logger = logging.getLogger(__name__) +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) +secretsmanager_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) + + +class PGVectorRepositoryService(VectorStoreRepositoryService): + """Service for PGVector repository operations. + + Inherits common vector store behavior from VectorStoreRepositoryService. + Only implements PGVector-specific collection management and score normalization. + """ + + def _drop_collection_index(self, collection_id: str) -> None: + """Drop PGVector collection table.""" + try: + logger.info(f"Dropping PGVector collection for {collection_id}") + + embeddings = RagEmbeddings(model_name=collection_id) + vector_store = self._get_vector_store_client( + collection_id=collection_id, + embeddings=embeddings, + ) + + # Drop the collection if supported + if hasattr(vector_store, "delete_collection"): + vector_store.delete_collection() + logger.info(f"Dropped PGVector collection: {collection_id}") + else: + logger.warning("Vector store does not support collection deletion") + + except Exception as e: + logger.error(f"Failed to drop PGVector collection: {e}", exc_info=True) + # Don't raise - continue with document deletion + + def _normalize_similarity_score(self, score: float) -> float: + """Convert PGVector cosine distance to similarity score. + + PGVector returns cosine distance (0-2 range, lower = more similar). + Convert to similarity (0-1 range, higher = more similar). + + Args: + score: Cosine distance from PGVector + + Returns: + Similarity score in 0-1 range + """ + return max(0.0, 1.0 - (score / 2.0)) + + def _get_vector_store_client(self, collection_id: str, embeddings: Embeddings) -> VectorStore: + """Get PGVector vector store client. + + Args: + collection_id: Collection identifier + embeddings: Embeddings adapter + + Returns: + PGVector client instance + + Raises: + ValueError: If repository is not registered or not a PGVector repository + """ + prefix = os.environ.get("REGISTERED_REPOSITORIES_PS_PREFIX") + parameter_name = f"{prefix}{self.repository_id}" + + try: + connection_info = ssm_client.get_parameter(Name=parameter_name) + connection_info = json.loads(connection_info["Parameter"]["Value"]) + except ssm_client.exceptions.ParameterNotFound: + logger.error( + f"Repository '{self.repository_id}' not found in SSM Parameter Store. " + f"Parameter: {parameter_name}. " + f"Ensure the repository is registered before use." + ) + raise ValueError( + f"Repository '{self.repository_id}' is not registered. " + f"Please register the repository before performing operations." + ) + + if not RepositoryType.is_type(connection_info, RepositoryType.PGVECTOR): + raise ValueError(f"Repository {self.repository_id} is not a PGVector repository") + + if "passwordSecretId" in connection_info: + # Provides backwards compatibility to non-IAM authenticated vector stores + secrets_response = secretsmanager_client.get_secret_value(SecretId=connection_info.get("passwordSecretId")) + user = connection_info.get("username") + password = json.loads(secrets_response.get("SecretString")).get("password") + else: + # Use IAM auth token to connect + user = get_lambda_role_name() + password = generate_auth_token(connection_info.get("dbHost"), connection_info.get("dbPort"), user) + + connection_string = PGVector.connection_string_from_db_params( + driver="psycopg2", + host=connection_info.get("dbHost"), + port=connection_info.get("dbPort"), + database=connection_info.get("dbName"), + user=user, + password=password, + ) + + return PGVector( + collection_name=collection_id, + connection_string=connection_string, + embedding_function=embeddings, + ) diff --git a/lambda/repository/services/repository_service.py b/lambda/repository/services/repository_service.py new file mode 100644 index 000000000..31482ee95 --- /dev/null +++ b/lambda/repository/services/repository_service.py @@ -0,0 +1,180 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base service interface for repository operations.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from models.domain_objects import IngestionJob, RagCollectionConfig, RagDocument + + +class RepositoryService(ABC): + """Abstract base class defining repository-specific operations. + + Each repository type (OpenSearch, PGVector, Bedrock KB) implements this + interface to provide type-specific behavior for document management. + """ + + def __init__(self, repository: Dict[str, Any]): + """Initialize service with repository configuration. + + Args: + repository: Repository configuration dictionary + """ + self.repository = repository + self.repository_id = repository.get("repositoryId") + + @abstractmethod + def supports_custom_collections(self) -> bool: + """Check if repository supports user-defined collections. + + Returns: + True if custom collections are supported, False otherwise + """ + pass + + @abstractmethod + def should_create_default_collection(self) -> bool: + """Check if a default/virtual collection should be created. + + Returns: + True if default collection should be created, False otherwise + """ + pass + + @abstractmethod + def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: + """Extract collection ID from pipeline configuration. + + Args: + pipeline_config: Pipeline configuration dictionary + + Returns: + Collection ID to use for operations + """ + pass + + @abstractmethod + def ingest_document( + self, + job: IngestionJob, + texts: List[str], + metadatas: List[Dict[str, Any]], + ) -> RagDocument: + """Ingest a document into the repository. + + Args: + job: Ingestion job with document details + texts: List of text chunks + metadatas: List of metadata dictionaries for each chunk + + Returns: + RagDocument representing the ingested document + """ + pass + + @abstractmethod + def delete_document( + self, + document: RagDocument, + s3_client: Any, + bedrock_agent_client: Optional[Any] = None, + ) -> None: + """Delete a document from the repository. + + Args: + document: Document to delete + s3_client: S3 client for file operations + bedrock_agent_client: Bedrock agent client (for Bedrock KB only) + """ + pass + + @abstractmethod + def delete_collection( + self, + collection_id: str, + s3_client: Any, + bedrock_agent_client: Optional[Any] = None, + ) -> None: + """Delete an entire collection from the repository. + + Args: + collection_id: Collection to delete + s3_client: S3 client for file operations + bedrock_agent_client: Bedrock agent client (for Bedrock KB only) + """ + pass + + @abstractmethod + def retrieve_documents( + self, + query: str, + collection_id: str, + top_k: int, + model_name: str, + include_score: bool = False, + bedrock_agent_client: Optional[Any] = None, + ) -> List[Dict[str, Any]]: + """Retrieve documents matching a query. + + Args: + query: Search query + collection_id: Collection to search + top_k: Number of results to return + model_name: Embedding model name to use for query embedding + include_score: Whether to include similarity scores in results + bedrock_agent_client: Bedrock agent client (for Bedrock KB only) + + Returns: + List of matching documents with page_content and metadata + """ + pass + + @abstractmethod + def validate_document_source(self, s3_path: str) -> str: + """Validate and normalize document source path. + + Args: + s3_path: S3 path to validate + + Returns: + Normalized S3 path + + Raises: + ValueError: If path is invalid for this repository type + """ + pass + + @abstractmethod + def get_vector_store_client(self, collection_id: str, embeddings: Any) -> Optional[Any]: + """Get vector store client for this repository. + + Args: + collection_id: Collection identifier + embeddings: Embeddings adapter + + Returns: + Vector store client, or None if not applicable (e.g., Bedrock KB) + """ + pass + + @abstractmethod + def create_default_collection(self) -> Optional[RagCollectionConfig]: + """Create a default collection for this repository. + + Returns: + Default collection configuration, or None if not applicable + """ + pass diff --git a/lambda/repository/services/repository_service_factory.py b/lambda/repository/services/repository_service_factory.py new file mode 100644 index 000000000..88f2b9d35 --- /dev/null +++ b/lambda/repository/services/repository_service_factory.py @@ -0,0 +1,84 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating repository service instances.""" + +from typing import Any, Dict, Type + +from utilities.repository_types import RepositoryType + +from .bedrock_kb_repository_service import BedrockKBRepositoryService +from .opensearch_repository_service import OpenSearchRepositoryService +from .pgvector_repository_service import PGVectorRepositoryService +from .repository_service import RepositoryService + + +class RepositoryServiceFactory: + """Factory for creating repository-specific service instances. + + Encapsulates repository-specific behavior, eliminating the need for + conditional logic throughout the codebase. + """ + + # Registry mapping repository types to service classes + _services: Dict[RepositoryType, Type[RepositoryService]] = { + RepositoryType.OPENSEARCH: OpenSearchRepositoryService, + RepositoryType.PGVECTOR: PGVectorRepositoryService, + RepositoryType.BEDROCK_KB: BedrockKBRepositoryService, + } + + @classmethod + def create_service(cls, repository: Dict[str, Any]) -> RepositoryService: + """Create appropriate service instance for repository type. + + Args: + repository: Repository configuration dictionary + + Returns: + Service instance for the repository type + + Raises: + ValueError: If repository type is not supported + """ + repo_type = RepositoryType.get_type(repository) + + service_class = cls._services.get(repo_type) + if not service_class: + raise ValueError( + f"Unsupported repository type: {repo_type}. " f"Supported types: {list(cls._services.keys())}" + ) + + return service_class(repository) + + @classmethod + def register_service(cls, repo_type: RepositoryType, service_class: Type[RepositoryService]) -> None: + """Register a new service class for a repository type. + + Allows extending the factory with new repository types without + modifying the factory code (Open/Closed Principle). + + Args: + repo_type: Repository type to register + service_class: Service class to use for this type + """ + cls._services[repo_type] = service_class + + @classmethod + def get_supported_types(cls) -> list[RepositoryType]: + """Get list of supported repository types. + + Returns: + List of registered repository types + """ + return list(cls._services.keys()) diff --git a/lambda/repository/services/vector_store_repository_service.py b/lambda/repository/services/vector_store_repository_service.py new file mode 100644 index 000000000..89910ea60 --- /dev/null +++ b/lambda/repository/services/vector_store_repository_service.py @@ -0,0 +1,333 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base implementation for vector store-based repository services (OpenSearch, PGVector). + +This class provides common functionality for repositories that use traditional +vector stores with chunking and embedding pipelines. +""" + +import logging +import os +from abc import abstractmethod +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import boto3 +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore +from models.domain_objects import ( + CollectionMetadata, + CollectionStatus, + IngestionJob, + RagCollectionConfig, + RagDocument, + VectorStoreStatus, +) +from repository.embeddings import RagEmbeddings +from repository.rag_document_repo import RagDocumentRepository +from utilities.common_functions import retry_config + +from .repository_service import RepositoryService + +logger = logging.getLogger(__name__) +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + + +class VectorStoreRepositoryService(RepositoryService): + """Base implementation for vector store-based repository services. + + Provides common functionality for OpenSearch and PGVector repositories + that share similar ingestion, deletion, and retrieval patterns. + + Subclasses only need to implement repository-specific operations like + index/collection dropping and score normalization. + """ + + def supports_custom_collections(self) -> bool: + """Vector stores support custom collections.""" + return True + + def should_create_default_collection(self) -> bool: + """Vector stores create virtual default collections.""" + return True + + def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: + """Extract collection ID from pipeline config or use embedding model.""" + collection_id = pipeline_config.get("collectionId") + if not collection_id: + collection_id = pipeline_config.get("embeddingModel") + return collection_id + + def ingest_document( + self, + job: IngestionJob, + texts: List[str], + metadatas: List[Dict[str, Any]], + ) -> RagDocument: + """Ingest document into vector store with chunking and embedding.""" + # Store chunks in vector store + all_ids = self._store_chunks( + texts=texts, + metadatas=metadatas, + collection_id=job.collection_id, + embedding_model=job.embedding_model, + ) + + # Create document record + rag_document = RagDocument( + repository_id=job.repository_id, + collection_id=job.collection_id, + document_name=os.path.basename(job.s3_path), + source=job.s3_path, + subdocs=all_ids, + chunk_strategy=job.chunk_strategy, + username=job.username, + ingestion_type=job.ingestion_type, + ) + + rag_document_repository = RagDocumentRepository( + os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"] + ) + rag_document_repository.save(rag_document) + + logger.info( + f"Ingested document {job.s3_path} ({len(all_ids)} chunks) " + f"into {self.repository.get('type')} collection {job.collection_id}" + ) + return rag_document + + def delete_document( + self, + document: RagDocument, + s3_client: Any, + bedrock_agent_client: Optional[Any] = None, + ) -> None: + """Delete document from vector store.""" + embeddings = RagEmbeddings(model_name=document.collection_id) + vector_store = self._get_vector_store_client( + collection_id=document.collection_id, + embeddings=embeddings, + ) + vector_store.delete(document.subdocs) + + def delete_collection( + self, + collection_id: str, + s3_client: Any, + bedrock_agent_client: Optional[Any] = None, + ) -> None: + """Delete collection from vector store. + + Delegates to subclass-specific implementation for dropping + indexes/collections. + """ + self._drop_collection_index(collection_id) + + def retrieve_documents( + self, + query: str, + collection_id: str, + top_k: int, + model_name: str, + include_score: bool = False, + bedrock_agent_client: Optional[Any] = None, + ) -> List[Dict[str, Any]]: + """Retrieve documents from vector store using similarity search. + + Args: + query: Search query + collection_id: Collection to search + top_k: Number of results to return + model_name: Embedding model name to use for query embedding + include_score: Whether to include similarity scores in metadata + bedrock_agent_client: Not used for vector stores + + Returns: + List of documents with page_content and metadata + """ + embeddings = RagEmbeddings(model_name=model_name) + vector_store = self._get_vector_store_client( + collection_id=collection_id, + embeddings=embeddings, + ) + + results = vector_store.similarity_search_with_score(query, k=top_k) + + documents = [] + for i, (doc, score) in enumerate(results): + doc_dict = { + "page_content": doc.page_content, + "metadata": doc.metadata.copy() if doc.metadata else {}, + } + + if include_score: + # Normalize score based on repository type + normalized_score = self._normalize_similarity_score(score) + doc_dict["metadata"]["similarity_score"] = normalized_score + + logger.info( + f"Result {i + 1}: Raw Score={score:.4f}, Similarity={normalized_score:.4f}, " + f"Content: {doc.page_content[:200]}..." + ) + logger.info(f"Result {i + 1} metadata: {doc.metadata}") + + documents.append(doc_dict) + + # Warn if all scores are low (possible embedding model mismatch) + if include_score and results: + max_score = max(self._normalize_similarity_score(score) for _, score in results) + if max_score < 0.3: + logger.warning( + f"All similarity scores < 0.3 for query '{query}' - " "possible embedding model mismatch" + ) + + return documents + + def validate_document_source(self, s3_path: str) -> str: + """Vector stores accept any valid S3 path.""" + if not s3_path.startswith("s3://"): + raise ValueError(f"Invalid S3 path: {s3_path}") + return s3_path + + def get_vector_store_client(self, collection_id: str, embeddings: Any) -> Any: + """Get vector store client for this repository.""" + return self._get_vector_store_client( + collection_id=collection_id, + embeddings=embeddings, + ) + + # Protected methods for subclass customization + + @abstractmethod + def _drop_collection_index(self, collection_id: str) -> None: + """Drop collection index/table (repository-specific). + + Args: + collection_id: Collection to drop + """ + pass + + def _normalize_similarity_score(self, score: float) -> float: + """Normalize similarity score to 0-1 range. + + Default implementation returns score as-is (for OpenSearch). + Subclasses can override for different scoring systems (e.g., PGVector). + + Args: + score: Raw similarity score from vector store + + Returns: + Normalized score in 0-1 range + """ + return score + + def create_default_collection(self) -> Optional[RagCollectionConfig]: + """Create a default collection for vector store repositories. + + Returns: + Default collection configuration using repository's embedding model + """ + try: + # Check if repository is active + active = self.repository.get("status", VectorStoreStatus.UNKNOWN) in [ + VectorStoreStatus.CREATE_COMPLETE, + VectorStoreStatus.UPDATE_COMPLETE, + VectorStoreStatus.UPDATE_COMPLETE_CLEANUP_IN_PROGRESS, + VectorStoreStatus.UPDATE_IN_PROGRESS, + ] + + if not active: + logger.info(f"Repository {self.repository_id} is not active") + return None + + embedding_model = self.repository.get("embeddingModelId") + if not embedding_model: + logger.info(f"Repository {self.repository_id} has no default embedding model") + return None + + # Use embedding model as collection ID + collection_id = embedding_model + sanitized_name = f"{self.repository.get('name', self.repository_id)}-{embedding_model}".replace(".", "-") + + default_collection = RagCollectionConfig( + collectionId=collection_id, + repositoryId=self.repository_id, + name=sanitized_name, + description="Default collection using repository's embedding model", + embeddingModel=embedding_model, + chunkingStrategy=self.repository.get("chunkingStrategy"), + allowedGroups=self.repository.get("allowedGroups", []), + createdBy=self.repository.get("createdBy", "system"), + status=CollectionStatus.ACTIVE, + metadata=CollectionMetadata(tags=["default"], customFields={}), + allowChunkingOverride=True, + pipelines=self.repository.get("pipelines", []), + default=True, + createdAt=datetime.now(timezone.utc), + updatedAt=datetime.now(timezone.utc), + ) + + logger.info(f"Created virtual default collection for repository {self.repository_id}") + return default_collection + + except Exception as e: + logger.error(f"Failed to create default collection for repository {self.repository_id}: {e}") + return None + + def _store_chunks( + self, + texts: List[str], + metadatas: List[Dict[str, Any]], + collection_id: str, + embedding_model: str, + ) -> List[str]: + """Store document chunks in vector store.""" + embeddings = RagEmbeddings(model_name=embedding_model) + vector_store = self._get_vector_store_client( + collection_id=collection_id, + embeddings=embeddings, + ) + + all_ids = [] + batch_size = 500 + + for i in range(0, len(texts), batch_size): + text_batch = texts[i : i + batch_size] + metadata_batch = metadatas[i : i + batch_size] + + batch_ids = vector_store.add_texts(texts=text_batch, metadatas=metadata_batch) + + if not batch_ids: + raise Exception(f"Failed to store batch {i // batch_size + 1}") + + all_ids.extend(batch_ids) + + if not all_ids: + raise Exception("Failed to store any documents in vector store") + + return all_ids + + @abstractmethod + def _get_vector_store_client(self, collection_id: str, embeddings: Embeddings) -> VectorStore: + """Get vector store client for this repository type. + + Args: + collection_id: Collection identifier + embeddings: Embeddings adapter + + Returns: + Vector store client instance + """ + pass diff --git a/lambda/repository/state_machine/cleanup_repo_docs.py b/lambda/repository/state_machine/cleanup_repo_docs.py index 23f55ac8b..9e1924137 100644 --- a/lambda/repository/state_machine/cleanup_repo_docs.py +++ b/lambda/repository/state_machine/cleanup_repo_docs.py @@ -16,6 +16,7 @@ import os from typing import Any, Dict +from models.domain_objects import IngestionType from pydantic import BaseModel from repository.rag_document_repo import RagDocumentRepository @@ -25,7 +26,11 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any] | Any: """ - Remove documents associated with a repository + Remove LISA-managed documents from repository. + + Only deletes documents with ingestion_type of MANUAL or AUTO. + Preserves EXISTING documents (user-managed). + Args: event: Event data containing bucket and prefix information context: Lambda context @@ -37,14 +42,27 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any] | Any: stack_name = event.get("stackName") last_evaluated = event.get("lastEvaluated") - docs, last_evaluated = doc_repo.list_all(repository_id=repository_id, last_evaluated_key=last_evaluated) - for doc in docs: + # Get all documents + docs, last_evaluated, _ = doc_repo.list_all(repository_id=repository_id, last_evaluated_key=last_evaluated) + + # Filter to LISA-managed only (MANUAL or AUTO) + lisa_managed = [d for d in docs if d.ingestion_type in [IngestionType.MANUAL, IngestionType.AUTO]] + + logger.info( + f"Repository cleanup: total={len(docs)}, " + f"lisa_managed={len(lisa_managed)}, " + f"preserved={len(docs) - len(lisa_managed)}" + ) + + # Delete from DynamoDB + for doc in lisa_managed: doc_repo.delete_by_id(doc.document_id) - doc_repo.delete_s3_docs(repository_id=repository_id, docs=docs) + # Delete from S3 (only LISA-managed) + doc_repo.delete_s3_docs(repository_id=repository_id, docs=lisa_managed) # Ensure JSON-serializable payload for Step Functions when Pydantic models are provided - serializable_docs = [doc.model_dump() if isinstance(doc, BaseModel) else doc for doc in docs] + serializable_docs = [doc.model_dump() if isinstance(doc, BaseModel) else doc for doc in lisa_managed] return { "repositoryId": repository_id, "stackName": stack_name, diff --git a/lambda/repository/state_machine/wait_for_collection_deletions.py b/lambda/repository/state_machine/wait_for_collection_deletions.py new file mode 100644 index 000000000..00f08d0e4 --- /dev/null +++ b/lambda/repository/state_machine/wait_for_collection_deletions.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wait for all collection deletion jobs to complete before deleting repository.""" + +import logging +from typing import Any, Dict + +from repository.ingestion_job_repo import IngestionJobRepository + +logger = logging.getLogger(__name__) + + +def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Check if all collection deletion jobs for a repository are complete. + + Args: + event: Event data containing repositoryId and stackName + context: Lambda context + + Returns: + Dictionary with completion status and job counts + """ + repository_id = event.get("repositoryId") + stack_name = event.get("stackName") + + logger.info(f"Checking collection deletion jobs for repository {repository_id}") + + job_repo = IngestionJobRepository() + + # Query all jobs for this repository + pending_jobs = job_repo.find_pending_collection_deletions(repository_id) + + pending_count = len(pending_jobs) + all_complete = pending_count == 0 + + logger.info( + f"Repository {repository_id}: " f"pending_collection_deletions={pending_count}, " f"all_complete={all_complete}" + ) + + return { + "repositoryId": repository_id, + "stackName": stack_name, + "allCollectionDeletionsComplete": all_complete, + "pendingDeletionCount": pending_count, + } diff --git a/lambda/repository/vector_store_repo.py b/lambda/repository/vector_store_repo.py index f06195659..56ba11564 100644 --- a/lambda/repository/vector_store_repo.py +++ b/lambda/repository/vector_store_repo.py @@ -13,9 +13,11 @@ # limitations under the License. import logging import os +import time from typing import Any, cast, List import boto3 +from models.domain_objects import VectorStoreStatus from utilities.common_functions import retry_config from utilities.encoders import convert_decimal @@ -30,7 +32,7 @@ def __init__(self) -> None: self.table = dynamodb.Table(os.environ["LISA_RAG_VECTOR_STORE_TABLE"]) def get_registered_repositories(self) -> List[dict]: - """Get a list of all registered RAG repositories.""" + """Get a list of all registered RAG repositories with default values for new fields.""" response = self.table.scan() items = response["Items"] while "LastEvaluatedKey" in response: @@ -41,9 +43,17 @@ def get_registered_repositories(self) -> List[dict]: registered_repositories = [] for item in items: + config = item.get("config", {}) + config["status"] = item.get("status", VectorStoreStatus.UNKNOWN) if item.get("legacy", False): - item["config"]["legacy"] = True - registered_repositories.append(item["config"]) + config["legacy"] = True + + if "metadata" not in config: + config["metadata"] = {"tags": []} + elif isinstance(config["metadata"], dict) and "tags" not in config["metadata"]: + config["metadata"]["tags"] = [] + + registered_repositories.append(config) return registered_repositories @@ -74,7 +84,7 @@ def find_repository_by_id(self, repository_id: str, raw_config: bool = False) -> repository_id: The ID of the repository to find. raw_config: return the full object in dynamo, instead of just the repository config portion Returns: - The repository configuration. + The repository configuration with default values for new fields. Raises: ValueError: If the repository is not found or the table does not exist. @@ -90,7 +100,64 @@ def find_repository_by_id(self, repository_id: str, raw_config: bool = False) -> raise ValueError(f"Repository with ID '{repository_id}' not found") repository: dict[str, Any] = convert_decimal(response.get("Item")) - return repository if raw_config else cast(dict[str, Any], repository.get("config", {})) + + if raw_config: + return repository + + # Get config and apply defaults for backward compatibility + config = cast(dict[str, Any], repository.get("config", {})) + config["status"] = repository.get("status") + + if "metadata" not in config: + config["metadata"] = {"tags": []} + elif isinstance(config["metadata"], dict) and "tags" not in config["metadata"]: + config["metadata"]["tags"] = [] + + return config + + def update(self, repository_id: str, updates: dict[str, Any], status: str | None = None) -> dict[str, Any]: + """ + Update a repository configuration. + + Args: + repository_id: The ID of the repository to update. + updates: Dictionary of fields to update in the config. + status: Optional status to set (if None, status is not updated). + + Returns: + The updated repository configuration. + + Raises: + ValueError: If the update fails or repository not found. + """ + try: + current = self.table.get_item(Key={"repositoryId": repository_id}) + if "Item" not in current: + raise ValueError(f"Repository with ID '{repository_id}' not found") + + # Keep original config with Decimal types intact + config: dict[str, Any] = current["Item"].get("config", {}) + config.update(updates) + + update_expr = "SET #config = :config, #updatedAt = :updatedAt" + expr_names = {"#config": "config", "#updatedAt": "updatedAt"} + expr_values = {":config": config, ":updatedAt": int(time.time() * 1000)} + + if status is not None: + update_expr += ", #status = :status" + expr_names["#status"] = "status" + expr_values[":status"] = status + + self.table.update_item( + Key={"repositoryId": repository_id}, + UpdateExpression=update_expr, + ExpressionAttributeNames=expr_names, + ExpressionAttributeValues=expr_values, + ) + + return config + except Exception as e: + raise ValueError(f"Failed to update repository: {repository_id}", e) def delete(self, repository_id: str) -> bool: """ diff --git a/lambda/session/lambda_functions.py b/lambda/session/lambda_functions.py index 0971e00f1..72987a97d 100644 --- a/lambda/session/lambda_functions.py +++ b/lambda/session/lambda_functions.py @@ -27,8 +27,8 @@ import create_env_variables # noqa: F401 from botocore.exceptions import ClientError from cachetools import cached, TTLCache -from utilities.auth import get_username -from utilities.common_functions import api_wrapper, get_groups, get_session_id, retry_config +from utilities.auth import get_user_context, get_username +from utilities.common_functions import api_wrapper, get_session_id, retry_config from utilities.encoders import convert_decimal from utilities.session_encryption import decrypt_session_fields, migrate_session_to_encrypted, SessionEncryptionError @@ -469,7 +469,7 @@ def rename_session(event: dict, context: dict) -> dict: def put_session(event: dict, context: dict) -> dict: """Append the message to the record in DynamoDB.""" try: - user_id = get_username(event) + user_id, _, groups = get_user_context(event) session_id = get_session_id(event) try: @@ -578,7 +578,7 @@ def put_session(event: dict, context: dict) -> dict: "userId": user_id, "sessionId": session_id, "messages": messages, - "userGroups": get_groups(event), + "userGroups": groups, "timestamp": datetime.now().isoformat(), } sqs_client.send_message( diff --git a/lambda/utilities/auth.py b/lambda/utilities/auth.py index 239ec6854..67c051df1 100644 --- a/lambda/utilities/auth.py +++ b/lambda/utilities/auth.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os from functools import wraps -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, List, Tuple import boto3 from botocore.config import Config -from utilities.common_functions import get_groups from utilities.exceptions import HTTPException logger = logging.getLogger(__name__) @@ -40,6 +40,12 @@ def get_username(event: dict) -> str: return username +def get_groups(event: Any) -> List[str]: + """Get user groups from event.""" + groups: List[str] = json.loads(event.get("requestContext", {}).get("authorizer", {}).get("groups", "[]")) + return groups + + def is_admin(event: dict) -> bool: """Get admin status from event.""" admin_group = os.environ.get("ADMIN_GROUP", "") @@ -48,9 +54,28 @@ def is_admin(event: dict) -> bool: return admin_group in groups -def get_user_context(event: Dict[str, Any]) -> Tuple[str, bool]: +def get_user_context(event: Dict[str, Any]) -> Tuple[str, bool, List[str]]: """Extract user context from event.""" - return get_username(event), is_admin(event) + return get_username(event), is_admin(event), get_groups(event) + + +def user_has_group_access(user_groups: List[str], allowed_groups: List[str]) -> bool: + """ + Check if user has access based on group membership. + + Args: + user_groups: List of groups the user belongs to + allowed_groups: List of groups allowed to access the resource + + Returns: + True if user has access (either no restrictions or user has required group) + """ + # Public resource (no group restrictions) + if not allowed_groups: + return True + + # Check if user has at least one matching group + return len(set(user_groups).intersection(set(allowed_groups))) > 0 def admin_only(func: Callable) -> Callable: diff --git a/lambda/utilities/bedrock_kb.py b/lambda/utilities/bedrock_kb.py index d80e2e4dd..7b122fbc7 100644 --- a/lambda/utilities/bedrock_kb.py +++ b/lambda/utilities/bedrock_kb.py @@ -20,90 +20,670 @@ from __future__ import annotations +import logging import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple +from models.domain_objects import ( + IngestionJob, + IngestionType, + JobActionType, + NoneChunkingStrategy, + RagCollectionConfig, + RagDocument, +) -def retrieve_documents( - bedrock_runtime_client: Any, +logger = logging.getLogger(__name__) + + +class S3DocumentDiscoveryResult: + """Result of S3 document discovery operation.""" + + def __init__( + self, + discovered: int = 0, + skipped: int = 0, + successful: int = 0, + failed: int = 0, + document_ids: Optional[List[str]] = None, + errors: Optional[List[str]] = None, + ): + self.discovered = discovered + self.skipped = skipped + self.successful = successful + self.failed = failed + self.document_ids = document_ids or [] + self.errors = errors or [] + + +class S3DocumentDiscoveryService: + """Service for discovering and tracking existing documents in S3 buckets.""" + + def __init__( + self, + s3_client: Any, + bedrock_agent_client: Any, + rag_document_repository: Any, + metadata_generator: Any, + s3_metadata_manager: Any, + collection_service: Any, + vector_store_repo: Any, + ): + """Initialize S3 document discovery service. + + Args: + s3_client: boto3 S3 client + bedrock_agent_client: boto3 bedrock-agent client + rag_document_repository: Repository for RagDocument persistence + metadata_generator: MetadataGenerator instance + s3_metadata_manager: S3MetadataManager instance + collection_service: CollectionService instance + vector_store_repo: VectorStoreRepository instance + """ + self.s3_client = s3_client + self.bedrock_agent_client = bedrock_agent_client + self.rag_document_repository = rag_document_repository + self.metadata_generator = metadata_generator + self.s3_metadata_manager = s3_metadata_manager + self.collection_service = collection_service + self.vector_store_repo = vector_store_repo + + def discover_and_ingest_documents( + self, + repository_id: str, + collection_id: str, + s3_bucket: str, + s3_prefix: str = "", + ingestion_type: IngestionType = IngestionType.EXISTING, + ) -> S3DocumentDiscoveryResult: + """ + Discover and ingest existing documents from S3 bucket. + + Scans S3 bucket, creates metadata.json files, creates RagDocument entries, + and triggers Bedrock KB sync. + + Args: + repository_id: Repository identifier + collection_id: Collection identifier + s3_bucket: S3 bucket to scan + s3_prefix: Optional S3 prefix to scan within bucket + ingestion_type: Type of ingestion (default: EXISTING) + + Returns: + S3DocumentDiscoveryResult with operation statistics + """ + logger.info(f"Starting S3 document discovery for bucket {s3_bucket} with prefix '{s3_prefix}'") + + result = S3DocumentDiscoveryResult() + + try: + # Get repository configuration + repository = self.vector_store_repo.find_repository_by_id(repository_id) + + # Get collection for metadata generation + collection = self._get_collection(repository_id, collection_id) + + # Scan S3 bucket for documents + documents_to_process, skipped_count = self._scan_s3_bucket(s3_bucket, s3_prefix) + result.discovered = len(documents_to_process) + result.skipped = skipped_count + + if not documents_to_process: + logger.info(f"No valid documents found in S3 bucket {s3_bucket} with prefix '{s3_prefix}'") + return result + + logger.info(f"Found {len(documents_to_process)} documents to process") + + # Process each document + for document_key in documents_to_process: + try: + s3_path = f"s3://{s3_bucket}/{document_key}" + + # Check if document already exists (idempotent) + if self._document_exists(repository_id, collection_id, s3_path): + existing_doc = next( + self.rag_document_repository.find_by_source( + repository_id, collection_id, s3_path, join_docs=False + ) + ) + result.document_ids.append(existing_doc.document_id) + result.successful += 1 + logger.info(f"Document {s3_path} already tracked, skipping") + continue + + # Create metadata.json file + self._create_metadata_file( + repository=repository, + collection=collection, + s3_bucket=s3_bucket, + document_key=document_key, + repository_id=repository_id, + collection_id=collection_id, + ) + + # Create RagDocument entry + document_id = self._create_rag_document( + repository_id=repository_id, + collection_id=collection_id, + s3_path=s3_path, + ingestion_type=ingestion_type, + ) + + result.document_ids.append(document_id) + result.successful += 1 + logger.info(f"Tracked existing document {s3_path}") + + except Exception as e: + result.failed += 1 + error_msg = f"Failed to process {document_key}: {str(e)}" + logger.error(error_msg, exc_info=True) + result.errors.append(error_msg) + + # Trigger Bedrock KB sync if documents were processed + if result.successful > 0: + self._trigger_kb_sync(repository, collection_id, result.successful) + + logger.info( + f"S3 discovery completed: {result.successful} successful, " + f"{result.failed} failed, {result.skipped} skipped" + ) + + return result + + except Exception as e: + logger.error(f"Failed to discover S3 documents: {str(e)}", exc_info=True) + raise + + def _scan_s3_bucket(self, s3_bucket: str, s3_prefix: str) -> Tuple[List[str], int]: + """ + Scan S3 bucket and return list of document keys. + + Args: + s3_bucket: S3 bucket name + s3_prefix: S3 prefix to scan + + Returns: + Tuple of (document_keys, skipped_count) + """ + list_params = {"Bucket": s3_bucket, "Delimiter": "/"} + if s3_prefix: + list_params["Prefix"] = s3_prefix if s3_prefix.endswith("/") else f"{s3_prefix}/" + + response = self.s3_client.list_objects_v2(**list_params) + + if "Contents" not in response: + return [], 0 + + documents_to_process = [] + skipped_count = 0 + + for obj in response["Contents"]: + key = obj["Key"] + # Skip metadata files and directories + if key.endswith("/") or key.endswith(".metadata.json"): + skipped_count += 1 + continue + documents_to_process.append(key) + + return documents_to_process, skipped_count + + def _get_collection(self, repository_id: str, collection_id: str) -> Optional[RagCollectionConfig]: + """Get collection configuration.""" + try: + return self.collection_service.get_collection( + collection_id=collection_id, + repository_id=repository_id, + username="system", + user_groups=[], + is_admin=True, + ) + except Exception as e: + logger.warning(f"Could not fetch collection for metadata: {e}") + return None + + def _document_exists(self, repository_id: str, collection_id: str, s3_path: str) -> bool: + """Check if document already exists in repository.""" + existing_docs = list( + self.rag_document_repository.find_by_source(repository_id, collection_id, s3_path, join_docs=False) + ) + return len(existing_docs) > 0 + + def _create_metadata_file( + self, + repository: Dict[str, Any], + collection: Optional[RagCollectionConfig], + s3_bucket: str, + document_key: str, + repository_id: str, + collection_id: str, + ) -> None: + """Create and upload metadata.json file for document.""" + try: + metadata_content = self.metadata_generator.generate_metadata_json( + repository=repository, + collection=collection, + document_metadata=None, # No document-specific metadata for existing docs + ) + + self.s3_metadata_manager.upload_metadata_file( + s3_client=self.s3_client, + bucket=s3_bucket, + document_key=document_key, + metadata_content=metadata_content, + repository_id=repository_id, + collection_id=collection_id, + ) + logger.info(f"Created metadata file for s3://{s3_bucket}/{document_key}") + except Exception as e: + logger.error(f"Failed to create metadata for {document_key}: {str(e)}") + # Continue - metadata is optional + + def _create_rag_document( + self, + repository_id: str, + collection_id: str, + s3_path: str, + ingestion_type: IngestionType, + ) -> str: + """Create and save RagDocument entry.""" + rag_document = RagDocument( + repository_id=repository_id, + collection_id=collection_id, + document_name=os.path.basename(s3_path), + source=s3_path, + subdocs=[], # Empty - KB manages chunks internally + chunk_strategy=NoneChunkingStrategy(), # KB manages chunking + username="system", # System-discovered + ingestion_type=ingestion_type, + ) + self.rag_document_repository.save(rag_document) + return rag_document.document_id + + def _trigger_kb_sync(self, repository: Dict[str, Any], collection_id: str, document_count: int) -> None: + """Trigger Bedrock KB sync for ingested documents.""" + bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + knowledge_base_id = bedrock_config.get("knowledgeBaseId", bedrock_config.get("bedrockKnowledgeBaseId")) + + if not knowledge_base_id: + logger.warning("No knowledge base ID found, skipping KB sync") + return + + logger.info(f"Triggering Bedrock KB sync for collection {collection_id}") + try: + self.bedrock_agent_client.start_ingestion_job( + knowledgeBaseId=knowledge_base_id, + dataSourceId=collection_id, + ) + logger.info(f"Successfully triggered KB sync for {document_count} documents") + except Exception as e: + logger.error(f"Failed to trigger KB sync: {str(e)}") + # Don't fail - documents are already tracked + + +logger = logging.getLogger(__name__) + + +def get_datasource_bucket_for_collection( repository: Dict[str, Any], - query: str, - top_k: int, - repository_id: str, -) -> List[Dict[str, Any]]: - """Retrieve documents from Bedrock Knowledge Base. + collection_id: str, +) -> str: + """ + Get the S3 bucket for a specific collection/data source. + + Supports multiple configuration formats: + - Legacy: bedrockKnowledgeDatasourceS3Bucket (single bucket) + - New: dataSources array with id and s3Uri per data source + - Pipeline: pipelines array with collectionId and s3Bucket Args: - bedrock_runtime_client: boto3 bedrock-agent-runtime client repository: Repository configuration dictionary - query: Text query to search - top_k: Number of results to return - repository_id: Repository identifier to include in metadata + collection_id: Collection/data source ID Returns: - List of documents in the format expected by callers + S3 bucket name + + Raises: + ValueError: If bucket cannot be determined """ bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + repository_id = repository.get("repositoryId", "unknown") - response = bedrock_runtime_client.retrieve( - knowledgeBaseId=bedrock_config.get("bedrockKnowledgeBaseId", None), - retrievalQuery={"text": query}, - retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": int(top_k)}}, - ) + # Try legacy format first + legacy_bucket = bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket") + if legacy_bucket: + return legacy_bucket - docs: List[Dict[str, Any]] = [] - for doc in response.get("retrievalResults", []): - uri = (doc.get("location", {}) or {}).get("s3Location", {}).get("uri") - name = uri.split("/")[-1] if uri else None - docs.append( - { - "page_content": (doc.get("content", {}) or {}).get("text", ""), - "metadata": { - "source": uri, - "name": name, - "repository_id": repository_id, - }, - } - ) + # Try pipelines array (most common in current configs) + pipelines = repository.get("pipelines", []) + for pipeline in pipelines: + # Handle both dict and object formats + pipeline_collection_id = pipeline.get("collectionId") if isinstance(pipeline, dict) else pipeline.collectionId + s3_bucket = pipeline.get("s3Bucket") if isinstance(pipeline, dict) else pipeline.s3Bucket + + if pipeline_collection_id == collection_id and s3_bucket: + return s3_bucket - return docs + # Try dataSources array + data_sources = bedrock_config.get("dataSources", []) + for data_source in data_sources: + # Handle both dict and object formats + ds_id = data_source.get("id") if isinstance(data_source, dict) else data_source.id + s3_uri = data_source.get("s3Uri") if isinstance(data_source, dict) else data_source.s3Uri + + if ds_id == collection_id: + # Extract bucket from s3Uri (format: s3://bucket/ or s3://bucket/prefix) + if s3_uri and s3_uri.startswith("s3://"): + bucket = s3_uri[5:].split("/")[0] + if bucket: + return bucket + + logger.error(f"Invalid s3Uri format for data source {ds_id}: {s3_uri}") + raise ValueError( + f"Data source {ds_id} has invalid s3Uri format: {s3_uri}. " + "Expected format: s3://bucket-name/ or s3://bucket-name/prefix" + ) + + # No matching configuration found + available_pipelines = [p.get("collectionId") if isinstance(p, dict) else p.collectionId for p in pipelines] + logger.error( + f"Repository {repository_id} missing S3 bucket configuration. " + f"Collection ID: {collection_id}, Available pipelines: {available_pipelines}, " + f"Available data sources: {[ds.get('id') if isinstance(ds, dict) else ds.id for ds in data_sources]}" + ) + raise ValueError( + f"Cannot determine S3 bucket for collection {collection_id}. " + "Repository configuration must include either:\n" + "- 'bedrockKnowledgeDatasourceS3Bucket' (legacy single bucket)\n" + f"- A pipeline with collectionId='{collection_id}' and s3Bucket field\n" + f"- A data source in 'dataSources' array with id='{collection_id}' and s3Uri" + ) def ingest_document_to_kb( s3_client: Any, bedrock_agent_client: Any, - job: Any, + job: IngestionJob, repository: Dict[str, Any], ) -> None: - """Copy the source object into the KB datasource bucket and trigger ingestion.""" + """ + Copy the source object into the KB datasource bucket and trigger ingestion. S3 will + kick off another IngestionJob to store the document in the collection DB + """ bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + # Get datasource bucket for this collection (supports multiple config formats) + datasource_bucket = get_datasource_bucket_for_collection( + repository=repository, + collection_id=job.collection_id, + ) + source_bucket = job.s3_path.split("/")[2] + source_key = job.s3_path.split(source_bucket + "/")[1] + s3_client.copy_object( - CopySource={"Bucket": source_bucket, "Key": job.s3_path.split(source_bucket + "/")[1]}, - Bucket=bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket", None), + CopySource={"Bucket": source_bucket, "Key": source_key}, + Bucket=datasource_bucket, Key=os.path.basename(job.s3_path), ) + + s3_client.delete_object(Bucket=source_bucket, Key=source_key) + + # Use collection_id from job as data source ID + data_source_id = job.collection_id + # Support both field names for backward compatibility + kb_id = bedrock_config.get("bedrockKnowledgeBaseId") or bedrock_config.get("knowledgeBaseId") + + if not kb_id: + logger.error(f"Repository {repository.get('repositoryId')} missing knowledge base ID") + raise ValueError( + "Bedrock KB repository is missing required field 'bedrockKnowledgeBaseId' or 'knowledgeBaseId'. " + "Please update the repository configuration with the actual AWS Bedrock Knowledge Base ID." + ) + bedrock_agent_client.start_ingestion_job( - knowledgeBaseId=bedrock_config.get("bedrockKnowledgeBaseId", None), - dataSourceId=bedrock_config.get("bedrockKnowledgeDatasourceId", None), + knowledgeBaseId=kb_id, + dataSourceId=data_source_id, ) def delete_document_from_kb( s3_client: Any, bedrock_agent_client: Any, - job: Any, + job: IngestionJob, repository: Dict[str, Any], ) -> None: """Remove the source object from the KB datasource bucket and re-sync the KB.""" bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + # Get datasource bucket for this collection (supports multiple config formats) + datasource_bucket = get_datasource_bucket_for_collection( + repository=repository, + collection_id=job.collection_id, + ) + s3_client.delete_object( - Bucket=bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket", None), + Bucket=datasource_bucket, Key=os.path.basename(job.s3_path), ) + + # Use collection_id from job as data source ID + data_source_id = job.collection_id + # Support both field names for backward compatibility + kb_id = bedrock_config.get("bedrockKnowledgeBaseId") or bedrock_config.get("knowledgeBaseId") + + if not kb_id: + logger.error(f"Repository {repository.get('repositoryId')} missing knowledge base ID") + raise ValueError( + "Bedrock KB repository is missing required field 'bedrockKnowledgeBaseId' or 'knowledgeBaseId'. " + "Please update the repository configuration with the actual AWS Bedrock Knowledge Base ID." + ) + + bedrock_agent_client.start_ingestion_job( + knowledgeBaseId=kb_id, + dataSourceId=data_source_id, + ) + + +def bulk_delete_documents_from_kb( + s3_client: Any, + bedrock_agent_client: Any, + repository: Dict[str, Any], + s3_paths: List[str], + data_source_id: Optional[str] = None, +) -> None: + """Bulk delete documents from KB datasource bucket and trigger single ingestion. + + Args: + s3_client: boto3 S3 client + bedrock_agent_client: boto3 bedrock-agent client + repository: Repository configuration dictionary + s3_paths: List of S3 paths to delete + data_source_id: Optional data source ID. If not provided, will try to get from config. + """ + bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) + datasource_bucket = bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket") + + # Batch delete from S3 (max 1000 per request) + batch_size = 1000 + for i in range(0, len(s3_paths), batch_size): + batch = s3_paths[i : i + batch_size] + delete_objects = [{"Key": os.path.basename(path)} for path in batch] + + if delete_objects: + s3_client.delete_objects(Bucket=datasource_bucket, Delete={"Objects": delete_objects}) + + # Determine data source ID + if not data_source_id: + # Try new structure with dataSources array + data_sources = bedrock_config.get("dataSources", []) + if data_sources: + first_data_source = data_sources[0] + data_source_id = ( + first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id + ) + else: + # Try legacy single data source ID + data_source_id = bedrock_config.get("bedrockKnowledgeDatasourceId") + + # Trigger single ingestion job to sync KB + # Support both field names for backward compatibility + kb_id = bedrock_config.get("bedrockKnowledgeBaseId") or bedrock_config.get("knowledgeBaseId") + + if not kb_id: + logger.error(f"Repository {repository.get('repositoryId')} missing knowledge base ID") + raise ValueError( + "Bedrock KB repository is missing required field 'bedrockKnowledgeBaseId' or 'knowledgeBaseId'. " + "Please update the repository configuration with the actual AWS Bedrock Knowledge Base ID." + ) + bedrock_agent_client.start_ingestion_job( - knowledgeBaseId=bedrock_config.get("bedrockKnowledgeBaseId", None), - dataSourceId=bedrock_config.get("bedrockKnowledgeDatasourceId", None), + knowledgeBaseId=kb_id, + dataSourceId=data_source_id, + ) + + +def ingest_bedrock_s3_documents( + s3_client: Any, + ingestion_job_repository: Any, + ingestion_service: Any, + repository_id: str, + collection_id: str, + s3_bucket: str, + embedding_model: str, + s3_prefix: str = "", + batch_size: int = 100, +) -> Tuple[int, int]: + """ + Discover and create ingestion jobs for existing documents in S3 bucket. + + Scans S3 bucket for documents and creates batch ingestion jobs. + Skips metadata files and directories. + + Args: + s3_client: boto3 S3 client + ingestion_job_repository: Repository for saving ingestion jobs + ingestion_service: Service for submitting jobs + repository_id: Repository identifier + collection_id: Collection identifier + s3_bucket: S3 bucket to scan + embedding_model: Embedding model identifier + s3_prefix: Optional S3 prefix to scan within bucket + batch_size: Number of documents per batch job (default: 100) + + Returns: + Tuple of (discovered_count, skipped_count) + """ + logger.info(f"Discovering documents in S3 bucket {s3_bucket} with prefix '{s3_prefix}'") + + try: + # List objects in S3 bucket + list_params = {"Bucket": s3_bucket, "Delimiter": "/"} + if s3_prefix: + list_params["Prefix"] = s3_prefix if s3_prefix.endswith("/") else f"{s3_prefix}/" + + response = s3_client.list_objects_v2(**list_params) + + if "Contents" not in response: + logger.info(f"No objects found in bucket {s3_bucket}") + return 0, 0 + + # Filter valid documents + documents_to_process = [] + skipped_count = 0 + + for obj in response["Contents"]: + key = obj["Key"] + # Skip metadata files and directories + if key.endswith("/") or key.endswith(".metadata.json"): + skipped_count += 1 + continue + documents_to_process.append(key) + + discovered_count = len(documents_to_process) + + if not documents_to_process: + logger.info(f"No valid documents found in bucket {s3_bucket}") + return discovered_count, skipped_count + + logger.info(f"Found {discovered_count} documents to process, {skipped_count} skipped") + + # Create batch jobs + for i in range(0, len(documents_to_process), batch_size): + batch = documents_to_process[i : i + batch_size] + s3_paths = [f"s3://{s3_bucket}/{key}" for key in batch] + + job = IngestionJob( + repository_id=repository_id, + collection_id=collection_id, + embedding_model=embedding_model, + chunk_strategy=NoneChunkingStrategy(), + s3_path=s3_paths[0] if s3_paths else "", # First path as primary + username="system", # System-initiated + ingestion_type=IngestionType.EXISTING, # Mark as pre-existing documents + job_type=JobActionType.DOCUMENT_BATCH_INGESTION, + s3_paths=s3_paths, + ) + + ingestion_job_repository.save(job) + ingestion_service.submit_create_job(job) + + logger.info(f"Created {(len(documents_to_process) + batch_size - 1) // batch_size} batch jobs") + return discovered_count, skipped_count + + except Exception as e: + logger.error(f"Failed to discover S3 documents: {str(e)}", exc_info=True) + return 0, 0 + + +def create_s3_scan_job( + ingestion_job_repository: Any, + ingestion_service: Any, + repository_id: str, + collection_id: str, + embedding_model: str, + s3_bucket: str, + s3_prefix: str = "", +) -> str: + """ + Create a batch ingestion job to scan and ingest existing S3 documents. + + This creates a batch job with empty s3_paths that will be processed by + pipeline_ingest_documents. The empty s3_paths signals that the S3 bucket + should be scanned to discover existing documents. + + Args: + ingestion_job_repository: Repository for saving ingestion jobs + ingestion_service: Service for submitting jobs + repository_id: Repository identifier + collection_id: Collection identifier + embedding_model: Embedding model identifier + s3_bucket: S3 bucket to scan + s3_prefix: Optional S3 prefix to scan within bucket + + Returns: + Job ID of the created scan job + """ + logger.info(f"Creating S3 scan job for bucket {s3_bucket} with prefix '{s3_prefix}'") + + # Store bucket/prefix in s3_path field for the scan job + # Format: s3://bucket/prefix (or just s3://bucket if no prefix) + scan_path = f"s3://{s3_bucket}/{s3_prefix}" if s3_prefix else f"s3://{s3_bucket}/" + + # Create batch job with empty s3_paths - this signals S3 scan mode + job = IngestionJob( + repository_id=repository_id, + collection_id=collection_id, + embedding_model=embedding_model, + chunk_strategy=NoneChunkingStrategy(), + s3_path=scan_path, # Store scan location + username="system", # System-initiated + ingestion_type=IngestionType.EXISTING, # Mark as pre-existing documents + job_type=JobActionType.DOCUMENT_BATCH_INGESTION, + s3_paths=[], # Empty list signals S3 scan mode ) + + ingestion_job_repository.save(job) + ingestion_service.submit_create_job(job) + + logger.info(f"Created S3 scan job {job.id} for bucket {s3_bucket}") + return job.id diff --git a/lambda/utilities/bedrock_kb_discovery.py b/lambda/utilities/bedrock_kb_discovery.py new file mode 100644 index 000000000..bea8ffb93 --- /dev/null +++ b/lambda/utilities/bedrock_kb_discovery.py @@ -0,0 +1,304 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Discovery service for Bedrock Knowledge Base data sources. + +This module provides functionality to discover and list Knowledge Bases and their +data sources from AWS Bedrock Agent APIs. It supports caching and pagination for +efficient resource discovery. +""" + +import logging +from typing import Any, Dict, List, Optional + +import boto3 +from botocore.exceptions import ClientError +from models.domain_objects import ( + ChunkingStrategyType, + DataSourceMetadata, + KnowledgeBaseMetadata, + PipelineTrigger, +) +from utilities.validation import ValidationError + +logger = logging.getLogger(__name__) + + +def list_knowledge_bases( + bedrock_agent_client: Optional[Any] = None, +) -> List[KnowledgeBaseMetadata]: + """ + List all Knowledge Bases accessible in the AWS account. + + Args: + bedrock_agent_client: Optional boto3 bedrock-agent client + + Returns: + List of KnowledgeBaseMetadata objects + + Raises: + ValidationError: If API call fails + """ + if not bedrock_agent_client: + bedrock_agent_client = boto3.client("bedrock-agent") + + try: + knowledge_bases = [] + next_token = None + + # Handle pagination + while True: + list_params = {"maxResults": 100} # Maximum allowed + if next_token: + list_params["nextToken"] = next_token + + response = bedrock_agent_client.list_knowledge_bases(**list_params) + + for kb_summary in response.get("knowledgeBaseSummaries", []): + knowledge_bases.append(KnowledgeBaseMetadata(**kb_summary)) + + # Check for more pages + next_token = response.get("nextToken") + if not next_token: + break + + logger.info(f"Discovered {len(knowledge_bases)} Knowledge Bases") + return knowledge_bases + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code == "AccessDeniedException": + raise ValidationError( + "Access denied to list Knowledge Bases. " "Please check IAM permissions for bedrock:ListKnowledgeBases." + ) + elif error_code == "ThrottlingException": + raise ValidationError("Rate limit exceeded while listing Knowledge Bases. " "Please try again later.") + else: + raise ValidationError(f"Failed to list Knowledge Bases: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error listing Knowledge Bases: {str(e)}", exc_info=True) + raise ValidationError(f"Unexpected error listing Knowledge Bases: {str(e)}") + + +def discover_kb_data_sources( + kb_id: str, + bedrock_agent_client: Optional[Any] = None, +) -> List[DataSourceMetadata]: + """ + Discover all data sources in a Bedrock Knowledge Base. + + Args: + kb_id: Knowledge Base ID + bedrock_agent_client: Optional boto3 bedrock-agent client + + Returns: + List of DataSourceMetadata objects + + Raises: + ValidationError: If KB doesn't exist or API call fails + """ + if not bedrock_agent_client: + bedrock_agent_client = boto3.client("bedrock-agent") + + try: + data_sources = [] + next_token = None + + # Handle pagination + while True: + list_params = { + "knowledgeBaseId": kb_id, + "maxResults": 100, # Maximum allowed + } + if next_token: + list_params["nextToken"] = next_token + + response = bedrock_agent_client.list_data_sources(**list_params) + + for ds_summary in response.get("dataSourceSummaries", []): + # Get detailed configuration for each data source + ds_detail = bedrock_agent_client.get_data_source( + knowledgeBaseId=kb_id, + dataSourceId=ds_summary["dataSourceId"], + ) + + data_source = ds_detail.get("dataSource", {}) + + # Extract S3 configuration and add to data_source + s3_config = extract_s3_configuration(data_source) + data_source["s3Bucket"] = s3_config["bucket"] + data_source["s3Prefix"] = s3_config["prefix"] + + ds_metadata = DataSourceMetadata(**data_source) + data_sources.append(ds_metadata) + + # Check for more pages + next_token = response.get("nextToken") + if not next_token: + break + + logger.info(f"Discovered {len(data_sources)} data sources in KB {kb_id}") + + return data_sources + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code == "ResourceNotFoundException": + raise ValidationError( + f"Knowledge Base '{kb_id}' not found. " f"Please verify the KB ID in the AWS Bedrock console." + ) + elif error_code == "AccessDeniedException": + raise ValidationError( + f"Access denied to Knowledge Base '{kb_id}'. " + f"Please check IAM permissions for bedrock:ListDataSources and bedrock:GetDataSource." + ) + elif error_code == "ThrottlingException": + raise ValidationError( + f"Rate limit exceeded while discovering data sources for KB '{kb_id}'. " f"Please try again later." + ) + else: + raise ValidationError(f"Failed to discover data sources: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error discovering data sources for KB {kb_id}: {str(e)}", exc_info=True) + raise ValidationError(f"Unexpected error discovering data sources: {str(e)}") + + +def extract_s3_configuration(data_source: Dict[str, Any]) -> Dict[str, str]: + """Extract S3 bucket and prefix from data source configuration. + + Args: + data_source: Data source configuration dictionary + + Returns: + Dictionary with 'bucket' and 'prefix' keys + """ + data_source_config = data_source.get("dataSourceConfiguration", {}) + s3_config = data_source_config.get("s3Configuration", {}) + + # Extract bucket from ARN (format: arn:aws:s3:::bucket-name) + bucket_arn = s3_config.get("bucketArn", "") + bucket = bucket_arn.split(":::")[-1] if bucket_arn else "" + + # Get first inclusion prefix if available + inclusion_prefixes = s3_config.get("inclusionPrefixes", []) + prefix = inclusion_prefixes[0] if inclusion_prefixes else "" + + return {"bucket": bucket, "prefix": prefix} + + +def build_pipeline_configs_from_kb_config( + kb_config: Any, +) -> List[Dict[str, Any]]: + """Build PipelineConfigs from BedrockKnowledgeBaseConfig. + + Args: + kb_config: BedrockKnowledgeBaseConfig object with knowledgeBaseId and dataSources array + + Returns: + List of PipelineConfig dictionaries + + Raises: + ValidationError: If duplicate data source IDs or S3 URIs found + """ + + pipeline_configs = [] + data_source_ids = set() + s3_uris = set() + + # Extract data sources (handle both dict and object) + if isinstance(kb_config, dict): + data_sources = kb_config.get("dataSources", []) + else: + data_sources = kb_config.dataSources + + for data_source in data_sources: + # Extract fields (handle both dict and object) + if isinstance(data_source, dict): + data_source_id = data_source.get("id") + data_source_name = data_source.get("name") + s3_uri = data_source.get("s3Uri") + else: + data_source_id = data_source.id + data_source_name = data_source.name + s3_uri = data_source.s3Uri + + # Validate required fields + if not data_source_id: + raise ValidationError("Data source ID is required") + if not data_source_name: + raise ValidationError("Data source name is required") + if not s3_uri: + raise ValidationError("S3 URI is required") + if not s3_uri.startswith("s3://"): + raise ValidationError(f"Invalid S3 URI format: {s3_uri}") + + # Check for duplicate data source IDs + if data_source_id in data_source_ids: + raise ValidationError(f"Duplicate data source ID: {data_source_id}") + data_source_ids.add(data_source_id) + + # Check for duplicate S3 URIs + if s3_uri in s3_uris: + raise ValidationError(f"Duplicate S3 URI: {s3_uri}") + s3_uris.add(s3_uri) + + # Parse S3 URI (s3://bucket/prefix) + s3_parts = s3_uri[5:].split("/", 1) # Remove s3:// and split + s3_bucket = s3_parts[0] + s3_prefix = s3_parts[1] if len(s3_parts) > 1 else "" + + # Build pipeline config with collectionId set to dataSourceId, and store name separately + pipeline_config = { + "s3Bucket": s3_bucket, + "s3Prefix": s3_prefix, + "collectionId": data_source_id, # Use data source ID as collection ID + "collectionName": data_source_name, # Store data source name for collection creation + "trigger": PipelineTrigger.EVENT.value, + "autoRemove": True, + "chunkingStrategy": {"type": ChunkingStrategyType.NONE.value}, + } + pipeline_configs.append(pipeline_config) + + logger.info(f"Built {len(pipeline_configs)} pipeline configs from {len(data_sources)} data sources") + return pipeline_configs + + +def get_available_data_sources( + kb_id: str, + repository_id: Optional[str] = None, + bedrock_agent_client: Optional[Any] = None, +) -> List[DataSourceMetadata]: + """ + Get all data sources for a Knowledge Base. + + Args: + kb_id: Knowledge Base ID + repository_id: Optional repository ID (unused, for API compatibility) + bedrock_agent_client: Optional boto3 bedrock-agent client + + Returns: + List of DataSourceMetadata objects + + Raises: + ValidationError: If KB doesn't exist or API call fails + """ + # Get all data sources for the KB + all_data_sources = discover_kb_data_sources( + kb_id=kb_id, + bedrock_agent_client=bedrock_agent_client, + ) + + logger.info(f"Found {len(all_data_sources)} data sources for KB {kb_id}") + + return all_data_sources diff --git a/lambda/utilities/bedrock_kb_validation.py b/lambda/utilities/bedrock_kb_validation.py new file mode 100644 index 000000000..a4c65d86f --- /dev/null +++ b/lambda/utilities/bedrock_kb_validation.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validation utilities for Bedrock Knowledge Base operations.""" + +import logging +from typing import Any, Dict, Optional + +import boto3 +from botocore.exceptions import ClientError +from utilities.validation import ValidationError + +logger = logging.getLogger(__name__) + + +def validate_bedrock_kb_exists(kb_id: str, bedrock_agent_client: Optional[Any] = None) -> Dict[str, Any]: + """ + Validate that a Bedrock Knowledge Base exists and is accessible. + + Args: + kb_id: Knowledge Base ID to validate + bedrock_agent_client: Optional boto3 bedrock-agent client (creates one if not provided) + + Returns: + Knowledge Base configuration dictionary + + Raises: + ValidationError: If KB doesn't exist or is not accessible + """ + if not bedrock_agent_client: + bedrock_agent_client = boto3.client("bedrock-agent") + + try: + response = bedrock_agent_client.get_knowledge_base(knowledgeBaseId=kb_id) + kb_config = response.get("knowledgeBase", {}) + + logger.info(f"Validated Knowledge Base {kb_id}: {kb_config.get('name')}") + return kb_config + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + + if error_code == "ResourceNotFoundException": + raise ValidationError( + f"Knowledge Base '{kb_id}' not found. " f"Please verify the KB ID in the AWS Bedrock console." + ) + elif error_code == "AccessDeniedException": + raise ValidationError( + f"Access denied to Knowledge Base '{kb_id}'. " + f"Please check IAM permissions for bedrock:GetKnowledgeBase." + ) + else: + raise ValidationError(f"Failed to validate Knowledge Base '{kb_id}': {str(e)}") + except Exception as e: + raise ValidationError(f"Unexpected error validating Knowledge Base '{kb_id}': {str(e)}") + + +def validate_data_source_exists( + kb_id: str, data_source_id: str, bedrock_agent_client: Optional[Any] = None +) -> Dict[str, Any]: + """ + Validate that a data source exists in a Bedrock Knowledge Base. + + Args: + kb_id: Knowledge Base ID + data_source_id: Data Source ID to validate + bedrock_agent_client: Optional boto3 bedrock-agent client + + Returns: + Data source configuration dictionary + + Raises: + ValidationError: If data source doesn't exist or is not accessible + """ + if not bedrock_agent_client: + bedrock_agent_client = boto3.client("bedrock-agent") + + try: + response = bedrock_agent_client.get_data_source(knowledgeBaseId=kb_id, dataSourceId=data_source_id) + data_source_config = response.get("dataSource", {}) + + logger.info(f"Validated Data Source {data_source_id} in KB {kb_id}: " f"{data_source_config.get('name')}") + return data_source_config + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + + if error_code == "ResourceNotFoundException": + raise ValidationError( + f"Data Source '{data_source_id}' not found in Knowledge Base '{kb_id}'. " + f"Please verify the Data Source ID in the AWS Bedrock console." + ) + elif error_code == "AccessDeniedException": + raise ValidationError( + f"Access denied to Data Source '{data_source_id}'. " + f"Please check IAM permissions for bedrock:GetDataSource." + ) + else: + raise ValidationError(f"Failed to validate Data Source '{data_source_id}': {str(e)}") + except Exception as e: + raise ValidationError(f"Unexpected error validating Data Source '{data_source_id}': {str(e)}") + + +def validate_bedrock_kb_repository( + kb_id: str, data_source_id: str, bedrock_agent_client: Optional[Any] = None +) -> tuple[Dict[str, Any], Dict[str, Any]]: + """ + Validate both Knowledge Base and Data Source exist. + + Args: + kb_id: Knowledge Base ID + data_source_id: Data Source ID + bedrock_agent_client: Optional boto3 bedrock-agent client + + Returns: + Tuple of (kb_config, data_source_config) + + Raises: + ValidationError: If validation fails + """ + if not bedrock_agent_client: + bedrock_agent_client = boto3.client("bedrock-agent") + + # Validate KB exists + kb_config = validate_bedrock_kb_exists(kb_id, bedrock_agent_client) + + # Validate data source exists + data_source_config = validate_data_source_exists(kb_id, data_source_id, bedrock_agent_client) + + logger.info(f"Successfully validated Bedrock KB repository: KB={kb_id}, DataSource={data_source_id}") + + return kb_config, data_source_config diff --git a/lambda/utilities/chunking_strategy_factory.py b/lambda/utilities/chunking_strategy_factory.py new file mode 100644 index 000000000..d3f5d3fbb --- /dev/null +++ b/lambda/utilities/chunking_strategy_factory.py @@ -0,0 +1,192 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory pattern for creating chunking strategies.""" +import logging +import os +from abc import ABC, abstractmethod +from typing import List + +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_core.documents import Document +from models.domain_objects import ChunkingStrategy, ChunkingStrategyType, FixedChunkingStrategy +from utilities.exceptions import RagUploadException + +logger = logging.getLogger(__name__) + +DEFAULT_STRATEGY = FixedChunkingStrategy(size=os.getenv("CHUNK_SIZE", "512"), overlap=os.getenv("CHUNK_OVERLAP", "51")) + + +class ChunkingStrategyHandler(ABC): + """Abstract base class for chunking strategy handlers.""" + + @abstractmethod + def chunk_documents(self, docs: List[Document], strategy: ChunkingStrategy) -> List[Document]: + """ + Chunk documents according to the strategy. + + Parameters + ---------- + docs : List[Document] + List of documents to chunk + strategy : ChunkingStrategy + The chunking strategy configuration + + Returns + ------- + List[Document] + List of chunked documents + """ + pass + + +class FixedSizeChunkingHandler(ChunkingStrategyHandler): + """Handler for fixed-size chunking strategy.""" + + def chunk_documents(self, docs: List[Document], strategy: ChunkingStrategy = DEFAULT_STRATEGY) -> List[Document]: + """ + Chunk documents using fixed-size strategy with RecursiveCharacterTextSplitter. + + Parameters + ---------- + docs : List[Document] + List of documents to chunk + strategy : ChunkingStrategy + The chunking strategy configuration (FixedChunkingStrategy) + + Returns + ------- + List[Document] + List of chunked documents + """ + # Handle both legacy (size/overlap) and new (chunkSize/chunkOverlap) formats + chunk_size = strategy.size + chunk_overlap = strategy.overlap + + # Apply defaults from environment if not specified + if not chunk_size: + chunk_size = int(os.getenv("CHUNK_SIZE", "512")) + if not chunk_overlap: + chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "51")) + + # Validate parameters + if chunk_size < 100 or chunk_size > 10000: + raise RagUploadException("Invalid chunk size: must be between 100 and 10000") + + if chunk_overlap < 0 or chunk_overlap >= chunk_size: + raise RagUploadException("Invalid chunk overlap: must be non-negative and less than chunk size") + + logger.info(f"Chunking documents with fixed size strategy: size={chunk_size}, overlap={chunk_overlap}") + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + length_function=len, + ) + return text_splitter.split_documents(docs) # type: ignore [no-any-return] + + +class NoneChunkingHandler(ChunkingStrategyHandler): + """Handler for no-chunking strategy - returns documents as-is.""" + + def chunk_documents(self, docs: List[Document], strategy: ChunkingStrategy) -> List[Document]: + """ + Return documents without chunking. + + Parameters + ---------- + docs : List[Document] + List of documents to process + strategy : ChunkingStrategy + The chunking strategy configuration (NoneChunkingStrategy) + + Returns + ------- + List[Document] + Original list of documents unmodified + """ + logger.info(f"Processing {len(docs)} documents with NONE chunking strategy (no chunking)") + return docs + + +class ChunkingStrategyFactory: + """Factory for creating and executing chunking strategies.""" + + _handlers = { + ChunkingStrategyType.FIXED: FixedSizeChunkingHandler(), + ChunkingStrategyType.NONE: NoneChunkingHandler(), + } + + @classmethod + def chunk_documents(cls, docs: List[Document], strategy: ChunkingStrategy = DEFAULT_STRATEGY) -> List[Document]: + """ + Chunk documents using the appropriate strategy handler. + + Parameters + ---------- + docs : List[Document] + List of documents to chunk + strategy : ChunkingStrategy + The chunking strategy configuration + + Returns + ------- + List[Document] + List of chunked documents + + Raises + ------ + ValueError + If the chunking strategy type is not supported + """ + if strategy is None: + strategy = DEFAULT_STRATEGY + handler = cls._handlers.get(strategy.type) + if not handler: + supported_strategies = ", ".join([s.value for s in cls._handlers.keys()]) + logger.error( + f"Unsupported chunking strategy: {strategy.type}. Supported strategies: {supported_strategies}" + ) + raise ValueError(f"Unsupported chunking strategy: {strategy.type}") + + return handler.chunk_documents(docs, strategy) + + @classmethod + def register_handler(cls, strategy_type: ChunkingStrategyType, handler: ChunkingStrategyHandler) -> None: + """ + Register a new chunking strategy handler. + + This allows for extending the factory with additional chunking strategies. + + Parameters + ---------- + strategy_type : ChunkingStrategyType + The strategy type to register + handler : ChunkingStrategyHandler + The handler instance for this strategy + """ + cls._handlers[strategy_type] = handler + logger.info(f"Registered chunking strategy handler: {strategy_type.value}") + + @classmethod + def get_supported_strategies(cls) -> List[ChunkingStrategyType]: + """ + Get list of supported chunking strategy types. + + Returns + ------- + List[ChunkingStrategyType] + List of supported strategy types + """ + return list(cls._handlers.keys()) diff --git a/lambda/utilities/common_functions.py b/lambda/utilities/common_functions.py index c7e9079dc..24c70d48e 100644 --- a/lambda/utilities/common_functions.py +++ b/lambda/utilities/common_functions.py @@ -20,9 +20,10 @@ import os import tempfile from contextvars import ContextVar +from datetime import datetime from decimal import Decimal from functools import cache -from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, cast, Dict, Optional, TypeVar, Union import boto3 from botocore.config import Config @@ -213,6 +214,8 @@ class DecimalEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: if isinstance(obj, Decimal): return float(obj) + if isinstance(obj, datetime): + return obj.isoformat() return super().default(obj) @@ -261,28 +264,34 @@ def generate_exception_response( Dict[str, Union[str, int, Dict[str, str]]] An HTML response. """ + # Check for ValidationError from utilities.validation status_code = 400 - if hasattr(e, "response"): # i.e. validate the exception was from an API call + error_message: str + if type(e).__name__ == "ValidationError": + error_message = str(e) + logger.exception(e) + elif hasattr(e, "response"): # i.e. validate the exception was from an API call metadata = e.response.get("ResponseMetadata") if metadata: status_code = metadata.get("HTTPStatusCode", 400) + error_message = str(e) logger.exception(e) elif hasattr(e, "http_status_code"): status_code = e.http_status_code - e = e.message # type: ignore [assignment] + error_message = getattr(e, "message", str(e)) logger.exception(e) elif hasattr(e, "status_code"): status_code = e.status_code - e = e.message # type: ignore [assignment] + error_message = getattr(e, "message", str(e)) logger.exception(e) else: error_msg = str(e) if error_msg in ["'requestContext'", "'pathParameters'", "'body'"]: - e = f"Missing event parameter: {error_msg}" # type: ignore [assignment] + error_message = f"Missing event parameter: {error_msg}" else: - e = f"Bad Request: {error_msg}" # type: ignore [assignment] + error_message = f"Bad Request: {error_msg}" logger.exception(e) - return generate_html_response(status_code, e) # type: ignore [arg-type] + return generate_html_response(status_code, error_message) # type: ignore [arg-type] def get_id_token(event: dict) -> str: @@ -363,24 +372,12 @@ def get_rest_api_container_endpoint() -> str: return f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve" -def get_username(event: dict) -> str: - """Get the username from the event.""" - username: str = event.get("requestContext", {}).get("authorizer", {}).get("username", "system") - return username - - def get_session_id(event: dict) -> str: """Get the session ID from the event.""" session_id: str = event.get("pathParameters", {}).get("sessionId") return session_id -def get_groups(event: Any) -> List[str]: - """Get user groups from event.""" - groups: List[str] = json.loads(event.get("requestContext", {}).get("authorizer", {}).get("groups", "[]")) - return groups - - def get_principal_id(event: Any) -> str: """Get principal from event.""" principal: str = event.get("requestContext", {}).get("authorizer", {}).get("principal", "") @@ -466,25 +463,6 @@ def get_item(response: Any) -> Any: return items[0] if items else None -def user_has_group_access(user_groups: List[str], allowed_groups: List[str]) -> bool: - """ - Check if user has access based on group membership. - - Args: - user_groups: List of groups the user belongs to - allowed_groups: List of groups allowed to access the resource - - Returns: - True if user has access (either no restrictions or user has required group) - """ - # Public resource (no group restrictions) - if not allowed_groups: - return True - - # Check if user has at least one matching group - return len(set(user_groups).intersection(set(allowed_groups))) > 0 - - def get_property_path(data: dict[str, Any], property_path: str) -> Optional[Any]: """Get the value represented by a property path.""" props = property_path.split(".") diff --git a/lambda/utilities/exceptions.py b/lambda/utilities/exceptions.py index ae88844cb..85ec0eea7 100644 --- a/lambda/utilities/exceptions.py +++ b/lambda/utilities/exceptions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + """Exceptions from handling RAG documents.""" @@ -24,3 +25,13 @@ def __init__(self, status_code: int = 400, message: str = "Bad Request") -> None self.http_status_code = status_code self.message = message super().__init__(self.message) + + +class NotFoundException(HTTPException): + def __init__(self, detail: str = "Not Found"): + super().__init__(404, detail) + + +class UnauthorizedException(HTTPException): + def __init__(self, detail: str = "Unauthorized"): + super().__init__(401, detail) diff --git a/lambda/utilities/file_processing.py b/lambda/utilities/file_processing.py index db04f85ce..e8a557539 100644 --- a/lambda/utilities/file_processing.py +++ b/lambda/utilities/file_processing.py @@ -16,17 +16,16 @@ import logging import os from io import BytesIO -from typing import Any, List, Optional from urllib.parse import urlparse import boto3 import docx from botocore.exceptions import ClientError -from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document -from models.domain_objects import ChunkingStrategyType, IngestionJob +from models.domain_objects import IngestionJob from pypdf import PdfReader from pypdf.errors import PdfReadError +from utilities.chunking_strategy_factory import ChunkingStrategyFactory from utilities.constants import DOCX_FILE, PDF_FILE, RICH_TEXT_FILE, TEXT_FILE from utilities.exceptions import RagUploadException @@ -35,8 +34,25 @@ s3 = session.client("s3", region_name=os.environ["AWS_REGION"]) -def _get_metadata(s3_uri: str, name: str) -> dict: - return {"source": s3_uri, "name": name} +def _get_metadata(s3_uri: str, name: str, metadata: dict | None = None) -> dict: + """ + Create metadata dictionary for a document. + + Args: + s3_uri: S3 URI of the document + name: Name of the document + metadata: Optional additional metadata to merge into the result + + Returns: + Dictionary containing document metadata + """ + base_metadata = {"source": s3_uri, "name": name} + + # Merge additional metadata if provided + if metadata: + base_metadata.update(metadata) + + return base_metadata def _get_s3_uri(bucket: str, key: str) -> str: @@ -59,26 +75,6 @@ def _extract_text_by_content_type(content_type: str, s3_object: dict) -> str: raise RagUploadException("Unsupported file type") -def _generate_chunks(docs: List[Document], chunk_size: Optional[int], chunk_overlap: Optional[int]) -> List[Document]: - if not chunk_size: - chunk_size = int(os.getenv("CHUNK_SIZE", "512")) - if not chunk_overlap: - chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "51")) - - if chunk_size < 100 or chunk_size > 10000: - raise RagUploadException("Invalid chunk size") - - if chunk_overlap < 0 or chunk_overlap >= chunk_size: - raise RagUploadException("Invalid chunk overlap") - - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - length_function=len, - ) - return text_splitter.split_documents(docs) # type: ignore [no-any-return] - - def _extract_pdf_content(s3_object: dict) -> str: """Return text extracted from PDF. @@ -140,15 +136,24 @@ def _extract_text_content(s3_object: dict) -> str: def generate_chunks(ingestion_job: IngestionJob) -> list[Document]: - """Generate chunks from an ingestion job. + """Generate chunks from an ingestion job using the configured chunking strategy. Parameters ---------- - ingestion_job (IngestionJob): Ingestion job containing file information and chunking strategy + ingestion_job : IngestionJob + Ingestion job containing file information and chunking strategy Returns ------- - List[Document]: List of document chunks for the processed file + list[Document] + List of document chunks for the processed file + + Raises + ------ + RagUploadException + If S3 path is invalid or file processing fails + ValueError + If chunking strategy is not supported """ # Parse S3 URI using urlparse parsed_uri = urlparse(ingestion_job.s3_path) @@ -166,29 +171,25 @@ def generate_chunks(ingestion_job: IngestionJob) -> list[Document]: logger.error(f"Error getting object from S3: {key}") raise e - chunk_strategy = ingestion_job.chunk_strategy - if chunk_strategy.type == ChunkingStrategyType.FIXED: - return generate_fixed_chunks(ingestion_job, content_type, s3_object) - - logger.error(f"Unrecognized chunk strategy {chunk_strategy.type}") - raise Exception("Unrecognized chunk strategy") - - -def generate_fixed_chunks(ingestion_job: IngestionJob, content_type: str, s3_object: Any) -> list[Document]: - # Get chunk parameters from chunking strategy - chunk_size = ingestion_job.chunk_strategy.size if ingestion_job.chunk_strategy.size else None - chunk_overlap = ingestion_job.chunk_strategy.overlap if ingestion_job.chunk_strategy.overlap else None - # Extract text and create initial document extracted_text = _extract_text_by_content_type(content_type=content_type, s3_object=s3_object) basename = os.path.basename(ingestion_job.s3_path) - docs = [Document(page_content=extracted_text, metadata=_get_metadata(s3_uri=ingestion_job.s3_path, name=basename))] - # Generate chunks using existing helper function - doc_chunks = _generate_chunks(docs, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + # Pass metadata from IngestionJob to be merged into document metadata + docs = [ + Document( + page_content=extracted_text, + metadata=_get_metadata(s3_uri=ingestion_job.s3_path, name=basename, metadata=ingestion_job.metadata), + ) + ] + + # Use factory to chunk documents based on strategy + logger.info(f"Processing document with chunking strategy: {ingestion_job.chunk_strategy.type}") + doc_chunks = ChunkingStrategyFactory.chunk_documents(docs, ingestion_job.chunk_strategy) # Update part number of doc metadata for i, doc in enumerate(doc_chunks): doc.metadata["part"] = i + 1 + logger.info(f"Generated {len(doc_chunks)} chunks for document: {basename}") return doc_chunks diff --git a/lambda/utilities/repository_types.py b/lambda/utilities/repository_types.py index 0ec62cdd5..1800d951f 100644 --- a/lambda/utilities/repository_types.py +++ b/lambda/utilities/repository_types.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from enum import Enum from typing import Any, Dict @@ -22,11 +24,11 @@ class RepositoryType(str, Enum): BEDROCK_KB = "bedrock_knowledge_base" @classmethod - def get_type(cls, repository: Dict[str, Any]) -> "RepositoryType": + def get_type(cls, repository: Dict[str, Any]) -> RepositoryType: return RepositoryType(repository.get("type")) @classmethod - def is_type(cls, repository: Dict[str, Any], repo_type: "RepositoryType") -> bool: + def is_type(cls, repository: Dict[str, Any], repo_type: RepositoryType) -> bool: return repository.get("type") == repo_type def calculate_similarity_score(self, score: float) -> float: diff --git a/lambda/utilities/validation.py b/lambda/utilities/validation.py index 781a0484a..7ab79c5cd 100644 --- a/lambda/utilities/validation.py +++ b/lambda/utilities/validation.py @@ -14,8 +14,12 @@ """Validation utilities for Lambda functions.""" import logging +from typing import Any, List + +import botocore.session logger = logging.getLogger(__name__) +sess = botocore.session.Session() class ValidationError(Exception): @@ -24,6 +28,11 @@ class ValidationError(Exception): pass +# Alias for backward compatibility and +# clarity with Pydantic ValidationError +RequestValidationError = ValidationError + + class SecurityError(Exception): """Custom exception for security-related errors.""" @@ -51,6 +60,48 @@ def validate_model_name(model_name: str) -> bool: return True +def validate_instance_type(type: str) -> str: + """Validate that the type is a valid EC2 instance type. + + Args: + type: EC2 instance type to validate + + Returns: + str: The validated instance type + + Raises: + ValueError: If instance type is invalid + """ + if type in sess.get_service_model("ec2").shape_for("InstanceType").enum: + return type + + raise ValueError("Invalid EC2 instance type.") + + +def validate_all_fields_defined(fields: List[Any]) -> bool: + """Validate that all fields are non-null in the field list. + + Args: + fields: List of fields to validate + + Returns: + bool: True if all fields are non-null, False otherwise + """ + return all((field is not None for field in fields)) + + +def validate_any_fields_defined(fields: List[Any]) -> bool: + """Validate that at least one field is non-null in the field list. + + Args: + fields: List of fields to validate + + Returns: + bool: True if at least one field is non-null, False otherwise + """ + return any((field is not None for field in fields)) + + def safe_error_response(error: Exception) -> dict: """Create a safe error response that doesn't leak implementation details. diff --git a/lambda/utilities/validators.py b/lambda/utilities/validators.py deleted file mode 100644 index 1804c523a..000000000 --- a/lambda/utilities/validators.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functional validators for use with Pydantic.""" - -from typing import Any, List - -import botocore.session - -sess = botocore.session.Session() - - -def validate_instance_type(type: str) -> str: - """Validate that the type is a valid EC2 instance type.""" - if type in sess.get_service_model("ec2").shape_for("InstanceType").enum: - return type - - raise ValueError("Invalid EC2 instance type.") - - -def validate_all_fields_defined(fields: List[Any]) -> bool: - """Validate that all fields are non-null in the field list.""" - return all((field is not None for field in fields)) - - -def validate_any_fields_defined(fields: List[Any]) -> bool: - """Validate that at least one field is non-null in the field list.""" - return any((field is not None for field in fields)) diff --git a/lambda/utilities/vector_store.py b/lambda/utilities/vector_store.py deleted file mode 100644 index 5634f47a0..000000000 --- a/lambda/utilities/vector_store.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Helper to return Langchain vector store corresponding to backing store.""" -import json -import logging -import os - -import boto3 -from langchain_community.vectorstores import OpenSearchVectorSearch, PGVector -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore -from opensearchpy import RequestsHttpConnection -from requests_aws4auth import AWS4Auth -from utilities.common_functions import get_lambda_role_name, retry_config -from utilities.rds_auth import generate_auth_token -from utilities.repository_types import RepositoryType - -from . import create_env_variables # noqa type: ignore - -opensearch_endpoint = "" -logger = logging.getLogger(__name__) -session = boto3.Session() -ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) -secretsmanager_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) - - -def get_vector_store_client(repository_id: str, index: str, embeddings: Embeddings) -> VectorStore: - """Return Langchain VectorStore corresponding to the specified store. - - Creates a langchain vector store based on the specified embeddigs adapter and backing store. - """ - prefix = os.environ.get("REGISTERED_REPOSITORIES_PS_PREFIX") - connection_info = ssm_client.get_parameter(Name=f"{prefix}{repository_id}") - connection_info = json.loads(connection_info["Parameter"]["Value"]) - if RepositoryType.is_type(connection_info, RepositoryType.OPENSEARCH): - service = "es" - credentials = session.get_credentials() - - auth = AWS4Auth( - credentials.access_key, - credentials.secret_key, - session.region_name, - service, - session_token=credentials.token, - ) - - opensearch_endpoint = f"https://{connection_info.get('endpoint')}" - - return OpenSearchVectorSearch( - opensearch_url=opensearch_endpoint, - index_name=index, - embedding_function=embeddings, - http_auth=auth, - timeout=300, - use_ssl=True, - verify_certs=True, - connection_class=RequestsHttpConnection, - ) - - elif RepositoryType.is_type(connection_info, RepositoryType.PGVECTOR): - if "passwordSecretId" in connection_info: - # provides backwards compatibility to non-iam authenticated vector stores - secrets_response = secretsmanager_client.get_secret_value(SecretId=connection_info.get("passwordSecretId")) - user = connection_info.get("username") - password = json.loads(secrets_response.get("SecretString")).get("password") - else: - # use IAM auth token to connect - user = get_lambda_role_name() - password = generate_auth_token(connection_info.get("dbHost"), connection_info.get("dbPort"), user) - - connection_string = PGVector.connection_string_from_db_params( - driver="psycopg2", - host=connection_info.get("dbHost"), - port=connection_info.get("dbPort"), - database=connection_info.get("dbName"), - user=user, - password=password, - ) - return PGVector( - collection_name=index, - connection_string=connection_string, - embedding_function=embeddings, - ) - - raise ValueError(f"Unrecognized RAG store: '{repository_id}'") diff --git a/lib/api-base/authorizer.ts b/lib/api-base/authorizer.ts index 2f0a6c7a6..0b636c37a 100644 --- a/lib/api-base/authorizer.ts +++ b/lib/api-base/authorizer.ts @@ -44,6 +44,7 @@ export type AuthorizerProps = { vpc: Vpc; securityGroups: ISecurityGroup[]; tokenTable: ITable | undefined; + managementKeySecretName: string; } & BaseProps; /** @@ -61,7 +62,7 @@ export class CustomAuthorizer extends Construct { constructor (scope: Construct, id: string, props: AuthorizerProps) { super(scope, id); - const { config, role, vpc, securityGroups, tokenTable } = props; + const { config, role, vpc, securityGroups, tokenTable, managementKeySecretName } = props; const commonLambdaLayer = LayerVersion.fromLayerVersionArn( this, @@ -75,8 +76,6 @@ export class CustomAuthorizer extends Construct { StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/layerVersion/authorizer`), ); - const managementKeySecretNameStringParameter = StringParameter.fromStringParameterName(this, createCdkId([id, 'managementKeyStringParameter']), `${config.deploymentPrefix}/managementKeySecretName`); - // Create Lambda authorizer const lambdaPath = config.lambdaPath || LAMBDA_PATH; const authorizerLambda = new Function(this, 'AuthorizerLambda', { @@ -95,7 +94,7 @@ export class CustomAuthorizer extends Construct { ADMIN_GROUP: config.authConfig!.adminGroup, USER_GROUP: config.authConfig!.userGroup, JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty, - MANAGEMENT_KEY_NAME: managementKeySecretNameStringParameter.stringValue, + MANAGEMENT_KEY_NAME: managementKeySecretName, ...(tokenTable ? { TOKEN_TABLE_NAME: tokenTable?.tableName } : {}) }, role: role, @@ -108,7 +107,7 @@ export class CustomAuthorizer extends Construct { tokenTable.grantReadData(authorizerLambda); } - const managementKeySecret = Secret.fromSecretNameV2(this, createCdkId([id, 'managementKey']), managementKeySecretNameStringParameter.stringValue); + const managementKeySecret = Secret.fromSecretNameV2(this, createCdkId([id, 'managementKey']), managementKeySecretName); managementKeySecret.grantRead(authorizerLambda); // Update diff --git a/lib/core/apiBaseConstruct.ts b/lib/core/apiBaseConstruct.ts index 8467309c7..73693b0e1 100644 --- a/lib/core/apiBaseConstruct.ts +++ b/lib/core/apiBaseConstruct.ts @@ -14,19 +14,38 @@ limitations under the License. */ -import { Stack, StackProps } from 'aws-cdk-lib'; + import { Authorizer, Cors, EndpointType, RestApi, StageOptions } from 'aws-cdk-lib/aws-apigateway'; -import { Construct } from 'constructs'; + +import { AttributeType, BillingMode, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; import { CustomAuthorizer } from '../api-base/authorizer'; -import { BaseProps } from '../schema'; +import { Duration, Stack, StackProps } from 'aws-cdk-lib'; +import { ITable, Table } from 'aws-cdk-lib/aws-dynamodb'; +import { StringParameter } from 'aws-cdk-lib/aws-ssm'; +import { Construct } from 'constructs'; +import { Code, Function, } from 'aws-cdk-lib/aws-lambda'; + +import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; -import { Role } from 'aws-cdk-lib/aws-iam'; -import { ITable } from 'aws-cdk-lib/aws-dynamodb'; +import { APP_MANAGEMENT_KEY, BaseProps, Config } from '../schema'; +import { + Effect, + ManagedPolicy, + PolicyDocument, + PolicyStatement, + Role, + ServicePrincipal, +} from 'aws-cdk-lib/aws-iam'; +import { Secret } from 'aws-cdk-lib/aws-secretsmanager'; +import { LAMBDA_PATH } from '../util'; +import { getDefaultRuntime } from '../api-base/utils'; +import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { EventBus } from 'aws-cdk-lib/aws-events'; export type LisaApiBaseProps = { vpc: Vpc; - tokenTable: ITable | undefined; + securityGroups: ISecurityGroup[]; } & BaseProps & StackProps; @@ -39,11 +58,47 @@ export class LisaApiBaseConstruct extends Construct { public readonly restApiId: string; public readonly rootResourceId: string; public readonly restApiUrl: string; + public readonly tokenTable?: ITable; + public readonly managementKeySecretName: string; constructor (scope: Stack, id: string, props: LisaApiBaseProps) { super(scope, id); - const { config, vpc, tokenTable } = props; + const { config, vpc, securityGroups } = props; + + // TokenTable is now managed in API Base so it's independent of Serve + // Create the table - if it already exists from previous Serve deployment, + // CloudFormation will handle the conflict. For new deployments, it will be created. + let tokenTable: ITable | undefined; + + // Use new table name to avoid conflicts with existing Serve stack deployments + const tableName = `${config.deploymentName}-LISAApiBaseTokenTable`; + const tokenTableNameParam = `${config.deploymentPrefix}/tokenTableName`; + + // Create the table with new name + // Serve stack will automatically use the new table via SSM parameter reference + tokenTable = new Table(scope, 'TokenTable', { + tableName: tableName, + partitionKey: { + name: 'token', + type: AttributeType.STRING, + }, + billingMode: BillingMode.PAY_PER_REQUEST, + encryption: TableEncryption.AWS_MANAGED, + removalPolicy: config.removalPolicy, + }); + + // Store token table name in SSM for cross-stack reference + new StringParameter(scope, 'TokenTableNameParameter', { + parameterName: tokenTableNameParam, + stringValue: tokenTable.tableName, + description: 'DynamoDB table name for API tokens', + }); + + this.tokenTable = tokenTable; + + const { managementKeySecretName } = this.createManagementKeySecret(scope, config, vpc, securityGroups); + this.managementKeySecretName = managementKeySecretName; const deployOptions: StageOptions = { stageName: config.deploymentStage, @@ -56,8 +111,9 @@ export class LisaApiBaseConstruct extends Construct { const authorizer = new CustomAuthorizer(scope, 'LisaApiAuthorizer', { config: config, securityGroups: [vpc.securityGroups.lambdaSg], - tokenTable, + tokenTable: this.tokenTable, vpc, + managementKeySecretName: this.managementKeySecretName, ...(config.roles && { role: Role.fromRoleName(scope, 'AuthorizerRole', config.roles.RestApiAuthorizerRole), @@ -85,4 +141,73 @@ export class LisaApiBaseConstruct extends Construct { this.rootResourceId = restApi.restApiRootResourceId; this.restApiUrl = restApi.url; } + + private createManagementKeySecret (scope: Stack, config: Config, vpc: Vpc, securityGroups: ISecurityGroup[]): { managementKeySecretName: string } { + const managementKeySecretName = `${config.deploymentName}-management-key`; + + const managementEventBus = new EventBus(scope, createCdkId([scope.node.id, 'managementEventBus']), { + eventBusName: `${config.deploymentName}-management-events`, + }); + + const managementKeySecret = new Secret(scope, createCdkId([scope.node.id, 'managementKeySecret']), { + secretName: managementKeySecretName, + description: 'LISA management key secret', + generateSecretString: { + excludePunctuation: true, + passwordLength: 16 + }, + removalPolicy: config.removalPolicy + }); + + const rotationLambda = new Function(scope, createCdkId([scope.node.id, 'managementKeyRotationLambda']), { + runtime: getDefaultRuntime(), + handler: 'management_key.handler', + code: Code.fromAsset(config.lambdaPath || LAMBDA_PATH), + timeout: Duration.minutes(5), + environment: { + EVENT_BUS_NAME: managementEventBus.eventBusName, + }, + role: new Role(scope, createCdkId([scope.node.id, 'managementKeyRotationRole']), { + assumedBy: new ServicePrincipal('lambda.amazonaws.com'), + managedPolicies: [ + ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaVPCAccessExecutionRole'), + ], + inlinePolicies: { + 'SecretsManagerRotation': new PolicyDocument({ + statements: [ + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'secretsmanager:DescribeSecret', + 'secretsmanager:GetSecretValue', + 'secretsmanager:PutSecretValue', + 'secretsmanager:UpdateSecretVersionStage' + ], + resources: [managementKeySecret.secretArn] + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['events:PutEvents'], + resources: [managementEventBus.eventBusArn] + }) + ] + }) + } + }), + securityGroups: securityGroups, + vpc: vpc.vpc, + }); + + managementKeySecret.addRotationSchedule('RotationSchedule', { + automaticallyAfter: Duration.days(30), + rotationLambda: rotationLambda + }); + + new StringParameter(scope, createCdkId(['AppManagementKeySecretName']), { + parameterName: `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`, + stringValue: managementKeySecret.secretName, + }); + + return { managementKeySecretName }; + } } diff --git a/lib/core/apiDeploymentConstruct.ts b/lib/core/apiDeploymentConstruct.ts index a02f9b11b..e2fec2034 100644 --- a/lib/core/apiDeploymentConstruct.ts +++ b/lib/core/apiDeploymentConstruct.ts @@ -47,7 +47,7 @@ export class LisaApiDeploymentConstruct extends Construct { // https://github.com/aws/aws-cdk/issues/25582 (deployment as any).resource.stageName = config.deploymentStage; - const api_url = `https://${restApiId}.execute-api.${Aws.REGION}.${Aws.URL_SUFFIX}/${config.deploymentStage}`; + const api_url = config.apiGatewayConfig?.domainName ? `https://${config.apiGatewayConfig?.domainName}` : `https://${restApiId}.execute-api.${Aws.REGION}.${Aws.URL_SUFFIX}/${config.deploymentStage}`; new StringParameter(scope, 'LisaApiDeploymentStringParameter', { parameterName: `${config.deploymentPrefix}/LisaApiUrl`, stringValue: api_url, diff --git a/lib/core/api_base.ts b/lib/core/api_base.ts index 00fed76d9..b46e4868e 100644 --- a/lib/core/api_base.ts +++ b/lib/core/api_base.ts @@ -15,6 +15,7 @@ */ import { Stack } from 'aws-cdk-lib'; import { Authorizer, RestApi } from 'aws-cdk-lib/aws-apigateway'; +import { ITable } from 'aws-cdk-lib/aws-dynamodb'; import { Construct } from 'constructs'; import { LisaApiBaseConstruct, LisaApiBaseProps } from './apiBaseConstruct'; @@ -22,6 +23,7 @@ import { LisaApiBaseConstruct, LisaApiBaseProps } from './apiBaseConstruct'; * LisaApiBase Stack */ export class LisaApiBaseStack extends Stack { + public readonly tokenTable?: ITable; public readonly restApi: RestApi; public readonly authorizer?: Authorizer; public readonly restApiId: string; @@ -38,5 +40,6 @@ export class LisaApiBaseStack extends Stack { this.restApiId = api.restApi.restApiId; this.rootResourceId = api.restApi.restApiRootResourceId; this.restApiUrl = api.restApi.url; + this.tokenTable = api.tokenTable; } } diff --git a/lib/core/iam/ecs.json b/lib/core/iam/ecs.json index 17710ad00..d0995ff0c 100644 --- a/lib/core/iam/ecs.json +++ b/lib/core/iam/ecs.json @@ -206,7 +206,7 @@ "dynamodb:Query", "dynamodb:Scan" ], - "Resource": "arn:${AWS::Partition}:dynamodb:${AWS::Region}:${AWS::AccountId}:table/*-LISAApiTokenTable" + "Resource": "arn:${AWS::Partition}:dynamodb:${AWS::Region}:${AWS::AccountId}:table/*-LISAApi*TokenTable" } ] } diff --git a/lib/core/iam/rag.json b/lib/core/iam/rag.json index 3b99a8bec..d97afbd9c 100644 --- a/lib/core/iam/rag.json +++ b/lib/core/iam/rag.json @@ -66,7 +66,18 @@ "Effect": "Allow", "Action": [ "bedrock:StartIngestionJob", - "bedrock:Retrieve" + "bedrock:Retrieve", + "bedrock:ListKnowledgeBases", + "bedrock:GetKnowledgeBase", + "bedrock:ListDataSources", + "bedrock:GetDataSource" + ], + "Resource": "*" + }, + { + "Effect": "Allow", + "Action": [ + "cloudwatch:PutMetricData" ], "Resource": "*" } diff --git a/lib/core/iam/roles.ts b/lib/core/iam/roles.ts index 34788c7d9..dcd131a06 100644 --- a/lib/core/iam/roles.ts +++ b/lib/core/iam/roles.ts @@ -33,6 +33,7 @@ export enum Roles { ECS_MCPWORKBENCH_API_ROLE = 'ECSMcpWorkbenchApiRole', LAMBDA_CONFIGURATION_API_EXECUTION_ROLE = 'LambdaConfigurationApiExecutionRole', LAMBDA_EXECUTION_ROLE = 'LambdaExecutionRole', + MCP_SERVER_DEPLOYER_ROLE = 'McpServerDeployerRole', MODEL_API_ROLE = 'ModelApiRole', MODEL_SFN_LAMBDA_ROLE = 'ModelsSfnLambdaRole', MODEL_SFN_ROLE = 'ModelSfnRole', @@ -60,6 +61,7 @@ export const RoleNames: Record = { [Roles.ECS_MCPWORKBENCH_API_ROLE]: 'ECSMcpWorkbenchApiRole', [Roles.LAMBDA_CONFIGURATION_API_EXECUTION_ROLE]: 'LambdaConfigurationApiExecutionRole', [Roles.LAMBDA_EXECUTION_ROLE]: 'LambdaExecutionRole', + [Roles.MCP_SERVER_DEPLOYER_ROLE]: 'McpServerDeployerRole', [Roles.MODEL_API_ROLE]: 'ModelApiRole', [Roles.MODEL_SFN_LAMBDA_ROLE]: 'ModelsSfnLambdaRole', [Roles.MODEL_SFN_ROLE]: 'ModelSfnRole', diff --git a/lib/docs/.vitepress/config.mts b/lib/docs/.vitepress/config.mts index b51a4b1b9..8bda7c3f6 100644 --- a/lib/docs/.vitepress/config.mts +++ b/lib/docs/.vitepress/config.mts @@ -57,7 +57,7 @@ const navLinks = [ { text: 'Model Management UI', link: '/config/model-management-ui' }, { text: 'Guardrails', link: '/config/guardrails' }, { text: 'Usage & Features', link: '/config/usage' }, - { text: 'RAG Vector Stores', link: '/config/vector-stores' }, + { text: 'RAG Repository', link: '/config/repositories' }, { text: 'Langfuse Tracing', link: '/config/langfuse-tracing'}, { text: 'Configuration Schema', @@ -68,8 +68,9 @@ const navLinks = [ { text: 'Role Overrides', link: '/config/role-overrides' }, ], }, - { text: 'Model Context Protocol (MCP)', link: '/config/mcp' }, - { text: 'MCP Workbench', link: '/config/mcp-workbench' }, + { text: 'LISA MCP: Self-host servers', link: '/config/hosted-mcp' }, + { text: 'MCP Connections: Third-party servers', link: '/config/mcp' }, + { text: 'MCP Workbench: Experimentation', link: '/config/mcp-workbench' }, { text: 'Usage Analytics', link: '/config/cloudwatch' }, ], }, diff --git a/lib/docs/admin/architecture.md b/lib/docs/admin/architecture.md index d6ed34e4d..3a21cee2c 100644 --- a/lib/docs/admin/architecture.md +++ b/lib/docs/admin/architecture.md @@ -1,82 +1,42 @@ # Architecture Overview -LISA’s major components include Serve, a Chat user interface (UI), and retrieval augmented generation (RAG). -Serve is required but the remaining components are optional. LISA also offers APIs for customers using LISA for model -hosting and orchestration. - -![LISA Overall Architecture](../assets/LisaArchitecture.png) - -* **Serve:** This is the core of LISA. Serve supports model self-hosting in scalable Amazon ECS clusters. -Through LiteLLM, Serve is also compatible with 100+ models hosted by external model providers like Amazon Bedrock. -* **Chat UI:** Customers prompt LLMs, receive responses, create and modify prompt templates and personas, adjust model -arguments, use advanced features like RAG and MCP, manage session history, export conversations and images, upload files to vector stores, and upload files for non-RAG in context references. Administrators can add, remove, and update models. They can control chat features available to customers without requiring code changes or application re-deployment. Administrators can set up vector stores and document ingestion pipelines to support RAG. The chat UI also offers user authentication. Administrators configure LISA with their identity provider (IdP). Out of the box LISA supports the OIDC protocol. -* **RAG:** LISA is compatible with Amazon OpenSearch or PostgreSQL's PGVector extension in Amazon RDS. LISA offers -automatic document ingestion pipelines for customers to routinely load files into their vector stores. LISA will also -soon support Amazon Bedrock Knowledge bases. -* **APIs:** Customers leveraging LISA for model hosting and orchestration can integrate LISA with existing mission -tooling or alternative front ends. LISA uses Amazon DynamoDB to store Tokens to interact with the exposed APIs. - * Inference requests through LiteLLM support prompting LLMs configured with LISA. Prompts can support LISA’s RAG and - MCP and features. +LISA’s four major components include Serve, a Chat user interface (UI), retrieval augmented generation (RAG), and model context protocol (MCP). + +LISA Serve and LISA MCP are standalone, core solutions with APIs for customers not leveraging LISA’s Chat UI. Both LISA’s Chat UI and RAG are optional components, but must be used with Serve. + +* **Serve:** LISA’s foundational, core component offering centralized model orchestration and inference. Through LiteLLM, Serve is compatible with 100+ models hosted by external model providers like Amazon Bedrock and Jumpstart. Serve also supports secure model self-hosting in scalable Amazon ECS clusters. +* **Chat UI:** LISA’s configurable UI supports customers prompting LLMs, creating and sharing prompt templates, leveraging RAG and MCP capabilities, comparing model responses, and managing chat session history. Administrators can manage advanced features, configure resources through wizards to manage models, RAG repositories, document ingestion pipelines, and MCP servers. Administrators configure Chat with an identity provider (IdP) to manage user access. Enterprise groups can then be associated with resources to manage access to: models, Amazon Guardrails, RAG repositories, RAG collections, and MCP tools. The chat UI must be used with Serve. +* **RAG:** LISA is compatible with Amazon OpenSearch, Amazon Bedrock Knowledge Bases, and PostgreSQL's PGVector extension in Amazon RDS. LISA offers document ingestion with LangChain along with automated pipelines for customers to routinely load files into their repositories and collections. RAG must be used with Serve. +* **MCP:** LISA’s core component offering scalable MCP server self-hosting. It supports hosting STDIO, HTTP, and SSE servers in AWS Fargate (via Amazon Elastic Container Service or ECS) clusters with custom or prebuilt images. LISA MCP also supports importing custom resources for hosting via S3. LISA MCP can be deployed alongside Serve or independently. +* **APIs:** Customers leveraging LISA Serve for model orchestration and hosting, or LISA MCP for tool hosting can opt to securely integrate directly with mission tooling or alternative front ends. LISA uses Amazon DynamoDB to store Tokens to interact with the exposed APIs. + * Inference requests through LiteLLM support prompting LLMs configured with LISA. Prompts include LISA’s RAG and MCP features. * Chat session API supports session history, conversation continuity and management. * Model management API supports deploying, updating, and deleting third party and internally hosted models. + * MCP API supports deploying, updating, deleting, and calling internally hosted MCP tools. -## Serve +## LISA Serve ![LISA Serve Architecture](../assets/LisaServe.png) -LISA Serve is the foundational, core component of the solution. It provides model self-hosting and integration with -compatible external model providers. Serve supports text generation, image generation, and embedding models. Serve’s -components are designed for scale and reliability. Serve can be accessed via LISA’s REST APIs, or through LISA’s chat -user interface (UI). Regardless of origin, all inference requests are routed via an Application Load Balancer (ALB), -which serves as the main entry point to LISA Serve. The ALB forwards requests through the LiteLLM proxy, hosted in its -own scalable Amazon Elastic Container Service (ECS) cluster with Amazon Elastic Compute Cloud (EC2) instance. LiteLLM -routes traffic to the appropriate model. - -Self-hosted model traffic is directed to model specific ALBs, which enable autoscaling in the event of heavy traffic. -Each self-hosted model has their own Amazon ECS cluster and Amazon EC2 instance. Text generation and image generation -models compatible with Hugging Face’s +LISA Serve provides model self-hosting and integration with compatible external model providers. Serve supports text generation, image generation, and embedding models. Serve’s components are designed for scale and reliability. Serve can be accessed via LISA’s REST APIs, or through LISA’s chat +UI. Regardless of origin, all inference requests are routed via an Application Load Balancer (ALB), which serves as the main entry point to LISA Serve. The ALB forwards requests through the LiteLLM proxy, hosted in its own scalable Amazon Elastic Container Service (ECS) cluster with Amazon Elastic Compute Cloud (EC2) instance. LiteLLM routes traffic to the appropriate model. + +Self-hosted model traffic is directed to model specific ALBs, which enable autoscaling in the event of heavy traffic. Each self-hosted model has its own Amazon ECS cluster and Amazon EC2 instance. Text generation and image generation models compatible with Hugging Face’s [Text Generation Inference (TGI)](https://huggingface.co/docs/text-generation-inference/en/index) and -[vLLM](https://docs.vllm.ai/en/latest/) images are supported. Embedding models compatible with Huffing Face’s +[vLLM](https://docs.vllm.ai/en/latest/) images are supported. Embedding models compatible with Hugging Face’s [Text Embedding Inference (TEI)](https://huggingface.co/docs/text-embeddings-inference/en/index) and -[vLLM](https://docs.vllm.ai/en/latest/) images are also supported. LISA uses **** S3 for loading the model weights. +[vLLM](https://docs.vllm.ai/en/latest/) images are also supported. LISA uses Amazon S3 for loading the model weights. **Technical Notes:** -* RAG operations are managed through `lambda/rag/lambda_functions.py`, which handles embedding generation and document -retrieval via OpenSearch and PostgreSQL. -* Direct requests to the LISA Serve ALB entrypoint must utilize the OpenAI API spec, which we support through the use -* of the LiteLLM proxy. - -## Chat UI -![LISA Chatbot Architecture](../assets/LisaChat.png) - -LISA provides a customizable chat user interface (UI). The UI is hosted as a static website in Amazon S3, and is fronted -by Amazon API Gateway. Customers prompt models and view real-time responses. The UI is integrated with LISA Serve, Chat -APIs, Model Management APIs, and RAG. LISA’s chat UI supports integration with an OIDC identity provider to handle user -authentication. LISA can be accessible to all users, or limited to a single enterprise user group. Users added to the -Administrator role have access to application configuration. - -**LISA’s chat UI features include:** +* RAG operations are managed through `lambda/rag/lambda_functions.py`, which handles embedding generation and document retrieval via OpenSearch and PostgreSQL. +* Direct requests to the LISA Serve ALB entrypoint must utilize the OpenAI API spec, which we support through the use of the LiteLLM proxy. +* LISA supports OpenAI's API spec, which means LISA can be easily configured with the Continue plugin for use with Jetbrains or VS Code integrated development environments (IDE). -* Prompting text and image generation LLMs and receiving responses -* Viewing, deleting, and exporting chat history -* Supports streaming responses, viewing metadata, and Markdown formatting -* Creating and sharing directive prompt and persona templates in a Prompt Library -* Advanced model args like max tokens, Top P, Temperature, stop words -* Referencing vector stores to support RAG -* Uploading docs into vector stores -* Uploading docs into non-RAG in context -* RAG document library -* Non-RAG in context Document summarization feature -* Model Context Protocol (MCP) support -* Administrators control which features are available without having to make code changes -* Administrators can configure models with LISA via the model manage wizard -* Administrators can add and manage vector stores and manage group access, and automatic ingestion pipelines - -## Model Management +### Model Management ![LISA Model Management Architecture](../assets/LisaModelManagement.png) -The Model Management is responsible for managing the entire lifecycle of models in LISA. This includes creation, updating, +Use Model Management for managing the entire lifecycle of models configured or hosted with LISA. This includes creation, updating, deletion of models deployed on ECS or third party provided. LISA handles scaling of these operations, ensuring that the underlying infrastructure is managed efficiently. @@ -107,3 +67,68 @@ security, networking, and infrastructure components are automatically deployed a `ecs_model_deployer/src/lib/lisa_model_stack.ts`. * ECS Cluster: ECS cluster and task definitions are located in `ecs_model_deployer/src/lib/ecsCluster.ts`, with model containers specified in `ecs_model_deployer/src/lib/ecs-model.ts`. + + +## LISA MCP +![LISA MCP Architecture](../assets/LisaMcp.png) + +LISA MCP is a standalone product that provides scalable infrastructure for deploying and hosting Model Context Protocol (MCP) servers. It allows customers to self-host MCP servers for enterprise use. LISA MCP can be deployed independently of LISA Serve or configured to work seamlessly with LISA Serve and the Chat UI. + +Each MCP server deployed via LISA MCP is provisioned on AWS Fargate via Amazon ECS, fronted by Application Load Balancers (ALBs) and Network Load Balancers (NLBs), and published through the existing API Gateway. This architecture allows chat sessions to securely invoke MCP tools without leaving your VPC. All routes remain protected by the same API Gateway Lambda authorizer patterns that guards the rest of LISA, ensuring API Keys, IDP lockdown, and JWT group enforcement continue to apply automatically. + +**Server Types:** LISA MCP supports all MCP server types: +* **STDIO servers:** Automatically wrapped with `mcp-proxy` and exposed over HTTP on port 8080 +* **HTTP servers:** Direct HTTP endpoints using the configured port (default 8000) +* **SSE servers:** Server-Sent Events endpoints for streaming responses + +**Networking Architecture:** The networking follows a layered approach: +* **API Gateway** receives MCP traffic on `/mcp/{serverId}` routes +* **Network Load Balancer (NLB)** terminates the API Gateway VPC Link and forwards to the Application Load Balancer +* **Application Load Balancer (ALB)** provides HTTP features including health checks, routing, and load balancing +* **ECS Fargate** hosts the MCP server containers within your VPC using the same subnets and security groups as the MCP API stack + +**Lifecycle Management:** AWS Step Functions orchestrate the complete lifecycle of MCP servers, handling creation, update, deletion, start, and stop workflows. Each workflow provisions the required resources using CloudFormation templates, which manage infrastructure components like ECS Fargate services, load balancers, VPC Links, and auto-scaling configurations. + +**Key Features:** +* Turn-key hosting for STDIO, HTTP, or SSE MCP servers with a single API/UI workflow +* Dynamic container builds from pre-built images or S3 artifacts synced at deploy time +* Auto-scaling with configurable Fargate min/max capacity, custom metrics, and scaling targets per server +* Secure VPC networking with private ALB for internal traffic and NLB + VPC Link for API Gateway access +* Group-aware routing to limit server visibility to specific identity provider groups or make them public +* External integrations via API Gateway URLs, enabling trusted third-party agents, copilots, or workflow engines to invoke hosted MCP servers using the same credentials and auth controls + +**Technical Notes:** + +* MCP Server Lifecycle: Lifecycle operations such as create, update, delete, start, and stop are orchestrated by Step Functions workflows (`CreateMcpServer`, `UpdateMcpServer`, `DeleteMcpServer`). The MCP API Handler Lambda validates requests and manages server metadata in DynamoDB. +* CloudFormation: Infrastructure components are provisioned using CloudFormation templates synthesized by the MCP server deployer Lambda, as defined in `mcp_server_deployer/src/lib/ecsMcpServer.ts`. +* ECS Fargate: Each MCP server runs in its own ECS Fargate cluster with dedicated ALB and NLB. The Fargate cluster configuration is located in `mcp_server_deployer/src/lib/ecsFargateCluster.ts`. +* Authentication: API Gateway enforces the same Lambda authorizer used across LISA (JWT validation + optional API key checks). The `{LISA_BEARER_TOKEN}` placeholder in connection details is automatically replaced with the user's bearer token at connection time. +* Data Storage: Server metadata is stored in the `MCP_SERVERS_TABLE` DynamoDB table. When `DEPLOYMENT_PREFIX` is configured, completed servers are published to `McpConnectionsTable` so the chat application can surface them alongside externally hosted connections. + + +## Chat UI +![LISA Chatbot Architecture](../assets/LisaChat.png) + +LISA provides a configurable chat user interface (UI). The UI is hosted as a static website in Amazon S3, and is fronted +by Amazon API Gateway. Customers prompt models and view responses. The UI is integrated with LISA Serve, Chat +APIs, Model Management APIs, and RAG. LISA’s chat UI supports integration with an OIDC identity provider to handle user +authentication. LISA can be accessible to all users, or limited to a single enterprise user group. Users added to the +Administrator role have access to application configuration. + +**Features:** + +* Prompting text and image generation LLMs and receiving responses +* Viewing, deleting, and exporting chat history +* Supports streaming responses, viewing metadata, RAG citations +* Supports Markdown, mermaid, and math formatting +* Creating and sharing directive prompt and persona templates in a Prompt Library +* Supports advanced model args like max tokens, Top P, Temperature, stop words +* Referencing vector stores for RAG, and doc uploads +* RAG document library +* Uploading docs into non-RAG in context +* Non-RAG in context Document summarization feature +* Model Context Protocol (MCP) support for LISA MCP, along with MCP Workbench, and third party MCP Connections features +* Administrators control which features are available without having to make code changes via the Configuration page +* Administrators configure models with LISA via the model manage wizard in Model Management +* Administrators add and manage vector stores and manage group access, and automatic ingestion pipelines via RAG Management +* Administrators configure hosted MCP Servers via the MCP Management diff --git a/lib/docs/assets/LisaArchitecture.png b/lib/docs/assets/LisaArchitecture.png deleted file mode 100644 index 884e275bf..000000000 Binary files a/lib/docs/assets/LisaArchitecture.png and /dev/null differ diff --git a/lib/docs/assets/LisaChat.png b/lib/docs/assets/LisaChat.png index f9d958e0a..0aacd722f 100644 Binary files a/lib/docs/assets/LisaChat.png and b/lib/docs/assets/LisaChat.png differ diff --git a/lib/docs/assets/LisaMcp.png b/lib/docs/assets/LisaMcp.png new file mode 100644 index 000000000..52cbc168f Binary files /dev/null and b/lib/docs/assets/LisaMcp.png differ diff --git a/lib/docs/assets/LisaServe.png b/lib/docs/assets/LisaServe.png index 22365e94f..e58b07e20 100644 Binary files a/lib/docs/assets/LisaServe.png and b/lib/docs/assets/LisaServe.png differ diff --git a/lib/docs/config/collection-management-api.md b/lib/docs/config/collection-management-api.md new file mode 100644 index 000000000..817812848 --- /dev/null +++ b/lib/docs/config/collection-management-api.md @@ -0,0 +1,1690 @@ +# Collection Management API + +The Collection Management API provides endpoints for creating, reading, updating, and deleting collections within RAG vector stores. Collections enable organizing documents with different chunking strategies and access controls without requiring infrastructure changes. + +## Base URL Structure + +All collection endpoints are accessed through LISA's main API Gateway with the following structure: +``` +https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/{repositoryId}/collection +``` + +## Authentication + +All API endpoints require proper authentication through LISA's configured authorization mechanism. Ensure your requests include valid authorization headers as configured in your LISA deployment. + +## Endpoints + +### Create Collection + +Create a new collection within a vector store. + +**Endpoint:** `POST /repository/{repositoryId}/collection` + +**Path Parameters:** +- `repositoryId` (string, required): The parent vector store repository ID + +**Request Body:** + +```json +{ + "name": "Legal Documents", + "description": "Collection for legal contracts and agreements", + "chunkingStrategy": { + "type": "RECURSIVE", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200, + "separators": ["\n\n", "\n", ". ", " "] + } + }, + "allowedGroups": ["legal-team", "compliance"], + "metadata": { + "tags": ["legal", "contracts", "confidential"] + }, + "private": false, + "allowChunkingOverride": true +} +``` + +**Request Body Schema:** + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `name` | string | Yes | Collection name (1-100 characters) | +| `description` | string | No | Collection description | +| `embeddingModel` | string | No | Embedding model ID (inherits from parent if omitted) | +| `chunkingStrategy` | object | No | Chunking strategy configuration (inherits from parent if omitted) | +| `chunkingStrategy.type` | enum | No | Strategy type: `FIXED` | +| `chunkingStrategy.parameters` | object | No | Strategy-specific parameters | +| `allowedGroups` | array[string] | No | User groups with access (inherits from parent if omitted) | +| `metadata` | object | No | Collection-specific metadata | +| `metadata.tags` | array[string] | No | Metadata tags (max 50 tags, 50 chars each) | +| `private` | boolean | No | Whether collection is private to creator (default: false) | +| `allowChunkingOverride` | boolean | No | Allow chunking strategy override during ingestion (default: true) | +| `pipelines` | array[object] | No | Automated ingestion pipelines | + +**Chunking Strategy Types:** + +1. **FIXED**: Fixed-size chunks with overlap + ```json + { + "type": "fixed", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200 + } + } + ``` + +2. **SEMANTIC**: Semantic-based chunking + ```json + { + "type": "SEMANTIC", + "parameters": { + "threshold": 0.5 + } + } + ``` + +3. **RECURSIVE**: Recursive text splitting with custom separators + ```json + { + "type": "RECURSIVE", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200, + "separators": ["\n\n", "\n", ". ", " "] + } + } + ``` + +**Response (200 OK):** + +```json +{ + "collectionId": "550e8400-e29b-41d4-a716-446655440000", + "repositoryId": "repo-123", + "name": "Legal Documents", + "description": "Collection for legal contracts and agreements", + "chunkingStrategy": { + "type": "RECURSIVE", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200, + "separators": ["\n\n", "\n", ". ", " "] + } + }, + "allowChunkingOverride": true, + "metadata": { + "tags": ["legal", "contracts", "confidential"] + }, + "allowedGroups": ["legal-team", "compliance"], + "embeddingModel": "amazon.titan-embed-text-v1", + "createdBy": "user-456", + "createdAt": "2025-10-13T10:30:00Z", + "updatedAt": "2025-10-13T10:30:00Z", + "status": "ACTIVE", + "private": false, + "pipelines": [] +} +``` + +**Error Responses:** + +| Status Code | Description | Example | +|-------------|-------------|---------| +| 400 | Bad Request - Invalid input | `{"error": "Collection name must be unique within repository"}` | +| 403 | Forbidden - Insufficient permissions | `{"error": "User does not have write access to repository"}` | +| 404 | Not Found - Repository not found | `{"error": "Repository 'repo-123' not found"}` | +| 500 | Internal Server Error | `{"error": "Failed to create collection"}` | + +**Example cURL Request:** + +```bash +curl -X POST "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection" \ + -H "Authorization: Bearer {YOUR_TOKEN}" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Legal Documents", + "description": "Collection for legal contracts and agreements", + "chunkingStrategy": { + "type": "RECURSIVE", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200, + "separators": ["\n\n", "\n", ". ", " "] + } + }, + "allowedGroups": ["legal-team", "compliance"], + "metadata": { + "tags": ["legal", "contracts", "confidential"] + }, + "private": false + }' +``` + +**Example Python Request:** + +```python +import requests +import json + +url = "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection" +headers = { + "Authorization": "Bearer {YOUR_TOKEN}", + "Content-Type": "application/json" +} +payload = { + "name": "Legal Documents", + "description": "Collection for legal contracts and agreements", + "chunkingStrategy": { + "type": "RECURSIVE", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200, + "separators": ["\n\n", "\n", ". ", " "] + } + }, + "allowedGroups": ["legal-team", "compliance"], + "metadata": { + "tags": ["legal", "contracts", "confidential"] + }, + "private": False +} + +response = requests.post(url, headers=headers, json=payload) +if response.status_code == 200: + collection = response.json() + print(f"Created collection: {collection['collectionId']}") +else: + print(f"Error: {response.status_code} - {response.text}") +``` + +**Example JavaScript Request:** + +```javascript +const url = 'https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection'; +const headers = { + 'Authorization': 'Bearer {YOUR_TOKEN}', + 'Content-Type': 'application/json' +}; +const payload = { + name: 'Legal Documents', + description: 'Collection for legal contracts and agreements', + chunkingStrategy: { + type: 'RECURSIVE', + parameters: { + chunkSize: 1000, + chunkOverlap: 200, + separators: ['\n\n', '\n', '. ', ' '] + } + }, + allowedGroups: ['legal-team', 'compliance'], + metadata: { + tags: ['legal', 'contracts', 'confidential'] + }, + private: false +}; + +fetch(url, { + method: 'POST', + headers: headers, + body: JSON.stringify(payload) +}) + .then(response => { + if (response.status === 200) { + return response.json(); + } + throw new Error(`Error: ${response.status}`); + }) + .then(collection => { + console.log(`Created collection: ${collection.collectionId}`); + }) + .catch(error => { + console.error('Error:', error); + }); +``` + +### Get Collection + +Retrieve a collection by ID within a vector store. + +**Endpoint:** `GET /repository/{repositoryId}/collection/{collectionId}` + +**Path Parameters:** +- `repositoryId` (string, required): The parent vector store repository ID +- `collectionId` (string, required): The collection ID (UUID) + +**Response (200 OK):** + +```json +{ + "collectionId": "550e8400-e29b-41d4-a716-446655440000", + "repositoryId": "repo-123", + "name": "Legal Documents", + "description": "Collection for legal contracts and agreements", + "chunkingStrategy": { + "type": "RECURSIVE", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200, + "separators": ["\n\n", "\n", ". ", " "] + } + }, + "allowChunkingOverride": true, + "metadata": { + "tags": ["legal", "contracts", "confidential"] + }, + "allowedGroups": ["legal-team", "compliance"], + "embeddingModel": "amazon.titan-embed-text-v1", + "createdBy": "user-456", + "createdAt": "2025-10-13T10:30:00Z", + "updatedAt": "2025-10-13T10:30:00Z", + "status": "ACTIVE", + "private": false, + "pipelines": [] +} +``` + +**Error Responses:** + +| Status Code | Description | Example | +|-------------|-------------|---------| +| 403 | Forbidden - Insufficient permissions | `{"error": "Permission denied: User does not have read access to collection"}` | +| 404 | Not Found - Collection not found | `{"error": "Collection '550e8400-e29b-41d4-a716-446655440000' not found"}` | +| 404 | Not Found - Repository not found | `{"error": "Repository 'repo-123' not found"}` | +| 500 | Internal Server Error | `{"error": "Failed to retrieve collection"}` | + +**Example cURL Request:** + +```bash +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example Python Request:** + +```python +import requests + +url = "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000" +headers = { + "Authorization": "Bearer {YOUR_TOKEN}" +} + +response = requests.get(url, headers=headers) +if response.status_code == 200: + collection = response.json() + print(f"Collection: {collection['name']}") + print(f"Status: {collection['status']}") + print(f"Allowed Groups: {collection['allowedGroups']}") +else: + print(f"Error: {response.status_code} - {response.text}") +``` + +**Example JavaScript Request:** + +```javascript +const url = 'https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000'; +const headers = { + 'Authorization': 'Bearer {YOUR_TOKEN}' +}; + +fetch(url, { + method: 'GET', + headers: headers +}) + .then(response => { + if (response.status === 200) { + return response.json(); + } + throw new Error(`Error: ${response.status}`); + }) + .then(collection => { + console.log(`Collection: ${collection.name}`); + console.log(`Status: ${collection.status}`); + console.log(`Allowed Groups: ${collection.allowedGroups}`); + }) + .catch(error => { + console.error('Error:', error); + }); +``` + +### Update Collection + +Update a collection's configuration within a vector store. Supports partial updates - only specified fields will be modified. + +**Endpoint:** `PUT /repository/{repositoryId}/collection/{collectionId}` + +**Path Parameters:** +- `repositoryId` (string, required): The parent vector store repository ID +- `collectionId` (string, required): The collection ID (UUID) + +**Request Body:** + +All fields are optional. Only include fields you want to update. + +```json +{ + "name": "Updated Legal Documents", + "description": "Updated description for legal contracts", + "chunkingStrategy": { + "type": "FIXED", + "parameters": { + "chunkSize": 1500, + "chunkOverlap": 300 + } + }, + "allowedGroups": ["legal-team", "compliance", "audit"], + "metadata": { + "tags": ["legal", "contracts", "confidential", "2025"] + }, + "private": false, + "allowChunkingOverride": true, + "status": "ACTIVE" +} +``` + +**Request Body Schema:** + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `name` | string | No | Collection name (1-100 characters) | +| `description` | string | No | Collection description | +| `chunkingStrategy` | object | No | Chunking strategy configuration | +| `chunkingStrategy.type` | enum | No | Strategy type: `FIXED`, `SEMANTIC`, or `RECURSIVE` | +| `chunkingStrategy.parameters` | object | No | Strategy-specific parameters | +| `allowedGroups` | array[string] | No | User groups with access | +| `metadata` | object | No | Collection-specific metadata | +| `metadata.tags` | array[string] | No | Metadata tags (max 50 tags, 50 chars each) | +| `private` | boolean | No | Whether collection is private to creator | +| `allowChunkingOverride` | boolean | No | Allow chunking strategy override during ingestion | +| `pipelines` | array[object] | No | Automated ingestion pipelines | +| `status` | enum | No | Collection status: `ACTIVE`, `ARCHIVED`, or `DELETED` | + +**Immutable Fields:** + +The following fields cannot be modified after creation and will be ignored if included in the request: +- `collectionId` +- `repositoryId` +- `embeddingModel` +- `createdBy` +- `createdAt` + +**Response (200 OK):** + +```json +{ + "collection": { + "collectionId": "550e8400-e29b-41d4-a716-446655440000", + "repositoryId": "repo-123", + "name": "Updated Legal Documents", + "description": "Updated description for legal contracts", + "chunkingStrategy": { + "type": "FIXED", + "parameters": { + "chunkSize": 1500, + "chunkOverlap": 300 + } + }, + "allowChunkingOverride": true, + "metadata": { + "tags": ["legal", "contracts", "confidential", "2025"] + }, + "allowedGroups": ["legal-team", "compliance", "audit"], + "embeddingModel": "amazon.titan-embed-text-v1", + "createdBy": "user-456", + "createdAt": "2025-10-13T10:30:00Z", + "updatedAt": "2025-10-13T14:45:00Z", + "status": "ACTIVE", + "private": false, + "pipelines": [] + }, + "warnings": [ + "Changing chunking strategy will only affect new documents. Existing documents will retain their original chunking. Consider re-ingesting existing documents if needed." + ] +} +``` + +**Response Fields:** + +| Field | Type | Description | +|-------|------|-------------| +| `collection` | object | The updated collection configuration | +| `warnings` | array[string] | Optional warnings about the update (e.g., chunking strategy changes) | + +**Error Responses:** + +| Status Code | Description | Example | +|-------------|-------------|---------| +| 400 | Bad Request - Invalid input | `{"error": "Collection name must be unique within repository"}` | +| 403 | Forbidden - Insufficient permissions | `{"error": "Permission denied: User does not have write access to collection"}` | +| 404 | Not Found - Collection not found | `{"error": "Collection '550e8400-e29b-41d4-a716-446655440000' not found"}` | +| 404 | Not Found - Repository not found | `{"error": "Repository 'repo-123' not found"}` | +| 500 | Internal Server Error | `{"error": "Failed to update collection"}` | + +**Example cURL Request:** + +```bash +curl -X PUT "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000" \ + -H "Authorization: Bearer {YOUR_TOKEN}" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Updated Legal Documents", + "description": "Updated description for legal contracts", + "metadata": { + "tags": ["legal", "contracts", "confidential", "2025"] + } + }' +``` + +**Example Python Request:** + +```python +import requests +import json + +url = "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000" +headers = { + "Authorization": "Bearer {YOUR_TOKEN}", + "Content-Type": "application/json" +} +payload = { + "name": "Updated Legal Documents", + "description": "Updated description for legal contracts", + "metadata": { + "tags": ["legal", "contracts", "confidential", "2025"] + } +} + +response = requests.put(url, headers=headers, json=payload) +if response.status_code == 200: + result = response.json() + collection = result['collection'] + print(f"Updated collection: {collection['collectionId']}") + print(f"New name: {collection['name']}") + + # Check for warnings + if 'warnings' in result: + print("Warnings:") + for warning in result['warnings']: + print(f" - {warning}") +else: + print(f"Error: {response.status_code} - {response.text}") +``` + +**Example JavaScript Request:** + +```javascript +const url = 'https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000'; +const headers = { + 'Authorization': 'Bearer {YOUR_TOKEN}', + 'Content-Type': 'application/json' +}; +const payload = { + name: 'Updated Legal Documents', + description: 'Updated description for legal contracts', + metadata: { + tags: ['legal', 'contracts', 'confidential', '2025'] + } +}; + +fetch(url, { + method: 'PUT', + headers: headers, + body: JSON.stringify(payload) +}) + .then(response => { + if (response.status === 200) { + return response.json(); + } + throw new Error(`Error: ${response.status}`); + }) + .then(result => { + const collection = result.collection; + console.log(`Updated collection: ${collection.collectionId}`); + console.log(`New name: ${collection.name}`); + + // Check for warnings + if (result.warnings) { + console.log('Warnings:'); + result.warnings.forEach(warning => { + console.log(` - ${warning}`); + }); + } + }) + .catch(error => { + console.error('Error:', error); + }); +``` + +**Partial Update Example:** + +You can update just one field without affecting others: + +```bash +# Update only the description +curl -X PUT "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000" \ + -H "Authorization: Bearer {YOUR_TOKEN}" \ + -H "Content-Type: application/json" \ + -d '{ + "description": "New description only" + }' +``` + +**Chunking Strategy Change Warning:** + +When updating the chunking strategy on a collection that already has documents, the API will return a warning: + +```json +{ + "collection": { ... }, + "warnings": [ + "Changing chunking strategy will only affect new documents. Existing documents will retain their original chunking. Consider re-ingesting existing documents if needed." + ] +} +``` + +This warning indicates that: +1. New documents uploaded after the change will use the new chunking strategy +2. Existing documents will keep their original chunking +3. You may want to re-ingest existing documents to apply the new strategy + +### Delete Collection + +Delete a collection within a vector store. This operation requires admin access and will remove all associated documents from S3 and the vector store. + +**Endpoint:** `DELETE /repository/{repositoryId}/collection/{collectionId}` + +**Path Parameters:** +- `repositoryId` (string, required): The parent vector store repository ID +- `collectionId` (string, required): The collection ID (UUID) + +**Query Parameters:** + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `hardDelete` | boolean | No | false | Whether to permanently delete (true) or soft delete (false) | + +**Deletion Behavior:** + +- **Soft Delete (default)**: Marks the collection status as `DELETED` but retains the record in the database +- **Hard Delete**: Permanently removes the collection record from the database +- **Document Cleanup**: Both deletion types remove all associated documents from: + - S3 storage + - DynamoDB document table + - Vector store embeddings + +**Response (204 No Content):** + +No response body is returned on successful deletion. + +**Error Responses:** + +| Status Code | Description | Example | +|-------------|-------------|---------| +| 400 | Bad Request - Cannot delete default collection | `{"error": "Cannot delete the default collection"}` | +| 403 | Forbidden - Insufficient permissions | `{"error": "Permission denied: User does not have admin access to collection"}` | +| 404 | Not Found - Collection not found | `{"error": "Collection '550e8400-e29b-41d4-a716-446655440000' not found"}` | +| 404 | Not Found - Repository not found | `{"error": "Repository 'repo-123' not found"}` | +| 500 | Internal Server Error | `{"error": "Failed to delete collection"}` | + +**Example cURL Request (Soft Delete):** + +```bash +curl -X DELETE "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example cURL Request (Hard Delete):** + +```bash +curl -X DELETE "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000?hardDelete=true" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example Python Request:** + +```python +import requests + +url = "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000" +headers = { + "Authorization": "Bearer {YOUR_TOKEN}" +} + +# Soft delete (default) +response = requests.delete(url, headers=headers) +if response.status_code == 204: + print("Collection soft deleted successfully") +else: + print(f"Error: {response.status_code} - {response.text}") + +# Hard delete +params = {"hardDelete": "true"} +response = requests.delete(url, headers=headers, params=params) +if response.status_code == 204: + print("Collection permanently deleted") +else: + print(f"Error: {response.status_code} - {response.text}") +``` + +**Example JavaScript Request:** + +```javascript +const url = 'https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collection/550e8400-e29b-41d4-a716-446655440000'; +const headers = { + 'Authorization': 'Bearer {YOUR_TOKEN}' +}; + +// Soft delete (default) +fetch(url, { + method: 'DELETE', + headers: headers +}) + .then(response => { + if (response.status === 204) { + console.log('Collection soft deleted successfully'); + } else { + throw new Error(`Error: ${response.status}`); + } + }) + .catch(error => { + console.error('Error:', error); + }); + +// Hard delete +const hardDeleteUrl = `${url}?hardDelete=true`; +fetch(hardDeleteUrl, { + method: 'DELETE', + headers: headers +}) + .then(response => { + if (response.status === 204) { + console.log('Collection permanently deleted'); + } else { + throw new Error(`Error: ${response.status}`); + } + }) + .catch(error => { + console.error('Error:', error); + }); +``` + +**Important Notes:** + +1. **Admin Access Required**: Only users with admin access to the collection can delete it +2. **Default Collection Protection**: The default collection (based on embedding model ID) cannot be deleted +3. **Document Cleanup**: All documents in the collection will be removed from S3, DynamoDB, and the vector store +4. **Irreversible Operation**: Hard delete is permanent and cannot be undone +5. **Soft Delete Recovery**: Soft-deleted collections can be restored by updating the status back to `ACTIVE` + +**Deletion Confirmation Workflow:** + +Before deleting a collection, it's recommended to: + +1. **Check document count**: Use the GET endpoint to see how many documents will be affected +2. **Warn users**: Display a confirmation dialog showing the collection name and document count +3. **Require confirmation**: Ask users to type the collection name to confirm deletion +4. **Log the action**: Ensure audit logs capture who deleted the collection and when + +**Example Confirmation Flow:** + +```python +import requests + +def delete_collection_with_confirmation(repository_id, collection_id, token): + """Delete a collection with user confirmation.""" + url = f"https://{{API-GATEWAY-DOMAIN}}/{{STAGE}}/repository/{repository_id}/collection/{collection_id}" + headers = {"Authorization": f"Bearer {token}"} + + # Step 1: Get collection details + response = requests.get(url, headers=headers) + if response.status_code != 200: + print(f"Error fetching collection: {response.status_code}") + return False + + collection = response.json() + collection_id = collection['name'] + + # Step 2: Get document count (from list_docs endpoint) + docs_url = f"https://{{API-GATEWAY-DOMAIN}}/{{STAGE}}/repository/{repository_id}/documents" + docs_response = requests.get( + docs_url, + headers=headers, + params={"collectionId": collection_id, "pageSize": 1} + ) + + doc_count = 0 + if docs_response.status_code == 200: + doc_count = docs_response.json().get('totalDocuments', 0) + + # Step 3: Display warning and get confirmation + print(f"\nWARNING: You are about to delete collection '{collection_id}'") + print(f"This will remove {doc_count} documents from S3 and the vector store.") + print("This action cannot be undone.") + + confirmation = input(f"\nType the collection name '{collection_id}' to confirm: ") + + if confirmation != collection_id: + print("Deletion cancelled - name did not match") + return False + + # Step 4: Delete the collection + response = requests.delete(url, headers=headers) + if response.status_code == 204: + print(f"Collection '{collection_id}' deleted successfully") + return True + else: + print(f"Error deleting collection: {response.status_code} - {response.text}") + return False + +# Usage +delete_collection_with_confirmation( + "repo-123", + "550e8400-e29b-41d4-a716-446655440000", + "{YOUR_TOKEN}" +) +``` + +### List Collections + +List collections in a repository with pagination, filtering, and sorting. + +**Endpoint:** `GET /repository/{repositoryId}/collections` + +**Path Parameters:** +- `repositoryId` (string, required): The parent vector store repository ID + +**Query Parameters:** + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `pageSize` | integer | No | 20 | Number of items per page (max: 100) | +| `filter` | string | No | - | Text filter for name/description (substring match) | +| `status` | enum | No | - | Status filter: `ACTIVE`, `ARCHIVED`, or `DELETED` | +| `sortBy` | enum | No | `createdAt` | Sort field: `name`, `createdAt`, or `updatedAt` | +| `sortOrder` | enum | No | `desc` | Sort order: `asc` or `desc` | +| `lastEvaluatedKeyCollectionId` | string | No | - | Pagination token: collection ID from previous response | +| `lastEvaluatedKeyRepositoryId` | string | No | - | Pagination token: repository ID from previous response | +| `lastEvaluatedKeyStatus` | string | No | - | Pagination token: status from previous response (if status filter used) | +| `lastEvaluatedKeyCreatedAt` | string | No | - | Pagination token: createdAt from previous response | + +**Response (200 OK):** + +```json +{ + "collections": [ + { + "collectionId": "550e8400-e29b-41d4-a716-446655440000", + "repositoryId": "repo-123", + "name": "Legal Documents", + "description": "Collection for legal contracts and agreements", + "chunkingStrategy": { + "type": "RECURSIVE", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200, + "separators": ["\n\n", "\n", ". ", " "] + } + }, + "allowChunkingOverride": true, + "metadata": { + "tags": ["legal", "contracts", "confidential"] + }, + "allowedGroups": ["legal-team", "compliance"], + "embeddingModel": "amazon.titan-embed-text-v1", + "createdBy": "user-456", + "createdAt": "2025-10-13T10:30:00Z", + "updatedAt": "2025-10-13T10:30:00Z", + "status": "ACTIVE", + "private": false, + "pipelines": [] + }, + { + "collectionId": "660e8400-e29b-41d4-a716-446655440001", + "repositoryId": "repo-123", + "name": "Technical Documentation", + "description": "Collection for technical manuals and guides", + "chunkingStrategy": { + "type": "FIXED", + "parameters": { + "chunkSize": 1500, + "chunkOverlap": 300 + } + }, + "allowChunkingOverride": true, + "metadata": { + "tags": ["technical", "documentation"] + }, + "allowedGroups": ["engineering", "support"], + "embeddingModel": "amazon.titan-embed-text-v1", + "createdBy": "user-789", + "createdAt": "2025-10-12T15:20:00Z", + "updatedAt": "2025-10-12T15:20:00Z", + "status": "ACTIVE", + "private": false, + "pipelines": [] + } + ], + "pagination": { + "totalCount": 45, + "currentPage": 1, + "totalPages": 3 + }, + "lastEvaluatedKey": { + "collectionId": "660e8400-e29b-41d4-a716-446655440001", + "repositoryId": "repo-123", + "createdAt": "2025-10-12T15:20:00Z" + }, + "hasNextPage": true, + "hasPreviousPage": false +} +``` + +**Response Fields:** + +| Field | Type | Description | +|-------|------|-------------| +| `collections` | array[object] | List of collection configurations | +| `pagination` | object | Pagination metadata | +| `pagination.totalCount` | integer | Total number of collections (null if filters applied) | +| `pagination.currentPage` | integer | Current page number (null if using lastEvaluatedKey) | +| `pagination.totalPages` | integer | Total number of pages (null if filters applied) | +| `lastEvaluatedKey` | object | Pagination token for next page (null if no more pages) | +| `hasNextPage` | boolean | Whether there are more pages | +| `hasPreviousPage` | boolean | Whether there is a previous page | + +**Access Control Filtering:** + +- **Admin users**: See all collections in the repository +- **Non-admin users**: See only collections where: + - User's groups intersect with collection's allowed groups, AND + - Collection is not private OR user is the creator + +**Error Responses:** + +| Status Code | Description | Example | +|-------------|-------------|---------| +| 400 | Bad Request - Invalid parameters | `{"error": "Invalid sortBy value: invalid"}` | +| 403 | Forbidden - Insufficient permissions | `{"error": "Permission denied: User does not have read access to repository"}` | +| 404 | Not Found - Repository not found | `{"error": "Repository 'repo-123' not found"}` | +| 500 | Internal Server Error | `{"error": "Failed to list collections"}` | + +**Example cURL Request (Basic):** + +```bash +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collections" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example cURL Request (With Filters):** + +```bash +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collections?pageSize=50&filter=legal&status=ACTIVE&sortBy=name&sortOrder=asc" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example cURL Request (Pagination):** + +```bash +# First page +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collections?pageSize=20" \ + -H "Authorization: Bearer {YOUR_TOKEN}" + +# Next page (using lastEvaluatedKey from previous response) +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collections?pageSize=20&lastEvaluatedKeyCollectionId=660e8400-e29b-41d4-a716-446655440001&lastEvaluatedKeyRepositoryId=repo-123&lastEvaluatedKeyCreatedAt=2025-10-12T15:20:00Z" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example Python Request:** + +```python +import requests +import urllib.parse + +url = "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collections" +headers = { + "Authorization": "Bearer {YOUR_TOKEN}" +} + +# Basic request +response = requests.get(url, headers=headers) +if response.status_code == 200: + result = response.json() + collections = result['collections'] + print(f"Found {len(collections)} collections") + + for collection in collections: + print(f" - {collection['name']} ({collection['collectionId']})") + + # Check for more pages + if result['hasNextPage']: + print("More pages available") + last_key = result['lastEvaluatedKey'] + print(f"Next page token: {last_key}") +else: + print(f"Error: {response.status_code} - {response.text}") + +# Request with filters +params = { + "pageSize": 50, + "filter": "legal", + "status": "ACTIVE", + "sortBy": "name", + "sortOrder": "asc" +} +response = requests.get(url, headers=headers, params=params) +if response.status_code == 200: + result = response.json() + print(f"Filtered results: {len(result['collections'])} collections") + +# Pagination example +def get_all_collections(repository_id, page_size=20): + """Fetch all collections with pagination.""" + all_collections = [] + last_evaluated_key = None + + while True: + params = {"pageSize": page_size} + + # Add pagination token if available + if last_evaluated_key: + params["lastEvaluatedKeyCollectionId"] = last_evaluated_key["collectionId"] + params["lastEvaluatedKeyRepositoryId"] = last_evaluated_key["repositoryId"] + if "createdAt" in last_evaluated_key: + params["lastEvaluatedKeyCreatedAt"] = last_evaluated_key["createdAt"] + + response = requests.get( + f"https://{{API-GATEWAY-DOMAIN}}/{{STAGE}}/repository/{repository_id}/collections", + headers=headers, + params=params + ) + + if response.status_code != 200: + print(f"Error: {response.status_code} - {response.text}") + break + + result = response.json() + all_collections.extend(result['collections']) + + # Check if there are more pages + if not result['hasNextPage']: + break + + last_evaluated_key = result['lastEvaluatedKey'] + + return all_collections + +# Get all collections +all_collections = get_all_collections("repo-123") +print(f"Total collections: {len(all_collections)}") +``` + +**Example JavaScript Request:** + +```javascript +const url = 'https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/repo-123/collections'; +const headers = { + 'Authorization': 'Bearer {YOUR_TOKEN}' +}; + +// Basic request +fetch(url, { + method: 'GET', + headers: headers +}) + .then(response => { + if (response.status === 200) { + return response.json(); + } + throw new Error(`Error: ${response.status}`); + }) + .then(result => { + const collections = result.collections; + console.log(`Found ${collections.length} collections`); + + collections.forEach(collection => { + console.log(` - ${collection.name} (${collection.collectionId})`); + }); + + // Check for more pages + if (result.hasNextPage) { + console.log('More pages available'); + console.log('Next page token:', result.lastEvaluatedKey); + } + }) + .catch(error => { + console.error('Error:', error); + }); + +// Request with filters +const params = new URLSearchParams({ + pageSize: '50', + filter: 'legal', + status: 'ACTIVE', + sortBy: 'name', + sortOrder: 'asc' +}); + +fetch(`${url}?${params}`, { + method: 'GET', + headers: headers +}) + .then(response => response.json()) + .then(result => { + console.log(`Filtered results: ${result.collections.length} collections`); + }) + .catch(error => { + console.error('Error:', error); + }); + +// Pagination example +async function getAllCollections(repositoryId, pageSize = 20) { + const allCollections = []; + let lastEvaluatedKey = null; + + while (true) { + const params = new URLSearchParams({ pageSize: pageSize.toString() }); + + // Add pagination token if available + if (lastEvaluatedKey) { + params.append('lastEvaluatedKeyCollectionId', lastEvaluatedKey.collectionId); + params.append('lastEvaluatedKeyRepositoryId', lastEvaluatedKey.repositoryId); + if (lastEvaluatedKey.createdAt) { + params.append('lastEvaluatedKeyCreatedAt', lastEvaluatedKey.createdAt); + } + } + + const response = await fetch( + `https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/${repositoryId}/collections?${params}`, + { method: 'GET', headers: headers } + ); + + if (response.status !== 200) { + console.error(`Error: ${response.status}`); + break; + } + + const result = await response.json(); + allCollections.push(...result.collections); + + // Check if there are more pages + if (!result.hasNextPage) { + break; + } + + lastEvaluatedKey = result.lastEvaluatedKey; + } + + return allCollections; +} + +// Get all collections +getAllCollections('repo-123') + .then(collections => { + console.log(`Total collections: ${collections.length}`); + }) + .catch(error => { + console.error('Error:', error); + }); +``` + +**Filtering Examples:** + +1. **Filter by name/description:** + ``` + GET /repository/repo-123/collections?filter=legal + ``` + Returns collections with "legal" in name or description + +2. **Filter by status:** + ``` + GET /repository/repo-123/collections?status=ACTIVE + ``` + Returns only active collections + +3. **Combined filters:** + ``` + GET /repository/repo-123/collections?filter=legal&status=ACTIVE + ``` + Returns active collections with "legal" in name or description + +**Sorting Examples:** + +1. **Sort by name (ascending):** + ``` + GET /repository/repo-123/collections?sortBy=name&sortOrder=asc + ``` + +2. **Sort by creation date (newest first):** + ``` + GET /repository/repo-123/collections?sortBy=createdAt&sortOrder=desc + ``` + +3. **Sort by last update (oldest first):** + ``` + GET /repository/repo-123/collections?sortBy=updatedAt&sortOrder=asc + ``` + +**Pagination Notes:** + +1. **Page Size**: Maximum 100 items per page. Default is 20. +2. **Total Count**: Only available when no filters are applied (for performance reasons) +3. **Pagination Token**: Use `lastEvaluatedKey` from response to get next page +4. **URL Encoding**: Ensure pagination token values are URL-encoded when constructing URLs + +## Inheritance Rules + +Collections inherit configuration from their parent vector store: + +1. **Embedding Model**: If not specified, inherits from parent vector store +2. **Allowed Groups**: If not specified or empty array, inherits from parent vector store +3. **Chunking Strategy**: If not specified, inherits from parent vector store's first pipeline +4. **Metadata**: Merged with parent vector store metadata (collection tags take precedence) + +## Validation Rules + +### Collection Name +- Required for creation +- Maximum 100 characters +- Must be unique within repository +- Allowed characters: alphanumeric, spaces, hyphens, underscores + +### Allowed Groups +- Must be subset of parent repository's allowed groups +- Empty array inherits from parent + +### Chunking Strategy Parameters +- `chunkSize`: 100-10000 +- `chunkOverlap`: 0 to chunkSize/2 +- `separators`: non-empty array for RECURSIVE strategy + +### Metadata Tags +- Maximum 50 tags per collection +- Each tag maximum 50 characters +- Allowed characters: alphanumeric, hyphens, underscores + +## Access Control + +### Permission Levels +- **Read**: View collection configuration, query documents +- **Write**: Upload documents, update collection metadata +- **Admin**: Delete collection, modify access control + +### Access Rules +1. Admin users have full access to all collections +2. Non-admin users must have group membership intersection with collection's allowed groups +3. Private collections are only accessible to creator and admins +4. Vector stores with `allowUserCollections: false` prevent non-admin collection creation + +## Best Practices + +1. **Use Descriptive Names**: Choose clear, descriptive names for collections to make them easy to identify +2. **Organize by Content Type**: Create separate collections for different document types (e.g., legal, technical, marketing) +3. **Optimize Chunking Strategy**: Select chunking strategies appropriate for your content: + - Use `FIXED` for uniform documents + - Use `SEMANTIC` for documents with clear semantic boundaries + - Use `RECURSIVE` for documents with hierarchical structure +4. **Manage Access Control**: Use allowed groups to restrict access to sensitive collections +5. **Use Private Collections**: Mark collections as private for personal or temporary collections +6. **Tag Collections**: Use metadata tags for easier filtering and organization + +## Troubleshooting + +### Common Errors + +**"Collection name must be unique within repository"** +- Solution: Choose a different name or check existing collections + +**"User does not have write access to repository"** +- Solution: Ensure user is in one of the repository's allowed groups or is an admin + +**"Allowed groups must be subset of parent repository groups"** +- Solution: Only specify groups that exist in the parent repository's allowed groups + +**"Chunk size must be between 100 and 10000"** +- Solution: Adjust chunk size to be within the valid range + +**"Cannot create collection: allowUserCollections is false"** +- Solution: Contact an administrator to enable user collections or have an admin create the collection + + +### List User Collections (Cross-Repository) + +List all collections the user has access to across all repositories. This endpoint aggregates collections from multiple repositories based on user permissions. + +**Endpoint:** `GET /repository/collections` + +**Query Parameters:** + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `pageSize` | integer | No | 20 | Number of items per page (max: 100) | +| `filter` | string | No | - | Text filter for name/description (substring match) | +| `sortBy` | enum | No | `createdAt` | Sort field: `name`, `createdAt`, or `updatedAt` | +| `sortOrder` | enum | No | `desc` | Sort order: `asc` or `desc` | +| `lastEvaluatedKey` | string | No | - | Pagination token (JSON string) from previous response | + +**Response (200 OK):** + +```json +{ + "collections": [ + { + "collectionId": "550e8400-e29b-41d4-a716-446655440000", + "repositoryId": "repo-123", + "repositoryName": "Legal Repository", + "name": "Legal Documents", + "description": "Collection for legal contracts and agreements", + "chunkingStrategy": { + "type": "RECURSIVE", + "parameters": { + "chunkSize": 1000, + "chunkOverlap": 200, + "separators": ["\n\n", "\n", ". ", " "] + } + }, + "allowChunkingOverride": true, + "metadata": { + "tags": ["legal", "contracts", "confidential"] + }, + "allowedGroups": ["legal-team", "compliance"], + "embeddingModel": "amazon.titan-embed-text-v1", + "createdBy": "user-456", + "createdAt": "2025-10-13T10:30:00Z", + "updatedAt": "2025-10-13T10:30:00Z", + "status": "ACTIVE", + "private": false, + "pipelines": [] + }, + { + "collectionId": "660e8400-e29b-41d4-a716-446655440001", + "repositoryId": "repo-456", + "repositoryName": "Technical Repository", + "name": "Technical Documentation", + "description": "Collection for technical manuals and guides", + "chunkingStrategy": { + "type": "FIXED", + "parameters": { + "chunkSize": 1500, + "chunkOverlap": 300 + } + }, + "allowChunkingOverride": true, + "metadata": { + "tags": ["technical", "documentation"] + }, + "allowedGroups": ["engineering", "support"], + "embeddingModel": "amazon.titan-embed-text-v1", + "createdBy": "user-789", + "createdAt": "2025-10-12T15:20:00Z", + "updatedAt": "2025-10-12T15:20:00Z", + "status": "ACTIVE", + "private": false, + "pipelines": [] + } + ], + "pagination": { + "totalCount": null, + "currentPage": null, + "totalPages": null + }, + "lastEvaluatedKey": "{\"version\":\"v1\",\"offset\":20,\"filters\":{\"filter\":null,\"sortBy\":\"createdAt\",\"sortOrder\":\"desc\"}}", + "hasNextPage": true, + "hasPreviousPage": false +} +``` + +**Response Fields:** + +| Field | Type | Description | +|-------|------|-------------| +| `collections` | array[object] | List of collection configurations with repository names | +| `collections[].repositoryName` | string | Name of the parent repository (enriched field) | +| `pagination` | object | Pagination metadata (totalCount not available for cross-repo queries) | +| `lastEvaluatedKey` | string | Pagination token for next page (JSON string, null if no more pages) | +| `hasNextPage` | boolean | Whether there are more pages | +| `hasPreviousPage` | boolean | Whether there is a previous page | + +**Access Control:** + +This endpoint implements a **repository-first permission model**: + +1. **Repository-Level Filtering**: First filters repositories based on user's group membership +2. **Collection-Level Filtering**: Then filters collections within accessible repositories +3. **Admin Access**: Admin users see all collections from all repositories +4. **Non-Admin Access**: Non-admin users see collections where: + - User's groups intersect with repository's allowed groups, AND + - User's groups intersect with collection's allowed groups (or collection inherits from repository), AND + - Collection is not private OR user is the creator + +**Pagination Strategy:** + +The endpoint automatically selects an appropriate pagination strategy based on dataset size: + +- **Simple Strategy** (<1000 collections): In-memory aggregation with offset-based pagination +- **Scalable Strategy** (1000+ collections): Incremental merge with per-repository cursors + +**Pagination Token Format:** + +The pagination token is a JSON string with two possible formats: + +**V1 Token (Simple Strategy):** +```json +{ + "version": "v1", + "offset": 20, + "filters": { + "filter": "legal", + "sortBy": "name", + "sortOrder": "asc" + } +} +``` + +**V2 Token (Scalable Strategy):** +```json +{ + "version": "v2", + "repositoryCursors": { + "repo-123": { + "lastEvaluatedKey": {"collectionId": "...", "repositoryId": "..."}, + "exhausted": false + }, + "repo-456": { + "lastEvaluatedKey": null, + "exhausted": true + } + }, + "globalOffset": 0, + "filters": { + "filter": null, + "sortBy": "createdAt", + "sortOrder": "desc" + } +} +``` + +**Error Responses:** + +| Status Code | Description | Example | +|-------------|-------------|---------| +| 400 | Bad Request - Invalid parameters | `{"error": "Invalid sortBy value. Must be one of: name, createdAt, updatedAt"}` | +| 401 | Unauthorized - Missing authentication | `{"error": "Authentication required"}` | +| 500 | Internal Server Error | `{"error": "Failed to retrieve collections"}` | + +**Example cURL Request (Basic):** + +```bash +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/collections" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example cURL Request (With Filters):** + +```bash +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/collections?pageSize=50&filter=legal&sortBy=name&sortOrder=asc" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example cURL Request (Pagination):** + +```bash +# First page +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/collections?pageSize=20" \ + -H "Authorization: Bearer {YOUR_TOKEN}" + +# Next page (using lastEvaluatedKey from previous response) +NEXT_TOKEN='{"version":"v1","offset":20,"filters":{"filter":null,"sortBy":"createdAt","sortOrder":"desc"}}' +curl -X GET "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/collections?pageSize=20&lastEvaluatedKey=$(echo $NEXT_TOKEN | jq -R -r @uri)" \ + -H "Authorization: Bearer {YOUR_TOKEN}" +``` + +**Example Python Request:** + +```python +import requests +import json +import urllib.parse + +url = "https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/collections" +headers = { + "Authorization": "Bearer {YOUR_TOKEN}" +} + +# Basic request +response = requests.get(url, headers=headers) +if response.status_code == 200: + result = response.json() + collections = result['collections'] + print(f"Found {len(collections)} collections across all repositories") + + for collection in collections: + print(f" - {collection['name']} (Repository: {collection['repositoryName']})") + + # Check for more pages + if result['hasNextPage']: + print("More pages available") + next_token = result['lastEvaluatedKey'] + print(f"Next page token: {next_token}") +else: + print(f"Error: {response.status_code} - {response.text}") + +# Request with filters +params = { + "pageSize": 50, + "filter": "legal", + "sortBy": "name", + "sortOrder": "asc" +} +response = requests.get(url, headers=headers, params=params) +if response.status_code == 200: + result = response.json() + print(f"Filtered results: {len(result['collections'])} collections") + +# Pagination example +def get_all_user_collections(page_size=20): + """Fetch all accessible collections with pagination.""" + all_collections = [] + last_evaluated_key = None + + while True: + params = {"pageSize": page_size} + + # Add pagination token if available + if last_evaluated_key: + params["lastEvaluatedKey"] = last_evaluated_key + + response = requests.get( + f"https://{{API-GATEWAY-DOMAIN}}/{{STAGE}}/repository/collections", + headers=headers, + params=params + ) + + if response.status_code != 200: + print(f"Error: {response.status_code} - {response.text}") + break + + result = response.json() + all_collections.extend(result['collections']) + + # Check if there are more pages + if not result['hasNextPage']: + break + + last_evaluated_key = result['lastEvaluatedKey'] + + return all_collections + +# Get all accessible collections +all_collections = get_all_user_collections() +print(f"Total accessible collections: {len(all_collections)}") + +# Group by repository +from collections import defaultdict +by_repo = defaultdict(list) +for collection in all_collections: + by_repo[collection['repositoryName']].append(collection['name']) + +print("\nCollections by repository:") +for repo_name, coll_names in by_repo.items(): + print(f" {repo_name}: {len(coll_names)} collections") + for name in coll_names: + print(f" - {name}") +``` + +**Example JavaScript Request:** + +```javascript +const url = 'https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/collections'; +const headers = { + 'Authorization': 'Bearer {YOUR_TOKEN}' +}; + +// Basic request +fetch(url, { + method: 'GET', + headers: headers +}) + .then(response => { + if (response.status === 200) { + return response.json(); + } + throw new Error(`Error: ${response.status}`); + }) + .then(result => { + const collections = result.collections; + console.log(`Found ${collections.length} collections across all repositories`); + + collections.forEach(collection => { + console.log(` - ${collection.name} (Repository: ${collection.repositoryName})`); + }); + + // Check for more pages + if (result.hasNextPage) { + console.log('More pages available'); + console.log('Next page token:', result.lastEvaluatedKey); + } + }) + .catch(error => { + console.error('Error:', error); + }); + +// Request with filters +const params = new URLSearchParams({ + pageSize: '50', + filter: 'legal', + sortBy: 'name', + sortOrder: 'asc' +}); + +fetch(`${url}?${params}`, { + method: 'GET', + headers: headers +}) + .then(response => response.json()) + .then(result => { + console.log(`Filtered results: ${result.collections.length} collections`); + }) + .catch(error => { + console.error('Error:', error); + }); + +// Pagination example +async function getAllUserCollections(pageSize = 20) { + const allCollections = []; + let lastEvaluatedKey = null; + + while (true) { + const params = new URLSearchParams({ pageSize: pageSize.toString() }); + + // Add pagination token if available + if (lastEvaluatedKey) { + params.append('lastEvaluatedKey', lastEvaluatedKey); + } + + const response = await fetch( + `https://{API-GATEWAY-DOMAIN}/{STAGE}/repository/collections?${params}`, + { method: 'GET', headers: headers } + ); + + if (response.status !== 200) { + console.error(`Error: ${response.status}`); + break; + } + + const result = await response.json(); + allCollections.push(...result.collections); + + // Check if there are more pages + if (!result.hasNextPage) { + break; + } + + lastEvaluatedKey = result.lastEvaluatedKey; + } + + return allCollections; +} + +// Get all accessible collections +getAllUserCollections() + .then(collections => { + console.log(`Total accessible collections: ${collections.length}`); + + // Group by repository + const byRepo = collections.reduce((acc, collection) => { + const repoName = collection.repositoryName; + if (!acc[repoName]) { + acc[repoName] = []; + } + acc[repoName].push(collection.name); + return acc; + }, {}); + + console.log('\nCollections by repository:'); + Object.entries(byRepo).forEach(([repoName, collNames]) => { + console.log(` ${repoName}: ${collNames.length} collections`); + collNames.forEach(name => { + console.log(` - ${name}`); + }); + }); + }) + .catch(error => { + console.error('Error:', error); + }); +``` + +**Use Cases:** + +1. **Document Library UI**: Display all collections user can access in a single view +2. **Collection Browser**: Allow users to browse and search across all their collections +3. **Access Audit**: Verify which collections a user has access to +4. **Collection Discovery**: Help users find collections they didn't know existed + +**Performance Considerations:** + +- **Small Deployments** (<1000 collections): Response time ~100-200ms +- **Large Deployments** (1000+ collections): Response time ~200-500ms per page +- **Caching**: Consider caching results for frequently accessed data +- **Pagination**: Use appropriate page sizes (20-50) for optimal performance + +**Comparison with Single-Repository Endpoint:** + +| Feature | `/repository/{repositoryId}/collection` | `/repository/collections` | +|---------|----------------------------------------|---------------------------| +| Scope | Single repository | All accessible repositories | +| Repository Name | Not included | Included in response | +| Total Count | Available | Not available (performance) | +| Pagination | DynamoDB native | Hybrid (simple/scalable) | +| Use Case | Repository-specific operations | Cross-repository browsing | + +**Best Practices:** + +1. **Use Appropriate Page Size**: Start with default (20) and adjust based on UI needs +2. **Handle Pagination Tokens**: Store and pass tokens as opaque strings +3. **Filter Early**: Use filter parameter to reduce result set size +4. **Cache Results**: Cache responses for read-heavy workloads +5. **Monitor Performance**: Track response times for large datasets diff --git a/lib/docs/config/hosted-mcp.md b/lib/docs/config/hosted-mcp.md new file mode 100644 index 000000000..d5003843a --- /dev/null +++ b/lib/docs/config/hosted-mcp.md @@ -0,0 +1,253 @@ +# Hosted MCP Servers + +## Overview + +LISA MCP provides scalable infrastructure to support the deployment and hosting of first-party MCP servers and tools. +It is a stand-alone solution that can either be deployed independently of LISA Serve, or configured to work seamlessly +with LISA Serve. Each MCP server deployed via LISA MCP is provisioned on AWS Fargate via Amazon ECS, fronted by +Application/Network Load Balancers, and published through the existing API Gateway. This allows chat sessions to securely invoke +MCP tools without leaving your VPC. Every route remains protected by the same API Gateway Lambda authorizer that guards the rest +of LISA, so API Keys, IDP lockdown, and JWT group enforcement continue to apply automatically. Because the endpoints are +standard HTTP routes behind API Gateway, you can also share them with trusted third-party agents, copilots, or workflow engines +outside of LISA while preserving the same authentication store; you may issue API keys, short-lived JWTs, or IDP credentials, and those external consumers can use the MCP server just like LISA chat clients. The Create, Update, Delete workflows are orchestrated by +Step Functions and are auditable through DynamoDB status records. + +## Key Features + +- **Turn‑key hosting** – Deploy STDIO, HTTP, or SSE MCP servers with a single API/UI workflow +- **Dynamic container builds** – Bring a pre-built image or point to S3 artifacts that are turned into a container at deploy time +- **mcp-proxy support** – STDIO servers are automatically wrapped with `mcp-proxy` and exposed over HTTP +- **Auto scaling** – Configure Fargate min/max capacity, custom metrics, and scaling targets per server +- **Secure networking** – Private VPC networking with ALB for internal traffic and NLB + VPC Link for API Gateway access +- **Group-aware routing** – Limit server visibility to specific identity provider groups or make them public (`lisa:public`) +- **Lifecycle automation** – Step Functions manage provisioning, health polling, failure handling, and connection publishing +- **UI & API parity** – Manage servers through the MCP Management admin page or the `/mcp` REST endpoints +- **External integrations** – Exposed via API Gateway URLs so external copilots, RPA bots, or SaaS workloads can invoke + the hosted MCP server using the same credentials and auth controls you already enforce in LISA + +## Architecture + +### Workflow + +1. **Create request** – Admin issues `POST /{stage}/mcp` (or uses the UI) with a `HostedMcpServerModel` payload. +2. **DynamoDB record** – The Lambda API validates the payload, enforces unique normalized names, and writes the server + record with status `CREATING`. +3. **State machine** – The `CreateMcpServer` Step Functions workflow executes: + - `handle_set_server_to_creating` – persists status + - `handle_deploy_server` – calls the MCP server deployer Lambda with the sanitized config + - `handle_poll_deployment` – waits for the CloudFormation stack to finish + - `handle_add_server_to_active` – marks the record `IN_SERVICE` and (optionally) publishes a connection for the chat UI +4. **Deployer Lambda** – Synthesizes a dedicated CloudFormation stack that builds/launches an ECS Fargate service, + load balancers, VPC Link integration, and optional auto scaling targets. +5. **API Gateway** – Receives MCP traffic on `/mcp/{serverId}` and forwards through the VPC Link/NLB to the Fargate task. + +### Data Storage + +- **`MCP_SERVERS_TABLE`** – Primary metadata store (status, scaling config, networking details, groups). +- **`McpConnectionsTable`** (optional) – When `DEPLOYMENT_PREFIX` is set, completed servers are published here so the + chat application can surface them alongside externally hosted connections. +- **SSM** – Stores the Lisa API base URL (`LisaApiUrl`) and the optional hosted connections table name. + +### Networking + +- ECS tasks run inside your VPC using the same subnets/security groups as the MCP API stack. +- An **Application Load Balancer** fronts internal traffic while a **Network Load Balancer** terminates the API Gateway + VPC Link. +- STDIO servers always expose port `8080` (via `mcp-proxy`); HTTP/SSE servers use the configured `port` (default `8000`). +- API clients send JWTs; the MCP server receives `Authorization: Bearer {LISA_BEARER_TOKEN}`, which LISA replaces per user + when establishing a connection. +- API Gateway enforces the same Lambda authorizer used across LISA (JWT validation + optional API key checks). If + **API Key Required** or **IDP Lockdown** is enabled at the stage, hosted MCP endpoints automatically inherit those + protections—no extra configuration is necessary on the server itself. +- External consumers (agents, other apps, automation) call the same API Gateway URLs; simply provision them API keys or + federated identities and they gain access to the MCP server without any direct network connectivity to the VPC. + +## Prerequisites + +- Administrator access to LISA and the MCP Management UI. +- MCP Server Connections feature enabled (see [Model Context Protocol (MCP)](./mcp.md)). +- AWS resources created by `deploylisa` (state machines, MCP API stack, hosting bucket, etc.). +- S3 bucket path (if you need to sync binaries, Python files, or configuration at container start). +- Optional pre-built ECR image ARN or Docker Hub image reference (if not using dynamic builds). +- Identified server type (`stdio`, `http`, or `sse`) and the exact `startCommand` to launch it. +- IAM task execution/task roles if your container must call other AWS services (otherwise defaults are generated). + +## Deployment Flow + +1. **Prepare artifacts** + - Upload your MCP server files to the hosting bucket (e.g., `s3:///servers/my-server/`), or publish a container + image that already includes the server runtime. +2. **Send create request** + - Use the REST API or MCP Management UI to submit the configuration (see examples below). +3. **Monitor progress** + - The UI surfaces status transitions (`CREATING → IN_SERVICE`). For API-only workflows, poll `GET /{stage}/mcp/{id}` or + view the Step Functions execution in CloudWatch for stack details. +4. **Publish to users** + - Once `IN_SERVICE`, the server automatically appears on the MCP Management table. If groups are defined, only members + of those groups will see the connection by default. + +## Configuration Reference + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `name` | string | ✅ | Human-friendly name (must normalize to a unique alphanumeric identifier). | +| `description` | string | | Optional UI description. | +| `serverType` | `'stdio' \| 'http' \| 'sse'` | ✅ | Determines networking and entrypoint behavior. STDIO servers run behind `mcp-proxy`. | +| `startCommand` | string | ✅ | Command executed inside the container (e.g., `python server.py`). | +| `s3Path` | string | | Optional `bucket/path` pointing at artifacts to sync into `/app/server`. | +| `image` | string | | Container image to start from. If omitted, a base image is selected automatically based on `serverType`. | +| `port` | number | | TCP port exposed by HTTP/SSE servers. STDIO servers default to `8080`. | +| `cpu` | number | | Fargate CPU units (256, 512, 1024, 2048, 4096). Defaults to 256. | +| `memoryLimitMiB` | number | | Fargate memory in MiB (min 512, max 30720). Defaults to 512. | +| `autoScalingConfig.minCapacity` | number | ✅ | Minimum number of tasks (must be ≥ 1). | +| `autoScalingConfig.maxCapacity` | number | ✅ | Maximum number of tasks (must be ≥ minCapacity). | +| `autoScalingConfig.targetValue` | number | | Target metric value (e.g., requests per target). | +| `autoScalingConfig.metricName` | string | | Optional custom metric identifier. | +| `autoScalingConfig.duration` | number | | Scaling lookback window (seconds). | +| `autoScalingConfig.cooldown` | number | | Cooldown between scaling actions (seconds). | +| `loadBalancerConfig.healthCheckConfig.*` | object | | Optional ALB health check overrides (`path`, `interval`, `timeout`, `healthyThresholdCount`, `unhealthyThresholdCount`). | +| `containerHealthCheckConfig.*` | object | | Optional ECS task health checks (`command`, `interval`, `startPeriod`, `timeout`, `retries`). | +| `environment` | map | | Extra environment variables for your server. Avoid putting secrets here—use AWS Secrets Manager or SSM and reference them from your code. | +| `taskExecutionRoleArn` | string | | Execution role to pull private images / read from ECR or S3. Generated automatically if omitted. | +| `taskRoleArn` | string | | Task role used by your server code to call AWS APIs. Generated automatically if omitted. | +| `groups` | string[] | | Identity provider groups allowed to see/use the server. Prefix is added automatically if missing (`group:finance`). Empty/null defaults to `lisa:public`. | + +## Example: Create a Hosted MCP Server + +```bash +curl -X POST https://api.example.com/prod/mcp \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "name": "docs-mcp-workbench", + "description": "Company knowledge base tools", + "serverType": "stdio", + "startCommand": "python main.py", + "autoScalingConfig": { "minCapacity": 1, "maxCapacity": 2, "targetValue": 80 }, + "s3Path": "servers/docs-mcp/", + "cpu": 512, + "memoryLimitMiB": 1024, + "environment": { + "TOOLS_DIR": "/app/server/tools", + "LOG_LEVEL": "info" + }, + "loadBalancerConfig": { + "healthCheckConfig": { + "path": "/health", + "interval": 30, + "timeout": 5, + "healthyThresholdCount": 2, + "unhealthyThresholdCount": 3 + } + }, + "containerHealthCheckConfig": { + "command": "curl -f http://localhost:8080/healthz", + "interval": 30, + "startPeriod": 10, + "timeout": 5, + "retries": 3 + }, + "groups": ["group:admins"] + }' +``` + +Response (truncated): + +```json +{ + "id": "3f5a…", + "name": "docs-mcp-workbench", + "status": "Creating", + "autoScalingConfig": { + "minCapacity": 1, + "maxCapacity": 2 + }, + "stack_name": null +} +``` + +## API Operations + +| Method & Path | Description | +|---------------|-------------| +| `POST /{stage}/mcp` | Create a hosted server (admin only). | +| `GET /{stage}/mcp` | List hosted servers (admin only). | +| `GET /{stage}/mcp/{serverId}` | Retrieve a single hosted server, including current status and stack info. | +| `PUT /{stage}/mcp/{serverId}` | Update auto scaling, environment, health checks, or enable/disable the service. Only servers in `IN_SERVICE` or `STOPPED` can be updated. | +| `DELETE /{stage}/mcp/{serverId}` | Begin the teardown workflow. Only `IN_SERVICE`, `STOPPED`, or `FAILED` servers can be deleted. | + +> **Tip:** The Update API accepts an `UpdateHostedMcpServerRequest` payload. Use the `enabled` flag to start/stop an +> existing server; supply `autoScalingConfig`, `environment`, or health-check fields to update those aspects. The request +> must include at least one field or validation will fail. + +## MCP Management UI Workflow + +1. **Navigate** – Select **Admin → MCP Management** in the top navigation (admin-only). +2. **Create** – Click **Create hosted MCP server** to open the wizard: + - **Server details** – Name, owner visibility (public vs. private), server type, start command, optional S3 path/image, + environment variables, and group assignments. + - **Scaling** – Min/max capacity, CPU/memory, optional custom metric target. + - **Health checks** – Configure ALB and container checks or accept the defaults. +3. **Review & launch** – Submit the form. A banner will display the provisioning status and surface Step Functions failure + messages if any arise. +4. **Manage** – Use the action bar to **Edit**, **Delete**, **Start**, or **Stop** selected servers. Bulk actions are + available when multiple rows are selected. +5. **Monitor** – Columns expose current status, stack name, owner, groups, and timestamps. Use the table preferences panel + to adjust visible columns or export the data. + +## Working with S3 Artifacts + +- Uploaded files are synced into `/app/server` before the `startCommand` runs. Ensure your script either resides at the + root of that directory or adjusts paths accordingly. +- Make your scripts executable (`chmod +x`) when copying to S3 if they are shell/binary files. +- If you provide both `image` and `s3Path`, the image is used as the base layer and the artifacts overwrite/add files at + runtime. This is useful for extending a golden image with per-server content. +- Grant the specified task role `s3:GetObject` permissions on the hosting bucket path. When using the default role, LISA + automatically injects the policy. + +## Best Practices + +- **Unique names** – Server names are normalized to alphanumeric characters for stack/resource naming. Choose descriptive + names and avoid collisions. +- **Secrets management** – Do not embed secrets in `environment`. Instead, fetch them from AWS Secrets Manager or SSM in + your server code. +- **Scaling guardrails** – Start with conservative `minCapacity` values and validate CPU/memory usage before increasing + `maxCapacity`. +- **Health checks** – Provide both container and load balancer health checks so the workflow can detect failures early. +- **Group scoping** – Restrict access to high-privilege tools by assigning identity provider groups at creation time. +- **Testing** – Use the [MCP Workbench](./mcp-workbench.md) to iterate on tools locally, then package the same files for + hosted deployment. +- **Monitoring** – Use CloudWatch metrics/alarms for the ECS service, Application Load Balancer, and the Step Functions + workflows to detect regressions quickly. + +## Troubleshooting + +### Create API returns *“CREATE_MCP_SERVER_SFN_ARN not configured”* +- **Cause:** Environment variables were not set when the MCP API Lambda was deployed. +- **Resolution:** Re-run `deploylisa` or manually set `CREATE_MCP_SERVER_SFN_ARN`, `DELETE_MCP_SERVER_SFN_ARN`, and + `UPDATE_MCP_SERVER_SFN_ARN` on the MCP API Lambda, then retry. + +### Error *“Server name conflicts with existing server”* +- **Cause:** Another record normalizes to the same alphanumeric identifier (e.g., `Docs-MCP` vs `docs_mcp`). +- **Resolution:** Choose a different name or delete the prior server before re-creating it. + +### Stack stuck in `CREATING` +- **Cause:** CloudFormation deployment failed (missing IAM roles, invalid container image, unreachable S3 path). +- **Resolution:** Inspect the `CreateMcpServer` Step Functions execution, then open the CloudFormation stack events to + identify the failing resource. Fix the underlying issue and re-run the create workflow. + +### Hosted server is `IN_SERVICE` but unreachable +- **Cause:** Incorrect `port`, health check, or security group settings. +- **Resolution:** Verify the ALB target group health, container logs, and that the application is listening on the + expected port. For STDIO servers, ensure the `startCommand` launches an MCP-compatible process that `mcp-proxy` can wrap. + +### Bearer token placeholder not replaced +- **Cause:** Custom headers still show `{LISA_BEARER_TOKEN}`. +- **Resolution:** The placeholder is replaced at connection time. Make sure the consuming application sends an + `Authorization` header when invoking the MCP connection. The API automatically replaces the placeholder right before + returning connection details. + +### Update API rejects payload +- **Cause:** The `UpdateHostedMcpServerRequest` validator requires at least one field; it also blocks simultaneous + enable/disable and auto scaling changes. +- **Resolution:** Split enable/disable operations from scaling updates, and include only the fields you intend to change. diff --git a/lib/docs/config/mcp.md b/lib/docs/config/mcp.md index b2a3abf03..0845bf147 100644 --- a/lib/docs/config/mcp.md +++ b/lib/docs/config/mcp.md @@ -1,8 +1,8 @@ # Model Context Protocol (MCP) LISA supports Model Context Protocol (MCP), a popular open standard that enables developers to securely connect AI -assistants to systems where data lives. Customers can connect MCP servers with LISA and use the tools hosted on that -server. For example, if an MCP server is added to LISA that supports email/calendar actions then LISA customers can +assistants to systems where data lives. Customers can connect 3rd Party and LISA hosted MCP servers within the chat UI and use the tools hosted on those +servers. For example, if an MCP server is added to LISA that supports email/calendar actions then LISA customers can prompt for supported tasks. In this case, customers could request help sending calendar invites on their behalf to specific colleagues based on everyone’s availability. The LLM would automatically engage the appropriate MCP server tools and perform the necessary steps to complete the task. diff --git a/lib/docs/config/repositories.md b/lib/docs/config/repositories.md new file mode 100644 index 000000000..2d82c2e0f --- /dev/null +++ b/lib/docs/config/repositories.md @@ -0,0 +1,293 @@ +# Retrieval-Augmented Generation (RAG) + +Retrieval-Augmented Generation (RAG) is a technique that enhances language models by combining retrieval and generation. Instead of relying solely on pre-trained knowledge, RAG first retrieves relevant external documents (e.g., from a database, search engine, or vector store) and then uses them to generate more accurate and context-aware responses. + +## RAG Repositories and Collections + +LISA RAG introduces a hierarchical architecture for managing RAG content through **repositories** and **collections**: + +- **Repository**: The top-level container that defines the underlying vector store implementation (OpenSearch, PGVector, or Bedrock Knowledge Base), embedding model, and access controls. Repositories are created and managed by administrators. Repository access can be restricted to specific enterprise groups. + +- **Collection**: Within repositories, collections support a logical grouping of documents. One repository can support many collections. Collection access can be restricted to specific enterprise groups. Collections enable flexible organization of content with their own chunking strategies, metadata tags, and access controls. Administrators create and manage collections via API or UI. Users can view and upload documents within a collection using LISA's Document Library and RAG file upload. + +### Architecture Overview + +The repository-collection model provides a two-tier organizational structure analogous to filing cabinets (repositories) containing organized drawers (collections). This architecture enables: + +- **Multi-Backend Support**: Unified interface across OpenSearch, PGVector, and AWS Bedrock Knowledge Base implementations +- **Configuration Isolation**: Each collection maintains independent chunking strategies, embedding models, and access controls +- **Scalable Organization**: Organize documents by department, project, content type, or security classification without infrastructure changes +- **Backward Compatibility**: Existing repositories automatically include a default collection based on the embedding model ID + +### Key Benefits + +- **Dynamic Management**: Create, update, and delete collections via API without infrastructure changes +- **Optimized Chunking**: Configure chunking strategies per collection to match content type (legal documents, code, customer support tickets) +- **Granular Access Control**: Enforce user group-based permissions at both the repository and collection level +- **Multi-tenancy**: Within repositories, further manage access by restricting collections access (e.g., by enterprise groups for specific organizations, departments, or teams) +- **Enhanced Metadata**: Tag documents with collection-specific metadata for powerful filtering +- **Flexible Embedding Models**: Each collection can use its own embedding model, optimizing retrieval for specific document types + +### Document Ingestion Methods + +Customers have two methods to load files into repositories configured with LISA: + +1. **Manual Upload**: Load files via the chat assistant user interface (UI), or API +2. **Automated Pipeline**: (Admins-only) Configure LISA's ingestion pipelines for automated document processing + +## Configuration + +### Chat Assistant UI + +Files loaded via the chat assistant UI are limited by size, and are processed through a batch job. The status of the job can be viewed within the RAG File Upload dialog. When uploading documents through the UI, you can select a specific collection within a repository. If no collection is specified, documents are ingested into the default collection, which defaults to the embedding model associated with the parent repository. + +### Automated Document Repository Ingestion Pipeline + +LISA's automated document ingestion pipeline supports larger files and broader file types. Supported file types include: PDF, docx, and plain text. The individual file size limit is 50 MB. LISA's pipelines offer chunking support for fixed size chunking or no chunking. For customers using Amazon Bedrock Knowledge Bases, LISA supports all chunking strategies offered by the service. LISA's automated ingestion pipelines provide customers with a flexible, scalable solution for loading documents into configured repositories and collections. + +Customers can set up multiple ingestion pipelines for a repository. For each pipeline they define: +- The target repository and collection +- Embedding model (inherited from repository if not defined) +- Chunking strategy (can be customized per pipeline) +- Ingestion trigger (event-based or daily schedule) +- S3 bucket and prefix to monitor + +Pipelines can be configured at both the repository level (for default collection ingestion) and at the collection level (for targeted ingestion). Each pipeline can run based on an event trigger or daily schedule. Pre-processing converts files into the necessary format, then processing ingests the files with the specified embedding model and loads the data into the designated collection within the repository. + +LISA also supports deleting files and content from repositories, as well as listing the file names and dates ingested. When `autoRemove` is enabled, deleting a document from the repository will also remove it from S3, and vice versa. + +#### Benefits + +The automated ingestion pipeline provides: + +1. **Flexibility**: Accommodates various data sources and formats +2. **Efficiency**: Streamlines the document ingestion process with pre-processing and intelligent indexing +3. **Customization**: Allows customers to choose and easily switch between preferred vector stores +4. **Integration**: Leverages existing LISA capabilities while extending functionality + +#### Use Cases + +Common use cases for automated ingestion include: + +- Large-scale document ingestion for enterprise customers +- Integration of external mission-critical data sources +- Customized knowledge base creation for specific industries or applications +- Department or project-specific document collections with isolated access +- Content-type optimized chunking strategies (legal, technical, conversational) + +> **_NOTE:_** Event ingestion requires Amazon EventBridge to be enabled on the S3 bucket. You can enable this in the bucket's properties configuration page in the AWS Management Console. + +## Managing Collections + +### Collection Lifecycle + +Collections can be created, updated, and deleted through the LISA UI or API. Each collection maintains: + +- **Chunking Strategy**: Optimized for the content type (fixed size or none) +- **Embedding Model**: Inherited from repository or customized per collection +- **Access Control**: User group restrictions inherited from the repository or customized per collection +- **Metadata Tags**: Custom tags for organizing and filtering documents +- **Privacy Settings**: Collections can be marked as private for restricted visibility +- **Ingestion Pipelines**: Dedicated pipelines for automated document ingestion + +Collections support flexible chunking configuration with multiple override levels: + +- **Default Strategy**: Inherited from the repository configuration +- **Collection Strategy**: Override at the collection level for content-specific optimization +- **Pipeline Strategy**: Further override at the ingestion pipeline level +- **API Override**: Optionally allow per-document chunking strategy via API (controlled by `allowChunkingOverride` flag) + +### Default Collections + +Every repository includes a default collection based on the embedding model ID. This ensures backward compatibility with existing LISA deployments (pre v6.0). When no collection is specified during document ingestion or retrieval, the default collection is used. + +Default collections provide: + +- **Automatic Creation**: Generated automatically during repository creation with no additional configuration +- **Zero Downtime Migration**: Existing documents remain accessible through default collections without database migrations +- **Optional Adoption**: Collections are completely optional—repositories continue to function without explicit collection configuration +- **Preserved Documents**: All existing documents remain accessible through default collections after upgrade + +### Document Lifecycle Management + +LISA implements intelligent document lifecycle management that respects how content is created and maintained: + +- **Ingestion Type Tracking**: The system distinguishes between LISA-managed documents, pipeline-generated content, and user-managed documents in Bedrock Knowledge Bases +- **Asynchronous Deletion**: Collection deletion operations execute asynchronously with optimized cleanup strategies per repository type: + - OpenSearch: Drops the entire index before document deletion + - PGVector: Drops the collection table/schema + - Bedrock Knowledge Base: Performs bulk document deletion +- **Document Preservation**: User-managed documents in Bedrock Knowledge Bases are automatically preserved during collection operations, ensuring external content is not inadvertently removed +- **Status Tracking**: Collections maintain status indicators (ACTIVE, DELETE_IN_PROGRESS, DELETE_FAILED) for monitoring lifecycle operations + +### Collection Permissions + +Collection access is controlled through user groups: + +- **Repository-level Groups**: Collections inherit allowed groups from their parent repository by default +- **Collection-level Groups**: Collections can override with their own group restrictions for finer control +- **Admin Access**: Administrators have full access to all collections across all repositories +- **User Collection Creation**: Repositories can be configured to allow or restrict user-created collections via the `allowUserCollections` flag + +## Configuration Examples + +RAG repositories and collections are configurable through the chat assistant web UI or programmatically via the API, allowing customers to tailor the ingestion process to their specific needs. + +### Creating a Repository + +Repositories are created by administrators and define the underlying vector store implementation, embedding model, and default access controls. + +#### Request Example: + +```bash +curl -s -H 'Authorization: Bearer ' -XPOST -d @repository.json https:///repository +``` + +```json +// repository.json +{ + "repositoryId": "my-rag-repository", + "repositoryName": "My RAG Repository", + "type": "pgvector", + "embeddingModelId": "amazon.titan-embed-text-v1", + "rdsConfig": { + "username": "postgres" + }, + "allowedGroups": ["engineering", "data-science"], + "metadata": { + "tags": ["production", "customer-docs"] + }, + "allowUserCollections": true, + "pipelines": [ + { + "chunkingStrategy": { + "type": "fixed", + "size": 512, + "overlap": 51 + }, + "trigger": "event", + "s3Bucket": "my-ingestion-bucket", + "s3Prefix": "documents/", + "autoRemove": true + } + ] +} +``` + +#### Response Fields: + +- `status`: "success" if the state machine was started successfully +- `executionArn`: The state machine ARN used to deploy the repository + +### Creating a Collection + +Collections can be created by users with appropriate permissions within an existing repository. + +#### Request Example: + +```bash +curl -s -H 'Authorization: Bearer ' -XPOST -d @collection.json https:///repository/my-rag-repository/collection +``` + +```json +// collection.json +{ + "name": "Legal Documents", + "description": "Collection for legal contracts and agreements", + "chunkingStrategy": { + "type": "fixed", + "size": 512, + "overlap": 51 + }, + "allowChunkingOverride": false, + "metadata": { + "tags": ["legal", "contracts", "confidential"] + }, + "allowedGroups": ["legal-team", "compliance"], + "private": true, + "pipelines": [ + { + "s3Bucket": "legal-docs-bucket", + "s3Prefix": "contracts/", + "trigger": "event", + "autoRemove": true + } + ] +} +``` + +#### Response Fields: + +- `collectionId`: Unique identifier for the created collection (UUID) +- `repositoryId`: Parent repository identifier +- `name`: User-friendly collection name +- `embeddingModel`: Inherited from parent repository +- `createdBy`: User ID of collection creator +- `createdAt`: Creation timestamp (ISO 8601) +- `status`: Collection status (ACTIVE) + +### Listing Collections + +Retrieve all collections accessible to the current user within a repository. + +#### Request Example: + +```bash +curl -s -H 'Authorization: Bearer ' \ + 'https:///repository/my-rag-repository/collections?page=1&pageSize=20&sortBy=name&sortOrder=asc' +``` + +#### Query Parameters: + +- `page`: Page number (default: 1) +- `pageSize`: Items per page (default: 20, max: 100) +- `filter`: Filter by name or description (optional) +- `sortBy`: Sort field - `name`, `createdAt`, or `updatedAt` (default: `createdAt`) +- `sortOrder`: Sort order - `asc` or `desc` (default: `desc`) + +## UI Components + +### RAG Repository Management (Admin) + +Administrators access repository management through the Admin Configurations page. This interface provides: + +- Create, update, and delete repositories +- Configure vector store implementation (OpenSearch, PGVector, Bedrock Knowledge Base) +- Set default embedding models and chunking strategies +- Define repository-level access controls +- Configure metadata tags +- Enable or disable user-created collections + +### RAG Collection Library + +The Collection Library is accessible from the Document Library page and provides: + +- Browse collections within accessible repositories +- Create new collections (if permitted) +- Update collection settings +- Delete collections (if permitted) +- View collection metadata and statistics +- Filter document collection + +Collections are organized in a tree structure, similar to folders, making it intuitive to navigate and manage documents. + +### Chat Interface + +The chat interface includes repository and collection selection: + +- Select a repository from available options +- Choose a specific collection within the repository +- Default collection is used if none specified +- Embedding model is automatically determined by the collection + +### Document Library + +The Document Library displays documents organized by collection: + +- Tree view showing repository → collection → documents hierarchy +- Filter and search within specific collections +- Upload documents to selected collections +- View document metadata including collection assignment +- Delete documents with optional S3 removal (when `autoRemove` is enabled) + + diff --git a/lib/docs/config/vector-stores.md b/lib/docs/config/vector-stores.md deleted file mode 100644 index 626c78f52..000000000 --- a/lib/docs/config/vector-stores.md +++ /dev/null @@ -1,79 +0,0 @@ -# Retrieval-Augmented Generation (RAG) - -Retrieval-Augmented Generation (RAG) is a technique that enhances language models by combining retrieval and generation. Instead of relying solely on pre-trained knowledge, RAG first retrieves relevant external documents (e.g., from a database, search engine, or vector store) and then uses them to generate more accurate and context-aware responses. - -Customers have two methods to load files into vector stores configured with LISA. Customers can either manually load files via the chatbot user interface (UI), or via an ingestion pipeline. - -## Configuration - -### Chat UI - -Files loaded via the chatbot UI are limited by Lambda's service limits on document file size and volume. - -### Automated Document Vector Store Ingestion Pipeline - -The Automated Document Ingestion Pipeline is designed to enhance LISA's RAG capabilities. Documents loaded via a pipeline are not subject to these limits, further expanding LISA’s ingestion capabilities. This pipeline feature supports the following document file types: PDF, docx, and plain text. The individual file size limit is 50 MB. - -This feature provides customers with a flexible, scalable solution for loading documents into configured vector stores. - -Customers can set up many ingestion pipelines. For each pipeline, they define the vector store and embedding model, and ingestion trigger. Each pipeline can be set up to run based on an event trigger, or to run daily. From there, pre-processing kicks off to convert files into the necessary format. From there, processing kicks off to ingest the files with the specified embedding model and loads the data into the designated vector store. This feature leverages LISA’s existing chunking and vectorizing capabilities. - -LISA also supports deleting files and content from a vector store, as well as listing the file names and dates ingested. - -### Benefits -1. **Flexibility**: Accommodates various data sources and formats -2. **Efficiency**: Streamlines the document ingestion process with pre-processing and intelligent indexing -3. **Customization**: Allows customers to choose and easily switch between preferred vector stores -4. **Integration**: Leverages existing LISA capabilities while extending functionality - -### Use Cases -- Large-scale document ingestion for enterprise customers -- Integration of external mission-critical data sources -- Customized knowledge base creation for specific industries or applications - -This new Automated Document Ingestion Pipeline significantly expands LISA's capabilities, providing customers with a powerful tool for managing and utilizing their document-based knowledge more effectively. - -> **_NOTE:_** Event ingestion requires Amazon EventBridge to be enabled on the S3 bucket. You can enable this in the bucket's properties configuration page in the AWS Management Console. -### Configuration Example - -RAG repositories and Automated Ingestion Pipelines are configurable through the chatbot web UI or programmatically via the API for managing RAG repositories, allowing customers to tailor the ingestion process to their specific needs. - -#### Request Example: - -```bash -curl -s -H 'Authorization: Bearer ' -XPOST -d @body.json https:///models -``` - -#### Response Example: - -```json -// body.json -{ - "ragConfig": { - "repositoryId": "my-vector-store", - "repositoryName": "My Vector Store", - "type": "pgvector", - "rdsConfig": { - "username": "postgres" - }, - "pipelines": [ - { - "chunkOverlap": 51, - "chunkSize": 256, - "embeddingModel": "titan-embed-text-v1", - "trigger": "event", - "s3Bucket": "my-ingestion-bucket", - "s3Prefix": "/some/path/to/watch" - } - ] - } -} -``` - -#### Explanation of Response Fields: - -- `status`: "success" if the state machine was started successfully. -- `executionArn`: The state machine ARN used to deploy the vector store. - - - diff --git a/lib/docs/package.json b/lib/docs/package.json index bbb871314..7d1b5e088 100644 --- a/lib/docs/package.json +++ b/lib/docs/package.json @@ -10,7 +10,8 @@ "docs:dev": "vitepress dev .", "docs:build": "vitepress build .", "docs:preview": "vitepress preview .", - "clean": "rm -rf ./dist ./node_modules" + "clean": "rm -rf ./dist ./node_modules", + "test": "echo \"No tests for documentation package\"" }, "author": "", "license": "Apache-2.0", diff --git a/lib/docs/user/breaking-changes.md b/lib/docs/user/breaking-changes.md index 264cb9af0..d160bf854 100644 --- a/lib/docs/user/breaking-changes.md +++ b/lib/docs/user/breaking-changes.md @@ -1,5 +1,29 @@ # Breaking Changes +## v6.0.0 + +Beginning with LISA v6.0.0, the API token table is no longer owned by the Serve stack—it's been moved into the API Base +stack so MCP hosting and future API workloads can scale independently. As part of this move the DynamoDB table is renamed +(`LisaServeTokenTable` → `LisaApiBaseTokenTable`). CloudFormation cannot migrate the data automatically, so **admins must +export all existing API keys before upgrading** and then create the corresponding records in the new table after the +deployment completes. If you rely on programmatic API access (admin keys, service accounts, automations, etc.), +make sure to capture the current values so they can be re-added once the new table exists. + +Additionally, the LISA management key secret has been moved from the Serve stack to the API Base stack, and the secret +name has changed from `${deploymentName}-lisa-management-key` to `${deploymentName}-management-key` (removed the +`lisa-` prefix). The new secret will be auto-generated with a new value during deployment. **If you have scripts, +automations, or integrations that reference the management key by its secret name, you must update them to use the new +name.** If you need to preserve the existing management key value, export it from AWS Secrets Manager before upgrading +and manually update the new secret after deployment completes. The SSM parameter `${deploymentPrefix}/appManagementKeySecretName` +will automatically point to the new secret name, so code that references the secret via this parameter will continue to +work without changes. + +Finally, existing Bedrock Knowledge Base repositories must be redeployed to support the new collections infrastructure. +This is a simple update operation that creates the necessary infrastructure for automatic data source collection creation. +**Use the repository update API or UI to redeploy existing Bedrock Knowledge Base repositories** after upgrading to +v6.0.0. This migration enables the new collections features for Bedrock Knowledge Base repositories, including +automatic collection creation for each data source and enhanced metadata capabilities. + ## v4.0.0 With the release of LISA v4.0, we introduced a significant update to the configuration and functionality of RAG diff --git a/lib/mcp/index.ts b/lib/mcp/index.ts new file mode 100644 index 000000000..323a9c06c --- /dev/null +++ b/lib/mcp/index.ts @@ -0,0 +1,40 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +import { Stack } from 'aws-cdk-lib'; +import { Construct } from 'constructs'; + +import { LisaMcpApiConstruct, LisaMcpApiProps } from './mcpApiConstruct'; + +export * from './mcpApiConstruct'; +export { McpServerDeployer } from './mcp-server-deployer'; +export { McpServerApi } from './mcp-server-api'; +export { CreateMcpServerStateMachine } from './state-machine/create-mcp-server'; + +/** + * Lisa MCP Server API Stack. + */ +export class LisaMcpApiStack extends Stack { + /** + * @param {Construct} scope - The parent or owner of the construct. + * @param {string} id - The unique identifier for the construct within its scope. + * @param {LisaMcpApiProps} props - Properties for the Stack. + */ + constructor (scope: Construct, id: string, props: LisaMcpApiProps) { + super(scope, id, props); + + (new LisaMcpApiConstruct(this, id + 'Resources', props)).node.addMetadata('aws:cdk:path', this.node.path); + } +} diff --git a/lib/mcp/mcp-server-api.ts b/lib/mcp/mcp-server-api.ts new file mode 100644 index 000000000..8d863cd9e --- /dev/null +++ b/lib/mcp/mcp-server-api.ts @@ -0,0 +1,490 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Cors, IAuthorizer, RestApi } from 'aws-cdk-lib/aws-apigateway'; +import * as dynamodb from 'aws-cdk-lib/aws-dynamodb'; +import { Effect, IRole, ManagedPolicy, Policy, PolicyDocument, PolicyStatement, Role, ServicePrincipal } from 'aws-cdk-lib/aws-iam'; +import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { IFunction, LayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { StringParameter } from 'aws-cdk-lib/aws-ssm'; +import { Construct } from 'constructs'; + +import { getDefaultRuntime, registerAPIEndpoint } from '../api-base/utils'; +import { APP_MANAGEMENT_KEY, BaseProps } from '../schema'; +import { createCdkId, createLambdaRole } from '../core/utils'; +import { Vpc } from '../networking/vpc'; +import { LAMBDA_PATH } from '../util'; +import { McpServerDeployer } from './mcp-server-deployer'; +import { CreateMcpServerStateMachine } from './state-machine/create-mcp-server'; +import { DeleteMcpServerStateMachine } from './state-machine/delete-mcp-server'; +import { UpdateMcpServerStateMachine } from './state-machine/update-mcp-server'; +import { Bucket, HttpMethods } from 'aws-cdk-lib/aws-s3'; +import { RemovalPolicy } from 'aws-cdk-lib'; + +type McpServerApiProps = { + authorizer: IAuthorizer; + restApiId: string; + rootResourceId: string; + securityGroups: ISecurityGroup[]; + vpc: Vpc; +} & BaseProps; + +/** + * API for managing MCP server dynamic hosting infrastructure + */ +export class McpServerApi extends Construct { + readonly createStateMachineArn: string; + readonly deleteStateMachineArn: string; + readonly updateStateMachineArn: string; + readonly mcpServerDeployerFn: IFunction; + + constructor (scope: Construct, id: string, props: McpServerApiProps) { + super(scope, id); + + const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + + // Get common layer based on arn from SSM due to issues with cross stack references + const commonLambdaLayer = LayerVersion.fromLayerVersionArn( + this, + 'mcpserver-common-lambda-layer', + StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/layerVersion/common`), + ); + + const fastapiLambdaLayer = LayerVersion.fromLayerVersionArn( + this, + 'mcpserver-fastapi-lambda-layer', + StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/layerVersion/fastapi`), + ); + + const lambdaLayers = [commonLambdaLayer, fastapiLambdaLayer]; + + // Get management key name + const managementKeyName = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`); + + const mcpServersTable = new dynamodb.Table(this, 'HostMcpServerTable', { + partitionKey: { + name: 'id', + type: dynamodb.AttributeType.STRING + }, + billingMode: dynamodb.BillingMode.PAY_PER_REQUEST, + encryption: dynamodb.TableEncryption.AWS_MANAGED, + removalPolicy: config.removalPolicy, + }); + + const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', + StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) + ); + + const bucket = new Bucket(scope, createCdkId(['LISA', 'MCP-Hosting', config.deploymentName, config.deploymentStage]), { + removalPolicy: config.removalPolicy, + autoDeleteObjects: config.removalPolicy === RemovalPolicy.DESTROY, + enforceSSL: true, + cors: [ + { + allowedMethods: [HttpMethods.GET, HttpMethods.POST], + allowedHeaders: ['*'], + allowedOrigins: ['*'], + exposedHeaders: ['Access-Control-Allow-Origin'], + }, + ], + serverAccessLogsBucket: bucketAccessLogsBucket, + serverAccessLogsPrefix: 'logs/mcp-hosting-bucket/' + }); + + // Get reference to REST API first (will be reused) + const restApi = RestApi.fromRestApiAttributes(this, 'RestApi', { + restApiId: restApiId, + rootResourceId: rootResourceId, + }); + + // Create or get the /mcp resource explicitly to capture its ID + // This resource ID is needed for the deployer to reference it when creating server routes + // We do this before registerAPIEndpoint so we can pass the ID to the deployer + let mcpResource = restApi.root.getResource('mcp'); + if (!mcpResource) { + mcpResource = restApi.root.addResource('mcp'); + } + // Add CORS preflight support for the /mcp resource + // This ensures OPTIONS method is available even if the resource already existed + // addCorsPreflight is idempotent - it won't create duplicate OPTIONS methods + mcpResource.addCorsPreflight({ + allowOrigins: Cors.ALL_ORIGINS, + allowHeaders: Cors.DEFAULT_HEADERS, + }); + const mcpResourceId = mcpResource.resourceId; + + // Create MCP server deployer + // Pass authorizer ID so deployed servers can use the same authorizer + const authorizerId = authorizer.authorizerId; + const mcpServerDeployer = new McpServerDeployer(this, 'mcp-server-deployer', { + securityGroupId: vpc.securityGroups.ecsModelAlbSg.securityGroupId, + config: config, + vpc: vpc, + restApiId: restApiId, + rootResourceId: rootResourceId, + hostingBucketArn: bucket.bucketArn, + mcpResourceId: mcpResourceId, + authorizerId: authorizerId, + }); + + this.mcpServerDeployerFn = mcpServerDeployer.mcpServerDeployerFn; + + // Create state machine Lambda role + const stateMachinesLambdaRole = this.createStateMachineLambdaRole( + mcpServersTable.tableArn, + mcpServerDeployer.mcpServerDeployerFn.functionArn, + managementKeyName, + config + ); + + // Create state machine for creating MCP servers + const createMcpServerStateMachine = new CreateMcpServerStateMachine(this, 'CreateMcpServerWorkflow', { + config: config, + mcpServerTable: mcpServersTable, + lambdaLayers: lambdaLayers, + role: stateMachinesLambdaRole, + vpc: vpc, + securityGroups: securityGroups, + mcpServerDeployerFnArn: mcpServerDeployer.mcpServerDeployerFn.functionArn, + managementKeyName: managementKeyName, + }); + + this.createStateMachineArn = createMcpServerStateMachine.stateMachineArn; + + // Create state machine for deleting MCP servers + const deleteMcpServerStateMachine = new DeleteMcpServerStateMachine(this, 'DeleteMcpServerWorkflow', { + config: config, + mcpServerTable: mcpServersTable, + lambdaLayers: lambdaLayers, + role: stateMachinesLambdaRole, + vpc: vpc, + securityGroups: securityGroups, + }); + + this.deleteStateMachineArn = deleteMcpServerStateMachine.stateMachineArn; + + // Create state machine for updating MCP servers + const updateMcpServerStateMachine = new UpdateMcpServerStateMachine(this, 'UpdateMcpServerWorkflow', { + config: config, + mcpServerTable: mcpServersTable, + lambdaLayers: lambdaLayers, + role: stateMachinesLambdaRole, + vpc: vpc, + securityGroups: securityGroups, + }); + + this.updateStateMachineArn = updateMcpServerStateMachine.stateMachineArn; + + const env = { + MCP_SERVERS_TABLE_NAME: mcpServersTable.tableName, + CREATE_MCP_SERVER_SFN_ARN: createMcpServerStateMachine.stateMachineArn, + DELETE_MCP_SERVER_SFN_ARN: deleteMcpServerStateMachine.stateMachineArn, + UPDATE_MCP_SERVER_SFN_ARN: updateMcpServerStateMachine.stateMachineArn, + ADMIN_GROUP: config.authConfig?.adminGroup || '', + }; + + const lambdaRole = createLambdaRole(this, config.deploymentName, 'McpServerDynamicApi', mcpServersTable.tableArn, config.roles?.LambdaExecutionRole); + const lambdaPath = config.lambdaPath || LAMBDA_PATH; + + // Create the API Lambda function to trigger the MCP server create state machine + // Note: registerAPIEndpoint will use getOrCreateResource which will find the existing /mcp resource + const lambdaFunction = registerAPIEndpoint( + this, + restApi, + lambdaPath, + lambdaLayers, + { + name: 'create_hosted_mcp_server', + resource: 'mcp_server', + description: 'Create LISA MCP hosted server', + path: 'mcp', + method: 'POST', + environment: env + }, + getDefaultRuntime(), + vpc, + securityGroups, + authorizer, + lambdaRole, + ); + + // Register GET endpoint for listing hosted MCP servers + registerAPIEndpoint( + this, + restApi, + lambdaPath, + lambdaLayers, + { + name: 'list_hosted_mcp_servers', + resource: 'mcp_server', + description: 'List LISA MCP hosted servers', + path: 'mcp', + method: 'GET', + environment: env + }, + getDefaultRuntime(), + vpc, + securityGroups, + authorizer, + lambdaRole, + ); + + // Register GET endpoint for getting a specific hosted MCP server by ID + registerAPIEndpoint( + this, + restApi, + lambdaPath, + lambdaLayers, + { + name: 'get_hosted_mcp_server', + resource: 'mcp_server', + description: 'Get LISA MCP hosted server by ID', + path: 'mcp/{serverId}', + method: 'GET', + environment: env + }, + getDefaultRuntime(), + vpc, + securityGroups, + authorizer, + lambdaRole, + ); + + // Register DELETE endpoint for deleting a hosted MCP server by ID + registerAPIEndpoint( + this, + restApi, + lambdaPath, + lambdaLayers, + { + name: 'delete_hosted_mcp_server', + resource: 'mcp_server', + description: 'Delete LISA MCP hosted server by ID', + path: 'mcp/{serverId}', + method: 'DELETE', + environment: env + }, + getDefaultRuntime(), + vpc, + securityGroups, + authorizer, + lambdaRole, + ); + + // Register PUT endpoint for updating a hosted MCP server by ID + registerAPIEndpoint( + this, + restApi, + lambdaPath, + lambdaLayers, + { + name: 'update_hosted_mcp_server', + resource: 'mcp_server', + description: 'Update LISA MCP hosted server by ID', + path: 'mcp/{serverId}', + method: 'PUT', + environment: env + }, + getDefaultRuntime(), + vpc, + securityGroups, + authorizer, + lambdaRole, + ); + + // Grant permissions for state machine invocation + const workflowPermissions = new Policy(this, 'McpServerApiStateMachinePerms', { + statements: [ + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'states:StartExecution', + ], + resources: [ + createMcpServerStateMachine.stateMachineArn, + deleteMcpServerStateMachine.stateMachineArn, + updateMcpServerStateMachine.stateMachineArn, + ], + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'dynamodb:GetItem', + 'dynamodb:Scan', + 'dynamodb:PutItem', + 'dynamodb:UpdateItem', + 'dynamodb:DeleteItem', + ], + resources: [ + mcpServersTable.tableArn, + `${mcpServersTable.tableArn}/*` + ], + }), + ] + }); + lambdaFunction.role!.attachInlinePolicy(workflowPermissions); + } + + /** + * Creates a role for the state machine lambdas + * @param mcpServerTableArn - Arn of the MCP server table + * @param mcpServerDeployerFnArn - Arn of the MCP server deployer lambda + * @param managementKeyName - Name of the management key secret + * @param config - Config object + * @returns The created role + */ + createStateMachineLambdaRole (mcpServerTableArn: string, mcpServerDeployerFnArn: string, managementKeyName: string, config: any): IRole { + const statements: PolicyStatement[] = [ + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'dynamodb:DeleteItem', + 'dynamodb:GetItem', + 'dynamodb:PutItem', + 'dynamodb:UpdateItem', + 'dynamodb:Scan', + ], + resources: [ + mcpServerTableArn, + `${mcpServerTableArn}/*`, + ] + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'cloudformation:CreateStack', + 'cloudformation:DeleteStack', + 'cloudformation:DescribeStacks', + 'cloudformation:DescribeStackResources', + ], + resources: [ + // Limit CloudFormation permissions to MCP server stacks that this deployment creates. + `arn:${config.partition}:cloudformation:${config.region}:${config.accountNumber}:stack/${config.appName}-${config.deploymentName}-${config.deploymentStage}-mcp-server-*`, + ], + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'ecs:DescribeTaskDefinition', + 'ecs:RegisterTaskDefinition', + 'ecs:UpdateService', + 'ecs:DescribeServices', + ], + resources: ['*'], // ECS resources are dynamic and created by CloudFormation + }), + // Allow passing task/execution roles to ECS when registering task definitions + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['iam:PassRole'], + resources: ['*'], + conditions: { + StringEquals: { + 'iam:PassedToService': 'ecs-tasks.amazonaws.com' + } + } + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'application-autoscaling:RegisterScalableTarget', + 'application-autoscaling:DescribeScalableTargets', + 'application-autoscaling:DeregisterScalableTarget', + ], + resources: ['*'], // Application Auto Scaling resources are dynamic + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'lambda:InvokeFunction' + ], + resources: [ + mcpServerDeployerFnArn + ] + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'ec2:CreateNetworkInterface', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DescribeSubnets', + 'ec2:DeleteNetworkInterface', + 'ec2:AssignPrivateIpAddresses', + 'ec2:UnassignPrivateIpAddresses' + ], + resources: ['*'], + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'ssm:GetParameter', + ], + resources: [ + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/lisaServeRestApiUri`, + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter/LISA-management-key`, + ], + }), + ]; + + // Add permissions for MCP Connections table if chat is deployed + // This table is created in the chat stack and stores user-facing MCP connections + if (config.deployChat) { + // Reference the table by ARN pattern (table will be created in chat stack) + // CDK generates table names, so we use a pattern that matches the naming convention + // Format: {StackName}-{ConstructId}{Hash} or {StackName}-{ConstructId}-{Hash} + // Stack name format: {deploymentName}-{appName}-chat-{deploymentStage} + // Construct ID: McpServersTable + const stackNamePattern = `${config.deploymentName}-${config.appName}-chat-${config.deploymentStage}`; + const mcpConnectionsTableArnPattern = `arn:${config.partition}:dynamodb:${config.region}:${config.accountNumber}:table/${stackNamePattern}-*McpServersTable*`; + statements.push( + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'dynamodb:PutItem', + 'dynamodb:UpdateItem', + 'dynamodb:GetItem', + 'dynamodb:DeleteItem', + 'dynamodb:Scan', + ], + resources: [ + mcpConnectionsTableArnPattern, + ], + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'ssm:GetParameter', + ], + resources: [ + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/table/mcpServersTable`, + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/LisaApiUrl`, + ], + }) + ); + } + + return new Role(this, 'McpServerSfnLambdaRole', { + assumedBy: new ServicePrincipal('lambda.amazonaws.com'), + managedPolicies: [ + ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaVPCAccessExecutionRole'), + ], + inlinePolicies: { + lambdaPermissions: new PolicyDocument({ + statements: statements, + }), + } + }); + } +} diff --git a/lib/mcp/mcp-server-deployer.ts b/lib/mcp/mcp-server-deployer.ts new file mode 100644 index 000000000..a84f395d2 --- /dev/null +++ b/lib/mcp/mcp-server-deployer.ts @@ -0,0 +1,136 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { Construct } from 'constructs'; +import { IFunction, Runtime } from 'aws-cdk-lib/aws-lambda'; +import { + Effect, + IRole, + ManagedPolicy, + PolicyDocument, + PolicyStatement, + Role, + ServicePrincipal, +} from 'aws-cdk-lib/aws-iam'; +import { Duration, Size, Stack } from 'aws-cdk-lib'; + +import { createCdkId } from '../core/utils'; +import { BaseProps, Config } from '../schema'; +import { Vpc } from '../networking/vpc'; +import { NodejsFunction } from 'aws-cdk-lib/aws-lambda-nodejs'; +import { CodeFactory, MCP_SERVER_DEPLOYER_DIST_PATH } from '../util'; + +export type McpServerDeployerProps = { + securityGroupId: string; + config: Config; + vpc: Vpc; + restApiId: string; + rootResourceId: string; + hostingBucketArn: string; + mcpResourceId: string; + authorizerId?: string; +} & BaseProps; + +export class McpServerDeployer extends Construct { + readonly mcpServerDeployerFn: IFunction; + + constructor (scope: Construct, id: string, props: McpServerDeployerProps) { + super(scope, id); + const stackName = Stack.of(scope).stackName; + const { config } = props; + + const role = config.roles ? + Role.fromRoleName(this, createCdkId([stackName, 'ecs-model-deployer-role']), config.roles.McpServerDeployerRole) : + this.createRole(stackName); + + const stripped_config = { + 'appName': props.config.appName, + 'deploymentName': props.config.deploymentName, + 'deploymentPrefix': props.config.deploymentPrefix, + 'region': props.config.region, + 'deploymentStage': props.config.deploymentStage, + 'removalPolicy': props.config.removalPolicy, + 'subnets': props.config.subnets, + 'taskRole': props.config.roles?.ECSModelTaskRole, + 'certificateAuthorityBundle': props.config.certificateAuthorityBundle, + 'pypiConfig': props.config.pypiConfig, + }; + + const functionId = createCdkId([stackName, 'mcp_server_deployer', 'Fn']); + const mcpServerDeployerPath = config.mcpServerDeployerPath || MCP_SERVER_DEPLOYER_DIST_PATH; + this.mcpServerDeployerFn = new NodejsFunction(this, functionId, { + functionName: functionId, + code: CodeFactory.createCode(mcpServerDeployerPath), + timeout: Duration.minutes(10), + ephemeralStorageSize: Size.mebibytes(2048), + runtime: Runtime.NODEJS_18_X, + handler: 'index.handler', + memorySize: 1024, + role, + environment: { + 'LISA_VPC_ID': props.vpc.vpc.vpcId, + 'LISA_SECURITY_GROUP_ID': props.securityGroupId, + 'LISA_CONFIG': JSON.stringify(stripped_config), + 'LISA_REST_API_ID': props.restApiId, + 'LISA_ROOT_RESOURCE_ID': props.rootResourceId, + 'LISA_HOSTING_BUCKET_ARN': props.hostingBucketArn, + 'LISA_MCP_RESOURCE_ID': props.mcpResourceId, + ...(props.authorizerId && { 'LISA_AUTHORIZER_ID': props.authorizerId }), + }, + vpcSubnets: props.vpc.subnetSelection, + vpc: props.vpc.vpc, + securityGroups: [props.vpc.securityGroups.lambdaSg], + }); + } + + + /** + * Create MCP Server Deployer role + * @param stackName - deployment stack name + * @returns new role + */ + createRole (stackName: string): IRole { + return new Role(this, createCdkId([stackName, 'mcp-server-deployer-role']), { + assumedBy: new ServicePrincipal('lambda.amazonaws.com'), + managedPolicies: [ + ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaVPCAccessExecutionRole'), + ], + inlinePolicies: { + lambdaPermissions: new PolicyDocument({ + statements: [ + new PolicyStatement({ + actions: ['sts:AssumeRole'], + resources: ['arn:*:iam::*:role/cdk-*'], + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'ec2:CreateNetworkInterface', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DescribeSubnets', + 'ec2:DeleteNetworkInterface', + 'ec2:AssignPrivateIpAddresses', + 'ec2:UnassignPrivateIpAddresses', + ], + resources: ['*'], + }), + ], + }), + + }, + }); + } +} diff --git a/lib/mcp/mcpApiConstruct.ts b/lib/mcp/mcpApiConstruct.ts new file mode 100644 index 000000000..2bcd5d408 --- /dev/null +++ b/lib/mcp/mcpApiConstruct.ts @@ -0,0 +1,59 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Stack, StackProps } from 'aws-cdk-lib'; +import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; +import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { Construct } from 'constructs'; + +import { Vpc } from '../networking/vpc'; +import { McpServerApi } from './mcp-server-api'; +import { BaseProps } from '../schema'; + +export type LisaMcpApiProps = BaseProps & + StackProps & { + authorizer: IAuthorizer; + restApiId: string; + rootResourceId: string; + securityGroups: ISecurityGroup[]; + vpc: Vpc; + }; + +/** + * Lisa MCP Server API Construct. + */ +export class LisaMcpApiConstruct extends Construct { + /** + * @param {Stack} scope - The parent or owner of the construct. + * @param {string} id - The unique identifier for the construct within its scope. + * @param {LisaMcpApiProps} props - Properties for the Stack. + */ + constructor (scope: Stack, id: string, props: LisaMcpApiProps) { + super(scope, id); + + const { authorizer, config, restApiId, rootResourceId, securityGroups, vpc } = props; + + // Add MCP Server API dynamic hosting + new McpServerApi(scope, 'McpServerApi', { + authorizer, + config, + restApiId, + rootResourceId, + securityGroups, + vpc, + }); + } +} diff --git a/lib/mcp/state-machine/constants.ts b/lib/mcp/state-machine/constants.ts new file mode 100644 index 000000000..f633ccf22 --- /dev/null +++ b/lib/mcp/state-machine/constants.ts @@ -0,0 +1,23 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { Duration } from 'aws-cdk-lib'; +import { WaitTime } from 'aws-cdk-lib/aws-stepfunctions'; + +export const LAMBDA_MEMORY: number = 128; +export const LAMBDA_TIMEOUT: Duration = Duration.seconds(60); +export const OUTPUT_PATH: string = '$.Payload'; +export const POLLING_TIMEOUT: WaitTime = WaitTime.duration(Duration.seconds(60)); diff --git a/lib/mcp/state-machine/create-mcp-server.ts b/lib/mcp/state-machine/create-mcp-server.ts new file mode 100644 index 000000000..40966aaf8 --- /dev/null +++ b/lib/mcp/state-machine/create-mcp-server.ts @@ -0,0 +1,191 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { + Choice, + Condition, + DefinitionBody, + Fail, + StateMachine, + Succeed, + Wait, +} from 'aws-cdk-lib/aws-stepfunctions'; +import { Construct } from 'constructs'; +import { Duration } from 'aws-cdk-lib'; +import { BaseProps } from '../../schema'; +import { ITable } from 'aws-cdk-lib/aws-dynamodb'; +import { Code, Function, ILayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { IRole } from 'aws-cdk-lib/aws-iam'; +import { LAMBDA_MEMORY, LAMBDA_TIMEOUT, OUTPUT_PATH, POLLING_TIMEOUT } from './constants'; +import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { LambdaInvoke } from 'aws-cdk-lib/aws-stepfunctions-tasks'; +import { Vpc } from '../../networking/vpc'; +import { getDefaultRuntime } from '../../api-base/utils'; +import { LAMBDA_PATH } from '../../util'; + +type CreateMcpServerStateMachineProps = BaseProps & { + mcpServerTable: ITable, + lambdaLayers: ILayerVersion[]; + mcpServerDeployerFnArn: string; + vpc: Vpc, + securityGroups: ISecurityGroup[]; + managementKeyName: string; + role?: IRole, + executionRole?: IRole; +}; + +/** + * State Machine for creating MCP servers. + */ +export class CreateMcpServerStateMachine extends Construct { + readonly stateMachineArn: string; + + constructor (scope: Construct, id: string, props: CreateMcpServerStateMachineProps) { + super(scope, id); + + const { config, mcpServerTable, lambdaLayers, mcpServerDeployerFnArn, role, vpc, securityGroups, managementKeyName, executionRole } = props; + const lambdaPath = config.lambdaPath || LAMBDA_PATH; + const environment = { + MCP_SERVER_DEPLOYER_FN_ARN: mcpServerDeployerFnArn, + MCP_SERVERS_TABLE_NAME: mcpServerTable.tableName, + MANAGEMENT_KEY_NAME: managementKeyName, + RESTAPI_SSL_CERT_ARN: config.restApiConfig?.sslCertIamArn ?? '', + DEPLOYMENT_PREFIX: config.deploymentPrefix ?? '', + }; + + const setServerToCreating = new LambdaInvoke(this, 'SetServerToCreating', { + lambdaFunction: new Function(this, 'SetServerToCreatingFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.create_mcp_server.handle_set_server_to_creating', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const deployServer = new LambdaInvoke(this, 'DeployServer', { + lambdaFunction: new Function(this, 'DeployServerFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.create_mcp_server.handle_deploy_server', + code: Code.fromAsset(lambdaPath), + timeout: Duration.minutes(8), + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const pollDeployment = new LambdaInvoke(this, 'PollDeployment', { + lambdaFunction: new Function(this, 'PollDeploymentFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.create_mcp_server.handle_poll_deployment', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const pollDeploymentChoice = new Choice(this, 'PollDeploymentChoice'); + const waitBeforePolling = new Wait(this, 'WaitBeforePolling', { + time: POLLING_TIMEOUT, + }); + + const addServerToActive = new LambdaInvoke(this, 'AddServerToActive', { + lambdaFunction: new Function(this, 'AddServerToActiveFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.create_mcp_server.handle_add_server_to_active', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const handleFailureState = new LambdaInvoke(this, 'HandleFailure', { + lambdaFunction: new Function(this, 'HandleFailureFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.create_mcp_server.handle_failure', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const successState = new Succeed(this, 'CreateSuccess'); + const failState = new Fail(this, 'CreateFailed'); + + // State Machine definition + setServerToCreating.next(deployServer); + deployServer.addCatch(handleFailureState, { + errors: ['States.TaskFailed'], + }); + deployServer.next(pollDeployment); + pollDeployment.addCatch(handleFailureState, { + errors: ['States.TaskFailed'], + }); + pollDeployment.next(pollDeploymentChoice); + pollDeploymentChoice + .when(Condition.booleanEquals('$.continue_polling', true), waitBeforePolling) + .otherwise(addServerToActive); + waitBeforePolling.next(pollDeployment); + + // terminal states + handleFailureState.next(failState); + addServerToActive.next(successState); + + const stateMachine = new StateMachine(this, 'CreateMcpServerSM', { + definitionBody: DefinitionBody.fromChainable(setServerToCreating), + ...(executionRole && + { + role: executionRole + }) + }); + + this.stateMachineArn = stateMachine.stateMachineArn; + } +} diff --git a/lib/mcp/state-machine/delete-mcp-server.ts b/lib/mcp/state-machine/delete-mcp-server.ts new file mode 100644 index 000000000..e6be513cb --- /dev/null +++ b/lib/mcp/state-machine/delete-mcp-server.ts @@ -0,0 +1,172 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + + +import { Construct } from 'constructs'; +import { LambdaInvoke } from 'aws-cdk-lib/aws-stepfunctions-tasks'; +import { + Choice, + Condition, + DefinitionBody, + StateMachine, + Succeed, + Wait, +} from 'aws-cdk-lib/aws-stepfunctions'; +import { Code, Function, ILayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { BaseProps } from '../../schema'; +import { IRole } from 'aws-cdk-lib/aws-iam'; +import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { ITable } from 'aws-cdk-lib/aws-dynamodb'; +import { LAMBDA_MEMORY, LAMBDA_TIMEOUT, OUTPUT_PATH, POLLING_TIMEOUT } from '../../models/state-machine/constants'; +import { Vpc } from '../../networking/vpc'; +import { getDefaultRuntime } from '../../api-base/utils'; +import { LAMBDA_PATH } from '../../util'; + +type DeleteMcpServerStateMachineProps = BaseProps & { + mcpServerTable: ITable, + lambdaLayers: ILayerVersion[], + vpc: Vpc, + securityGroups: ISecurityGroup[]; + role?: IRole, + executionRole?: IRole; +}; + + +/** + * State Machine for deleting MCP servers. + */ +export class DeleteMcpServerStateMachine extends Construct { + readonly stateMachineArn: string; + + constructor (scope: Construct, id: string, props: DeleteMcpServerStateMachineProps) { + super(scope, id); + + const { config, mcpServerTable, lambdaLayers, role, vpc, securityGroups, executionRole } = props; + + const environment = { // Environment variables to set in all Lambda functions + MCP_SERVERS_TABLE_NAME: mcpServerTable.tableName, + DEPLOYMENT_PREFIX: config.deploymentPrefix ?? '', + }; + const lambdaPath = config.lambdaPath || LAMBDA_PATH; + + // Needs to return if server has a stack to delete. Updates server state to DELETING. + // Input payload to state machine contains the server ID that we want to delete. + const setServerToDeleting = new LambdaInvoke(this, 'SetServerToDeleting', { + lambdaFunction: new Function(this, 'SetServerToDeletingFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.delete_mcp_server.handle_set_server_to_deleting', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const deleteStack = new LambdaInvoke(this, 'DeleteStack', { + lambdaFunction: new Function(this, 'DeleteStackFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.delete_mcp_server.handle_delete_stack', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const monitorDeleteStack = new LambdaInvoke(this, 'MonitorDeleteStack', { + lambdaFunction: new Function(this, 'MonitorDeleteStackFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.delete_mcp_server.handle_monitor_delete_stack', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const deleteFromDdb = new LambdaInvoke(this, 'DeleteFromDdb', { + lambdaFunction: new Function(this, 'DeleteFromDdbFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.delete_mcp_server.handle_delete_from_ddb', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const successState = new Succeed(this, 'DeleteSuccess'); + + const deleteStackChoice = new Choice(this, 'DeleteStackChoice'); + const pollDeleteStackChoice = new Choice(this, 'PollDeleteStackChoice'); + const waitBeforePollingStackStatus = new Wait(this, 'WaitBeforePollDeleteStack', { + time: POLLING_TIMEOUT, + }); + + // State Machine definition + setServerToDeleting.next(deleteStackChoice); + + deleteStackChoice + .when(Condition.isNotNull('$.cloudformation_stack_arn'), deleteStack) + .otherwise(deleteFromDdb); + + deleteStack.next(monitorDeleteStack); + monitorDeleteStack.next(pollDeleteStackChoice); + + waitBeforePollingStackStatus.next(monitorDeleteStack); + + pollDeleteStackChoice + .when(Condition.booleanEquals('$.continue_polling', true), waitBeforePollingStackStatus) + .otherwise(deleteFromDdb); + + + deleteFromDdb.next(successState); + + const stateMachine = new StateMachine(this, 'DeleteMcpServerSM', { + definitionBody: DefinitionBody.fromChainable(setServerToDeleting), + ...(executionRole && + { + role: executionRole + }) + }); + + this.stateMachineArn = stateMachine.stateMachineArn; + } +} diff --git a/lib/mcp/state-machine/update-mcp-server.ts b/lib/mcp/state-machine/update-mcp-server.ts new file mode 100644 index 000000000..10d03e681 --- /dev/null +++ b/lib/mcp/state-machine/update-mcp-server.ts @@ -0,0 +1,211 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { BaseProps } from '../../schema'; +import { ITable } from 'aws-cdk-lib/aws-dynamodb'; +import { Code, Function, ILayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { IRole } from 'aws-cdk-lib/aws-iam'; +import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { Construct } from 'constructs'; +import { LambdaInvoke } from 'aws-cdk-lib/aws-stepfunctions-tasks'; +import { LAMBDA_MEMORY, LAMBDA_TIMEOUT, OUTPUT_PATH, POLLING_TIMEOUT } from '../../models/state-machine/constants'; +import { + Choice, + Condition, + DefinitionBody, + StateMachine, + Succeed, + Wait, +} from 'aws-cdk-lib/aws-stepfunctions'; +import { Vpc } from '../../networking/vpc'; +import { getDefaultRuntime } from '../../api-base/utils'; +import { LAMBDA_PATH } from '../../util'; + +type UpdateMcpServerStateMachineProps = BaseProps & { + mcpServerTable: ITable, + lambdaLayers: ILayerVersion[], + vpc: Vpc, + securityGroups: ISecurityGroup[]; + role?: IRole, + executionRole?: IRole; +}; + + +/** + * State Machine for updating MCP servers. + */ +export class UpdateMcpServerStateMachine extends Construct { + readonly stateMachineArn: string; + + constructor (scope: Construct, id: string, props: UpdateMcpServerStateMachineProps) { + super(scope, id); + + const { + config, + mcpServerTable, + lambdaLayers, + role, + vpc, + securityGroups, + executionRole + } = props; + + const environment = { // Environment variables to set in all Lambda functions + MCP_SERVERS_TABLE_NAME: mcpServerTable.tableName, + DEPLOYMENT_PREFIX: config.deploymentPrefix ?? '', + }; + const lambdaPath = config.lambdaPath || LAMBDA_PATH; + const handleJobIntake = new LambdaInvoke(this, 'HandleJobIntake', { + lambdaFunction: new Function(this, 'HandleJobIntakeFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.update_mcp_server.handle_job_intake', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const handlePollCapacity = new LambdaInvoke(this, 'HandlePollCapacity', { + lambdaFunction: new Function(this, 'HandlePollCapacityFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.update_mcp_server.handle_poll_capacity', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const handleEcsUpdate = new LambdaInvoke(this, 'HandleEcsUpdate', { + lambdaFunction: new Function(this, 'HandleEcsUpdateFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.update_mcp_server.handle_ecs_update', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const handlePollEcsDeployment = new LambdaInvoke(this, 'HandlePollEcsDeployment', { + lambdaFunction: new Function(this, 'HandlePollEcsDeploymentFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.update_mcp_server.handle_poll_ecs_deployment', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const handleFinishUpdate = new LambdaInvoke(this, 'HandleFinishUpdate', { + lambdaFunction: new Function(this, 'HandleFinishUpdateFunc', { + runtime: getDefaultRuntime(), + handler: 'mcp_server.state_machine.update_mcp_server.handle_finish_update', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + // terminal states + const successState = new Succeed(this, 'UpdateSuccess'); + + // choice states + const hasEcsUpdateChoice = new Choice(this, 'HasEcsUpdateChoice'); + const hasCapacityUpdateChoice = new Choice(this, 'HasCapacityUpdateChoice'); + const pollAsgChoice = new Choice(this, 'PollAsgChoice'); + const pollEcsDeploymentChoice = new Choice(this, 'PollEcsDeploymentChoice'); + + // wait states + const waitBeforePollAsg = new Wait(this, 'WaitBeforePollAsg', { + time: POLLING_TIMEOUT + }); + const waitBeforePollEcsDeployment = new Wait(this, 'WaitBeforePollEcsDeployment', { + time: POLLING_TIMEOUT + }); + + // State Machine definition + handleJobIntake.next(hasEcsUpdateChoice); + + // ECS update flow + hasEcsUpdateChoice + .when(Condition.booleanEquals('$.needs_ecs_update', true), handleEcsUpdate) + .otherwise(hasCapacityUpdateChoice); + + handleEcsUpdate.next(handlePollEcsDeployment); + handlePollEcsDeployment.next(pollEcsDeploymentChoice); + pollEcsDeploymentChoice + .when(Condition.booleanEquals('$.should_continue_ecs_polling', true), waitBeforePollEcsDeployment) + .otherwise(hasCapacityUpdateChoice); + waitBeforePollEcsDeployment.next(handlePollEcsDeployment); + + // Capacity update flow + hasCapacityUpdateChoice + .when(Condition.booleanEquals('$.has_capacity_update', true), handlePollCapacity) + .otherwise(handleFinishUpdate); + + handlePollCapacity.next(pollAsgChoice); + pollAsgChoice.when(Condition.booleanEquals('$.should_continue_capacity_polling', true), waitBeforePollAsg) + .otherwise(handleFinishUpdate); + waitBeforePollAsg.next(handlePollCapacity); + + handleFinishUpdate.next(successState); + + const stateMachine = new StateMachine(this, 'UpdateMcpServerSM', { + definitionBody: DefinitionBody.fromChainable(handleJobIntake), + ...(executionRole && + { + role: executionRole + }) + }); + + this.stateMachineArn = stateMachine.stateMachineArn; + + } +} diff --git a/lib/models/model-api.ts b/lib/models/model-api.ts index 1689e0445..7f8d0fcf5 100644 --- a/lib/models/model-api.ts +++ b/lib/models/model-api.ts @@ -36,7 +36,7 @@ import { Provider } from 'aws-cdk-lib/custom-resources'; import { Construct } from 'constructs'; import { getDefaultRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../api-base/utils'; -import { BaseProps } from '../schema'; +import { APP_MANAGEMENT_KEY, BaseProps } from '../schema'; import { Vpc } from '../networking/vpc'; import { ECSModelDeployer } from './ecs-model-deployer'; @@ -143,7 +143,7 @@ export class ModelsApi extends Construct { vpc }); - const managementKeyName = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/managementKeySecretName`); + const managementKeyName = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`); const stateMachineExecutionRole = config.roles ? { executionRole: Role.fromRoleName(this, Roles.MODEL_SFN_ROLE, config.roles.ModelSfnRole) } : @@ -533,7 +533,7 @@ export class ModelsApi extends Construct { resources: [ lisaServeEndpointUrlParamArn, `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/lisaServeRestApiUri`, - `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter/LISA-lisa-management-key`, + `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter/LISA-management-key`, `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/LiteLLMDbConnectionInfo`, `arn:${config.partition}:ssm:${config.region}:${config.accountNumber}:parameter${config.deploymentPrefix}/modelTableName`, ], diff --git a/lib/rag/api/repository.ts b/lib/rag/api/repository.ts index 84caa897b..404b92324 100644 --- a/lib/rag/api/repository.ts +++ b/lib/rag/api/repository.ts @@ -18,7 +18,7 @@ import { Duration } from 'aws-cdk-lib'; import { IAuthorizer, RestApi } from 'aws-cdk-lib/aws-apigateway'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { IRole } from 'aws-cdk-lib/aws-iam'; -import { ILayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { IFunction, ILayerVersion } from 'aws-cdk-lib/aws-lambda'; import { Construct } from 'constructs'; import { getDefaultRuntime, PythonLambdaFunction, registerAPIEndpoint } from '../../api-base/utils'; @@ -53,6 +53,8 @@ type RepositoryApiProps = { * API for RAG repository operations */ export class RepositoryApi extends Construct { + public createCollectionFunction: IFunction; + constructor (scope: Construct, id: string, props: RepositoryApiProps) { super(scope, id); @@ -116,20 +118,30 @@ export class RepositoryApi extends Construct { }, }, { - name: 'delete', + name: 'get_repository_by_id', resource: 'repository', - description: 'Delete a repository', + description: 'Get a repository by ID', path: 'repository/{repositoryId}', - method: 'DELETE', + method: 'GET', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'update_repository', + resource: 'repository', + description: 'Update a repository', + path: 'repository/{repositoryId}', + method: 'PUT', environment: { ...baseEnvironment, }, }, { - name: 'delete_index', + name: 'delete', resource: 'repository', - description: 'Delete an index within a repository', - path: 'repository/{repositoryId}/index/{modelName}', + description: 'Delete a repository', + path: 'repository/{repositoryId}', method: 'DELETE', environment: { ...baseEnvironment, @@ -166,6 +178,16 @@ export class RepositoryApi extends Construct { ...baseEnvironment, }, }, + { + name: 'get_document', + resource: 'repository', + description: 'Get a document by ID', + path: 'repository/{repositoryId}/{documentId}', + method: 'GET', + environment: { + ...baseEnvironment, + }, + }, { name: 'download_document', resource: 'repository', @@ -195,12 +217,92 @@ export class RepositoryApi extends Construct { environment: { ...baseEnvironment, }, + }, + { + name: 'list_collections', + resource: 'repository', + description: 'List all collections within a repository', + path: 'repository/{repositoryId}/collection', + method: 'GET', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'list_user_collections', + resource: 'repository', + description: 'List all collections user has access to across all repositories', + path: 'repository/collections', + method: 'GET', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'create_collection', + resource: 'repository', + description: 'Create a new collection within a repository', + path: 'repository/{repositoryId}/collection', + method: 'POST', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'get_collection', + resource: 'repository', + description: 'Get a collection by ID within a repository', + path: 'repository/{repositoryId}/collection/{collectionId}', + method: 'GET', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'update_collection', + resource: 'repository', + description: 'Update a collection within a repository', + path: 'repository/{repositoryId}/collection/{collectionId}', + method: 'PUT', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'delete_collection', + resource: 'repository', + description: 'Delete a collection within a repository', + path: 'repository/{repositoryId}/collection/{collectionId}', + method: 'DELETE', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'list_bedrock_knowledge_bases', + resource: 'repository', + description: 'List all ACTIVE Bedrock Knowledge Bases', + path: 'bedrock-kb', + method: 'GET', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'list_bedrock_data_sources', + resource: 'repository', + description: 'List data sources for a Bedrock Knowledge Base', + path: 'bedrock-kb/{kbId}/data-sources', + method: 'GET', + environment: { + ...baseEnvironment, + }, } ]; const lambdaPath = config.lambdaPath || LAMBDA_PATH; apis.forEach((f) => { - registerAPIEndpoint( + const lambdaFunction = registerAPIEndpoint( this, restApi, lambdaPath, @@ -212,6 +314,11 @@ export class RepositoryApi extends Construct { authorizer, lambdaExecutionRole, ); + + // Capture create_collection Lambda for backward compatibility + if (f.name === 'create_collection') { + this.createCollectionFunction = lambdaFunction; + } }); } } diff --git a/lib/rag/ingestion/ingestion-job-construct.ts b/lib/rag/ingestion/ingestion-job-construct.ts index 918f5c210..038bbfacb 100644 --- a/lib/rag/ingestion/ingestion-job-construct.ts +++ b/lib/rag/ingestion/ingestion-job-construct.ts @@ -210,12 +210,16 @@ export class IngestionJobConstruct extends Construct { layers: layers, role: lambdaRole }); + const scheduleAlias = new lambda.Alias(this, 'ScheduleLambdaAlias', { + aliasName: 'live', + version: handlePipelineIngestScheduleLambda.currentVersion + }); const scheduleParameterName = `${config.deploymentPrefix}/ingestion/ingest/schedule`; new StringParameter(this, 'IngestionJobScheduleLambdaArn', { parameterName: scheduleParameterName, - stringValue: handlePipelineIngestScheduleLambda.functionArn + stringValue: scheduleAlias.functionArn }); - handlePipelineIngestScheduleLambda.addPermission('AllowEventBridgeInvoke', { + scheduleAlias.addPermission('AllowEventBridgeInvoke', { principal: new iam.ServicePrincipal('events.amazonaws.com'), action: 'lambda:InvokeFunction' }); @@ -233,12 +237,16 @@ export class IngestionJobConstruct extends Construct { layers, role: lambdaRole }); + const eventAlias = new lambda.Alias(this, 'EventLambdaAlias', { + aliasName: 'live', + version: handlePipelineIngestEvent.currentVersion + }); const eventParameterName = `${config.deploymentPrefix}/ingestion/ingest/event`; new StringParameter(this, 'IngestionJobEventLambdaArn', { parameterName: eventParameterName, - stringValue: handlePipelineIngestEvent.functionArn + stringValue: eventAlias.functionArn }); - handlePipelineIngestEvent.addPermission('AllowEventBridgeInvoke', { + eventAlias.addPermission('AllowEventBridgeInvoke', { principal: new iam.ServicePrincipal('events.amazonaws.com'), action: 'lambda:InvokeFunction' }); @@ -256,13 +264,17 @@ export class IngestionJobConstruct extends Construct { layers, role: lambdaRole }); + const deleteAlias = new lambda.Alias(this, 'DeleteLambdaAlias', { + aliasName: 'live', + version: handlePipelineDeleteEvent.currentVersion + }); const deleteParameterName = `${config.deploymentPrefix}/ingestion/delete/event`; new StringParameter(this, 'DeletionJobEventLambdaArn', { parameterName: deleteParameterName, - stringValue: handlePipelineDeleteEvent.functionArn + stringValue: deleteAlias.functionArn }); - handlePipelineDeleteEvent.addPermission('AllowEventBridgeInvoke', { + deleteAlias.addPermission('AllowEventBridgeInvoke', { principal: new iam.ServicePrincipal('events.amazonaws.com'), action: 'lambda:InvokeFunction' }); diff --git a/lib/rag/ragConstruct.ts b/lib/rag/ragConstruct.ts index 509287b82..d97894133 100644 --- a/lib/rag/ragConstruct.ts +++ b/lib/rag/ragConstruct.ts @@ -28,7 +28,7 @@ import { ARCHITECTURE } from '../core'; import { Layer } from '../core/layers'; import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; -import { BaseProps, Config, RDSConfig } from '../schema'; +import { APP_MANAGEMENT_KEY, BaseProps, Config, RDSConfig } from '../schema'; import { SecurityGroupEnum } from '../core/iam/SecurityGroups'; import { SecurityGroupFactory } from '../networking/vpc/security-group-factory'; import { Roles } from '../core/iam/roles'; @@ -152,6 +152,62 @@ export class LisaRagConstruct extends Construct { removalPolicy: config.removalPolicy, }); + // Create Collections table + const collectionsTableName = createCdkId([config.deploymentName, 'RagCollectionsTable']); + const collectionsTable = new Table(scope, collectionsTableName, { + partitionKey: { + name: 'collectionId', + type: AttributeType.STRING, + }, + sortKey: { + name: 'repositoryId', + type: AttributeType.STRING + }, + billingMode: dynamodb.BillingMode.PAY_PER_REQUEST, + encryption: dynamodb.TableEncryption.AWS_MANAGED, + removalPolicy: config.removalPolicy, + timeToLiveAttribute: 'ttl', + }); + + // Add GSI for querying collections by repository + collectionsTable.addGlobalSecondaryIndex({ + indexName: 'RepositoryIndex', + partitionKey: { + name: 'repositoryId', + type: AttributeType.STRING, + }, + sortKey: { + name: 'createdAt', + type: AttributeType.STRING, + } + }); + + // Add GSI for filtering collections by status + collectionsTable.addGlobalSecondaryIndex({ + indexName: 'StatusIndex', + partitionKey: { + name: 'repositoryId', + type: AttributeType.STRING, + }, + sortKey: { + name: 'status', + type: AttributeType.STRING, + } + }); + + // Add GSI to document table for querying documents by collection + docMetaTable.addGlobalSecondaryIndex({ + indexName: 'CollectionIndex', + partitionKey: { + name: 'collectionId', + type: AttributeType.STRING, + }, + sortKey: { + name: 'createdAt', + type: AttributeType.STRING, + } + }); + const modelTableNameStringParameter = StringParameter.fromStringParameterName(this, 'ModelTableNameStringParameter', `${config.deploymentPrefix}/modelTableName`); const baseEnvironment: Record = { @@ -160,8 +216,9 @@ export class LisaRagConstruct extends Construct { CHUNK_OVERLAP: config.ragFileProcessingConfig!.chunkOverlap.toString(), CHUNK_SIZE: config.ragFileProcessingConfig!.chunkSize.toString(), LISA_API_URL_PS_NAME: endpointUrl.parameterName, + LISA_RAG_COLLECTIONS_TABLE: collectionsTable.tableName, LOG_LEVEL: config.logLevel, - MANAGEMENT_KEY_SECRET_NAME_PS: `${config.deploymentPrefix}/managementKeySecretName`, + MANAGEMENT_KEY_SECRET_NAME_PS: `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`, MODEL_TABLE_NAME: modelTableNameStringParameter.stringValue, RAG_DOCUMENT_TABLE: docMetaTable.tableName, RAG_SUB_DOCUMENT_TABLE: subDocTable.tableName, @@ -320,15 +377,6 @@ export class LisaRagConstruct extends Construct { layers, }); - new VectorStoreCreator(scope, 'VectorStoreCreatorStack', { - config, - vpc, - ragVectorStoreTable, - stackName: createCdkId([config.appName, config.deploymentName, config.deploymentStage, 'vectorstore-creator']), - baseEnvironment, - layers - }); - this.legacyRepositories( config, vpc, @@ -354,10 +402,20 @@ export class LisaRagConstruct extends Construct { lambdaExecutionRole: lambdaRole, }); + new VectorStoreCreator(scope, 'VectorStoreCreatorStack', { + config, + vpc, + ragVectorStoreTable, + stackName: createCdkId([config.appName, config.deploymentName, config.deploymentStage, 'vectorstore-creator']), + baseEnvironment, + layers + }); + modelsPs.grantRead(lambdaRole); endpointUrl.grantRead(lambdaRole); docMetaTable.grantReadWriteData(lambdaRole); subDocTable.grantReadWriteData(lambdaRole); + collectionsTable.grantReadWriteData(lambdaRole); } legacyRepositories ( diff --git a/lib/rag/state_machine/pipeline-state-machine.ts b/lib/rag/state_machine/pipeline-state-machine.ts new file mode 100644 index 000000000..790f783bf --- /dev/null +++ b/lib/rag/state_machine/pipeline-state-machine.ts @@ -0,0 +1,298 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { Construct } from 'constructs'; +import { Duration, Stack } from 'aws-cdk-lib'; +import { BaseProps } from '../../schema'; +import { ILayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { Effect, PolicyStatement, Role, ServicePrincipal } from 'aws-cdk-lib/aws-iam'; +import { Vpc } from '../../networking/vpc'; +import { ITable } from 'aws-cdk-lib/aws-dynamodb'; +import { StringParameter } from 'aws-cdk-lib/aws-ssm'; +import { createCdkId } from '../../core/utils'; +import { Roles } from '../../core/iam/roles'; +import * as sfn from 'aws-cdk-lib/aws-stepfunctions'; +import * as tasks from 'aws-cdk-lib/aws-stepfunctions-tasks'; +import { LAMBDA_MEMORY, LAMBDA_TIMEOUT, OUTPUT_PATH } from './constants'; +import { getDefaultRuntime } from '../../api-base/utils'; +import * as lambda from 'aws-cdk-lib/aws-lambda'; + +export type PipelineStateMachineProps = BaseProps & { + vpc: Vpc; + layers: ILayerVersion[]; + collectionsTable: ITable; + baseEnvironment: Record; +}; + +/** + * Pipeline State Machine for managing collection pipeline lifecycle. + * + * This construct creates a Step Functions state machine that orchestrates + * the creation, update, and deletion of EventBridge rules for collection pipelines. + */ +export class PipelineStateMachine extends Construct { + public readonly stateMachine: sfn.StateMachine; + public readonly stateMachineArn: string; + + constructor (scope: Construct, id: string, props: PipelineStateMachineProps) { + super(scope, id); + + const { config, vpc, layers, collectionsTable, baseEnvironment } = props; + const stack = Stack.of(this); + + // Get the Lambda execution role from SSM parameter + const lambdaExecutionRole = Role.fromRoleArn( + this, + createCdkId([Roles.RAG_LAMBDA_EXECUTION_ROLE, 'PipelineSM']), + StringParameter.valueForStringParameter( + this, + `${config.deploymentPrefix}/roles/${createCdkId([config.deploymentName, Roles.RAG_LAMBDA_EXECUTION_ROLE])}`, + ), + ); + + // Grant permissions for EventBridge rule management + lambdaExecutionRole.addToPrincipalPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'events:PutRule', + 'events:DeleteRule', + 'events:PutTargets', + 'events:RemoveTargets', + 'events:DescribeRule', + 'events:ListTargetsByRule', + 'events:ListRules' + ], + resources: [ + `arn:${config.partition}:events:${config.region}:${stack.account}:rule/${config.deploymentName}-*` + ] + })); + + // Grant permissions for Lambda invocation permissions + lambdaExecutionRole.addToPrincipalPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'lambda:AddPermission', + 'lambda:RemovePermission', + 'lambda:GetPolicy' + ], + resources: [ + `arn:${config.partition}:lambda:${config.region}:${stack.account}:function:${config.deploymentName}-*` + ] + })); + + // Grant DynamoDB permissions for collections table + collectionsTable.grantReadWriteData(lambdaExecutionRole); + + const lambdaEnvironment = { + ...baseEnvironment, + COLLECTIONS_TABLE_NAME: collectionsTable.tableName, + DEPLOYMENT_NAME: config.deploymentName, + DEPLOYMENT_STAGE: config.deploymentStage, + PARTITION: config.partition, + REGION: config.region + }; + + // Lambda function for input validation + const validateInputLambda = new lambda.Function(this, 'ValidateInputLambda', { + functionName: `${config.deploymentName}-${config.deploymentStage}-pipeline-validate-input`, + runtime: getDefaultRuntime(), + handler: 'repository.state_machine.pipeline_validate_input.handler', + code: lambda.Code.fromAsset('./lambda'), + timeout: Duration.seconds(30), + memorySize: 256, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + environment: lambdaEnvironment, + layers: layers, + role: lambdaExecutionRole + }); + + // Lambda function for creating pipeline rules + const createPipelineRulesLambda = new lambda.Function(this, 'CreatePipelineRulesLambda', { + functionName: `${config.deploymentName}-${config.deploymentStage}-pipeline-create-rules`, + runtime: getDefaultRuntime(), + handler: 'repository.state_machine.pipeline_create_rules.handler', + code: lambda.Code.fromAsset('./lambda'), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + environment: lambdaEnvironment, + layers: layers, + role: lambdaExecutionRole + }); + + // Lambda function for updating pipeline rules + const updatePipelineRulesLambda = new lambda.Function(this, 'UpdatePipelineRulesLambda', { + functionName: `${config.deploymentName}-${config.deploymentStage}-pipeline-update-rules`, + runtime: getDefaultRuntime(), + handler: 'repository.state_machine.pipeline_update_rules.handler', + code: lambda.Code.fromAsset('./lambda'), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + environment: lambdaEnvironment, + layers: layers, + role: lambdaExecutionRole + }); + + // Lambda function for deleting pipeline rules + const deletePipelineRulesLambda = new lambda.Function(this, 'DeletePipelineRulesLambda', { + functionName: `${config.deploymentName}-${config.deploymentStage}-pipeline-delete-rules`, + runtime: getDefaultRuntime(), + handler: 'repository.state_machine.pipeline_delete_rules.handler', + code: lambda.Code.fromAsset('./lambda'), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + environment: lambdaEnvironment, + layers: layers, + role: lambdaExecutionRole + }); + + // Lambda function for updating collection status + const updateCollectionStatusLambda = new lambda.Function(this, 'UpdateCollectionStatusLambda', { + functionName: `${config.deploymentName}-${config.deploymentStage}-pipeline-update-status`, + runtime: getDefaultRuntime(), + handler: 'repository.state_machine.pipeline_update_status.handler', + code: lambda.Code.fromAsset('./lambda'), + timeout: Duration.seconds(30), + memorySize: 256, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + environment: lambdaEnvironment, + layers: layers, + role: lambdaExecutionRole + }); + + // Define Step Functions tasks + const validateInput = new tasks.LambdaInvoke(this, 'Validate Input', { + lambdaFunction: validateInputLambda, + outputPath: OUTPUT_PATH, + resultPath: '$.validationResult' + }); + + const createRules = new tasks.LambdaInvoke(this, 'Create Pipeline Rules', { + lambdaFunction: createPipelineRulesLambda, + outputPath: OUTPUT_PATH, + resultPath: '$.createResult' + }); + + const updateRules = new tasks.LambdaInvoke(this, 'Update Pipeline Rules', { + lambdaFunction: updatePipelineRulesLambda, + outputPath: OUTPUT_PATH, + resultPath: '$.updateResult' + }); + + const deleteRules = new tasks.LambdaInvoke(this, 'Delete Pipeline Rules', { + lambdaFunction: deletePipelineRulesLambda, + outputPath: OUTPUT_PATH, + resultPath: '$.deleteResult' + }); + + const updateStatusSuccess = new tasks.LambdaInvoke(this, 'Update Status - Success', { + lambdaFunction: updateCollectionStatusLambda, + payload: sfn.TaskInput.fromObject({ + 'repositoryId.$': '$.repositoryId', + 'collectionId.$': '$.collectionId', + 'status': 'ACTIVE' + }), + outputPath: OUTPUT_PATH + }); + + const updateStatusFailed = new tasks.LambdaInvoke(this, 'Update Status - Failed', { + lambdaFunction: updateCollectionStatusLambda, + payload: sfn.TaskInput.fromObject({ + 'repositoryId.$': '$.repositoryId', + 'collectionId.$': '$.collectionId', + 'status': 'PIPELINE_FAILED', + 'error.$': '$.error' + }), + outputPath: OUTPUT_PATH + }); + + const succeed = new sfn.Succeed(this, 'Pipeline Operation Succeeded'); + const fail = new sfn.Fail(this, 'Pipeline Operation Failed', { + cause: 'Pipeline operation failed', + error: 'PipelineOperationError' + }); + + // Define operation routing choice + const operationChoice = new sfn.Choice(this, 'Route by Operation') + .when(sfn.Condition.stringEquals('$.operation', 'CREATE'), createRules) + .when(sfn.Condition.stringEquals('$.operation', 'UPDATE'), updateRules) + .when(sfn.Condition.stringEquals('$.operation', 'DELETE'), deleteRules) + .otherwise(fail); + + // Define state machine workflow + const definition = validateInput + .next(operationChoice); + + createRules + .addCatch(updateStatusFailed.next(fail), { + resultPath: '$.error' + }) + .next(updateStatusSuccess) + .next(succeed); + + updateRules + .addCatch(updateStatusFailed.next(fail), { + resultPath: '$.error' + }) + .next(updateStatusSuccess) + .next(succeed); + + deleteRules + .addCatch(updateStatusFailed.next(fail), { + resultPath: '$.error' + }) + .next(updateStatusSuccess) + .next(succeed); + + // Create IAM role for Step Functions + const stateMachineRole = new Role(this, 'PipelineStateMachineRole', { + assumedBy: new ServicePrincipal('states.amazonaws.com'), + description: 'Role for Pipeline State Machine' + }); + + // Grant permissions to invoke Lambda functions + validateInputLambda.grantInvoke(stateMachineRole); + createPipelineRulesLambda.grantInvoke(stateMachineRole); + updatePipelineRulesLambda.grantInvoke(stateMachineRole); + deletePipelineRulesLambda.grantInvoke(stateMachineRole); + updateCollectionStatusLambda.grantInvoke(stateMachineRole); + + // Create the state machine + this.stateMachine = new sfn.StateMachine(this, 'PipelineStateMachine', { + stateMachineName: `${config.deploymentName}-${config.deploymentStage}-pipeline-state-machine`, + definition, + role: stateMachineRole, + timeout: Duration.minutes(30), + tracingEnabled: true + }); + + this.stateMachineArn = this.stateMachine.stateMachineArn; + + // Store state machine ARN in SSM for Collection API to use + new StringParameter(this, 'PipelineStateMachineArnParameter', { + parameterName: `${config.deploymentPrefix}/pipeline/statemachine/arn`, + stringValue: this.stateMachineArn, + description: 'ARN of the Pipeline State Machine' + }); + } +} diff --git a/lib/rag/vector-store/state_machine/create-store.ts b/lib/rag/vector-store/state_machine/create-store.ts index 8b2d30131..185fb5ab7 100644 --- a/lib/rag/vector-store/state_machine/create-store.ts +++ b/lib/rag/vector-store/state_machine/create-store.ts @@ -14,7 +14,7 @@ limitations under the License. */ import { Construct } from 'constructs'; -import { BaseProps } from '../../../schema'; +import { BaseProps, VectorStoreStatus, } from '../../../schema'; import * as ddb from 'aws-cdk-lib/aws-dynamodb'; import * as lambda from 'aws-cdk-lib/aws-lambda'; import * as iam from 'aws-cdk-lib/aws-iam'; @@ -25,6 +25,7 @@ import { Vpc } from '../../../networking/vpc'; import * as ssm from 'aws-cdk-lib/aws-ssm'; type CreateStoreStateMachineProps = BaseProps & { + createBedrockCollectionFnArn: string; executionRole: iam.IRole; parameterName: string, role?: iam.IRole, @@ -40,21 +41,26 @@ export class CreateStoreStateMachine extends Construct { constructor (scope: Construct, id: string, props: CreateStoreStateMachineProps) { super(scope, id); - const { config, executionRole, parameterName, role, vectorStoreConfigTable, vectorStoreDeployerFnArn } = props; + const { config, createBedrockCollectionFnArn, executionRole, parameterName, role, vectorStoreConfigTable, vectorStoreDeployerFnArn } = props; + + // Get reference to the Bedrock collection creation Lambda + const createBedrockCollectionFn = lambda.Function.fromFunctionArn( + this, + 'CreateBedrockCollectionFunction', + createBedrockCollectionFnArn + ); // Task to create an entry in DynamoDB for the vector store const createVectorStoreEntry = new tasks.DynamoPutItem(this, 'CreateVectorStoreEntry', { table: vectorStoreConfigTable, item: { repositoryId: tasks.DynamoAttributeValue.fromString(sfn.JsonPath.stringAt('$.body.ragConfig.repositoryId')), - status: tasks.DynamoAttributeValue.fromString('CREATE_IN_PROGRESS'), + status: tasks.DynamoAttributeValue.fromString(VectorStoreStatus.CREATE_IN_PROGRESS), config: tasks.DynamoAttributeValue.mapFromJsonPath('$.config') }, resultPath: '$.dynamoResult', }); - const createVectorStoreInfraChoice = new sfn.Choice(this, 'CreateVectorStoreInfraChoice'); - // Task to invoke a Lambda function to deploy the vector store const deployVectorStore = new tasks.LambdaInvoke(this, 'DeployVectorStore', { lambdaFunction: lambda.Function.fromFunctionArn(this, 'VectorStoreDeployer', vectorStoreDeployerFnArn), @@ -68,7 +74,7 @@ export class CreateStoreStateMachine extends Construct { }); // Task to check the deployment status using a Lambda function - const checkDeploymentStatus = new tasks.CallAwsService(this, 'DescribeStack', { + const checkDeploymentStatus = new tasks.CallAwsService(this, 'CheckDeploymentStatus', { service: 'cloudformation', action: 'describeStacks', parameters: { @@ -87,26 +93,25 @@ export class CreateStoreStateMachine extends Construct { time: sfn.WaitTime.duration(Duration.seconds(30)), }); - // Task to update the status of the vector store entry to 'COMPLETED' on successful deployment - const updateBedrockKBSuccess = new tasks.DynamoUpdateItem(this, 'UpdateBedrockKBSuccess', { - table: vectorStoreConfigTable, - key: { repositoryId: tasks.DynamoAttributeValue.fromString(sfn.JsonPath.stringAt('$.body.ragConfig.repositoryId')) }, - updateExpression: 'SET #status = :status', - expressionAttributeNames: { '#status': 'status' }, - expressionAttributeValues: { - ':status': tasks.DynamoAttributeValue.fromString('CREATE_COMPLETE') - }, + // Task to create default collection for Bedrock KB + const createDefaultCollectionTask = new tasks.LambdaInvoke(this, 'CreateDefaultCollection', { + lambdaFunction: createBedrockCollectionFn, + payload: sfn.TaskInput.fromObject({ + ragConfig: sfn.JsonPath.objectAt('$.body.ragConfig'), + }), + resultPath: '$.collectionResult', }); // Task to update the status of the vector store entry to 'COMPLETED' on successful deployment + // For Bedrock KB without pipelines, stackName may be null const updateSuccessStatus = new tasks.DynamoUpdateItem(this, 'UpdateSuccessStatus', { table: vectorStoreConfigTable, key: { repositoryId: tasks.DynamoAttributeValue.fromString(sfn.JsonPath.stringAt('$.body.ragConfig.repositoryId')) }, updateExpression: 'SET #status = :status, #stackName = :stackName', expressionAttributeNames: { '#status': 'status', '#stackName': 'stackName' }, expressionAttributeValues: { - ':status': tasks.DynamoAttributeValue.fromString('CREATE_COMPLETE'), - ':stackName': tasks.DynamoAttributeValue.fromString(sfn.JsonPath.stringAt('$.deployResult.stackName') ?? '') + ':status': tasks.DynamoAttributeValue.fromString(VectorStoreStatus.CREATE_COMPLETE), + ':stackName': tasks.DynamoAttributeValue.fromString(sfn.JsonPath.stringAt('$.deployResult.stackName')) }, }); @@ -114,47 +119,59 @@ export class CreateStoreStateMachine extends Construct { const updateFailureStatus = new tasks.DynamoUpdateItem(this, 'UpdateFailureStatus', { table: vectorStoreConfigTable, key: { repositoryId: tasks.DynamoAttributeValue.fromString(sfn.JsonPath.stringAt('$.body.ragConfig.repositoryId')) }, - updateExpression: 'SET #status = :status, #stackName = :stackName', - expressionAttributeNames: { '#status': 'status', '#stackName': 'stackName' }, + updateExpression: 'SET #status = :status', + expressionAttributeNames: { '#status': 'status' }, expressionAttributeValues: { - ':status': tasks.DynamoAttributeValue.fromString('CREATE_FAILED'), - ':stackName': tasks.DynamoAttributeValue.fromString(sfn.JsonPath.stringAt('$.deployResult.stackName')) + ':status': tasks.DynamoAttributeValue.fromString(VectorStoreStatus.CREATE_FAILED) }, }); + // Check if this is a Bedrock KB repository to create default collections + const skipCollectionCreation = new sfn.Pass(this, 'SkipCollectionCreation'); + + const checkIfBedrockKB = new sfn.Choice(this, 'IsBedrockKB?') + .when( + sfn.Condition.stringEquals('$.body.ragConfig.type', 'bedrock_knowledge_base'), + createDefaultCollectionTask + ) + .otherwise(skipCollectionCreation); + + // Both paths converge to update success status + createDefaultCollectionTask.next(updateSuccessStatus); + skipCollectionCreation.next(updateSuccessStatus); + // Define the sequence of tasks and conditions in the state machine + const deploymentComplete = new sfn.Choice(this, 'DeploymentComplete?') + .when( + sfn.Condition.and( + sfn.Condition.isPresent('$.deployResult.status'), + sfn.Condition.or( + sfn.Condition.stringEquals('$.deployResult.status', VectorStoreStatus.CREATE_IN_PROGRESS), + sfn.Condition.stringEquals('$.deployResult.status', VectorStoreStatus.UPDATE_IN_PROGRESS), + sfn.Condition.stringEquals('$.deployResult.status', VectorStoreStatus.UPDATE_COMPLETE_CLEANUP_IN_PROGRESS), + ), + ), + wait.next(checkDeploymentStatus) + ) + .when( + sfn.Condition.and( + sfn.Condition.isPresent('$.deployResult.status'), + sfn.Condition.or( + sfn.Condition.stringEquals('$.deployResult.status', VectorStoreStatus.CREATE_COMPLETE), + sfn.Condition.stringEquals('$.deployResult.status', VectorStoreStatus.UPDATE_COMPLETE), + ), + ), + checkIfBedrockKB + ) + .otherwise(updateFailureStatus); + + checkDeploymentStatus.next(deploymentComplete); + const definition = createVectorStoreEntry - .next(createVectorStoreInfraChoice - .when(sfn.Condition.and(sfn.Condition.stringEquals('$.body.ragConfig.type', 'bedrock_knowledge_base'), - sfn.Condition.isNotPresent('$.body.ragConfig.pipelines[0]')), updateBedrockKBSuccess) - .otherwise(deployVectorStore.addCatch(updateFailureStatus) - .next( - checkDeploymentStatus.next( - new sfn.Choice(this, 'DeploymentComplete?') - .when( - sfn.Condition.and( - sfn.Condition.isPresent('$.deployResult.status'), - sfn.Condition.or( - sfn.Condition.stringEquals('$.deployResult.status', 'CREATE_IN_PROGRESS'), - sfn.Condition.stringEquals('$.deployResult.status', 'UPDATE_IN_PROGRESS'), - sfn.Condition.stringEquals('$.deployResult.status', 'UPDATE_COMPLETE_CLEANUP_IN_PROGRESS'), - ), - ), - wait.next(checkDeploymentStatus) - ) - .when( - sfn.Condition.and( - sfn.Condition.isPresent('$.deployResult.status'), - sfn.Condition.or( - sfn.Condition.stringEquals('$.deployResult.status', 'CREATE_COMPLETE'), - sfn.Condition.stringEquals('$.deployResult.status', 'UPDATE_COMPLETE'), - ), - ), - updateSuccessStatus - ) - .otherwise(updateFailureStatus) - ) - ))); + .next( + deployVectorStore.addCatch(updateFailureStatus, { resultPath: '$.error' }) + .next(checkDeploymentStatus) + ); // Create a new state machine using the definition and roles specified this.stateMachine = new sfn.StateMachine(this, 'CreateStoreStateMachine', { diff --git a/lib/rag/vector-store/state_machine/delete-store.ts b/lib/rag/vector-store/state_machine/delete-store.ts index 1c5cb23f3..dc4b4a034 100644 --- a/lib/rag/vector-store/state_machine/delete-store.ts +++ b/lib/rag/vector-store/state_machine/delete-store.ts @@ -14,7 +14,7 @@ limitations under the License. */ import { Construct } from 'constructs'; -import { BaseProps } from '../../../schema'; +import { BaseProps, VectorStoreStatus, } from '../../../schema'; import { ITable } from 'aws-cdk-lib/aws-dynamodb'; import { Code, Function, ILayerVersion } from 'aws-cdk-lib/aws-lambda'; import * as iam from 'aws-cdk-lib/aws-iam'; @@ -120,21 +120,45 @@ export class DeleteStoreStateMachine extends Construct { updateExpression: 'SET #status = :status', expressionAttributeNames: { '#status': 'status' }, expressionAttributeValues: { - ':status': tasks.DynamoAttributeValue.fromString('DELETE_IN_PROGRESS'), + ':status': tasks.DynamoAttributeValue.fromString(VectorStoreStatus.DELETE_IN_PROGRESS), }, resultPath: '$.updateDynamoDbResult', }); - const handleCleanupBedrockKnowledgeBase = new Choice(this, 'BedrockKnowledgeBase') - .when(sfn.Condition.and(sfn.Condition.stringEquals('$.ddbResult.Item.config.M.type.S', 'bedrock_knowledge_base'), - sfn.Condition.isNull('$.stackName')), deleteDynamoDbEntry) - .otherwise(deleteStack); + const handleStackDeletion = new Choice(this, 'HasStackName') + .when(sfn.Condition.isPresent('$.stackName'), deleteStack) + .otherwise(deleteDynamoDbEntry); + + // Check if stackName exists in DDB (for Bedrock KB without pipelines, it may be NULL) + const checkStackNameExists = new Choice(this, 'CheckStackNameExists') + .when( + sfn.Condition.isPresent('$.ddbResult.Item.stackName.S'), + new sfn.Pass(this, 'ExtractStackName', { + parameters: { + 'repositoryId.$': '$.repositoryId', + 'stackName.$': '$.ddbResult.Item.stackName.S', + 'documents.$': '$.documents', + 'lastEvaluated.$': '$.lastEvaluated', + 'ddbResult.$': '$.ddbResult', + }, + }).next(handleStackDeletion) + ) + .otherwise( + new sfn.Pass(this, 'HandleMissingStackName', { + parameters: { + 'repositoryId.$': '$.repositoryId', + 'documents.$': '$.documents', + 'lastEvaluated.$': '$.lastEvaluated', + 'ddbResult.$': '$.ddbResult', + }, + }).next(handleStackDeletion) + ); const getRepoFromDdb = new tasks.DynamoGetItem(this, 'GetRepoFromDdb', { table: ragVectorStoreTable, key: { repositoryId: tasks.DynamoAttributeValue.fromString(sfn.JsonPath.stringAt('$.repositoryId')) }, resultPath: '$.ddbResult', - }).next(handleCleanupBedrockKnowledgeBase); + }).next(checkStackNameExists); const lambdaPath = config.lambdaPath || LAMBDA_PATH; @@ -150,9 +174,22 @@ export class DeleteStoreStateMachine extends Construct { role: executionRole, }); - // Allow the Step Functions role to invoke the cleanup lambda + const waitForCollectionDeletionsFunc = new Function(this, 'WaitForCollectionDeletionsFunc', { + runtime: getDefaultRuntime(), + handler: 'repository.state_machine.wait_for_collection_deletions.lambda_handler', + code: Code.fromAsset(lambdaPath), + timeout: Duration.seconds(30), + memorySize: LAMBDA_MEMORY, + vpc: vpc.vpc, + environment: environment, + layers: lambdaLayers, + role: executionRole, + }); + + // Allow the Step Functions role to invoke the lambdas if (role) { cleanupDocsFunc.grantInvoke(role); + waitForCollectionDeletionsFunc.grantInvoke(role); } const hasMoreDocs = new Choice(this, 'HasMoreDocs') @@ -176,17 +213,37 @@ export class DeleteStoreStateMachine extends Construct { outputPath: OUTPUT_PATH, }); + // Wait for collection deletions to complete + const waitForCollectionDeletions = new LambdaInvoke(this, 'WaitForCollectionDeletions', { + lambdaFunction: waitForCollectionDeletionsFunc, + payload: sfn.TaskInput.fromObject({ + 'repositoryId.$': '$.repositoryId', + 'stackName.$': '$.stackName', + }), + outputPath: OUTPUT_PATH, + }); + + const waitForCollectionDeletionsRetry = new sfn.Wait(this, 'WaitForCollectionDeletionsRetry', { + time: sfn.WaitTime.duration(Duration.seconds(10)), + }).next(waitForCollectionDeletions); + + const checkCollectionDeletionsComplete = new Choice(this, 'CheckCollectionDeletionsComplete') + .when(Condition.booleanEquals('$.allCollectionDeletionsComplete', true), cleanupDocs.next(hasMoreDocs)) + .otherwise(waitForCollectionDeletionsRetry); + + waitForCollectionDeletions.next(checkCollectionDeletionsComplete); + const shouldSkipCleanup = new Choice(this, 'ShouldSkipCleanup') .when(Condition.and(Condition.isPresent('$.skipDocumentRemoval'), Condition.booleanEquals('$.skipDocumentRemoval', true)), - handleCleanupBedrockKnowledgeBase) - .otherwise(cleanupDocs.next(hasMoreDocs)); + getRepoFromDdb) + .otherwise(waitForCollectionDeletions); deleteStack.next(checkStackStatus.addCatch(deleteDynamoDbEntry, { resultPath: '$.error' })) .next( new sfn.Choice(this, 'DeletionSuccessful?') - .when(sfn.Condition.stringEquals('$.checkResult.status', 'DELETE_FAILED'), updateFailureStatus) + .when(sfn.Condition.stringEquals('$.checkResult.status', VectorStoreStatus.DELETE_FAILED), updateFailureStatus) .otherwise(wait.next(checkStackStatus)) ); // Define the sequence of tasks and conditions in the state machine diff --git a/lib/rag/vector-store/vector-store-creator.ts b/lib/rag/vector-store/vector-store-creator.ts index 089150d48..b893e01b2 100644 --- a/lib/rag/vector-store/vector-store-creator.ts +++ b/lib/rag/vector-store/vector-store-creator.ts @@ -26,8 +26,9 @@ import { Roles } from '../../core/iam/roles'; import * as dynamodb from 'aws-cdk-lib/aws-dynamodb'; import * as ssm from 'aws-cdk-lib/aws-ssm'; import { ILayerVersion, Runtime } from 'aws-cdk-lib/aws-lambda'; -import { CodeFactory, VECTOR_STORE_DEPLOYER_DIST_PATH } from '../../util'; +import { CodeFactory, LAMBDA_PATH, VECTOR_STORE_DEPLOYER_DIST_PATH } from '../../util'; import { NodejsFunction } from 'aws-cdk-lib/aws-lambda-nodejs'; +import { getDefaultRuntime } from '../../api-base/utils'; export type VectorStoreCreatorStackProps = StackProps & BaseProps & { ragVectorStoreTable: CfnOutput; @@ -176,13 +177,33 @@ export class VectorStoreCreatorStack extends Construct { securityGroups: [props.vpc.securityGroups.lambdaSg], }); + // Create Lambda for Bedrock collection creation + const createBedrockCollectionFn = new lambda.Function(this, 'CreateBedrockCollectionFn', { + functionName: createCdkId([config.deploymentName, config.deploymentStage, 'create_bedrock_collection']), + runtime: getDefaultRuntime(), + handler: 'repository.lambda_functions.create_bedrock_collection', + code: lambda.Code.fromAsset(config.lambdaPath || LAMBDA_PATH), + timeout: Duration.minutes(5), + memorySize: 512, + role: lambdaExecutionRole, + environment: baseEnvironment, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: [vpc.securityGroups.lambdaSg], + layers: layers, + }); + + // Grant permissions + vectorStoreTable.grantReadWriteData(createBedrockCollectionFn); + // Allow the state machine to invoke the deployer Lambda this.vectorStoreCreatorFn.grantInvoke(stateMachineRole); + createBedrockCollectionFn.grantInvoke(stateMachineRole); // Minimal policies for state machine role stateMachineRole.addToPolicy(new iam.PolicyStatement({ actions: ['lambda:InvokeFunction'], - resources: [this.vectorStoreCreatorFn.functionArn], + resources: [this.vectorStoreCreatorFn.functionArn, createBedrockCollectionFn.functionArn], })); stateMachineRole.addToPolicy(new iam.PolicyStatement({ actions: ['cloudformation:DescribeStacks', 'cloudformation:DeleteStack'], @@ -195,6 +216,7 @@ export class VectorStoreCreatorStack extends Construct { new CreateStoreStateMachine(this, 'CreateVectorStoreStateMachine', { config: props.config, + createBedrockCollectionFnArn: createBedrockCollectionFn.functionArn, executionRole: lambdaExecutionRole, parameterName: baseEnvironment['LISA_RAG_CREATE_STATE_MACHINE_ARN_PARAMETER'], role: stateMachineRole, diff --git a/lib/schema/collectionSchema.ts b/lib/schema/collectionSchema.ts new file mode 100644 index 000000000..ac503d9a8 --- /dev/null +++ b/lib/schema/collectionSchema.ts @@ -0,0 +1,204 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { z } from 'zod'; +import { + RagRepositoryPipeline, + ChunkingStrategySchema, +} from './ragSchema'; + +/** + * Enum for collection status + */ +export enum CollectionStatus { + ACTIVE = 'ACTIVE', + ARCHIVED = 'ARCHIVED', + DELETED = 'DELETED', + DELETE_IN_PROGRESS = 'DELETE_IN_PROGRESS', + DELETE_FAILED = 'DELETE_FAILED' +} + + +// Future chunking strategies (not yet implemented): +// +// export const SemanticChunkingStrategySchema = z.object({ +// type: z.literal(ChunkingStrategyType.SEMANTIC).describe('Semantic chunking strategy type'), +// threshold: z.number().min(0.0).max(1.0).describe('Similarity threshold for semantic boundaries'), +// chunkSize: z.number().min(100).max(10000).default(1000).optional().describe('Maximum chunk size'), +// }); +// +// export const RecursiveChunkingStrategySchema = z.object({ +// type: z.literal(ChunkingStrategyType.RECURSIVE).describe('Recursive chunking strategy type'), +// chunkSize: z.number().min(100).max(10000).describe('Target size of each chunk'), +// chunkOverlap: z.number().min(0).describe('Overlap between chunks'), +// separators: z.array(z.string()).min(1).default(['\n\n', '\n', '. ', ' ']).describe('Separators to use for recursive splitting'), +// }).refine( +// (data) => data.chunkOverlap <= data.chunkSize / 2, +// { message: 'chunkOverlap must be less than or equal to half of chunkSize' } +// ); +// +// When implementing new strategies: +// 1. Add the strategy type to ChunkingStrategyType enum (uncomment above) +// 2. Create the schema (uncomment and modify above) +// 3. Update ChunkingStrategySchema to be a union: z.union([FixedSizeChunkingStrategySchema, SemanticChunkingStrategySchema, ...]) +// 4. Implement the backend handler in chunking_strategy_factory.py + +/** + * Pipeline configuration schema - reusing from ragSchema + */ +export const PipelineConfigSchema = RagRepositoryPipeline; + +/** + * Collection metadata schema + */ +export const CollectionMetadataSchema = z.object({ + tags: z.array( + z.string() + .max(50) + .regex(/^[a-zA-Z0-9_-]+$/, 'Tags must contain only alphanumeric characters, hyphens, and underscores') + ).max(50).default([]).describe('Metadata tags for the collection (max 50 tags, each max 50 chars)'), + customFields: z.record(z.any()).default({}).describe('Custom metadata fields'), +}); + +/** + * RAG Collection configuration schema + */ +export const RagCollectionConfigSchema = z.object({ + collectionId: z.string().uuid().describe('Unique collection identifier (UUID)'), + repositoryId: z.string().min(1).describe('Parent vector store ID'), + name: z.string() + .max(100) + .regex(/^[a-zA-Z0-9 _-]+$/, 'Collection name must contain only alphanumeric characters, spaces, hyphens, and underscores') + .optional() + .describe('User-friendly collection name'), + description: z.string().optional().describe('Collection description'), + chunkingStrategy: ChunkingStrategySchema.optional().describe('Chunking strategy for documents (inherits from parent if omitted)'), + allowChunkingOverride: z.boolean().default(true).describe('Allow users to override chunking strategy during ingestion'), + metadata: CollectionMetadataSchema.optional().describe('Collection-specific metadata (merged with parent metadata)'), + allowedGroups: z.array(z.string()).optional().describe('User groups with access to collection (inherits from parent if omitted)'), + embeddingModel: z.string().min(1).describe('Embedding model ID (can be set at creation, inherits from parent if omitted, immutable after creation)'), + createdBy: z.string().min(1).describe('User ID of creator'), + createdAt: z.string().datetime().describe('Creation timestamp (ISO 8601)'), + updatedAt: z.string().datetime().describe('Last update timestamp (ISO 8601)'), + status: z.nativeEnum(CollectionStatus).default(CollectionStatus.ACTIVE).describe('Collection status'), + pipelines: z.array(PipelineConfigSchema).default([]).describe('Automated ingestion pipelines'), + default: z.boolean().default(false).optional().describe('Indicates if this is a default collection (virtual, no DB entry)'), +}); + +/** + * Collection sort options + */ +export enum CollectionSortBy { + NAME = 'name', + CREATED_AT = 'createdAt', + UPDATED_AT = 'updatedAt', +} + +/** + * Sort order options + */ +export enum SortOrder { + ASC = 'asc', + DESC = 'desc', +} + +/** + * List collections query parameters schema + */ +export const ListCollectionsQuerySchema = z.object({ + page: z.number().int().min(1).default(1).describe('Page number'), + pageSize: z.number().int().min(1).max(100).default(20).describe('Number of items per page'), + filter: z.string().optional().describe('Filter by name or description (substring match)'), + sortBy: z.nativeEnum(CollectionSortBy).default(CollectionSortBy.CREATED_AT).describe('Sort field'), + sortOrder: z.nativeEnum(SortOrder).default(SortOrder.DESC).describe('Sort order'), + status: z.nativeEnum(CollectionStatus).optional().describe('Filter by status'), +}); + +/** + * List collections response schema + */ +export const ListCollectionsResponseSchema = z.object({ + collections: z.array(RagCollectionConfigSchema).describe('List of collections'), + totalCount: z.number().int().optional().describe('Total number of collections'), + currentPage: z.number().int().optional().describe('Current page number'), + totalPages: z.number().int().optional().describe('Total number of pages'), + hasNextPage: z.boolean().default(false).describe('Whether there is a next page'), + hasPreviousPage: z.boolean().default(false).describe('Whether there is a previous page'), + lastEvaluatedKey: z.record(z.string()).optional().describe('Last evaluated key for pagination'), +}); + +/** + * Type exports + */ +// ChunkingStrategy types are re-exported from ragSchema above +// export type SemanticChunkingStrategy = z.infer; // Not yet implemented +// export type RecursiveChunkingStrategy = z.infer; // Not yet implemented +// PipelineConfig type is exported from ragSchema via PipelineConfigSchema +export type CollectionMetadata = z.infer; +export type RagCollectionConfig = z.infer; +export type ListCollectionsQuery = z.infer; +export type ListCollectionsResponse = z.infer; + +/** + * Inheritance rules documentation + */ +export const COLLECTION_INHERITANCE_RULES = { + embeddingModel: { + inherited: true, + mutableAtCreation: true, + mutableAfterCreation: false, + description: 'Inherits from parent if not specified at creation. Can be overridden at creation time but becomes immutable after creation.', + }, + allowedGroups: { + inherited: true, + mutable: true, + constraint: 'Must be a subset of parent vector store\'s allowedGroups', + description: 'Inherits from parent if not specified, can be restricted but not expanded', + }, + chunkingStrategy: { + inherited: true, + mutable: true, + description: 'Inherits from parent if not specified, can be overridden per collection', + }, + metadata: { + inherited: true, + mergeStrategy: 'composite', + mutable: true, + description: 'Merged from parent and collection. Tags are combined (deduplicated), custom fields from collection override parent on conflict.', + }, +} as const; + +/** + * Immutable fields that cannot be changed after creation + */ +export const IMMUTABLE_FIELDS = [ + 'collectionId', + 'repositoryId', + 'embeddingModel', + 'createdBy', + 'createdAt', +] as const; + +/** + * Validation rules + */ +export const VALIDATION_RULES = { + nameUniqueness: 'Collection name must be unique within the parent repository', + allowedGroupsSubset: 'Collection allowedGroups must be a subset of parent repository allowedGroups', + chunkOverlapConstraint: 'chunkOverlap must be less than or equal to chunkSize/2', + tagsLimit: 'Maximum 50 tags per collection, each tag maximum 50 characters', + nameCharacters: 'Name must contain only alphanumeric characters, spaces, hyphens, and underscores', +} as const; diff --git a/lib/schema/configSchema.ts b/lib/schema/configSchema.ts index 7734d74f6..ca3e91d2e 100644 --- a/lib/schema/configSchema.ts +++ b/lib/schema/configSchema.ts @@ -795,6 +795,7 @@ const RoleConfig = z.object({ S3ReaderRole: z.string().max(64).optional(), UIDeploymentRole: z.string().max(64).optional(), VectorStoreCreatorRole: z.string().max(64).optional(), + McpServerDeployerRole: z.string().max(64), }) .describe('Role overrides used across stacks.'); @@ -855,6 +856,8 @@ export const RawConfigObject = z.object({ deployDocs: z.boolean().default(true).describe('Whether to deploy docs stacks.'), deployUi: z.boolean().default(true).describe('Whether to deploy UI stacks.'), deployMetrics: z.boolean().default(true).describe('Whether to deploy Metrics stack.'), + deployMcp: z.boolean().default(true).describe('Whether to deploy LISA MCP stack.'), + deployServe: z.boolean().default(true).describe('Whether to deploy LISA Serve stack.'), deployMcpWorkbench: z.boolean().default(true).describe('Whether to deploy MCP Workbench stack.'), logLevel: z.union([z.literal('DEBUG'), z.literal('INFO'), z.literal('WARNING'), z.literal('ERROR')]) .default('DEBUG') @@ -888,6 +891,7 @@ export const RawConfigObject = z.object({ deploymentPrefix: z.string().optional().describe('Prefix for deployment resources.'), webAppAssetsPath: z.string().optional().describe('Optional path to precompiled webapp assets. If not specified the web application will be built at deploy time.'), ecsModelDeployerPath: z.string().optional().describe('Optional path to precompiled ecs model deployer. If not specified the ecs model deployer will be built at deploy time.'), + mcpServerDeployerPath: z.string().optional().describe('Optional path to precompiled mcp server deployer. If not specified the mcp server deployer will be built at deploy time.'), vectorStoreDeployerPath: z.string().optional().describe('Optional path to precompiled vector store deployer. If not specified the vector store deployer will be built at deploy time.'), documentsPath: z.string().optional().describe('Optional path to precompiled LISA documents. If not specified the LISA docs will be built at deploy time.'), lambdaPath: z.any().optional().describe('Optional path to precompiled LISA lambda. If not specified the LISA lambda will be built at deploy time.'), diff --git a/lib/schema/constants.ts b/lib/schema/constants.ts new file mode 100644 index 000000000..e18b0fa50 --- /dev/null +++ b/lib/schema/constants.ts @@ -0,0 +1,17 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +export const APP_MANAGEMENT_KEY = 'appManagementKeySecretName'; diff --git a/lib/schema/index.ts b/lib/schema/index.ts index 284c17f80..0f72d9211 100644 --- a/lib/schema/index.ts +++ b/lib/schema/index.ts @@ -15,5 +15,7 @@ */ export * from './configSchema'; export * from './ragSchema'; +export * from './collectionSchema'; export * from './cdk'; export * from './schema'; +export * from './constants'; diff --git a/lib/schema/ragSchema.ts b/lib/schema/ragSchema.ts index 72b68c761..a3a077f76 100644 --- a/lib/schema/ragSchema.ts +++ b/lib/schema/ragSchema.ts @@ -16,6 +16,61 @@ import { z } from 'zod'; import { EbsDeviceVolumeType } from './cdk'; +/** + * Enum for chunking strategy types + */ +export enum ChunkingStrategyType { + FIXED = 'fixed', + NONE = 'none', +} + +/** + * Fixed size chunking strategy schema + */ +export const FixedSizeChunkingStrategySchema = z.object({ + type: z.literal(ChunkingStrategyType.FIXED).describe('Fixed size chunking strategy type'), + size: z.number().min(100).max(10000).default(512).describe('Size of each chunk in characters'), + overlap: z.number().min(0).default(51).describe('Overlap between chunks in characters'), +}).refine( + (data) => data.overlap <= data.size / 2, + { message: 'overlap must be less than or equal to half of size' } +); + +/** + * None chunking strategy schema - documents ingested as-is without chunking + */ +export const NoneChunkingStrategySchema = z.object({ + type: z.literal(ChunkingStrategyType.NONE).describe('No chunking - documents ingested as-is'), +}); + +/** + * Union of all chunking strategy types + */ +export const ChunkingStrategySchema = z.union([ + FixedSizeChunkingStrategySchema, + NoneChunkingStrategySchema, +]); + +export type ChunkingStrategy = z.infer; +export type FixedSizeChunkingStrategy = z.infer; +export type NoneChunkingStrategy = z.infer; + +/** + * Defines possible states for a vector store deployment. + * These statuses are used by both create-store and delete-store state machines. + */ +export enum VectorStoreStatus { + CREATE_IN_PROGRESS = 'CREATE_IN_PROGRESS', + CREATE_COMPLETE = 'CREATE_COMPLETE', + CREATE_FAILED = 'CREATE_FAILED', + UPDATE_IN_PROGRESS = 'UPDATE_IN_PROGRESS', + UPDATE_COMPLETE = 'UPDATE_COMPLETE', + UPDATE_COMPLETE_CLEANUP_IN_PROGRESS = 'UPDATE_COMPLETE_CLEANUP_IN_PROGRESS', + DELETE_IN_PROGRESS = 'DELETE_IN_PROGRESS', + DELETE_FAILED = 'DELETE_FAILED', + UNKNOWN = 'UNKNOWN', +} + /** * Enum for different types of RAG repositories available */ @@ -25,12 +80,15 @@ export enum RagRepositoryType { BEDROCK_KNOWLEDGE_BASE = 'bedrock_knowledge_base', } +export const BedrockDataSource = z.object({ + id: z.string().describe('The ID of the Bedrock Knowledge Base data source'), + name: z.string().describe('The name of the Bedrock Knowledge Base data source'), + s3Uri: z.string().regex(/^s3:\/\/[a-z0-9][a-z0-9.-]*[a-z0-9](\/.*)?$/, 'Must be a valid S3 URI (s3://bucket/prefix)').describe('The S3 URI of the data source'), +}); + export const BedrockKnowledgeBaseInstanceConfig = z.object({ - bedrockKnowledgeBaseName: z.string().describe('The name of the Bedrock Knowledge Base.'), - bedrockKnowledgeBaseId: z.string().describe('The id of the Bedrock Knowledge Base.'), - bedrockKnowledgeDatasourceName: z.string().describe('The name of the Bedrock Knowledge Datasource.'), - bedrockKnowledgeDatasourceId: z.string().describe('The id of the Bedrock Knowledge Datasource.'), - bedrockKnowledgeDatasourceS3Bucket: z.string().describe('The S3 bucket of the Bedrock Knowledge Base.'), + knowledgeBaseId: z.string().describe('The ID of the Bedrock Knowledge Base'), + dataSources: z.array(BedrockDataSource).min(1).describe('Array of data sources in this Knowledge Base'), }); export const OpenSearchNewClusterConfig = z.object({ @@ -57,9 +115,11 @@ const triggerSchema = z.object({ }); export const RagRepositoryPipeline = z.object({ - chunkSize: z.number().default(512).describe('The size of the chunks used for document segmentation.'), - chunkOverlap: z.number().default(51).describe('The size of the overlap between chunks.'), - embeddingModel: z.string().describe('The embedding model used for document ingestion in this pipeline.'), + chunkSize: z.number().optional().default(512).describe('The size of the chunks used for document segmentation.'), + chunkOverlap: z.number().optional().default(51).describe('The size of the overlap between chunks.'), + chunkingStrategy: ChunkingStrategySchema.optional().describe('Chunking strategy for documents in this pipeline.'), + embeddingModel: z.string().optional().describe('The embedding model used for document ingestion in this pipeline.'), + collectionId: z.string().optional().describe('The collection ID to ingest documents into.'), s3Bucket: z.string().describe('The S3 bucket monitored by this pipeline for document processing.'), s3Prefix: z.string() .regex(/^(?!.*(?:^|\/)\.\.?(\/|$)).*/, 'Prefix cannot contain relative path components (ie `.` or `..`)') @@ -89,14 +149,20 @@ export type RdsConfig = z.infer; export type BedrockKnowledgeBaseConfig = z.infer; +export const RagRepositoryMetadata = z.object({ + tags: z.array(z.string()).default([]).describe('Tags for categorizing and organizing the repository.'), + customFields: z.record(z.any()).optional().describe('Custom metadata fields for the repository.'), +}); + export const RagRepositoryConfigSchema = z .object({ repositoryId: z.string() .nonempty() - .regex(/^[a-z0-9-]{1,63}/, 'Only lowercase alphanumeric characters and \'-\' are supported.') + .regex(/^[a-z0-9-]{3,20}/, 'Only lowercase alphanumeric characters and \'-\' are supported.') .regex(/^(?!-).*(? { return !((input.type === RagRepositoryType.OPENSEARCH && input.opensearchConfig === undefined) || diff --git a/lib/serve/index.ts b/lib/serve/index.ts index 28edb3d10..ec5c7a973 100644 --- a/lib/serve/index.ts +++ b/lib/serve/index.ts @@ -34,6 +34,7 @@ export class LisaServeApplicationStack extends Stack { public readonly tokenTable?: ITable; public readonly guardrailsTableNamePs: StringParameter; public readonly guardrailsTable: ITable; + public readonly ecsCluster: any; /** * @param {Construct} scope - The parent or owner of the construct. @@ -51,5 +52,6 @@ export class LisaServeApplicationStack extends Stack { this.tokenTable = app.tokenTable; this.guardrailsTableNamePs = app.guardrailsTableNamePs; this.guardrailsTable = app.guardrailsTable; + this.ecsCluster = app.ecsCluster; } } diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py index 4e9ad3d9a..ccb3fc983 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py index e648c5bcf..907891e00 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py @@ -32,7 +32,6 @@ # The following are field names, not passwords or tokens API_KEY_HEADER_NAMES = [ - "authorization", "Authorization", # OpenAI Bearer token format, collides with IdP, but that's okay for this use case "Api-Key", # pragma: allowlist secret # Azure key format, can be used with Continue IDE plugin ] @@ -108,7 +107,7 @@ def id_token_is_valid( }, ) return data - except jwt.exceptions.PyJWTError as e: + except (jwt.exceptions.PyJWTError, jwt.exceptions.DecodeError) as e: logger.exception(e) return None @@ -125,7 +124,7 @@ def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: return group in current_node -def get_authorization_token(headers: Dict[str, str], header_name: str) -> str: +def get_authorization_token(headers: Dict[str, str], header_name: str = "Authorization") -> str: """Get Bearer token from Authorization headers if it exists.""" if header_name in headers: return headers.get(header_name, "").removeprefix("Bearer").strip() @@ -165,9 +164,8 @@ async def dispatch(self, request: Request, call_next) -> Response: valid = True else: for header_name in API_KEY_HEADER_NAMES: - authorization = request.headers.get(header_name, "").strip() - id_token = authorization.split(" ")[-1] - if len(id_token) > 0 and id_token_is_valid( + id_token = get_authorization_token(request.headers, header_name) + if id_token and id_token_is_valid( id_token=id_token, authority=os.environ["AUTHORITY"], client_id=os.environ["CLIENT_ID"], @@ -181,6 +179,11 @@ async def dispatch(self, request: Request, call_next) -> Response: return JSONResponse( status_code=401, content={"detail": "Unauthorized"}, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "*", + "Access-Control-Allow-Headers": "*", + }, ) return await call_next(request) @@ -206,8 +209,8 @@ def _get_token_info(self, token: str) -> Any: def is_valid_api_token(self, headers: Dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" for header_name in API_KEY_HEADER_NAMES: - token = headers.get(header_name, "").removeprefix("Bearer").strip() - if len(token) > 0: + token = get_authorization_token(headers, header_name) + if token: token_info = self._get_token_info(token) if token_info: token_expiration = int(token_info.get(TOKEN_EXPIRATION_NAME, datetime.max.timestamp())) @@ -251,8 +254,8 @@ def is_valid_api_token(self, headers: Dict[str, str]) -> bool: self._refreshTokens() for header_name in API_KEY_HEADER_NAMES: - token = headers.get(header_name, "").strip() - if token in self._secret_tokens: + token = get_authorization_token(headers, header_name) + if token and token in self._secret_tokens: return True return False diff --git a/lib/serve/mcpWorkbenchStack.ts b/lib/serve/mcpWorkbenchStack.ts index d95a677ae..4dccb042f 100644 --- a/lib/serve/mcpWorkbenchStack.ts +++ b/lib/serve/mcpWorkbenchStack.ts @@ -46,6 +46,4 @@ export class McpWorkbenchStack extends Stack { authorizer }); } - - } diff --git a/lib/serve/rest-api/Dockerfile b/lib/serve/rest-api/Dockerfile index 9681ce9bd..aef78ed7e 100644 --- a/lib/serve/rest-api/Dockerfile +++ b/lib/serve/rest-api/Dockerfile @@ -34,15 +34,7 @@ COPY ${NODEENV_CACHE_DIR} /tmp/nodeenv-cache/ # Pre-cache nodeenv for prisma-client-py # If the copied directory has content, use it (for offline environments) # Otherwise, download it during build (requires internet) -RUN mkdir -p /root/.cache/prisma-python && \ - if [ -d "/tmp/nodeenv-cache" ] && [ -n "$(ls /tmp/nodeenv-cache 2>/dev/null)" ]; then \ - echo "Using pre-cached nodeenv from host" && \ - cp -r /tmp/nodeenv-cache /root/.cache/prisma-python/nodeenv && \ - rm -rf /tmp/nodeenv-cache; \ - else \ - echo "Downloading nodeenv (requires internet)" && \ - python -m nodeenv /root/.cache/prisma-python/nodeenv; \ - fi +RUN python -m prisma version # Copy the source code into the container COPY src/ ./src diff --git a/lib/serve/rest-api/src/entrypoint.sh b/lib/serve/rest-api/src/entrypoint.sh index 2cc753bec..63180044f 100644 --- a/lib/serve/rest-api/src/entrypoint.sh +++ b/lib/serve/rest-api/src/entrypoint.sh @@ -41,11 +41,34 @@ echo "--------------------------------" head -20 litellm_config.yaml echo "--------------------------------" +# Configure logging behavior based on DEBUG environment variable +# Set DEBUG=true in ECS task definition to enable debug logging for all services +if [ "${DEBUG}" = "true" ]; then + LOG_LEVEL="DEBUG" + GUNICORN_LOG_LEVEL="debug" + PRISMA_LOG_LEVEL="info,query" +else + LOG_LEVEL="INFO" + GUNICORN_LOG_LEVEL="info" + PRISMA_LOG_LEVEL="warn" +fi + +# Configure LiteLLM logging +export LITELLM_LOG=${LOG_LEVEL} +export LITELLM_JSON_LOGS=${LITELLM_JSON_LOGS:-false} +export LITELLM_DISABLE_HEALTH_CHECK_LOGS=${LITELLM_DISABLE_HEALTH_CHECK_LOGS:-false} + +# Configure Prisma logging +export PRISMA_LOG_LEVEL=${PRISMA_LOG_LEVEL} + # Start LiteLLM in the background with better error handling echo "🚀 Starting LiteLLM server..." echo " - Config file: litellm_config.yaml" echo " - Port: 4000 (internal)" echo " - Database: Prisma with auto-push enabled" +echo " - Debug mode: ${DEBUG:-false}" +echo " - Log level: $LOG_LEVEL" +echo " - Prisma log level: $PRISMA_LOG_LEVEL" # Start LiteLLM and capture its PID litellm -c litellm_config.yaml --use_prisma_db_push > litellm.log 2>&1 & @@ -54,6 +77,12 @@ LITELLM_PID=$! echo " - LiteLLM PID: $LITELLM_PID" echo " - Log file: litellm.log" +# Tail the log file to stdout so Docker can capture it +tail -f litellm.log & +TAIL_PID=$! + +echo " - Log tail PID: $TAIL_PID" + # LiteLLM is starting in the background, proceed with Gunicorn startup # Validate THREADS variable with default value @@ -65,5 +94,8 @@ echo " - Host: $HOST" echo " - Port: $PORT" echo " - Workers: $THREADS" echo " - Timeout: 600 seconds" +echo " - Log level: $GUNICORN_LOG_LEVEL" -exec gunicorn -k uvicorn.workers.UvicornWorker -t 600 -w "$THREADS" -b "$HOST:$PORT" "src.main:app" +exec gunicorn -k uvicorn.workers.UvicornWorker -t 600 -w "$THREADS" -b "$HOST:$PORT" \ + --log-level "$GUNICORN_LOG_LEVEL" \ + "src.main:app" diff --git a/lib/serve/serveApplicationConstruct.ts b/lib/serve/serveApplicationConstruct.ts index df5be76bd..3f156fc3f 100644 --- a/lib/serve/serveApplicationConstruct.ts +++ b/lib/serve/serveApplicationConstruct.ts @@ -14,7 +14,7 @@ limitations under the License. */ import { Duration, Stack, StackProps } from 'aws-cdk-lib'; -import { AttributeType, BillingMode, ITable, Table, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; +import { ITable, Table } from 'aws-cdk-lib/aws-dynamodb'; import { Credentials, DatabaseInstance, DatabaseInstanceEngine } from 'aws-cdk-lib/aws-rds'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { Construct } from 'constructs'; @@ -23,17 +23,16 @@ import { FastApiContainer } from '../api-base/fastApiContainer'; import { ECSCluster } from '../api-base/ecsCluster'; import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; -import { BaseProps, Config } from '../schema'; +import { APP_MANAGEMENT_KEY, BaseProps, Config } from '../schema'; import { Effect, - ManagedPolicy, Policy, PolicyDocument, PolicyStatement, Role, ServicePrincipal, } from 'aws-cdk-lib/aws-iam'; -import { HostedRotation, ISecret, Secret } from 'aws-cdk-lib/aws-secretsmanager'; +import { HostedRotation, ISecret } from 'aws-cdk-lib/aws-secretsmanager'; import { SecurityGroupEnum } from '../core/iam/SecurityGroups'; import { SecurityGroupFactory } from '../networking/vpc/security-group-factory'; import { LAMBDA_PATH, REST_API_PATH } from '../util'; @@ -41,7 +40,6 @@ import { AwsCustomResource, PhysicalResourceId } from 'aws-cdk-lib/custom-resour import { getDefaultRuntime } from '../api-base/utils'; import { ISecurityGroup, Port } from 'aws-cdk-lib/aws-ec2'; import { ECSTasks } from '../api-base/ecsCluster'; -import { EventBus } from 'aws-cdk-lib/aws-events'; import { GuardrailsTable } from '../models/guardrails-table'; export type LisaServeApplicationProps = { @@ -59,7 +57,6 @@ export class LisaServeApplicationConstruct extends Construct { public readonly endpointUrl: StringParameter; public readonly tokenTable?: ITable; public readonly ecsCluster: ECSCluster; - public readonly managementKeySecretName: string; public readonly guardrailsTableNamePs: StringParameter; public readonly guardrailsTable: ITable; @@ -72,24 +69,22 @@ export class LisaServeApplicationConstruct extends Construct { super(scope, id); const { config, vpc, securityGroups } = props; - let tokenTable; - if (config.restApiConfig.internetFacing) { - // Create DynamoDB Table for enabling API token usage - tokenTable = new Table(scope, 'TokenTable', { - tableName: `${config.deploymentName}-LISAApiTokenTable`, - partitionKey: { - name: 'token', - type: AttributeType.STRING, - }, - billingMode: BillingMode.PAY_PER_REQUEST, - encryption: TableEncryption.AWS_MANAGED, - removalPolicy: config.removalPolicy, - }); - } + // TokenTable is now created in API Base, reference it from SSM parameter + // API Base stack must be deployed before Serve stack (dependency is set in stages.ts) + const tokenTableNameParameter = StringParameter.fromStringParameterName( + scope, + 'TokenTableNameParameter', + `${config.deploymentPrefix}/tokenTableName` + ); + // Reference the table by name (table is created in API Base stack) + let tokenTable = Table.fromTableName( + scope, + 'TokenTable', + tokenTableNameParameter.stringValue + ); this.tokenTable = tokenTable; - const { managementKeySecretName } = this.createManagementKeySecret(scope, config, vpc, securityGroups); - this.managementKeySecretName = managementKeySecretName; + const managementKeySecretNameStringParameter = StringParameter.fromStringParameterName(this, createCdkId([id, 'managementKeyStringParameter']), `${config.deploymentPrefix}/${APP_MANAGEMENT_KEY}`); // Create guardrails table in serve stack to avoid circular dependency const guardrailsTableConstruct = new GuardrailsTable(scope, 'GuardrailsTable', { @@ -112,7 +107,7 @@ export class LisaServeApplicationConstruct extends Construct { securityGroup: vpc.securityGroups.restApiAlbSg, tokenTable: tokenTable, vpc: vpc, - managementKeyName: managementKeySecretName + managementKeyName: managementKeySecretNameStringParameter.stringValue }); // LiteLLM requires a PostgreSQL database to support multiple-instance scaling with dynamic model management. @@ -373,73 +368,4 @@ export class LisaServeApplicationConstruct extends Construct { ); } - private createManagementKeySecret (scope: Stack, config: Config, vpc: Vpc, securityGroups: ISecurityGroup[]): { managementKeySecretName: string } { - const managementKeySecretName = `${config.deploymentName}-lisa-management-key`; - - const managementEventBus = new EventBus(scope, createCdkId([scope.node.id, 'managementEventBus']), { - eventBusName: `${config.deploymentName}-lisa-management-events`, - }); - - const managementKeySecret = new Secret(scope, createCdkId([scope.node.id, 'managementKeySecret']), { - secretName: managementKeySecretName, - description: 'LISA management key secret', - generateSecretString: { - excludePunctuation: true, - passwordLength: 16 - }, - removalPolicy: config.removalPolicy - }); - - const rotationLambda = new Function(scope, createCdkId([scope.node.id, 'managementKeyRotationLambda']), { - runtime: getDefaultRuntime(), - handler: 'management_key.handler', - code: Code.fromAsset(config.lambdaPath || LAMBDA_PATH), - timeout: Duration.minutes(5), - environment: { - EVENT_BUS_NAME: managementEventBus.eventBusName, - }, - role: new Role(scope, createCdkId([scope.node.id, 'managementKeyRotationRole']), { - assumedBy: new ServicePrincipal('lambda.amazonaws.com'), - managedPolicies: [ - ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaVPCAccessExecutionRole'), - ], - inlinePolicies: { - 'SecretsManagerRotation': new PolicyDocument({ - statements: [ - new PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'secretsmanager:DescribeSecret', - 'secretsmanager:GetSecretValue', - 'secretsmanager:PutSecretValue', - 'secretsmanager:UpdateSecretVersionStage' - ], - resources: [managementKeySecret.secretArn] - }), - new PolicyStatement({ - effect: Effect.ALLOW, - actions: ['events:PutEvents'], - resources: [managementEventBus.eventBusArn] - }) - ] - }) - } - }), - securityGroups: securityGroups, - vpc: vpc.vpc, - }); - - managementKeySecret.addRotationSchedule('RotationSchedule', { - automaticallyAfter: Duration.days(30), - rotationLambda: rotationLambda - }); - - new StringParameter(scope, createCdkId(['ManagementKeySecretName']), { - parameterName: `${config.deploymentPrefix}/managementKeySecretName`, - stringValue: managementKeySecret.secretName, - }); - - return { managementKeySecretName }; - } - } diff --git a/lib/stages.ts b/lib/stages.ts index 061681c9f..1baa94080 100644 --- a/lib/stages.ts +++ b/lib/stages.ts @@ -48,6 +48,7 @@ import { McpWorkbenchStack } from './serve/mcpWorkbenchStack'; import { UserInterfaceStack } from './user-interface'; import { LisaDocsStack } from './docs'; import { LisaMetricsStack } from './metrics'; +import { LisaMcpApiStack } from './mcp'; import fs from 'node:fs'; import { VERSION_PATH } from './util'; @@ -249,26 +250,14 @@ export class LisaServeApplicationStage extends Stage { }); this.stacks.push(coreStack); - const serveStack = new LisaServeApplicationStack(this, 'LisaServe', { - ...baseStackProps, - description: `LISA-serve: ${config.deploymentName}-${config.deploymentStage}`, - stackName: createCdkId([config.deploymentName, config.appName, 'serve', config.deploymentStage]), - vpc: networkingStack.vpc, - securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], - }); - this.stacks.push(serveStack); - serveStack.addDependency(networkingStack); - serveStack.addDependency(iamStack); - const apiBaseStack = new LisaApiBaseStack(this, 'LisaApiBase', { ...baseStackProps, - tokenTable: serveStack.tokenTable, stackName: createCdkId([config.deploymentName, config.appName, 'API']), description: `LISA-API: ${config.deploymentName}-${config.deploymentStage}`, - vpc: networkingStack.vpc + vpc: networkingStack.vpc, + securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], }); apiBaseStack.addDependency(coreStack); - apiBaseStack.addDependency(serveStack); this.stacks.push(apiBaseStack); const apiDeploymentStack = new LisaApiDeploymentStack(this, 'LisaApiDeployment', { @@ -279,115 +268,149 @@ export class LisaServeApplicationStage extends Stage { }); apiDeploymentStack.addDependency(apiBaseStack); - const modelsApiDeploymentStack = new LisaModelsApiStack(this, 'LisaModelsApiDeployment', { - ...baseStackProps, - authorizer: apiBaseStack.authorizer, - description: `LISA-models: ${config.deploymentName}-${config.deploymentStage}`, - guardrailsTable: serveStack.guardrailsTable, - lisaServeEndpointUrlPs: config.restApiConfig.internetFacing ? serveStack.endpointUrl : undefined, - restApiId: apiBaseStack.restApiId, - rootResourceId: apiBaseStack.rootResourceId, - stackName: createCdkId([config.deploymentName, config.appName, 'models', config.deploymentStage]), - securityGroups: [networkingStack.vpc.securityGroups.ecsModelAlbSg], - vpc: networkingStack.vpc, - }); - modelsApiDeploymentStack.addDependency(serveStack); - apiDeploymentStack.addDependency(modelsApiDeploymentStack); - this.stacks.push(modelsApiDeploymentStack); - - if (config.deployMcpWorkbench) { - const mcpWorkbenchStack = new McpWorkbenchStack(this, 'LisaMcpWorkbench', { - ...baseStackProps, - stackName: createCdkId([config.deploymentName, config.appName, 'mcp-workbench', config.deploymentStage]), - description: `LISA-mcp-workbench: ${config.deploymentName}-${config.deploymentStage}`, - vpc: networkingStack.vpc, - restApiId: apiBaseStack.restApiId, - rootResourceId: apiBaseStack.rootResourceId, - apiCluster: serveStack.restApi.apiCluster, - authorizer: apiBaseStack.authorizer, - }); - mcpWorkbenchStack.addDependency(coreStack); - mcpWorkbenchStack.addDependency(apiBaseStack); - mcpWorkbenchStack.addDependency(serveStack); - apiDeploymentStack.addDependency(mcpWorkbenchStack); - this.stacks.push(mcpWorkbenchStack); - } - - if (config.deployRag) { - const ragStack = new LisaRagStack(this, 'LisaRAG', { + if (config.deployMcp) { + const mcpApiStack = new LisaMcpApiStack(this, 'LisaMcpApi', { ...baseStackProps, authorizer: apiBaseStack.authorizer!, - description: `LISA-rag: ${config.deploymentName}-${config.deploymentStage}`, + description: `LISA-mcp: ${config.deploymentName}-${config.deploymentStage}`, restApiId: apiBaseStack.restApiId, rootResourceId: apiBaseStack.rootResourceId, - endpointUrl: config.restApiConfig.internetFacing ? serveStack.endpointUrl : undefined, - modelsPs: config.restApiConfig.internetFacing ? serveStack.modelsPs : undefined, - stackName: createCdkId([config.deploymentName, config.appName, 'rag', config.deploymentStage]), - securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], + stackName: createCdkId([config.deploymentName, config.appName, 'mcp', config.deploymentStage]), + securityGroups: [networkingStack.vpc.securityGroups.ecsModelAlbSg], vpc: networkingStack.vpc, }); - ragStack.addDependency(coreStack); - ragStack.addDependency(iamStack); - ragStack.addDependency(apiBaseStack); - this.stacks.push(ragStack); - apiDeploymentStack.addDependency(ragStack); + apiDeploymentStack.addDependency(mcpApiStack); + mcpApiStack.addDependency(apiBaseStack); + this.stacks.push(mcpApiStack); } - // Declare metricsStack here so that we can reference it in chatStack - let metricsStack: LisaMetricsStack | undefined; - if (config.deployMetrics) { - metricsStack = new LisaMetricsStack(this, 'LisaMetrics', { + + + if (config.deployServe) { + const serveStack = new LisaServeApplicationStack(this, 'LisaServe', { ...baseStackProps, - authorizer: apiBaseStack.authorizer!, - stackName: createCdkId([config.deploymentName, config.appName, 'metrics', config.deploymentStage]), - description: `LISA-metrics: ${config.deploymentName}-${config.deploymentStage}`, - restApiId: apiBaseStack.restApiId, - rootResourceId: apiBaseStack.rootResourceId, - securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], + description: `LISA-serve: ${config.deploymentName}-${config.deploymentStage}`, + stackName: createCdkId([config.deploymentName, config.appName, 'serve', config.deploymentStage]), vpc: networkingStack.vpc, + securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], }); - metricsStack.addDependency(apiBaseStack); - metricsStack.addDependency(coreStack); - apiDeploymentStack.addDependency(metricsStack); - this.stacks.push(metricsStack); - } + this.stacks.push(serveStack); + serveStack.addDependency(networkingStack); + serveStack.addDependency(iamStack); + serveStack.addDependency(apiBaseStack); - if (config.deployChat) { - const chatStack = new LisaChatApplicationStack(this, 'LisaChat', { + const modelsApiDeploymentStack = new LisaModelsApiStack(this, 'LisaModelsApiDeployment', { ...baseStackProps, - authorizer: apiBaseStack.authorizer!, - stackName: createCdkId([config.deploymentName, config.appName, 'chat', config.deploymentStage]), - description: `LISA-chat: ${config.deploymentName}-${config.deploymentStage}`, + authorizer: apiBaseStack.authorizer, + description: `LISA-models: ${config.deploymentName}-${config.deploymentStage}`, + guardrailsTable: serveStack.guardrailsTable, + lisaServeEndpointUrlPs: config.restApiConfig.internetFacing ? serveStack.endpointUrl : undefined, restApiId: apiBaseStack.restApiId, rootResourceId: apiBaseStack.rootResourceId, - securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], + stackName: createCdkId([config.deploymentName, config.appName, 'models', config.deploymentStage]), + securityGroups: [networkingStack.vpc.securityGroups.ecsModelAlbSg], vpc: networkingStack.vpc, }); - chatStack.addDependency(apiBaseStack); - chatStack.addDependency(coreStack); - if (metricsStack) { - chatStack.addDependency(metricsStack); + modelsApiDeploymentStack.addDependency(serveStack); + apiDeploymentStack.addDependency(modelsApiDeploymentStack); + this.stacks.push(modelsApiDeploymentStack); + + if (config.deployMcpWorkbench) { + const mcpWorkbenchStack = new McpWorkbenchStack(this, 'LisaMcpWorkbench', { + ...baseStackProps, + stackName: createCdkId([config.deploymentName, config.appName, 'mcp-workbench', config.deploymentStage]), + description: `LISA-mcp-workbench: ${config.deploymentName}-${config.deploymentStage}`, + vpc: networkingStack.vpc, + restApiId: apiBaseStack.restApiId, + rootResourceId: apiBaseStack.rootResourceId, + apiCluster: serveStack.restApi.apiCluster, + authorizer: apiBaseStack.authorizer, + }); + mcpWorkbenchStack.addDependency(coreStack); + mcpWorkbenchStack.addDependency(apiBaseStack); + mcpWorkbenchStack.addDependency(serveStack); + apiDeploymentStack.addDependency(mcpWorkbenchStack); + this.stacks.push(mcpWorkbenchStack); } - apiDeploymentStack.addDependency(chatStack); - this.stacks.push(chatStack); - if (config.deployUi) { - const uiStack = new UserInterfaceStack(this, 'LisaUserInterface', { + if (config.deployRag) { + const ragStack = new LisaRagStack(this, 'LisaRAG', { ...baseStackProps, - architecture: ARCHITECTURE, - stackName: createCdkId([config.deploymentName, config.appName, 'ui', config.deploymentStage]), - description: `LISA-user-interface: ${config.deploymentName}-${config.deploymentStage}`, + authorizer: apiBaseStack.authorizer!, + description: `LISA-rag: ${config.deploymentName}-${config.deploymentStage}`, restApiId: apiBaseStack.restApiId, rootResourceId: apiBaseStack.rootResourceId, + endpointUrl: config.restApiConfig.internetFacing ? serveStack.endpointUrl : undefined, + modelsPs: config.restApiConfig.internetFacing ? serveStack.modelsPs : undefined, + stackName: createCdkId([config.deploymentName, config.appName, 'rag', config.deploymentStage]), + securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], + vpc: networkingStack.vpc, }); - uiStack.addDependency(chatStack); - uiStack.addDependency(serveStack); - uiStack.addDependency(apiBaseStack); - apiDeploymentStack.addDependency(uiStack); - this.stacks.push(uiStack); + ragStack.addDependency(coreStack); + ragStack.addDependency(iamStack); + ragStack.addDependency(apiBaseStack); + this.stacks.push(ragStack); + apiDeploymentStack.addDependency(ragStack); } + + // Declare metricsStack here so that we can reference it in chatStack + let metricsStack: LisaMetricsStack | undefined; + if (config.deployMetrics) { + metricsStack = new LisaMetricsStack(this, 'LisaMetrics', { + ...baseStackProps, + authorizer: apiBaseStack.authorizer!, + stackName: createCdkId([config.deploymentName, config.appName, 'metrics', config.deploymentStage]), + description: `LISA-metrics: ${config.deploymentName}-${config.deploymentStage}`, + restApiId: apiBaseStack.restApiId, + rootResourceId: apiBaseStack.rootResourceId, + securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], + vpc: networkingStack.vpc, + }); + metricsStack.addDependency(apiBaseStack); + metricsStack.addDependency(coreStack); + apiDeploymentStack.addDependency(metricsStack); + this.stacks.push(metricsStack); + } + + if (config.deployChat) { + const chatStack = new LisaChatApplicationStack(this, 'LisaChat', { + ...baseStackProps, + authorizer: apiBaseStack.authorizer!, + stackName: createCdkId([config.deploymentName, config.appName, 'chat', config.deploymentStage]), + description: `LISA-chat: ${config.deploymentName}-${config.deploymentStage}`, + restApiId: apiBaseStack.restApiId, + rootResourceId: apiBaseStack.rootResourceId, + securityGroups: [networkingStack.vpc.securityGroups.lambdaSg], + vpc: networkingStack.vpc, + }); + chatStack.addDependency(apiBaseStack); + chatStack.addDependency(coreStack); + if (metricsStack) { + chatStack.addDependency(metricsStack); + } + apiDeploymentStack.addDependency(chatStack); + this.stacks.push(chatStack); + + if (config.deployUi) { + const uiStack = new UserInterfaceStack(this, 'LisaUserInterface', { + ...baseStackProps, + architecture: ARCHITECTURE, + stackName: createCdkId([config.deploymentName, config.appName, 'ui', config.deploymentStage]), + description: `LISA-user-interface: ${config.deploymentName}-${config.deploymentStage}`, + restApiId: apiBaseStack.restApiId, + rootResourceId: apiBaseStack.rootResourceId, + }); + uiStack.addDependency(chatStack); + uiStack.addDependency(serveStack); + uiStack.addDependency(apiBaseStack); + apiDeploymentStack.addDependency(uiStack); + this.stacks.push(uiStack); + } + } + } + if (config.deployDocs) { const docsStack = new LisaDocsStack(this, 'LisaDocs', { ...baseStackProps diff --git a/lib/user-interface/react/.gitignore b/lib/user-interface/react/.gitignore index a547bf36d..de9ee66e8 100644 --- a/lib/user-interface/react/.gitignore +++ b/lib/user-interface/react/.gitignore @@ -11,6 +11,7 @@ node_modules dist dist-ssr *.local +coverage # Editor directories and files .vscode/* diff --git a/lib/user-interface/react/package.json b/lib/user-interface/react/package.json index 90f8a8715..5db2b053d 100644 --- a/lib/user-interface/react/package.json +++ b/lib/user-interface/react/package.json @@ -1,7 +1,7 @@ { "name": "lisa-web", "private": true, - "version": "5.4.0", + "version": "6.0.0", "type": "module", "scripts": { "dev": "vite", @@ -10,7 +10,11 @@ "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", "preview": "vite preview", "format": "prettier --ignore-path .gitignore --write \"**/*.+(tsx|js|ts|json)\"", - "clean": "rm -rf ./dist ./node_modules" + "clean": "rm -rf ./dist ./node_modules", + "test": "vitest run", + "test:ui": "vitest --ui", + "test:watch": "vitest --watch", + "test:coverage": "vitest run --coverage" }, "dependencies": { "@cloudscape-design/chat-components": "^1.0.61", @@ -30,13 +34,13 @@ "@swc/core": "^1.11.8", "ace-builds": "^1.43.2", "axios": "^1.8.2", + "dompurify": "^3.2.5", "fdir": "^6.5.0", "git-repo-info": "^2.1.1", "jszip": "^3.10.1", "langchain": "^0.3.15", "lodash": "^4.17.21", "luxon": "^3.5.0", - "dompurify": "^3.2.5", "mermaid": "^11.10.1", "react": "^18.3.1", "react-ace": "^14.0.1", @@ -61,6 +65,9 @@ "vitepress": "^1.6.4" }, "devDependencies": { + "@testing-library/jest-dom": "^6.9.1", + "@testing-library/react": "^16.3.0", + "@testing-library/user-event": "^14.6.1", "@types/ace": "^0.0.52", "@types/markdown-it": "^14.1.2", "@types/react": "^18.3.18", @@ -71,6 +78,9 @@ "@typescript-eslint/eslint-plugin": "^8.0.0", "@typescript-eslint/parser": "^8.0.0", "@vitejs/plugin-react-swc": "^3.7.2", + "@vitest/coverage-istanbul": "^3.2.4", + "@vitest/coverage-v8": "^4.0.6", + "@vitest/ui": "^3.2.4", "autoprefixer": "^10.4.20", "eslint": "^9.0.0", "eslint-config-prettier": "^10.0.1", @@ -78,12 +88,14 @@ "eslint-plugin-prettier": "^5.2.3", "eslint-plugin-react-hooks": "^5.2.0", "eslint-plugin-react-refresh": "^0.4.18", + "jsdom": "^26.1.0", "linkify-it": "^5.0.0", "markdown-it": "^14.1.0", "postcss": "^8.5.1", "prettier": "^3.4.2", "redux-mock-store": "^1.5.5", "uuid": "^13.0.0", - "vite": "^6.3.4" + "vite": "^6.3.4", + "vitest": "^3.2.4" } } diff --git a/lib/user-interface/react/src/App.tsx b/lib/user-interface/react/src/App.tsx index cb4e27b81..a38fcd8fb 100644 --- a/lib/user-interface/react/src/App.tsx +++ b/lib/user-interface/react/src/App.tsx @@ -28,14 +28,16 @@ import SystemBanner from './components/system-banner/system-banner'; import { useAppSelector } from './config/store'; import { selectCurrentUserIsAdmin, selectCurrentUserIsUser } from './shared/reducers/user.reducer'; import ModelManagement from './pages/ModelManagement'; +import McpManagement from './pages/McpManagement'; import ModelLibrary from './pages/ModelLibrary'; +import RepositoryManagement from './pages/RepositoryManagement'; import NotificationBanner from './shared/notification/notification'; import ConfirmationModal, { ConfirmationModalProps } from './shared/modal/confirmation-modal'; import Configuration from './pages/Configuration'; import { useGetConfigurationQuery } from './shared/reducers/configuration.reducer'; import { IConfiguration } from './shared/model/configuration.model'; import DocumentLibrary from './pages/DocumentLibrary'; -import RepositoryLibrary from './pages/RepositoryLibrary'; +import CollectionLibrary from './pages/CollectionLibrary'; import { Breadcrumbs } from './shared/breadcrumb/breadcrumbs'; import BreadcrumbsDefaultChangeListener from './shared/breadcrumb/breadcrumbs-change-listener'; import PromptTemplatesLibrary from './pages/PromptTemplatesLibrary'; @@ -142,7 +144,7 @@ function App () { notifications={} stickyNotifications={true} navigation={nav} - navigationWidth={450} + navigationWidth={300} content={ } /> + {window.env.HOSTED_MCP_ENABLED && + + + } + />} + + + + } + /> {config?.configuration?.enabledComponents?.modelLibrary && - + } /> diff --git a/lib/user-interface/react/src/components/Topbar.tsx b/lib/user-interface/react/src/components/Topbar.tsx index 2e246379e..04e3ea8dd 100644 --- a/lib/user-interface/react/src/components/Topbar.tsx +++ b/lib/user-interface/react/src/components/Topbar.tsx @@ -37,7 +37,7 @@ function Topbar ({ configs }: TopbarProps): ReactElement { const auth = useAuth(); const isUserAdmin = useAppSelector(selectCurrentUserIsAdmin); const userName = useAppSelector(selectCurrentUsername); - const {colorScheme, setColorScheme} = useContext(ColorSchemeContext); + const { colorScheme, setColorScheme } = useContext(ColorSchemeContext); const libraryItems = [ ...(configs?.configuration.enabledComponents?.modelLibrary ? [{ @@ -136,6 +136,26 @@ function Topbar ({ configs }: TopbarProps): ReactElement { external: false, href: '/model-management', } as ButtonDropdownProps.Item, + { + id: 'repository-management', + type: 'button', + variant: 'link', + text: 'RAG Management', + disableUtilityCollapse: false, + external: false, + href: '/repository-management', + } as ButtonDropdownProps.Item, + ...(window.env.HOSTED_MCP_ENABLED ? [ + { + id: 'mcp-management', + type: 'button', + variant: 'link', + text: 'MCP Management', + disableUtilityCollapse: false, + external: false, + href: '/mcp-management', + } as ButtonDropdownProps.Item, + ] : []), ...(configs?.configuration.enabledComponents?.showMcpWorkbench ? [{ id: 'mcp-workbench', type: 'button', @@ -169,26 +189,27 @@ function Topbar ({ configs }: TopbarProps): ReactElement { iconName: 'user-profile', items: [ { id: 'version-info', text: `LISA v${window.gitInfo?.revisionTag}`, disabled: true }, - { id: 'color-mode', text: colorScheme === Mode.Light ? 'Dark mode' : 'Light mode', iconSvg: ( - - {' '} - - {' '} - - ) + { + id: 'color-mode', text: colorScheme === Mode.Light ? 'Dark mode' : 'Light mode', iconSvg: ( + + {' '} + + {' '} + + ) }, auth.isAuthenticated ? { id: 'signout', text: 'Sign out' } : { id: 'signin', text: 'Sign in' }, ], diff --git a/lib/user-interface/react/src/components/chatbot/Chat.tsx b/lib/user-interface/react/src/components/chatbot/Chat.tsx index de2ef75f6..8e9c954fb 100644 --- a/lib/user-interface/react/src/components/chatbot/Chat.tsx +++ b/lib/user-interface/react/src/components/chatbot/Chat.tsx @@ -121,6 +121,7 @@ export default function Chat ({ sessionId }) { const [modelFilterValue, setModelFilterValue] = useState(''); const [hasUserInteractedWithModel, setHasUserInteractedWithModel] = useState(false); const [mermaidRenderComplete, setMermaidRenderComplete] = useState(0); + const [dynamicMaxRows, setDynamicMaxRows] = useState(8); // Callback to handle Mermaid diagram rendering completion const handleMermaidRenderComplete = useCallback(() => { @@ -158,6 +159,27 @@ export default function Chat ({ sessionId }) { } }, [userPreferences, userName]); + useEffect(() => { + const calculateMaxRows = () => { + const LINE_HEIGHT = 24; // pixels per row + const RESERVED_UI_HEIGHT = 280; // model selector, buttons, status + const MAX_INPUT_PERCENTAGE = 0.5; // 50% of viewport max + + const availableHeight = window.innerHeight - RESERVED_UI_HEIGHT; + const maxInputHeight = availableHeight * MAX_INPUT_PERCENTAGE; + const calculatedMaxRows = Math.floor(maxInputHeight / LINE_HEIGHT); + + // Clamp between 3 and 12 rows + const clampedMaxRows = Math.max(3, Math.min(12, calculatedMaxRows)); + setDynamicMaxRows(clampedMaxRows); + }; + + calculateMaxRows(); + window.addEventListener('resize', calculateMaxRows); + return () => window.removeEventListener('resize', calculateMaxRows); + }, []); + + // Custom hooks const { session, @@ -251,11 +273,11 @@ export default function Chat ({ sessionId }) { return getRelevantDocuments({ query, repositoryId: ragConfig.repositoryId, - repositoryType: ragConfig.repositoryType, - modelName: ragConfig.embeddingModel?.modelId, + collectionId: ragConfig.collection?.collectionId, topK: ragTopK, + modelName: !ragConfig.collection?.collectionId ? ragConfig.embeddingModel?.modelId : undefined, }); - }, [getRelevantDocuments, chatConfiguration.sessionConfiguration, ragConfig.repositoryId, ragConfig.repositoryType, ragConfig.embeddingModel?.modelId]); + }, [getRelevantDocuments, chatConfiguration.sessionConfiguration, ragConfig.repositoryId, ragConfig.collection, ragConfig.embeddingModel]); const { isRunning, setIsRunning, isStreaming, generateResponse, stopGeneration } = useChatGeneration({ chatConfiguration, @@ -598,7 +620,7 @@ export default function Chat ({ sessionId }) { }, [shouldShowStopButton, userPrompt.length, isRunning, callingToolName, loadingSession, handleSendGenerateRequest]); return ( -
+
{/* MCP Connections - invisible components that manage the connections */} {McpConnections} {useMemo(() => ( )} -
+
{useMemo(() => session.history.map((message, idx) => ( { // Add timestamp to filename for RAG uploads to not conflict with existing S3 files @@ -219,8 +218,12 @@ export const RagUploadModal = ({ const [ingestingFiles, setIngestingFiles] = useState(false); const [ingestionStatus, setIngestionStatus] = useState(''); const [ingestionType, setIngestionType] = useState(StatusTypes.LOADING); - const [chunkSize, setChunkSize] = useState(512); - const [chunkOverlap, setChunkOverlap] = useState(51); + const [overrideChunkingStrategy, setOverrideChunkingStrategy] = useState(false); + const [chunkingStrategy, setChunkingStrategy] = useState({ + type: ChunkingStrategyType.FIXED, + size: 512, + overlap: 51, + }); const dispatch = useAppDispatch(); const [getPresignedUrl] = useLazyGetPresignedUrlQuery(); const notificationService = useNotificationService(dispatch); @@ -238,7 +241,7 @@ export const RagUploadModal = ({ const s3UploadRequest = uploadToS3Request(urlResponse.data, file); const uploadResp = await uploadToS3Mutation(s3UploadRequest); - if (uploadResp.error) { + if ('error' in uploadResp) { handleError(`Error encountered while uploading file ${file.name}`); return false; } @@ -253,23 +256,23 @@ export const RagUploadModal = ({ setIngestionStatus('Ingesting documents into the selected repository...'); try { // Ingest all of the documents which uploaded successfully - const ingestResp = await ingestDocuments({ documents: fileKeys, repositoryId: ragConfig.repositoryId, - embeddingModel: { id: ragConfig.embeddingModel.modelId, modelType: ragConfig.embeddingModel.modelType, streaming: ragConfig.embeddingModel.streaming }, + collectionId: ragConfig.collection?.collectionId, repostiroyType: ragConfig.repositoryType, - chunkSize, - chunkOverlap + chunkingStrategy: overrideChunkingStrategy ? chunkingStrategy : undefined, }); - if (ingestResp.error) { + if ('error' in ingestResp) { throw new Error('Failed to ingest documents into RAG'); } else { setIngestionType(StatusTypes.SUCCESS); - const jobIds = ingestResp.data?.ingestionJobIds || []; - setIngestionStatus(`Successfully ingested documents into the selected repository. Job IDs: ${jobIds.join(', ')}`); - notificationService.generateNotification(`Successfully ingested ${fileKeys.length} document(s) into the selected repository. Job IDs: ${jobIds.join(', ')}`, 'success'); + const jobs = ingestResp.data?.jobs || []; + const jobIds = jobs.map((job) => job.jobId); + const collectionName = ingestResp.data?.collectionName || ingestResp.data?.collectionId || 'repository'; + setIngestionStatus(`Successfully submitted documents for ingestion into ${collectionName}. Job IDs: ${jobIds.join(', ')}`); + notificationService.generateNotification(`Successfully submitted ${fileKeys.length} document(s) for ingestion into ${collectionName}. ${jobs.length} job(s) created.`, 'success'); setShowRagUploadModal(false); } } catch { @@ -332,38 +335,41 @@ export const RagUploadModal = ({

- - - { - const intVal = parseInt(event.detail.value); - if (intVal >= 0) { - setChunkSize(intVal); - } - }} - /> - - - { - const intVal = parseInt(event.detail.value); - if (intVal >= 0) { - setChunkOverlap(intVal); - } - }} - /> - - + + {/* Chunking Strategy Override Checkbox - Hidden for Bedrock repositories */} + {ragConfig.repositoryType !== RagRepositoryType.BEDROCK_KNOWLEDGE_BASE && ( + setOverrideChunkingStrategy(detail.checked)} + > + Override default chunking strategy + + )} + + {/* Chunking Strategy Form - Only shown when override is enabled and not Bedrock */} + {overrideChunkingStrategy && ragConfig.repositoryType !== RagRepositoryType.BEDROCK_KNOWLEDGE_BASE && ( + { + if (values.chunkingStrategy !== undefined) { + setChunkingStrategy(values.chunkingStrategy); + } else if (values['chunkingStrategy.size'] !== undefined) { + setChunkingStrategy({ + ...chunkingStrategy, + size: values['chunkingStrategy.size'], + }); + } else if (values['chunkingStrategy.overlap'] !== undefined) { + setChunkingStrategy({ + ...chunkingStrategy, + overlap: values['chunkingStrategy.overlap'], + }); + } + }} + touchFields={() => { }} + formErrors={{}} + /> + )} + setSelectedFiles(detail.value)} value={selectedFiles} @@ -394,7 +400,6 @@ export const RagUploadModal = ({
diff --git a/lib/user-interface/react/src/components/chatbot/components/RagOptions.tsx b/lib/user-interface/react/src/components/chatbot/components/RagOptions.tsx index addca3234..c2b7f91af 100644 --- a/lib/user-interface/react/src/components/chatbot/components/RagOptions.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/RagOptions.tsx @@ -18,9 +18,12 @@ import { Autosuggest, Grid, SpaceBetween } from '@cloudscape-design/components'; import { useEffect, useMemo, useState, useRef } from 'react'; import { useGetAllModelsQuery } from '@/shared/reducers/model-management.reducer'; import { IModel, ModelStatus, ModelType } from '@/shared/model/model-management.model'; -import { useListRagRepositoriesQuery } from '@/shared/reducers/rag.reducer'; +import { useListRagRepositoriesQuery, useListCollectionsQuery, RagCollectionConfig } from '@/shared/reducers/rag.reducer'; +import { RagRepositoryType, VectorStoreStatus } from '#root/lib/schema'; +import { CollectionStatus } from '#root/lib/schema/collectionSchema'; export type RagConfig = { + collection?: RagCollectionConfig; embeddingModel: IModel; repositoryId: string; repositoryType: string; @@ -33,40 +36,67 @@ type RagControlProps = { ragConfig: RagConfig; }; -export default function RagControls ({isRunning, setUseRag, setRagConfig, ragConfig }: RagControlProps) { +export default function RagControls ({ isRunning, setUseRag, setRagConfig, ragConfig }: RagControlProps) { const { data: repositories, isLoading: isLoadingRepositories } = useListRagRepositoriesQuery(undefined, { refetchOnMountOrArgChange: 5 }); - const { data: allModels, isLoading: isLoadingModels } = useGetAllModelsQuery(undefined, {refetchOnMountOrArgChange: 5, + + const { data: collections, isLoading: isLoadingCollections } = useListCollectionsQuery( + { repositoryId: ragConfig?.repositoryId }, + { + skip: !ragConfig?.repositoryId, + refetchOnMountOrArgChange: 5 + } + ); + + const { data: allModels } = useGetAllModelsQuery(undefined, { + refetchOnMountOrArgChange: 5, selectFromResult: (state) => ({ isLoading: state.isLoading, data: (state.data || []).filter((model) => model.modelType === ModelType.embedding && model.status === ModelStatus.InService), - })}); + }) + }); - const [userHasSelectedModel, setUserHasSelectedModel] = useState(false); + const [userHasSelectedCollection, setUserHasSelectedCollection] = useState(false); const lastRepositoryIdRef = useRef(undefined); const selectedRepositoryOption = ragConfig?.repositoryId ?? ''; - const selectedEmbeddingOption = ragConfig?.embeddingModel?.modelId ?? ''; - - const embeddingOptions = useMemo(() => { - if (!allModels || !selectedRepositoryOption) return []; - - const repository = repositories?.find((repo) => repo.repositoryId === selectedRepositoryOption); - const defaultModelId = repository?.embeddingModelId; - - return allModels.map((model) => ({ - value: model.modelId, - label: model.modelId + (model.modelId === defaultModelId ? ' (default)' : '') - })); - }, [allModels, repositories, selectedRepositoryOption]); + const selectedCollectionOption = ragConfig?.collection?.name ?? ''; + + const filteredRepositories = useMemo(() => { + if (!repositories) return []; + return repositories.filter((repo) => + repo.status === VectorStoreStatus.CREATE_COMPLETE || repo.status === VectorStoreStatus.UPDATE_COMPLETE + ); + }, [repositories]); + + const collectionOptions = useMemo(() => { + if (!collections) return []; + // Filter to only show ACTIVE collections + return collections + .filter((collection) => collection.status === CollectionStatus.ACTIVE) + .map((collection) => ({ + value: collection.collectionId, + label: collection.name, + })); + }, [collections]); + // Update useRag flag based on repository type and configuration useEffect(() => { - setUseRag(!!selectedEmbeddingOption && !!selectedRepositoryOption); - }, [selectedRepositoryOption, selectedEmbeddingOption, setUseRag]); - - // Effect for handling repository changes and auto-selection + const hasRepository = !!ragConfig?.repositoryId; + const hasCollection = !!ragConfig?.collection; + const hasEmbeddingModel = !!ragConfig?.embeddingModel; + const isBedrockRepo = ragConfig?.repositoryType === RagRepositoryType.BEDROCK_KNOWLEDGE_BASE; + // For Bedrock repositories: require both repository AND collection + // For non-Bedrock repositories: require repository AND embedding model (or collection as alternative) + if (isBedrockRepo) { + setUseRag(hasRepository && hasCollection); + } else { + setUseRag(hasRepository && (hasEmbeddingModel || hasCollection)); + } + }, [ragConfig?.repositoryId, ragConfig?.repositoryType, ragConfig?.collection, ragConfig?.embeddingModel, setUseRag]); + // Effect for handling repository changes, default collection, and default embedding model selection useEffect(() => { const currentRepositoryId = ragConfig?.repositoryId; const repositoryHasChanged = currentRepositoryId !== lastRepositoryIdRef.current; @@ -74,41 +104,63 @@ export default function RagControls ({isRunning, setUseRag, setRagConfig, ragCon // Update tracking and reset user selection flag when repository changes if (repositoryHasChanged) { lastRepositoryIdRef.current = currentRepositoryId; - setUserHasSelectedModel(false); + setUserHasSelectedCollection(false); } - // Auto-select default model when repository changes or no model is set - if (currentRepositoryId && repositories && allModels) { - const repository = repositories.find((repo) => repo.repositoryId === currentRepositoryId); + if (currentRepositoryId && filteredRepositories && allModels && !userHasSelectedCollection) { + const repository = filteredRepositories.find((repo) => repo.repositoryId === currentRepositoryId); + const isNonBedrockRepo = repository?.type !== RagRepositoryType.BEDROCK_KNOWLEDGE_BASE; + + // For non-bedrock repositories, auto-select the first available collection if it exists + if (isNonBedrockRepo && collections && collections.length > 0 && !ragConfig?.collection) { + const activeCollections = collections.filter((c) => c.status === CollectionStatus.ACTIVE); + if (activeCollections.length > 0) { + const defaultCollection = activeCollections[0]; + const embeddingModel = allModels.find((model) => model.modelId === defaultCollection.embeddingModel); + + setRagConfig((config) => ({ + ...config, + collection: defaultCollection, + embeddingModel: embeddingModel, + })); + return; + } + } - if (repository?.embeddingModelId) { + // Set default embedding model when no collection is selected + if (repository?.embeddingModelId && !ragConfig?.collection) { const defaultModel = allModels.find((model) => model.modelId === repository.embeddingModelId); - if (defaultModel) { - const shouldAutoSwitch = repositoryHasChanged || - (!ragConfig?.embeddingModel && !userHasSelectedModel); - - if (shouldAutoSwitch) { - setRagConfig((config) => ({ - ...config, - embeddingModel: defaultModel, - })); - } + if (defaultModel && !ragConfig?.embeddingModel) { + setRagConfig((config) => ({ + ...config, + embeddingModel: defaultModel, + })); } } } - }, [ragConfig?.repositoryId, ragConfig?.embeddingModel, repositories, allModels, userHasSelectedModel, setRagConfig]); + }, [ + ragConfig?.repositoryId, + ragConfig?.collection, + ragConfig?.embeddingModel, + filteredRepositories, + allModels, + collections, + userHasSelectedCollection, + setRagConfig + ]); const handleRepositoryChange = ({ detail }) => { const newRepositoryId = detail.value; - setUserHasSelectedModel(false); // Reset when repository changes + setUserHasSelectedCollection(false); // Reset collection selection flag if (newRepositoryId) { - const repository = repositories?.find((repo) => repo.repositoryId === newRepositoryId); + const repository = filteredRepositories?.find((repo) => repo.repositoryId === newRepositoryId); setRagConfig((config) => ({ ...config, repositoryId: newRepositoryId, repositoryType: repository?.type || 'unknown', + collection: undefined, // Clear collection when repository changes embeddingModel: undefined, // Clear current model so useEffect can set default })); } else { @@ -116,27 +168,45 @@ export default function RagControls ({isRunning, setUseRag, setRagConfig, ragCon ...config, repositoryId: undefined, repositoryType: undefined, + collection: undefined, embeddingModel: undefined, })); } }; - const handleModelChange = ({ detail }) => { - const newModelId = detail.value; - setUserHasSelectedModel(true); // Mark that user has made an explicit choice + const handleCollectionChange = ({ detail }) => { + const newCollectionId = detail.value; + setUserHasSelectedCollection(true); + + if (newCollectionId) { + const collection = collections?.find( + (c) => c.collectionId === newCollectionId + ); + if (collection) { + // Find the embedding model from allModels + const embeddingModel = allModels?.find( + (model) => model.modelId === collection.embeddingModel + ); - if (newModelId) { - const model = allModels.find((model) => model.modelId === newModelId); - if (model) { setRagConfig((config) => ({ ...config, - embeddingModel: model, + collection: collection, + embeddingModel: embeddingModel, })); } } else { + // User cleared collection - fall back to repository default + const repository = filteredRepositories?.find( + (repo) => repo.repositoryId === ragConfig?.repositoryId + ); + const defaultModel = allModels?.find( + (model) => model.modelId === repository?.embeddingModelId + ); + setRagConfig((config) => ({ ...config, - embeddingModel: undefined, + collection: undefined, + embeddingModel: defaultModel, })); } }; @@ -159,22 +229,22 @@ export default function RagControls ({isRunning, setUseRag, setRagConfig, ragCon value={selectedRepositoryOption} enteredTextLabel={(text) => `Use: "${text}"`} onChange={handleRepositoryChange} - options={repositories?.map((repository) => ({ + options={filteredRepositories?.map((repository) => ({ value: repository.repositoryId, label: repository?.repositoryName?.length ? repository?.repositoryName : repository.repositoryId })) || []} /> No embedding models available.
} + statusType={isLoadingCollections ? 'loading' : 'finished'} + loadingText='Loading collections...' + placeholder='Select a collection (optional)' + empty={
No collections available.
} filteringType='auto' - value={selectedEmbeddingOption} + value={selectedCollectionOption} enteredTextLabel={(text) => `Use: "${text}"`} - onChange={handleModelChange} - options={embeddingOptions} + onChange={handleCollectionChange} + options={collectionOptions} /> diff --git a/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx b/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx index bc8e879b7..485f6e376 100644 --- a/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx @@ -212,7 +212,7 @@ export function Sessions ({ newSession }) {
- + = { - showMcpWorkbench: {prerequisites: ['mcpConnections'] }, - mcpConnections: {dependents: ['showMcpWorkbench']} + showMcpWorkbench: { prerequisites: ['mcpConnections'] }, + mcpConnections: { dependents: ['showMcpWorkbench'] } }; const configurableOperations = [{ - header: 'RAG Components', + header: 'RAG', items: ragOptions }, { - header: 'Library Components', + header: 'Library', items: libraryOptions }, { - header: 'In-Context Components', + header: 'In-Context', items: inContextOptions }, { - header: 'Advanced Components', + header: 'Advanced', items: advancedOptions }, { - header: 'MCP Components', + header: 'MCP', items: mcpOptions }]; @@ -141,7 +141,7 @@ export function ActivatedUserComponents (props: ActivatedComponentConfigurationP - Activated Chat UI Components + Chat Features
}> diff --git a/lib/user-interface/react/src/components/configuration/ConfigurationComponent.tsx b/lib/user-interface/react/src/components/configuration/ConfigurationComponent.tsx index 7e5e90bab..2b7a026c2 100644 --- a/lib/user-interface/react/src/components/configuration/ConfigurationComponent.tsx +++ b/lib/user-interface/react/src/components/configuration/ConfigurationComponent.tsx @@ -28,7 +28,6 @@ import { selectCurrentUsername } from '../../shared/reducers/user.reducer'; import { getJsonDifference } from '../../shared/util/validationUtils'; import { setConfirmationModal } from '../../shared/reducers/modal.reducer'; import { useNotificationService } from '../../shared/util/hooks'; -import RepositoryTable from './RepositoryTable'; import { mcpServerApi } from '@/shared/reducers/mcp-server.reducer'; export type ConfigState = { @@ -151,9 +150,9 @@ export function ConfigurationComponent (): ReactElement {
- LISA App Configuration + LISA Feature Configuration
-
- RAG Repository Configuration -
-
); } diff --git a/lib/user-interface/react/src/components/configuration/createRepository/BedrockKnowledgeBaseConfigForm.tsx b/lib/user-interface/react/src/components/configuration/createRepository/BedrockKnowledgeBaseConfigForm.tsx deleted file mode 100644 index 5473340bf..000000000 --- a/lib/user-interface/react/src/components/configuration/createRepository/BedrockKnowledgeBaseConfigForm.tsx +++ /dev/null @@ -1,78 +0,0 @@ -/** - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -import Container from '@cloudscape-design/components/container'; -import { Header, SpaceBetween } from '@cloudscape-design/components'; -import FormField from '@cloudscape-design/components/form-field'; -import Input from '@cloudscape-design/components/input'; -import React, { ReactElement } from 'react'; -import { FormProps } from '../../../shared/form/form-props'; -import { BedrockKnowledgeBaseConfig as BedrockKnowledgeBaseConfigSchema, BedrockKnowledgeBaseInstanceConfig } from '#root/lib/schema'; - -type BedrockKnowledgeBaseConfigProps = { - isEdit: boolean -}; - -export function BedrockKnowledgeBaseConfigForm (props: FormProps & BedrockKnowledgeBaseConfigProps): ReactElement { - const { item, touchFields, setFields, formErrors, isEdit } = props; - - return ( - Bedrock Knowledge Base Config}> - - - touchFields(['bedrockKnowledgeBaseConfig.bedrockKnowledgeBaseName'])} - onChange={({ detail }) => setFields({ 'bedrockKnowledgeBaseConfig.bedrockKnowledgeBaseName': detail.value })} - placeholder='Knowledge Base Name' disabled={isEdit} /> - - - touchFields(['bedrockKnowledgeBaseConfig.bedrockKnowledgeBaseId'])} - onChange={({ detail }) => setFields({ 'bedrockKnowledgeBaseConfig.bedrockKnowledgeBaseId': detail.value })} - placeholder='Knowledge Base ID' disabled={isEdit} /> - - - touchFields(['bedrockKnowledgeBaseConfig.bedrockKnowledgeDatasourceName'])} - onChange={({ detail }) => setFields({ 'bedrockKnowledgeBaseConfig.bedrockKnowledgeDatasourceName': detail.value })} - placeholder='Knowledge Base Datasource Name' disabled={isEdit} /> - - - touchFields(['bedrockKnowledgeBaseConfig.bedrockKnowledgeDatasourceId'])} - onChange={({ detail }) => setFields({ 'bedrockKnowledgeBaseConfig.bedrockKnowledgeDatasourceId': detail.value })} - placeholder='Knowledge Base Datasource ID' disabled={isEdit} /> - - - touchFields(['bedrockKnowledgeBaseConfig.bedrockKnowledgeDatasourceS3Bucket'])} - onChange={({ detail }) => setFields({ 'bedrockKnowledgeBaseConfig.bedrockKnowledgeDatasourceS3Bucket': detail.value })} - placeholder='Knowledge Base Datasource S3 Bucket' disabled={isEdit} /> - - - - ); -} diff --git a/lib/user-interface/react/src/components/document-library/CollectionLibraryComponent.test.tsx b/lib/user-interface/react/src/components/document-library/CollectionLibraryComponent.test.tsx new file mode 100644 index 000000000..b85c92346 --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/CollectionLibraryComponent.test.tsx @@ -0,0 +1,309 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { screen, waitFor } from '@testing-library/react'; +import { CollectionLibraryComponent } from './CollectionLibraryComponent'; +import { renderWithProviders } from '../../test/helpers/render'; +import { + createMockCollections, + createMockCollection, +} from '../../test/factories/collection.factory'; +import { MemoryRouter } from 'react-router-dom'; +import * as ragReducer from '../../shared/reducers/rag.reducer'; +import * as modelReducer from '../../shared/reducers/model-management.reducer'; + +const mockNavigate = vi.fn(); + +vi.mock('react-router-dom', async () => { + const actual: any = await vi.importActual('react-router-dom'); + return { + ...actual, + useNavigate: () => mockNavigate, + }; +}); + +describe('CollectionLibraryComponent', () => { + beforeEach(() => { + vi.clearAllMocks(); + + // Default mocks + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: [], + isLoading: false, + isError: false, + error: undefined, + refetch: vi.fn(), + } as any); + + vi.spyOn(ragReducer, 'useDeleteCollectionMutation').mockReturnValue([ + vi.fn(), + { isLoading: false, isError: false, error: undefined }, + ] as any); + + // Mock model management API + vi.spyOn(modelReducer, 'useGetAllModelsQuery').mockReturnValue({ + data: [], + isFetching: false, + isLoading: false, + isError: false, + error: undefined, + refetch: vi.fn(), + } as any); + }); + + describe('Rendering', () => { + it('should display collections in table format', async () => { + const mockCollections = createMockCollections(3); + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: mockCollections, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + // Check for the Collections header + expect(screen.getByText('Collections')).toBeInTheDocument(); + // Check for the count + expect(screen.getByText('3')).toBeInTheDocument(); + // Check for column headers - use getAllByText since modal also has "Collection Name" + const collectionNameHeaders = screen.getAllByText('Collection Name'); + expect(collectionNameHeaders.length).toBeGreaterThan(0); + }); + }); + + it('should render collections header with count', async () => { + const mockCollections = createMockCollections(5); + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: mockCollections, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('Collections')).toBeInTheDocument(); + expect(screen.getByText('5')).toBeInTheDocument(); + }); + }); + + it('should display collection data in table rows', async () => { + const mockCollection = createMockCollection({ + name: 'Engineering Docs', + collectionId: 'eng-123', + repositoryId: 'repo-456', + }); + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: [mockCollection], + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('Engineering Docs')).toBeInTheDocument(); + expect(screen.getByText('repo-456')).toBeInTheDocument(); + }); + }); + + it('should show loading state', async () => { + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: undefined, + isLoading: true, + } as any); + + renderWithProviders( + + + + ); + + expect(screen.getByText('Loading collections')).toBeInTheDocument(); + }); + + it('should show empty state when no collections', async () => { + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: [], + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('No collections')).toBeInTheDocument(); + }); + }); + }); + + describe('Navigation', () => { + it('should have link to document library', async () => { + const mockCollection = createMockCollection({ + collectionId: 'col-123', + repositoryId: 'repo-456', + name: 'Test Collection', + }); + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: [mockCollection], + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + const link = screen.getByText('Test Collection'); + expect(link).toBeInTheDocument(); + expect(link.closest('a')).toHaveAttribute('href', '#/document-library/repo-456/col-123'); + }); + }); + }); + + describe('Actions Button', () => { + it('should render Actions button for admin users', async () => { + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: createMockCollections(1), + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('Actions')).toBeInTheDocument(); + }); + }); + + it('should not render Actions button for non-admin users', async () => { + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: createMockCollections(1), + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.queryByText('Actions')).not.toBeInTheDocument(); + }); + }); + + it('should disable Actions button when no collection is selected', async () => { + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: createMockCollections(1), + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + const actionsButton = screen.getByText('Actions').closest('button'); + expect(actionsButton).toBeDisabled(); + }); + }); + }); + + describe('Refresh Functionality', () => { + it('should render refresh button', async () => { + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: createMockCollections(1), + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + const refreshButton = screen.getByLabelText('Refresh collections'); + expect(refreshButton).toBeInTheDocument(); + }); + }); + }); + + describe('Filter Functionality', () => { + it('should render filter input', async () => { + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: createMockCollections(3), + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByPlaceholderText('Find collections')).toBeInTheDocument(); + }); + }); + }); + + describe('Pagination', () => { + it('should handle large number of collections', async () => { + // Create enough collections to test pagination behavior + const mockCollections = createMockCollections(25); + vi.spyOn(ragReducer, 'useListAllCollectionsQuery').mockReturnValue({ + data: mockCollections, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + // Verify the component renders successfully with many items + expect(screen.getByText('Collections')).toBeInTheDocument(); + expect(screen.getByText('25')).toBeInTheDocument(); + }); + }); + }); +}); diff --git a/lib/user-interface/react/src/components/document-library/CollectionLibraryComponent.tsx b/lib/user-interface/react/src/components/document-library/CollectionLibraryComponent.tsx new file mode 100644 index 000000000..a573f10ac --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/CollectionLibraryComponent.tsx @@ -0,0 +1,280 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { ReactElement, useState } from 'react'; +import { Box, Button, ButtonDropdown, CollectionPreferences, Header, Icon, Pagination, Table, TextFilter } from '@cloudscape-design/components'; +import SpaceBetween from '@cloudscape-design/components/space-between'; +import { + COLLECTION_COLUMN_DEFINITIONS, + getCollectionTablePreference, + getDefaultCollectionPreferences, + PAGE_SIZE_OPTIONS, +} from '@/components/document-library/CollectionTableConfig'; +import { ragApi, useDeleteCollectionMutation, useListAllCollectionsQuery } from '@/shared/reducers/rag.reducer'; +import { useLocalStorage } from '@/shared/hooks/use-local-storage'; +import { useCollection } from '@cloudscape-design/collection-hooks'; +import { useAppDispatch } from '@/config/store'; +import { setConfirmationModal } from '@/shared/reducers/modal.reducer'; +import { CreateCollectionModal } from '@/components/document-library/createCollection/CreateCollectionModal'; +import { CollectionStatus } from '#root/lib/schema/collectionSchema'; + +type CollectionLibraryComponentProps = { + admin?: boolean; +}; + +export function CollectionLibraryComponent ({ admin = false }: CollectionLibraryComponentProps): ReactElement { + const { + data: allCollections, + isLoading: fetchingCollections, + } = useListAllCollectionsQuery(undefined, { refetchOnMountOrArgChange: 5 }); + + const [deleteCollection, { isLoading: isDeleteLoading }] = useDeleteCollectionMutation(); + const dispatch = useAppDispatch(); + + const [preferences, setPreferences] = useLocalStorage( + 'CollectionLibraryPreferences', + getDefaultCollectionPreferences() + ); + + // Modal state + const [createCollectionModalVisible, setCreateCollectionModalVisible] = useState(false); + const [isEdit, setIsEdit] = useState(false); + + const { items, actions, filteredItemsCount, collectionProps, filterProps, paginationProps } = useCollection( + allCollections ?? [], + { + filtering: { + empty: ( + + + No collections + + + ), + }, + pagination: { pageSize: preferences.pageSize }, + sorting: { + defaultState: { + sortingColumn: { + sortingField: 'name', + }, + }, + }, + selection: { + trackBy: (item) => `${item.repositoryId}#${item.collectionId}`, + }, + } + ); + + const selectedCollection = collectionProps.selectedItems.length === 1 ? collectionProps.selectedItems[0] : null; + const isDefaultCollection = (selectedCollection as any)?.default === true; + const collectionStatus = selectedCollection?.status; + + // Determine which actions should be disabled based on status + const isEditDisabled = !selectedCollection || + isDefaultCollection || + collectionStatus === CollectionStatus.ARCHIVED || + collectionStatus === CollectionStatus.DELETED || + collectionStatus === CollectionStatus.DELETE_IN_PROGRESS; + + const isDeleteDisabled = !selectedCollection || + collectionStatus === CollectionStatus.DELETED || + collectionStatus === CollectionStatus.DELETE_IN_PROGRESS; + + const getEditDisabledReason = () => { + if (!selectedCollection) return 'Please select a collection'; + if (isDefaultCollection) return 'Cannot edit default collection'; + if (collectionStatus === CollectionStatus.ARCHIVED) return 'Cannot edit archived collection'; + if (collectionStatus === CollectionStatus.DELETED) return 'Cannot edit deleted collection'; + if (collectionStatus === CollectionStatus.DELETE_IN_PROGRESS) return 'Cannot edit collection being deleted'; + return undefined; + }; + + const getDeleteDisabledReason = () => { + if (!selectedCollection) return 'Please select a collection'; + if (collectionStatus === CollectionStatus.DELETED) return 'Collection already deleted'; + if (collectionStatus === CollectionStatus.DELETE_IN_PROGRESS) return 'Collection deletion in progress'; + return undefined; + }; + + const handleSelectionChange = ({ detail }) => { + if (admin) { + actions.setSelectedItems(detail.selectedItems); + } + // Navigation is now handled by onRowClick to separate selection from navigation + }; + + const handleAction = async (e: any) => { + switch (e.detail.id) { + case 'edit': { + setIsEdit(true); + setCreateCollectionModalVisible(true); + break; + } + case 'delete': { + if (!selectedCollection) return; + + dispatch( + setConfirmationModal({ + action: 'Delete', + resourceName: 'Collection', + onConfirm: () => + deleteCollection({ + repositoryId: selectedCollection.repositoryId, + collectionId: selectedCollection.collectionId, + embeddingModel: selectedCollection.embeddingModel, + default: (selectedCollection as any).default, + }), + description: ( +
+ Are you sure you want to delete the collection{' '} + {selectedCollection.name || selectedCollection.collectionId}? +
+
+ {isDefaultCollection ? ( + <> + Note: This will remove all documents in the default collection, + but the collection will remain visible in the Collection Library. This is a clean up operation. +
+
+ + ) : ( + <>This action cannot be undone. + )} +
+ ), + }), + ); + break; + } + default: + console.error('Action not implemented', e.detail.id); + } + }; + + return ( + <> + {admin && ( + + )} + + } + header={ +
+ + {admin && ( + <> + + Actions + + + + )} + + } + > + Collections +
+ } + pagination={} + preferences={ + setPreferences(detail)} + contentDisplayPreference={{ + title: 'Select visible columns', + options: getCollectionTablePreference(), + }} + pageSizePreference={{ title: 'Page size', options: PAGE_SIZE_OPTIONS }} + /> + } + /> + + ); +} + +export default CollectionLibraryComponent; diff --git a/lib/user-interface/react/src/components/document-library/CollectionTableConfig.tsx b/lib/user-interface/react/src/components/document-library/CollectionTableConfig.tsx new file mode 100644 index 000000000..26d0923ce --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/CollectionTableConfig.tsx @@ -0,0 +1,133 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { CollectionPreferencesProps, TableProps } from '@cloudscape-design/components'; +import { DEFAULT_PAGE_SIZE_OPTIONS } from '@/shared/preferences/common-preferences'; +import Badge from '@cloudscape-design/components/badge'; +import Link from '@cloudscape-design/components/link'; +import StatusIndicator, { StatusIndicatorProps } from '@cloudscape-design/components/status-indicator'; +import { ReactNode } from 'react'; +import { RagCollectionConfig } from '@/shared/reducers/rag.reducer'; +import { CollectionStatus } from '#root/lib/schema'; + +export const PAGE_SIZE_OPTIONS = DEFAULT_PAGE_SIZE_OPTIONS('Collections'); + +export type CollectionTableRow = TableProps.ColumnDefinition & { + visible: boolean; + header: string; +}; + +export const COLLECTION_COLUMN_DEFINITIONS: ReadonlyArray = [ + { + id: 'name', + header: 'Collection Name', + cell: (collection) => ( + <> + + {collection.name || collection.collectionId} + + {(collection as any).default === true && ( + <> Global + )} + + ), + sortingField: 'name', + visible: true, + isRowHeader: true, + }, + { + id: 'collectionId', + header: 'Collection ID', + cell: (collection) => collection.collectionId, + sortingField: 'collectionId', + visible: false, + }, + { + id: 'repositoryId', + header: 'Repository', + cell: (collection) => ( + + {collection.repositoryId} + + ), + sortingField: 'repositoryId', + visible: true, + }, + { + id: 'embeddingModel', + header: 'Embedding Model', + cell: (collection) => collection.embeddingModel || '-', + visible: true, + }, + { + id: 'allowedGroups', + header: 'Allowed Groups', + cell: (collection) => { + if (!collection.allowedGroups || collection.allowedGroups.length === 0) { + return (public); + } + return collection.allowedGroups.join(', '); + }, + visible: true, + }, + { + id: 'status', + header: 'Status', + cell: (collection) => getStatusIndicator(collection.status), + visible: true, + }, +]; + +function getStatusIndicator (status: CollectionStatus): ReactNode { + let type: StatusIndicatorProps.Type; + switch (status) { + case 'ACTIVE': + type = 'success'; + break; + case 'DELETE_IN_PROGRESS': + type = 'pending'; + break; + case 'ARCHIVED': + case 'DELETED': + type = 'stopped'; + break; + case 'DELETE_FAILED': + type = 'error'; + break; + } + return {status}; +} + +export function getCollectionTablePreference (): ReadonlyArray { + return COLLECTION_COLUMN_DEFINITIONS.map((c) => ({ + id: c.id!, + label: c.header, + })); +} + +export function getCollectionTableColumnDisplay (): CollectionPreferencesProps.ContentDisplayItem[] { + return COLLECTION_COLUMN_DEFINITIONS.map((c) => ({ + id: c.id!, + visible: c.visible, + })); +} + +export function getDefaultCollectionPreferences (): CollectionPreferencesProps.Preferences { + return { + pageSize: PAGE_SIZE_OPTIONS[0].value, + contentDisplay: getCollectionTableColumnDisplay(), + }; +} diff --git a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.test.tsx b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.test.tsx new file mode 100644 index 000000000..75d3950b4 --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.test.tsx @@ -0,0 +1,338 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { screen, waitFor } from '@testing-library/react'; +import { DocumentLibraryComponent, getMatchesCountText } from './DocumentLibraryComponent'; +import { renderWithProviders } from '../../test/helpers/render'; +import { MemoryRouter } from 'react-router-dom'; +import { createMockDocument } from '../../test/factories/document.factory'; +import * as ragReducer from '../../shared/reducers/rag.reducer'; +import * as store from '../../config/store'; + +vi.mock('../../shared/util/downloader', () => ({ + downloadFile: vi.fn(), +})); + +describe('DocumentLibraryComponent', () => { + beforeEach(() => { + vi.clearAllMocks(); + + // Mock Redux selectors + vi.spyOn(store, 'useAppSelector').mockImplementation((selector: any) => { + if (selector.toString().includes('selectCurrentUsername')) return 'test-user'; + if (selector.toString().includes('selectCurrentUserIsAdmin')) return false; + return null; + }); + + vi.spyOn(store, 'useAppDispatch').mockReturnValue(vi.fn() as any); + + // Default mocks for queries + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: [], + totalDocuments: 0, + hasNextPage: false, + }, + isLoading: false, + } as any); + + vi.spyOn(ragReducer, 'useGetCollectionQuery').mockReturnValue({ + data: null, + } as any); + + vi.spyOn(ragReducer, 'useDeleteRagDocumentsMutation').mockReturnValue([ + vi.fn(), + { isLoading: false }, + ] as any); + + vi.spyOn(ragReducer, 'useLazyDownloadRagDocumentQuery').mockReturnValue([ + vi.fn(), + { isFetching: false }, + ] as any); + }); + + describe('Rendering', () => { + it('should render document table with repository ID in header', async () => { + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('repo-123 Documents')).toBeInTheDocument(); + }); + }); + + it('should render collection name in header when collectionId is provided', async () => { + vi.spyOn(ragReducer, 'useGetCollectionQuery').mockReturnValue({ + data: { + collectionId: 'col-123', + name: 'Engineering Docs', + }, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('Engineering Docs Documents')).toBeInTheDocument(); + }); + }); + + it('should display documents in table', async () => { + const mockDocs = [ + createMockDocument({ document_name: 'doc1.pdf' }), + createMockDocument({ document_name: 'doc2.pdf', document_id: 'doc-456' }), + ]; + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: mockDocs, + totalDocuments: 2, + hasNextPage: false, + }, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('doc1.pdf')).toBeInTheDocument(); + expect(screen.getByText('doc2.pdf')).toBeInTheDocument(); + }); + }); + + it('should show loading state', async () => { + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: undefined, + isLoading: true, + } as any); + + renderWithProviders( + + + + ); + + expect(screen.getByText('Loading documents')).toBeInTheDocument(); + }); + + it('should show empty state when no documents', async () => { + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: [], + totalDocuments: 0, + hasNextPage: false, + }, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('No documents')).toBeInTheDocument(); + }); + }); + + it('should display document count in header', async () => { + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: [createMockDocument()], + totalDocuments: 42, + hasNextPage: false, + }, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('(42)')).toBeInTheDocument(); + }); + }); + }); + + describe('Actions Button', () => { + it('should render Actions button', async () => { + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: [createMockDocument()], + totalDocuments: 1, + hasNextPage: false, + }, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByText('Actions')).toBeInTheDocument(); + }); + }); + + it('should disable Actions button when no documents selected', async () => { + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: [createMockDocument()], + totalDocuments: 1, + hasNextPage: false, + }, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + const actionsButton = screen.getByText('Actions').closest('button'); + expect(actionsButton).toBeDisabled(); + }); + }); + }); + + describe('Refresh Functionality', () => { + it('should render refresh button', async () => { + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: [createMockDocument()], + totalDocuments: 1, + hasNextPage: false, + }, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + const refreshButton = screen.getByLabelText('Refresh documents'); + expect(refreshButton).toBeInTheDocument(); + }); + }); + }); + + describe('Filter Functionality', () => { + it('should render filter input', async () => { + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: [createMockDocument()], + totalDocuments: 1, + hasNextPage: false, + }, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + const filterInput = screen.getByRole('searchbox'); + expect(filterInput).toBeInTheDocument(); + }); + }); + }); + + describe('Pagination', () => { + it('should render pagination controls', async () => { + vi.spyOn(ragReducer, 'useListRagDocumentsQuery').mockReturnValue({ + data: { + documents: [createMockDocument()], + totalDocuments: 50, + hasNextPage: true, + }, + isLoading: false, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(screen.getByLabelText('Next page')).toBeInTheDocument(); + expect(screen.getByLabelText('Previous page')).toBeInTheDocument(); + }); + }); + }); + + describe('Collection Filtering', () => { + it('should fetch collection data when collectionId is provided', async () => { + vi.spyOn(ragReducer, 'useGetCollectionQuery').mockReturnValue({ + data: { collectionId: 'col-123', name: 'Test Collection' }, + } as any); + + renderWithProviders( + + + + ); + + await waitFor(() => { + expect(ragReducer.useGetCollectionQuery).toHaveBeenCalled(); + }); + }); + }); + + describe('Utility Functions', () => { + it('should return correct matches count text for single match', () => { + expect(getMatchesCountText(1)).toBe('1 match'); + }); + + it('should return correct matches count text for multiple matches', () => { + expect(getMatchesCountText(5)).toBe('5 matches'); + }); + + it('should return correct matches count text for zero matches', () => { + expect(getMatchesCountText(0)).toBe('0 matches'); + }); + }); +}); diff --git a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx index a7f269085..5ad0c14f9 100644 --- a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx +++ b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx @@ -29,6 +29,7 @@ import SpaceBetween from '@cloudscape-design/components/space-between'; import { ragApi, useDeleteRagDocumentsMutation, + useGetCollectionQuery, useLazyDownloadRagDocumentQuery, useListRagDocumentsQuery, } from '../../shared/reducers/rag.reducer'; @@ -46,6 +47,7 @@ import { downloadFile } from '../../shared/util/downloader'; type DocumentLibraryComponentProps = { repositoryId?: string; + collectionId?: string; }; export function getMatchesCountText (count) { @@ -60,7 +62,7 @@ function disabledDeleteReason (selectedItems: ReadonlyArray) { return selectedItems.length === 0 ? 'Please select an item' : 'You are not an owner of all selected items'; } -export function DocumentLibraryComponent ({ repositoryId }: DocumentLibraryComponentProps): ReactElement { +export function DocumentLibraryComponent ({ repositoryId, collectionId }: DocumentLibraryComponentProps): ReactElement { const [deleteMutation, { isLoading: isDeleteLoading }] = useDeleteRagDocumentsMutation(); const [currentPage, setCurrentPage] = useState(1); @@ -80,9 +82,16 @@ export function DocumentLibraryComponent ({ repositoryId }: DocumentLibraryCompo const [preferences, setPreferences] = useLocalStorage('DocumentRagPreferences', DEFAULT_PREFERENCES); const dispatch = useAppDispatch(); + // Fetch collection data if collectionId is provided + const { data: collectionData } = useGetCollectionQuery( + { repositoryId, collectionId }, + { skip: !repositoryId || !collectionId } + ); + const { data: paginatedDocs, isLoading } = useListRagDocumentsQuery( { repositoryId, + collectionId, lastEvaluatedKey: lastEvaluatedKey || undefined, pageSize: preferences.pageSize }, @@ -214,7 +223,9 @@ export function DocumentLibraryComponent ({ repositoryId }: DocumentLibraryCompo } > - {repositoryId} Documents + {collectionId && collectionData + ? `${collectionData.name || collectionId} Documents` + : `${repositoryId} Documents`} } pagination={ diff --git a/lib/user-interface/react/src/components/document-library/RepositoryLibraryComponent.tsx b/lib/user-interface/react/src/components/document-library/RepositoryLibraryComponent.tsx deleted file mode 100644 index f152dba1b..000000000 --- a/lib/user-interface/react/src/components/document-library/RepositoryLibraryComponent.tsx +++ /dev/null @@ -1,128 +0,0 @@ -/** - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -import { ReactElement, useEffect, useState } from 'react'; -import { Box, Cards, CollectionPreferences, Header, Pagination, TextFilter } from '@cloudscape-design/components'; -import SpaceBetween from '@cloudscape-design/components/space-between'; -import { - CARD_DEFINITIONS, - DEFAULT_PREFERENCES, - PAGE_SIZE_OPTIONS, - VISIBLE_CONTENT_OPTIONS, -} from './RepositoryLibraryConfig'; -import { useListRagRepositoriesQuery } from '../../shared/reducers/rag.reducer'; -import { useLocalStorage } from '../../shared/hooks/use-local-storage'; -import { useNavigate } from 'react-router-dom'; -import { RagRepositoryConfig } from '#root/lib/schema'; - -export function RepositoryLibraryComponent (): ReactElement { - const { - data: allRepos, - isLoading: fetchingRepos, - } = useListRagRepositoriesQuery(undefined, { refetchOnMountOrArgChange: 5 }); - - const [matchedRepos, setMatchedRepos] = useState([]); - const [searchText, setSearchText] = useState(''); - const [numberOfPages, setNumberOfPages] = useState(1); - const [currentPageIndex, setCurrentPageIndex] = useState(1); - const [selectedItems, setSelectedItems] = useState([]); - const [preferences, setPreferences] = useLocalStorage('RagPreferences', DEFAULT_PREFERENCES); - const [count, setCount] = useState(0); - - const navigate = useNavigate(); - - useEffect(() => { - let newPageCount: number; - if (searchText) { - const filteredRepos = allRepos.filter((repo) => JSON.stringify(repo).toLowerCase().includes(searchText.toLowerCase())); - setMatchedRepos( - filteredRepos.slice(preferences.pageSize * (currentPageIndex - 1), preferences.pageSize * currentPageIndex), - ); - newPageCount = Math.ceil(filteredRepos.length / preferences.pageSize); - setCount(filteredRepos.length); - } else { - setMatchedRepos(allRepos ? allRepos.slice(preferences.pageSize * (currentPageIndex - 1), preferences.pageSize * currentPageIndex) : []); - newPageCount = Math.ceil(allRepos ? (allRepos.length / preferences.pageSize) : 1); - setCount(allRepos ? allRepos.length : 0); - } - - if (newPageCount < numberOfPages) { - setCurrentPageIndex(1); - } - setNumberOfPages(newPageCount); - }, [allRepos, searchText, preferences, currentPageIndex, numberOfPages]); - - return ( - <> - setSelectedItems(detail?.selectedItems ?? [])} - selectedItems={selectedItems} - ariaLabels={{ - itemSelectionLabel: (e, t) => `select ${t.modelName}`, - selectionGroupLabel: 'Repo selection', - }} - cardDefinition={CARD_DEFINITIONS(navigate)} - visibleSections={preferences.visibleContent} - loadingText='Loading repos' - items={matchedRepos} - trackBy='repositoryId' - variant='full-page' - loading={fetchingRepos && !allRepos} - cardsPerRow={[{ cards: 3 }]} - header={ -
- Repositories -
- } - filter={ { - setSearchText(detail.filteringText); - }} />} - pagination={ setCurrentPageIndex(detail.currentPageIndex)} - pagesCount={numberOfPages} />} - preferences={ - setPreferences(detail)} - pageSizePreference={{ - title: 'Page size', - options: PAGE_SIZE_OPTIONS, - }} - visibleContentPreference={{ - title: 'Select visible columns', - options: VISIBLE_CONTENT_OPTIONS, - }} - /> - } - empty={ - - - No repositories - - - } - /> - - ); -} - -export default RepositoryLibraryComponent; diff --git a/lib/user-interface/react/src/components/document-library/createCollection/AccessControlForm.tsx b/lib/user-interface/react/src/components/document-library/createCollection/AccessControlForm.tsx new file mode 100644 index 000000000..8d0dff0f0 --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/createCollection/AccessControlForm.tsx @@ -0,0 +1,67 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import React, { ReactElement } from 'react'; +import FormField from '@cloudscape-design/components/form-field'; +import { Alert, SpaceBetween } from '@cloudscape-design/components'; +import { ArrayInputField } from '../../../shared/form/array-input'; +import { CommonFieldsForm } from '../../../shared/form/CommonFieldsForm'; +import { RagCollectionConfig } from '#root/lib/schema'; +import { ModifyMethod } from '../../../shared/form/form-props'; + +export type AccessControlFormProps = { + item: RagCollectionConfig; + setFields(values: { [key: string]: any }, method?: ModifyMethod): void; + touchFields(fields: string[], method?: ModifyMethod): void; + formErrors: any; +}; + +export function AccessControlForm (props: AccessControlFormProps): ReactElement { + const { item, touchFields, setFields, formErrors } = props; + + return ( + + + Access control is optional. If no groups are specified, the collection will be + accessible to all users. You can also inherit access controls from the parent repository. + + + {/* Common Fields (Allowed Groups) */} + + + {/* Metadata Tags */} + + setFields({ 'metadata.tags': tags })} + placeholder='Add tag' + /> + + + ); +} diff --git a/lib/user-interface/react/src/components/document-library/createCollection/ChunkingConfigForm.test.tsx b/lib/user-interface/react/src/components/document-library/createCollection/ChunkingConfigForm.test.tsx new file mode 100644 index 000000000..191ada8bc --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/createCollection/ChunkingConfigForm.test.tsx @@ -0,0 +1,374 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { render, screen } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { describe, it, expect, vi } from 'vitest'; +import { ChunkingConfigForm } from './ChunkingConfigForm'; +import { ChunkingStrategyType } from '#root/lib/schema'; + +describe('ChunkingConfigForm', () => { + const mockSetFields = vi.fn(); + const mockTouchFields = vi.fn(); + + beforeEach(() => { + mockSetFields.mockClear(); + mockTouchFields.mockClear(); + }); + + describe('Dropdown Options', () => { + it('displays FIXED and NONE options in dropdown', async () => { + const user = userEvent.setup(); + + render( + + ); + + // Click the dropdown to open it (find by text content) + const dropdown = screen.getByRole('button'); + await user.click(dropdown); + + // Verify both options are present using getAllByText since they appear multiple times + const fixedOptions = screen.getAllByText('Fixed Size'); + const noneOptions = screen.getAllByText('None (No Chunking)'); + + expect(fixedOptions.length).toBeGreaterThan(0); + expect(noneOptions.length).toBeGreaterThan(0); + }); + }); + + describe('FIXED Strategy Selection', () => { + it('shows size and overlap fields when FIXED is selected', () => { + render( + + ); + + // Verify size and overlap fields are visible + expect(screen.getByLabelText(/chunk size/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/chunk overlap/i)).toBeInTheDocument(); + }); + + it('displays correct values for FIXED strategy', () => { + render( + + ); + + const sizeInput = screen.getByLabelText(/chunk size/i) as HTMLInputElement; + const overlapInput = screen.getByLabelText(/chunk overlap/i) as HTMLInputElement; + + expect(sizeInput.value).toBe('1024'); + expect(overlapInput.value).toBe('100'); + }); + + it('calls setFields when size is changed', async () => { + const user = userEvent.setup(); + + render( + + ); + + const sizeInput = screen.getByLabelText(/chunk size/i); + + // Type directly without clearing (which triggers onChange for each character) + await user.click(sizeInput); + await user.keyboard('{Control>}a{/Control}'); // Select all + await user.keyboard('1024'); + + // Verify setFields was called (it will be called multiple times during typing) + expect(mockSetFields).toHaveBeenCalled(); + + // Verify at least one call contains the size field + const calls = mockSetFields.mock.calls; + const hasSizeCall = calls.some((call) => + call[0] && typeof call[0]['chunkingStrategy.size'] === 'number' + ); + expect(hasSizeCall).toBe(true); + }); + + it('calls setFields when overlap is changed', async () => { + const user = userEvent.setup(); + + render( + + ); + + const overlapInput = screen.getByLabelText(/chunk overlap/i); + + // Type directly without clearing (which triggers onChange for each character) + await user.click(overlapInput); + await user.keyboard('{Control>}a{/Control}'); // Select all + await user.keyboard('100'); + + // Verify setFields was called (it will be called multiple times during typing) + expect(mockSetFields).toHaveBeenCalled(); + + // Verify at least one call contains the overlap field + const calls = mockSetFields.mock.calls; + const hasOverlapCall = calls.some((call) => + call[0] && typeof call[0]['chunkingStrategy.overlap'] === 'number' + ); + expect(hasOverlapCall).toBe(true); + }); + }); + + describe('NONE Strategy Selection', () => { + it('hides size and overlap fields when NONE is selected', () => { + render( + + ); + + // Verify size and overlap fields are NOT visible + expect(screen.queryByLabelText(/chunk size/i)).not.toBeInTheDocument(); + expect(screen.queryByLabelText(/chunk overlap/i)).not.toBeInTheDocument(); + }); + + it('displays NONE as selected option', () => { + render( + + ); + + // The selected option should show "None (No Chunking)" + expect(screen.getByText('None (No Chunking)')).toBeInTheDocument(); + }); + + it('calls setFields with NONE strategy when NONE is selected', async () => { + const user = userEvent.setup(); + + render( + + ); + + // Click dropdown and select NONE + const dropdown = screen.getByRole('button'); + await user.click(dropdown); + + const noneOption = screen.getByText('None (No Chunking)'); + await user.click(noneOption); + + expect(mockSetFields).toHaveBeenCalledWith({ + chunkingStrategy: { type: ChunkingStrategyType.NONE } + }); + }); + }); + + describe('Strategy Switching', () => { + it('switches from FIXED to NONE and hides fields', async () => { + const { rerender } = render( + + ); + + // Verify FIXED fields are visible + expect(screen.getByLabelText(/chunk size/i)).toBeInTheDocument(); + + // Switch to NONE + rerender( + + ); + + // Verify fields are hidden + expect(screen.queryByLabelText(/chunk size/i)).not.toBeInTheDocument(); + expect(screen.queryByLabelText(/chunk overlap/i)).not.toBeInTheDocument(); + }); + + it('switches from NONE to FIXED and shows fields', async () => { + const { rerender } = render( + + ); + + // Verify fields are hidden + expect(screen.queryByLabelText(/chunk size/i)).not.toBeInTheDocument(); + + // Switch to FIXED + rerender( + + ); + + // Verify fields are visible + expect(screen.getByLabelText(/chunk size/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/chunk overlap/i)).toBeInTheDocument(); + }); + + it('calls setFields with default FIXED values when switching to FIXED', async () => { + const user = userEvent.setup(); + + render( + + ); + + // Click dropdown and select FIXED + const dropdown = screen.getByRole('button'); + await user.click(dropdown); + + const fixedOption = screen.getByText('Fixed Size'); + await user.click(fixedOption); + + expect(mockSetFields).toHaveBeenCalledWith({ + chunkingStrategy: { + type: ChunkingStrategyType.FIXED, + size: 512, + overlap: 51 + } + }); + }); + }); + + describe('Form Validation', () => { + it('displays error for size field', () => { + render( + + ); + + expect(screen.getByText('Size must be between 100 and 10000')).toBeInTheDocument(); + }); + + it('displays error for overlap field', () => { + render( + + ); + + expect(screen.getByText('Overlap must be less than size/2')).toBeInTheDocument(); + }); + + it('calls touchFields when size input loses focus', async () => { + const user = userEvent.setup(); + + render( + + ); + + const sizeInput = screen.getByLabelText(/chunk size/i); + await user.click(sizeInput); + await user.tab(); // Blur the input + + expect(mockTouchFields).toHaveBeenCalledWith(['chunkingStrategy.size']); + }); + + it('calls touchFields when overlap input loses focus', async () => { + const user = userEvent.setup(); + + render( + + ); + + const overlapInput = screen.getByLabelText(/chunk overlap/i); + await user.click(overlapInput); + await user.tab(); // Blur the input + + expect(mockTouchFields).toHaveBeenCalledWith(['chunkingStrategy.overlap']); + }); + }); + + describe('Default Behavior', () => { + it('defaults to FIXED strategy when item is undefined', () => { + render( + + ); + + // Should show "Fixed Size" as selected + expect(screen.getByText('Fixed Size')).toBeInTheDocument(); + }); + }); +}); diff --git a/lib/user-interface/react/src/components/document-library/createCollection/ChunkingConfigForm.tsx b/lib/user-interface/react/src/components/document-library/createCollection/ChunkingConfigForm.tsx new file mode 100644 index 000000000..9fd62911e --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/createCollection/ChunkingConfigForm.tsx @@ -0,0 +1,130 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import React, { ReactElement } from 'react'; +import FormField from '@cloudscape-design/components/form-field'; +import Input from '@cloudscape-design/components/input'; +import Select from '@cloudscape-design/components/select'; +import { SpaceBetween } from '@cloudscape-design/components'; +import { ChunkingStrategy, ChunkingStrategyType } from '#root/lib/schema'; +import { ModifyMethod } from '../../../shared/form/form-props'; + +// Utility function to create default chunking strategy +function createDefaultChunkingStrategy () { + return { + type: ChunkingStrategyType.FIXED, + size: 512, + overlap: 51, + }; +} + +export type ChunkingConfigFormProps = { + item: ChunkingStrategy | undefined; + setFields(values: { [key: string]: any }, method?: ModifyMethod): void; + touchFields(fields: string[], method?: ModifyMethod): void; + formErrors: any; + disabled?: boolean; +}; + +export function ChunkingConfigForm (props: ChunkingConfigFormProps): ReactElement { + const { item, touchFields, setFields, formErrors, disabled = false } = props; + + // Chunking type options + const chunkingTypeOptions = [ + { label: 'Fixed Size', value: ChunkingStrategyType.FIXED }, + { label: 'None (No Chunking)', value: ChunkingStrategyType.NONE }, + // Future: { label: 'Semantic', value: ChunkingStrategyType.SEMANTIC }, + // Future: { label: 'Recursive', value: ChunkingStrategyType.RECURSIVE }, + ]; + + return ( + + {/* Chunking Type */} + + { + setFields({ + 'chunkingStrategy.size': Number(detail.value) + }); + }} + onBlur={() => touchFields(['chunkingStrategy.size'])} + disabled={disabled} + /> + + + + { + setFields({ + 'chunkingStrategy.overlap': Number(detail.value) + }); + }} + onBlur={() => touchFields(['chunkingStrategy.overlap'])} + disabled={disabled} + /> + + + )} + + ); +} diff --git a/lib/user-interface/react/src/components/document-library/createCollection/CollectionConfigForm.test.tsx b/lib/user-interface/react/src/components/document-library/createCollection/CollectionConfigForm.test.tsx new file mode 100644 index 000000000..0627a9e28 --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/createCollection/CollectionConfigForm.test.tsx @@ -0,0 +1,311 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { screen } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { CollectionConfigForm } from './CollectionConfigForm'; +import { renderWithProviders, createMockQueryHook } from '../../../test/helpers/render'; +import * as ragReducer from '../../../shared/reducers/rag.reducer'; +import * as modelManagementReducer from '../../../shared/reducers/model-management.reducer'; +import { ModelStatus, ModelType } from '../../../shared/model/model-management.model'; + +describe('CollectionConfigForm', () => { + const mockSetFields = vi.fn(); + const mockTouchFields = vi.fn(); + const mockFormErrors = {}; + + const mockRepositories = [ + { + repositoryId: 'repo-1', + repositoryName: 'Repository 1', + }, + { + repositoryId: 'repo-2', + repositoryName: 'Repository 2', + }, + ]; + + const mockEmbeddingModels = [ + { + modelId: 'model-1', + modelName: 'Embedding Model 1', + modelType: ModelType.embedding, + status: ModelStatus.InService, + }, + ]; + + const mockItem = { + repositoryId: '', + name: '', + description: '', + embeddingModel: '', + chunkingStrategy: undefined, + allowedGroups: [], + metadata: { tags: [], customFields: {} }, + allowChunkingOverride: true, + pipelines: [], + }; + + beforeEach(() => { + vi.clearAllMocks(); + + vi.spyOn(ragReducer, 'useListRagRepositoriesQuery').mockReturnValue({ + data: mockRepositories, + isLoading: false, + isError: false, + error: undefined, + refetch: vi.fn(), + } as any); + + vi.spyOn(modelManagementReducer, 'useGetAllModelsQuery').mockImplementation( + createMockQueryHook(mockEmbeddingModels) as any + ); + }); + + describe('Rendering', () => { + it('should render all form fields', () => { + renderWithProviders( + + ); + + expect(screen.getByText('Collection Name')).toBeInTheDocument(); + expect(screen.getByText('Description (optional)')).toBeInTheDocument(); + expect(screen.getByText('Repository')).toBeInTheDocument(); + expect(screen.getByText('Embedding Model')).toBeInTheDocument(); + }); + + it('should render collection name input', () => { + renderWithProviders( + + ); + + const input = screen.getByPlaceholderText('Documents'); + expect(input).toBeInTheDocument(); + }); + }); + + describe('User Interactions', () => { + it('should call setFields when collection name changes', async () => { + const user = userEvent.setup(); + + renderWithProviders( + + ); + + const input = screen.getByPlaceholderText('Documents'); + await user.type(input, 'Test Collection'); + + expect(mockSetFields).toHaveBeenCalledWith({ name: 'T' }); + }); + + it('should call touchFields when collection name loses focus', async () => { + const user = userEvent.setup(); + + renderWithProviders( + + ); + + const input = screen.getByPlaceholderText('Documents'); + await user.click(input); + await user.tab(); + + expect(mockTouchFields).toHaveBeenCalledWith(['name']); + }); + + it('should call setFields when description changes', async () => { + const user = userEvent.setup(); + + renderWithProviders( + + ); + + const textarea = screen.getByPlaceholderText('Collection of documents'); + await user.type(textarea, 'Test'); + + expect(mockSetFields).toHaveBeenCalledWith({ description: 'T' }); + }); + + it('should render repository select with options', async () => { + renderWithProviders( + + ); + + // Repository select should be present + expect(screen.getByText('Repository')).toBeInTheDocument(); + expect(screen.getByText('Select a repository')).toBeInTheDocument(); + }); + }); + + describe('Edit Mode', () => { + it('should disable repository field when isEdit is true', () => { + renderWithProviders( + + ); + + // Repository field should be present and disabled + expect(screen.getByText('Repository')).toBeInTheDocument(); + }); + + it('should disable embedding model field when isEdit is true', () => { + renderWithProviders( + + ); + + const input = screen.getByPlaceholderText('Select an embedding model'); + expect(input).toBeDisabled(); + }); + + + it('should enable repository field when isEdit is false', () => { + renderWithProviders( + + ); + + // Repository field should be present + expect(screen.getByText('Repository')).toBeInTheDocument(); + const selectTrigger = screen.getByText('Select a repository'); + expect(selectTrigger).toBeInTheDocument(); + }); + + it('should enable embedding model field when isEdit is false', () => { + renderWithProviders( + + ); + + const input = screen.getByPlaceholderText('Select an embedding model'); + expect(input).not.toBeDisabled(); + }); + + }); + + describe('Error Handling', () => { + it('should display error for collection name', () => { + const errorMessage = 'Collection name is required'; + renderWithProviders( + + ); + + expect(screen.getByText(errorMessage)).toBeInTheDocument(); + }); + + it('should display error for repository', () => { + const errorMessage = 'Repository is required'; + renderWithProviders( + + ); + + expect(screen.getByText(errorMessage)).toBeInTheDocument(); + }); + }); + + describe('Loading States', () => { + it('should show loading state for repositories', () => { + vi.spyOn(ragReducer, 'useListRagRepositoriesQuery').mockReturnValue({ + data: undefined, + isLoading: true, + isError: false, + error: undefined, + refetch: vi.fn(), + } as any); + + renderWithProviders( + + ); + + // Repository field should be present + expect(screen.getByText('Repository')).toBeInTheDocument(); + }); + }); +}); diff --git a/lib/user-interface/react/src/components/document-library/createCollection/CollectionConfigForm.tsx b/lib/user-interface/react/src/components/document-library/createCollection/CollectionConfigForm.tsx new file mode 100644 index 000000000..2b83a6338 --- /dev/null +++ b/lib/user-interface/react/src/components/document-library/createCollection/CollectionConfigForm.tsx @@ -0,0 +1,175 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import React, { ReactElement, useMemo } from 'react'; +import { FormProps } from '../../../shared/form/form-props'; +import FormField from '@cloudscape-design/components/form-field'; +import Input from '@cloudscape-design/components/input'; +import Select from '@cloudscape-design/components/select'; +import { SpaceBetween, Textarea } from '@cloudscape-design/components'; +import { useListRagRepositoriesQuery } from '../../../shared/reducers/rag.reducer'; +import { CommonFieldsForm } from '../../../shared/form/CommonFieldsForm'; +import { RagCollectionConfig, RagRepositoryType, VectorStoreStatus } from '#root/lib/schema'; + +export type CollectionConfigProps = { + isEdit: boolean; +}; + +export function CollectionConfigForm ( + props: FormProps & CollectionConfigProps +): ReactElement { + const { item, touchFields, setFields, formErrors, isEdit } = props; + + // Fetch repositories for dropdown + const { data: repositories, isLoading: isLoadingRepos } = useListRagRepositoriesQuery(undefined, { + refetchOnMountOrArgChange: 5 + }); + + // Repository options + const repositoryOptions = useMemo(() => { + if (!repositories || !Array.isArray(repositories)) { + return []; + } + return repositories + .filter((repository) => + repository.status && [ + VectorStoreStatus.CREATE_COMPLETE, + VectorStoreStatus.UPDATE_COMPLETE, + VectorStoreStatus.UPDATE_COMPLETE_CLEANUP_IN_PROGRESS, + VectorStoreStatus.UPDATE_IN_PROGRESS, + ].includes(repository.status) + ) + // BRK not supported yet + .filter((repo) => repo.type !== RagRepositoryType.BEDROCK_KNOWLEDGE_BASE) + .map((repo) => ({ + label: repo.repositoryName || repo.repositoryId, + value: repo.repositoryId, + })); + }, [repositories]); + + return ( + + {/* Collection Name */} + + setFields({ name: detail.value })} + onBlur={() => touchFields(['name'])} + placeholder='Documents' + /> + + + {/* Description */} + +