From e1d8a4917bbc9fecbbbe392e222e6d4eae8cd7d4 Mon Sep 17 00:00:00 2001 From: bedanley Date: Wed, 17 Sep 2025 10:35:50 -0600 Subject: [PATCH 01/14] Set endpointUrl on ECS Cluster (#397) --- ecs_model_deployer/src/lib/ecsCluster.ts | 25 ++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/ecs_model_deployer/src/lib/ecsCluster.ts b/ecs_model_deployer/src/lib/ecsCluster.ts index 45b95cd95..280676e6a 100644 --- a/ecs_model_deployer/src/lib/ecsCluster.ts +++ b/ecs_model_deployer/src/lib/ecsCluster.ts @@ -124,9 +124,9 @@ export class ECSCluster extends Construct { // we want to set these based on the task created but currently the ECSCluster for model // will only create one task, so grab these values during creation so we can set the properties // on this class - let container; - let taskRole; - let endpointUrl; + let container: ContainerDefinition | undefined; + let taskRole: IRole | undefined; + let endpointUrl: string | undefined; Object.entries(ecsConfig.tasks).forEach(([, taskDefinition]) => { const environment = taskDefinition.environment; @@ -201,7 +201,7 @@ export class ECSCluster extends Construct { } const roleId = identifier; - const taskRole = taskRoleName ? + taskRole = taskRoleName ? Role.fromRoleName(this, createCdkId([config.deploymentName, roleId]), taskRoleName) : this.createTaskRole(config.deploymentName ?? '', config.deploymentPrefix, roleId); @@ -231,7 +231,7 @@ export class ECSCluster extends Construct { : undefined; const image = CodeFactory.createImage(taskDefinition.containerConfig.image, this, identifier, ecsConfig.buildArgs); - const container = ec2TaskDefinition.addContainer(createCdkId([identifier, 'Container']), { + container = ec2TaskDefinition.addContainer(createCdkId([identifier, 'Container']), { containerName: createCdkId([config.deploymentName, identifier], 32, 2), image, environment, @@ -322,15 +322,20 @@ export class ECSCluster extends Construct { const domain = loadBalancer.loadBalancerDnsName; endpointUrl = `${protocol}://${domain}`; + }); - new CfnOutput(this, 'modelEndpointurl', { - key: 'modelEndpointUrl', - value: this.endpointUrl, - }); + // Validate endpointUrl is set before creating output + if (!endpointUrl) { + throw new Error('Failed to create endpoint URL - no tasks configured'); + } + + new CfnOutput(this, 'modelEndpointurl', { + key: 'modelEndpointUrl', + value: endpointUrl, }); // Update - this.endpointUrl = endpointUrl!; + this.endpointUrl = endpointUrl; this.container = container!; this.taskRole = taskRole!; } From 70f677df80da9eb578682d9e025d99c68a261577 Mon Sep 17 00:00:00 2001 From: bedanley Date: Fri, 19 Sep 2025 11:33:18 -0600 Subject: [PATCH 02/14] Feature/adc build (#389) * Add ADC build and docs --- Makefile | 7 +- bin/build-assets | 7 + bin/build-images | 143 +++++++++++++++ bin/build-lambdas | 47 +++++ bin/copy-deps.sh | 79 -------- bin/package-lambda-layer | 151 ++++++++++----- lib/docs/admin/deploy.md | 183 +++++++++++++------ lib/rag/ingestion/ingestion-image/Dockerfile | 3 +- lib/rag/layer/requirements.txt | 9 +- lib/serve/rest-api/Dockerfile | 2 +- package.json | 5 +- requirements-dev.txt | 4 +- 12 files changed, 449 insertions(+), 191 deletions(-) create mode 100755 bin/build-assets create mode 100755 bin/build-images create mode 100755 bin/build-lambdas delete mode 100755 bin/copy-deps.sh diff --git a/Makefile b/Makefile index a12fb4dfa..84af21cd8 100644 --- a/Makefile +++ b/Makefile @@ -143,9 +143,9 @@ else endif -## Set up Python interpreter environment +## Set up Python interpreter environment to match LISA deployed version createPythonEnvironment: - python3 -m venv .venv + python3.11 -m venv .venv @printf ">>> New virtual environment created. To activate run: 'source .venv/bin/activate'" @@ -269,6 +269,9 @@ listStacks: buildNpmModules: npm run build +buildArchive: + BUILD_ASSETS=true npm run build + define print_config @printf "\n \ DEPLOYING $(STACK) STACK APP INFRASTRUCTURE \n \ diff --git a/bin/build-assets b/bin/build-assets new file mode 100755 index 000000000..4a214ae4a --- /dev/null +++ b/bin/build-assets @@ -0,0 +1,7 @@ +#!/bin/bash +set -e + +ROOT=$(pwd) + +./bin/build-lambdas +./bin/build-images --export diff --git a/bin/build-images b/bin/build-images new file mode 100755 index 000000000..f79d0adb7 --- /dev/null +++ b/bin/build-images @@ -0,0 +1,143 @@ +#!/bin/bash + +set -e + +ROOT=$(pwd) +OUTPUT_DIR=$ROOT/dist/images +DOCKER_CMD=$(command -v finch >/dev/null 2>&1 && echo "finch" || echo "docker") + +# Parse command line arguments +UPLOAD=false +EXPORT=false +for arg in "$@"; do + case $arg in + --upload) + UPLOAD=true + shift + ;; + --export) + EXPORT=true + mkdir -p $OUTPUT_DIR + shift + ;; + esac +done + +# Default LISA_VERSION if not set +LISA_VERSION=${LISA_VERSION:-$(cat ./VERSION 2>/dev/null || echo "latest")} + +# ECR configuration +ACCOUNT=${AWS_ACCOUNT:-""} +REGION=${AWS_REGION:-"us-east-1"} +DOMAIN=${AWS_DOMAIN:-"amazonaws.com"} +ECR_BASE_URL=$ACCOUNT.dkr.ecr.$REGION.$DOMAIN + +# Function to build a single image +build_image() { + local dockerfile_path="$1" + local repository_name="$2" + local image_tag="$3" + local build_context_path="$4" + shift 4 + local build_args=("$@") + + echo "Building image: $repository_name:$image_tag" + echo "Context: $build_context_path" + + # Construct docker build command + local docker_cmd="$DOCKER_CMD build" + + # Add build args + for arg in "${build_args[@]}"; do + docker_cmd="$docker_cmd --build-arg $arg" + done + + # Add dockerfile, tag, and context + docker_cmd="$docker_cmd -f $build_context_path/$dockerfile_path -t $repository_name:$image_tag $build_context_path" + + echo "Executing: $docker_cmd" + eval "$docker_cmd" + echo "Successfully built $repository_name:$image_tag" + + # Upload to ECR if --upload flag is set + if [[ "$UPLOAD" == "true" && -n "$ACCOUNT" ]]; then + local ecr_tag="$ECR_BASE_URL/$repository_name:$image_tag" + echo "Tagging for ECR: $ecr_tag" + $DOCKER_CMD tag "$repository_name:$image_tag" "$ecr_tag" + echo "Pushing to ECR: $ecr_tag" + $DOCKER_CMD push "$ecr_tag" + echo "Successfully pushed $ecr_tag" + fi + + # Export image if --export flag is set + if [[ "$EXPORT" == "true" ]]; then + local export_file="$OUTPUT_DIR/${repository_name}_${image_tag}.tar" + echo "Exporting image to: $export_file" + $DOCKER_CMD save "$repository_name:$image_tag" -o "$export_file" + echo "Successfully exported $export_file" + fi + echo "" +} + +# Function to login to ECR +ecr_login() { + if [[ "$UPLOAD" == "true" && -n "$ACCOUNT" ]]; then + echo "Logging into ECR..." + aws ecr get-login-password --region $REGION | $DOCKER_CMD login --username AWS --password-stdin $ACCOUNT.dkr.ecr.$REGION.$DOMAIN + echo "ECR login successful" + echo "" + fi +} + +# Main function to build all images +build_all_images() { + echo "Starting Docker image builds..." + echo "LISA_VERSION: $LISA_VERSION" + if [[ "$UPLOAD" == "true" && -n "$ACCOUNT" ]]; then + echo "ECR_BASE_URL: $ECR_BASE_URL" + echo "Upload: Enabled" + else + echo "Upload: Disabled" + fi + echo "" + + ecr_login + + # lisa-rest-api + python3 scripts/cache-tiktoken-for-offline.py ./lib/serve/rest-api/TIKTOKEN_CACHE + build_image "Dockerfile" "lisa-rest-api" "$LISA_VERSION" "./lib/serve/rest-api" \ + "NODE_ENV=production" \ + "LITELLM_CONFIG=\"db_key: sk-a8814208-0388-480c-9fc7-fea59607ca38\"" \ + "BASE_IMAGE=python:3.11" + + # lisa-batch-ingestion + RAG_DIR="./lib/rag/ingestion/ingestion-image" + BUILD_DIR="${RAG_DIR}/build" + mkdir -p "$BUILD_DIR" + rsync -av --exclude='__pycache__' ./lambda/ "$BUILD_DIR/" + rsync -av --exclude='__pycache__' ./lisa-sdk/lisapy/ "$BUILD_DIR/lisapy/" + build_image "Dockerfile" "lisa-batch-ingestion" "$LISA_VERSION" "$RAG_DIR" "NODE_ENV=production" + + # lisa-tei + build_image "Dockerfile" "lisa-tei" "latest" "./lib/serve/ecs-model/embedding/tei" \ + "NODE_ENV=production" \ + "BASE_IMAGE=ghcr.io/huggingface/text-embeddings-inference:latest" \ + "MOUNTS3_DEB_URL=https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb" + + # lisa-tgi + build_image "Dockerfile" "lisa-tgi" "latest" "./lib/serve/ecs-model/textgen/tgi" \ + "NODE_ENV=production" \ + "BASE_IMAGE=ghcr.io/huggingface/text-generation-inference:latest" \ + "MOUNTS3_DEB_URL=https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb" + + # lisa-vllm + build_image "Dockerfile" "lisa-vllm" "latest" "./lib/serve/ecs-model/vllm" \ + "NODE_ENV=production" \ + "BASE_IMAGE=vllm/vllm-openai:latest" \ + "MOUNTS3_DEB_URL=https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb" + + echo "All images built successfully!" +} + +# Run the build +build_all_images diff --git a/bin/build-lambdas b/bin/build-lambdas new file mode 100755 index 000000000..ecfa139a7 --- /dev/null +++ b/bin/build-lambdas @@ -0,0 +1,47 @@ +#!/bin/bash +set -e + +ROOT=$(pwd) +OUTPUT_DIR=$ROOT/dist/layers +mkdir -p $OUTPUT_DIR + +PYPI_URL=${PYPI_URL:-https://pypi.org/simple/} +source .venv/bin/activate + +build_layer() { + local package_name=$1 + local source_path=$2 + local pre_build_cmd=$3 + echo "Building Lambda Layer $package_name from $source_path..." + + if [ -n "$pre_build_cmd" ]; then + eval "$pre_build_cmd" + fi + + cd $source_path + $ROOT/bin/package-lambda-layer --src . --output "$package_name.zip" --pypi $PYPI_URL --layer + mv ./build/"$package_name.zip" $OUTPUT_DIR/ + rm -rf ./build + cd $ROOT +} + +build_lambda() { + local package_name=$1 + local source_path=$2 + echo "Building Lambda $package_name from $source_path..." + cd "$source_path" + $ROOT/bin/package-lambda-layer --src . --output "$package_name.zip" --pypi $PYPI_URL + mv ./build/"$package_name.zip" $OUTPUT_DIR/ + rm -rf ./build + cd $ROOT +} + +echo "Building Python Lambda Layers..." +build_layer "AimlAdcLisaCommonLayer" "./lib/core/layers/common" +build_layer "AimlAdcLisaAuthLayer" "./lib/core/layers/authorizer" +build_layer "AimlAdcLisaFastApiLayer" "./lib/core/layers/fastapi" +build_layer "AimlAdcLisaRag" "./lib/rag/layer" "python3 scripts/cache-tiktoken-for-offline.py ./lib/rag/layer/TIKTOKEN_CACHE" +build_layer "AimlAdcLisaSdk" "./lisa-sdk" + +echo "Building Python Lambdas..." +build_lambda "AimlAdcLisaLambda" "./lambda" diff --git a/bin/copy-deps.sh b/bin/copy-deps.sh deleted file mode 100755 index c972a7247..000000000 --- a/bin/copy-deps.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash - -function install_python_deps() { - local input_path=$1 - local output_path=$2 - local package=$3 - - echo "Installing Python dependencies for $package" - mkdir -p "${output_path}" - if ! pip install -r ${input_path}/requirements.txt --target $output_path --platform manylinux2014_x86_64 --only-binary=:all: --no-deps --no-cache-dir; then - echo "Failed to install Python dependencies for ${package}" - exit 1 - fi - - echo "${package} dependencies installed successfully" - rsync -a "${input_path}/" "${output_path}" - - echo "Optimizing ${package}" - find $output_path -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null - find $output_path -type d -name "*.dist-info" -exec rm -rf {} + 2>/dev/null - find $output_path -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null - find $output_path -type f -name "*.pyc" -delete - find $output_path -type f -name "*.pyo" -delete - find $output_path -type f -name "*.so" -exec strip {} + 2>/dev/null -} - -function setup_python_dist(){ - cd dist - - # Define the layers - PYTHON_VERSION="3.11" - DIST="." - OUTPUT_DIR="python/lib/${PYTHON_VERSION}/site-packages" - - # Create a virtual environment for isolation - python -m venv .venv - source .venv/bin/activate - - # # Install dependencies for each lambda layer - layers=("authorizer" "common" "fastapi") - layers_path="../lib/core/layers" - layers_output="${DIST}/lambdaLayer" - for layer in "${layers[@]}"; do - # ./package-lambda-layer --src="${layers_path}/${layer}" --build="./dist/$layer" --output="${layers_output}/${layer}" - layer_path="${layers_path}/${layer}" - layer_output="${layers_output}/${layer}/${OUTPUT_DIR}" - install_python_deps $layer_path $layer_output $layer - done - - # Install rag layer - rag_path="../lib/rag/layer" - rag_output="${DIST}/rag/${OUTPUT_DIR}" - rag_package="rag" - install_python_deps $rag_path $rag_output $rag_package - - # Install lisa-sdk dependencies - sdk_path="../lisa-sdk" - sdk_output="${DIST}/lisa-sdk/${OUTPUT_DIR}" - sdk_package="lisa-sdk" - install_python_deps $sdk_path $sdk_output $sdk_package - - # Deactivate virtual environment - deactivate - rm -rf .venv - echo "All Python dependencies installed successfully" - cd - -} - -function copy_dist() { - mkdir -p dist/ecs_model_deployer && rsync -av ecs_model_deployer/dist/ dist/ecs_model_deployer/ && cp ecs_model_deployer/Dockerfile dist/ecs_model_deployer/ - mkdir -p dist/vector_store_deployer && rsync -av vector_store_deployer/dist/ dist/vector_store_deployer/ && cp vector_store_deployer/Dockerfile dist/vector_store_deployer/ - mkdir -p dist/lisa-web && rsync -av lib/user-interface/react/dist/ dist/lisa-web - mkdir -p dist/docs && rsync -av lib/docs/dist/ dist/docs - cp VERSION dist/ -} - -mkdir -p dist -# setup_python_dist -copy_dist diff --git a/bin/package-lambda-layer b/bin/package-lambda-layer index 18591b8ca..16cc05c1a 100755 --- a/bin/package-lambda-layer +++ b/bin/package-lambda-layer @@ -4,76 +4,141 @@ set -e SRC=src OUTPUT=Lambda.zip EXCLUDE_PACKAGES="" -SRC_ROOT=$PWD -BUILD_DIR=$SRC_ROOT/build +BUILD_DIR=$PWD/build +IS_LAYER=0 +TMP_DIR=$BUILD_DIR/tmp/ +PYPI_URL= # Parse named parameters while [ $# -gt 0 ]; do - case "$1" in - --src=*) - SRC="${1#*=}" - ;; - --output=*) - OUTPUT="${1#*=}" - ;; - --build=*) - BUILD_DIR="${1#*=}" - ;; - --exclude=*) - EXCLUDE_PACKAGES="${1#*=}" - ;; - *) - echo "Unknown parameter: $1" - echo "Usage: $0 --src= --output= --exclude=" - exit 1 - ;; - esac + if [[ $1 == *"="* ]]; then + # Handle --param=value style + param="${1%%=*}" + value="${1#*=}" + + case "$param" in + --src) + SRC="$value" + ;; + --output) + OUTPUT="$value" + ;; + --build) + BUILD_DIR="$value" + ;; + --exclude) + EXCLUDE_PACKAGES="$value" + ;; + --pypi) + PYPI_URL="$value" + ;; + --layer) + IS_LAYER=1 + ;; + *) + echo "Unknown parameter: $param" + echo "Usage: $0 --src --output --exclude --layer" + exit 1 + ;; + esac + else + # Handle --param value style + case "$1" in + --src) + shift + SRC="$1" + ;; + --output) + shift + OUTPUT="$1" + ;; + --build) + shift + BUILD_DIR="$1" + TMP_DIR=$BUILD_DIR/tmp/python/ + ;; + --exclude) + shift + EXCLUDE_PACKAGES="$1" + ;; + --pypi) + shift + PYPI_URL="$1" + ;; + --layer) + IS_LAYER=1 + ;; + *) + echo "Unknown parameter: $1" + echo "Usage: $0 --src --output --exclude " + exit 1 + ;; + esac + fi shift done +echo "Starting" +if [ $IS_LAYER -eq 1 ]; then + TMP_DIR=$BUILD_DIR/tmp/python/ +fi + +if [ -z "$PYPI_URL" ]; then + echo "Must supply PYPI_URL via --pypi" + exit 1 +fi + +# Extract IP from PYPI_URL for trusted host +TRUSTED_HOST=$(echo $PYPI_URL | sed 's|http://||' | sed 's|/.*||') + +# Print parameters for debugging +echo "Source directory: $SRC" +echo "Output file: $OUTPUT" +echo "Build directory: $BUILD_DIR" +echo "Temp directory: $TMP_DIR" + + install_requirements() { - echo "installing requirements" - rm -rf "$BUILD_DIR" - mkdir -p "$BUILD_DIR/python" - python3 -m pip install "$SRC_ROOT" --target "${BUILD_DIR}/python" + echo "Installing requirements" + rm -rf "$TMP_DIR" + mkdir -p "$TMP_DIR" + if [ -f "$SRC/requirements.txt" ]; then + echo "Installing requirements from $SRC/requirements.txt" + echo "Using python version $(python3 --version)" + python3 -m pip install -r "$SRC/requirements.txt" --force-reinstall --no-cache-dir --target "$TMP_DIR" --index-url $PYPI_URL --trusted-host $TRUSTED_HOST + else + echo "No requirements.txt found in $SRC" + fi } build_package() { - echo "building package" + echo "Building package" if [ -d "$SRC" ]; then - cp -r "$SRC"/* "${BUILD_DIR}/python/" - fi -} - -copy_configuration() { - echo "copying configuration" - if [ -d "configuration/Packaging" ]; then - cp -a configuration/Packaging "$BUILD_DIR" + rsync -av --exclude='build' --exclude='.hatch' --exclude='.venv' "$SRC/" "$TMP_DIR/" fi } package_artifacts() { - echo "packaging" + echo "Packaging" if [ -n "$EXCLUDE_PACKAGES" ]; then echo "Removing excluded packages: $EXCLUDE_PACKAGES" for pkg in ${EXCLUDE_PACKAGES//,/ }; do echo "Removing $pkg" - rm -rf ${BUILD_DIR}/python/${pkg} - rm -rf ${BUILD_DIR}/python/${pkg}-* + rm -rf ${TMP_DIR}/${pkg} + rm -rf ${TMP_DIR}/${pkg}-* # Also remove egg-info directories - find "${BUILD_DIR}/python" -type d -name "${pkg}*egg-info" -exec rm -rf {} + + find "$TMP_DIR" -type d -name "${pkg}*egg-info" -exec rm -rf {} + done fi # AWS Lambda recommends to exclude __pycache__: https://docs.aws.amazon.com/lambda/latest/dg/python-package.html#python-package-pycache - find "${BUILD_DIR}/python" -depth -name __pycache__ -exec rm -rf {} \; - cd "${BUILD_DIR}" - zip "${BUILD_DIR}/${OUTPUT}" ./python -r - rm -rf "${BUILD_DIR}/python" + find "${TMP_DIR}" -depth -name __pycache__ -exec rm -rf {} \; + cd "${BUILD_DIR}/tmp/" + zip -r "${BUILD_DIR}/${OUTPUT}" . + rm -rf "${BUILD_DIR}/tmp" } install_requirements build_package -copy_configuration package_artifacts diff --git a/lib/docs/admin/deploy.md b/lib/docs/admin/deploy.md index 33765d25a..e7096ca54 100644 --- a/lib/docs/admin/deploy.md +++ b/lib/docs/admin/deploy.md @@ -195,77 +195,144 @@ make bootstrap ``` ## ADC Region Deployment Tips -If you are deploying LISA into an ADC region with limited access to dependencies, we recommend that you build LISA in a -commercial region first, and then bring it up into your ADC region to deploy. First, do the npm and pip installs on a -computer with access to the dependencies. Then bundle it up with the libraries included and move into the ADC region. -Some properties will need to be set in the deployment file pointing to the built artifacts. From there the deployment -process is the same. - -### Using pre-built resources - -A default configuration will build the necessary containers, lambda layers, and production optimized -web application at build time. In the event that you would like to use pre-built resources due to -network connectivity reasons or other concerns with the environment where you'll be deploying LISA -you can do so. - -- For ECS containers (Models, APIs, etc) you can modify the `containerConfig` block of - the corresponding entry in `config.yaml`. For container images you can provide a path to a directory - from which a docker container will be built (default), a path to a tarball, an ECR repository arn and - optional tag, or a public registry path. - - We provide immediate support for HuggingFace TGI and TEI containers and for vLLM containers. The `example_config.yaml` - file provides examples for TGI and TEI, and the only difference for using vLLM is to change the - `inferenceContainer`, `baseImage`, and `path` options, as indicated in the snippet below. All other options can - remain the same as the model definition examples we have for the TGI or TEI models. vLLM can also support embedding - models in this way, so all you need to do is refer to the embedding model artifacts and remove the `streaming` field - to deploy the embedding model. - - vLLM has support for the OpenAI Embeddings API, but model support for it is limited because the feature is new. Currently, - the only supported embedding model with vLLM is [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct), - but this list is expected to grow over time as vLLM updates. - ```yaml - ecsModels: - - modelName: your-model-name - inferenceContainer: tgi - baseImage: ghcr.io/huggingface/text-generation-inference:2.0.1 - ``` -- If you are deploying the LISA Chat User Interface you can optionally specify the path to the pre-built - website assets using the top level `webAppAssetsPath` parameter in `config.yaml`. Specifying this path - (typically `lib/user-interface/react/dist`) will avoid using a container to build and bundle the assets - at CDK build time. -- For the lambda layers you can specify the path to a local zip archive of the layer code by including - the optional `lambdaLayerAssets` block in `config.yaml` similar to the following: +Amazon Dedicated Cloud (ADC) regions are isolated AWS environments designed for government customers' most sensitive workloads. These regions have restricted internet access and limited external dependencies, requiring special deployment considerations for LISA. -``` +There are two deployment approaches for ADC regions: + +1. **Pre-built Resources (Recommended)**: Build all components in a commercial region, then transfer to ADC +2. **In-Region Building**: Configure LISA to use ADC-accessible repositories for building components + +### Approach 1: Pre-built Resources (Recommended) + +This approach builds all necessary components in a commercial region with full internet access, then transfers them to the ADC region. + +#### Step 1: Build Components in Commercial Region + +1. Set up LISA in a commercial AWS region with internet access +2. Build all components: + ```bash + make buildArchive + ``` + This generates: + - Lambda function zip files in `./dist/layers/*.zip` + - Docker images exported as `./dist/images/*.tar` files + +#### Step 2: Transfer to ADC Region + +1. Upload Docker images to ECR in your ADC region: + ```bash + # Load and tag images + docker load -i lisa-rest-api.tar + docker tag lisa-rest-api:latest .dkr.ecr..amazonaws.com/lisa-rest-api:latest + + # Push to ADC ECR + aws ecr get-login-password --region | docker login --username AWS --password-stdin .dkr.ecr..amazonaws.com + docker push .dkr.ecr..amazonaws.com/lisa-rest-api:latest + ``` + You'll want to repeat this for lisa-batch-ingestion, as well as any of the LISA base model hosting containers (lisa-vllm, lisa-tgi, lisa-tei) + +2. Transfer built artifacts to ADC environment + +#### Step 3: Configure LISA for Pre-built Resources + +Update your `config-custom.yaml` in the ADC region: + +```yaml +# Lambda layers from pre-built archives lambdaLayerAssets: - authorizerLayerPath: lib/core/layers/authorizer_layer.zip - commonLayerPath: lib/core/layers/common_layer.zip - fastapiLayerPath: /path/to/fastapi_layer.zip - sdkLayerPath: lib/rag/layers/sdk_layer.zip + authorizerLayerPath: './dist/layers/AimlAdcLisaAuthLayer.zip' + commonLayerPath: './dist/layers/AimlAdcLisaCommonLayer.zip' + fastapiLayerPath: './dist/layers/AimlAdcLisaFastApiLayer.zip' + ragLayerPath: './dist/layers/AimlAdcLisaRag.zip' + sdkLayerPath: './dist/layers/AimlAdcLisaSdk.zip' + +# Lambda functions +lambdaPath: './dist/layers/AimlAdcLisaLambda.zip' + +# Pre-built web assets +webAppAssetsPath: './dist/lisa-web' +documentsPath: './dist/docs' +ecsModelDeployerPath: './dist/ecs_model_deployer' +vectorStoreDeployerPath: './dist/vector_store_deployer' + +# Container images from ECR +batchIngestionConfig: + type: external + code: .dkr.ecr..amazonaws.com/lisa-batch-ingestion:latest + +restApiConfig: + imageConfig: + type: external + code: .dkr.ecr..amazonaws.com/lisa-rest-api:latest ``` -### Deploying in ADC region -Now that we have everything setup we are ready to deploy. -```bash -make deploy -``` +### Approach 2: In-Region Building -By default, all stacks will be deployed but a particular stack can be deployed by providing the `STACK` argument to the `deploy` target. +This approach configures LISA to build components using repositories accessible from within the ADC region. -```bash -make deploy STACK=LisaServe -``` +#### Prerequisites +- ADC-accessible package repositories (PyPI mirror, npm registry, container registry) +- ADC-accessible container registries +- Network connectivity to required build dependencies -Available stacks can be listed by running: +#### Configuration -```bash -make listStacks +Update your `config-custom.yaml` to point to ADC-accessible repositories: + +```yaml +# Configure pip to use ADC-accessible PyPI mirror +pipConfig: + indexUrl: https://your-adc-pypi-mirror.com/simple + trustedHost: your-adc-pypi-mirror.com + +# Configure npm to use ADC-accessible registry +npmConfig: + registry: https://your-adc-npm-registry.com + +# Use ADC-accessible base images for LISA-Serve and Batch Ingestion +baseImage: /python:3.11 ``` +You'll also want any model hosting base containers available, e.g. vllm/vllm-openai:latest and ghcr.io/huggingface/text-embeddings-inference:latest + +To utilize the prebuilt hosting model containers with self-hosted models, select `type: ecr` in the Model Deployment > Container Configs. -After the `deploy` command is run, you should see many docker build outputs and eventually a CDK progress bar. The deployment should take about 10-15 minutes and will produce a single cloud formation output for the websocket URL. +### Deployment Steps -You can test the deployment with the integration test: +Once your configuration is complete: + +1. Bootstrap CDK (if not already done): + ```bash + make bootstrap + ``` + +2. Deploy LISA: + ```bash + make deploy + ``` + +3. Deploy specific stacks if needed: + ```bash + make deploy STACK=LisaServe + ``` + +4. List available stacks: + ```bash + make listStacks + ``` + +### Testing Your Deployment + +After deployment completes (10-15 minutes), test with: ```bash -pytest lisa-sdk/tests --url --verify | false +pytest lisa-sdk/tests --url --verify ``` + +### Troubleshooting ADC Deployments + +- **Build failures**: Ensure all dependencies are accessible from ADC region +- **Container pull errors**: Verify ECR repositories exist and have correct permissions +- **Lambda deployment issues**: Check that lambda zip files are properly formatted and accessible +- **Network connectivity**: Confirm VPC configuration allows required outbound connections diff --git a/lib/rag/ingestion/ingestion-image/Dockerfile b/lib/rag/ingestion/ingestion-image/Dockerfile index 22ae46154..984fc3aa7 100644 --- a/lib/rag/ingestion/ingestion-image/Dockerfile +++ b/lib/rag/ingestion/ingestion-image/Dockerfile @@ -1,4 +1,5 @@ -FROM public.ecr.aws/lambda/python:3.11 +ARG BASE_IMAGE=public.ecr.aws/lambda/python:3.11 +FROM ${BASE_IMAGE} ARG BUILD_DIR=build diff --git a/lib/rag/layer/requirements.txt b/lib/rag/layer/requirements.txt index 36d98af3e..e2ef8b0a7 100644 --- a/lib/rag/layer/requirements.txt +++ b/lib/rag/layer/requirements.txt @@ -6,8 +6,9 @@ langchain-community==0.3.9 langchain-openai==0.2.11 opensearch-py==2.6.0 pgvector==0.2.5 -psycopg2-binary==2.9.9 +psycopg2-binary==2.9.10 pypdf==6.0.0 -lxml==5.1.0 -python-docx==1.1.0 -requests-aws4auth==1.2.3 +lxml==5.3.0 +python-docx==1.1.2 +requests-aws4auth==1.3.1 +tiktoken==0.9.0 diff --git a/lib/serve/rest-api/Dockerfile b/lib/serve/rest-api/Dockerfile index ceb768db6..1831629a6 100644 --- a/lib/serve/rest-api/Dockerfile +++ b/lib/serve/rest-api/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE +ARG BASE_IMAGE=python:3.11 FROM ${BASE_IMAGE} # Copy LiteLLM config directly out of the LISA config.yaml file diff --git a/package.json b/package.json index 99dcedbfe..d807c3007 100644 --- a/package.json +++ b/package.json @@ -33,10 +33,11 @@ "cypress" ], "scripts": { - "build": "tsc && npm run build -ws", + "build": "if [ \"$BUILD_ASSETS\" = \"true\" ]; then npm run build:assets; fi", + "build:assets": "./bin/build-assets", "deploy": "tsx ./bin/lisa.ts", "copy-dist": "cp VERSION ./dist/", - "clean": "npm run clean -ws && rm -rf dist node_modules cdk.out build", + "clean": "npm run clean -ws && rm -rf dist node_modules cdk.out build lib/rag/layer/TIKTOKEN_CACHE lib/serve/rest-api/TIKTOKEN_CACHE", "watch": "tsc -w", "test": "jest", "cdk": "cdk", diff --git a/requirements-dev.txt b/requirements-dev.txt index 4533e6e59..ed9330934 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,13 +16,15 @@ pypdf==6.0.0 langchain-community==0.3.9 langchain-openai==0.2.11 langchain==0.3.9 +--only-binary=pyarrow,lxml,psycopg2-binary +pyarrow # Testing pytest==8.3.2 pytest-cov==4.1.0 moto[all]==5.0.3 coverage==7.4.4 -lxml==5.1.0 +lxml==5.3.0 opensearch-py==2.8.0 requests_aws4auth==1.3.1 PyJWT==2.8.0 From 0842dc4534573b2243b0e8bfb0ef97bcad7c8562 Mon Sep 17 00:00:00 2001 From: Bear Danley Date: Fri, 19 Sep 2025 18:36:42 +0000 Subject: [PATCH 03/14] Update package.json build --- package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package.json b/package.json index d807c3007..132923776 100644 --- a/package.json +++ b/package.json @@ -33,7 +33,7 @@ "cypress" ], "scripts": { - "build": "if [ \"$BUILD_ASSETS\" = \"true\" ]; then npm run build:assets; fi", + "build": "tsc && npm run build -ws && if [ \"$BUILD_ASSETS\" = \"true\" ]; then npm run build:assets; fi", "build:assets": "./bin/build-assets", "deploy": "tsx ./bin/lisa.ts", "copy-dist": "cp VERSION ./dist/", From 31b56a3b3a8ea358a886eaa418ebc56ab5412fb8 Mon Sep 17 00:00:00 2001 From: bedanley Date: Tue, 23 Sep 2025 15:47:04 -0600 Subject: [PATCH 04/14] Consolidate Embedding Client --- Makefile | 5 +- lambda/authorizer/lambda_functions.py | 4 +- lambda/models/state_machine/create_model.py | 8 +- lambda/repository/embeddings.py | 97 ++- lambda/repository/lambda_functions.py | 113 ++- .../repository/pipeline_ingest_documents.py | 12 +- lambda/utilities/auth.py | 19 + lambda/utilities/constants.py | 1 + lambda/utilities/file_processing.py | 19 +- lambda/utilities/repository_types.py | 15 +- lambda/utilities/vector_store.py | 2 +- lib/api-base/fastApiContainer.ts | 1 + lib/core/layers/authorizer/requirements.txt | 3 +- lib/core/layers/common/requirements.txt | 4 +- lib/core/layers/fastapi/requirements.txt | 2 +- lib/rag/api/repository.ts | 10 + lib/rag/layer/requirements.txt | 2 +- .../src/api/endpoints/v1/embeddings.py | 2 +- .../src/api/endpoints/v1/generation.py | 8 +- .../rest-api/src/api/endpoints/v1/models.py | 8 +- .../api/endpoints/v2/litellm_passthrough.py | 70 +- lib/serve/rest-api/src/api/routes.py | 31 +- lib/serve/rest-api/src/auth.py | 268 +++++-- lib/serve/rest-api/src/entrypoint.sh | 2 + lib/serve/rest-api/src/handlers/generation.py | 9 +- lib/serve/rest-api/src/handlers/models.py | 4 +- lib/serve/rest-api/src/lisa_serve/__init__.py | 2 +- .../lisa_serve/ecs/embedding/instructor.py | 28 +- .../src/lisa_serve/ecs/textgen/tgi.py | 4 +- .../rest-api/src/lisa_serve/registry/index.py | 2 +- lib/serve/rest-api/src/requirements.txt | 1 + lib/serve/rest-api/src/utils/cache_manager.py | 20 +- lib/serve/rest-api/src/utils/decorators.py | 30 + .../src/utils/generate_litellm_config.py | 6 +- lib/serve/rest-api/src/utils/request_utils.py | 22 +- lib/serve/rest-api/src/utils/resources.py | 2 +- .../react/src/components/chatbot/Chat.tsx | 11 +- .../chatbot/utils/messageBuilder.utils.tsx | 16 +- .../ModelManagementActions.tsx | 8 +- .../ModelManagementComponent.tsx | 1 + .../create-model/BaseModelConfig.tsx | 2 +- .../shared/model/model-management.model.ts | 4 +- lisa-sdk/lisapy/langchain.py | 2 +- lisa-sdk/lisapy/model.py | 94 ++- lisa-sdk/lisapy/repository.py | 109 +++ lisa-sdk/lisapy/types.py | 51 +- lisa-sdk/pyproject.toml | 6 +- .../tests/test_langchain_management_key.py | 57 ++ package-lock.json | 8 + requirements-dev.txt | 1 + test/cdk/stacks/nag.test.ts | 86 --- test/lambda/test_file_processing.py | 55 ++ test/lambda/test_pipeline_ingest_documents.py | 6 +- test/lambda/test_repository_lambda.py | 360 ++++------ test/lambda/test_similarity_functions.py | 203 ++++++ test/python/README.md | 126 ++++ test/python/integration-setup-test.py | 665 ++++++++++++++++++ test/python/integration-setup-test.sh | 153 ++++ 58 files changed, 2256 insertions(+), 604 deletions(-) create mode 100644 lib/serve/rest-api/src/utils/decorators.py create mode 100644 lisa-sdk/tests/test_langchain_management_key.py delete mode 100644 test/cdk/stacks/nag.test.ts create mode 100644 test/lambda/test_similarity_functions.py create mode 100644 test/python/README.md create mode 100644 test/python/integration-setup-test.py create mode 100755 test/python/integration-setup-test.sh diff --git a/Makefile b/Makefile index 84af21cd8..3e9d265ac 100644 --- a/Makefile +++ b/Makefile @@ -206,7 +206,10 @@ modelCheck: fi; \ echo "Converting and uploading safetensors for model: $$MODEL_ID"; \ tgiImage=$$(yq -r '[.ecsModels[] | select(.inferenceContainer == "tgi") | .baseImage] | first' $(PROJECT_DIR)/config-custom.yaml); \ - echo $$tgiImage; \ + if [ "$$tgiImage" = "null" ] || [ -z "$$tgiImage" ]; then \ + tgiImage="ghcr.io/huggingface/text-generation-inference:latest"; \ + fi; \ + echo "Using TGI image: $$tgiImage"; \ $(PROJECT_DIR)/scripts/convert-and-upload-model.sh -m $$MODEL_ID -s $(MODEL_BUCKET) -a $$access_token -t $$tgiImage -d $$localModelDir; \ fi; \ fi; \ diff --git a/lambda/authorizer/lambda_functions.py b/lambda/authorizer/lambda_functions.py index ed6716e89..db1aca7b9 100644 --- a/lambda/authorizer/lambda_functions.py +++ b/lambda/authorizer/lambda_functions.py @@ -18,7 +18,6 @@ import os import ssl from datetime import datetime -from functools import cache from typing import Any, Dict import boto3 @@ -26,6 +25,7 @@ import jwt import requests from botocore.exceptions import ClientError +from cachetools import cached, TTLCache from utilities.common_functions import authorization_wrapper, get_id_token, get_property_path, retry_config logger = logging.getLogger(__name__) @@ -203,7 +203,7 @@ def find_jwt_username(jwt_data: dict[str, str]) -> str: return username -@cache +@cached(cache=TTLCache(maxsize=1, ttl=300)) def get_management_tokens() -> list[str]: """Return secret management tokens if they exist.""" secret_tokens: list[str] = [] diff --git a/lambda/models/state_machine/create_model.py b/lambda/models/state_machine/create_model.py index 769acbaa7..0af97fead 100644 --- a/lambda/models/state_machine/create_model.py +++ b/lambda/models/state_machine/create_model.py @@ -269,11 +269,17 @@ def camelize_object(o): # type: ignore[no-untyped-def] stack_name = payload.get("stackName", None) if not stack_name: + # Log the full payload for debugging + logger.error(f"ECS Model Deployer response: {payload}") + error_message = payload.get("errorMessage", "Unknown error") + error_type = payload.get("errorType", "Unknown error type") + raise StackFailedToCreateException( json.dumps( { - "error": "Failed to create Model CloudFormation Stack. Please validate model parameters are valid.", + "error": f"Failed to create Model CloudFormation Stack. {error_type}: {error_message}", "event": event, + "deployer_response": payload, } ) ) diff --git a/lambda/repository/embeddings.py b/lambda/repository/embeddings.py index 9f4863baf..a1fbe7faa 100644 --- a/lambda/repository/embeddings.py +++ b/lambda/repository/embeddings.py @@ -14,13 +14,15 @@ import logging import os -from typing import Any, List +from typing import List import boto3 import requests from lisapy.langchain import LisaOpenAIEmbeddings -from utilities.common_functions import get_cert_path, retry_config -from utilities.validation import ValidationError +from pydantic import BaseModel, field_validator +from utilities.auth import get_management_key +from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config +from utilities.validation import validate_model_name, ValidationError logger = logging.getLogger(__name__) ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) @@ -30,39 +32,47 @@ lisa_api_endpoint = "" -class PipelineEmbeddings: +class RagEmbeddings(BaseModel): """ - Handles document embeddings for pipeline processing using management credentials. - - This class provides methods to embed both single queries and batches of documents - using the LISA API with management-level authentication. + Handles document embeddings through LiteLLM using management credentials. """ model_name: str - - def __init__(self, model_name: str) -> None: + token: str + lisa_api_endpoint: str + base_url: str + cert_path: str | bool + + @field_validator("model_name") + @classmethod + def validate_model_name(cls, v: str) -> str: + validate_model_name(v) + return v + + def __init__(self, model_name: str, id_token: str | None = None, **data) -> None: + # Prepare initialization data + init_data = {"model_name": model_name, **data} try: - self.model_name = model_name - # Get the management key secret name from SSM Parameter Store - secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) - secret_name = secret_name_param["Parameter"]["Value"] - - # Get the management token from Secrets Manager using the secret name - secret_response = secrets_client.get_secret_value(SecretId=secret_name) - self.token = secret_response["SecretString"] - - # Get the API endpoint from SSM - lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) - self.base_url = f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve" - - # Get certificate path for SSL verification - self.cert_path = get_cert_path(iam_client) - + # Use management token if id_token is not provided + if id_token is None: + logger.info("Using management key for ingestion") + init_data["token"] = get_management_key() + else: + init_data["token"] = id_token + + init_data["lisa_api_endpoint"] = get_rest_api_container_endpoint() + init_data["base_url"] = get_rest_api_container_endpoint() + init_data["cert_path"] = get_cert_path(iam_client) + + super().__init__(**init_data) logger.info("Successfully initialized pipeline embeddings") except Exception: logger.error("Failed to initialize pipeline embeddings", exc_info=True) raise + class Config: + arbitrary_types_allowed = True + def embed_documents(self, texts: List[str]) -> List[List[float]]: """ Generate embeddings for a list of documents. @@ -88,14 +98,13 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: response = requests.post( url, json=request_data, - headers={"Authorization": self.token, "Content-Type": "application/json"}, + headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"}, verify=self.cert_path, # Use proper SSL verification timeout=300, # 5 minute timeout ) if response.status_code != 200: logger.error(f"Embedding request failed with status {response.status_code}") - logger.error(f"Response content: {response.text}") raise Exception(f"Embedding request failed with status {response.status_code}") result = response.json() @@ -150,40 +159,22 @@ def embed_query(self, text: str) -> List[float]: return self.embed_documents([text])[0] -def get_embeddings_pipeline(model_name: str) -> Any: +def get_openai_embeddings(model_name: str, id_token: str | None = None) -> LisaOpenAIEmbeddings: """ - Get embeddings for pipeline requests using management token. + Initialize and return an embeddings client for the specified model. Do not use for embedding documents since OpenAI + client does not use the provided model for embedding. Args: model_name: Name of the embedding model to use - - Raises: - ValidationError: If model name is invalid - Exception: If API request fails - """ - logger.info("Starting pipeline embeddings request") - - return PipelineEmbeddings(model_name=model_name) - - -def get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings: - """ - Initialize and return an embeddings client for the specified model. - - Args: - model_name: Name of the embedding model to use - id_token: Authentication token for API access + id_token: Authentication token for API access. If not provided, uses management token. Returns: LisaOpenAIEmbeddings: Configured embeddings client """ - global lisa_api_endpoint - - if not lisa_api_endpoint: - lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) - lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"] + if id_token is None: + id_token = get_management_key() - base_url = f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve" + base_url = get_rest_api_container_endpoint() cert_path = get_cert_path(iam_client) embedding = LisaOpenAIEmbeddings( diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index 9422b81a9..a20dc0f43 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -23,7 +23,7 @@ from boto3.dynamodb.types import TypeSerializer from botocore.config import Config from models.domain_objects import FixedChunkingStrategy, IngestionJob, IngestionStatus, RagDocument -from repository.embeddings import get_embeddings +from repository.embeddings import RagEmbeddings from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService from repository.rag_document_repo import RagDocumentRepository @@ -40,7 +40,6 @@ region_name = os.environ["AWS_REGION"] session = boto3.Session() ssm_client = boto3.client("ssm", region_name, config=retry_config) -secrets_client = boto3.client("secretsmanager", region_name, config=retry_config) iam_client = boto3.client("iam", region_name, config=retry_config) step_functions_client = boto3.client("stepfunctions", region_name, config=retry_config) ddb_client = boto3.client("dynamodb", region_name, config=retry_config) @@ -109,6 +108,7 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: - queryStringParameters.query: Search query text - queryStringParameters.repositoryType: Type of repository - queryStringParameters.topK (optional): Number of results to return (default: 3) + - queryStringParameters.score (optional): Include similarity scores (default: false) context (dict): The Lambda context object Returns: @@ -122,6 +122,7 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: model_name = query_string_params["modelName"] query = query_string_params["query"] top_k = query_string_params.get("topK", 3) + include_score = query_string_params.get("score", "false").lower() == "true" repository_id = event["pathParameters"]["repositoryId"] repository = vs_repo.find_repository_by_id(repository_id) @@ -139,7 +140,7 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: repository_id=repository_id, ) else: - embeddings = get_embeddings(model_name=model_name, id_token=id_token) + embeddings = RagEmbeddings(model_name=model_name, id_token=id_token) vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) # empty vector stores do not have an initialize index. Return empty docs @@ -148,11 +149,11 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: ): logger.info(f"Index {model_name} does not exist. Returning empty docs.") else: - results = vs.similarity_search( - query, - k=top_k, + docs = ( + _similarity_search_with_score(vs, query, top_k, repository) + if include_score + else _similarity_search(vs, query, top_k) ) - docs = [{"page_content": r.page_content, "metadata": r.metadata} for r in results] doc_content = [ { "Document": { @@ -536,6 +537,48 @@ def delete(event: dict, context: dict) -> Any: return {"status": "success", "executionArn": response["executionArn"]} +@api_wrapper +@admin_only +def delete_index(event: dict, context: dict) -> None: + """ + Clear the vector store for the specified repository and model. + + Args: + event (dict): The Lambda event object containing path parameters + context (dict): The Lambda context object + """ + path_params = event.get("pathParameters", {}) or {} + repository_id = path_params.get("repositoryId", None) + if not repository_id: + raise ValidationError("repositoryId is required") + model_name = path_params.get("modelName", None) + if not model_name: + raise ValidationError("modelName is required") + + repository = vs_repo.find_repository_by_id(repository_id=repository_id) + id_token = get_id_token(event) + embeddings = RagEmbeddings(model_name=model_name, id_token=id_token) + vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) + + try: + if RepositoryType.is_type(repository, RepositoryType.OPENSEARCH): + if vs.client.indices.exists(index=model_name): + vs.client.indices.delete(index=model_name) + logger.info(f"Deleted OpenSearch index: {model_name}") + else: + logger.info(f"OpenSearch index {model_name} does not exist") + elif RepositoryType.is_type(repository, RepositoryType.PGVECTOR): + # For PGVector, delete all documents in the collection + vs.delete_collection() + logger.info(f"Deleted PGVector collection: {model_name}") + else: + logger.error(f"Unsupported repository type: {repository.get('type')}") + return {"status": "error", "message": "Repository is not supported"} + except Exception as e: + logger.error(f"Failed to clear vector store: {e}") + return {"status": "error", "message": str(e)} + + def _remove_legacy(repository_id: str) -> None: registered_repositories = ssm_client.get_parameter(Name=os.environ["REGISTERED_REPOSITORIES_PS"]) registered_repositories = json.loads(registered_repositories["Parameter"]["Value"]) @@ -549,3 +592,59 @@ def _remove_legacy(repository_id: str) -> None: Type="String", Overwrite=True, ) + + +def _similarity_search(vs, query: str, top_k: int) -> list[dict[str, Any]]: + """Perform similarity search without scores. + + Args: + vs: Vector store instance + query: Search query string + top_k: Number of top results to return + + Returns: + List of documents with page_content and metadata + """ + results = vs.similarity_search_with_score( + query, + k=top_k, + ) + + return [{"page_content": doc.page_content, "metadata": doc.metadata} for doc, score in results] + + +def _similarity_search_with_score(vs, query: str, top_k: int, repository: dict) -> list[dict[str, Any]]: + """Perform similarity search with normalized scores. + + Args: + vs: Vector store instance + query: Search query string + top_k: Number of top results to return + repository: Repository configuration dict + + Returns: + List of documents with page_content, metadata, and similarity_score + """ + results = vs.similarity_search_with_score( + query, + k=top_k, + ) + docs = [] + for i, (doc, score) in enumerate(results): + similarity_score = RepositoryType.get_type(repository=repository).calculate_similarity_score(score) + logger.info( + f"Result {i + 1}: Raw Score={score:.4f}, Similarity={similarity_score:.4f}, " + + f"Content: {doc.page_content[:200]}..." + ) + logger.info(f"Result {i + 1} metadata: {doc.metadata}") + docs.append( + { + "page_content": doc.page_content, + "metadata": {**doc.metadata, "similarity_score": similarity_score}, + } + ) + + if results and max(score for _, score in results) < 0.3: + logger.warning(f"All similarity < 0.3 for query '{query}' - possible embedding model mismatch") + + return docs diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 83598ef9a..2b82080e2 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -21,7 +21,7 @@ import boto3 from models.domain_objects import FixedChunkingStrategy, IngestionJob, IngestionStatus, IngestionType, RagDocument -from repository.embeddings import get_embeddings_pipeline +from repository.embeddings import RagEmbeddings from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService from repository.rag_document_repo import RagDocumentRepository @@ -48,10 +48,12 @@ def pipeline_ingest(job: IngestionJob) -> None: + texts = [] + metadatas = [] + all_ids = [] try: # chunk and save chunks in vector store repository = vs_repo.find_repository_by_id(job.repository_id) - all_ids = [] if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): ingest_document_to_kb( s3_client=s3, @@ -104,15 +106,15 @@ def pipeline_ingest(job: IngestionJob) -> None: logging.info(f"Successfully ingested document {job.s3_path} ({len(all_ids)} chunks) into {job.collection_id}") except Exception as e: ingestion_job_repository.update_status(job, IngestionStatus.INGESTION_FAILED) - error_msg = f"Failed to process document: {str(e)}" logger.error(error_msg, exc_info=True) + logger.error(f"Job: {job.model_dump_json(indent=2)}") raise Exception(error_msg) def remove_document_from_vectorstore(doc: RagDocument) -> None: # Delete from the Vector Store - embeddings = get_embeddings_pipeline(model_name=doc.collection_id) + embeddings = RagEmbeddings(model_name=doc.collection_id) vector_store = get_vector_store_client( doc.repository_id, index=doc.collection_id, @@ -280,7 +282,7 @@ def store_chunks_in_vectorstore( texts: List[str], metadatas: List[Dict], repository_id: str, embedding_model: str ) -> List[str]: """Store document chunks in vector store.""" - embeddings = get_embeddings_pipeline(model_name=embedding_model) + embeddings = RagEmbeddings(model_name=embedding_model) vs = get_vector_store_client( repository_id, index=embedding_model, diff --git a/lambda/utilities/auth.py b/lambda/utilities/auth.py index c6f5f5bd6..d98466eb0 100644 --- a/lambda/utilities/auth.py +++ b/lambda/utilities/auth.py @@ -16,11 +16,23 @@ from functools import wraps from typing import Any, Callable, Dict +import boto3 +from botocore.config import Config from utilities.common_functions import get_groups from utilities.exceptions import HTTPException logger = logging.getLogger(__name__) +retry_config = Config( + retries={ + "max_attempts": 3, + "mode": "standard", + }, +) + +secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + def get_username(event: dict) -> str: """Get the username from the event.""" @@ -46,3 +58,10 @@ def wrapper(event: Dict[str, Any], context: Dict[str, Any], *args: Any, **kwargs return func(event, context, *args, **kwargs) return wrapper + + +def get_management_key() -> str: + secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) + secret_name = secret_name_param["Parameter"]["Value"] + secret_response = secrets_client.get_secret_value(SecretId=secret_name) + return secret_response["SecretString"] diff --git a/lambda/utilities/constants.py b/lambda/utilities/constants.py index 0dd0b48f4..a263d03ff 100644 --- a/lambda/utilities/constants.py +++ b/lambda/utilities/constants.py @@ -16,3 +16,4 @@ PDF_FILE = "pdf" TEXT_FILE = "txt" DOCX_FILE = "docx" +RICH_TEXT_FILE = "rtf" diff --git a/lambda/utilities/file_processing.py b/lambda/utilities/file_processing.py index f2f9c6857..db04f85ce 100644 --- a/lambda/utilities/file_processing.py +++ b/lambda/utilities/file_processing.py @@ -27,7 +27,7 @@ from models.domain_objects import ChunkingStrategyType, IngestionJob from pypdf import PdfReader from pypdf.errors import PdfReadError -from utilities.constants import DOCX_FILE, PDF_FILE, TEXT_FILE +from utilities.constants import DOCX_FILE, PDF_FILE, RICH_TEXT_FILE, TEXT_FILE from utilities.exceptions import RagUploadException logger = logging.getLogger(__name__) @@ -47,12 +47,13 @@ def _extract_text_by_content_type(content_type: str, s3_object: dict) -> str: extraction_functions = { PDF_FILE: _extract_pdf_content, DOCX_FILE: _extract_docx_content, - TEXT_FILE: lambda obj: obj["Body"].read(), + TEXT_FILE: _extract_text_content, + RICH_TEXT_FILE: _extract_text_content, } extraction_function = extraction_functions.get(content_type) if extraction_function: - return str(extraction_function(s3_object)) + return extraction_function(s3_object) else: logger.error(f"File has unsupported content type: {content_type}") raise RagUploadException("Unsupported file type") @@ -126,6 +127,18 @@ def _extract_docx_content(s3_object: dict) -> str: return output +def _extract_text_content(s3_object: dict) -> str: + """ + Extracts text content from an S3 object. Decode as + utf-8 to properly read special characters + + Parameters + ---------- + s3_object (dict): an S3 object containing a text file body. + """ + return s3_object["Body"].read().decode("utf-8", errors="replace") + + def generate_chunks(ingestion_job: IngestionJob) -> list[Document]: """Generate chunks from an ingestion job. diff --git a/lambda/utilities/repository_types.py b/lambda/utilities/repository_types.py index 2e9b054aa..0ec62cdd5 100644 --- a/lambda/utilities/repository_types.py +++ b/lambda/utilities/repository_types.py @@ -21,6 +21,19 @@ class RepositoryType(str, Enum): OPENSEARCH = "opensearch" BEDROCK_KB = "bedrock_knowledge_base" + @classmethod + def get_type(cls, repository: Dict[str, Any]) -> "RepositoryType": + return RepositoryType(repository.get("type")) + @classmethod def is_type(cls, repository: Dict[str, Any], repo_type: "RepositoryType") -> bool: - return repository.get("type") == repo_type.value + return repository.get("type") == repo_type + + def calculate_similarity_score(self, score: float) -> float: + # Convert cosine distance to similarity for PGVector + # PGVector returns cosine distance (0-2 range, lower = more similar) + # Convert to similarity (0-1 range, higher = more similar) + if self == RepositoryType.PGVECTOR: + return max(0.0, 1.0 - (score / 2.0)) + else: + return score diff --git a/lambda/utilities/vector_store.py b/lambda/utilities/vector_store.py index 46370f400..b7a6cfdb9 100644 --- a/lambda/utilities/vector_store.py +++ b/lambda/utilities/vector_store.py @@ -61,7 +61,7 @@ def get_vector_store_client(repository_id: str, index: str, embeddings: Embeddin return OpenSearchVectorSearch( opensearch_url=opensearch_endpoint, - index_name=index.lower(), + index_name=index, embedding_function=embeddings, http_auth=auth, timeout=300, diff --git a/lib/api-base/fastApiContainer.ts b/lib/api-base/fastApiContainer.ts index 6bdc6276b..5dc802bcb 100644 --- a/lib/api-base/fastApiContainer.ts +++ b/lib/api-base/fastApiContainer.ts @@ -92,6 +92,7 @@ export class FastApiContainer extends Construct { AWS_REGION_NAME: config.region, // for supporting SageMaker endpoints in LiteLLM THREADS: Ec2Metadata.get('m5.large').vCpus.toString(), LITELLM_KEY: config.litellmConfig.db_key, + OPENAI_API_KEY: 'dummy-key', // pragma: allowlist secret - Required for OpenAI compatible self-hosted models through LiteLLM TIKTOKEN_CACHE_DIR: '/app/TIKTOKEN_CACHE' }; diff --git a/lib/core/layers/authorizer/requirements.txt b/lib/core/layers/authorizer/requirements.txt index fb0ff6550..5bbe17394 100644 --- a/lib/core/layers/authorizer/requirements.txt +++ b/lib/core/layers/authorizer/requirements.txt @@ -1,4 +1,5 @@ # urllib3<2 // Provided by Lambda -requests==2.32.4 +# cachetools==5.5.0 // provided by Common Layer +# requests==2.32.4 // provided by Common Layer cryptography==44.0.1 PyJWT==2.9.0 diff --git a/lib/core/layers/common/requirements.txt b/lib/core/layers/common/requirements.txt index 7488d54a0..c8fa2bf4f 100644 --- a/lib/core/layers/common/requirements.txt +++ b/lib/core/layers/common/requirements.txt @@ -1,4 +1,6 @@ # boto3>=1.34.131 // Provided by Lambda # botocore>=1.34.131 // Provided by Lambda # urllib3<2 // Provided by Lambda -psycopg2-binary==2.9.9 +psycopg2-binary==2.9.10 +cachetools==5.5.0 +requests==2.32.4 diff --git a/lib/core/layers/fastapi/requirements.txt b/lib/core/layers/fastapi/requirements.txt index 11af70d39..0f576f8da 100644 --- a/lib/core/layers/fastapi/requirements.txt +++ b/lib/core/layers/fastapi/requirements.txt @@ -1,5 +1,5 @@ # boto3==1.34.131 // Provided by Lambda +# requests==2.32.4 // provided by Common Layer fastapi==0.111.0 mangum==0.17.0 pydantic==2.8.2 -requests==2.32.4 diff --git a/lib/rag/api/repository.ts b/lib/rag/api/repository.ts index 4a8e6a360..abbe6014a 100644 --- a/lib/rag/api/repository.ts +++ b/lib/rag/api/repository.ts @@ -125,6 +125,16 @@ export class RepositoryApi extends Construct { ...baseEnvironment, }, }, + { + name: 'delete_index', + resource: 'repository', + description: 'Delete an index within a repository', + path: 'repository/{repositoryId}/index/{modelName}', + method: 'DELETE', + environment: { + ...baseEnvironment, + }, + }, { name: 'similarity_search', resource: 'repository', diff --git a/lib/rag/layer/requirements.txt b/lib/rag/layer/requirements.txt index e2ef8b0a7..f0173c448 100644 --- a/lib/rag/layer/requirements.txt +++ b/lib/rag/layer/requirements.txt @@ -1,12 +1,12 @@ # boto3>=1.34.131 // Provided by Lambda # botocore>=1.34.131 // Provided by Lambda # urllib3<2 // Provided by Lambda +# psycopg2-binary==2.9.10 // provided by Common Layer langchain==0.3.9 langchain-community==0.3.9 langchain-openai==0.2.11 opensearch-py==2.6.0 pgvector==0.2.5 -psycopg2-binary==2.9.10 pypdf==6.0.0 lxml==5.3.0 python-docx==1.1.2 diff --git a/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py b/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py index 163e8303b..454fd1ced 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py @@ -27,7 +27,7 @@ router = APIRouter() -@router.post(f"/{RestApiResource.EMBEDDINGS.value}") +@router.post(f"/{RestApiResource.EMBEDDINGS}") async def embeddings(request: EmbeddingsRequest) -> JSONResponse: """Text embeddings.""" response = await handle_embeddings(request.dict()) diff --git a/lib/serve/rest-api/src/api/endpoints/v1/generation.py b/lib/serve/rest-api/src/api/endpoints/v1/generation.py index 6413035c8..d80ff6f14 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/generation.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/generation.py @@ -33,7 +33,7 @@ router = APIRouter() -@router.post(f"/{RestApiResource.GENERATE.value}") +@router.post(f"/{RestApiResource.GENERATE}") async def generate(request: GenerateRequest) -> JSONResponse: """Text generation.""" response = await handle_generate(request.dict()) @@ -41,7 +41,7 @@ async def generate(request: GenerateRequest) -> JSONResponse: return JSONResponse(content=response, status_code=200) -@router.post(f"/{RestApiResource.GENERATE_STREAM.value}") +@router.post(f"/{RestApiResource.GENERATE_STREAM}") async def generate_stream(request: GenerateStreamRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( @@ -50,7 +50,7 @@ async def generate_stream(request: GenerateStreamRequest) -> StreamingResponse: ) -@router.post(f"/{RestApiResource.OPENAI_CHAT_COMPLETIONS.value}") +@router.post(f"/{RestApiResource.OPENAI_CHAT_COMPLETIONS}") async def openai_chat_completion_generate_stream(request: OpenAIChatCompletionsRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( @@ -59,7 +59,7 @@ async def openai_chat_completion_generate_stream(request: OpenAIChatCompletionsR ) -@router.post(f"/{RestApiResource.OPENAI_COMPLETIONS.value}") +@router.post(f"/{RestApiResource.OPENAI_COMPLETIONS}") async def openai_completion_generate_stream(request: OpenAICompletionsRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( diff --git a/lib/serve/rest-api/src/api/endpoints/v1/models.py b/lib/serve/rest-api/src/api/endpoints/v1/models.py index e3d374552..3bcb353c0 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/models.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/models.py @@ -33,7 +33,7 @@ router = APIRouter() -@router.get(f"/{RestApiResource.DESCRIBE_MODEL.value}") +@router.get(f"/{RestApiResource.DESCRIBE_MODEL}") async def describe_model( provider: str = Query( None, @@ -52,7 +52,7 @@ async def describe_model( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.DESCRIBE_MODELS.value}") +@router.get(f"/{RestApiResource.DESCRIBE_MODELS}") async def describe_models( model_types: Optional[List[ModelType]] = Query( None, @@ -69,7 +69,7 @@ async def describe_models( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.LIST_MODELS.value}") +@router.get(f"/{RestApiResource.LIST_MODELS}") async def list_models( model_types: Optional[List[ModelType]] = Query( None, @@ -86,7 +86,7 @@ async def list_models( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.OPENAI_LIST_MODELS.value}") +@router.get(f"/{RestApiResource.OPENAI_LIST_MODELS}") async def openai_list_models() -> JSONResponse: """List models for OpenAI Compatibility. Only returns TEXTGEN models.""" response = await handle_openai_list_models() diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index 880eb2b60..1341b3469 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -25,7 +25,7 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from starlette.status import HTTP_401_UNAUTHORIZED -from ....auth import get_authorization_token, get_jwks_client, id_token_is_valid, is_idp_used, is_user_in_group +from ....auth import Authorizer # Local LiteLLM installation URL. By default, LiteLLM runs on port 4000. Change the port here if the # port was changed as part of the LiteLLM startup in entrypoint.sh @@ -102,40 +102,10 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: litellm_path = f"{LITELLM_URL}/{api_path}" headers = dict(request.headers.items()) - if not is_valid_management_token(headers): - # If not handling an OpenAI request, we will also check if the user is an Admin user before allowing the - # request, otherwise, we will block it. This prevents non-admins from invoking model management APIs - # directly. If LISA Serve is deployed without an IdP configuration, we cannot determine who is an admin - # user, so all API routes will default to being openly accessible. - if is_idp_used(): - client_id = os.environ.get("CLIENT_ID", "") - authority = os.environ.get("AUTHORITY", "") - admin_group = os.environ.get("ADMIN_GROUP", "") - user_group = os.environ.get("USER_GROUP", "") - jwt_groups_property = os.environ.get("JWT_GROUPS_PROP", "") - - id_token = get_authorization_token(headers=headers, header_name="Authorization") - jwks_client = get_jwks_client() - if jwt_data := id_token_is_valid( - id_token=id_token, authority=authority, client_id=client_id, jwks_client=jwks_client - ): - if user_group != "" and not is_user_in_group( - jwt_data=jwt_data, group=user_group, jwt_groups_property=jwt_groups_property - ): - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough" - ) - if api_path not in OPENAI_ROUTES: - if not is_user_in_group( - jwt_data=jwt_data, group=admin_group, jwt_groups_property=jwt_groups_property - ): - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough" - ) - else: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough" - ) + authorizer = Authorizer() + require_admin = api_path not in OPENAI_ROUTES + if not await authorizer.can_access(request, require_admin): + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated in litellm_passthrough") # At this point in the request, we have already validated auth with IdP or persistent token. By using LiteLLM for # model management, LiteLLM requires an admin key, and that forces all requests to require a key as well. To avoid @@ -154,32 +124,6 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: return StreamingResponse(generate_response(response.iter_lines()), status_code=response.status_code) else: # not a streaming request response = requests.request(method=http_method, url=litellm_path, json=params, headers=headers) + if response.status_code != 200: + logger.error(f"LiteLLM error response: {response.text}") return JSONResponse(response.json(), status_code=response.status_code) - - -def refresh_management_tokens() -> list[str]: - """Return secret management tokens if they exist.""" - secret_tokens = [] - - try: - secret_tokens.append( - secrets_manager.get_secret_value(SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSCURRENT")[ - "SecretString" - ] - ) - secret_tokens.append( - secrets_manager.get_secret_value( - SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSPREVIOUS" - )["SecretString"] - ) - except Exception: - logger.info(f"No previous secret version for {os.environ.get('MANAGEMENT_KEY_NAME')}") - - return secret_tokens - - -def is_valid_management_token(headers: dict[str, str]) -> bool: - """Return if API Token from request headers is valid if found.""" - secret_tokens = refresh_management_tokens() - token = get_authorization_token(headers=headers, header_name="Authorization").strip() - return token in secret_tokens diff --git a/lib/serve/rest-api/src/api/routes.py b/lib/serve/rest-api/src/api/routes.py index ca79631ae..08e052796 100644 --- a/lib/serve/rest-api/src/api/routes.py +++ b/lib/serve/rest-api/src/api/routes.py @@ -20,20 +20,20 @@ from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse -from ..auth import OIDCHTTPBearer +from ..auth import Authorizer from .endpoints.v2 import litellm_passthrough logger = logging.getLogger(__name__) router = APIRouter() -if os.getenv("USE_AUTH", "true").lower() == "false": - dependencies = [] - logger.info("Auth disabled") -else: - security = OIDCHTTPBearer() - dependencies = [Depends(security)] +dependencies = [] +if os.getenv("USE_AUTH", "true").lower() == "true": logger.info("Auth enabled") + security = Authorizer() + dependencies = [Depends(security)] +else: + logger.info("Auth disabled") router.include_router( litellm_passthrough.router, prefix="/v2/serve", tags=["litellm_passthrough"], dependencies=dependencies @@ -46,6 +46,17 @@ async def health_check() -> JSONResponse: This needs to match the path in the config.yaml file. """ - content = {"status": "OK"} - - return JSONResponse(content=content, status_code=200) + try: + # Basic health verification - check if required environment variables are set + required_vars = ["AWS_REGION", "LOG_LEVEL"] + missing_vars = [var for var in required_vars if not os.getenv(var)] + + if missing_vars: + content = {"status": "UNHEALTHY", "missing_env_vars": missing_vars} + return JSONResponse(content=content, status_code=503) + + content = {"status": "OK"} + return JSONResponse(content=content, status_code=200) + except Exception as e: + content = {"status": "UNHEALTHY", "error": str(e)} + return JSONResponse(content=content, status_code=503) diff --git a/lib/serve/rest-api/src/auth.py b/lib/serve/rest-api/src/auth.py index c83d86392..7cc88ba5e 100644 --- a/lib/serve/rest-api/src/auth.py +++ b/lib/serve/rest-api/src/auth.py @@ -13,27 +13,28 @@ # limitations under the License. """Authentication for FastAPI app.""" +import asyncio import os import ssl import sys +import threading from datetime import datetime +from enum import Enum from pathlib import Path -from time import time from typing import Any, Dict, Optional import boto3 import jwt import requests +from cachetools import TTLCache +from cachetools.keys import hashkey from fastapi import HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from loguru import logger from starlette.status import HTTP_401_UNAUTHORIZED -# The following are field names, not passwords or tokens -API_KEY_HEADER_NAMES = [ - "Authorization", # OpenAI Bearer token format, collides with IdP, but that's okay for this use case - "Api-Key", # pragma: allowlist secret # Azure key format, can be used with Continue IDE plugin -] +from .utils.decorators import singleton + TOKEN_EXPIRATION_NAME = "tokenExpiration" # nosec B105 TOKEN_TABLE_NAME = "TOKEN_TABLE_NAME" # nosec B105 USE_AUTH = "USE_AUTH" @@ -51,6 +52,19 @@ ) +# The following are field names, not passwords or tokens +class AuthHeaders(str, Enum): + """API key header names.""" + + AUTHORIZATION = "Authorization" # OpenAI Bearer token format, collides with IdP, but that's okay for this use case + API_KEY = "Api-Key" # pragma: allowlist secret # Azure key format, can be used with Continue IDE plugin + + @classmethod + def values(cls) -> list[str]: + """Return list of header values.""" + return list(cls) + + def is_idp_used() -> bool: """Get if the identity provider is being used based on environment variable.""" return os.environ.get(USE_AUTH, "false").lower() == "true" @@ -88,6 +102,7 @@ def id_token_is_valid( id_token: str, client_id: str, authority: str, jwks_client: jwt.PyJWKClient ) -> Optional[Dict[str, Any]]: """Check whether an ID token is valid and return decoded data.""" + logger.info(f"Auth Token: {id_token}") try: signing_key = jwks_client.get_signing_key_from_jwt(id_token) data: Dict[str, Any] = jwt.decode( @@ -106,13 +121,13 @@ def id_token_is_valid( }, ) return data - except jwt.exceptions.PyJWTError as e: + except (jwt.exceptions.PyJWTError, jwt.exceptions.DecodeError) as e: logger.exception(e) return None def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: str) -> bool: - """Check if the user is an admin.""" + """Check if the user is in group.""" props = jwt_groups_property.split(".") current_node = jwt_data for prop in props: @@ -123,7 +138,7 @@ def is_user_in_group(jwt_data: dict[str, Any], group: str, jwt_groups_property: return group in current_node -def get_authorization_token(headers: Dict[str, str], header_name: str) -> str: +def get_authorization_token(headers: Dict[str, str], header_name: str = AuthHeaders.AUTHORIZATION) -> str: """Get Bearer token from Authorization headers if it exists.""" if header_name in headers: return headers.get(header_name, "").removeprefix("Bearer").strip() @@ -133,29 +148,38 @@ def get_authorization_token(headers: Dict[str, str], header_name: str) -> str: class OIDCHTTPBearer(HTTPBearer): """OIDC based bearer token authenticator.""" - def __init__(self, **kwargs: Dict[str, Any]): + def __init__(self, authority: Optional[str] = None, client_id: Optional[str] = None, **kwargs: Dict[str, Any]): super().__init__(**kwargs) - self._token_authorizer = ApiTokenAuthorizer() - self._management_token_authorizer = ManagementTokenAuthorizer() + self.authority = authority or os.environ.get("AUTHORITY", "") + self.client_id = client_id or os.environ.get("CLIENT_ID", "") + self.jwks_client = get_jwks_client() - self._jwks_client = get_jwks_client() - - async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]: - """Verify the provided bearer token or API Key. API Key will take precedence over the bearer token.""" - if self._token_authorizer.is_valid_api_token(request.headers): - return None # valid API token, not continuing with OIDC auth - elif self._management_token_authorizer.is_valid_api_token(request.headers): - logger.info("looks like a valid mgmt token") - return None # valid management token, not continuing with OIDC auth + async def id_token_is_valid(self, request: Request) -> Optional[Dict[str, Any]]: + """Check whether an ID token is valid and return decoded data.""" http_auth_creds = await super().__call__(request) - if not id_token_is_valid( - id_token=http_auth_creds.credentials, - authority=os.environ["AUTHORITY"], - client_id=os.environ["CLIENT_ID"], - jwks_client=self._jwks_client, - ): - raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated") - return http_auth_creds + id_token = http_auth_creds.credentials + logger.info(f"Auth Token: {id_token}") + try: + signing_key = self.jwks_client.get_signing_key_from_jwt(id_token) + data: Dict[str, Any] = jwt.decode( + id_token, + signing_key.key, + algorithms=["RS256"], + issuer=self.authority, + audience=self.client_id, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "verify_aud": True, + "verify_iss": True, + }, + ) + return data + except (jwt.exceptions.PyJWTError, jwt.exceptions.DecodeError) as e: + logger.exception(e) + return None class ApiTokenAuthorizer: @@ -175,12 +199,15 @@ def _get_token_info(self, token: str) -> Any: ddb_response = self._token_table.get_item(Key={"token": token}, ReturnConsumedCapacity="NONE") return ddb_response.get("Item", None) - def is_valid_api_token(self, headers: Dict[str, str]) -> bool: + async def is_valid_api_token(self, headers: Dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" - for header_name in API_KEY_HEADER_NAMES: - token = headers.get(header_name, "").removeprefix("Bearer").strip() + + for header_name in AuthHeaders.values(): + token = get_authorization_token(headers, header_name) + + logger.info(f"API Auth Token: {token}") if token: - token_info = self._get_token_info(token) + token_info = await asyncio.to_thread(self._get_token_info, token) if token_info: token_expiration = int(token_info.get(TOKEN_EXPIRATION_NAME, datetime.max.timestamp())) current_time = int(datetime.now().timestamp()) @@ -193,33 +220,154 @@ class ManagementTokenAuthorizer: """Class for checking Management tokens against a SecretsManager secret.""" def __init__(self) -> None: - self._secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) - self._secret_tokens: list[str] = [] - self._last_run = 0 - - def _refreshTokens(self) -> None: - """Refresh secret management tokens.""" - current_time = int(time()) - if current_time - (self._last_run or 0) > 3600: - secret_tokens = [] + self._cache = TTLCache(maxsize=1, ttl=300) + self._cache_lock = threading.RLock() + self._local = threading.local() + + def _get_secrets_client(self): + """Get thread-local secrets manager client.""" + if not hasattr(self._local, "secrets_manager"): + self._local.secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) + return self._local.secrets_manager + + def get_management_tokens(self) -> list[str]: + """Return secret management tokens if they exist.""" + cache_key = hashkey() + + with self._cache_lock: + if cache_key in self._cache: + return self._cache[cache_key] + + logger.info("Updating management tokens cache") + secret_tokens = [] + secret_id = os.environ.get("MANAGEMENT_KEY_NAME") + secrets_manager = self._get_secrets_client() + + try: + secret_tokens.append( + secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSCURRENT")["SecretString"] + ) secret_tokens.append( - self._secrets_manager.get_secret_value( - SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSCURRENT" - )["SecretString"] + secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSPREVIOUS")["SecretString"] ) - try: - secret_tokens.append( - self._secrets_manager.get_secret_value( - SecretId=os.environ.get("MANAGEMENT_KEY_NAME"), VersionStage="AWSPREVIOUS" - )["SecretString"] - ) - except Exception: - logger.info(f"No previous secret version for {os.environ.get('MANAGEMENT_KEY_NAME')}") - self._secret_tokens = secret_tokens - self._last_run = current_time - - def is_valid_api_token(self, headers: Dict[str, str]) -> bool: + except Exception: + logger.info(f"No previous secret version for {secret_id}") + + with self._cache_lock: + self._cache[cache_key] = secret_tokens + + return secret_tokens + + async def is_valid_api_token(self, headers: Dict[str, str]) -> bool: """Return if API Token from request headers is valid if found.""" - self._refreshTokens() - token = headers.get("Authorization", "").strip() - return token in self._secret_tokens + secret_tokens = await asyncio.to_thread(self.get_management_tokens) + token = get_authorization_token(headers) + return token in secret_tokens + + +@singleton +class Authorizer: + """Composite authenticator that tries multiple authentication methods in order.""" + + def __init__(self) -> None: + self.client_id = os.environ.get("CLIENT_ID", "") + self.authority = os.environ.get("AUTHORITY", "") + self.admin_group = os.environ.get("ADMIN_GROUP", "") + self.user_group = os.environ.get("USER_GROUP", "") + self.jwt_groups_property = os.environ.get("JWT_GROUPS_PROP", "") + self.use_idp = is_idp_used() + + self.token_authorizer = ApiTokenAuthorizer() + self.management_token_authorizer = ManagementTokenAuthorizer() + self.oidc_authorizer = OIDCHTTPBearer(authority=self.authority, client_id=self.client_id) + + async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]: + jwt_data = await self.authenticate_request(request) + return jwt_data + + async def authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]: + """Authenticate request and return JWT data if OIDC, None if API/management token.""" + if not self.use_idp: + return None + + # First try API tokens + logger.info("Try API Auth Token...") + if await self.token_authorizer.is_valid_api_token(request.headers): + logger.info("Valid API token") + return None + + # Then try management tokens + logger.info("Try Management Auth Token...") + if await self.management_token_authorizer.is_valid_api_token(request.headers): + logger.info("Valid Management token") + return None + + # Finally try OIDC Bearer tokens + logger.info("Try OIDC Auth Token...") + jwt_data = await self.oidc_authorizer.id_token_is_valid(request) + if jwt_data: + logger.info("Valid OIDC token") + return jwt_data + + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated") + + def _log_access_attempt( + self, request: Request, auth_method: str, user_id: str, endpoint: str, success: bool, reason: str = "" + ) -> None: + """Centralized logging for all authentication attempts.""" + status = "SUCCESS" if success else "FAILED" + log_msg = f"AUTH {status}: user={user_id} method={auth_method} endpoint={endpoint}" + if reason: + log_msg += f" reason={reason}" + + if success: + logger.info(log_msg) + else: + logger.warning(log_msg) + + async def can_access( + self, request: Request, require_admin: bool, jwt_data: Optional[Dict[str, Any]] = None + ) -> bool: + """Return whether the user is authorized to access the endpoint.""" + endpoint = f"{request.method} {request.url.path}" + auth_method = "NO_IDP" + user_id = "anonymous" + has_access = False + reason = "" + + if not self.use_idp: + auth_method = "NO_IDP" + user_id = "anonymous" + has_access = True + reason = "IDP disabled" + else: + # Use provided JWT data or authenticate request + if jwt_data is None: + jwt_data = await self.authenticate_request(request) + + if not jwt_data: + auth_method = "API_TOKEN" + user_id = "api_user" + has_access = True + reason = "Valid API/Management token" + else: + auth_method = "OIDC" + user_id = jwt_data.get("sub", jwt_data.get("username", "unknown")) + + # If user is admin, always allow access + if is_user_in_group(jwt_data, self.admin_group, self.jwt_groups_property): + has_access = True + reason = "Admin user" + # If admin is required but user is not admin, deny access + elif require_admin: + has_access = False + reason = "Admin required" + # For non-admin requests, check user group + else: + has_access = self.user_group == "" or is_user_in_group( + jwt_data=jwt_data, group=self.user_group, jwt_groups_property=self.jwt_groups_property + ) + reason = "Valid user group" if has_access else "Invalid user group" + + self._log_access_attempt(request, auth_method, user_id, endpoint, has_access, reason) + return has_access diff --git a/lib/serve/rest-api/src/entrypoint.sh b/lib/serve/rest-api/src/entrypoint.sh index 15fe25eda..590fc4f78 100644 --- a/lib/serve/rest-api/src/entrypoint.sh +++ b/lib/serve/rest-api/src/entrypoint.sh @@ -20,6 +20,8 @@ python ./src/utils/generate_litellm_config.py -f litellm_config.yaml # src/api/endpoints/v2/litellm_passthrough.py for the LiteLLM URI litellm -c litellm_config.yaml & +# Validate THREADS variable with default value +THREADS=${THREADS:-4} echo "Starting Gunicorn with $THREADS workers..." # Start Gunicorn with Uvicorn workers. diff --git a/lib/serve/rest-api/src/handlers/generation.py b/lib/serve/rest-api/src/handlers/generation.py index bf35adb86..313b3781f 100644 --- a/lib/serve/rest-api/src/handlers/generation.py +++ b/lib/serve/rest-api/src/handlers/generation.py @@ -26,9 +26,12 @@ async def handle_generate(request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle for generate endpoint.""" model, model_kwargs, text = await validate_and_prepare_llm_request(request_data, RestApiResource.GENERATE) - response = await model.generate(text=text, model_kwargs=model_kwargs) - - return response.dict() # type: ignore + try: + response = await model.generate(text=text, model_kwargs=model_kwargs) + return response.dict() # type: ignore + except Exception as e: + logger.error(f"Model generation failed: {e}") + raise @handle_stream_exceptions diff --git a/lib/serve/rest-api/src/handlers/models.py b/lib/serve/rest-api/src/handlers/models.py index 7955c1f83..240dddc02 100644 --- a/lib/serve/rest-api/src/handlers/models.py +++ b/lib/serve/rest-api/src/handlers/models.py @@ -115,10 +115,10 @@ async def handle_describe_models(model_types: List[ModelType]) -> DefaultDict[st response: DefaultDict[str, DefaultDict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) for model_type, providers in registered_models.items(): - response[model_type.value] = {} # type: ignore + response[model_type] = {} # type: ignore providers = providers or {} for provider, model_names in providers.items(): - response[model_type.value][provider] = [ + response[model_type][provider] = [ registered_models_cache["metadata"][f"{provider}.{model_name}"] for model_name in model_names ] # type: ignore diff --git a/lib/serve/rest-api/src/lisa_serve/__init__.py b/lib/serve/rest-api/src/lisa_serve/__init__.py index 51c90fc4e..c25b6dd2e 100644 --- a/lib/serve/rest-api/src/lisa_serve/__init__.py +++ b/lib/serve/rest-api/src/lisa_serve/__init__.py @@ -30,7 +30,7 @@ { "sink": sys.stdout, "format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | " - "{request_id} | {message}", + "{extra[request_id]} | {message}", "level": logger_level.upper(), } ] diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py index 849d9fc7a..b4b110808 100644 --- a/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py +++ b/lib/serve/rest-api/src/lisa_serve/ecs/embedding/instructor.py @@ -76,18 +76,22 @@ async def embed_query(self, *, text: str, model_kwargs: Dict[str, Any]) -> Embed "text": text, } - async with ClientSession() as session: - async with session.post(self.endpoint_url, json=payload) as server_response: - server_response.raise_for_status() - server_response_json = await server_response.json() - - response = EmbedQueryResponse(embeddings=server_response_json) - - logger.debug( - f"Received: {escape_curly_brackets(response.json())}", - extra={"event": f"{self.__class__.__name__}:embed_query"}, - ) - return response + try: + async with ClientSession() as session: + async with session.post(self.endpoint_url, json=payload) as server_response: + server_response.raise_for_status() + server_response_json = await server_response.json() + + response = EmbedQueryResponse(embeddings=server_response_json) + + logger.debug( + f"Received: {escape_curly_brackets(response.json())}", + extra={"event": f"{self.__class__.__name__}:embed_query"}, + ) + return response + except Exception as e: + logger.error(f"Embedding request failed: {e}") + raise # Register the model diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py index a1415be92..bcd92224e 100644 --- a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py +++ b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py @@ -195,10 +195,12 @@ async def openai_generate_stream( AsyncGenerator[GenerateStreamResponse, None] Text generation model response with streaming. """ - request = {"prompt": text, **model_kwargs} + # Generate static values once before streaming resp_id = str(uuid.uuid4()) fingerprint = str(uuid.uuid4()) created = int(time.time()) + + request = {"prompt": text, **model_kwargs} if is_text_completion: response_class = OpenAICompletionsResponse else: diff --git a/lib/serve/rest-api/src/lisa_serve/registry/index.py b/lib/serve/rest-api/src/lisa_serve/registry/index.py index 5b0a56f8f..77e23af74 100644 --- a/lib/serve/rest-api/src/lisa_serve/registry/index.py +++ b/lib/serve/rest-api/src/lisa_serve/registry/index.py @@ -45,6 +45,6 @@ def get_assets(self, provider: str) -> Dict[str, Any]: except KeyError: raise KeyError( f"Model provider '{provider}' not found in registry. Available providers: " - f"{', '.join(list(self.registry))}" + f"{', '.join(self.registry)}" ) return model_assets # type: ignore diff --git a/lib/serve/rest-api/src/requirements.txt b/lib/serve/rest-api/src/requirements.txt index 8230b4bab..2f6c32c48 100644 --- a/lib/serve/rest-api/src/requirements.txt +++ b/lib/serve/rest-api/src/requirements.txt @@ -2,6 +2,7 @@ aioboto3>=12.0.0,<15.0.0 aiobotocore>=2.11.0,<3.0.0 aiohttp==3.12.14 boto3>=1.34.0,<1.37.0 +cachetools==5.5.0 click==8.1.7 cryptography>=43.0.1,<44.0.0 fastapi==0.115.11 diff --git a/lib/serve/rest-api/src/utils/cache_manager.py b/lib/serve/rest-api/src/utils/cache_manager.py index 44ff5d749..3c94bbace 100644 --- a/lib/serve/rest-api/src/utils/cache_manager.py +++ b/lib/serve/rest-api/src/utils/cache_manager.py @@ -13,6 +13,7 @@ # limitations under the License. """Model Cache Utilities.""" +import threading from typing import Any, Dict, Optional, Tuple from .resources import ModelType, RestApiResource @@ -33,24 +34,31 @@ } MODEL_ASSETS_CACHE: Dict[str, Tuple[Any, Any]] = {} +# Thread locks for cache operations +_REGISTERED_MODELS_LOCK = threading.RLock() +_MODEL_ASSETS_LOCK = threading.RLock() + def get_registered_models_cache() -> Dict[str, Dict[str, Any]]: """Get the cache containing the registered models.""" - return REGISTERED_MODELS_CACHE + with _REGISTERED_MODELS_LOCK: + return REGISTERED_MODELS_CACHE.copy() def get_model_assets(model_key: str) -> Optional[Tuple[Any, Any]]: """Get the cache belonging to the model assets.""" - return MODEL_ASSETS_CACHE.get(model_key, None) + with _MODEL_ASSETS_LOCK: + return MODEL_ASSETS_CACHE.get(model_key) def cache_model_assets(key: str, model_assets: Tuple[Any, Any]) -> None: """Cache the specified model assets for the specified key.""" - global MODEL_ASSETS_CACHE - MODEL_ASSETS_CACHE[key] = model_assets + with _MODEL_ASSETS_LOCK: + MODEL_ASSETS_CACHE[key] = model_assets def set_registered_models_cache(models: Dict[str, Dict[str, Any]]) -> None: """Set the registered model cache to the specified models value.""" - global REGISTERED_MODELS_CACHE - REGISTERED_MODELS_CACHE = models + with _REGISTERED_MODELS_LOCK: + global REGISTERED_MODELS_CACHE + REGISTERED_MODELS_CACHE = models diff --git a/lib/serve/rest-api/src/utils/decorators.py b/lib/serve/rest-api/src/utils/decorators.py new file mode 100644 index 000000000..a550aa44c --- /dev/null +++ b/lib/serve/rest-api/src/utils/decorators.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility decorators.""" +from typing import Any, Callable, cast, Dict, TypeVar + +T = TypeVar("T") + + +def singleton(cls: type[T]) -> Callable[..., T]: + """Singleton decorator.""" + instances: Dict[type, Any] = {} + + def get_instance(*args: Any, **kwargs: Any) -> T: + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return cast(T, instances[cls]) + + return get_instance diff --git a/lib/serve/rest-api/src/utils/generate_litellm_config.py b/lib/serve/rest-api/src/utils/generate_litellm_config.py index 0aa046aff..6eca0a700 100644 --- a/lib/serve/rest-api/src/utils/generate_litellm_config.py +++ b/lib/serve/rest-api/src/utils/generate_litellm_config.py @@ -23,14 +23,13 @@ import yaml from rds_auth import generate_auth_token, get_lambda_role_name -ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"]) -secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) - @click.command() @click.option("-f", "--filepath", type=click.Path(exists=True, file_okay=True, dir_okay=False, writable=True)) def generate_config(filepath: str) -> None: """Read LiteLLM configuration and rewrite it with LISA-deployed model information.""" + ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"]) + with open(filepath, "r") as fp: config_contents = yaml.safe_load(fp) # Get and load registered models from ParameterStore @@ -92,6 +91,7 @@ def get_database_credentials(db_params: dict[str, str]) -> Tuple: """Get database password from Secrets Manager or using IAM auth.""" if "passwordSecretId" in db_params: + secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) secret_response = secrets_client.get_secret_value(SecretId=db_params["passwordSecretId"]) secret = json.loads(secret_response["SecretString"]) return (db_params["username"], secret["password"]) diff --git a/lib/serve/rest-api/src/utils/request_utils.py b/lib/serve/rest-api/src/utils/request_utils.py index a611d0fa1..cc4123e61 100644 --- a/lib/serve/rest-api/src/utils/request_utils.py +++ b/lib/serve/rest-api/src/utils/request_utils.py @@ -78,12 +78,17 @@ async def validate_model(request_data: Dict[str, Any], resource: RestApiResource registered_models_cache = get_registered_models_cache() supported_models = registered_models_cache[resource][provider] if model_name not in supported_models: + # Sanitize inputs for logging to prevent log injection + safe_model_name = str(model_name).replace("\n", "").replace("\r", "") + safe_resource = str(resource).replace("\n", "").replace("\r", "") + safe_supported = str(supported_models).replace("\n", "").replace("\r", "") + message = ( - f"Provider does not support model {model_name} for endpoint " - f"/{resource}, expected one of: {supported_models}" + f"Provider does not support model {safe_model_name} for endpoint " + f"/{safe_resource}, expected one of: {safe_supported}" ) logger.error(message, extra={"event": event, "status": "ERROR"}) - raise Exception(message) + raise ValueError(message) async def get_model_and_validator(request_data: Dict[str, Any]) -> Tuple[Any, Any]: @@ -113,7 +118,10 @@ async def get_model_and_validator(request_data: Dict[str, Any]) -> Tuple[Any, An # Retrieve model endpoint URL registered_models_cache = get_registered_models_cache() - endpoint_url = registered_models_cache["endpointUrls"][model_key] + try: + endpoint_url = registered_models_cache["endpointUrls"][model_key] + except KeyError: + raise KeyError(f"Model endpoint URL not found for {model_key}") # Instantiate the model model = adapter(model_name=model_name, endpoint_url=endpoint_url) @@ -158,7 +166,11 @@ async def validate_and_prepare_llm_request( task_logger.debug("Finish task", status="FINISH") - return model, model_kwargs.dict(), request_data["text"] + text = request_data.get("text") + if text is None: + raise ValueError("Missing required field: text") + + return model, model_kwargs.dict(), text def handle_stream_exceptions( diff --git a/lib/serve/rest-api/src/utils/resources.py b/lib/serve/rest-api/src/utils/resources.py index 888863922..c8929ae7a 100644 --- a/lib/serve/rest-api/src/utils/resources.py +++ b/lib/serve/rest-api/src/utils/resources.py @@ -170,7 +170,7 @@ class OpenAICompletionsRequest(BaseModel): ] ), ) - echo: Optional[int] = Field(False, description="Whether to prepend the prompt to the generated text.") + echo: Optional[bool] = Field(False, description="Whether to prepend the prompt to the generated text.") frequency_penalty: Optional[float] = Field(None, description="Penalty to add for text repetition.") logit_bias: Optional[Dict[Any, Any]] = Field( None, diff --git a/lib/user-interface/react/src/components/chatbot/Chat.tsx b/lib/user-interface/react/src/components/chatbot/Chat.tsx index 0c35973f1..523cb6d28 100644 --- a/lib/user-interface/react/src/components/chatbot/Chat.tsx +++ b/lib/user-interface/react/src/components/chatbot/Chat.tsx @@ -504,21 +504,26 @@ export default function Chat ({ sessionId }) { })); } + // Fetch RAG documents once if needed + let ragDocs = null; + if (useRag && !isImageGenerationMode) { + ragDocs = await fetchRelevantDocuments(userPrompt); + } + // Use extracted message builder utilities const messageContent = await buildMessageContent({ isImageGenerationMode, fileContext, useRag, userPrompt, - fetchRelevantDocuments, + ragDocs, }); const messageMetadata = await buildMessageMetadata({ isImageGenerationMode, useRag, - userPrompt, chatConfiguration, - fetchRelevantDocuments, + ragDocs, }); messages.push(new LisaChatMessage({ diff --git a/lib/user-interface/react/src/components/chatbot/utils/messageBuilder.utils.tsx b/lib/user-interface/react/src/components/chatbot/utils/messageBuilder.utils.tsx index a08966f15..de6b0da9a 100644 --- a/lib/user-interface/react/src/components/chatbot/utils/messageBuilder.utils.tsx +++ b/lib/user-interface/react/src/components/chatbot/utils/messageBuilder.utils.tsx @@ -21,7 +21,7 @@ export type MessageContentParams = { fileContext: string; useRag: boolean; userPrompt: string; - fetchRelevantDocuments?: (query: string) => Promise; + ragDocs?: any; }; export const buildMessageContent = async ({ @@ -29,7 +29,7 @@ export const buildMessageContent = async ({ fileContext, useRag, userPrompt, - fetchRelevantDocuments, + ragDocs, }: MessageContentParams) => { if (isImageGenerationMode) { return userPrompt; @@ -43,8 +43,7 @@ export const buildMessageContent = async ({ ]; } - if (useRag && fetchRelevantDocuments) { - const ragDocs = await fetchRelevantDocuments(userPrompt); + if (useRag && ragDocs) { return [ { type: 'text', text: 'File context: ' + formatDocumentsAsString(ragDocs.data?.docs) }, { type: 'text', text: userPrompt }, @@ -64,15 +63,13 @@ export const buildMessageContent = async ({ export const buildMessageMetadata = async ({ isImageGenerationMode, useRag, - userPrompt, chatConfiguration, - fetchRelevantDocuments, + ragDocs, }: { isImageGenerationMode: boolean; useRag: boolean; - userPrompt: string; chatConfiguration: any; - fetchRelevantDocuments?: (query: string) => Promise; + ragDocs?: any; }) => { const metadata: any = {}; @@ -81,8 +78,7 @@ export const buildMessageMetadata = async ({ metadata.imageGenerationSettings = chatConfiguration.sessionConfiguration.imageGenerationArgs; } - if (useRag && !isImageGenerationMode && fetchRelevantDocuments) { - const ragDocs = await fetchRelevantDocuments(userPrompt); + if (useRag && !isImageGenerationMode && ragDocs) { metadata.ragContext = formatDocumentsAsString(ragDocs.data?.docs, true); metadata.ragDocuments = formatDocumentTitlesAsString(ragDocs.data?.docs); } diff --git a/lib/user-interface/react/src/components/model-management/ModelManagementActions.tsx b/lib/user-interface/react/src/components/model-management/ModelManagementActions.tsx index d01841811..d287c09cf 100644 --- a/lib/user-interface/react/src/components/model-management/ModelManagementActions.tsx +++ b/lib/user-interface/react/src/components/model-management/ModelManagementActions.tsx @@ -20,10 +20,7 @@ import { useAppDispatch, useAppSelector } from '@/config/store'; import { IModel, ModelStatus } from '@/shared/model/model-management.model'; import { useNotificationService } from '@/shared/util/hooks'; import { INotificationService } from '@/shared/notification/notification.service'; -import { - modelManagementApi, - useDeleteModelMutation, useUpdateModelMutation, -} from '@/shared/reducers/model-management.reducer'; +import { useDeleteModelMutation, useUpdateModelMutation} from '@/shared/reducers/model-management.reducer'; import { MutationTrigger } from '@reduxjs/toolkit/dist/query/react/buildHooks'; import { Action, ThunkDispatch } from '@reduxjs/toolkit'; import { setConfirmationModal } from '@/shared/reducers/modal.reducer'; @@ -38,6 +35,7 @@ export type ModelActionProps = { updateConfigMutation?: any; currentDefaultModel?: string; currentConfig?: any; + refetch?: () => void; }; function ModelActions (props: ModelActionProps): ReactElement { @@ -49,7 +47,7 @@ function ModelActions (props: ModelActionProps): ReactElement {