diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9b1982fe6..d44717974 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -107,7 +107,7 @@ repos: - --config-file=pyproject.toml - --install-types - --non-interactive - exclude: ^test + exclude: ^test/ - repo: https://github.com/pre-commit/mirrors-eslint rev: 'v9.39.1' @@ -125,11 +125,11 @@ repos: # hooks: # - id: python-safety-dependencies-check -# - repo: https://github.com/asottile/pyupgrade -# rev: v3.19.0 -# hooks: -# - id: pyupgrade -# args: [--py313-plus] +- repo: https://github.com/asottile/pyupgrade + rev: v3.19.0 + hooks: + - id: pyupgrade + args: [--py313-plus] # - repo: https://github.com/bridgecrewio/checkov # rev: '3.2.327' diff --git a/Makefile b/Makefile index ea53341d1..e3ba78f73 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,9 @@ createTypeScriptEnvironment installTypeScriptRequirements \ deploy destroy \ clean cleanTypeScript cleanPython cleanCfn cleanMisc \ - help dockerCheck dockerLogin listStacks modelCheck buildNpmModules + help dockerCheck dockerLogin listStacks modelCheck buildNpmModules \ + test test-coverage test-lambda test-mcp-workbench test-sdk test-rest-api test-sdk-integ test-integ test-rag-integ test-metadata-integ \ + lock-poetry validate-deps ################################################################################# # GLOBALS # @@ -138,6 +140,7 @@ installPythonRequirements: CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install pip --upgrade CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install --prefer-binary -r requirements-dev.txt CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install -e lisa-sdk + CC=/usr/bin/gcc10-gcc CXX=/usr/bin/gcc10-g++ pip3 install -e lib/serve/mcp-workbench ## Set up TypeScript interpreter environment createTypeScriptEnvironment: @@ -366,14 +369,104 @@ help: }' \ | more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars') -## Run Python tests with coverage report +## Run all Python unit tests (non-integration) with coverage report test-coverage: - pytest --verbose \ + @echo "Running lambda tests with coverage..." + @pytest test/lambda --verbose \ --cov lambda \ --cov-report term-missing \ --cov-report html:build/coverage \ --cov-report xml:build/coverage/coverage.xml \ --cov-fail-under 83 + @echo "" + @echo "Running MCP Workbench tests with coverage..." + @pytest test/mcp-workbench --verbose \ + --cov lib/serve/mcp-workbench/src \ + --cov-report term-missing \ + --cov-report html:build/coverage-mcp \ + --cov-report xml:build/coverage-mcp/coverage.xml \ + --cov-append \ + --cov-fail-under 83 + @echo "" + @echo "Running SDK tests with coverage..." + @pytest test/sdk --verbose \ + --cov lisa-sdk/lisapy \ + --cov-report term-missing \ + --cov-report html:build/coverage-sdk \ + --cov-report xml:build/coverage-sdk/coverage.xml \ + --cov-append \ + --cov-fail-under 80 + @echo "" + @echo "Running REST API tests with coverage..." + @pytest test/rest-api --verbose \ + --cov lib/serve/rest-api/src \ + --cov-config lib/serve/rest-api/.coveragerc \ + --cov-report term-missing \ + --cov-report html:build/coverage-rest-api \ + --cov-report xml:build/coverage-rest-api/coverage.xml \ + --cov-append \ + --cov-fail-under 80 + + +## Run all Python unit tests (non-integration) without coverage +test: + @echo "Running lambda tests..." + @pytest test/lambda --verbose + @echo "" + @echo "Running MCP Workbench tests..." + @pytest test/mcp-workbench --verbose + @echo "" + @echo "Running SDK tests..." + @pytest test/sdk --verbose + @echo "" + @echo "Running REST API tests..." + @pytest test/rest-api --verbose + +## Run lambda tests only +test-lambda: + pytest test/lambda --verbose + +## Run MCP Workbench tests only +test-mcp-workbench: + pytest test/mcp-workbench --verbose + +## Run LISA SDK unit tests only +test-sdk: + pytest test/sdk --verbose + +## Run REST API unit tests only +test-rest-api: + pytest test/rest-api --verbose + +## Run LISA SDK integration tests (requires deployed LISA environment) +test-sdk-integ: + @echo "Running LISA SDK integration tests..." + @echo "Note: These tests require a deployed LISA environment with:" + @echo " - --api or --url argument for API endpoint" + @echo " - --region, --deployment, --profile arguments" + @echo " - AWS credentials configured" + @echo "" + @echo "Example: pytest test/integration/sdk --api https://your-api.com --region us-west-2" + @echo "" + pytest test/integration/sdk --verbose + +## Run integration tests (Python-based) +test-integ: + pytest test/python --verbose + +## Run RAG integration tests (requires deployed LISA environment) +test-rag-integ: + @echo "Running RAG integration tests..." + @echo "Note: These tests require a deployed LISA environment with:" + @echo " - LISA_API_URL environment variable set" + @echo " - LISA_DEPLOYMENT_NAME environment variable set" + @echo " - AWS credentials configured" + @echo "" + pytest test/integration --verbose + +## Run repository metadata preservation integration tests +test-metadata-integ: + pytest test/integration/test_repository_update_metadata_preservation.py --verbose ## Regenerate all Poetry lock files lock-poetry: diff --git a/assets/LisaArchitecture.png b/assets/LisaArchitecture.png deleted file mode 100644 index 884e275bf..000000000 Binary files a/assets/LisaArchitecture.png and /dev/null differ diff --git a/assets/LisaChat.png b/assets/LisaChat.png deleted file mode 100644 index f9d958e0a..000000000 Binary files a/assets/LisaChat.png and /dev/null differ diff --git a/assets/LisaModelManagement.png b/assets/LisaModelManagement.png deleted file mode 100644 index 61d47c444..000000000 Binary files a/assets/LisaModelManagement.png and /dev/null differ diff --git a/assets/LisaServe.png b/assets/LisaServe.png deleted file mode 100644 index 22365e94f..000000000 Binary files a/assets/LisaServe.png and /dev/null differ diff --git a/bin/build-images b/bin/build-images index afece822c..dd887b992 100755 --- a/bin/build-images +++ b/bin/build-images @@ -116,7 +116,7 @@ build_all_images() { build_image "Dockerfile" "lisa-rest-api" "$LISA_VERSION" "./lib/serve/rest-api" \ "NODE_ENV=production" \ "LITELLM_CONFIG=\"db_key: sk-a8814208-0388-480c-9fc7-fea59607ca38\"" \ - "BASE_IMAGE=python:3.13-slim" + "BASE_IMAGE=public.ecr.aws/docker/library/python:3.13-slim" # lisa-batch-ingestion RAG_DIR="./lib/rag/ingestion/ingestion-image" @@ -130,7 +130,7 @@ build_all_images() { MCP_DIR="./lib/serve/mcp-workbench" build_image "Dockerfile" "lisa-mcp-workbench" "$LISA_VERSION" "$MCP_DIR" \ "NODE_ENV=production" \ - "BASE_IMAGE=python:3.13-slim" + "BASE_IMAGE=public.ecr.aws/docker/library/python:3.13-slim" else echo "deployMcpWorkbench is disabled, skipping lisa-mcp-workbench build" echo "" @@ -151,7 +151,7 @@ build_all_images() { # lisa-vllm build_image "Dockerfile" "lisa-vllm" "latest" "./lib/serve/ecs-model/vllm" \ "NODE_ENV=production" \ - "BASE_IMAGE=vllm/vllm-openai:latest" \ + "BASE_IMAGE=public.ecr.aws/deep-learning-containers/vllm:0.13-gpu-py312" \ "MOUNTS3_DEB_URL=https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb" echo "All images built successfully!" diff --git a/cypress/src/e2e/fixtures/test-document.txt b/cypress/src/e2e/fixtures/test-document.txt new file mode 100644 index 000000000..8e3a89712 --- /dev/null +++ b/cypress/src/e2e/fixtures/test-document.txt @@ -0,0 +1 @@ +In the quiet town of Maplewood, there lived a cat named Whiskers. 𝒲hiskers, unlike other cats, had a penchant for adventure. One sunny afternoon, while exploring the attic, he stumbled upon an old, dusty book. 𝒞urious, he pawed at the cover until it opened to a page detailing a hidden treasure in the nearby forest. 𝒲hiskers, with his adventurous spirit, decided to embark on a quest. Armed with nothing but his curiosity, he ventured into the woods, guided by the book's cryptic clues. As the sun set, he found himself at the edge of a shimmering lake, where the treasure was said to be hidden. To his surprise, the lake reflected not gold, but the beauty of the world around him. 𝒲hiskers realized that sometimes, the greatest treasure is the journey itself. 𝒞ontent with his discovery, he returned home, ready for his next adventure. diff --git a/cypress/src/e2e/specs/bedrock-model-workflow.e2e.spec.ts b/cypress/src/e2e/specs/bedrock-model-workflow.e2e.spec.ts index 25b795189..c0c84f995 100644 --- a/cypress/src/e2e/specs/bedrock-model-workflow.e2e.spec.ts +++ b/cypress/src/e2e/specs/bedrock-model-workflow.e2e.spec.ts @@ -25,7 +25,8 @@ import { runBedrockModelWorkflowTests } from '../../shared/specs/bedrock-model-w describe('Bedrock Model Workflow (E2E)', () => { before(() => { - cy.clearAllSessionStorage(); + // Clear Cypress session cache to allow fresh login + Cypress.session.clearAllSavedSessions(); }); beforeEach(() => { diff --git a/cypress/src/e2e/support/commands.ts b/cypress/src/e2e/support/commands.ts index 6ec87d14d..20201a9d5 100644 --- a/cypress/src/e2e/support/commands.ts +++ b/cypress/src/e2e/support/commands.ts @@ -138,7 +138,7 @@ Cypress.Commands.add('loginAs', (role = 'user') => { .click({ force: true }); }); - // Wait for redirect back to app + // Wait for redirect back to app and allow configuration to load cy.wait(2000); }); }); @@ -155,16 +155,57 @@ Cypress.Commands.add('loginAs', (role = 'user') => { expect(hasOidcToken).to.equal(true); }); }, - cacheAcrossSpecs: true, + cacheAcrossSpecs: false, } ); - // After session restore/setup, Cypress clears the page - // We must visit again and wait for APIs + // After session restore/setup, Cypress clears the page which may have cancelled + // in-flight API requests. Selectively clear API cache reducers to ensure cancelled + // requests don't pollute the cache, while preserving user preferences. + cy.window().then((win) => { + const persistedState = win.localStorage.getItem('persist:lisa'); + if (persistedState) { + try { + const state = JSON.parse(persistedState); + // Clear only API cache reducers that may have stale/cancelled data + // Preserve: user, userPreferences, notification, modal, breadcrumbGroup + const apiReducersToReset = [ + 'models', // modelManagementApi.reducerPath + 'configuration', // configurationApi.reducerPath + 'sessions', // sessionApi.reducerPath + 'rag', // ragApi.reducerPath + 'promptTemplates', // promptTemplateApi.reducerPath + 'mcpServers', // mcpServerApi.reducerPath + 'mcpTools', // mcpToolsApi.reducerPath + 'apiTokens', // apiTokenApi.reducerPath + 'userPreferences', // userPreferencesApi.reducerPath + ]; + apiReducersToReset.forEach((key) => { + if (state[key]) { + delete state[key]; + } + }); + win.localStorage.setItem('persist:lisa', JSON.stringify(state)); + } catch { + // If parsing fails, remove the entire persisted state + win.localStorage.removeItem('persist:lisa'); + } + } + }); + + // Set up intercepts BEFORE visiting so they catch all requests + // cy.session() clears all intercepts, so we must set them up fresh here setupApiIntercepts(); + + // Visit the app - intercepts are now ready to catch requests cy.visit(BASE_URL); + + // Wait for app to be ready using DOM-based assertions waitForAppReady(); - waitForCriticalApis(); + + // Now wait for the critical configuration API to complete + // This ensures the app has loaded its configuration before tests proceed + // waitForCriticalApis(); log.snapshot('after'); log.end(); diff --git a/cypress/src/shared/specs/bedrock-model-workflow.shared.spec.ts b/cypress/src/shared/specs/bedrock-model-workflow.shared.spec.ts index 9c3c2d25b..318763487 100644 --- a/cypress/src/shared/specs/bedrock-model-workflow.shared.spec.ts +++ b/cypress/src/shared/specs/bedrock-model-workflow.shared.spec.ts @@ -22,7 +22,7 @@ */ import { navigateToAdminPage } from '../../support/adminHelpers'; -import { navigateAndVerifyChatPage } from '../../support/chatHelpers'; +import { insertChatPrompt, navigateAndVerifyChatPage, sendMessageWithButton, verifyChatResponseReceived } from '../../support/chatHelpers'; import { BedrockModelConfig, openCreateModelWizard, @@ -57,12 +57,34 @@ import { verifyPromptTemplateInList, deletePromptTemplateIfExists, selectPromptTemplateInChat, - selectDirectiveAndSend, + promptTemplateExists, + PromptTemplateType, } from '../../support/promptTemplateHelpers'; +import { + CollectionConfig, + navigateToRagManagement, + waitForRepositoryReady, + getAutoCreatedCollectionInfo, + renameCollection, + uploadDocument, + waitForDocumentIngested, + selectRagRepositoryInChat, + selectCollectionInChat, +} from '../../support/collectionHelpers'; + + +// Use date-based naming for easier debugging and test reusability +function getTodayDateString (): string { + const today = new Date(); + const month = String(today.getMonth() + 1).padStart(2, '0'); + const day = String(today.getDate()).padStart(2, '0'); + const year = today.getFullYear(); + return `${month}-${day}-${year}`; +} // Amazon Nova Micro - cheapest Bedrock serverless model const DEFAULT_TEST_MODEL: BedrockModelConfig = { - modelId: `e2e-nova-micro-${Date.now()}`, + modelId: `e2e-nova-micro-${getTodayDateString()}`, modelName: 'bedrock/us.amazon.nova-micro-v1:0', modelDescription: 'E2E test model - Amazon Nova Micro', streaming: true, @@ -71,20 +93,43 @@ const DEFAULT_TEST_MODEL: BedrockModelConfig = { export type BedrockWorkflowTestOptions = { modelConfig?: BedrockModelConfig; repositoryConfig?: RepositoryConfig; + collectionConfig?: CollectionConfig; promptTemplateConfig?: PromptTemplateConfig; skipChat?: boolean; skipCleanup?: boolean; + testDocumentPath?: string; }; export function runBedrockModelWorkflowTests (options: BedrockWorkflowTestOptions = {}) { + const dateString = getTodayDateString(); const testModel = options.modelConfig || DEFAULT_TEST_MODEL; const testRepository: RepositoryConfig = options.repositoryConfig || { - repositoryId: `e2e-repo-${Date.now()}`, + repositoryId: `e2e-repo-${dateString}`, knowledgeBaseName: 'test-bedrock-kb', dataSourceIndex: 0, }; + const testCollection: CollectionConfig = options.collectionConfig || { + collectionId: `e2e-collection-${dateString}`, + collectionName: `E2E Test Collection ${dateString}`, + repositoryId: testRepository.repositoryId, + }; + const testDocumentPath = options.testDocumentPath || 'test-document.txt'; + + // Track test state for dependencies + const testState = { + modelCreated: false, + repositoryCreated: false, + repositoryReady: false, + collectionRenamed: false, + collectionId: '', // Store the actual collection ID + documentUploaded: false, + documentIngested: false, + personaTemplateCreated: false, + directiveTemplateCreated: false, + }; + const testPromptTemplatePersona: PromptTemplateConfig = { - title: `E2E Magic 8 Ball Persona ${Date.now()}`, + title: `E2E Magic 8 Ball Persona ${dateString}`, body: `You are a Magic 8 Ball—a mystical oracle that responds to yes/no questions with cryptic, fate-laden answers. You speak only in the traditional Magic 8 Ball responses, selecting one at random for each query. Never explain yourself, provide reasoning, or deviate from these phrases. Positive Responses: It is certain @@ -112,54 +157,149 @@ My sources say no Outlook not so good Very doubtful Respond with only one phrase per message, chosen randomly. Treat every input as a question seeking guidance from the universe.`, - type: 'system', + type: PromptTemplateType.Persona, sharePublic: true, }; const testPromptTemplateDirective: PromptTemplateConfig = { - title: `E2E Test Directive ${Date.now()}`, + title: `E2E Test Directive ${dateString}`, body: 'Is it going to rain', - type: 'user', + type: PromptTemplateType.Directive, sharePublic: true, }; - it('Admin creates a Bedrock model via wizard', () => { + it('Admin creates a Bedrock model via wizard (or uses existing)', () => { + // Ensure app is fully ready before navigating + cy.get('header button[aria-label="Libraries"]', { timeout: 30000 }).should('be.visible'); + cy.get('header button[aria-label="Administration"]', { timeout: 30000 }).should('be.visible'); + navigateToAdminPage('Model Management'); - openCreateModelWizard(); - fillBedrockModelConfig(testModel); - completeBedrockModelWizard(); - waitForModelCreationSuccess(testModel.modelId); + // Wait for models API to load and check if model already exists + cy.wait('@getModels', { timeout: 30000 }).then((interception) => { + const models = interception.response?.body || { models: [] }; + const modelExists = models.models.some((model: any) => model.modelId === testModel.modelId); + + if (modelExists) { + cy.log(`Model ${testModel.modelId} already exists, skipping creation`); + testState.modelCreated = true; + } else { + openCreateModelWizard(); + fillBedrockModelConfig(testModel); + completeBedrockModelWizard(); + waitForModelCreationSuccess(testModel.modelId); + testState.modelCreated = true; + } + }); }); - it('New model appears in Model Management list', () => { + it('New model appears in Model Management list', function () { + if (!testState.modelCreated) { + this.skip(); + } + navigateToAdminPage('Model Management'); verifyModelInList(testModel.modelId); }); - it('Admin creates a repository with the new Bedrock model', () => { + it('Admin creates a Bedrock Knowledgebase repository (or uses existing)', () => { navigateToRepositoryManagement(); - openCreateRepositoryWizard(); - fillRepositoryConfig(testRepository); - selectKnowledgeBase(testRepository.knowledgeBaseName); - selectDataSource(testRepository.dataSourceIndex); - skipToCreateRepository(); - completeRepositoryWizard(); - waitForRepositoryCreationSuccess(testRepository.repositoryId); + // Wait for repositories API to load and check if repository already exists + cy.wait('@getRepositories', { timeout: 30000 }).then((interception) => { + const repositories = interception.response?.body || []; + const repoExists = repositories.some((repo: any) => repo.repositoryId === testRepository.repositoryId); + + if (repoExists) { + cy.log(`Repository ${testRepository.repositoryId} already exists, skipping creation`); + testState.repositoryCreated = true; + } else { + openCreateRepositoryWizard(); + fillRepositoryConfig(testRepository); + selectKnowledgeBase(testRepository.knowledgeBaseName); + selectDataSource(testRepository.dataSourceIndex); + skipToCreateRepository(); + completeRepositoryWizard(); + waitForRepositoryCreationSuccess(testRepository.repositoryId); + testState.repositoryCreated = true; + } + }); }); - it('New repository appears in RAG Management list', () => { + it('New repository appears in RAG Management list', function () { + if (!testState.repositoryCreated) { + this.skip(); + } + navigateToRepositoryManagement(); verifyRepositoryInList(testRepository.repositoryId); }); + it('Wait for repository to be fully created and ready', function () { + if (!testState.repositoryCreated) { + this.skip(); + } + + navigateToRepositoryManagement(); + waitForRepositoryReady(testRepository.repositoryId, 300000); + testState.repositoryReady = true; + }); + + it('Rename auto-created collection to known name', function () { + if (!testState.repositoryReady) { + this.skip(); + } + + navigateToRagManagement(); + + // Get the auto-created collection info (name and ID) and rename it + getAutoCreatedCollectionInfo(testRepository.repositoryId).then((collectionInfo) => { + cy.log(`Auto-created collection: ${collectionInfo.name} (ID: ${collectionInfo.id})`); + testState.collectionId = collectionInfo.id; // Store the collection ID + renameCollection(collectionInfo.name, testCollection.collectionName); + testState.collectionRenamed = true; + }); + }); + + it('Upload test document to collection via chat page', function () { + if (!testState.collectionRenamed) { + this.skip(); + } + + // Navigate to chat page + navigateAndVerifyChatPage(); + + // Select model, repository, and collection + selectModelInChat(testModel.modelId); + selectRagRepositoryInChat(testRepository.repositoryId); + selectCollectionInChat(testCollection.collectionName); + + // Upload the document + uploadDocument(testDocumentPath); + testState.documentUploaded = true; + }); + + it('Wait for document to be ingested', function () { + if (!testState.documentUploaded) { + this.skip(); + } + waitForDocumentIngested(testRepository.repositoryId, testState.collectionId, testDocumentPath, 300000); + testState.documentIngested = true; + }); + it('Admin creates a persona prompt template', () => { navigateToPromptTemplates(); - openCreatePromptTemplateWizard(); - fillPromptTemplateConfig(testPromptTemplatePersona); - completePromptTemplateWizard(); - waitForPromptTemplateCreationSuccess(testPromptTemplatePersona.title); + promptTemplateExists(testPromptTemplatePersona.title).then((exists) => { + if (exists) { + cy.log(`Prompt template ${testPromptTemplatePersona.title} already exists, skipping creation`); + return; + } + + openCreatePromptTemplateWizard(); + fillPromptTemplateConfig(testPromptTemplatePersona); + completePromptTemplateWizard(); + waitForPromptTemplateCreationSuccess(testPromptTemplatePersona.title); + }); }); it('Persona prompt template appears in Prompt Templates list', () => { @@ -170,10 +310,17 @@ Respond with only one phrase per message, chosen randomly. Treat every input as it('Admin creates a directive prompt template', () => { navigateToPromptTemplates(); - openCreatePromptTemplateWizard(); - fillPromptTemplateConfig(testPromptTemplateDirective); - completePromptTemplateWizard(); - waitForPromptTemplateCreationSuccess(testPromptTemplateDirective.title); + promptTemplateExists(testPromptTemplateDirective.title).then((exists) => { + if (exists) { + cy.log(`Prompt template ${testPromptTemplateDirective.title} already exists, skipping creation`); + return; + } + + openCreatePromptTemplateWizard(); + fillPromptTemplateConfig(testPromptTemplateDirective); + completePromptTemplateWizard(); + waitForPromptTemplateCreationSuccess(testPromptTemplateDirective.title); + }); }); it('Directive prompt template appears in Prompt Templates list', () => { @@ -181,42 +328,55 @@ Respond with only one phrase per message, chosen randomly. Treat every input as verifyPromptTemplateInList(testPromptTemplateDirective.title); }); - it('User selects model, applies persona, inserts directive, and sends message', () => { + it('Send chat message with persona and directive', () => { navigateAndVerifyChatPage(); selectModelInChat(testModel.modelId); // Apply the Magic 8 Ball persona (system prompt) - selectPromptTemplateInChat(testPromptTemplatePersona.title, 'system'); - // Insert directive template and send message - selectDirectiveAndSend(testPromptTemplateDirective.title); + selectPromptTemplateInChat(testPromptTemplatePersona.title, PromptTemplateType.Persona); + selectPromptTemplateInChat(testPromptTemplateDirective.title, PromptTemplateType.Directive); + sendMessageWithButton(); + verifyChatResponseReceived(); }); - it('Cleanup: delete all chat sessions', () => { + it('Send chat message with rag response', () => { navigateAndVerifyChatPage(); - deleteAllSessions(); + selectModelInChat(testModel.modelId); + selectRagRepositoryInChat(testRepository.repositoryId); + selectCollectionInChat(testCollection.collectionName); + insertChatPrompt('Who is Whiskers?'); + sendMessageWithButton(); + verifyChatResponseReceived(); }); - it('Cleanup: delete test repository', () => { - navigateToRepositoryManagement(); - cy.wait(2000); - deleteRepositoryIfExists(testRepository.repositoryId); - }); + if (!options.skipCleanup) { + it('Cleanup: delete all chat sessions', () => { + navigateAndVerifyChatPage(); + deleteAllSessions(); + }); - it('Cleanup: delete persona prompt template', () => { - navigateToPromptTemplates(); - cy.wait(2000); - deletePromptTemplateIfExists(testPromptTemplatePersona.title); - }); + it('Cleanup: delete test repository', () => { + navigateToRepositoryManagement(); + cy.wait(2000); + deleteRepositoryIfExists(testRepository.repositoryId); + }); - it('Cleanup: delete directive prompt template', () => { - navigateToPromptTemplates(); - cy.wait(2000); - deletePromptTemplateIfExists(testPromptTemplateDirective.title); - }); + it('Cleanup: delete persona prompt template', () => { + navigateToPromptTemplates(); + cy.wait(2000); + deletePromptTemplateIfExists(testPromptTemplatePersona.title); + }); - it('Cleanup: delete test model', () => { - navigateToAdminPage('Model Management'); - cy.wait(2000); - deleteModelIfExists(testModel.modelId); - }); + it('Cleanup: delete directive prompt template', () => { + navigateToPromptTemplates(); + cy.wait(2000); + deletePromptTemplateIfExists(testPromptTemplateDirective.title); + }); + + it('Cleanup: delete test model', () => { + navigateToAdminPage('Model Management'); + cy.wait(2000); + deleteModelIfExists(testModel.modelId); + }); + } } diff --git a/cypress/src/shared/specs/user.shared.spec.ts b/cypress/src/shared/specs/user.shared.spec.ts index 109bef61f..adbfc890e 100644 --- a/cypress/src/shared/specs/user.shared.spec.ts +++ b/cypress/src/shared/specs/user.shared.spec.ts @@ -28,7 +28,7 @@ import { checkNoAdminButton } from '../../support/adminHelpers'; export function runUserTests () { it('Non-admin does not see the Administration button', () => { // Wait for configuration to load before checking UI - cy.wait('@getConfiguration', { timeout: 30000 }); + // cy.wait('@getConfiguration', { timeout: 30000 }); checkNoAdminButton(); }); diff --git a/cypress/src/support/chatHelpers.ts b/cypress/src/support/chatHelpers.ts index 32c2cd3ac..2a71854c6 100644 --- a/cypress/src/support/chatHelpers.ts +++ b/cypress/src/support/chatHelpers.ts @@ -181,10 +181,30 @@ export function sendMessageWithButton () { } /** - * Verify that a chat response was received + * Insert text into the chat prompt input + * @param text - The text to insert into the chat input + */ +export function insertChatPrompt (text: string) { + cy.get(CHAT_SELECTORS.MESSAGE_INPUT) + .should('be.visible') + .and('not.be.disabled') + .clear() + .type(text, { delay: 0 }); +} + +/** + * Verify that a chat response was received and is complete * @param minMessages - Minimum number of messages expected (default: 2 for user + assistant) */ export function verifyChatResponseReceived (minMessages: number = 2) { - cy.get('[data-testid="chat-message"]', { timeout: 30000 }) + // Wait for "Generating response" box to disappear, indicating the response is complete + cy.get('[data-testid="generating-response-box"]', { timeout: 60000 }).should('not.exist'); + + // Wait for AI response message + cy.get('[data-testid="chat-message-ai"]', { timeout: 30000 }) + .should('have.length.at.least', 1); + + // Verify total message count (user + assistant messages) + cy.get('[data-testid^="chat-message-"]', { timeout: 30000 }) .should('have.length.at.least', minMessages); } diff --git a/cypress/src/support/collectionHelpers.ts b/cypress/src/support/collectionHelpers.ts new file mode 100644 index 000000000..e118b00eb --- /dev/null +++ b/cypress/src/support/collectionHelpers.ts @@ -0,0 +1,432 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/** + * collectionHelpers.ts + * Reusable helpers for RAG collection management and document operations. + */ + +export type CollectionConfig = { + collectionId: string; + collectionName: string; + repositoryId: string; +}; + +/** + * Navigate to the RAG Management page + */ +export function navigateToRagManagement () { + cy.visit('/#/repository-management'); + cy.url().should('include', '/repository-management'); + cy.wait(1000); +} + +/** + * Get the API base URL from the application's environment + */ +function getApiBaseUrl (): Cypress.Chainable { + return cy.window().then((win: any) => { + const apiBaseUrl = win.env?.API_BASE_URL || ''; + return apiBaseUrl.replace(/\/+$/, ''); // Remove trailing slashes + }); +} + +/** + * Get the authentication token from session storage + */ +function getAuthToken (): Cypress.Chainable { + return cy.window().then((win) => { + // Find the OIDC token in sessionStorage + const oidcKey = Object.keys(win.sessionStorage).find((key) => key.startsWith('oidc.user:')); + if (oidcKey) { + const oidcData = JSON.parse(win.sessionStorage.getItem(oidcKey) || '{}'); + return oidcData.id_token || oidcData.access_token || null; + } + return null; + }); +} + +/** + * Make an authenticated API request + * @param method - HTTP method (GET, POST, PUT, DELETE, etc.) + * @param path - API path (e.g., '/repository', '/collections') + * @param options - Additional request options (body, headers, etc.) + */ +function makeAuthenticatedRequest ( + method: string, + path: string, + options: Partial = {} +): Cypress.Chainable> { + return getApiBaseUrl().then((apiBaseUrl) => { + return getAuthToken().then((token) => { + return cy.request({ + method, + url: `${apiBaseUrl}${path}`, + headers: { + ...(token ? { Authorization: `Bearer ${token}` } : {}), + ...(options.headers || {}), + }, + failOnStatusCode: false, + ...options, + }); + }); + }); +} + +/** + * Wait for repository to be fully created (up to 5 minutes) + * Checks repository status until it's CREATE_COMPLETE or UPDATE_COMPLETE + */ +export function waitForRepositoryReady (repositoryId: string, timeoutMs: number = 300000) { + cy.log(`Waiting for repository ${repositoryId} to be ready...`); + + const startTime = Date.now(); + const checkInterval = 10000; // Check every 10 seconds + + function checkRepositoryStatus (): Cypress.Chainable { + return makeAuthenticatedRequest('GET', '/repository').then((response) => { + if (response.status === 200 && Array.isArray(response.body)) { + const repository = response.body.find((repo: any) => repo.repositoryId === repositoryId); + + if (repository) { + cy.log(`Repository ${repositoryId} status: ${repository.status}`); + + if (repository.status === 'CREATE_COMPLETE' || repository.status === 'UPDATE_COMPLETE') { + cy.log(`Repository ${repositoryId} is ready with status: ${repository.status}`); + return cy.wrap(null); // Success - stop checking + } + + if (repository.status === 'CREATE_FAILED' || repository.status === 'UPDATE_FAILED') { + throw new Error(`Repository ${repositoryId} creation failed with status: ${repository.status}`); + } + } + } + + const elapsed = Date.now() - startTime; + if (elapsed < timeoutMs) { + return cy.wait(checkInterval).then(() => checkRepositoryStatus()); + } else { + throw new Error(`Repository ${repositoryId} did not become ready within ${timeoutMs}ms`); + } + }); + } + + checkRepositoryStatus(); +} + +/** + * Rename a collection via the collections table + * Finds the auto-created collection and renames it to a known name + */ +export function renameCollection (oldName: string, newName: string) { + cy.log(`Renaming collection from "${oldName}" to "${newName}"`); + + // Find the collection row by the link text (collection name is in a link) + cy.contains('a', oldName) + .should('be.visible') + .closest('tr') + .within(() => { + // Select the radio button for this row + cy.get('input[type="radio"]') + .first() + .click({ force: true }); + }); + + // Click the Actions button using data-testid (find button within the wrapper) + cy.get('[data-testid="collection-actions-button"]') + .find('button') + .should('be.visible') + .and('not.be.disabled') + .click(); + + // Click Edit from the dropdown menu (data-testid is on the li element) + cy.get('li[data-testid="edit"]') + .should('be.visible') + .click(); + + // Wait for the edit modal/wizard to open + cy.get('[role="dialog"]').should('be.visible'); + + // Update the collection name in the form (look for input with label "Collection Name") + cy.get('label') + .contains('Collection Name') + .invoke('attr', 'for') + .then((inputId) => { + cy.get(`#${inputId}`) + .should('be.visible') + .clear() + .type(newName); + }); + + // Click "Skip to Update" button to go to final step + cy.contains('button', 'Skip to Update') + .scrollIntoView() + .should('be.visible') + .click(); + + // Click "Update Collection" button to save changes + cy.contains('button', 'Update Collection') + .should('be.visible') + .and('not.be.disabled') + .click(); + + // Wait for success notification + cy.contains(/successfully.*updated/i, { timeout: 10000 }) + .should('be.visible'); +} + +/** + * Upload a document to a collection via the chat page + * Note: Model, repository, and collection must already be selected in the chat UI + */ +export function uploadDocument (filePath: string) { + cy.log(`Uploading document: ${filePath}`); + + // Click the "Upload to RAG" button + cy.get('button[data-testid="upload-to-rag"]') + .should('be.visible') + .and('not.be.disabled') + .click(); + + // Wait for upload dialog to open + cy.get('[role="dialog"]') + .filter(':visible') + .first() + .should('be.visible') + .within(() => { + // Select the file using the file input within the data-testid wrapper + // Path is relative to cypress directory + cy.get('[data-testid="rag-upload-file-input"]') + .find('input[type="file"]') + .selectFile(`src/e2e/fixtures/${filePath}`, { force: true }); + + // Wait a moment for file to be attached + cy.wait(1000); + + // Click the Upload button to submit + cy.contains('button', 'Upload') + .should('be.visible') + .and('not.be.disabled') + .click(); + }); + + // Wait for upload success notification + cy.contains(/successfully.*uploaded/i, { timeout: 30000 }) + .should('be.visible'); +} + +/** + * Wait for document to be ingested (up to 5 minutes) + * Checks document status until it appears in the documents list + * @param repositoryId - The repository ID + * @param collectionId - The collection ID (not the name) + * @param documentName - The document filename to wait for + * @param timeoutMs - Maximum time to wait in milliseconds + */ +export function waitForDocumentIngested (repositoryId: string, collectionId: string, documentName: string, timeoutMs: number = 300000) { + cy.log(`Waiting for document "${documentName}" to be ingested in collection ${collectionId}...`); + + const startTime = Date.now(); + const checkInterval = 10000; // Check every 10 seconds + + function checkDocumentStatus (): any { + return makeAuthenticatedRequest('GET', `/repository/${repositoryId}/document?collectionId=${collectionId}&pageSize=100`).then((response) => { + if (response.status === 200 && response.body.documents) { + // Look for document by name (matches the uploaded filename with timestamp prefix) + const document = response.body.documents.find((doc: any) => + doc.document_name && doc.document_name.includes(documentName) + ); + + if (document) { + cy.log(`Document "${documentName}" found in collection (${document.document_name})`); + return; // Success - stop checking + } + } + + const elapsed = Date.now() - startTime; + if (elapsed < timeoutMs) { + return cy.wait(checkInterval).then(() => checkDocumentStatus()); + } else { + throw new Error(`Document "${documentName}" was not found in collection within ${timeoutMs}ms`); + } + }); + } + + checkDocumentStatus(); +} + +/** + * Select RAG repository in chat + */ +export function selectRagRepositoryInChat (repositoryId: string) { + cy.log(`Selecting RAG repository: ${repositoryId}`); + + // Click the RAG repository input + cy.get('input#rag-repository-autosuggest, input[placeholder*="RAG Repository" i]') + .should('be.visible') + .click({ force: true }); + + // Wait for dropdown to appear and select the repository + cy.get('[role="option"]') + .contains(repositoryId) + .should('be.visible') + .click(); +} + +/** + * Select collection in chat + * If the collection doesn't exist, selects the first available option + */ +export function selectCollectionInChat (collectionName: string) { + cy.log(`Selecting collection: ${collectionName}`); + + // Click the collection input + cy.get('input#collection-autosuggest, input[placeholder*="collection" i]') + .should('be.visible') + .click({ force: true }); + + // Wait for dropdown to appear + cy.get('[role="option"]', { timeout: 10000 }).should('be.visible'); + + // Try to find the collection by name + cy.get('body').then(($body) => { + const collectionOption = $body.find(`[role="option"]:contains("${collectionName}")`); + + if (collectionOption.length > 0) { + // Collection found - click it + cy.get('[role="option"]') + .contains(collectionName) + .click(); + } else { + // Collection not found - select first option + cy.log(`⚠️ Collection "${collectionName}" not found. Selecting first available collection.`); + cy.get('[role="option"]') + .first() + .click(); + } + }); +} + +/** + * Send a message and verify RAG response with sources + */ +export function sendMessageAndVerifyRagResponse (message: string) { + cy.log(`Sending message: ${message}`); + + // Set up intercept for chat completions API + cy.intercept('POST', '**/chat/completions').as('chatCompletion'); + + // Type the message + cy.get('textarea[placeholder*="message" i]') + .should('be.visible') + .clear() + .type(message); + + // Send the message + cy.get('button[aria-label="Send message"]') + .should('be.visible') + .and('not.be.disabled') + .click(); + + // Wait for the chat completion to finish + cy.wait('@chatCompletion', { timeout: 60000 }); + + // Wait for response to complete (look for chat messages - user + assistant) + cy.get('[data-testid="chat-message"]', { timeout: 60000 }) + .should('have.length.at.least', 2); + + // Verify source references are present in the response + cy.contains(/source|reference|citation/i, { timeout: 10000 }) + .should('be.visible'); +} + +/** + * Get the auto-created collection name and ID for a repository + * Returns both the name and ID of the auto-created collection + */ +export function getAutoCreatedCollectionInfo (repositoryId: string): Cypress.Chainable<{name: string, id: string}> { + return makeAuthenticatedRequest('GET', `/repository/${repositoryId}/collection`).then((response) => { + if (response.body && response.body.collections && Array.isArray(response.body.collections)) { + const collections = response.body.collections; + + if (collections.length === 0) { + throw new Error(`No collections found for repository ${repositoryId}`); + } + + // Find the auto-created collection (default or has dataSourceId for Bedrock KB) + const autoCreatedCollection = collections.find( + (collection: any) => collection.default === true || collection.dataSourceId !== null + ); + + if (autoCreatedCollection) { + const collectionName = autoCreatedCollection.name; + const collectionId = autoCreatedCollection.collectionId; + Cypress.log({ name: 'getAutoCreatedCollectionInfo', message: `Found auto-created collection: ${collectionName} (ID: ${collectionId})` }); + return cy.wrap({ name: collectionName, id: collectionId }); + } + + // Fallback to first collection if no default found + const firstCollection = collections[0]; + Cypress.log({ name: 'getAutoCreatedCollectionInfo', message: `Using first collection: ${firstCollection.name} (ID: ${firstCollection.collectionId})` }); + return cy.wrap({ name: firstCollection.name, id: firstCollection.collectionId }); + } + throw new Error(`Failed to fetch collections for repository ${repositoryId}`); + }); +} + +/** + * Get the auto-created collection name for a repository + * Returns the name of the auto-created collection (typically marked with default: true or has dataSourceId) + * @deprecated Use getAutoCreatedCollectionInfo instead to get both name and ID + */ +export function getAutoCreatedCollectionName (repositoryId: string): Cypress.Chainable { + return getAutoCreatedCollectionInfo(repositoryId).then((info) => info.name); +} + +/** + * Delete a collection by name (for cleanup) + */ +export function deleteCollectionIfExists (collectionName: string) { + cy.get('body').then(($body) => { + if ($body.text().includes(collectionName)) { + // Select the collection + cy.contains(collectionName) + .closest('tr, [data-testid*="collection"]') + .find('input[type="radio"], input[type="checkbox"]') + .first() + .click({ force: true }); + + // Click the Actions dropdown or Delete button + cy.get('[data-testid="collection-actions-dropdown"], button') + .contains(/actions|delete/i) + .click(); + + // Click Delete from the dropdown menu if needed + cy.get('body').then(($body) => { + if ($body.find('[role="menuitem"]').length > 0) { + cy.contains('[role="menuitem"]', 'Delete').click(); + } + }); + + // Wait for confirmation modal and click Delete button + cy.get('[data-testid="confirmation-modal-delete-btn"]', { timeout: 5000 }) + .should('be.visible') + .click(); + + cy.wait(2000); + } + }); +} diff --git a/cypress/src/support/modelFormHelpers.ts b/cypress/src/support/modelFormHelpers.ts index 3df8e952d..6f6768df3 100644 --- a/cypress/src/support/modelFormHelpers.ts +++ b/cypress/src/support/modelFormHelpers.ts @@ -26,6 +26,16 @@ export type BedrockModelConfig = { streaming?: boolean; }; +/** + * Check if a model exists in the model management list + * @returns Cypress.Chainable + */ +export function modelExists (modelId: string): Cypress.Chainable { + return cy.get('body').then(($body) => { + return $body.text().includes(modelId); + }); +} + /** * Open the Create Model wizard modal */ diff --git a/cypress/src/support/promptTemplateHelpers.ts b/cypress/src/support/promptTemplateHelpers.ts index 7051e5c67..be49aa9b8 100644 --- a/cypress/src/support/promptTemplateHelpers.ts +++ b/cypress/src/support/promptTemplateHelpers.ts @@ -19,13 +19,31 @@ * Contains reusable helpers for prompt template creation and management. */ +/** + * Prompt template types - mirrors PromptTemplateType from React app + */ +export enum PromptTemplateType { + Persona = 'persona', + Directive = 'directive' +} + export type PromptTemplateConfig = { title: string; body: string; - type?: 'system' | 'user'; + type?: PromptTemplateType; sharePublic?: boolean; }; +/** + * Check if a prompt template exists in the prompt templates list + * @returns Cypress.Chainable + */ +export function promptTemplateExists (templateTitle: string): Cypress.Chainable { + return cy.get('body').then(($body) => { + return $body.text().includes(templateTitle); + }); +} + /** * Navigate to Prompt Templates Library page */ @@ -77,7 +95,7 @@ export function fillPromptTemplateConfig (config: PromptTemplateConfig) { .should('be.visible') .click(); - const typeLabel = config.type === 'system' ? 'Persona' : 'Directive'; + const typeLabel = config.type === PromptTemplateType.Persona ? 'Persona' : 'Directive'; cy.get('[role="listbox"]') .should('be.visible') .contains('[role="option"]', typeLabel) @@ -168,54 +186,6 @@ export function deletePromptTemplateIfExists (templateTitle: string) { }); } -/** - * Select a prompt template in chat - * @param templateTitle - The title of the template to select - * @param templateType - The type of template ('system' for Persona, 'user' for Directive) - */ -export function selectPromptTemplateInChat (templateTitle: string, templateType: 'system' | 'user' = 'user') { - if (templateType === 'system') { - // For Persona templates, use the "Edit Persona" button in Additional Configuration dropdown - cy.contains('button', 'Additional Configuration') - .should('be.visible') - .click(); - - cy.contains('[role="menuitem"]', 'Edit Persona') - .should('be.visible') - .click(); - } else { - // For Directive templates, use the "Insert Prompt Template" button - cy.get('button[aria-label="Insert Prompt Template"]') - .should('be.visible') - .click(); - } - - // Wait for modal to open - cy.get('[role="dialog"]') - .should('be.visible') - .within(() => { - // Search for and select the template - cy.get('input[placeholder="Search by title"]') - .should('be.visible') - .type(templateTitle); - - // Select from the dropdown - cy.contains('[role="option"]', templateTitle) - .should('be.visible') - .click(); - - // Click the Use button - const buttonText = templateType === 'system' ? 'Use Persona' : 'Use Prompt'; - cy.contains('button', buttonText) - .should('be.visible') - .and('not.be.disabled') - .click(); - }); - - // Wait for modal to close - cy.get('[role="dialog"]').should('not.exist'); -} - /** * Send a message that's already in the input field by clicking the send button */ @@ -227,20 +197,46 @@ export function sendMessageWithButton () { } /** - * Verify that a chat response was received - * @param minMessages - Minimum number of messages expected (default: 2 for user + assistant) + * Select a prompt template in chat using the Welcome Screen buttons + * @param templateTitle - The title of the template to select + * @param templateType - The type of template (Persona or Directive) */ -export function verifyChatResponseReceived (minMessages: number = 2) { - cy.get('[data-testid="chat-message"]', { timeout: 30000 }) - .should('have.length.at.least', minMessages); -} +export function selectPromptTemplateInChat (templateTitle: string, templateType: PromptTemplateType = PromptTemplateType.Directive) { + // Use the Welcome Screen buttons (Select Persona / Select Directive) + // These are visible when there's no chat history + const isPersona = templateType === PromptTemplateType.Persona; + const selectButtonTestId = isPersona ? 'select-persona-button' : 'select-directive-button'; + const useButtonTestId = '[data-testid="use-prompt-button"]'; + const modalSelector = '[data-testid="prompt-template-modal"]'; + + // Click the Select Persona/Directive button using data-testid + cy.get(`[data-testid="${selectButtonTestId}"]`, { timeout: 10000 }) + .click(); -/** - * Select a directive prompt template, which inserts text into the message input, then send it - * @param templateTitle - The title of the directive template to select - */ -export function selectDirectiveAndSend (templateTitle: string) { - selectPromptTemplateInChat(templateTitle, 'user'); - sendMessageWithButton(); - verifyChatResponseReceived(); + // Wait for the modal to open + cy.get(modalSelector).should('be.visible'); + + // Type in the autosuggest input to search for the template + cy.get(`${modalSelector} input[placeholder="Search by title"]`) + .should('be.visible') + .type(templateTitle); + + // Wait for dropdown options to appear and select the correct one + // Avoid the first entry prefixed with "Use" by selecting the option that matches the exact title + cy.get('[role="option"]') + .filter(`:contains("${templateTitle}")`) + .not(':contains("Use")') + .first() + .should('be.visible') + .click(); + + // Click the Use Persona/Directive button using data-testid + cy.get(`${modalSelector} ${useButtonTestId}`) + .should('be.visible') + .and('not.be.disabled') + .click(); + + // Wait for modal to close and UI to stabilize + cy.get(modalSelector).should('not.be.visible'); + cy.wait(500); } diff --git a/cypress/src/support/repositoryHelpers.ts b/cypress/src/support/repositoryHelpers.ts index 233313ece..dba1143b9 100644 --- a/cypress/src/support/repositoryHelpers.ts +++ b/cypress/src/support/repositoryHelpers.ts @@ -25,6 +25,16 @@ export type RepositoryConfig = { dataSourceIndex?: number; }; +/** + * Check if a repository exists in the repository management list + * @returns Cypress.Chainable + */ +export function repositoryExists (repositoryId: string): Cypress.Chainable { + return cy.get('body').then(($body) => { + return $body.text().includes(repositoryId); + }); +} + /** * Navigate to the repository management page */ @@ -118,7 +128,10 @@ export function selectDataSource (index: number = 0) { * Skip to the create step in the repository wizard */ export function skipToCreateRepository () { - cy.contains('button', 'Skip to Create').should('be.visible').click(); + cy.contains('button', 'Skip to Create') + .scrollIntoView() + .should('be.visible') + .click(); } /** diff --git a/ecs_model_deployer/src/lib/ecs-model.ts b/ecs_model_deployer/src/lib/ecs-model.ts index f77511499..0d4b98330 100644 --- a/ecs_model_deployer/src/lib/ecs-model.ts +++ b/ecs_model_deployer/src/lib/ecs-model.ts @@ -24,9 +24,8 @@ import { getModelIdentifier } from './utils'; import { APP_MANAGEMENT_KEY, Ec2Metadata, EcsClusterConfig, EcsSourceType, PartialConfig } from '../../../lib/schema'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; -// This is the amount of memory to buffer (or subtract off) from the total instance memory, if we don't include this, -// the container can have a hard time finding available RAM resources to start and the tasks will fail deployment -const CONTAINER_MEMORY_BUFFER = 1024 * 5; +// Default memory buffer if not specified in config (2GB) +const DEFAULT_CONTAINER_MEMORY_BUFFER = 1024 * 2; /** * Properties for the EcsModel Construct. @@ -72,7 +71,7 @@ export class EcsModel extends Construct { amiHardwareType: AmiHardwareType.GPU, autoScalingConfig: modelConfig.autoScalingConfig, buildArgs: this.getBuildArguments(config, modelConfig), - containerMemoryBuffer: CONTAINER_MEMORY_BUFFER, + containerMemoryBuffer: modelConfig.containerMemoryBuffer ?? DEFAULT_CONTAINER_MEMORY_BUFFER, instanceType: modelConfig.instanceType, internetFacing: false, loadBalancerConfig: modelConfig.loadBalancerConfig, diff --git a/ecs_model_deployer/src/lib/ecsCluster.ts b/ecs_model_deployer/src/lib/ecsCluster.ts index 1618d38dd..06ca41377 100644 --- a/ecs_model_deployer/src/lib/ecsCluster.ts +++ b/ecs_model_deployer/src/lib/ecsCluster.ts @@ -19,6 +19,7 @@ import { CfnOutput, Duration, RemovalPolicy } from 'aws-cdk-lib'; import { BlockDeviceVolume, GroupMetrics, Monitoring } from 'aws-cdk-lib/aws-autoscaling'; import { Metric, Stats } from 'aws-cdk-lib/aws-cloudwatch'; import { InstanceType, ISecurityGroup, IVpc, SubnetSelection } from 'aws-cdk-lib/aws-ec2'; +import { Alias } from 'aws-cdk-lib/aws-kms'; import { Cluster, ContainerDefinition, @@ -94,11 +95,18 @@ export class ECSCluster extends Construct { containerInsightsV2: !config.region?.includes('iso') ? ContainerInsights.ENABLED : ContainerInsights.DISABLED, }); - // Create auto scaling group + // SNS encryption key for ECS lifecycle hooks (AppSec Finding #5) + const snsEncryptionKey = Alias.fromAliasName( + this, + createCdkId([identifier, 'SnsKey']), + 'alias/aws/sns' + ); + + // Create auto scaling group with SNS topic encryption for lifecycle hooks const autoScalingGroup = cluster.addCapacity(createCdkId([identifier, 'ASG']), { vpcSubnets: subnetSelection, instanceType: new InstanceType(ecsConfig.instanceType), - machineImage: EcsOptimizedImage.amazonLinux2(ecsConfig.amiHardwareType), + machineImage: EcsOptimizedImage.amazonLinux2023(ecsConfig.amiHardwareType), minCapacity: ecsConfig.autoScalingConfig.minCapacity, maxCapacity: ecsConfig.autoScalingConfig.maxCapacity, cooldown: Duration.seconds(ecsConfig.autoScalingConfig.cooldown), @@ -114,6 +122,7 @@ export class ECSCluster extends Construct { }), }, ], + topicEncryptionKey: snsEncryptionKey, }); new CfnOutput(this, 'autoScalingGroup', { @@ -273,6 +282,10 @@ export class ECSCluster extends Construct { }); // Add listener + // Note: This ALB is internal (internetFacing: false) and uses HTTP only. + // If HTTPS is enabled in the future, use SslPolicy.TLS13_RES for AppSec compliance. + // SslPolicy.TLS13_RES maps to ELBSecurityPolicy-TLS13-1-2-2021-06 + // This policy provides forward secrecy with ECDHE cipher suites and excludes RSA key exchange. const listenerProps: BaseApplicationListenerProps = { port: 80, open: false, @@ -298,8 +311,14 @@ export class ECSCluster extends Construct { }, port: 80, targets: [service], + // Slow start gives new targets time to warm up before receiving full traffic + slowStart: Duration.seconds(60), }); + // Configure target group for LLM workloads which may have long response times + // This prevents 504 Gateway Timeout errors during model inference + targetGroup.setAttribute('deregistration_delay.timeout_seconds', '30'); + // ALB metric for ASG to use for auto scaling EC2 instances // TODO: Update this to step scaling for embedding models?? const requestCountPerTargetMetric = new Metric({ diff --git a/example_config.yaml b/example_config.yaml index 173db24b2..70484107d 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -67,15 +67,17 @@ ragRepositories: [] # If adding an existing PGVector database, this configurations assumes: # 1. The database has been configured to have pgvector installed and enabled: https://aws.amazon.com/about-aws/whats-new/2023/05/amazon-rds-postgresql-pgvector-ml-model-integration/ # 2. The database is accessible by RAG-related lambda functions (add inbound PostgreSQL access on the database's security group for all Lambda RAG security groups) -# 3. A secret ID exists in SecretsManager holding the database password within a json block of '{"password":"your_password_here"}'. This is the same format that RDS natively provides a password in SecretsManager. -# If the passwordSecretId or dbHost are not provided, then a sample database will be created for you. Only the username is required. +# 3. If using password auth (iamRdsAuth: false), a secret ID exists in SecretsManager holding the database password within a json block of '{"password":"your_password_here"}'. +# If using IAM auth (default), the database must have IAM authentication enabled. +# If the dbHost is not provided, then a sample database will be created for you. # - repositoryId: pgvector-rag # type: pgvector # rdsConfig: # username: postgres -# passwordSecretId: # password ID as stored in SecretsManager. Example: "rds!db-aa88493d-be8d-4a3f-96dc-c668165f7826" +# passwordSecretId: # password ID as stored in SecretsManager (only needed if iamRdsAuth: false). Example: "rds!db-aa88493d-be8d-4a3f-96dc-c668165f7826" # dbHost: # Host name of database. Example hostname from RDS: "my-db-name.291b2f03.us-east-1.rds.amazonaws.com" # dbName: postgres +# iamRdsAuth: true # Set to false to use password-based authentication instead of IAM auth (default: false) # You can optionally provide a list of models and the deployment process will ensure they exist in your model bucket and try to download them if they don't exist # ecsModels: # - modelName: mistralai/Mistral-7B-Instruct-v0.2 diff --git a/lambda/api_tokens/handler.py b/lambda/api_tokens/handler.py index 3478989c7..66768b39f 100644 --- a/lambda/api_tokens/handler.py +++ b/lambda/api_tokens/handler.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Optional +from typing import Any from uuid import uuid4 from boto3.dynamodb.conditions import Key @@ -37,10 +37,10 @@ class CreateTokenAdminHandler: """Admin creates token for any user or system""" - def __init__(self, token_table): + def __init__(self, token_table: Any) -> None: self.token_table = token_table - def _get_user_token(self, username: str) -> Optional[dict]: + def _get_user_token(self, username: str) -> dict | None: """Query for existing token by username using GSI""" response = self.token_table.query( IndexName="username-index", KeyConditionExpression=Key("username").eq(username), Limit=1 @@ -48,7 +48,9 @@ def _get_user_token(self, username: str) -> Optional[dict]: items = response.get("Items", []) return items[0] if items else None - def __call__(self, username: str, request: CreateTokenAdminRequest, created_by: str, is_admin: bool): + def __call__( + self, username: str, request: CreateTokenAdminRequest, created_by: str, is_admin: bool + ) -> CreateTokenResponse: # Authorization: Only admins can create tokens for other users if not is_admin: raise UnauthorizedError("Only admins can create tokens for other users") @@ -97,10 +99,10 @@ def __call__(self, username: str, request: CreateTokenAdminRequest, created_by: class CreateTokenUserHandler: """User creates their own token""" - def __init__(self, token_table): + def __init__(self, token_table: Any) -> None: self.token_table = token_table - def _get_user_token(self, username: str) -> Optional[dict]: + def _get_user_token(self, username: str) -> dict | None: """Query for existing token by username using GSI""" response = self.token_table.query( IndexName="username-index", KeyConditionExpression=Key("username").eq(username), Limit=1 @@ -110,7 +112,7 @@ def _get_user_token(self, username: str) -> Optional[dict]: def __call__( self, request: CreateTokenUserRequest, username: str, user_groups: list[str], is_admin: bool, is_api_user: bool - ): + ) -> CreateTokenResponse: # Authorization: User must be admin or in apiGroup if not is_admin and not is_api_user: raise ForbiddenError("User must be in the API group to create tokens") @@ -156,7 +158,7 @@ def __call__( class ListTokensHandler: """List tokens - admins see all, users see only their own""" - def __init__(self, token_table): + def __init__(self, token_table: Any) -> None: self.token_table = token_table def __call__(self, username: str, is_admin: bool) -> ListTokensResponse: @@ -203,7 +205,7 @@ def __call__(self, username: str, is_admin: bool) -> ListTokensResponse: class GetTokenHandler: """Get specific token details""" - def __init__(self, token_table): + def __init__(self, token_table: Any) -> None: self.token_table = token_table def __call__(self, token_uuid: str, username: str, is_admin: bool) -> TokenInfo: @@ -263,7 +265,7 @@ def __call__(self, token_uuid: str, username: str, is_admin: bool) -> TokenInfo: class DeleteTokenHandler: """Delete token - handles both modern and legacy tokens""" - def __init__(self, token_table): + def __init__(self, token_table: Any) -> None: self.token_table = token_table def __call__(self, token_uuid: str, username: str, is_admin: bool) -> DeleteTokenResponse: diff --git a/lambda/api_tokens/lambda_functions.py b/lambda/api_tokens/lambda_functions.py index 9e686e353..94c90343b 100644 --- a/lambda/api_tokens/lambda_functions.py +++ b/lambda/api_tokens/lambda_functions.py @@ -13,19 +13,17 @@ # limitations under the License. """APIGW endpoints for managing API tokens.""" +import logging import os -from typing import Annotated, Union +from typing import Annotated import boto3 -from fastapi import FastAPI, HTTPException, Path, Request -from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import RequestValidationError -from fastapi.middleware.cors import CORSMiddleware +from fastapi import HTTPException, Path, Request from fastapi.responses import JSONResponse from mangum import Mangum from utilities.auth import get_user_context, is_api_user from utilities.common_functions import retry_config -from utilities.fastapi_middleware.aws_api_gateway_middleware import AWSAPIGatewayMiddleware +from utilities.fastapi_factory import create_fastapi_app from .domain_objects import ( CreateTokenAdminRequest, @@ -35,7 +33,7 @@ ListTokensResponse, TokenInfo, ) -from .exception import ForbiddenError, TokenAlreadyExistsError, TokenNotFoundError, UnauthorizedError +from .exception import TokenAlreadyExistsError, TokenNotFoundError from .handler import ( CreateTokenAdminHandler, CreateTokenUserHandler, @@ -44,17 +42,9 @@ ListTokensHandler, ) -app = FastAPI(redirect_slashes=False, lifespan="off", docs_url="/docs", openapi_url="/openapi.json") -app.add_middleware(AWSAPIGatewayMiddleware) +logger = logging.getLogger(__name__) -# Enable CORS -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], -) +app = create_fastapi_app() # Initialize boto3 resources dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -67,29 +57,9 @@ async def token_not_found_handler(request: Request, exc: TokenNotFoundError) -> return JSONResponse(status_code=404, content={"message": str(exc)}) -@app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: - """Handle exception when request fails validation and translate to a 422 error.""" - return JSONResponse( - status_code=422, content={"detail": jsonable_encoder(exc.errors()), "type": "RequestValidationError"} - ) - - -@app.exception_handler(UnauthorizedError) -async def unauthorized_handler(request: Request, exc: UnauthorizedError) -> JSONResponse: - """Handle unauthorized access attempts and translate to a 401 error.""" - return JSONResponse(status_code=401, content={"message": str(exc)}) - - -@app.exception_handler(ForbiddenError) -async def forbidden_handler(request: Request, exc: ForbiddenError) -> JSONResponse: - """Handle forbidden access attempts and translate to a 403 error.""" - return JSONResponse(status_code=403, content={"message": str(exc)}) - - @app.exception_handler(TokenAlreadyExistsError) @app.exception_handler(ValueError) -async def user_error_handler(request: Request, exc: Union[TokenAlreadyExistsError, ValueError]) -> JSONResponse: +async def user_error_handler(request: Request, exc: TokenAlreadyExistsError | ValueError) -> JSONResponse: """Handle errors when customer requests options that cannot be processed.""" return JSONResponse(status_code=400, content={"message": str(exc)}) diff --git a/lambda/authorizer/lambda_functions.py b/lambda/authorizer/lambda_functions.py index 243048a99..fc7be3e2e 100644 --- a/lambda/authorizer/lambda_functions.py +++ b/lambda/authorizer/lambda_functions.py @@ -18,7 +18,7 @@ import logging import os import ssl -from typing import Any, Dict +from typing import Any import boto3 import create_env_variables # noqa: F401 @@ -38,7 +38,7 @@ @authorization_wrapper -def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: ignore [no-untyped-def] +def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """Handle authorization for REST API.""" logger.info("REST API authorization handler started") @@ -108,7 +108,7 @@ def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: i return deny_policy -def generate_policy(*, effect: str, resource: str, username: str = "username") -> Dict[str, Any]: +def generate_policy(*, effect: str, resource: str, username: str = "username") -> dict[str, Any]: """Generate IAM policy.""" policy = { "principalId": username, @@ -159,10 +159,10 @@ def is_valid_api_token(token: str) -> dict | None: logger.info(f"Token expired at {token_expiration}") return None - return token_info + return token_info # type: ignore[no-any-return] -def id_token_is_valid(*, id_token: str, client_id: str, authority: str) -> Dict[str, Any] | None: +def id_token_is_valid(*, id_token: str, client_id: str, authority: str) -> dict[str, Any] | None: """Check whether an ID token is valid and return decoded data.""" if not jwt.algorithms.has_crypto: logger.error("No crypto support for JWT, please install the cryptography dependency") diff --git a/lambda/configuration/lambda_functions.py b/lambda/configuration/lambda_functions.py index 86db08b6c..15dc33378 100644 --- a/lambda/configuration/lambda_functions.py +++ b/lambda/configuration/lambda_functions.py @@ -18,7 +18,7 @@ import os import time from decimal import Decimal -from typing import Any, Dict +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -33,15 +33,15 @@ @api_wrapper -def get_configuration(event: dict, context: dict) -> Dict[str, Any]: +def get_configuration(event: dict, context: dict) -> dict[str, Any]: """List configuration entries by configScope from DynamoDB.""" config_scope = event["queryStringParameters"]["configScope"] - return _get_configurations(config_scope) + return _get_configurations(config_scope) # type: ignore[return-value] def _get_configurations(config_scope: str) -> list[dict[str, Any]]: - response = {} + response: dict[str, Any] = {} try: response = table.query( KeyConditionExpression="#s = :configScope", @@ -55,11 +55,12 @@ def _get_configurations(config_scope: str) -> list[dict[str, Any]]: else: logger.exception("Error fetching session") - return response.get("Items", []) # type: ignore [no-any-return] + items = response.get("Items", []) + return items if isinstance(items, list) else [] @api_wrapper -def update_configuration(event: dict, context: dict) -> None: +def update_configuration(event: dict, context: dict) -> dict[str, str]: """Update configuration in DynamoDB.""" # from https://stackoverflow.com/a/71446846 body = json.loads(event["body"], parse_float=Decimal) @@ -74,9 +75,12 @@ def update_configuration(event: dict, context: dict) -> None: table.put_item(Item=body) except ClientError: logger.exception("Error updating session in DynamoDB") + raise + return {"status": "ok"} -def check_show_mcp_workbench(body, old_configuration): + +def check_show_mcp_workbench(body: dict[str, Any], old_configuration: dict[str, Any]) -> None: old_show_mcp_value = get_property_path(old_configuration, "configuration.enabledComponents.showMcpWorkbench") new_show_mcp_value = get_property_path(body, "configuration.enabledComponents.showMcpWorkbench") diff --git a/lambda/dockerimagebuilder/__init__.py b/lambda/dockerimagebuilder/__init__.py index d0d1d4186..80b371e64 100644 --- a/lambda/dockerimagebuilder/__init__.py +++ b/lambda/dockerimagebuilder/__init__.py @@ -16,7 +16,7 @@ import os import shlex import uuid -from typing import Any, Dict +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -87,7 +87,7 @@ """ -def handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: ignore [no-untyped-def] +def handler(event: dict[str, Any], context) -> dict[str, Any]: # type: ignore [no-untyped-def] logger.info(f"Starting Docker image builder with event: {event}") base_image = event["base_image"] @@ -99,7 +99,7 @@ def handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: ignore [ ec2_resource = boto3.resource("ec2", region_name=os.environ["AWS_REGION"]) ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"]) - response = ssm_client.get_parameter(Name="/aws/service/ami-amazon-linux-latest/amzn2-ami-hvm-x86_64-gp2") + response = ssm_client.get_parameter(Name="/aws/service/ami-amazon-linux-latest/al2023-ami-kernel-default-x86_64") ami_id = response["Parameter"]["Value"] image_tag = str(uuid.uuid4()) @@ -125,7 +125,13 @@ def handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: ignore [ "UserData": rendered_userdata, "IamInstanceProfile": {"Arn": os.environ["LISA_INSTANCE_PROFILE"]}, "BlockDeviceMappings": [ - {"DeviceName": "/dev/xvda", "Ebs": {"VolumeSize": int(os.environ["LISA_IMAGEBUILDER_VOLUME_SIZE"])}} + { + "DeviceName": "/dev/xvda", + "Ebs": { + "VolumeSize": int(os.environ["LISA_IMAGEBUILDER_VOLUME_SIZE"]), + "Encrypted": True, + }, + } ], "TagSpecifications": [ { diff --git a/lambda/management_key.py b/lambda/management_key.py index c1013b810..55ba3abc8 100644 --- a/lambda/management_key.py +++ b/lambda/management_key.py @@ -18,7 +18,7 @@ import logging import os import string -from typing import Any, Dict +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -33,7 +33,7 @@ events_client = boto3.client("events", region_name=os.environ["AWS_REGION"], config=retry_config) -def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """ AWS Secrets Manager rotation handler for management key. @@ -180,7 +180,7 @@ def finish_secret(secret_arn: str, token: str) -> None: raise -def publish_rotation_event(secret_arn: str, new_version: str, old_version: str) -> None: +def publish_rotation_event(secret_arn: str, new_version: str, old_version: str | None) -> None: """ Publish a management key rotation event to EventBridge. """ diff --git a/lambda/mcp_server/lambda_functions.py b/lambda/mcp_server/lambda_functions.py index 30cf822d5..36b614b83 100644 --- a/lambda/mcp_server/lambda_functions.py +++ b/lambda/mcp_server/lambda_functions.py @@ -13,6 +13,8 @@ # limitations under the License. """Lambda functions for managing MCP Servers in AWS DynamoDB.""" +from __future__ import annotations + import json import logging import os @@ -20,7 +22,7 @@ import uuid from decimal import Decimal from functools import reduce -from typing import Any, Dict, List, Optional +from typing import Any import boto3 from boto3.dynamodb.conditions import Attr, Key @@ -48,7 +50,7 @@ def _normalize_server_name(name: str) -> str: return re.sub(r"[^a-zA-Z0-9]", "", name) -def replace_bearer_token_header(mcp_server: dict, replacement: str): +def replace_bearer_token_header(mcp_server: dict, replacement: str) -> None: """Replace {LISA_BEARER_TOKEN} placeholder with actual bearer token in custom headers.""" custom_headers = mcp_server.get("customHeaders", {}) for key, value in custom_headers.items(): @@ -56,7 +58,7 @@ def replace_bearer_token_header(mcp_server: dict, replacement: str): custom_headers[key] = value.replace("{LISA_BEARER_TOKEN}", replacement) -def _build_groups_condition(groups: List[str]) -> Any: +def _build_groups_condition(groups: list[str]) -> Any: """Build DynamoDB condition for groups filtering.""" # Servers with no groups (groups attribute doesn't exist, is null, or is empty array) should be included no_groups_condition = Attr("groups").not_exists() | Attr("groups").eq(None) | Attr("groups").eq([]) @@ -70,11 +72,11 @@ def _build_groups_condition(groups: List[str]) -> Any: def _get_mcp_servers( - user_id: Optional[str] = None, - active: Optional[bool] = None, - replace_bearer_token: Optional[str] = None, - groups: Optional[List] = None, -) -> Dict[str, Any]: + user_id: str | None = None, + active: bool | None = None, + replace_bearer_token: str | None = None, + groups: list[str] | None = None, +) -> dict[str, Any]: """Helper function to retrieve mcp servers from DynamoDB.""" filter_expression = None condition = None @@ -123,7 +125,7 @@ def _get_mcp_servers( condition = _build_groups_condition(groups) filter_expression = condition if filter_expression is None else filter_expression & condition - scan_arguments = { + scan_arguments: dict[str, Any] = { "TableName": os.environ["MCP_SERVERS_TABLE_NAME"], "IndexName": os.environ["MCP_SERVERS_BY_OWNER_INDEX_NAME"], } @@ -188,17 +190,17 @@ def get(event: dict, context: dict) -> Any: raise ValueError(f"Not authorized to get {mcp_server_id}.") -def _is_member(user_groups: List[str], prompt_groups: List[str]) -> bool: +def _is_member(user_groups: list[str], prompt_groups: list[str]) -> bool: return bool(set(user_groups) & set(prompt_groups)) def _set_can_use( - connections: Dict[str, Any], user_id: Optional[str] = None, groups: Optional[List[str]] = None -) -> Dict[str, Any]: + connections: dict[str, Any], user_id: str | None = None, groups: list[str] | None = None +) -> dict[str, Any]: if groups is None: groups = [] items = connections.get("Items", []) - formatted_groups = [f"group:{group}" for group in groups] + formatted_groups: list[str] = [f"group:{group}" for group in groups] for item in items: item["canUse"] = ( _is_member(formatted_groups, item.get("groups", [])) @@ -210,7 +212,7 @@ def _set_can_use( @api_wrapper -def list(event: dict, context: dict) -> Dict[str, Any]: +def list_mcp_servers(event: dict, context: dict) -> dict[str, Any]: """List mcp servers for a user from DynamoDB.""" user_id, is_admin_user, groups = get_user_context(event) @@ -275,7 +277,7 @@ def update(event: dict, context: dict) -> Any: @api_wrapper -def delete(event: dict, context: dict) -> Dict[str, str]: +def delete(event: dict, context: dict) -> dict[str, str]: """Logically delete a mcp server from DynamoDB.""" user_id, is_admin_user, _ = get_user_context(event) mcp_server_id = get_mcp_server_id(event) @@ -322,7 +324,7 @@ def create_hosted_mcp_server(event: dict, context: dict) -> Any: # Scan all items to check for duplicate normalized names items = [] - scan_arguments = {} + scan_arguments: dict[str, Any] = {} while True: response = table.scan(**scan_arguments) items.extend(response.get("Items", [])) @@ -362,7 +364,7 @@ def create_hosted_mcp_server(event: dict, context: dict) -> Any: @api_wrapper @admin_only -def list_hosted_mcp_servers(event: dict, context: dict) -> Dict[str, Any]: +def list_hosted_mcp_servers(event: dict, context: dict) -> dict[str, Any]: """List all hosted MCP servers from DynamoDB.""" user_id, is_admin_user, groups = get_user_context(event) @@ -371,7 +373,7 @@ def list_hosted_mcp_servers(event: dict, context: dict) -> Dict[str, Any]: logger.info(f"Listing all hosted MCP servers for user {user_id} (is_admin)") # Get all items from the table items = [] - scan_arguments = {} + scan_arguments: dict[str, Any] = {} while True: response = table.scan(**scan_arguments) items.extend(response.get("Items", [])) diff --git a/lambda/mcp_server/models.py b/lambda/mcp_server/models.py index d4ddb45cf..f0fb5154f 100644 --- a/lambda/mcp_server/models.py +++ b/lambda/mcp_server/models.py @@ -14,10 +14,9 @@ import uuid from enum import StrEnum -from typing import Dict, List, Optional, Union +from typing import Self from pydantic import BaseModel, Field, field_validator, model_validator -from typing_extensions import Self from utilities.time import iso_string from utilities.validation import validate_any_fields_defined @@ -49,10 +48,10 @@ class McpServerModel(BaseModel): """ # Unique identifier for the mcp server - id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4())) + id: str | None = Field(default_factory=lambda: str(uuid.uuid4())) # Timestamp of when the mcp server was created - created: Optional[str] = Field(default_factory=iso_string) + created: str | None = Field(default_factory=iso_string) # Owner of the MCP user owner: str @@ -64,19 +63,19 @@ class McpServerModel(BaseModel): name: str # Description of the MCP server - description: Optional[str] = Field(default_factory=lambda: None) + description: str | None = Field(default_factory=lambda: None) # Custom headers for the MCP client - customHeaders: Optional[dict] = Field(default_factory=lambda: None) + customHeaders: dict | None = Field(default_factory=lambda: None) # Custom client properties for the MCP client - clientConfig: Optional[dict] = Field(default_factory=lambda: None) + clientConfig: dict | None = Field(default_factory=lambda: None) # Status of the server set by admins - status: Optional[McpServerStatus] = Field(default=McpServerStatus.ACTIVE) + status: McpServerStatus | None = Field(default=McpServerStatus.ACTIVE) # Groups of the MCP server - groups: Optional[List[str]] = Field(default_factory=lambda: None) + groups: list[str] | None = Field(default_factory=lambda: None) class LoadBalancerHealthCheckConfig(BaseModel): @@ -98,7 +97,7 @@ class LoadBalancerConfig(BaseModel): class ContainerHealthCheckConfig(BaseModel): """Specifies container health check parameters.""" - command: Union[str, List[str]] + command: str | list[str] interval: int = Field(gt=0) startPeriod: int = Field(ge=0) timeout: int = Field(gt=0) @@ -110,21 +109,21 @@ class AutoScalingConfig(BaseModel): minCapacity: int maxCapacity: int - targetValue: Optional[int] = Field(default=None) - metricName: Optional[str] = Field(default=None) - duration: Optional[int] = Field(default=None) - cooldown: Optional[int] = Field(default=None) + targetValue: int | None = Field(default=None) + metricName: str | None = Field(default=None) + duration: int | None = Field(default=None) + cooldown: int | None = Field(default=None) class AutoScalingConfigUpdate(BaseModel): """Updatable auto-scaling configuration for hosted MCP servers (all fields optional).""" - minCapacity: Optional[int] = Field(default=None) - maxCapacity: Optional[int] = Field(default=None) - targetValue: Optional[int] = Field(default=None) - metricName: Optional[str] = Field(default=None) - duration: Optional[int] = Field(default=None) - cooldown: Optional[int] = Field(default=None) + minCapacity: int | None = Field(default=None) + maxCapacity: int | None = Field(default=None) + targetValue: int | None = Field(default=None) + metricName: str | None = Field(default=None) + duration: int | None = Field(default=None) + cooldown: int | None = Field(default=None) class HostedMcpServerModel(BaseModel): @@ -134,10 +133,10 @@ class HostedMcpServerModel(BaseModel): """ # Unique identifier for the mcp server - id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4())) + id: str | None = Field(default_factory=lambda: str(uuid.uuid4())) # Timestamp of when the mcp server was created - created: Optional[str] = Field(default_factory=iso_string) + created: str | None = Field(default_factory=iso_string) # Owner of the MCP server owner: str @@ -146,13 +145,13 @@ class HostedMcpServerModel(BaseModel): name: str # Description of the MCP server - description: Optional[str] = Field(default_factory=lambda: None) + description: str | None = Field(default_factory=lambda: None) # Command to start the server startCommand: str # Port number (optional, used for HTTP/SSE servers) - port: Optional[int] = Field(default=None) + port: int | None = Field(default=None) # Server type: 'stdio', 'http', or 'sse' serverType: str @@ -160,56 +159,56 @@ class HostedMcpServerModel(BaseModel): # Container image (optional) # If provided without s3Path: use as pre-built container image # If provided with s3Path: use as base image for building from S3 artifacts - image: Optional[str] = Field(default=None) + image: str | None = Field(default=None) # S3 path to server artifacts (binaries, Python files, etc.) # If provided with image: image is used as base image for building # If provided without image: default base image is used - s3Path: Optional[str] = Field(default=None) + s3Path: str | None = Field(default=None) # Auto-scaling configuration autoScalingConfig: AutoScalingConfig # Load balancer configuration (optional, will use defaults if not provided) - loadBalancerConfig: Optional[LoadBalancerConfig] = Field(default=None) + loadBalancerConfig: LoadBalancerConfig | None = Field(default=None) # Container health check configuration (optional, will use defaults if not provided) - containerHealthCheckConfig: Optional[ContainerHealthCheckConfig] = Field(default=None) + containerHealthCheckConfig: ContainerHealthCheckConfig | None = Field(default=None) # Environment variables for the container - environment: Optional[Dict[str, str]] = Field(default_factory=lambda: None) + environment: dict[str, str] | None = Field(default_factory=lambda: None) # IAM role ARN for task execution (optional, will be auto-created if not provided) - taskExecutionRoleArn: Optional[str] = Field(default=None) + taskExecutionRoleArn: str | None = Field(default=None) # IAM role ARN for running tasks (optional, will be auto-created if not provided) - taskRoleArn: Optional[str] = Field(default=None) + taskRoleArn: str | None = Field(default=None) # Fargate CPU units (defaults to 256 which equals 0.25 vCPU) - cpu: Optional[int] = Field(default=256) + cpu: int | None = Field(default=256) # Fargate memory limit in MiB (defaults to 512 MiB) - memoryLimitMiB: Optional[int] = Field(default=512) + memoryLimitMiB: int | None = Field(default=512) # Groups of the MCP server (for authorization) - groups: Optional[List[str]] = Field(default_factory=lambda: None) + groups: list[str] | None = Field(default_factory=lambda: None) # Status of the server - status: Optional[HostedMcpServerStatus] = Field(default=HostedMcpServerStatus.CREATING) + status: HostedMcpServerStatus | None = Field(default=HostedMcpServerStatus.CREATING) class UpdateHostedMcpServerRequest(BaseModel): """Specifies parameters for hosted MCP server update requests.""" - enabled: Optional[bool] = None - autoScalingConfig: Optional[AutoScalingConfigUpdate] = None - environment: Optional[Dict[str, str]] = None - containerHealthCheckConfig: Optional[ContainerHealthCheckConfig] = None - loadBalancerConfig: Optional[LoadBalancerConfig] = None - cpu: Optional[int] = None - memoryLimitMiB: Optional[int] = None - description: Optional[str] = None - groups: Optional[List[str]] = None + enabled: bool | None = None + autoScalingConfig: AutoScalingConfigUpdate | None = None + environment: dict[str, str] | None = None + containerHealthCheckConfig: ContainerHealthCheckConfig | None = None + loadBalancerConfig: LoadBalancerConfig | None = None + cpu: int | None = None + memoryLimitMiB: int | None = None + description: str | None = None + groups: list[str] | None = None @model_validator(mode="after") def validate_update_request(self) -> Self: @@ -235,7 +234,7 @@ def validate_update_request(self) -> Self: @field_validator("autoScalingConfig") @classmethod - def validate_autoscaling_config(cls, config: Optional[AutoScalingConfig]) -> Optional[AutoScalingConfig]: + def validate_autoscaling_config(cls, config: AutoScalingConfig | None) -> AutoScalingConfig | None: """Validates auto-scaling configuration.""" if config is not None and not config: raise ValueError("The autoScalingConfig must not be null if defined in request payload.") @@ -244,8 +243,8 @@ def validate_autoscaling_config(cls, config: Optional[AutoScalingConfig]) -> Opt @field_validator("containerHealthCheckConfig") @classmethod def validate_container_health_check_config( - cls, config: Optional[ContainerHealthCheckConfig] - ) -> Optional[ContainerHealthCheckConfig]: + cls, config: ContainerHealthCheckConfig | None + ) -> ContainerHealthCheckConfig | None: """Validates container health check configuration.""" if config is not None and not config: raise ValueError("The containerHealthCheckConfig must not be null if defined in request payload.") @@ -253,7 +252,7 @@ def validate_container_health_check_config( @field_validator("loadBalancerConfig") @classmethod - def validate_load_balancer_config(cls, config: Optional[LoadBalancerConfig]) -> Optional[LoadBalancerConfig]: + def validate_load_balancer_config(cls, config: LoadBalancerConfig | None) -> LoadBalancerConfig | None: """Validates load balancer configuration.""" if config is not None and not config: raise ValueError("The loadBalancerConfig must not be null if defined in request payload.") @@ -261,7 +260,7 @@ def validate_load_balancer_config(cls, config: Optional[LoadBalancerConfig]) -> @field_validator("cpu") @classmethod - def validate_cpu(cls, cpu: Optional[int]) -> Optional[int]: + def validate_cpu(cls, cpu: int | None) -> int | None: """Validates CPU units.""" if cpu is not None: # Fargate CPU must be in valid units: 256, 512, 1024, 2048, 4096 @@ -272,7 +271,7 @@ def validate_cpu(cls, cpu: Optional[int]) -> Optional[int]: @field_validator("memoryLimitMiB") @classmethod - def validate_memory(cls, memory: Optional[int]) -> Optional[int]: + def validate_memory(cls, memory: int | None) -> int | None: """Validates memory limit.""" if memory is not None: if memory < 512: diff --git a/lambda/mcp_server/state_machine/create_mcp_server.py b/lambda/mcp_server/state_machine/create_mcp_server.py index 0eb5a1cf7..7d65a36dd 100644 --- a/lambda/mcp_server/state_machine/create_mcp_server.py +++ b/lambda/mcp_server/state_machine/create_mcp_server.py @@ -19,7 +19,7 @@ import os import re from copy import deepcopy -from typing import Any, Dict, Optional +from typing import Any import boto3 from botocore.config import Config @@ -39,7 +39,7 @@ MAX_POLLS = 60 -def handle_set_server_to_creating(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_set_server_to_creating(event: dict[str, Any], context: Any) -> dict[str, Any]: """Set DDB entry to CREATING status.""" logger.info(f"Setting MCP server to CREATING status: {event.get('id')}") output_dict = deepcopy(event) @@ -63,7 +63,7 @@ def handle_set_server_to_creating(event: Dict[str, Any], context: Any) -> Dict[s return output_dict -def handle_deploy_server(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_deploy_server(event: dict[str, Any], context: Any) -> dict[str, Any]: """Invoke MCP server deployer to create infrastructure.""" logger.info(f"Deploying MCP server: {event.get('id')}") output_dict = deepcopy(event) @@ -125,7 +125,7 @@ def handle_deploy_server(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict -def handle_poll_deployment(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_poll_deployment(event: dict[str, Any], context: Any) -> dict[str, Any]: """Poll CloudFormation stack status.""" logger.info(f"Polling deployment status for stack: {event.get('stack_name')}") output_dict = deepcopy(event) @@ -174,11 +174,11 @@ def handle_poll_deployment(event: Dict[str, Any], context: Any) -> Dict[str, Any return output_dict -def _get_mcp_connections_table_name(deployment_prefix: str) -> Optional[str]: +def _get_mcp_connections_table_name(deployment_prefix: str) -> str | None: """Get MCP connections table name from SSM parameter if chat is deployed.""" try: response = ssmClient.get_parameter(Name=f"{deployment_prefix}/table/mcpServersTable") - return response["Parameter"]["Value"] + return response["Parameter"]["Value"] # type: ignore[no-any-return] except ssmClient.exceptions.ParameterNotFound: logger.info("MCP connections table SSM parameter not found, chat may not be deployed") return None @@ -187,11 +187,11 @@ def _get_mcp_connections_table_name(deployment_prefix: str) -> Optional[str]: return None -def _get_api_gateway_url(deployment_prefix: str) -> Optional[str]: +def _get_api_gateway_url(deployment_prefix: str) -> str | None: """Get API Gateway base URL from SSM parameter.""" try: response = ssmClient.get_parameter(Name=f"{deployment_prefix}/LisaApiUrl") - return response["Parameter"]["Value"] + return response["Parameter"]["Value"] # type: ignore[no-any-return] except Exception as e: logger.warning(f"Error getting API Gateway URL: {str(e)}") return None @@ -202,7 +202,7 @@ def _normalize_server_identifier(server_id: str) -> str: return re.sub(r"[^a-zA-Z0-9]", "", server_id) -def handle_add_server_to_active(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_add_server_to_active(event: dict[str, Any], context: Any) -> dict[str, Any]: """Set server status to IN_SERVICE after successful deployment.""" logger.info(f"Setting MCP server to IN_SERVICE: {event.get('id')}") output_dict = deepcopy(event) @@ -233,7 +233,7 @@ def handle_add_server_to_active(event: Dict[str, Any], context: Any) -> Dict[str if mcp_connections_table_name: try: api_gateway_url = _get_api_gateway_url(deployment_prefix) - if api_gateway_url: + if api_gateway_url and name: # Normalize server ID to match what CDK uses for resource naming normalized_id = _normalize_server_identifier(name) # Construct API Gateway URL for the hosted server @@ -281,7 +281,7 @@ def handle_add_server_to_active(event: Dict[str, Any], context: Any) -> Dict[str return output_dict -def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_failure(event: dict[str, Any], context: Any) -> dict[str, Any]: """Handle failure in the state machine.""" logger.error(f"Handling MCP server creation failure: {event}") diff --git a/lambda/mcp_server/state_machine/delete_mcp_server.py b/lambda/mcp_server/state_machine/delete_mcp_server.py index 4ac074173..3cfdd0a99 100644 --- a/lambda/mcp_server/state_machine/delete_mcp_server.py +++ b/lambda/mcp_server/state_machine/delete_mcp_server.py @@ -17,7 +17,7 @@ import logging import os from copy import deepcopy -from typing import Any, Dict, Optional +from typing import Any from uuid import uuid4 import boto3 @@ -41,11 +41,11 @@ STACK_ARN = "cloudformation_stack_arn" -def _get_mcp_connections_table_name(deployment_prefix: str) -> Optional[str]: +def _get_mcp_connections_table_name(deployment_prefix: str) -> str | None: """Get MCP connections table name from SSM parameter if chat is deployed.""" try: response = ssmClient.get_parameter(Name=f"{deployment_prefix}/table/mcpServersTable") - return response["Parameter"]["Value"] + return response["Parameter"]["Value"] # type: ignore[no-any-return] except ssmClient.exceptions.ParameterNotFound: logger.info("MCP connections table SSM parameter not found, chat may not be deployed") return None @@ -54,7 +54,7 @@ def _get_mcp_connections_table_name(deployment_prefix: str) -> Optional[str]: return None -def handle_set_server_to_deleting(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_set_server_to_deleting(event: dict[str, Any], context: Any) -> dict[str, Any]: """Start deletion workflow based on user-specified server input.""" output_dict = deepcopy(event) server_id = event["id"] @@ -88,7 +88,7 @@ def handle_set_server_to_deleting(event: Dict[str, Any], context: Any) -> Dict[s return output_dict -def handle_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_delete_stack(event: dict[str, Any], context: Any) -> dict[str, Any]: """Initialize stack deletion.""" output_dict = deepcopy(event) stack_arn = event.get(STACK_ARN) @@ -106,7 +106,7 @@ def handle_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict -def handle_monitor_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_monitor_delete_stack(event: dict[str, Any], context: Any) -> dict[str, Any]: """Get stack status while it is being deleted and evaluate if state machine should continue polling.""" output_dict = deepcopy(event) # Prefer ARN if available, fall back to stack name @@ -144,7 +144,7 @@ def handle_monitor_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str return output_dict -def handle_delete_from_ddb(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_delete_from_ddb(event: dict[str, Any], context: Any) -> dict[str, Any]: """Delete item from DDB after successful deletion workflow and remove from connections table.""" server_id = event["id"] server_key = {"id": server_id} diff --git a/lambda/mcp_server/state_machine/update_mcp_server.py b/lambda/mcp_server/state_machine/update_mcp_server.py index 6c77213ab..11a60461f 100644 --- a/lambda/mcp_server/state_machine/update_mcp_server.py +++ b/lambda/mcp_server/state_machine/update_mcp_server.py @@ -17,8 +17,9 @@ import logging import os import re +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional +from typing import Any import boto3 from boto3.dynamodb.conditions import Attr @@ -42,11 +43,11 @@ MAX_POLLS = 30 -def _get_mcp_connections_table_name(deployment_prefix: str) -> Optional[str]: +def _get_mcp_connections_table_name(deployment_prefix: str) -> str | None: """Get MCP connections table name from SSM parameter if chat is deployed.""" try: response = ssm_client.get_parameter(Name=f"{deployment_prefix}/table/mcpServersTable") - return response["Parameter"]["Value"] + return response["Parameter"]["Value"] # type: ignore[no-any-return] except ssm_client.exceptions.ParameterNotFound: logger.info("MCP connections table SSM parameter not found, chat may not be deployed") return None @@ -60,15 +61,15 @@ def _normalize_server_identifier(server_id: str) -> str: return re.sub(r"[^a-zA-Z0-9]", "", server_id) -def _update_simple_field(server_config: Dict[str, Any], field_name: str, value: Any, server_id: str) -> None: +def _update_simple_field(server_config: dict[str, Any], field_name: str, value: Any, server_id: str) -> None: """Update a simple field in server_config.""" logger.info(f"Setting {field_name} to '{value}' for server '{server_id}'") server_config[field_name] = value def _update_container_config( - server_config: Dict[str, Any], container_config: Dict[str, Any], server_id: str -) -> Dict[str, Any]: + server_config: dict[str, Any], container_config: dict[str, Any], server_id: str +) -> dict[str, Any]: """Handle container config update. Returns: @@ -137,7 +138,7 @@ def _update_container_config( return container_metadata -def _get_metadata_update_handlers(server_config: Dict[str, Any], server_id: str) -> Dict[str, Callable[..., Any]]: +def _get_metadata_update_handlers(server_config: dict[str, Any], server_id: str) -> dict[str, Callable[..., Any]]: """Return a dictionary mapping field names to their update handlers.""" return { "description": lambda value: _update_simple_field(server_config, "description", value, server_id), @@ -155,8 +156,8 @@ def _get_metadata_update_handlers(server_config: Dict[str, Any], server_id: str) def _process_metadata_updates( - server_config: Dict[str, Any], update_payload: Dict[str, Any], server_id: str -) -> tuple[bool, Dict[str, Any]]: + server_config: dict[str, Any], update_payload: dict[str, Any], server_id: str +) -> tuple[bool, dict[str, Any]]: """ Process metadata updates. @@ -234,7 +235,7 @@ def _update_mcp_connections_table_status(server_id: str, status: str) -> None: def _update_mcp_connections_table_metadata( - server_id: str, description: Optional[str] = None, groups: Optional[List[str]] = None + server_id: str, description: str | None = None, groups: list[str] | None = None ) -> None: """Update MCP Connections table metadata (description, groups) for a server.""" deployment_prefix = os.environ.get("DEPLOYMENT_PREFIX", "") @@ -252,7 +253,7 @@ def _update_mcp_connections_table_metadata( return # Format groups with "group:" prefix if not already present - formatted_groups: Optional[List[str]] = None + formatted_groups: list[str] | None = None if groups is not None: formatted_groups = [] for group in groups: @@ -268,8 +269,8 @@ def _update_mcp_connections_table_metadata( for item in response.get("Items", []): update_expression_parts = [] - expr_attr_names: Dict[str, str] = {} - expr_attr_values: Dict[str, Any] = {} + expr_attr_names: dict[str, str] = {} + expr_attr_values: dict[str, Any] = {} if description is not None: update_expression_parts.append("#d = :desc") @@ -299,7 +300,7 @@ def _update_mcp_connections_table_metadata( # Don't fail the update if connection table update fails -def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_job_intake(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Handle initial UpdateMcpServer job submission. @@ -484,12 +485,12 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: # Use updated values if provided, otherwise use current values from server_config if updated_min_capacity is not None: - update_params["MinCapacity"] = updated_min_capacity + update_params["MinCapacity"] = updated_min_capacity # type: ignore[assignment] else: update_params["MinCapacity"] = server_config["autoScalingConfig"].get("minCapacity", 1) if updated_max_capacity is not None: - update_params["MaxCapacity"] = updated_max_capacity + update_params["MaxCapacity"] = updated_max_capacity # type: ignore[assignment] else: update_params["MaxCapacity"] = server_config["autoScalingConfig"].get("maxCapacity", 1) @@ -615,11 +616,11 @@ def get_ecs_resources_from_stack(stack_name: str) -> tuple[str, str, str]: def create_updated_task_definition( task_definition_arn: str, - updated_env_vars: Optional[Dict[str, str]] = None, - env_vars_to_delete: Optional[List[str]] = None, - updated_cpu: Optional[int] = None, - updated_memory: Optional[int] = None, - updated_health_check: Optional[Dict[str, Any]] = None, + updated_env_vars: dict[str, str] | None = None, + env_vars_to_delete: list[str] | None = None, + updated_cpu: int | None = None, + updated_memory: int | None = None, + updated_health_check: dict[str, Any] | None = None, ) -> str: """Create new task definition revision with updated configuration. @@ -741,7 +742,7 @@ def update_ecs_service(cluster_arn: str, service_arn: str, task_definition_arn: raise RuntimeError(f"Failed to update ECS service: {str(e)}") -def handle_ecs_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_ecs_update(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Update ECS task definition with new environment variables and update service. @@ -807,7 +808,7 @@ def handle_ecs_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict -def handle_poll_ecs_deployment(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_poll_ecs_deployment(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Monitor ECS service deployment progress. @@ -901,7 +902,7 @@ def handle_poll_ecs_deployment(event: Dict[str, Any], context: Any) -> Dict[str, return output_dict -def handle_poll_capacity(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_poll_capacity(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Poll ECS service to confirm if the capacity is done updating. @@ -948,7 +949,7 @@ def handle_poll_capacity(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict -def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_finish_update(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Finalize update in DDB. @@ -964,7 +965,7 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: stack_name = event["stack_name"] ddb_update_expression = "SET #status = :ms, last_modified = :lm" - ddb_update_values: Dict[str, Any] = { + ddb_update_values: dict[str, Any] = { ":lm": now(), } ExpressionAttributeNames = {"#status": "status"} diff --git a/lambda/mcp_workbench/lambda_functions.py b/lambda/mcp_workbench/lambda_functions.py index 1efca3966..4d6ebf3ad 100644 --- a/lambda/mcp_workbench/lambda_functions.py +++ b/lambda/mcp_workbench/lambda_functions.py @@ -18,7 +18,7 @@ import os import uuid from decimal import Decimal -from typing import Any, Dict, Optional +from typing import Any import boto3 import botocore.exceptions @@ -49,7 +49,7 @@ class MCPToolModel(BaseModel): contents: str # Timestamp of when the tool was created/updated - updated_at: Optional[str] = Field(default_factory=iso_string) + updated_at: str | None = Field(default_factory=iso_string) @property def s3_key(self) -> str: @@ -113,7 +113,7 @@ def read(event: dict, context: dict) -> Any: @api_wrapper -def list(event: dict, context: dict) -> Dict[str, Any]: +def list(event: dict, context: dict) -> dict[str, Any]: """List all tools from S3.""" if not is_admin(event): raise ValueError("Only admin users can access tools.") @@ -227,7 +227,7 @@ def update(event: dict, context: dict) -> Any: @api_wrapper -def delete(event: dict, context: dict) -> Dict[str, str]: +def delete(event: dict, context: dict) -> dict[str, str]: """Delete a tool from S3.""" if not is_admin(event): raise ValueError("Only admin users can access tools.") @@ -260,7 +260,7 @@ def delete(event: dict, context: dict) -> Dict[str, str]: @api_wrapper -def validate_syntax(event: dict, context: dict) -> Dict[str, Any]: +def validate_syntax(event: dict, context: dict) -> dict[str, Any]: """Validate Python code syntax without execution.""" if not is_admin(event): raise ValueError("Only admin users can validate code syntax.") diff --git a/lambda/mcp_workbench/mcp_mocks.py b/lambda/mcp_workbench/mcp_mocks.py index a500d5756..2f0f040b2 100644 --- a/lambda/mcp_workbench/mcp_mocks.py +++ b/lambda/mcp_workbench/mcp_mocks.py @@ -21,8 +21,9 @@ """ from abc import ABC, abstractmethod +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any class BaseTool(ABC): diff --git a/lambda/mcp_workbench/s3_event_handler.py b/lambda/mcp_workbench/s3_event_handler.py index 3be8a5dc6..f72a2f189 100644 --- a/lambda/mcp_workbench/s3_event_handler.py +++ b/lambda/mcp_workbench/s3_event_handler.py @@ -17,7 +17,7 @@ import json import logging import os -from typing import Any, Dict +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -32,7 +32,7 @@ ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) -def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Handle S3 events from EventBridge and trigger MCP Workbench service redeployment. @@ -144,7 +144,7 @@ def get_service_name() -> str: raise -def force_service_deployment(cluster_name: str, service_name: str) -> Dict[str, Any]: +def force_service_deployment(cluster_name: str, service_name: str) -> dict[str, Any]: """ Force a new deployment of the specified ECS service. """ @@ -155,7 +155,7 @@ def force_service_deployment(cluster_name: str, service_name: str) -> Dict[str, response = ecs_client.update_service(cluster=cluster_name, service=service_name, forceNewDeployment=True) logger.info(f"Successfully triggered new deployment for service '{service_name}'") - return response + return dict(response) # Convert to dict to satisfy return type except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") @@ -174,7 +174,7 @@ def force_service_deployment(cluster_name: str, service_name: str) -> Dict[str, raise -def validate_s3_event(event: Dict[str, Any]) -> bool: +def validate_s3_event(event: dict[str, Any]) -> bool: """ Validate that the event is a proper S3 event from EventBridge. """ diff --git a/lambda/mcp_workbench/syntax_validator.py b/lambda/mcp_workbench/syntax_validator.py index 46fd5362e..7ee227844 100644 --- a/lambda/mcp_workbench/syntax_validator.py +++ b/lambda/mcp_workbench/syntax_validator.py @@ -20,7 +20,7 @@ import sys from dataclasses import dataclass from types import ModuleType -from typing import Any, Dict, List, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -30,8 +30,8 @@ class ValidationResult: """Result of Python code validation.""" is_valid: bool - syntax_errors: List[Dict[str, Any]] - missing_required_imports: Optional[List[str]] = None + syntax_errors: list[dict[str, Any]] + missing_required_imports: list[str] | None = None def __post_init__(self) -> None: """Initialize list fields if None.""" @@ -118,7 +118,7 @@ def validate_code(self, code: str) -> ValidationResult: is_valid=is_valid, syntax_errors=syntax_errors, missing_required_imports=missing_required_imports ) - def _validate_module_execution(self, code: str) -> List[Dict[str, Any]]: + def _validate_module_execution(self, code: str) -> list[dict[str, Any]]: """Validate code by attempting to execute it as a module.""" errors = [] @@ -222,7 +222,7 @@ def _setup_mcp_environment(self, module: Any) -> None: else: logger.info("Real MCP Workbench package is already available in sys.modules") - def _check_required_mcp_imports(self, tree: ast.AST) -> List[str]: + def _check_required_mcp_imports(self, tree: ast.AST) -> list[str]: """Check if required MCP imports are present in the AST.""" missing_required = [] @@ -248,9 +248,9 @@ def _check_required_mcp_imports(self, tree: ast.AST) -> List[str]: return missing_required - def _collect_imports(self, tree: ast.AST) -> Dict[str, Any]: + def _collect_imports(self, tree: ast.AST) -> dict[str, Any]: """Collect all import statements from the AST.""" - imports: Dict[str, Any] = { + imports: dict[str, Any] = { "modules": set(), # Direct module imports: import os "from_imports": {}, # From imports: from os import path -> {'os': {'path'}} "aliases": {}, # Import aliases: import numpy as np -> {'np': 'numpy'} @@ -290,7 +290,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: visitor.visit(tree) return imports - def _format_syntax_error(self, syntax_error: SyntaxError) -> Dict[str, Any]: + def _format_syntax_error(self, syntax_error: SyntaxError) -> dict[str, Any]: """Format a SyntaxError into a standardized error dictionary.""" return { "type": "SyntaxError", diff --git a/lambda/metrics/lambda_functions.py b/lambda/metrics/lambda_functions.py index 03efb3e69..5d21e1bb4 100644 --- a/lambda/metrics/lambda_functions.py +++ b/lambda/metrics/lambda_functions.py @@ -16,7 +16,7 @@ import json import logging import os -from typing import Any, Dict, List +from typing import Any import boto3 import create_env_variables # noqa: F401 @@ -75,14 +75,14 @@ def get_user_metrics_all(event: dict, context: dict) -> dict: total_mcp_tool_calls = sum(item.get("mcpToolCallsCount", 0) for item in items) # Collect all unique user groups - all_user_groups: Dict[str, int] = {} + all_user_groups: dict[str, int] = {} for item in items: if item.get("userGroups"): for group in item["userGroups"]: all_user_groups[group] = all_user_groups.get(group, 0) + 1 # Collect all MCP tool usage across users - all_mcp_tool_usage: Dict[str, int] = {} + all_mcp_tool_usage: dict[str, int] = {} for item in items: if item.get("mcpToolUsage"): for tool_name, count in item["mcpToolUsage"].items(): @@ -126,14 +126,14 @@ def count_unique_users_and_publish_metric() -> Any: raise -def count_users_by_group_and_publish_metric() -> Dict[str, int]: +def count_users_by_group_and_publish_metric() -> dict[str, int]: """Count users in each group and publish metrics to CloudWatch.""" try: # Scan the table to get users with groups response = usage_metrics_table.scan(ProjectionExpression="userGroups") # Count users in each group - group_counts: Dict[str, int] = {} + group_counts: dict[str, int] = {} for item in response.get("Items", []): if "userGroups" in item: for group in item["userGroups"]: @@ -234,7 +234,7 @@ def process_metrics_sqs_event(event: dict, context: dict) -> None: logger.error(f"Error processing SQS message: {str(e)}") -def count_rag_usage(messages: List[Dict[str, Any]]) -> int: +def count_rag_usage(messages: list[dict[str, Any]]) -> int: """Count occurrences of 'File context:' in all human messages to determine RAG usage. Parameters: @@ -290,7 +290,7 @@ def count_rag_usage(messages: List[Dict[str, Any]]) -> int: return file_context_count -def calculate_session_metrics(messages: List[Dict[str, Any]]) -> Dict[str, Any]: +def calculate_session_metrics(messages: list[dict[str, Any]]) -> dict[str, Any]: """Calculate metrics for a complete session. Parameters: @@ -307,7 +307,7 @@ def calculate_session_metrics(messages: List[Dict[str, Any]]) -> Dict[str, Any]: return {"totalPrompts": 0, "ragUsage": 0, "mcpToolCallsCount": 0, "mcpToolUsage": {}} total_prompts = 0 - mcp_tool_usage: Dict[str, int] = {} + mcp_tool_usage: dict[str, int] = {} # Count human messages for total prompts for message in messages: @@ -356,8 +356,8 @@ def publish_metric_deltas( delta_prompts: int, delta_rag: int, delta_mcp_calls: int, - delta_mcp_usage: Dict[str, int], - user_groups: List[str], + delta_mcp_usage: dict[str, int], + user_groups: list[str], ) -> None: """Publish only metric deltas to CloudWatch to prevent double counting. @@ -486,7 +486,7 @@ def publish_metric_deltas( def update_user_metrics_by_session( - user_id: str, session_id: str, session_metrics: Dict[str, Any], user_groups: List[str] + user_id: str, session_id: str, session_metrics: dict[str, Any], user_groups: list[str] ) -> None: """Update usage metrics for a given user based on session-level metrics. @@ -562,7 +562,7 @@ def update_user_metrics_by_session( total_mcp_calls = sum(sm.get("mcpToolCallsCount", 0) for sm in all_session_metrics.values()) # Aggregate MCP tool usage across all sessions - aggregate_mcp_usage: Dict[str, int] = {} + aggregate_mcp_usage: dict[str, int] = {} for sm in all_session_metrics.values(): for tool_name, count in sm.get("mcpToolUsage", {}).items(): aggregate_mcp_usage[tool_name] = aggregate_mcp_usage.get(tool_name, 0) + count diff --git a/lambda/models/clients/litellm_client.py b/lambda/models/clients/litellm_client.py index 11fc6546c..07afea04a 100644 --- a/lambda/models/clients/litellm_client.py +++ b/lambda/models/clients/litellm_client.py @@ -14,9 +14,9 @@ """Client for interfacing with the LiteLLM proxy's management options directly.""" -from typing import Any, Dict, List, Union +from typing import Any -import requests +import requests # type: ignore[import-untyped,unused-ignore] from starlette.datastructures import Headers from ..exception import ModelNotFoundError @@ -25,13 +25,13 @@ class LiteLLMClient: """Client definition for interfacing directly with LiteLLM management operations.""" - def __init__(self, base_uri: str, headers: Headers, verify: Union[str, bool], timeout: int = 30): + def __init__(self, base_uri: str, headers: Headers, verify: str | bool, timeout: int = 30): self._base_uri = base_uri self._headers = headers self._timeout = timeout self._verify = verify - def list_models(self) -> List[Dict[str, Any]]: + def list_models(self) -> list[dict[str, Any]]: """ Retrieve all models from the database. @@ -46,10 +46,10 @@ def list_models(self) -> List[Dict[str, Any]]: verify=self._verify, ) all_models = resp.json() - models_list: List[Dict[str, Any]] = all_models["data"] + models_list: list[dict[str, Any]] = all_models["data"] return models_list - def add_model(self, model_name: str, litellm_params: Dict[str, str]) -> Dict[str, Any]: + def add_model(self, model_name: str, litellm_params: dict[str, str]) -> dict[str, Any]: """ Add a new model configuration to the database. @@ -86,7 +86,7 @@ def delete_model(self, identifier: str) -> None: verify=self._verify, ) - def get_model(self, identifier: str) -> Dict[str, Any]: + def get_model(self, identifier: str) -> dict[str, Any]: """ Get model metadata from the database. @@ -99,7 +99,7 @@ def get_model(self, identifier: str) -> Dict[str, Any]: raise ModelNotFoundError("Specified model was not found.") return filtered_models[0] - def create_guardrail(self, guardrail_config: Dict[str, Any]) -> Dict[str, Any]: + def create_guardrail(self, guardrail_config: dict[str, Any]) -> dict[str, Any]: """ Create a new guardrail configuration in LiteLLM. @@ -120,7 +120,7 @@ def create_guardrail(self, guardrail_config: Dict[str, Any]) -> Dict[str, Any]: resp.raise_for_status() return resp.json() # type: ignore [no-any-return] - def update_guardrail(self, guardrail_id: str, guardrail_config: Dict[str, Any]) -> Dict[str, Any]: + def update_guardrail(self, guardrail_id: str, guardrail_config: dict[str, Any]) -> dict[str, Any]: """ Update an existing guardrail configuration in LiteLLM. @@ -156,7 +156,7 @@ def delete_guardrail(self, guardrail_id: str) -> None: ) resp.raise_for_status() - def get_guardrail_info(self, guardrail_id: str) -> Dict[str, Any]: + def get_guardrail_info(self, guardrail_id: str) -> dict[str, Any]: """ Get information about a specific guardrail. @@ -175,7 +175,7 @@ def get_guardrail_info(self, guardrail_id: str) -> Dict[str, Any]: resp.raise_for_status() return resp.json() # type: ignore [no-any-return] - def apply_guardrail(self, guardrail_name: str, text: str) -> Dict[str, Any]: + def apply_guardrail(self, guardrail_name: str, text: str) -> dict[str, Any]: """ Apply a guardrail to text content for validation. diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index c8459a7e7..1c0c14bac 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -21,17 +21,18 @@ import re import urllib.parse import uuid +from collections.abc import Generator from dataclasses import dataclass from datetime import datetime, timedelta from enum import auto, Enum, StrEnum -from typing import Annotated, Any, Dict, Generator, List, Literal, Optional, TypeAlias, Union +from typing import Annotated, Any, Literal, Self, TypeAlias, Union from uuid import uuid4 from zoneinfo import ZoneInfo from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt from pydantic.functional_validators import AfterValidator, field_validator, model_validator -from typing_extensions import Self from utilities.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, MIN_PAGE_SIZE +from utilities.healthcheck_validator import validate_healthcheck_command from utilities.time import now, utc_now from utilities.validation import ( validate_all_fields_defined, @@ -69,6 +70,7 @@ class ModelType(StrEnum): TEXTGEN = auto() IMAGEGEN = auto() + VIDEOGEN = auto() EMBEDDING = auto() @@ -87,13 +89,13 @@ class GuardrailConfig(BaseModel): guardrailIdentifier: str = Field(min_length=1) guardrailVersion: str = Field(default="DRAFT") mode: GuardrailMode = Field(default=GuardrailMode.PRE_CALL) - description: Optional[str] = None - allowedGroups: List[str] = Field(default_factory=list) - markedForDeletion: Optional[bool] = Field(default=False) + description: str | None = None + allowedGroups: list[str] = Field(default_factory=list) + markedForDeletion: bool | None = Field(default=False) # Type alias for guardrails configuration - maps guardrail IDs to their configs -GuardrailsConfig: TypeAlias = Dict[str, GuardrailConfig] +GuardrailsConfig: TypeAlias = dict[str, GuardrailConfig] class GuardrailRequest(BaseModel): @@ -121,8 +123,8 @@ class GuardrailsTableEntry(BaseModel): guardrailIdentifier: str guardrailVersion: str mode: str - description: Optional[str] - allowedGroups: List[str] + description: str | None + allowedGroups: list[str] createdDate: int = Field(default_factory=lambda: now()) lastModifiedDate: int = Field(default_factory=lambda: now()) @@ -213,13 +215,13 @@ def validate_stop_after_start(self) -> Self: class WeeklySchedule(BaseModel): """Defines schedule for each day of the week with one start/stop time per day""" - monday: Optional[DaySchedule] = None - tuesday: Optional[DaySchedule] = None - wednesday: Optional[DaySchedule] = None - thursday: Optional[DaySchedule] = None - friday: Optional[DaySchedule] = None - saturday: Optional[DaySchedule] = None - sunday: Optional[DaySchedule] = None + monday: DaySchedule | None = None + tuesday: DaySchedule | None = None + wednesday: DaySchedule | None = None + thursday: DaySchedule | None = None + friday: DaySchedule | None = None + saturday: DaySchedule | None = None + sunday: DaySchedule | None = None @model_validator(mode="after") def validate_daily_schedules(self) -> Self: @@ -254,18 +256,18 @@ class BaseSchedulingConfig(BaseModel): # Schedule metadata and tracking scheduleEnabled: bool = False - lastScheduleUpdate: Optional[str] = None - scheduledActionArns: Optional[List[str]] = None + lastScheduleUpdate: str | None = None + scheduledActionArns: list[str] | None = None # Status tracking scheduleConfigured: bool = False lastScheduleFailed: bool = False # Next scheduled action info (computed field) - nextScheduledAction: Optional[NextScheduledAction] = None + nextScheduledAction: NextScheduledAction | None = None # Failure tracking - lastScheduleFailure: Optional[ScheduleFailure] = None + lastScheduleFailure: ScheduleFailure | None = None @field_validator("timezone") @classmethod @@ -317,14 +319,14 @@ def validate_recurring_schedule_exclusivity(self) -> Self: class AutoScalingConfig(BaseModel): """Specifies auto-scaling parameters for model deployment.""" - blockDeviceVolumeSize: Optional[NonNegativeInt] = 50 + blockDeviceVolumeSize: NonNegativeInt | None = 50 minCapacity: PositiveInt maxCapacity: PositiveInt - desiredCapacity: Optional[PositiveInt] = None + desiredCapacity: PositiveInt | None = None cooldown: PositiveInt defaultInstanceWarmup: PositiveInt metricConfig: MetricConfig - scheduling: Optional[SchedulingConfig] = None + scheduling: SchedulingConfig | None = None @model_validator(mode="after") def validate_auto_scaling_config(self) -> Self: @@ -343,11 +345,11 @@ def validate_auto_scaling_config(self) -> Self: class AutoScalingInstanceConfig(BaseModel): """Defines instance count parameters for auto-scaling updates.""" - minCapacity: Optional[PositiveInt] = None - maxCapacity: Optional[PositiveInt] = None - desiredCapacity: Optional[PositiveInt] = None - cooldown: Optional[PositiveInt] = None - defaultInstanceWarmup: Optional[PositiveInt] = None + minCapacity: PositiveInt | None = None + maxCapacity: PositiveInt | None = None + desiredCapacity: PositiveInt | None = None + cooldown: PositiveInt | None = None + defaultInstanceWarmup: PositiveInt | None = None @model_validator(mode="after") def validate_auto_scaling_instance_config(self) -> Self: @@ -371,12 +373,19 @@ def validate_auto_scaling_instance_config(self) -> Self: class ContainerHealthCheckConfig(BaseModel): """Specifies container health check parameters.""" - command: Union[str, List[str]] + command: str | list[str] interval: PositiveInt startPeriod: PositiveInt timeout: PositiveInt retries: PositiveInt + @field_validator("command") + @classmethod + def validate_command(cls, command: str | list[str]) -> str | list[str]: + """Validates healthcheck command format for ECS compatibility.""" + validate_healthcheck_command(command) + return command + class ContainerConfigImage(BaseModel): """Defines container image configuration.""" @@ -391,14 +400,14 @@ class ContainerConfig(BaseModel): image: ContainerConfigImage sharedMemorySize: PositiveInt healthCheckConfig: ContainerHealthCheckConfig - environment: Optional[Dict[str, str]] = {} + environment: dict[str, str] | None = {} @field_validator("environment") @classmethod - def validate_environment(cls, environment: Dict[str, str]) -> Dict[str, str]: + def validate_environment(cls, environment: dict[str, str]) -> dict[str, str]: """Validates environment variable key names.""" if environment: - if not all((key for key in environment.keys())): + if not all(key for key in environment.keys()): raise ValueError("Empty strings are not allowed for environment variable key names.") return environment @@ -406,20 +415,20 @@ def validate_environment(cls, environment: Dict[str, str]) -> Dict[str, str]: class ContainerConfigUpdatable(BaseModel): """Specifies container configuration fields that can be updated.""" - environment: Optional[Dict[str, str]] = None - sharedMemorySize: Optional[PositiveInt] = None - healthCheckCommand: Optional[Union[str, List[str]]] = None - healthCheckInterval: Optional[PositiveInt] = None - healthCheckTimeout: Optional[PositiveInt] = None - healthCheckStartPeriod: Optional[PositiveInt] = None - healthCheckRetries: Optional[PositiveInt] = None + environment: dict[str, str] | None = None + sharedMemorySize: PositiveInt | None = None + healthCheckCommand: str | list[str] | None = None + healthCheckInterval: PositiveInt | None = None + healthCheckTimeout: PositiveInt | None = None + healthCheckStartPeriod: PositiveInt | None = None + healthCheckRetries: PositiveInt | None = None @field_validator("environment") @classmethod - def validate_environment(cls, environment: Dict[str, str]) -> Dict[str, str]: + def validate_environment(cls, environment: dict[str, str]) -> dict[str, str]: """Validates environment variable key names.""" if environment: - if not all((key for key in environment.keys())): + if not all(key for key in environment.keys()): raise ValueError("Empty strings are not allowed for environment variable key names.") return environment @@ -427,7 +436,7 @@ def validate_environment(cls, environment: Dict[str, str]) -> Dict[str, str]: class ModelFeature(BaseModel): """Defines model feature attributes.""" - __exceptions: List[Any] = [] + __exceptions: list[Any] = [] name: str overview: str @@ -438,21 +447,21 @@ def __init__(self, **kwargs: Any) -> None: class LISAModel(BaseModel): """Defines core model attributes and configuration.""" - autoScalingConfig: Optional[AutoScalingConfig] = None - containerConfig: Optional[ContainerConfig] = None - inferenceContainer: Optional[InferenceContainer] = None - instanceType: Optional[Annotated[str, AfterValidator(validate_instance_type)]] = None - loadBalancerConfig: Optional[LoadBalancerConfig] = None + autoScalingConfig: AutoScalingConfig | None = None + containerConfig: ContainerConfig | None = None + inferenceContainer: InferenceContainer | None = None + instanceType: Annotated[str, AfterValidator(validate_instance_type)] | None = None + loadBalancerConfig: LoadBalancerConfig | None = None modelId: str modelName: str - modelDescription: Optional[str] = None + modelDescription: str | None = None modelType: ModelType - modelUrl: Optional[str] = None + modelUrl: str | None = None status: ModelStatus streaming: bool - features: Optional[List[ModelFeature]] = None - allowedGroups: Optional[List[str]] = None - guardrailsConfig: Optional[GuardrailsConfig] = None + features: list[ModelFeature] | None = None + allowedGroups: list[str] | None = None + guardrailsConfig: GuardrailsConfig | None = None class ApiResponseBase(BaseModel): @@ -464,21 +473,21 @@ class ApiResponseBase(BaseModel): class CreateModelRequest(BaseModel): """Specifies parameters for model creation requests.""" - autoScalingConfig: Optional[AutoScalingConfig] = None - containerConfig: Optional[ContainerConfig] = None - inferenceContainer: Optional[InferenceContainer] = None - instanceType: Optional[Annotated[str, AfterValidator(validate_instance_type)]] = None - loadBalancerConfig: Optional[LoadBalancerConfig] = None + autoScalingConfig: AutoScalingConfig | None = None + containerConfig: ContainerConfig | None = None + inferenceContainer: InferenceContainer | None = None + instanceType: Annotated[str, AfterValidator(validate_instance_type)] | None = None + loadBalancerConfig: LoadBalancerConfig | None = None modelId: str = Field(min_length=1) modelName: str = Field(min_length=1) - modelDescription: Optional[str] = None + modelDescription: str | None = None modelType: ModelType - modelUrl: Optional[str] = None - streaming: Optional[bool] = False - features: Optional[List[ModelFeature]] = None - allowedGroups: Optional[List[str]] = None - apiKey: Optional[str] = None - guardrailsConfig: Optional[GuardrailsConfig] = None + modelUrl: str | None = None + streaming: bool | None = False + features: list[ModelFeature] | None = None + allowedGroups: list[str] | None = None + apiKey: str | None = None + guardrailsConfig: GuardrailsConfig | None = None @model_validator(mode="after") def validate_create_model_request(self) -> Self: @@ -512,7 +521,7 @@ class CreateModelResponse(ApiResponseBase): class ListModelsResponse(BaseModel): """Defines response structure for model listing.""" - models: List[LISAModel] + models: list[LISAModel] class GetModelResponse(ApiResponseBase): @@ -524,15 +533,15 @@ class GetModelResponse(ApiResponseBase): class UpdateModelRequest(BaseModel): """Specifies parameters for model update requests.""" - autoScalingInstanceConfig: Optional[AutoScalingInstanceConfig] = None - enabled: Optional[bool] = None - modelType: Optional[ModelType] = None - modelDescription: Optional[str] = None - streaming: Optional[bool] = None - allowedGroups: Optional[List[str]] = None - features: Optional[List[ModelFeature]] = None - containerConfig: Optional[ContainerConfigUpdatable] = None - guardrailsConfig: Optional[GuardrailsConfig] = None + autoScalingInstanceConfig: AutoScalingInstanceConfig | None = None + enabled: bool | None = None + modelType: ModelType | None = None + modelDescription: str | None = None + streaming: bool | None = None + allowedGroups: list[str] | None = None + features: list[ModelFeature] | None = None + containerConfig: ContainerConfigUpdatable | None = None + guardrailsConfig: GuardrailsConfig | None = None @model_validator(mode="after") def validate_update_model_request(self) -> Self: @@ -600,8 +609,8 @@ class GetScheduleResponse(BaseModel): """Response object for getting schedule configuration.""" modelId: str - scheduling: Dict[str, Any] - nextScheduledAction: Optional[Dict[str, str]] = None + scheduling: dict[str, Any] + nextScheduledAction: dict[str, str] | None = None class DeleteScheduleResponse(BaseModel): @@ -620,11 +629,11 @@ class GetScheduleStatusResponse(BaseModel): scheduleConfigured: bool lastScheduleFailed: bool scheduleStatus: str - scheduleType: Optional[str] = None + scheduleType: str | None = None timezone: str - nextScheduledAction: Optional[Dict[str, str]] = None - lastScheduleUpdate: Optional[str] = None - lastScheduleFailure: Optional[Dict[str, Any]] = None + nextScheduledAction: dict[str, str] | None = None + lastScheduleUpdate: str | None = None + lastScheduleFailure: dict[str, Any] | None = None class IngestionType(StrEnum): @@ -645,7 +654,7 @@ class JobActionType(StrEnum): COLLECTION_DELETION = auto() -RagDocumentDict = Dict[str, Any] +RagDocumentDict = dict[str, Any] class ChunkingStrategyType(StrEnum): @@ -715,9 +724,9 @@ class RagSubDocument(BaseModel): """Represents a sub-document entity for DynamoDB storage.""" document_id: str - subdocs: List[str] = Field(default_factory=lambda: []) - index: Optional[int] = Field(default=None) - sk: Optional[str] = None + subdocs: list[str] = Field(default_factory=lambda: []) + index: int | None = Field(default=None) + sk: str | None = None def __init__(self, **data: Any) -> None: super().__init__(**data) @@ -727,18 +736,18 @@ def __init__(self, **data: Any) -> None: class RagDocument(BaseModel): """Represents a RAG document entity for DynamoDB storage.""" - pk: Optional[str] = None + pk: str | None = None document_id: str = Field(default_factory=lambda: str(uuid.uuid4())) repository_id: str = Field(min_length=3, max_length=20) collection_id: str document_name: str source: str username: str - subdocs: List[str] = Field(default_factory=lambda: [], exclude=True) + subdocs: list[str] = Field(default_factory=lambda: [], exclude=True) chunk_strategy: ChunkingStrategy ingestion_type: IngestionType = Field(default_factory=lambda: IngestionType.MANUAL) upload_date: int = Field(default_factory=lambda: now()) - chunks: Optional[int] = 0 + chunks: int | None = 0 model_config = ConfigDict(use_enum_values=True, validate_default=True) def __init__(self, **data: Any) -> None: @@ -753,7 +762,7 @@ def createPartitionKey(repository_id: str, collection_id: str) -> str: """Generates a partition key from repository and collection IDs.""" return f"{repository_id}#{collection_id}" - def chunk_doc(self, chunk_size: int = 1000) -> Generator[RagSubDocument, None, None]: + def chunk_doc(self, chunk_size: int = 1000) -> Generator[RagSubDocument]: """Segments document into smaller sub-documents.""" total_subdocs = len(self.subdocs) for start_index in range(0, total_subdocs, chunk_size): @@ -763,16 +772,16 @@ def chunk_doc(self, chunk_size: int = 1000) -> Generator[RagSubDocument, None, N ) @staticmethod - def join_docs(documents: List[RagDocumentDict]) -> List[RagDocumentDict]: + def join_docs(documents: list[RagDocumentDict]) -> list[RagDocumentDict]: """Combines multiple sub-documents into a single document.""" - grouped_docs: dict[str, List[RagDocumentDict]] = {} + grouped_docs: dict[str, list[RagDocumentDict]] = {} for doc in documents: doc_id = doc.get("document_id", "") if doc_id not in grouped_docs: grouped_docs[doc_id] = [] grouped_docs[doc_id].append(doc) - joined_docs: List[RagDocumentDict] = [] + joined_docs: list[RagDocumentDict] = [] for docs in grouped_docs.values(): joined_doc = docs[0] joined_doc["subdocs"] = [sub_doc for doc in docs for sub_doc in (doc.get("subdocs", []) or [])] @@ -786,29 +795,29 @@ class IngestionJob(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) s3_path: str - collection_id: Optional[str] = Field( + collection_id: str | None = Field( default=None, description="Collection ID for full deletion, None for default collection deletion" ) - document_id: Optional[str] = Field(default=None) + document_id: str | None = Field(default=None) repository_id: str - chunk_strategy: Optional[ChunkingStrategy] = Field(default=None) - embedding_model: Optional[str] = Field( + chunk_strategy: ChunkingStrategy | None = Field(default=None) + embedding_model: str | None = Field( default=None, description="Embedding model name, used as index identifier for default collections" ) - username: Optional[str] = Field(default=None) + username: str | None = Field(default=None) ingestion_type: IngestionType = Field( default=IngestionType.MANUAL, description="How the document was ingested (MANUAL, AUTO, or EXISTING)" ) status: IngestionStatus = IngestionStatus.INGESTION_PENDING created_date: str = Field(default_factory=lambda: utc_now().isoformat()) - error_message: Optional[str] = Field(default=None) - document_name: Optional[str] = Field(default=None) - auto: Optional[bool] = Field(default=None) - metadata: Optional[dict] = Field(default=None) - job_type: Optional[JobActionType] = Field(default=None, description="Type of deletion job") + error_message: str | None = Field(default=None) + document_name: str | None = Field(default=None) + auto: bool | None = Field(default=None) + metadata: dict | None = Field(default=None) + job_type: JobActionType | None = Field(default=None, description="Type of deletion job") collection_deletion: bool = Field(default=False, description="Indicates this is a collection deletion job") - s3_paths: Optional[List[str]] = Field(default=None, description="List of S3 paths for batch ingestion operations") - document_ids: Optional[List[str]] = Field( + s3_paths: list[str] | None = Field(default=None, description="List of S3 paths for batch ingestion operations") + document_ids: list[str] | None = Field( default=None, description="List of document IDs from completed batch operations" ) @@ -843,7 +852,7 @@ def validate_collection_deletion_identifiers(self) -> Self: class PaginatedResponse(BaseModel): """Base class for paginated API responses.""" - lastEvaluatedKey: Optional[Dict[str, str]] = None + lastEvaluatedKey: dict[str, str] | None = None hasNextPage: bool = False hasPreviousPage: bool = False @@ -851,7 +860,7 @@ class PaginatedResponse(BaseModel): class ListJobsResponse(PaginatedResponse): """Response structure for listing ingestion jobs with pagination.""" - jobs: List[IngestionJob] + jobs: list[IngestionJob] @dataclass @@ -862,9 +871,7 @@ class PaginationResult: has_previous_page: bool @classmethod - def from_keys( - cls, original_key: Optional[Dict[str, str]], returned_key: Optional[Dict[str, str]] - ) -> "PaginationResult": + def from_keys(cls, original_key: dict[str, str] | None, returned_key: dict[str, str] | None) -> PaginationResult: """Create pagination result from keys.""" return cls(has_next_page=returned_key is not None, has_previous_page=original_key is not None) @@ -874,18 +881,18 @@ class PaginationParams: """Shared pagination parameter handling.""" page_size: int = DEFAULT_PAGE_SIZE - last_evaluated_key: Optional[Dict[str, str]] = None + last_evaluated_key: dict[str, str] | None = None @staticmethod def parse_page_size( - query_params: Dict[str, str], default: int = DEFAULT_PAGE_SIZE, max_size: int = MAX_PAGE_SIZE + query_params: dict[str, str], default: int = DEFAULT_PAGE_SIZE, max_size: int = MAX_PAGE_SIZE ) -> int: """Parse and validate page size with configurable limits.""" page_size = int(query_params.get("pageSize", str(default))) return max(MIN_PAGE_SIZE, min(page_size, max_size)) @staticmethod - def parse_last_evaluated_key(query_params: Dict[str, str], key_fields: List[str]) -> Optional[Dict[str, str]]: + def parse_last_evaluated_key(query_params: dict[str, str], key_fields: list[str]) -> dict[str, str] | None: """Parse last evaluated key from query parameters. Args: @@ -918,7 +925,7 @@ def parse_last_evaluated_key(query_params: Dict[str, str], key_fields: List[str] return last_evaluated_key if last_evaluated_key else None @staticmethod - def parse_last_evaluated_key_v2(query_params: Dict[str, str]) -> Optional[Dict[str, Any]]: + def parse_last_evaluated_key_v2(query_params: dict[str, str]) -> dict[str, Any] | None: """Parse v2 pagination token from query parameters. The v2 token format supports scalable pagination with per-repository cursors. @@ -967,11 +974,11 @@ def parse_last_evaluated_key_v2(query_params: Dict[str, str]) -> Optional[Dict[s class FilterParams: """Shared filtering parameter handling for collections.""" - filter_text: Optional[str] = None - status_filter: Optional[CollectionStatus] = None + filter_text: str | None = None + status_filter: CollectionStatus | None = None @staticmethod - def from_query_params(query_params: Dict[str, str]) -> FilterParams: + def from_query_params(query_params: dict[str, str]) -> FilterParams: """Parse filter parameters from query string parameters. Args: @@ -999,11 +1006,11 @@ def from_query_params(query_params: Dict[str, str]) -> FilterParams: class SortParams: """Shared sorting parameter handling for collections.""" - sort_by: CollectionSortBy = None # Will be set to default in from_query_params - sort_order: SortOrder = None # Will be set to default in from_query_params + sort_by: CollectionSortBy | None = None # Will be set to default in from_query_params + sort_order: SortOrder | None = None # Will be set to default in from_query_params @staticmethod - def from_query_params(query_params: Dict[str, str]) -> SortParams: + def from_query_params(query_params: dict[str, str]) -> SortParams: """Parse sort parameters from query string parameters. Args: @@ -1071,24 +1078,24 @@ class PipelineConfig(BaseModel): """Defines pipeline configuration for automated document ingestion.""" autoRemove: bool = Field(default=True, description="Automatically remove documents after ingestion") - chunkOverlap: Optional[int] = Field( + chunkOverlap: int | None = Field( default=None, ge=0, description="Chunk overlap for pipeline ingestion (deprecated, use chunkingStrategy)" ) - chunkSize: Optional[int] = Field( + chunkSize: int | None = Field( default=None, ge=0, description="Chunk size for pipeline ingestion (deprecated, use chunkingStrategy)", ) - chunkingStrategy: Optional[ChunkingStrategy] = Field( + chunkingStrategy: ChunkingStrategy | None = Field( default=None, description="Chunking strategy for documents in this pipeline" ) - collectionId: Optional[str] = Field( + collectionId: str | None = Field( default=None, description="Collection ID for this pipeline (for Bedrock KB, this is the data source ID)" ) s3Bucket: str = Field(min_length=1, description="S3 bucket for pipeline source") s3Prefix: str = Field(description="S3 prefix for pipeline source") trigger: PipelineTrigger = Field(description="Pipeline trigger type") - metadata: Optional[CollectionMetadata] = Field( + metadata: CollectionMetadata | None = Field( default_factory=lambda: CollectionMetadata(tags=[]), description="Metadata for the pipeline including tags" ) @@ -1110,6 +1117,9 @@ def validate_chunking_config(self) -> Self: # If legacy fields provided but no chunkingStrategy, create one if has_legacy and not has_new: + # At this point we know both are not None due to has_legacy check + if self.chunkSize is None or self.chunkOverlap is None: + raise ValueError("chunkSize and chunkOverlap must both be set") self.chunkingStrategy = FixedChunkingStrategy( type=ChunkingStrategyType.FIXED, size=self.chunkSize, overlap=self.chunkOverlap ) @@ -1120,12 +1130,12 @@ def validate_chunking_config(self) -> Self: class CollectionMetadata(BaseModel): """Defines metadata for a collection.""" - tags: List[str] = Field(default_factory=list, max_length=50, description="Metadata tags for the collection") - customFields: Dict[str, Any] = Field(default_factory=dict, description="Custom metadata fields") + tags: list[str] = Field(default_factory=list, max_length=50, description="Metadata tags for the collection") + customFields: dict[str, Any] = Field(default_factory=dict, description="Custom metadata fields") @field_validator("tags") @classmethod - def validate_tags(cls, tags: List[str]) -> List[str]: + def validate_tags(cls, tags: list[str]) -> list[str]: """Validates metadata tags.""" tag_pattern = re.compile(r"^[a-zA-Z0-9_-]+$") for tag in tags: @@ -1139,7 +1149,7 @@ def validate_tags(cls, tags: List[str]) -> List[str]: return tags @classmethod - def merge(cls, parent: Optional[CollectionMetadata], child: Optional[CollectionMetadata]) -> CollectionMetadata: + def merge(cls, parent: CollectionMetadata | None, child: CollectionMetadata | None) -> CollectionMetadata: """Merges parent and child metadata. Args: @@ -1170,17 +1180,17 @@ class RagCollectionConfig(BaseModel): collectionId: str = Field(default_factory=lambda: str(uuid4()), description="Unique collection identifier") repositoryId: str = Field(min_length=1, description="Parent repository ID this collection belongs to") - name: Optional[str] = Field(default=None, max_length=100, description="User-friendly collection name") - description: Optional[str] = Field(default=None, description="Collection description") - chunkingStrategy: Optional[ChunkingStrategy] = Field(default=None, description="Chunking strategy for documents") + name: str | None = Field(default=None, max_length=100, description="User-friendly collection name") + description: str | None = Field(default=None, description="Collection description") + chunkingStrategy: ChunkingStrategy | None = Field(default=None, description="Chunking strategy for documents") allowChunkingOverride: bool = Field( default=True, description="Allow users to override chunking strategy during ingestion" ) - metadata: Optional[CollectionMetadata] = Field( + metadata: CollectionMetadata | None = Field( default=None, description="Collection-specific metadata (merged with parent)" ) - allowedGroups: Optional[List[str]] = Field(default=None, description="User groups with access to collection") - embeddingModel: Optional[str] = Field( + allowedGroups: list[str] | None = Field(default=None, description="User groups with access to collection") + embeddingModel: str | None = Field( default=None, description="Embedding model ID (can be set at creation, immutable after)" ) createdBy: str = Field(min_length=1, description="User ID of creator") @@ -1188,10 +1198,10 @@ class RagCollectionConfig(BaseModel): updatedAt: datetime = Field(default_factory=utc_now, description="Last update timestamp") status: CollectionStatus = Field(default=CollectionStatus.ACTIVE, description="Collection status") default: bool = Field(default=False, description="Indicates if this is a default collection") - dataSourceId: Optional[str] = Field( + dataSourceId: str | None = Field( default=None, description="Bedrock KB data source ID for filtering (Bedrock KB only)" ) - pipelines: Optional[List[PipelineConfig]] = Field( + pipelines: list[PipelineConfig] | None = Field( default=None, description="Pipeline configurations for this collection" ) @@ -1199,7 +1209,7 @@ class RagCollectionConfig(BaseModel): @field_validator("name") @classmethod - def validate_name(cls, name: Optional[str]) -> Optional[str]: + def validate_name(cls, name: str | None) -> str | None: """Validates collection name.""" if name is not None: if len(name) > 100: @@ -1213,7 +1223,7 @@ def validate_name(cls, name: Optional[str]) -> Optional[str]: @field_validator("allowedGroups") @classmethod - def validate_allowed_groups(cls, groups: Optional[List[str]]) -> Optional[List[str]]: + def validate_allowed_groups(cls, groups: list[str] | None) -> list[str] | None: """Validates allowed groups.""" if groups is not None and len(groups) == 0: # Empty list should be treated as None (inherit from parent) @@ -1224,20 +1234,20 @@ def validate_allowed_groups(cls, groups: Optional[List[str]]) -> Optional[List[s class IngestDocumentRequest(BaseModel): """Request model for ingesting documents.""" - keys: List[str] = Field(description="S3 keys to ingest") - collectionId: Optional[str] = Field(default=None, description="Target collection ID") - embeddingModel: Optional[Dict[str, str]] = Field(default=None, description="Embedding model config") - chunkingStrategy: Optional[Dict[str, Any]] = Field(default=None, description="Chunking strategy override") - metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata") + keys: list[str] = Field(description="S3 keys to ingest") + collectionId: str | None = Field(default=None, description="Target collection ID") + embeddingModel: dict[str, str] | None = Field(default=None, description="Embedding model config") + chunkingStrategy: dict[str, Any] | None = Field(default=None, description="Chunking strategy override") + metadata: dict[str, Any] | None = Field(default=None, description="Additional metadata") class ListCollectionsResponse(PaginatedResponse): """Response model for listing collections.""" - collections: List[RagCollectionConfig] = Field(description="List of collections") - totalCount: Optional[int] = Field(default=None, description="Total number of collections") - currentPage: Optional[int] = Field(default=None, description="Current page number") - totalPages: Optional[int] = Field(default=None, description="Total number of pages") + collections: list[RagCollectionConfig] = Field(description="List of collections") + totalCount: int | None = Field(default=None, description="Total number of collections") + currentPage: int | None = Field(default=None, description="Current page number") + totalPages: int | None = Field(default=None, description="Total number of pages") class CollectionSortBy(StrEnum): @@ -1258,8 +1268,8 @@ class SortOrder(StrEnum): class RepositoryMetadata(BaseModel): """Defines metadata for a repository/vector store.""" - tags: List[str] = Field(default_factory=list, description="Tags for categorizing the repository") - customFields: Optional[Dict[str, Any]] = Field(default=None, description="Custom metadata fields") + tags: list[str] = Field(default_factory=list, description="Tags for categorizing the repository") + customFields: dict[str, Any] | None = Field(default=None, description="Custom metadata fields") class OpenSearchNewClusterConfig(BaseModel): @@ -1306,14 +1316,16 @@ class RdsInstanceConfig(BaseModel): """Configuration schema for RDS Instances needed for LiteLLM scaling or PGVector RAG operations. The optional fields can be omitted to create a new database instance, otherwise fill in all fields - to use an existing database instance. + to use an existing database instance. By default, IAM authentication is used. Set iamRdsAuth + to false in config to use password-based authentication. """ username: str = Field(default="postgres", description="The username used for database connection.") - passwordSecretId: Optional[str] = Field( - default=None, description="The SecretsManager Secret ID that stores the existing database password." + passwordSecretId: str | None = Field( + default=None, + description="The SecretsManager Secret ID that stores the existing database password.", ) - dbHost: Optional[str] = Field(default=None, description="The database hostname for the existing database instance.") + dbHost: str | None = Field(default=None, description="The database hostname for the existing database instance.") dbName: str = Field(default="postgres", description="The name of the database for the database instance.") dbPort: int = Field( default=5432, @@ -1344,7 +1356,7 @@ class BedrockKnowledgeBaseConfig(BaseModel): """ knowledgeBaseId: str = Field(min_length=1, description="The ID of the Bedrock Knowledge Base") - dataSources: List[BedrockDataSource] = Field( + dataSources: list[BedrockDataSource] = Field( min_length=1, description="Array of data sources in this Knowledge Base" ) @@ -1353,38 +1365,38 @@ class VectorStoreConfig(BaseModel): """Represents a vector store/repository configuration.""" repositoryId: str = Field(description="Unique identifier for the repository") - repositoryName: Optional[str] = Field(default=None, description="User-friendly name for the repository") - description: Optional[str] = Field(default=None, description="Description of the repository") - embeddingModelId: Optional[str] = Field(default=None, description="Default embedding model ID") + repositoryName: str | None = Field(default=None, description="User-friendly name for the repository") + description: str | None = Field(default=None, description="Description of the repository") + embeddingModelId: str | None = Field(default=None, description="Default embedding model ID") type: str = Field(description="Type of vector store (opensearch, pgvector, bedrock_knowledge_base)") - allowedGroups: List[str] = Field(default_factory=list, description="User groups with access to this repository") - metadata: Optional[RepositoryMetadata] = Field(default=None, description="Repository metadata") - pipelines: Optional[List[PipelineConfig]] = Field(default=None, description="Automated ingestion pipelines") + allowedGroups: list[str] = Field(default_factory=list, description="User groups with access to this repository") + metadata: RepositoryMetadata | None = Field(default=None, description="Repository metadata") + pipelines: list[PipelineConfig] | None = Field(default=None, description="Automated ingestion pipelines") # Type-specific configurations - opensearchConfig: Optional[Union[OpenSearchNewClusterConfig, OpenSearchExistingClusterConfig]] = Field( + opensearchConfig: OpenSearchNewClusterConfig | OpenSearchExistingClusterConfig | None = Field( default=None, description="OpenSearch configuration" ) - rdsConfig: Optional[RdsInstanceConfig] = Field(default=None, description="RDS/PGVector configuration") - bedrockKnowledgeBaseConfig: Optional[BedrockKnowledgeBaseConfig] = Field( + rdsConfig: RdsInstanceConfig | None = Field(default=None, description="RDS/PGVector configuration") + bedrockKnowledgeBaseConfig: BedrockKnowledgeBaseConfig | None = Field( default=None, description="Bedrock Knowledge Base configuration with data sources" ) # Status and timestamps - status: Optional[VectorStoreStatus] = Field(default=None, description="Repository Status") + status: VectorStoreStatus | None = Field(default=None, description="Repository Status") createdBy: str = Field(description="Creation user") createdAt: datetime = Field(default_factory=utc_now, description="Creation timestamp") - updatedAt: Optional[datetime] = Field(default_factory=utc_now, description="Last update timestamp") + updatedAt: datetime | None = Field(default_factory=utc_now, description="Last update timestamp") class UpdateVectorStoreRequest(BaseModel): """Request model for updating a vector store.""" - repositoryName: Optional[str] = Field(default=None, description="User-friendly name") - description: Optional[str] = Field(default=None, description="Description of the repository") - embeddingModelId: Optional[str] = Field(default=None, description="Default embedding model ID") - allowedGroups: Optional[List[str]] = Field(default=None, description="User groups with access") - metadata: Optional[RepositoryMetadata] = Field(default=None, description="Repository metadata") - pipelines: Optional[List[PipelineConfig]] = Field(default=None, description="Automated ingestion pipelines") - bedrockKnowledgeBaseConfig: Optional[BedrockKnowledgeBaseConfig] = Field( + repositoryName: str | None = Field(default=None, description="User-friendly name") + description: str | None = Field(default=None, description="Description of the repository") + embeddingModelId: str | None = Field(default=None, description="Default embedding model ID") + allowedGroups: list[str] | None = Field(default=None, description="User groups with access") + metadata: RepositoryMetadata | None = Field(default=None, description="Repository metadata") + pipelines: list[PipelineConfig] | None = Field(default=None, description="Automated ingestion pipelines") + bedrockKnowledgeBaseConfig: BedrockKnowledgeBaseConfig | None = Field( default=None, description="Bedrock Knowledge Base configuration" ) @@ -1394,10 +1406,10 @@ class KnowledgeBaseMetadata(BaseModel): knowledgeBaseId: str = Field(description="Knowledge Base ID") name: str = Field(description="Knowledge Base name") - description: Optional[str] = Field(default="", description="Knowledge Base description") + description: str | None = Field(default="", description="Knowledge Base description") status: str = Field(description="Knowledge Base status (ACTIVE, CREATING, DELETING, etc.)") createdAt: datetime = Field(default_factory=utc_now, description="Creation timestamp") - updatedAt: Optional[datetime] = Field(default=None, description="Last update timestamp") + updatedAt: datetime | None = Field(default=None, description="Last update timestamp") class DataSourceMetadata(BaseModel): @@ -1405,14 +1417,14 @@ class DataSourceMetadata(BaseModel): dataSourceId: str = Field(description="Data Source ID") name: str = Field(description="Data Source name") - description: Optional[str] = Field(default="", description="Data Source description") + description: str | None = Field(default="", description="Data Source description") status: str = Field(description="Data Source status (AVAILABLE, CREATING, DELETING, etc.)") s3Bucket: str = Field(description="S3 bucket for the data source") s3Prefix: str = Field(default="", description="S3 prefix for the data source") createdAt: datetime = Field(default_factory=utc_now, description="Creation timestamp") - updatedAt: Optional[datetime] = Field(default=None, description="Last update timestamp") - managed: Optional[bool] = Field(default=False, description="Whether this data source is managed by a collection") - collectionId: Optional[str] = Field(default=None, description="Collection ID if managed") + updatedAt: datetime | None = Field(default=None, description="Last update timestamp") + managed: bool | None = Field(default=False, description="Whether this data source is managed by a collection") + collectionId: str | None = Field(default=None, description="Collection ID if managed") @field_validator("s3Bucket") @classmethod diff --git a/lambda/models/handler/create_model_handler.py b/lambda/models/handler/create_model_handler.py index a9ffbc1d9..0da28fc1d 100644 --- a/lambda/models/handler/create_model_handler.py +++ b/lambda/models/handler/create_model_handler.py @@ -17,6 +17,7 @@ import os from models.exception import ModelAlreadyExistsError +from utilities.time import now from ..domain_objects import CreateModelRequest, CreateModelResponse, ModelStatus from .base_handler import BaseApiHandler @@ -26,7 +27,7 @@ class CreateModelHandler(BaseApiHandler): """Handler class for CreateModel requests.""" - def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse: # type: ignore + def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse: """Create model infrastructure and add model data to LiteLLM database.""" model_id = create_request.modelId @@ -37,6 +38,21 @@ def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse: self.validate(create_request) + # Create initial DynamoDB record before starting state machine + # This ensures model_config exists even if state machine fails + model_config_data = create_request.model_dump() + model_config_data.pop("guardrailsConfig", None) + + self._model_table.put_item( + Item={ + "model_id": model_id, + "model_status": ModelStatus.CREATING, + "model_config": model_config_data, + "model_description": create_request.modelDescription, + "last_modified_date": now(), + } + ) + self._stepfunctions.start_execution( stateMachineArn=os.environ["CREATE_SFN_ARN"], input=create_request.model_dump_json() ) diff --git a/lambda/models/handler/delete_model_handler.py b/lambda/models/handler/delete_model_handler.py index e9776af07..cb48d90b2 100644 --- a/lambda/models/handler/delete_model_handler.py +++ b/lambda/models/handler/delete_model_handler.py @@ -35,7 +35,7 @@ class DeleteModelHandler(BaseApiHandler): """Handler class for DeleteModel requests.""" - def __call__(self, model_id: str) -> DeleteModelResponse: # type: ignore + def __call__(self, model_id: str) -> DeleteModelResponse: """Kick off state machine to delete infrastructure and remove model reference from LiteLLM.""" table_item = self._model_table.get_item(Key={"model_id": model_id}).get("Item", None) if not table_item: @@ -108,7 +108,7 @@ def _get_vector_store_table_name(self) -> str | None: response = ssm_client.get_parameter(Name=parameter_name) table_name = response["Parameter"]["Value"] logger.debug(f"Retrieved RAG vector store table name from SSM: {table_name}") - return table_name + return table_name # type: ignore[no-any-return] except ClientError as e: if e.response["Error"]["Code"] == "ParameterNotFound": logger.debug(f"SSM parameter {parameter_name} not found - RAG not deployed") @@ -131,7 +131,7 @@ def _get_collection_table_name(self) -> str | None: response = ssm_client.get_parameter(Name=parameter_name) table_name = response["Parameter"]["Value"] logger.debug(f"Retrieved RAG collections table name from SSM: {table_name}") - return table_name + return table_name # type: ignore[no-any-return] except ClientError as e: if e.response["Error"]["Code"] == "ParameterNotFound": logger.debug(f"SSM parameter {parameter_name} not found - RAG not deployed") diff --git a/lambda/models/handler/get_model_handler.py b/lambda/models/handler/get_model_handler.py index 198f592aa..207d7687a 100644 --- a/lambda/models/handler/get_model_handler.py +++ b/lambda/models/handler/get_model_handler.py @@ -14,7 +14,6 @@ """Handler for GetModel requests.""" -from typing import List, Optional from utilities.auth import user_has_group_access @@ -27,9 +26,7 @@ class GetModelHandler(BaseApiHandler): """Handler class for GetModel requests.""" - def __call__( - self, model_id: str, user_groups: Optional[List[str]] = None, is_admin: bool = False - ) -> GetModelResponse: + def __call__(self, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False) -> GetModelResponse: """Get model metadata from LiteLLM and translate to a model management response object.""" ddb_item = self._model_table.get_item(Key={"model_id": model_id}).get("Item", None) if not ddb_item: diff --git a/lambda/models/handler/list_models_handler.py b/lambda/models/handler/list_models_handler.py index fd2ae7541..5b65ff299 100644 --- a/lambda/models/handler/list_models_handler.py +++ b/lambda/models/handler/list_models_handler.py @@ -14,7 +14,6 @@ """Handler for ListModels requests.""" -from typing import List, Optional from utilities.auth import user_has_group_access @@ -26,7 +25,7 @@ class ListModelsHandler(BaseApiHandler): """Handler class for ListModels requests.""" - def __call__(self, user_groups: Optional[List[str]] = None, is_admin: bool = False) -> ListModelsResponse: + def __call__(self, user_groups: list[str] | None = None, is_admin: bool = False) -> ListModelsResponse: """Call handler to get all models from DynamoDB and transform results into API response format.""" ddb_models = [] models_response = self._model_table.scan() diff --git a/lambda/models/handler/schedule_handlers.py b/lambda/models/handler/schedule_handlers.py index ff812622c..fda0a115f 100644 --- a/lambda/models/handler/schedule_handlers.py +++ b/lambda/models/handler/schedule_handlers.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from typing import Any, List, Optional +from typing import Any from ..domain_objects import ( DeleteScheduleResponse, @@ -48,7 +48,7 @@ def __call__( self, model_id: str, schedule_config: SchedulingConfig, - user_groups: Optional[List[str]] = None, + user_groups: list[str] | None = None, is_admin: bool = False, ) -> UpdateScheduleResponse: """Create or update a schedule for a model""" @@ -86,7 +86,7 @@ class GetScheduleHandler(ScheduleBaseHandler): """Handler class for GetSchedule requests""" def __call__( - self, model_id: str, user_groups: Optional[List[str]] = None, is_admin: bool = False + self, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False ) -> GetScheduleResponse: """Get current schedule configuration for a model""" # Validate model exists and user access @@ -111,7 +111,7 @@ class DeleteScheduleHandler(ScheduleBaseHandler): """Handler class for DeleteSchedule requests""" def __call__( - self, model_id: str, user_groups: Optional[List[str]] = None, is_admin: bool = False + self, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False ) -> DeleteScheduleResponse: """Delete a schedule for a model""" # Validate model exists, user access, and model status @@ -132,7 +132,7 @@ class GetScheduleStatusHandler(ScheduleBaseHandler): """Handler class for GetScheduleStatus requests""" def __call__( - self, model_id: str, user_groups: Optional[List[str]] = None, is_admin: bool = False + self, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False ) -> GetScheduleStatusResponse: """Get current schedule status and next scheduled action for a model""" # Validate model exists and user access diff --git a/lambda/models/handler/update_model_handler.py b/lambda/models/handler/update_model_handler.py index 4d21d631f..b338d1884 100644 --- a/lambda/models/handler/update_model_handler.py +++ b/lambda/models/handler/update_model_handler.py @@ -26,7 +26,7 @@ class UpdateModelHandler(BaseApiHandler): """Handler class for UpdateModel requests.""" - def __call__(self, model_id: str, update_request: UpdateModelRequest) -> UpdateModelResponse: # type: ignore + def __call__(self, model_id: str, update_request: UpdateModelRequest) -> UpdateModelResponse: """Call handler to update model metadata or scaling config based on user request.""" ddb_item = self._model_table.get_item(Key={"model_id": model_id}).get("Item", None) if not ddb_item: diff --git a/lambda/models/handler/utils.py b/lambda/models/handler/utils.py index 00486a276..334650926 100644 --- a/lambda/models/handler/utils.py +++ b/lambda/models/handler/utils.py @@ -14,7 +14,8 @@ """Common utility functions across all API handlers.""" -from typing import Any, Dict, List, Optional +import logging +from typing import Any from utilities.auth import user_has_group_access from utilities.validation import ValidationError @@ -22,19 +23,22 @@ from ..domain_objects import GuardrailConfig, LISAModel from ..exception import InvalidStateTransitionError, ModelNotFoundError +logger = logging.getLogger(__name__) -def to_lisa_model(model_dict: Dict[str, Any]) -> LISAModel: + +def to_lisa_model(model_dict: dict[str, Any]) -> LISAModel: """Convert DDB model entry dictionary to a LISAModel object.""" - model_dict["model_config"]["status"] = model_dict["model_status"] + model_config = model_dict.get("model_config", {}) + model_config["status"] = model_dict.get("model_status", "Unknown") if "model_url" in model_dict: - model_dict["model_config"]["modelUrl"] = model_dict["model_url"] - lisa_model: LISAModel = LISAModel.model_validate(model_dict["model_config"]) + model_config["modelUrl"] = model_dict["model_url"] + lisa_model: LISAModel = LISAModel.model_validate(model_config) return lisa_model def get_model_and_validate_access( - model_table, model_id: str, user_groups: Optional[List[str]] = None, is_admin: bool = False -) -> Dict[str, Any]: + model_table: Any, model_id: str, user_groups: list[str] | None = None, is_admin: bool = False +) -> dict[str, Any]: """ Get model from DynamoDB and validate user access @@ -66,16 +70,16 @@ def get_model_and_validate_access( if not user_has_group_access(user_groups, allowed_groups): raise ValidationError(f"Access denied to access model {model_id}") - return model_item + return model_item # type: ignore[no-any-return] def get_model_and_validate_status( - model_table, + model_table: Any, model_id: str, - allowed_statuses: List[str] = None, - user_groups: Optional[List[str]] = None, + allowed_statuses: list[str] | None = None, + user_groups: list[str] | None = None, is_admin: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Get model from DynamoDB, validate user access, and check model status @@ -111,12 +115,12 @@ def get_model_and_validate_status( return model_item -def create_guardrail_config(item: Dict[str, Any]) -> GuardrailConfig: +def create_guardrail_config(item: dict[str, Any]) -> GuardrailConfig: """Create a GuardrailConfig object from a DynamoDB guardrail item.""" return GuardrailConfig(**item) -def attach_guardrails_to_model(model: LISAModel, guardrail_items: List[Dict[str, Any]]) -> None: +def attach_guardrails_to_model(model: LISAModel, guardrail_items: list[dict[str, Any]]) -> None: """Build guardrails config from DDB items and attach to model.""" if not guardrail_items: return @@ -126,17 +130,17 @@ def attach_guardrails_to_model(model: LISAModel, guardrail_items: List[Dict[str, } -def fetch_guardrails_for_model(guardrails_table, model_id: str) -> List[Dict[str, Any]]: +def fetch_guardrails_for_model(guardrails_table: Any, model_id: str) -> list[dict[str, Any]]: """Query guardrails table for a specific model ID.""" guardrails_response = guardrails_table.query( IndexName="ModelIdIndex", KeyConditionExpression="modelId = :modelId", ExpressionAttributeValues={":modelId": model_id}, ) - return guardrails_response.get("Items", []) + return guardrails_response.get("Items", []) # type: ignore[no-any-return] -def fetch_all_guardrails(guardrails_table) -> List[Dict[str, Any]]: +def fetch_all_guardrails(guardrails_table: Any) -> list[dict[str, Any]]: """Scan all guardrails from the table with pagination.""" all_guardrails = [] guardrails_response = guardrails_table.scan() @@ -151,9 +155,9 @@ def fetch_all_guardrails(guardrails_table) -> List[Dict[str, Any]]: return all_guardrails -def group_guardrails_by_model(guardrail_items: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: +def group_guardrails_by_model(guardrail_items: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]: """Group guardrail items by modelId.""" - guardrails_by_model: Dict[str, List[Dict[str, Any]]] = {} + guardrails_by_model: dict[str, list[dict[str, Any]]] = {} for item in guardrail_items: model_id = item["modelId"] if model_id not in guardrails_by_model: diff --git a/lambda/models/lambda_functions.py b/lambda/models/lambda_functions.py index 549933c2f..8eb0a4d57 100644 --- a/lambda/models/lambda_functions.py +++ b/lambda/models/lambda_functions.py @@ -13,20 +13,19 @@ # limitations under the License. """APIGW endpoints for managing models.""" +import logging import os -from typing import Annotated, Union +from typing import Annotated +from urllib.parse import urlparse import boto3 import botocore.session -from fastapi import FastAPI, HTTPException, Path, Request -from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import RequestValidationError -from fastapi.middleware.cors import CORSMiddleware +from fastapi import HTTPException, Path, Request from fastapi.responses import JSONResponse from mangum import Mangum -from utilities.auth import get_groups, is_admin +from utilities.auth import get_groups, get_username, is_admin from utilities.common_functions import retry_config -from utilities.fastapi_middleware.aws_api_gateway_middleware import AWSAPIGatewayMiddleware +from utilities.fastapi_factory import create_fastapi_app from .domain_objects import ( CreateModelRequest, @@ -55,18 +54,10 @@ UpdateScheduleHandler, ) +logger = logging.getLogger(__name__) + sess = botocore.session.Session() -app = FastAPI(redirect_slashes=False, lifespan="off", docs_url="/docs", openapi_url="/openapi.json") -app.add_middleware(AWSAPIGatewayMiddleware) - -# Enable CORS -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], -) +app = create_fastapi_app() autoscaling = boto3.client("autoscaling", region_name=os.environ["AWS_REGION"], config=retry_config) dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -97,31 +88,97 @@ async def model_not_found_handler(request: Request, exc: ModelNotFoundError) -> return JSONResponse(status_code=404, content={"detail": str(exc)}) -@app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: - """Handle exception when request fails validation and and translate to a 422 error.""" - return JSONResponse( - status_code=422, content={"detail": jsonable_encoder(exc.errors()), "type": "RequestValidationError"} - ) - - @app.exception_handler(InvalidStateTransitionError) @app.exception_handler(ModelAlreadyExistsError) @app.exception_handler(ValueError) async def user_error_handler( - request: Request, exc: Union[InvalidStateTransitionError, ModelAlreadyExistsError, ValueError] + request: Request, exc: InvalidStateTransitionError | ModelAlreadyExistsError | ValueError ) -> JSONResponse: """Handle errors when customer requests options that cannot be processed.""" return JSONResponse(status_code=400, content={"detail": str(exc)}) +@app.exception_handler(ModelInUseError) +async def model_in_use_handler(request: Request, exc: ModelInUseError) -> JSONResponse: + """Handle exception when attempting to delete a model that is in use.""" + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + @app.post(path="", include_in_schema=False) @app.post(path="/") async def create_model(create_request: CreateModelRequest, request: Request) -> CreateModelResponse: """Endpoint to create a model.""" - admin_status, _ = get_admin_status_and_groups(request) + admin_status, user_groups = get_admin_status_and_groups(request) if not admin_status: raise HTTPException(status_code=403, detail="User does not have permission to create models.") + + # Extract user context for audit logging + event = request.scope.get("aws.event", {}) + username = get_username(event) if event else "unknown" + auth_type = event.get("requestContext", {}).get("authorizer", {}).get("authType", "unknown") + source_ip = event.get("requestContext", {}).get("identity", {}).get("sourceIp", "unknown") + + # Extract container image and healthcheck details for audit logging + container_image = None + registry_domain = None + healthcheck_command = None + + if create_request.containerConfig: + container_image = create_request.containerConfig.image.baseImage + # Extract registry domain from image URL + try: + if "://" in container_image: + registry_domain = urlparse(container_image).netloc + elif "/" in container_image: + registry_domain = container_image.split("/")[0] + else: + registry_domain = "unknown" + except Exception: + registry_domain = "parse_error" + + healthcheck_command = create_request.containerConfig.healthCheckConfig.command + + # Log CreateModel request for security audit + logger.info( + "CreateModel request", + extra={ + "event_type": "CREATE_MODEL_REQUEST", + "user": { + "username": username, + "groups": user_groups, + "auth_type": auth_type, + "source_ip": source_ip, + }, + "model": { + "model_id": create_request.modelId, + "model_name": create_request.modelName, + "instance_type": create_request.instanceType if hasattr(create_request, "instanceType") else None, + "auto_scaling": ( + { + "min_capacity": ( + create_request.autoScalingConfig.minCapacity if create_request.autoScalingConfig else None + ), + "max_capacity": ( + create_request.autoScalingConfig.maxCapacity if create_request.autoScalingConfig else None + ), + } + if create_request.autoScalingConfig + else None + ), + }, + "container": ( + { + "base_image": container_image, + "registry_domain": registry_domain, + "image_type": create_request.containerConfig.image.type if create_request.containerConfig else None, + "healthcheck_command": healthcheck_command, + } + if create_request.containerConfig + else None + ), + }, + ) + create_handler = CreateModelHandler( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, @@ -129,9 +186,44 @@ async def create_model(create_request: CreateModelRequest, request: Request) -> guardrails_table_resource=guardrails_table, ) try: - return create_handler(create_request=create_request) + response = create_handler(create_request=create_request) + + # Log successful creation + logger.info( + "CreateModel request successful", + extra={ + "event_type": "CREATE_MODEL_SUCCESS", + "model_id": create_request.modelId, + "username": username, + }, + ) + + return response except ModelAlreadyExistsError as e: + # Log failure + logger.warning( + "CreateModel request failed - model already exists", + extra={ + "event_type": "CREATE_MODEL_FAILURE", + "model_id": create_request.modelId, + "username": username, + "error": str(e), + }, + ) raise HTTPException(status_code=409, detail=str(e)) + except Exception as e: + # Log unexpected failure + logger.error( + "CreateModel request failed with unexpected error", + extra={ + "event_type": "CREATE_MODEL_ERROR", + "model_id": create_request.modelId, + "username": username, + "error": str(e), + }, + exc_info=True, + ) + raise @app.get(path="", include_in_schema=False) diff --git a/lambda/models/model_api_key_cleanup.py b/lambda/models/model_api_key_cleanup.py index f4b6afa14..0555f7d6f 100644 --- a/lambda/models/model_api_key_cleanup.py +++ b/lambda/models/model_api_key_cleanup.py @@ -28,18 +28,19 @@ import os import sys import traceback -from typing import Any, Dict, List +from typing import Any import boto3 import psycopg2 -from utilities.common_functions import retry_config +from utilities.common_functions import get_lambda_role_name, retry_config +from utilities.rds_auth import generate_auth_token # Add the lambda directory to the Python path sys.path.append("/opt/python") sys.path.append("/var/task") -def get_all_dynamodb_models() -> List[Dict[str, str]]: +def get_all_dynamodb_models() -> list[dict[str, str]]: """Get all models from DynamoDB table with their IDs and names.""" try: dynamodb = boto3.client("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -87,8 +88,8 @@ def get_all_dynamodb_models() -> List[Dict[str, str]]: return [] -def get_database_connection(): - """Get database connection using connection info from SSM.""" +def get_database_connection() -> Any: + """Get database connection using password auth or IAM auth based on config.""" ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) # Get database connection info from SSM using environment variable @@ -103,38 +104,47 @@ def get_database_connection(): except Exception as e: raise ValueError(f"Failed to get database connection info from SSM: {e}") - # Get database credentials from Secrets Manager - try: - secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) - secret_response = secrets_client.get_secret_value(SecretId=db_params["passwordSecretId"]) - secret = json.loads(secret_response["SecretString"]) - except Exception as e: - raise ValueError(f"Failed to get database credentials from Secrets Manager: {e}") - # Validate required parameters - required_params = ["dbHost", "dbPort", "dbName", "username"] + required_params = ["dbHost", "dbPort", "dbName"] for param in required_params: if param not in db_params: raise ValueError(f"Missing required database parameter: {param}") - if "password" not in secret: - raise ValueError("Missing password in secret") + # Check if using password auth (passwordSecretId present) or IAM auth + if "passwordSecretId" in db_params: + # Password auth: get credentials from Secrets Manager + try: + secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) + secret_response = secrets_client.get_secret_value(SecretId=db_params["passwordSecretId"]) + secret = json.loads(secret_response["SecretString"]) + except Exception as e: + raise ValueError(f"Failed to get database credentials from Secrets Manager: {e}") + + if "password" not in secret: + raise ValueError("Missing password in secret") - # Create connection with proper error handling + user = db_params.get("username", "postgres") + password = secret["password"] + else: + # IAM auth: generate auth token + user = get_lambda_role_name() + password = generate_auth_token(db_params["dbHost"], db_params["dbPort"], user) + + # Create connection try: conn = psycopg2.connect( host=db_params["dbHost"], port=db_params["dbPort"], database=db_params["dbName"], - user=db_params["username"], - password=secret["password"], + user=user, + password=password, ) return conn except Exception as e: raise ValueError(f"Failed to connect to database: {e}") -def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Lambda handler for Bedrock model API key cleanup. @@ -189,9 +199,7 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: # Use psycopg2's identifier quoting to prevent SQL injection cursor.execute( - psycopg2.sql.SQL("SELECT * FROM {} LIMIT 1").format( # noqa: S608, P103 - psycopg2.sql.Identifier(litellm_table) - ) + psycopg2.sql.SQL("SELECT * FROM {} LIMIT 1").format(psycopg2.sql.Identifier(litellm_table)) # noqa: S608 ) columns = [desc[0] for desc in cursor.description] print(f"Table {litellm_table} columns: {columns}") @@ -210,7 +218,7 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: # Query all models from the LiteLLM database cursor.execute( - psycopg2.sql.SQL("SELECT {}, {}, {} FROM {}").format( # noqa: S608, P103 + psycopg2.sql.SQL("SELECT {}, {}, {} FROM {}").format( # noqa: S608 psycopg2.sql.Identifier(model_id_col), psycopg2.sql.Identifier(model_name_col), psycopg2.sql.Identifier(litellm_params_col), @@ -276,7 +284,7 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: # Update the model in the database clean_params_json = json.dumps(clean_params) cursor.execute( - psycopg2.sql.SQL("UPDATE {} SET {} = %s WHERE {} = %s").format( # noqa: S608, P103 + psycopg2.sql.SQL("UPDATE {} SET {} = %s WHERE {} = %s").format( # noqa: S608 psycopg2.sql.Identifier(litellm_table), psycopg2.sql.Identifier(litellm_params_col), psycopg2.sql.Identifier(model_id_col), diff --git a/lambda/models/scheduling/schedule_management.py b/lambda/models/scheduling/schedule_management.py index 79275ac3b..4ecbc2aa2 100644 --- a/lambda/models/scheduling/schedule_management.py +++ b/lambda/models/scheduling/schedule_management.py @@ -16,7 +16,7 @@ import logging import os from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any from zoneinfo import ZoneInfo import boto3 @@ -43,7 +43,7 @@ model_table = dynamodb.Table(os.environ.get("MODEL_TABLE_NAME")) -def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """Main Lambda handler for schedule management operations""" try: logger.info(f"Processing schedule management request: {json.dumps(event, default=str)}") @@ -70,7 +70,7 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return {"statusCode": 500, "body": json.dumps({"error": "ScheduleManagementError", "message": str(e)})} -def update_schedule(event: Dict[str, Any]) -> Dict[str, Any]: +def update_schedule(event: dict[str, Any]) -> dict[str, Any]: """Update an existing schedule for a model""" model_id = event["modelId"] schedule_config = event.get("scheduleConfig") @@ -126,7 +126,7 @@ def update_schedule(event: Dict[str, Any]) -> Dict[str, Any]: raise RuntimeError(f"Failed to update schedule: {str(e)}") -def delete_schedule(event: Dict[str, Any]) -> Dict[str, Any]: +def delete_schedule(event: dict[str, Any]) -> dict[str, Any]: """Delete a schedule for a model""" model_id = event["modelId"] @@ -166,7 +166,7 @@ def delete_schedule(event: Dict[str, Any]) -> Dict[str, Any]: raise RuntimeError(f"Failed to delete schedule: {str(e)}") -def get_schedule(event: Dict[str, Any]) -> Dict[str, Any]: +def get_schedule(event: dict[str, Any]) -> dict[str, Any]: """Get current schedule configuration for a model""" model_id = event["modelId"] @@ -199,7 +199,7 @@ def get_schedule(event: Dict[str, Any]) -> Dict[str, Any]: raise RuntimeError(f"Failed to get schedule: {str(e)}") -def create_scheduled_actions(model_id: str, auto_scaling_group: str, schedule_config: SchedulingConfig) -> List[str]: +def create_scheduled_actions(model_id: str, auto_scaling_group: str, schedule_config: SchedulingConfig) -> list[str]: """Create Auto Scaling scheduled actions based on schedule configuration""" scheduled_action_arns = [] @@ -222,7 +222,7 @@ def create_scheduled_actions(model_id: str, auto_scaling_group: str, schedule_co return scheduled_action_arns -def create_scheduling_config(schedule_data: Dict[str, Any]) -> SchedulingConfig: +def create_scheduling_config(schedule_data: dict[str, Any]) -> SchedulingConfig: """Create the appropriate scheduling config instance based on scheduleType""" schedule_type = schedule_data.get("scheduleType") @@ -234,7 +234,7 @@ def create_scheduling_config(schedule_data: Dict[str, Any]) -> SchedulingConfig: raise ValueError(f"Unknown schedule type: {schedule_type}") -def get_existing_asg_capacity(auto_scaling_group: str) -> Dict[str, int]: +def get_existing_asg_capacity(auto_scaling_group: str) -> dict[str, int]: """Get the existing Auto Scaling Group's current capacity configuration""" try: response = autoscaling_client.describe_auto_scaling_groups(AutoScalingGroupNames=[auto_scaling_group]) @@ -255,7 +255,7 @@ def get_existing_asg_capacity(auto_scaling_group: str) -> Dict[str, int]: raise RuntimeError(f"Failed to get ASG capacity: {str(e)}") -def get_model_baseline_capacity(model_id: str) -> Dict[str, int]: +def get_model_baseline_capacity(model_id: str) -> dict[str, int]: """Get the baseline capacity configuration from the model's DynamoDB record""" try: response = model_table.get_item(Key={"model_id": model_id}) @@ -392,7 +392,7 @@ def scale_immediately(auto_scaling_group: str, day_schedule: DaySchedule, timezo def create_recurring_scheduled_actions( model_id: str, auto_scaling_group: str, day_schedule: DaySchedule, timezone_name: str -) -> List[str]: +) -> list[str]: """Create scheduled actions for recurring schedule""" scheduled_action_arns = [] @@ -460,7 +460,7 @@ def create_recurring_scheduled_actions( def create_daily_scheduled_actions( model_id: str, auto_scaling_group: str, daily_schedule: WeeklySchedule, timezone_name: str -) -> List[str]: +) -> list[str]: """Create scheduled actions for daily schedule (different times each day with one start/stop time per day)""" scheduled_action_arns = [] @@ -566,7 +566,7 @@ def construct_scheduled_action_arn(auto_scaling_group: str, action_name: str) -> ) -def delete_scheduled_actions(scheduled_action_arns: List[str]) -> None: +def delete_scheduled_actions(scheduled_action_arns: list[str]) -> None: """Delete Auto Scaling scheduled actions by ARN""" for arn in scheduled_action_arns: try: @@ -588,7 +588,7 @@ def delete_scheduled_actions(scheduled_action_arns: List[str]) -> None: raise -def cleanup_scheduled_actions(scheduled_action_arns: List[str]) -> None: +def cleanup_scheduled_actions(scheduled_action_arns: list[str]) -> None: """Clean up scheduled actions (used for error recovery)""" for arn in scheduled_action_arns: try: @@ -654,7 +654,7 @@ def cleanup_scheduled_actions_by_name_pattern(auto_scaling_group: str, model_id: logger.error(f"Failed to cleanup scheduled actions by pattern for model {model_id}: {e}") -def calculate_next_scheduled_action(schedule_config: SchedulingConfig, timezone_name: str) -> Optional[Dict[str, str]]: +def calculate_next_scheduled_action(schedule_config: SchedulingConfig, timezone_name: str) -> dict[str, str] | None: """Calculate the next scheduled action (START or STOP) based on the schedule configuration""" try: tz = ZoneInfo(timezone_name) @@ -671,7 +671,7 @@ def calculate_next_scheduled_action(schedule_config: SchedulingConfig, timezone_ return None -def _calculate_next_recurring_action(day_schedule: DaySchedule, now: datetime, tz: ZoneInfo) -> Dict[str, str]: +def _calculate_next_recurring_action(day_schedule: DaySchedule, now: datetime, tz: ZoneInfo) -> dict[str, str]: """Calculate next action for recurring schedule""" # Parse schedule times start_hour, start_minute = map(int, day_schedule.startTime.split(":")) @@ -699,9 +699,7 @@ def _calculate_next_recurring_action(day_schedule: DaySchedule, now: datetime, t return {"action": "START", "scheduledTime": tomorrow_start.isoformat()} -def _calculate_next_daily_action( - daily_schedule: WeeklySchedule, now: datetime, tz: ZoneInfo -) -> Optional[Dict[str, str]]: +def _calculate_next_daily_action(daily_schedule: WeeklySchedule, now: datetime, tz: ZoneInfo) -> dict[str, str] | None: """Calculate next action for daily schedule""" current_weekday = now.weekday() @@ -739,7 +737,7 @@ def _calculate_next_daily_action( return None -def _get_next_action_for_today(day_schedule: DaySchedule, now: datetime, tz: ZoneInfo) -> Optional[Dict[str, str]]: +def _get_next_action_for_today(day_schedule: DaySchedule, now: datetime, tz: ZoneInfo) -> dict[str, str] | None: """Get next action for today's schedule only""" today = now.date() @@ -765,7 +763,7 @@ def _get_next_action_for_today(day_schedule: DaySchedule, now: datetime, tz: Zon return None -def merge_schedule_data(model_id: str, partial_update: Dict[str, Any]) -> Dict[str, Any]: +def merge_schedule_data(model_id: str, partial_update: dict[str, Any]) -> dict[str, Any]: """Merge partial schedule update with existing schedule data""" # Get existing schedule data from model_config.autoScalingConfig.scheduling existing_data = {} @@ -804,7 +802,7 @@ def merge_schedule_data(model_id: str, partial_update: Dict[str, Any]) -> Dict[s return merged_data -def get_existing_scheduled_action_arns(model_id: str) -> List[str]: +def get_existing_scheduled_action_arns(model_id: str) -> list[str]: """Get existing scheduled action ARNs for a model""" try: response = model_table.get_item(Key={"model_id": model_id}) @@ -817,7 +815,7 @@ def get_existing_scheduled_action_arns(model_id: str) -> List[str]: auto_scaling_config = model_config.get("autoScalingConfig", {}) scheduling_config = auto_scaling_config.get("scheduling", {}) - return scheduling_config.get("scheduledActionArns", []) + return scheduling_config.get("scheduledActionArns", []) # type: ignore[no-any-return] except Exception as e: logger.error(f"Failed to get existing scheduled actions for model {model_id}: {e}") @@ -825,7 +823,7 @@ def get_existing_scheduled_action_arns(model_id: str) -> List[str]: def update_model_schedule_record( - model_id: str, scheduling_config: Optional[SchedulingConfig], scheduled_action_arns: List[str], enabled: bool + model_id: str, scheduling_config: SchedulingConfig | None, scheduled_action_arns: list[str], enabled: bool ) -> None: """Update model record in DynamoDB with schedule information""" try: diff --git a/lambda/models/scheduling/schedule_monitoring.py b/lambda/models/scheduling/schedule_monitoring.py index fad7de176..666e0beaa 100644 --- a/lambda/models/scheduling/schedule_monitoring.py +++ b/lambda/models/scheduling/schedule_monitoring.py @@ -15,7 +15,7 @@ import json import logging import os -from typing import Any, Dict, Optional +from typing import Any import boto3 from botocore.config import Config @@ -35,7 +35,7 @@ model_table = dynamodb.Table(os.environ.get("MODEL_TABLE_NAME")) -def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """Main Lambda handler for CloudWatch Events from Auto Scaling Groups""" logger.info(f"Processing event - RequestId: {context.aws_request_id}") @@ -58,7 +58,7 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return {"statusCode": 500, "body": json.dumps({"error": "ScheduleMonitoringError", "message": str(e)})} -def handle_autoscaling_event(event: Dict[str, Any]) -> Dict[str, Any]: +def handle_autoscaling_event(event: dict[str, Any]) -> dict[str, Any]: """Handle Auto Scaling Group CloudWatch events""" try: detail = event.get("detail", {}) @@ -85,7 +85,7 @@ def handle_autoscaling_event(event: Dict[str, Any]) -> Dict[str, Any]: raise ValueError(f"Failed to handle Auto Scaling event: {str(e)}") -def handle_successful_scaling(model_id: str, auto_scaling_group: str, detail: Dict[str, Any]) -> Dict[str, Any]: +def handle_successful_scaling(model_id: str, auto_scaling_group: str, detail: dict[str, Any]) -> dict[str, Any]: """Handle successful Auto Scaling actions using ASG state""" try: # Check ASG state to determine model status @@ -137,7 +137,7 @@ def handle_successful_scaling(model_id: str, auto_scaling_group: str, detail: Di raise -def sync_model_status(event: Dict[str, Any]) -> Dict[str, Any]: +def sync_model_status(event: dict[str, Any]) -> dict[str, Any]: """Manually sync model status using ASG state""" model_id = event.get("modelId") if not model_id: @@ -207,7 +207,7 @@ def sync_model_status(event: Dict[str, Any]) -> Dict[str, Any]: raise ValueError(f"Failed to sync status: {str(e)}") -def find_model_by_asg_name(asg_name: str) -> Optional[str]: +def find_model_by_asg_name(asg_name: str) -> str | None: """Find model ID by looking up which model uses the given Auto Scaling Group""" try: response = model_table.scan( @@ -217,7 +217,7 @@ def find_model_by_asg_name(asg_name: str) -> Optional[str]: ) if response["Items"]: - return response["Items"][0]["model_id"] + return response["Items"][0]["model_id"] # type: ignore[no-any-return] return None @@ -249,7 +249,7 @@ def update_model_status(model_id: str, new_status: ModelStatus, reason: str) -> raise -def get_model_info(model_id: str) -> Optional[Dict[str, Any]]: +def get_model_info(model_id: str) -> dict[str, Any] | None: """Get model information from DynamoDB""" try: response = model_table.get_item(Key={"model_id": model_id}) @@ -257,7 +257,7 @@ def get_model_info(model_id: str) -> Optional[Dict[str, Any]]: if "Item" not in response: return None - return response["Item"] + return response["Item"] # type: ignore[no-any-return] except Exception as e: logger.error(f"Failed to get model info for {model_id}: {e}") diff --git a/lambda/models/state_machine/create_model.py b/lambda/models/state_machine/create_model.py index 539ca8c8e..0be5b750e 100644 --- a/lambda/models/state_machine/create_model.py +++ b/lambda/models/state_machine/create_model.py @@ -19,13 +19,13 @@ import os from copy import deepcopy from datetime import datetime -from typing import Any, Dict +from typing import Any from zoneinfo import ZoneInfo import boto3 from botocore.config import Config from models.clients.litellm_client import LiteLLMClient -from models.domain_objects import CreateModelRequest, GuardrailsTableEntry, InferenceContainer, ModelStatus +from models.domain_objects import CreateModelRequest, GuardrailsTableEntry, InferenceContainer, ModelStatus, ModelType from models.exception import ( MaxPollsExceededException, StackFailedToCreateException, @@ -81,11 +81,11 @@ def get_container_path(inference_container_type: InferenceContainer) -> str: return path_mapping[inference_container_type] -def adjust_initial_capacity_for_schedule(prepared_event: Dict[str, Any]) -> None: +def adjust_initial_capacity_for_schedule(prepared_event: dict[str, Any]) -> None: """Adjust Auto Scaling Group initial capacity based on schedule configuration""" try: # Check if scheduling is configured - auto_scaling_config = prepared_event.get("autoScalingConfig", {}) + auto_scaling_config = prepared_event.get("autoScalingConfig", {}) or {} scheduling_config = auto_scaling_config.get("scheduling") if ( @@ -172,22 +172,20 @@ def adjust_initial_capacity_for_schedule(prepared_event: Dict[str, Any]) -> None logger.info("Using original capacity settings due to scheduling error") -def handle_set_model_to_creating(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_set_model_to_creating(event: dict[str, Any], context: Any) -> dict[str, Any]: """Set DDB entry to CREATING status.""" logger.info(f"Setting model to CREATING status: {event.get('modelId')}") output_dict = deepcopy(event) request = CreateModelRequest.model_validate(event) is_lisa_managed = all( - ( - bool(request_param) - for request_param in ( - request.autoScalingConfig, - request.containerConfig, - request.inferenceContainer, - request.instanceType, - request.loadBalancerConfig, - ) + bool(request_param) + for request_param in ( + request.autoScalingConfig, + request.containerConfig, + request.inferenceContainer, + request.instanceType, + request.loadBalancerConfig, ) ) @@ -213,7 +211,7 @@ def handle_set_model_to_creating(event: Dict[str, Any], context: Any) -> Dict[st return output_dict -def handle_start_copy_docker_image(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_start_copy_docker_image(event: dict[str, Any], context: Any) -> dict[str, Any]: """Start process for copying Docker image into local AWS account.""" logger.info(f"Starting Docker image copy for model: {event.get('modelId')}") output_dict = deepcopy(event) @@ -282,7 +280,7 @@ def handle_start_copy_docker_image(event: Dict[str, Any], context: Any) -> Dict[ return output_dict -def handle_poll_docker_image_available(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_poll_docker_image_available(event: dict[str, Any], context: Any) -> dict[str, Any]: """Check that Docker image is available in account or not.""" output_dict = deepcopy(event) @@ -320,7 +318,7 @@ def handle_poll_docker_image_available(event: Dict[str, Any], context: Any) -> D return output_dict -def handle_start_create_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_start_create_stack(event: dict[str, Any], context: Any) -> dict[str, Any]: """Start model infrastructure creation.""" output_dict = deepcopy(event) request = CreateModelRequest.model_validate(event) @@ -363,7 +361,7 @@ def camelize_object(o): # type: ignore[no-untyped-def] } # Remove scheduling configuration from autoScalingConfig before sending to ECS deployer - if "autoScalingConfig" in prepared_event and "scheduling" in prepared_event["autoScalingConfig"]: + if prepared_event.get("autoScalingConfig") and "scheduling" in prepared_event["autoScalingConfig"]: del prepared_event["autoScalingConfig"]["scheduling"] # Log the complete payload being sent (excluding large environment variables) @@ -448,7 +446,7 @@ def camelize_object(o): # type: ignore[no-untyped-def] return output_dict -def handle_poll_create_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_poll_create_stack(event: dict[str, Any], context: Any) -> dict[str, Any]: """Check that model infrastructure creation has completed or not.""" output_dict = deepcopy(event) stack = cfnClient.describe_stacks(StackName=event["stack_name"])["Stacks"][0] @@ -486,11 +484,93 @@ def handle_poll_create_stack(event: Dict[str, Any], context: Any) -> Dict[str, A ) -def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +autoscaling_client = boto3.client("autoscaling", region_name=os.environ["AWS_REGION"], config=retry_config) + + +def handle_poll_model_ready(event: dict[str, Any], context: Any) -> dict[str, Any]: + """ + Poll ASG to confirm model instances are healthy before marking as InService. + + This handler checks that the Auto Scaling Group has healthy instances running + before proceeding to add the model to LiteLLM. This ensures the model is actually + ready to serve requests, not just that the infrastructure was created. + """ + output_dict = deepcopy(event) + model_id = event.get("modelId", "unknown") + asg_name = event.get("autoScalingGroup") + + if not asg_name: + logger.warning(f"No ASG name found for model {model_id}, skipping capacity check") + output_dict["continue_polling_capacity"] = False + output_dict["remaining_capacity_polls"] = 0 + return output_dict + + logger.info(f"Polling capacity for model {model_id}, ASG: {asg_name}") + + try: + asg_info = autoscaling_client.describe_auto_scaling_groups(AutoScalingGroupNames=[asg_name])[ + "AutoScalingGroups" + ][0] + + desired_capacity = asg_info["DesiredCapacity"] + instances = asg_info.get("Instances", []) + num_healthy_instances = sum( + 1 + for instance in instances + if instance.get("HealthStatus") == "Healthy" and instance.get("LifecycleState") == "InService" + ) + + logger.info( + f"ASG {asg_name}: desired={desired_capacity}, healthy_in_service={num_healthy_instances}, " + f"total_instances={len(instances)}" + ) + + # Initialize or decrement remaining polls + remaining_polls = event.get("remaining_capacity_polls", 60) - 1 # ~30 minutes at 30s intervals + output_dict["remaining_capacity_polls"] = remaining_polls + + if remaining_polls <= 0: + logger.error(f"Model '{model_id}' did not start healthy instances in expected amount of time.") + # Continue anyway - the model will be added to LiteLLM but may not be ready + # This allows the user to see the model and troubleshoot + output_dict["continue_polling_capacity"] = False + output_dict["capacity_timeout"] = True + return output_dict + + # Check if we have the desired number of healthy instances + # For scheduled models that start with 0 capacity, we consider them ready + if desired_capacity == 0: + logger.info(f"Model {model_id} has desired capacity of 0 (scheduled), marking as ready") + output_dict["continue_polling_capacity"] = False + elif num_healthy_instances >= desired_capacity: + logger.info(f"Model {model_id} has {num_healthy_instances}/{desired_capacity} healthy instances, ready!") + output_dict["continue_polling_capacity"] = False + else: + logger.info( + f"Model {model_id} waiting for instances: {num_healthy_instances}/{desired_capacity} healthy. " + f"Polls remaining: {remaining_polls}" + ) + output_dict["continue_polling_capacity"] = True + + except Exception as e: + logger.error(f"Error checking ASG status for model {model_id}: {e}") + # On error, continue polling if we have polls remaining + remaining_polls = event.get("remaining_capacity_polls", 60) - 1 + output_dict["remaining_capacity_polls"] = remaining_polls + output_dict["continue_polling_capacity"] = remaining_polls > 0 + + return output_dict + + +def handle_add_model_to_litellm(event: dict[str, Any], context: Any) -> dict[str, Any]: """Add model to LiteLLM once it is created.""" output_dict = deepcopy(event) is_lisa_managed = event["create_infra"] + # Check if this is a video generation model + model_type = event.get("modelType", "").upper() + is_video_model = model_type == ModelType.VIDEOGEN.upper() + # Parse the JSON string from environment variable litellm_config_str = os.environ.get("LITELLM_CONFIG_OBJ", json.dumps({})) try: @@ -503,7 +583,14 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str # Only set api_key if it's present in the event if "apiKey" in event: litellm_params["api_key"] = event["apiKey"] # pragma: allowlist-secret - litellm_params["drop_params"] = True # drop unrecognized param instead of failing the request on it + + # For video generation models, use empty litellm_settings to avoid drop_params error + if is_video_model: + litellm_params = {} + if "apiKey" in event: + litellm_params["api_key"] = event["apiKey"] # pragma: allowlist-secret + else: + litellm_params["drop_params"] = True # drop unrecognized param instead of failing the request on it if is_lisa_managed: # get load balancer from cloudformation stack @@ -547,7 +634,7 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str ) # If scheduling is configured, sync model status to ensure it reflects actual ASG state - scheduling_config = event.get("autoScalingConfig", {}).get("scheduling") + scheduling_config = (event.get("autoScalingConfig", {}) or {}).get("scheduling") auto_scaling_group = event.get("autoScalingGroup") if scheduling_config and auto_scaling_group: @@ -562,7 +649,7 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str return output_dict -def handle_add_guardrails_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_add_guardrails_to_litellm(event: dict[str, Any], context: Any) -> dict[str, Any]: """Add guardrails to LiteLLM and store them in DynamoDB.""" logger.info(f"Adding guardrails to LiteLLM for model: {event.get('modelId')}") output_dict = deepcopy(event) @@ -664,7 +751,7 @@ def handle_add_guardrails_to_litellm(event: Dict[str, Any], context: Any) -> Dic return output_dict -def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_failure(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Handle failures from state machine. diff --git a/lambda/models/state_machine/delete_model.py b/lambda/models/state_machine/delete_model.py index 73589ed44..721f24df5 100644 --- a/lambda/models/state_machine/delete_model.py +++ b/lambda/models/state_machine/delete_model.py @@ -17,7 +17,7 @@ import logging import os from copy import deepcopy -from typing import Any, Dict +from typing import Any from uuid import uuid4 import boto3 @@ -54,7 +54,7 @@ LITELLM_ID = "litellm_id" -def handle_set_model_to_deleting(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_set_model_to_deleting(event: dict[str, Any], context: Any) -> dict[str, Any]: """Start deletion workflow based on user-specified model input.""" output_dict = deepcopy(event) model_id = event["modelId"] @@ -81,14 +81,14 @@ def handle_set_model_to_deleting(event: Dict[str, Any], context: Any) -> Dict[st return output_dict -def handle_delete_from_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_delete_from_litellm(event: dict[str, Any], context: Any) -> dict[str, Any]: """Delete model reference from LiteLLM.""" if event[LITELLM_ID]: # if non-null ID litellm_client.delete_model(identifier=event[LITELLM_ID]) return event -def handle_delete_guardrails(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_delete_guardrails(event: dict[str, Any], context: Any) -> dict[str, Any]: """Delete all guardrails associated with the model from both LiteLLM and DynamoDB.""" logger.info(f"Deleting guardrails for model: {event.get('modelId')}") output_dict = deepcopy(event) @@ -150,7 +150,7 @@ def handle_delete_guardrails(event: Dict[str, Any], context: Any) -> Dict[str, A return output_dict -def handle_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_delete_stack(event: dict[str, Any], context: Any) -> dict[str, Any]: """Initialize stack deletion.""" stack_arn = event[CFN_STACK_ARN] logger.info(f"Deleting CloudFormation stack: {stack_arn}") @@ -162,7 +162,7 @@ def handle_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return event # no payload mutations needed between this and next state -def handle_monitor_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_monitor_delete_stack(event: dict[str, Any], context: Any) -> dict[str, Any]: """Get stack status while it is being deleted and evaluate if state machine should continue polling.""" output_dict = deepcopy(event) stack_arn = event[CFN_STACK_ARN] @@ -179,7 +179,7 @@ def handle_monitor_delete_stack(event: Dict[str, Any], context: Any) -> Dict[str return output_dict -def handle_delete_from_ddb(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_delete_from_ddb(event: dict[str, Any], context: Any) -> dict[str, Any]: """Delete item from DDB after successful deletion workflow.""" model_key = {"model_id": event["modelId"]} ddb_table.delete_item(Key=model_key) diff --git a/lambda/models/state_machine/schedule_handlers.py b/lambda/models/state_machine/schedule_handlers.py index 8923b9058..8efa510ac 100644 --- a/lambda/models/state_machine/schedule_handlers.py +++ b/lambda/models/state_machine/schedule_handlers.py @@ -15,7 +15,7 @@ import json import logging import os -from typing import Any, Dict +from typing import Any import boto3 from botocore.config import Config @@ -32,7 +32,7 @@ model_table = dynamodb.Table(os.environ.get("MODEL_TABLE_NAME")) -def handle_schedule_creation(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_schedule_creation(event: dict[str, Any], context: Any) -> dict[str, Any]: """Create Auto Scaling scheduled actions for the model if scheduling is configured""" logger.info(f"Processing schedule creation for model: {event.get('modelId')}") output_dict = event.copy() @@ -84,7 +84,7 @@ def handle_schedule_creation(event: Dict[str, Any], context: Any) -> Dict[str, A return output_dict -def handle_schedule_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_schedule_update(event: dict[str, Any], context: Any) -> dict[str, Any]: """Update Auto Scaling scheduled actions when schedule configuration changes""" logger.info(f"Processing schedule update for model: {event.get('modelId')}") output_dict = event.copy() @@ -126,7 +126,7 @@ def handle_schedule_update(event: Dict[str, Any], context: Any) -> Dict[str, Any return output_dict -def handle_cleanup_schedule(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_cleanup_schedule(event: dict[str, Any], context: Any) -> dict[str, Any]: """Clean up scheduled actions before deleting the model""" logger.info(f"Cleaning up schedule for model: {event.get('modelId')}") output_dict = event.copy() diff --git a/lambda/models/state_machine/update_model.py b/lambda/models/state_machine/update_model.py index 029caaf4c..bda54faba 100644 --- a/lambda/models/state_machine/update_model.py +++ b/lambda/models/state_machine/update_model.py @@ -17,12 +17,13 @@ import json import logging import os +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional +from typing import Any import boto3 from models.clients.litellm_client import LiteLLMClient -from models.domain_objects import GuardrailsTableEntry, ModelStatus +from models.domain_objects import GuardrailsTableEntry, ModelStatus, ModelType from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config from utilities.time import now @@ -50,15 +51,15 @@ logging.basicConfig(level=logging.INFO) -def _update_simple_field(model_config: Dict[str, Any], field_name: str, value: Any, model_id: str) -> None: +def _update_simple_field(model_config: dict[str, Any], field_name: str, value: Any, model_id: str) -> None: """Update a simple field in model_config.""" logger.info(f"Setting {field_name} to '{value}' for model '{model_id}'") model_config[field_name] = value def _update_container_config( - model_config: Dict[str, Any], container_config: Dict[str, Any], model_id: str -) -> Dict[str, Any]: + model_config: dict[str, Any], container_config: dict[str, Any], model_id: str +) -> dict[str, Any]: """Handle container config update. Returns: @@ -119,7 +120,7 @@ def _update_container_config( return container_metadata -def _get_metadata_update_handlers(model_config: Dict[str, Any], model_id: str) -> Dict[str, Callable[..., Any]]: +def _get_metadata_update_handlers(model_config: dict[str, Any], model_id: str) -> dict[str, Callable[..., Any]]: """Return a dictionary mapping field names to their update handlers.""" return { "modelType": lambda value: _update_simple_field(model_config, "modelType", value, model_id), @@ -132,8 +133,8 @@ def _get_metadata_update_handlers(model_config: Dict[str, Any], model_id: str) - def _process_metadata_updates( - model_config: Dict[str, Any], update_payload: Dict[str, Any], model_id: str -) -> tuple[bool, Dict[str, Any]]: + model_config: dict[str, Any], update_payload: dict[str, Any], model_id: str +) -> tuple[bool, dict[str, Any]]: """ Process metadata updates. @@ -163,7 +164,7 @@ def _process_metadata_updates( return has_updates, update_metadata -def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_job_intake(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Handle initial UpdateModel job submission. @@ -338,7 +339,7 @@ def handle_job_intake(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict -def handle_poll_capacity(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_poll_capacity(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Poll autoscaling and target group to confirm if the capacity is done updating. @@ -369,7 +370,7 @@ def handle_poll_capacity(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict -def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_finish_update(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Finalize update in DDB. @@ -388,6 +389,10 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: )["Item"] model_url = ddb_item["model_url"] + # Check if this is a video generation model + model_type = ddb_item.get("model_config", {}).get("modelType", "").upper() + is_video_model = model_type == ModelType.VIDEOGEN.value.upper() + # Parse the JSON string from environment variable litellm_config_str = os.environ.get("LITELLM_CONFIG_OBJ", json.dumps({})) try: @@ -397,11 +402,15 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: # Fallback to default if JSON parsing fails litellm_params = {} + # For video generation models, use empty litellm_settings to avoid drop_params error + if is_video_model: + litellm_params = {} + litellm_params["model"] = f"openai/{ddb_item['model_config']['modelName']}" litellm_params["api_base"] = model_url ddb_update_expression = "SET model_status = :ms, last_modified_date = :lm" - ddb_update_values: Dict[str, Any] = { + ddb_update_values: dict[str, Any] = { ":lm": now(), } @@ -441,7 +450,7 @@ def handle_finish_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict -def handle_update_guardrails(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_update_guardrails(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Update guardrails for a model in LiteLLM and DynamoDB. @@ -734,9 +743,9 @@ def get_ecs_resources_from_stack(stack_name: str) -> tuple[str, str, str]: def create_updated_task_definition( task_definition_arn: str, - updated_env_vars: Dict[str, str], - env_vars_to_delete: Optional[List[str]] = None, - updated_container_config: Optional[Dict[str, Any]] = None, + updated_env_vars: dict[str, str], + env_vars_to_delete: list[str] | None = None, + updated_container_config: dict[str, Any] | None = None, ) -> str: """Create new task definition revision with updated environment variables and container config. @@ -861,7 +870,7 @@ def update_ecs_service(cluster_arn: str, service_arn: str, task_definition_arn: raise RuntimeError(f"Failed to update ECS service: {str(e)}") -def handle_ecs_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_ecs_update(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Update ECS task definition with new environment variables and update service. @@ -909,6 +918,7 @@ def handle_ecs_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: update_ecs_service(cluster_arn, service_arn, new_task_def_arn) # Set up tracking for deployment monitoring + output_dict["old_task_definition_arn"] = task_definition_arn # Save old task def for cleanup output_dict["new_task_definition_arn"] = new_task_def_arn output_dict["ecs_service_arn"] = service_arn output_dict["ecs_cluster_arn"] = cluster_arn @@ -923,14 +933,15 @@ def handle_ecs_update(event: Dict[str, Any], context: Any) -> Dict[str, Any]: return output_dict -def handle_poll_ecs_deployment(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_poll_ecs_deployment(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Monitor ECS service deployment progress. This handler will: 1. Check if ECS service deployment is complete - 2. Return boolean for continued polling if needed - 3. Handle deployment failures + 2. Verify that tasks are actually running and healthy + 3. Return boolean for continued polling if needed + 4. Handle deployment failures """ output_dict = deepcopy(event) model_id = event["model_id"] @@ -968,16 +979,49 @@ def handle_poll_ecs_deployment(event: Dict[str, Any], context: Any) -> Dict[str, and task_def.startswith(new_task_def_arn.split(":")[0]) ): primary_deployment = deployment + running_count = deployment.get("runningCount", 0) + desired_count = deployment.get("desiredCount", 0) + pending_count = deployment.get("pendingCount", 0) + rollout_state = deployment.get("rolloutState", "N/A") + logger.info( f"Found matching deployment: status={deployment['status']}, " - f"rolloutState={deployment.get('rolloutState', 'N/A')}" + f"rolloutState={rollout_state}, " + f"running={running_count}, desired={desired_count}, pending={pending_count}" ) - if deployment["status"] != "PRIMARY" or deployment.get("rolloutState") != "COMPLETED": + + # For daemon services, desiredCount may be 0 or match the number of instances + # We need to check that: + # 1. Deployment is PRIMARY + # 2. rolloutState is COMPLETED (or IN_PROGRESS for daemon services) + # 3. There are no pending tasks + # 4. Running count matches desired count (or running > 0 for daemon services) + if deployment["status"] != "PRIMARY": is_deployment_stable = False - logger.info( - f"Deployment not yet stable: status={deployment['status']}, " - f"rolloutState={deployment.get('rolloutState', 'N/A')}" - ) + logger.info(f"Deployment not PRIMARY: status={deployment['status']}") + elif rollout_state == "FAILED": + logger.error(f"Deployment FAILED for model '{model_id}'") + output_dict["ecs_polling_error"] = f"ECS deployment failed for model '{model_id}'" + output_dict["should_continue_ecs_polling"] = False + return output_dict + elif pending_count > 0: + is_deployment_stable = False + logger.info(f"Deployment has pending tasks: {pending_count}") + elif running_count == 0: + is_deployment_stable = False + logger.info("Deployment has no running tasks yet") + elif rollout_state not in ["COMPLETED", None]: + # For daemon services, rolloutState might not be COMPLETED immediately + # but if we have running tasks and no pending, we can consider it stable + if running_count > 0 and pending_count == 0: + logger.info( + f"Deployment has running tasks ({running_count}) with no pending, " + f"considering stable despite rolloutState={rollout_state}" + ) + is_deployment_stable = True + else: + is_deployment_stable = False + logger.info(f"Deployment rolloutState not COMPLETED: {rollout_state}") else: logger.info("Deployment is stable and completed") break @@ -1004,6 +1048,16 @@ def handle_poll_ecs_deployment(event: Dict[str, Any], context: Any) -> Dict[str, if is_deployment_stable: logger.info(f"ECS deployment completed successfully for model '{model_id}'") + + # Deregister old task definition to keep things clean + old_task_def_arn = event.get("old_task_definition_arn") + if old_task_def_arn: + try: + ecs_client.deregister_task_definition(taskDefinition=old_task_def_arn) + logger.info(f"Deregistered old task definition: {old_task_def_arn}") + except Exception as deregister_error: + # Log but don't fail - deregistration is cleanup, not critical + logger.warning(f"Failed to deregister old task definition {old_task_def_arn}: {deregister_error}") else: logger.info(f"ECS deployment still in progress for model '{model_id}', remaining polls: {remaining_polls}") diff --git a/lambda/prompt_templates/lambda_functions.py b/lambda/prompt_templates/lambda_functions.py index 1e3f34ba1..62afe3057 100644 --- a/lambda/prompt_templates/lambda_functions.py +++ b/lambda/prompt_templates/lambda_functions.py @@ -13,12 +13,14 @@ # limitations under the License. """Lambda functions for managing prompt templates in AWS DynamoDB.""" +from __future__ import annotations + import json import logging import os from decimal import Decimal from functools import reduce -from typing import Any, Dict, List, Optional +from typing import Any import boto3 from boto3.dynamodb.conditions import Attr, Key @@ -35,10 +37,10 @@ def _get_prompt_templates( - user_id: Optional[str] = None, - groups: Optional[List] = None, - latest: Optional[bool] = None, -) -> Dict[str, Any]: + user_id: str | None = None, + groups: list[str] | None = None, + latest: bool | None = None, +) -> dict[str, Any]: """Helper function to retrieve prompt templates from DynamoDB.""" filter_expression = None @@ -60,7 +62,7 @@ def _get_prompt_templates( condition = reduce(lambda a, b: a | b, conditions, condition) filter_expression = condition if filter_expression is None else filter_expression & condition - scan_arguments = { + scan_arguments: dict[str, Any] = { "TableName": os.environ["PROMPT_TEMPLATES_TABLE_NAME"], "IndexName": os.environ["PROMPT_TEMPLATES_BY_LATEST_INDEX_NAME"], } @@ -106,7 +108,7 @@ def get(event: dict, context: dict) -> Any: raise ValueError(f"Not authorized to get {prompt_template_id}.") -def is_member(user_groups: List[str], prompt_groups: List[str]) -> bool: +def is_member(user_groups: list[str], prompt_groups: list[str]) -> bool: if "lisa:public" in prompt_groups: return True @@ -114,7 +116,7 @@ def is_member(user_groups: List[str], prompt_groups: List[str]) -> bool: @api_wrapper -def list(event: dict, context: dict) -> Dict[str, Any]: +def list_prompt(event: dict, context: dict) -> dict[str, Any]: """List prompt templates for a user from DynamoDB.""" query_params = event.get("queryStringParameters", {}) user_id, is_admin, groups = get_user_context(event) @@ -186,7 +188,7 @@ def update(event: dict, context: dict) -> Any: @api_wrapper -def delete(event: dict, context: dict) -> Dict[str, str]: +def delete(event: dict, context: dict) -> dict[str, str]: """Logically delete a prompt template from DynamoDB.""" user_id, is_admin, _ = get_user_context(event) prompt_template_id = get_prompt_template_id(event) diff --git a/lambda/prompt_templates/models.py b/lambda/prompt_templates/models.py index 6fe7db806..ef0811581 100644 --- a/lambda/prompt_templates/models.py +++ b/lambda/prompt_templates/models.py @@ -14,7 +14,7 @@ import uuid from enum import StrEnum -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel, Field from utilities.time import iso_string @@ -34,32 +34,32 @@ class PromptTemplateModel(BaseModel): """ # Unique identifier for the prompt template - id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4())) + id: str | None = Field(default_factory=lambda: str(uuid.uuid4())) # Timestamp of when the prompt template was created - created: Optional[str] = Field(default_factory=iso_string) + created: str | None = Field(default_factory=iso_string) # Owner of the prompt template owner: str # List of groups that have access to the prompt template - groups: List[str] = Field(default=[]) + groups: list[str] = Field(default=[]) # Title of the prompt template title: str # Current revision number of the prompt template - revision: Optional[int] = Field(default=1) + revision: int | None = Field(default=1) # Flag indicating if this is the latest revision - latest: Optional[bool] = Field(default=True) + latest: bool | None = Field(default=True) type: PromptTemplateType = Field(default=PromptTemplateType.PERSONA) # The main body content of the prompt template body: str - def new_revision(self, update: Dict[str, Any]) -> "PromptTemplateModel": + def new_revision(self, update: dict[str, Any]) -> "PromptTemplateModel": """ Create a new revision of the current prompt template. @@ -69,6 +69,7 @@ def new_revision(self, update: Dict[str, Any]) -> "PromptTemplateModel": Returns: PromptTemplateModel: A new instance of PromptTemplateModel with updated attributes. """ - return self.model_copy( # type: ignore + result: PromptTemplateModel = self.model_copy( update=update | {"created": iso_string(), "revision": (self.revision or 0) + 1} ) + return result diff --git a/lambda/repository/collection_repo.py b/lambda/repository/collection_repo.py index 2ed367d89..94ecd6374 100644 --- a/lambda/repository/collection_repo.py +++ b/lambda/repository/collection_repo.py @@ -16,7 +16,7 @@ import logging import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import boto3 from boto3.dynamodb.conditions import Attr, Key @@ -38,7 +38,7 @@ class CollectionRepositoryError(Exception): class CollectionRepository: """Collection repository for DynamoDB operations.""" - def __init__(self, table_name: Optional[str] = None) -> None: + def __init__(self, table_name: str | None = None) -> None: """ Initialize the Collection Repository. @@ -96,7 +96,7 @@ def create(self, collection: RagCollectionConfig) -> RagCollectionConfig: logger.error(f"Unexpected error creating collection: {e}") raise CollectionRepositoryError(f"Unexpected error creating collection: {str(e)}") - def find_by_id(self, collection_id: str, repository_id: str) -> Optional[RagCollectionConfig]: + def find_by_id(self, collection_id: str, repository_id: str) -> RagCollectionConfig | None: """ Find a collection by its ID and repository ID. @@ -133,8 +133,8 @@ def update( self, collection_id: str, repository_id: str, - updates: Dict[str, Any], - expected_version: Optional[str] = None, + updates: dict[str, Any], + expected_version: str | None = None, ) -> RagCollectionConfig: """ Update a collection with optimistic locking. @@ -250,12 +250,12 @@ def list_by_repository( self, repository_id: str, page_size: int = 20, - last_evaluated_key: Optional[Dict[str, str]] = None, - filter_text: Optional[str] = None, - status_filter: Optional[CollectionStatus] = None, + last_evaluated_key: dict[str, str] | None = None, + filter_text: str | None = None, + status_filter: CollectionStatus | None = None, sort_by: CollectionSortBy = CollectionSortBy.CREATED_AT, sort_order: SortOrder = SortOrder.DESC, - ) -> Tuple[List[RagCollectionConfig], Optional[Dict[str, str]]]: + ) -> tuple[list[RagCollectionConfig], dict[str, str] | None]: """ List collections for a repository with pagination, filtering, and sorting. @@ -332,7 +332,7 @@ def list_by_repository( logger.error(f"Failed to list collections for repository {repository_id}: {e}") raise CollectionRepositoryError(f"Failed to list collections: {str(e)}") - def count_by_repository(self, repository_id: str, status: Optional[CollectionStatus] = None) -> int: + def count_by_repository(self, repository_id: str, status: CollectionStatus | None = None) -> int: """ Count collections in a repository. @@ -362,13 +362,13 @@ def count_by_repository(self, repository_id: str, status: Optional[CollectionSta count = response.get("Count", 0) logger.info(f"Counted {count} collections for repository {repository_id}") - return count + return count # type: ignore[no-any-return] except Exception as e: logger.error(f"Failed to count collections for repository {repository_id}: {e}") raise CollectionRepositoryError(f"Failed to count collections: {str(e)}") - def find_by_name(self, repository_id: str, collection_name: str) -> Optional[RagCollectionConfig]: + def find_by_name(self, repository_id: str, collection_name: str) -> RagCollectionConfig | None: """ Find a collection by repository ID and name. @@ -402,7 +402,7 @@ def find_by_name(self, repository_id: str, collection_name: str) -> Optional[Rag logger.error(f"Failed to find collection by name '{collection_name}': {e}") raise CollectionRepositoryError(f"Failed to find collection by name: {str(e)}") - def find_collections_using_model(self, model_id: str) -> List[Dict[str, str]]: + def find_collections_using_model(self, model_id: str) -> list[dict[str, str]]: """ Find all collections that use a specific embedding model. Excludes collections with status indicating they are deleted or archived. diff --git a/lambda/repository/collection_service.py b/lambda/repository/collection_service.py index 4750fee6c..1debd8019 100644 --- a/lambda/repository/collection_service.py +++ b/lambda/repository/collection_service.py @@ -18,7 +18,7 @@ import heapq import logging import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import boto3 from models.domain_objects import ( @@ -52,9 +52,9 @@ class CollectionService: def __init__( self, - collection_repo: Optional[CollectionRepository] = None, - vector_store_repo: Optional[VectorStoreRepository] = None, - document_repo: Optional[RagDocumentRepository] = None, + collection_repo: CollectionRepository | None = None, + vector_store_repo: VectorStoreRepository | None = None, + document_repo: RagDocumentRepository | None = None, ): self.collection_repo = collection_repo or CollectionRepository() self.vector_store_repo = vector_store_repo or VectorStoreRepository() @@ -66,7 +66,7 @@ def has_access( self, collection: RagCollectionConfig, username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, require_write: bool = False, ) -> bool: @@ -109,7 +109,7 @@ def create_collection( """ # Check if collection name already exists in this repository - existing = self.collection_repo.find_by_name(collection.repositoryId, collection.name) + existing = self.collection_repo.find_by_name(collection.repositoryId, collection.name) # type: ignore[arg-type] if existing: raise ValidationError( f"Collection with name '{collection.name}' already exists in repository '{collection.repositoryId}'" @@ -125,7 +125,7 @@ def get_collection( repository_id: str, collection_id: str, username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, ) -> RagCollectionConfig: """Get a collection with access control. @@ -148,11 +148,11 @@ def list_collections( self, repository_id: str, username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, page_size: int = 20, - last_evaluated_key: Optional[Dict[str, str]] = None, - ) -> Tuple[List[RagCollectionConfig], Optional[Dict[str, str]]]: + last_evaluated_key: dict[str, str] | None = None, + ) -> tuple[list[RagCollectionConfig], dict[str, str] | None]: """List collections with access control. For Bedrock KB repositories, default collections are persisted to the database @@ -189,7 +189,7 @@ def update_collection( repository_id: str, collection_data: Any, username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, ) -> RagCollectionConfig: """Update a collection with access control and name uniqueness validation. @@ -259,12 +259,12 @@ def update_collection( def delete_collection( self, repository_id: str, - collection_id: Optional[str], - embedding_name: Optional[str], + collection_id: str | None, + embedding_name: str | None, username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Delete a collection with access control. Args: @@ -293,6 +293,9 @@ def delete_collection( # For regular collections, verify access and update status if not is_default_collection: + if collection_id is None: + raise ValidationError("collection_id is required for non-default collections") + collection = self.collection_repo.find_by_id(collection_id, repository_id) if not collection: raise ValidationError(f"Collection {collection_id} not found") @@ -300,7 +303,11 @@ def delete_collection( raise ValidationError(f"Permission denied to delete collection {collection_id}") # Update collection status to DELETE_IN_PROGRESS - self.collection_repo.update(collection_id, repository_id, {"status": CollectionStatus.DELETE_IN_PROGRESS}) + self.collection_repo.update( + collection_id, + repository_id, + {"status": CollectionStatus.DELETE_IN_PROGRESS}, + ) embedding_model = None # Don't set embedding_model for regular collections else: @@ -354,7 +361,7 @@ def delete_collection( # Add summary if counts available if lisa_managed_count is not None and user_managed_count is not None: - response["summary"] = { + response["summary"] = { # type: ignore[assignment] "lisaManagedDocuments": lisa_managed_count, "userManagedDocuments": user_managed_count, "action": ( @@ -369,8 +376,12 @@ def delete_collection( logger.error(f"Failed to submit deletion job: {e}", exc_info=True) # Update collection status to DELETE_FAILED (only for regular collections) - if not is_default_collection: - self.collection_repo.update(collection_id, repository_id, {"status": CollectionStatus.DELETE_FAILED}) + if not is_default_collection and collection_id is not None: + self.collection_repo.update( + collection_id, + repository_id, + {"status": CollectionStatus.DELETE_FAILED}, + ) raise @@ -379,7 +390,7 @@ def get_collection_by_name( repository_id: str, collection_name: str, username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, ) -> RagCollectionConfig: """Get a collection by name with access control.""" @@ -407,9 +418,9 @@ def get_collection_model( repository_id: str, collection_id: str, username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, - ) -> Optional[str]: + ) -> str | None: """Get embedding model from collection or repository default. Args: @@ -439,13 +450,13 @@ def get_collection_model( def list_all_user_collections( self, username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, page_size: int = 20, - pagination_token: Optional[Dict[str, Any]] = None, - filter_text: Optional[str] = None, - sort_params: Optional[SortParams] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + pagination_token: dict[str, Any] | None = None, + filter_text: str | None = None, + sort_params: SortParams | None = None, + ) -> tuple[list[dict[str, Any]], dict[str, Any] | None]: """ List all collections user has access to across all repositories. @@ -473,7 +484,8 @@ def list_all_user_collections( logger.info( f"Listing all user collections for user={username}, is_admin={is_admin}, " - f"page_size={page_size}, filter={filter_text}, sort_by={sort_params.sort_by.value}" + f"page_size={page_size}, filter={filter_text}, " + f"sort_by={sort_params.sort_by.value}" # type: ignore[union-attr] ) # Get repositories user can access @@ -504,8 +516,8 @@ def list_all_user_collections( return collections, next_token def _get_accessible_repositories( - self, username: str, user_groups: List[str], is_admin: bool - ) -> List[Dict[str, Any]]: + self, username: str, user_groups: list[str], is_admin: bool + ) -> list[dict[str, Any]]: """ Get all repositories user has access to. @@ -527,7 +539,7 @@ def _get_accessible_repositories( logger.debug(f"User has access to {len(accessible)} of {len(all_repos)} repositories") return accessible - def _has_repository_access(self, user_groups: List[str], repository: Dict[str, Any]) -> bool: + def _has_repository_access(self, user_groups: list[str], repository: dict[str, Any]) -> bool: """ Check if user has access to repository based on groups. @@ -553,8 +565,8 @@ def _has_repository_access(self, user_groups: List[str], repository: Dict[str, A return has_access def _enrich_with_repository_metadata( - self, collections: List[RagCollectionConfig], repositories: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: + self, collections: list[RagCollectionConfig], repositories: list[dict[str, Any]] + ) -> list[dict[str, Any]]: """ Enrich collections with repository metadata. @@ -586,7 +598,7 @@ def _enrich_with_repository_metadata( return enriched - def _estimate_total_collections(self, repositories: List[Dict[str, Any]]) -> int: + def _estimate_total_collections(self, repositories: list[dict[str, Any]]) -> int: """ Estimate total number of collections across repositories. @@ -609,15 +621,15 @@ def _estimate_total_collections(self, repositories: List[Dict[str, Any]]) -> int def _paginate_collections( self, - repositories: List[Dict[str, Any]], + repositories: list[dict[str, Any]], username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, page_size: int, - pagination_token: Optional[Dict[str, Any]], - filter_text: Optional[str], + pagination_token: dict[str, Any] | None, + filter_text: str | None, sort_params: SortParams, - ) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + ) -> tuple[list[dict[str, Any]], dict[str, Any] | None]: """ Simple pagination strategy for small-to-medium deployments. @@ -652,7 +664,7 @@ def _paginate_collections( offset = 0 # Aggregate all collections from accessible repositories - all_collections: List[RagCollectionConfig] = [] + all_collections: list[RagCollectionConfig] = [] for repo in repositories: repo_id = repo["repositoryId"] @@ -709,8 +721,8 @@ def _paginate_collections( "offset": end_idx, "filters": { "filter": filter_text, - "sortBy": sort_params.sort_by.value, - "sortOrder": sort_params.sort_order.value, + "sortBy": sort_params.sort_by.value, # type: ignore[union-attr] + "sortOrder": sort_params.sort_order.value, # type: ignore[union-attr] }, } @@ -740,8 +752,8 @@ def _matches_filter(self, collection: RagCollectionConfig, filter_text: str) -> return False def _sort_collections( - self, collections: List[RagCollectionConfig], sort_params: SortParams - ) -> List[RagCollectionConfig]: + self, collections: list[RagCollectionConfig], sort_params: SortParams + ) -> list[RagCollectionConfig]: """ Sort collections by specified field and order. @@ -763,15 +775,15 @@ def _sort_collections( def _paginate_large_collections( self, - repositories: List[Dict[str, Any]], + repositories: list[dict[str, Any]], username: str, - user_groups: List[str], + user_groups: list[str], is_admin: bool, page_size: int, - pagination_token: Optional[Dict[str, Any]], - filter_text: Optional[str], + pagination_token: dict[str, Any] | None, + filter_text: str | None, sort_params: SortParams, - ) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + ) -> tuple[list[dict[str, Any]], dict[str, Any] | None]: """ Scalable pagination strategy for large deployments. @@ -803,8 +815,8 @@ def _paginate_large_collections( token_filters = pagination_token.get("filters", {}) if ( token_filters.get("filter") != filter_text - or token_filters.get("sortBy") != sort_params.sort_by.value - or token_filters.get("sortOrder") != sort_params.sort_order.value + or token_filters.get("sortBy") != sort_params.sort_by.value # type: ignore[union-attr] + or token_filters.get("sortOrder") != sort_params.sort_order.value # type: ignore[union-attr] ): logger.warning("Pagination token filters don't match, resetting cursors") cursors = {} @@ -878,7 +890,11 @@ def _paginate_large_collections( cursors[repo_id]["exhausted"] = True # Merge batches using heap for efficient sorting - merged = self._merge_sorted_batches(batches, sort_params.sort_by.value, sort_params.sort_order.value) + merged = self._merge_sorted_batches( + batches, + sort_params.sort_by.value, # type: ignore[union-attr] + sort_params.sort_order.value, # type: ignore[union-attr] + ) # Extract requested page start_idx = global_offset @@ -907,16 +923,16 @@ def _paginate_large_collections( "seenCollectionIds": serializable_seen_ids, "filters": { "filter": filter_text, - "sortBy": sort_params.sort_by.value, - "sortOrder": sort_params.sort_order.value, + "sortBy": sort_params.sort_by.value, # type: ignore[union-attr] + "sortOrder": sort_params.sort_order.value, # type: ignore[union-attr] }, } return enriched, next_token def _merge_sorted_batches( - self, batches: List[Dict[str, Any]], sort_by: str, sort_order: str - ) -> List[RagCollectionConfig]: + self, batches: list[dict[str, Any]], sort_by: str, sort_order: str + ) -> list[RagCollectionConfig]: """ Merge pre-sorted batches from multiple repositories using min-heap. @@ -935,7 +951,7 @@ def _merge_sorted_batches( return [] # Create heap with first item from each batch - heap: List[Tuple[Any, str, int, Dict[str, Any]]] = [] + heap: list[tuple[Any, str, int, dict[str, Any]]] = [] for batch in batches: if batch["collections"]: diff --git a/lambda/repository/config/params.py b/lambda/repository/config/params.py index 35e03ce92..4b65e4e85 100644 --- a/lambda/repository/config/params.py +++ b/lambda/repository/config/params.py @@ -17,7 +17,7 @@ import json import urllib.parse from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any from utilities.constants import DEFAULT_PAGE_SIZE, DEFAULT_TIME_LIMIT_HOURS, MAX_PAGE_SIZE, MIN_PAGE_SIZE from utilities.validation import ValidationError @@ -29,11 +29,11 @@ class ListJobsParams: repository_id: str page_size: int = 10 - last_evaluated_key: Optional[Dict[str, Any]] = None + last_evaluated_key: dict[str, Any] | None = None time_limit_hours: int = DEFAULT_TIME_LIMIT_HOURS @classmethod - def from_event(cls, event: Dict[str, Any]) -> "ListJobsParams": + def from_event(cls, event: dict[str, Any]) -> "ListJobsParams": """Extract and validate parameters from Lambda event.""" path_params = event.get("pathParameters", {}) query_params = event.get("queryStringParameters", {}) or {} @@ -50,25 +50,25 @@ def from_event(cls, event: Dict[str, Any]) -> "ListJobsParams": ) @staticmethod - def _parse_time_limit(query_params: Dict[str, str]) -> int: + def _parse_time_limit(query_params: dict[str, str]) -> int: """Parse time limit from query parameters.""" return int(query_params.get("timeLimit", str(DEFAULT_TIME_LIMIT_HOURS))) @staticmethod - def _parse_page_size(query_params: Dict[str, str]) -> int: + def _parse_page_size(query_params: dict[str, str]) -> int: """Parse and validate page size from query parameters.""" page_size = int(query_params.get("pageSize", str(DEFAULT_PAGE_SIZE))) return max(MIN_PAGE_SIZE, min(page_size, MAX_PAGE_SIZE)) @staticmethod - def _parse_last_evaluated_key(query_params: Dict[str, str]) -> Optional[Dict[str, str]]: + def _parse_last_evaluated_key(query_params: dict[str, str]) -> dict[str, str] | None: """Parse lastEvaluatedKey with specific error handling.""" if "lastEvaluatedKey" not in query_params: return None try: decoded = urllib.parse.unquote(query_params["lastEvaluatedKey"]) - return json.loads(decoded) + return json.loads(decoded) # type: ignore[no-any-return] except json.JSONDecodeError as e: raise ValidationError(f"Invalid JSON in lastEvaluatedKey: {e}") except (TypeError, ValueError) as e: diff --git a/lambda/repository/embeddings.py b/lambda/repository/embeddings.py index b691783bf..666d48845 100644 --- a/lambda/repository/embeddings.py +++ b/lambda/repository/embeddings.py @@ -14,11 +14,13 @@ import logging import os -from typing import List +from typing import Any import boto3 import requests -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, ConfigDict, field_validator +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry from utilities.auth import get_management_key from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config from utilities.validation import validate_model_name, ValidationError @@ -30,12 +32,39 @@ lisa_api_endpoint = "" +# Module-level session with connection pooling for better performance +# This reuses TCP connections across multiple embedding requests +_http_session: requests.Session | None = None + + +def _get_http_session() -> requests.Session: + """Get or create a shared HTTP session with connection pooling.""" + global _http_session + if _http_session is None: + _http_session = requests.Session() + # Configure retry strategy for transient failures + retry_strategy = Retry( + total=2, + backoff_factor=0.5, + status_forcelist=[502, 503, 504], + ) + adapter = HTTPAdapter( + pool_connections=10, # Number of connection pools + pool_maxsize=20, # Max connections per pool + max_retries=retry_strategy, + ) + _http_session.mount("http://", adapter) + _http_session.mount("https://", adapter) + return _http_session + class RagEmbeddings(BaseModel): """ Handles document embeddings through LiteLLM using management credentials. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + model_name: str token: str lisa_api_endpoint: str @@ -48,7 +77,7 @@ def validate_model_name(cls, v: str) -> str: validate_model_name(v) return v - def __init__(self, model_name: str, id_token: str | None = None, **data) -> None: + def __init__(self, model_name: str, id_token: str | None = None, **data: Any) -> None: # Prepare initialization data init_data = {"model_name": model_name, **data} try: @@ -69,10 +98,7 @@ def __init__(self, model_name: str, id_token: str | None = None, **data) -> None logger.error("Failed to initialize pipeline embeddings", exc_info=True) raise - class Config: - arbitrary_types_allowed = True - - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: """ Generate embeddings for a list of documents. @@ -92,8 +118,12 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: logger.info(f"Embedding {len(texts)} documents using {self.model_name}") try: url = f"{self.base_url}/embeddings" - request_data = {"input": texts, "model": self.model_name} - response = requests.post( + # Use encoding_format="float" to ensure embeddings are returned as float arrays + request_data = {"input": texts, "model": self.model_name, "encoding_format": "float"} + + # Use shared session with connection pooling for better performance + session = _get_http_session() + response = session.post( url, json=request_data, headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"}, @@ -103,6 +133,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: if response.status_code != 200: logger.error(f"Embedding request failed with status {response.status_code}") + logger.error(f"Embedding error response body: {response.text}") raise Exception(f"Embedding request failed with status {response.status_code}") result = response.json() @@ -149,7 +180,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: logger.error(f"Failed to get embeddings: {str(e)}", exc_info=True) raise - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: if not text or not isinstance(text, str): raise ValidationError("Invalid query text") diff --git a/lambda/repository/ingestion_job_repo.py b/lambda/repository/ingestion_job_repo.py index 21f536220..bdfa6351c 100644 --- a/lambda/repository/ingestion_job_repo.py +++ b/lambda/repository/ingestion_job_repo.py @@ -17,7 +17,7 @@ import logging import os from datetime import timedelta -from typing import Dict, Optional +from typing import Any import boto3 from models.domain_objects import IngestionJob, IngestionStatus @@ -27,14 +27,14 @@ logger = logging.getLogger(__name__) -def _get_ingestion_job_table(): +def _get_ingestion_job_table() -> Any: """Lazy initialization of DynamoDB table.""" dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) return dynamodb.Table(os.environ["LISA_INGESTION_JOB_TABLE_NAME"]) class IngestionJobListResponse: - def __init__(self, jobs: list[IngestionJob], continuation_token: Optional[str]): + def __init__(self, jobs: list[IngestionJob], continuation_token: str | None): self.jobs = jobs self.continuation_token = continuation_token @@ -52,25 +52,25 @@ def __init__(self, message: str): class IngestionJobRepository: - def __init__(self): - self._ddb_client = None - self._table_name = None - self._batch_client = None + def __init__(self) -> None: + self._ddb_client: Any = None + self._table_name: str | None = None + self._batch_client: Any = None @property - def ddb_client(self): + def ddb_client(self) -> Any: if self._ddb_client is None: self._ddb_client = boto3.client("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) return self._ddb_client @property - def table_name(self): + def table_name(self) -> str: if self._table_name is None: self._table_name = os.environ["LISA_INGESTION_JOB_TABLE_NAME"] return self._table_name @property - def batch_client(self): + def batch_client(self) -> Any: if self._batch_client is None: self._batch_client = boto3.client("batch", region_name=os.environ["AWS_REGION"], config=retry_config) return self._batch_client @@ -94,7 +94,7 @@ def find_by_path(self, s3_path: str) -> list[IngestionJob]: items = response.get("Items", []) return [IngestionJob(**item) for item in items] - def find_by_document(self, document_id: str) -> Optional[IngestionJob]: + def find_by_document(self, document_id: str) -> IngestionJob | None: response = _get_ingestion_job_table().query( IndexName="documentId", KeyConditionExpression="document_id = :document_id", @@ -129,8 +129,8 @@ def list_jobs_by_repository( is_admin: bool, time_limit_hours: int = 1, page_size: int = 10, - last_evaluated_key: Optional[Dict[str, str]] = None, - ) -> tuple[list[IngestionJob], Optional[Dict[str, str]]]: + last_evaluated_key: dict[str, str] | None = None, + ) -> tuple[list[IngestionJob], dict[str, str] | None]: """List ingestion jobs filtered by repository, user permissions, and time limit with pagination. Args: @@ -164,7 +164,9 @@ def list_jobs_by_repository( # Add username filter for non-admin users if not is_admin: query_params["FilterExpression"] = "username = :username" - query_params["ExpressionAttributeValues"].update({":username": username, ":system_username": "system"}) + expr_attr_values = query_params["ExpressionAttributeValues"] + if isinstance(expr_attr_values, dict): + expr_attr_values.update({":username": username, ":system_username": "system"}) # Add pagination token if provided if last_evaluated_key: @@ -188,7 +190,7 @@ def list_jobs_by_repository( return jobs, last_evaluated_key_response - def get_batch_job_status(self, job_id: str) -> Optional[str]: + def get_batch_job_status(self, job_id: str) -> str | None: """Get the status of a batch job by job ID. Args: @@ -199,10 +201,10 @@ def get_batch_job_status(self, job_id: str) -> Optional[str]: """ response = self.batch_client.describe_jobs(jobs=[job_id]) if response.get("jobs"): - return response["jobs"][0].get("status") + return response["jobs"][0].get("status") # type: ignore[no-any-return] return None - def find_batch_job_for_document(self, document_id: str, job_queue: str) -> Optional[Dict]: + def find_batch_job_for_document(self, document_id: str, job_queue: str) -> dict | None: """Find the batch job associated with a document ingestion. Args: diff --git a/lambda/repository/ingestion_service.py b/lambda/repository/ingestion_service.py index b622891be..f544038c9 100644 --- a/lambda/repository/ingestion_service.py +++ b/lambda/repository/ingestion_service.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Dict, Optional +from typing import Any import boto3 from models.domain_objects import Enum, FixedChunkingStrategy, IngestDocumentRequest, IngestionJob, IngestionType @@ -48,12 +48,12 @@ def create_delete_job(self, job: IngestionJob) -> None: def create_ingestion_job( self, repository: dict, - collection: Optional[dict], + collection: dict | None, request: IngestDocumentRequest, query_params: dict, s3_path: str, username: str, - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ingestion_type: IngestionType = IngestionType.MANUAL, ) -> IngestionJob: @@ -113,9 +113,9 @@ def create_ingestion_job( def _merge_metadata_for_ingestion( self, repository: dict, - collection: Optional[dict], - document_metadata: Optional[Dict[str, Any]] = None, - ) -> Optional[Dict[str, Any]]: + collection: dict | None, + document_metadata: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: """ Merge metadata from repository, collection, and document sources for ingestion jobs. @@ -133,7 +133,7 @@ def _merge_metadata_for_ingestion( Returns: Merged metadata dictionary or None if no metadata sources exist """ - merged_metadata: Dict[str, Any] = {} + merged_metadata: dict[str, Any] = {} # 1. Merge repository metadata (lowest precedence) repo_metadata = repository.get("metadata") diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index bc095f680..c9133d05f 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -19,7 +19,7 @@ import os import urllib.parse from types import SimpleNamespace -from typing import Any, cast, Dict, List, Optional +from typing import Any, cast import boto3 from boto3.dynamodb.types import TypeSerializer @@ -89,7 +89,7 @@ @api_wrapper -def list_all(event: dict, context: dict) -> List[Dict[str, Any]]: +def list_all(event: dict, context: dict) -> list[dict[str, Any]]: """ List all available repositories that the user has access to. @@ -122,7 +122,7 @@ def list_status(event: dict, context: dict) -> dict[str, Any]: @api_wrapper -def similarity_search(event: dict, context: dict) -> Dict[str, Any]: +def similarity_search(event: dict, context: dict) -> dict[str, Any]: """Return documents matching the query. Conducts similarity search against the vector store returning the top K @@ -149,11 +149,11 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: """ query_string_params = event.get("queryStringParameters") path_params = event.get("pathParameters") - query = query_string_params.get("query") - top_k = int(query_string_params.get("topK", 3)) - include_score = query_string_params.get("score", "false").lower() == "true" - repository_id = path_params.get("repositoryId") - collection_id = query_string_params.get("collectionId") + query = query_string_params.get("query") # type: ignore[union-attr] + top_k = int(query_string_params.get("topK", 3)) # type: ignore[union-attr] + include_score = query_string_params.get("score", "false").lower() == "true" # type: ignore[union-attr] + repository_id = path_params.get("repositoryId") # type: ignore[union-attr] + collection_id = query_string_params.get("collectionId") # type: ignore[union-attr] repository = get_repository(event, repository_id=repository_id) @@ -165,13 +165,13 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: model_name = ( collection_service.get_collection_model( repository_id=repository_id, - collection_id=collection_id if not is_default else None, + collection_id=collection_id if not is_default else None, # type: ignore[arg-type] username=username, user_groups=groups, is_admin=is_admin, ) if collection_id - else query_string_params.get("modelName") + else query_string_params.get("modelName") # type: ignore[union-attr] ) if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): @@ -191,9 +191,9 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: # Delegate to service for retrieval - service handles repository-specific logic docs = service.retrieve_documents( query=query, - collection_id=search_collection_id, + collection_id=search_collection_id, # type: ignore[arg-type] top_k=top_k, - model_name=model_name, + model_name=model_name, # type: ignore[arg-type] include_score=include_score, bedrock_agent_client=bedrock_client, ) @@ -229,7 +229,7 @@ def get_repository(event: dict[str, Any], repository_id: str) -> dict[str, Any]: return repo -def create_bedrock_collection(event: dict, context: dict) -> Dict[str, Any]: +def create_bedrock_collection(event: dict, context: dict) -> dict[str, Any]: """ Create collections for a Bedrock Knowledge Base repository based on pipeline configurations. This is called by the state machine during repository creation. @@ -311,7 +311,7 @@ def create_bedrock_collection(event: dict, context: dict) -> Dict[str, Any]: ) # Create collection using service helper - collection = service._create_collection_for_data_source( + collection = service._create_collection_for_data_source( # type: ignore[attr-defined] data_source_id=collection_id, s3_uri=s3_uri, is_default=False, collection_name=collection_name ) @@ -366,7 +366,7 @@ def create_bedrock_collection(event: dict, context: dict) -> Dict[str, Any]: @api_wrapper @admin_only -def create_collection(event: dict, context: dict) -> Dict[str, Any]: +def create_collection(event: dict, context: dict) -> dict[str, Any]: """ Create a new collection within a vector store. @@ -434,7 +434,7 @@ def create_collection(event: dict, context: dict) -> Dict[str, Any]: @api_wrapper -def get_collection(event: dict, context: dict) -> Dict[str, Any]: +def get_collection(event: dict, context: dict) -> dict[str, Any]: """ Get a collection by ID within a vector store. @@ -493,7 +493,7 @@ def get_collection(event: dict, context: dict) -> Dict[str, Any]: @api_wrapper @admin_only -def update_collection(event: dict, context: dict) -> Dict[str, Any]: +def update_collection(event: dict, context: dict) -> dict[str, Any]: """ Update a collection within a vector store. @@ -555,7 +555,7 @@ def update_collection(event: dict, context: dict) -> Dict[str, Any]: @api_wrapper @admin_only -def delete_collection(event: dict, context: dict) -> Dict[str, Any]: +def delete_collection(event: dict, context: dict) -> dict[str, Any]: """ Delete a collection (regular or default) within a vector store. @@ -598,7 +598,7 @@ def delete_collection(event: dict, context: dict) -> Dict[str, Any]: is_default_collection = repo.get("embeddingModelId") == collection_id # Delete collection via service - result: Dict[str, Any] = collection_service.delete_collection( + result: dict[str, Any] = collection_service.delete_collection( repository_id=repository_id, collection_id=collection_id, # None for default collections embedding_name=embedding_name if is_default_collection else None, # None for regular collections @@ -611,7 +611,7 @@ def delete_collection(event: dict, context: dict) -> Dict[str, Any]: @api_wrapper -def list_collections(event: dict, context: dict) -> Dict[str, Any]: +def list_collections(event: dict, context: dict) -> dict[str, Any]: """ List collections in a repository with pagination, filtering, and sorting. @@ -724,7 +724,7 @@ def list_collections(event: dict, context: dict) -> Dict[str, Any]: @api_wrapper -def list_user_collections(event: dict, context: dict) -> Dict[str, Any]: +def list_user_collections(event: dict, context: dict) -> dict[str, Any]: """ List all collections user has access to across all repositories. @@ -825,7 +825,7 @@ def _ensure_document_ownership(event: dict[str, Any], docs: list[RagDocument]) - @api_wrapper -def delete_documents(event: dict, context: dict) -> Dict[str, Any]: +def delete_documents(event: dict, context: dict) -> dict[str, Any]: """Purge all records related to the specified document from the RAG repository. If a documentId is supplied, a single document will be removed. If a documentName is supplied, all documents with that name will be removed @@ -864,7 +864,14 @@ def delete_documents(event: dict, context: dict) -> Dict[str, Any]: rag_documents: list[RagDocument] = [] if document_ids: - rag_documents = [doc_repo.find_by_id(document_id=document_id) for document_id in document_ids] + rag_documents = [ + doc + for doc in ( + doc_repo.find_by_id(document_id=document_id) + for document_id in document_ids # type: ignore[arg-type,unused-ignore] + ) + if doc is not None + ] if not rag_documents: raise ValueError(f"No documents found in repository collection {repository_id}:{collection_id}") @@ -966,7 +973,7 @@ def ingest_documents(event: dict, context: dict) -> dict: repository = get_repository(event, repository_id=repository_id) # Get collection if specified - collection: Optional[dict[str, Any]] = None + collection: dict[str, Any] | None = None if request.collectionId and request.collectionId != repository.get("embeddingModelId"): collection = collection_service.get_collection( collection_id=request.collectionId, @@ -998,7 +1005,10 @@ def ingest_documents(event: dict, context: dict) -> dict: # Upload metadata file try: s3_metadata_manager.upload_metadata_file( - s3_client=s3, bucket=bucket, document_key=key, metadata_content=job.metadata + s3_client=s3, + bucket=bucket, + document_key=key, + metadata_content=job.metadata, # type: ignore[arg-type] ) logger.info(f"Uploaded metadata file for {key}") except Exception as e: @@ -1007,7 +1017,7 @@ def ingest_documents(event: dict, context: dict) -> dict: jobs.append({"jobId": job.id, "documentId": job.document_id, "status": job.status, "s3Path": job.s3_path}) collection_id = job.collection_id - collection_name: Optional[str] = None + collection_name: str | None = None if collection: collection_name = collection.get("name") if not collection_name: @@ -1017,7 +1027,7 @@ def ingest_documents(event: dict, context: dict) -> dict: @api_wrapper -def get_document(event: dict, context: dict) -> Dict[str, Any]: +def get_document(event: dict, context: dict) -> dict[str, Any]: """Get a document by ID. Args: @@ -1039,7 +1049,7 @@ def get_document(event: dict, context: dict) -> Dict[str, Any]: _ = get_repository(event, repository_id=repository_id) doc = doc_repo.find_by_id(document_id=document_id) - result: dict[str, Any] = doc.model_dump() + result: dict[str, Any] = doc.model_dump() # type: ignore[union-attr] return result @@ -1065,9 +1075,9 @@ def download_document(event: dict, context: dict) -> str: if not repository_id: raise ValidationError("repositoryId is required") _ = get_repository(event, repository_id=repository_id) - doc = doc_repo.find_by_id(document_id=document_id) + doc = doc_repo.find_by_id(document_id=document_id) # type: ignore[arg-type] - source = doc.source + source = doc.source # type: ignore[union-attr] bucket, key = source.replace("s3://", "").split("/", 1) url: str = s3.generate_presigned_url( @@ -1144,7 +1154,7 @@ def list_docs(event: dict, context: dict) -> dict[str, Any]: query_string_params = event.get("queryStringParameters", {}) or {} collection_id = query_string_params.get("collectionId") - last_evaluated: Optional[dict[str, Optional[str]]] = None + last_evaluated: dict[str, str | None] | None = None if not repository_id: raise ValidationError("repositoryId is required") @@ -1186,7 +1196,7 @@ def list_docs(event: dict, context: dict) -> dict[str, Any]: @api_wrapper -def list_jobs(event: Dict[str, Any], context: dict) -> Dict[str, Any]: +def list_jobs(event: dict[str, Any], context: dict) -> dict[str, Any]: """List ingestion jobs for a specific repository with filtering and pagination. Args: @@ -1291,7 +1301,7 @@ def create(event: dict, context: dict) -> Any: "Please select at least one data source." ) # Convert bedrockKnowledgeBaseConfig to pipelines - vector_store_config.pipelines = build_pipeline_configs_from_kb_config( + vector_store_config.pipelines = build_pipeline_configs_from_kb_config( # type: ignore[assignment] vector_store_config.bedrockKnowledgeBaseConfig ) @@ -1317,7 +1327,7 @@ def create(event: dict, context: dict) -> Any: @api_wrapper -def get_repository_by_id(event: dict, context: dict) -> Dict[str, Any]: +def get_repository_by_id(event: dict, context: dict) -> dict[str, Any]: """ Get a vector store configuration by ID. @@ -1411,7 +1421,7 @@ def _validate_immutable_pipeline_fields(current_pipelines: list, new_pipelines: @api_wrapper @admin_only -def update_repository(event: dict, context: dict) -> Dict[str, Any]: +def update_repository(event: dict, context: dict) -> dict[str, Any]: """ Update a vector store configuration. This function is only accessible by administrators. @@ -1519,11 +1529,13 @@ def update_repository(event: dict, context: dict) -> Dict[str, Any]: # If metadata provided but missing tags, preserve existing tags elif "tags" not in current_meta and "tags" in existing_meta: pipeline["metadata"]["tags"] = existing_meta["tags"] - logger.info(f"Preserved tags for collection {collection_id}: {existing_meta['tags']}") + logger.info(f"Preserved tags for collection {collection_id}: " f"{existing_meta['tags']}") # Check if pipeline configuration has changed # Use the converted pipelines from updates if available, otherwise use request.pipelines - new_pipelines = updates.get("pipelines") if "pipelines" in updates else request.pipelines + new_pipelines = ( + updates.get("pipelines") if "pipelines" in updates else request.pipelines # type: ignore[assignment] + ) # Validate immutable pipeline fields for existing repositories if new_pipelines is not None and current_pipelines: @@ -1550,7 +1562,7 @@ def update_repository(event: dict, context: dict) -> Dict[str, Any]: # Check if pipelines were added or removed if current_pipeline_keys != new_pipeline_keys: - added = new_pipeline_keys - current_pipeline_keys + added = new_pipeline_keys - current_pipeline_keys # type: ignore[assignment] removed = current_pipeline_keys - new_pipeline_keys logger.info(f"Pipeline changes detected: added={list(added)}, removed={list(removed)}") require_deployment = True @@ -1692,7 +1704,7 @@ def _remove_legacy(repository_id: str) -> None: @api_wrapper -def list_bedrock_knowledge_bases(event: dict, context: dict) -> Dict[str, Any]: +def list_bedrock_knowledge_bases(event: dict, context: dict) -> dict[str, Any]: """ List all ACTIVE Bedrock Knowledge Bases in the AWS account. @@ -1749,7 +1761,7 @@ def list_bedrock_knowledge_bases(event: dict, context: dict) -> Dict[str, Any]: @api_wrapper -def list_bedrock_data_sources(event: dict, context: dict) -> Dict[str, Any]: +def list_bedrock_data_sources(event: dict, context: dict) -> dict[str, Any]: """ List data sources for a specific Bedrock Knowledge Base. diff --git a/lambda/repository/metadata_generator.py b/lambda/repository/metadata_generator.py index 89425215e..fe3f149f2 100644 --- a/lambda/repository/metadata_generator.py +++ b/lambda/repository/metadata_generator.py @@ -17,7 +17,7 @@ import json import logging import re -from typing import Any, Dict, Optional +from typing import Any from models.domain_objects import RagCollectionConfig from utilities.validation import ValidationError @@ -58,11 +58,11 @@ def _extract_tags_from_metadata(metadata: Any) -> set: @staticmethod def merge_metadata( - repository: Dict[str, Any], - collection: Optional[Dict[str, Any]], - document_metadata: Optional[Dict[str, Any]] = None, + repository: dict[str, Any], + collection: dict[str, Any] | None, + document_metadata: dict[str, Any] | None = None, for_bedrock_kb: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Merge metadata from repository, collection, and document sources. @@ -79,11 +79,11 @@ def merge_metadata( Returns: Merged metadata dictionary """ - merged_metadata: Dict[str, Any] = {} + merged_metadata: dict[str, Any] = {} all_tags: set = set() # Helper function to merge non-tag metadata - def merge_non_tag_metadata(metadata_source: Dict[str, Any]) -> None: + def merge_non_tag_metadata(metadata_source: dict[str, Any]) -> None: for key, value in metadata_source.items(): if key != "tags" and not isinstance(value, dict): merged_metadata[key] = value @@ -127,10 +127,10 @@ def merge_non_tag_metadata(metadata_source: Dict[str, Any]) -> None: @staticmethod def generate_metadata_json( - repository: Dict[str, Any], - collection: Optional[RagCollectionConfig], - document_metadata: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + repository: dict[str, Any], + collection: RagCollectionConfig | None, + document_metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: """Generate metadata.json content for Bedrock KB. Merges metadata from three sources with precedence: @@ -174,7 +174,7 @@ def generate_metadata_json( return {"metadataAttributes": merged_metadata} @staticmethod - def validate_metadata(metadata: Dict[str, Any]) -> bool: + def validate_metadata(metadata: dict[str, Any]) -> bool: """Validate metadata against Bedrock KB requirements. Args: diff --git a/lambda/repository/pipeline_delete_documents.py b/lambda/repository/pipeline_delete_documents.py index f5eb30da7..e3682baf4 100644 --- a/lambda/repository/pipeline_delete_documents.py +++ b/lambda/repository/pipeline_delete_documents.py @@ -14,7 +14,7 @@ import logging import os -from typing import Any, Dict +from typing import Any import boto3 from boto3.dynamodb.conditions import Key @@ -107,9 +107,9 @@ def pipeline_delete_collection(job: IngestionJob) -> None: # Drop index for faster cleanup (OpenSearch/PGVector) # This removes all embeddings from the vector store if RepositoryType.is_type(repository, RepositoryType.OPENSEARCH): - drop_opensearch_index(job.repository_id, job.collection_id) + drop_opensearch_index(job.repository_id, job.collection_id) # type: ignore[arg-type] elif RepositoryType.is_type(repository, RepositoryType.PGVECTOR): - drop_pgvector_collection(job.repository_id, job.collection_id) + drop_pgvector_collection(job.repository_id, job.collection_id) # type: ignore[arg-type] elif RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): # For Bedrock KB, use bulk delete for efficiency # Only delete LISA-managed documents (MANUAL/AUTO), preserve EXISTING @@ -165,13 +165,13 @@ def pipeline_delete_collection(job: IngestionJob) -> None: # Delete all documents and subdocuments from DynamoDB # This method handles pagination and batch deletion logger.info(f"Deleting all documents from DynamoDB for collection {job.collection_id}") - rag_document_repository.delete_all(job.repository_id, job.collection_id) + rag_document_repository.delete_all(job.repository_id, job.collection_id) # type: ignore[arg-type] logger.info("Successfully deleted all documents from DynamoDB") # Delete collection DB entry is_default_collection = job.embedding_model is not None if not is_default_collection: - collection_repo.delete(job.collection_id, job.repository_id) + collection_repo.delete(job.collection_id, job.repository_id) # type: ignore[arg-type] # Update job status ingestion_job_repository.update_status(job, IngestionStatus.DELETE_COMPLETED) @@ -183,7 +183,11 @@ def pipeline_delete_collection(job: IngestionJob) -> None: # Update collection status to DELETE_FAILED try: - collection_repo.update(job.collection_id, job.repository_id, {"status": CollectionStatus.DELETE_FAILED}) + collection_repo.update( + job.collection_id, # type: ignore[arg-type] + job.repository_id, + {"status": CollectionStatus.DELETE_FAILED}, + ) except Exception as update_error: logger.error(f"Failed to update collection status: {update_error}") @@ -221,7 +225,7 @@ def pipeline_delete_document(job: IngestionJob) -> None: logger.info(f"Deleting document {job.s3_path} for repository {job.repository_id}") # Find associated RagDocument - rag_document = rag_document_repository.find_by_id(job.document_id, join_docs=True) + rag_document = rag_document_repository.find_by_id(job.document_id, join_docs=True) # type: ignore[arg-type] if rag_document: # Actually remove from vector store @@ -290,7 +294,7 @@ def pipeline_delete_documents(job: IngestionJob) -> None: failed = 0 errors = [] # For Bedrock KB, group S3 paths by data source (collection_id) - s3_paths_by_data_source = {} + s3_paths_by_data_source = {} # type: ignore[var-annotated] for document_id in document_ids: try: @@ -371,7 +375,7 @@ def pipeline_delete_documents(job: IngestionJob) -> None: raise Exception(error_msg) -def handle_pipeline_delete_event(event: Dict[str, Any], context: Any) -> None: +def handle_pipeline_delete_event(event: dict[str, Any], context: Any) -> None: """Handle pipeline document deletion for S3 ObjectRemoved events.""" # Extract and validate inputs logger.debug(f"Received event: {event}") diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 79402a0d2..d3ba11755 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -17,7 +17,7 @@ import logging import os from datetime import timedelta -from typing import Any, Dict, List +from typing import Any import boto3 from models.domain_objects import ( @@ -87,7 +87,7 @@ def pipeline_ingest_document(job: IngestionJob) -> None: try: kb_bucket = get_datasource_bucket_for_collection( repository=repository, - collection_id=job.collection_id, + collection_id=job.collection_id, # type: ignore[arg-type] ) except ValueError as e: error_msg = str(e) @@ -113,7 +113,7 @@ def pipeline_ingest_document(job: IngestionJob) -> None: # Check if document already exists (idempotent operation) existing_docs = list( rag_document_repository.find_by_source( - job.repository_id, job.collection_id, kb_s3_path, join_docs=False + job.repository_id, job.collection_id, kb_s3_path, join_docs=False # type: ignore[arg-type] ) ) @@ -165,7 +165,7 @@ def pipeline_ingest_document(job: IngestionJob) -> None: collection = None try: collection = collection_service.get_collection( - collection_id=job.collection_id, + collection_id=job.collection_id, # type: ignore[arg-type] repository_id=job.repository_id, username="system", user_groups=[], @@ -203,18 +203,23 @@ def pipeline_ingest_document(job: IngestionJob) -> None: # Non-Bedrock KB path documents = generate_chunks(job) - texts, metadatas = prepare_chunks(documents, job.repository_id, job.collection_id) + texts, metadatas = prepare_chunks(documents, job.repository_id, job.collection_id) # type: ignore[arg-type] all_ids = store_chunks_in_vectorstore( texts=texts, metadatas=metadatas, repository_id=job.repository_id, - collection_id=job.collection_id, - embedding_model=job.embedding_model, + collection_id=job.collection_id, # type: ignore[arg-type] + embedding_model=job.embedding_model, # type: ignore[arg-type] ) # remove old for rag_document in list( - rag_document_repository.find_by_source(job.repository_id, job.collection_id, job.s3_path, join_docs=True) + rag_document_repository.find_by_source( + job.repository_id, + job.collection_id, # type: ignore[arg-type] + job.s3_path, + join_docs=True, + ) ): prev_job = ingestion_job_repository.find_by_document(rag_document.document_id) @@ -323,7 +328,7 @@ def pipeline_ingest_documents(job: IngestionJob) -> None: errors.append(error_msg) # Update job with document IDs - job.document_ids = document_ids + job.document_ids = document_ids # type: ignore[assignment] if failed == 0: ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_COMPLETED) @@ -385,7 +390,7 @@ def _handle_s3_discovery_scan(job: IngestionJob) -> None: # Perform discovery and ingestion result = discovery_service.discover_and_ingest_documents( repository_id=job.repository_id, - collection_id=job.collection_id, + collection_id=job.collection_id, # type: ignore[arg-type] s3_bucket=s3_bucket, s3_prefix=s3_prefix, ingestion_type=job.ingestion_type, @@ -417,10 +422,10 @@ def remove_document_from_vectorstore(doc: RagDocument) -> None: collection_id=doc.collection_id, embeddings=embeddings, ) - vector_store.delete(doc.subdocs) + vector_store.delete(doc.subdocs) # type: ignore[union-attr] -def handle_pipeline_ingest_event(event: Dict[str, Any], context: Any) -> None: +def handle_pipeline_ingest_event(event: dict[str, Any], context: Any) -> None: """Handle pipeline document ingestion.""" # Extract and validate inputs logger.debug(f"Received event: {event}") @@ -452,13 +457,20 @@ def handle_pipeline_ingest_event(event: Dict[str, Any], context: Any) -> None: data_sources = bedrock_config.get("dataSources", []) if data_sources: first_data_source = data_sources[0] - if isinstance(first_data_source, dict): - collection_id = first_data_source.get("id") - else: - collection_id = getattr(first_data_source, "id", None) + collection_id_val: str | None = ( + first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id + ) + if not collection_id_val: + logger.error(f"Bedrock KB repository {repository_id} has invalid data source") + return + collection_id = collection_id_val else: # Try legacy single data source ID - collection_id = bedrock_config.get("bedrockKnowledgeDatasourceId") + collection_id_val = bedrock_config.get("bedrockKnowledgeDatasourceId") + if not collection_id_val: + logger.error(f"Bedrock KB repository {repository_id} missing data source ID") + return + collection_id = collection_id_val if not collection_id: logger.error(f"Bedrock KB repository {repository_id} missing data source ID") @@ -529,7 +541,7 @@ def handle_pipeline_ingest_event(event: Dict[str, Any], context: Any) -> None: logger.info(f"Submitted ingestion job for document {s3_path} in repository {repository_id}") -def handle_pipline_ingest_schedule(event: Dict[str, Any], context: Any) -> None: +def handle_pipline_ingest_schedule(event: dict[str, Any], context: Any) -> None: """ Lists all objects in the specified S3 bucket and prefix that were modified in the last 24 hours. @@ -623,7 +635,7 @@ def handle_pipline_ingest_schedule(event: Dict[str, Any], context: Any) -> None: raise e -def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) -> list[tuple[list[str], list[dict]]]: +def batch_texts(texts: list[str], metadatas: list[dict], batch_size: int = 500) -> list[tuple[list[str], list[dict]]]: """ Split texts and metadata into batches of specified size. @@ -642,7 +654,7 @@ def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) return batches -def extract_chunk_strategy(pipeline_config: Dict) -> ChunkingStrategy: +def extract_chunk_strategy(pipeline_config: dict) -> ChunkingStrategy: """ Extract and validate chunking strategy from pipeline configuration. @@ -665,7 +677,8 @@ def extract_chunk_strategy(pipeline_config: Dict) -> ChunkingStrategy: if chunk_type == "fixed": # Use Pydantic model validation for type safety and validation - return FixedChunkingStrategy.model_validate(chunking_strategy) + result: FixedChunkingStrategy = FixedChunkingStrategy.model_validate(chunking_strategy) + return result else: # Future: Handle other chunking strategy types (semantic, recursive, etc.) raise ValueError(f"Unsupported chunking strategy type: {chunk_type}") @@ -683,7 +696,7 @@ def extract_chunk_strategy(pipeline_config: Dict) -> ChunkingStrategy: return FixedChunkingStrategy(size=512, overlap=51) -def prepare_chunks(docs: List, repository_id: str, collection_id: str) -> tuple[List[str], List[Dict]]: +def prepare_chunks(docs: list, repository_id: str, collection_id: str) -> tuple[list[str], list[dict]]: """Prepare texts and metadata from document chunks.""" texts = [] metadatas = [] @@ -696,8 +709,8 @@ def prepare_chunks(docs: List, repository_id: str, collection_id: str) -> tuple[ def store_chunks_in_vectorstore( - texts: List[str], metadatas: List[Dict], repository_id: str, collection_id: str, embedding_model: str -) -> List[str]: + texts: list[str], metadatas: list[dict], repository_id: str, collection_id: str, embedding_model: str +) -> list[str]: """Store document chunks in vector store using repository service.""" vs_repo = VectorStoreRepository() repository = vs_repo.find_repository_by_id(repository_id) @@ -717,7 +730,7 @@ def store_chunks_in_vectorstore( for i, (text_batch, metadata_batch) in enumerate(batches, 1): logger.info(f"Processing batch {i}/{total_batches} with {len(text_batch)} texts") - batch_ids = vs.add_texts(texts=text_batch, metadatas=metadata_batch) + batch_ids = vs.add_texts(texts=text_batch, metadatas=metadata_batch) # type: ignore[union-attr] if not batch_ids: raise Exception(f"Failed to store documents in vector store for batch {i}") all_ids.extend(batch_ids) diff --git a/lambda/repository/rag_document_repo.py b/lambda/repository/rag_document_repo.py index 8c3583179..d688da3bf 100644 --- a/lambda/repository/rag_document_repo.py +++ b/lambda/repository/rag_document_repo.py @@ -13,8 +13,8 @@ # limitations under the License. import logging import os +from collections.abc import Generator from concurrent.futures import as_completed, ThreadPoolExecutor -from typing import Generator, Optional import boto3 from boto3.dynamodb.conditions import Key @@ -94,7 +94,7 @@ def save(self, document: RagDocument) -> None: logging.error(f"Error saving document: {e.response['Error']['Message']}") raise - def find_by_id(self, document_id: str, join_docs: bool = False) -> Optional[RagDocument]: + def find_by_id(self, document_id: str, join_docs: bool = False) -> RagDocument | None: """Query documents using GSI. Args: @@ -167,7 +167,7 @@ def find_by_name( def find_by_source( self, repository_id: str, collection_id: str, document_source: str, join_docs: bool = False - ) -> Generator[RagDocument, None, None]: + ) -> Generator[RagDocument]: """Get a list of documents from the RagDocTable by source. Args: @@ -199,7 +199,7 @@ def find_by_source( yield from self._yield_documents(response["Items"], join_docs=join_docs) - def _yield_documents(self, items: list[dict], join_docs: bool) -> Generator[RagDocument, None, None]: + def _yield_documents(self, items: list[dict], join_docs: bool) -> Generator[RagDocument]: for item in items: document = RagDocument(**item) if join_docs: @@ -209,11 +209,11 @@ def _yield_documents(self, items: list[dict], join_docs: bool) -> Generator[RagD def list_all( self, repository_id: str, - collection_id: Optional[str] = None, - last_evaluated_key: Optional[dict] = None, + collection_id: str | None = None, + last_evaluated_key: dict | None = None, limit: int = 100, join_docs: bool = False, - ) -> tuple[list[RagDocument], Optional[dict], int]: + ) -> tuple[list[RagDocument], dict | None, int]: """List all documents in a collection. Args: @@ -261,7 +261,7 @@ def list_all( logging.error(f"Error listing documents: {e.response['Error']['Message']}") raise - def count_documents(self, repository_id: str, collection_id: Optional[str] = None) -> int: + def count_documents(self, repository_id: str, collection_id: str | None = None) -> int: """Count total documents in a repository/collection. Args: repository_id: Repository ID @@ -269,7 +269,7 @@ def count_documents(self, repository_id: str, collection_id: Optional[str] = Non Returns: Total number of documents """ - count = 0 + count: int = 0 # Count all rag documents using repo id only if not collection_id: response = self.doc_table.query( @@ -277,11 +277,11 @@ def count_documents(self, repository_id: str, collection_id: Optional[str] = Non KeyConditionExpression=Key("repository_id").eq(repository_id), Select="COUNT", ) - count = response.get("Count", 0) + count = int(response.get("Count", 0)) else: pk = RagDocument.createPartitionKey(repository_id, collection_id) response = self.doc_table.query(KeyConditionExpression=Key("pk").eq(pk), Select="COUNT") - count = response.get("Count", 0) + count = int(response.get("Count", 0)) return count def find_subdocs_by_id(self, document_id: str) -> list[RagSubDocument]: diff --git a/lambda/repository/s3_metadata_manager.py b/lambda/repository/s3_metadata_manager.py index 599ef176e..ac64c3177 100644 --- a/lambda/repository/s3_metadata_manager.py +++ b/lambda/repository/s3_metadata_manager.py @@ -16,7 +16,7 @@ import json import logging -from typing import Any, Dict, List, Tuple +from typing import Any from botocore.exceptions import ClientError @@ -32,10 +32,10 @@ class S3MetadataManager: def upload_metadata_file( self, - s3_client, + s3_client: Any, bucket: str, document_key: str, - metadata_content: Dict[str, Any], + metadata_content: dict[str, Any], ) -> str: """Upload metadata.json file to S3. @@ -90,10 +90,12 @@ def upload_metadata_file( else: logger.error(f"Failed to upload metadata file after {MAX_RETRIES} attempts: {metadata_key}") raise + # This should never be reached due to the raise above, but mypy needs it + raise RuntimeError(f"Failed to upload metadata file: {metadata_key}") # pragma: no cover def delete_metadata_file( self, - s3_client, + s3_client: Any, bucket: str, document_key: str, ) -> None: @@ -132,7 +134,9 @@ def delete_metadata_file( # Log other errors but don't fail logger.warning(f"Failed to delete metadata file: {metadata_key}, error: {e}") - def batch_upload_metadata(self, s3_client, bucket: str, documents: List[Tuple[str, Dict[str, Any]]]) -> List[str]: + def batch_upload_metadata( + self, s3_client: Any, bucket: str, documents: list[tuple[str, dict[str, Any]]] + ) -> list[str]: """Upload multiple metadata files in batch. Args: @@ -163,7 +167,7 @@ def batch_upload_metadata(self, s3_client, bucket: str, documents: List[Tuple[st return uploaded_keys - def batch_delete_metadata(self, s3_client, bucket: str, document_keys: List[str]) -> int: + def batch_delete_metadata(self, s3_client: Any, bucket: str, document_keys: list[str]) -> int: """Delete multiple metadata files in batch. Args: diff --git a/lambda/repository/services/bedrock_kb_repository_service.py b/lambda/repository/services/bedrock_kb_repository_service.py index 13fbf7773..e4fe516b4 100644 --- a/lambda/repository/services/bedrock_kb_repository_service.py +++ b/lambda/repository/services/bedrock_kb_repository_service.py @@ -16,7 +16,7 @@ import logging import os -from typing import Any, Dict, List, Optional +from typing import Any import boto3 from boto3.dynamodb.conditions import Key @@ -55,14 +55,14 @@ def should_create_default_collection(self) -> bool: """Bedrock KB does not need virtual default collections.""" return False - def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: + def get_collection_id_from_config(self, pipeline_config: dict[str, Any]) -> str: """For Bedrock KB, collection ID is the data source ID. Extracts the data source ID from the pipeline config's collectionId field, which should match one of the data sources in bedrockKnowledgeBaseConfig. """ # The pipeline config should have a collectionId that matches a data source ID - collection_id = pipeline_config.get("collectionId") + collection_id: str | None = pipeline_config.get("collectionId") if collection_id: return collection_id @@ -74,10 +74,11 @@ def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: data_sources = bedrock_config.get("dataSources", []) if data_sources: first_data_source = data_sources[0] - data_source_id = ( + data_source_id: str | None = ( first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id ) - return data_source_id + if data_source_id: + return data_source_id # Try legacy single data source ID data_source_id = bedrock_config.get("bedrockKnowledgeDatasourceId") @@ -89,8 +90,8 @@ def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: def ingest_document( self, job: IngestionJob, - texts: List[str], - metadatas: List[Dict[str, Any]], + texts: list[str], + metadatas: list[dict[str, Any]], ) -> RagDocument: """Track document for Bedrock KB - KB handles actual ingestion. @@ -108,8 +109,17 @@ def ingest_document( os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"] ) + # Ensure collection_id is not None + if not job.collection_id: + raise ValueError("collection_id is required for document ingestion") + existing_docs = list( - rag_document_repository.find_by_source(job.repository_id, job.collection_id, kb_s3_path, join_docs=False) + rag_document_repository.find_by_source( + job.repository_id, + job.collection_id, + kb_s3_path, + join_docs=False, + ) ) if existing_docs: @@ -140,7 +150,7 @@ def delete_document( self, document: RagDocument, s3_client: Any, - bedrock_agent_client: Optional[Any] = None, + bedrock_agent_client: Any | None = None, ) -> None: """Delete document from Bedrock KB.""" if not bedrock_agent_client: @@ -166,7 +176,7 @@ def delete_collection( self, collection_id: str, s3_client: Any, - bedrock_agent_client: Optional[Any] = None, + bedrock_agent_client: Any | None = None, ) -> None: """Delete all LISA-managed documents from Bedrock KB collection. @@ -229,8 +239,8 @@ def retrieve_documents( top_k: int, model_name: str, include_score: bool = False, - bedrock_agent_client: Optional[Any] = None, - ) -> List[Dict[str, Any]]: + bedrock_agent_client: Any | None = None, + ) -> list[dict[str, Any]]: """Retrieve documents from Bedrock KB using retrieve API. Args: @@ -249,7 +259,7 @@ def retrieve_documents( bedrock_config = self.repository.get("bedrockKnowledgeBaseConfig", {}) # Support both field names for backward compatibility - kb_id = bedrock_config.get("knowledgeBaseId", bedrock_config.get("bedrockKnowledgeBaseId")) + kb_id: str | None = bedrock_config.get("knowledgeBaseId", bedrock_config.get("bedrockKnowledgeBaseId")) if not kb_id: raise ValueError( @@ -263,7 +273,7 @@ def retrieve_documents( logger.info(f"Retrieving from KB: kb_id={kb_id}, data_source={collection_id}, query={query[:50]}...") # Build retrieve params with data source filter - retrieve_params = { + retrieve_params: dict[str, Any] = { "knowledgeBaseId": kb_id, "retrievalQuery": {"text": query}, "retrievalConfiguration": { @@ -276,7 +286,8 @@ def retrieve_documents( # Add data source filter if collection_id is provided # collection_id corresponds to the data source ID in Bedrock KB if collection_id: - retrieve_params["retrievalConfiguration"]["vectorSearchConfiguration"]["filter"] = { + vector_search_config = retrieve_params["retrievalConfiguration"]["vectorSearchConfiguration"] + vector_search_config["filter"] = { "equals": { "key": "x-amz-bedrock-kb-data-source-id", "value": collection_id, @@ -302,7 +313,9 @@ def retrieve_documents( ) logger.error(f"Bedrock retrieve failed for KB {kb_id}: {error_message}") - if "filter" in retrieve_params.get("retrievalConfiguration", {}).get("vectorSearchConfiguration", {}): + retrieval_config = retrieve_params.get("retrievalConfiguration", {}) + vector_search = retrieval_config.get("vectorSearchConfiguration", {}) + if "filter" in vector_search: logger.error( "Filter may not be supported. Ensure metadata field 'x-amz-bedrock-kb-data-source-id' " "is configured in the Knowledge Base." @@ -338,16 +351,19 @@ def retrieve_documents( def validate_document_source(self, s3_path: str) -> str: """Validate document is from KB data source bucket.""" bedrock_config = self.repository.get("bedrockKnowledgeBaseConfig", {}) - kb_bucket = bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket") + kb_bucket: str | None = bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket") + + if not kb_bucket: + raise ValueError("KB bucket not configured") return self._validate_and_normalize_path(s3_path, kb_bucket) - def get_vector_store_client(self, collection_id: str, embeddings: Any) -> Optional[Any]: + def get_vector_store_client(self, collection_id: str, embeddings: Any) -> Any | None: """Bedrock KB does not use external vector store clients.""" return None def _create_collection_for_data_source( - self, data_source_id: str, s3_uri: str = "", is_default: bool = False, collection_name: Optional[str] = None + self, data_source_id: str, s3_uri: str = "", is_default: bool = False, collection_name: str | None = None ) -> RagCollectionConfig: """Create a collection configuration for a specific data source. @@ -396,7 +412,7 @@ def _create_collection_for_data_source( return collection - def create_default_collection(self, ingest_docs=False) -> Optional[RagCollectionConfig]: + def create_default_collection(self, ingest_docs: bool = False) -> RagCollectionConfig | None: """Create a default collection for Bedrock KB repository. For Bedrock KB, the collection ID is the data source ID. @@ -421,9 +437,13 @@ def create_default_collection(self, ingest_docs=False) -> Optional[RagCollection # Use first data source from array, or legacy single ID if data_sources: first_data_source = data_sources[0] - data_source_id = ( + data_source_id: str | None = ( first_data_source.get("id") if isinstance(first_data_source, dict) else first_data_source.id ) + if not data_source_id: + logger.warning(f"Bedrock KB repository {self.repository_id} has invalid data source") + return None + s3_uri = ( first_data_source.get("s3Uri", "") if isinstance(first_data_source, dict) @@ -431,6 +451,9 @@ def create_default_collection(self, ingest_docs=False) -> Optional[RagCollection ) else: data_source_id = legacy_data_source_id + if not data_source_id: + logger.warning(f"Bedrock KB repository {self.repository_id} missing data source ID") + return None s3_uri = "" # Use helper method to create collection diff --git a/lambda/repository/services/pgvector_repository_service.py b/lambda/repository/services/pgvector_repository_service.py index f34515484..aa65d46a1 100644 --- a/lambda/repository/services/pgvector_repository_service.py +++ b/lambda/repository/services/pgvector_repository_service.py @@ -110,15 +110,18 @@ def _get_vector_store_client(self, collection_id: str, embeddings: Embeddings) - if not RepositoryType.is_type(connection_info, RepositoryType.PGVECTOR): raise ValueError(f"Repository {self.repository_id} is not a PGVector repository") + # Check if using password auth (passwordSecretId present) or IAM auth if "passwordSecretId" in connection_info: - # Provides backwards compatibility to non-IAM authenticated vector stores + # Password auth: get credentials from Secrets Manager secrets_response = secretsmanager_client.get_secret_value(SecretId=connection_info.get("passwordSecretId")) user = connection_info.get("username") password = json.loads(secrets_response.get("SecretString")).get("password") + use_ssl = False else: - # Use IAM auth token to connect + # IAM auth: generate auth token user = get_lambda_role_name() password = generate_auth_token(connection_info.get("dbHost"), connection_info.get("dbPort"), user) + use_ssl = True # IAM auth requires SSL connection_string = PGVector.connection_string_from_db_params( driver="psycopg2", @@ -129,6 +132,9 @@ def _get_vector_store_client(self, collection_id: str, embeddings: Embeddings) - password=password, ) + if use_ssl: + connection_string = f"{connection_string}?sslmode=require" + return PGVector( collection_name=collection_id, connection_string=connection_string, diff --git a/lambda/repository/services/repository_service.py b/lambda/repository/services/repository_service.py index 31482ee95..36bc21eca 100644 --- a/lambda/repository/services/repository_service.py +++ b/lambda/repository/services/repository_service.py @@ -15,7 +15,7 @@ """Base service interface for repository operations.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any from models.domain_objects import IngestionJob, RagCollectionConfig, RagDocument @@ -27,7 +27,7 @@ class RepositoryService(ABC): interface to provide type-specific behavior for document management. """ - def __init__(self, repository: Dict[str, Any]): + def __init__(self, repository: dict[str, Any]): """Initialize service with repository configuration. Args: @@ -55,7 +55,7 @@ def should_create_default_collection(self) -> bool: pass @abstractmethod - def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: + def get_collection_id_from_config(self, pipeline_config: dict[str, Any]) -> str: """Extract collection ID from pipeline configuration. Args: @@ -70,8 +70,8 @@ def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: def ingest_document( self, job: IngestionJob, - texts: List[str], - metadatas: List[Dict[str, Any]], + texts: list[str], + metadatas: list[dict[str, Any]], ) -> RagDocument: """Ingest a document into the repository. @@ -90,7 +90,7 @@ def delete_document( self, document: RagDocument, s3_client: Any, - bedrock_agent_client: Optional[Any] = None, + bedrock_agent_client: Any | None = None, ) -> None: """Delete a document from the repository. @@ -106,7 +106,7 @@ def delete_collection( self, collection_id: str, s3_client: Any, - bedrock_agent_client: Optional[Any] = None, + bedrock_agent_client: Any | None = None, ) -> None: """Delete an entire collection from the repository. @@ -125,8 +125,8 @@ def retrieve_documents( top_k: int, model_name: str, include_score: bool = False, - bedrock_agent_client: Optional[Any] = None, - ) -> List[Dict[str, Any]]: + bedrock_agent_client: Any | None = None, + ) -> list[dict[str, Any]]: """Retrieve documents matching a query. Args: @@ -158,7 +158,7 @@ def validate_document_source(self, s3_path: str) -> str: pass @abstractmethod - def get_vector_store_client(self, collection_id: str, embeddings: Any) -> Optional[Any]: + def get_vector_store_client(self, collection_id: str, embeddings: Any) -> Any | None: """Get vector store client for this repository. Args: @@ -171,7 +171,7 @@ def get_vector_store_client(self, collection_id: str, embeddings: Any) -> Option pass @abstractmethod - def create_default_collection(self) -> Optional[RagCollectionConfig]: + def create_default_collection(self) -> RagCollectionConfig | None: """Create a default collection for this repository. Returns: diff --git a/lambda/repository/services/repository_service_factory.py b/lambda/repository/services/repository_service_factory.py index 88f2b9d35..a1f0e1828 100644 --- a/lambda/repository/services/repository_service_factory.py +++ b/lambda/repository/services/repository_service_factory.py @@ -14,7 +14,7 @@ """Factory for creating repository service instances.""" -from typing import Any, Dict, Type +from typing import Any from utilities.repository_types import RepositoryType @@ -32,14 +32,14 @@ class RepositoryServiceFactory: """ # Registry mapping repository types to service classes - _services: Dict[RepositoryType, Type[RepositoryService]] = { + _services: dict[RepositoryType, type[RepositoryService]] = { RepositoryType.OPENSEARCH: OpenSearchRepositoryService, RepositoryType.PGVECTOR: PGVectorRepositoryService, RepositoryType.BEDROCK_KB: BedrockKBRepositoryService, } @classmethod - def create_service(cls, repository: Dict[str, Any]) -> RepositoryService: + def create_service(cls, repository: dict[str, Any]) -> RepositoryService: """Create appropriate service instance for repository type. Args: @@ -62,7 +62,7 @@ def create_service(cls, repository: Dict[str, Any]) -> RepositoryService: return service_class(repository) @classmethod - def register_service(cls, repo_type: RepositoryType, service_class: Type[RepositoryService]) -> None: + def register_service(cls, repo_type: RepositoryType, service_class: type[RepositoryService]) -> None: """Register a new service class for a repository type. Allows extending the factory with new repository types without diff --git a/lambda/repository/services/vector_store_repository_service.py b/lambda/repository/services/vector_store_repository_service.py index 00b4b0bc8..e84edadfa 100644 --- a/lambda/repository/services/vector_store_repository_service.py +++ b/lambda/repository/services/vector_store_repository_service.py @@ -21,7 +21,7 @@ import logging import os from abc import abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any import boto3 from langchain_core.embeddings import Embeddings @@ -63,26 +63,34 @@ def should_create_default_collection(self) -> bool: """Vector stores create virtual default collections.""" return True - def get_collection_id_from_config(self, pipeline_config: Dict[str, Any]) -> str: + def get_collection_id_from_config(self, pipeline_config: dict[str, Any]) -> str: """Extract collection ID from pipeline config or use embedding model.""" - collection_id = pipeline_config.get("collectionId") + collection_id: str | None = pipeline_config.get("collectionId") if not collection_id: collection_id = pipeline_config.get("embeddingModel") + if not collection_id: + raise ValueError("No collection ID or embedding model found in pipeline config") return collection_id def ingest_document( self, job: IngestionJob, - texts: List[str], - metadatas: List[Dict[str, Any]], + texts: list[str], + metadatas: list[dict[str, Any]], ) -> RagDocument: """Ingest document into vector store with chunking and embedding.""" # Store chunks in vector store + collection_id_str: str = job.collection_id if job.collection_id else "" + embedding_model_str: str = job.embedding_model if job.embedding_model else "" + + if not collection_id_str or not embedding_model_str: + raise ValueError("collection_id and embedding_model are required for ingestion") + all_ids = self._store_chunks( texts=texts, metadatas=metadatas, - collection_id=job.collection_id, - embedding_model=job.embedding_model, + collection_id=collection_id_str, + embedding_model=embedding_model_str, ) # Create document record @@ -112,7 +120,7 @@ def delete_document( self, document: RagDocument, s3_client: Any, - bedrock_agent_client: Optional[Any] = None, + bedrock_agent_client: Any | None = None, ) -> None: """Delete document from vector store.""" embeddings = RagEmbeddings(model_name=document.collection_id) @@ -126,7 +134,7 @@ def delete_collection( self, collection_id: str, s3_client: Any, - bedrock_agent_client: Optional[Any] = None, + bedrock_agent_client: Any | None = None, ) -> None: """Delete collection from vector store. @@ -142,8 +150,8 @@ def retrieve_documents( top_k: int, model_name: str, include_score: bool = False, - bedrock_agent_client: Optional[Any] = None, - ) -> List[Dict[str, Any]]: + bedrock_agent_client: Any | None = None, + ) -> list[dict[str, Any]]: """Retrieve documents from vector store using similarity search. Args: @@ -233,7 +241,7 @@ def _normalize_similarity_score(self, score: float) -> float: """ return score - def create_default_collection(self) -> Optional[RagCollectionConfig]: + def create_default_collection(self) -> RagCollectionConfig | None: """Create a default collection for vector store repositories. Returns: @@ -288,11 +296,11 @@ def create_default_collection(self) -> Optional[RagCollectionConfig]: def _store_chunks( self, - texts: List[str], - metadatas: List[Dict[str, Any]], + texts: list[str], + metadatas: list[dict[str, Any]], collection_id: str, embedding_model: str, - ) -> List[str]: + ) -> list[str]: """Store document chunks in vector store.""" embeddings = RagEmbeddings(model_name=embedding_model) vector_store = self._get_vector_store_client( diff --git a/lambda/repository/state_machine/cleanup_repo_docs.py b/lambda/repository/state_machine/cleanup_repo_docs.py index 9e1924137..4edcbbebb 100644 --- a/lambda/repository/state_machine/cleanup_repo_docs.py +++ b/lambda/repository/state_machine/cleanup_repo_docs.py @@ -14,7 +14,7 @@ import logging import os -from typing import Any, Dict +from typing import Any from models.domain_objects import IngestionType from pydantic import BaseModel @@ -24,7 +24,7 @@ doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"]) -def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any] | Any: +def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any] | Any: """ Remove LISA-managed documents from repository. @@ -43,7 +43,10 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any] | Any: last_evaluated = event.get("lastEvaluated") # Get all documents - docs, last_evaluated, _ = doc_repo.list_all(repository_id=repository_id, last_evaluated_key=last_evaluated) + docs, last_evaluated, _ = doc_repo.list_all( + repository_id=repository_id, # type: ignore[arg-type] + last_evaluated_key=last_evaluated, + ) # Filter to LISA-managed only (MANUAL or AUTO) lisa_managed = [d for d in docs if d.ingestion_type in [IngestionType.MANUAL, IngestionType.AUTO]] @@ -59,7 +62,7 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any] | Any: doc_repo.delete_by_id(doc.document_id) # Delete from S3 (only LISA-managed) - doc_repo.delete_s3_docs(repository_id=repository_id, docs=lisa_managed) + doc_repo.delete_s3_docs(repository_id=repository_id, docs=lisa_managed) # type: ignore[arg-type] # Ensure JSON-serializable payload for Step Functions when Pydantic models are provided serializable_docs = [doc.model_dump() if isinstance(doc, BaseModel) else doc for doc in lisa_managed] diff --git a/lambda/repository/state_machine/list_modified_objects.py b/lambda/repository/state_machine/list_modified_objects.py index 4ae8ecdd2..87748c5f5 100644 --- a/lambda/repository/state_machine/list_modified_objects.py +++ b/lambda/repository/state_machine/list_modified_objects.py @@ -17,7 +17,7 @@ import logging import os from datetime import timedelta -from typing import Any, Dict +from typing import Any import boto3 from utilities.time import utc_now @@ -76,7 +76,7 @@ def validate_bucket_prefix(bucket: str, prefix: str) -> bool: return True -def handle_list_modified_objects(event: Dict[str, Any], context: Any) -> Dict[str, Any] | Any: +def handle_list_modified_objects(event: dict[str, Any], context: Any) -> dict[str, Any] | Any: """ Lists all objects in the specified S3 bucket and prefix that were modified in the last 24 hours. diff --git a/lambda/repository/state_machine/wait_for_collection_deletions.py b/lambda/repository/state_machine/wait_for_collection_deletions.py index 00f08d0e4..0c4e96f51 100644 --- a/lambda/repository/state_machine/wait_for_collection_deletions.py +++ b/lambda/repository/state_machine/wait_for_collection_deletions.py @@ -15,14 +15,14 @@ """Wait for all collection deletion jobs to complete before deleting repository.""" import logging -from typing import Any, Dict +from typing import Any from repository.ingestion_job_repo import IngestionJobRepository logger = logging.getLogger(__name__) -def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def lambda_handler(event: dict[str, Any], context: Any) -> dict[str, Any]: """ Check if all collection deletion jobs for a repository are complete. @@ -41,7 +41,7 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: job_repo = IngestionJobRepository() # Query all jobs for this repository - pending_jobs = job_repo.find_pending_collection_deletions(repository_id) + pending_jobs = job_repo.find_pending_collection_deletions(repository_id) # type: ignore[arg-type] pending_count = len(pending_jobs) all_complete = pending_count == 0 diff --git a/lambda/repository/vector_store_repo.py b/lambda/repository/vector_store_repo.py index 25a0bc569..0eaab8d57 100644 --- a/lambda/repository/vector_store_repo.py +++ b/lambda/repository/vector_store_repo.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, cast, List +from typing import Any, cast import boto3 from boto3.dynamodb.conditions import Attr @@ -34,7 +34,7 @@ def __init__(self, table_name: str | None = None) -> None: table_name = os.environ["LISA_RAG_VECTOR_STORE_TABLE"] self.table = dynamodb.Table(table_name) - def get_registered_repositories(self) -> List[dict]: + def get_registered_repositories(self) -> list[dict]: """Get a list of all registered RAG repositories with default values for new fields.""" response = self.table.scan() items = response["Items"] @@ -181,7 +181,7 @@ def delete(self, repository_id: str) -> bool: except Exception as e: raise ValueError(f"Failed to delete repository: {repository_id}", e) - def find_repositories_using_model(self, model_id: str) -> List[dict]: + def find_repositories_using_model(self, model_id: str) -> list[dict]: """ Find all repositories that use a specific model. Excludes repositories with status indicating they are deleted or archived. diff --git a/lambda/session/lambda_functions.py b/lambda/session/lambda_functions.py index 9d25d5a46..1efc079fb 100644 --- a/lambda/session/lambda_functions.py +++ b/lambda/session/lambda_functions.py @@ -20,15 +20,16 @@ import uuid from concurrent.futures import ThreadPoolExecutor from decimal import Decimal -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import boto3 import create_env_variables # noqa: F401 from botocore.exceptions import ClientError -from cachetools import cached, TTLCache +from cachetools import cached, TTLCache # type: ignore[import-untyped,unused-ignore] from utilities.auth import get_user_context, get_username from utilities.common_functions import api_wrapper, get_session_id, retry_config from utilities.encoders import convert_decimal +from utilities.input_validation import MAX_LARGE_REQUEST_SIZE from utilities.session_encryption import decrypt_session_fields, migrate_session_to_encrypted, SessionEncryptionError from utilities.time import iso_string @@ -50,7 +51,7 @@ executor = ThreadPoolExecutor(max_workers=10) # Cache for configuration values to avoid repeated database queries -cache = TTLCache(maxsize=1, ttl=300) # 5 minutes +cache: TTLCache = TTLCache(maxsize=1, ttl=300) # 5 minutes @cached(cache=cache) @@ -81,7 +82,7 @@ def _is_session_encryption_enabled() -> bool: enabled_components = configuration.get("enabledComponents", {}) encrypt_session = enabled_components.get("encryptSession", False) # Default to False logger.info(f"Retrieved session encryption setting from global config: {encrypt_session}") - return encrypt_session + return encrypt_session # type: ignore[no-any-return] else: logger.warning("No global configuration found, defaulting session encryption to disabled") return False @@ -119,7 +120,7 @@ def _get_current_model_config(model_id: str) -> Any: return {} -def _update_session_with_current_model_config(session_config: Dict[str, Any]) -> Dict[str, Any]: +def _update_session_with_current_model_config(session_config: dict[str, Any]) -> dict[str, Any]: """Update session configuration with the most recent model configuration. Parameters @@ -176,7 +177,7 @@ def _update_session_with_current_model_config(session_config: Dict[str, Any]) -> return updated_config -def _get_all_user_sessions(user_id: str) -> List[Dict[str, Any]]: +def _get_all_user_sessions(user_id: str) -> list[dict[str, Any]]: """Get all sessions for a user from DynamoDB. Parameters @@ -205,7 +206,33 @@ def _get_all_user_sessions(user_id: str) -> List[Dict[str, Any]]: return response.get("Items", []) # type: ignore [no-any-return] -def _delete_user_session(session_id: str, user_id: str) -> Dict[str, bool]: +def _extract_video_s3_keys(session: dict) -> list[str]: + """Extract all video S3 keys from a session's history. + + Parameters + ---------- + session : dict + The session object containing history. + + Returns + ------- + list[str] + A list of S3 keys for videos in the session. + """ + video_keys: list[str] = [] + for message in session.get("history", []): + content = message.get("content") + if isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "video_url": + video_url = item.get("video_url", {}) + s3_key = video_url.get("s3_key") + if s3_key: + video_keys.append(s3_key) + return video_keys + + +def _delete_user_session(session_id: str, user_id: str) -> dict[str, bool]: """Delete a session from DynamoDB. Parameters @@ -222,9 +249,39 @@ def _delete_user_session(session_id: str, user_id: str) -> Dict[str, bool]: """ deleted = False try: + # First, get the session to extract any video S3 keys before deleting + response = table.get_item(Key={"sessionId": session_id, "userId": user_id}) + session = response.get("Item", {}) + + # Decrypt session if encrypted to access history for video keys + if session.get("is_encrypted", False): + try: + logger.info(f"Decrypting session {session_id} to extract video keys for deletion") + session = decrypt_session_fields(session, user_id, session_id) + except SessionEncryptionError as e: + logger.warning(f"Failed to decrypt session {session_id} for video cleanup: {e}") + # Continue with deletion even if decryption fails - videos may remain orphaned + + # Extract video S3 keys from the session history + video_keys = _extract_video_s3_keys(session) + + # Delete the session from DynamoDB table.delete_item(Key={"sessionId": session_id, "userId": user_id}) + + # Delete associated images from S3 bucket = s3_resource.Bucket(s3_bucket_name) bucket.objects.filter(Prefix=f"images/{session_id}").delete() + + # Delete associated videos from S3 + if video_keys: + logger.info(f"Deleting {len(video_keys)} videos from S3 for session {session_id}") + for video_key in video_keys: + try: + s3_client.delete_object(Bucket=s3_bucket_name, Key=video_key) + logger.debug(f"Deleted video: {video_key}") + except ClientError as video_error: + logger.warning(f"Failed to delete video {video_key}: {video_error}") + deleted = True except ClientError as error: if error.response["Error"]["Code"] == "ResourceNotFoundException": @@ -248,7 +305,21 @@ def _generate_presigned_image_url(key: str) -> str: return url -def _map_session(session: dict, user_id: Optional[str] = None) -> Dict[str, Any]: +def _generate_presigned_video_url(key: str) -> str: + url: str = s3_client.generate_presigned_url( + "get_object", + Params={ + "Bucket": s3_bucket_name, + "Key": key, + "ResponseContentType": "video/mp4", + "ResponseCacheControl": "no-cache", + "ResponseContentDisposition": "inline", + }, + ) + return url + + +def _map_session(session: dict, user_id: str | None = None) -> dict[str, Any]: return { "sessionId": session.get("sessionId", None), "name": session.get("name", None), @@ -262,7 +333,7 @@ def _map_session(session: dict, user_id: Optional[str] = None) -> Dict[str, Any] } -def _find_first_human_message(session: dict, user_id: Optional[str] = None) -> str: +def _find_first_human_message(session: dict, user_id: str | None = None) -> str: # Check if session is encrypted if session.get("is_encrypted", False): # For encrypted sessions, decrypt to get the first message @@ -300,7 +371,7 @@ def _find_first_human_message(session: dict, user_id: Optional[str] = None) -> s @api_wrapper -def list_sessions(event: dict, context: dict) -> List[Dict[str, Any]]: +def list_sessions(event: dict, context: dict) -> list[dict[str, Any]]: """List sessions by user ID from DynamoDB.""" user_id = get_username(event) @@ -310,13 +381,22 @@ def list_sessions(event: dict, context: dict) -> List[Dict[str, Any]]: return list(executor.map(lambda session: _map_session(session, user_id), sessions)) -def _process_image(task: Tuple[dict, str]) -> None: +def _process_image(task: tuple[dict, str]) -> None: msg, key = task try: image_url = _generate_presigned_image_url(key) msg["image_url"]["url"] = image_url except Exception as e: - print(f"Error uploading to S3: {e}") + print(f"Error generating presigned image URL: {e}") + + +def _process_video(task: tuple[dict, str]) -> None: + msg, key = task + try: + video_url = _generate_presigned_video_url(key) + msg["video_url"]["url"] = video_url + except Exception as e: + print(f"Error generating presigned video URL: {e}") @api_wrapper @@ -357,16 +437,22 @@ def get_session(event: dict, context: dict) -> dict: resp["configuration"] = configuration # Create a list of tasks for parallel processing - tasks = [] + image_tasks = [] + video_tasks = [] for message in resp.get("history", []): - if isinstance(message.get("content", None), List): + if isinstance(message.get("content", None), list): for item in message.get("content", None): if item.get("type", None) == "image_url": s3_key = item.get("image_url", {}).get("s3_key", None) if s3_key: - tasks.append((item, s3_key)) + image_tasks.append((item, s3_key)) + elif item.get("type", None) == "video_url": + s3_key = item.get("video_url", {}).get("s3_key", None) + if s3_key: + video_tasks.append((item, s3_key)) - list(executor.map(_process_image, tasks)) + list(executor.map(_process_image, image_tasks)) + list(executor.map(_process_video, video_tasks)) return resp # type: ignore [no-any-return] except ValueError as e: return {"statusCode": 400, "body": json.dumps({"error": str(e)})} @@ -383,7 +469,7 @@ def delete_session(event: dict, context: dict) -> dict: @api_wrapper -def delete_user_sessions(event: dict, context: dict) -> Dict[str, bool]: +def delete_user_sessions(event: dict, context: dict) -> dict[str, bool]: """Delete sessions by user ID from DyanmoDB.""" user_id = get_username(event) @@ -395,7 +481,7 @@ def delete_user_sessions(event: dict, context: dict) -> Dict[str, bool]: return {"deleted": True} -@api_wrapper +@api_wrapper(max_request_size=MAX_LARGE_REQUEST_SIZE) def attach_image_to_session(event: dict, context: dict) -> dict: """Append the message to the record in DynamoDB.""" try: @@ -465,7 +551,7 @@ def rename_session(event: dict, context: dict) -> dict: return {"statusCode": 400, "body": json.dumps({"error": str(e)})} -@api_wrapper +@api_wrapper(max_request_size=MAX_LARGE_REQUEST_SIZE) def put_session(event: dict, context: dict) -> dict: """Append the message to the record in DynamoDB.""" try: diff --git a/lambda/utilities/auth.py b/lambda/utilities/auth.py index 47c7db756..b5baba8d2 100644 --- a/lambda/utilities/auth.py +++ b/lambda/utilities/auth.py @@ -16,8 +16,9 @@ import logging import os import secrets +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, Dict, List, Tuple +from typing import Any import boto3 from botocore.config import Config @@ -42,9 +43,9 @@ def get_username(event: dict) -> str: return username -def get_groups(event: Any) -> List[str]: +def get_groups(event: Any) -> list[str]: """Get user groups from event.""" - groups: List[str] = json.loads(event.get("requestContext", {}).get("authorizer", {}).get("groups", "[]")) + groups: list[str] = json.loads(event.get("requestContext", {}).get("authorizer", {}).get("groups", "[]")) return groups @@ -56,12 +57,12 @@ def is_admin(event: dict) -> bool: return admin_group in groups -def get_user_context(event: Dict[str, Any]) -> Tuple[str, bool, List[str]]: +def get_user_context(event: dict[str, Any]) -> tuple[str, bool, list[str]]: """Extract user context from event.""" return get_username(event), is_admin(event), get_groups(event) -def user_has_group_access(user_groups: List[str], allowed_groups: List[str]) -> bool: +def user_has_group_access(user_groups: list[str], allowed_groups: list[str]) -> bool: """ Check if user has access based on group membership. @@ -84,7 +85,7 @@ def admin_only(func: Callable) -> Callable: """Annotation to wrap is_admin""" @wraps(func) - def wrapper(event: Dict[str, Any], context: Dict[str, Any], *args: Any, **kwargs: Any) -> Any: + def wrapper(event: dict[str, Any], context: dict[str, Any], *args: Any, **kwargs: Any) -> Any: if not is_admin(event): raise HTTPException(status_code=403, message="User does not have permission to access this repository") return func(event, context, *args, **kwargs) @@ -96,7 +97,7 @@ def get_management_key() -> str: secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) secret_name = secret_name_param["Parameter"]["Value"] secret_response = secrets_client.get_secret_value(SecretId=secret_name) - return secret_response["SecretString"] + return secret_response["SecretString"] # type: ignore[no-any-return] # API token utility functions diff --git a/lambda/utilities/aws_helpers.py b/lambda/utilities/aws_helpers.py new file mode 100644 index 000000000..1ae06e207 --- /dev/null +++ b/lambda/utilities/aws_helpers.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AWS-specific helper utilities.""" + +import logging +import os +import tempfile +from functools import cache +from typing import Any, cast + +import boto3 +from botocore.config import Config + +logger = logging.getLogger(__name__) + +# Boto3 retry configuration +retry_config = Config( + retries={ + "max_attempts": 3, + "mode": "standard", + }, +) + +# Global SSM client +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + +# Global certificate file handle +_cert_file = None + + +@cache +def get_cert_path(iam_client: Any) -> str | bool: + """ + Get certificate path for SSL validation against LISA Serve endpoint. + + This function retrieves IAM server certificates for SSL verification. + For ACM certificates or when no certificate is specified, it returns + True to use default verification. + + Parameters + ---------- + iam_client : Any + Boto3 IAM client instance. + + Returns + ------- + Union[str, bool] + Path to certificate file, or True to use default verification. + + Example + ------- + >>> iam = boto3.client("iam") + >>> cert_path = get_cert_path(iam) + >>> if isinstance(cert_path, str): + ... # Use custom certificate + ... requests.get(url, verify=cert_path) + ... else: + ... # Use default verification + ... requests.get(url, verify=True) + """ + global _cert_file + + cert_arn = os.environ.get("RESTAPI_SSL_CERT_ARN") + if not cert_arn: + logger.info("No SSL certificate ARN specified, using default verification") + return True + + # For ACM certificates, use default verification since they are trusted AWS certificates + if ":acm:" in cert_arn: + logger.info("ACM certificate detected, using default SSL verification") + return True + + try: + # Clean up previous cert file if it exists + if _cert_file and os.path.exists(_cert_file.name): + try: + os.unlink(_cert_file.name) + except Exception as e: + logger.warning(f"Failed to clean up previous cert file: {e}") + + # Get the certificate name from the ARN + cert_name = cert_arn.split("/")[1] + logger.info(f"Retrieving certificate '{cert_name}' from IAM") + + # Get the certificate from IAM + rest_api_cert = iam_client.get_server_certificate(ServerCertificateName=cert_name) + cert_body = rest_api_cert["ServerCertificate"]["CertificateBody"] + + # Create a new temporary file + _cert_file = tempfile.NamedTemporaryFile(delete=False) + _cert_file.write(cert_body.encode("utf-8")) + _cert_file.flush() + + logger.info(f"Certificate saved to temporary file: {_cert_file.name}") + return _cert_file.name + + except Exception as e: + logger.error(f"Failed to get certificate from IAM: {e}", exc_info=True) + # If we fail to get the cert, return True to fall back to default verification + return True + + +@cache +def get_rest_api_container_endpoint() -> str: + """ + Get REST API container base URI from SSM Parameter Store. + + Returns + ------- + str + The REST API container endpoint URL. + + Example + ------- + >>> endpoint = get_rest_api_container_endpoint() + >>> endpoint + 'https://api.example.com/v1/serve' + """ + lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) + lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"] + return f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve" + + +def _get_lambda_role_arn() -> str: + """ + Get the ARN of the Lambda execution role. + + Returns + ------- + str + The full ARN of the Lambda execution role. + + Example + ------- + >>> _get_lambda_role_arn() + 'arn:aws:sts::123456789012:assumed-role/MyLambdaRole/MyFunction' + """ + sts = boto3.client("sts", region_name=os.environ["AWS_REGION"]) + identity = sts.get_caller_identity() + return cast(str, identity["Arn"]) + + +def get_lambda_role_name() -> str: + """ + Extract the role name from the Lambda execution role ARN. + + Returns + ------- + str + The name of the Lambda execution role without the full ARN. + + Example + ------- + >>> get_lambda_role_name() + 'MyLambdaRole' + """ + arn = _get_lambda_role_arn() + parts = arn.split(":assumed-role/")[1].split("/") + return parts[0] + + +def get_account_and_partition() -> tuple[str, str]: + """ + Get AWS account ID and partition from environment or ECR repository ARN. + + Returns + ------- + tuple[str, str] + Tuple of (account_id, partition). + + Example + ------- + >>> account_id, partition = get_account_and_partition() + >>> account_id + '123456789012' + >>> partition + 'aws' + """ + account_id = os.environ.get("AWS_ACCOUNT_ID", "") + partition = os.environ.get("AWS_PARTITION", "aws") + + if not account_id: + ecr_repo_arn = os.environ.get("ECR_REPOSITORY_ARN", "") + if ecr_repo_arn: + arn_parts = ecr_repo_arn.split(":") + partition = arn_parts[1] + account_id = arn_parts[4] + + return account_id, partition diff --git a/lambda/utilities/bedrock_kb.py b/lambda/utilities/bedrock_kb.py index 8326bfc65..344fef29f 100644 --- a/lambda/utilities/bedrock_kb.py +++ b/lambda/utilities/bedrock_kb.py @@ -22,7 +22,7 @@ import logging import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from models.domain_objects import ( IngestionJob, @@ -45,8 +45,8 @@ def __init__( skipped: int = 0, successful: int = 0, failed: int = 0, - document_ids: Optional[List[str]] = None, - errors: Optional[List[str]] = None, + document_ids: list[str] | None = None, + errors: list[str] | None = None, ): self.discovered = discovered self.skipped = skipped @@ -194,7 +194,7 @@ def discover_and_ingest_documents( logger.error(f"Failed to discover S3 documents: {str(e)}", exc_info=True) raise - def _scan_s3_bucket(self, s3_bucket: str, s3_prefix: str) -> Tuple[List[str], int]: + def _scan_s3_bucket(self, s3_bucket: str, s3_prefix: str) -> tuple[list[str], int]: """ Scan S3 bucket and return list of document keys. @@ -227,10 +227,10 @@ def _scan_s3_bucket(self, s3_bucket: str, s3_prefix: str) -> Tuple[List[str], in return documents_to_process, skipped_count - def _get_collection(self, repository_id: str, collection_id: str) -> Optional[RagCollectionConfig]: + def _get_collection(self, repository_id: str, collection_id: str) -> RagCollectionConfig | None: """Get collection configuration.""" try: - return self.collection_service.get_collection( + return self.collection_service.get_collection( # type: ignore[no-any-return] collection_id=collection_id, repository_id=repository_id, username="system", @@ -250,8 +250,8 @@ def _document_exists(self, repository_id: str, collection_id: str, s3_path: str) def _create_metadata_file( self, - repository: Dict[str, Any], - collection: Optional[RagCollectionConfig], + repository: dict[str, Any], + collection: RagCollectionConfig | None, s3_bucket: str, document_key: str, repository_id: str, @@ -297,7 +297,7 @@ def _create_rag_document( self.rag_document_repository.save(rag_document) return rag_document.document_id - def _trigger_kb_sync(self, repository: Dict[str, Any], collection_id: str, document_count: int) -> None: + def _trigger_kb_sync(self, repository: dict[str, Any], collection_id: str, document_count: int) -> None: """Trigger Bedrock KB sync for ingested documents.""" bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) knowledge_base_id = bedrock_config.get("knowledgeBaseId", bedrock_config.get("bedrockKnowledgeBaseId")) @@ -322,7 +322,7 @@ def _trigger_kb_sync(self, repository: Dict[str, Any], collection_id: str, docum def get_datasource_bucket_for_collection( - repository: Dict[str, Any], + repository: dict[str, Any], collection_id: str, ) -> str: """ @@ -349,7 +349,7 @@ def get_datasource_bucket_for_collection( # Try legacy format first legacy_bucket = bedrock_config.get("bedrockKnowledgeDatasourceS3Bucket") if legacy_bucket: - return legacy_bucket + return legacy_bucket # type: ignore[no-any-return] # Try pipelines array (most common in current configs) pipelines = repository.get("pipelines", []) @@ -359,7 +359,7 @@ def get_datasource_bucket_for_collection( s3_bucket = pipeline.get("s3Bucket") if isinstance(pipeline, dict) else pipeline.s3Bucket if pipeline_collection_id == collection_id and s3_bucket: - return s3_bucket + return s3_bucket # type: ignore[no-any-return] # Try dataSources array data_sources = bedrock_config.get("dataSources", []) @@ -373,7 +373,7 @@ def get_datasource_bucket_for_collection( if s3_uri and s3_uri.startswith("s3://"): bucket = s3_uri[5:].split("/")[0] if bucket: - return bucket + return bucket # type: ignore[no-any-return] logger.error(f"Invalid s3Uri format for data source {ds_id}: {s3_uri}") raise ValueError( @@ -401,7 +401,7 @@ def ingest_document_to_kb( s3_client: Any, bedrock_agent_client: Any, job: IngestionJob, - repository: Dict[str, Any], + repository: dict[str, Any], ) -> None: """ Copy the source object into the KB datasource bucket and trigger ingestion. S3 will @@ -410,6 +410,9 @@ def ingest_document_to_kb( bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) # Get datasource bucket for this collection (supports multiple config formats) + if job.collection_id is None: + raise ValueError("collection_id is required for Bedrock KB operations") + datasource_bucket = get_datasource_bucket_for_collection( repository=repository, collection_id=job.collection_id, @@ -448,12 +451,15 @@ def delete_document_from_kb( s3_client: Any, bedrock_agent_client: Any, job: IngestionJob, - repository: Dict[str, Any], + repository: dict[str, Any], ) -> None: """Remove the source object from the KB datasource bucket and re-sync the KB.""" bedrock_config = repository.get("bedrockKnowledgeBaseConfig", {}) # Get datasource bucket for this collection (supports multiple config formats) + if job.collection_id is None: + raise ValueError("collection_id is required for Bedrock KB operations") + datasource_bucket = get_datasource_bucket_for_collection( repository=repository, collection_id=job.collection_id, @@ -485,9 +491,9 @@ def delete_document_from_kb( def bulk_delete_documents_from_kb( s3_client: Any, bedrock_agent_client: Any, - repository: Dict[str, Any], - s3_paths: List[str], - data_source_id: Optional[str] = None, + repository: dict[str, Any], + s3_paths: list[str], + data_source_id: str | None = None, ) -> None: """Bulk delete documents from KB datasource bucket and trigger single ingestion. @@ -550,8 +556,8 @@ def ingest_bedrock_s3_documents( embedding_model: str, s3_prefix: str = "", batch_size: int = 100, - metadata: Optional[Dict[str, Any]] = None, -) -> Tuple[int, int]: + metadata: dict[str, Any] | None = None, +) -> tuple[int, int]: """ Discover and create ingestion jobs for existing documents in S3 bucket. @@ -643,7 +649,7 @@ def create_s3_scan_job( embedding_model: str, s3_bucket: str, s3_prefix: str = "", - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> str: """ Create a batch ingestion job to scan and ingest existing S3 documents. diff --git a/lambda/utilities/bedrock_kb_discovery.py b/lambda/utilities/bedrock_kb_discovery.py index bea8ffb93..39b286293 100644 --- a/lambda/utilities/bedrock_kb_discovery.py +++ b/lambda/utilities/bedrock_kb_discovery.py @@ -20,7 +20,7 @@ """ import logging -from typing import Any, Dict, List, Optional +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -36,8 +36,8 @@ def list_knowledge_bases( - bedrock_agent_client: Optional[Any] = None, -) -> List[KnowledgeBaseMetadata]: + bedrock_agent_client: Any | None = None, +) -> list[KnowledgeBaseMetadata]: """ List all Knowledge Bases accessible in the AWS account. @@ -93,8 +93,8 @@ def list_knowledge_bases( def discover_kb_data_sources( kb_id: str, - bedrock_agent_client: Optional[Any] = None, -) -> List[DataSourceMetadata]: + bedrock_agent_client: Any | None = None, +) -> list[DataSourceMetadata]: """ Discover all data sources in a Bedrock Knowledge Base. @@ -174,7 +174,7 @@ def discover_kb_data_sources( raise ValidationError(f"Unexpected error discovering data sources: {str(e)}") -def extract_s3_configuration(data_source: Dict[str, Any]) -> Dict[str, str]: +def extract_s3_configuration(data_source: dict[str, Any]) -> dict[str, str]: """Extract S3 bucket and prefix from data source configuration. Args: @@ -199,7 +199,7 @@ def extract_s3_configuration(data_source: Dict[str, Any]) -> Dict[str, str]: def build_pipeline_configs_from_kb_config( kb_config: Any, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Build PipelineConfigs from BedrockKnowledgeBaseConfig. Args: @@ -276,9 +276,9 @@ def build_pipeline_configs_from_kb_config( def get_available_data_sources( kb_id: str, - repository_id: Optional[str] = None, - bedrock_agent_client: Optional[Any] = None, -) -> List[DataSourceMetadata]: + repository_id: str | None = None, + bedrock_agent_client: Any | None = None, +) -> list[DataSourceMetadata]: """ Get all data sources for a Knowledge Base. diff --git a/lambda/utilities/bedrock_kb_validation.py b/lambda/utilities/bedrock_kb_validation.py index a4c65d86f..7b529bcb5 100644 --- a/lambda/utilities/bedrock_kb_validation.py +++ b/lambda/utilities/bedrock_kb_validation.py @@ -15,7 +15,7 @@ """Validation utilities for Bedrock Knowledge Base operations.""" import logging -from typing import Any, Dict, Optional +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -def validate_bedrock_kb_exists(kb_id: str, bedrock_agent_client: Optional[Any] = None) -> Dict[str, Any]: +def validate_bedrock_kb_exists(kb_id: str, bedrock_agent_client: Any | None = None) -> dict[str, Any]: """ Validate that a Bedrock Knowledge Base exists and is accessible. @@ -46,7 +46,7 @@ def validate_bedrock_kb_exists(kb_id: str, bedrock_agent_client: Optional[Any] = kb_config = response.get("knowledgeBase", {}) logger.info(f"Validated Knowledge Base {kb_id}: {kb_config.get('name')}") - return kb_config + return kb_config # type: ignore[no-any-return] except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "") @@ -67,8 +67,8 @@ def validate_bedrock_kb_exists(kb_id: str, bedrock_agent_client: Optional[Any] = def validate_data_source_exists( - kb_id: str, data_source_id: str, bedrock_agent_client: Optional[Any] = None -) -> Dict[str, Any]: + kb_id: str, data_source_id: str, bedrock_agent_client: Any | None = None +) -> dict[str, Any]: """ Validate that a data source exists in a Bedrock Knowledge Base. @@ -91,7 +91,7 @@ def validate_data_source_exists( data_source_config = response.get("dataSource", {}) logger.info(f"Validated Data Source {data_source_id} in KB {kb_id}: " f"{data_source_config.get('name')}") - return data_source_config + return data_source_config # type: ignore[no-any-return] except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "") @@ -113,8 +113,8 @@ def validate_data_source_exists( def validate_bedrock_kb_repository( - kb_id: str, data_source_id: str, bedrock_agent_client: Optional[Any] = None -) -> tuple[Dict[str, Any], Dict[str, Any]]: + kb_id: str, data_source_id: str, bedrock_agent_client: Any | None = None +) -> tuple[dict[str, Any], dict[str, Any]]: """ Validate both Knowledge Base and Data Source exist. diff --git a/lambda/utilities/chunking_strategy_factory.py b/lambda/utilities/chunking_strategy_factory.py index cc7c8b632..0e9762a0e 100644 --- a/lambda/utilities/chunking_strategy_factory.py +++ b/lambda/utilities/chunking_strategy_factory.py @@ -16,7 +16,6 @@ import logging import os from abc import ABC, abstractmethod -from typing import List from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -25,27 +24,30 @@ logger = logging.getLogger(__name__) -DEFAULT_STRATEGY = FixedChunkingStrategy(size=os.getenv("CHUNK_SIZE", "512"), overlap=os.getenv("CHUNK_OVERLAP", "51")) +DEFAULT_STRATEGY = FixedChunkingStrategy( + size=int(os.getenv("CHUNK_SIZE", "512")), + overlap=int(os.getenv("CHUNK_OVERLAP", "51")), +) class ChunkingStrategyHandler(ABC): """Abstract base class for chunking strategy handlers.""" @abstractmethod - def chunk_documents(self, docs: List[Document], strategy: ChunkingStrategy) -> List[Document]: + def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy) -> list[Document]: """ Chunk documents according to the strategy. Parameters ---------- - docs : List[Document] + docs : list[Document] List of documents to chunk strategy : ChunkingStrategy The chunking strategy configuration Returns ------- - List[Document] + list[Document] List of chunked documents """ pass @@ -54,22 +56,26 @@ def chunk_documents(self, docs: List[Document], strategy: ChunkingStrategy) -> L class FixedSizeChunkingHandler(ChunkingStrategyHandler): """Handler for fixed-size chunking strategy.""" - def chunk_documents(self, docs: List[Document], strategy: ChunkingStrategy = DEFAULT_STRATEGY) -> List[Document]: + def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy = DEFAULT_STRATEGY) -> list[Document]: """ Chunk documents using fixed-size strategy with RecursiveCharacterTextSplitter. Parameters ---------- - docs : List[Document] + docs : list[Document] List of documents to chunk strategy : ChunkingStrategy The chunking strategy configuration (FixedChunkingStrategy) Returns ------- - List[Document] + list[Document] List of chunked documents """ + # Ensure we have a FixedChunkingStrategy + if not isinstance(strategy, FixedChunkingStrategy): + raise ValueError(f"Expected FixedChunkingStrategy, got {type(strategy).__name__}") + # Handle both legacy (size/overlap) and new (chunkSize/chunkOverlap) formats chunk_size = strategy.size chunk_overlap = strategy.overlap @@ -94,26 +100,27 @@ def chunk_documents(self, docs: List[Document], strategy: ChunkingStrategy = DEF chunk_overlap=chunk_overlap, length_function=len, ) - return text_splitter.split_documents(docs) # type: ignore [no-any-return] + result: list[Document] = text_splitter.split_documents(docs) + return result class NoneChunkingHandler(ChunkingStrategyHandler): """Handler for no-chunking strategy - returns documents as-is.""" - def chunk_documents(self, docs: List[Document], strategy: ChunkingStrategy) -> List[Document]: + def chunk_documents(self, docs: list[Document], strategy: ChunkingStrategy) -> list[Document]: """ Return documents without chunking. Parameters ---------- - docs : List[Document] + docs : list[Document] List of documents to process strategy : ChunkingStrategy The chunking strategy configuration (NoneChunkingStrategy) Returns ------- - List[Document] + list[Document] Original list of documents unmodified """ logger.info(f"Processing {len(docs)} documents with NONE chunking strategy (no chunking)") @@ -129,20 +136,20 @@ class ChunkingStrategyFactory: } @classmethod - def chunk_documents(cls, docs: List[Document], strategy: ChunkingStrategy = DEFAULT_STRATEGY) -> List[Document]: + def chunk_documents(cls, docs: list[Document], strategy: ChunkingStrategy = DEFAULT_STRATEGY) -> list[Document]: """ Chunk documents using the appropriate strategy handler. Parameters ---------- - docs : List[Document] + docs : list[Document] List of documents to chunk strategy : ChunkingStrategy The chunking strategy configuration Returns ------- - List[Document] + list[Document] List of chunked documents Raises @@ -180,13 +187,13 @@ def register_handler(cls, strategy_type: ChunkingStrategyType, handler: Chunking logger.info(f"Registered chunking strategy handler: {strategy_type.value}") @classmethod - def get_supported_strategies(cls) -> List[ChunkingStrategyType]: + def get_supported_strategies(cls) -> list[ChunkingStrategyType]: """ Get list of supported chunking strategy types. Returns ------- - List[ChunkingStrategyType] + list[ChunkingStrategyType] List of supported strategy types """ return list(cls._handlers.keys()) diff --git a/lambda/utilities/common_functions.py b/lambda/utilities/common_functions.py index 24c70d48e..3027f1032 100644 --- a/lambda/utilities/common_functions.py +++ b/lambda/utilities/common_functions.py @@ -12,37 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Common helper functions for RAG Lambdas.""" -import copy -import functools -import json +""" +Common helper functions for RAG Lambdas. + +DEPRECATED: This module is maintained for backward compatibility. +New code should import from the specific utility modules: + +- lambda_decorators: api_wrapper, authorization_wrapper +- response_builder: generate_html_response, generate_exception_response, DecimalEncoder +- event_parser: get_session_id, get_principal_id, get_bearer_token, get_id_token +- aws_helpers: get_cert_path, get_rest_api_container_endpoint, get_lambda_role_name, get_account_and_partition +- dict_helpers: merge_fields, get_property_path, get_item +""" import logging -import os -import tempfile -from contextvars import ContextVar -from datetime import datetime -from decimal import Decimal -from functools import cache -from typing import Any, Callable, cast, Dict, Optional, TypeVar, Union - -import boto3 -from botocore.config import Config +from collections.abc import Callable +from typing import Any, TypeVar + +# Re-export from organized modules for backward compatibility +from utilities.aws_helpers import ( + get_account_and_partition, + get_cert_path, + get_lambda_role_name, + get_rest_api_container_endpoint, + retry_config, + ssm_client, +) +from utilities.dict_helpers import get_item, get_property_path, merge_fields +from utilities.event_parser import get_bearer_token, get_id_token, get_principal_id, get_session_id +from utilities.lambda_decorators import api_wrapper, authorization_wrapper, ctx_context +from utilities.response_builder import DecimalEncoder, generate_exception_response, generate_html_response from . import create_env_variables # noqa type: ignore -retry_config = Config( - retries={ - "max_attempts": 3, - "mode": "standard", - }, -) -ctx_context: ContextVar[Any] = ContextVar("lamdbacontext") F = TypeVar("F", bound=Callable[..., Any]) logger = logging.getLogger(__name__) logging_configured = False -ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) - class LambdaContextFilter(logging.Filter): """Filter for logging to include request id and function name.""" @@ -98,421 +103,33 @@ def setup_root_logging() -> None: setup_root_logging() -def _sanitize_event(event: Dict[str, Dict[str, Any]]) -> str: - """Sanitize event before logging. - - Parameters - ---------- - event : Dict[str, Dict[str, Any]] - The lambda event. - - Returns - ------- - str - The sanitized event as a JSON-formatted string. - """ - # First normalize keys for our object - sanitized = copy.deepcopy(event) - if "headers" in event: - for key in event["headers"]: - if key != key.lower(): - sanitized["headers"][key.lower()] = event["headers"][key] - del sanitized["headers"][key] - if "multiValueHeaders" in sanitized: - for key in event["multiValueHeaders"]: - if key != key.lower(): - sanitized["multiValueHeaders"][key.lower()] = event["multiValueHeaders"][key] - del sanitized["multiValueHeaders"][key] - - if "headers" in sanitized and "authorization" in sanitized["headers"]: - sanitized["headers"]["authorization"] = "" - if "multiValueHeaders" in sanitized and "authorization" in sanitized["headers"]: - sanitized["multiValueHeaders"]["authorization"] = [""] - return json.dumps(sanitized) - - -def api_wrapper(f: F) -> F: - """Wrap the lambda function. - - Parameters - ---------- - f : F - The function to be wrapped. - - Returns - ------- - F - The wrapped function. - """ - - @functools.wraps(f) - def wrapper(event: dict, context: dict) -> Dict[str, Union[str, int, Dict[str, str]]]: - """Wrap Lambda event. - - Parameters - ---------- - event : dict - Lambda event. - context : dict - Lambda context. - - Returns - ------- - Dict[str, Union[str, int, Dict[str, str]]] - _description_ - """ - ctx_context.set(context) - code_func_name = f.__name__ - lambda_func_name = context.function_name # type: ignore [attr-defined] - logger.info(f"Lambda {lambda_func_name}({code_func_name}) invoked with {_sanitize_event(event)}") - try: - result = f(event, context) - return generate_html_response(200, result) - except Exception as e: - return generate_exception_response(e) - - return wrapper # type: ignore [return-value] - - -def authorization_wrapper(f: F) -> F: - """Wrap the lambda function. - - Parameters - ---------- - f : F - The function to be wrapped. - - Returns - ------- - F - The wrapped function. - """ - - @functools.wraps(f) - def wrapper(event: dict, context: dict) -> F: - """Wrap Lambda event. - - Parameters - ---------- - event : dict - Lambda event. - context : dict - Lambda context. - - Returns - ------- - F - The wrapped function. - """ - ctx_context.set(context) - return f(event, context) # type: ignore [no-any-return] - - return wrapper # type: ignore [return-value] - - -class DecimalEncoder(json.JSONEncoder): - def default(self, obj: Any) -> Any: - if isinstance(obj, Decimal): - return float(obj) - if isinstance(obj, datetime): - return obj.isoformat() - return super().default(obj) - - -def generate_html_response(status_code: int, response_body: dict) -> Dict[str, Union[str, int, Dict[str, str]]]: - """Generate a response for an API call. - - Parameters - ---------- - status_code : int - HTTP status code. - response_body : dict - Response body. - - Returns - ------- - Dict[str, Union[str, int, Dict[str, str]]] - An HTML response. - """ - return { - "statusCode": status_code, - "body": json.dumps(response_body, cls=DecimalEncoder), - "headers": { - "Access-Control-Allow-Origin": "*", - "Content-Type": "application/json", - "Cache-Control": "no-store, no-cache", - "Pragma": "no-cache", - "Strict-Transport-Security": "max-age:47304000; includeSubDomains", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - }, - } - - -def generate_exception_response( - e: Exception, -) -> Dict[str, Union[str, int, Dict[str, str]]]: - """Generate a response for an exception used for all exceptions that are not caught by the API. - - Parameters - ---------- - e : Exception - Exception that was caught. - - Returns - ------- - Dict[str, Union[str, int, Dict[str, str]]] - An HTML response. - """ - # Check for ValidationError from utilities.validation - status_code = 400 - error_message: str - if type(e).__name__ == "ValidationError": - error_message = str(e) - logger.exception(e) - elif hasattr(e, "response"): # i.e. validate the exception was from an API call - metadata = e.response.get("ResponseMetadata") - if metadata: - status_code = metadata.get("HTTPStatusCode", 400) - error_message = str(e) - logger.exception(e) - elif hasattr(e, "http_status_code"): - status_code = e.http_status_code - error_message = getattr(e, "message", str(e)) - logger.exception(e) - elif hasattr(e, "status_code"): - status_code = e.status_code - error_message = getattr(e, "message", str(e)) - logger.exception(e) - else: - error_msg = str(e) - if error_msg in ["'requestContext'", "'pathParameters'", "'body'"]: - error_message = f"Missing event parameter: {error_msg}" - else: - error_message = f"Bad Request: {error_msg}" - logger.exception(e) - return generate_html_response(status_code, error_message) # type: ignore [arg-type] - - -def get_id_token(event: dict) -> str: - """Return token from event request headers. - - Extracts bearer token from authorization header in lambda event. - """ - auth_header = None - - if "authorization" in event["headers"]: - auth_header = event["headers"]["authorization"] - elif "Authorization" in event["headers"]: - auth_header = event["headers"]["Authorization"] - else: - raise ValueError("Missing authorization token.") - - # remove bearer token prefix if present - return str(auth_header).removeprefix("Bearer ").removeprefix("bearer ").strip() - - -_cert_file = None - - -@cache -def get_cert_path(iam_client: Any) -> Union[str, bool]: - """ - Get cert path for IAM certs for SSL validation against LISA Serve endpoint. - - Returns the path to the certificate file for SSL verification, or True to use - default verification if no certificate ARN is specified. - """ - global _cert_file - - cert_arn = os.environ.get("RESTAPI_SSL_CERT_ARN") - if not cert_arn: - logger.info("No SSL certificate ARN specified, using default verification") - return True - # For ACM certificates, use default verification since they are trusted AWS certificates - elif ":acm:" in cert_arn: - logger.info("ACM certificate detected, using default SSL verification") - return True - - try: - # Clean up previous cert file if it exists - if _cert_file and os.path.exists(_cert_file.name): - try: - os.unlink(_cert_file.name) - except Exception as e: - logger.warning(f"Failed to clean up previous cert file: {e}") - - # Get the certificate name from the ARN - cert_name = cert_arn.split("/")[1] - logger.info(f"Retrieving certificate '{cert_name}' from IAM") - - # Get the certificate from IAM - rest_api_cert = iam_client.get_server_certificate(ServerCertificateName=cert_name) - cert_body = rest_api_cert["ServerCertificate"]["CertificateBody"] - - # Create a new temporary file - _cert_file = tempfile.NamedTemporaryFile(delete=False) - _cert_file.write(cert_body.encode("utf-8")) - _cert_file.flush() - - logger.info(f"Certificate saved to temporary file: {_cert_file.name}") - return _cert_file.name - - except Exception as e: - logger.error(f"Failed to get certificate from IAM: {e}", exc_info=True) - # If we fail to get the cert, return True to fall back to default verification - return True - - -@cache -def get_rest_api_container_endpoint() -> str: - """Get REST API container base URI from SSM Parameter Store.""" - lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) - lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"] - return f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve" - - -def get_session_id(event: dict) -> str: - """Get the session ID from the event.""" - session_id: str = event.get("pathParameters", {}).get("sessionId") - return session_id - - -def get_principal_id(event: Any) -> str: - """Get principal from event.""" - principal: str = event.get("requestContext", {}).get("authorizer", {}).get("principal", "") - return principal - - -def merge_fields(source: dict, target: dict, fields: list[str]) -> dict: - """ - Merge specified fields from source dictionary to target dictionary. - Supports both top-level and nested fields using dot notation. - - Args: - source: Source dictionary to copy fields from - target: Target dictionary to copy fields into - fields: List of field names, can use dot notation for nested fields - - Returns: - Updated target dictionary - """ - - def get_nested_value(obj: dict[str, Any], path: list[str]) -> Any: - current: Any = obj - for key in path: - if not isinstance(current, dict): - return None - current = current.get(key) - if current is None: - return None - return current - - def set_nested_value(obj: dict, path: list[str], value: Any) -> None: - current = obj - for key in path[:-1]: - if key not in current: - current[key] = {} - current = current[key] - if value is not None: - current[path[-1]] = value - - for field in fields: - if "." in field: - # Handle nested fields - keys = field.split(".") - value = get_nested_value(source, keys) - if value is not None: - set_nested_value(target, keys, value) - else: - # Handle top-level fields - if field in source: - target[field] = source[field] - - return target - - -def _get_lambda_role_arn() -> str: - """Get the ARN of the Lambda execution role. - - Returns - ------- - str - The full ARN of the Lambda execution role - """ - sts = boto3.client("sts", region_name=os.environ["AWS_REGION"]) - identity = sts.get_caller_identity() - return cast(str, identity["Arn"]) # This will include the role name - - -def get_lambda_role_name() -> str: - """Extract the role name from the Lambda execution role ARN. - - Returns - ------- - str - The name of the Lambda execution role without the full ARN - """ - arn = _get_lambda_role_arn() - parts = arn.split(":assumed-role/")[1].split("/") - return parts[0] # This is the role name - - -def get_item(response: Any) -> Any: - items = response.get("Items", []) - return items[0] if items else None - - -def get_property_path(data: dict[str, Any], property_path: str) -> Optional[Any]: - """Get the value represented by a property path.""" - props = property_path.split(".") - current_node = data - for prop in props: - if prop in current_node: - current_node = current_node[prop] - else: - return None - - return current_node - - -def get_bearer_token(event, with_prefix: bool = True): - """ - Extracts a Bearer token from the Authorization header in a Lambda event. - - Args: - event (dict): AWS Lambda event (API Gateway / ALB proxy style). - - Returns: - str | None: The token string if present and properly formatted, else None. - """ - headers = event.get("headers") or {} - # Headers may vary in casing - auth_header = headers.get("Authorization") or headers.get("authorization") - if not auth_header: - return None - - if not auth_header.lower().startswith("bearer "): - return None - - # Return the token after "Bearer " - return auth_header.split(" ", 1)[1].strip() - - -def get_account_and_partition() -> tuple[str, str]: - """Get AWS account ID and partition from environment or ECR repository ARN. - - Returns: - tuple[str, str]: (account_id, partition) - """ - account_id = os.environ.get("AWS_ACCOUNT_ID", "") - partition = os.environ.get("AWS_PARTITION", "aws") - - if not account_id: - ecr_repo_arn = os.environ.get("ECR_REPOSITORY_ARN", "") - if ecr_repo_arn: - arn_parts = ecr_repo_arn.split(":") - partition = arn_parts[1] - account_id = arn_parts[4] - - return account_id, partition +# Export all public functions for backward compatibility +__all__ = [ + # Lambda decorators + "api_wrapper", + "authorization_wrapper", + "ctx_context", + # Response builders + "generate_html_response", + "generate_exception_response", + "DecimalEncoder", + # Event parsers + "get_session_id", + "get_principal_id", + "get_bearer_token", + "get_id_token", + # AWS helpers + "get_cert_path", + "get_rest_api_container_endpoint", + "get_lambda_role_name", + "get_account_and_partition", + "retry_config", + "ssm_client", + # Dict helpers + "merge_fields", + "get_property_path", + "get_item", + # Logging + "LambdaContextFilter", + "setup_root_logging", +] diff --git a/lambda/utilities/db_setup_iam_auth.py b/lambda/utilities/db_setup_iam_auth.py index 07a65cc7e..46c30c2ba 100644 --- a/lambda/utilities/db_setup_iam_auth.py +++ b/lambda/utilities/db_setup_iam_auth.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import logging import os from typing import Any @@ -19,73 +20,228 @@ import psycopg2 from botocore.exceptions import ClientError +logger = logging.getLogger(__name__) -def get_db_credentials(secret_arn: str) -> Any: - """Retrieve database credentials from Secrets Manager""" + +class IamAuthSetupRequest: + """Request payload for IAM database user setup.""" + + def __init__( + self, + secret_arn: str, + db_host: str, + db_port: int, + db_name: str, + db_user: str, + iam_name: str, + ): + self.secret_arn = secret_arn + self.db_host = db_host + self.db_port = db_port + self.db_name = db_name + self.db_user = db_user + self.iam_name = iam_name + + @classmethod + def from_event(cls, event: dict[str, Any]) -> "IamAuthSetupRequest": + """Parse and validate request from Lambda event payload.""" + required_fields = { + "secretArn": "secret_arn", # pragma: allowlist secret + "dbHost": "db_host", + "dbPort": "db_port", + "dbName": "db_name", + "dbUser": "db_user", + "iamName": "iam_name", + } + + missing = [field for field in required_fields if field not in event] + if missing: + raise ValueError(f"Missing required fields: {', '.join(missing)}") + + return cls( + secret_arn=str(event["secretArn"]), + db_host=str(event["dbHost"]), + db_port=int(event["dbPort"]), + db_name=str(event["dbName"]), + db_user=str(event["dbUser"]), + iam_name=str(event["iamName"]), + ) + + +def get_db_credentials(secret_arn: str) -> Any | None: + """Retrieve database credentials from Secrets Manager. + + Returns None if the secret doesn't exist (already deleted after bootstrap). + """ client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) try: response = client.get_secret_value(SecretId=secret_arn) except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code == "ResourceNotFoundException": + logger.info(f"Bootstrap secret not found (already deleted): {secret_arn}") + return None raise Exception(f"Error retrieving secrets: {e}") secret = response["SecretString"] - secret_dict = json.loads(secret) # Converting string to dictionary + secret_dict = json.loads(secret) return secret_dict -def create_db_user(db_host: str, db_port: str, db_name: str, db_user: str, secret_arn: str, iam_name: str) -> None: - """Create a PostgreSQL user for IAM authentication""" - # Get credentials from Secrets Manager +def delete_bootstrap_secret(secret_arn: str) -> bool: + """Delete the bootstrap password secret from Secrets Manager. + + Returns True if secret was deleted, False if deletion was skipped or failed. + """ + client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) + + try: + client.delete_secret(SecretId=secret_arn, ForceDeleteWithoutRecovery=True) + logger.info(f"Successfully deleted bootstrap secret: {secret_arn}") + return True + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code == "ResourceNotFoundException": + logger.info(f"Bootstrap secret already deleted: {secret_arn}") + return True + logger.error(f"Failed to delete bootstrap secret: {e}") + return False + + +def create_db_user(db_host: str, db_port: str, db_name: str, db_user: str, secret_arn: str, iam_name: str) -> bool: + """Create a PostgreSQL user for IAM authentication. + + Returns True if user was created/updated, False if skipped (secret not found). + """ + logger.info(f"Starting IAM user creation for: {iam_name}") + logger.info(f"Database connection details - host: {db_host}, port: {db_port}, dbname: {db_name}, user: {db_user}") + credentials = get_db_credentials(secret_arn) - # Connect to the database as the admin user - conn = psycopg2.connect(dbname=db_name, user=db_user, password=credentials["password"], host=db_host, port=db_port) + if credentials is None: + logger.info("Bootstrap secret not found - IAM user was likely already created in a previous run") + logger.info("Skipping user creation to avoid errors. If permissions need updating, manually run SQL grants.") + return False + + logger.info("Successfully retrieved bootstrap credentials from Secrets Manager") + logger.info(f"Connecting to database at {db_host}:{db_port}/{db_name} as {db_user}") + + try: + conn = psycopg2.connect( + dbname=db_name, user=db_user, password=credentials["password"], host=db_host, port=db_port + ) + except psycopg2.Error as e: + logger.error(f"Failed to connect to database: {e}") + raise Exception(f"Failed to connect to database: {e}") + cursor = conn.cursor() - # Attempt to create the database user for IAM authentication + # Create vector extension (requires superuser privileges from bootstrap user) try: + logger.info("Creating vector extension if not exists") + cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") + conn.commit() + logger.info("Vector extension created or already exists") + except psycopg2.Error as e: + conn.rollback() + logger.error(f"Error creating vector extension: {e}") + raise Exception(f"Error creating vector extension: {e}") + + try: + logger.info(f"Creating database user: {iam_name}") cursor.execute(f'CREATE USER "{iam_name}"') conn.commit() except psycopg2.Error as e: - # Log but ignore the error if the user already exists - if e.pgcode not in ["23505", "42710"]: # Unique violation error code + conn.rollback() # Must rollback failed transaction before executing more commands + if e.pgcode not in ["23505", "42710"]: + logger.error(f"Error creating user: {e}") raise Exception(f"Error creating user: {e}") + logger.info(f"User {iam_name} already exists (pgcode: {e.pgcode})") - # Other SQL commands to configure user privileges sql_commands = [ f'GRANT rds_iam to "{iam_name}"', + # Schema-level permissions f'GRANT USAGE, CREATE ON SCHEMA public TO "{iam_name}"', + f'GRANT ALL ON SCHEMA public TO "{iam_name}"', + # Existing object permissions f'GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{iam_name}"', f'GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO "{iam_name}"', f'GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA public TO "{iam_name}"', f'GRANT ALL PRIVILEGES ON ALL PROCEDURES IN SCHEMA public TO "{iam_name}"', + # Default privileges for future objects f'ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL PRIVILEGES ON TABLES TO "{iam_name}"', + f'ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL PRIVILEGES ON SEQUENCES TO "{iam_name}"', + f'ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL PRIVILEGES ON FUNCTIONS TO "{iam_name}"', + # Database-level permissions f'GRANT CONNECT ON DATABASE "{db_name}" TO "{iam_name}"', + f'GRANT CREATE ON DATABASE "{db_name}" TO "{iam_name}"', f'GRANT ALL PRIVILEGES ON DATABASE "{db_name}" TO "{iam_name}"', + # RDS-specific admin role (provides elevated privileges without SUPERUSER) + f'GRANT rds_superuser TO "{iam_name}"', ] try: for command in sql_commands: + logger.info(f"Executing: {command}") cursor.execute(command) conn.commit() + logger.info("Successfully granted all privileges to IAM user") except psycopg2.Error as e: + logger.error(f"Error granting privileges to user: {e}") raise Exception(f"Error granting privileges to user: {e}") finally: cursor.close() conn.close() + return True + def handler(event: dict[str, Any], context: Any) -> dict[str, Any]: - """Lambda handler""" - # Extract parameters from the environment and event - secret_arn = os.environ["SECRET_ARN"] - db_host = os.environ["DB_HOST"] - db_port = os.environ["DB_PORT"] - db_name = os.environ["DB_NAME"] - db_user = os.environ["DB_USER"] - iam_name = os.environ["IAM_NAME"] - - # Call function to create DB user - create_db_user(db_host, db_port, db_name, db_user, secret_arn, iam_name) - - return {"statusCode": 200, "body": "Database user created successfully"} + """Lambda handler for IAM database user setup. + + Creates an IAM-authenticated PostgreSQL user. The bootstrap secret is kept + for CloudFormation compatibility (not deleted) even though it won't be used + for authentication after IAM auth is enabled. + """ + logger.info(f"IAM auth setup Lambda invoked with event: {json.dumps(event)}") + + try: + request = IamAuthSetupRequest.from_event(event) + logger.info( + f"""Parsed request - dbHost: {request.db_host}, dbPort: {request.db_port}, dbName: {request.db_name}, + iamName: {request.iam_name}""" + ) + except (ValueError, KeyError, TypeError) as e: + logger.error(f"Invalid request payload: {e}") + return {"statusCode": 400, "body": json.dumps({"error": f"Invalid request payload: {e}"})} + + try: + user_created = create_db_user( + request.db_host, + str(request.db_port), + request.db_name, + request.db_user, + request.secret_arn, + request.iam_name, + ) + + # Note: We no longer delete the bootstrap secret to maintain CloudFormation compatibility + # The secret remains but is not used for authentication when IAM auth is enabled + logger.info("IAM user setup complete. Bootstrap secret retained for CloudFormation compatibility.") + + result = { + "statusCode": 200, + "body": json.dumps( + { + "message": "Database user setup complete", + "userCreated": user_created, + "secretDeleted": False, # Secret is retained + } + ), + } + logger.info(f"IAM auth setup completed successfully: {result}") + return result + + except Exception as e: + logger.error(f"IAM auth setup failed: {e}") + return {"statusCode": 500, "body": json.dumps({"error": str(e)})} diff --git a/lambda/utilities/dict_helpers.py b/lambda/utilities/dict_helpers.py new file mode 100644 index 000000000..3db830ed8 --- /dev/null +++ b/lambda/utilities/dict_helpers.py @@ -0,0 +1,140 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic dictionary manipulation utilities.""" + +from typing import Any + + +def merge_fields(source: dict, target: dict, fields: list[str]) -> dict: + """ + Merge specified fields from source dictionary to target dictionary. + + Supports both top-level and nested fields using dot notation. + + Parameters + ---------- + source : dict + Source dictionary to copy fields from. + target : dict + Target dictionary to copy fields into. + fields : list[str] + List of field names, can use dot notation for nested fields. + + Returns + ------- + dict + Updated target dictionary. + + Example + ------- + >>> source = {"user": {"name": "John", "age": 30}, "status": "active"} + >>> target = {"id": "123"} + >>> merge_fields(source, target, ["user.name", "status"]) + {'id': '123', 'user': {'name': 'John'}, 'status': 'active'} + """ + + def get_nested_value(obj: dict[str, Any], path: list[str]) -> Any: + """Get value from nested dictionary using path.""" + current: Any = obj + for key in path: + if not isinstance(current, dict): + return None + current = current.get(key) + if current is None: + return None + return current + + def set_nested_value(obj: dict, path: list[str], value: Any) -> None: + """Set value in nested dictionary using path.""" + current = obj + for key in path[:-1]: + if key not in current: + current[key] = {} + current = current[key] + if value is not None: + current[path[-1]] = value + + for field in fields: + if "." in field: + # Handle nested fields + keys = field.split(".") + value = get_nested_value(source, keys) + if value is not None: + set_nested_value(target, keys, value) + else: + # Handle top-level fields + if field in source: + target[field] = source[field] + + return target + + +def get_property_path(data: dict[str, Any], property_path: str) -> Any | None: + """ + Get value from nested dictionary using dot-notation path. + + Parameters + ---------- + data : dict[str, Any] + Dictionary to extract value from. + property_path : str + Dot-notation path to the property (e.g., "user.address.city"). + + Returns + ------- + Optional[Any] + The value at the specified path, or None if path doesn't exist. + + Example + ------- + >>> data = {"user": {"address": {"city": "Seattle"}}} + >>> get_property_path(data, "user.address.city") + 'Seattle' + >>> get_property_path(data, "user.phone") + None + """ + props = property_path.split(".") + current_node = data + for prop in props: + if prop in current_node: + current_node = current_node[prop] + else: + return None + + return current_node + + +def get_item(response: Any) -> Any: + """ + Extract first item from DynamoDB query/scan response. + + Parameters + ---------- + response : Any + DynamoDB query or scan response. + + Returns + ------- + Any + First item from the response, or None if no items. + + Example + ------- + >>> response = {"Items": [{"id": "123", "name": "John"}]} + >>> get_item(response) + {'id': '123', 'name': 'John'} + """ + items = response.get("Items", []) + return items[0] if items else None diff --git a/lambda/utilities/event_parser.py b/lambda/utilities/event_parser.py new file mode 100644 index 000000000..50a104e1a --- /dev/null +++ b/lambda/utilities/event_parser.py @@ -0,0 +1,206 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for parsing API Gateway Lambda events.""" + +import copy +import json +from typing import Any + +from utilities.header_sanitizer import sanitize_headers + + +def sanitize_event_for_logging(event: dict[str, Any]) -> str: + """ + Sanitize Lambda event before logging. + + This function sanitizes the event by: + 1. Normalizing header keys to lowercase + 2. Redacting authorization headers + 3. Replacing security-critical headers with server-controlled values + + Parameters + ---------- + event : Dict[str, Any] + The Lambda event from API Gateway. + + Returns + ------- + str + The sanitized event as a JSON-formatted string. + + Example + ------- + >>> event = { + ... "headers": {"Authorization": "Bearer token123"}, + ... "path": "/users/123" + ... } + >>> sanitized = sanitize_event_for_logging(event) + >>> "token123" in sanitized + False + """ + # Deep copy to avoid modifying original event + sanitized = copy.deepcopy(event) + + # Normalize header keys to lowercase + if "headers" in event: + for key in event["headers"]: + if key != key.lower(): + sanitized["headers"][key.lower()] = event["headers"][key] + del sanitized["headers"][key] + + if "multiValueHeaders" in sanitized: + for key in event["multiValueHeaders"]: + if key != key.lower(): + sanitized["multiValueHeaders"][key.lower()] = event["multiValueHeaders"][key] + del sanitized["multiValueHeaders"][key] + + # Redact authorization headers + if "headers" in sanitized and "authorization" in sanitized["headers"]: + sanitized["headers"]["authorization"] = "" + if "multiValueHeaders" in sanitized and "authorization" in sanitized["multiValueHeaders"]: + sanitized["multiValueHeaders"]["authorization"] = [""] + + # Sanitize security-critical headers to prevent log injection + if "headers" in sanitized: + sanitized["headers"] = sanitize_headers(sanitized["headers"], event) + + return json.dumps(sanitized) + + +def get_session_id(event: dict) -> str: + """ + Extract session ID from Lambda event path parameters. + + Parameters + ---------- + event : dict + Lambda event from API Gateway. + + Returns + ------- + str + The session ID from path parameters. + + Example + ------- + >>> event = {"pathParameters": {"sessionId": "sess-123"}} + >>> get_session_id(event) + 'sess-123' + """ + session_id: str = event.get("pathParameters", {}).get("sessionId") + return session_id + + +def get_principal_id(event: dict) -> str: + """ + Extract principal ID from Lambda event authorizer context. + + Parameters + ---------- + event : dict + Lambda event from API Gateway. + + Returns + ------- + str + The principal ID from authorizer context. + + Example + ------- + >>> event = { + ... "requestContext": { + ... "authorizer": {"principal": "user-123"} + ... } + ... } + >>> get_principal_id(event) + 'user-123' + """ + principal: str = event.get("requestContext", {}).get("authorizer", {}).get("principal", "") + return principal + + +def get_bearer_token(event: dict) -> str | None: + """ + Extract Bearer token from Authorization header in Lambda event. + + Parameters + ---------- + event : dict + Lambda event from API Gateway. + + Returns + ------- + Optional[str] + The token string if present and properly formatted, else None. + + Example + ------- + >>> event = {"headers": {"Authorization": "Bearer abc123"}} + >>> get_bearer_token(event) + 'abc123' + """ + headers = event.get("headers") or {} + # Headers may vary in casing + auth_header: str | None = headers.get("Authorization") or headers.get("authorization") + if not auth_header: + return None + + if not auth_header.lower().startswith("bearer "): + return None + + # Return the token after "Bearer " + token: str = auth_header.split(" ", 1)[1].strip() + return token + + +def get_id_token(event: dict) -> str: + """ + Extract ID token from Authorization header in Lambda event. + + This function extracts the bearer token from the authorization header, + removing the "Bearer" prefix if present. + + Parameters + ---------- + event : dict + Lambda event from API Gateway. + + Returns + ------- + str + The ID token without the "Bearer" prefix. + + Raises + ------ + ValueError + If authorization header is missing. + + Example + ------- + >>> event = {"headers": {"Authorization": "Bearer token123"}} + >>> get_id_token(event) + 'token123' + """ + auth_header = None + + if "authorization" in event["headers"]: + auth_header = event["headers"]["authorization"] + elif "Authorization" in event["headers"]: + auth_header = event["headers"]["Authorization"] + else: + raise ValueError("Missing authorization token.") + + # Remove bearer token prefix if present + return str(auth_header).removeprefix("Bearer ").removeprefix("bearer ").strip() diff --git a/lambda/utilities/exceptions.py b/lambda/utilities/exceptions.py index f8815da9e..c24fae60d 100644 --- a/lambda/utilities/exceptions.py +++ b/lambda/utilities/exceptions.py @@ -35,3 +35,8 @@ def __init__(self, detail: str = "Not Found"): class UnauthorizedException(HTTPException): def __init__(self, detail: str = "Unauthorized"): super().__init__(401, detail) # flake8: noqa + + +class ForbiddenException(HTTPException): + def __init__(self, detail: str = "Forbidden"): + super().__init__(403, detail) # flake8: noqa diff --git a/lambda/utilities/fastapi_factory.py b/lambda/utilities/fastapi_factory.py new file mode 100644 index 000000000..0bd05c0ba --- /dev/null +++ b/lambda/utilities/fastapi_factory.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating FastAPI applications with standard LISA configuration.""" + +from fastapi import FastAPI, Request +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from utilities.exceptions import ForbiddenException, HTTPException, NotFoundException, UnauthorizedException +from utilities.fastapi_middleware.aws_api_gateway_middleware import AWSAPIGatewayMiddleware +from utilities.fastapi_middleware.exception_handlers import generic_exception_handler +from utilities.fastapi_middleware.input_validation_middleware import InputValidationMiddleware +from utilities.fastapi_middleware.request_logging_middleware import RequestLoggingMiddleware +from utilities.fastapi_middleware.security_headers_middleware import SecurityHeadersMiddleware + + +def create_fastapi_app() -> FastAPI: + """ + Create a FastAPI application with standard LISA configuration. + + This factory function creates a FastAPI app with: + - Standard FastAPI settings (redirect_slashes, lifespan, docs) + - Input validation middleware (null bytes, request size, HTTP methods) + - AWS API Gateway middleware (extracts Lambda event context) + - Request logging middleware (audit trail with sanitized data) + - Security headers middleware (HSTS, X-Frame-Options, etc.) + - CORS middleware with permissive settings + - Request validation exception handler (422 errors) + - Generic exception handler (500 errors) + + Middleware execution order (IMPORTANT): + 1. InputValidationMiddleware - Validates input FIRST (security) + 2. AWSAPIGatewayMiddleware - Extracts AWS event context + 3. RequestLoggingMiddleware - Logs requests with sanitized data + 4. SecurityHeadersMiddleware - Adds security headers to responses + 5. CORSMiddleware - Handles CORS (last middleware) + + Returns: + FastAPI: Configured FastAPI application instance + + Example: + >>> from utilities.fastapi_factory import create_fastapi_app + >>> app = create_fastapi_app() + >>> # Add domain-specific exception handlers + >>> @app.exception_handler(MyCustomError) + >>> async def my_handler(request, exc): + >>> return JSONResponse(status_code=404, content={"error": str(exc)}) + """ + # Create FastAPI app with standard settings + app = FastAPI( + redirect_slashes=False, + lifespan="off", + docs_url="/docs", + openapi_url="/openapi.json", + ) + + # Add middleware in reverse order (last added = first executed) + # Middleware execution order: InputValidation -> AWSAPIGateway -> RequestLogging -> SecurityHeaders -> CORS + + # CORS middleware (executed last, added first) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Security headers middleware (adds HSTS, X-Frame-Options, etc.) + app.add_middleware(SecurityHeadersMiddleware) + + # Request logging middleware (logs all requests with sanitized data) + app.add_middleware(RequestLoggingMiddleware) + + # AWS API Gateway middleware (extracts Lambda event context) + app.add_middleware(AWSAPIGatewayMiddleware) + + # Input validation middleware (must be executed first for security) + app.add_middleware(InputValidationMiddleware) + + # Register standard exception handlers + + # HTTP exceptions (401, 403, 404, etc.) + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: + """Handle custom HTTP exceptions and translate to appropriate status codes.""" + return JSONResponse(status_code=exc.http_status_code, content={"message": exc.message}) + + # Convenience aliases for specific HTTP exceptions (for direct import in tests) + @app.exception_handler(UnauthorizedException) + async def unauthorized_handler(request: Request, exc: UnauthorizedException) -> JSONResponse: + """Handle unauthorized exceptions and translate to a 401 error.""" + return JSONResponse(status_code=401, content={"message": exc.message}) + + @app.exception_handler(ForbiddenException) + async def forbidden_handler(request: Request, exc: ForbiddenException) -> JSONResponse: + """Handle forbidden exceptions and translate to a 403 error.""" + return JSONResponse(status_code=403, content={"message": exc.message}) + + @app.exception_handler(NotFoundException) + async def not_found_handler(request: Request, exc: NotFoundException) -> JSONResponse: + """Handle not found exceptions and translate to a 404 error.""" + return JSONResponse(status_code=404, content={"message": exc.message}) + + # Request validation errors (422) + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: + """Handle exception when request fails validation and translate to a 422 error.""" + return JSONResponse( + status_code=422, + content={"detail": jsonable_encoder(exc.errors()), "type": "RequestValidationError"}, + ) + + # Generic exception handler (500) - must be registered last + @app.exception_handler(Exception) + async def handle_generic_exception(request: Request, exc: Exception) -> JSONResponse: + """Handle all unhandled exceptions - delegates to common handler.""" + return await generic_exception_handler(request, exc) + + return app diff --git a/lambda/utilities/fastapi_middleware/exception_handlers.py b/lambda/utilities/fastapi_middleware/exception_handlers.py new file mode 100644 index 000000000..b416446b3 --- /dev/null +++ b/lambda/utilities/fastapi_middleware/exception_handlers.py @@ -0,0 +1,62 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common exception handlers for FastAPI applications.""" + +import logging + +from fastapi import Request +from fastapi.responses import JSONResponse + +logger = logging.getLogger(__name__) + + +async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """ + Handle all unhandled exceptions. + + This handler catches any exceptions not handled by more specific handlers. + It logs detailed error information internally but returns a generic message + to the client to avoid exposing internal implementation details. + + Security Note: Never expose internal details (stack traces, file paths, etc.) + in error responses as they can aid attackers in reconnaissance. + + Args: + request: The FastAPI request object + exc: The exception that was raised + + Returns: + JSONResponse with 500 status code and generic error message + """ + # Log detailed error information for debugging + logger.error( + f"Unhandled exception in {request.method} {request.url.path}", + exc_info=exc, + extra={ + "method": request.method, + "path": request.url.path, + "exception_type": type(exc).__name__, + "exception_message": str(exc), + }, + ) + + # Return generic error message to client + return JSONResponse( + status_code=500, + content={ + "error": "Internal Server Error", + "message": "An unexpected error occurred while processing your request", + }, + ) diff --git a/lambda/utilities/fastapi_middleware/input_validation_middleware.py b/lambda/utilities/fastapi_middleware/input_validation_middleware.py new file mode 100644 index 000000000..43a4b92ad --- /dev/null +++ b/lambda/utilities/fastapi_middleware/input_validation_middleware.py @@ -0,0 +1,298 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Middleware for FastAPI that validates and sanitizes input to prevent security vulnerabilities.""" + +import html +import logging +import re + +from fastapi import status +from fastapi.responses import JSONResponse +from starlette.middleware.base import ASGIApp, BaseHTTPMiddleware, Request, RequestResponseEndpoint, Response + +logger = logging.getLogger(__name__) + +# Default maximum request size: 1MB +DEFAULT_MAX_REQUEST_SIZE = 1024 * 1024 + + +def sanitize_input(data: str) -> str: + """ + Sanitize string input by removing or escaping dangerous characters. + + This function: + - Escapes HTML/XML special characters to prevent XSS + - Removes script tags and their content + - Preserves legitimate special characters (hyphens, underscores, etc.) + + Args: + data: String to sanitize + + Returns: + Sanitized string safe for processing + """ + if not data: + return data + + # Remove script tags and their content (case-insensitive) + data = re.sub(r"]*>.*?", "", data, flags=re.IGNORECASE | re.DOTALL) + + # Escape HTML special characters to prevent XSS + # This preserves legitimate characters like hyphens, underscores, etc. + data = html.escape(data) + + return data + + +class InputValidationMiddleware(BaseHTTPMiddleware): + """ + Middleware that validates and sanitizes all incoming requests. + + This middleware provides security protections against: + - Null byte injection attacks + - Oversized payload attacks + - Special character injection + + It intercepts all requests before they reach the application handlers + and returns appropriate HTTP error codes for invalid input. + """ + + def __init__(self, app: ASGIApp, max_request_size: int = DEFAULT_MAX_REQUEST_SIZE) -> None: + """ + Initialize the input validation middleware. + + Args: + app: The ASGI application + max_request_size: Maximum allowed request body size in bytes (default: 1MB) + """ + super().__init__(app) + self.app = app + self.max_request_size = max_request_size + + def contains_null_bytes(self, data: str) -> bool: + """ + Check if a string contains null bytes. + + Null bytes (\\x00) can be used to bypass input validation or cause + unexpected behavior in string processing. + + Args: + data: String to check for null bytes + + Returns: + True if null bytes are found, False otherwise + """ + return "\x00" in data + + async def check_request_size(self, request: Request) -> JSONResponse | None: + """ + Validate that the request body size does not exceed the configured limit. + + Args: + request: The incoming HTTP request + + Returns: + JSONResponse with 413 status if size exceeds limit, None otherwise + """ + content_length = request.headers.get("content-length") + if content_length: + try: + size = int(content_length) + if size > self.max_request_size: + logger.warning( + f"Request size {size} bytes exceeds maximum {self.max_request_size} bytes", + extra={ + "request_size": size, + "max_size": self.max_request_size, + "path": request.url.path, + "method": request.method, + }, + ) + return JSONResponse( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + content={ + "error": "Payload Too Large", + "message": ( + f"Request body size exceeds maximum allowed size " f"of {self.max_request_size} bytes" + ), + }, + ) + except ValueError: + # Invalid content-length header, let it pass and fail later if needed + logger.warning(f"Invalid content-length header: {content_length}") + + return None + + async def validate_query_params(self, request: Request) -> JSONResponse | None: + """ + Validate query parameters for null bytes. + + Args: + request: The incoming HTTP request + + Returns: + JSONResponse with 400 status if null bytes found, None otherwise + """ + for key, value in request.query_params.items(): + if self.contains_null_bytes(key) or self.contains_null_bytes(value): + logger.warning( + f"Null byte detected in query parameter: {key}", + extra={ + "parameter_name": key, + "path": request.url.path, + "method": request.method, + }, + ) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={ + "error": "Bad Request", + "message": "Invalid characters detected in query parameters", + }, + ) + return None + + async def validate_path_params(self, request: Request) -> JSONResponse | None: + """ + Validate path parameters for null bytes. + + Args: + request: The incoming HTTP request + + Returns: + JSONResponse with 400 status if null bytes found, None otherwise + """ + path = str(request.url.path) + if self.contains_null_bytes(path): + logger.warning( + "Null byte detected in path", + extra={ + "path": path, + "method": request.method, + }, + ) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={ + "error": "Bad Request", + "message": "Invalid characters detected in request path", + }, + ) + return None + + async def validate_request_body(self, request: Request) -> JSONResponse | None: + """ + Validate request body for null bytes. + + This reads the request body and checks for null bytes. If found, + returns an error response. Otherwise, the body is consumed and needs + to be restored for downstream handlers. + + Args: + request: The incoming HTTP request + + Returns: + JSONResponse with 400 status if null bytes found, None otherwise + """ + # Only check body for methods that typically have a body + if request.method in ("POST", "PUT", "PATCH"): + try: + body = await request.body() + if body: + # Check for null bytes in the raw body + if b"\x00" in body: + logger.warning( + "Null byte detected in request body", + extra={ + "path": request.url.path, + "method": request.method, + "body_size": len(body), + }, + ) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={ + "error": "Bad Request", + "message": "Invalid characters detected in request body", + }, + ) + except Exception as e: + # If we can't read the body, let it pass and fail later with proper error handling + logger.warning(f"Error reading request body for validation: {e}") + + return None + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + """ + Process the request through validation checks before passing to handlers. + + Validation order: + 1. HTTP method validation (returns 405 if invalid) + 2. Request size check (returns 413 if too large) + 3. Path parameter validation (returns 400 if null bytes found) + 4. Query parameter validation (returns 400 if null bytes found) + 5. Request body validation (returns 400 if null bytes found) + + Args: + request: The incoming HTTP request + call_next: The next middleware or handler in the chain + + Returns: + Response from validation or from the next handler + """ + # Validate HTTP method + # FastAPI will handle method validation at the route level, but we add + # this as a safety check for any routes that might not be properly configured + valid_methods = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"} + if request.method not in valid_methods: + logger.warning( + f"Invalid HTTP method: {request.method}", + extra={ + "method": request.method, + "path": request.url.path, + }, + ) + return JSONResponse( + status_code=status.HTTP_405_METHOD_NOT_ALLOWED, + content={ + "error": "Method Not Allowed", + "message": f"HTTP method {request.method} is not allowed", + }, + headers={"Allow": ", ".join(sorted(valid_methods))}, + ) + + # Check request size + size_error = await self.check_request_size(request) + if size_error: + return size_error + + # Validate path parameters + path_error = await self.validate_path_params(request) + if path_error: + return path_error + + # Validate query parameters + query_error = await self.validate_query_params(request) + if query_error: + return query_error + + # Validate request body + body_error = await self.validate_request_body(request) + if body_error: + return body_error + + # All validations passed, proceed to next handler + response = await call_next(request) + return response diff --git a/lambda/utilities/fastapi_middleware/request_logging_middleware.py b/lambda/utilities/fastapi_middleware/request_logging_middleware.py new file mode 100644 index 000000000..105cfe76d --- /dev/null +++ b/lambda/utilities/fastapi_middleware/request_logging_middleware.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Middleware for logging all incoming requests to FastAPI applications.""" + +import json +import logging +import time +from typing import Any + +from starlette.middleware.base import BaseHTTPMiddleware, Request, RequestResponseEndpoint, Response +from utilities.header_sanitizer import sanitize_headers + +logger = logging.getLogger(__name__) + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + """ + Middleware that logs all incoming requests with sanitized data. + + This middleware provides: + - Automatic logging of all requests (method, path, headers, params) + - Header sanitization (redacts auth, replaces user-controlled headers) + - Request timing (duration in milliseconds) + - User context extraction (username, groups, auth type) + - Correlation IDs for request tracing + + Security features: + - Authorization headers are redacted + - User-controlled headers (x-forwarded-for) replaced with server values + - Real client IP extracted from API Gateway context + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + """ + Process the request, log details, and pass to next handler. + + Args: + request: The incoming HTTP request + call_next: The next middleware or handler in the chain + + Returns: + Response from the next handler + """ + # Start timing + start_time = time.time() + + # Extract AWS event from request scope (set by AWSAPIGatewayMiddleware) + event = request.scope.get("aws.event", {}) + + # Build sanitized request data for logging + log_data = self._build_log_data(request, event) + + # Log the incoming request + logger.info( + f"Request: {request.method} {request.url.path}", + extra=log_data, + ) + + # Process the request + response = await call_next(request) + + # Calculate request duration + duration_ms = (time.time() - start_time) * 1000 + + # Log the response + logger.info( + f"Response: {request.method} {request.url.path} - {response.status_code} ({duration_ms:.2f}ms)", + extra={ + "method": request.method, + "path": request.url.path, + "status_code": response.status_code, + "duration_ms": duration_ms, + "request_id": event.get("requestContext", {}).get("requestId"), + }, + ) + + return response + + def _build_log_data(self, request: Request, event: dict[str, Any]) -> dict[str, Any]: + """ + Build sanitized log data from request and AWS event. + + Args: + request: The FastAPI request object + event: The AWS Lambda event (from API Gateway) + + Returns: + Dictionary with sanitized request data for logging + """ + # Extract request context + request_context = event.get("requestContext", {}) + authorizer = request_context.get("authorizer", {}) + identity = request_context.get("identity", {}) + + # Sanitize headers (redact auth, replace user-controlled headers) + raw_headers = dict(request.headers) + sanitized_headers = sanitize_headers(raw_headers, event) + + # Build log data + log_data = { + "method": request.method, + "path": request.url.path, + "query_params": dict(request.query_params), + "headers": sanitized_headers, + "request_id": request_context.get("requestId"), + "source_ip": identity.get("sourceIp"), # Real IP from API Gateway + "user_agent": identity.get("userAgent"), + "user": { + "username": authorizer.get("username"), + "groups": authorizer.get("groups", []), + "auth_type": authorizer.get("authType"), + }, + } + + # Add path parameters if present + if hasattr(request, "path_params") and request.path_params: + log_data["path_params"] = dict(request.path_params) + + return log_data + + def _sanitize_body(self, body: bytes) -> str: + """ + Sanitize request body for logging. + + Attempts to parse as JSON and redact sensitive fields. + If parsing fails, returns a placeholder. + + Args: + body: Raw request body bytes + + Returns: + Sanitized body as string + """ + if not body: + return "" + + try: + # Try to parse as JSON + body_json = json.loads(body) + + # Redact sensitive fields + sensitive_fields = ["password", "token", "secret", "apiKey", "api_key"] + for field in sensitive_fields: + if field in body_json: + body_json[field] = "" + + return json.dumps(body_json) + except (json.JSONDecodeError, UnicodeDecodeError): + # Not JSON or can't decode - return placeholder + return f"" diff --git a/lambda/utilities/fastapi_middleware/security_headers_middleware.py b/lambda/utilities/fastapi_middleware/security_headers_middleware.py new file mode 100644 index 000000000..6351da5e1 --- /dev/null +++ b/lambda/utilities/fastapi_middleware/security_headers_middleware.py @@ -0,0 +1,68 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Middleware for adding security headers to all FastAPI responses.""" + +from starlette.middleware.base import BaseHTTPMiddleware, Request, RequestResponseEndpoint, Response + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """ + Middleware that adds security headers to all HTTP responses. + + Security headers included: + - Strict-Transport-Security: Forces HTTPS connections + - X-Content-Type-Options: Prevents MIME sniffing attacks + - X-Frame-Options: Prevents clickjacking attacks + - Cache-Control: Prevents caching of sensitive data + - Pragma: Legacy cache control for HTTP/1.0 + - Content-Type: Ensures JSON responses are properly typed + + These headers protect against common web vulnerabilities and ensure + secure communication between client and server. + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + """ + Process the request and add security headers to the response. + + Args: + request: The incoming HTTP request + call_next: The next middleware or handler in the chain + + Returns: + Response with security headers added + """ + # Call the next handler to get the response + response = await call_next(request) + + # Add security headers to the response + # HSTS: Force HTTPS for 547 days (47304000 seconds) including subdomains + response.headers["Strict-Transport-Security"] = "max-age=47304000; includeSubDomains" + + # Prevent MIME sniffing (forces browser to respect Content-Type) + response.headers["X-Content-Type-Options"] = "nosniff" + + # Prevent clickjacking by disallowing iframe embedding + response.headers["X-Frame-Options"] = "DENY" + + # Prevent caching of sensitive data + response.headers["Cache-Control"] = "no-store, no-cache" + response.headers["Pragma"] = "no-cache" + + # Ensure Content-Type is set (FastAPI usually sets this, but we ensure it) + if "Content-Type" not in response.headers: + response.headers["Content-Type"] = "application/json" + + return response diff --git a/lambda/utilities/file_processing.py b/lambda/utilities/file_processing.py index e8a557539..4e92d02e1 100644 --- a/lambda/utilities/file_processing.py +++ b/lambda/utilities/file_processing.py @@ -132,7 +132,7 @@ def _extract_text_content(s3_object: dict) -> str: ---------- s3_object (dict): an S3 object containing a text file body. """ - return s3_object["Body"].read().decode("utf-8", errors="replace") + return s3_object["Body"].read().decode("utf-8", errors="replace") # type: ignore[no-any-return] def generate_chunks(ingestion_job: IngestionJob) -> list[Document]: @@ -184,8 +184,11 @@ def generate_chunks(ingestion_job: IngestionJob) -> list[Document]: ] # Use factory to chunk documents based on strategy - logger.info(f"Processing document with chunking strategy: {ingestion_job.chunk_strategy.type}") - doc_chunks = ChunkingStrategyFactory.chunk_documents(docs, ingestion_job.chunk_strategy) + chunk_strategy = ingestion_job.chunk_strategy + if chunk_strategy is None: + raise ValueError("Chunking strategy is required") + logger.info(f"Processing document with chunking strategy: {chunk_strategy.type}") + doc_chunks = ChunkingStrategyFactory.chunk_documents(docs, chunk_strategy) # Update part number of doc metadata for i, doc in enumerate(doc_chunks): diff --git a/lambda/utilities/header_sanitizer.py b/lambda/utilities/header_sanitizer.py new file mode 100644 index 000000000..f1ef5a34c --- /dev/null +++ b/lambda/utilities/header_sanitizer.py @@ -0,0 +1,157 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility for sanitizing HTTP headers before logging to prevent log injection attacks.""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +# Security-critical headers that should be replaced with server-controlled values +SECURITY_CRITICAL_HEADERS = { + "x-forwarded-for", + "x-forwarded-host", + "x-forwarded-server", + "x-amzn-client-id", + "x-real-ip", + "forwarded", +} + + +def get_real_client_ip(event: dict[str, Any]) -> str: + """ + Extract the real client IP address from API Gateway event context. + + This function retrieves the actual source IP from the API Gateway request context, + which cannot be spoofed by the client. User-provided headers like x-forwarded-for + should never be trusted for security-critical operations. + + Args: + event: Lambda event from API Gateway containing requestContext + + Returns: + Real client IP address from API Gateway, or "unknown" if not available + """ + try: + # API Gateway provides the real source IP in requestContext.identity.sourceIp + # This value is set by AWS and cannot be manipulated by the client + source_ip: str | None = event.get("requestContext", {}).get("identity", {}).get("sourceIp") + if source_ip: + return source_ip + + # Fallback: check if this is a direct Lambda invocation (testing) + logger.warning("No sourceIp found in API Gateway event context") + return "unknown" + + except Exception as e: + logger.error(f"Error extracting real client IP: {e}") + return "unknown" + + +def sanitize_headers(headers: dict[str, Any], event: dict[str, Any]) -> dict[str, Any]: + """ + Sanitize HTTP headers by replacing user-controlled values with server-controlled values. + + This prevents attackers from manipulating security-critical headers in logs, + which could be used to hide their true source IP or manipulate audit trails. + + Args: + headers: Original HTTP headers from the request + event: Lambda event from API Gateway (used to extract real values) + + Returns: + Dictionary of sanitized headers with security-critical values replaced + + Example: + >>> headers = {"x-forwarded-for": "1.2.3.4, 5.6.7.8"} + >>> event = {"requestContext": {"identity": {"sourceIp": "9.10.11.12"}}} + >>> sanitized = sanitize_headers(headers, event) + >>> sanitized["x-forwarded-for"] + "9.10.11.12" + """ + if not headers: + return {} + + # Create a copy to avoid modifying the original + sanitized = dict(headers) + + # Get the real client IP from API Gateway + real_ip = get_real_client_ip(event) + + # Replace security-critical headers with server-controlled values + for header_name in SECURITY_CRITICAL_HEADERS: + # Check both lowercase and original case (HTTP headers are case-insensitive) + header_lower = header_name.lower() + + # Find the actual header key (may have different casing) + actual_key = None + for key in sanitized.keys(): + if key.lower() == header_lower: + actual_key = key + break + + if actual_key: + # Store original value for debugging (with clear marker) + original_value = sanitized[actual_key] + + # Replace with server-controlled value + if header_lower in ("x-forwarded-for", "x-real-ip"): + sanitized[actual_key] = real_ip + elif header_lower == "x-forwarded-host": + # Use the actual Host header from API Gateway + sanitized[actual_key] = event.get("requestContext", {}).get("domainName", "unknown") + elif header_lower == "x-forwarded-server": + # Use API Gateway stage + sanitized[actual_key] = event.get("requestContext", {}).get("stage", "unknown") + elif header_lower == "x-amzn-client-id": + # Use the validated request ID from API Gateway + sanitized[actual_key] = event.get("requestContext", {}).get("requestId", "unknown") + elif header_lower == "forwarded": + # Reconstruct Forwarded header with server values + sanitized[actual_key] = f"for={real_ip}" + + # Log the sanitization for security monitoring + if original_value != sanitized[actual_key]: + logger.debug( + f"Sanitized header {actual_key}: original={original_value}, sanitized={sanitized[actual_key]}" + ) + + return sanitized + + +def get_sanitized_headers_for_logging(event: dict[str, Any]) -> dict[str, Any]: + """ + Extract and sanitize headers from Lambda event for safe logging. + + This is a convenience function that extracts headers from the event + and sanitizes them in one step. + + Args: + event: Lambda event from API Gateway + + Returns: + Dictionary of sanitized headers safe for logging + + Example: + >>> event = { + ... "headers": {"x-forwarded-for": "1.2.3.4"}, + ... "requestContext": {"identity": {"sourceIp": "5.6.7.8"}} + ... } + >>> headers = get_sanitized_headers_for_logging(event) + >>> headers["x-forwarded-for"] + "5.6.7.8" + """ + headers = event.get("headers", {}) + return sanitize_headers(headers, event) diff --git a/lambda/utilities/healthcheck_validator.py b/lambda/utilities/healthcheck_validator.py new file mode 100644 index 000000000..f342de9fd --- /dev/null +++ b/lambda/utilities/healthcheck_validator.py @@ -0,0 +1,83 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validator for ECS healthcheck command format.""" + + +def validate_healthcheck_command(command: str | list[str]) -> None: + """ + Validate ECS healthcheck command format. + + This validation ensures the command format is compatible with ECS requirements + to prevent deployment failures. It does NOT restrict command content - admins + are trusted to configure their containers appropriately. + + Args: + command: Healthcheck command as string or array + + Raises: + ValueError: If command format is invalid for ECS + + Examples: + Valid formats: + - "curl -f http://localhost:8080/health" + - ["CMD-SHELL", "curl -f http://localhost:8080/health"] + - ["CMD", "curl", "-f", "http://localhost:8080/health"] + + Invalid formats: + - "" (empty string) + - [] (empty array) + - ["curl", "-f", "..."] (missing CMD/CMD-SHELL prefix) + """ + # Check if command is None + if command is None: + raise ValueError("Healthcheck command cannot be None") + + # Check if command is string + if isinstance(command, str): + if not command.strip(): + raise ValueError("Healthcheck command cannot be an empty string") + # String format is valid - ECS converts to CMD-SHELL + return + + # Check if command is list + if isinstance(command, list): + if len(command) == 0: + raise ValueError("Healthcheck command array cannot be empty") + + # Check first element is CMD or CMD-SHELL + if command[0] not in ["CMD", "CMD-SHELL"]: + raise ValueError( + f"Healthcheck array must start with 'CMD' or 'CMD-SHELL', got: '{command[0]}'. " + "Example: ['CMD-SHELL', 'curl -f http://localhost:8080/health']" + ) + + # Check there's at least one command after the prefix + if len(command) < 2: + raise ValueError( + f"Healthcheck array must contain a command after '{command[0]}'. " + "Example: ['CMD-SHELL', 'curl -f http://localhost:8080/health']" + ) + + # Check command part is not empty + if isinstance(command[1], str) and not command[1].strip(): + raise ValueError("Healthcheck command cannot be empty after CMD/CMD-SHELL prefix") + + return + + # Invalid type + raise ValueError( + f"Healthcheck command must be a string or array, got: {type(command).__name__}. " + "Example: ['CMD-SHELL', 'curl -f http://localhost:8080/health']" + ) diff --git a/lambda/utilities/input_validation.py b/lambda/utilities/input_validation.py new file mode 100644 index 000000000..463c88b19 --- /dev/null +++ b/lambda/utilities/input_validation.py @@ -0,0 +1,189 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Input validation utilities for Lambda functions.""" + +import functools +import logging +from collections.abc import Callable +from typing import Any, TypeVar + +from utilities.response_builder import generate_html_response + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + +# Default maximum request size: 1MB +DEFAULT_MAX_REQUEST_SIZE = 1024 * 1024 +# Max API Gateway size - use for image uploads / chat sessions +MAX_LARGE_REQUEST_SIZE = 10 * 1024 * 1024 + + +def contains_null_bytes(data: str) -> bool: + """ + Check if a string contains null bytes. + + Null bytes (\\x00) can be used to bypass input validation or cause + unexpected behavior in string processing. + + Args: + data: String to check for null bytes + + Returns: + True if null bytes are found, False otherwise + """ + return "\x00" in data + + +def validate_input(max_request_size: int = DEFAULT_MAX_REQUEST_SIZE) -> Callable[[F], F]: + """ + Decorator to validate Lambda event input before processing. + + This decorator provides security protections against: + - Null byte injection attacks + - Oversized payload attacks + - Invalid HTTP methods + + Args: + max_request_size: Maximum allowed request body size in bytes (default: 1MB) + + Returns: + Decorator function that wraps the Lambda handler + """ + + def decorator(f: F) -> F: + @functools.wraps(f) + def wrapper(event: dict, context: dict) -> dict[str, str | int | dict[str, str]]: + """ + Validate Lambda event input. + + Validation order: + 1. HTTP method validation (returns 405 if invalid) + 2. Request size check (returns 413 if too large) + 3. Path validation (returns 400 if null bytes found) + 4. Query parameter validation (returns 400 if null bytes found) + 5. Request body validation (returns 400 if null bytes found) + + Args: + event: Lambda event from API Gateway + context: Lambda context + + Returns: + Error response if validation fails, otherwise calls wrapped function + """ + # 1. Validate HTTP method + http_method = event.get("httpMethod", "") + valid_methods = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"} + if http_method not in valid_methods: + logger.warning( + f"Invalid HTTP method: {http_method}", + extra={ + "method": http_method, + "path": event.get("path", ""), + }, + ) + return generate_html_response( + 405, + { + "error": "Method Not Allowed", + "message": f"HTTP method {http_method} is not allowed", + }, + ) + + # 2. Check request size + body = event.get("body", "") + if body: + body_size = len(body.encode("utf-8")) + if body_size > max_request_size: + logger.warning( + f"Request size {body_size} bytes exceeds maximum {max_request_size} bytes", + extra={ + "request_size": body_size, + "max_size": max_request_size, + "path": event.get("path", ""), + "method": http_method, + }, + ) + return generate_html_response( + 413, + { + "error": "Payload Too Large", + "message": f"Request body size exceeds maximum allowed size of {max_request_size} bytes", + }, + ) + + # 3. Validate path for null bytes + path = event.get("path", "") + if contains_null_bytes(path): + logger.warning( + "Null byte detected in path", + extra={ + "path": path, + "method": http_method, + }, + ) + return generate_html_response( + 400, + { + "error": "Bad Request", + "message": "Invalid characters detected in request path", + }, + ) + + # 4. Validate query parameters for null bytes + query_params = event.get("queryStringParameters") or {} + for key, value in query_params.items(): + if contains_null_bytes(key) or contains_null_bytes(str(value)): + logger.warning( + f"Null byte detected in query parameter: {key}", + extra={ + "parameter_name": key, + "path": path, + "method": http_method, + }, + ) + return generate_html_response( + 400, + { + "error": "Bad Request", + "message": "Invalid characters detected in query parameters", + }, + ) + + # 5. Validate request body for null bytes + if body and contains_null_bytes(body): + logger.warning( + "Null byte detected in request body", + extra={ + "path": path, + "method": http_method, + "body_size": body_size, + }, + ) + return generate_html_response( + 400, + { + "error": "Bad Request", + "message": "Invalid characters detected in request body", + }, + ) + + # All validations passed, call the wrapped function + result: dict[str, str | int | dict[str, str]] = f(event, context) + return result + + return wrapper # type: ignore [return-value] + + return decorator diff --git a/lambda/utilities/lambda_decorators.py b/lambda/utilities/lambda_decorators.py new file mode 100644 index 000000000..73f041de6 --- /dev/null +++ b/lambda/utilities/lambda_decorators.py @@ -0,0 +1,171 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda function decorators for API Gateway integration.""" + +import functools +import logging +from collections.abc import Callable +from contextvars import ContextVar +from typing import Any, overload + +from utilities.event_parser import sanitize_event_for_logging +from utilities.input_validation import DEFAULT_MAX_REQUEST_SIZE, validate_input +from utilities.response_builder import generate_exception_response, generate_html_response + +logger = logging.getLogger(__name__) + +# Context variable to store Lambda context across the request +ctx_context: ContextVar[Any] = ContextVar("lamdbacontext") + +# Type for Lambda handler functions - can return dict, list, or any JSON-serializable type +LambdaHandler = Callable[[dict[Any, Any], Any], Any] + + +@overload +def api_wrapper(_func: LambdaHandler) -> LambdaHandler: + """Overload for decorator without parentheses.""" + ... + + +@overload +def api_wrapper( + _func: None = None, + *, + max_request_size: int = DEFAULT_MAX_REQUEST_SIZE, +) -> Callable[[LambdaHandler], LambdaHandler]: + """Overload for decorator with parameters.""" + ... + + +def api_wrapper( + _func: LambdaHandler | None = None, + *, + max_request_size: int = DEFAULT_MAX_REQUEST_SIZE, +) -> LambdaHandler | Callable[[LambdaHandler], LambdaHandler]: + """ + Wrap Lambda function with comprehensive API Gateway integration. + + This decorator provides: + - Input validation (null bytes, request size, HTTP methods) + - Request logging with sanitized headers + - Exception handling with appropriate status codes + - Security headers in responses + + Can be used with or without parameters: + - @api_wrapper + - @api_wrapper() + - @api_wrapper(max_request_size=10 * 1024 * 1024) + + Parameters + ---------- + _func : LambdaHandler | None + The Lambda handler function (used when decorator is applied without parentheses). + max_request_size : int + Maximum allowed request body size in bytes (default: 1MB). + + Returns + ------- + LambdaHandler | Callable[[LambdaHandler], LambdaHandler] + The wrapped function with API Gateway integration. + + Example + ------- + >>> @api_wrapper + ... def get_user(event: dict, context: dict) -> dict: + ... user_id = event["pathParameters"]["userId"] + ... return {"userId": user_id, "name": "John"} + + >>> @api_wrapper(max_request_size=10 * 1024 * 1024) + ... def upload_image(event: dict, context: dict) -> dict: + ... # Handle large payload + ... return {"status": "uploaded"} + """ + + def decorator(f: LambdaHandler) -> LambdaHandler: + @functools.wraps(f) + @validate_input(max_request_size=max_request_size) + def wrapper(event: dict[Any, Any], context: Any) -> dict[Any, Any]: + """Execute Lambda handler with API Gateway integration.""" + ctx_context.set(context) + code_func_name = f.__name__ + lambda_func_name = getattr(context, "function_name", "unknown") + + # Log request with sanitized event data + sanitized_event = sanitize_event_for_logging(event) + logger.info(f"Lambda {lambda_func_name}({code_func_name}) invoked with {sanitized_event}") + + try: + result = f(event, context) + return generate_html_response(200, result) + except Exception as e: + return generate_exception_response(e) + + return wrapper + + # Handle both @api_wrapper and @api_wrapper() syntax + if _func is not None: + return decorator(_func) + return decorator + + +def authorization_wrapper(f: LambdaHandler) -> LambdaHandler: + """ + Wrap Lambda authorizer function. + + This decorator sets up the Lambda context for authorizer functions + without adding API Gateway response formatting. + + Parameters + ---------- + f : LambdaHandler + The Lambda authorizer function to wrap. + + Returns + ------- + LambdaHandler + The wrapped authorizer function. + + Example + ------- + >>> @authorization_wrapper + ... def authorizer(event: dict, context: dict) -> dict: + ... token = event["authorizationToken"] + ... return {"principalId": "user123", "policyDocument": {...}} + """ + + @functools.wraps(f) + def wrapper(event: dict[Any, Any], context: Any) -> Any: + """Execute Lambda authorizer with context setup.""" + ctx_context.set(context) + return f(event, context) + + return wrapper + + +def get_lambda_context() -> Any: + """ + Get the current Lambda context from context variable. + + Returns + ------- + Any + The Lambda context object. + + Raises + ------ + LookupError + If called outside of a Lambda execution context. + """ + return ctx_context.get() diff --git a/lambda/utilities/repository_types.py b/lambda/utilities/repository_types.py index 1800d951f..afa0c8d2f 100644 --- a/lambda/utilities/repository_types.py +++ b/lambda/utilities/repository_types.py @@ -15,7 +15,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Dict +from typing import Any class RepositoryType(str, Enum): @@ -24,11 +24,11 @@ class RepositoryType(str, Enum): BEDROCK_KB = "bedrock_knowledge_base" @classmethod - def get_type(cls, repository: Dict[str, Any]) -> RepositoryType: + def get_type(cls, repository: dict[str, Any]) -> RepositoryType: return RepositoryType(repository.get("type")) @classmethod - def is_type(cls, repository: Dict[str, Any], repo_type: RepositoryType) -> bool: + def is_type(cls, repository: dict[str, Any], repo_type: RepositoryType) -> bool: return repository.get("type") == repo_type def calculate_similarity_score(self, score: float) -> float: diff --git a/lambda/utilities/response_builder.py b/lambda/utilities/response_builder.py new file mode 100644 index 000000000..60804c0ca --- /dev/null +++ b/lambda/utilities/response_builder.py @@ -0,0 +1,176 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Response builders for API Gateway Lambda functions.""" + +import json +import logging +from datetime import datetime +from decimal import Decimal +from typing import Any + +logger = logging.getLogger(__name__) + + +class DecimalEncoder(json.JSONEncoder): + """JSON encoder that handles Decimal and datetime objects.""" + + def default(self, obj: Any) -> Any: + """ + Encode special types to JSON-serializable formats. + + Parameters + ---------- + obj : Any + Object to encode. + + Returns + ------- + Any + JSON-serializable representation. + """ + if isinstance(obj, Decimal): + return float(obj) + if isinstance(obj, datetime): + return obj.isoformat() + return super().default(obj) + + +def generate_html_response(status_code: int, response_body: dict) -> dict[str, str | int | dict[str, str]]: + """ + Generate API Gateway response with security headers. + + This function creates a properly formatted API Gateway response with: + - JSON-encoded body + - Security headers (HSTS, X-Frame-Options, etc.) + - CORS headers + - Cache control headers + + Parameters + ---------- + status_code : int + HTTP status code (e.g., 200, 400, 500). + response_body : dict + Response body to be JSON-encoded. + + Returns + ------- + Dict[str, Union[str, int, Dict[str, str]]] + API Gateway response object. + + Example + ------- + >>> generate_html_response(200, {"userId": "123", "name": "John"}) + { + "statusCode": 200, + "body": '{"userId": "123", "name": "John"}', + "headers": {...} + } + """ + return { + "statusCode": status_code, + "body": json.dumps(response_body, cls=DecimalEncoder), + "headers": { + "Access-Control-Allow-Origin": "*", + "Content-Type": "application/json", + "Cache-Control": "no-store, no-cache", + "Pragma": "no-cache", + "Strict-Transport-Security": "max-age:47304000; includeSubDomains", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + }, + } + + +def generate_exception_response(e: Exception) -> dict[str, str | int | dict[str, str]]: + """ + Generate API Gateway error response from exception. + + This function maps exceptions to appropriate HTTP status codes and + generates user-friendly error messages while logging detailed errors + internally. + + Exception Mapping: + - ValidationError → 400 Bad Request + - AWS SDK exceptions → Status from response metadata + - Custom exceptions with http_status_code/status_code → Custom status + - Missing event parameters → 400 Bad Request + - All other exceptions → 500 Internal Server Error + + Parameters + ---------- + e : Exception + Exception that was caught. + + Returns + ------- + Dict[str, Union[str, int, Dict[str, str]]] + API Gateway error response. + + Example + ------- + >>> try: + ... raise ValueError("Invalid user ID") + ... except Exception as e: + ... response = generate_exception_response(e) + >>> response["statusCode"] + 500 + """ + status_code = 400 + error_message: str + + if type(e).__name__ == "ValidationError": + # User input validation error - return 400 with error message + error_message = str(e) + logger.exception(e) + elif hasattr(e, "response"): + # AWS SDK exception - extract status code and message + metadata = e.response.get("ResponseMetadata") + if metadata: + status_code = metadata.get("HTTPStatusCode", 400) + error_message = str(e) + logger.exception(e) + elif hasattr(e, "http_status_code"): + # Custom exception with http_status_code attribute + status_code = e.http_status_code + error_message = getattr(e, "message", str(e)) + logger.exception(e) + elif hasattr(e, "status_code"): + # Custom exception with status_code attribute (e.g., HTTPException) + status_code = e.status_code + error_message = getattr(e, "message", str(e)) + logger.exception(e) + else: + # Generic unhandled exception - return 500 with generic message + error_msg = str(e) + if error_msg in ["'requestContext'", "'pathParameters'", "'body'"]: + # Missing event parameter - this is a 400 error + status_code = 400 + error_message = f"Missing event parameter: {error_msg}" + else: + # Genuine server error - return 500 with generic message + status_code = 500 + error_message = "An unexpected error occurred while processing your request" + # Log detailed error for debugging + logger.error( + f"Unhandled exception: {type(e).__name__}: {error_msg}", + exc_info=e, + extra={ + "exception_type": type(e).__name__, + "exception_message": error_msg, + }, + ) + logger.exception(e) + + return generate_html_response(status_code, error_message) # type: ignore [arg-type] diff --git a/lambda/utilities/session_encryption.py b/lambda/utilities/session_encryption.py index 3b57ef8b7..1ec78ab64 100644 --- a/lambda/utilities/session_encryption.py +++ b/lambda/utilities/session_encryption.py @@ -19,7 +19,7 @@ import logging import os from decimal import Decimal -from typing import Any, Dict, Optional +from typing import Any import boto3 from botocore.exceptions import ClientError @@ -64,7 +64,7 @@ def _get_kms_key_arn() -> str: return key_arn -def _generate_data_key(key_arn: str, encryption_context: Optional[Dict[str, str]] = None) -> tuple[bytes, bytes]: +def _generate_data_key(key_arn: str, encryption_context: dict[str, str] | None = None) -> tuple[bytes, bytes]: """ Generate a data key from KMS. @@ -85,7 +85,7 @@ def _generate_data_key(key_arn: str, encryption_context: Optional[Dict[str, str] raise SessionEncryptionError(f"Failed to generate data key: {e}") -def _decrypt_data_key(encrypted_data_key: bytes, encryption_context: Optional[Dict[str, str]] = None) -> bytes: +def _decrypt_data_key(encrypted_data_key: bytes, encryption_context: dict[str, str] | None = None) -> bytes: """ Decrypt a data key using KMS. @@ -104,7 +104,7 @@ def _decrypt_data_key(encrypted_data_key: bytes, encryption_context: Optional[Di raise SessionEncryptionError(f"Failed to decrypt data key: {e}") -def _create_encryption_context(user_id: str, session_id: str) -> Dict[str, str]: +def _create_encryption_context(user_id: str, session_id: str) -> dict[str, str]: """ Create encryption context for KMS operations. @@ -226,7 +226,7 @@ def is_encrypted_data(data: str) -> bool: return False -def migrate_session_to_encrypted(session_data: Dict[str, Any], user_id: str, session_id: str) -> Dict[str, Any]: +def migrate_session_to_encrypted(session_data: dict[str, Any], user_id: str, session_id: str) -> dict[str, Any]: """ Migrate a session from unencrypted to encrypted format. @@ -264,7 +264,7 @@ def migrate_session_to_encrypted(session_data: Dict[str, Any], user_id: str, ses raise SessionEncryptionError(f"Failed to migrate session to encrypted: {e}") -def decrypt_session_fields(session_data: Dict[str, Any], user_id: str, session_id: str) -> Dict[str, Any]: +def decrypt_session_fields(session_data: dict[str, Any], user_id: str, session_id: str) -> dict[str, Any]: """ Decrypt encrypted fields in session data. diff --git a/lambda/utilities/time.py b/lambda/utilities/time.py index 2cd0d855a..730d9d4ca 100644 --- a/lambda/utilities/time.py +++ b/lambda/utilities/time.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime, timezone +from datetime import datetime, timezone, tzinfo -def now(tz=timezone.utc) -> int: +def now(tz: tzinfo = timezone.utc) -> int: """Return UTC epoch milliseconds.""" return int(datetime.now(tz).timestamp() * 1000) -def now_seconds(tz=timezone.utc) -> int: +def now_seconds(tz: tzinfo = timezone.utc) -> int: """Return UTC epoch seconds.""" return int(datetime.now(tz).timestamp()) -def iso_string(tz=timezone.utc) -> str: +def iso_string(tz: tzinfo = timezone.utc) -> str: """Return ISO datetime string with UTC offset.""" return datetime.now(tz).isoformat() diff --git a/lambda/utilities/validation.py b/lambda/utilities/validation.py index 7ab79c5cd..a48d4b633 100644 --- a/lambda/utilities/validation.py +++ b/lambda/utilities/validation.py @@ -14,7 +14,7 @@ """Validation utilities for Lambda functions.""" import logging -from typing import Any, List +from typing import Any import botocore.session @@ -78,7 +78,7 @@ def validate_instance_type(type: str) -> str: raise ValueError("Invalid EC2 instance type.") -def validate_all_fields_defined(fields: List[Any]) -> bool: +def validate_all_fields_defined(fields: list[Any]) -> bool: """Validate that all fields are non-null in the field list. Args: @@ -87,10 +87,10 @@ def validate_all_fields_defined(fields: List[Any]) -> bool: Returns: bool: True if all fields are non-null, False otherwise """ - return all((field is not None for field in fields)) + return all(field is not None for field in fields) -def validate_any_fields_defined(fields: List[Any]) -> bool: +def validate_any_fields_defined(fields: list[Any]) -> bool: """Validate that at least one field is non-null in the field list. Args: @@ -99,7 +99,7 @@ def validate_any_fields_defined(fields: List[Any]) -> bool: Returns: bool: True if at least one field is non-null, False otherwise """ - return any((field is not None for field in fields)) + return any(field is not None for field in fields) def safe_error_response(error: Exception) -> dict: diff --git a/lib/api-base/ecsCluster.ts b/lib/api-base/ecsCluster.ts index 71e95a245..c8c6672f6 100644 --- a/lib/api-base/ecsCluster.ts +++ b/lib/api-base/ecsCluster.ts @@ -20,6 +20,7 @@ import { AdjustmentType, AutoScalingGroup, BlockDeviceVolume, GroupMetrics, Moni import { LogGroup, RetentionDays } from 'aws-cdk-lib/aws-logs'; import { Metric } from 'aws-cdk-lib/aws-cloudwatch'; import { InstanceType, ISecurityGroup, Port, SecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { Alias } from 'aws-cdk-lib/aws-kms'; import { AmiHardwareType, AsgCapacityProvider, @@ -216,7 +217,7 @@ export class ECSCluster extends Construct { vpc: vpc.vpc, vpcSubnets: vpc.subnetSelection, instanceType: new InstanceType(ecsConfig.instanceType), - machineImage: EcsOptimizedImage.amazonLinux2(ecsConfig.amiHardwareType), + machineImage: EcsOptimizedImage.amazonLinux2023(ecsConfig.amiHardwareType), minCapacity: ecsConfig.autoScalingConfig.minCapacity, maxCapacity: ecsConfig.autoScalingConfig.maxCapacity, cooldown: Duration.seconds(ecsConfig.autoScalingConfig.cooldown), @@ -236,6 +237,15 @@ export class ECSCluster extends Construct { updatePolicy: UpdatePolicy.rollingUpdate({}) }); + // Enable SNS topic encryption for ECS lifecycle hooks + // AppSec Finding #5: SNS topics must use server-side encryption + // Uses AWS managed key (alias/aws/sns) for lifecycle hook drain notifications + const snsEncryptionKey = Alias.fromAliasName( + this, + createCdkId([config.deploymentName, config.deploymentStage, 'SnsKey']), + 'alias/aws/sns' + ); + const asgCapacityProvider = new AsgCapacityProvider(this, createCdkId([config.deploymentName, config.deploymentStage, 'AsgCapacityProvider']), { autoScalingGroup, // Managed scaling tracks cluster reservation to add/remove instances automatically @@ -247,6 +257,9 @@ export class ECSCluster extends Construct { // disable managed scaling because we are going to setup rules to do it enableManagedScaling: false, enableManagedTerminationProtection: false, + + // Encrypt SNS topic used for lifecycle hook notifications + topicEncryptionKey: snsEncryptionKey, }); cluster.addAsgCapacityProvider(asgCapacityProvider); @@ -270,7 +283,6 @@ export class ECSCluster extends Construct { ], evaluationPeriods: 5, adjustmentType: AdjustmentType.CHANGE_IN_CAPACITY, - cooldown: Duration.seconds(300) }); autoScalingGroup.scaleOnMetric(createCdkId(['ASG', identifier, 'ScaleOut']), { @@ -281,7 +293,6 @@ export class ECSCluster extends Construct { ], evaluationPeriods: 2, adjustmentType: AdjustmentType.CHANGE_IN_CAPACITY, - cooldown: Duration.seconds(120) }); // Tag Auto Scaling Group for schedule management @@ -385,13 +396,16 @@ export class ECSCluster extends Construct { asgSecurityGroup.addIngressRule(securityGroup, Port.allTcp()); // Add listener + // AppSec TLS Configuration: Use TLS 1.2/1.3 policy with forward secrecy (ECDHE cipher suites only) + // SslPolicy.TLS13_RES maps to ELBSecurityPolicy-TLS13-1-2-2021-06 + // This policy excludes RSA key exchange cipher suites to meet tlscheckerv2 compliance requirements const listenerProps: BaseApplicationListenerProps = { port: ecsConfig.loadBalancerConfig.sslCertIamArn ? 443 : 80, open: ecsConfig.internetFacing, certificates: ecsConfig.loadBalancerConfig.sslCertIamArn ? [{ certificateArn: ecsConfig.loadBalancerConfig.sslCertIamArn }] : undefined, - sslPolicy: ecsConfig.loadBalancerConfig.sslCertIamArn ? SslPolicy.RECOMMENDED_TLS : SslPolicy.RECOMMENDED, + sslPolicy: ecsConfig.loadBalancerConfig.sslCertIamArn ? SslPolicy.TLS13_RES : undefined, }; const listener = loadBalancer.addListener( @@ -581,7 +595,12 @@ export class ECSCluster extends Construct { circuitBreaker: !this.config.region.includes('iso') ? { rollback: true } : undefined, capacityProviderStrategies: [ { capacityProvider: this.asgCapacityProvider.capacityProviderName, weight: 1 } - ] + ], + // Speed up deployments by allowing more aggressive rollout + minHealthyPercent: 50, // Allow 50% of tasks to be replaced at once + maxHealthyPercent: 200, // Allow up to 2x desired count during deployment + // Reduce health check grace period for faster failure detection + healthCheckGracePeriod: Duration.seconds(60) }; const service = new Ec2Service(this, createCdkId([this.config.deploymentName, taskName, 'Ec2Svc']), serviceProps); diff --git a/lib/api-base/fastApiContainer.ts b/lib/api-base/fastApiContainer.ts index b9d23c144..658b1a942 100644 --- a/lib/api-base/fastApiContainer.ts +++ b/lib/api-base/fastApiContainer.ts @@ -176,7 +176,7 @@ export class FastApiContainer extends Construct { interval: 60, timeout: 30, healthyThresholdCount: 2, - unhealthyThresholdCount: 10 + unhealthyThresholdCount: 3 // Reduced from 10 to 3 for faster failure detection }, domainName: config.restApiConfig.domainName, sslCertIamArn: config.restApiConfig?.sslCertIamArn ?? null, diff --git a/lib/api-base/utils.ts b/lib/api-base/utils.ts index 9854fe484..8cce7228f 100644 --- a/lib/api-base/utils.ts +++ b/lib/api-base/utils.ts @@ -94,7 +94,10 @@ export function registerAPIEndpoint ( let handler; if (funcDef.existingFunction) { - handler = Function.fromFunctionArn(scope, functionId, funcDef.existingFunction); + handler = Function.fromFunctionAttributes(scope, functionId, { + functionArn: funcDef.existingFunction, + sameEnvironment: true, + }); // create a CFN L1 primitive because `handler.addPermission` doesn't behave as expected // https://stackoverflow.com/questions/71075361/aws-cdk-lambda-resource-based-policy-for-a-function-with-an-alias diff --git a/lib/chat/api/mcp.ts b/lib/chat/api/mcp.ts index 2c2d93ce9..1b0a711c2 100644 --- a/lib/chat/api/mcp.ts +++ b/lib/chat/api/mcp.ts @@ -119,7 +119,7 @@ export class McpApi extends Construct { // Create API Lambda functions const apis: PythonLambdaFunction[] = [ { - name: 'list', + name: 'list_mcp_servers', resource: 'mcp_server', description: 'Lists available mcp servers for user', path: 'mcp-server', diff --git a/lib/chat/api/prompt-template-api.ts b/lib/chat/api/prompt-template-api.ts index 284ce41c3..81d363b59 100644 --- a/lib/chat/api/prompt-template-api.ts +++ b/lib/chat/api/prompt-template-api.ts @@ -126,7 +126,7 @@ export class PromptTemplateApi extends Construct { environment, }, { - name: 'list', + name: 'list_prompt', resource: 'prompt_templates', description: 'Lists all available prompt templates', path: 'prompt-templates', diff --git a/lib/chat/api/session.ts b/lib/chat/api/session.ts index 1e8a777e2..88586905a 100644 --- a/lib/chat/api/session.ts +++ b/lib/chat/api/session.ts @@ -28,8 +28,6 @@ import { BaseProps } from '../../schema'; import { createLambdaRole } from '../../core/utils'; import { Vpc } from '../../networking/vpc'; import { LAMBDA_PATH } from '../../util'; -import { Bucket, HttpMethods } from 'aws-cdk-lib/aws-s3'; -import { RemovalPolicy } from 'aws-cdk-lib'; /** * Properties for SessionApi Construct. @@ -115,27 +113,12 @@ export class SessionApi extends Construct { stringValue: sessionEncryptionKey.keyArn, }); - const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', - StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) + // Get Images S3 bucket name from API Base stack (created there for cross-stack access) + const imagesBucketName = StringParameter.valueForStringParameter( + this, + `${config.deploymentPrefix}/generatedImagesBucketName` ); - // Create Images S3 bucket - const imagesBucket = new Bucket(scope, 'GeneratedImagesBucket', { - removalPolicy: config.removalPolicy, - autoDeleteObjects: config.removalPolicy === RemovalPolicy.DESTROY, - enforceSSL: true, - cors: [ - { - allowedMethods: [HttpMethods.GET, HttpMethods.POST], - allowedHeaders: ['*'], - allowedOrigins: ['*'], - exposedHeaders: ['Access-Control-Allow-Origin'], - }, - ], - serverAccessLogsBucket: bucketAccessLogsBucket, - serverAccessLogsPrefix: 'logs/generated-images-bucket/' - }); - const restApi = RestApi.fromRestApiAttributes(this, 'RestApi', { restApiId: restApiId, rootResourceId: rootResourceId, @@ -150,7 +133,7 @@ export class SessionApi extends Construct { const env = { SESSIONS_TABLE_NAME: sessionTable.tableName, SESSIONS_BY_USER_ID_INDEX_NAME: byUserIdIndex, - GENERATED_IMAGES_S3_BUCKET_NAME: imagesBucket.bucketName, + GENERATED_IMAGES_S3_BUCKET_NAME: imagesBucketName, MODEL_TABLE_NAME: modelTableName, CONFIG_TABLE_NAME: configTable.tableName, SESSION_ENCRYPTION_KEY_ARN: sessionEncryptionKey.keyArn, @@ -288,13 +271,42 @@ export class SessionApi extends Construct { ); if (f.method === 'POST' || f.method === 'PUT') { sessionTable.grantWriteData(lambdaFunction); - imagesBucket.grantReadWrite(lambdaFunction); + // Grant S3 read/write permissions for image/video operations + lambdaRole.addToPrincipalPolicy( + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['s3:PutObject', 's3:GetObject'], + resources: [`arn:${config.partition}:s3:::${imagesBucketName}/*`] + }) + ); } else if (f.method === 'GET') { sessionTable.grantReadData(lambdaFunction); - imagesBucket.grantRead(lambdaFunction); + // Grant S3 read permissions + lambdaRole.addToPrincipalPolicy( + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['s3:GetObject'], + resources: [`arn:${config.partition}:s3:::${imagesBucketName}/*`] + }) + ); } else if (f.method === 'DELETE') { sessionTable.grantReadWriteData(lambdaFunction); - imagesBucket.grantDelete(lambdaFunction); + // Grant S3 list permission on bucket for prefix-based listing + lambdaRole.addToPrincipalPolicy( + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['s3:ListBucket'], + resources: [`arn:${config.partition}:s3:::${imagesBucketName}`] + }) + ); + // Grant S3 delete permissions on objects + lambdaRole.addToPrincipalPolicy( + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['s3:DeleteObject'], + resources: [`arn:${config.partition}:s3:::${imagesBucketName}/*`] + }) + ); } }); } diff --git a/lib/core/apiBaseConstruct.ts b/lib/core/apiBaseConstruct.ts index dc6f41661..df67a1ae4 100644 --- a/lib/core/apiBaseConstruct.ts +++ b/lib/core/apiBaseConstruct.ts @@ -20,11 +20,11 @@ import { Authorizer, Cors, EndpointType, RestApi, StageOptions } from 'aws-cdk-l import { AttributeType, BillingMode, ProjectionType, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; import { CustomAuthorizer } from '../api-base/authorizer'; -import { Duration, Stack, StackProps } from 'aws-cdk-lib'; +import { Duration, RemovalPolicy, Stack, StackProps } from 'aws-cdk-lib'; import { ITable, Table } from 'aws-cdk-lib/aws-dynamodb'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { Construct } from 'constructs'; -import { Code, Function, } from 'aws-cdk-lib/aws-lambda'; +import { Code, Function, IFunction, LayerVersion } from 'aws-cdk-lib/aws-lambda'; import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; @@ -42,6 +42,7 @@ import { LAMBDA_PATH } from '../util'; import { getPythonRuntime } from '../api-base/utils'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { EventBus } from 'aws-cdk-lib/aws-events'; +import { Bucket, BucketEncryption, HttpMethods } from 'aws-cdk-lib/aws-s3'; export type LisaApiBaseProps = { vpc: Vpc; @@ -60,12 +61,45 @@ export class LisaApiBaseConstruct extends Construct { public readonly restApiUrl: string; public readonly tokenTable?: ITable; public readonly managementKeySecretName: string; + public readonly iamAuthSetupFn: IFunction; + public readonly imagesBucket: Bucket; constructor (scope: Stack, id: string, props: LisaApiBaseProps) { super(scope, id); const { config, vpc, securityGroups } = props; + // Get bucket access logs bucket + const bucketAccessLogsBucket = Bucket.fromBucketArn(scope, 'BucketAccessLogsBucket', + StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/bucket/bucket-access-logs`) + ); + + // Create Images S3 bucket for generated images and videos + // This is created in API Base stack so it's available to both Chat and Serve stacks + this.imagesBucket = new Bucket(scope, 'GeneratedImagesBucket', { + removalPolicy: config.removalPolicy, + autoDeleteObjects: config.removalPolicy === RemovalPolicy.DESTROY, + enforceSSL: true, + cors: [ + { + allowedMethods: [HttpMethods.GET, HttpMethods.POST], + allowedHeaders: ['*'], + allowedOrigins: ['*'], + exposedHeaders: ['Access-Control-Allow-Origin'], + }, + ], + serverAccessLogsBucket: bucketAccessLogsBucket, + serverAccessLogsPrefix: 'logs/generated-images-bucket/', + encryption: BucketEncryption.S3_MANAGED + }); + + // Store bucket name in SSM for cross-stack access + new StringParameter(scope, 'GeneratedImagesBucketNameParameter', { + parameterName: `${config.deploymentPrefix}/generatedImagesBucketName`, + stringValue: this.imagesBucket.bucketName, + description: 'S3 bucket name for generated images and videos', + }); + // TokenTable is now managed in API Base so it's independent of Serve // Create the table - if it already exists from previous Serve deployment, // CloudFormation will handle the conflict. For new deployments, it will be created. @@ -107,6 +141,10 @@ export class LisaApiBaseConstruct extends Construct { const { managementKeySecretName } = this.createManagementKeySecret(scope, config, vpc, securityGroups); this.managementKeySecretName = managementKeySecretName; + // Create shared IAM auth setup Lambda for PGVector databases + // This Lambda is used by Serve, RAG, and vector_store_deployer stacks + this.iamAuthSetupFn = this.createIamAuthSetupLambda(scope, config, vpc, securityGroups); + const deployOptions: StageOptions = { stageName: config.deploymentStage, throttlingRateLimit: 100, @@ -218,4 +256,63 @@ export class LisaApiBaseConstruct extends Construct { return { managementKeySecretName }; } + + /** + * Creates a shared Lambda for IAM authentication setup on PGVector databases. + * This Lambda creates IAM database users and deletes bootstrap secrets. + * It's shared across Serve, RAG, and vector_store_deployer stacks. + */ + private createIamAuthSetupLambda (scope: Stack, config: Config, vpc: Vpc, securityGroups: ISecurityGroup[]): IFunction { + // Create IAM role for the Lambda + const iamAuthSetupRole = new Role(scope, 'IamAuthSetupRole', { + assumedBy: new ServicePrincipal('lambda.amazonaws.com'), + managedPolicies: [ + ManagedPolicy.fromAwsManagedPolicyName('service-role/AWSLambdaVPCAccessExecutionRole'), + ], + }); + + // Grant permissions to read/delete secrets (specific secrets will be passed via event) + iamAuthSetupRole.addToPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['secretsmanager:GetSecretValue', 'secretsmanager:DeleteSecret'], + resources: [`arn:${config.partition}:secretsmanager:${config.region}:${config.accountNumber}:secret:*`], + })); + + // Get common layer for psycopg2 + const commonLayer = LayerVersion.fromLayerVersionArn( + scope, + 'IamAuthCommonLayer', + StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/layerVersion/common`), + ); + + const iamAuthSetupFn = new Function(scope, 'IamAuthSetupFn', { + functionName: createCdkId([config.deploymentName, config.deploymentStage, 'iam_auth_setup']), + runtime: getPythonRuntime(), + handler: 'utilities.db_setup_iam_auth.handler', + code: Code.fromAsset(config.lambdaPath || LAMBDA_PATH), + timeout: Duration.minutes(2), + memorySize: 256, + role: iamAuthSetupRole, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: [commonLayer], + }); + + // Store the IAM auth setup Lambda ARN in SSM for other stacks to use + new StringParameter(scope, 'IamAuthSetupFnArnParam', { + parameterName: `${config.deploymentPrefix}/iamAuthSetupFnArn`, + stringValue: iamAuthSetupFn.functionArn, + description: 'ARN of the shared IAM auth setup Lambda for PGVector databases', + }); + + // Store the IAM auth setup Lambda role ARN in SSM for granting secret permissions + new StringParameter(scope, 'IamAuthSetupRoleArnParam', { + parameterName: `${config.deploymentPrefix}/iamAuthSetupRoleArn`, + stringValue: iamAuthSetupRole.roleArn, + description: 'ARN of the IAM auth setup Lambda role for granting secret permissions', + }); + + return iamAuthSetupFn; + } } diff --git a/lib/core/coreConstruct.ts b/lib/core/coreConstruct.ts index 831ade7d3..2284efad5 100644 --- a/lib/core/coreConstruct.ts +++ b/lib/core/coreConstruct.ts @@ -22,7 +22,7 @@ import { BaseProps } from '../schema'; import { RemovalPolicy, Stack, StackProps } from 'aws-cdk-lib'; import { COMMON_LAYER_PATH, FASTAPI_LAYER_PATH, AUTHORIZER_LAYER_PATH, CDK_LAYER_PATH } from '../util'; -import { Bucket } from 'aws-cdk-lib/aws-s3'; +import { BlockPublicAccess, Bucket, BucketAccessControl, BucketEncryption, ObjectOwnership } from 'aws-cdk-lib/aws-s3'; import { getNodeRuntime } from '../api-base/utils'; export const ARCHITECTURE = lambda.Architecture.X86_64; @@ -47,6 +47,10 @@ export class CoreConstruct extends Construct { autoDeleteObjects: config.removalPolicy === RemovalPolicy.DESTROY, bucketName: ([config.deploymentName, config.accountNumber, config.deploymentStage, 'bucket', 'access', 'logs'].join('-')).toLowerCase(), enforceSSL: true, + encryption: BucketEncryption.S3_MANAGED, + blockPublicAccess: BlockPublicAccess.BLOCK_ALL, + objectOwnership: ObjectOwnership.BUCKET_OWNER_PREFERRED, + accessControl: BucketAccessControl.LOG_DELIVERY_WRITE, }); new StringParameter(scope, 'LISABucketAccessLogsBucket', { diff --git a/lib/core/layers/fastapi/requirements.txt b/lib/core/layers/fastapi/requirements.txt index 488f78a38..0a768869d 100644 --- a/lib/core/layers/fastapi/requirements.txt +++ b/lib/core/layers/fastapi/requirements.txt @@ -1,4 +1,4 @@ -# boto3==1.36.0 // Provided by Lambda +# boto3==1.40.76 // Provided by Lambda # requests==2.32.5 // provided by Common Layer fastapi==0.124.2 mangum==0.19.0 diff --git a/lib/docs/.vitepress/config.mts b/lib/docs/.vitepress/config.mts index dd4852bf0..0c5ff99a0 100644 --- a/lib/docs/.vitepress/config.mts +++ b/lib/docs/.vitepress/config.mts @@ -27,6 +27,7 @@ const navLinks = [ { text: 'What is LISA', link: '/admin/getting-started#what-is-lisa' }, { text: 'Major Features', link: '/admin/getting-started#major-features' }, { text: 'Key Features & Benefits', link: '/admin/getting-started#key-features-benefits' }, + { text: 'Access Control', link: '/admin/getting-started#access-control' }, ] }, { diff --git a/lib/docs/admin/deploy.md b/lib/docs/admin/deploy.md index cb9e6463d..3a593cd01 100644 --- a/lib/docs/admin/deploy.md +++ b/lib/docs/admin/deploy.md @@ -384,7 +384,7 @@ cp -r ~/.cache/prisma* lib/serve/rest-api/PRISMA_CACHE/ ``` **Important Notes:** -- The cache is platform-specific. Generate it on a system matching your Docker base image (e.g., for `python:3.13-slim` which is Debian-based, so you may want to use a Debian-based system) +- The cache is platform-specific. Generate it on a system matching your Docker base image (e.g., for `public.ecr.aws/docker/library/python:3.13-slim` which is Debian-based, so you may want to use a Debian-based system) - The `prisma version` command downloads binaries for your current platform - Both `prisma/` and `prisma-python/` directories are required for offline operation diff --git a/lib/docs/admin/getting-started.md b/lib/docs/admin/getting-started.md index 6f4b0cde5..65c155c2b 100644 --- a/lib/docs/admin/getting-started.md +++ b/lib/docs/admin/getting-started.md @@ -96,3 +96,54 @@ flexibility for different use cases. *The below screenshot showcases LISA’s Model Management page. It is filtered to display the Claude models configured with LISA, although they are hosted by the Amazon Bedrock service. Via LISA’s Model Management page, Administrators configure self-hosted and externally hosted third party (3P) models with LISA. LISA is compatible with over 100 externally hosted models via the LiteLLM proxy. Administrators do not need to worry about the 3P model provider’s unique API requirements since LiteLLM handles the standardization.* ![LISA Model Management](../assets/LISA_Model_Mgmt.png) + +# Access Control + +LISA Roles and Enterprise Groups control access to features and resources. + +## Roles + +`AdminGroup` and `UserGroup` properties in the configuration are used to control tiers of application access, not resource access. + +- **AdminGroup**: The IDP group that distinguishes which users have access to create and manage restricted resource configuration within the UI, including: + - Activating application features + - Configuring models via Model Management + - Configuring repos and Collections via RAG management + - MCP server management + - MCP Workbench code editor + +- **UserGroup** (optional): If provided, this is required when the IDP is used for multiple systems and you want to control which users in the IDP have access to LISA. + +- **API Management** (v6.1+): A new role that allows users to manage their API tokens within LISA, but does not grant full Admin privileges. + +## Groups + +Access to resources can be constrained by Enterprise Groups, including: + +- LISA models +- Prompt templates +- RAG repos +- RAG collections +- MCP Connections +- LISA MCP servers +- API tokens + +You can create or bring any number of Enterprise Groups in your IDP, which can then be used in LISA to lock down resources at creation/update. When you create/update any resource, you can assign 0, 1, or many Groups to that resource. + +### Example: Group-Based Access Control + +For example, let's say your IDP has the following groups: **Team Red**, **Team White**, and **Team Blue**. Below shows how you can use Groups to lock down access to Models, and then RAG repos and their Collections: + +**Models:** +- Model 1: Teams Red and White +- Model 2: none (Global) +- Model 3: Team Blue + +**RAG Repositories and Collections:** +- RAG Repo 1: Teams Red, White, Blue + - Collection A: Team Red + - Collection B: Team White + - Collection C: Teams White and Blue +- RAG Repo 2: none (Global) + - Collection X: Team Blue + - Collection Y: none (Global) diff --git a/lib/mcp/mcp-server-api.ts b/lib/mcp/mcp-server-api.ts index b39b4894c..401587f4e 100644 --- a/lib/mcp/mcp-server-api.ts +++ b/lib/mcp/mcp-server-api.ts @@ -31,7 +31,7 @@ import { McpServerDeployer } from './mcp-server-deployer'; import { CreateMcpServerStateMachine } from './state-machine/create-mcp-server'; import { DeleteMcpServerStateMachine } from './state-machine/delete-mcp-server'; import { UpdateMcpServerStateMachine } from './state-machine/update-mcp-server'; -import { Bucket, HttpMethods } from 'aws-cdk-lib/aws-s3'; +import { Bucket, BucketEncryption, HttpMethods } from 'aws-cdk-lib/aws-s3'; import { RemovalPolicy } from 'aws-cdk-lib'; type McpServerApiProps = { @@ -101,7 +101,8 @@ export class McpServerApi extends Construct { }, ], serverAccessLogsBucket: bucketAccessLogsBucket, - serverAccessLogsPrefix: 'logs/mcp-hosting-bucket/' + serverAccessLogsPrefix: 'logs/mcp-hosting-bucket/', + encryption: BucketEncryption.S3_MANAGED }); // Get reference to REST API first (will be reused) diff --git a/lib/models/docker-image-builder.ts b/lib/models/docker-image-builder.ts index 0928dcbb2..e320a873e 100644 --- a/lib/models/docker-image-builder.ts +++ b/lib/models/docker-image-builder.ts @@ -26,7 +26,7 @@ import { } from 'aws-cdk-lib/aws-iam'; import { Code, Function } from 'aws-cdk-lib/aws-lambda'; import { Duration, Stack } from 'aws-cdk-lib'; -import { Bucket } from 'aws-cdk-lib/aws-s3'; +import { BlockPublicAccess, Bucket, BucketEncryption } from 'aws-cdk-lib/aws-s3'; import { BucketDeployment, Source } from 'aws-cdk-lib/aws-s3-deployment'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { createCdkId } from '../core/utils'; @@ -60,7 +60,10 @@ export class DockerImageBuilder extends Construct { const ec2DockerBucket = new Bucket(this, createCdkId([stackName, 'docker-image-builder-ec2-bucket']), { enforceSSL: true, + encryption: BucketEncryption.S3_MANAGED, + blockPublicAccess: BlockPublicAccess.BLOCK_ALL, serverAccessLogsBucket: bucketAccessLogsBucket, + serverAccessLogsPrefix: 'logs/docker-image-builder-bucket/', }); const ecsModelPath = ECS_MODEL_PATH; new BucketDeployment(this, createCdkId([stackName, 'docker-image-builder-ec2-dplmnt']), { diff --git a/lib/models/state-machine/create-model.ts b/lib/models/state-machine/create-model.ts index 4f22bacb6..a9d4718ef 100644 --- a/lib/models/state-machine/create-model.ts +++ b/lib/models/state-machine/create-model.ts @@ -197,6 +197,29 @@ export class CreateModelStateMachine extends Construct { time: POLLING_TIMEOUT, }); + const pollModelReady = new LambdaInvoke(this, 'PollModelReady', { + lambdaFunction: new Function(this, 'PollModelReadyFunc', { + runtime: getPythonRuntime(), + handler: 'models.state_machine.create_model.handle_poll_model_ready', + code: Code.fromAsset(lambdaPath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc.vpc, + vpcSubnets: vpc.subnetSelection, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, + }); + + const pollModelReadyChoice = new Choice(this, 'PollModelReadyChoice'); + + const waitBeforePollingModelReady = new Wait(this, 'WaitBeforePollingModelReady', { + time: POLLING_TIMEOUT, + }); + const createSchedule = new LambdaInvoke(this, 'CreateSchedule', { lambdaFunction: new Function(this, 'CreateScheduleFunc', { runtime: getPythonRuntime(), @@ -299,10 +322,17 @@ export class CreateModelStateMachine extends Construct { }); pollCreateStackChoice .when(Condition.booleanEquals('$.continue_polling_stack', true), waitBeforePollingCreateStack) - .otherwise(createSchedule); + .otherwise(pollModelReady); waitBeforePollingCreateStack.next(pollCreateStack); - // Create schedule after stack is created + // Poll for model instances to be healthy before proceeding + pollModelReady.next(pollModelReadyChoice); + pollModelReadyChoice + .when(Condition.booleanEquals('$.continue_polling_capacity', true), waitBeforePollingModelReady) + .otherwise(createSchedule); + waitBeforePollingModelReady.next(pollModelReady); + + // Create schedule after model is ready createSchedule.next(addModelToLitellm); // Check for guardrails and add them if present diff --git a/lib/rag/ingestion/ingestion-image/Dockerfile b/lib/rag/ingestion/ingestion-image/Dockerfile index 0c7c18c05..9fad16343 100644 --- a/lib/rag/ingestion/ingestion-image/Dockerfile +++ b/lib/rag/ingestion/ingestion-image/Dockerfile @@ -1,6 +1,22 @@ ARG BASE_IMAGE=public.ecr.aws/lambda/python:3.13 FROM ${BASE_IMAGE} +# Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) +RUN mkdir -p /etc/ssh && \ + echo "" >> /etc/ssh/ssh_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/ssh_config && \ + echo "Host *" >> /etc/ssh/ssh_config && \ + echo " Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/ssh_config && \ + echo " MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/ssh_config && \ + echo " KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/ssh_config && \ + if [ -f /etc/ssh/sshd_config ]; then \ + echo "" >> /etc/ssh/sshd_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/sshd_config && \ + echo "Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/sshd_config && \ + echo "MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/sshd_config && \ + echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ + fi + ARG BUILD_DIR=build WORKDIR /workdir diff --git a/lib/rag/ingestion/ingestion-image/requirements.txt b/lib/rag/ingestion/ingestion-image/requirements.txt index e89238acc..0fd25a55d 100644 --- a/lib/rag/ingestion/ingestion-image/requirements.txt +++ b/lib/rag/ingestion/ingestion-image/requirements.txt @@ -4,11 +4,9 @@ # NumPy 2.x has pre-built wheels for Python 3.13 numpy>=2.1.0 -# AWS SDK - Version constrained by litellm[proxy]==1.80.9 in rest-api -# Standardized to boto3==1.36.0 for compatibility across all components -aioboto3==13.4.0 -aiobotocore==2.18.0 -boto3==1.36.0 +# AWS SDK - Version constrained by litellm[proxy]==1.81.3 in rest-api +# Standardized to boto3==1.40.76 for compatibility across all components +boto3==1.40.76 aiohttp==3.13.2 click==8.3.1 @@ -17,8 +15,8 @@ fastapi_utils==0.8.0 fastapi==0.124.2 gunicorn==23.0.0 langchain-community==0.4.1 -langchain-core==1.1.3 -langchain-text-splitters==1.0.0 +langchain-core==1.2.7 +langchain-text-splitters==1.1.0 loguru==0.7.3 mangum==0.19.0 opensearch-py==3.1.0 @@ -34,6 +32,6 @@ python-docx==1.2.0 requests-aws4auth==1.3.1 requests==2.32.5 text-generation==0.7.0 -# ASGI Server - Version constrained by litellm[proxy]==1.80.9 in rest-api +# ASGI Server - Version constrained by litellm[proxy]==1.81.3 in rest-api # Standardized to 0.38.0 for compatibility across all components uvicorn==0.38.0 diff --git a/lib/rag/layer/requirements.txt b/lib/rag/layer/requirements.txt index 0569dd077..a495cd203 100644 --- a/lib/rag/layer/requirements.txt +++ b/lib/rag/layer/requirements.txt @@ -3,9 +3,9 @@ # Core RAG packages # psycopg2-binary==2.9.11 // provided by Common Layer -langchain-text-splitters==1.0.0 +langchain-text-splitters==1.1.0 langchain-community==0.4.1 -langchain-core==1.1.3 +langchain-core==1.2.7 # Required by langchain-community # Python 3.13+ requires numpy>=2.1.0 numpy>=2.1.0 diff --git a/lib/rag/ragConstruct.ts b/lib/rag/ragConstruct.ts index c84075416..07018bc2b 100644 --- a/lib/rag/ragConstruct.ts +++ b/lib/rag/ragConstruct.ts @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -import { CfnOutput, Duration, RemovalPolicy, Stack, StackProps } from 'aws-cdk-lib'; +import { CfnOutput, RemovalPolicy, Stack, StackProps } from 'aws-cdk-lib'; import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; -import { ISecurityGroup, IVpc, SubnetSelection } from 'aws-cdk-lib/aws-ec2'; -import { Code, Function, IFunction, ILayerVersion, LayerVersion } from 'aws-cdk-lib/aws-lambda'; -import { Bucket, HttpMethods } from 'aws-cdk-lib/aws-s3'; +import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { ILayerVersion, LayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { Bucket, BucketEncryption, HttpMethods } from 'aws-cdk-lib/aws-s3'; import * as dynamodb from 'aws-cdk-lib/aws-dynamodb'; import { AttributeType, BillingMode, StreamViewType, Table, TableEncryption } from 'aws-cdk-lib/aws-dynamodb'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; @@ -28,12 +28,12 @@ import { ARCHITECTURE } from '../core'; import { Layer } from '../core/layers'; import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; -import { APP_MANAGEMENT_KEY, BaseProps, Config, RDSConfig } from '../schema'; +import { APP_MANAGEMENT_KEY, BaseProps, Config } from '../schema'; import { SecurityGroupEnum } from '../core/iam/SecurityGroups'; import { SecurityGroupFactory } from '../networking/vpc/security-group-factory'; import { Roles } from '../core/iam/roles'; import { VectorStoreCreatorStack as VectorStoreCreator } from './vector-store/vector-store-creator'; -import { AnyPrincipal, CfnServiceLinkedRole, Effect, IRole, PolicyDocument, PolicyStatement, Role, ServicePrincipal } from 'aws-cdk-lib/aws-iam'; +import { AnyPrincipal, CfnServiceLinkedRole, Effect, IRole, PolicyStatement, Role } from 'aws-cdk-lib/aws-iam'; import { IAMClient, ListRolesCommand } from '@aws-sdk/client-iam'; import { RagRepositoryConfig, RagRepositoryType } from '../schema'; import { Domain, EngineVersion, IDomain } from 'aws-cdk-lib/aws-opensearchservice'; @@ -41,13 +41,12 @@ import { ISecret, Secret } from 'aws-cdk-lib/aws-secretsmanager'; import { Credentials, DatabaseInstance, DatabaseInstanceEngine } from 'aws-cdk-lib/aws-rds'; import { LegacyIngestPipelineStateMachine } from './state_machine/legacy-ingest-pipeline'; import * as customResources from 'aws-cdk-lib/custom-resources'; -import DynamoDB from 'aws-sdk/clients/dynamodb'; +import { marshall } from '@aws-sdk/util-dynamodb'; import * as readlineSync from 'readline-sync'; -import { LAMBDA_PATH, RAG_LAYER_PATH } from '../util'; +import { RAG_LAYER_PATH } from '../util'; import { IngestionStack } from './ingestion/ingestion-stack'; import * as child_process from 'child_process'; import * as path from 'path'; -import { getPythonRuntime } from '../api-base/utils'; import { AwsCustomResource, PhysicalResourceId } from 'aws-cdk-lib/custom-resources'; export type LisaRagProps = { @@ -107,7 +106,8 @@ export class LisaRagConstruct extends Construct { }, ], serverAccessLogsBucket: bucketAccessLogsBucket, - serverAccessLogsPrefix: 'logs/rag-bucket/' + serverAccessLogsPrefix: 'logs/rag-bucket/', + encryption: BucketEncryption.S3_MANAGED }); const ragTableName = createCdkId([config.deploymentName, 'RagDocumentTable']); @@ -363,14 +363,14 @@ export class LisaRagConstruct extends Construct { value: ragRepositoryConfigTable.tableArn }); - // Create SSM parameter for vector store table name so other stacks can optionally reference it + // Create SSM parameter for vector store table name so other stacks can optionally reference it. new StringParameter(scope, createCdkId(['RagVectorStoreTableName', 'Parameter']), { parameterName: `${config.deploymentPrefix}/ragVectorStoreTableName`, stringValue: ragRepositoryConfigTable.tableName, description: 'RAG Vector Store (Repository Config) DynamoDB table name', }); - // Create SSM parameter for collections table name so other stacks can optionally reference it + // Create SSM parameter for collections table name so other stacks can optionally reference it. new StringParameter(scope, createCdkId(['RagCollectionsTableName', 'Parameter']), { parameterName: `${config.deploymentPrefix}/ragCollectionsTableName`, stringValue: collectionsTable.tableName, @@ -537,42 +537,46 @@ export class LisaRagConstruct extends Construct { openSearchEndpointPs.node.addDependency(openSearchDomain); openSearchEndpointPs.grantRead(lambdaRole); } else if (ragConfig.type === RagRepositoryType.PGVECTOR && ragConfig.rdsConfig) { - let rdsPasswordSecret: ISecret; + // Determine authentication method - default to IAM auth (iamRdsAuth = false) + const useIamAuth = config.iamRdsAuth ?? false; + + let rdsSecret: ISecret; let rdsConnectionInfoPs: StringParameter; - // if dbHost and passwordSecretId are defined, then connect to DB with existing params - if (!!ragConfig.rdsConfig.dbHost && !!ragConfig.rdsConfig.passwordSecretId) { + let pgvector_db: DatabaseInstance | undefined; + + // if dbHost and passwordSecretId are defined, connect to existing DB + if (ragConfig.rdsConfig.dbHost && ragConfig.rdsConfig.passwordSecretId) { rdsConnectionInfoPs = new StringParameter(this.scope, createCdkId([connectionParamName, ragConfig.repositoryId, 'StringParameter']), { parameterName: `${config.deploymentPrefix}/${connectionParamName}/${ragConfig.repositoryId}`, stringValue: JSON.stringify({ - ...(config.iamRdsAuth ? {} : { - passwordSecretId: ragConfig.rdsConfig?.passwordSecretId - }), username: ragConfig.rdsConfig?.username, dbHost: ragConfig.rdsConfig?.dbHost, dbName: ragConfig.rdsConfig?.dbName, dbPort: ragConfig.rdsConfig?.dbPort, - type: RagRepositoryType.PGVECTOR + type: RagRepositoryType.PGVECTOR, + // Include passwordSecretId only when using password auth + ...(!useIamAuth ? { passwordSecretId: ragConfig.rdsConfig?.passwordSecretId } : {}) }), description: 'Connection info for LISA Serve PGVector database', }); - rdsPasswordSecret = Secret.fromSecretNameV2( + rdsSecret = Secret.fromSecretNameV2( this.scope, - createCdkId([config.deploymentName, ragConfig.repositoryId, 'RagRDSPwdSecret']), + createCdkId([config.deploymentName, ragConfig.repositoryId, 'RagRDSSecret']), ragConfig.rdsConfig.passwordSecretId, ); } else { const username = ragConfig.rdsConfig.username; const dbCreds = Credentials.fromGeneratedSecret(username); - const pgvector_db = new DatabaseInstance(this.scope, createCdkId([ragConfig.repositoryId, 'PGVectorDB']), { + pgvector_db = new DatabaseInstance(this.scope, createCdkId([ragConfig.repositoryId, 'PGVectorDB']), { engine: DatabaseInstanceEngine.POSTGRES, vpc: vpc.vpc, subnetGroup: vpc.subnetGroup, credentials: dbCreds, - iamAuthentication: true, + iamAuthentication: useIamAuth, // Enable IAM auth when iamRdsAuth is false securityGroups: [securityGroups.pgvector], removalPolicy: RemovalPolicy.DESTROY, }); - rdsPasswordSecret = pgvector_db.secret!; + rdsSecret = pgvector_db.secret!; rdsConnectionInfoPs = new StringParameter(this.scope, createCdkId([connectionParamName, ragConfig.repositoryId, 'StringParameter']), { parameterName: `${config.deploymentPrefix}/${connectionParamName}/${ragConfig.repositoryId}`, stringValue: JSON.stringify({ @@ -581,53 +585,98 @@ export class LisaRagConstruct extends Construct { dbPort: ragConfig.rdsConfig.dbPort, type: RagRepositoryType.PGVECTOR, username: username, - ...(config.iamRdsAuth ? {} : { passwordSecretId: rdsPasswordSecret.secretName }), + // Include passwordSecretId only when using password auth + ...(!useIamAuth ? { passwordSecretId: rdsSecret.secretName } : {}) }), description: 'Connection info for LISA Serve PGVector database', }); - if (config.iamRdsAuth) { - // grant the role permissions to connect as the IAM role itself - pgvector_db.grantConnect(lambdaRole, lambdaRole.roleName); + if (!useIamAuth) { + // Password auth: secret read access granted below (grantConnect requires IAM auth) } else { - // grant the role permissions to connect as the postgres user - pgvector_db.grantConnect(lambdaRole); + // IAM auth: manually grant rds-db:connect permission + // Note: We do NOT use pgvector_db.grantConnect() due to CDK bug #11851 + // The grantConnect method generates incorrect ARN format (uses rds: instead of rds-db:) + // Per AWS docs: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html + // The correct format is: arn:aws:rds-db:region:account-id:dbuser:DbiResourceId/db-user-name + lambdaRole.addToPrincipalPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['rds-db:connect'], + resources: [ + // Use wildcard for DbiResourceId since it's not available in CloudFormation + // Format: arn:aws:rds-db:region:account:dbuser:*/username + `arn:${config.partition}:rds-db:${config.region}:${config.accountNumber}:dbuser:*/${lambdaRole.roleName}` + ] + })); } + + // Update ragConfig with the endpoint address for use in AwsCustomResource + ragConfig.rdsConfig.dbHost = pgvector_db.dbInstanceEndpointAddress; } - if (config.iamRdsAuth) { - // Create the lambda for generating DB users for IAM auth - const createDbUserLambda = this.getIAMAuthLambda(config, ragConfig.repositoryId, ragConfig.rdsConfig!, rdsPasswordSecret, lambdaRole.roleName, vpc.vpc, [securityGroups.pgvector], vpc.subnetSelection); + if (!useIamAuth) { + // Password auth: grant secret read access + rdsSecret.grantRead(lambdaRole); + } else { + // Use the shared IAM auth setup Lambda from API Base stack + const iamAuthSetupFnArn = StringParameter.valueForStringParameter( + this.scope, + `${config.deploymentPrefix}/iamAuthSetupFnArn` + ); - const customResourceRole = new Role(this.scope, createCdkId(['CustomResourceRole', ragConfig.repositoryId]), { - assumedBy: new ServicePrincipal('lambda.amazonaws.com'), - }); - createDbUserLambda.grantInvoke(customResourceRole); - - // run updateInstanceKmsConditionsLambda every deploy - new AwsCustomResource(this.scope, createCdkId([ragConfig.repositoryId, 'CreateDbUserCustomResource']), { - onCreate: { - service: 'Lambda', - action: 'invoke', - physicalResourceId: PhysicalResourceId.of(createCdkId([ragConfig.repositoryId, 'CreateDbUserCustomResource'])), - parameters: { - FunctionName: createDbUserLambda.functionName, - Payload: '{}' - }, - }, - onUpdate: { - service: 'Lambda', - action: 'invoke', - physicalResourceId: PhysicalResourceId.of(createCdkId([ragConfig.repositoryId, 'CreateDbUserCustomResource'])), - parameters: { - FunctionName: createDbUserLambda.functionName, - Payload: '{}' - }, + // Get the IAM auth setup Lambda role ARN from SSM to grant it permissions + const iamAuthSetupRoleArn = StringParameter.valueForStringParameter( + this.scope, + `${config.deploymentPrefix}/iamAuthSetupRoleArn` + ); + + // Import the IAM auth setup role to grant it secret permissions + const iamAuthSetupRole = Role.fromRoleArn( + this.scope, + createCdkId([ragConfig.repositoryId, 'IamAuthSetupRoleRef']), + iamAuthSetupRoleArn + ); + + // Grant the IAM auth setup Lambda role permission to read the bootstrap secret + rdsSecret.grantRead(iamAuthSetupRole); + + // Run the shared IAM auth setup Lambda on create and update + // Pass parameters via payload since the Lambda is shared + // Use Stack.of(this.scope).toJsonString() to properly resolve CDK tokens in the payload + const lambdaInvokeParams = { + service: 'Lambda', + action: 'invoke', + physicalResourceId: PhysicalResourceId.of(createCdkId([ragConfig.repositoryId, 'CreateDbUserCustomResource'])), + parameters: { + FunctionName: iamAuthSetupFnArn, + Payload: Stack.of(this.scope).toJsonString({ + secretArn: rdsSecret.secretArn, + dbHost: ragConfig.rdsConfig!.dbHost, + dbPort: ragConfig.rdsConfig!.dbPort, + dbName: ragConfig.rdsConfig!.dbName, + dbUser: ragConfig.rdsConfig!.username, + iamName: lambdaRole.roleName, + }) }, - role: customResourceRole + }; + + const createDbUserResource = new AwsCustomResource(this.scope, createCdkId([ragConfig.repositoryId, 'CreateDbUserCustomResource']), { + onCreate: lambdaInvokeParams, + onUpdate: lambdaInvokeParams, // Also run on updates to ensure IAM user is created + policy: customResources.AwsCustomResourcePolicy.fromStatements([ + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['lambda:InvokeFunction'], + resources: [iamAuthSetupFnArn], + }) + ]), }); - } else { - rdsPasswordSecret.grantRead(lambdaRole); + + // Ensure the RDS instance is fully available before running IAM auth setup + // (only when we created a new RDS instance) + if (pgvector_db) { + createDbUserResource.node.addDependency(pgvector_db); + } } rdsConnectionInfoPs.grantRead(lambdaRole); @@ -639,8 +688,8 @@ export class LisaRagConstruct extends Construct { const createOrUpdateParameters = { TableName: ragRepositoryConfigTable.tableName, - Item: this.toDynamoDBItem({ - repositoryId: ragConfig.repositoryId, // Partition key value + Item: marshall({ + repositoryId: ragConfig.repositoryId, status: 'CREATE_COMPLETE', config: ragConfig, legacy: true @@ -665,7 +714,7 @@ export class LisaRagConstruct extends Construct { action: 'deleteItem', parameters: { TableName: ragRepositoryConfigTable.tableName, - Key: this.toDynamoDBItem({ repositoryId: ragConfig.repositoryId }), + Key: marshall({ repositoryId: ragConfig.repositoryId }), }, }, policy: customResources.AwsCustomResourcePolicy.fromSdkCalls({ @@ -721,79 +770,6 @@ export class LisaRagConstruct extends Construct { } } - getIAMAuthLambda (config: Config, repositoryId: string, rdsConfig: NonNullable, secret: ISecret, user: string, vpc: IVpc, securityGroups: ISecurityGroup[], vpcSubnets?: SubnetSelection): IFunction { - // Create the IAM role for updating the database to allow IAM authentication - const iamAuthLambdaRole = new Role(this.scope, createCdkId([repositoryId, 'IAMAuthLambdaRole']), { - assumedBy: new ServicePrincipal('lambda.amazonaws.com'), - inlinePolicies: { - 'EC2NetworkInterfaces': new PolicyDocument({ - statements: [ - new PolicyStatement({ - effect: Effect.ALLOW, - actions: ['ec2:CreateNetworkInterface', 'ec2:DescribeNetworkInterfaces', 'ec2:DeleteNetworkInterface'], - resources: ['*'], - }), - ], - }), - }, - }); - - secret.grantRead(iamAuthLambdaRole); - - const commonLayer = this.getLambdaLayer(repositoryId, config); - const lambdaPath = config.lambdaPath || LAMBDA_PATH; - - return new Function(this.scope, createCdkId([repositoryId, 'CreateDbUserLambda']), { - runtime: getPythonRuntime(), - handler: 'utilities.db_setup_iam_auth.handler', - code: Code.fromAsset(lambdaPath), - timeout: Duration.minutes(2), - environment: { - SECRET_ARN: secret.secretArn, // ARN of the RDS secret - DB_HOST: rdsConfig.dbHost!, - DB_PORT: String(rdsConfig.dbPort), // Default PostgreSQL port - DB_NAME: rdsConfig.dbName, // Database name - DB_USER: rdsConfig.username, // Admin user for RDS - IAM_NAME: user, // IAM role for Lambda execution - }, - role: iamAuthLambdaRole, // Lambda execution role - layers: [commonLayer], - vpc, - securityGroups, - vpcSubnets - }); - } - - getLambdaLayer (repositoryId: string, config: Config): ILayerVersion { - return LayerVersion.fromLayerVersionArn( - this.scope, - createCdkId([repositoryId, 'CommonLayerVersion']), - StringParameter.valueForStringParameter(this.scope, `${config.deploymentPrefix}/layerVersion/common`), - ); - } - - toDynamoDBItem (obj: Record): DynamoDB.PutItemInputAttributeMap { - const dynamoItem: DynamoDB.PutItemInputAttributeMap = {}; - - for (const [key, value] of Object.entries(obj)) { - if (typeof value === 'string') { - dynamoItem[key] = { S: value }; - } else if (typeof value === 'number') { - dynamoItem[key] = { N: value.toString() }; - } else if (typeof value === 'boolean') { - dynamoItem[key] = { BOOL: value }; - } else if (Array.isArray(value)) { - dynamoItem[key] = { L: value.map((item) => this.toDynamoDBItem({ item })['item']) }; - } else if (typeof value === 'object' && value !== null) { - dynamoItem[key] = { M: this.toDynamoDBItem(value) }; - } else if (value === null) { - dynamoItem[key] = { NULL: true }; - } - } - - return dynamoItem; - } - /** * This method links the OpenSearch Service role to the service-linked role if it exists. * If the role doesn't exist, it will be created. diff --git a/lib/rag/vector-store/state_machine/create-store.ts b/lib/rag/vector-store/state_machine/create-store.ts index 185fb5ab7..2c78f227c 100644 --- a/lib/rag/vector-store/state_machine/create-store.ts +++ b/lib/rag/vector-store/state_machine/create-store.ts @@ -126,6 +126,15 @@ export class CreateStoreStateMachine extends Construct { }, }); + // Fail state to mark the state machine execution as failed + const failExecution = new sfn.Fail(this, 'FailExecution', { + cause: 'Vector store deployment failed', + error: 'DeploymentFailed', + }); + + // Chain failure status update to fail state + updateFailureStatus.next(failExecution); + // Check if this is a Bedrock KB repository to create default collections const skipCollectionCreation = new sfn.Pass(this, 'SkipCollectionCreation'); diff --git a/lib/rag/vector-store/state_machine/delete-store.ts b/lib/rag/vector-store/state_machine/delete-store.ts index 5d6419baa..9a21fb3bf 100644 --- a/lib/rag/vector-store/state_machine/delete-store.ts +++ b/lib/rag/vector-store/state_machine/delete-store.ts @@ -113,6 +113,15 @@ export class DeleteStoreStateMachine extends Construct { ':status': tasks.DynamoAttributeValue.fromString('$.checkResult.status'), }, }); + + // Fail state to mark the state machine execution as failed + const failExecution = new sfn.Fail(this, 'FailExecution', { + cause: 'Vector store deletion failed', + error: 'DeletionFailed', + }); + + // Chain failure status update to fail state + updateFailureStatus.next(failExecution); // Task to update the status of the vector store entry to 'COMPLETED' on successful deployment const updateDeleteStatus = new tasks.DynamoUpdateItem(this, 'UpdateDeleteStatus', { table: ragVectorStoreTable, diff --git a/lib/rag/vector-store/vector-store-creator.ts b/lib/rag/vector-store/vector-store-creator.ts index ffd4c3ce3..b5ea338f9 100644 --- a/lib/rag/vector-store/vector-store-creator.ts +++ b/lib/rag/vector-store/vector-store-creator.ts @@ -88,24 +88,66 @@ export class VectorStoreCreatorStack extends Construct { })); // IAM: manage roles created within the dynamic stacks and allow passing to services + // + // Security Strategy (Findings #1, #8, #13 - IAM Privilege Escalation Prevention): + // + // 1. Self-Targeting Prevention: Use ArnNotEquals condition to prevent the VectorStoreCreator + // role from modifying itself via AttachRolePolicy, DetachRolePolicy, PutRolePolicy, or + // DeleteRolePolicy actions. This prevents privilege escalation where the role could grant + // itself additional permissions. + // + // 2. Resource Pattern Restriction: Limit role creation and management actions to roles that + // follow the vector store naming pattern. This ensures the role can only manage roles + // created by the vector store deployer, not arbitrary IAM roles in the account. + // + // 3. CDK Bootstrap Role Assumption: AssumeRole is restricted to CDK bootstrap roles only, + // preventing assumption of roles with privilege escalation risks. + // + // Note: Tag-based conditions were considered but not used due to lack of support in all + // AWS regions. ARN pattern matching provides equivalent security with broader compatibility. + // + // Restrict permission mutation actions to prevent self-targeting (Security Finding #1, #8, #13) cdkRole.addToPolicy(new iam.PolicyStatement({ actions: [ - 'iam:CreateRole', - 'iam:DeleteRole', 'iam:AttachRolePolicy', 'iam:DetachRolePolicy', 'iam:PutRolePolicy', 'iam:DeleteRolePolicy', - 'iam:TagRole', - 'iam:UntagRole', + ], + resources: ['*'], + conditions: { + ArnNotEquals: { + // Prevent the role from modifying itself + 'iam:ResourceArn': cdkRole.roleArn + } + } + })); + + // IAM: manage roles created by vector store deployer (restricted to naming pattern) + cdkRole.addToPolicy(new iam.PolicyStatement({ + actions: [ + 'iam:CreateRole', + 'iam:DeleteRole', 'iam:GetRole', 'iam:GetRolePolicy', 'iam:ListRolePolicies', 'iam:ListAttachedRolePolicies', - 'iam:ListRoleTags', + 'iam:TagRole', + 'iam:UntagRole', 'iam:UpdateAssumeRolePolicy', - 'iam:ListRoles' + 'iam:ListRoleTags', + ], + resources: [ + // Roles created by vector store deployer follow this pattern: + // ${appName}-${deploymentName}-${deploymentStage}-vector-store-${repositoryId}-* + `arn:${config.partition}:iam::${config.accountNumber}:role/${config.appName}-${config.deploymentName}-${config.deploymentStage}-vector*`, + `arn:${config.partition}:iam::${config.accountNumber}:role/${config.deploymentName}-${config.appName}-${config.deploymentStage}-vector*`, ], + })); + + // IAM: ListRoles requires wildcard resource (read-only operation) + cdkRole.addToPolicy(new iam.PolicyStatement({ + actions: ['iam:ListRoles'], resources: ['*'], })); @@ -146,6 +188,7 @@ export class VectorStoreCreatorStack extends Construct { 'deploymentName', 'deploymentStage', 'deploymentPrefix', + 'iamRdsAuth', 'partition', 'region', 'removalPolicy', diff --git a/lib/schema/configSchema.ts b/lib/schema/configSchema.ts index 9fea711f8..7bcefc73e 100644 --- a/lib/schema/configSchema.ts +++ b/lib/schema/configSchema.ts @@ -43,6 +43,7 @@ export type SecurityGroups = { export enum ModelType { TEXTGEN = 'textgen', EMBEDDING = 'embedding', + VIDEOGEN = 'videogen', } /** @@ -403,12 +404,15 @@ export const VALID_INSTANCE_KEYS = Ec2Metadata.getValidInstanceKeys() as [string const ContainerHealthCheckConfigSchema = z.object({ command: z.array(z.string()).default(['CMD-SHELL', 'exit 0']).describe('The command to run for health checks'), interval: z.number().default(10).describe('The time interval between health checks, in seconds.'), - startPeriod: z.number().default(30).describe('The time to wait before starting the first health check, in seconds.'), + startPeriod: z.number().default(300).describe('The time to wait before starting the first health check, in seconds. Default 600s (10 min) to allow for large model loading.'), timeout: z.number().default(5).describe('The maximum time allowed for each health check to complete, in seconds'), - retries: z.number().default(2).describe('The number of times to retry a failed health check before considering the container as unhealthy.'), + retries: z.number().default(3).describe('The number of times to retry a failed health check before considering the container as unhealthy.'), }) .describe('Configuration for container health checks'); +export { ContainerHealthCheckConfigSchema }; +export type ContainerHealthCheckConfig = z.infer; + export const ImageTarballAsset = z.object({ path: z.string(), type: z.literal(EcsSourceType.TARBALL) @@ -467,7 +471,7 @@ export const ContainerConfigSchema = z.object({ export type ContainerConfig = z.infer; -const HealthCheckConfigSchema = z.object({ +export const LoadBalancerHealthCheckConfigSchema = z.object({ path: z.string().describe('Path for the health check.'), interval: z.number().default(30).describe('Interval in seconds between health checks.'), timeout: z.number().default(10).describe('Timeout in seconds for each health check.'), @@ -476,16 +480,18 @@ const HealthCheckConfigSchema = z.object({ }) .describe('Health check configuration for the load balancer.'); +export type LoadBalancerHealthCheckConfig = z.infer; + export const LoadBalancerConfigSchema = z.object({ sslCertIamArn: z.string().nullish().default(null).describe('SSL certificate IAM ARN for load balancer.'), - healthCheckConfig: HealthCheckConfigSchema, + healthCheckConfig: LoadBalancerHealthCheckConfigSchema, domainName: z.string().nullish().default(null).describe('Domain name to use instead of the load balancer\'s default DNS name.'), }) .describe('Configuration for load balancer settings.'); export const MetricConfigSchema = z.object({ - albMetricName: z.string().describe('Name of the ALB metric.'), - targetValue: z.number().describe('Target value for the metric.'), + albMetricName: z.string().default('RequestCountPerTarget').describe('Name of the ALB metric.'), + targetValue: z.number().default(30).describe('Target value for the metric.'), duration: z.number().default(60).describe('Duration in seconds for metric evaluation.'), estimatedInstanceWarmup: z.number().min(0).default(180).describe('Estimated warm-up time in seconds until a newly launched instance can send metrics to CloudWatch.'), }) @@ -625,6 +631,8 @@ export const EcsClusterConfigSchema = z autoScalingConfig: AutoScalingConfigSchema, loadBalancerConfig: LoadBalancerConfigSchema, localModelCode: z.string().default('/opt/model-code'), + containerMemoryBuffer: z.number().default(1024 * 2) + .describe('Memory in MiB to reserve for the host OS/ECS agent. Container gets (instance memory - buffer). Default: 2048 MiB'), modelHosting: z .string() .default('ecs') @@ -634,15 +642,16 @@ export const EcsClusterConfigSchema = z }) .refine( (data) => { - // 'textgen' type must have boolean streaming, 'embedding' type must have null streaming + // 'textgen' type must have boolean streaming, 'embedding' and 'videogen' types must have null streaming const isValidForTextgen = data.modelType === 'textgen' && typeof data.streaming === 'boolean'; const isValidForEmbedding = data.modelType === 'embedding' && data.streaming === null; + const isValidForVideogen = data.modelType === 'videogen' && data.streaming === null; - return isValidForTextgen || isValidForEmbedding; + return isValidForTextgen || isValidForEmbedding || isValidForVideogen; }, { message: `For 'textgen' models, 'streaming' must be true or false. - For 'embedding' models, 'streaming' must not be set.`, + For 'embedding' and 'videogen' models, 'streaming' must not be set.`, path: ['streaming'], }, ); @@ -722,12 +731,12 @@ const FastApiContainerConfigSchema = z.object({ }) .refine( (config) => { - return !config.dbHost && !config.passwordSecretId; + return !config.dbHost; }, { message: 'We do not allow using an existing DB for LiteLLM because of its requirement in internal model management ' + - 'APIs. Please do not define the dbHost or passwordSecretId fields for the FastAPI container DB config.', + 'APIs. Please do not define the dbHost field for the FastAPI container DB config.', }, ), }).describe('Configuration schema for REST API.'); @@ -870,7 +879,7 @@ export const RawConfigObject = z.object({ indexUrl: '', trustedHost: '', }).describe('Pypi configuration.'), - baseImage: z.string().default('python:3.13-slim').describe('Base image used for LISA serve components'), + baseImage: z.string().default('public.ecr.aws/docker/library/python:3.13-slim').describe('Base image used for LISA serve components'), nodejsImage: z.string().default('public.ecr.aws/lambda/nodejs:24').describe('Base image used for LISA NodeJS lambda deployments'), condaUrl: z.string().default('').describe('Conda URL configuration'), certificateAuthorityBundle: z.string().default('').describe('Certificate Authority Bundle file'), @@ -921,7 +930,8 @@ export const RawConfigObject = z.object({ bootstrapRolePrefix: z.string().optional().describe('Prefix for CDK bootstrap role names. Useful when roles have custom prefixes like My_User_Roles_. Leave empty for standard role names.'), litellmConfig: LiteLLMConfig, convertInlinePoliciesToManaged: z.boolean().optional().default(false).describe('Convert inline policies to managed policies'), - iamRdsAuth: z.boolean().optional().default(false).describe('Enable IAM authentication for RDS'), + iamRdsAuth: z.boolean().optional().default(false) + .describe('Enable IAM authentication for RDS. When true (default), IAM authentication is used and the bootstrap password is deleted after setup. When false, password-based authentication is used. WARNING: Switching from true to false after deployment is not supported - the master password is permanently deleted when IAM auth is enabled. This is a one-way migration.'), }); export const RawConfigSchema = RawConfigObject diff --git a/lib/schema/ragSchema.ts b/lib/schema/ragSchema.ts index 4f5b21b2c..06f700d47 100644 --- a/lib/schema/ragSchema.ts +++ b/lib/schema/ragSchema.ts @@ -148,12 +148,13 @@ export type OpenSearchConfig = export const RdsInstanceConfig = z.object({ username: z.string().default('postgres').describe('The username used for database connection.'), - passwordSecretId: z.string().optional().describe('The SecretsManager Secret ID that stores the existing database password.'), + passwordSecretId: z.string().optional().describe('The SecretsManager Secret ID that stores the existing database password. Only used when iamRdsAuth is false.'), dbHost: z.string().optional().describe('The database hostname for the existing database instance.'), dbName: z.string().default('postgres').describe('The name of the database for the database instance.'), dbPort: z.number().default(5432).describe('The port of the existing database instance or the port to be opened on the database instance.'), }).describe('Configuration schema for RDS Instances needed for LiteLLM scaling or PGVector RAG operations.\n \n ' + - 'The optional fields can be omitted to create a new database instance, otherwise fill in all fields to use an existing database instance.'); + 'The optional fields can be omitted to create a new database instance, otherwise fill in all fields to use an existing database instance. ' + + 'By default, IAM authentication is used. Set iamRdsAuth to false in config to use password-based authentication.'); export type RdsConfig = z.infer; @@ -163,7 +164,7 @@ export const RagRepositoryMetadata = MetadataSchema.extend({ customFields: z.record(z.string(), z.any()).optional().describe('Custom metadata fields for the repository.'), }); -const BaseRagRepositoryConfigSchema = z.object({ +export const BaseRagRepositoryConfigSchema = z.object({ repositoryId: z.string() .nonempty() .regex(/^[a-z0-9-]{3,20}/, 'Only lowercase alphanumeric characters and \'-\' are supported.') diff --git a/lib/serve/ecs-model/embedding/instructor/Dockerfile b/lib/serve/ecs-model/embedding/instructor/Dockerfile index d8443a5f1..bbde17aa5 100644 --- a/lib/serve/ecs-model/embedding/instructor/Dockerfile +++ b/lib/serve/ecs-model/embedding/instructor/Dockerfile @@ -1,6 +1,22 @@ -ARG BASE_IMAGE=python:3.13-slim +ARG BASE_IMAGE=public.ecr.aws/docker/library/python:3.13-slim FROM ${BASE_IMAGE} +# Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) +RUN mkdir -p /etc/ssh && \ + echo "" >> /etc/ssh/ssh_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/ssh_config && \ + echo "Host *" >> /etc/ssh/ssh_config && \ + echo " Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/ssh_config && \ + echo " MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/ssh_config && \ + echo " KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/ssh_config && \ + if [ -f /etc/ssh/sshd_config ]; then \ + echo "" >> /etc/ssh/sshd_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/sshd_config && \ + echo "Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/sshd_config && \ + echo "MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/sshd_config && \ + echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ + fi + #### POINT TO NEW PYPI CONFIG ARG PYPI_INDEX_URL ARG PYPI_TRUSTED_HOST diff --git a/lib/serve/ecs-model/embedding/instructor/src/inference.py b/lib/serve/ecs-model/embedding/instructor/src/inference.py index a8dc65043..fd5d88c7a 100644 --- a/lib/serve/ecs-model/embedding/instructor/src/inference.py +++ b/lib/serve/ecs-model/embedding/instructor/src/inference.py @@ -13,7 +13,7 @@ # limitations under the License. """Inference handler.""" -from typing import Any, Dict +from typing import Any import torch from InstructorEmbedding import INSTRUCTOR @@ -36,7 +36,7 @@ def model_fn(model_dir: str) -> INSTRUCTOR: return model -def predict_fn(data: Dict[str, Any], model: INSTRUCTOR) -> Any: +def predict_fn(data: dict[str, Any], model: INSTRUCTOR) -> Any: """Get embeddings.""" instruction = data["instruction"] text = data["text"] diff --git a/lib/serve/ecs-model/embedding/tei/Dockerfile b/lib/serve/ecs-model/embedding/tei/Dockerfile index cbcd04397..1409cba4c 100644 --- a/lib/serve/ecs-model/embedding/tei/Dockerfile +++ b/lib/serve/ecs-model/embedding/tei/Dockerfile @@ -1,6 +1,22 @@ ARG BASE_IMAGE=ghcr.io/huggingface/text-embeddings-inference:latest FROM ${BASE_IMAGE} +# Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) +RUN mkdir -p /etc/ssh && \ + echo "" >> /etc/ssh/ssh_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/ssh_config && \ + echo "Host *" >> /etc/ssh/ssh_config && \ + echo " Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/ssh_config && \ + echo " MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/ssh_config && \ + echo " KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/ssh_config && \ + if [ -f /etc/ssh/sshd_config ]; then \ + echo "" >> /etc/ssh/sshd_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/sshd_config && \ + echo "Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/sshd_config && \ + echo "MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/sshd_config && \ + echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ + fi + ##### DOWNLOAD MOUNTPOINTS S3 ARG MOUNTS3_DEB_URL ARG MOUNTS3_DEB_SHA256 diff --git a/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh b/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh index 3d0786aa3..0e6e7f9d6 100644 --- a/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh +++ b/lib/serve/ecs-model/embedding/tei/src/entrypoint.sh @@ -4,37 +4,46 @@ set -e # Environment variables for LISA deployment declare -a vars=("S3_BUCKET_MODELS" "LOCAL_MODEL_PATH" "MODEL_NAME" "S3_MOUNT_POINT" "THREADS") -# TEI Configuration Environment Variables (read natively by TEI): -# Performance & Concurrency: +# TEI Configuration Environment Variables +# Based on official TEI documentation: https://huggingface.co/docs/text-embeddings-inference/cli_arguments +# +# PERFORMANCE & BATCHING (Critical for throughput): # MAX_CONCURRENT_REQUESTS - Maximum concurrent requests (default: 512) -# MAX_BATCH_TOKENS - Maximum tokens per batch (default: 16384) -# MAX_BATCH_REQUESTS - Maximum requests per batch -# MAX_CLIENT_BATCH_SIZE - Maximum client batch size (default: 1024) -# TOKENIZATION_WORKERS - Number of tokenization workers +# MAX_BATCH_TOKENS - Maximum tokens per batch (default: 16384) **CRITICAL for GPU utilization** +# MAX_BATCH_REQUESTS - Maximum requests per batch (optional) +# MAX_CLIENT_BATCH_SIZE - Maximum inputs per client request (default: 32) +# TOKENIZATION_WORKERS - Number of tokenization workers (default: CPU cores) # -# Model Configuration: +# MODEL CONFIGURATION: # REVISION - Model revision/branch to use -# DTYPE - Data type for model weights (float16, float32, etc.) -# HUGGINGFACE_HUB_CACHE - Custom cache directory (default: /data) -# HF_API_TOKEN - Hugging Face API token for private models +# DTYPE - Data type for model weights (float16, float32) +# POOLING - Pooling method (cls, mean, splade, last-token) +# DEFAULT_PROMPT_NAME - Default prompt name from model config +# DEFAULT_PROMPT - Default prompt text to prepend +# DENSE_PATH - Path to Dense module for some models +# SERVED_MODEL_NAME - Model name for OpenAI-compatible endpoints # -# Features: -# AUTO_TRUNCATE - Enable automatic truncation (true/false, default: false) +# INPUT HANDLING: +# AUTO_TRUNCATE - Automatically truncate long inputs (true/false) # PAYLOAD_LIMIT - Maximum payload size in bytes (default: 2000000) # -# Network & Security: -# HOSTNAME - Server hostname (default: container hostname) -# PORT - Server port (default: 8080, overridden by --port) -# UDS_PATH - Unix domain socket path (default: /tmp/text-embeddings-inference-server) -# API_KEY - API key for authentication +# AUTHENTICATION & NETWORK: +# HF_TOKEN - Hugging Face API token for private models +# API_KEY - API key for request authorization +# HOSTNAME - Server hostname (default: 0.0.0.0) +# PORT - Server port (default: 3000, overridden to 8080) +# UDS_PATH - Unix domain socket path # CORS_ALLOW_ORIGIN - CORS origin configuration # -# Output & Observability: -# JSON_OUTPUT - Enable JSON output (true/false, overridden by --json-output) -# OTLP_ENDPOINT - OpenTelemetry endpoint for metrics +# OBSERVABILITY: +# JSON_OUTPUT - Enable JSON output (true/false) +# OTLP_ENDPOINT - OpenTelemetry endpoint for tracing +# OTLP_SERVICE_NAME - Service name for OpenTelemetry +# PROMETHEUS_PORT - Prometheus metrics port (default: 9000) +# DISABLE_SPANS - Disable tracing spans # -# Custom LISA Environment Variables: -# TEI_POOLING - Pooling method (mean, cls, max, mean_sqrt_len) - not available as native env var +# STORAGE: +# HUGGINGFACE_HUB_CACHE - Custom cache directory # Check the necessary environment variables for var in "${vars[@]}"; do @@ -55,38 +64,126 @@ mkdir -p ${LOCAL_MODEL_PATH} # Use rsync with S3_MOUNT_POINT ls ${S3_MOUNT_POINT}/${MODEL_NAME} | xargs -n1 -P${THREADS} -I% rsync -Pa --exclude "*.bin" ${S3_MOUNT_POINT}/${MODEL_NAME}/% ${LOCAL_MODEL_PATH}/ -# Build additional arguments for TEI (only for parameters not supported via env vars) +# Build CLI arguments from environment variables +# TEI reads some env vars natively, but we explicitly pass them as CLI args for clarity ADDITIONAL_ARGS="" -# Pooling configuration (not available as env var) -if [[ -n "${TEI_POOLING}" ]]; then - ADDITIONAL_ARGS+=" --pooling ${TEI_POOLING}" - echo "Using pooling method: ${TEI_POOLING}" +echo "Building TEI CLI arguments from environment variables..." + +# Performance & Batching +if [[ -n "${MAX_CONCURRENT_REQUESTS}" ]]; then + ADDITIONAL_ARGS+=" --max-concurrent-requests ${MAX_CONCURRENT_REQUESTS}" + echo " --max-concurrent-requests ${MAX_CONCURRENT_REQUESTS}" fi -# Start the webserver -# TEI natively reads these environment variables: -# - MAX_CONCURRENT_REQUESTS -# - MAX_BATCH_TOKENS -# - MAX_BATCH_REQUESTS -# - MAX_CLIENT_BATCH_SIZE -# - REVISION -# - TOKENIZATION_WORKERS -# - DTYPE -# - AUTO_TRUNCATE -# - PAYLOAD_LIMIT -# - HUGGINGFACE_HUB_CACHE -# - HF_API_TOKEN -# - HOSTNAME -# - PORT -# - UDS_PATH -# - API_KEY -# - JSON_OUTPUT -# - OTLP_ENDPOINT -# - CORS_ALLOW_ORIGIN +if [[ -n "${MAX_BATCH_TOKENS}" ]]; then + ADDITIONAL_ARGS+=" --max-batch-tokens ${MAX_BATCH_TOKENS}" + echo " --max-batch-tokens ${MAX_BATCH_TOKENS}" +fi + +if [[ -n "${MAX_BATCH_REQUESTS}" ]]; then + ADDITIONAL_ARGS+=" --max-batch-requests ${MAX_BATCH_REQUESTS}" + echo " --max-batch-requests ${MAX_BATCH_REQUESTS}" +fi + +if [[ -n "${MAX_CLIENT_BATCH_SIZE}" ]]; then + ADDITIONAL_ARGS+=" --max-client-batch-size ${MAX_CLIENT_BATCH_SIZE}" + echo " --max-client-batch-size ${MAX_CLIENT_BATCH_SIZE}" +fi + +if [[ -n "${TOKENIZATION_WORKERS}" ]]; then + ADDITIONAL_ARGS+=" --tokenization-workers ${TOKENIZATION_WORKERS}" + echo " --tokenization-workers ${TOKENIZATION_WORKERS}" +fi + +# Model Configuration +if [[ -n "${REVISION}" ]]; then + ADDITIONAL_ARGS+=" --revision ${REVISION}" + echo " --revision ${REVISION}" +fi + +if [[ -n "${DTYPE}" ]]; then + ADDITIONAL_ARGS+=" --dtype ${DTYPE}" + echo " --dtype ${DTYPE}" +fi + +if [[ -n "${POOLING}" ]]; then + ADDITIONAL_ARGS+=" --pooling ${POOLING}" + echo " --pooling ${POOLING}" +fi + +if [[ -n "${DEFAULT_PROMPT_NAME}" ]]; then + ADDITIONAL_ARGS+=" --default-prompt-name ${DEFAULT_PROMPT_NAME}" + echo " --default-prompt-name ${DEFAULT_PROMPT_NAME}" +fi + +if [[ -n "${DEFAULT_PROMPT}" ]]; then + ADDITIONAL_ARGS+=" --default-prompt \"${DEFAULT_PROMPT}\"" + echo " --default-prompt \"${DEFAULT_PROMPT}\"" +fi + +if [[ -n "${DENSE_PATH}" ]]; then + ADDITIONAL_ARGS+=" --dense-path ${DENSE_PATH}" + echo " --dense-path ${DENSE_PATH}" +fi +if [[ -n "${SERVED_MODEL_NAME}" ]]; then + ADDITIONAL_ARGS+=" --served-model-name ${SERVED_MODEL_NAME}" + echo " --served-model-name ${SERVED_MODEL_NAME}" +fi + +# Input Handling +if [[ "${AUTO_TRUNCATE}" == "true" ]]; then + ADDITIONAL_ARGS+=" --auto-truncate" + echo " --auto-truncate" +fi + +if [[ -n "${PAYLOAD_LIMIT}" ]]; then + ADDITIONAL_ARGS+=" --payload-limit ${PAYLOAD_LIMIT}" + echo " --payload-limit ${PAYLOAD_LIMIT}" +fi + +# Authentication +if [[ -n "${HF_TOKEN}" ]]; then + ADDITIONAL_ARGS+=" --hf-token ${HF_TOKEN}" + echo " --hf-token [REDACTED]" +fi + +if [[ -n "${API_KEY}" ]]; then + ADDITIONAL_ARGS+=" --api-key ${API_KEY}" + echo " --api-key [REDACTED]" +fi + +# Observability +if [[ -n "${OTLP_ENDPOINT}" ]]; then + ADDITIONAL_ARGS+=" --otlp-endpoint ${OTLP_ENDPOINT}" + echo " --otlp-endpoint ${OTLP_ENDPOINT}" +fi + +if [[ -n "${OTLP_SERVICE_NAME}" ]]; then + ADDITIONAL_ARGS+=" --otlp-service-name ${OTLP_SERVICE_NAME}" + echo " --otlp-service-name ${OTLP_SERVICE_NAME}" +fi + +if [[ -n "${PROMETHEUS_PORT}" ]]; then + ADDITIONAL_ARGS+=" --prometheus-port ${PROMETHEUS_PORT}" + echo " --prometheus-port ${PROMETHEUS_PORT}" +fi + +if [[ "${DISABLE_SPANS}" == "true" ]]; then + ADDITIONAL_ARGS+=" --disable-spans" + echo " --disable-spans" +fi + +# CORS +if [[ -n "${CORS_ALLOW_ORIGIN}" ]]; then + ADDITIONAL_ARGS+=" --cors-allow-origin ${CORS_ALLOW_ORIGIN}" + echo " --cors-allow-origin ${CORS_ALLOW_ORIGIN}" +fi + +# Start the webserver echo "Starting TEI with args: ${ADDITIONAL_ARGS}" echo "TEI environment variables:" -env | grep -E "^(MAX_CONCURRENT_REQUESTS|MAX_BATCH_TOKENS|MAX_BATCH_REQUESTS|MAX_CLIENT_BATCH_SIZE|REVISION|TOKENIZATION_WORKERS|DTYPE|AUTO_TRUNCATE|PAYLOAD_LIMIT|HUGGINGFACE_HUB_CACHE|HF_API_TOKEN|HOSTNAME|PORT|UDS_PATH|API_KEY|JSON_OUTPUT|OTLP_ENDPOINT|CORS_ALLOW_ORIGIN|TEI_POOLING)=" || echo "No TEI environment variables set" +env | grep -E "^(MAX_CONCURRENT_REQUESTS|MAX_BATCH_TOKENS|MAX_BATCH_REQUESTS|MAX_CLIENT_BATCH_SIZE|TOKENIZATION_WORKERS|REVISION|DTYPE|POOLING|DEFAULT_PROMPT|DENSE_PATH|SERVED_MODEL_NAME|AUTO_TRUNCATE|PAYLOAD_LIMIT|HF_TOKEN|API_KEY|OTLP_ENDPOINT|PROMETHEUS_PORT|CORS_ALLOW_ORIGIN)=" || echo "No TEI environment variables set" text-embeddings-router --model-id $LOCAL_MODEL_PATH --port 8080 --json-output ${ADDITIONAL_ARGS} diff --git a/lib/serve/ecs-model/textgen/tgi/Dockerfile b/lib/serve/ecs-model/textgen/tgi/Dockerfile index b272baa68..a6ca7a891 100644 --- a/lib/serve/ecs-model/textgen/tgi/Dockerfile +++ b/lib/serve/ecs-model/textgen/tgi/Dockerfile @@ -1,6 +1,22 @@ ARG BASE_IMAGE=ghcr.io/huggingface/text-generation-inference:latest FROM ${BASE_IMAGE} +# Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) +RUN mkdir -p /etc/ssh && \ + echo "" >> /etc/ssh/ssh_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/ssh_config && \ + echo "Host *" >> /etc/ssh/ssh_config && \ + echo " Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/ssh_config && \ + echo " MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/ssh_config && \ + echo " KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/ssh_config && \ + if [ -f /etc/ssh/sshd_config ]; then \ + echo "" >> /etc/ssh/sshd_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/sshd_config && \ + echo "Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/sshd_config && \ + echo "MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/sshd_config && \ + echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ + fi + ##### DOWNLOAD MOUNTPOINTS S3 ARG MOUNTS3_DEB_URL RUN apt update -y && apt install -y wget rsync && \ diff --git a/lib/serve/ecs-model/vllm/Dockerfile b/lib/serve/ecs-model/vllm/Dockerfile index dedaa69c8..ad3d91d06 100644 --- a/lib/serve/ecs-model/vllm/Dockerfile +++ b/lib/serve/ecs-model/vllm/Dockerfile @@ -1,13 +1,36 @@ -ARG BASE_IMAGE=python:3.13-slim +ARG BASE_IMAGE=public.ecr.aws/deep-learning-containers/vllm:0.13-gpu-py312 FROM ${BASE_IMAGE} +# Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) +RUN mkdir -p /etc/ssh && \ + echo "" >> /etc/ssh/ssh_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/ssh_config && \ + echo "Host *" >> /etc/ssh/ssh_config && \ + echo " Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/ssh_config && \ + echo " MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/ssh_config && \ + echo " KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/ssh_config && \ + if [ -f /etc/ssh/sshd_config ]; then \ + echo "" >> /etc/ssh/sshd_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/sshd_config && \ + echo "Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/sshd_config && \ + echo "MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/sshd_config && \ + echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ + fi + ##### DOWNLOAD MOUNTPOINTS S3 ARG MOUNTS3_DEB_URL ARG MOUNTS3_DEB_SHA256 -RUN apt update -y && apt install -y wget rsync && \ - wget ${MOUNTS3_DEB_URL} && \ - apt install -y ./mount-s3.deb && \ - rm mount-s3.deb +RUN if command -v apt-get >/dev/null 2>&1; then \ + apt update -y && apt install -y wget rsync && \ + wget ${MOUNTS3_DEB_URL} && apt install -y ./mount-s3.deb && \ + rm mount-s3.deb && rm -rf /var/lib/apt/lists/*; \ + elif command -v yum >/dev/null 2>&1; then \ + yum install -y wget rsync && wget ${MOUNTS3_DEB_URL} && \ + yum install -y ./mount-s3.rpm && yum clean all && rm mount-s3.rpm; \ + elif command -v apk >/dev/null 2>&1; then \ + apk add --no-cache wget rsync && wget ${MOUNTS3_DEB_URL} && \ + apk add --allow-untrusted ./mount-s3.apk && rm mount-s3.apk; \ + fi COPY src/entrypoint.sh ./entrypoint.sh RUN chmod +x entrypoint.sh diff --git a/lib/serve/ecs-model/vllm/src/entrypoint.sh b/lib/serve/ecs-model/vllm/src/entrypoint.sh index ae834a7ef..c23a0452a 100644 --- a/lib/serve/ecs-model/vllm/src/entrypoint.sh +++ b/lib/serve/ecs-model/vllm/src/entrypoint.sh @@ -176,7 +176,71 @@ if [[ -n "${VLLM_TENSOR_PARALLEL_SIZE}" ]] && [[ ${VLLM_TENSOR_PARALLEL_SIZE} -g fi # Start the webserver -# vLLM natively reads VLLM_* environment variables for configuration +# vLLM reads some VLLM_* environment variables natively, but many require CLI args +# Map environment variables to CLI arguments for full control + +echo "Building vLLM CLI arguments from environment variables..." + +# GPU memory utilization (0.0-1.0) +if [[ -n "${VLLM_GPU_MEMORY_UTILIZATION}" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --gpu-memory-utilization ${VLLM_GPU_MEMORY_UTILIZATION}" + echo " --gpu-memory-utilization ${VLLM_GPU_MEMORY_UTILIZATION}" +fi + +# Max model length (context window) +if [[ -n "${VLLM_MAX_MODEL_LEN}" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --max-model-len ${VLLM_MAX_MODEL_LEN}" + echo " --max-model-len ${VLLM_MAX_MODEL_LEN}" +fi + +# Max number of batched tokens per iteration +if [[ -n "${VLLM_MAX_NUM_BATCHED_TOKENS}" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --max-num-batched-tokens ${VLLM_MAX_NUM_BATCHED_TOKENS}" + echo " --max-num-batched-tokens ${VLLM_MAX_NUM_BATCHED_TOKENS}" +fi + +# Max number of sequences (concurrent requests) +if [[ -n "${VLLM_MAX_NUM_SEQS}" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --max-num-seqs ${VLLM_MAX_NUM_SEQS}" + echo " --max-num-seqs ${VLLM_MAX_NUM_SEQS}" +fi + +# Enable prefix caching +if [[ "${VLLM_ENABLE_PREFIX_CACHING}" == "true" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --enable-prefix-caching" + echo " --enable-prefix-caching" +fi + +# Enable chunked prefill +if [[ "${VLLM_ENABLE_CHUNKED_PREFILL}" == "true" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --enable-chunked-prefill" + echo " --enable-chunked-prefill" +fi + +# Data type (auto, half, float16, bfloat16, float, float32) +if [[ -n "${VLLM_DTYPE}" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --dtype ${VLLM_DTYPE}" + echo " --dtype ${VLLM_DTYPE}" +fi + +# Tensor parallel size (for multi-GPU) +if [[ -n "${VLLM_TENSOR_PARALLEL_SIZE}" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --tensor-parallel-size ${VLLM_TENSOR_PARALLEL_SIZE}" + echo " --tensor-parallel-size ${VLLM_TENSOR_PARALLEL_SIZE}" +fi + +# Quantization method +if [[ -n "${VLLM_QUANTIZATION}" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --quantization ${VLLM_QUANTIZATION}" + echo " --quantization ${VLLM_QUANTIZATION}" +fi + +# Trust remote code (for custom models) +if [[ "${VLLM_TRUST_REMOTE_CODE}" == "true" ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --trust-remote-code" + echo " --trust-remote-code" +fi + echo "Starting vLLM with args: ${ADDITIONAL_ARGS}" echo "vLLM environment variables:" env | grep -E "^(VLLM_|MAX_TOTAL_TOKENS)=" || echo "No vLLM environment variables set" diff --git a/lib/serve/mcp-workbench/Dockerfile b/lib/serve/mcp-workbench/Dockerfile index 1e0efb480..ce52e7b3f 100644 --- a/lib/serve/mcp-workbench/Dockerfile +++ b/lib/serve/mcp-workbench/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=python:3.13-slim +ARG BASE_IMAGE=public.ecr.aws/docker/library/python:3.13-slim FROM ${BASE_IMAGE} ARG RCLONE_VERSION=v1.71.0 @@ -15,12 +15,30 @@ ENV RCLONE_SOURCE=$RCLONE_SOURCE WORKDIR /workspace -RUN apt-get update && apt-get install -y \ - curl \ - fuse3 \ - unzip \ - xz-utils \ - && rm -rf /var/lib/apt/lists/* +# Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) +RUN mkdir -p /etc/ssh && \ + echo "" >> /etc/ssh/ssh_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/ssh_config && \ + echo "Host *" >> /etc/ssh/ssh_config && \ + echo " Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/ssh_config && \ + echo " MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/ssh_config && \ + echo " KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/ssh_config && \ + if [ -f /etc/ssh/sshd_config ]; then \ + echo "" >> /etc/ssh/sshd_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/sshd_config && \ + echo "Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/sshd_config && \ + echo "MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/sshd_config && \ + echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ + fi + +# Install dependencies +RUN if command -v apt-get >/dev/null 2>&1; then \ + apt-get update && apt-get install -y curl fuse3 unzip xz-utils && rm -rf /var/lib/apt/lists/*; \ + elif command -v yum >/dev/null 2>&1; then \ + yum install -y curl fuse3 unzip xz && yum clean all; \ + elif command -v apk >/dev/null 2>&1; then \ + apk add --no-cache curl fuse3 unzip xz; \ + fi # Install s6-overlay ADD $S6_OVERLAY_NOARCH_SOURCE /tmp/ diff --git a/lib/serve/mcp-workbench/pyproject.toml b/lib/serve/mcp-workbench/pyproject.toml index cfcd89b9e..de687c4f6 100644 --- a/lib/serve/mcp-workbench/pyproject.toml +++ b/lib/serve/mcp-workbench/pyproject.toml @@ -14,17 +14,15 @@ dependencies = [ "pyyaml>=6.0.2", "click==8.3.1", "starlette>=0.40.0,<0.51.0", - "uvicorn>=0.31.1,<0.39.0", - "aioboto3==13.4.0", - "aiobotocore==2.18.0", + "uvicorn>=0.31.1,<0.32.0", "aiohttp==3.13.2", - "boto3==1.36.0", + "boto3==1.40.76", "cryptography==46.0.3", - "gunicorn==23.0.0", - "pydantic==2.12.5", - "PyJWT==2.10.1", + "gunicorn>=23.0.0,<24.0.0", + "pydantic>=2.5.0,<3.0.0", + "PyJWT>=2.10.1,<3.0.0", "requests==2.32.5", - "fastapi==0.124.2", + "fastapi>=0.120.1", "fastapi_utils==0.8.0", "loguru==0.7.3" ] @@ -49,7 +47,7 @@ where = ["src"] [tool.pytest.ini_options] minversion = "7.0" addopts = "-ra -q --tb=short" -testpaths = ["tests"] +testpaths = ["../../test/mcp-workbench"] python_files = ["test_*.py", "*_test.py"] python_classes = ["Test*"] python_functions = ["test_*"] diff --git a/lib/serve/mcp-workbench/requirements.txt b/lib/serve/mcp-workbench/requirements.txt index 56d7f6eba..460f36cb6 100644 --- a/lib/serve/mcp-workbench/requirements.txt +++ b/lib/serve/mcp-workbench/requirements.txt @@ -1,2 +1,2 @@ ## Add additional requirements to this file -## boto3==1.36.0 +## boto3==1.40.76 diff --git a/lib/serve/mcp-workbench/src/examples/sample_tools/calculator_tool.py b/lib/serve/mcp-workbench/src/examples/sample_tools/calculator_tool.py index 31085baea..7780c7f5c 100644 --- a/lib/serve/mcp-workbench/src/examples/sample_tools/calculator_tool.py +++ b/lib/serve/mcp-workbench/src/examples/sample_tools/calculator_tool.py @@ -23,6 +23,7 @@ Both methods allow you to create tools that can be called by AI models to perform specific tasks. """ +from collections.abc import Callable from typing import Annotated from mcpworkbench.core.base_tool import BaseTool @@ -45,7 +46,7 @@ class CalculatorTool(BaseTool): 4. Define the actual tool function with proper type annotations """ - def __init__(self): + def __init__(self) -> None: """ Initialize the tool with metadata. @@ -57,7 +58,7 @@ def __init__(self): name="calculator", description="Performs basic arithmetic operations (add, subtract, multiply, divide)" ) - async def execute(self): + async def execute(self) -> Callable: """ Return the callable function that implements the tool's functionality. @@ -71,7 +72,7 @@ async def calculate( operator: Annotated[str, "add, subtract, multiply, or divide"], left_operand: Annotated[float, "The first number"], right_operand: Annotated[float, "The second number"], - ): + ) -> dict[str, float | str]: """ Execute the calculator operation. diff --git a/lib/serve/mcp-workbench/src/examples/sample_tools/text_utils.py b/lib/serve/mcp-workbench/src/examples/sample_tools/text_utils.py index 81871c978..e662a043d 100644 --- a/lib/serve/mcp-workbench/src/examples/sample_tools/text_utils.py +++ b/lib/serve/mcp-workbench/src/examples/sample_tools/text_utils.py @@ -33,7 +33,7 @@ name="text_length", description="Count the number of characters in a text string", ) -async def count_characters(text: Annotated[str, "The text string to analyze"]): +async def count_characters(text: Annotated[str, "The text string to analyze"]) -> dict[str, int | str]: """Count the number of characters in the given text.""" return { "text": text, @@ -50,7 +50,7 @@ async def count_characters(text: Annotated[str, "The text string to analyze"]): def transform_text( text: Annotated[str, "The text string to transform"], transformation: Annotated[str, "Type of transformation: 'upper', 'lower', 'title', or 'capitalize'"], -): +) -> dict[str, str]: """Transform the given text according to the specified transformation.""" if transformation == "upper": result = text.upper() @@ -70,6 +70,6 @@ def transform_text( name="text_reverse", description="Reverse the characters in a text string", ) -def reverse_text(text: Annotated[str, "The text string to reverse"]): +def reverse_text(text: Annotated[str, "The text string to reverse"]) -> dict[str, str]: """Reverse the characters in the given text.""" return {"original": text, "reversed": text[::-1]} diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/adapters/tool_adapter.py b/lib/serve/mcp-workbench/src/mcpworkbench/adapters/tool_adapter.py index b97b76a52..72c8be790 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/adapters/tool_adapter.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/adapters/tool_adapter.py @@ -17,7 +17,7 @@ import asyncio import logging from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any from ..core.base_tool import BaseTool from ..core.tool_discovery import ToolInfo, ToolType @@ -32,7 +32,7 @@ def __init__(self, tool_info: ToolInfo): self.tool_info = tool_info @abstractmethod - async def execute(self, arguments: Dict[str, Any]) -> Any: + async def execute(self, arguments: dict[str, Any]) -> Any: """Execute the tool with the given arguments.""" pass @@ -60,7 +60,7 @@ def __init__(self, tool_info: ToolInfo): super().__init__(tool_info) self.tool_instance: BaseTool = tool_info.tool_instance - async def execute(self, arguments: Dict[str, Any]) -> Any: + async def execute(self, arguments: dict[str, Any]) -> Any: """Execute the BaseTool instance.""" try: # Call the tool's execute method @@ -84,7 +84,7 @@ def __init__(self, tool_info: ToolInfo): super().__init__(tool_info) self.function = tool_info.tool_instance - async def execute(self, arguments: Dict[str, Any]) -> Any: + async def execute(self, arguments: dict[str, Any]) -> Any: """Execute the decorated function.""" try: # Check if the function is async diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/cli.py b/lib/serve/mcp-workbench/src/mcpworkbench/cli.py index 726e1a4ca..99eab6476 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/cli.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/cli.py @@ -18,7 +18,6 @@ import re import sys from pathlib import Path -from typing import Optional import click import yaml @@ -38,7 +37,7 @@ def load_config_from_file(config_path: str) -> dict: """Load configuration from YAML file.""" try: - with open(config_path, "r") as f: + with open(config_path) as f: return yaml.safe_load(f) or {} except FileNotFoundError: logger.error(f"Configuration file not found: {config_path}") @@ -79,16 +78,16 @@ def merge_config(file_config: dict, cli_overrides: dict) -> dict: @click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging") @click.option("--debug", is_flag=True, help="Enable debug logging") def main( - config: Optional[Path], - tools_dir: Optional[Path], - host: Optional[str], - port: Optional[int], - exit_route: Optional[str], - rescan_route: Optional[str], - cors_origins: Optional[str], + config: Path | None, + tools_dir: Path | None, + host: str | None, + port: int | None, + exit_route: str | None, + rescan_route: str | None, + cors_origins: str | None, verbose: bool, debug: bool, -): +) -> None: """MCP Workbench - A dynamic host for Python files used as MCP tools.""" # Set logging level @@ -106,14 +105,14 @@ def main( file_config = load_config_from_file(str(config)) # Prepare CLI overrides - cli_overrides = {} + cli_overrides: dict[str, str | list[str]] = {} if tools_dir: cli_overrides["tools_dir"] = str(tools_dir) if host: cli_overrides["host"] = host if port: - cli_overrides["port"] = port + cli_overrides["port"] = str(port) if exit_route: cli_overrides["exit_route"] = exit_route if rescan_route: @@ -122,8 +121,8 @@ def main( # Handle CORS origins if cors_origins: cleaned_origins = re.sub(r'^([\s"]+)?(.+?)([\s"]*)?$', r"\2", cors_origins) - origins = [origin.strip() for origin in cleaned_origins.split(",")] - cli_overrides["cors_origins"] = origins + origins_list: list[str] = [origin.strip() for origin in cleaned_origins.split(",")] + cli_overrides["cors_origins"] = origins_list # Merge configurations merged_config = merge_config(file_config, cli_overrides) diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/config/models.py b/lib/serve/mcp-workbench/src/mcpworkbench/config/models.py index 51a209463..47280856b 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/config/models.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/config/models.py @@ -14,7 +14,6 @@ """Configuration models for MCP Workbench.""" -from typing import List, Optional from pydantic import BaseModel, Field @@ -22,11 +21,11 @@ class CORSConfig(BaseModel): """CORS configuration settings.""" - allow_origins: List[str] = Field(default=["*"], description="Allowed origins for CORS") - allow_methods: List[str] = Field(default=["GET", "POST", "OPTIONS"], description="Allowed HTTP methods") - allow_headers: List[str] = Field(default=["*"], description="Allowed headers") + allow_origins: list[str] = Field(default=["*"], description="Allowed origins for CORS") + allow_methods: list[str] = Field(default=["GET", "POST", "OPTIONS"], description="Allowed HTTP methods") + allow_headers: list[str] = Field(default=["*"], description="Allowed headers") allow_credentials: bool = Field(default=True, description="Allow credentials in CORS requests") - expose_headers: List[str] = Field(default=[], description="Headers to expose to the browser") + expose_headers: list[str] = Field(default=[], description="Headers to expose to the browser") max_age: int = Field(default=600, description="Maximum age for CORS preflight cache") @@ -41,8 +40,8 @@ class ServerConfig(BaseModel): tools_directory: str = Field(..., description="Directory containing tool files") # Management tool settings - exit_route_path: Optional[str] = Field(default=None, description="Enable exit_server MCP tool when set") - rescan_route_path: Optional[str] = Field(default=None, description="Enable rescan_tools MCP tool when set") + exit_route_path: str | None = Field(default=None, description="Enable exit_server MCP tool when set") + rescan_route_path: str | None = Field(default=None, description="Enable rescan_tools MCP tool when set") # CORS settings cors_settings: CORSConfig = Field(default_factory=CORSConfig, description="CORS configuration") @@ -83,8 +82,15 @@ def from_dict(cls, data: dict) -> "ServerConfig": if not isinstance(config_data["cors_settings"], dict): config_data["cors_settings"] = {} - # Set the origins - config_data["cors_settings"]["allow_origins"] = cors_origins + # Convert cors_origins to a list if it's a string (comma-separated) + if isinstance(cors_origins, str): + origins_list = [origin.strip() for origin in cors_origins.split(",") if origin.strip()] + config_data["cors_settings"]["allow_origins"] = origins_list + elif isinstance(cors_origins, list): + config_data["cors_settings"]["allow_origins"] = cors_origins + else: + # Fallback to default + config_data["cors_settings"]["allow_origins"] = ["*"] # Handle cors_settings if "cors_settings" in config_data and isinstance(config_data["cors_settings"], dict): diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/annotations.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/annotations.py index c971b6f05..d10dcc85c 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/annotations.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/annotations.py @@ -14,11 +14,14 @@ """Annotations for function-based MCP tools.""" +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, Dict +from typing import Any, cast, TypeVar +F = TypeVar("F", bound=Callable[..., Any]) -def mcp_tool(name: str, description: str): + +def mcp_tool(name: str, description: str) -> Callable[[F], F]: """ Decorator to mark a function as an MCP tool. @@ -30,14 +33,14 @@ def mcp_tool(name: str, description: str): The decorated function with MCP tool metadata """ - def decorator(func: Callable) -> Callable: + def decorator(func: F) -> F: # Store metadata as function attributes - func._mcp_tool_name = name - func._mcp_tool_description = description - func._is_mcp_tool = True + func._mcp_tool_name = name # type: ignore[attr-defined] + func._mcp_tool_description = description # type: ignore[attr-defined] + func._is_mcp_tool = True # type: ignore[attr-defined] @wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> Any: # If the function is not already async, we need to handle it if hasattr(func, "__code__") and func.__code__.co_flags & 0x80: # CO_COROUTINE return await func(*args, **kwargs) @@ -45,22 +48,22 @@ async def wrapper(*args, **kwargs): return func(*args, **kwargs) # Copy metadata to wrapper - wrapper._mcp_tool_name = name - wrapper._mcp_tool_description = description - wrapper._is_mcp_tool = True - wrapper._original_func = func + wrapper._mcp_tool_name = name # type: ignore[attr-defined] + wrapper._mcp_tool_description = description # type: ignore[attr-defined] + wrapper._is_mcp_tool = True # type: ignore[attr-defined] + wrapper._original_func = func # type: ignore[attr-defined] - return wrapper + return cast(F, wrapper) return decorator def is_mcp_tool(func: Callable) -> bool: """Check if a function is marked as an MCP tool.""" - return hasattr(func, "_is_mcp_tool") and func._is_mcp_tool + return hasattr(func, "_is_mcp_tool") and getattr(func, "_is_mcp_tool", False) -def get_tool_metadata(func: Callable) -> Dict[str, Any]: +def get_tool_metadata(func: Callable) -> dict[str, Any]: """Get the MCP tool metadata from a decorated function.""" if not is_mcp_tool(func): raise ValueError("Function is not marked as an MCP tool") diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py index ccb3fc983..0e087d028 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/base_tool.py @@ -26,8 +26,9 @@ # limitations under the License."""Base tool class and related data structures.""" from abc import ABC, abstractmethod +from collections.abc import Callable from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Any from pydantic import BaseModel, Field @@ -49,15 +50,13 @@ class ToolInfo(BaseModel): module_name: str = Field(..., description="Python module name") # For class-based tools - class_name: Optional[str] = Field(default=None, description="Class name for class-based tools") + class_name: str | None = Field(default=None, description="Class name for class-based tools") # For function-based tools - function_name: Optional[str] = Field(default=None, description="Function name for function-based tools") + function_name: str | None = Field(default=None, description="Function name for function-based tools") # Tool instance or function reference (not serialized) - tool_instance: Optional[Union[Any, Callable]] = Field( - default=None, exclude=True, description="Tool instance or function" - ) + tool_instance: Any | Callable | None = Field(default=None, exclude=True, description="Tool instance or function") class BaseTool(ABC): diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_discovery.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_discovery.py index 01159e0df..40013c341 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_discovery.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_discovery.py @@ -20,7 +20,7 @@ import logging import sys from pathlib import Path -from typing import Dict, List +from typing import Any from pydantic import BaseModel @@ -33,11 +33,11 @@ class RescanResult(BaseModel): """Result of a tool directory rescan.""" - tools_added: List[str] = [] - tools_updated: List[str] = [] - tools_removed: List[str] = [] + tools_added: list[str] = [] + tools_updated: list[str] = [] + tools_removed: list[str] = [] total_tools: int = 0 - errors: List[str] = [] + errors: list[str] = [] class ToolDiscovery: @@ -51,8 +51,8 @@ def __init__(self, tools_directory: str): tools_directory: Path to directory containing tool files """ self.tools_directory = Path(tools_directory) - self.loaded_modules: Dict[str, any] = {} - self.current_tools: Dict[str, ToolInfo] = {} + self.loaded_modules: dict[str, Any] = {} + self.current_tools: dict[str, ToolInfo] = {} if not self.tools_directory.exists(): raise ValueError(f"Tools directory does not exist: {tools_directory}") @@ -60,7 +60,7 @@ def __init__(self, tools_directory: str): if not self.tools_directory.is_dir(): raise ValueError(f"Tools directory is not a directory: {tools_directory}") - def discover_tools(self) -> List[ToolInfo]: + def discover_tools(self) -> list[ToolInfo]: """ Discover all tools in the tools directory. @@ -125,7 +125,7 @@ def rescan_tools(self) -> RescanResult: return result - def _reload_modules(self): + def _reload_modules(self) -> None: """Reload all previously loaded modules to pick up file changes.""" modules_to_reload = [] @@ -149,7 +149,7 @@ def _reload_modules(self): except KeyError: pass - def _discover_tools_in_file(self, file_path: Path) -> List[ToolInfo]: + def _discover_tools_in_file(self, file_path: Path) -> list[ToolInfo]: """ Discover tools in a single Python file. @@ -159,7 +159,7 @@ def _discover_tools_in_file(self, file_path: Path) -> List[ToolInfo]: Returns: List of tools found in the file """ - tools = [] + tools: list[ToolInfo] = [] try: # Create module name from file path @@ -192,7 +192,7 @@ def _discover_tools_in_file(self, file_path: Path) -> List[ToolInfo]: return tools - def _find_class_based_tools(self, module, file_path: Path, module_name: str) -> List[ToolInfo]: + def _find_class_based_tools(self, module: Any, file_path: Path, module_name: str) -> list[ToolInfo]: """Find BaseTool subclasses in the module.""" tools = [] @@ -218,7 +218,7 @@ def _find_class_based_tools(self, module, file_path: Path, module_name: str) -> instance = obj(name=tool_name, description=tool_description) else: # Custom constructor - try to instantiate with no args - instance = obj() + instance = obj() # type: ignore[call-arg] # Get tool metadata tool_name = getattr(instance, "name", name.lower()) @@ -242,7 +242,7 @@ def _find_class_based_tools(self, module, file_path: Path, module_name: str) -> return tools - def _find_function_based_tools(self, module, file_path: Path, module_name: str) -> List[ToolInfo]: + def _find_function_based_tools(self, module: Any, file_path: Path, module_name: str) -> list[ToolInfo]: """Find @mcp_tool decorated functions in the module.""" tools = [] diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_registry.py b/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_registry.py index 4541cfb2f..b04a8aa56 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_registry.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/core/tool_registry.py @@ -16,7 +16,6 @@ import logging import threading -from typing import Dict, List, Optional from .base_tool import ToolInfo @@ -26,9 +25,9 @@ class ToolRegistry: """Thread-safe registry for managing discovered tools.""" - def __init__(self): + def __init__(self) -> None: """Initialize the tool registry.""" - self._tools: Dict[str, ToolInfo] = {} + self._tools: dict[str, ToolInfo] = {} self._lock = threading.RLock() def register_tool(self, tool_info: ToolInfo) -> None: @@ -42,7 +41,7 @@ def register_tool(self, tool_info: ToolInfo) -> None: self._tools[tool_info.name] = tool_info logger.info(f"Registered tool: {tool_info.name}") - def register_tools(self, tools: List[ToolInfo]) -> None: + def register_tools(self, tools: list[ToolInfo]) -> None: """ Register multiple tools in the registry. @@ -71,7 +70,7 @@ def unregister_tool(self, tool_name: str) -> bool: return True return False - def get_tool(self, tool_name: str) -> Optional[ToolInfo]: + def get_tool(self, tool_name: str) -> ToolInfo | None: """ Get a tool by name. @@ -84,7 +83,7 @@ def get_tool(self, tool_name: str) -> Optional[ToolInfo]: with self._lock: return self._tools.get(tool_name) - def list_tools(self) -> List[ToolInfo]: + def list_tools(self) -> list[ToolInfo]: """ Get a list of all registered tools. @@ -94,7 +93,7 @@ def list_tools(self) -> List[ToolInfo]: with self._lock: return list(self._tools.values()) - def list_tool_names(self) -> List[str]: + def list_tool_names(self) -> list[str]: """ Get a list of all registered tool names. @@ -110,7 +109,7 @@ def clear(self) -> None: self._tools.clear() logger.info("Cleared all tools from registry") - def update_registry(self, new_tools: List[ToolInfo]) -> None: + def update_registry(self, new_tools: list[ToolInfo]) -> None: """ Update the registry with a new set of tools. This replaces all existing tools. diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py index 4cee09f04..e78fd5ecc 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/auth.py @@ -19,7 +19,7 @@ from datetime import datetime from pathlib import Path from time import time -from typing import Any, Dict, Optional +from typing import Any import boto3 import jwt @@ -62,12 +62,13 @@ def is_idp_used() -> bool: raise RuntimeError("No crypto support for JWT.") -def get_oidc_metadata(cert_path: Optional[str] = None) -> Dict[str, Any]: +def get_oidc_metadata(cert_path: str | None = None) -> dict[str, Any]: """Get OIDC endpoints and metadata from authority.""" authority = os.environ.get("AUTHORITY") resp = requests.get(f"{authority}/.well-known/openid-configuration", verify=cert_path or True, timeout=30) resp.raise_for_status() - return resp.json() # type: ignore + result: dict[str, Any] = resp.json() + return result def get_jwks_client() -> jwt.PyJWKClient: @@ -87,11 +88,11 @@ def get_jwks_client() -> jwt.PyJWKClient: def id_token_is_valid( id_token: str, client_id: str, authority: str, jwks_client: jwt.PyJWKClient -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """Check whether an ID token is valid and return decoded data.""" try: signing_key = jwks_client.get_signing_key_from_jwt(id_token) - data: Dict[str, Any] = jwt.decode( + data: dict[str, Any] = jwt.decode( id_token, signing_key.key, algorithms=["RS256"], @@ -124,7 +125,7 @@ def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: return group in current_node -def get_authorization_token(headers: Dict[str, str], header_name: str = "Authorization") -> str: +def get_authorization_token(headers: dict[str, str], header_name: str = "Authorization") -> str: """Get Bearer token from Authorization headers if it exists.""" if header_name in headers: return headers.get(header_name, "").removeprefix("Bearer").strip() @@ -132,10 +133,10 @@ def get_authorization_token(headers: Dict[str, str], header_name: str = "Authori class LoggingMiddleware(BaseHTTPMiddleware): - def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None): + def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None: super().__init__(app, dispatch) - async def dispatch(self, request, call_next): + async def dispatch(self, request: Request, call_next: Any) -> Response: response = await call_next(request) response.headers["Custom"] = "Example" return response @@ -144,22 +145,22 @@ async def dispatch(self, request, call_next): class OIDCHTTPBearer(BaseHTTPMiddleware): """OIDC based bearer token authenticator.""" - def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None): + def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None: super().__init__(app, dispatch) self._token_authorizer = ApiTokenAuthorizer() self._management_token_authorizer = ManagementTokenAuthorizer() self._jwks_client = get_jwks_client() - async def dispatch(self, request: Request, call_next) -> Response: + async def dispatch(self, request: Request, call_next: Any) -> Response: """Verify the provided bearer token or API Key. API Key will take precedence over the bearer token.""" if request.method == "OPTIONS": return await call_next(request) valid = False - if self._token_authorizer.is_valid_api_token(request.headers): + if self._token_authorizer.is_valid_api_token(dict(request.headers)): logger.info("looks like a valid api token") valid = True - elif self._management_token_authorizer.is_valid_api_token(request.headers): + elif self._management_token_authorizer.is_valid_api_token(dict(request.headers)): logger.info("looks like a valid mgmt token") valid = True else: @@ -205,7 +206,7 @@ def _get_token_info(self, token: str) -> Any: ddb_response = self._token_table.get_item(Key={"token": token}, ReturnConsumedCapacity="NONE") return ddb_response.get("Item", None) - def is_valid_api_token(self, headers: Dict[str, str]) -> bool: + def is_valid_api_token(self, headers: dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" for header_name in API_KEY_HEADER_NAMES: token = get_authorization_token(headers, header_name) @@ -248,7 +249,7 @@ def _refreshTokens(self) -> None: self._secret_tokens = secret_tokens self._last_run = current_time - def is_valid_api_token(self, headers: Dict[str, str]) -> bool: + def is_valid_api_token(self, headers: dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" self._refreshTokens() diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py index 2868c3f99..7ecb4eda3 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/mcp_server.py @@ -19,11 +19,12 @@ import logging import sys from datetime import datetime -from typing import Any, Dict, List +from typing import Any from fastmcp import FastMCP from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount, Route @@ -51,7 +52,7 @@ def __init__(self, config: ServerConfig, tool_discovery: ToolDiscovery, tool_reg self.config = config self.tool_discovery = tool_discovery self.tool_registry = tool_registry - self.registered_tools: Dict[str, Any] = {} + self.registered_tools: dict[str, Any] = {} # Create FastMCP application self.app = FastMCP("mcpworkbench") @@ -60,20 +61,20 @@ def __init__(self, config: ServerConfig, tool_discovery: ToolDiscovery, tool_reg # Register built-in management tools self._register_management_tools() - def _register_management_tools(self): + def _register_management_tools(self) -> None: """Register built-in management tools - now removed as they are HTTP routes.""" # Management functionality moved to HTTP GET endpoints pass - def _add_management_routes(self, app: Starlette): + def _add_management_routes(self, app: Starlette) -> None: if self.config.exit_route_path: - async def exit_endpoint(request): + async def exit_endpoint(request: Request) -> JSONResponse: """HTTP GET endpoint to gracefully shutdown the server.""" logger.info("Exit requested via HTTP endpoint") # Schedule shutdown after response is sent - async def delayed_shutdown(): + async def delayed_shutdown() -> None: await asyncio.sleep(1) logger.info("Shutting down server...") sys.exit(0) @@ -91,7 +92,7 @@ async def delayed_shutdown(): if self.config.rescan_route_path: - async def rescan_endpoint(request): + async def rescan_endpoint(request: Request) -> JSONResponse: """HTTP GET endpoint to rescan tools directory and reload tools.""" try: logger.info("Rescanning tools directory via HTTP...") @@ -134,12 +135,12 @@ async def rescan_endpoint(request): app.add_route(self.config.rescan_route_path, rescan_endpoint, methods=["GET"]) - def _create_starlette_app(self): + def _create_starlette_app(self) -> Starlette: """Create Starlette application with MCP and HTTP routes.""" mcp_app = self.app.http_app(path="/", transport="streamable-http", stateless_http=True) - async def health_check(request): + async def health_check(request: Request) -> JSONResponse: """Health check endpoint for Docker health checks.""" return JSONResponse({"status": "healthy", "service": "mcpworkbench"}) @@ -163,7 +164,7 @@ async def health_check(request): return Starlette(routes=routes, lifespan=mcp_app.lifespan) - async def _register_discovered_tools(self, tools: List[ToolInfo]): + async def _register_discovered_tools(self, tools: list[ToolInfo]) -> None: """Register discovered tools with FastMCP.""" for tool_info in tools: try: @@ -171,7 +172,7 @@ async def _register_discovered_tools(self, tools: List[ToolInfo]): except Exception as e: logger.error(f"Failed to register tool {tool_info.name}: {e}") - async def _register_single_tool(self, tool_info: ToolInfo): + async def _register_single_tool(self, tool_info: ToolInfo) -> None: """Register a single discovered tool with FastMCP.""" if tool_info.tool_type == ToolType.CLASS_BASED: await self._register_class_tool(tool_info) @@ -180,7 +181,7 @@ async def _register_single_tool(self, tool_info: ToolInfo): else: logger.error(f"Unknown tool type for {tool_info.name}: {tool_info.tool_type}") - async def _register_class_tool(self, tool_info: ToolInfo): + async def _register_class_tool(self, tool_info: ToolInfo) -> None: """Register a class-based tool with FastMCP.""" if not isinstance(tool_info.tool_instance, BaseTool): raise ValueError(f"Class tool {tool_info.name} instance must be a BaseTool") @@ -198,7 +199,7 @@ async def _register_class_tool(self, tool_info: ToolInfo): self.registered_tools[tool_info.name] = tool_info logger.debug(f"Registered class-based tool: {tool_info.name}") - async def _register_function_tool(self, tool_info: ToolInfo): + async def _register_function_tool(self, tool_info: ToolInfo) -> None: """Register a function-based tool with FastMCP.""" if not callable(tool_info.tool_instance): raise ValueError(f"Function tool {tool_info.name} instance must be callable") @@ -213,7 +214,7 @@ async def _register_function_tool(self, tool_info: ToolInfo): wrapper_func = function else: # Wrap sync function to be async - async def async_wrapper(**kwargs): + async def async_wrapper(**kwargs: Any) -> Any: return function(**kwargs) wrapper_func = async_wrapper @@ -227,7 +228,7 @@ async def async_wrapper(**kwargs): self.registered_tools[tool_info.name] = tool_info logger.debug(f"Registered function-based tool: {tool_info.name}") - async def discover_and_register_tools(self): + async def discover_and_register_tools(self) -> list[ToolInfo]: """Discover and register initial tools.""" logger.info("Discovering initial tools...") tools = self.tool_discovery.discover_tools() @@ -240,7 +241,7 @@ async def discover_and_register_tools(self): return tools - async def start(self): + async def start(self) -> None: """Start the server.""" # Discover and register tools await self.discover_and_register_tools() @@ -271,7 +272,7 @@ async def start(self): server = uvicorn.Server(config) await server.serve() - def run(self): + def run(self) -> None: """Run the server (blocking).""" # Use a more robust approach to handle event loops asyncio.run(self.start()) diff --git a/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py b/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py index df53d5a16..603d026bf 100644 --- a/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py +++ b/lib/serve/mcp-workbench/src/mcpworkbench/server/middleware.py @@ -16,8 +16,9 @@ import logging import sys +from collections.abc import Callable from datetime import datetime -from typing import Callable +from typing import Any from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware as StarletteCORSMiddleware @@ -34,7 +35,7 @@ class CORSMiddleware(StarletteCORSMiddleware): """CORS middleware wrapper for configuration compatibility.""" - def __init__(self, app, cors_config: CORSConfig): + def __init__(self, app: Any, cors_config: CORSConfig) -> None: super().__init__( app, allow_origins=cors_config.allow_origins, @@ -49,7 +50,7 @@ def __init__(self, app, cors_config: CORSConfig): class ExitRouteMiddleware(BaseHTTPMiddleware): """Middleware to handle application exit requests.""" - def __init__(self, app, exit_path: str): + def __init__(self, app: Any, exit_path: str) -> None: super().__init__(app) self.exit_path = exit_path.rstrip("/") @@ -76,7 +77,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: # Continue with normal request processing return await call_next(request) - async def _delayed_exit(self): + async def _delayed_exit(self) -> None: """Exit the application after a short delay.""" import asyncio # noqa: PLC0415 @@ -88,7 +89,7 @@ async def _delayed_exit(self): class RescanMiddleware(BaseHTTPMiddleware): """Middleware to handle tool rescanning requests.""" - def __init__(self, app, rescan_path: str, tool_discovery: ToolDiscovery, tool_registry: ToolRegistry): + def __init__(self, app: Any, rescan_path: str, tool_discovery: ToolDiscovery, tool_registry: ToolRegistry) -> None: super().__init__(app) self.rescan_path = rescan_path.rstrip("/") self.tool_discovery = tool_discovery diff --git a/lib/serve/mcp-workbench/test_install.py b/lib/serve/mcp-workbench/test_install.py index 089396ef1..77dac6ad6 100644 --- a/lib/serve/mcp-workbench/test_install.py +++ b/lib/serve/mcp-workbench/test_install.py @@ -20,12 +20,13 @@ import subprocess import sys +from typing import Any from mcpworkbench.core.annotations import mcp_tool from mcpworkbench.core.base_tool import BaseTool -def test_cli_available(): +def test_cli_available() -> bool: """Test that the CLI command is available.""" try: @@ -45,15 +46,15 @@ def test_cli_available(): return False -def test_basic_functionality(): +def test_basic_functionality() -> bool: """Test basic functionality works.""" try: class TestTool(BaseTool): - def __init__(self): + def __init__(self) -> None: super().__init__("test", "A test tool") - async def execute(self, **kwargs): + async def execute(self, **kwargs: Any) -> dict[str, str]: # type: ignore[override] return {"result": "test successful"} # Test tool instantiation @@ -64,7 +65,7 @@ async def execute(self, **kwargs): # Test annotation @mcp_tool(name="test_func", description="Test function") - def test_func(): + def test_func() -> str: return "annotated test successful" assert hasattr(test_func, "_is_mcp_tool") @@ -77,7 +78,7 @@ def test_func(): return False -def main(): +def main() -> bool: """Run all installation tests.""" print("Testing MCP Workbench installation...") print("=" * 50) diff --git a/lib/serve/mcpWorkbenchConstruct.ts b/lib/serve/mcpWorkbenchConstruct.ts index 7620d98b5..9afde469d 100644 --- a/lib/serve/mcpWorkbenchConstruct.ts +++ b/lib/serve/mcpWorkbenchConstruct.ts @@ -31,6 +31,7 @@ import * as events from 'aws-cdk-lib/aws-events'; import * as targets from 'aws-cdk-lib/aws-events-targets'; import { ECSCluster, ECSTasks } from '../api-base/ecsCluster'; import { Ec2Service } from 'aws-cdk-lib/aws-ecs'; +import { BucketEncryption } from 'aws-cdk-lib/aws-s3'; export type McpWorkbenchConstructProps = { restApiId: string; @@ -188,7 +189,8 @@ export class McpWorkbenchConstruct extends Construct { enforceSSL: true, serverAccessLogsBucket: bucketAccessLogsBucket, serverAccessLogsPrefix: 'logs/mcpworkbench-bucket/', - eventBridgeEnabled: true + eventBridgeEnabled: true, + encryption: BucketEncryption.S3_MANAGED }); } diff --git a/lib/serve/rest-api/.coveragerc b/lib/serve/rest-api/.coveragerc new file mode 100644 index 000000000..5d1b54764 --- /dev/null +++ b/lib/serve/rest-api/.coveragerc @@ -0,0 +1,18 @@ +[run] +omit = + # Exclude __init__ files + */src/*/__init__.py + */src/*/*/__init__.py + */src/*/*/*/__init__.py + */src/*/*/*/*/__init__.py + # Exclude FastAPI endpoint wrappers (thin wrappers around handlers) + */src/api/endpoints/v1/*.py + */src/api/endpoints/v2/*.py + # Exclude main application file (requires full integration test) + */src/main.py + # Exclude model adapters (require actual model endpoints) + */src/lisa_serve/ecs/textgen/*.py + */src/lisa_serve/ecs/embedding/*.py + */src/lisa_serve/base/*.py + # Exclude routes (thin wrappers) + */src/api/routes.py diff --git a/lib/serve/rest-api/Dockerfile b/lib/serve/rest-api/Dockerfile index 39d00791a..1da725049 100644 --- a/lib/serve/rest-api/Dockerfile +++ b/lib/serve/rest-api/Dockerfile @@ -1,15 +1,33 @@ -ARG BASE_IMAGE=python:3.13-slim +ARG BASE_IMAGE=public.ecr.aws/docker/library/python:3.13-slim FROM ${BASE_IMAGE} ARG PRISMA_CACHE_DIR=PRISMA_CACHE ENV PRISMA_CACHE_DIR=$PRISMA_CACHE_DIR +# Apply SSH security hardening - disable weak ciphers (3DES-CBC, etc.) +RUN mkdir -p /etc/ssh && \ + echo "" >> /etc/ssh/ssh_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/ssh_config && \ + echo "Host *" >> /etc/ssh/ssh_config && \ + echo " Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/ssh_config && \ + echo " MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/ssh_config && \ + echo " KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/ssh_config && \ + if [ -f /etc/ssh/sshd_config ]; then \ + echo "" >> /etc/ssh/sshd_config && \ + echo "# LISA Security Hardening - Disable weak ciphers" >> /etc/ssh/sshd_config && \ + echo "Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com" >> /etc/ssh/sshd_config && \ + echo "MACs hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/sshd_config && \ + echo "KexAlgorithms curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group16-sha512,diffie-hellman-group18-sha512" >> /etc/ssh/sshd_config; \ + fi + # Install build dependencies for madoka package -RUN apt-get update && apt-get install -y \ - gcc \ - g++ \ - make \ - && rm -rf /var/lib/apt/lists/* +RUN if command -v apt-get >/dev/null 2>&1; then \ + apt-get update && apt-get install -y gcc g++ make procps && rm -rf /var/lib/apt/lists/*; \ + elif command -v yum >/dev/null 2>&1; then \ + yum install -y gcc gcc-c++ make procps-ng && yum clean all; \ + elif command -v apk >/dev/null 2>&1; then \ + apk add --no-cache gcc g++ make musl-dev procps; \ + fi # Copy LiteLLM config directly out of the LISA config.yaml file ARG LITELLM_CONFIG diff --git a/lib/serve/rest-api/src/api/endpoints/v1/models.py b/lib/serve/rest-api/src/api/endpoints/v1/models.py index 3bcb353c0..68f7ad1e4 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/models.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/models.py @@ -15,7 +15,6 @@ """Model information routes.""" import logging -from typing import List, Optional from fastapi import APIRouter, Query from fastapi.responses import JSONResponse @@ -54,7 +53,7 @@ async def describe_model( @router.get(f"/{RestApiResource.DESCRIBE_MODELS}") async def describe_models( - model_types: Optional[List[ModelType]] = Query( + model_types: list[ModelType] | None = Query( None, description="The types of models to list. If not provided, all types will be listed.", alias="modelTypes", @@ -71,7 +70,7 @@ async def describe_models( @router.get(f"/{RestApiResource.LIST_MODELS}") async def list_models( - model_types: Optional[List[ModelType]] = Query( + model_types: list[ModelType] | None = Query( None, description="The types of models to list. If not provided, all types will be listed.", alias="modelTypes", diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index 16eddbfbb..2c0a8d103 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -14,20 +14,20 @@ """Model invocation routes.""" +import fnmatch import json import logging import os +import uuid from collections.abc import Iterator -from typing import Union import boto3 -import requests +from auth import Authorizer, extract_user_groups_from_jwt from fastapi import APIRouter, HTTPException, Request from fastapi.responses import JSONResponse, Response, StreamingResponse +from requests import request as requests_request from starlette.status import HTTP_401_UNAUTHORIZED - -from ....auth import Authorizer, extract_user_groups_from_jwt -from ....utils.guardrails import ( +from utils.guardrails import ( create_guardrail_json_response, create_guardrail_streaming_response, extract_guardrail_response, @@ -35,7 +35,7 @@ get_model_guardrails, is_guardrail_violation, ) -from ....utils.metrics import publish_metrics_event +from utils.metrics import publish_metrics_event # Local LiteLLM installation URL. By default, LiteLLM runs on port 4000. Change the port here if the # port was changed as part of the LiteLLM startup in entrypoint.sh @@ -68,6 +68,15 @@ "v1/audio/speech", "audio/transcriptions", "v1/audio/transcriptions", + # Video routes (using wildcards for IDs) + "videos", + "v1/videos", + "videos/*", + "v1/videos/*", + "videos/*/content", + "v1/videos/*/content", + "videos/*/remix", + "v1/videos/*/remix", # Health check routes "health", "health/readiness", @@ -85,12 +94,57 @@ LITELLM_KEY = os.environ["LITELLM_KEY"] secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) +s3_client = boto3.client("s3", region_name=os.environ["AWS_REGION"]) +s3_bucket_name = os.environ.get("GENERATED_IMAGES_S3_BUCKET_NAME", "") logger = logging.getLogger(__name__) router = APIRouter() +def _generate_presigned_video_url(key: str, content_type: str = "video/mp4") -> str: + """Generate a presigned URL for video content stored in S3.""" + url: str = s3_client.generate_presigned_url( + "get_object", + Params={ + "Bucket": s3_bucket_name, + "Key": key, + "ResponseContentType": content_type, + "ResponseCacheControl": "no-cache", + "ResponseContentDisposition": "inline", + }, + ExpiresIn=3600, # URL expires in 1 hour + ) + return url + + +def is_openai_route(api_path: str) -> bool: + # First check for exact matches (most common case) + if api_path in OPENAI_ROUTES: + return True + + # Only check wildcard patterns if the path contains "video" (since only video routes have wildcards) + # This avoids expensive pattern matching for non-video routes + if "video" not in api_path: + return False + + wildcard_patterns = [pattern for pattern in OPENAI_ROUTES if "*" in pattern] + wildcard_patterns.sort(key=len, reverse=True) + + for route_pattern in wildcard_patterns: + if fnmatch.fnmatch(api_path, route_pattern): + # For patterns like "videos/*" (not "videos/*/something"), ensure we don't match + # paths with additional segments (e.g., "videos/123/content" should not match "videos/*") + if route_pattern.endswith("/*") and not route_pattern.endswith("/*/"): + pattern_segments = route_pattern.count("/") + path_segments = api_path.count("/") + if path_segments != pattern_segments: + continue + return True + + return False + + async def apply_guardrails_to_request(params: dict, model_id: str, jwt_data: dict) -> None: """ Apply guardrails to a chat completion request. @@ -130,7 +184,7 @@ async def apply_guardrails_to_request(params: dict, model_id: str, jwt_data: dic def handle_guardrail_violation_response( - response: requests.Response, model_id: str, params: dict, is_streaming: bool + response: Response, model_id: str, params: dict, is_streaming: bool ) -> Response | None: """ Handle guardrail violation errors in LiteLLM responses. @@ -179,7 +233,7 @@ def handle_guardrail_violation_response( return None -def generate_response(iterator: Iterator[Union[str, bytes]]) -> Iterator[str]: +def generate_response(iterator: Iterator[str | bytes]) -> Iterator[str]: """For streaming responses, generate strings instead of bytes objects so that clients recognize the LLM output.""" for line in iterator: if isinstance(line, bytes): @@ -188,7 +242,7 @@ def generate_response(iterator: Iterator[Union[str, bytes]]) -> Iterator[str]: yield f"{line}\n\n" -def generate_response_with_guardrail_handling(iterator: Iterator[Union[str, bytes]], model: str) -> Iterator[str]: +def generate_response_with_guardrail_handling(iterator: Iterator[str | bytes], model: str) -> Iterator[str]: """ Generate streaming responses with guardrail violation error handling. @@ -227,8 +281,7 @@ def generate_response_with_guardrail_handling(iterator: Iterator[Union[str, byte if guardrail_response: # Stream the guardrail response created = int(chunk_data.get("created", 0)) - for chunk in create_guardrail_streaming_response(guardrail_response, model, created): - yield chunk + yield from create_guardrail_streaming_response(guardrail_response, model, created) return # Stop streaming after guardrail response else: # Could not extract guardrail response, pass through the error @@ -261,10 +314,13 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: headers = dict(request.headers.items()) authorizer = Authorizer() - require_admin = api_path not in OPENAI_ROUTES + require_admin = not is_openai_route(api_path) jwt_data = await authorizer.authenticate_request(request) if not await authorizer.can_access(request, require_admin): - raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, message="Not authenticated in litellm_passthrough") + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + message="Not authenticated in litellm_passthrough", + ) # At this point in the request, we have already validated auth with IdP or persistent token. By using LiteLLM for # model management, LiteLLM requires an admin key, and that forces all requests to require a key as well. To avoid @@ -274,19 +330,125 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: http_method = request.method if http_method == "GET" or http_method == "DELETE": - response = requests.request(method=http_method, url=litellm_path, headers=headers) - return JSONResponse(response.json(), status_code=response.status_code) - # not a GET or DELETE request, so expect a JSON payload as part of the request + + response = requests_request(method=http_method, url=litellm_path, headers=headers) + + # Check content type to handle binary responses (e.g., video content) + content_type = response.headers.get("content-type", "").lower() + + # If it's JSON, parse and return as JSON + if "application/json" in content_type or "text/json" in content_type: + try: + return JSONResponse(response.json(), status_code=response.status_code) + except (ValueError, json.JSONDecodeError): + # If JSON parsing fails, fall through to return raw content + pass + + # For video content, store in S3 and return presigned URL + if "video/" in content_type and "/content" in api_path and response.status_code == 200: + try: + # Extract video ID from path (e.g., videos/video_abc123/content -> video_abc123) + path_parts = api_path.split("/") + video_id = path_parts[-2] if len(path_parts) >= 2 else str(uuid.uuid4()) + + # Generate a unique S3 key for the video + file_extension = ".mp4" # Default to mp4 + if "video/webm" in content_type: + file_extension = ".webm" + elif "video/quicktime" in content_type: + file_extension = ".mov" + + s3_key = f"videos/{video_id}{file_extension}" + + # Upload video to S3 + s3_client.put_object( + Bucket=s3_bucket_name, + Key=s3_key, + Body=response.content, + ContentType=content_type, + ) + + # Generate presigned URL + presigned_url = _generate_presigned_video_url(s3_key) + + # Return JSON response with presigned URL + return JSONResponse( + { + "url": presigned_url, + "s3_key": s3_key, + "content_type": content_type, + }, + status_code=200, + ) + except Exception as e: + logger.error(f"Error storing video to S3: {e}") + # Fall through to return raw content if S3 storage fails + + # For other binary content (image, etc.) or non-JSON, return raw response + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + media_type=content_type if content_type else None, + ) + + # Check if request is multipart/form-data (used for video generation with image references) + content_type = request.headers.get("content-type", "").lower() + is_multipart = "multipart/form-data" in content_type + is_video_endpoint = "video" in api_path.lower() + + # Handle multipart/form-data requests (video generation with image references) + if is_multipart and is_video_endpoint: + try: + # Parse the form data + form = await request.form() + + # Build files dict for requests library + files = {} + data = {} + + for field_name, field_value in form.items(): + # Check if it's a file field + if hasattr(field_value, "read"): + # It's a file - read the content and prepare for upload + file_content = await field_value.read() + filename = getattr(field_value, "filename", "file") + content_type = getattr(field_value, "content_type", "application/octet-stream") + files[field_name] = (filename, file_content, content_type) + else: + # It's a regular form field + data[field_name] = field_value + + # Create new headers without Content-Type (requests library will set it with correct boundary) + # Use LITELLM_KEY instead of the user's token (consistent with rest of the code) + forward_headers = {"Authorization": f"Bearer {LITELLM_KEY}"} + + # Forward multipart request to LiteLLM + response = requests_request( + method=http_method, url=litellm_path, data=data, files=files, headers=forward_headers + ) + + if response.status_code != 200: + logger.error(f"LiteLLM error response: {response.text}") + + return JSONResponse(response.json(), status_code=response.status_code) + + except Exception as e: + logger.error(f"Error processing multipart request: {e}") + raise HTTPException(status_code=400, detail=f"Error processing multipart request: {str(e)}") + + # Handle JSON requests (default behavior) params = await request.json() # Apply guardrails for chat/completions requests if api_path in ["chat/completions", "v1/chat/completions"]: model_id = params.get("model") - if model_id: + if model_id and jwt_data: await apply_guardrails_to_request(params, model_id, jwt_data) if params.get("stream", False): # if a streaming request - response = requests.request(method=http_method, url=litellm_path, json=params, headers=headers, stream=True) + + response = requests_request(method=http_method, url=litellm_path, json=params, headers=headers, stream=True) # Check for guardrail violations model_id = params.get("model", "") @@ -307,9 +469,13 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: status_code=response.status_code, ) else: - return StreamingResponse(generate_response(response.iter_lines()), status_code=response.status_code) + return StreamingResponse( + generate_response(response.iter_lines()), + status_code=response.status_code, + ) else: # not a streaming request - response = requests.request(method=http_method, url=litellm_path, json=params, headers=headers) + + response = requests_request(method=http_method, url=litellm_path, json=params, headers=headers) # Check for guardrail violations model_id = params.get("model", "") diff --git a/lib/serve/rest-api/src/api/routes.py b/lib/serve/rest-api/src/api/routes.py index 08e052796..a2ff9289f 100644 --- a/lib/serve/rest-api/src/api/routes.py +++ b/lib/serve/rest-api/src/api/routes.py @@ -17,12 +17,11 @@ import logging import os +from api.endpoints.v2 import litellm_passthrough +from auth import Authorizer from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse -from ..auth import Authorizer -from .endpoints.v2 import litellm_passthrough - logger = logging.getLogger(__name__) router = APIRouter() diff --git a/lib/serve/rest-api/src/auth.py b/lib/serve/rest-api/src/auth.py index 6fee3708b..acd7840c4 100644 --- a/lib/serve/rest-api/src/auth.py +++ b/lib/serve/rest-api/src/auth.py @@ -22,7 +22,7 @@ from datetime import datetime from enum import Enum from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any import boto3 import jwt @@ -33,8 +33,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from loguru import logger from starlette.status import HTTP_401_UNAUTHORIZED - -from .utils.decorators import singleton +from utils.decorators import singleton TOKEN_EXPIRATION_NAME = "tokenExpiration" # nosec B105 TOKEN_TABLE_NAME = "TOKEN_TABLE_NAME" # nosec B105 @@ -71,12 +70,13 @@ def values(cls) -> list[str]: raise RuntimeError("No crypto support for JWT.") -def get_oidc_metadata(cert_path: Optional[str] = None) -> Dict[str, Any]: +def get_oidc_metadata(cert_path: str | None = None) -> dict[str, Any]: """Get OIDC endpoints and metadata from authority.""" authority = os.environ.get("AUTHORITY") resp = requests.get(f"{authority}/.well-known/openid-configuration", verify=cert_path or True, timeout=30) resp.raise_for_status() - return resp.json() # type: ignore + result: dict[str, Any] = resp.json() + return result def get_jwks_client() -> jwt.PyJWKClient: @@ -96,11 +96,11 @@ def get_jwks_client() -> jwt.PyJWKClient: def id_token_is_valid( id_token: str, client_id: str, authority: str, jwks_client: jwt.PyJWKClient -) -> Optional[Dict[str, Any]]: +) -> dict[str, Any] | None: """Check whether an ID token is valid and return decoded data.""" try: signing_key = jwks_client.get_signing_key_from_jwt(id_token) - data: Dict[str, Any] = jwt.decode( + data: dict[str, Any] = jwt.decode( id_token, signing_key.key, algorithms=["RS256"], @@ -133,7 +133,7 @@ def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: return group in current_node -def extract_user_groups_from_jwt(jwt_data: Optional[Dict[str, Any]]) -> list[str]: +def extract_user_groups_from_jwt(jwt_data: dict[str, Any] | None) -> list[str]: """ Extract user groups from JWT data using the JWT_GROUPS_PROP environment variable. @@ -160,7 +160,7 @@ def extract_user_groups_from_jwt(jwt_data: Optional[Dict[str, Any]]) -> list[str # Traverse the property path to find groups props = jwt_groups_property.split(".") - current_node = jwt_data + current_node: Any = jwt_data for prop in props: if isinstance(current_node, dict) and prop in current_node: @@ -171,13 +171,14 @@ def extract_user_groups_from_jwt(jwt_data: Optional[Dict[str, Any]]) -> list[str # current_node should now be the groups list if isinstance(current_node, list): - return current_node + groups: list[str] = current_node + return groups else: logger.warning(f"Expected list of groups but got {type(current_node)}") return [] -def get_authorization_token(headers: Dict[str, str], header_name: str = AuthHeaders.AUTHORIZATION) -> str: +def get_authorization_token(headers: dict[str, str], header_name: str = AuthHeaders.AUTHORIZATION) -> str: """Get Bearer token from Authorization headers if it exists.""" if header_name in headers: return headers.get(header_name, "").removeprefix("Bearer").strip() @@ -187,19 +188,19 @@ def get_authorization_token(headers: Dict[str, str], header_name: str = AuthHead class OIDCHTTPBearer(HTTPBearer): """OIDC based bearer token authenticator.""" - def __init__(self, authority: Optional[str] = None, client_id: Optional[str] = None, **kwargs: Dict[str, Any]): + def __init__(self, authority: str | None = None, client_id: str | None = None, **kwargs: dict[str, Any]): super().__init__(**kwargs) self.authority = authority or os.environ.get("AUTHORITY", "") self.client_id = client_id or os.environ.get("CLIENT_ID", "") self.jwks_client = get_jwks_client() - async def id_token_is_valid(self, request: Request) -> Optional[Dict[str, Any]]: + async def id_token_is_valid(self, request: Request) -> dict[str, Any] | None: """Check whether an ID token is valid and return decoded data.""" http_auth_creds = await super().__call__(request) id_token = http_auth_creds.credentials try: signing_key = self.jwks_client.get_signing_key_from_jwt(id_token) - data: Dict[str, Any] = jwt.decode( + data: dict[str, Any] = jwt.decode( id_token, signing_key.key, algorithms=["RS256"], @@ -237,7 +238,7 @@ def _get_token_info(self, token_hash: str) -> Any: ddb_response = self._token_table.get_item(Key={"token": token_hash}, ReturnConsumedCapacity="NONE") return ddb_response.get("Item", None) - async def is_valid_api_token(self, headers: Dict[str, str]) -> Optional[Dict[str, Any]]: + async def is_valid_api_token(self, headers: dict[str, str]) -> dict[str, Any] | None: """Return token info if API Token from request headers is valid, else None.""" for header_name in AuthHeaders.values(): @@ -268,7 +269,8 @@ async def is_valid_api_token(self, headers: Dict[str, str]) -> Optional[Dict[str continue # Token is valid - return the token info - return token_info + result: dict[str, Any] = dict(token_info) + return result return None @@ -277,11 +279,11 @@ class ManagementTokenAuthorizer: """Class for checking Management tokens against a SecretsManager secret.""" def __init__(self) -> None: - self._cache = TTLCache(maxsize=1, ttl=300) + self._cache: TTLCache = TTLCache(maxsize=1, ttl=300) self._cache_lock = threading.RLock() self._local = threading.local() - def _get_secrets_client(self): + def _get_secrets_client(self) -> Any: """Get thread-local secrets manager client.""" if not hasattr(self._local, "secrets_manager"): self._local.secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) @@ -293,10 +295,11 @@ def get_management_tokens(self) -> list[str]: with self._cache_lock: if cache_key in self._cache: - return self._cache[cache_key] + cached_tokens: list[str] = self._cache[cache_key] + return cached_tokens logger.info("Updating management tokens cache") - secret_tokens = [] + secret_tokens: list[str] = [] secret_id = os.environ.get("MANAGEMENT_KEY_NAME") secrets_manager = self._get_secrets_client() @@ -315,7 +318,7 @@ def get_management_tokens(self) -> list[str]: return secret_tokens - async def is_valid_api_token(self, headers: Dict[str, str]) -> bool: + async def is_valid_api_token(self, headers: dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" secret_tokens = await asyncio.to_thread(self.get_management_tokens) token = get_authorization_token(headers) @@ -337,11 +340,11 @@ def __init__(self) -> None: self.management_token_authorizer = ManagementTokenAuthorizer() self.oidc_authorizer = OIDCHTTPBearer(authority=self.authority, client_id=self.client_id) - async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]: + async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: jwt_data = await self.authenticate_request(request) return jwt_data - async def authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]: + async def authenticate_request(self, request: Request) -> dict[str, Any] | None: """Authenticate request and return JWT data if valid, else None. Invalid requests throw an exception""" logger.trace(f"Authenticating request: {request.method} {request.url.path}") @@ -356,7 +359,7 @@ async def authenticate_request(self, request: Request) -> Optional[Dict[str, Any # Then try management tokens logger.trace("Try Management Auth Token...") - if await self.management_token_authorizer.is_valid_api_token(request.headers): + if await self.management_token_authorizer.is_valid_api_token(dict(request.headers)): logger.trace("Valid Management token") return None @@ -383,9 +386,7 @@ def _log_access_attempt( else: logger.warning(log_msg) - async def can_access( - self, request: Request, require_admin: bool, jwt_data: Optional[Dict[str, Any]] = None - ) -> bool: + async def can_access(self, request: Request, require_admin: bool, jwt_data: dict[str, Any] | None = None) -> bool: """Return whether the user is authorized to access the endpoint.""" endpoint = f"{request.method} {request.url.path}" @@ -439,7 +440,7 @@ async def can_access( self._log_access_attempt(request, auth_method, user_id, endpoint, has_access, reason) return has_access - def _set_token_context(self, request: Request, token_info: Dict[str, Any]) -> None: + def _set_token_context(self, request: Request, token_info: dict[str, Any]) -> None: """Store token info in request state for later access.""" request.state.api_token_info = token_info request.state.username = token_info.get("username", "api-token") diff --git a/lib/serve/rest-api/src/entrypoint.sh b/lib/serve/rest-api/src/entrypoint.sh index 63180044f..b96de0a7c 100644 --- a/lib/serve/rest-api/src/entrypoint.sh +++ b/lib/serve/rest-api/src/entrypoint.sh @@ -48,7 +48,7 @@ if [ "${DEBUG}" = "true" ]; then GUNICORN_LOG_LEVEL="debug" PRISMA_LOG_LEVEL="info,query" else - LOG_LEVEL="INFO" + LOG_LEVEL="${LITELLM_LOG_LEVEL:-WARNING}" GUNICORN_LOG_LEVEL="info" PRISMA_LOG_LEVEL="warn" fi @@ -56,12 +56,40 @@ fi # Configure LiteLLM logging export LITELLM_LOG=${LOG_LEVEL} export LITELLM_JSON_LOGS=${LITELLM_JSON_LOGS:-false} -export LITELLM_DISABLE_HEALTH_CHECK_LOGS=${LITELLM_DISABLE_HEALTH_CHECK_LOGS:-false} +export LITELLM_DISABLE_HEALTH_CHECK_LOGS=${LITELLM_DISABLE_HEALTH_CHECK_LOGS:-true} # Configure Prisma logging export PRISMA_LOG_LEVEL=${PRISMA_LOG_LEVEL} +# Wait for database to be reachable before starting LiteLLM +# This prevents startup errors from race conditions +if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_PORT" ]; then + echo "🔍 Checking database connectivity..." + echo " - Host: $DATABASE_HOST" + echo " - Port: $DATABASE_PORT" + + MAX_RETRIES=30 + RETRY_INTERVAL=2 + retry_count=0 + + while [ $retry_count -lt $MAX_RETRIES ]; do + if timeout 5 bash -c "echo > /dev/tcp/$DATABASE_HOST/$DATABASE_PORT" 2>/dev/null; then + echo "✅ Database is reachable" + break + fi + retry_count=$((retry_count + 1)) + echo " - Waiting for database... (attempt $retry_count/$MAX_RETRIES)" + sleep $RETRY_INTERVAL + done + + if [ $retry_count -eq $MAX_RETRIES ]; then + echo "⚠️ Database not reachable after $MAX_RETRIES attempts, proceeding anyway..." + fi +fi + # Start LiteLLM in the background with better error handling +# Note: For IAM RDS authentication, LiteLLM handles token refresh natively +# when IAM_TOKEN_DB_AUTH=true is set (configured via CDK environment variables) echo "🚀 Starting LiteLLM server..." echo " - Config file: litellm_config.yaml" echo " - Port: 4000 (internal)" @@ -69,9 +97,19 @@ echo " - Database: Prisma with auto-push enabled" echo " - Debug mode: ${DEBUG:-false}" echo " - Log level: $LOG_LEVEL" echo " - Prisma log level: $PRISMA_LOG_LEVEL" +if [ "$IAM_TOKEN_DB_AUTH" = "true" ]; then + echo " - IAM Auth: enabled (tokens auto-refresh)" + echo " - Database User: $DATABASE_USER" +fi # Start LiteLLM and capture its PID -litellm -c litellm_config.yaml --use_prisma_db_push > litellm.log 2>&1 & +# Note: Transient DB connection errors may appear during IAM token refresh cycles +# These are expected with LiteLLM < 1.81 and the service recovers automatically +# Set LITELLM_LOG_LEVEL=INFO to see all logs, or DEBUG for verbose output +# Use --num_workers to increase parallelism for embedding requests +LITELLM_WORKERS=${LITELLM_WORKERS:-4} +echo " - LiteLLM workers: $LITELLM_WORKERS" +litellm -c litellm_config.yaml --use_prisma_db_push --num_workers "$LITELLM_WORKERS" > litellm.log 2>&1 & LITELLM_PID=$! echo " - LiteLLM PID: $LITELLM_PID" @@ -96,6 +134,9 @@ echo " - Workers: $THREADS" echo " - Timeout: 600 seconds" echo " - Log level: $GUNICORN_LOG_LEVEL" +# Set PYTHONPATH to include src directory so imports work correctly +export PYTHONPATH="/app/src:${PYTHONPATH:-}" + exec gunicorn -k uvicorn.workers.UvicornWorker -t 600 -w "$THREADS" -b "$HOST:$PORT" \ --log-level "$GUNICORN_LOG_LEVEL" \ "src.main:app" diff --git a/lib/serve/rest-api/src/handlers/embeddings.py b/lib/serve/rest-api/src/handlers/embeddings.py index 73f9adb36..f6fcdde43 100644 --- a/lib/serve/rest-api/src/handlers/embeddings.py +++ b/lib/serve/rest-api/src/handlers/embeddings.py @@ -14,17 +14,32 @@ """Embedding route handlers.""" import logging -from typing import Any, Dict +from typing import Any -from ..utils.request_utils import validate_and_prepare_llm_request -from ..utils.resources import RestApiResource +from utils.request_utils import RegistryProtocol, validate_and_prepare_llm_request +from utils.resources import RestApiResource logger = logging.getLogger(__name__) -async def handle_embeddings(request_data: Dict[str, Any]) -> Dict[str, Any]: - """Handle for embeddings endpoint.""" - model, model_kwargs, text = await validate_and_prepare_llm_request(request_data, RestApiResource.EMBEDDINGS) +async def handle_embeddings(request_data: dict[str, Any], registry: RegistryProtocol | None = None) -> dict[str, Any]: + """Handle for embeddings endpoint. + + Parameters + ---------- + request_data : dict[str, Any] + Request data + registry : RegistryProtocol | None + Optional registry for dependency injection (testing) + + Returns + ------- + dict[str, Any] + Embeddings response + """ + model, model_kwargs, text = await validate_and_prepare_llm_request( + request_data, RestApiResource.EMBEDDINGS, registry + ) response = await model.embed_query(text=text, model_kwargs=model_kwargs) return response.dict() # type: ignore diff --git a/lib/serve/rest-api/src/handlers/generation.py b/lib/serve/rest-api/src/handlers/generation.py index 313b3781f..a796c98b0 100644 --- a/lib/serve/rest-api/src/handlers/generation.py +++ b/lib/serve/rest-api/src/handlers/generation.py @@ -12,20 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Generation route handlers.""" +"""Generation route handlers - refactored for testability.""" import json import logging -from typing import Any, AsyncGenerator, Dict, List, Tuple +from collections.abc import AsyncGenerator +from typing import Any -from ..utils.request_utils import handle_stream_exceptions, validate_and_prepare_llm_request -from ..utils.resources import RestApiResource +from services.text_processing import ( + map_openai_params_to_lisa, + parse_model_provider_from_string, + render_context_from_messages, +) +from utils.request_utils import ( + handle_stream_exceptions, + RegistryProtocol, + validate_and_prepare_llm_request, +) +from utils.resources import RestApiResource logger = logging.getLogger(__name__) -async def handle_generate(request_data: Dict[str, Any]) -> Dict[str, Any]: - """Handle for generate endpoint.""" - model, model_kwargs, text = await validate_and_prepare_llm_request(request_data, RestApiResource.GENERATE) +async def handle_generate(request_data: dict[str, Any], registry: RegistryProtocol | None = None) -> dict[str, Any]: + """Handle for generate endpoint. + + Parameters + ---------- + request_data : dict[str, Any] + Request data + registry : RegistryProtocol | None + Optional registry for dependency injection (testing) + + Returns + ------- + dict[str, Any] + Generation response + """ + model, model_kwargs, text = await validate_and_prepare_llm_request(request_data, RestApiResource.GENERATE, registry) try: response = await model.generate(text=text, model_kwargs=model_kwargs) return response.dict() # type: ignore @@ -35,55 +58,63 @@ async def handle_generate(request_data: Dict[str, Any]) -> Dict[str, Any]: @handle_stream_exceptions -async def handle_generate_stream(request_data: Dict[str, Any]) -> AsyncGenerator[str, None]: - """Handle for generate_stream endpoint.""" - model, model_kwargs, text = await validate_and_prepare_llm_request(request_data, RestApiResource.GENERATE_STREAM) - async for response in model.generate_stream(text=text, model_kwargs=model_kwargs): - yield f"data:{json.dumps(response.dict(exclude_none=True))}\n\n" - +async def handle_generate_stream( + request_data: dict[str, Any], registry: RegistryProtocol | None = None +) -> AsyncGenerator[str]: + """Handle for generate_stream endpoint. -def render_context(messages_list: List[Dict[str, str]]) -> str: - """Provide context string for LLM from previous messages.""" - out_str = "\n\n".join([message["content"] for message in messages_list]) - return out_str + Parameters + ---------- + request_data : dict[str, Any] + Request data + registry : RegistryProtocol | None + Optional registry for dependency injection (testing) - -def parse_model_provider_names(model_string: str) -> Tuple[str, str]: - """Parse out the model name and its provider name from the combined name of the two. - - Format is assumed to be `${model_name} (${provider_name})` and neither of the model_name or provider_name have - a space in them. Requests using the OpenAI text generation APIs will require that model names follow this format. + Yields + ------ + str + Streaming response chunks """ - model_parts = model_string.split() - model_name = model_parts[0].strip() - provider = model_parts[1].replace("(", "").replace(")", "").strip() - return model_name, provider + model, model_kwargs, text = await validate_and_prepare_llm_request( + request_data, RestApiResource.GENERATE_STREAM, registry + ) + async for response in model.generate_stream(text=text, model_kwargs=model_kwargs): + yield f"data:{json.dumps(response.dict(exclude_none=True))}\n\n" @handle_stream_exceptions async def handle_openai_generate_stream( - request_data: Dict[str, Any], is_text_completion: bool = False -) -> AsyncGenerator[str, None]: - """Handle for openai_generate_stream endpoint.""" - # map OpenAI API settings (keys) with corresponding TGI model settings (values). Any unsupported options ignored. - request_mapping = { - "echo": "return_full_text", - "frequency_penalty": "repetition_penalty", - "max_tokens": "max_new_tokens", - "seed": "seed", - "stop": "stop_sequences", - "temperature": "temperature", - "top_p": "top_p", - } - mapped_kwargs = { - request_mapping[k]: request_data[k] for k in request_mapping if k in request_data and request_data[k] - } + request_data: dict[str, Any], is_text_completion: bool = False, registry: RegistryProtocol | None = None +) -> AsyncGenerator[str]: + """Handle for openai_generate_stream endpoint. + + Parameters + ---------- + request_data : dict[str, Any] + Request data + is_text_completion : bool + Whether this is a text completion request + registry : RegistryProtocol | None + Optional registry for dependency injection (testing) + Yields + ------ + str + Streaming response chunks + """ + # Map OpenAI parameters to LISA parameters + mapped_kwargs = map_openai_params_to_lisa(request_data) + + # Extract text based on completion type if is_text_completion: text = request_data["prompt"] # text is already a string else: - text = render_context(request_data["messages"]) # text must be converted from a list to a string - model_name, provider = parse_model_provider_names(request_data["model"]) + text = render_context_from_messages(request_data["messages"]) # convert list to string + + # Parse model and provider + model_name, provider = parse_model_provider_from_string(request_data["model"]) + + # Build LISA request lisa_request_data = { "modelName": model_name, "provider": provider, @@ -91,12 +122,20 @@ async def handle_openai_generate_stream( "streaming": request_data.get("stream", False), "modelKwargs": mapped_kwargs, } + model, model_kwargs, text = await validate_and_prepare_llm_request( - lisa_request_data, RestApiResource.GENERATE_STREAM + lisa_request_data, RestApiResource.GENERATE_STREAM, registry ) + async for response in model.openai_generate_stream( text=text, model_kwargs=model_kwargs, is_text_completion=is_text_completion ): yield f"data:{json.dumps(response.dict(exclude_none=True))}\n\n" + if is_text_completion: yield "data: [DONE]\n\n" + + +# Keep backward compatibility - these are now just aliases to the service functions +render_context = render_context_from_messages +parse_model_provider_names = parse_model_provider_from_string diff --git a/lib/serve/rest-api/src/handlers/models.py b/lib/serve/rest-api/src/handlers/models.py index ffdd95ed4..cc159e17e 100644 --- a/lib/serve/rest-api/src/handlers/models.py +++ b/lib/serve/rest-api/src/handlers/models.py @@ -12,114 +12,116 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Model route handlers.""" +"""Model route handlers - refactored for testability.""" import logging -import time -from collections import defaultdict -from typing import Any, DefaultDict, Dict, List +from typing import Any, DefaultDict from fastapi import HTTPException - -from ..utils.cache_manager import get_registered_models_cache -from ..utils.resources import ModelType +from services.model_service import ModelService +from utils.cache_manager import get_registered_models_cache +from utils.resources import ModelType logger = logging.getLogger(__name__) -async def handle_list_models(model_types: List[ModelType]) -> Dict[ModelType, Dict[str, List[str]]]: +def _get_model_service() -> ModelService: + """Factory function to create ModelService with current cache. + + This allows for dependency injection in tests. + """ + return ModelService(get_registered_models_cache()) + + +async def handle_list_models( + model_types: list[ModelType], model_service: ModelService | None = None +) -> dict[ModelType, dict[str, list[str]]]: """Handle for list_models endpoint. Parameters ---------- model_types : List[ModelType] - Model types to list. - - registered_models_cache : Dict[str, Dict[str, Any]] - Registered models cache. + Model types to list + model_service : ModelService | None + Optional model service for dependency injection (testing) Returns ------- Dict[ModelType, Dict[str, List[str]]] - List of model names by model type and model provider. + List of model names by model type and model provider """ - registered_models_cache = get_registered_models_cache() - response = {model_type: registered_models_cache[model_type] for model_type in model_types} + service = model_service or _get_model_service() + return service.list_models(model_types) - return response - -async def handle_openai_list_models() -> Dict[str, Any]: +async def handle_openai_list_models(model_service: ModelService | None = None) -> dict[str, Any]: """Handle for list_models endpoint. + Parameters + ---------- + model_service : ModelService | None + Optional model service for dependency injection (testing) + Returns ------- - Dict[str, Union[str, Any] - OpenAI-compatible response object to list Models. This only returns Text Generation models. + Dict[str, Any] + OpenAI-compatible response object to list Models """ - registered_models_cache = get_registered_models_cache() - - model_payload: List[Dict[str, Any]] = [] - for provider, models in registered_models_cache[ModelType.TEXTGEN].items(): - model_payload.extend( - {"id": f"{model} ({provider})", "object": "model", "created": int(time.time()), "owned_by": "LISA"} - for model in models - ) + service = model_service or _get_model_service() + return service.list_models_openai_format() - response = {"data": model_payload, "object": "list"} - return response - -async def handle_describe_model(provider: str, model_name: str) -> Dict[str, Any]: +async def handle_describe_model( + provider: str, model_name: str, model_service: ModelService | None = None +) -> dict[str, Any]: """Handle for describe_model endpoint. Parameters ---------- provider : str - Model provider name. - + Model provider name model_name : str - Model name. + Model name + model_service : ModelService | None + Optional model service for dependency injection (testing) Returns ------- Dict[str, Any] - Model metadata. + Model metadata + + Raises + ------ + HTTPException + If model metadata not found """ - model_key = f"{provider}.{model_name}" - registered_models_cache = get_registered_models_cache() - metadata = registered_models_cache["metadata"].get(model_key) + service = model_service or _get_model_service() + metadata = service.get_model_metadata(provider, model_name) + if not metadata: error_message = f"Metadata for provider {provider} and model {model_name} not found." logger.error(error_message, extra={"event": "handle_describe_model", "status": "ERROR"}) - raise HTTPException(status_code=404, message=error_message) + raise HTTPException(status_code=404, detail=error_message) - return metadata # type: ignore + return metadata -async def handle_describe_models(model_types: List[ModelType]) -> DefaultDict[str, DefaultDict[str, Dict[str, Any]]]: +async def handle_describe_models( + model_types: list[ModelType], model_service: ModelService | None = None +) -> DefaultDict[str, DefaultDict[str, dict[str, Any]]]: """Handle for describe_models endpoint. Parameters ---------- model_types : List[ModelType] - Model types to list. + Model types to list + model_service : ModelService | None + Optional model service for dependency injection (testing) Returns ------- DefaultDict[str, DefaultDict[str, Dict[str, Any]]] - Model metadata by model type, model provider, and model name. + Model metadata by model type, model provider, and model name """ - registered_models = await handle_list_models(model_types) - registered_models_cache = get_registered_models_cache() - response: DefaultDict[str, DefaultDict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) - - for model_type, providers in registered_models.items(): - response[model_type] = {} # type: ignore - providers = providers or {} - for provider, model_names in providers.items(): - response[model_type][provider] = [ - registered_models_cache["metadata"][f"{provider}.{model_name}"] for model_name in model_names - ] # type: ignore - - return response + service = model_service or _get_model_service() + return service.describe_models(model_types) diff --git a/lib/serve/rest-api/src/lisa_serve/base/base.py b/lib/serve/rest-api/src/lisa_serve/base/base.py index 1e6d1fdfb..964b8de99 100644 --- a/lib/serve/rest-api/src/lisa_serve/base/base.py +++ b/lib/serve/rest-api/src/lisa_serve/base/base.py @@ -15,7 +15,8 @@ """Base model adapters and responses.""" import re from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Dict, List, Optional +from collections.abc import AsyncGenerator +from typing import Any from pydantic import BaseModel, Field @@ -27,30 +28,30 @@ class EmbedQueryResponse(BaseModel): """Response for embed_query method.""" - embeddings: List[List[float]] = Field(..., description="Batch of text embeddings.") + embeddings: list[list[float]] = Field(..., description="Batch of text embeddings.") class GenerateResponse(BaseModel): """Response for generate method.""" generatedText: str = Field(..., description="Generated text.") - generatedTokens: Optional[int] = Field(..., description="Number of generated tokens.") - finishReason: Optional[str] = Field(None, description="Reason for finishing text generation.") + generatedTokens: int | None = Field(..., description="Number of generated tokens.") + finishReason: str | None = Field(None, description="Reason for finishing text generation.") class Token(BaseModel): """Token for generate_stream method.""" text: str = Field(..., description="Token text.") - special: Optional[bool] = Field(None, description="Whether token is a special token.") + special: bool | None = Field(None, description="Whether token is a special token.") class GenerateStreamResponse(BaseModel): """Response for generate_stream method.""" token: Token - generatedTokens: Optional[int] = Field(..., description="Number of generated tokens.") - finishReason: Optional[str] = Field(None, description="Reason for finishing text generation.") + generatedTokens: int | None = Field(..., description="Number of generated tokens.") + finishReason: str | None = Field(None, description="Reason for finishing text generation.") class OpenAIChatCompletionsDelta(BaseModel): @@ -66,7 +67,7 @@ class OpenAIChatCompletionsChoice(BaseModel): delta: OpenAIChatCompletionsDelta = Field( ..., description="A chat completion delta generated by streamed model responses." ) - finish_reason: Optional[str] = Field(..., description="The reason the model stopped generating tokens.") + finish_reason: str | None = Field(..., description="The reason the model stopped generating tokens.") index: int = Field(..., description="The index of the choice in the list of choices.") @@ -97,7 +98,7 @@ class OpenAICompletionsChoice(BaseModel): """Text choice object from Completions endpoint.""" text: str = Field(..., description="A chat completion delta generated by streamed model responses.") - finish_reason: Optional[str] = Field(..., description="The reason the model stopped generating tokens.") + finish_reason: str | None = Field(..., description="The reason the model stopped generating tokens.") index: int = Field(..., description="The index of the choice in the list of choices.") @@ -141,12 +142,12 @@ class EmbeddingModelAdapter(ABC): Endpoint URL. """ - def __init__(self, *, model_name: str, endpoint_url: Optional[str] = None) -> None: + def __init__(self, *, model_name: str, endpoint_url: str | None = None) -> None: self.model_name = model_name self.endpoint_url = endpoint_url @abstractmethod - def embed_query(self, *, text: str, model_kwargs: Dict[str, Any]) -> EmbedQueryResponse: + def embed_query(self, *, text: str, model_kwargs: dict[str, Any]) -> EmbedQueryResponse: """Embed query. Parameters @@ -177,12 +178,12 @@ class TextGenModelAdapter(ABC): Endpoint URL. """ - def __init__(self, *, model_name: str, endpoint_url: Optional[str] = None) -> None: + def __init__(self, *, model_name: str, endpoint_url: str | None = None) -> None: self.model_name = model_name self.endpoint_url = endpoint_url @abstractmethod - def generate(self, *, text: str, model_kwargs: Dict[str, Any]) -> GenerateResponse: + def generate(self, *, text: str, model_kwargs: dict[str, Any]) -> GenerateResponse: """Text generation. Parameters @@ -209,8 +210,8 @@ def generate_stream( self, *, text: str, - model_kwargs: Dict[str, Any], - ) -> AsyncGenerator[GenerateStreamResponse, None]: + model_kwargs: dict[str, Any], + ) -> AsyncGenerator[GenerateStreamResponse]: """Text generation with token streaming. Parameters diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py index b4b110808..2bd532b19 100644 --- a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py +++ b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py @@ -13,7 +13,7 @@ # limitations under the License. """Model adapter and kwargs validator for ECS embedding instructor model endpoints.""" -from typing import Any, Dict +from typing import Any from aiohttp import ClientSession from loguru import logger @@ -53,7 +53,7 @@ def __init__(self, *, model_name: str, endpoint_url: str) -> None: # PyTorch DLC has the endpoint at path /predictions/model self.endpoint_url = f"{self.endpoint_url.rstrip('/')}/predictions/model" # type: ignore - async def embed_query(self, *, text: str, model_kwargs: Dict[str, Any]) -> EmbedQueryResponse: # type: ignore + async def embed_query(self, *, text: str, model_kwargs: dict[str, Any]) -> EmbedQueryResponse: # type: ignore """Embed data. Parameters diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/tei.py b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/tei.py index 6ca2c6858..df2e2f21f 100644 --- a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/tei.py +++ b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/tei.py @@ -13,7 +13,7 @@ # limitations under the License. """Model adapter and kwargs validator for ECS embedding instructor model endpoints.""" -from typing import Any, Dict, Union +from typing import Any from aiohttp import ClientSession from loguru import logger @@ -55,7 +55,7 @@ def __init__(self, *, model_name: str, endpoint_url: str) -> None: self.endpoint_url = endpoint_url.rstrip("/") - async def embed_query(self, *, text: Union[str, list[str]], model_kwargs: Dict[str, Any]) -> EmbedQueryResponse: # type: ignore # noqa: E501 + async def embed_query(self, *, text: str | list[str], model_kwargs: dict[str, Any]) -> EmbedQueryResponse: # type: ignore # noqa: E501 """Embed data. Parameters diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py index bcd92224e..29a4dcde6 100644 --- a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py +++ b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py @@ -15,7 +15,8 @@ """Model adapter and kwargs validator for ECS text generation TGI model endpoints.""" import time import uuid -from typing import Any, AsyncGenerator, Dict, List, Optional +from collections.abc import AsyncGenerator +from typing import Any from loguru import logger from pydantic import BaseModel, confloat, Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt @@ -82,15 +83,15 @@ class EcsTextGenTgiValidator(BaseModel): """ max_new_tokens: NonNegativeInt = 50 - top_k: Optional[NonNegativeInt] = None - top_p: Optional[confloat(gt=0.0, lt=1.0)] = None # type: ignore - typical_p: Optional[confloat(gt=0.0, lt=1.0)] = None # type: ignore - temperature: Optional[NonNegativeFloat] = None - repetition_penalty: Optional[PositiveFloat] = None + top_k: NonNegativeInt | None = None + top_p: confloat(gt=0.0, lt=1.0) | None = None # type: ignore + typical_p: confloat(gt=0.0, lt=1.0) | None = None # type: ignore + temperature: NonNegativeFloat | None = None + repetition_penalty: PositiveFloat | None = None return_full_text: bool = False - truncate: Optional[PositiveInt] = None - stop_sequences: List[str] = Field(default_factory=list) - seed: Optional[PositiveInt] = None + truncate: PositiveInt | None = None + stop_sequences: list[str] = Field(default_factory=list) + seed: PositiveInt | None = None do_sample: bool = False watermark: bool = False @@ -113,7 +114,7 @@ def __init__(self, *, model_name: str, endpoint_url: str) -> None: # Define client self.client = AsyncClient(endpoint_url, timeout=60) - async def generate(self, *, text: str, model_kwargs: Dict[str, Any]) -> GenerateResponse: # type: ignore + async def generate(self, *, text: str, model_kwargs: dict[str, Any]) -> GenerateResponse: # type: ignore """Text generation. Parameters @@ -143,8 +144,8 @@ async def generate(self, *, text: str, model_kwargs: Dict[str, Any]) -> Generate return response async def generate_stream( - self, *, text: str, model_kwargs: Dict[str, Any] - ) -> AsyncGenerator[GenerateStreamResponse, None]: + self, *, text: str, model_kwargs: dict[str, Any] + ) -> AsyncGenerator[GenerateStreamResponse]: """Text generation with token streaming. Parameters @@ -174,8 +175,8 @@ async def generate_stream( yield response async def openai_generate_stream( - self, *, text: str, model_kwargs: Dict[str, Any], is_text_completion: bool - ) -> AsyncGenerator[GenerateStreamResponse, None]: + self, *, text: str, model_kwargs: dict[str, Any], is_text_completion: bool + ) -> AsyncGenerator[GenerateStreamResponse]: """Text generation with token streaming, conforming to the OpenAI API specification. Parameters diff --git a/lib/serve/rest-api/src/lisa_serve/registry/index.py b/lib/serve/rest-api/src/lisa_serve/registry/index.py index 77e23af74..eb2dfc992 100644 --- a/lib/serve/rest-api/src/lisa_serve/registry/index.py +++ b/lib/serve/rest-api/src/lisa_serve/registry/index.py @@ -13,14 +13,14 @@ # limitations under the License. """Model registry.""" -from typing import Any, Dict +from typing import Any class ModelRegistry: """Registry for model providers.""" def __init__(self) -> None: - self.registry: Dict[str, Any] = {} + self.registry: dict[str, Any] = {} def register(self, *, provider: str, adapter: Any, validator: Any) -> None: """Register the adapter and validator for the model provider. @@ -38,7 +38,7 @@ def register(self, *, provider: str, adapter: Any, validator: Any) -> None: """ self.registry[provider] = {"adapter": adapter, "validator": validator} - def get_assets(self, provider: str) -> Dict[str, Any]: + def get_assets(self, provider: str) -> dict[str, Any]: """Get model registry entry.""" try: model_assets = self.registry[provider] diff --git a/lib/serve/rest-api/src/main.py b/lib/serve/rest-api/src/main.py index a0816e1bc..68c4a631a 100644 --- a/lib/serve/rest-api/src/main.py +++ b/lib/serve/rest-api/src/main.py @@ -16,21 +16,17 @@ import json import os import sys -import time from contextlib import asynccontextmanager -from typing import Any, Dict -from uuid import uuid4 -from aiobotocore.session import get_session -from fastapi import FastAPI, Request +import boto3 +from api.routes import router +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response +from lisa_serve.registry import registry from loguru import logger - -from .api.routes import router -from .lisa_serve.registry import registry -from .utils.cache_manager import set_registered_models_cache -from .utils.resources import ModelType, RestApiResource +from middleware import process_request_middleware +from services.model_registration import ModelRegistrationService +from utils.cache_manager import set_registered_models_cache logger.remove() logger_level = os.environ.get("LOG_LEVEL", "INFO") @@ -62,58 +58,20 @@ async def lifespan(app: FastAPI): # type: ignore task_logger = logger.bind(event=event) task_logger.debug("Start task", status="START") - new_models: Dict[str, Dict[str, Any]] = { - ModelType.EMBEDDING: {}, - ModelType.TEXTGEN: {}, - RestApiResource.EMBEDDINGS: {}, - RestApiResource.GENERATE: {}, - RestApiResource.GENERATE_STREAM: {}, - "metadata": {}, - "endpointUrls": {}, - } + # Create model registration service + registration_service = ModelRegistrationService(registry) + try: verify_path = os.getenv("SSL_CERT_FILE") or None - session = get_session() - async with session.create_client("ssm", region_name=os.environ["AWS_REGION"], verify=verify_path) as client: - response = await client.get_parameter(Name=os.environ["REGISTERED_MODELS_PS_NAME"]) + # Use synchronous boto3 client - this runs once at startup so async isn't needed + # This avoids aiobotocore dependency which has version conflicts with litellm's boto3 + ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], verify=verify_path) + response = ssm_client.get_parameter(Name=os.environ["REGISTERED_MODELS_PS_NAME"]) + registered_models = json.loads(response["Parameter"]["Value"]) - for model in registered_models: - provider = model["provider"] - # provider format is `modelHosting.modelType.inferenceContainer`, example: "ecs.textgen.tgi" - [_, _, inference_container] = provider.split(".") - model_name = model["modelName"] - model_type = model["modelType"] - - if inference_container not in ["tgi", "tei", "instructor"]: # stopgap for supporting new containers for v2 - continue # not implementing new providers inside the existing cache; cache is on deprecation path - - # Get default model kwargs - validator = registry.get_assets(provider)["validator"] - model_kwargs = validator().dict() - - # Get model endpoint URL - model_key = f"{provider}.{model_name}" - new_models["endpointUrls"][model_key] = model["endpointUrl"] - - # Get other model metadata to expose to endpoints - new_models["metadata"][model_key] = { - "provider": provider, - "modelName": model_name, - "modelType": model_type, - "modelKwargs": model_kwargs, - } - if "streaming" in model: - new_models["metadata"][model_key]["streaming"] = model["streaming"] - - # Make list of registered accessible either by ModelType and by RestApiResource - if model_type == ModelType.EMBEDDING: - new_models[RestApiResource.EMBEDDINGS].setdefault(provider, []).append(model_name) - new_models[ModelType.EMBEDDING].setdefault(provider, []).append(model_name) - elif model_type == ModelType.TEXTGEN: - new_models[RestApiResource.GENERATE].setdefault(provider, []).append(model_name) - new_models[ModelType.TEXTGEN].setdefault(provider, []).append(model_name) - if model["streaming"]: - new_models[RestApiResource.GENERATE_STREAM].setdefault(provider, []).append(model_name) + + # Register all models using the service + new_models = registration_service.register_models(registered_models) # Update the global cache set_registered_models_cache(new_models) @@ -144,38 +102,6 @@ async def lifespan(app: FastAPI): # type: ignore @app.middleware("http") -async def process_request(request: Request, call_next: Any) -> Any: +async def process_request(request, call_next): # type: ignore """Middleware for processing all HTTP requests.""" - event = "process_request" - request_id = str(uuid4()) # Unique ID for this request - tic = time.time() - - with logger.contextualize(request_id=request_id, endpoint=request.url.path): - try: - task_logger = logger.bind(event=event) - task_logger.debug("Start task", status="START") - - # Attempt to call the next request handler - response = await call_next(request) - - # If response is successful, log the finish status - duration = time.time() - tic - task_logger.debug(f"Finish task (took {duration:.2f} seconds)", status="FINISH") - - except Exception as e: - # In case of an exception, log the error and prepare a generic response - duration = time.time() - tic - task_logger.exception( - f"Error occurred during processing: {e} (took {duration:.2f} seconds)", - status="ERROR", - ) - response = JSONResponse( - status_code=500, - content={"detail": "Internal server error"}, - ) - - # Add the unique request ID to the response headers - if response is not None and isinstance(response, Response): - response.headers["X-Request-ID"] = request_id - - return response + return await process_request_middleware(request, call_next) diff --git a/lib/serve/rest-api/src/middleware/__init__.py b/lib/serve/rest-api/src/middleware/__init__.py new file mode 100644 index 000000000..05a3c56d3 --- /dev/null +++ b/lib/serve/rest-api/src/middleware/__init__.py @@ -0,0 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Middleware modules.""" +from .request_middleware import process_request_middleware + +__all__ = ["process_request_middleware"] diff --git a/lib/serve/rest-api/src/middleware/request_middleware.py b/lib/serve/rest-api/src/middleware/request_middleware.py new file mode 100644 index 000000000..dddf03bb7 --- /dev/null +++ b/lib/serve/rest-api/src/middleware/request_middleware.py @@ -0,0 +1,73 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Request processing middleware.""" +import time +from collections.abc import Callable +from typing import Any +from uuid import uuid4 + +from fastapi import Request, Response +from fastapi.responses import JSONResponse +from loguru import logger + + +async def process_request_middleware(request: Request, call_next: Callable[[Request], Any]) -> Any: + """Middleware for processing all HTTP requests. + + Parameters + ---------- + request : Request + The incoming request + call_next : Callable + The next middleware or route handler + + Returns + ------- + Response + The response with added request ID header + """ + event = "process_request" + request_id = str(uuid4()) # Unique ID for this request + tic = time.time() + + with logger.contextualize(request_id=request_id, endpoint=request.url.path): + try: + task_logger = logger.bind(event=event) + task_logger.debug("Start task", status="START") + + # Attempt to call the next request handler + response = await call_next(request) + + # If response is successful, log the finish status + duration = time.time() - tic + task_logger.debug(f"Finish task (took {duration:.2f} seconds)", status="FINISH") + + except Exception as e: + # In case of an exception, log the error and prepare a generic response + duration = time.time() - tic + task_logger.exception( + f"Error occurred during processing: {e} (took {duration:.2f} seconds)", + status="ERROR", + ) + response = JSONResponse( + status_code=500, + content={"detail": "Internal server error"}, + ) + + # Add the unique request ID to the response headers + if response is not None and isinstance(response, Response): + response.headers["X-Request-ID"] = request_id + + return response diff --git a/lib/serve/rest-api/src/requirements.txt b/lib/serve/rest-api/src/requirements.txt index 291a4c3ca..29b54bfc8 100644 --- a/lib/serve/rest-api/src/requirements.txt +++ b/lib/serve/rest-api/src/requirements.txt @@ -1,29 +1,32 @@ -# AWS SDK - Version constrained by litellm[proxy]==1.80.9 -# litellm requires boto3==1.36.0, which requires aioboto3==13.4.0 and aiobotocore==2.18.0 -aioboto3==13.4.0 -aiobotocore==2.18.0 -boto3==1.36.0 +# AWS SDK Dependencies +# boto3 version pinned by litellm[proxy]==1.81.3 for RDS IAM token refresh +boto3==1.40.76 + +# OpenTelemetry - Optional for LiteLLM Weave integration, silences import warnings +opentelemetry-api>=1.20.0 +opentelemetry-sdk>=1.20.0 aiohttp==3.13.2 backoff==2.2.1 cachetools==6.2.2 click==8.3.1 cryptography==46.0.3 -fastapi==0.124.2 +fastapi>=0.120.1 fastapi_utils==0.8.0 -gunicorn==23.0.0 +gunicorn>=23.0.0,<24.0.0 -# LiteLLM - Constrains boto3==1.36.0 and uvicorn<0.32.0 -litellm[proxy]==1.80.9 +# LiteLLM - Upgraded to 1.81.3 for RDS IAM token refresh fix (PR #18795) +# Fixes: "All connection attempts failed" errors every 15 minutes with IAM auth +litellm[proxy]==1.81.3 loguru==0.7.3 -pydantic==2.12.5 -PyJWT==2.10.1 +pydantic>=2.5.0,<3.0.0 +PyJWT>=2.10.1,<3.0.0 text-generation==0.7.0 prisma==0.15.0 -pynacl==1.6.1 +pynacl>=1.5.0,<2.0.0 starlette>=0.40.0,<0.51.0 -# ASGI Server - Version constrained by litellm[proxy]==1.80.9 -# litellm requires uvicorn>=0.31.1,<0.39.0 -uvicorn>=0.31.1,<0.39.0 +# ASGI Server - Version constrained by litellm[proxy]==1.81.3 +# litellm requires uvicorn>=0.31.1,<0.32.0 +uvicorn>=0.31.1,<0.32.0 diff --git a/lib/serve/rest-api/src/services/__init__.py b/lib/serve/rest-api/src/services/__init__.py new file mode 100644 index 000000000..382e3705c --- /dev/null +++ b/lib/serve/rest-api/src/services/__init__.py @@ -0,0 +1,15 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Service layer for REST API business logic.""" diff --git a/lib/serve/rest-api/src/services/model_registration.py b/lib/serve/rest-api/src/services/model_registration.py new file mode 100644 index 000000000..3a66650c3 --- /dev/null +++ b/lib/serve/rest-api/src/services/model_registration.py @@ -0,0 +1,157 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model registration service.""" +from typing import Any, Protocol + +from utils.resources import ModelType, RestApiResource + + +class RegistryProtocol(Protocol): + """Protocol for model registry.""" + + def get_assets(self, provider: str) -> dict[str, Any]: + """Get model assets for a provider.""" + ... + + +class ModelRegistrationService: + """Service for registering models from configuration.""" + + # Supported inference containers + SUPPORTED_CONTAINERS = ["tgi", "tei", "instructor"] + + def __init__(self, registry: RegistryProtocol): + """Initialize the service. + + Parameters + ---------- + registry : RegistryProtocol + The model registry to use for getting validators + """ + self.registry = registry + + def create_empty_cache(self) -> dict[str, dict[str, Any]]: + """Create an empty model cache structure. + + Returns + ------- + dict[str, dict[str, Any]] + Empty cache with all required keys + """ + return { + ModelType.EMBEDDING: {}, + ModelType.TEXTGEN: {}, + RestApiResource.EMBEDDINGS: {}, + RestApiResource.GENERATE: {}, + RestApiResource.GENERATE_STREAM: {}, + "metadata": {}, + "endpointUrls": {}, + } + + def is_supported_container(self, inference_container: str) -> bool: + """Check if inference container is supported. + + Parameters + ---------- + inference_container : str + The inference container name + + Returns + ------- + bool + True if supported, False otherwise + """ + return inference_container in self.SUPPORTED_CONTAINERS + + def register_model(self, model: dict[str, Any], cache: dict[str, dict[str, Any]]) -> None: + """Register a single model into the cache. + + Parameters + ---------- + model : dict[str, Any] + Model configuration with keys: provider, modelName, modelType, endpointUrl, streaming + cache : dict[str, dict[str, Any]] + The cache to update + """ + provider = model["provider"] + model_name = model["modelName"] + model_type = model["modelType"] + + # provider format is `modelHosting.modelType.inferenceContainer` + # example: "ecs.textgen.tgi" + parts = provider.split(".") + if len(parts) != 3: + return # Invalid provider format + + inference_container = parts[2] + + # Skip unsupported containers + if not self.is_supported_container(inference_container): + return + + # Get default model kwargs from validator + validator = self.registry.get_assets(provider)["validator"] + model_kwargs = validator().dict() + + # Build model key + model_key = f"{provider}.{model_name}" + + # Store endpoint URL + cache["endpointUrls"][model_key] = model["endpointUrl"] + + # Store metadata + cache["metadata"][model_key] = { + "provider": provider, + "modelName": model_name, + "modelType": model_type, + "modelKwargs": model_kwargs, + } + if "streaming" in model: + cache["metadata"][model_key]["streaming"] = model["streaming"] + + # Register by model type and resource + if model_type == ModelType.EMBEDDING: + cache[RestApiResource.EMBEDDINGS].setdefault(provider, []).append(model_name) + cache[ModelType.EMBEDDING].setdefault(provider, []).append(model_name) + elif model_type == ModelType.TEXTGEN: + cache[RestApiResource.GENERATE].setdefault(provider, []).append(model_name) + cache[ModelType.TEXTGEN].setdefault(provider, []).append(model_name) + if model.get("streaming", False): + cache[RestApiResource.GENERATE_STREAM].setdefault(provider, []).append(model_name) + + def register_models(self, models: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: + """Register multiple models. + + Parameters + ---------- + models : list[dict[str, Any]] + List of model configurations + + Returns + ------- + dict[str, dict[str, Any]] + The populated cache + """ + cache = self.create_empty_cache() + + for model in models: + try: + self.register_model(model, cache) + except Exception: # nosec B112 + # Skip models that fail to register - this is intentional + # to allow partial registration when some models are misconfigured + continue + + return cache diff --git a/lib/serve/rest-api/src/services/model_service.py b/lib/serve/rest-api/src/services/model_service.py new file mode 100644 index 000000000..daef3f9ba --- /dev/null +++ b/lib/serve/rest-api/src/services/model_service.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Service for model operations - follows Single Responsibility Principle.""" + +import time +from collections import defaultdict +from typing import Any, DefaultDict + +from utils.resources import ModelType + + +class ModelService: + """Service class for model-related operations. + + This class encapsulates all model listing and description logic, + making it easy to test without external dependencies. + """ + + def __init__(self, models_cache: dict[str, Any]): + """Initialize with models cache. + + Parameters + ---------- + models_cache : dict + The registered models cache + """ + self.models_cache = models_cache + + def list_models(self, model_types: list[ModelType]) -> dict[ModelType, dict[str, list[str]]]: + """List models by type. + + Parameters + ---------- + model_types : List[ModelType] + Model types to list + + Returns + ------- + Dict[ModelType, Dict[str, List[str]]] + List of model names by model type and provider + """ + return {model_type: self.models_cache.get(model_type, {}) for model_type in model_types} + + def list_models_openai_format(self) -> dict[str, Any]: + """List models in OpenAI-compatible format. + + Returns + ------- + Dict[str, Any] + OpenAI-compatible response with text generation models + """ + textgen_models = self.models_cache.get(ModelType.TEXTGEN, {}) + + model_payload: list[dict[str, Any]] = [] + for provider, models in textgen_models.items(): + model_payload.extend( + {"id": f"{model} ({provider})", "object": "model", "created": int(time.time()), "owned_by": "LISA"} + for model in models + ) + + return {"data": model_payload, "object": "list"} + + def get_model_metadata(self, provider: str, model_name: str) -> dict[str, Any] | None: + """Get metadata for a specific model. + + Parameters + ---------- + provider : str + Model provider name + model_name : str + Model name + + Returns + ------- + Dict[str, Any] | None + Model metadata or None if not found + """ + model_key = f"{provider}.{model_name}" + metadata_cache = self.models_cache.get("metadata", {}) + result = metadata_cache.get(model_key) + return result if result is not None else None + + def describe_models(self, model_types: list[ModelType]) -> DefaultDict[str, DefaultDict[str, dict[str, Any]]]: + """Get detailed metadata for models by type. + + Parameters + ---------- + model_types : List[ModelType] + Model types to describe + + Returns + ------- + DefaultDict[str, DefaultDict[str, Dict[str, Any]]] + Model metadata by type, provider, and name + """ + registered_models = self.list_models(model_types) + metadata_cache = self.models_cache.get("metadata", {}) + response: DefaultDict[str, DefaultDict[str, dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) + + for model_type, providers in registered_models.items(): + response[model_type] = {} # type: ignore + providers = providers or {} + for provider, model_names in providers.items(): + response[model_type][provider] = [ + metadata_cache[f"{provider}.{model_name}"] + for model_name in model_names + if f"{provider}.{model_name}" in metadata_cache + ] # type: ignore + + return response diff --git a/lib/serve/rest-api/src/services/text_processing.py b/lib/serve/rest-api/src/services/text_processing.py new file mode 100644 index 000000000..ff670b864 --- /dev/null +++ b/lib/serve/rest-api/src/services/text_processing.py @@ -0,0 +1,100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text processing utilities - pure functions for easy testing.""" + + +def render_context_from_messages(messages_list: list[dict[str, str]]) -> str: + """Render context string from message list. + + Pure function that converts a list of messages into a single context string. + + Parameters + ---------- + messages_list : List[Dict[str, str]] + List of messages with 'content' field + + Returns + ------- + str + Concatenated message content + """ + return "\n\n".join([message["content"] for message in messages_list]) + + +def parse_model_provider_from_string(model_string: str) -> tuple[str, str]: + """Parse model name and provider from combined string. + + Pure function that extracts model and provider from format: "model_name (provider_name)" + + Parameters + ---------- + model_string : str + Combined model string in format "model_name (provider_name)" + + Returns + ------- + Tuple[str, str] + Model name and provider name + + Raises + ------ + ValueError + If string format is invalid + """ + if not model_string or "(" not in model_string or ")" not in model_string: + raise ValueError(f"Invalid model string format: {model_string}") + + model_parts = model_string.split() + if len(model_parts) < 2: + raise ValueError(f"Invalid model string format: {model_string}") + + model_name = model_parts[0].strip() + provider = model_parts[1].replace("(", "").replace(")", "").strip() + + if not model_name or not provider: + raise ValueError(f"Invalid model string format: {model_string}") + + return model_name, provider + + +def map_openai_params_to_lisa(request_data: dict) -> dict: + """Map OpenAI API parameters to LISA parameters. + + Pure function that transforms OpenAI request format to LISA format. + + Parameters + ---------- + request_data : dict + OpenAI-format request data + + Returns + ------- + dict + Mapped parameters for LISA + """ + # Mapping of OpenAI params to TGI/LISA params + param_mapping = { + "echo": "return_full_text", + "frequency_penalty": "repetition_penalty", + "max_tokens": "max_new_tokens", + "seed": "seed", + "stop": "stop_sequences", + "temperature": "temperature", + "top_p": "top_p", + } + + return { + param_mapping[k]: request_data[k] for k in param_mapping if k in request_data and request_data[k] is not None + } diff --git a/lib/serve/rest-api/src/utils/cache_manager.py b/lib/serve/rest-api/src/utils/cache_manager.py index 3c94bbace..9f3b87660 100644 --- a/lib/serve/rest-api/src/utils/cache_manager.py +++ b/lib/serve/rest-api/src/utils/cache_manager.py @@ -14,7 +14,7 @@ """Model Cache Utilities.""" import threading -from typing import Any, Dict, Optional, Tuple +from typing import Any from .resources import ModelType, RestApiResource @@ -23,7 +23,7 @@ # - RestApiResource keys (EMBEDDINGS, GENERATE, GENERATE_STREAM) contain models by endpoint. # - 'metadata' contains detailed information about each model. # - 'endpointUrls' contains the URLs for model instantiation. -REGISTERED_MODELS_CACHE: Dict[str, Dict[str, Any]] = { +REGISTERED_MODELS_CACHE: dict[str, dict[str, Any]] = { ModelType.EMBEDDING: {}, ModelType.TEXTGEN: {}, RestApiResource.EMBEDDINGS: {}, @@ -32,32 +32,32 @@ "metadata": {}, "endpointUrls": {}, } -MODEL_ASSETS_CACHE: Dict[str, Tuple[Any, Any]] = {} +MODEL_ASSETS_CACHE: dict[str, tuple[Any, Any]] = {} # Thread locks for cache operations _REGISTERED_MODELS_LOCK = threading.RLock() _MODEL_ASSETS_LOCK = threading.RLock() -def get_registered_models_cache() -> Dict[str, Dict[str, Any]]: +def get_registered_models_cache() -> dict[str, dict[str, Any]]: """Get the cache containing the registered models.""" with _REGISTERED_MODELS_LOCK: return REGISTERED_MODELS_CACHE.copy() -def get_model_assets(model_key: str) -> Optional[Tuple[Any, Any]]: +def get_model_assets(model_key: str) -> tuple[Any, Any] | None: """Get the cache belonging to the model assets.""" with _MODEL_ASSETS_LOCK: return MODEL_ASSETS_CACHE.get(model_key) -def cache_model_assets(key: str, model_assets: Tuple[Any, Any]) -> None: +def cache_model_assets(key: str, model_assets: tuple[Any, Any]) -> None: """Cache the specified model assets for the specified key.""" with _MODEL_ASSETS_LOCK: MODEL_ASSETS_CACHE[key] = model_assets -def set_registered_models_cache(models: Dict[str, Dict[str, Any]]) -> None: +def set_registered_models_cache(models: dict[str, dict[str, Any]]) -> None: """Set the registered model cache to the specified models value.""" with _REGISTERED_MODELS_LOCK: global REGISTERED_MODELS_CACHE diff --git a/lib/serve/rest-api/src/utils/decorators.py b/lib/serve/rest-api/src/utils/decorators.py index a550aa44c..659335bb3 100644 --- a/lib/serve/rest-api/src/utils/decorators.py +++ b/lib/serve/rest-api/src/utils/decorators.py @@ -13,14 +13,15 @@ # limitations under the License. """Utility decorators.""" -from typing import Any, Callable, cast, Dict, TypeVar +from collections.abc import Callable +from typing import Any, cast, TypeVar T = TypeVar("T") def singleton(cls: type[T]) -> Callable[..., T]: """Singleton decorator.""" - instances: Dict[type, Any] = {} + instances: dict[type, Any] = {} def get_instance(*args: Any, **kwargs: Any) -> T: if cls not in instances: diff --git a/lib/serve/rest-api/src/utils/generate_litellm_config.py b/lib/serve/rest-api/src/utils/generate_litellm_config.py index 6217cf3d6..1ae021f10 100644 --- a/lib/serve/rest-api/src/utils/generate_litellm_config.py +++ b/lib/serve/rest-api/src/utils/generate_litellm_config.py @@ -16,12 +16,44 @@ import json import os -from typing import Tuple import boto3 import click import yaml -from rds_auth import generate_auth_token, get_lambda_role_name + + +def _is_embedding_model(model: dict) -> bool: + """Check if a model is an embedding model based on naming conventions.""" + model_name = model.get("modelName", "").lower() + model_id = model.get("modelId", "").lower() + return "embed" in model_name or "embed" in model_id + + +def _build_model_config(model: dict) -> dict: + """Build LiteLLM model configuration for a registered model.""" + model_name = model["modelName"] + is_embedding = _is_embedding_model(model) + + # Use hosted_vllm provider for embedding models to avoid encoding_format issues + # LiteLLM 1.80+ has issues with openai/ provider sending invalid encoding_format to vLLM + if is_embedding: + provider_prefix = "hosted_vllm" + else: + provider_prefix = "openai" + + litellm_params = { + "model": f"{provider_prefix}/{model_name}", + "api_base": model["endpointUrl"] + "/v1", # Local containers require the /v1 for OpenAI API routing. + } + + # For embedding models, also add drop_params as a safety measure + if is_embedding: + litellm_params["drop_params"] = True + + return { + "model_name": model["modelId"], # Use user-provided name if one given, otherwise it is the model name. + "litellm_params": litellm_params, + } @click.command() @@ -30,7 +62,7 @@ def generate_config(filepath: str) -> None: """Read LiteLLM configuration and rewrite it with LISA-deployed model information.""" ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"]) - with open(filepath, "r") as fp: + with open(filepath) as fp: config_contents = yaml.safe_load(fp) # Get and load registered models from ParameterStore param_response = ssm_client.get_parameter(Name=os.environ["REGISTERED_MODELS_PS_NAME"]) @@ -55,6 +87,10 @@ def generate_config(filepath: str) -> None: { "drop_params": True, # drop unrecognized param instead of failing the request on it "request_timeout": 600, + # Performance optimizations for embeddings + "num_retries": 2, # Reduce retries for faster failure detection + "retry_after": 1, # Shorter retry delay + "embedding_cache": True, # Enable embedding caching (if Redis configured) } ) @@ -62,21 +98,34 @@ def generate_config(filepath: str) -> None: db_param_response = ssm_client.get_parameter(Name=os.environ["LITELLM_DB_INFO_PS_NAME"]) db_params = json.loads(db_param_response["Parameter"]["Value"]) - username, password = get_database_credentials(db_params) - connection_str = ( - f"postgresql://{username}:{password}@{db_params['dbHost']}:{db_params['dbPort']}" f"/{db_params['dbName']}" - ) + # Check if using IAM auth - either via environment variable (preferred) or SSM parameter + # IAM_TOKEN_DB_AUTH is set by CDK when iamRdsAuth=true + use_iam_auth = os.environ.get("IAM_TOKEN_DB_AUTH", "").lower() == "true" or "passwordSecretId" not in db_params if "general_settings" not in config_contents: config_contents["general_settings"] = {} - config_contents["general_settings"].update( - { - "store_model_in_db": True, - "database_url": connection_str, - "master_key": config_contents["db_key"], - } - ) + if use_iam_auth: + config_contents["general_settings"].update( + { + "store_model_in_db": True, + "master_key": config_contents["db_key"], + } + ) + else: + # Password auth: build connection string with password from Secrets Manager + username, password = get_database_credentials(db_params) + connection_str = ( + f"postgresql://{username}:{password}@{db_params['dbHost']}:{db_params['dbPort']}" f"/{db_params['dbName']}" + ) + + config_contents["general_settings"].update( + { + "store_model_in_db": True, + "database_url": connection_str, + "master_key": config_contents["db_key"], + } + ) print(f"Generated config_contents file: \n{json.dumps(config_contents, indent=2)}") @@ -85,17 +134,22 @@ def generate_config(filepath: str) -> None: yaml.safe_dump(config_contents, fp) -def get_database_credentials(db_params: dict[str, str]) -> Tuple: - """Get database password from Secrets Manager or using IAM auth.""" - - if "passwordSecretId" in db_params: - secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) +def get_database_credentials(db_params: dict[str, str]) -> tuple: + """Get database credentials using password auth from Secrets Manager.""" + secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) + try: secret_response = secrets_client.get_secret_value(SecretId=db_params["passwordSecretId"]) - secret = json.loads(secret_response["SecretString"]) - return (db_params["username"], secret["password"]) - else: - iam_name = get_lambda_role_name() - return (iam_name, generate_auth_token(db_params["dbHost"], db_params["dbPort"], iam_name)) + except secrets_client.exceptions.ResourceNotFoundException: + raise RuntimeError( + f"Database password secret '{db_params['passwordSecretId']}' not found. " + "This typically occurs when switching from IAM authentication (iamRdsAuth=true) " + "back to password authentication (iamRdsAuth=false). The master password is " + "permanently deleted when IAM auth is enabled. To resolve this, either: " + "1) Set iamRdsAuth=true in your config, or " + "2) Recreate the database by deleting and redeploying the stack." + ) + secret = json.loads(secret_response["SecretString"]) + return (db_params["username"], secret["password"]) if __name__ == "__main__": diff --git a/lib/serve/rest-api/src/utils/guardrails.py b/lib/serve/rest-api/src/utils/guardrails.py index 1d6eb042a..d44b33475 100644 --- a/lib/serve/rest-api/src/utils/guardrails.py +++ b/lib/serve/rest-api/src/utils/guardrails.py @@ -18,14 +18,14 @@ import os import re from collections.abc import Iterator -from typing import Any, Dict, List, Optional +from typing import Any import boto3 from fastapi.responses import JSONResponse from loguru import logger -async def get_model_guardrails(model_id: str) -> List[Dict[str, Any]]: +async def get_model_guardrails(model_id: str) -> list[dict[str, Any]]: """ Query the guardrails DynamoDB table for guardrails associated with a model. @@ -52,14 +52,14 @@ async def get_model_guardrails(model_id: str) -> List[Dict[str, Any]]: guardrails = response.get("Items", []) logger.debug(f"Found {len(guardrails)} guardrails for model {model_id}") - return guardrails + return guardrails # type: ignore[no-any-return] except Exception as e: logger.error(f"Error fetching guardrails for model {model_id}: {e}") return [] -def get_applicable_guardrails(user_groups: List[str], guardrails: List[Dict[str, Any]], model_id: str) -> List[str]: +def get_applicable_guardrails(user_groups: list[str], guardrails: list[dict[str, Any]], model_id: str) -> list[str]: """ Determine which guardrails apply to a user based on group membership. @@ -131,7 +131,7 @@ def is_guardrail_violation(error_msg: str) -> bool: return "Violated guardrail policy" in error_msg -def extract_guardrail_response(error_msg: str) -> Optional[str]: +def extract_guardrail_response(error_msg: str) -> str | None: """ Extract the bedrock_guardrail_response from an error message. diff --git a/lib/serve/rest-api/src/utils/metrics.py b/lib/serve/rest-api/src/utils/metrics.py index f6b43b967..417a5e95f 100644 --- a/lib/serve/rest-api/src/utils/metrics.py +++ b/lib/serve/rest-api/src/utils/metrics.py @@ -21,10 +21,9 @@ from datetime import datetime import boto3 +from auth import get_user_context from fastapi import Request -from ..auth import get_user_context - logger = logging.getLogger(__name__) sqs_client = boto3.client("sqs", region_name=os.environ["AWS_REGION"]) diff --git a/lib/serve/rest-api/src/utils/rds_auth.py b/lib/serve/rest-api/src/utils/rds_auth.py index dd777b93a..e3adb07ab 100644 --- a/lib/serve/rest-api/src/utils/rds_auth.py +++ b/lib/serve/rest-api/src/utils/rds_auth.py @@ -11,19 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""RDS authentication utilities.""" + import os from typing import cast -from urllib.parse import quote_plus import boto3 -def generate_auth_token(host: str, port: str, user: str) -> str: - rds = boto3.client("rds", region_name=os.environ["AWS_REGION"]) - token = rds.generate_db_auth_token(DBHostname=host, Port=port, DBUsername=user) - return quote_plus(token) - - def _get_lambda_role_arn() -> str: """Get the ARN of the Lambda execution role. @@ -34,7 +30,7 @@ def _get_lambda_role_arn() -> str: """ sts = boto3.client("sts", region_name=os.environ["AWS_REGION"]) identity = sts.get_caller_identity() - return cast(str, identity["Arn"]) # This will include the role name + return cast(str, identity["Arn"]) def get_lambda_role_name() -> str: @@ -47,4 +43,4 @@ def get_lambda_role_name() -> str: """ arn = _get_lambda_role_arn() parts = arn.split(":assumed-role/")[1].split("/") - return parts[0] # This is the role name + return parts[0] diff --git a/lib/serve/rest-api/src/utils/request_utils.py b/lib/serve/rest-api/src/utils/request_utils.py index 9b5fc2fc7..b4aa86094 100644 --- a/lib/serve/rest-api/src/utils/request_utils.py +++ b/lib/serve/rest-api/src/utils/request_utils.py @@ -17,13 +17,34 @@ import os import sys import traceback -from typing import Any, AsyncGenerator, Callable, Dict, Tuple +from collections.abc import AsyncGenerator, Callable +from typing import Any, Protocol from loguru import logger +from utils.cache_manager import cache_model_assets, get_model_assets, get_registered_models_cache +from utils.resources import RestApiResource + + +class RegistryProtocol(Protocol): + """Protocol for model registry - allows dependency injection.""" + + def get_assets(self, provider: str) -> dict[str, Any]: + """Get model assets for a provider.""" + ... + + +def _get_default_registry() -> RegistryProtocol: + """Lazy import of registry to avoid import-time dependencies. + + This function is only called at runtime, not at import time, + allowing tests to mock the registry without importing lisa_serve. + """ + # Import here to avoid circular dependencies and allow test mocking + # This is intentionally not at module level + from lisa_serve.registry import registry # noqa: PLC0415 + + return registry -from ..lisa_serve.registry import registry -from .cache_manager import cache_model_assets, get_model_assets, get_registered_models_cache -from .resources import RestApiResource logger.remove() logger_level = os.environ.get("LOG_LEVEL", "INFO") @@ -48,7 +69,7 @@ ) -async def validate_model(request_data: Dict[str, Any], resource: RestApiResource) -> None: +async def validate_model(request_data: dict[str, Any], resource: RestApiResource) -> None: """Validate that the selected model is registered and supported for the specified resource. Parameters @@ -91,13 +112,17 @@ async def validate_model(request_data: Dict[str, Any], resource: RestApiResource raise ValueError(message) -async def get_model_and_validator(request_data: Dict[str, Any]) -> Tuple[Any, Any]: +async def get_model_and_validator( + request_data: dict[str, Any], registry: RegistryProtocol | None = None +) -> tuple[Any, Any]: """Get model and model kwargs validator. Parameters ---------- request_data : Dict[str, Any] Request data. + registry : RegistryProtocol | None + Optional registry for dependency injection (testing). Returns ------- @@ -112,6 +137,9 @@ async def get_model_and_validator(request_data: Dict[str, Any]) -> Tuple[Any, An model_assets = get_model_assets(model_key) if not model_assets: # If not cached, retrieve model assets from registry + if registry is None: + registry = _get_default_registry() + registry_assets = registry.get_assets(provider) adapter = registry_assets["adapter"] validator = registry_assets["validator"] @@ -134,8 +162,8 @@ async def get_model_and_validator(request_data: Dict[str, Any]) -> Tuple[Any, An async def validate_and_prepare_llm_request( - request_data: Dict[str, Any], resource: RestApiResource -) -> Tuple[Any, Any, str]: + request_data: dict[str, Any], resource: RestApiResource, registry: RegistryProtocol | None = None +) -> tuple[Any, Any, str]: """Validate and prepare data for LLM (Language Model) requests. Parameters @@ -146,6 +174,9 @@ async def validate_and_prepare_llm_request( resource : RestApiResource REST API resource. + registry : RegistryProtocol | None + Optional registry for dependency injection (testing). + Returns ------- Tuple @@ -159,7 +190,7 @@ async def validate_and_prepare_llm_request( await validate_model(request_data, resource) # Instantiate the model and get the model kwargs validator - model, validator = await get_model_and_validator(request_data) + model, validator = await get_model_and_validator(request_data, registry) # Verify model kwargs model_kwargs = validator(**request_data["modelKwargs"]) @@ -174,8 +205,8 @@ async def validate_and_prepare_llm_request( def handle_stream_exceptions( - func: Callable[..., AsyncGenerator[str, None]], -) -> Callable[..., AsyncGenerator[str, None]]: + func: Callable[..., AsyncGenerator[str]], +) -> Callable[..., AsyncGenerator[str]]: """Decorate a streaming function to handle exceptions gracefully. This decorator catches any exceptions raised during the execution of a streaming function @@ -200,7 +231,7 @@ def handle_stream_exceptions( The items yielded by the original function, or a JSON-formatted error message in case of an exception. """ - async def wrapper(*args: Any, **kwargs: Any) -> AsyncGenerator[str, None]: + async def wrapper(*args: Any, **kwargs: Any) -> AsyncGenerator[str]: try: async for item in func(*args, **kwargs): yield item diff --git a/lib/serve/rest-api/src/utils/resources.py b/lib/serve/rest-api/src/utils/resources.py index c8929ae7a..17934b00a 100644 --- a/lib/serve/rest-api/src/utils/resources.py +++ b/lib/serve/rest-api/src/utils/resources.py @@ -14,7 +14,7 @@ """REST API resources.""" from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any from pydantic import BaseModel, Field @@ -43,6 +43,7 @@ class ModelType(str, Enum): EMBEDDING = "embedding" TEXTGEN = "textgen" + VIDEOGEN = "videogen" class _BaseModelRequest(BaseModel): @@ -50,8 +51,8 @@ class _BaseModelRequest(BaseModel): provider: str = Field(..., description="The backend provider for the model.") modelName: str = Field(..., description="The model name.") - text: Union[str, list[str]] = Field(..., description="The input text(s) to be processed by the model.") - modelKwargs: Dict[str, Any] = Field(default={}, description="Arguments to the model.") + text: str | list[str] = Field(..., description="The input text(s) to be processed by the model.") + modelKwargs: dict[str, Any] = Field(default={}, description="Arguments to the model.") class EmbeddingsRequest(_BaseModelRequest): @@ -72,13 +73,13 @@ class OpenAIChatCompletionsRequest(BaseModel): Additional documentation at https://platform.openai.com/docs/api-reference/chat/create """ - messages: List[Dict[str, str]] = Field(..., description="A list of messages comprising the conversation so far.") + messages: list[dict[str, str]] = Field(..., description="A list of messages comprising the conversation so far.") model: str = Field(..., description="ID of the model to use.") - frequency_penalty: Optional[float] = Field(None, description="Penalty to add for text repetition.") - logit_bias: Optional[Dict[Any, Any]] = Field( + frequency_penalty: float | None = Field(None, description="Penalty to add for text repetition.") + logit_bias: dict[Any, Any] | None = Field( None, description="Modify the likelihood of specified tokens appearing in the completion." ) - logprobs: Optional[bool] = Field( + logprobs: bool | None = Field( False, description=" ".join( [ @@ -88,7 +89,7 @@ class OpenAIChatCompletionsRequest(BaseModel): ] ), ) - top_logprobs: Optional[int] = Field( + top_logprobs: int | None = Field( None, description=" ".join( [ @@ -99,8 +100,8 @@ class OpenAIChatCompletionsRequest(BaseModel): ] ), ) - max_tokens: Optional[int] = Field(50, description="Maximum number of generated tokens.") - n: Optional[int] = Field( + max_tokens: int | None = Field(50, description="Maximum number of generated tokens.") + n: int | None = Field( 1, description=" ".join( [ @@ -109,7 +110,7 @@ class OpenAIChatCompletionsRequest(BaseModel): ] ), ) - presence_penalty: Optional[float] = Field( + presence_penalty: float | None = Field( 0, description=" ".join( [ @@ -118,11 +119,11 @@ class OpenAIChatCompletionsRequest(BaseModel): ] ), ) - seed: Optional[int] = Field(None, description="Random sampling seed.") - stop: Optional[List[str]] = Field( + seed: int | None = Field(None, description="Random sampling seed.") + stop: list[str] | None = Field( default_factory=list, description="Stop generating tokens if a member of `stop` is generated." ) - stream: Optional[bool] = Field( + stream: bool | None = Field( False, description=" ".join( [ @@ -132,7 +133,7 @@ class OpenAIChatCompletionsRequest(BaseModel): ] ), ) - top_p: Optional[float] = Field( + top_p: float | None = Field( None, description=" ".join( [ @@ -141,7 +142,7 @@ class OpenAIChatCompletionsRequest(BaseModel): ] ), ) - temperature: Optional[float] = Field(None, description="Value used to divide the logits distribution.") + temperature: float | None = Field(None, description="Value used to divide the logits distribution.") class OpenAICompletionsRequest(BaseModel): @@ -160,7 +161,7 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - best_of: Optional[int] = Field( + best_of: int | None = Field( 1, description=" ".join( [ @@ -170,9 +171,9 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - echo: Optional[bool] = Field(False, description="Whether to prepend the prompt to the generated text.") - frequency_penalty: Optional[float] = Field(None, description="Penalty to add for text repetition.") - logit_bias: Optional[Dict[Any, Any]] = Field( + echo: bool | None = Field(False, description="Whether to prepend the prompt to the generated text.") + frequency_penalty: float | None = Field(None, description="Penalty to add for text repetition.") + logit_bias: dict[Any, Any] | None = Field( None, description=" ".join( [ @@ -181,7 +182,7 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - logprobs: Optional[int] = Field( + logprobs: int | None = Field( None, description=" ".join( [ @@ -194,10 +195,10 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - max_tokens: Optional[int] = Field( + max_tokens: int | None = Field( 50, description="The maximum number of tokens that can be generated in the completion." ) - n: Optional[int] = Field( + n: int | None = Field( 1, description=" ".join( [ @@ -206,7 +207,7 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - presence_penalty: Optional[float] = Field( + presence_penalty: float | None = Field( 0, description=" ".join( [ @@ -215,11 +216,11 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - seed: Optional[int] = Field(None, description="Random sampling seed.") - stop: Optional[Any] = Field( + seed: int | None = Field(None, description="Random sampling seed.") + stop: Any | None = Field( default_factory=list, description="Stop generating tokens if a member of `stop` is generated." ) - stream: Optional[bool] = Field( + stream: bool | None = Field( False, description=" ".join( [ @@ -229,7 +230,7 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - suffix: Optional[str] = Field( + suffix: str | None = Field( None, description=" ".join( [ @@ -239,8 +240,8 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - temperature: Optional[float] = Field(1.0, description="Value used to divide the logits distribution.") - top_p: Optional[float] = Field( + temperature: float | None = Field(1.0, description="Value used to divide the logits distribution.") + top_p: float | None = Field( None, description=" ".join( [ diff --git a/lib/serve/serveApplicationConstruct.ts b/lib/serve/serveApplicationConstruct.ts index 72b6665b7..4006a8862 100644 --- a/lib/serve/serveApplicationConstruct.ts +++ b/lib/serve/serveApplicationConstruct.ts @@ -18,29 +18,25 @@ import { ITable, Table } from 'aws-cdk-lib/aws-dynamodb'; import { Credentials, DatabaseInstance, DatabaseInstanceEngine } from 'aws-cdk-lib/aws-rds'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { Construct } from 'constructs'; -import { Code, Function, IFunction, ILayerVersion, LayerVersion } from 'aws-cdk-lib/aws-lambda'; import { FastApiContainer } from '../api-base/fastApiContainer'; import { ECSCluster } from '../api-base/ecsCluster'; import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; -import { APP_MANAGEMENT_KEY, BaseProps, Config } from '../schema'; +import { APP_MANAGEMENT_KEY, BaseProps } from '../schema'; import { Effect, Policy, - PolicyDocument, PolicyStatement, - Role, - ServicePrincipal, } from 'aws-cdk-lib/aws-iam'; -import { HostedRotation, ISecret } from 'aws-cdk-lib/aws-secretsmanager'; +import { HostedRotation } from 'aws-cdk-lib/aws-secretsmanager'; import { SecurityGroupEnum } from '../core/iam/SecurityGroups'; import { SecurityGroupFactory } from '../networking/vpc/security-group-factory'; -import { LAMBDA_PATH, REST_API_PATH } from '../util'; -import { AwsCustomResource, PhysicalResourceId } from 'aws-cdk-lib/custom-resources'; -import { getPythonRuntime } from '../api-base/utils'; +import { REST_API_PATH } from '../util'; +import { AwsCustomResource, AwsCustomResourcePolicy, PhysicalResourceId } from 'aws-cdk-lib/custom-resources'; import { ISecurityGroup, Port } from 'aws-cdk-lib/aws-ec2'; import { ECSTasks } from '../api-base/ecsCluster'; import { GuardrailsTable } from '../models/guardrails-table'; +import { Role } from 'aws-cdk-lib/aws-iam'; export type LisaServeApplicationProps = { vpc: Vpc; @@ -70,6 +66,9 @@ export class LisaServeApplicationConstruct extends Construct { super(scope, id); const { config, vpc, securityGroups } = props; + // Determine authentication method - default to IAM auth (iamRdsAuth = false) + const useIamAuth = config.iamRdsAuth ?? false; + // TokenTable is now created in API Base, reference it from SSM parameter // API Base stack must be deployed before Serve stack (dependency is set in stages.ts) const tokenTableNameParameter = StringParameter.fromStringParameterName( @@ -128,6 +127,8 @@ export class LisaServeApplicationConstruct extends Construct { } const username = config.restApiConfig.rdsConfig.username; + + // Create credentials for database setup const dbCreds = Credentials.fromGeneratedSecret(username); // DB is a Single AZ instance for cost + inability to make non-Aurora multi-AZ cluster in CDK @@ -138,15 +139,22 @@ export class LisaServeApplicationConstruct extends Construct { vpc: vpc.vpc, subnetGroup: vpc.subnetGroup, credentials: dbCreds, - iamAuthentication: true, + iamAuthentication: useIamAuth, // Enable IAM auth when iamRdsAuth is true + databaseName: config.restApiConfig.rdsConfig.dbName, // Specify database name to match config securityGroups: [litellmDbSg], removalPolicy: config.removalPolicy, }); - const litellmDbPasswordSecret = litellmDb.secret!; + // Secret is used for password auth or for IAM user bootstrap + const litellmDbSecret = litellmDb.secret!; + + // Add rotation policy for the database password secret (only if using password auth) + if (!useIamAuth) { + // WARNING: If switching from IAM auth (iamRdsAuth=true) back to password auth (iamRdsAuth=false), + // the deployment will fail because the master password secret was deleted during IAM auth setup. + // This is a one-way migration - once IAM auth is enabled, you cannot switch back to password auth + // without recreating the database. - // Add rotation policy for the database password secret (only if not using IAM auth) - if (!config.iamRdsAuth) { // Allow the rotation Lambda to connect to the database securityGroups.forEach((sg) => { litellmDbSg.addIngressRule( @@ -156,7 +164,7 @@ export class LisaServeApplicationConstruct extends Construct { ); }); - litellmDbPasswordSecret.addRotationSchedule('DatabasePasswordRotationSchedule', { + litellmDbSecret.addRotationSchedule('DatabasePasswordRotationSchedule', { automaticallyAfter: Duration.days(30), hostedRotation: HostedRotation.postgreSqlSingleUser({ functionName: `${config.deploymentName}-Litellm-Rotation-Function`, @@ -174,18 +182,10 @@ export class LisaServeApplicationConstruct extends Construct { dbHost: litellmDb.dbInstanceEndpointAddress, dbName: config.restApiConfig.rdsConfig.dbName, dbPort: config.restApiConfig.rdsConfig.dbPort, - // only include passwordSecretId if authenticating with username/password - ...(config.iamRdsAuth ? {} : { passwordSecretId: litellmDbPasswordSecret.secretName }) + // Include passwordSecretId only when using password auth + ...(!useIamAuth ? { passwordSecretId: litellmDbSecret.secretName } : {}) }), }); - console.log('storing llmdbconninfop', JSON.stringify({ - username: username, - dbHost: litellmDb.dbInstanceEndpointAddress, - dbName: config.restApiConfig.rdsConfig.dbName, - dbPort: config.restApiConfig.rdsConfig.dbPort, - // only include passwordSecretId if authenticating with username/password - ...(config.iamRdsAuth ? {} : { passwordSecretId: litellmDbPasswordSecret.secretName }) - })); // update the rdsConfig with the endpoint address config.restApiConfig.rdsConfig.dbHost = litellmDb.dbInstanceEndpointAddress; @@ -208,40 +208,100 @@ export class LisaServeApplicationConstruct extends Construct { if (serveRole) { // Grant access to REST API task role only litellmDbConnectionInfoPs.grantRead(serveRole); - if (config.iamRdsAuth) { - litellmDb.grantConnect(serveRole, serveRole.roleName); - // Create the lambda for generating DB users for IAM auth - const createDbUserLambda = this.getIAMAuthLambda(scope, config, litellmDbPasswordSecret, serveRole.roleName, vpc, [litellmDbSg]); + if (!useIamAuth) { + // Password auth: grant secret read access only (grantConnect requires IAM auth) + litellmDbSecret.grantRead(serveRole); + } else { + // IAM auth: manually grant rds-db:connect permission + // Note: We do NOT use litellmDb.grantConnect() due to CDK bug #11851 + // The grantConnect method generates incorrect ARN format (uses rds: instead of rds-db:) + // Per AWS docs: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html + // The correct format is: arn:aws:rds-db:region:account-id:dbuser:DbiResourceId/db-user-name + serveRole.addToPrincipalPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['rds-db:connect'], + resources: [ + // Use wildcard for DbiResourceId since it's not available in CloudFormation + // Format: arn:aws:rds-db:region:account:dbuser:*/username + `arn:${config.partition}:rds-db:${config.region}:${config.accountNumber}:dbuser:*/${serveRole.roleName}` + ] + })); + + // Use the shared IAM auth setup Lambda from API Base stack + const iamAuthSetupFnArn = StringParameter.valueForStringParameter( + scope, + `${config.deploymentPrefix}/iamAuthSetupFnArn` + ); - const customResourceRole = new Role(scope, 'LISAServeCustomResourceRole', { - assumedBy: new ServicePrincipal('lambda.amazonaws.com'), - }); - createDbUserLambda.grantInvoke(customResourceRole); - - // run updateInstanceKmsConditionsLambda every deploy - new AwsCustomResource(scope, 'LISAServeCreateDbUserCustomResource', { - onCreate: { - service: 'Lambda', - action: 'invoke', - physicalResourceId: PhysicalResourceId.of('LISAServeCreateDbUserCustomResource'), - parameters: { - FunctionName: createDbUserLambda.functionName, - Payload: '{}' - }, + // Get the IAM auth setup Lambda role ARN from SSM to grant it permissions + const iamAuthSetupRoleArn = StringParameter.valueForStringParameter( + scope, + `${config.deploymentPrefix}/iamAuthSetupRoleArn` + ); + + // Import the IAM auth setup role to grant it secret permissions + const iamAuthSetupRole = Role.fromRoleArn( + scope, + 'IamAuthSetupRoleRef', + iamAuthSetupRoleArn + ); + + // Grant the IAM auth setup Lambda role permission to read the bootstrap secret + litellmDbSecret.grantRead(iamAuthSetupRole); + + // Run the shared IAM auth setup Lambda on create and update + // This runs when switching to IAM auth or updating the configuration + // Pass parameters via payload since the Lambda is shared + // Use Stack.of(scope).toJsonString() to properly resolve CDK tokens in the payload + const lambdaInvokeParams = { + service: 'Lambda', + action: 'invoke', + physicalResourceId: PhysicalResourceId.of('LISAServeCreateDbUserCustomResource'), + parameters: { + FunctionName: iamAuthSetupFnArn, + Payload: Stack.of(scope).toJsonString({ + secretArn: litellmDbSecret.secretArn, + dbHost: config.restApiConfig.rdsConfig.dbHost, + dbPort: config.restApiConfig.rdsConfig.dbPort, + dbName: config.restApiConfig.rdsConfig.dbName, + dbUser: config.restApiConfig.rdsConfig.username, + iamName: serveRole.roleName, + }) }, - role: customResourceRole + }; + + const createDbUserResource = new AwsCustomResource(scope, 'LISAServeCreateDbUserCustomResource', { + onCreate: lambdaInvokeParams, + onUpdate: lambdaInvokeParams, // Also run on updates to ensure IAM user is created + policy: AwsCustomResourcePolicy.fromStatements([ + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['lambda:InvokeFunction'], + resources: [iamAuthSetupFnArn], + }) + ]), }); - } else { - litellmDb.grantConnect(serveRole); - litellmDbPasswordSecret.grantRead(serveRole); + + // Ensure the RDS instance is fully available before running IAM auth setup + createDbUserResource.node.addDependency(litellmDb); + + // Ensure the ECS service waits for IAM user setup to complete + restApi.node.addDependency(createDbUserResource); } + this.modelsPs.grantRead(serveRole); } // Use the guardrails table name from the construct we just created const guardrailsTableName = this.guardrailsTable.tableName; + // Get generated images bucket name for video/image content storage + const imagesBucketName = StringParameter.valueForStringParameter( + scope, + `${config.deploymentPrefix}/generatedImagesBucketName` + ); + // Add parameter as container environment variable for both RestAPI and RagAPI const container = restApi.apiCluster.containers[ECSTasks.REST]; if (container) { @@ -249,12 +309,23 @@ export class LisaServeApplicationConstruct extends Construct { container.addEnvironment('REGISTERED_MODELS_PS_NAME', this.modelsPs.parameterName); container.addEnvironment('LITELLM_DB_INFO_PS_NAME', litellmDbConnectionInfoPs.parameterName); container.addEnvironment('GUARDRAILS_TABLE_NAME', guardrailsTableName); + container.addEnvironment('GENERATED_IMAGES_S3_BUCKET_NAME', imagesBucketName); // Add metrics queue URL if provided if (props.metricsQueueUrl) { // Get the queue URL from SSM parameter const queueUrl = StringParameter.valueForStringParameter(scope, props.metricsQueueUrl); container.addEnvironment('USAGE_METRICS_QUEUE_URL', queueUrl); } + + // Add IAM auth environment variables for LiteLLM's native token refresh + // When these are set, LiteLLM automatically generates and refreshes IAM auth tokens + if (useIamAuth && serveRole) { + container.addEnvironment('IAM_TOKEN_DB_AUTH', 'true'); + container.addEnvironment('DATABASE_HOST', litellmDb.dbInstanceEndpointAddress); + container.addEnvironment('DATABASE_NAME', config.restApiConfig.rdsConfig.dbName); + container.addEnvironment('DATABASE_PORT', config.restApiConfig.rdsConfig.dbPort.toString()); + container.addEnvironment('DATABASE_USER', serveRole.roleName); + } } restApi.node.addDependency(this.modelsPs); restApi.node.addDependency(litellmDbConnectionInfoPs); @@ -316,6 +387,15 @@ export class LisaServeApplicationConstruct extends Construct { restRole.attachInlinePolicy(invocation_permissions); restRole.attachInlinePolicy(guardrails_permissions); + // Grant S3 bucket permissions for video/image content storage + restRole.addToPrincipalPolicy( + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['s3:PutObject', 's3:GetObject', 's3:DeleteObject'], + resources: [`arn:${config.partition}:s3:::${imagesBucketName}/*`] + }) + ); + // Grant SQS send permissions if metrics queue URL is provided if (props.metricsQueueUrl) { // Get the queue name from SSM parameter @@ -343,56 +423,4 @@ export class LisaServeApplicationConstruct extends Construct { } }; - getIAMAuthLambda (scope: Stack, config: Config, secret: ISecret, user: string, vpc: Vpc, securityGroups: ISecurityGroup[]): IFunction { - // Create the IAM role for updating the database to allow IAM authentication - const iamAuthLambdaRole = new Role(scope, createCdkId(['LISAServe', 'IAMAuthLambdaRole']), { - assumedBy: new ServicePrincipal('lambda.amazonaws.com'), - inlinePolicies: { - 'EC2NetworkInterfaces': new PolicyDocument({ - statements: [ - new PolicyStatement({ - effect: Effect.ALLOW, - actions: ['ec2:CreateNetworkInterface', 'ec2:DescribeNetworkInterfaces', 'ec2:DeleteNetworkInterface'], - resources: ['*'], - }), - ], - }), - } - }); - - secret.grantRead(iamAuthLambdaRole); - - const commonLayer = this.getLambdaLayer(scope, config); - const lambdaPath = config.lambdaPath || LAMBDA_PATH; - - // Create the Lambda function that will create the database user - return new Function(scope, 'LISAServeCreateDbUserLambda', { - runtime: getPythonRuntime(), - handler: 'utilities.db_setup_iam_auth.handler', - code: Code.fromAsset(lambdaPath), - timeout: Duration.minutes(2), - environment: { - SECRET_ARN: secret.secretArn, // ARN of the RDS secret - DB_HOST: config.restApiConfig.rdsConfig.dbHost!, - DB_PORT: String(config.restApiConfig.rdsConfig.dbPort), // Default PostgreSQL port - DB_NAME: config.restApiConfig.rdsConfig.dbName, // Database name - DB_USER: config.restApiConfig.rdsConfig.username, // Admin user for RDS - IAM_NAME: user, // IAM role for Lambda execution - }, - role: iamAuthLambdaRole, // Lambda execution role - layers: [commonLayer], - vpc: vpc.vpc, - vpcSubnets: vpc.subnetSelection, - securityGroups: securityGroups, - }); - } - - getLambdaLayer (scope: Stack, config: Config): ILayerVersion { - return LayerVersion.fromLayerVersionArn( - scope, - 'LISAServeCommonLayerVersion', - StringParameter.valueForStringParameter(scope, `${config.deploymentPrefix}/layerVersion/common`), - ); - } - } diff --git a/lib/user-interface/react/package.json b/lib/user-interface/react/package.json index ffa3f25d6..22467670f 100644 --- a/lib/user-interface/react/package.json +++ b/lib/user-interface/react/package.json @@ -27,7 +27,7 @@ "@fortawesome/free-solid-svg-icons": "^7.1.0", "@fortawesome/react-fontawesome": "^3.1.1", "@langchain/core": "^1.1.4", - "@langchain/openai": "^1.1.3", + "@langchain/openai": "1.2.0", "@microsoft/fetch-event-source": "^2.0.1", "@reduxjs/toolkit": "^2.11.1", "@swc/core": "^1.15.3", @@ -94,6 +94,8 @@ "jsdom": "^27.3.0", "linkify-it": "^5.0.0", "markdown-it": "^14.1.0", + "patch-package": "^8.0.1", + "postinstall-postinstall": "^2.1.0", "prettier": "^3.7.4", "redux-mock-store": "^1.5.5", "uuid": "^13.0.0", diff --git a/lib/user-interface/react/src/components/Topbar.tsx b/lib/user-interface/react/src/components/Topbar.tsx index faecd3916..9a0056f6c 100644 --- a/lib/user-interface/react/src/components/Topbar.tsx +++ b/lib/user-interface/react/src/components/Topbar.tsx @@ -78,7 +78,7 @@ function Topbar ({ configs }: TopbarProps): ReactElement { external: false, href: '/mcp-connections', } as ButtonDropdownProps.Item] : []) - ]; + ].sort((a,b) => a.text.localeCompare(b.text)); return ( void; onRefresh: () => void; disableCreate?: boolean; + isFetching?: boolean; }; export function ApiTokenActions ({ @@ -36,6 +38,7 @@ export function ApiTokenActions ({ setCreateWizardVisible, onRefresh, disableCreate = false, + isFetching = false, }: ApiTokenActionsProps): ReactElement { const dispatch = useAppDispatch(); const notificationService = useNotificationService(dispatch); @@ -74,12 +77,11 @@ export function ApiTokenActions ({ return ( - + ariaLabel='Refresh tokens' + /> + ariaLabel='Refresh jobs' + /> } > diff --git a/lib/user-interface/react/src/components/chatbot/components/Message.tsx b/lib/user-interface/react/src/components/chatbot/components/Message.tsx index ae24d6965..8de856431 100644 --- a/lib/user-interface/react/src/components/chatbot/components/Message.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/Message.tsx @@ -35,7 +35,7 @@ import 'katex/dist/katex.min.css'; import styles from './Message.module.css'; import { MessageContent } from '@langchain/core/messages'; -import { base64ToBlob, fetchImage, getDisplayableMessage, messageContainsImage } from '@/components/utils'; +import { base64ToBlob, fetchImage, getDisplayableMessage, messageContainsImage, messageContainsVideo } from '@/components/utils'; import React, { useEffect, useState, useMemo } from 'react'; import { IChatConfiguration } from '@/shared/model/chat.configurations.model'; import { downloadFile } from '@/shared/util/downloader'; @@ -61,17 +61,31 @@ type MessageProps = { chatConfiguration: IChatConfiguration; showUsage?: boolean; onMermaidRenderComplete?: () => void; + onVideoLoadComplete?: () => void; + retryResponse?: () => Promise + errorState?: boolean; }; -export const Message = React.memo(({ message, isRunning, showMetadata, isStreaming, markdownDisplay, setUserPrompt, setChatConfiguration, handleSendGenerateRequest, chatConfiguration, callingToolName, showUsage = false, onMermaidRenderComplete }: MessageProps) => { +export const Message = React.memo(({ message, isRunning, showMetadata, isStreaming, markdownDisplay, setUserPrompt, setChatConfiguration, handleSendGenerateRequest, chatConfiguration, callingToolName, showUsage = false, onMermaidRenderComplete, onVideoLoadComplete, retryResponse, errorState }: MessageProps) => { const currentUser = useAppSelector(selectCurrentUsername); const ragCitations = !isStreaming && message?.metadata?.ragDocuments ? message?.metadata.ragDocuments : undefined; const [resend, setResend] = useState(false); const [showImageViewer, setShowImageViewer] = useState(false); const [selectedImage, setSelectedImage] = useState(undefined); const [selectedMetadata, setSelectedMetadata] = useState(undefined); + const [reasoningExpanded, setReasoningExpanded] = useState(true); const { colorScheme } = useContext(ColorSchemeContext); const isDarkMode = colorScheme === Mode.Dark; + const hasMessageContent = message?.content && typeof message.content === 'string' && message.content.trim() && message.content.trim() !== '\u00A0'; + + // Auto-expand reasoning when it first appears, then auto-collapse when message content starts arriving + useEffect(() => { + if (hasMessageContent) { + setReasoningExpanded(false); + } else if (!hasMessageContent && message?.reasoningContent) { + setReasoningExpanded(true); + } + }, [hasMessageContent, message?.reasoningContent]); useEffect(() => { if (resend) { @@ -258,10 +272,11 @@ export const Message = React.memo(({ message, isRunning, showMetadata, isStreami { if (e.detail.id === 'download-image') { @@ -269,6 +284,8 @@ export const Message = React.memo(({ message, isRunning, showMetadata, isStreami await fetchImage(item.image_url.url) : base64ToBlob(item.image_url.url.split(',')[1], 'image/png'); downloadFile(URL.createObjectURL(file), `${metadata?.imageGenerationParams?.prompt}.png`); + } else if (e.detail.id === 'share-image') { + navigator.clipboard.writeText(item.image_url.url); } else if (e.detail.id === 'copy-image') { const copy = new ClipboardItem({ 'image/png': item.image_url.url.startsWith('https://') ? @@ -293,6 +310,51 @@ export const Message = React.memo(({ message, isRunning, showMetadata, isStreami }} /> ; + } else if (item.type === 'video_url' && item.video_url?.url) { + const videoId = item.video_url.video_id; + return ( +
+ + { + if (e.detail.id === 'download-video') { + const videoUrl = item.video_url.url; + const videoBlob = await fetch(videoUrl).then((r) => r.blob()); + const filename = `${metadata?.videoGenerationParams?.prompt || 'video'}.mp4`; + downloadFile(URL.createObjectURL(videoBlob), filename); + } else if (e.detail.id === 'share-video') { + navigator.clipboard.writeText(item.video_url.url); + } else if (e.detail.id === 'remix-video' && videoId) { + // Call the remix endpoint to create a new variation + setUserPrompt(`Remix video: ${metadata?.videoGenerationParams?.prompt ?? ''}`); + // Store the video_id for the remix call + setChatConfiguration( + merge({}, chatConfiguration, { + sessionConfiguration: { + remixVideoId: videoId + } + }) + ); + setResend(true); + } + }} + /> +
+ ); } return null; }); @@ -314,9 +376,9 @@ export const Message = React.memo(({ message, isRunning, showMetadata, isStreami return ( (message.type === MessageTypes.HUMAN || message.type === MessageTypes.AI || message.type === MessageTypes.TOOL) && -
+
- {(isRunning && !callingToolName) && ( + {(isRunning && !callingToolName && !message?.metadata?.videoGeneration) && ( : undefined} > - + Generating response @@ -356,7 +418,7 @@ export const Message = React.memo(({ message, isRunning, showMetadata, isStreami )} - {message?.type === 'ai' && !isRunning && !callingToolName && message?.content && ( + {message?.type === 'ai' && !isRunning && !callingToolName && (message?.content || message?.reasoningContent) && ( : undefined} > - {renderContent(message.content, message.metadata)} + {message?.reasoningContent && chatConfiguration.sessionConfiguration.showReasoningContent && ( + + { + setReasoningExpanded(detail.expanded); + }} + > + + + + +
{message.reasoningContent}
+
+
+
+ { + if (detail.id === 'copy-reasoning') { + navigator.clipboard.writeText(message.reasoningContent || ''); + } + }} + ariaLabel='Copy reasoning content' + dropdownExpandToViewport + items={[ + { + type: 'icon-button', + id: 'copy-reasoning', + iconName: 'copy', + text: 'Copy Reasoning', + popoverFeedback: ( + + Reasoning copied + + ) + } + ]} + variant='icon' + /> +
+
+
+
+ )} + {message?.content && (typeof message.content === 'string' ? (message.content.trim() && message.content.trim() !== '\u00A0') : true) && renderContent(message.content, message.metadata)} {showMetadata && !isStreaming && }
- {!isStreaming && !messageContainsImage(message.content) &&
@@ -430,9 +538,13 @@ export const Message = React.memo(({ message, isRunning, showMetadata, isStreami
- ['copy'].includes(detail.id) && - navigator.clipboard.writeText(getDisplayableMessage(message.content)) + onItemClick={async ({ detail }) => { + if (detail.id === 'copy'){ + navigator.clipboard.writeText(getDisplayableMessage(message.content)); + } else if (detail.id === 'retry'){ + await retryResponse(); + } + } } ariaLabel='Chat actions' dropdownExpandToViewport @@ -447,7 +559,19 @@ export const Message = React.memo(({ message, isRunning, showMetadata, isStreami Input copied ) - } + }, + ...(errorState ? [ + { + type: 'icon-button' as const, + id: 'retry' as const, + iconName: 'refresh' as const, + text: 'Retry Message' as const, + popoverFeedback: ( + + Retrying Message + + ) + }] : []) ]} variant='icon' /> diff --git a/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx b/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx index 4586b31c9..66497e30b 100644 --- a/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx @@ -31,6 +31,7 @@ import { IChatConfiguration } from '@/shared/model/chat.configurations.model'; import { IModel, ModelType } from '@/shared/model/model-management.model'; import { IConfiguration } from '@/shared/model/configuration.model'; import { LisaChatSession } from '@/components/types'; +import { ModelFeatures } from '@/components/types'; export type SessionConfigurationProps = { title?: string; @@ -93,7 +94,16 @@ export const SessionConfiguration = ({ }; }); + const reasoningEffortOptions = [ + { value: 'none', label: 'None' }, + { value: 'minimal', label: 'Minimal' }, + { value: 'low', label: 'Low' }, + { value: 'medium', label: 'Medium' }, + { value: 'high', label: 'High' }, + { value: 'xhigh', label: 'X-High' }, + ]; const isImageModel = selectedModel?.modelType === ModelType.imagegen; + const isVideoModel = selectedModel?.modelType === ModelType.videogen; return ( - + updateSessionConfiguration('streaming', detail.checked)} checked={chatConfiguration.sessionConfiguration.streaming} @@ -126,7 +136,7 @@ export const SessionConfiguration = ({ > Show Message Metadata } - {systemConfig && systemConfig.configuration.enabledComponents.editChatHistoryBuffer && !isImageModel && !modelOnly && + {systemConfig && systemConfig.configuration.enabledComponents.editChatHistoryBuffer && !isImageModel && !isVideoModel && !modelOnly && } + {selectedModel?.features?.find((feature) => feature.name === ModelFeatures.REASONING) && + + { + updateSessionConfiguration('videoGenerationArgs', { + ...chatConfiguration.sessionConfiguration.videoGenerationArgs, + seconds: detail.selectedOption.value, + }); + }} + options={[ + { label: '4 seconds', value: '4' }, + { label: '8 seconds', value: '8' }, + { label: '12 seconds', value: '12' }, + ]} + /> + + + & BaseModelConf 'modelType': detail.selectedOption.value, }; - // turn off streaming for embedded models - if (fields.modelType === ModelType.embedding || fields.modelType === ModelType.imagegen) { + // enable streaming by default for textgen models + if (fields.modelType === ModelType.textgen) { + fields['streaming'] = true; + } else if (fields.modelType === ModelType.embedding || fields.modelType === ModelType.imagegen || fields.modelType === ModelType.videogen) { fields['streaming'] = false; } // turn off summarization and image input for embedded and imagegen models - if ((fields.modelType === ModelType.embedding || fields.modelType === ModelType.imagegen)) { + if ((fields.modelType === ModelType.embedding || fields.modelType === ModelType.imagegen || fields.modelType === ModelType.videogen)) { fields['features'] = props.item.features.filter((feature) => feature.name !== ModelFeatures.SUMMARIZATION && feature.name !== ModelFeatures.IMAGE_INPUT && feature.name !== ModelFeatures.TOOL_CALLS); } @@ -132,6 +143,7 @@ export function BaseModelConfig (props: FormProps & BaseModelConf options={[ { label: 'TEXTGEN', value: ModelType.textgen }, { label: 'IMAGEGEN', value: ModelType.imagegen }, + { label: 'VIDEOGEN', value: ModelType.videogen }, { label: 'EMBEDDING', value: ModelType.embedding }, ]} disabled={props.isEdit} @@ -193,7 +205,7 @@ export function BaseModelConfig (props: FormProps & BaseModelConf props.setFields({'streaming': detail.checked}) } onBlur={() => props.touchFields(['streaming'])} - disabled={isEmbeddingModel || isImageModel} + disabled={isEmbeddingModel || isImageModel || isVideoModel} checked={props.item.streaming} /> @@ -210,11 +222,29 @@ export function BaseModelConfig (props: FormProps & BaseModelConf props.setFields({'features': props.item.features.filter((feature) => feature.name !== ModelFeatures.TOOL_CALLS)}); } }} - disabled={isEmbeddingModel || isImageModel} + disabled={isEmbeddingModel || isImageModel || isVideoModel} onBlur={() => props.touchFields(['features'])} checked={props.item.features.find((feature) => feature.name === ModelFeatures.TOOL_CALLS) !== undefined} /> + + { + if (detail.checked && props.item.features.find((feature) => feature.name === ModelFeatures.REASONING) === undefined) { + props.setFields({'features': props.item.features.concat({name: ModelFeatures.REASONING, overview: ''})}); + } else if (!detail.checked && props.item.features.find((feature) => feature.name === ModelFeatures.REASONING) !== undefined) { + props.setFields({'features': props.item.features.filter((feature) => feature.name !== ModelFeatures.REASONING)}); + } + }} + disabled={isEmbeddingModel || isImageModel || isVideoModel} + onBlur={() => props.touchFields(['features'])} + checked={props.item.features.find((feature) => feature.name === ModelFeatures.REASONING) !== undefined} + /> + & BaseModelConf props.setFields({'features': props.item.features.filter((feature) => feature.name !== ModelFeatures.IMAGE_INPUT)}); } }} - disabled={isEmbeddingModel || isImageModel} + disabled={isEmbeddingModel || isImageModel || isVideoModel} onBlur={() => props.touchFields(['features'])} checked={props.item.features.find((feature) => feature.name === ModelFeatures.IMAGE_INPUT) !== undefined} /> @@ -247,7 +277,7 @@ export function BaseModelConfig (props: FormProps & BaseModelConf props.setFields({'features': props.item.features.filter((feature) => feature.name !== ModelFeatures.SUMMARIZATION)}); } }} - disabled={isEmbeddingModel || isImageModel} + disabled={isEmbeddingModel || isImageModel || isVideoModel} onBlur={() => props.touchFields(['features'])} checked={props.item.features.find((feature) => feature.name === ModelFeatures.SUMMARIZATION) !== undefined} /> diff --git a/lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx b/lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx index 80395f1f4..1e74417a2 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx @@ -16,7 +16,7 @@ import _ from 'lodash'; import { Modal, Wizard } from '@cloudscape-design/components'; -import { IModel, IModelRequest, ModelRequestSchema } from '../../../shared/model/model-management.model'; +import { IModel, IModelRequest, ModelRequestSchema, ModelRequestBaseSchema } from '../../../shared/model/model-management.model'; import { ReactElement, useEffect, useMemo, useState } from 'react'; import { scrollToInvalid, useValidationReducer } from '../../../shared/validation'; import { BaseModelConfig } from './BaseModelConfig'; @@ -67,7 +67,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { deleteScheduleMutation, { isSuccess: isScheduleDeleteSuccess, isError: isScheduleDeleteError, error: scheduleDeleteError, isLoading: isScheduleDeleting, reset: resetScheduleDelete }, ] = useDeleteScheduleMutation(); - const initialForm = ModelRequestSchema.partial().parse({}); + const initialForm = ModelRequestBaseSchema.partial().parse({}); const dispatch = useAppDispatch(); const notificationService = useNotificationService(dispatch); diff --git a/lib/user-interface/react/src/components/prompt-templates-library/PromptTemplateModal.tsx b/lib/user-interface/react/src/components/prompt-templates-library/PromptTemplateModal.tsx index a6e043b4c..97edeba50 100644 --- a/lib/user-interface/react/src/components/prompt-templates-library/PromptTemplateModal.tsx +++ b/lib/user-interface/react/src/components/prompt-templates-library/PromptTemplateModal.tsx @@ -82,8 +82,10 @@ export const PromptTemplateModal = ({ } }, [showModal, dispatch]); + const modalTestId = 'prompt-template-modal'; return ( { setShowModal(false); setUserPrompt(''); @@ -104,6 +106,7 @@ export const PromptTemplateModal = ({ Cancel + ariaLabel='Refresh prompt templates' + /> {PromptTemplatesActionButton(dispatch, notificationService, props, {isUserAdmin, username})} + {RepositoryActionButton(dispatch, notificationService, props)}