diff --git a/VERSION b/VERSION index 09b254e90..5fe607230 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -6.0.0 +6.0.1 diff --git a/cypress/src/support/adminHelpers.ts b/cypress/src/support/adminHelpers.ts index b4c696f8f..a70af5128 100644 --- a/cypress/src/support/adminHelpers.ts +++ b/cypress/src/support/adminHelpers.ts @@ -47,7 +47,6 @@ export function expandAdminMenu () { 'Model Management', 'RAG Management', 'MCP Management', - 'MCP Workbench', ]); }); } diff --git a/lambda/models/lambda_functions.py b/lambda/models/lambda_functions.py index 1340d9e3a..7ebf3119c 100644 --- a/lambda/models/lambda_functions.py +++ b/lambda/models/lambda_functions.py @@ -18,7 +18,7 @@ import boto3 import botocore.session -from fastapi import FastAPI, Path, Request +from fastapi import FastAPI, HTTPException, Path, Request from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -61,6 +61,21 @@ stepfunctions = boto3.client("stepfunctions", region_name=os.environ["AWS_REGION"], config=retry_config) +def get_admin_status_and_groups(request: Request) -> tuple[bool, list[str]]: + admin_status = False + user_groups = [] + + if "aws.event" in request.scope: + event = request.scope["aws.event"] + try: + user_groups = get_groups(event) + admin_status = is_admin(event) + except Exception: + user_groups = [] + admin_status = False + return admin_status, user_groups + + @app.exception_handler(ModelNotFoundError) async def model_not_found_handler(request: Request, exc: ModelNotFoundError) -> JSONResponse: """Handle exception when model cannot be found and translate to a 404 error.""" @@ -87,8 +102,11 @@ async def user_error_handler( @app.post(path="", include_in_schema=False) @app.post(path="/") -async def create_model(create_request: CreateModelRequest) -> CreateModelResponse: +async def create_model(create_request: CreateModelRequest, request: Request) -> CreateModelResponse: """Endpoint to create a model.""" + admin_status, _ = get_admin_status_and_groups(request) + if not admin_status: + raise HTTPException(status_code=403, detail="User does not have permission to create models.") create_handler = CreateModelHandler( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, @@ -109,18 +127,7 @@ async def list_models(request: Request) -> ListModelsResponse: guardrails_table_resource=guardrails_table, ) - user_groups = [] - admin_status = False - - if "aws.event" in request.scope: - event = request.scope["aws.event"] - try: - user_groups = get_groups(event) - admin_status = is_admin(event) - except Exception: - user_groups = [] - admin_status = False - + admin_status, user_groups = get_admin_status_and_groups(request) return list_handler(user_groups=user_groups, is_admin=admin_status) @@ -136,18 +143,7 @@ async def get_model( guardrails_table_resource=guardrails_table, ) - user_groups = [] - admin_status = False - - if "aws.event" in request.scope: - event = request.scope["aws.event"] - try: - user_groups = get_groups(event) - admin_status = is_admin(event) - except Exception: - user_groups = [] - admin_status = False - + admin_status, user_groups = get_admin_status_and_groups(request) return get_handler(model_id=model_id, user_groups=user_groups, is_admin=admin_status) @@ -155,8 +151,12 @@ async def get_model( async def update_model( model_id: Annotated[str, Path(title="The unique model ID of the model to update")], update_request: UpdateModelRequest, + request: Request, ) -> UpdateModelResponse: """Endpoint to update a model.""" + admin_status, _ = get_admin_status_and_groups(request) + if not admin_status: + raise HTTPException(status_code=403, detail="User does not have permission to update models.") update_handler = UpdateModelHandler( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, @@ -171,6 +171,9 @@ async def delete_model( model_id: Annotated[str, Path(title="The unique model ID of the model to delete")], request: Request ) -> DeleteModelResponse: """Endpoint to delete a model.""" + admin_status, _ = get_admin_status_and_groups(request) + if not admin_status: + raise HTTPException(status_code=403, detail="User does not have permission to delete models.") delete_handler = DeleteModelHandler( autoscaling_client=autoscaling, stepfunctions_client=stepfunctions, diff --git a/lib/api-base/fastApiContainer.ts b/lib/api-base/fastApiContainer.ts index 0ffca9d5a..b9d23c144 100644 --- a/lib/api-base/fastApiContainer.ts +++ b/lib/api-base/fastApiContainer.ts @@ -81,8 +81,8 @@ export class FastApiContainer extends Construct { }; // Add build config overrides if provided - if (config.restApiConfig.buildConfig?.NODEENV_CACHE_DIR) { - buildArgs.NODEENV_CACHE_DIR = config.restApiConfig.buildConfig.NODEENV_CACHE_DIR; + if (config.restApiConfig.buildConfig?.PRISMA_CACHE_DIR) { + buildArgs.PRISMA_CACHE_DIR = config.restApiConfig.buildConfig.PRISMA_CACHE_DIR; } // Add MCP Workbench build config overrides if provided diff --git a/lib/docs/admin/deploy.md b/lib/docs/admin/deploy.md index 888d15fa0..928db8044 100644 --- a/lib/docs/admin/deploy.md +++ b/lib/docs/admin/deploy.md @@ -294,10 +294,10 @@ npmConfig: # Use ADC-accessible base images for LISA-Serve and Batch Ingestion baseImage: /python:3.11 -# Configure offline build dependencies for REST API (nodeenv for prisma-client-py) +# Configure offline build dependencies for REST API (prisma-client-py dependencies) restApiConfig: buildConfig: - NODEENV_CACHE_DIR: "./nodeenv-cache" # Path relative to lib/serve/rest-api/ + PRISMA_CACHE_DIR: "./PRISMA_CACHE" # Path relative to lib/serve/rest-api/ # Configure offline build dependencies for MCP Workbench (S6 Overlay and rclone) mcpWorkbenchBuildConfig: @@ -311,12 +311,33 @@ You'll also want any model hosting base containers available, e.g. vllm/vllm-ope For environments without internet access during Docker builds, you can pre-cache required dependencies: -**REST API nodeenv cache** (required by prisma-client-py): +**REST API Prisma cache** (required by prisma-client-py): + +The `prisma-client-py` package requires platform-specific binaries and a Node.js environment to function. When Prisma runs for the first time, it downloads these dependencies to `~/.cache/prisma/` and `~/.cache/prisma-python/`. For offline deployments, you need to pre-populate this cache. + +Below is an example workflow using an Amazon Linux 2023 instance with Python 3.12: + ```bash -# Create the cache directory in the REST API build context -python -m nodeenv lib/serve/rest-api/nodeenv-cache +# Ensure Pip is up-to-date +pip3 install --upgrade pip + +# Install Prisma Python package +pip3 install prisma + +# Trigger Prisma to download all required binaries and create its Node.js environment +# This populates ~/.cache/prisma/ and ~/.cache/prisma-python/ +prisma version + +# Copy the complete Prisma cache to your build context +# The wildcard captures both 'prisma' and 'prisma-python' directories +cp -r ~/.cache/prisma* lib/serve/rest-api/PRISMA_CACHE/ ``` +**Important Notes:** +- The cache is platform-specific. Generate it on a system matching your Docker base image (e.g., for `python:3.13-slim` which is Debian-based, so you may want to use a Debian-based system) +- The `prisma version` command downloads binaries for your current platform +- Both `prisma/` and `prisma-python/` directories are required for offline operation + **MCP Workbench dependencies** (S6 Overlay and rclone): ```bash # Download S6 Overlay files diff --git a/lib/schema/configSchema.ts b/lib/schema/configSchema.ts index ca3e91d2e..a10ea561e 100644 --- a/lib/schema/configSchema.ts +++ b/lib/schema/configSchema.ts @@ -711,7 +711,7 @@ const FastApiContainerConfigSchema = z.object({ sslCertIamArn: z.string().nullish().default(null).describe('ARN of the self-signed cert to be used throughout the system'), imageConfig: ImageAssetSchema.optional().describe('Override image configuration for ECS FastAPI Containers'), buildConfig: z.object({ - NODEENV_CACHE_DIR: z.string().optional().describe('Override with a path relative to the build directory for a pre-cached nodeenv directory. Defaults to NODEENV_CACHE. For offline environments, populate using: python -m nodeenv PATH') + PRISMA_CACHE_DIR: z.string().optional().describe('Override with a path relative to the build directory for a pre-cached prisma directory. Defaults to PRISMA_CACHE.') }).default({}), rdsConfig: RdsInstanceConfig .default({ diff --git a/lib/serve/rest-api/Dockerfile b/lib/serve/rest-api/Dockerfile index aef78ed7e..738647a44 100644 --- a/lib/serve/rest-api/Dockerfile +++ b/lib/serve/rest-api/Dockerfile @@ -1,8 +1,8 @@ ARG BASE_IMAGE=python:3.11 FROM ${BASE_IMAGE} -ARG NODEENV_CACHE_DIR=NODEENV_CACHE -ENV NODEENV_CACHE_DIR=$NODEENV_CACHE_DIR +ARG PRISMA_CACHE_DIR=PRISMA_CACHE +ENV PRISMA_CACHE_DIR=$PRISMA_CACHE_DIR # Install build dependencies for madoka package RUN apt-get update && apt-get install -y \ @@ -28,13 +28,21 @@ WORKDIR /app COPY src/requirements.txt . RUN pip install --no-cache-dir --upgrade -r requirements.txt -# Copy nodeenv cache directory (always exists, may be empty or populated) -COPY ${NODEENV_CACHE_DIR} /tmp/nodeenv-cache/ +# Copy prisma cache directory (always exists, may be empty or populated) +COPY ${PRISMA_CACHE_DIR} /tmp/prisma-cache/ -# Pre-cache nodeenv for prisma-client-py +# Pre-cache prisma for prisma-client-py # If the copied directory has content, use it (for offline environments) # Otherwise, download it during build (requires internet) -RUN python -m prisma version +RUN mkdir -p /root/.cache && \ + if [ -d "/tmp/prisma-cache" ] && [ -n "$(ls /tmp/prisma-cache 2>/dev/null)" ]; then \ + echo "Using pre-cached Prisma dependencies from host" && \ + cp -r /tmp/prisma-cache/prisma* /root/.cache && \ + rm -rf /tmp/prisma-cache; \ + else \ + echo "Fetching Prisma Dependencies (requires internet)" && \ + prisma version; \ + fi # Copy the source code into the container COPY src/ ./src diff --git a/lib/serve/rest-api/NODEENV_CACHE/.gitkeep b/lib/serve/rest-api/NODEENV_CACHE/.gitkeep deleted file mode 100644 index b098f0c35..000000000 --- a/lib/serve/rest-api/NODEENV_CACHE/.gitkeep +++ /dev/null @@ -1,2 +0,0 @@ -# Placeholder to ensure NODEENV_CACHE directory exists in build context -# For offline builds, populate this directory using: python -m nodeenv NODEENV_CACHE diff --git a/lib/serve/rest-api/PRISMA_CACHE/.gitkeep b/lib/serve/rest-api/PRISMA_CACHE/.gitkeep new file mode 100644 index 000000000..ad102e091 --- /dev/null +++ b/lib/serve/rest-api/PRISMA_CACHE/.gitkeep @@ -0,0 +1,5 @@ +# Placeholder to ensure PRISMA_CACHE directory exists in build context +# For offline builds, populate this directory by copying the complete Prisma cache: +# 1. Install prisma: pip3 install prisma +# 2. Generate cache: prisma version +# 3. Copy cache: cp -r ~/.cache/prisma* lib/serve/rest-api/PRISMA_CACHE/ diff --git a/lib/user-interface/react/package.json b/lib/user-interface/react/package.json index 5db2b053d..3b1b428b1 100644 --- a/lib/user-interface/react/package.json +++ b/lib/user-interface/react/package.json @@ -1,7 +1,7 @@ { "name": "lisa-web", "private": true, - "version": "6.0.0", + "version": "6.0.1", "type": "module", "scripts": { "dev": "vite", diff --git a/lib/user-interface/react/src/shared/model/model-management.model.ts b/lib/user-interface/react/src/shared/model/model-management.model.ts index 73033ef31..ff76b88e3 100644 --- a/lib/user-interface/react/src/shared/model/model-management.model.ts +++ b/lib/user-interface/react/src/shared/model/model-management.model.ts @@ -210,7 +210,7 @@ export const metricConfigSchema = z.object({ }); export const loadBalancerHealthCheckConfigSchema = z.object({ - path: z.string().default('/status'), + path: z.string().default('/health'), interval: z.number().default(60), timeout: z.number().default(30), healthyThresholdCount: z.number().default(2), diff --git a/lib/user-interface/react/src/shared/reducers/configuration.reducer.ts b/lib/user-interface/react/src/shared/reducers/configuration.reducer.ts index 72ad7e724..53b81044c 100644 --- a/lib/user-interface/react/src/shared/reducers/configuration.reducer.ts +++ b/lib/user-interface/react/src/shared/reducers/configuration.reducer.ts @@ -23,7 +23,7 @@ export const configurationApi = createApi({ baseQuery: lisaBaseQuery(), tagTypes: ['configuration'], refetchOnFocus: true, - refetchOnReconnect: true, + refetchOnMountOrArgChange: true, endpoints: (builder) => ({ getConfiguration: builder.query({ query: (configScope) => ({ diff --git a/lib/user-interface/react/src/shared/reducers/mcp-server.reducer.ts b/lib/user-interface/react/src/shared/reducers/mcp-server.reducer.ts index e8c324a7b..6986fb8bb 100644 --- a/lib/user-interface/react/src/shared/reducers/mcp-server.reducer.ts +++ b/lib/user-interface/react/src/shared/reducers/mcp-server.reducer.ts @@ -68,7 +68,7 @@ export const mcpServerApi = createApi({ baseQuery: lisaBaseQuery(), tagTypes: ['mcpServers'], refetchOnFocus: true, - refetchOnReconnect: true, + refetchOnMountOrArgChange: true, endpoints: (builder) => ({ createMcpServer: builder.mutation({ query: (mcpServer) => ({ diff --git a/lib/user-interface/react/src/shared/reducers/mcp-tools.reducer.ts b/lib/user-interface/react/src/shared/reducers/mcp-tools.reducer.ts index 5d41461b1..2a64f5282 100644 --- a/lib/user-interface/react/src/shared/reducers/mcp-tools.reducer.ts +++ b/lib/user-interface/react/src/shared/reducers/mcp-tools.reducer.ts @@ -31,7 +31,7 @@ export const mcpToolsApi = createApi({ baseQuery: lisaBaseQuery(), tagTypes: ['mcpTools'], refetchOnFocus: true, - refetchOnReconnect: true, + refetchOnMountOrArgChange: true, endpoints: (builder) => ({ listMcpTools: builder.query({ query: () => ({ diff --git a/lib/user-interface/react/src/shared/reducers/model-management.reducer.ts b/lib/user-interface/react/src/shared/reducers/model-management.reducer.ts index 55de56c4d..80f70b291 100644 --- a/lib/user-interface/react/src/shared/reducers/model-management.reducer.ts +++ b/lib/user-interface/react/src/shared/reducers/model-management.reducer.ts @@ -23,7 +23,7 @@ export const modelManagementApi = createApi({ baseQuery: lisaBaseQuery(), tagTypes: ['models'], refetchOnFocus: true, - refetchOnReconnect: true, + refetchOnMountOrArgChange: true, endpoints: (builder) => ({ getAllModels: builder.query({ query: () => ({ diff --git a/lib/user-interface/react/src/shared/reducers/prompt-templates.reducer.ts b/lib/user-interface/react/src/shared/reducers/prompt-templates.reducer.ts index 26afbcf47..58ebe58cd 100644 --- a/lib/user-interface/react/src/shared/reducers/prompt-templates.reducer.ts +++ b/lib/user-interface/react/src/shared/reducers/prompt-templates.reducer.ts @@ -54,7 +54,7 @@ export const promptTemplateApi = createApi({ baseQuery: lisaBaseQuery(), tagTypes: ['promptTemplates'], refetchOnFocus: true, - refetchOnReconnect: true, + refetchOnMountOrArgChange: true, endpoints: (builder) => ({ createPromptTemplate: builder.mutation({ query: (promptTemplate) => ({ diff --git a/lib/user-interface/react/src/shared/reducers/rag.reducer.ts b/lib/user-interface/react/src/shared/reducers/rag.reducer.ts index 3e4b68a5f..36fb77993 100644 --- a/lib/user-interface/react/src/shared/reducers/rag.reducer.ts +++ b/lib/user-interface/react/src/shared/reducers/rag.reducer.ts @@ -177,7 +177,7 @@ export const ragApi = createApi({ baseQuery: lisaBaseQuery(), tagTypes: ['repositories', 'docs', 'repository-status', 'jobs', 'collections'], refetchOnFocus: true, - refetchOnReconnect: true, + refetchOnMountOrArgChange: true, endpoints: (builder) => ({ listRagRepositories: builder.query({ query: () => ({ diff --git a/lib/user-interface/react/src/shared/reducers/session.reducer.ts b/lib/user-interface/react/src/shared/reducers/session.reducer.ts index a0276fbc0..035835e45 100644 --- a/lib/user-interface/react/src/shared/reducers/session.reducer.ts +++ b/lib/user-interface/react/src/shared/reducers/session.reducer.ts @@ -29,7 +29,7 @@ export const sessionApi = createApi({ baseQuery: lisaBaseQuery(), tagTypes: ['sessions'], refetchOnFocus: true, - refetchOnReconnect: true, + refetchOnMountOrArgChange: true, endpoints: (builder) => ({ getSessionById: builder.query({ query: (sessionId: string) => ({ diff --git a/lib/user-interface/react/src/shared/reducers/user-preferences.reducer.ts b/lib/user-interface/react/src/shared/reducers/user-preferences.reducer.ts index 8440dcc66..f43b8f0e2 100644 --- a/lib/user-interface/react/src/shared/reducers/user-preferences.reducer.ts +++ b/lib/user-interface/react/src/shared/reducers/user-preferences.reducer.ts @@ -50,7 +50,7 @@ export const userPreferencesApi = createApi({ baseQuery: lisaBaseQuery(), tagTypes: ['user-preferences'], refetchOnFocus: true, - refetchOnReconnect: true, + refetchOnMountOrArgChange: true, endpoints: (builder) => ({ updateUserPreferences: builder.mutation({ query: (userPreferences) => ({ diff --git a/lisa-sdk/pyproject.toml b/lisa-sdk/pyproject.toml index 15fd0b9c2..4a2dc6d16 100644 --- a/lisa-sdk/pyproject.toml +++ b/lisa-sdk/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "lisapy" -version = "6.0.0" +version = "6.0.1" description = "A simple SDK to help you interact with LISA. LISA is an LLM hosting solution for AWS dedicated clouds or ADCs." readme = "README.md" requires-python = ">=3.11" @@ -15,7 +15,7 @@ dependencies = [ [tool.poetry] name = "lisapy" -version = "6.0.0" +version = "6.0.1" description = "A simple SDK to help you interact with LISA. LISA is an LLM hosting solution for AWS dedicated clouds or ADCs." authors = ["Steve Goley "] readme = "README.md" diff --git a/package.json b/package.json index f2c4cec0d..81b0781d6 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@awslabs/lisa", - "version": "6.0.0", + "version": "6.0.1", "description": "A scalable infrastructure-as-code solution for self-hosting and orchestrating LLM inference with RAG capabilities, providing low-latency access to generative AI and embedding models across multiple providers.", "homepage": "https://awslabs.github.io/LISA/", "license": "Apache-2.0", diff --git a/test/lambda/test_models_lambda.py b/test/lambda/test_models_lambda.py index f115a643f..05fdabe20 100644 --- a/test/lambda/test_models_lambda.py +++ b/test/lambda/test_models_lambda.py @@ -20,6 +20,7 @@ import boto3 import pytest +from fastapi import HTTPException, Request from models.domain_objects import ( AutoScalingConfig, AutoScalingInstanceConfig, @@ -28,6 +29,7 @@ ContainerHealthCheckConfig, CreateModelRequest, CreateModelResponse, + DeleteModelResponse, InferenceContainer, LISAModel, LoadBalancerConfig, @@ -36,6 +38,7 @@ ModelStatus, ModelType, UpdateModelRequest, + UpdateModelResponse, ) from models.exception import InvalidStateTransitionError, ModelAlreadyExistsError, ModelNotFoundError from models.handler.base_handler import BaseApiHandler @@ -45,6 +48,17 @@ from models.handler.list_models_handler import ListModelsHandler from models.handler.update_model_handler import UpdateModelHandler from models.handler.utils import to_lisa_model +from models.lambda_functions import ( + app, + create_model, + delete_model, + get_admin_status_and_groups, + get_instances, + model_not_found_handler, + update_model, + user_error_handler, + validation_exception_handler, +) from moto import mock_aws # Set mock AWS credentials @@ -555,7 +569,6 @@ async def test_exception_handlers(): """Test exception handlers.""" from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError - from models.lambda_functions import model_not_found_handler, user_error_handler, validation_exception_handler # Setup mock request request = MagicMock() @@ -592,7 +605,6 @@ async def test_exception_handlers(): async def test_fastapi_endpoints(sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client): """Test FastAPI endpoints.""" from fastapi.testclient import TestClient - from models.lambda_functions import app # Create test client client = TestClient(app) @@ -604,7 +616,11 @@ async def test_fastapi_endpoints(sample_model, model_table, mock_autoscaling_cli "models.lambda_functions.UpdateModelHandler" ) as mock_update_handler, patch( "models.lambda_functions.DeleteModelHandler" - ) as mock_delete_handler: + ) as mock_delete_handler, patch( + "models.lambda_functions.get_admin_status_and_groups" + ) as mock_get_admin_status: + # Mock admin status - return admin=True for all operations in this test + mock_get_admin_status.return_value = (True, []) # Setup handler mocks create_handler_instance = MagicMock() @@ -721,7 +737,6 @@ async def test_fastapi_endpoints(sample_model, model_table, mock_autoscaling_cli @pytest.mark.asyncio async def test_get_instances(): """Test get_instances endpoint.""" - from models.lambda_functions import get_instances # Mock the shape_for method to return a mock with enum attribute mock_shape = MagicMock() @@ -738,3 +753,242 @@ async def test_get_instances(): assert "t2.micro" in result assert "t3.small" in result assert "m5.large" in result + + +@pytest.fixture +def admin_event(): + """Create an AWS event with admin user.""" + return { + "requestContext": { + "authorizer": { + "username": "admin-user", + "groups": json.dumps(["admin-group"]), + } + } + } + + +@pytest.fixture +def non_admin_event(): + """Create an AWS event with non-admin user.""" + return { + "requestContext": { + "authorizer": { + "username": "regular-user", + "groups": json.dumps(["user-group"]), + } + } + } + + +@pytest.mark.asyncio +async def test_get_admin_status_and_groups(): + """Test the get_admin_status_and_groups helper function.""" + + # Test with admin event + admin_event = { + "requestContext": { + "authorizer": { + "username": "admin-user", + "groups": json.dumps(["admin-group"]), + } + } + } + + mock_request = MagicMock(spec=Request) + mock_request.scope = {"aws.event": admin_event} + + with patch("models.lambda_functions.is_admin") as mock_is_admin, patch( + "models.lambda_functions.get_groups" + ) as mock_get_groups: + mock_is_admin.return_value = True + mock_get_groups.return_value = ["admin-group"] + + admin_status, user_groups = get_admin_status_and_groups(mock_request) + assert admin_status is True + assert user_groups == ["admin-group"] + + # Test with non-admin event + non_admin_event = { + "requestContext": { + "authorizer": { + "username": "regular-user", + "groups": json.dumps(["user-group"]), + } + } + } + + mock_request.scope = {"aws.event": non_admin_event} + + with patch("models.lambda_functions.is_admin") as mock_is_admin, patch( + "models.lambda_functions.get_groups" + ) as mock_get_groups: + mock_is_admin.return_value = False + mock_get_groups.return_value = ["user-group"] + + admin_status, user_groups = get_admin_status_and_groups(mock_request) + assert admin_status is False + assert user_groups == ["user-group"] + + # Test with no event in scope + mock_request.scope = {} + admin_status, user_groups = get_admin_status_and_groups(mock_request) + assert admin_status is False + assert user_groups == [] + + +@pytest.mark.asyncio +async def test_create_model_admin_required( + sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client, admin_event, non_admin_event +): + """Test that create_model endpoint requires admin access.""" + + # Test non-admin cannot create + mock_request = MagicMock(spec=Request) + mock_request.scope = {"aws.event": non_admin_event} + + create_request = CreateModelRequest( + modelId="test-model", modelName="test-model", modelType=ModelType.TEXTGEN, streaming=True + ) + + with patch("models.lambda_functions.is_admin") as mock_is_admin, patch( + "models.lambda_functions.get_groups" + ) as mock_get_groups: + mock_is_admin.return_value = False + mock_get_groups.return_value = [] + + with pytest.raises(HTTPException) as exc_info: + await create_model(create_request, mock_request) + assert exc_info.value.status_code == 403 + assert "User does not have permission to create models" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_update_model_admin_required( + sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client, admin_event, non_admin_event +): + """Test that update_model endpoint requires admin access.""" + + # Test non-admin cannot update + mock_request = MagicMock(spec=Request) + mock_request.scope = {"aws.event": non_admin_event} + + update_request = UpdateModelRequest(streaming=False) + + with patch("models.lambda_functions.is_admin") as mock_is_admin, patch( + "models.lambda_functions.get_groups" + ) as mock_get_groups: + mock_is_admin.return_value = False + mock_get_groups.return_value = [] + + with pytest.raises(HTTPException) as exc_info: + await update_model("test-model", update_request, mock_request) + assert exc_info.value.status_code == 403 + assert "User does not have permission to update models" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_delete_model_admin_required( + sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client, admin_event, non_admin_event +): + """Test that delete_model endpoint requires admin access.""" + + # Test non-admin cannot delete + mock_request = MagicMock(spec=Request) + mock_request.scope = {"aws.event": non_admin_event} + + with patch("models.lambda_functions.is_admin") as mock_is_admin, patch( + "models.lambda_functions.get_groups" + ) as mock_get_groups: + mock_is_admin.return_value = False + mock_get_groups.return_value = [] + + with pytest.raises(HTTPException) as exc_info: + await delete_model("test-model", mock_request) + assert exc_info.value.status_code == 403 + assert "User does not have permission to delete models" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_create_update_delete_admin_allowed( + sample_model, model_table, mock_autoscaling_client, mock_stepfunctions_client, admin_event +): + """Test that admin users can successfully create, update, and delete models.""" + + mock_request = MagicMock(spec=Request) + mock_request.scope = {"aws.event": admin_event} + + with patch("models.lambda_functions.is_admin") as mock_is_admin, patch( + "models.lambda_functions.get_groups" + ) as mock_get_groups, patch("models.lambda_functions.CreateModelHandler") as mock_create_handler, patch( + "models.lambda_functions.UpdateModelHandler" + ) as mock_update_handler, patch( + "models.lambda_functions.DeleteModelHandler" + ) as mock_delete_handler: + mock_is_admin.return_value = True + mock_get_groups.return_value = ["admin-group"] + + # Mock create handler + create_handler_instance = MagicMock() + create_model_response = CreateModelResponse( + model=LISAModel( + modelId="new-model", + modelName="new-model-name", + modelType=ModelType.TEXTGEN, + status=ModelStatus.CREATING, + streaming=True, + ) + ) + create_handler_instance.return_value = create_model_response + mock_create_handler.return_value = create_handler_instance + + # Test admin can create + create_request = CreateModelRequest( + modelId="test-model", modelName="test-model", modelType=ModelType.TEXTGEN, streaming=True + ) + response = await create_model(create_request, mock_request) + assert isinstance(response, CreateModelResponse) + assert response.model.modelId == "new-model" + + # Mock update handler + model_table.put_item(Item=sample_model) + update_handler_instance = MagicMock() + update_model_response = UpdateModelResponse( + model=LISAModel( + modelId="test-model", + modelName="gpt-3.5-turbo", + modelType=ModelType.TEXTGEN, + status=ModelStatus.IN_SERVICE, + streaming=False, + features=[{"name": "test-feature", "overview": "This is a test feature"}], + ) + ) + update_handler_instance.return_value = update_model_response + mock_update_handler.return_value = update_handler_instance + + # Test admin can update + update_request = UpdateModelRequest(streaming=False) + response = await update_model("test-model", update_request, mock_request) + assert isinstance(response, UpdateModelResponse) + assert response.model.modelId == "test-model" + + # Mock delete handler + delete_handler_instance = MagicMock() + delete_model_response = DeleteModelResponse( + model=LISAModel( + modelId="test-model", + modelName="gpt-3.5-turbo", + modelType=ModelType.TEXTGEN, + status=ModelStatus.DELETING, + streaming=True, + features=[{"name": "test-feature", "overview": "This is a test feature"}], + ) + ) + delete_handler_instance.return_value = delete_model_response + mock_delete_handler.return_value = delete_handler_instance + + # Test admin can delete + response = await delete_model("test-model", mock_request) + assert isinstance(response, DeleteModelResponse) + assert response.model.modelId == "test-model" + assert response.model.status == ModelStatus.DELETING