diff --git a/.github/workflows/code.end-to-end-test.nightly.yml b/.github/workflows/code.end-to-end-test.nightly.yml index 213e0afdb..b55035000 100644 --- a/.github/workflows/code.end-to-end-test.nightly.yml +++ b/.github/workflows/code.end-to-end-test.nightly.yml @@ -36,8 +36,6 @@ jobs: cache: 'npm' - name: Install base dependencies run: npm ci - - name: Install Cypress deps - run: npm ci --prefix cypress - name: Run Cypress E2E Suite env: TEST_ACCOUNT_PASSWORD: ${{ secrets.TEST_ACCOUNT_PASSWORD }} diff --git a/.gitignore b/.gitignore index c1a2773e8..6ce6efa21 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,8 @@ lib/rag/ingestion/ingestion-image/build .DS_Store *.iml *.code-workspace +.cursor +memory-bank/ # Coverage Statistic Folders coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b2e813c9f..1c1b26b4a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,14 @@ repos: files: config.yaml - repo: https://github.com/PyCQA/bandit - rev: '1.7.5' + rev: '1.7.10' hooks: - id: bandit args: [--recursive, -c=pyproject.toml] - additional_dependencies: ['bandit[toml]'] + additional_dependencies: ['bandit[toml]', 'pbr'] - repo: https://github.com/Yelp/detect-secrets - rev: v1.4.0 + rev: v1.5.0 hooks: - id: detect-secrets exclude: (?x)^( @@ -27,7 +27,7 @@ repos: )$ - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v5.0.0 hooks: - id: check-json - id: check-yaml @@ -41,7 +41,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell entry: codespell @@ -49,7 +49,7 @@ repos: pass_filenames: false - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: isort (python) @@ -59,13 +59,14 @@ repos: hooks: - id: black -- repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.1.3' +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: 'v0.8.4' hooks: - id: ruff args: - --exit-non-zero-on-fix - --per-file-ignores=test/**/*.py:E402 + exclude: \.ipynb$ - repo: https://github.com/pycqa/flake8 rev: '7.1.1' @@ -85,7 +86,7 @@ repos: - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.6.1' + rev: 'v1.13.0' hooks: - id: mypy verbose: true diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d84d65c3..323fbe16c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,46 @@ +# v5.2.0 +## Key Features +### Model Context Protocol (MCP) Enhancements +- **Connection Validation**: Real-time connection testing with detailed feedback on server connectivity during connection creation/edit +- **Enhanced Debugging**: Improved error handling and connection status reporting for MCP servers + +### Session Management Improvements +- **Time-Based Session Grouping**: Sessions are now automatically organized into time-based groups based on updated date (Last Day, Last 7 Days, Last Month, Last 3 Months, Older) +- **Session ID Removal**: Removed session ID from prompt input for cleaner user interface + +### RAG (Retrieval-Augmented Generation) Improvements +#### Document Processing +- **Document Chunk Processing Fixes**: Resolved issues with document chunk processing and ingestion +- **Document Library Pagination**: Added pagination support for the Document Library to handle large numbers of documents efficiently + +#### Vector Store Configuration +- **Default Embedding Model Support**: Added ability to define a default embedding model when creating or updating vector stores +- **IAM Permissions Optimization**: Trimmed vector store IAM permissions to follow the principle of least privilege +- **Container Configuration**: Added container override configuration for batch ingestion processes + +#### Batch Ingestion +- **Container Configuration**: Added support for container override configuration in batch ingestion jobs +- **Max Batch Jobs Setting**: Implemented dynamic maximum batch jobs limit +- **Ingestion Rules Updates**: Automatic updates to ingestion rules when Lambda functions are updated + +### Model Management Improvements +- **Base Container Configuration**: Added support for using prebuilt model containers, instead of building during model deployment + +### UI/UX Enhancements +- **General UI Improvements**: Various user interface enhancements to improve usability +- **Updated Default System Prompt**: Updated LISAs default system prompt to take advantage of new rendering capabilities. Pairing this prompt with new UI components supports the display of: + - Inline-Code + - Mathematic equations using LaTex syntax + - Mermaid Diagrams. These diagrams can also be copied and downloaded as images + +## Acknowledgements +* @bedanley +* @estohlmann +* @dustins +* @jmharold + +**Full Changelog**: https://github.com/awslabs/LISA/compare/v5.1.0...v5.2.0 + # v5.1.0 ## Key Features ### Model Management Enhancements diff --git a/VERSION b/VERSION index 831446cbd..91ff57278 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -5.1.0 +5.2.0 diff --git a/lambda/mcp_server/lambda_functions.py b/lambda/mcp_server/lambda_functions.py index 6afb00906..ffa58b4bd 100644 --- a/lambda/mcp_server/lambda_functions.py +++ b/lambda/mcp_server/lambda_functions.py @@ -21,7 +21,8 @@ import boto3 from boto3.dynamodb.conditions import Attr, Key -from utilities.common_functions import api_wrapper, get_item, get_username, is_admin, retry_config +from utilities.auth import get_username, is_admin +from utilities.common_functions import api_wrapper, get_item, retry_config from .models import McpServerModel, McpServerStatus diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index 6d7ee2e58..3f47912c7 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -432,7 +432,9 @@ class RagDocument(BaseModel): def __init__(self, **data: Any) -> None: super().__init__(**data) self.pk = self.createPartitionKey(self.repository_id, self.collection_id) - self.chunks = len(self.subdocs) + # Only calculate chunks if not explicitly provided in data (for new documents) + if "chunks" not in data: + self.chunks = len(self.subdocs) @staticmethod def createPartitionKey(repository_id: str, collection_id: str) -> str: diff --git a/lambda/models/lambda_functions.py b/lambda/models/lambda_functions.py index a04f08882..034a89dd0 100644 --- a/lambda/models/lambda_functions.py +++ b/lambda/models/lambda_functions.py @@ -24,7 +24,8 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from mangum import Mangum -from utilities.common_functions import get_groups, is_admin, retry_config +from utilities.auth import is_admin +from utilities.common_functions import get_groups, retry_config from utilities.fastapi_middleware.aws_api_gateway_middleware import AWSAPIGatewayMiddleware from .domain_objects import ( diff --git a/lambda/models/state_machine/create_model.py b/lambda/models/state_machine/create_model.py index 3c97baa2c..07e6d0a82 100644 --- a/lambda/models/state_machine/create_model.py +++ b/lambda/models/state_machine/create_model.py @@ -117,6 +117,49 @@ def handle_start_copy_docker_image(event: Dict[str, Any], context: Any) -> Dict[ image_path = get_container_path(request.inferenceContainer) output_dict["containerConfig"]["image"]["path"] = image_path + # Check if image type is ECR - skip building docker image if it already exists + if request.containerConfig and request.containerConfig.image.type == "ecr": + logger.info(f"ECR image detected for model {event.get('modelId')}, verifying image accessibility") + # Verify the ECR image is accessible + try: + # Extract repository name and tag from the base image + base_image = request.containerConfig.image.baseImage + if ":" in base_image: + repository_name, image_tag = base_image.rsplit(":", 1) + else: + repository_name = base_image + image_tag = "latest" + + # Remove registry URL if present to get just the repository name + if "/" in repository_name: + repository_name = repository_name.split("/")[-1] + + # Verify image exists in ECR + ecrClient.describe_images(repositoryName=repository_name, imageIds=[{"imageTag": image_tag}]) + + logger.info(f"ECR image {base_image} verified successfully") + output_dict["image_info"] = { + "image_tag": image_tag, + "image_uri": repository_name, + "image_type": "ecr", + "remaining_polls": 0, + "image_status": "prebuilt", + } + return output_dict + + except ecrClient.exceptions.ImageNotFoundException: + error_msg = f"ECR image {base_image} not found. Please ensure the image exists and is accessible." + logger.error(error_msg) + raise Exception(error_msg) + except ecrClient.exceptions.RepositoryNotFoundException: + error_msg = ( + f"ECR repository {repository_name} not found. Please ensure the repository exists and is accessible." + ) + logger.error(error_msg) + raise Exception(error_msg) + + # For non-ECR images, proceed with the normal docker image building process + logger.info(f"Invoking image build for model {event.get('modelId')}") response = lambdaClient.invoke( FunctionName=os.environ["DOCKER_IMAGE_BUILDER_FN_ARN"], Payload=json.dumps( @@ -130,6 +173,7 @@ def handle_start_copy_docker_image(event: Dict[str, Any], context: Any) -> Dict[ payload = response["Payload"].read() output_dict["image_info"] = json.loads(payload) output_dict["image_info"]["remaining_polls"] = 30 + output_dict["image_info"]["image_status"] = "building" return output_dict @@ -138,14 +182,22 @@ def handle_poll_docker_image_available(event: Dict[str, Any], context: Any) -> D output_dict = deepcopy(event) try: + # Use the appropriate repository name based on image type + repository_name = ( + event["image_info"]["image_uri"] + if event["image_info"].get("image_type") == "ecr" + else os.environ["ECR_REPOSITORY_NAME"] + ) ecrClient.describe_images( - repositoryName=os.environ["ECR_REPOSITORY_NAME"], imageIds=[{"imageTag": event["image_info"]["image_tag"]}] + repositoryName=repository_name, imageIds=[{"imageTag": event["image_info"]["image_tag"]}] ) except ecrClient.exceptions.ImageNotFoundException: output_dict["continue_polling_docker"] = True output_dict["image_info"]["remaining_polls"] -= 1 if output_dict["image_info"]["remaining_polls"] <= 0: - ec2Client.terminate_instances(InstanceIds=[event["image_info"]["instance_id"]]) + # Only terminate EC2 instance if one exists (not for pre-existing ECR images) + if "instance_id" in event["image_info"]: + ec2Client.terminate_instances(InstanceIds=[event["image_info"]["instance_id"]]) raise MaxPollsExceededException( json.dumps( { @@ -157,7 +209,9 @@ def handle_poll_docker_image_available(event: Dict[str, Any], context: Any) -> D return output_dict output_dict["continue_polling_docker"] = False - ec2Client.terminate_instances(InstanceIds=[event["image_info"]["instance_id"]]) + # Only terminate EC2 instance if one exists (not for pre-existing ECR images) + if "instance_id" in event["image_info"]: + ec2Client.terminate_instances(InstanceIds=[event["image_info"]["instance_id"]]) return output_dict @@ -178,11 +232,32 @@ def camelize_object(o): # type: ignore[no-untyped-def] prepared_event = camelize_object(event) prepared_event["containerConfig"]["environment"] = event["containerConfig"]["environment"] - prepared_event["containerConfig"]["image"] = { - "repositoryArn": os.environ["ECR_REPOSITORY_ARN"], - "tag": event["image_info"]["image_tag"], - "type": "ecr", - } + + # Handle ECR images differently - use the existing ECR image instead of the built one + if event["image_info"].get("image_type") == "ecr": + # For pre-existing ECR images, construct the ARN using the image repository + account_id = os.environ.get("AWS_ACCOUNT_ID", "") + if not account_id: + # Try to get account ID from the existing ECR repository ARN + ecr_repo_arn = os.environ.get("ECR_REPOSITORY_ARN", "") + if ecr_repo_arn: + account_id = ecr_repo_arn.split(":")[4] + + repository_arn = ( + f"arn:aws:ecr:{os.environ['AWS_REGION']}:{account_id}:repository/{event['image_info']['image_uri']}" + ) + prepared_event["containerConfig"]["image"] = { + "repositoryArn": repository_arn, + "tag": event["image_info"]["image_tag"], + "type": "ecr", + } + else: + # For built images, use the default ECR repository + prepared_event["containerConfig"]["image"] = { + "repositoryArn": os.environ["ECR_REPOSITORY_ARN"], + "tag": event["image_info"]["image_tag"], + "type": "ecr", + } response = lambdaClient.invoke( FunctionName=os.environ["ECS_MODEL_DEPLOYER_FN_ARN"], diff --git a/lambda/prompt_templates/lambda_functions.py b/lambda/prompt_templates/lambda_functions.py index 8affa65a0..68fc278e9 100644 --- a/lambda/prompt_templates/lambda_functions.py +++ b/lambda/prompt_templates/lambda_functions.py @@ -22,7 +22,8 @@ import boto3 from boto3.dynamodb.conditions import Attr, Key -from utilities.common_functions import api_wrapper, get_groups, get_item, get_username, is_admin, retry_config +from utilities.auth import get_username, is_admin +from utilities.common_functions import api_wrapper, get_groups, get_item, retry_config from .models import PromptTemplateModel diff --git a/lambda/repository/embeddings.py b/lambda/repository/embeddings.py new file mode 100644 index 000000000..9f4863baf --- /dev/null +++ b/lambda/repository/embeddings.py @@ -0,0 +1,192 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, List + +import boto3 +import requests +from lisapy.langchain import LisaOpenAIEmbeddings +from utilities.common_functions import get_cert_path, retry_config +from utilities.validation import ValidationError + +logger = logging.getLogger(__name__) +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) +secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) +iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config) + +lisa_api_endpoint = "" + + +class PipelineEmbeddings: + """ + Handles document embeddings for pipeline processing using management credentials. + + This class provides methods to embed both single queries and batches of documents + using the LISA API with management-level authentication. + """ + + model_name: str + + def __init__(self, model_name: str) -> None: + try: + self.model_name = model_name + # Get the management key secret name from SSM Parameter Store + secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) + secret_name = secret_name_param["Parameter"]["Value"] + + # Get the management token from Secrets Manager using the secret name + secret_response = secrets_client.get_secret_value(SecretId=secret_name) + self.token = secret_response["SecretString"] + + # Get the API endpoint from SSM + lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) + self.base_url = f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve" + + # Get certificate path for SSL verification + self.cert_path = get_cert_path(iam_client) + + logger.info("Successfully initialized pipeline embeddings") + except Exception: + logger.error("Failed to initialize pipeline embeddings", exc_info=True) + raise + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for a list of documents. + + Args: + texts: List of text strings to embed + + Returns: + List of embedding vectors + + Raises: + ValidationError: If input texts are invalid + Exception: If embedding request fails + """ + if not texts: + raise ValidationError("No texts provided for embedding") + + logger.info(f"Embedding {len(texts)} documents") + try: + url = f"{self.base_url}/embeddings" + request_data = {"input": texts, "model": self.model_name} + + response = requests.post( + url, + json=request_data, + headers={"Authorization": self.token, "Content-Type": "application/json"}, + verify=self.cert_path, # Use proper SSL verification + timeout=300, # 5 minute timeout + ) + + if response.status_code != 200: + logger.error(f"Embedding request failed with status {response.status_code}") + logger.error(f"Response content: {response.text}") + raise Exception(f"Embedding request failed with status {response.status_code}") + + result = response.json() + logger.debug(f"API Response: {result}") # Log the full response for debugging + + # Handle different response formats + embeddings = [] + if isinstance(result, dict): + if "data" in result: + # OpenAI-style format + for item in result["data"]: + if isinstance(item, dict) and "embedding" in item: + embeddings.append(item["embedding"]) + else: + embeddings.append(item) # Assume the item itself is the embedding + else: + # Try to find embeddings in the response + for key in ["embeddings", "embedding", "vectors", "vector"]: + if key in result: + embeddings = result[key] + break + elif isinstance(result, list): + # Direct list format + embeddings = result + + if not embeddings: + logger.error(f"Could not find embeddings in response: {result}") + raise Exception("No embeddings found in API response") + + if len(embeddings) != len(texts): + logger.error(f"Mismatch between number of texts ({len(texts)}) and embeddings ({len(embeddings)})") + raise Exception("Number of embeddings does not match number of input texts") + + logger.info(f"Successfully embedded {len(texts)} documents") + return embeddings + + except requests.Timeout: + logger.error("Embedding request timed out") + raise Exception("Embedding request timed out after 5 minutes") + except requests.RequestException as e: + logger.error(f"Request failed: {str(e)}", exc_info=True) + raise + except Exception as e: + logger.error(f"Failed to get embeddings: {str(e)}", exc_info=True) + raise + + def embed_query(self, text: str) -> List[float]: + if not text or not isinstance(text, str): + raise ValidationError("Invalid query text") + + logger.info("Embedding single query text") + return self.embed_documents([text])[0] + + +def get_embeddings_pipeline(model_name: str) -> Any: + """ + Get embeddings for pipeline requests using management token. + + Args: + model_name: Name of the embedding model to use + + Raises: + ValidationError: If model name is invalid + Exception: If API request fails + """ + logger.info("Starting pipeline embeddings request") + + return PipelineEmbeddings(model_name=model_name) + + +def get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings: + """ + Initialize and return an embeddings client for the specified model. + + Args: + model_name: Name of the embedding model to use + id_token: Authentication token for API access + + Returns: + LisaOpenAIEmbeddings: Configured embeddings client + """ + global lisa_api_endpoint + + if not lisa_api_endpoint: + lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) + lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"] + + base_url = f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve" + cert_path = get_cert_path(iam_client) + + embedding = LisaOpenAIEmbeddings( + lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=cert_path + ) + return embedding diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index ecf58132c..9422b81a9 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -16,31 +16,23 @@ import json import logging import os -from typing import Any, cast, Dict, List +import urllib.parse +from typing import Any, cast, Dict, List, Optional import boto3 -import requests from boto3.dynamodb.types import TypeSerializer from botocore.config import Config -from lisapy.langchain import LisaOpenAIEmbeddings from models.domain_objects import FixedChunkingStrategy, IngestionJob, IngestionStatus, RagDocument +from repository.embeddings import get_embeddings from repository.ingestion_job_repo import IngestionJobRepository from repository.ingestion_service import DocumentIngestionService from repository.rag_document_repo import RagDocumentRepository from repository.vector_store_repo import VectorStoreRepository -from utilities.bedrock_kb import is_bedrock_kb_repository, retrieve_documents -from utilities.common_functions import ( - admin_only, - api_wrapper, - get_cert_path, - get_groups, - get_id_token, - get_username, - is_admin, - retry_config, - user_has_group_access, -) +from utilities.auth import admin_only, get_username, is_admin +from utilities.bedrock_kb import retrieve_documents +from utilities.common_functions import api_wrapper, get_groups, get_id_token, retry_config, user_has_group_access from utilities.exceptions import HTTPException +from utilities.repository_types import RepositoryType from utilities.validation import ValidationError from utilities.vector_store import get_vector_store_client @@ -64,177 +56,12 @@ signature_version="s3v4", ), ) -lisa_api_endpoint = "" doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"], os.environ["RAG_SUB_DOCUMENT_TABLE"]) vs_repo = VectorStoreRepository() ingestion_service = DocumentIngestionService() ingestion_job_repository = IngestionJobRepository() -def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings: - """ - Initialize and return an embeddings client for the specified model. - - Args: - model_name: Name of the embedding model to use - id_token: Authentication token for API access - - Returns: - LisaOpenAIEmbeddings: Configured embeddings client - """ - global lisa_api_endpoint - - if not lisa_api_endpoint: - lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) - lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"] - - base_url = f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve" - cert_path = get_cert_path(iam_client) - - embedding = LisaOpenAIEmbeddings( - lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=cert_path - ) - return embedding - - # Create embeddings client that matches LisaOpenAIEmbeddings interface - - -class PipelineEmbeddings: - """ - Handles document embeddings for pipeline processing using management credentials. - - This class provides methods to embed both single queries and batches of documents - using the LISA API with management-level authentication. - """ - - model_name: str - - def __init__(self, model_name: str) -> None: - try: - self.model_name = model_name - # Get the management key secret name from SSM Parameter Store - secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) - secret_name = secret_name_param["Parameter"]["Value"] - - # Get the management token from Secrets Manager using the secret name - secret_response = secrets_client.get_secret_value(SecretId=secret_name) - self.token = secret_response["SecretString"] - - # Get the API endpoint from SSM - lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) - self.base_url = f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve" - - # Get certificate path for SSL verification - self.cert_path = get_cert_path(iam_client) - - logger.info("Successfully initialized pipeline embeddings") - except Exception: - logger.error("Failed to initialize pipeline embeddings", exc_info=True) - raise - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """ - Generate embeddings for a list of documents. - - Args: - texts: List of text strings to embed - - Returns: - List of embedding vectors - - Raises: - ValidationError: If input texts are invalid - Exception: If embedding request fails - """ - if not texts: - raise ValidationError("No texts provided for embedding") - - logger.info(f"Embedding {len(texts)} documents") - try: - url = f"{self.base_url}/embeddings" - request_data = {"input": texts, "model": self.model_name} - - response = requests.post( - url, - json=request_data, - headers={"Authorization": self.token, "Content-Type": "application/json"}, - verify=self.cert_path, # Use proper SSL verification - timeout=300, # 5 minute timeout - ) - - if response.status_code != 200: - logger.error(f"Embedding request failed with status {response.status_code}") - logger.error(f"Response content: {response.text}") - raise Exception(f"Embedding request failed with status {response.status_code}") - - result = response.json() - logger.debug(f"API Response: {result}") # Log the full response for debugging - - # Handle different response formats - embeddings = [] - if isinstance(result, dict): - if "data" in result: - # OpenAI-style format - for item in result["data"]: - if isinstance(item, dict) and "embedding" in item: - embeddings.append(item["embedding"]) - else: - embeddings.append(item) # Assume the item itself is the embedding - else: - # Try to find embeddings in the response - for key in ["embeddings", "embedding", "vectors", "vector"]: - if key in result: - embeddings = result[key] - break - elif isinstance(result, list): - # Direct list format - embeddings = result - - if not embeddings: - logger.error(f"Could not find embeddings in response: {result}") - raise Exception("No embeddings found in API response") - - if len(embeddings) != len(texts): - logger.error(f"Mismatch between number of texts ({len(texts)}) and embeddings ({len(embeddings)})") - raise Exception("Number of embeddings does not match number of input texts") - - logger.info(f"Successfully embedded {len(texts)} documents") - return embeddings - - except requests.Timeout: - logger.error("Embedding request timed out") - raise Exception("Embedding request timed out after 5 minutes") - except requests.RequestException as e: - logger.error(f"Request failed: {str(e)}", exc_info=True) - raise - except Exception as e: - logger.error(f"Failed to get embeddings: {str(e)}", exc_info=True) - raise - - def embed_query(self, text: str) -> List[float]: - if not text or not isinstance(text, str): - raise ValidationError("Invalid query text") - - logger.info("Embedding single query text") - return self.embed_documents([text])[0] - - -def get_embeddings_pipeline(model_name: str) -> Any: - """ - Get embeddings for pipeline requests using management token. - - Args: - model_name: Name of the embedding model to use - - Raises: - ValidationError: If model name is invalid - Exception: If API request fails - """ - logger.info("Starting pipeline embeddings request") - - return PipelineEmbeddings(model_name=model_name) - - @api_wrapper def list_all(event: dict, context: dict) -> List[Dict[str, Any]]: """ @@ -303,7 +130,7 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: id_token = get_id_token(event) docs: List[Dict[str, Any]] = [] - if is_bedrock_kb_repository(repository): + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): docs = retrieve_documents( bedrock_runtime_client=bedrock_client, repository=repository, @@ -312,14 +139,20 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: repository_id=repository_id, ) else: - embeddings = _get_embeddings(model_name=model_name, id_token=id_token) + embeddings = get_embeddings(model_name=model_name, id_token=id_token) vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) - results = vs.similarity_search( - query, - k=top_k, - ) - docs = [{"page_content": r.page_content, "metadata": r.metadata} for r in results] + # empty vector stores do not have an initialize index. Return empty docs + if RepositoryType.is_type(repository, RepositoryType.OPENSEARCH) and not vs.client.indices.exists( + index=model_name + ): + logger.info(f"Index {model_name} does not exist. Returning empty docs.") + else: + results = vs.similarity_search( + query, + k=top_k, + ) + docs = [{"page_content": r.page_content, "metadata": r.metadata} for r in results] doc_content = [ { "Document": { @@ -554,7 +387,7 @@ def presigned_url(event: dict, context: dict) -> dict: @api_wrapper -def list_docs(event: dict, context: dict) -> dict[str, list[dict] | str | None]: +def list_docs(event: dict, context: dict) -> dict[str, Any]: """List all documents for a given repository/collection. Args: @@ -564,8 +397,7 @@ def list_docs(event: dict, context: dict) -> dict[str, list[dict] | str | None]: context (dict): The Lambda context object Returns: - Tuple list[RagDocument], dict[lastEvaluatedKey]: A list of RagDocument objects representing all documents - in the specified collection and the last evaluated key for pagination + Dict containing documents, pagination info, and metadata Raises: KeyError: If collectionId is not provided in queryStringParameters @@ -576,12 +408,44 @@ def list_docs(event: dict, context: dict) -> dict[str, list[dict] | str | None]: query_string_params = event.get("queryStringParameters", {}) or {} collection_id = query_string_params.get("collectionId") - last_evaluated = query_string_params.get("lastEvaluated") + last_evaluated: Optional[dict[str, Optional[str]]] = None + + if "lastEvaluatedKeyPk" in query_string_params: + last_evaluated = { + "pk": ( + urllib.parse.unquote(query_string_params["lastEvaluatedKeyPk"]) + if "lastEvaluatedKeyPk" in query_string_params + else None + ), + "document_id": ( + urllib.parse.unquote(query_string_params["lastEvaluatedKeyDocumentId"]) + if "lastEvaluatedKeyDocumentId" in query_string_params + else None + ), + "repository_id": ( + urllib.parse.unquote(query_string_params["lastEvaluatedKeyRepositoryId"]) + if "lastEvaluatedKeyRepositoryId" in query_string_params + else None + ), + } + + page_size = int(query_string_params.get("pageSize", "10")) + + if page_size < 1: + page_size = 1 + elif page_size > 100: # Cap at 100 to prevent abuse + page_size = 100 - docs, last_evaluated = doc_repo.list_all( - repository_id=repository_id, collection_id=collection_id, last_evaluated_key=last_evaluated + docs, last_evaluated, total_documents = doc_repo.list_all( + repository_id=repository_id, collection_id=collection_id, last_evaluated_key=last_evaluated, limit=page_size ) - return {"documents": [doc.model_dump() for doc in docs], "lastEvaluated": last_evaluated} + return { + "documents": [doc.model_dump() for doc in docs], + "lastEvaluated": last_evaluated, + "totalDocuments": total_documents, + "hasNextPage": last_evaluated is not None, + "hasPreviousPage": "lastEvaluated" in query_string_params, + } @api_wrapper diff --git a/lambda/repository/pipeline_delete_documents.py b/lambda/repository/pipeline_delete_documents.py index 452ea7835..307305412 100644 --- a/lambda/repository/pipeline_delete_documents.py +++ b/lambda/repository/pipeline_delete_documents.py @@ -19,12 +19,13 @@ import boto3 from models.domain_objects import IngestionJob, IngestionStatus, IngestionType from repository.ingestion_job_repo import IngestionJobRepository +from repository.ingestion_service import DocumentIngestionService from repository.pipeline_ingest_documents import remove_document_from_vectorstore +from repository.rag_document_repo import RagDocumentRepository from repository.vector_store_repo import VectorStoreRepository -from utilities.bedrock_kb import delete_document_from_kb, is_bedrock_kb_repository +from utilities.bedrock_kb import delete_document_from_kb from utilities.common_functions import retry_config - -from .lambda_functions import DocumentIngestionService, RagDocumentRepository +from utilities.repository_types import RepositoryType ingestion_service = DocumentIngestionService() ingestion_job_repository = IngestionJobRepository() @@ -47,7 +48,7 @@ def pipeline_delete(job: IngestionJob) -> None: if rag_document: # Actually remove from vector store repository = vs_repo.find_repository_by_id(job.repository_id) - if is_bedrock_kb_repository(repository): + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): delete_document_from_kb( s3_client=s3, bedrock_agent_client=bedrock_agent, diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index d53a7b0b6..83598ef9a 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -21,16 +21,18 @@ import boto3 from models.domain_objects import FixedChunkingStrategy, IngestionJob, IngestionStatus, IngestionType, RagDocument +from repository.embeddings import get_embeddings_pipeline from repository.ingestion_job_repo import IngestionJobRepository -from repository.lambda_functions import RagDocumentRepository +from repository.ingestion_service import DocumentIngestionService +from repository.rag_document_repo import RagDocumentRepository from repository.vector_store_repo import VectorStoreRepository -from utilities.bedrock_kb import ingest_document_to_kb, is_bedrock_kb_repository -from utilities.common_functions import get_username, retry_config +from utilities.auth import get_username +from utilities.bedrock_kb import ingest_document_to_kb +from utilities.common_functions import retry_config from utilities.file_processing import generate_chunks +from utilities.repository_types import RepositoryType from utilities.vector_store import get_vector_store_client -from .lambda_functions import DocumentIngestionService, get_embeddings_pipeline - dynamodb = boto3.resource("dynamodb", region_name=os.environ["AWS_REGION"], config=retry_config) ingestion_job_table = dynamodb.Table(os.environ["LISA_INGESTION_JOB_TABLE_NAME"]) ingestion_service = DocumentIngestionService() @@ -50,7 +52,7 @@ def pipeline_ingest(job: IngestionJob) -> None: # chunk and save chunks in vector store repository = vs_repo.find_repository_by_id(job.repository_id) all_ids = [] - if is_bedrock_kb_repository(repository): + if RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): ingest_document_to_kb( s3_client=s3, bedrock_agent_client=bedrock_agent, @@ -70,7 +72,7 @@ def pipeline_ingest(job: IngestionJob) -> None: if prev_job: ingestion_job_repository.update_status(prev_job, IngestionStatus.DELETE_IN_PROGRESS) - if not is_bedrock_kb_repository(repository): + if not RepositoryType.is_type(repository, RepositoryType.BEDROCK_KB): remove_document_from_vectorstore(rag_document) rag_document_repository.delete_by_id(rag_document.document_id) diff --git a/lambda/repository/rag_document_repo.py b/lambda/repository/rag_document_repo.py index bf3c34bd8..8dce81e07 100644 --- a/lambda/repository/rag_document_repo.py +++ b/lambda/repository/rag_document_repo.py @@ -212,7 +212,7 @@ def list_all( last_evaluated_key: Optional[dict] = None, limit: int = 100, join_docs: bool = False, - ) -> tuple[list[RagDocument], Optional[dict]]: + ) -> tuple[list[RagDocument], Optional[dict], int]: """List all documents in a collection. Args: @@ -252,12 +252,37 @@ def list_all( subdocs = self._get_subdoc_ids(self.find_subdocs_by_id(doc.document_id)) doc.subdocs = subdocs - return docs, next_key + total_documents = self.count_documents(repository_id=repository_id, collection_id=collection_id) + + return docs, next_key, total_documents except ClientError as e: logging.error(f"Error listing documents: {e.response['Error']['Message']}") raise + def count_documents(self, repository_id: str, collection_id: Optional[str] = None) -> int: + """Count total documents in a repository/collection. + Args: + repository_id: Repository ID + collection_id?: Collection ID + Returns: + Total number of documents + """ + count = 0 + # Count all rag documents using repo id only + if not collection_id: + response = self.doc_table.query( + IndexName="repository_index", + KeyConditionExpression=Key("repository_id").eq(repository_id), + Select="COUNT", + ) + count = response.get("Count", 0) + else: + pk = RagDocument.createPartitionKey(repository_id, collection_id) + response = self.doc_table.query(KeyConditionExpression=Key("pk").eq(pk), Select="COUNT") + count = response.get("Count", 0) + return count + def find_subdocs_by_id(self, document_id: str) -> list[RagSubDocument]: """Query subdocuments using GSI. @@ -321,7 +346,9 @@ def delete_s3_docs(self, repository_id: str, docs: list[RagDocument]) -> list[st for pipeline in repo.get("pipelines", []) } removed_source: list[str] = [ - doc.source for doc in docs if doc.ingestion_type != IngestionType.AUTO or pipelines.get(doc.collection_id) + doc.source + for doc in docs + if doc and (doc.ingestion_type != IngestionType.AUTO or pipelines.get(doc.collection_id)) ] for source in removed_source: logging.info(f"Removing S3 doc: {source}") diff --git a/lambda/requirements.txt b/lambda/requirements.txt new file mode 100644 index 000000000..ad8fd1357 --- /dev/null +++ b/lambda/requirements.txt @@ -0,0 +1 @@ +# All required dependencies are pulled in via layer diff --git a/lambda/session/lambda_functions.py b/lambda/session/lambda_functions.py index aad661d24..68fd4648e 100644 --- a/lambda/session/lambda_functions.py +++ b/lambda/session/lambda_functions.py @@ -26,7 +26,8 @@ import boto3 import create_env_variables # noqa: F401 from botocore.exceptions import ClientError -from utilities.common_functions import api_wrapper, get_groups, get_session_id, get_username, retry_config +from utilities.auth import get_username +from utilities.common_functions import api_wrapper, get_groups, get_session_id, retry_config from utilities.encoders import convert_decimal logger = logging.getLogger(__name__) @@ -205,6 +206,9 @@ def _map_session(session: dict) -> Dict[str, Any]: "firstHumanMessage": _find_first_human_message(session), "startTime": session.get("startTime", None), "createTime": session.get("createTime", None), + "lastUpdated": session.get( + "lastUpdated", session.get("startTime", None) + ), # Fallback to startTime for backward compatibility } @@ -370,9 +374,9 @@ def rename_session(event: dict, context: dict) -> dict: table.update_item( Key={"sessionId": session_id, "userId": user_id}, - UpdateExpression="SET #name = :name", - ExpressionAttributeNames={"#name": "name"}, - ExpressionAttributeValues={":name": body.get("name")}, + UpdateExpression="SET #name = :name, #lastUpdated = :lastUpdated", + ExpressionAttributeNames={"#name": "name", "#lastUpdated": "lastUpdated"}, + ExpressionAttributeValues={":name": body.get("name"), ":lastUpdated": datetime.now().isoformat()}, ) return {"statusCode": 200, "body": json.dumps({"message": "Session name updated successfully"})} except ValueError as e: @@ -429,13 +433,15 @@ def put_session(event: dict, context: dict) -> dict: table.update_item( Key={"sessionId": session_id, "userId": user_id}, UpdateExpression="SET #history = :history, #name = :name, #configuration = :configuration, " - + "#startTime = :startTime, #createTime = if_not_exists(#createTime, :createTime)", + + "#startTime = :startTime, #createTime = if_not_exists(#createTime, :createTime), " + + "#lastUpdated = :lastUpdated", ExpressionAttributeNames={ "#history": "history", "#name": "name", "#configuration": "configuration", "#startTime": "startTime", "#createTime": "createTime", + "#lastUpdated": "lastUpdated", }, ExpressionAttributeValues={ ":history": messages, @@ -443,6 +449,7 @@ def put_session(event: dict, context: dict) -> dict: ":configuration": configuration, ":startTime": datetime.now().isoformat(), ":createTime": datetime.now().isoformat(), + ":lastUpdated": datetime.now().isoformat(), }, ReturnValues="UPDATED_NEW", ) diff --git a/lambda/user_preferences/lambda_functions.py b/lambda/user_preferences/lambda_functions.py index de55a29b2..cae6e3e09 100644 --- a/lambda/user_preferences/lambda_functions.py +++ b/lambda/user_preferences/lambda_functions.py @@ -21,7 +21,8 @@ import boto3 from boto3.dynamodb.conditions import Key -from utilities.common_functions import api_wrapper, get_item, get_username, retry_config +from utilities.auth import get_username +from utilities.common_functions import api_wrapper, get_item, retry_config from .models import UserPreferencesModel diff --git a/lambda/utilities/auth.py b/lambda/utilities/auth.py new file mode 100644 index 000000000..c6f5f5bd6 --- /dev/null +++ b/lambda/utilities/auth.py @@ -0,0 +1,48 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from functools import wraps +from typing import Any, Callable, Dict + +from utilities.common_functions import get_groups +from utilities.exceptions import HTTPException + +logger = logging.getLogger(__name__) + + +def get_username(event: dict) -> str: + """Get the username from the event.""" + username: str = event.get("requestContext", {}).get("authorizer", {}).get("username", "system") + return username + + +def is_admin(event: dict) -> bool: + """Get admin status from event.""" + admin_group = os.environ.get("ADMIN_GROUP", "") + groups = get_groups(event) + logger.info(f"User groups: {groups} and admin: {admin_group}") + return admin_group in groups + + +def admin_only(func: Callable) -> Callable: + """Annotation to wrap is_admin""" + + @wraps(func) + def wrapper(event: Dict[str, Any], context: Dict[str, Any], *args: Any, **kwargs: Any) -> Any: + if not is_admin(event): + raise HTTPException(status_code=403, message="User does not have permission to access this repository") + return func(event, context, *args, **kwargs) + + return wrapper diff --git a/lambda/utilities/bedrock_kb.py b/lambda/utilities/bedrock_kb.py index b5164a526..d80e2e4dd 100644 --- a/lambda/utilities/bedrock_kb.py +++ b/lambda/utilities/bedrock_kb.py @@ -23,13 +23,6 @@ import os from typing import Any, Dict, List -BEDROCK_KB_TYPE = "bedrock_knowledge_base" - - -def is_bedrock_kb_repository(repository: Dict[str, Any]) -> Any: - """Return True if the repository is a Bedrock Knowledge Base.""" - return bool(repository.get("type", "") == BEDROCK_KB_TYPE) - def retrieve_documents( bedrock_runtime_client: Any, diff --git a/lambda/utilities/common_functions.py b/lambda/utilities/common_functions.py index 43746bc13..86651850a 100644 --- a/lambda/utilities/common_functions.py +++ b/lambda/utilities/common_functions.py @@ -21,12 +21,11 @@ import tempfile from contextvars import ContextVar from decimal import Decimal -from functools import cache, wraps +from functools import cache from typing import Any, Callable, cast, Dict, List, TypeVar, Union import boto3 from botocore.config import Config -from utilities.exceptions import HTTPException from . import create_env_variables # noqa type: ignore @@ -365,26 +364,6 @@ def get_username(event: dict) -> str: return username -def is_admin(event: dict) -> bool: - """Get admin status from event.""" - admin_group = os.environ.get("ADMIN_GROUP", "") - groups = get_groups(event) - logger.info(f"User groups: {groups} and admin: {admin_group}") - return admin_group in groups - - -def admin_only(func: Callable) -> Callable: - """Annotation to wrap is_admin""" - - @wraps(func) - def wrapper(event: Dict[str, Any], context: Dict[str, Any], *args: Any, **kwargs: Any) -> Any: - if not is_admin(event): - raise HTTPException(status_code=403, message="User does not have permission to access this repository") - return func(event, context, *args, **kwargs) - - return wrapper - - def get_session_id(event: dict) -> str: """Get the session ID from the event.""" session_id: str = event.get("pathParameters", {}).get("sessionId") diff --git a/lambda/utilities/repository_types.py b/lambda/utilities/repository_types.py new file mode 100644 index 000000000..2e9b054aa --- /dev/null +++ b/lambda/utilities/repository_types.py @@ -0,0 +1,26 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Any, Dict + + +class RepositoryType(str, Enum): + PGVECTOR = "pgvector" + OPENSEARCH = "opensearch" + BEDROCK_KB = "bedrock_knowledge_base" + + @classmethod + def is_type(cls, repository: Dict[str, Any], repo_type: "RepositoryType") -> bool: + return repository.get("type") == repo_type.value diff --git a/lambda/utilities/vector_store.py b/lambda/utilities/vector_store.py index 1aab1ba1b..46370f400 100644 --- a/lambda/utilities/vector_store.py +++ b/lambda/utilities/vector_store.py @@ -26,6 +26,7 @@ from requests_aws4auth import AWS4Auth from utilities.common_functions import get_lambda_role_name, retry_config from utilities.rds_auth import generate_auth_token +from utilities.repository_types import RepositoryType from . import create_env_variables # noqa type: ignore @@ -44,7 +45,7 @@ def get_vector_store_client(repository_id: str, index: str, embeddings: Embeddin prefix = os.environ.get("REGISTERED_REPOSITORIES_PS_PREFIX") connection_info = ssm_client.get_parameter(Name=f"{prefix}{repository_id}") connection_info = json.loads(connection_info["Parameter"]["Value"]) - if connection_info.get("type") == "opensearch": + if RepositoryType.is_type(connection_info, RepositoryType.OPENSEARCH): service = "es" credentials = session.get_credentials() @@ -60,7 +61,7 @@ def get_vector_store_client(repository_id: str, index: str, embeddings: Embeddin return OpenSearchVectorSearch( opensearch_url=opensearch_endpoint, - index_name=index, + index_name=index.lower(), embedding_function=embeddings, http_auth=auth, timeout=300, @@ -69,7 +70,7 @@ def get_vector_store_client(repository_id: str, index: str, embeddings: Embeddin connection_class=RequestsHttpConnection, ) - elif connection_info.get("type") == "pgvector": + elif RepositoryType.is_type(connection_info, RepositoryType.PGVECTOR): if "passwordSecretId" in connection_info: # provides backwards compatibility to non-iam authenticated vector stores secrets_response = secretsmanager_client.get_secret_value(SecretId=connection_info.get("passwordSecretId")) diff --git a/lib/core/layers/authorizer/requirements.txt b/lib/core/layers/authorizer/requirements.txt index d58c98049..fb0ff6550 100644 --- a/lib/core/layers/authorizer/requirements.txt +++ b/lib/core/layers/authorizer/requirements.txt @@ -1,4 +1,4 @@ # urllib3<2 // Provided by Lambda -requests==2.32.3 +requests==2.32.4 cryptography==44.0.1 PyJWT==2.9.0 diff --git a/lib/core/layers/index.ts b/lib/core/layers/index.ts index 08bcab457..d04ba2804 100644 --- a/lib/core/layers/index.ts +++ b/lib/core/layers/index.ts @@ -38,6 +38,7 @@ type LayerProps = { slimDeployment?: boolean; removePackages?: string[]; assetPath?: string; + afterBundle?: (inputDir: string, outputDir: string) => string[]; } & BaseProps; /** @@ -55,7 +56,7 @@ export class Layer extends Construct { constructor (scope: Construct, id: string, props: LayerProps) { super(scope, id); - const { assetPath, config, path: layerPath, description, architecture } = props; + const { assetPath, config, path: layerPath, description, architecture, afterBundle } = props; if (!fs.existsSync(`${layerPath}/requirements.txt`)) { throw new Error(`requirements.txt not found in ${layerPath}`); @@ -80,12 +81,19 @@ export class Layer extends Construct { removalPolicy: config.removalPolicy, bundling: { platform: architecture.dockerPlatform, - commandHooks: packagesExists ? { + commandHooks: (packagesExists || afterBundle) ? { beforeBundling (inputDir: string, outputDir: string): string[] { - return [`touch ${outputDir}/requirements.txt`]; + return [`mkdir -p ${outputDir}/python && touch ${outputDir}/python/requirements.txt`]; }, afterBundling (inputDir: string, outputDir: string): string[] { - return [`cp -r ${inputDir}/packages/* ${outputDir}/python/`]; + const commands = []; + if (packagesExists) { + commands.push(`cp -r ${inputDir}/packages/* ${outputDir}/python/`); + } + if (afterBundle) { + commands.push(...afterBundle(inputDir, outputDir)); + } + return commands; }, } : undefined }, diff --git a/lib/models/model-api.ts b/lib/models/model-api.ts index d91b14fc5..a81fc00db 100644 --- a/lib/models/model-api.ts +++ b/lib/models/model-api.ts @@ -387,7 +387,10 @@ export class ModelsApi extends Construct { new PolicyStatement({ effect: Effect.ALLOW, actions: [ - 'ecr:DescribeImages' + 'ecr:DescribeImages', + 'ecr:DescribeRepositories', + 'ecr:GetRepositoryPolicy', + 'ecr:ListImages' ], resources: ['*'] }), diff --git a/lib/models/state-machine/create-model.ts b/lib/models/state-machine/create-model.ts index 183027b58..cb5de9ddf 100644 --- a/lib/models/state-machine/create-model.ts +++ b/lib/models/state-machine/create-model.ts @@ -74,6 +74,7 @@ export class CreateModelStateMachine extends Construct { MANAGEMENT_KEY_NAME: managementKeyName, RESTAPI_SSL_CERT_ARN: config.restApiConfig?.sslCertIamArn ?? '', LITELLM_CONFIG_OBJ: JSON.stringify(config.litellmConfig), + AWS_ACCOUNT_ID: config.accountNumber, }; const setModelToCreating = new LambdaInvoke(this, 'SetModelToCreating', { @@ -212,14 +213,25 @@ export class CreateModelStateMachine extends Construct { const successState = new Succeed(this, 'CreateSuccess'); const failState = new Fail(this, 'CreateFailed'); + // Check if image is pre-existing ECR image + const checkImageTypeChoice = new Choice(this, 'CheckImageTypeChoice'); + // State Machine definition setModelToCreating.next(createModelInfraChoice); createModelInfraChoice .when(Condition.booleanEquals('$.create_infra', true), startCopyDockerImage) .otherwise(addModelToLitellm); + // Check if we need to poll for docker image or skip directly to stack creation + startCopyDockerImage.next(checkImageTypeChoice); + startCopyDockerImage.addCatch(handleFailureState, { // fail if ECR image verification fails + errors: ['States.TaskFailed'], + }); + checkImageTypeChoice + .when(Condition.stringEquals('$.image_info.image_status', 'prebuilt'), startCreateStack) + .otherwise(pollDockerImageAvailable); + // poll ECR image copy status loop - startCopyDockerImage.next(pollDockerImageAvailable); pollDockerImageAvailable.next(pollDockerImageChoice); pollDockerImageAvailable.addCatch(handleFailureState, { // fail if exception thrown from code errors: ['MaxPollsExceededException'], diff --git a/lib/rag/ingestion/ingestion-image/requirements.txt b/lib/rag/ingestion/ingestion-image/requirements.txt index 59b0c0369..97cd2f20d 100644 --- a/lib/rag/ingestion/ingestion-image/requirements.txt +++ b/lib/rag/ingestion/ingestion-image/requirements.txt @@ -3,10 +3,10 @@ # urllib3<2 // Provided by Lambda aioboto3==12.3.0 aiobotocore==2.11.2 -aiohttp==3.10.11 +aiohttp==3.12.14 boto3==1.34.34 click==8.1.7 -cryptography==43.0.3 +cryptography==44.0.1 fastapi_utils==0.7.0 fastapi==0.115.11 gunicorn==23.0.0 @@ -22,10 +22,10 @@ psycopg2-binary==2.9.9 pydantic==2.8.2 PyJWT==2.9.0 pynacl==1.5.0 -pypdf==4.3.1 +pypdf==6.0.0 lxml==5.1.0 python-docx==1.1.0 requests-aws4auth==1.2.3 -requests==2.32.3 +requests==2.32.4 text-generation==0.7.0 uvicorn==0.29.0 diff --git a/lib/rag/ingestion/ingestion-job-construct.ts b/lib/rag/ingestion/ingestion-job-construct.ts index d02357116..37625cdb4 100644 --- a/lib/rag/ingestion/ingestion-job-construct.ts +++ b/lib/rag/ingestion/ingestion-job-construct.ts @@ -22,14 +22,13 @@ */ import { Duration, Size, StackProps } from 'aws-cdk-lib'; import { Construct } from 'constructs'; -import { BaseProps } from '../../schema'; +import { BaseProps, EcsSourceType } from '../../schema'; import * as logs from 'aws-cdk-lib/aws-logs'; import * as iam from 'aws-cdk-lib/aws-iam'; import * as batch from 'aws-cdk-lib/aws-batch'; import * as ecs from 'aws-cdk-lib/aws-ecs'; import * as dynamodb from 'aws-cdk-lib/aws-dynamodb'; import * as lambda from 'aws-cdk-lib/aws-lambda'; -import { DockerImageAsset } from 'aws-cdk-lib/aws-ecr-assets'; import { Vpc } from '../../networking/vpc'; import path from 'path'; import { ILayerVersion } from 'aws-cdk-lib/aws-lambda'; @@ -37,6 +36,7 @@ import { getDefaultRuntime } from '../../api-base/utils'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import * as fs from 'fs'; import * as crypto from 'crypto'; +import { BATCH_INGESTION_PATH, CodeFactory } from '../../util'; // Props interface for the IngestionJobConstruct export type IngestionJobConstructProps = StackProps & BaseProps & { @@ -47,6 +47,19 @@ export type IngestionJobConstructProps = StackProps & BaseProps & { }; export class IngestionJobConstruct extends Construct { + private getMaxCpus (vpc: Vpc): number { + // Calculate maxvCpus based on available IPs in subnets to prevent IP exhaustion + // Each task uses 2 vCPUs, so maxvCpus = available_ips * 2 vCPUs per task + const availableIps = vpc.subnetSelection?.subnets?.reduce((total, subnet) => { + // Each subnet reserves 5 IPs (network, broadcast, gateway, DNS, future use) + const subnetSize = Math.pow(2, 32 - parseInt(subnet.ipv4CidrBlock.split('/')[1])); + return total + Math.max(0, subnetSize - 5); + }, 0) || 64; // Default to 64 if calculation fails + + const maxTasks = Math.min(availableIps, 256); // Cap at 256 for reasonable limits + return maxTasks * 2; // Each task uses 2 vCPUs + } + constructor (scope: Construct, id: string, props: IngestionJobConstructProps) { super(scope, id); @@ -84,10 +97,12 @@ export class IngestionJobConstruct extends Construct { }); // AWS Batch Fargate compute environment for running ingestion jobs + const maxvCpus = this.getMaxCpus(vpc); const computeEnv = new batch.FargateComputeEnvironment(this, 'IngestionJobFargateEnv', { computeEnvironmentName: `${config.deploymentName}-${config.deploymentStage}-ingestion-job-${hash}`, vpc: vpc.vpc, - + vpcSubnets: vpc.subnetSelection, + maxvCpus: maxvCpus, }); // AWS Batch job queue that uses the Fargate compute environment @@ -103,11 +118,10 @@ export class IngestionJobConstruct extends Construct { baseEnvironment['LISA_INGESTION_JOB_QUEUE_NAME'] = jobQueue.jobQueueName; // Set up build directory for Docker image - const ingestionImageRoot = path.join(__dirname, 'ingestion-image'); const buildDirName = 'build'; - const buildDir = path.join(ingestionImageRoot, buildDirName); + const buildDir = path.join(BATCH_INGESTION_PATH, buildDirName); - fs.mkdirSync(buildDir, {recursive: true}); + fs.mkdirSync(buildDir, { recursive: true }); const copyOptions = { recursive: true, @@ -136,12 +150,16 @@ export class IngestionJobConstruct extends Construct { }); } - // Build Docker image for batch jobs - const dockerImageAsset = new DockerImageAsset(this, 'IngestionJobImage', { - directory: ingestionImageRoot, + const imageConfig = config.batchIngestionConfig || { + baseImage: config.baseImage, + path: BATCH_INGESTION_PATH, + type: EcsSourceType.ASSET, buildArgs: { 'BUILD_DIR': buildDirName }, + }; + const image = CodeFactory.createImage(imageConfig, this, 'BatchIngestionContainer', { + 'BUILD_DIR': buildDirName }); // AWS Batch job definition specifying container configuration @@ -149,7 +167,7 @@ export class IngestionJobConstruct extends Construct { jobDefinitionName: `${config.deploymentName}-${config.deploymentStage}-ingestion-job-${hash}`, container: new batch.EcsFargateContainerDefinition(this, 'IngestionJobContainer', { environment: baseEnvironment, - image: ecs.ContainerImage.fromDockerImageAsset(dockerImageAsset), + image, memory: Size.mebibytes(4096), cpu: 2, command: ['-m', 'repository.pipeline_ingestion', 'Ref::ACTION', 'Ref::DOCUMENT_ID'], @@ -185,8 +203,9 @@ export class IngestionJobConstruct extends Construct { layers: layers, role: lambdaRole }); + const scheduleParameterName = `${config.deploymentPrefix}/ingestion/ingest/schedule`; new StringParameter(this, 'IngestionJobScheduleLambdaArn', { - parameterName: `${config.deploymentPrefix}/ingestion/ingest/schedule`, + parameterName: scheduleParameterName, stringValue: handlePipelineIngestScheduleLambda.functionArn }); handlePipelineIngestScheduleLambda.addPermission('AllowEventBridgeInvoke', { @@ -207,8 +226,9 @@ export class IngestionJobConstruct extends Construct { layers: layers, role: lambdaRole }); + const eventParameterName = `${config.deploymentPrefix}/ingestion/ingest/event`; new StringParameter(this, 'IngestionJobEventLambdaArn', { - parameterName: `${config.deploymentPrefix}/ingestion/ingest/event`, + parameterName: eventParameterName, stringValue: handlePipelineIngestEvent.functionArn }); handlePipelineIngestEvent.addPermission('AllowEventBridgeInvoke', { @@ -229,8 +249,9 @@ export class IngestionJobConstruct extends Construct { layers: layers, role: lambdaRole }); + const deleteParameterName = `${config.deploymentPrefix}/ingestion/delete/event`; new StringParameter(this, 'DeletionJobEventLambdaArn', { - parameterName: `${config.deploymentPrefix}/ingestion/delete/event`, + parameterName: deleteParameterName, stringValue: handlePipelineDeleteEvent.functionArn }); diff --git a/lib/rag/layer/requirements.txt b/lib/rag/layer/requirements.txt index 6128d3d98..36d98af3e 100644 --- a/lib/rag/layer/requirements.txt +++ b/lib/rag/layer/requirements.txt @@ -7,7 +7,7 @@ langchain-openai==0.2.11 opensearch-py==2.6.0 pgvector==0.2.5 psycopg2-binary==2.9.9 -pypdf==4.3.1 +pypdf==6.0.0 lxml==5.1.0 python-docx==1.1.0 requests-aws4auth==1.2.3 diff --git a/lib/rag/ragConstruct.ts b/lib/rag/ragConstruct.ts index e5c05ef26..fa35491ff 100644 --- a/lib/rag/ragConstruct.ts +++ b/lib/rag/ragConstruct.ts @@ -220,11 +220,15 @@ export class LisaRagConstruct extends Construct { const ragLambdaLayer = new Layer(scope, 'RagLayer', { config: config, path: RAG_LAYER_PATH, - description: 'Lambad dependencies for RAG API', + description: 'Lambda dependencies for RAG API', architecture: ARCHITECTURE, autoUpgrade: true, assetPath: config.lambdaLayerAssets?.ragLayerPath, + afterBundle: (inputDir: string, outputDir: string) => [ + `cp -r ${inputDir}/TIKTOKEN_CACHE/* ${outputDir}/TIKTOKEN_CACHE/` + ], }); + new StringParameter(scope, createCdkId([config.deploymentName, config.deploymentStage, 'RagLayer']), { parameterName: `${config.deploymentPrefix}/layerVersion/rag`, stringValue: ragLambdaLayer.layer.layerVersionArn diff --git a/lib/rag/state_machine/legacy-ingest-pipeline.ts b/lib/rag/state_machine/legacy-ingest-pipeline.ts index cc1ff514e..09379a7bd 100644 --- a/lib/rag/state_machine/legacy-ingest-pipeline.ts +++ b/lib/rag/state_machine/legacy-ingest-pipeline.ts @@ -104,7 +104,7 @@ export class LegacyIngestPipelineStateMachine extends Construct { const ingestionLambda = lambda.Function.fromFunctionArn(this, createCdkId(['IngestionScheduleLambda', hash]), ingestionLambdaArn.stringValue); // Create daily cron trigger with input template - new Rule(this, createCdkId(['DailyIngestRule', hash]), { + const dailyRule = new Rule(this, createCdkId(['DailyIngestRule', hash]), { ruleName: `${config.deploymentName}-${config.deploymentStage}-LegacyDailyIngestRule-${hash}`, schedule: Schedule.cron({ minute: '0', @@ -128,8 +128,12 @@ export class LegacyIngestPipelineStateMachine extends Construct { }) })] }); + + // Ensure rule is created after Lambda function parameter is available + dailyRule.node.addDependency(ingestionLambdaArn); } else if (pipelineConfig.trigger === 'event') { const ingestionLambdaArn = StringParameter.fromStringParameterName(this, createCdkId(['IngestionChangeEventLambdaStringParameter', hash]), `${config.deploymentPrefix}/ingestion/ingest/event`); + const ingestionLambda = lambda.Function.fromFunctionArn(this, createCdkId(['IngestionIngestEventLambda', hash]), ingestionLambdaArn.stringValue); // Create S3 event trigger with complete event pattern and transform input @@ -154,8 +158,8 @@ export class LegacyIngestPipelineStateMachine extends Construct { detail }; - new Rule(this, createCdkId(['S3EventIngestRule', hash]), { - ruleName: `${config.deploymentName}-${config.deploymentStage}-LegacyS3EventIngestRule-${hash}`, + const s3EventRule = new Rule(this, createCdkId(['S3EventIngestRule', hash]), { + ruleName: `${config.deploymentName}-${config.deploymentStage}-LegacyS3EventIngestRule-${hash}}`, eventPattern, targets: [new LambdaFunction(ingestionLambda, { event: RuleTargetInput.fromObject({ @@ -176,10 +180,14 @@ export class LegacyIngestPipelineStateMachine extends Construct { }) })] }); + + // Ensure rule is created after Lambda function parameter is available + s3EventRule.node.addDependency(ingestionLambdaArn); } if (pipelineConfig.autoRemove) { const deletionLambdaArn = StringParameter.fromStringParameterName(this, createCdkId(['IngestionDeleteEventLambdaStringParameter', hash]), `${config.deploymentPrefix}/ingestion/delete/event`); + const deletionLambda = lambda.Function.fromFunctionArn(this, createCdkId(['IngestionDeleteEventLambda', hash]), deletionLambdaArn.stringValue); console.log('Creating autodelete rule...'); @@ -216,7 +224,7 @@ export class LegacyIngestPipelineStateMachine extends Construct { detail }; - new Rule(this, createCdkId(['S3EventDeleteRule', hash]), { + const s3DeleteRule = new Rule(this, createCdkId(['S3EventDeleteRule', hash]), { ruleName: `${config.deploymentName}-${config.deploymentStage}-LegacyS3EventDeleteRule-${hash}`, eventPattern, targets: [new LambdaFunction(deletionLambda, { @@ -238,6 +246,9 @@ export class LegacyIngestPipelineStateMachine extends Construct { }) })] }); + + // Ensure rule is created after Lambda function parameter is available + s3DeleteRule.node.addDependency(deletionLambdaArn); } // Grant the execution role permissions to access specified S3 bucket/prefix diff --git a/lib/rag/vector-store/state_machine/delete-store.ts b/lib/rag/vector-store/state_machine/delete-store.ts index 38135d15c..cda13d22d 100644 --- a/lib/rag/vector-store/state_machine/delete-store.ts +++ b/lib/rag/vector-store/state_machine/delete-store.ts @@ -149,6 +149,11 @@ export class DeleteStoreStateMachine extends Construct { role: executionRole, }); + // Allow the Step Functions role to invoke the cleanup lambda + if (role) { + cleanupDocsFunc.grantInvoke(role); + } + const hasMoreDocs = new Choice(this, 'HasMoreDocs') .when(Condition.isNotNull('$.lastEvaluated'), new LambdaInvoke(this, 'CleanupRepositoryDocsRetry', { lambdaFunction: cleanupDocsFunc, diff --git a/lib/rag/vector-store/vector-store-creator.ts b/lib/rag/vector-store/vector-store-creator.ts index 93e62e1ac..089150d48 100644 --- a/lib/rag/vector-store/vector-store-creator.ts +++ b/lib/rag/vector-store/vector-store-creator.ts @@ -56,6 +56,17 @@ export class VectorStoreCreatorStack extends Construct { ], }); + cdkRole.addToPolicy(new iam.PolicyStatement({ + actions: [ + 's3:*', + 'ec2:*', + 'rds:*', + 'opensearch:*', + 'ssm:*', + ], + resources: ['*'] + }));// Additional CloudFormation permissions that might be needed + const lambdaExecutionRole = iam.Role.fromRoleArn( this, `${Roles.RAG_LAMBDA_EXECUTION_ROLE}-VectorStore`, @@ -65,25 +76,65 @@ export class VectorStoreCreatorStack extends Construct { ), ); - // Add permissions to create resources that will be in the dynamic stacks + // IAM: service-linked role creation for required services + cdkRole.addToPolicy(new iam.PolicyStatement({ + actions: ['iam:CreateServiceLinkedRole'], + resources: ['*'], + conditions: { + StringEquals: { + 'iam:AWSServiceName': ['opensearchservice.amazonaws.com', 'rds.amazonaws.com'] + } + } + })); + + // IAM: manage roles created within the dynamic stacks and allow passing to services cdkRole.addToPolicy(new iam.PolicyStatement({ actions: [ - 's3:*', - 'ec2:*', - 'rds:*', - 'opensearch:*', - 'ssm:*', - 'iam:*' + 'iam:CreateRole', + 'iam:DeleteRole', + 'iam:AttachRolePolicy', + 'iam:DetachRolePolicy', + 'iam:PutRolePolicy', + 'iam:DeleteRolePolicy', + 'iam:TagRole', + 'iam:UntagRole', + 'iam:GetRole', + 'iam:GetRolePolicy', + 'iam:ListRolePolicies', + 'iam:ListAttachedRolePolicies', + 'iam:ListRoleTags', + 'iam:UpdateAssumeRolePolicy', + 'iam:ListRoles' ], resources: ['*'], })); + // IAM: assume CDK bootstrap roles for deployment + cdkRole.addToPolicy(new iam.PolicyStatement({ + actions: ['iam:AssumeRole'], + resources: [ + `arn:${config.partition}:iam::${config.accountNumber}:role/cdk-*-deploy-role-${config.accountNumber}-${config.region}`, + `arn:${config.partition}:iam::${config.accountNumber}:role/cdk-hnb659fds-deploy-role-${config.accountNumber}-${config.region}` + ], + })); + + + cdkRole.addToPolicy(new iam.PolicyStatement({ + actions: ['iam:PassRole'], + resources: ['*'], + conditions: { + StringEquals: { + 'iam:PassedToService': [ + 'cloudformation.amazonaws.com', + 'lambda.amazonaws.com', + 'events.amazonaws.com' + ] + } + } + })); + const stateMachineRole = new iam.Role(this, createCdkId([config.deploymentName, config.deploymentStage, 'StateMachineRole']), { assumedBy: new iam.ServicePrincipal('states.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AWSStepFunctionsFullAccess'), - iam.ManagedPolicy.fromAwsManagedPolicyName('AWSCloudFormationFullAccess'), - ], }); vectorStoreTable.grantReadWriteData(stateMachineRole); @@ -125,6 +176,23 @@ export class VectorStoreCreatorStack extends Construct { securityGroups: [props.vpc.securityGroups.lambdaSg], }); + // Allow the state machine to invoke the deployer Lambda + this.vectorStoreCreatorFn.grantInvoke(stateMachineRole); + + // Minimal policies for state machine role + stateMachineRole.addToPolicy(new iam.PolicyStatement({ + actions: ['lambda:InvokeFunction'], + resources: [this.vectorStoreCreatorFn.functionArn], + })); + stateMachineRole.addToPolicy(new iam.PolicyStatement({ + actions: ['cloudformation:DescribeStacks', 'cloudformation:DeleteStack'], + resources: ['*'], + })); + stateMachineRole.addToPolicy(new iam.PolicyStatement({ + actions: ['dynamodb:PutItem', 'dynamodb:UpdateItem', 'dynamodb:GetItem', 'dynamodb:DeleteItem'], + resources: [vectorStoreTable.tableArn], + })); + new CreateStoreStateMachine(this, 'CreateVectorStoreStateMachine', { config: props.config, executionRole: lambdaExecutionRole, diff --git a/lib/schema/configSchema.ts b/lib/schema/configSchema.ts index a7b91d1b8..e5a5176aa 100644 --- a/lib/schema/configSchema.ts +++ b/lib/schema/configSchema.ts @@ -784,7 +784,8 @@ export const RawConfigObject = z.object({ region: z.string().describe('AWS region for deployment.'), partition: z.string().default('aws').describe('AWS partition for deployment.'), domain: z.string().default('amazonaws.com').describe('AWS domain for deployment'), - restApiConfig: FastApiContainerConfigSchema, + restApiConfig: FastApiContainerConfigSchema.describe('Image override for Rest API'), + batchIngestionConfig: ImageAssetSchema.optional().describe('Image override for Batch Ingestion'), vpcId: z.string().optional().describe('VPC ID for the application. (e.g. vpc-0123456789abcdef)'), subnets: z.array(z.object({ subnetId: z.string().startsWith('subnet-'), diff --git a/lib/schema/ragSchema.ts b/lib/schema/ragSchema.ts index c4986ba59..72b68c761 100644 --- a/lib/schema/ragSchema.ts +++ b/lib/schema/ragSchema.ts @@ -97,6 +97,7 @@ export const RagRepositoryConfigSchema = z .regex(/^(?!-).*(? litellm_config.yaml - # Copy the source code into the container COPY src/ ./src @@ -31,6 +26,11 @@ COPY TIKTOKEN_CACHE ./TIKTOKEN_CACHE # Generate the prisma binary RUN prisma generate +# Copy LiteLLM config directly to container, it will be updated at runtime +# with LISA-hosted models. This filename is expected in the entrypoint.sh file, so do not modify +# the filename unless you modify it in the entrypoint.sh file too. +RUN echo "$LITELLM_CONFIG" > litellm_config.yaml + # Make entrypoint.sh executable RUN chmod +x src/entrypoint.sh diff --git a/lib/serve/rest-api/src/entrypoint.sh b/lib/serve/rest-api/src/entrypoint.sh index 8e8c0bb6a..15fe25eda 100644 --- a/lib/serve/rest-api/src/entrypoint.sh +++ b/lib/serve/rest-api/src/entrypoint.sh @@ -5,6 +5,9 @@ set -e HOST="0.0.0.0" PORT="8080" +# Prisma client is generated during build +echo "Prisma client already generated during build" + # Update LiteLLM config that was already copied from config.yaml with runtime-deployed models. # Depends on SSM Parameter for registered models. echo "Configuring and starting LiteLLM" diff --git a/lib/serve/rest-api/src/requirements.txt b/lib/serve/rest-api/src/requirements.txt index c35af2e51..8230b4bab 100644 --- a/lib/serve/rest-api/src/requirements.txt +++ b/lib/serve/rest-api/src/requirements.txt @@ -1,9 +1,9 @@ -aioboto3==12.3.0 -aiobotocore==2.11.2 -aiohttp==3.10.11 -boto3==1.34.34 +aioboto3>=12.0.0,<15.0.0 +aiobotocore>=2.11.0,<3.0.0 +aiohttp==3.12.14 +boto3>=1.34.0,<1.37.0 click==8.1.7 -cryptography==43.0.3 +cryptography>=43.0.1,<44.0.0 fastapi==0.115.11 fastapi_utils==0.7.0 gunicorn==23.0.0 diff --git a/lib/user-interface/react/package.json b/lib/user-interface/react/package.json index 30c24e1d7..a3bea78f8 100644 --- a/lib/user-interface/react/package.json +++ b/lib/user-interface/react/package.json @@ -1,7 +1,7 @@ { "name": "lisa-web", "private": true, - "version": "5.1.0", + "version": "5.2.0", "type": "module", "scripts": { "dev": "vite", @@ -33,6 +33,7 @@ "langchain": "^0.3.15", "lodash": "^4.17.21", "luxon": "^3.5.0", + "mermaid": "^11.10.1", "react": "^18.3.1", "react-dom": "^18.3.1", "react-json-view-lite": "^0.9.8", @@ -40,11 +41,13 @@ "react-oidc-context": "^2.4.0", "react-redux": "^8.1.3", "react-router-dom": "^6.29.0", - "react-syntax-highlighter": "^15.6.1", + "react-syntax-highlighter": "^15.6.6", "react-textarea-autosize": "^8.5.7", "redux-persist": "^6.0.0", "regenerator-runtime": "^0.14.1", + "rehype-mathjax": "^7.1.0", "remark-breaks": "^4.0.0", + "remark-math": "^6.0.0", "tailwindcss": "^3.4.17", "typescript": "~5.1.6", "unraw": "^3.0.0", diff --git a/lib/user-interface/react/src/components/Topbar.tsx b/lib/user-interface/react/src/components/Topbar.tsx index 57637a222..191f0bff9 100644 --- a/lib/user-interface/react/src/components/Topbar.tsx +++ b/lib/user-interface/react/src/components/Topbar.tsx @@ -118,7 +118,7 @@ function Topbar ({ configs }: TopbarProps): ReactElement { disableUtilityCollapse: false, external: false, onClick: () => { - navigate('/ai-assistant'); + navigate('/'); }, }, ...( diff --git a/lib/user-interface/react/src/components/chatbot/Chat.tsx b/lib/user-interface/react/src/components/chatbot/Chat.tsx index fa675ca85..4c21dec41 100644 --- a/lib/user-interface/react/src/components/chatbot/Chat.tsx +++ b/lib/user-interface/react/src/components/chatbot/Chat.tsx @@ -24,7 +24,6 @@ import { ButtonGroup, Checkbox, Grid, PromptInput, - TextContent, Icon, } from '@cloudscape-design/components'; import StatusIndicator from '@cloudscape-design/components/status-indicator'; @@ -34,7 +33,8 @@ import { LisaAttachImageResponse, LisaChatMessage, LisaChatSession, - MessageTypes + MessageTypes, + ModelFeatures } from '../types'; import RagControls from './components/RagOptions'; import { ContextUploadModal, RagUploadModal } from './components/FileUploadModals'; @@ -117,6 +117,8 @@ export default function Chat ({ sessionId }) { const [useRag, setUseRag] = useState(false); const [openAiTools, setOpenAiTools] = useState(undefined); const [preferences, setPreferences] = useState(undefined); + const [modelFilterValue, setModelFilterValue] = useState(''); + const [hasUserInteractedWithModel, setHasUserInteractedWithModel] = useState(false); // Ref to track if we're processing tool calls to prevent infinite loops const isProcessingToolCalls = useRef(false); @@ -153,7 +155,6 @@ export default function Chat ({ sessionId }) { const { session, setSession, - internalSessionId, setInternalSessionId, loadingSession, chatConfiguration, @@ -170,6 +171,26 @@ export default function Chat ({ sessionId }) { setChatConfiguration ); + // Set default model if none is selected, default model is configured, and user hasn't interacted + useEffect(() => { + if (!selectedModel && !hasUserInteractedWithModel && config?.configuration?.global?.defaultModel && allModels) { + const defaultModelId = config.configuration.global.defaultModel; + handleModelChange(defaultModelId, selectedModel, setSelectedModel); + } + }, [selectedModel, hasUserInteractedWithModel, config?.configuration?.global?.defaultModel, allModels, handleModelChange, setSelectedModel]); + + // Wrapper for handleModelChange that tracks user interaction + const handleUserModelChange = (value: string) => { + setHasUserInteractedWithModel(true); + setModelFilterValue(value); + handleModelChange(value, selectedModel, setSelectedModel); + }; + + // Update filter value when selected model changes + useEffect(() => { + setModelFilterValue(selectedModel?.modelId ?? ''); + }, [selectedModel]); + const { memory, setMemory, metadata } = useMemory( session, chatConfiguration, @@ -425,10 +446,10 @@ export default function Chat ({ sessionId }) { }, [sessionHealth]); useEffect(() => { - if (bottomRef) { - bottomRef?.current.scrollIntoView({ behavior: 'smooth' }); + if (bottomRef.current) { + bottomRef.current.scrollIntoView({ behavior: 'smooth' }); } - }, [session.history.length]); + }, [session.history.length, isStreaming, isRunning, generateResponse]); // Reset tool call counter when session changes useEffect(() => { @@ -593,6 +614,9 @@ export default function Chat ({ sessionId }) { visible={modals.sessionConfiguration} setVisible={(show) => show ? openModal('sessionConfiguration') : closeModal('sessionConfiguration')} systemConfig={config} + session={session} + updateSession={updateSession} + ragConfig={ragConfig} /> )} -
+
{session.history.map((message, idx) => ( No models available.
} filteringType='auto' - value={selectedModel?.modelId ?? ''} + value={modelFilterValue} enteredTextLabel={(text) => `Use: "${text}"`} - onChange={({ detail: { value } }) => handleModelChange(value, selectedModel, setSelectedModel)} + onChange={({ detail: { value } }) => handleUserModelChange(value)} options={modelsOptions} ref={modelSelectRef} /> @@ -725,7 +749,7 @@ export default function Chat ({ sessionId }) { - - - -
- Session ID: {internalSessionId} -
-
-
- {enabledServers && enabledServers.length > 0 ? ( + + {enabledServers && enabledServers.length > 0 && selectedModel?.features?.filter((feature) => feature.name === ModelFeatures.TOOL_CALLS)?.length && true ? ( {enabledServers.length} MCP Servers - {openAiTools?.length || 0} tools - ) : (
)} + ) + : !selectedModel || !enabledServers || enabledServers.length === 0 ? (
) + : ( + This model does not have Tool Calling enabled + )} {isConnected ? 'Connected' : 'Disconnected'} diff --git a/lib/user-interface/react/src/components/chatbot/components/MermaidDiagram.tsx b/lib/user-interface/react/src/components/chatbot/components/MermaidDiagram.tsx new file mode 100644 index 000000000..7c6a51bf9 --- /dev/null +++ b/lib/user-interface/react/src/components/chatbot/components/MermaidDiagram.tsx @@ -0,0 +1,229 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import React, { useEffect, useRef, useState, useCallback } from 'react'; +import mermaid from 'mermaid'; +import { ButtonGroup, StatusIndicator } from '@cloudscape-design/components'; +import { downloadSvgAsPng } from '../../../shared/util/downloader'; + +type MermaidDiagramProps = { + chart: string; + id?: string; + isStreaming?: boolean; +}; + +const MermaidDiagram: React.FC = React.memo(({ chart, id, isStreaming }) => { + const containerRef = useRef(null); + const [error, setError] = useState(''); + const [svg, setSvg] = useState(''); + const [isLoading, setIsLoading] = useState(true); + const mermaidInitialized = useRef(false); + const lastRenderedChart = useRef(''); + + // Initialize Mermaid once + useEffect(() => { + if (!mermaidInitialized.current) { + mermaid.initialize({ + startOnLoad: false, + theme: 'dark', + securityLevel: 'loose', + fontFamily: 'Arial, sans-serif', + suppressErrorRendering: true, + fontSize: 14, + flowchart: { + useMaxWidth: true, + htmlLabels: true, + }, + sequence: { + useMaxWidth: true, + wrap: true, + }, + gantt: { + useMaxWidth: true, + }, + }); + mermaidInitialized.current = true; + } + }, []); + + + // Render the diagram once + useEffect(() => { + const renderDiagram = async () => { + // Skip rendering if we've already rendered this exact chart + if (lastRenderedChart.current === chart && svg) { + return; + } + + if (!chart.trim()) { + setError('Empty chart content'); + setIsLoading(false); + return; + } + + // Don't render during streaming or if syntax appears incomplete + if (isStreaming) { + setIsLoading(true); + return; + } + + setIsLoading(true); + setError(''); + + try { + const diagramId = id || `mermaid-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + const { svg: renderedSvg } = await mermaid.render(diagramId, chart); + setSvg(renderedSvg); + lastRenderedChart.current = chart; + } catch (err) { + console.error('Mermaid rendering error:', err); + setError(`Failed to render diagram: ${err instanceof Error ? err.message : 'Unknown error'}`); + } finally { + setIsLoading(false); + } + }; + + renderDiagram(); + }, [chart, id, svg, isStreaming]); + + const copyToClipboard = useCallback(async (content: string) => { + try { + await navigator.clipboard.writeText(content); + } catch (err) { + console.error('Failed to copy to clipboard:', err); + } + }, []); + + + const handleButtonClick = useCallback(({ detail }: { detail: { id: string } }) => { + if (detail.id === 'copy-code') { + copyToClipboard(chart); + } else if (detail.id === 'download-png') { + // Find the SVG element in the container + const svgElement = containerRef.current?.querySelector('svg'); + if (svgElement) { + downloadSvgAsPng(svgElement, 'mermaid-diagram.png'); + } + } + }, [chart, copyToClipboard]); + + // Error state - show original code + if (error) { + return ( +
+ Mermaid Error: {error} +
+ + Show diagram source + +
+                        {chart}
+                    
+
+
+ ); + } + + // Loading state + if (isLoading || !svg) { + return ( +
+ Rendering Mermaid diagram... +
+ ); + } + + const buttonItems = [ + { + type: 'icon-button' as const, + id: 'copy-code', + iconName: 'file' as const, + text: 'Copy Mermaid Code', + popoverFeedback: ( + + Mermaid code copied + + ) + }, + { + type: 'icon-button' as const, + id: 'download-png', + iconName: 'download' as const, + text: 'Download as PNG', + popoverFeedback: ( + + PNG downloaded + + ) + } + ]; + + return ( +
+
+ +
+
+
+ ); +}); + +MermaidDiagram.displayName = 'MermaidDiagram'; + +export default MermaidDiagram; diff --git a/lib/user-interface/react/src/components/chatbot/components/Message.tsx b/lib/user-interface/react/src/components/chatbot/components/Message.tsx index 3fdffff26..f47cb0761 100644 --- a/lib/user-interface/react/src/components/chatbot/components/Message.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/Message.tsx @@ -28,13 +28,17 @@ import { selectCurrentUsername } from '@/shared/reducers/user.reducer'; import ChatBubble from '@cloudscape-design/chat-components/chat-bubble'; import Avatar from '@cloudscape-design/chat-components/avatar'; import remarkBreaks from 'remark-breaks'; +import remarkMath from 'remark-math'; +import rehypeMathjax from 'rehype-mathjax'; import { MessageContent } from '@langchain/core/messages'; import { base64ToBlob, fetchImage, getDisplayableMessage, messageContainsImage } from '@/components/utils'; -import React, { useEffect, useState } from 'react'; +import { useEffect, useState, useMemo } from 'react'; import { IChatConfiguration } from '@/shared/model/chat.configurations.model'; import { downloadFile } from '@/shared/util/downloader'; import Link from '@cloudscape-design/components/link'; import ImageViewer from '@/components/chatbot/components/ImageViewer'; +import MermaidDiagram from '@/components/chatbot/components/MermaidDiagram'; +import UsageInfo from '@/components/chatbot/components/UsageInfo'; import { merge } from 'lodash'; type MessageProps = { @@ -48,9 +52,10 @@ type MessageProps = { handleSendGenerateRequest: () => void; setUserPrompt: (state: string) => void; chatConfiguration: IChatConfiguration; + showUsage?: boolean; }; -export default function Message ({ message, isRunning, showMetadata, isStreaming, markdownDisplay, setUserPrompt, setChatConfiguration, handleSendGenerateRequest, chatConfiguration, callingToolName }: MessageProps) { +export default function Message ({ message, isRunning, showMetadata, isStreaming, markdownDisplay, setUserPrompt, setChatConfiguration, handleSendGenerateRequest, chatConfiguration, callingToolName, showUsage = false }: MessageProps) { const currentUser = useAppSelector(selectCurrentUsername); const ragCitations = !isStreaming && message?.metadata?.ragDocuments ? message?.metadata.ragDocuments : undefined; const [resend, setResend] = useState(false); @@ -66,6 +71,146 @@ export default function Message ({ message, isRunning, showMetadata, isStreaming // eslint-disable-next-line react-hooks/exhaustive-deps }, [resend]); + // Memoize the ReactMarkdown components to prevent re-creation on every render + const markdownComponents = useMemo(() => ({ + code ({ className, children, ...props }: any) { + const match = /language-(\w+)/.exec(className || ''); + const codeString = String(children).replace(/\n$/, ''); + + const CodeBlockWithCopyButton = ({ language, code }: { language: string, code: string }) => { + return ( +
+
+ + navigator.clipboard.writeText(code) + } + ariaLabel='Chat actions' + dropdownExpandToViewport + items={[ + { + type: 'icon-button', + id: 'copy code', + iconName: 'copy', + text: 'Copy Code', + popoverFeedback: ( + + Code copied + + ) + } + ]} + variant='icon' + /> +
+ + {code} + +
+ ); + }; + const CodeBlockWithoutLanguage = ({ code }: { code: string }) => { + return ( +
+
+ + navigator.clipboard.writeText(code) + } + ariaLabel='Chat actions' + dropdownExpandToViewport + items={[ + { + type: 'icon-button', + id: 'copy code', + iconName: 'copy', + text: 'Copy Code', + popoverFeedback: ( + + Code copied + + ) + } + ]} + variant='icon' + /> +
+
+                            
+                                {code}
+                            
+                        
+
+ ); + }; + // Check if this is inline code by examining the props + const isInlineCode = !props.node || props.node.position?.start?.line === props.node.position?.end?.line; + + if (isInlineCode) { + return ( + + {children} + + ); + } + return match ? ( + match[1] === 'mermaid' ? ( + + ) : ( + + ) + ) : ( + + ); + }, + }), [isStreaming]); // Include isStreaming so the component can access it + const renderContent = (messageType: string, content: MessageContent, metadata?: LisaChatMessageMetadata) => { if (Array.isArray(content)) { return content.map((item, index) => { @@ -128,81 +273,10 @@ export default function Message ({ message, isRunning, showMetadata, isStreaming
{markdownDisplay ? ( { - - - return ( -
-
- - navigator.clipboard.writeText(code) - } - ariaLabel='Chat actions' - dropdownExpandToViewport - items={[ - { - type: 'icon-button', - id: 'copy code', - iconName: 'copy', - text: 'Copy Code', - popoverFeedback: ( - - Code copied - - ) - } - ]} - variant='icon' - /> -
- - {code} - -
- ); - }; - - return match ? ( - - ) : ( - - {children} - - ); - }, - ul ({ ...props }: any) { - return
    ; - }, - ol ({ ...props }: any) { - return
      ; - }, - li ({ ...props }: any) { - return
    1. ; - }, - }} + components={markdownComponents} /> ) : (
      {getDisplayableMessage(content, message.type === MessageTypes.AI ? ragCitations : undefined)}
      @@ -227,6 +301,7 @@ export default function Message ({ message, isRunning, showMetadata, isStreaming tooltipText='Generative AI assistant' /> } + actions={showUsage ? : undefined} > Generating response @@ -246,6 +321,7 @@ export default function Message ({ message, isRunning, showMetadata, isStreaming tooltipText='Generative AI assistant' /> } + actions={showUsage ? : undefined} > 🔨Calling {callingToolName} tool 🔨 @@ -266,11 +342,19 @@ export default function Message ({ message, isRunning, showMetadata, isStreaming tooltipText='Generative AI assistant' /> } + actions={showUsage ? : undefined} > {renderContent(message.type, message.content, message.metadata)} - {showMetadata && !isStreaming && - - } + {showMetadata && !isStreaming && + + + } {!isStreaming && !messageContainsImage(message.content) &&
      @@ -300,21 +384,47 @@ export default function Message ({ message, isRunning, showMetadata, isStreaming )} {message?.type === 'human' && ( - +
      + + } + > +
      + {renderContent(message.type, message.content)} +
      +
      + + ['copy'].includes(detail.id) && + navigator.clipboard.writeText(getDisplayableMessage(message.content)) + } + ariaLabel='Chat actions' + dropdownExpandToViewport + items={[ + { + type: 'icon-button', + id: 'copy', + iconName: 'copy', + text: 'Copy Input', + popoverFeedback: ( + + Input copied + + ) + } + ]} + variant='icon' /> - } - > -
      - {renderContent(message.type, message.content)}
      - + )} {message?.type === MessageTypes.TOOL && ( diff --git a/lib/user-interface/react/src/components/chatbot/components/RagOptions.tsx b/lib/user-interface/react/src/components/chatbot/components/RagOptions.tsx index e7c1c8b05..7674ccce9 100644 --- a/lib/user-interface/react/src/components/chatbot/components/RagOptions.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/RagOptions.tsx @@ -15,7 +15,7 @@ */ import { Autosuggest, Grid, SpaceBetween } from '@cloudscape-design/components'; -import { useEffect, useMemo, useState } from 'react'; +import { useEffect, useMemo, useState, useRef } from 'react'; import { useGetAllModelsQuery } from '@/shared/reducers/model-management.reducer'; import { IModel, ModelStatus, ModelType } from '@/shared/model/model-management.model'; import { useListRagRepositoriesQuery } from '@/shared/reducers/rag.reducer'; @@ -35,27 +35,109 @@ type RagControlProps = { export default function RagControls ({isRunning, setUseRag, setRagConfig, ragConfig }: RagControlProps) { const { data: repositories, isFetching: isLoadingRepositories } = useListRagRepositoriesQuery(undefined, {refetchOnMountOrArgChange: true}); - const [selectedEmbeddingOption, setSelectedEmbeddingOption] = useState(undefined); - const [selectedRepositoryOption, setSelectedRepositoryOption] = useState(undefined); const { data: allModels, isFetching: isFetchingModels } = useGetAllModelsQuery(undefined, {refetchOnMountOrArgChange: 5, selectFromResult: (state) => ({ isFetching: state.isFetching, data: (state.data || []).filter((model) => model.modelType === ModelType.embedding && model.status === ModelStatus.InService), })}); + + const [userHasSelectedModel, setUserHasSelectedModel] = useState(false); + + const lastRepositoryIdRef = useRef(undefined); + + const selectedRepositoryOption = ragConfig?.repositoryId ?? ''; + const selectedEmbeddingOption = ragConfig?.embeddingModel?.modelId ?? ''; + const embeddingOptions = useMemo(() => { - return allModels?.map((model) => ({value: model.modelId})) || []; - }, [allModels]); + if (!allModels || !selectedRepositoryOption) return []; + + const repository = repositories?.find((repo) => repo.repositoryId === selectedRepositoryOption); + const defaultModelId = repository?.embeddingModelId; + + return allModels.map((model) => ({ + value: model.modelId, + label: model.modelId + (model.modelId === defaultModelId ? ' (default)' : '') + })); + }, [allModels, repositories, selectedRepositoryOption]); useEffect(() => { setUseRag(!!selectedEmbeddingOption && !!selectedRepositoryOption); - // setUseRag is never going to change as it's just a setState function - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [selectedRepositoryOption, selectedEmbeddingOption]); + }, [selectedRepositoryOption, selectedEmbeddingOption, setUseRag]); + // Effect for handling repository changes and auto-selection useEffect(() => { - setSelectedEmbeddingOption(ragConfig?.embeddingModel?.modelId ?? undefined); - setSelectedRepositoryOption(ragConfig?.repositoryId ?? undefined); - }, [ragConfig]); + const currentRepositoryId = ragConfig?.repositoryId; + const repositoryHasChanged = currentRepositoryId !== lastRepositoryIdRef.current; + + // Update tracking and reset user selection flag when repository changes + if (repositoryHasChanged) { + lastRepositoryIdRef.current = currentRepositoryId; + setUserHasSelectedModel(false); + } + + // Auto-select default model when repository changes or no model is set + if (currentRepositoryId && repositories && allModels) { + const repository = repositories.find((repo) => repo.repositoryId === currentRepositoryId); + + if (repository?.embeddingModelId) { + const defaultModel = allModels.find((model) => model.modelId === repository.embeddingModelId); + + if (defaultModel) { + const shouldAutoSwitch = repositoryHasChanged || + (!ragConfig?.embeddingModel && !userHasSelectedModel); + + if (shouldAutoSwitch) { + setRagConfig((config) => ({ + ...config, + embeddingModel: defaultModel, + })); + } + } + } + } + }, [ragConfig?.repositoryId, ragConfig?.embeddingModel, repositories, allModels, userHasSelectedModel, setRagConfig]); + + const handleRepositoryChange = ({ detail }) => { + const newRepositoryId = detail.value; + setUserHasSelectedModel(false); // Reset when repository changes + + if (newRepositoryId) { + const repository = repositories?.find((repo) => repo.repositoryId === newRepositoryId); + setRagConfig((config) => ({ + ...config, + repositoryId: newRepositoryId, + repositoryType: repository?.type || 'unknown', + embeddingModel: undefined, // Clear current model so useEffect can set default + })); + } else { + setRagConfig((config) => ({ + ...config, + repositoryId: undefined, + repositoryType: undefined, + embeddingModel: undefined, + })); + } + }; + + const handleModelChange = ({ detail }) => { + const newModelId = detail.value; + setUserHasSelectedModel(true); // Mark that user has made an explicit choice + + if (newModelId) { + const model = allModels.find((model) => model.modelId === newModelId); + if (model) { + setRagConfig((config) => ({ + ...config, + embeddingModel: model, + })); + } + } else { + setRagConfig((config) => ({ + ...config, + embeddingModel: undefined, + })); + } + }; return ( @@ -72,17 +154,13 @@ export default function RagControls ({isRunning, setUseRag, setRagConfig, ragCon placeholder='Select a RAG Repository' empty={
      No repositories available.
      } filteringType='auto' - value={selectedRepositoryOption ?? ''} + value={selectedRepositoryOption} enteredTextLabel={(text) => `Use: "${text}"`} - onChange={({ detail }) => { - setSelectedRepositoryOption(detail.value); - setRagConfig((config) => ({ - ...config, - repositoryId: detail.value, - repositoryType: detail.value, - })); - }} - options={repositories?.map((repository) => ({value: repository.repositoryId, label: repository?.repositoryName?.length ? repository?.repositoryName : repository.repositoryId})) || []} + onChange={handleRepositoryChange} + options={repositories?.map((repository) => ({ + value: repository.repositoryId, + label: repository?.repositoryName?.length ? repository?.repositoryName : repository.repositoryId + })) || []} /> No embedding models available.
      } filteringType='auto' - value={selectedEmbeddingOption ?? ''} + value={selectedEmbeddingOption} enteredTextLabel={(text) => `Use: "${text}"`} - onChange={({ detail }) => { - setSelectedEmbeddingOption(detail.value); - - const model = allModels.find((model) => model.modelId === detail.value); - if (model) { - setRagConfig((config) => ({ - ...config, - embeddingModel: model, - })); - } - }} + onChange={handleModelChange} options={embeddingOptions} /> diff --git a/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx b/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx index b5967dc09..2b4750cc4 100644 --- a/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/SessionConfiguration.tsx @@ -31,6 +31,7 @@ import Toggle from '@cloudscape-design/components/toggle'; import { IChatConfiguration } from '@/shared/model/chat.configurations.model'; import { IModel, ModelType } from '@/shared/model/model-management.model'; import { IConfiguration } from '@/shared/model/configuration.model'; +import { LisaChatSession } from '@/components/types'; export type SessionConfigurationProps = { title?: string; @@ -41,7 +42,10 @@ export type SessionConfigurationProps = { selectedModel: IModel; isRunning: boolean; systemConfig: IConfiguration; - modelOnly?: boolean + modelOnly?: boolean; + session?: LisaChatSession; + updateSession?: (session: LisaChatSession) => void; + ragConfig?: any; }; export default function SessionConfiguration ({ @@ -53,16 +57,33 @@ export default function SessionConfiguration ({ visible, setVisible, systemConfig, - modelOnly = false + modelOnly = false, + session, + updateSession, + ragConfig }: SessionConfigurationProps) { // Defaults based on https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig // Default stop sequences based on User/Assistant instruction prompting for Falcon, Mistral, etc. const updateSessionConfiguration = (property: string, value: any): void => { - setChatConfiguration({ + const updatedConfiguration = { ...chatConfiguration, sessionConfiguration: { ...chatConfiguration.sessionConfiguration, [property]: value }, - }); + }; + + setChatConfiguration(updatedConfiguration); + + // Immediately persist the configuration to the session if available + if (session && updateSession) { + updateSession({ + ...session, + configuration: { + ...updatedConfiguration, + selectedModel: selectedModel, + ragConfig: ragConfig + } + }); + } }; const oneThroughTenOptions = [...Array(10).keys()].map((i) => { @@ -325,10 +346,14 @@ export default function SessionConfiguration ({ updateSessionConfiguration('modelArgs', { - ...chatConfiguration.sessionConfiguration.modelArgs, - stop: chatConfiguration.sessionConfiguration.modelArgs.stop.concat(''), - })} + onAddButtonClick={() => { + if (chatConfiguration.sessionConfiguration.modelArgs.stop.length < 4) { + updateSessionConfiguration('modelArgs', { + ...chatConfiguration.sessionConfiguration.modelArgs, + stop: chatConfiguration.sessionConfiguration.modelArgs.stop.concat(''), + }); + } + }} removeButtonText='Remove' onRemoveButtonClick={(event) => updateSessionConfiguration('modelArgs', { diff --git a/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx b/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx index fdb763296..bc8e879b7 100644 --- a/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx +++ b/lib/user-interface/react/src/components/chatbot/components/Sessions.tsx @@ -17,9 +17,10 @@ import SpaceBetween from '@cloudscape-design/components/space-between'; import Link from '@cloudscape-design/components/link'; import Header from '@cloudscape-design/components/header'; -import { ButtonDropdown, Grid, Input, Popover, Modal, FormField } from '@cloudscape-design/components'; +import ExpandableSection from '@cloudscape-design/components/expandable-section'; +import { ButtonDropdown, Input, Popover, Modal, FormField, Grid } from '@cloudscape-design/components'; import Button from '@cloudscape-design/components/button'; -import { useCollection } from '@cloudscape-design/collection-hooks'; + import { useLazyGetConfigurationQuery } from '@/shared/reducers/configuration.reducer'; import { sessionApi, @@ -43,6 +44,8 @@ import JSZip from 'jszip'; import { downloadFile } from '@/shared/util/downloader'; import { setConfirmationModal } from '@/shared/reducers/modal.reducer'; + + export function Sessions ({ newSession }) { const dispatch = useAppDispatch(); const notificationService = useNotificationService(dispatch); @@ -77,7 +80,8 @@ export function Sessions ({ newSession }) { const [renameModalVisible, setRenameModalVisible] = useState(false); const [sessionToRename, setSessionToRename] = useState(null); const [newSessionName, setNewSessionName] = useState(''); - const { data: sessions } = useListSessionsQuery(null, { refetchOnMountOrArgChange: 5 }); + const [sessionBeingDeleted, setSessionBeingDeleted] = useState(null); + const { data: sessions, isLoading: isSessionsLoading } = useListSessionsQuery(null, { refetchOnMountOrArgChange: 5 }); // Filter sessions based on search query const filteredSessions = useMemo(() => { @@ -88,16 +92,47 @@ export function Sessions ({ newSession }) { .filter((session) => getSessionDisplay(session).toLowerCase().includes(searchQuery.toLowerCase())); }, [sessions, searchQuery]); - const { items } = useCollection(filteredSessions, { - sorting: { - defaultState: { - sortingColumn: { - sortingField: 'StartTime', - }, - isDescending: true, - }, - }, - }); + // Group and sort sessions by time periods + const groupedSessions = useMemo(() => { + const now = new Date(); + const groups = { + 'Last Day': [] as LisaChatSession[], + 'Last 7 Days': [] as LisaChatSession[], + 'Last Month': [] as LisaChatSession[], + 'Last 3 Months': [] as LisaChatSession[], + 'Older': [] as LisaChatSession[] + }; + + filteredSessions.forEach((session) => { + // Use lastUpdated if available, otherwise fallback to startTime for backward compatibility + const lastUpdated = session.lastUpdated || session.startTime; + const sessionDate = new Date(lastUpdated); + const diffInDays = (now.getTime() - sessionDate.getTime()) / (1000 * 60 * 60 * 24); + + if (diffInDays <= 1) { + groups['Last Day'].push(session); + } else if (diffInDays <= 7) { + groups['Last 7 Days'].push(session); + } else if (diffInDays <= 30) { + groups['Last Month'].push(session); + } else if (diffInDays <= 90) { + groups['Last 3 Months'].push(session); + } else { + groups['Older'].push(session); + } + }); + + // Sort sessions within each group by lastUpdated (most recent first) + Object.keys(groups).forEach((key) => { + groups[key as keyof typeof groups].sort((a, b) => { + const aTime = new Date(a.lastUpdated || a.startTime).getTime(); + const bTime = new Date(b.lastUpdated || b.startTime).getTime(); + return bTime - aTime; // Descending order (newest first) + }); + }); + + return groups; + }, [filteredSessions]); useEffect(() => { if (!auth.isLoading && auth.isAuthenticated) { @@ -112,10 +147,18 @@ export function Sessions ({ newSession }) { useEffect(() => { if (!isDeleteByIdLoading && isDeleteByIdSuccess) { notificationService.generateNotification('Successfully deleted session', 'success'); - navigate('ai-assistant'); - newSession(); + // Only reload if we are deleting the current session or there is no current session (/ai-assistant with no session ID) + if (sessionBeingDeleted === currentSessionId || !currentSessionId) { + newSession(); + } + + // Reset the tracking state + setSessionBeingDeleted(null); } else if (!isDeleteByIdLoading && isDeleteByIdError) { notificationService.generateNotification(`Error deleting session: ${deleteByIdError.data?.message ?? deleteByIdError.data}`, 'error'); + + // Reset the tracking state on error too + setSessionBeingDeleted(null); } // eslint-disable-next-line react-hooks/exhaustive-deps }, [isDeleteByIdSuccess, isDeleteByIdError, deleteByIdError, isDeleteByIdLoading]); @@ -180,7 +223,7 @@ export function Sessions ({ newSession }) { setSearchQuery(detail.value)} - placeholder='Search sessions by message content...' + placeholder='Search sessions by name...' clearAriaLabel='Clear search' type='search' /> @@ -237,85 +280,123 @@ export function Sessions ({ newSession }) { History -
      - [{ colspan: 10 }, { colspan: 2 }])}> - {items.map((item) => ( - - - navigate(`ai-assistant/${item.sessionId}`)}> - - {getSessionDisplay(item, 40)} - - - - - { - if (e.detail.id === 'delete-session') { - dispatch( - setConfirmationModal({ - action: 'Delete', - resourceName: 'Session', - onConfirm: () => deleteById(item.sessionId), - description: `This will delete the Session: ${item.sessionId}.` - }) - ); - } else if (e.detail.id === 'download-session') { - getSessionById(item.sessionId).then((resp) => { - const sess: LisaChatSession = resp.data; - const file = new Blob([JSON.stringify(sess, null, 2)], { type: 'application/json' }); - downloadFile(URL.createObjectURL(file), `${sess.sessionId}.json`); - }); - } else if (e.detail.id === 'export-images') { - getSessionById(item.sessionId).then(async (resp) => { - const sess: LisaChatSession = resp.data; - const images = sess.history.filter((msg) => msg.type === 'ai' && messageContainsImage(msg.content)) - .flatMap((msg) => { - return msg.content.map((contentItem) => { - if (contentItem.type === 'image_url') { - return contentItem.image_url.url; - } - }); - }); - - if (images.length === 0) { - notificationService.generateNotification('No images found to export', 'info'); - } else { - const zip = new JSZip(); - const imagePromises = images.map(async (imageUrl, index) => { - try { - const blob = await fetchImage(imageUrl); - zip.file(`image_${index + 1}.png`, blob, { binary: true }); - } catch (error) { - console.error(`Error processing image ${index + 1}:`, error); - } - }); - - // Wait for all images to be processed - await Promise.all(imagePromises); - const content = await zip.generateAsync({ type: 'blob' }); - downloadFile(URL.createObjectURL(content), `${sess.sessionId}-images.zip`); - } - }); - } else if (e.detail.id === 'rename-session') { - handleRenameSession(item); - } - }} - /> - - - ))} - -
      + + {isSessionsLoading && ( + + + Loading sessions... + Please wait while we fetch your session history + + + )} + + {!isSessionsLoading && ( + + {(() => { + const timeGroups = Object.entries(groupedSessions); + + return timeGroups.map(([timeGroup, sessions]) => { + if (sessions.length === 0) return null; + + return ( + + + {sessions.map((item) => ( + + + + navigate(`ai-assistant/${item.sessionId}`)}> + + {getSessionDisplay(item, 40)} + + + + + { + if (e.detail.id === 'delete-session') { + dispatch( + setConfirmationModal({ + action: 'Delete', + resourceName: 'Session', + onConfirm: () => { + setSessionBeingDeleted(item.sessionId); + deleteById(item.sessionId); + }, + description: `This will delete the Session: ${item.sessionId}.` + }) + ); + } else if (e.detail.id === 'download-session') { + getSessionById(item.sessionId).then((resp) => { + const sess: LisaChatSession = resp.data; + const file = new Blob([JSON.stringify(sess, null, 2)], { type: 'application/json' }); + downloadFile(URL.createObjectURL(file), `${sess.sessionId}.json`); + }); + } else if (e.detail.id === 'export-images') { + getSessionById(item.sessionId).then(async (resp) => { + const sess: LisaChatSession = resp.data; + const images = sess.history.filter((msg) => msg.type === 'ai' && messageContainsImage(msg.content)) + .flatMap((msg) => { + if (Array.isArray(msg.content)) { + return msg.content.map((contentItem) => { + if (contentItem.type === 'image_url') { + return contentItem.image_url.url; + } + }); + } + return []; + }); + + if (images.length === 0) { + notificationService.generateNotification('No images found to export', 'info'); + } else { + const zip = new JSZip(); + const imagePromises = images.map(async (imageUrl, index) => { + try { + const blob = await fetchImage(imageUrl); + zip.file(`image_${index + 1}.png`, blob, { binary: true }); + } catch (error) { + console.error(`Error processing image ${index + 1}:`, error); + } + }); + + // Wait for all images to be processed + await Promise.all(imagePromises); + const content = await zip.generateAsync({ type: 'blob' }); + downloadFile(URL.createObjectURL(content), `${sess.sessionId}-images.zip`); + } + }); + } else if (e.detail.id === 'rename-session') { + handleRenameSession(item); + } + }} + /> + + + + ))} + + + ); + }); + })()} + + )} {/* Rename Session Modal */} + {showTokens && usage.completionTokens && ( + Tokens + {usage.promptTokens} + {usage.completionTokens} + + )} + {showResponseTime && usage.responseTime !== undefined && usage.responseTime !== null && ( + Response + {usage.responseTime.toFixed(2)}s + + )} + + ); +} diff --git a/lib/user-interface/react/src/components/chatbot/hooks/chat.hooks.tsx b/lib/user-interface/react/src/components/chatbot/hooks/chat.hooks.tsx index 024aec960..1bcd6c20e 100644 --- a/lib/user-interface/react/src/components/chatbot/hooks/chat.hooks.tsx +++ b/lib/user-interface/react/src/components/chatbot/hooks/chat.hooks.tsx @@ -77,6 +77,7 @@ export const useChatGeneration = ({ const generateResponse = async (params: GenerateLLMRequestParams) => { setIsRunning(true); stopRequested.current = false; + const startTime = performance.now(); // Start client timer try { // Handle image generation mode specifically if (isImageGenerationMode) { @@ -114,6 +115,9 @@ export const useChatGeneration = ({ type: 'image_url' })); + // Calculate response time + const responseTime = (performance.now() - startTime) / 1000; + // Save the response to the chat history setSession((prev) => ({ ...prev, @@ -124,6 +128,9 @@ export const useChatGeneration = ({ ...metadata, imageGeneration: true, imageGenerationParams: imageGenParams + }, + usage: { + responseTime: parseFloat(responseTime.toFixed(2)) } })], })); @@ -343,6 +350,27 @@ export const useChatGeneration = ({ }); } + // Calculate response time and update the final message with usage info + const responseTime = (performance.now() - startTime) / 1000; + setSession((prev) => { + const lastMessage = prev.history[prev.history.length - 1]; + if (lastMessage?.type === MessageTypes.AI) { + return { + ...prev, + history: [...prev.history.slice(0, -1), + new LisaChatMessage({ + ...lastMessage, + usage: { + ...lastMessage.usage, + responseTime: parseFloat(responseTime.toFixed(2)) + } + }) + ], + }; + } + return prev; + }); + await memory.saveContext({ input: params.input }, { output: resp.join('') }); setIsStreaming(false); } catch (exception) { @@ -355,10 +383,24 @@ export const useChatGeneration = ({ } else { const response = await llmClient.invoke(messages, { tools: modelSupportsTools ? openAiTools : undefined }); const content = response.content as string; + const usage = response.response_metadata.tokenUsage; + + // Calculate response time + const responseTime = (performance.now() - startTime) / 1000; + await memory.saveContext({ input: params.input }, { output: content }); setSession((prev) => ({ ...prev, - history: [...prev.history, new LisaChatMessage({ type: 'ai', content, metadata, toolCalls: [...(response.tool_calls ?? [])] })], + history: [...prev.history, new LisaChatMessage({ + type: 'ai', + content, + metadata, + toolCalls: [...(response.tool_calls ?? [])], + usage: { + ...usage, + responseTime: parseFloat(responseTime.toFixed(2)) + } + })], })); } } diff --git a/lib/user-interface/react/src/components/chatbot/hooks/useSession.hooks.tsx b/lib/user-interface/react/src/components/chatbot/hooks/useSession.hooks.tsx index f587ef649..33ab70023 100644 --- a/lib/user-interface/react/src/components/chatbot/hooks/useSession.hooks.tsx +++ b/lib/user-interface/react/src/components/chatbot/hooks/useSession.hooks.tsx @@ -14,7 +14,7 @@ limitations under the License. */ -import { useContext, useEffect, useState } from 'react'; +import { useEffect, useState } from 'react'; import { useAuth } from 'react-oidc-context'; import { v4 as uuidv4 } from 'uuid'; import { LisaChatSession } from '@/components/types'; @@ -23,15 +23,10 @@ import { RagConfig } from '../components/RagOptions'; import { IModel } from '@/shared/model/model-management.model'; import { useAppDispatch } from '@/config/store'; import { setBreadcrumbs } from '@/shared/reducers/breadcrumbs.reducer'; -import ConfigurationContext from '@/shared/configuration.provider'; -import { IConfiguration } from '@/shared/model/configuration.model'; -import { useGetAllModelsQuery } from '@/shared/reducers/model-management.reducer'; export const useSession = (sessionId: string, getSessionById: any) => { const dispatch = useAppDispatch(); const auth = useAuth(); - const config: IConfiguration = useContext(ConfigurationContext); - const { data: allModels } = useGetAllModelsQuery(); const [session, setSession] = useState({ history: [], @@ -44,7 +39,6 @@ export const useSession = (sessionId: string, getSessionById: any) => { const [chatConfiguration, setChatConfiguration] = useState(baseConfig); const [selectedModel, setSelectedModel] = useState(); const [ragConfig, setRagConfig] = useState({} as RagConfig); - const [hasUserInteractedWithModel, setHasUserInteractedWithModel] = useState(false); useEffect(() => { // always hide breadcrumbs @@ -69,10 +63,6 @@ export const useSession = (sessionId: string, getSessionById: any) => { setSession(sess); setChatConfiguration(sess.configuration ?? baseConfig); setSelectedModel(sess.configuration?.selectedModel ?? undefined); - // If session has a pre-selected model, consider it as user interaction - if (sess.configuration?.selectedModel) { - setHasUserInteractedWithModel(true); - } setRagConfig(sess.configuration?.ragConfig ?? {} as RagConfig); setLoadingSession(false); }); @@ -90,21 +80,9 @@ export const useSession = (sessionId: string, getSessionById: any) => { } }, [sessionId, dispatch, auth.user?.profile.sub, getSessionById]); - // Set default model if none is selected, default model is configured, and user hasn't interacted with model selection - useEffect(() => { - if (!selectedModel && !hasUserInteractedWithModel && config?.configuration?.global?.defaultModel && allModels) { - const defaultModel = allModels.find((model) => model.modelId === config.configuration.global.defaultModel); - if (defaultModel) { - setSelectedModel(defaultModel); - } - } - }, [selectedModel, hasUserInteractedWithModel, config?.configuration?.global?.defaultModel, allModels]); - // Wrapper function to track user interaction with model selection - const handleSetSelectedModel = (model: IModel | undefined) => { - setHasUserInteractedWithModel(true); - setSelectedModel(model); - }; + + return { session, @@ -115,7 +93,7 @@ export const useSession = (sessionId: string, getSessionById: any) => { chatConfiguration, setChatConfiguration, selectedModel, - setSelectedModel: handleSetSelectedModel, + setSelectedModel, ragConfig, setRagConfig, }; diff --git a/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx b/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx index 37256e881..4c4ee7abb 100644 --- a/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx +++ b/lib/user-interface/react/src/components/configuration/ActivatedUserComponents.tsx @@ -20,12 +20,12 @@ import { SetFieldsFunction } from '../../shared/validation'; const ragOptions = { uploadRagDocs: 'Allow document upload from Chat', - showRagLibrary: 'Show Document Library', editNumOfRagDocument: 'Edit number of referenced documents', }; const libraryOptions = { modelLibrary: 'Show Model Library', + showRagLibrary: 'Show Document Library', showPromptTemplateLibrary: 'Show Prompt Template Library' }; @@ -41,7 +41,6 @@ const advancedOptions = { viewMetaData: 'View chat meta-data', deleteSessionHistory: 'Delete Session History', editChatHistoryBuffer: 'Edit chat history buffer', - showPromptTemplateLibrary: 'Show Prompt Template Library', enableModelComparisonUtility: 'Enable Model Comparison Utility' }; diff --git a/lib/user-interface/react/src/components/configuration/RepositoryTableConfig.tsx b/lib/user-interface/react/src/components/configuration/RepositoryTableConfig.tsx index 27e90a695..041664624 100644 --- a/lib/user-interface/react/src/components/configuration/RepositoryTableConfig.tsx +++ b/lib/user-interface/react/src/components/configuration/RepositoryTableConfig.tsx @@ -57,6 +57,13 @@ export function getTableDefinition ({ sortingField: 'type', visible: true, }, + { + id: 'embeddingModelId', + header: 'Default Embedding Model', + cell: (e) => e.embeddingModelId ?? '-', + sortingField: 'type', + visible: true, + }, { id: 'allowedGroups', header: 'Allowed Groups', diff --git a/lib/user-interface/react/src/components/configuration/createRepository/RepositoryConfigForm.tsx b/lib/user-interface/react/src/components/configuration/createRepository/RepositoryConfigForm.tsx index c38151b1b..5328ab2a3 100644 --- a/lib/user-interface/react/src/components/configuration/createRepository/RepositoryConfigForm.tsx +++ b/lib/user-interface/react/src/components/configuration/createRepository/RepositoryConfigForm.tsx @@ -14,12 +14,12 @@ limitations under the License. */ -import React, { ReactElement } from 'react'; +import React, { ReactElement, useMemo, useState } from 'react'; import { FormProps } from '../../../shared/form/form-props'; import FormField from '@cloudscape-design/components/form-field'; import Input from '@cloudscape-design/components/input'; import Select from '@cloudscape-design/components/select'; -import { SpaceBetween } from '@cloudscape-design/components'; +import { Autosuggest, SpaceBetween } from '@cloudscape-design/components'; import { OpenSearchNewClusterConfig, RagRepositoryConfig, @@ -33,6 +33,8 @@ import { ArrayInputField } from '../../../shared/form/array-input'; import { RdsConfigForm } from './RdsConfigForm'; import { OpenSearchConfigForm } from './OpenSearchConfigForm'; import { BedrockKnowledgeBaseConfigForm } from './BedrockKnowledgeBaseConfigForm'; +import { useGetAllModelsQuery } from '@/shared/reducers/model-management.reducer'; +import { ModelStatus, ModelType } from '@/shared/model/model-management.model'; export type RepositoryConfigProps = { isEdit: boolean @@ -41,6 +43,15 @@ export type RepositoryConfigProps = { export function RepositoryConfigForm (props: FormProps & RepositoryConfigProps): ReactElement { const { item, touchFields, setFields, formErrors, isEdit } = props; const shape = RagRepositoryConfigSchema.innerType().shape; + const { data: embeddingModels, isFetching: isFetchingEmbeddingModels } = useGetAllModelsQuery(undefined, {refetchOnMountOrArgChange: 5, + selectFromResult: (state) => ({ + isFetching: state.isFetching, + data: (state.data || []).filter((model) => model.modelType === ModelType.embedding && model.status === ModelStatus.InService), + })}); + const embeddingOptions = useMemo(() => { + return embeddingModels?.map((model) => ({value: model.modelId})) || []; + }, [embeddingModels]); + const [selectedEmbeddingOption, setSelectedEmbeddingOption] = useState(undefined); return ( & Re setFields({ 'repositoryName': detail.value }); }} placeholder='Postgres RAG' /> + + No embedding models available.
      } + filteringType='auto' + value={selectedEmbeddingOption ?? ''} + enteredTextLabel={(text) => `Use: "${text}"`} + onChange={({ detail }) => { + setSelectedEmbeddingOption(detail.value); + setFields({ 'embeddingModelId': detail.value }); + }} + options={embeddingOptions} + /> + diff --git a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx index efb2d1618..14c1fd823 100644 --- a/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx +++ b/lib/user-interface/react/src/components/document-library/DocumentLibraryComponent.tsx @@ -15,7 +15,7 @@ */ import * as React from 'react'; -import { ReactElement } from 'react'; +import { ReactElement, useState } from 'react'; import { Button, ButtonDropdownProps, @@ -61,15 +61,39 @@ function disabledDeleteReason (selectedItems: ReadonlyArray) { } export function DocumentLibraryComponent ({ repositoryId }: DocumentLibraryComponentProps): ReactElement { - const { data: allDocs, isFetching } = useListRagDocumentsQuery({ repositoryId }, { refetchOnMountOrArgChange: 5 }); const [deleteMutation, { isLoading: isDeleteLoading }] = useDeleteRagDocumentsMutation(); + const [currentPage, setCurrentPage] = useState(1); + const [lastEvaluatedKey, setLastEvaluatedKey] = useState<{ + pk: string; + document_id: string; + repository_id: string; + } | null>(null); + const [pageHistory, setPageHistory] = useState>([]); + const currentUser = useAppSelector(selectCurrentUsername); const isAdmin = useAppSelector(selectCurrentUserIsAdmin); const [preferences, setPreferences] = useLocalStorage('DocumentRagPreferences', DEFAULT_PREFERENCES); const dispatch = useAppDispatch(); - const { items, actions, filteredItemsCount, collectionProps, filterProps, paginationProps } = useCollection( + const { data: paginatedDocs, isFetching } = useListRagDocumentsQuery( + { + repositoryId, + lastEvaluatedKey: lastEvaluatedKey || undefined, + pageSize: preferences.pageSize + }, + { refetchOnMountOrArgChange: 5 } + ); + + const allDocs = paginatedDocs?.documents || []; + const totalDocuments = paginatedDocs?.totalDocuments || 0; + const hasNextPage = paginatedDocs?.hasNextPage || false; + + const { items, actions, filteredItemsCount, collectionProps, filterProps } = useCollection( allDocs ?? [], { filtering: { empty: ( @@ -156,11 +180,7 @@ export function DocumentLibraryComponent ({ repositoryId }: DocumentLibraryCompo } header={
      @@ -190,7 +210,38 @@ export function DocumentLibraryComponent ({ repositoryId }: DocumentLibraryCompo
      } pagination={ - + { + if (hasNextPage && paginatedDocs?.lastEvaluated) { + // Add current key to history before moving to next page + setPageHistory([...pageHistory, lastEvaluatedKey]); + setLastEvaluatedKey(paginatedDocs.lastEvaluated); + // Update current page to reflect the navigation + setCurrentPage((prev) => prev + 1); + } + }} + onPreviousPageClick={() => { + if (pageHistory.length > 0) { + // Go back one page by popping from history + const previousKey = pageHistory[pageHistory.length - 1]; + setPageHistory(pageHistory.slice(0, -1)); + setLastEvaluatedKey(previousKey); + // Update current page to reflect the navigation + setCurrentPage((prev) => prev - 1); + } else { + // If no history, go to first page + setLastEvaluatedKey(null); + setCurrentPage(1); + } + }} + ariaLabels={{ + nextPageLabel: 'Next page', + previousPageLabel: 'Previous page', + pageLabel: (pageNumber) => `Page ${pageNumber} of ${Math.ceil(totalDocuments / (preferences.pageSize || 10))}`, + }} + /> } preferences={ (''); + const [isTestingConnection, setIsTestingConnection] = useState(false); + + // Test connection using useMcp + const { + state: connectionState, + tools, + } = useMcp({ + url: testConnectionUrl, + clientName: state.form.clientConfig?.name || 'Test Client', + clientConfig: state.form.clientConfig || {}, + customHeaders: state.form.customHeaders?.reduce((r,{key,value}) => (r[key] = value,r), {}), + autoReconnect: false, + autoRetry: false, + debug: false, + }); + + const testConnection = () => { + if (state.form.url && String(state.form.url).trim()) { + setIsTestingConnection(true); + setTestConnectionUrl(String(state.form.url).trim()); + + // Add a timeout to prevent infinite loading + setTimeout(() => { + if (isTestingConnection) { + setIsTestingConnection(false); + } + }, 30000); // 30 second timeout + } + }; + + // Reset test connection state when URL changes + useEffect(() => { + if (testConnectionUrl !== state.form.url) { + setTestConnectionUrl(''); + setIsTestingConnection(false); + } + }, [state.form.url, testConnectionUrl]); + + // Reset testing state when connection completes + useEffect(() => { + if (testConnectionUrl && (connectionState === 'ready' || connectionState === 'failed')) { + setIsTestingConnection(false); + } + }, [connectionState, testConnectionUrl]); + // create success notification useEffect(() => { if (isCreatingSuccess || isUpdatingSuccess) { const verb = isCreatingSuccess ? 'created' : 'updated'; const data = isCreatingSuccess ? createData : updateData; notificationService.generateNotification(`Successfully ${verb} MCP Connection: ${data.name}`, 'success'); - navigate('/mcp-connections'); + navigate(`/mcp-connections/${data.id}`); } }, [isCreatingSuccess, isUpdatingSuccess, notificationService, createData, updateData, navigate]); @@ -174,11 +224,47 @@ export function McpServerForm (props: McpServerFormProps) { placeholder='Enter MCP connection description' />
      - touchFields(['url'])} onChange={({ detail }) => { - setFields({ 'url': detail.value }); - }} - disabled={disabled} - placeholder='Enter MCP server URL' /> + + touchFields(['url'])} onChange={({ detail }) => { + setFields({ 'url': detail.value }); + }} + disabled={disabled} + placeholder='Enter MCP server URL' /> + + + {testConnectionUrl && ( + + + {connectionState === 'ready' ? 'Connection successful' : + connectionState === 'failed' ? 'Connection failed' : + connectionState === 'discovering' ? 'Discovering server...' : + connectionState === 'authenticating' ? 'Authenticating...' : + connectionState === 'connecting' || connectionState === 'loading' ? 'Connecting...' : + 'Connection failed'} + + {connectionState === 'ready' && tools && ( + + Available tools: {tools.length} + + )} + {connectionState === 'failed' && ( + + Unable to connect to the MCP server. Please check the URL and try again. + + )} + + )} {isUserAdmin && diff --git a/lib/user-interface/react/src/components/model-management/components/ModelComparisonComponents.tsx b/lib/user-interface/react/src/components/model-management/components/ModelComparisonComponents.tsx index 63ef4697b..bc9a0109e 100644 --- a/lib/user-interface/react/src/components/model-management/components/ModelComparisonComponents.tsx +++ b/lib/user-interface/react/src/components/model-management/components/ModelComparisonComponents.tsx @@ -14,7 +14,7 @@ limitations under the License. */ -import { ReactElement, memo } from 'react'; +import { ReactElement, memo, useCallback, useRef } from 'react'; import { Box, SpaceBetween, @@ -27,17 +27,17 @@ import { SelectProps, PromptInput } from '@cloudscape-design/components'; -import { IModel } from '../../../shared/model/model-management.model'; -import { ComparisonResponse, ModelSelection } from '../hooks/useModelComparison.hook'; +import { IModel } from '@/shared/model/model-management.model'; +import { ComparisonResponse, ModelSelection } from '@/components/model-management/hooks/useModelComparison.hook'; import { MODEL_COMPARISON_CONFIG, UI_CONFIG, PLACEHOLDERS, ARIA_LABELS -} from '../config/modelComparison.config'; -import { LisaChatMessage, MessageTypes } from '../../types'; -import Message from '../../chatbot/components/Message'; -import { IChatConfiguration } from '../../../shared/model/chat.configurations.model'; +} from '@/components/model-management/config/modelComparison.config'; +import { LisaChatMessage, MessageTypes } from '@/components/types'; +import Message from '@/components/chatbot/components/Message'; +import { IChatConfiguration } from '@/shared/model/chat.configurations.model'; import { downloadFile } from '@/shared/util/downloader'; type ModelSelectionSectionProps = { @@ -110,15 +110,68 @@ type PromptInputSectionProps = { prompt: string; onPromptChange: (value: string) => void; onCompare: () => void; + onStopComparison: () => void; canCompare: boolean; + shouldShowStopButton: boolean; }; export const PromptInputSection = memo(function PromptInputSection ({ prompt, onPromptChange, onCompare, - canCompare + onStopComparison, + canCompare, + shouldShowStopButton }: PromptInputSectionProps): ReactElement { + // Ref to track if we're processing a keyboard event + const isKeyboardEventRef = useRef(false); + + // Handle stop functionality similar to Chat.tsx + const handleStop = useCallback(() => { + onStopComparison(); + }, [onStopComparison]); + + // Custom action handler that only allows stop on button clicks + const handleAction = useCallback(() => { + // If this is a keyboard event, don't process it here (it's handled in handleKeyPress) + if (isKeyboardEventRef.current) { + return; + } + + if (shouldShowStopButton) { + // Only allow stop action on button clicks (not keyboard events) + handleStop(); + } else { + // Normal send functionality - allow both button clicks and Enter key + if (prompt.length > 0 && canCompare) { + onCompare(); + } + } + }, [shouldShowStopButton, handleStop, prompt.length, canCompare, onCompare]); + + // Handle Enter key press + const handleKeyPress = useCallback((event: any) => { + if (event.detail.key === 'Enter' && !event.detail.shiftKey) { + event.preventDefault(); + isKeyboardEventRef.current = true; + + // Handle the action directly for keyboard events + if (shouldShowStopButton) { + // Do nothing for stop button when Enter is pressed + } else { + // Normal send functionality for Enter key + if (prompt.length > 0 && canCompare) { + onCompare(); + } + } + + // Reset the flag after a short delay + setTimeout(() => { + isKeyboardEventRef.current = false; + }, 100); + } + }, [shouldShowStopButton, prompt.length, canCompare, onCompare]); + return ( Prompt @@ -126,17 +179,19 @@ export const PromptInputSection = memo(function PromptInputSection ({ value={prompt} onChange={({ detail }) => onPromptChange(detail.value)} placeholder={PLACEHOLDERS.PROMPT_INPUT} - actionButtonIconName='send' - actionButtonAriaLabel={ARIA_LABELS.SEND_PROMPT} - onAction={onCompare} - actionButtonDisabled={!canCompare} + actionButtonIconName={shouldShowStopButton ? 'status-negative' : 'send'} + actionButtonAriaLabel={shouldShowStopButton ? 'Stop comparison' : ARIA_LABELS.SEND_PROMPT} + onAction={handleAction} + onKeyDown={handleKeyPress} + maxRows={4} + minRows={2} + spellcheck={true} + disabled={!canCompare && !shouldShowStopButton} /> ); }); - - type ComparisonResultsProps = { prompt: string; responses: ComparisonResponse[]; @@ -172,13 +227,14 @@ export const ComparisonResults = memo(function ComparisonResults ({ content: response.loading ? '' : response.error ? `Error: ${response.error}` : response.response, metadata: { modelName: modelName - } + }, + usage: response.usage, }); }); // Dummy functions for Message component (not used in comparison context) - const handleSendGenerateRequest = () => {}; - const setUserPrompt = () => {}; + const handleSendGenerateRequest = () => { }; + const setUserPrompt = () => { }; const handleDownloadResults = (): void => { const results = responses.map((response) => { @@ -188,7 +244,8 @@ export const ComparisonResults = memo(function ComparisonResults ({ modelName: model?.modelName || response.modelId, response: response.response, error: response.error, - loading: response.loading + loading: response.loading, + usage: response.usage }; }); @@ -207,15 +264,13 @@ export const ComparisonResults = memo(function ComparisonResults ({ return ( - - + }>Comparison Results}> {/* Display user prompt */} @@ -224,13 +279,14 @@ export const ComparisonResults = memo(function ComparisonResults ({ message={userMessage} isRunning={false} callingToolName='' - showMetadata={false} + showMetadata={chatConfiguration.sessionConfiguration.showMetadata} isStreaming={false} markdownDisplay={markdownDisplay} setChatConfiguration={setChatConfiguration} handleSendGenerateRequest={handleSendGenerateRequest} setUserPrompt={setUserPrompt} chatConfiguration={chatConfiguration} + showUsage={true} /> )} @@ -246,13 +302,14 @@ export const ComparisonResults = memo(function ComparisonResults ({ message={aiMessages[index]} isRunning={response.loading} callingToolName='' - showMetadata={false} - isStreaming={false} + showMetadata={chatConfiguration.sessionConfiguration.showMetadata} + isStreaming={response.streaming} markdownDisplay={markdownDisplay} setChatConfiguration={setChatConfiguration} handleSendGenerateRequest={handleSendGenerateRequest} setUserPrompt={setUserPrompt} chatConfiguration={chatConfiguration} + showUsage={true} /> {response.error && ( diff --git a/lib/user-interface/react/src/components/model-management/config/modelComparison.config.ts b/lib/user-interface/react/src/components/model-management/config/modelComparison.config.ts index b497816a1..3c4edb033 100644 --- a/lib/user-interface/react/src/components/model-management/config/modelComparison.config.ts +++ b/lib/user-interface/react/src/components/model-management/config/modelComparison.config.ts @@ -15,12 +15,13 @@ */ import { vscDarkPlus } from 'react-syntax-highlighter/dist/esm/styles/prism'; +import { SYSTEM_PROMPT } from '@/shared/constants/systemPrompt'; export const MODEL_COMPARISON_CONFIG = { MAX_MODELS: 4, MIN_MODELS: 2, DEFAULT_MAX_TOKENS: 2000, - DEFAULT_SYSTEM_MESSAGE: 'You are a helpful AI assistant. Provide clear, concise, and accurate responses.', + DEFAULT_SYSTEM_MESSAGE: SYSTEM_PROMPT, RETRY_ATTEMPTS: 3, TIMEOUT_MS: 30000, } as const; diff --git a/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx index 900af125b..3b20905a7 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx @@ -19,9 +19,10 @@ import { FormProps} from '../../../shared/form/form-props'; import FormField from '@cloudscape-design/components/form-field'; import Input from '@cloudscape-design/components/input'; import { IContainerConfig } from '../../../shared/model/model-management.model'; -import { Button, Grid, Header, Icon, SpaceBetween } from '@cloudscape-design/components'; +import { Button, Grid, Header, Icon, Select, SpaceBetween } from '@cloudscape-design/components'; import Container from '@cloudscape-design/components/container'; import { EnvironmentVariables } from '../../../shared/form/environment-variables'; +import { EcsSourceType } from '../../../../../../schema'; type ContainerConfigProps = FormProps & { isEdit: boolean; @@ -61,15 +62,17 @@ export function ContainerConfig (props: ContainerConfigProps) : ReactElement { }} /> - - +