From 5e1ab83c9ace517918ea633fbd851e79d75697a1 Mon Sep 17 00:00:00 2001 From: WMC001 <46217886+WMC001@users.noreply.github.com> Date: Fri, 17 Apr 2026 10:11:19 +0800 Subject: [PATCH 001/156] Update CODEOWNERS --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a3a924020..64d98fbcf 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ # These owners will be the default owners for everything in the repo -* @Phinease @WMC001 +* @WMC001 @Dallas98 From d8f1b99f66f782f93aa181e6f252ec0d751a4361 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Sat, 18 Apr 2026 10:48:46 +0800 Subject: [PATCH 002/156] =?UTF-8?q?=F0=9F=94=A7=20Bump=20APP=5FVERSION=20t?= =?UTF-8?q?o=20v2.0.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/consts/const.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/consts/const.py b/backend/consts/const.py index 223a1d00b..bccb91ccd 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -345,4 +345,4 @@ class VectorDatabaseType(str, Enum): # APP Version -APP_VERSION = "v2.0.1" +APP_VERSION = "v2.0.2" From 045a94894ab3ec5200821eacc3c674b391555c7f Mon Sep 17 00:00:00 2001 From: panyehong <91180085+YehongPan@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:48:13 +0800 Subject: [PATCH 003/156] =?UTF-8?q?=E2=9C=A8=20Feat:=20Personal=20file=20u?= =?UTF-8?q?ploads=20support=20permission=20isolation.=20#2836=20(#2837)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [Specification Detail] 1. Login authentication verification has been added to all file retrieval-related interfaces. 2. Personally uploaded files are stored in a dedicated directory within minio, and permission checks are performed on each access. --- backend/agents/create_agent_info.py | 6 +- backend/apps/file_management_app.py | 280 +++++--- backend/services/file_management_service.py | 158 +++- .../services/tool_configuration_service.py | 8 +- sdk/nexent/core/agents/nexent_agent.py | 12 +- sdk/nexent/core/tools/analyze_image_tool.py | 18 +- .../core/tools/analyze_text_file_tool.py | 18 +- sdk/nexent/multi_modal/load_save_object.py | 48 +- test/backend/agents/test_create_agent_info.py | 109 ++- test/backend/app/test_file_management_app.py | 675 +++++++++++++++--- .../services/test_file_management_service.py | 428 ++++++++++- .../test_tool_configuration_service.py | 26 +- test/sdk/core/agents/test_nexent_agent.py | 265 +++++-- .../sdk/core/tools/test_analyze_image_tool.py | 82 ++- .../core/tools/test_analyze_text_file_tool.py | 109 +++ test/sdk/multi_modal/test_load_save_object.py | 319 ++++++++- 16 files changed, 2288 insertions(+), 273 deletions(-) diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index ea3ba24e8..e0fce0f47 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -9,7 +9,7 @@ from nexent.core.agents.agent_model import AgentRunInfo, ModelConfig, AgentConfig, ToolConfig, ExternalA2AAgentConfig from nexent.memory.memory_service import search_memory_in_levels -from services.file_management_service import get_llm_model +from services.file_management_service import get_llm_model, validate_urls_access from services.vectordatabase_service import ( ElasticSearchService, get_vector_db_core, @@ -479,12 +479,14 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tool_config.metadata = { "llm_model": get_llm_model(tenant_id=tenant_id), "storage_client": minio_client, - "data_process_service_url": DATA_PROCESS_SERVICE + "data_process_service_url": DATA_PROCESS_SERVICE, + "validate_url_access": lambda urls: validate_urls_access(urls, user_id) } elif tool_config.class_name == "AnalyzeImageTool": tool_config.metadata = { "vlm_model": get_vlm_model(tenant_id=tenant_id), "storage_client": minio_client, + "validate_url_access": lambda urls: validate_urls_access(urls, user_id) } tool_config_list.append(tool_config) diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 50224c952..b8e1ce711 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -14,7 +14,8 @@ from consts.model import ProcessParams from services.file_management_service import upload_to_minio, upload_files_impl, \ get_file_url_impl, get_file_stream_impl, delete_file_impl, list_files_impl, \ - resolve_preview_file, get_preview_stream + resolve_preview_file, get_preview_stream, check_file_access, check_file_access_batch +from utils.auth_utils import get_current_user_id from utils.file_management_utils import trigger_data_process logger = logging.getLogger("file_management_app") @@ -91,27 +92,36 @@ async def upload_files( folder: str = Form( "attachments", description="Storage folder path for MinIO (optional)"), index_name: Optional[str] = Form( - None, description="Knowledge base index for conflict resolution") + None, description="Knowledge base index for conflict resolution"), + authorization: Optional[str] = Header(None, alias="Authorization") ): - if not file: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, - detail="No files in the request") + try: + if not file: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, + detail="No files in the request") - errors, uploaded_file_paths, uploaded_filenames = await upload_files_impl(destination, file, folder, index_name) + user_id, tenant_id = get_current_user_id(authorization) + errors, uploaded_file_paths, uploaded_filenames = await upload_files_impl(destination, file, folder, index_name, user_id) - if uploaded_file_paths: - return JSONResponse( - status_code=HTTPStatus.OK, - content={ - "message": f"Files uploaded successfully to {destination}, ready for processing.", - "uploaded_filenames": uploaded_filenames, - "uploaded_file_paths": uploaded_file_paths, - "errors": errors - } - ) - else: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, - detail="No valid files uploaded") + if uploaded_file_paths: + return JSONResponse( + status_code=HTTPStatus.OK, + content={ + "message": f"Files uploaded successfully to {destination}, ready for processing.", + "uploaded_filenames": uploaded_filenames, + "uploaded_file_paths": uploaded_file_paths, + "errors": errors + } + ) + else: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, + detail="No valid files uploaded") + except HTTPException: + raise + except Exception as e: + logger.error(f"File upload error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="File upload error.") @file_management_config_router.post("/process") @@ -169,10 +179,15 @@ async def get_storage_file( ), ), expires: int = Query(3600, description="URL validity period (seconds)"), - filename: Optional[str] = Query(None, description="Original filename for download (optional)") + filename: Optional[str] = Query(None, description="Original filename for download (optional)"), + authorization: Optional[str] = Header(None, alias="Authorization") ): """ - Get information, download link, or file stream for a single file + Get information, download link, or file stream for a single file. + + Access control: + - knowledge_base/*: All authenticated users can access + - attachments/{user_id}/*: Only the owner (user_id) can access - **object_name**: File object name - **download**: Download mode: ignore (default, return file info), stream (return file stream), redirect (redirect to download URL) @@ -182,25 +197,29 @@ async def get_storage_file( Returns file information, download link, or file content """ try: + user_id, tenant_id = get_current_user_id(authorization) + + if not check_file_access(object_name, user_id): + logger.warning(f"[get_storage_file] Access denied: object_name={object_name}, user_id={user_id}") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail="You don't have permission to access this file" + ) + logger.info(f"[get_storage_file] Route matched! object_name={object_name}, download={download}, filename={filename}") if download == "redirect": - # return a redirect download URL result = await get_file_url_impl(object_name=object_name, expires=expires) return RedirectResponse(url=result["url"]) elif download == "stream": - # return a readable file stream file_stream, content_type = await get_file_stream_impl(object_name=object_name) logger.info(f"Streaming file: object_name={object_name}, content_type={content_type}") - - # Use provided filename or extract from object_name + download_filename = filename if not download_filename: - # Extract filename from object_name (get the last part after the last slash) download_filename = object_name.split("/")[-1] if "/" in object_name else object_name - - # Build Content-Disposition header with proper encoding for non-ASCII characters + content_disposition = build_content_disposition_header(download_filename) - + return StreamingResponse( file_stream, media_type=content_type, @@ -211,7 +230,6 @@ async def get_storage_file( } ) elif download == "base64": - # Return base64 encoded file content (primarily for images) file_stream, content_type = await get_file_stream_impl(object_name=object_name) try: data = file_stream.read() @@ -233,13 +251,13 @@ async def get_storage_file( }, ) else: - # return file metadata return await get_file_url_impl(object_name=object_name, expires=expires) + except HTTPException: + raise except Exception as e: logger.error(f"Failed to get file: object_name={object_name}, error={str(e)}") raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"Failed to get file information: {str(e)}" + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Failed to get file." ) @@ -248,17 +266,45 @@ async def get_storage_file( async def storage_upload_files( files: List[UploadFile] = File(..., description="List of files to upload"), folder: str = Form( - "attachments", description="Storage folder path (optional)") + "attachments", description="Storage folder path (optional)"), + authorization: Optional[str] = Header(None, alias="Authorization") ): """ - Upload one or more files to MinIO storage + Upload one or more files to MinIO storage. - **files**: List of files to upload - **folder**: Storage folder path (optional, defaults to 'attachments') + Use 'knowledge_base' for shared files accessible by all users. + Other folders (like 'attachments') will be isolated by user_id. Returns upload results including file information and access URLs """ - results = await upload_to_minio(files=files, folder=folder) + try: + user_id, tenant_id = get_current_user_id(authorization) + + if folder == "knowledge_base": + actual_folder = "knowledge_base" + else: + if user_id: + actual_folder = f"attachments/{user_id}" + else: + actual_folder = folder or "attachments" + + results = await upload_to_minio(files=files, folder=actual_folder, user_id=user_id) + + return { + "message": f"Processed {len(results)} files", + "success_count": sum(1 for r in results if r.get("success", False)), + "failed_count": sum(1 for r in results if not r.get("success", False)), + "results": results + } + except HTTPException: + raise + except Exception as e: + logger.error(f"Storage upload error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Storage upload error." + ) # Return upload results for all files return { @@ -274,10 +320,16 @@ async def get_storage_files( prefix: str = Query("", description="File prefix filter"), limit: int = Query(100, description="Maximum number of files to return"), include_urls: bool = Query( - True, description="Whether to include presigned URLs") + True, description="Whether to include presigned URLs"), + authorization: Optional[str] = Header(None, alias="Authorization") ): """ - Get list of files from MinIO storage + Get list of files from MinIO storage. + + Access control: + - Returns only files the user has permission to access: + - knowledge_base/*: All authenticated users can access + - attachments/{user_id}/*: Only the owner's files - **prefix**: File prefix filter (optional) - **limit**: Maximum number of files to return (default 100) @@ -286,8 +338,22 @@ async def get_storage_files( Returns file list and metadata """ try: + user_id, tenant_id = get_current_user_id(authorization) files = await list_files_impl(prefix, limit) - # Remove URLs if not needed + + if user_id: + filtered_files = [ + f for f in files + if f.get("key") and check_file_access(f.get("key"), user_id) + ] + else: + filtered_files = [ + f for f in files + if f.get("key") and f.get("key", "").startswith("knowledge_base/") + ] + + files = filtered_files + if not include_urls: for file in files: if "url" in file: @@ -297,10 +363,12 @@ async def get_storage_files( "total": len(files), "files": files } + except HTTPException: + raise except Exception as e: + logger.error(f"Get storage files error: {str(e)}") raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"Failed to get file list: {str(e)}" + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Get storage files error." ) @@ -481,7 +549,7 @@ async def download_datamate_file( # Build Content-Disposition header with proper encoding for non-ASCII characters content_disposition = build_content_disposition_header(download_filename) - + return StreamingResponse( iter([response.content]), media_type=content_type, @@ -507,25 +575,41 @@ async def download_datamate_file( @file_management_config_router.delete("/storage/{object_name:path}") async def remove_storage_file( - object_name: str = PathParam(..., description="File object name to delete") + object_name: str = PathParam(..., description="File object name to delete"), + authorization: Optional[str] = Header(None, alias="Authorization") ): """ - Delete file from MinIO storage + Delete file from MinIO storage. + + Access control: + - knowledge_base/*: Only allow deletion (admin operation) + - attachments/{user_id}/*: Only the owner (user_id) can delete - **object_name**: File object name to delete Returns deletion operation result """ try: + user_id, tenant_id = get_current_user_id(authorization) + + if not check_file_access(object_name, user_id): + logger.warning(f"[remove_storage_file] Access denied: object_name={object_name}, user_id={user_id}") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail="You don't have permission to delete this file" + ) + await delete_file_impl(object_name=object_name) return { "success": True, "message": f"File {object_name} successfully deleted" } + except HTTPException: + raise except Exception as e: + logger.error(f"Remove storage file error: {str(e)}") raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"Failed to delete file: {str(e)}" + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Remove storage file error." ) @@ -533,57 +617,82 @@ async def remove_storage_file( async def get_storage_file_batch_urls( request_data: dict = Body(..., description="JSON containing list of file object names"), - expires: int = Query(3600, description="URL validity period (seconds)") + expires: int = Query(3600, description="URL validity period (seconds)"), + authorization: Optional[str] = Header(None, alias="Authorization") ): """ - Batch get download URLs for multiple files (JSON request) + Batch get download URLs for multiple files (JSON request). + + Access control: + - knowledge_base/*: All authenticated users can access + - attachments/{user_id}/*: Only the owner (user_id) can access - **request_data**: JSON request body containing object_names list - **expires**: URL validity period in seconds (default 3600) Returns URL and status information for each file """ - # Extract object_names from request body - object_names = request_data.get("object_names", []) - if not object_names or not isinstance(object_names, list): - raise HTTPException( - status_code=400, detail="Request body must contain object_names array") + try: + user_id, tenant_id = get_current_user_id(authorization) - results = [] + object_names = request_data.get("object_names", []) + if not object_names or not isinstance(object_names, list): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail="Request body must contain object_names array") - for object_name in object_names: - try: - # Get file URL - result = get_file_url_impl( - object_name=object_name, expires=expires) - results.append({ - "object_name": object_name, - "success": result["success"], - "url": result.get("url"), - "error": result.get("error") - }) - except Exception as e: - results.append({ - "object_name": object_name, - "success": False, - "error": str(e) - }) + results = [] - return { - "total": len(results), - "success_count": sum(1 for r in results if r.get("success", False)), - "failed_count": sum(1 for r in results if not r.get("success", False)), - "results": results - } + for object_name in object_names: + if not check_file_access(object_name, user_id): + results.append({ + "object_name": object_name, + "success": False, + "error": "Access denied" + }) + continue + + try: + result = get_file_url_impl(object_name=object_name, expires=expires) + results.append({ + "object_name": object_name, + "success": result["success"], + "url": result.get("url"), + "error": result.get("error") + }) + except Exception as e: + results.append({ + "object_name": object_name, + "success": False, + "error": str(e) + }) + + return { + "total": len(results), + "success_count": sum(1 for r in results if r.get("success", False)), + "failed_count": sum(1 for r in results if not r.get("success", False)), + "results": results + } + except HTTPException: + raise + except Exception as e: + logger.error(f"Batch URLs error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Batch URLs error." + ) @file_management_config_router.get("/preview/{object_name:path}") async def preview_file( object_name: str = PathParam(..., description="File object name to preview"), filename: Annotated[Optional[str], Query(description="Original filename for display (optional)")] = None, range_header: Annotated[Optional[str], Header(alias="range")] = None, + authorization: Optional[str] = Header(None, alias="Authorization") ): """ - Preview file inline in browser + Preview file inline in browser. + + Access control: + - knowledge_base/*: All authenticated users can access + - attachments/{user_id}/*: Only the owner (user_id) can access - **object_name**: File object name in storage - **filename**: Original filename for Content-Disposition header (optional) @@ -592,6 +701,15 @@ async def preview_file( Returns 206 Partial Content when a valid Range header is present. """ try: + user_id, tenant_id = get_current_user_id(authorization) + + if not check_file_access(object_name, user_id): + logger.warning(f"[preview_file] Access denied: object_name={object_name}, user_id={user_id}") + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail="You don't have permission to access this file" + ) + actual_name, content_type, total_size = await resolve_preview_file(object_name=object_name) except FileTooLargeException as e: logger.warning(f"[preview_file] File too large: object_name={object_name}, error={str(e)}") @@ -608,13 +726,15 @@ async def preview_file( except UnsupportedFileTypeException as e: logger.error(f"[preview_file] Unsupported file type: object_name={object_name}, error={str(e)}") raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, + status_code=HTTPStatus.BAD_REQUEST, detail=f"File format not supported for preview: {str(e)}" ) + except HTTPException: + raise except Exception as e: logger.error(f"[preview_file] Unexpected error: object_name={object_name}, error={str(e)}") raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Failed to preview file" ) diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py index d73c91c72..e47d199d6 100644 --- a/backend/services/file_management_service.py +++ b/backend/services/file_management_service.py @@ -4,7 +4,7 @@ import os from io import BytesIO from pathlib import Path -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import httpx from fastapi import UploadFile @@ -50,7 +50,117 @@ logger = logging.getLogger("file_management_service") -async def upload_files_impl(destination: str, file: List[UploadFile], folder: str = None, index_name: Optional[str] = None) -> tuple: +def check_file_access(object_name: str, user_id: Optional[str]) -> bool: + """ + Check if user has permission to access the file. + + Access rules: + - knowledge_base/*: All authenticated users can access + - attachments/{user_id}/*: Only the owner (user_id) can access + - preview/*: Accessible if the original file is accessible + + Args: + object_name: File object name in storage + user_id: Current user ID + + Returns: + True if access is allowed, False otherwise + """ + if not user_id: + return False + + if object_name.startswith("knowledge_base/"): + # Knowledge base files: all authenticated users can access + return True + + # Check if file is in user's attachments folder + # Pattern: attachments/{user_id}/* + if object_name.startswith(f"attachments/{user_id}/"): + return True + + # For backward compatibility, allow access to files in root attachments folder + # Pattern: attachments/{filename} (no user_id subfolder) + if object_name.startswith("attachments/") and "/" not in object_name.replace("attachments/", "", 1): + # Old format: attachments/filename (no subdirectory) + # Allow access for backward compatibility + return True + + return False + + +def check_file_access_batch(object_names: List[str], user_id: Optional[str]) -> Dict[str, bool]: + """ + Batch check file access permissions. + + Args: + object_names: List of file object names + user_id: Current user ID + + Returns: + Dict mapping object_name to access permission (True/False) + """ + return {obj_name: check_file_access(obj_name, user_id) for obj_name in object_names} + + +def validate_s3_url_access(object_name: str, user_id: Optional[str]) -> None: + """ + Validate if user has permission to access the S3 URL. + + Args: + object_name: File object name in storage (extracted from S3 URL) + user_id: Current user ID + + Raises: + PermissionError: If user doesn't have permission to access the file + """ + if not user_id: + raise PermissionError("User authentication required to access files") + + if not check_file_access(object_name, user_id): + logger.warning(f"[validate_s3_url_access] Access denied: object_name={object_name}, user_id={user_id}") + raise PermissionError(f"Access denied: You don't have permission to access this file ({object_name})") + + +def validate_urls_access(urls: List[str], user_id: Optional[str]) -> None: + """ + Validate if user has permission to access the given URLs. + + Supports S3 URLs (s3://bucket/key or /bucket/key format). + + Args: + urls: List of URLs to validate (S3, HTTP, or HTTPS) + user_id: Current user ID + + Raises: + PermissionError: If user doesn't have permission to access any of the files + """ + if not urls: + return + + from sdk.nexent.multi_modal.utils import parse_s3_url + + for url in urls: + if not url: + continue + + # Only validate S3 URLs (MinIO storage) + # HTTP/HTTPS URLs are external resources and are not subject to MinIO access control + if url.startswith("s3://"): + try: + _, object_name = parse_s3_url(url) + validate_s3_url_access(object_name, user_id) + except ValueError as e: + logger.warning(f"[validate_urls_access] Failed to parse S3 URL: {url}, error: {e}") + raise PermissionError(f"Invalid S3 URL format: {url}") + elif url.startswith("/") and not url.startswith("//"): + # Handle /bucket/key format (absolute path style) + parts = url.strip("/").split("/", 1) + if len(parts) == 2: + bucket, object_name = parts + validate_s3_url_access(object_name, user_id) + + +async def upload_files_impl(destination: str, file: List[UploadFile], folder: str = None, index_name: Optional[str] = None, user_id: Optional[str] = None) -> tuple: """ Upload files to local storage or MinIO based on destination. @@ -58,6 +168,8 @@ async def upload_files_impl(destination: str, file: List[UploadFile], folder: st destination: "local" or "minio" file: List of UploadFile objects folder: Folder name for MinIO uploads + index_name: Knowledge base index for conflict resolution + user_id: User ID for attachment path isolation Returns: tuple: (errors, uploaded_file_paths, uploaded_filenames) @@ -84,7 +196,20 @@ async def upload_files_impl(destination: str, file: List[UploadFile], folder: st errors.append(f"Failed to save file: {f.filename}") elif destination == "minio": - minio_results = await upload_to_minio(files=file, folder=folder) + # Determine actual folder path based on file type + # knowledge_base: accessible by all authenticated users + # other folders (attachments): user-isolated path (attachments/{user_id}/...) + if folder == "knowledge_base": + actual_folder = "knowledge_base" + else: + # User isolation for personal attachments + if user_id: + actual_folder = f"attachments/{user_id}" + else: + # Fallback to old behavior if no user_id provided + actual_folder = folder or "attachments" + + minio_results = await upload_to_minio(files=file, folder=actual_folder) for result in minio_results: if result.get("success"): uploaded_filenames.append(result.get("file_name")) @@ -137,8 +262,18 @@ def make_unique_names(original_names: List[str], taken_lower: set) -> List[str]: return errors, uploaded_file_paths, uploaded_filenames -async def upload_to_minio(files: List[UploadFile], folder: str) -> List[dict]: - """Helper function to upload files to MinIO and return results.""" +async def upload_to_minio(files: List[UploadFile], folder: str, user_id: Optional[str] = None) -> List[dict]: + """ + Helper function to upload files to MinIO and return results. + + Args: + files: List of files to upload + folder: Storage folder path (will be prefixed with user_id if user_id is provided for attachments) + user_id: User ID for attachment path isolation + + Returns: + List of upload results + """ results = [] for f in files: try: @@ -148,11 +283,22 @@ async def upload_to_minio(files: List[UploadFile], folder: str) -> List[dict]: # Convert file content to BytesIO object file_obj = BytesIO(file_content) + # Determine actual folder path + # knowledge_base: no user isolation + # other folders: append user_id to path for isolation + if folder == "knowledge_base": + actual_folder = "knowledge_base" + else: + if user_id: + actual_folder = f"attachments/{user_id}" + else: + actual_folder = folder or "attachments" + # Upload file result = upload_fileobj( file_obj=file_obj, file_name=f.filename or "", - prefix=folder + prefix=actual_folder ) # Reset file pointer for potential re-reading diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index e3a4cfa4f..f36902edb 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -35,7 +35,7 @@ update_tool_table_from_scan_tool_list, ) from mcpadapt.smolagents_adapter import _sanitize_function_name -from services.file_management_service import get_llm_model +from services.file_management_service import get_llm_model, validate_urls_access from services.vectordatabase_service import get_embedding_model, get_rerank_model, get_vector_db_core from database.client import minio_client from services.image_service import get_vlm_model @@ -740,7 +740,8 @@ def _validate_local_tool( params = { **instantiation_params, 'vlm_model': image_to_text_model, - 'storage_client': minio_client + 'storage_client': minio_client, + 'validate_url_access': lambda urls: validate_urls_access(urls, user_id) } tool_instance = tool_class(**params) elif tool_name == "analyze_text_file": @@ -752,7 +753,8 @@ def _validate_local_tool( **instantiation_params, 'llm_model': long_text_to_text_model, 'storage_client': minio_client, - "data_process_service_url": DATA_PROCESS_SERVICE + "data_process_service_url": DATA_PROCESS_SERVICE, + 'validate_url_access': lambda urls: validate_urls_access(urls, user_id) } tool_instance = tool_class(**params) else: diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 6ba851a02..3674e05a6 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -94,15 +94,25 @@ def create_local_tool(self, tool_config: ToolConfig): tools_obj.rerank_model = tool_config.metadata.get( "rerank_model", None) if tool_config.metadata else None elif class_name == "AnalyzeTextFileTool": + # Extract validate_url_access from metadata if it's callable + validate_url_access = tool_config.metadata.get("validate_url_access") if tool_config.metadata else None + if validate_url_access is not None and not callable(validate_url_access): + validate_url_access = None tools_obj = tool_class(observer=self.observer, llm_model=tool_config.metadata.get("llm_model", []), storage_client=tool_config.metadata.get("storage_client", []), data_process_service_url=tool_config.metadata.get("data_process_service_url", []), + validate_url_access=validate_url_access, **params) elif class_name == "AnalyzeImageTool": + # Extract validate_url_access from metadata if it's callable + validate_url_access = tool_config.metadata.get("validate_url_access") if tool_config.metadata else None + if validate_url_access is not None and not callable(validate_url_access): + validate_url_access = None tools_obj = tool_class(observer=self.observer, vlm_model=tool_config.metadata.get("vlm_model", []), storage_client=tool_config.metadata.get("storage_client", []), + validate_url_access=validate_url_access, **params) else: tools_obj = tool_class(**params) @@ -225,7 +235,7 @@ def create_single_agent(self, agent_config: AgentConfig): try: # Create internal managed agents recursively managed_agents_list = [ - self.create_single_agent(sub_agent_config) + self.create_single_agent(sub_agent_config) for sub_agent_config in agent_config.managed_agents ] except Exception as e: diff --git a/sdk/nexent/core/tools/analyze_image_tool.py b/sdk/nexent/core/tools/analyze_image_tool.py index 84adeb484..3851a896b 100644 --- a/sdk/nexent/core/tools/analyze_image_tool.py +++ b/sdk/nexent/core/tools/analyze_image_tool.py @@ -58,6 +58,9 @@ class AnalyzeImageTool(Tool): }, "storage_client": { "description": "Storage client for downloading files" + }, + "validate_url_access": { + "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" } } output_type = "array" @@ -77,6 +80,10 @@ def __init__( storage_client: MinIOStorageClient = Field( description="Storage client for downloading files from S3 URLs、HTTP URLs、HTTPS URLs.", default=None, + exclude=True), + validate_url_access: callable = Field( + description="Callback function to validate URL access permissions", + default=None, exclude=True) ): super().__init__() @@ -87,8 +94,15 @@ def __init__( # Determine if the language is Chinese for internationalization self._is_chinese = bool(observer and observer.lang == "zh") - # Create LoadSaveObjectManager with the storage client - self.mm = LoadSaveObjectManager(storage_client=self.storage_client) + # Create LoadSaveObjectManager with the storage client and validation callback + # Ensure validate_url_access is callable before passing to LoadSaveObjectManager + validate_callback = None + if validate_url_access is not None and callable(validate_url_access): + validate_callback = validate_url_access + self.mm = LoadSaveObjectManager( + storage_client=self.storage_client, + validate_url_access=validate_callback + ) # Dynamically apply the load_object decorator to forward method self.forward = self.mm.load_object( diff --git a/sdk/nexent/core/tools/analyze_text_file_tool.py b/sdk/nexent/core/tools/analyze_text_file_tool.py index faba2153d..49b9a10ca 100644 --- a/sdk/nexent/core/tools/analyze_text_file_tool.py +++ b/sdk/nexent/core/tools/analyze_text_file_tool.py @@ -56,6 +56,9 @@ class AnalyzeTextFileTool(Tool): }, "llm_model": { "description": "The LLM model to use" + }, + "validate_url_access": { + "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" } } output_type = "array" @@ -81,6 +84,10 @@ def __init__( llm_model: str = Field( description="The LLM model to use", default=None, + exclude=True), + validate_url_access: callable = Field( + description="Callback function to validate URL access permissions", + default=None, exclude=True) ): super().__init__() @@ -88,7 +95,16 @@ def __init__( self.observer = observer self.llm_model = llm_model self.data_process_service_url = data_process_service_url - self.mm = LoadSaveObjectManager(storage_client=self.storage_client) + + # Create LoadSaveObjectManager with the storage client and validation callback + # Ensure validate_url_access is callable before passing to LoadSaveObjectManager + validate_callback = None + if validate_url_access is not None and callable(validate_url_access): + validate_callback = validate_url_access + self.mm = LoadSaveObjectManager( + storage_client=self.storage_client, + validate_url_access=validate_callback + ) self.time_out = 60 * 5 self.running_prompt_zh = "正在分析文件..." diff --git a/sdk/nexent/multi_modal/load_save_object.py b/sdk/nexent/multi_modal/load_save_object.py index 4bc391036..929ea571e 100644 --- a/sdk/nexent/multi_modal/load_save_object.py +++ b/sdk/nexent/multi_modal/load_save_object.py @@ -2,7 +2,7 @@ import inspect import logging from io import BytesIO -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional import requests from .utils import ( @@ -20,14 +20,24 @@ class LoadSaveObjectManager: """ Provide load/save decorators that operate on a specific storage client. - + The manager can be instantiated with a storage client and exposes decorator factories for `load_object` and `save_object`. A default module-level manager is also provided for backwards compatibility with existing helper functions. """ - def __init__(self, storage_client: Any): + def __init__(self, storage_client: Any, validate_url_access: callable = None): + """ + Initialize LoadSaveObjectManager. + + Args: + storage_client: Storage client for S3 operations + validate_url_access: Optional callback function to validate URL access permissions. + The callback receives a list of URLs and should raise + PermissionError if access is denied. + """ self._storage_client = storage_client + self._validate_url_access = validate_url_access def _get_client(self) -> Any: """ @@ -122,6 +132,11 @@ def load_object( def decorator(func: Callable): @functools.wraps(func) def wrapper(*args, **kwargs): + # Find the tool instance (self) from bound args + tool_instance = None + if args: + tool_instance = args[0] + def _transform_single_value(param_name: str, value: Any, transformer: Optional[Callable[[bytes], Any]]) -> Any: if isinstance(value, str): @@ -167,6 +182,31 @@ def _process_value(param_name: str, value: Any, bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() + # Collect all URLs to validate before downloading + all_urls_to_validate: List[str] = [] + for i, param_name in enumerate(input_names): + if param_name not in bound_args.arguments: + continue + + original_data = bound_args.arguments[param_name] + if original_data is None: + continue + + if isinstance(original_data, (list, tuple)): + all_urls_to_validate.extend([url for url in original_data if isinstance(url, str) and is_url(url)]) + elif isinstance(original_data, str) and is_url(original_data): + all_urls_to_validate.append(original_data) + + # Validate URL access before downloading any files + if all_urls_to_validate and self._validate_url_access is not None and callable(self._validate_url_access): + try: + self._validate_url_access(all_urls_to_validate) + except PermissionError: + raise + except Exception as e: + logger.error(f"[load_object] URL validation failed: {e}") + raise PermissionError(f"URL access validation failed: {e}") + for i, param_name in enumerate(input_names): if param_name not in bound_args.arguments: continue @@ -293,4 +333,4 @@ def wrapper(*args, **kwargs): return wrapper - return decorator \ No newline at end of file + return decorator diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index a0183d59e..ff2655e19 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -115,6 +115,7 @@ def _create_stub_module(name: str, **attrs): sys.modules['services.file_management_service'] = _create_stub_module( "services.file_management_service", get_llm_model=MagicMock(return_value="stub_llm_model"), + validate_urls_access=MagicMock(), ) sys.modules['services.tool_configuration_service'] = _create_stub_module( "services.tool_configuration_service", @@ -650,10 +651,11 @@ async def test_create_tool_config_list_with_analyze_image_tool(self): assert len(result) == 1 assert result[0] is mock_tool_instance mock_get_vlm_model.assert_called_once_with(tenant_id="tenant_1") - assert mock_tool_instance.metadata == { - "vlm_model": "mock_vlm_model", - "storage_client": mock_minio_client - } + # Verify metadata includes validate_url_access lambda + assert "vlm_model" in mock_tool_instance.metadata + assert "storage_client" in mock_tool_instance.metadata + assert "validate_url_access" in mock_tool_instance.metadata + assert callable(mock_tool_instance.metadata["validate_url_access"]) @pytest.mark.asyncio async def test_create_tool_config_list_with_analyze_text_file_tool(self): @@ -686,11 +688,12 @@ async def test_create_tool_config_list_with_analyze_text_file_tool(self): assert len(result) == 1 assert result[0] is mock_tool_instance mock_get_llm_model.assert_called_once_with(tenant_id="tenant_1") - assert mock_tool_instance.metadata == { - "llm_model": "mock_llm_model", - "storage_client": mock_minio_client, - "data_process_service_url": consts_const.DATA_PROCESS_SERVICE, - } + # Verify metadata includes validate_url_access lambda + assert "llm_model" in mock_tool_instance.metadata + assert "storage_client" in mock_tool_instance.metadata + assert "data_process_service_url" in mock_tool_instance.metadata + assert "validate_url_access" in mock_tool_instance.metadata + assert callable(mock_tool_instance.metadata["validate_url_access"]) @pytest.mark.asyncio async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self): @@ -1181,6 +1184,94 @@ async def test_create_tool_config_list_with_datamate_tool_no_rerank(self): assert len(result) == 1 assert result[0] is mock_tool_instance + @pytest.mark.asyncio + async def test_create_tool_config_list_analyze_image_tool_validate_url_access(self): + """ + Test that AnalyzeImageTool receives validate_url_access callback that + properly calls validate_urls_access with user_id. + """ + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "AnalyzeImageTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vlm_model') as mock_get_vlm_model, \ + patch('backend.agents.create_agent_info.minio_client', new_callable=MagicMock), \ + patch('backend.agents.create_agent_info.validate_urls_access') as mock_validate: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "AnalyzeImageTool", + "name": "analyze_image", + "description": "Analyze image tool", + "inputs": "string", + "output_type": "string", + "params": [], + "source": "local", + "usage": None + } + ] + mock_get_vlm_model.return_value = "mock_vlm_model" + + result = await create_tool_config_list("agent_1", "tenant_1", "user_123") + + assert len(result) == 1 + assert "validate_url_access" in result[0].metadata + assert callable(result[0].metadata["validate_url_access"]) + + # Test that the callback properly wraps validate_urls_access + mock_validate.reset_mock() + test_urls = ["s3://bucket/image.jpg"] + result[0].metadata["validate_url_access"](test_urls) + mock_validate.assert_called_once_with(test_urls, "user_123") + + @pytest.mark.asyncio + async def test_create_tool_config_list_analyze_text_file_tool_validate_url_access(self): + """ + Test that AnalyzeTextFileTool receives validate_url_access callback that + properly calls validate_urls_access with user_id. + """ + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "AnalyzeTextFileTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_llm_model') as mock_get_llm_model, \ + patch('backend.agents.create_agent_info.minio_client', new_callable=MagicMock), \ + patch('backend.agents.create_agent_info.validate_urls_access') as mock_validate: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "AnalyzeTextFileTool", + "name": "analyze_text_file", + "description": "Analyze text file tool", + "inputs": "array", + "output_type": "array", + "params": [], + "source": "local", + "usage": None + } + ] + mock_get_llm_model.return_value = "mock_llm_model" + + result = await create_tool_config_list("agent_1", "tenant_1", "user_456") + + assert len(result) == 1 + assert "validate_url_access" in result[0].metadata + assert callable(result[0].metadata["validate_url_access"]) + + # Test that the callback properly wraps validate_urls_access + mock_validate.reset_mock() + test_urls = ["s3://bucket/document.pdf"] + result[0].metadata["validate_url_access"](test_urls) + mock_validate.assert_called_once_with(test_urls, "user_456") + class TestCreateAgentConfig: """Tests for the create_agent_config function""" diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py index 1a192db62..fc33db8fb 100644 --- a/test/backend/app/test_file_management_app.py +++ b/test/backend/app/test_file_management_app.py @@ -7,7 +7,7 @@ import sys import types -from typing import Any, AsyncGenerator, List +from typing import Any, AsyncGenerator, Dict, List import pytest from unittest.mock import AsyncMock, MagicMock @@ -32,10 +32,10 @@ sfms_stub = types.ModuleType("services.file_management_service") -async def _stub_upload_to_minio(files, folder): +async def _stub_upload_to_minio(files, folder, user_id=None): return [] -async def _stub_upload_files_impl(destination, file, folder, index_name): +async def _stub_upload_files_impl(destination, file, folder, index_name, user_id=None): return [], [], [] async def _stub_get_file_url_impl(object_name: str, expires: int): @@ -48,9 +48,25 @@ async def _stub_delete_file_impl(object_name: str): return {"success": True} async def _stub_list_files_impl(prefix: str, limit: int | None = None): - files = [{"name": "a.txt", "url": "http://u"}] + files = [{"name": "a.txt", "url": "http://u", "key": "knowledge_base/a.txt"}] return files[:limit] if limit else files +def _stub_check_file_access(object_name: str, user_id: str) -> bool: + """Stub for check_file_access - allows access by default for testing.""" + if object_name.startswith("attachments/"): + # attachments/{user_id}/*: only owner can access + if user_id: + expected_prefix = f"attachments/{user_id}" + return object_name.startswith(expected_prefix) + return False + # knowledge_base/*: all authenticated users can access + return object_name.startswith("knowledge_base/") + + +def _stub_check_file_access_batch(object_names: List[str], user_id: str) -> Dict[str, bool]: + """Stub for check_file_access_batch - returns dict of object_name -> allowed.""" + return {name: _stub_check_file_access(name, user_id) for name in object_names} + async def _stub_preprocess_files_generator(*_: Any, **__: Any) -> AsyncGenerator[str, None]: yield "data: {\"type\": \"progress\", \"progress\": 0}\n\n" yield "data: {\"type\": \"complete\", \"progress\": 100}\n\n" @@ -72,19 +88,27 @@ def _stub_get_preview_stream(actual_object_name, start=None, end=None): sfms_stub.delete_file_impl = _stub_delete_file_impl sfms_stub.list_files_impl = _stub_list_files_impl sfms_stub.preprocess_files_generator = _stub_preprocess_files_generator +sfms_stub.check_file_access = _stub_check_file_access +sfms_stub.check_file_access_batch = _stub_check_file_access_batch sys.modules["services.file_management_service"] = sfms_stub setattr(services_pkg, "file_management_service", sfms_stub) -# Stub utils.auth_utils.get_current_user_info +# Stub utils.auth_utils.get_current_user_id (the function actually used in the app) utils_pkg = types.ModuleType("utils") utils_pkg.__path__ = [] sys.modules.setdefault("utils", utils_pkg) auth_utils_stub = types.ModuleType("utils.auth_utils") -def _stub_get_current_user_info(authorization, request): - return ("user1", "tenant1", "en") -auth_utils_stub.get_current_user_info = _stub_get_current_user_info + +def _stub_get_current_user_id(authorization): + """Stub for get_current_user_id - returns user_id and tenant_id tuple.""" + if authorization is None or (isinstance(authorization, str) and not authorization.strip()): + # Return None user_id when no auth (simulates real behavior in speed mode disabled) + return (None, "tenant1") + return ("user1", "tenant1") + +auth_utils_stub.get_current_user_id = _stub_get_current_user_id sys.modules["utils.auth_utils"] = auth_utils_stub setattr(utils_pkg, "auth_utils", auth_utils_stub) @@ -143,6 +167,11 @@ def make_upload_file(filename: str, content: bytes = b"data"): return f +# Mock authorization header for tests +MOCK_AUTH = "Bearer mock_token" +MOCK_AUTH_NONE = None + + # --- Tests --- @pytest.mark.asyncio @@ -154,13 +183,14 @@ async def test_options_route_ok(): @pytest.mark.asyncio async def test_upload_files_success(monkeypatch): - async def fake_upload_impl(dest, files, folder, index_name): + async def fake_upload_impl(dest, files, folder, index_name, user_id=None): return [], ["/abs/path1"], ["a.txt"] monkeypatch.setattr(file_management_app, "upload_files_impl", fake_upload_impl) result = await file_management_app.upload_files( - file=[make_upload_file("a.txt")], destination="local", folder="attachments", index_name=None + file=[make_upload_file("a.txt")], destination="local", folder="attachments", index_name=None, + authorization=MOCK_AUTH ) assert result.status_code == 200 content = result.body.decode() @@ -171,23 +201,42 @@ async def fake_upload_impl(dest, files, folder, index_name): @pytest.mark.asyncio async def test_upload_files_no_files_bad_request(): with pytest.raises(Exception) as ei: - await file_management_app.upload_files(file=[], destination="local", folder="attachments", index_name=None) + await file_management_app.upload_files( + file=[], destination="local", folder="attachments", index_name=None, + authorization=MOCK_AUTH + ) assert "No files in the request" in str(ei.value) @pytest.mark.asyncio async def test_upload_files_no_valid_files_uploaded(monkeypatch): - async def fake_upload_impl(dest, files, folder, index_name): + async def fake_upload_impl(dest, files, folder, index_name, user_id=None): return ["err"], [], [] monkeypatch.setattr(file_management_app, "upload_files_impl", fake_upload_impl) with pytest.raises(Exception) as ei: await file_management_app.upload_files( - file=[make_upload_file("x.txt")], destination="minio", folder="attachments", index_name=None + file=[make_upload_file("x.txt")], destination="minio", folder="attachments", index_name=None, + authorization=MOCK_AUTH ) assert "No valid files uploaded" in str(ei.value) +@pytest.mark.asyncio +async def test_upload_files_internal_error(monkeypatch): + """Test upload_files with internal error returns 500.""" + async def fake_upload_impl(dest, files, folder, index_name, user_id=None): + raise RuntimeError("Storage failed") + + monkeypatch.setattr(file_management_app, "upload_files_impl", fake_upload_impl) + with pytest.raises(Exception) as ei: + await file_management_app.upload_files( + file=[make_upload_file("a.txt")], destination="local", folder="attachments", index_name=None, + authorization=MOCK_AUTH + ) + assert "File upload error" in str(ei.value) + + @pytest.mark.asyncio async def test_process_files_success(monkeypatch): async def fake_trigger(files, params): @@ -239,9 +288,78 @@ async def fake_trigger(files, params): assert "boom" in str(ei.value) +# --- storage_upload_files tests --- + +@pytest.mark.asyncio +async def test_storage_upload_files_knowledge_base_folder(monkeypatch): + """Test storage_upload_files with knowledge_base folder (shared, no user isolation).""" + async def fake_upload(files, folder, user_id=None): + return [{"success": True, "file_name": "shared.pdf", "key": f"{folder}/shared.pdf"}] + + monkeypatch.setattr(file_management_app, "upload_to_minio", fake_upload) + + f1 = make_upload_file("shared.pdf") + result = await file_management_app.storage_upload_files( + files=[f1], + folder="knowledge_base", + authorization=MOCK_AUTH + ) + assert result["message"].startswith("Processed 1") + assert result["success_count"] == 1 + assert result["failed_count"] == 0 + + +@pytest.mark.asyncio +async def test_storage_upload_files_attachments_folder_user_isolation(monkeypatch): + """Test storage_upload_files with attachments folder uses user_id for isolation.""" + captured_params = {} + + async def fake_upload(files, folder, user_id=None): + captured_params["folder"] = folder + captured_params["user_id"] = user_id + return [{"success": True, "file_name": "private.txt"}] + + monkeypatch.setattr(file_management_app, "upload_to_minio", fake_upload) + + f1 = make_upload_file("private.txt") + result = await file_management_app.storage_upload_files( + files=[f1], + folder="attachments", + authorization=MOCK_AUTH + ) + # Folder should be prefixed with user_id + assert captured_params["folder"] == "attachments/user1" + assert captured_params["user_id"] == "user1" + assert result["success_count"] == 1 + + +@pytest.mark.asyncio +async def test_storage_upload_files_attachments_no_auth_uses_raw_folder(monkeypatch): + """Test storage_upload_files without auth uses raw folder name.""" + captured_params = {} + + async def fake_upload(files, folder, user_id=None): + captured_params["folder"] = folder + captured_params["user_id"] = user_id + return [{"success": True, "file_name": "test.txt"}] + + monkeypatch.setattr(file_management_app, "upload_to_minio", fake_upload) + + f1 = make_upload_file("test.txt") + result = await file_management_app.storage_upload_files( + files=[f1], + folder="attachments", + authorization=MOCK_AUTH_NONE + ) + # Without user_id, folder should be raw value + assert captured_params["folder"] == "attachments" + assert captured_params["user_id"] is None + assert result["success_count"] == 1 + + @pytest.mark.asyncio async def test_storage_upload_files_counts(monkeypatch): - async def fake_upload(files, folder): + async def fake_upload(files, folder, user_id=None): return [ {"success": True, "file_name": "a.txt"}, {"success": False, "file_name": "b.txt", "error": "x"}, @@ -250,29 +368,104 @@ async def fake_upload(files, folder): monkeypatch.setattr(file_management_app, "upload_to_minio", fake_upload) f1 = make_upload_file("a.txt") f2 = make_upload_file("b.txt") - result = await file_management_app.storage_upload_files(files=[f1, f2], folder="attachments") + result = await file_management_app.storage_upload_files( + files=[f1, f2], + folder="attachments", + authorization=MOCK_AUTH + ) assert result["message"].startswith("Processed 2") assert result["success_count"] == 1 assert result["failed_count"] == 1 assert len(result["results"]) == 2 +@pytest.mark.asyncio +async def test_storage_upload_files_internal_error(monkeypatch): + """Test storage_upload_files with internal error returns 500.""" + async def fake_upload(files, folder, user_id=None): + raise RuntimeError("MinIO connection failed") + + monkeypatch.setattr(file_management_app, "upload_to_minio", fake_upload) + f1 = make_upload_file("a.txt") + + with pytest.raises(Exception) as ei: + await file_management_app.storage_upload_files( + files=[f1], + folder="attachments", + authorization=MOCK_AUTH + ) + assert "Storage upload error" in str(ei.value) + + +# --- get_storage_files tests --- + @pytest.mark.asyncio async def test_get_storage_files_include_and_strip_urls(monkeypatch): async def fake_list(prefix, limit): - return [{"name": "a", "url": "http://u"}, {"name": "b"}] + return [ + {"name": "a", "url": "http://u", "key": "knowledge_base/a.txt"}, + {"name": "b", "key": "attachments/user1/b.txt"} + ] monkeypatch.setattr(file_management_app, "list_files_impl", fake_list) # include URLs - out1 = await file_management_app.get_storage_files(prefix="", limit=10, include_urls=True) + out1 = await file_management_app.get_storage_files( + prefix="", limit=10, include_urls=True, authorization=MOCK_AUTH + ) assert out1["total"] == 2 assert out1["files"][0]["url"] == "http://u" # strip URLs - out2 = await file_management_app.get_storage_files(prefix="", limit=10, include_urls=False) + out2 = await file_management_app.get_storage_files( + prefix="", limit=10, include_urls=False, authorization=MOCK_AUTH + ) assert out2["total"] == 2 assert "url" not in out2["files"][0] +@pytest.mark.asyncio +async def test_get_storage_files_with_user_id_filters_by_access(monkeypatch): + """Test that get_storage_files filters files based on user access control.""" + async def fake_list(prefix, limit): + return [ + {"name": "a", "key": "knowledge_base/shared.txt"}, + {"name": "b", "key": "attachments/user1/mine.txt"}, + {"name": "c", "key": "attachments/user2/theirs.txt"}, # Should be filtered out + {"name": "d", "key": "attachments/another_user/private.txt"}, # Should be filtered out + ] + + monkeypatch.setattr(file_management_app, "list_files_impl", fake_list) + + out = await file_management_app.get_storage_files( + prefix="", limit=10, include_urls=False, authorization=MOCK_AUTH + ) + # user1 can access knowledge_base and attachments/user1 + keys = [f["key"] for f in out["files"]] + assert "knowledge_base/shared.txt" in keys + assert "attachments/user1/mine.txt" in keys + assert "attachments/user2/theirs.txt" not in keys + assert "attachments/another_user/private.txt" not in keys + + +@pytest.mark.asyncio +async def test_get_storage_files_no_auth_only_knowledge_base(monkeypatch): + """Test that unauthenticated requests only see knowledge_base files.""" + async def fake_list(prefix, limit): + return [ + {"name": "a", "key": "knowledge_base/shared.txt"}, + {"name": "b", "key": "attachments/user1/mine.txt"}, + ] + + monkeypatch.setattr(file_management_app, "list_files_impl", fake_list) + + out = await file_management_app.get_storage_files( + prefix="", limit=10, include_urls=False, authorization=MOCK_AUTH_NONE + ) + # Without auth, only knowledge_base files should be visible + keys = [f["key"] for f in out["files"]] + assert "knowledge_base/shared.txt" in keys + assert "attachments/user1/mine.txt" not in keys + + @pytest.mark.asyncio async def test_get_storage_files_error(monkeypatch): async def boom(prefix, limit): @@ -280,17 +473,27 @@ async def boom(prefix, limit): monkeypatch.setattr(file_management_app, "list_files_impl", boom) with pytest.raises(Exception) as ei: - await file_management_app.get_storage_files(prefix="p", limit=1, include_urls=True) - assert "Failed to get file list" in str(ei.value) + await file_management_app.get_storage_files( + prefix="p", limit=1, include_urls=True, authorization=MOCK_AUTH + ) + assert "Failed to get file list" in str(ei.value) or "Get storage files error" in str(ei.value) +# --- get_storage_file tests --- + @pytest.mark.asyncio async def test_get_storage_file_redirect(monkeypatch): async def fake_get_url(object_name, expires): return {"success": True, "url": "http://example.com/a"} monkeypatch.setattr(file_management_app, "get_file_url_impl", fake_get_url) - resp = await file_management_app.get_storage_file(object_name="a.txt", download="redirect", expires=60, filename="a.txt") + resp = await file_management_app.get_storage_file( + object_name="knowledge_base/a.txt", + download="redirect", + expires=60, + filename="a.txt", + authorization=MOCK_AUTH + ) # Starlette RedirectResponse defaults to 307 assert 300 <= resp.status_code < 400 assert resp.headers["location"] == "http://example.com/a" @@ -304,7 +507,13 @@ async def gen(): return gen(), "text/plain" monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) - resp = await file_management_app.get_storage_file(object_name="a.txt", download="stream", expires=60, filename="a.txt") + resp = await file_management_app.get_storage_file( + object_name="attachments/user1/a.txt", + download="stream", + expires=60, + filename="a.txt", + authorization=MOCK_AUTH + ) assert resp.headers["content-type"].startswith("text/plain") assert resp.media_type == "text/plain" # Content-Disposition should be "attachment" not "inline", and filename should be extracted from object_name @@ -331,10 +540,11 @@ def read(self): monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) resp = await file_management_app.get_storage_file( - object_name="attachments/img.png", + object_name="attachments/user1/img.png", download="base64", expires=60, filename=None, + authorization=MOCK_AUTH ) assert resp.status_code == 200 @@ -357,21 +567,29 @@ def read(self): with pytest.raises(Exception) as exc_info: await file_management_app.get_storage_file( - object_name="attachments/img.png", + object_name="attachments/user1/img.png", download="base64", expires=60, filename=None, + authorization=MOCK_AUTH ) assert "Failed to read file content for base64 encoding" in str(exc_info.value) + @pytest.mark.asyncio async def test_get_storage_file_metadata(monkeypatch): async def fake_get_url(object_name, expires): return {"success": True, "url": "http://example.com/x"} monkeypatch.setattr(file_management_app, "get_file_url_impl", fake_get_url) - result = await file_management_app.get_storage_file(object_name="x", download="ignore", expires=10, filename="x.txt") + result = await file_management_app.get_storage_file( + object_name="knowledge_base/x", + download="ignore", + expires=10, + filename="x.txt", + authorization=MOCK_AUTH + ) assert result["url"] == "http://example.com/x" @@ -382,20 +600,90 @@ async def boom_url(object_name, expires): monkeypatch.setattr(file_management_app, "get_file_url_impl", boom_url) with pytest.raises(Exception) as ei: - await file_management_app.get_storage_file(object_name="x", download="ignore", expires=1, filename="x.txt") - assert "Failed to get file information" in str(ei.value) + await file_management_app.get_storage_file( + object_name="knowledge_base/x", + download="ignore", + expires=1, + filename="x.txt", + authorization=MOCK_AUTH + ) + assert "Failed to get file information" in str(ei.value) or "Failed to get file" in str(ei.value) + + +@pytest.mark.asyncio +async def test_get_storage_file_access_denied_for_attachments(monkeypatch): + """Test that access to other user's attachments is forbidden.""" + def fake_check_access(object_name, user_id): + if object_name.startswith("attachments/"): + expected_prefix = f"attachments/{user_id}" + return object_name.startswith(expected_prefix) + return object_name.startswith("knowledge_base/") + + monkeypatch.setattr(file_management_app, "check_file_access", fake_check_access) + + with pytest.raises(Exception) as ei: + await file_management_app.get_storage_file( + object_name="attachments/other_user/file.txt", + download="ignore", + expires=60, + filename="file.txt", + authorization=MOCK_AUTH + ) + assert "permission" in str(ei.value).lower() or "forbidden" in str(ei.value).lower() +@pytest.mark.asyncio +async def test_get_storage_file_allows_knowledge_base_access(monkeypatch): + """Test that knowledge_base files are accessible to all authenticated users.""" + async def fake_get_url(object_name, expires): + return {"success": True, "url": "http://example.com/shared"} + + monkeypatch.setattr(file_management_app, "get_file_url_impl", fake_get_url) + + result = await file_management_app.get_storage_file( + object_name="knowledge_base/shared.pdf", + download="redirect", + expires=60, + filename="shared.pdf", + authorization=MOCK_AUTH + ) + assert result.headers["location"] == "http://example.com/shared" + + +# --- remove_storage_file tests --- + @pytest.mark.asyncio async def test_remove_storage_file_success(monkeypatch): async def ok_delete(object_name): return {"success": True} monkeypatch.setattr(file_management_app, "delete_file_impl", ok_delete) - result = await file_management_app.remove_storage_file(object_name="x") + result = await file_management_app.remove_storage_file( + object_name="attachments/user1/x", + authorization=MOCK_AUTH + ) assert result["success"] is True +@pytest.mark.asyncio +async def test_remove_storage_file_access_denied(monkeypatch): + """Test that deletion of other user's file is forbidden.""" + def fake_check_access(object_name, user_id): + if object_name.startswith("attachments/"): + expected_prefix = f"attachments/{user_id}" + return object_name.startswith(expected_prefix) + return object_name.startswith("knowledge_base/") + + monkeypatch.setattr(file_management_app, "check_file_access", fake_check_access) + + with pytest.raises(Exception) as ei: + await file_management_app.remove_storage_file( + object_name="attachments/other_user/file.txt", + authorization=MOCK_AUTH + ) + assert "permission" in str(ei.value).lower() or "forbidden" in str(ei.value).lower() + + @pytest.mark.asyncio async def test_remove_storage_file_error(monkeypatch): async def boom_delete(object_name): @@ -403,14 +691,21 @@ async def boom_delete(object_name): monkeypatch.setattr(file_management_app, "delete_file_impl", boom_delete) with pytest.raises(Exception) as ei: - await file_management_app.remove_storage_file(object_name="x") - assert "Failed to delete file" in str(ei.value) + await file_management_app.remove_storage_file( + object_name="attachments/user1/x", + authorization=MOCK_AUTH + ) + assert "Failed to delete file" in str(ei.value) or "Remove storage file error" in str(ei.value) + +# --- get_storage_file_batch_urls tests --- @pytest.mark.asyncio async def test_get_storage_file_batch_urls_validation_error(): with pytest.raises(Exception) as ei: - await file_management_app.get_storage_file_batch_urls(request_data={}, expires=10) + await file_management_app.get_storage_file_batch_urls( + request_data={}, expires=10, authorization=MOCK_AUTH + ) assert "object_names" in str(ei.value) @@ -418,17 +713,64 @@ async def test_get_storage_file_batch_urls_validation_error(): async def test_get_storage_file_batch_urls_mixed(monkeypatch): def fake_get(object_name, expires): # Synchronous stub to match non-awaited usage in implementation - if object_name == "ok": + if object_name == "knowledge_base/ok.txt": return {"success": True, "url": "http://u"} raise RuntimeError("bad") monkeypatch.setattr(file_management_app, "get_file_url_impl", fake_get) out = await file_management_app.get_storage_file_batch_urls( - request_data={"object_names": ["ok", "bad"]}, expires=5 + request_data={"object_names": ["knowledge_base/ok.txt", "knowledge_base/bad.txt"]}, expires=5, authorization=MOCK_AUTH ) assert out["total"] == 2 assert out["success_count"] == 1 - assert any(item["object_name"] == "bad" and item["success"] is False for item in out["results"]) + assert any(item["object_name"] == "knowledge_base/bad.txt" and item["success"] is False for item in out["results"]) + + +@pytest.mark.asyncio +async def test_get_storage_file_batch_urls_all_denied(monkeypatch): + """Test batch URLs when all files are denied access.""" + def fake_check_access(object_name, user_id): + return False # Deny all access + + def fake_get(object_name, expires): + return {"success": True, "url": "http://u"} + + monkeypatch.setattr(file_management_app, "check_file_access", fake_check_access) + monkeypatch.setattr(file_management_app, "get_file_url_impl", fake_get) + + out = await file_management_app.get_storage_file_batch_urls( + request_data={"object_names": ["knowledge_base/file1.txt", "knowledge_base/file2.txt"]}, + expires=5, + authorization=MOCK_AUTH + ) + assert out["total"] == 2 + assert out["success_count"] == 0 + assert out["failed_count"] == 2 + assert all(item["success"] is False and item["error"] == "Access denied" for item in out["results"]) + + +@pytest.mark.asyncio +async def test_get_storage_file_batch_urls_error(monkeypatch): + """Test batch URLs with internal error returns error in results, not exception.""" + def fake_check_access(object_name, user_id): + return True + + def fake_get(object_name, expires): + raise RuntimeError("Internal error") + + monkeypatch.setattr(file_management_app, "check_file_access", fake_check_access) + monkeypatch.setattr(file_management_app, "get_file_url_impl", fake_get) + + out = await file_management_app.get_storage_file_batch_urls( + request_data={"object_names": ["knowledge_base/file1.txt"]}, + expires=5, + authorization=MOCK_AUTH + ) + # Error should be captured in results, not raised + assert out["total"] == 1 + assert out["success_count"] == 0 + assert out["failed_count"] == 1 + assert "Internal error" in out["results"][0]["error"] # --- Tests for build_content_disposition_header --- @@ -501,6 +843,31 @@ def boom(_value: str, safe: str = "") -> str: assert 'attachment' not in result +def test_build_content_disposition_header_empty_filename(): + """Test build_content_disposition_header with empty/None filename""" + result = file_management_app.build_content_disposition_header(None) + assert 'attachment; filename="download"' in result + + +def test_build_content_disposition_header_sanitizes_control_chars(): + """Test that control characters are removed from filename""" + result = file_management_app.build_content_disposition_header("test\x00file.pdf") + assert 'testfile.pdf' in result + + +def test_build_content_disposition_header_sanitizes_backslash(): + """Test that backslash is replaced with underscore""" + result = file_management_app.build_content_disposition_header("test\\file.pdf") + assert '_' in result + assert '\\' not in result + + +def test_build_content_disposition_header_sanitizes_leading_dots(): + """Test that leading dots are removed (Windows restriction)""" + result = file_management_app.build_content_disposition_header(".hidden.pdf") + assert '.hidden.pdf' not in result or result == 'attachment; filename="hidden.pdf"' + + # --- Tests for get_storage_file with filename parameter --- @pytest.mark.asyncio @@ -513,10 +880,11 @@ async def gen(): monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) resp = await file_management_app.get_storage_file( - object_name="attachments/file.pdf", - download="stream", + object_name="attachments/user1/file.pdf", + download="stream", expires=60, - filename="原始文件名.pdf" + filename="原始文件名.pdf", + authorization=MOCK_AUTH ) assert resp.media_type == "application/pdf" content_disposition = resp.headers.get("content-disposition", "") @@ -533,10 +901,11 @@ async def gen(): monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) resp = await file_management_app.get_storage_file( - object_name="attachments/test.txt", - download="stream", + object_name="attachments/user1/test.txt", + download="stream", expires=60, - filename=None + filename=None, + authorization=MOCK_AUTH ) assert resp.media_type == "text/plain" content_disposition = resp.headers.get("content-disposition", "") @@ -552,12 +921,13 @@ async def fake_get_stream(object_name): monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) with pytest.raises(Exception) as ei: await file_management_app.get_storage_file( - object_name="test.txt", - download="stream", + object_name="attachments/user1/test.txt", + download="stream", expires=60, - filename="test.txt" + filename="test.txt", + authorization=MOCK_AUTH ) - assert "Failed to get file information" in str(ei.value) + assert "Failed to get file information" in str(ei.value) or "Failed to get file" in str(ei.value) # --- Tests for download_datamate_file --- @@ -577,7 +947,7 @@ async def test_download_datamate_file_with_url(monkeypatch): mock_client.__aexit__ = AsyncMock(return_value=None) monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) - + resp = await file_management_app.download_datamate_file( url="http://example.com/api/data-management/datasets/123/files/456/download", base_url=None, @@ -606,7 +976,7 @@ async def test_download_datamate_file_with_parts(monkeypatch): mock_client.__aexit__ = AsyncMock(return_value=None) monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) - + resp = await file_management_app.download_datamate_file( url=None, base_url="http://example.com", @@ -632,7 +1002,7 @@ async def test_download_datamate_file_404_error(monkeypatch): mock_client.__aexit__ = AsyncMock(return_value=None) monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) - + with pytest.raises(Exception) as ei: await file_management_app.download_datamate_file( url="http://example.com/api/data-management/datasets/123/files/456/download", @@ -649,14 +1019,14 @@ async def test_download_datamate_file_404_error(monkeypatch): async def test_download_datamate_file_http_error(monkeypatch): """Test download_datamate_file with HTTP error""" import httpx - + mock_client = MagicMock() mock_client.get = AsyncMock(side_effect=httpx.HTTPError("Network error")) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) - + with pytest.raises(Exception) as ei: await file_management_app.download_datamate_file( url="http://example.com/api/data-management/datasets/123/files/456/download", @@ -699,7 +1069,7 @@ async def test_download_datamate_file_extract_filename_from_content_disposition( mock_client.__aexit__ = AsyncMock(return_value=None) monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) - + resp = await file_management_app.download_datamate_file( url="http://example.com/api/data-management/datasets/123/files/456/download", base_url=None, @@ -727,7 +1097,7 @@ async def test_download_datamate_file_extract_filename_from_url(monkeypatch): mock_client.__aexit__ = AsyncMock(return_value=None) monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) - + resp = await file_management_app.download_datamate_file( url="http://example.com/api/data-management/datasets/123/files/456/download", base_url=None, @@ -760,7 +1130,7 @@ async def fake_httpx_get(url, headers=None, follow_redirects=True): mock_client.__aexit__ = AsyncMock(return_value=None) monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) - + await file_management_app.download_datamate_file( url="http://example.com/api/data-management/datasets/123/files/456/download", base_url=None, @@ -798,6 +1168,28 @@ def fail_normalize(_url: str): assert "Failed to download file: boom" in str(exc.value) +@pytest.mark.asyncio +async def test_download_datamate_file_internal_error(monkeypatch): + """Test download_datamate_file with internal unexpected error.""" + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=RuntimeError("Unexpected error")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + + with pytest.raises(Exception) as exc: + await file_management_app.download_datamate_file( + url="http://example.com/api/data-management/datasets/123/files/456/download", + base_url=None, + dataset_id=None, + file_id=None, + filename=None, + authorization=None, + ) + assert "Failed to download file" in str(exc.value) + + # --- Tests for _normalize_datamate_download_url --- def test_normalize_datamate_download_url_valid(): @@ -808,7 +1200,7 @@ def test_normalize_datamate_download_url_valid(): def test_normalize_datamate_download_url_adds_scheme(): - """URLs without scheme should default to https://""" + """URLs without scheme should default to http://""" url = "example.com/api/data-management/datasets/123/files/456/download" result = file_management_app._normalize_datamate_download_url(url) assert result.startswith("http://example.com") @@ -848,7 +1240,7 @@ def test_build_datamate_url_from_parts_with_api(): def test_build_datamate_url_from_parts_without_scheme(): - """base_url without scheme should default to https://""" + """base_url without scheme should default to http://""" result = file_management_app._build_datamate_url_from_parts( "example.com", "123", @@ -929,6 +1321,28 @@ def test_build_datamate_url_from_parts_empty_base_url(): assert "base_url is required" in str(ei.value) +# --- Tests for _ensure_http_scheme --- + +def test_ensure_http_scheme_empty(): + """Test _ensure_http_scheme with empty URL raises error""" + with pytest.raises(Exception) as ei: + file_management_app._ensure_http_scheme("") + assert "URL cannot be empty" in str(ei.value) + + +def test_ensure_http_scheme_invalid_scheme(): + """Test _ensure_http_scheme with invalid scheme raises error""" + with pytest.raises(Exception) as ei: + file_management_app._ensure_http_scheme("ftp://example.com/file") + assert "http:// or https://" in str(ei.value) + + +def test_ensure_http_scheme_double_slash(): + """Test _ensure_http_scheme with // prefix""" + result = file_management_app._ensure_http_scheme("//example.com/file") + assert result.startswith("http://") + + # --- Tests for preview_file endpoint --- def _make_mock_stream(content: bytes = b"content"): @@ -944,14 +1358,15 @@ async def test_preview_file_pdf_success(monkeypatch): """PDF file: 200 response with inline disposition, Accept-Ranges, ETag.""" mock_stream = _make_mock_stream(b"PDF content") monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("documents/test.pdf", "application/pdf", 2048))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 2048))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=mock_stream)) resp = await file_management_app.preview_file( - object_name="documents/test.pdf", + object_name="knowledge_base/test.pdf", filename="test.pdf", range_header=None, + authorization=MOCK_AUTH ) assert resp.media_type == "application/pdf" @@ -962,7 +1377,7 @@ async def test_preview_file_pdf_success(monkeypatch): assert resp.headers.get("accept-ranges") == "bytes" assert resp.headers.get("content-length") == "2048" assert resp.headers.get("cache-control") == "public, max-age=3600" - assert "documents/test.pdf" in resp.headers.get("etag", "") + assert "knowledge_base/test.pdf" in resp.headers.get("etag", "") assert resp.background is not None await resp.background() mock_stream.close.assert_called_once() @@ -972,14 +1387,15 @@ async def test_preview_file_pdf_success(monkeypatch): async def test_preview_file_image_success(monkeypatch): """Image file: 200 response with correct content type.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("images/photo.png", "image/png", 512))) + AsyncMock(return_value=("knowledge_base/photo.png", "image/png", 512))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=_make_mock_stream(b"PNG data"))) resp = await file_management_app.preview_file( - object_name="images/photo.png", + object_name="knowledge_base/photo.png", filename="photo.png", range_header=None, + authorization=MOCK_AUTH ) assert resp.media_type == "image/png" @@ -990,14 +1406,15 @@ async def test_preview_file_image_success(monkeypatch): async def test_preview_file_text_success(monkeypatch): """Text file: 200 response with correct content type.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("files/readme.txt", "text/plain", 128))) + AsyncMock(return_value=("knowledge_base/readme.txt", "text/plain", 128))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=_make_mock_stream(b"Hello World"))) resp = await file_management_app.preview_file( - object_name="files/readme.txt", + object_name="knowledge_base/readme.txt", filename="readme.txt", range_header=None, + authorization=MOCK_AUTH ) assert resp.media_type == "text/plain" @@ -1008,14 +1425,15 @@ async def test_preview_file_text_success(monkeypatch): async def test_preview_file_without_filename_extracts_from_path(monkeypatch): """No filename parameter: extracts name from the last path segment.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("folder/subfolder/document.pdf", "application/pdf", 1024))) + AsyncMock(return_value=("knowledge_base/subfolder/document.pdf", "application/pdf", 1024))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=_make_mock_stream())) resp = await file_management_app.preview_file( - object_name="folder/subfolder/document.pdf", + object_name="knowledge_base/subfolder/document.pdf", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "document.pdf" in resp.headers.get("content-disposition", "") @@ -1025,14 +1443,15 @@ async def test_preview_file_without_filename_extracts_from_path(monkeypatch): async def test_preview_file_chinese_filename(monkeypatch): """Chinese filename: RFC 5987 UTF-8 encoded in Content-Disposition.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("documents/test.pdf", "application/pdf", 1024))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 1024))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=_make_mock_stream())) resp = await file_management_app.preview_file( - object_name="documents/test.pdf", + object_name="knowledge_base/test.pdf", filename="测试文档.pdf", range_header=None, + authorization=MOCK_AUTH ) cd = resp.headers.get("content-disposition", "") @@ -1044,14 +1463,15 @@ async def test_preview_file_chinese_filename(monkeypatch): async def test_preview_file_simple_object_name_without_slash(monkeypatch): """Object name without slash: uses it directly as display filename.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("simple.pdf", "application/pdf", 256))) + AsyncMock(return_value=("knowledge_base/simple.pdf", "application/pdf", 256))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=_make_mock_stream())) resp = await file_management_app.preview_file( - object_name="simple.pdf", + object_name="knowledge_base/simple.pdf", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "simple.pdf" in resp.headers.get("content-disposition", "") @@ -1061,20 +1481,61 @@ async def test_preview_file_simple_object_name_without_slash(monkeypatch): async def test_preview_file_office_converted_to_pdf(monkeypatch): """Office document: resolve returns PDF path; response is application/pdf.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("preview/converted/report_abc.pdf", "application/pdf", 8192))) + AsyncMock(return_value=("knowledge_base/converted/report_abc.pdf", "application/pdf", 8192))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=_make_mock_stream(b"Converted PDF"))) resp = await file_management_app.preview_file( - object_name="documents/report.docx", + object_name="knowledge_base/report.docx", filename="report.docx", range_header=None, + authorization=MOCK_AUTH ) assert resp.media_type == "application/pdf" assert "inline" in resp.headers.get("content-disposition", "") +@pytest.mark.asyncio +async def test_preview_file_access_denied(monkeypatch): + """Test preview_file access denied for other user's attachments.""" + def fake_check_access(object_name, user_id): + if object_name.startswith("attachments/"): + expected_prefix = f"attachments/{user_id}" + return object_name.startswith(expected_prefix) + return object_name.startswith("knowledge_base/") + + monkeypatch.setattr(file_management_app, "check_file_access", fake_check_access) + + with pytest.raises(Exception) as ei: + await file_management_app.preview_file( + object_name="attachments/other_user/file.pdf", + filename=None, + range_header=None, + authorization=MOCK_AUTH + ) + assert "permission" in str(ei.value).lower() or "forbidden" in str(ei.value).lower() + + +@pytest.mark.asyncio +async def test_preview_file_allows_knowledge_base(monkeypatch): + """Test preview_file allows knowledge_base files.""" + monkeypatch.setattr(file_management_app, "resolve_preview_file", + AsyncMock(return_value=("knowledge_base/shared.pdf", "application/pdf", 1024))) + monkeypatch.setattr(file_management_app, "get_preview_stream", + MagicMock(return_value=_make_mock_stream())) + + resp = await file_management_app.preview_file( + object_name="knowledge_base/shared.pdf", + filename=None, + range_header=None, + authorization=MOCK_AUTH + ) + + assert resp.status_code == 200 + assert resp.media_type == "application/pdf" + + # --- Range request tests --- @pytest.mark.asyncio @@ -1082,14 +1543,15 @@ async def test_preview_file_range_request_returns_206(monkeypatch): """Valid Range header: 206 with Content-Range and correct Content-Length.""" mock_stream = _make_mock_stream(b"partial chunk") monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/test.pdf", "application/pdf", 10000))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 10000))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=mock_stream)) resp = await file_management_app.preview_file( - object_name="docs/test.pdf", + object_name="knowledge_base/test.pdf", filename=None, range_header="bytes=0-4095", + authorization=MOCK_AUTH ) assert resp.status_code == 206 @@ -1105,14 +1567,15 @@ async def test_preview_file_range_request_returns_206(monkeypatch): async def test_preview_file_range_suffix_form(monkeypatch): """Suffix range (bytes=-N): 206 with correct Content-Range.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/test.pdf", "application/pdf", 10000))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 10000))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=_make_mock_stream(b"tail chunk"))) resp = await file_management_app.preview_file( - object_name="docs/test.pdf", + object_name="knowledge_base/test.pdf", filename=None, range_header="bytes=-500", + authorization=MOCK_AUTH ) assert resp.status_code == 206 @@ -1124,14 +1587,15 @@ async def test_preview_file_range_suffix_form(monkeypatch): async def test_preview_file_range_open_ended(monkeypatch): """Open-ended range (bytes=N-): 206 reaching end of file.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/test.pdf", "application/pdf", 1000))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 1000))) monkeypatch.setattr(file_management_app, "get_preview_stream", MagicMock(return_value=_make_mock_stream(b"tail"))) resp = await file_management_app.preview_file( - object_name="docs/test.pdf", + object_name="knowledge_base/test.pdf", filename=None, range_header="bytes=500-", + authorization=MOCK_AUTH ) assert resp.status_code == 206 @@ -1144,13 +1608,14 @@ async def test_preview_file_empty_file_returns_200_without_stream(monkeypatch): """Empty file: return 200 with zero content length and no stream fetch.""" mock_get_stream = MagicMock() monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/empty.txt", "text/plain", 0))) + AsyncMock(return_value=("knowledge_base/empty.txt", "text/plain", 0))) monkeypatch.setattr(file_management_app, "get_preview_stream", mock_get_stream) resp = await file_management_app.preview_file( - object_name="docs/empty.txt", + object_name="knowledge_base/empty.txt", filename="empty.txt", range_header=None, + authorization=MOCK_AUTH ) assert resp.status_code == 200 @@ -1164,13 +1629,14 @@ async def test_preview_file_empty_file_ignores_range_and_returns_200(monkeypatch """Empty file with Range header: still return 200 empty response.""" mock_get_stream = MagicMock() monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/empty.txt", "text/plain", 0))) + AsyncMock(return_value=("knowledge_base/empty.txt", "text/plain", 0))) monkeypatch.setattr(file_management_app, "get_preview_stream", mock_get_stream) resp = await file_management_app.preview_file( - object_name="docs/empty.txt", + object_name="knowledge_base/empty.txt", filename="empty.txt", range_header="bytes=0-10", + authorization=MOCK_AUTH ) assert resp.status_code == 200 @@ -1182,12 +1648,13 @@ async def test_preview_file_empty_file_ignores_range_and_returns_200(monkeypatch async def test_preview_file_invalid_range_returns_416(monkeypatch): """Out-of-bounds Range: 416 with Content-Range: bytes */total.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/test.pdf", "application/pdf", 10000))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 10000))) resp = await file_management_app.preview_file( - object_name="docs/test.pdf", + object_name="knowledge_base/test.pdf", filename=None, range_header="bytes=20000-30000", + authorization=MOCK_AUTH ) assert resp.status_code == 416 @@ -1198,12 +1665,13 @@ async def test_preview_file_invalid_range_returns_416(monkeypatch): async def test_preview_file_malformed_range_returns_416(monkeypatch): """Malformed Range header: 416.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/test.pdf", "application/pdf", 1000))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 1000))) resp = await file_management_app.preview_file( - object_name="docs/test.pdf", + object_name="knowledge_base/test.pdf", filename=None, range_header="invalid-range", + authorization=MOCK_AUTH ) assert resp.status_code == 416 @@ -1213,7 +1681,7 @@ async def test_preview_file_malformed_range_returns_416(monkeypatch): @pytest.mark.asyncio async def test_preview_file_too_large_error(monkeypatch): - """FileTooLargeException from resolve_preview_file → HTTP 413.""" + """FileTooLargeException from resolve_preview_file -> HTTP 413.""" _FileTooLargeException = sys.modules["consts.exceptions"].FileTooLargeException async def fake_resolve(object_name): @@ -1223,16 +1691,17 @@ async def fake_resolve(object_name): with pytest.raises(Exception) as ei: await file_management_app.preview_file( - object_name="files/huge.pdf", + object_name="knowledge_base/huge.pdf", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "100 MB" in str(ei.value) @pytest.mark.asyncio async def test_preview_file_not_found_from_resolve(monkeypatch): - """NotFoundException from resolve_preview_file → HTTP 404.""" + """NotFoundException from resolve_preview_file -> HTTP 404.""" _NotFoundException = sys.modules["consts.exceptions"].NotFoundException async def fake_resolve(object_name): @@ -1242,20 +1711,21 @@ async def fake_resolve(object_name): with pytest.raises(Exception) as ei: await file_management_app.preview_file( - object_name="missing/file.pdf", + object_name="knowledge_base/missing/file.pdf", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "File not found" in str(ei.value) @pytest.mark.asyncio async def test_preview_file_not_found_from_stream(monkeypatch): - """NotFoundException from get_preview_stream → HTTP 404.""" + """NotFoundException from get_preview_stream -> HTTP 404.""" not_found_exception = sys.modules["consts.exceptions"].NotFoundException monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/test.pdf", "application/pdf", 1024))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 1024))) def fake_stream(actual_name, start=None, end=None): raise not_found_exception("File not found during streaming") @@ -1264,9 +1734,10 @@ def fake_stream(actual_name, start=None, end=None): with pytest.raises(Exception) as ei: await file_management_app.preview_file( - object_name="docs/test.pdf", + object_name="knowledge_base/test.pdf", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "File not found" in str(ei.value) @@ -1275,7 +1746,7 @@ def fake_stream(actual_name, start=None, end=None): async def test_preview_file_unexpected_error_from_stream(monkeypatch): """Unexpected exception from get_preview_stream should map to HTTP 500.""" monkeypatch.setattr(file_management_app, "resolve_preview_file", - AsyncMock(return_value=("docs/test.pdf", "application/pdf", 1024))) + AsyncMock(return_value=("knowledge_base/test.pdf", "application/pdf", 1024))) def fake_stream(actual_name, start=None, end=None): raise RuntimeError("stream broken") @@ -1284,16 +1755,17 @@ def fake_stream(actual_name, start=None, end=None): with pytest.raises(Exception) as ei: await file_management_app.preview_file( - object_name="docs/test.pdf", + object_name="knowledge_base/test.pdf", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "Failed to preview file" in str(ei.value) @pytest.mark.asyncio async def test_preview_file_unsupported_format_error(monkeypatch): - """UnsupportedFileTypeException from resolve_preview_file → HTTP 400.""" + """UnsupportedFileTypeException from resolve_preview_file -> HTTP 400.""" _UnsupportedFileTypeException = sys.modules["consts.exceptions"].UnsupportedFileTypeException async def fake_resolve(object_name): @@ -1303,16 +1775,17 @@ async def fake_resolve(object_name): with pytest.raises(Exception) as ei: await file_management_app.preview_file( - object_name="files/archive.zip", + object_name="knowledge_base/archive.zip", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "not supported for preview" in str(ei.value) @pytest.mark.asyncio async def test_preview_file_internal_error(monkeypatch): - """Unexpected exception from resolve_preview_file → HTTP 500.""" + """Unexpected exception from resolve_preview_file -> HTTP 500.""" async def fake_resolve(object_name): raise Exception("Internal server error") @@ -1320,9 +1793,10 @@ async def fake_resolve(object_name): with pytest.raises(Exception) as ei: await file_management_app.preview_file( - object_name="files/test.pdf", + object_name="knowledge_base/test.pdf", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "Failed to preview file" in str(ei.value) assert "Internal server error" not in str(ei.value) @@ -1330,7 +1804,7 @@ async def fake_resolve(object_name): @pytest.mark.asyncio async def test_preview_file_office_conversion_error(monkeypatch): - """OfficeConversionException (subclass of Exception) → HTTP 500.""" + """OfficeConversionException (subclass of Exception) -> HTTP 500.""" _OfficeConversionException = sys.modules["consts.exceptions"].OfficeConversionException async def fake_resolve(object_name): @@ -1340,9 +1814,10 @@ async def fake_resolve(object_name): with pytest.raises(Exception) as ei: await file_management_app.preview_file( - object_name="files/report.docx", + object_name="knowledge_base/report.docx", filename=None, range_header=None, + authorization=MOCK_AUTH ) assert "Failed to preview file" in str(ei.value) @@ -1407,3 +1882,7 @@ def test_missing_dash_returns_none(self): def test_zero_size_file_returns_none(self): """Empty files do not support satisfiable ranges.""" assert file_management_app._parse_range_header("bytes=0-10", 0) is None + + def test_negative_start_returns_none(self): + """Negative start values are invalid.""" + assert file_management_app._parse_range_header("bytes=-10-20", 1000) is None diff --git a/test/backend/services/test_file_management_service.py b/test/backend/services/test_file_management_service.py index cc54e804f..e5bb34f17 100644 --- a/test/backend/services/test_file_management_service.py +++ b/test/backend/services/test_file_management_service.py @@ -717,6 +717,410 @@ async def test_list_files_impl_with_limit(self): mock_list.assert_called_once_with(prefix="folder/") +class TestCheckFileAccess: + """Test cases for check_file_access function""" + + def test_check_file_access_no_user_id_returns_false(self): + """Access denied when user_id is None or empty""" + from backend.services.file_management_service import check_file_access + + assert check_file_access("knowledge_base/file.txt", None) is False + assert check_file_access("attachments/user123/file.txt", "") is False + assert check_file_access("any/path.txt", None) is False + + def test_check_file_access_knowledge_base_allows_access(self): + """All authenticated users can access knowledge_base files""" + from backend.services.file_management_service import check_file_access + + assert check_file_access("knowledge_base/file.txt", "user123") is True + assert check_file_access("knowledge_base/subfolder/doc.pdf", "user456") is True + assert check_file_access("knowledge_base/", "any_user") is True + + def test_check_file_access_user_attachment_allows_owner(self): + """Users can access files in their own attachments folder""" + from backend.services.file_management_service import check_file_access + + assert check_file_access("attachments/user123/file.txt", "user123") is True + assert check_file_access("attachments/user123/subfolder/doc.pdf", "user123") is True + + def test_check_file_access_user_attachment_denies_others(self): + """Users cannot access files in other users' attachments folders""" + from backend.services.file_management_service import check_file_access + + assert check_file_access("attachments/user123/file.txt", "user456") is False + assert check_file_access("attachments/other/file.txt", "user123") is False + + def test_check_file_access_backward_compatibility_root_attachments(self): + """Old format attachments/filename (no subdirectory) allows access for backward compatibility""" + from backend.services.file_management_service import check_file_access + + assert check_file_access("attachments/file.txt", "any_user") is True + assert check_file_access("attachments/document.pdf", "any_user") is True + + def test_check_file_access_deep_attachments_denies_non_matching_user(self): + """Deeply nested attachments/other/user/file paths should deny non-matching users""" + from backend.services.file_management_service import check_file_access + + # Pattern: attachments/{user_id}/{filename} where user_id matches + assert check_file_access("attachments/user123/document.docx", "user123") is True + # Pattern: attachments/otheruser/{filename} - user123 is neither "otheruser" nor matching + assert check_file_access("attachments/otheruser/document.docx", "user123") is False + + def test_check_file_access_denies_arbitrary_paths(self): + """Arbitrary paths outside knowledge_base and attachments are denied""" + from backend.services.file_management_service import check_file_access + + assert check_file_access("private/file.txt", "user123") is False + assert check_file_access("system/config.json", "user123") is False + assert check_file_access("preview/file.pdf", "user123") is False + + +class TestCheckFileAccessBatch: + """Test cases for check_file_access_batch function""" + + def test_check_file_access_batch_empty_list(self): + """Empty list returns empty dict""" + from backend.services.file_management_service import check_file_access_batch + + result = check_file_access_batch([], "user123") + assert result == {} + + def test_check_file_access_batch_mixed_permissions(self): + """Batch returns dict with correct permissions for each object""" + from backend.services.file_management_service import check_file_access_batch + + object_names = [ + "knowledge_base/file.txt", + "attachments/user123/doc.pdf", + "attachments/other/doc.pdf", + "private/file.txt" + ] + result = check_file_access_batch(object_names, "user123") + + assert result["knowledge_base/file.txt"] is True + assert result["attachments/user123/doc.pdf"] is True + assert result["attachments/other/doc.pdf"] is False + assert result["private/file.txt"] is False + + +class TestValidateS3UrlAccess: + """Test cases for validate_s3_url_access function""" + + def test_validate_s3_url_access_no_user_id_raises_permission_error(self): + """PermissionError raised when user_id is None or empty""" + from backend.services.file_management_service import validate_s3_url_access + + with pytest.raises(PermissionError) as exc_info: + validate_s3_url_access("knowledge_base/file.txt", None) + assert "User authentication required" in str(exc_info.value) + + with pytest.raises(PermissionError) as exc_info: + validate_s3_url_access("knowledge_base/file.txt", "") + assert "User authentication required" in str(exc_info.value) + + def test_validate_s3_url_access_valid_access_no_exception(self): + """No exception raised when user has valid access""" + from backend.services.file_management_service import validate_s3_url_access + + # Should not raise + validate_s3_url_access("knowledge_base/file.txt", "user123") + validate_s3_url_access("attachments/user123/file.txt", "user123") + + def test_validate_s3_url_access_invalid_access_raises_permission_error(self): + """PermissionError raised when user doesn't have access""" + from backend.services.file_management_service import validate_s3_url_access + + with pytest.raises(PermissionError) as exc_info: + validate_s3_url_access("attachments/other/file.txt", "user123") + assert "Access denied" in str(exc_info.value) + assert "you don't have permission" in str(exc_info.value).lower() + + +class TestValidateUrlsAccess: + """Test cases for validate_urls_access function""" + + def test_validate_urls_access_empty_list_no_exception(self): + """Empty list returns without exception""" + from backend.services.file_management_service import validate_urls_access + + # Should not raise + validate_urls_access([], "user123") + + def test_validate_urls_access_none_urls_skipped(self): + """None or empty strings in list are skipped""" + from backend.services.file_management_service import validate_urls_access + + # Should not raise + validate_urls_access([None, "", "knowledge_base/file.txt"], "user123") + + def test_validate_urls_access_http_https_urls_not_validated(self): + """HTTP/HTTPS URLs are external resources and not subject to MinIO access control""" + from backend.services.file_management_service import validate_urls_access + + # Should not raise even for inaccessible-looking URLs + validate_urls_access([ + "https://example.com/file.pdf", + "http://other.com/doc.docx" + ], "user123") + + def test_validate_urls_access_s3_url_valid_access_no_exception(self): + """S3 URL with valid access doesn't raise""" + from backend.services.file_management_service import validate_urls_access + + # Should not raise + validate_urls_access(["s3://bucket/knowledge_base/file.txt"], "user123") + + def test_validate_urls_access_s3_url_invalid_access_raises(self): + """S3 URL with invalid access raises PermissionError""" + from backend.services.file_management_service import validate_urls_access + + with pytest.raises(PermissionError) as exc_info: + validate_urls_access(["s3://bucket/attachments/other/file.txt"], "user123") + assert "Access denied" in str(exc_info.value) + + def test_validate_urls_access_invalid_s3_url_format_raises(self): + """Invalid S3 URL format raises PermissionError""" + from backend.services.file_management_service import validate_urls_access + + # Missing bucket/key format + with pytest.raises(PermissionError) as exc_info: + validate_urls_access(["s3://only-bucket"], "user123") + assert "Invalid S3 URL format" in str(exc_info.value) + + def test_validate_urls_access_bucket_key_format_valid(self): + """Path-style URL /bucket/key format with valid access doesn't raise""" + from backend.services.file_management_service import validate_urls_access + + # Should not raise + validate_urls_access(["/bucket/knowledge_base/file.txt"], "user123") + + def test_validate_urls_access_bucket_key_format_invalid_access(self): + """Path-style URL /bucket/key format with invalid access raises""" + from backend.services.file_management_service import validate_urls_access + + with pytest.raises(PermissionError) as exc_info: + validate_urls_access(["/bucket/attachments/other/file.txt"], "user123") + assert "Access denied" in str(exc_info.value) + + def test_validate_urls_access_bucket_key_format_trailing_slash(self): + """Path-style URL with only bucket (no key) is skipped or handled gracefully""" + from backend.services.file_management_service import validate_urls_access + + # Single slash bucket - no key + validate_urls_access(["//bucket"], "user123") # Starts with // + + def test_validate_urls_access_mixed_s3_and_external(self): + """Mixed S3 and external URLs - S3 URLs are validated, others skipped""" + from backend.services.file_management_service import validate_urls_access + + # Should not raise - S3 URL is valid, HTTPS is external + validate_urls_access([ + "https://external.com/file.pdf", + "s3://bucket/knowledge_base/file.txt" + ], "user123") + + # Should raise - S3 URL is invalid + with pytest.raises(PermissionError): + validate_urls_access([ + "https://external.com/file.pdf", + "s3://bucket/attachments/other/file.txt" + ], "user123") + + +class TestUploadFilesImplMinioFolderLogic: + """Test cases for MinIO folder logic in upload_files_impl (lines 199-212)""" + + @pytest.mark.asyncio + async def test_upload_files_impl_minio_knowledge_base_folder(self): + """When folder is 'knowledge_base', uses 'knowledge_base' without user isolation""" + from backend.services.file_management_service import upload_files_impl + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + with patch('backend.services.file_management_service.upload_to_minio', AsyncMock(return_value=[ + {"success": True, "file_name": "test.txt", "object_name": "knowledge_base/test.txt"} + ])) as mock_upload: + errors, uploaded_paths, uploaded_names = await upload_files_impl( + destination="minio", file=[mock_file], folder="knowledge_base", user_id="user123") + + assert errors == [] + # Verify knowledge_base was passed without user_id prefix + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["folder"] == "knowledge_base" + + @pytest.mark.asyncio + async def test_upload_files_impl_minio_user_isolation_with_user_id(self): + """When folder is not knowledge_base and user_id provided, uses attachments/{user_id}""" + from backend.services.file_management_service import upload_files_impl + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + with patch('backend.services.file_management_service.upload_to_minio', AsyncMock(return_value=[ + {"success": True, "file_name": "test.txt", "object_name": "attachments/user123/test.txt"} + ])) as mock_upload: + errors, uploaded_paths, uploaded_names = await upload_files_impl( + destination="minio", file=[mock_file], folder="documents", user_id="user123") + + assert errors == [] + # Verify user_id was used to construct attachments/{user_id} + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["folder"] == "attachments/user123" + + @pytest.mark.asyncio + async def test_upload_files_impl_minio_fallback_without_user_id(self): + """When folder is not knowledge_base and no user_id, falls back to folder or 'attachments'""" + from backend.services.file_management_service import upload_files_impl + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + # With folder provided but no user_id + with patch('backend.services.file_management_service.upload_to_minio', AsyncMock(return_value=[ + {"success": True, "file_name": "test.txt", "object_name": "custom_folder/test.txt"} + ])) as mock_upload: + errors, uploaded_paths, uploaded_names = await upload_files_impl( + destination="minio", file=[mock_file], folder="custom_folder", user_id=None) + + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["folder"] == "custom_folder" + + @pytest.mark.asyncio + async def test_upload_files_impl_minio_fallback_none_folder(self): + """When folder is None and no user_id, falls back to 'attachments'""" + from backend.services.file_management_service import upload_files_impl + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + with patch('backend.services.file_management_service.upload_to_minio', AsyncMock(return_value=[ + {"success": True, "file_name": "test.txt", "object_name": "attachments/test.txt"} + ])) as mock_upload: + errors, uploaded_paths, uploaded_names = await upload_files_impl( + destination="minio", file=[mock_file], folder=None, user_id=None) + + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["folder"] == "attachments" + + +class TestUploadToMinioFolderLogic: + """Test cases for MinIO folder logic in upload_to_minio (lines 265-296)""" + + @pytest.mark.asyncio + async def test_upload_to_minio_knowledge_base_folder(self): + """When folder is 'knowledge_base', uses 'knowledge_base' without user isolation""" + from backend.services.file_management_service import upload_to_minio + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + with patch('backend.services.file_management_service.upload_fileobj', MagicMock(return_value={ + "success": True, "file_name": "test.txt", "object_name": "knowledge_base/test.txt" + })) as mock_upload: + results = await upload_to_minio(files=[mock_file], folder="knowledge_base", user_id="user123") + + assert len(results) == 1 + assert results[0]["success"] is True + # Verify knowledge_base was passed without user_id prefix + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["prefix"] == "knowledge_base" + + @pytest.mark.asyncio + async def test_upload_to_minio_user_isolation_with_user_id(self): + """When folder is not knowledge_base and user_id provided, uses attachments/{user_id}""" + from backend.services.file_management_service import upload_to_minio + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + with patch('backend.services.file_management_service.upload_fileobj', MagicMock(return_value={ + "success": True, "file_name": "test.txt", "object_name": "attachments/user456/test.txt" + })) as mock_upload: + results = await upload_to_minio(files=[mock_file], folder="documents", user_id="user456") + + assert len(results) == 1 + assert results[0]["success"] is True + # Verify user_id was used to construct attachments/{user_id} + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["prefix"] == "attachments/user456" + + @pytest.mark.asyncio + async def test_upload_to_minio_fallback_without_user_id(self): + """When folder is not knowledge_base and no user_id, uses folder as-is""" + from backend.services.file_management_service import upload_to_minio + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + with patch('backend.services.file_management_service.upload_fileobj', MagicMock(return_value={ + "success": True, "file_name": "test.txt", "object_name": "my_folder/test.txt" + })) as mock_upload: + results = await upload_to_minio(files=[mock_file], folder="my_folder", user_id=None) + + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["prefix"] == "my_folder" + + @pytest.mark.asyncio + async def test_upload_to_minio_fallback_none_folder(self): + """When folder is None and no user_id, falls back to 'attachments'""" + from backend.services.file_management_service import upload_to_minio + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + with patch('backend.services.file_management_service.upload_fileobj', MagicMock(return_value={ + "success": True, "file_name": "test.txt", "object_name": "attachments/test.txt" + })) as mock_upload: + results = await upload_to_minio(files=[mock_file], folder=None, user_id=None) + + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["prefix"] == "attachments" + + @pytest.mark.asyncio + async def test_upload_to_minio_attachments_folder_with_user_id(self): + """Attachments folder with user_id uses attachments/{user_id} path""" + from backend.services.file_management_service import upload_to_minio + + mock_file = MagicMock() + mock_file.filename = "test.txt" + mock_file.read = AsyncMock(return_value=b"test content") + mock_file.seek = AsyncMock() + + with patch('backend.services.file_management_service.upload_fileobj', MagicMock(return_value={ + "success": True, "file_name": "test.txt", "object_name": "attachments/abc123/test.txt" + })) as mock_upload: + results = await upload_to_minio(files=[mock_file], folder="attachments", user_id="abc123") + + mock_upload.assert_called_once() + call_kwargs = mock_upload.call_args[1] + assert call_kwargs["prefix"] == "attachments/abc123" + + class TestEdgeCasesAndErrorHandling: """Test cases for edge cases and error handling scenarios""" @@ -860,7 +1264,7 @@ async def test_upload_files_impl_no_semaphore_for_minio(self): @pytest.mark.asyncio async def test_upload_to_minio_with_none_folder(self): - """Test upload_to_minio with None folder""" + """Test upload_to_minio with None folder falls back to 'attachments'""" # Create mock UploadFile mock_file = MagicMock() mock_file.filename = "test.txt" @@ -868,23 +1272,23 @@ async def test_upload_to_minio_with_none_folder(self): mock_file.seek = AsyncMock() with patch('backend.services.file_management_service.upload_fileobj', MagicMock(return_value={ - "success": True, "file_name": "test.txt", "object_name": "test.txt" + "success": True, "file_name": "test.txt", "object_name": "attachments/test.txt" })) as mock_upload: - # Execute with None folder - results = await upload_to_minio(files=[mock_file], folder=None) + # Execute with None folder - should fall back to 'attachments' + results = await upload_to_minio(files=[mock_file], folder=None, user_id=None) # Assertions assert len(results) == 1 assert results[0]["success"] is True assert results[0]["file_name"] == "test.txt" mock_upload.assert_called_once() - # Verify that None was passed as prefix + # Verify that 'attachments' was passed as prefix (fallback when folder is None) call_args = mock_upload.call_args - assert call_args[1]["prefix"] is None + assert call_args[1]["prefix"] == "attachments" @pytest.mark.asyncio async def test_upload_to_minio_with_empty_folder(self): - """Test upload_to_minio with empty folder string""" + """Test upload_to_minio with empty folder string falls back to 'attachments'""" # Create mock UploadFile mock_file = MagicMock() mock_file.filename = "test.txt" @@ -892,19 +1296,19 @@ async def test_upload_to_minio_with_empty_folder(self): mock_file.seek = AsyncMock() with patch('backend.services.file_management_service.upload_fileobj', MagicMock(return_value={ - "success": True, "file_name": "test.txt", "object_name": "test.txt" + "success": True, "file_name": "test.txt", "object_name": "attachments/test.txt" })) as mock_upload: - # Execute with empty folder - results = await upload_to_minio(files=[mock_file], folder="") + # Execute with empty folder - empty string is falsy, falls back to 'attachments' + results = await upload_to_minio(files=[mock_file], folder="", user_id=None) # Assertions assert len(results) == 1 assert results[0]["success"] is True assert results[0]["file_name"] == "test.txt" mock_upload.assert_called_once() - # Verify that empty string was passed as prefix + # Verify that 'attachments' was passed as prefix (fallback when folder is empty/falsy) call_args = mock_upload.call_args - assert call_args[1]["prefix"] == "" + assert call_args[1]["prefix"] == "attachments" class TestGetLlmModel: diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 7dedc9dba..8cc749d33 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -2476,11 +2476,12 @@ def test_validate_local_tool_analyze_image_success(self, mock_signature, mock_ge assert result == "analyze image result" mock_get_vlm_model.assert_called_once_with(tenant_id="tenant1") - mock_tool_class.assert_called_once_with( - prompt="describe", - vlm_model="mock_vlm_model", - storage_client=mock_minio_client - ) + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args.kwargs + assert 'vlm_model' in call_kwargs + assert 'storage_client' in call_kwargs + assert 'validate_url_access' in call_kwargs + assert callable(call_kwargs['validate_url_access']) mock_tool_instance.forward.assert_called_once_with(image="bytes") @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @@ -2791,13 +2792,14 @@ def test_validate_local_tool_analyze_text_file_success(self, mock_minio_client, mock_get_class.assert_called_once_with("analyze_text_file") # Verify analyze_text_file specific parameters were passed - expected_params = { - "param": "config", - "llm_model": mock_llm_model, - "storage_client": mock_minio_client, - "data_process_service_url": "http://data-process-service", - } - mock_tool_class.assert_called_once_with(**expected_params) + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args.kwargs + assert 'llm_model' in call_kwargs + assert 'storage_client' in call_kwargs + assert 'data_process_service_url' in call_kwargs + assert call_kwargs['data_process_service_url'] == "http://data-process-service" + assert 'validate_url_access' in call_kwargs + assert callable(call_kwargs['validate_url_access']) mock_tool_instance.forward.assert_called_once_with(input="test input") # Verify service calls diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 474fa8baa..76001dc6c 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -918,6 +918,7 @@ def test_create_local_tool_analyze_text_file_tool(nexent_agent_instance): llm_model="llm_model_obj", storage_client="storage_client_obj", data_process_service_url="DATA_PROCESS_SERVICE", + validate_url_access=None, prompt="describe this", ) assert result == mock_analyze_tool_instance @@ -958,46 +959,7 @@ def test_create_local_tool_analyze_image_tool(nexent_agent_instance): observer=nexent_agent_instance.observer, vlm_model="vlm_model_obj", storage_client="storage_client_obj", - prompt="describe this", - ) - assert result == mock_analyze_tool_instance - - -def test_create_local_tool_analyze_image_tool(nexent_agent_instance): - """Test AnalyzeImageTool creation injects observer and metadata.""" - mock_analyze_tool_class = MagicMock() - mock_analyze_tool_instance = MagicMock() - mock_analyze_tool_class.return_value = mock_analyze_tool_instance - - tool_config = ToolConfig( - class_name="AnalyzeImageTool", - name="analyze_image", - description="desc", - inputs="{}", - output_type="string", - params={"prompt": "describe this"}, - source="local", - metadata={ - "vlm_model": "vlm_model_obj", - "storage_client": "storage_client_obj", - }, - ) - - original_value = nexent_agent.__dict__.get("AnalyzeImageTool") - nexent_agent.__dict__["AnalyzeImageTool"] = mock_analyze_tool_class - - try: - result = nexent_agent_instance.create_local_tool(tool_config) - finally: - if original_value is not None: - nexent_agent.__dict__["AnalyzeImageTool"] = original_value - elif "AnalyzeImageTool" in nexent_agent.__dict__: - del nexent_agent.__dict__["AnalyzeImageTool"] - - mock_analyze_tool_class.assert_called_once_with( - observer=nexent_agent_instance.observer, - vlm_model="vlm_model_obj", - storage_client="storage_client_obj", + validate_url_access=None, prompt="describe this", ) assert result == mock_analyze_tool_instance @@ -2317,6 +2279,229 @@ def test_create_local_tool_analyze_image(self, nexent_agent_instance): assert call_kwargs["param1"] == "value1" assert result == mock_tool_instance + def test_create_local_tool_analyze_text_file_with_validate_url_access_none(self, nexent_agent_instance): + """Test AnalyzeTextFileTool creation with validate_url_access not in metadata (None).""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + tool_config = ToolConfig( + class_name="AnalyzeTextFileTool", + name="analyze_text", + description="desc", + inputs="{}", + output_type="string", + params={"prompt": "describe this"}, + source="local", + metadata={ + "llm_model": ["gpt-4"], + "storage_client": "storage", + "data_process_service_url": "http://service.com" + } + ) + + original_value = nexent_agent.__dict__.get("AnalyzeTextFileTool") + nexent_agent.__dict__["AnalyzeTextFileTool"] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["AnalyzeTextFileTool"] = original_value + elif "AnalyzeTextFileTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["AnalyzeTextFileTool"] + + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args[1] + assert call_kwargs["validate_url_access"] is None + + def test_create_local_tool_analyze_text_file_with_validate_url_access_callable(self, nexent_agent_instance): + """Test AnalyzeTextFileTool creation with validate_url_access as callable.""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + def mock_validate_func(url): + return True + + tool_config = ToolConfig( + class_name="AnalyzeTextFileTool", + name="analyze_text", + description="desc", + inputs="{}", + output_type="string", + params={"prompt": "describe this"}, + source="local", + metadata={ + "llm_model": ["gpt-4"], + "storage_client": "storage", + "data_process_service_url": "http://service.com", + "validate_url_access": mock_validate_func + } + ) + + original_value = nexent_agent.__dict__.get("AnalyzeTextFileTool") + nexent_agent.__dict__["AnalyzeTextFileTool"] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["AnalyzeTextFileTool"] = original_value + elif "AnalyzeTextFileTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["AnalyzeTextFileTool"] + + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args[1] + assert call_kwargs["validate_url_access"] == mock_validate_func + + def test_create_local_tool_analyze_text_file_with_validate_url_access_not_callable(self, nexent_agent_instance): + """Test AnalyzeTextFileTool creation with non-callable validate_url_access (should be None).""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + tool_config = ToolConfig( + class_name="AnalyzeTextFileTool", + name="analyze_text", + description="desc", + inputs="{}", + output_type="string", + params={"prompt": "describe this"}, + source="local", + metadata={ + "llm_model": ["gpt-4"], + "storage_client": "storage", + "data_process_service_url": "http://service.com", + "validate_url_access": "not_a_callable_string" + } + ) + + original_value = nexent_agent.__dict__.get("AnalyzeTextFileTool") + nexent_agent.__dict__["AnalyzeTextFileTool"] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["AnalyzeTextFileTool"] = original_value + elif "AnalyzeTextFileTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["AnalyzeTextFileTool"] + + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args[1] + assert call_kwargs["validate_url_access"] is None + + def test_create_local_tool_analyze_image_with_validate_url_access_none(self, nexent_agent_instance): + """Test AnalyzeImageTool creation with validate_url_access not in metadata (None).""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + tool_config = ToolConfig( + class_name="AnalyzeImageTool", + name="analyze_image", + description="desc", + inputs="{}", + output_type="string", + params={"param1": "value1"}, + source="local", + metadata={ + "vlm_model": ["gpt-4-vision"], + "storage_client": "storage" + } + ) + + original_value = nexent_agent.__dict__.get("AnalyzeImageTool") + nexent_agent.__dict__["AnalyzeImageTool"] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["AnalyzeImageTool"] = original_value + elif "AnalyzeImageTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["AnalyzeImageTool"] + + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args[1] + assert call_kwargs["validate_url_access"] is None + + def test_create_local_tool_analyze_image_with_validate_url_access_callable(self, nexent_agent_instance): + """Test AnalyzeImageTool creation with validate_url_access as callable.""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + def mock_validate_func(url): + return True + + tool_config = ToolConfig( + class_name="AnalyzeImageTool", + name="analyze_image", + description="desc", + inputs="{}", + output_type="string", + params={"param1": "value1"}, + source="local", + metadata={ + "vlm_model": ["gpt-4-vision"], + "storage_client": "storage", + "validate_url_access": mock_validate_func + } + ) + + original_value = nexent_agent.__dict__.get("AnalyzeImageTool") + nexent_agent.__dict__["AnalyzeImageTool"] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["AnalyzeImageTool"] = original_value + elif "AnalyzeImageTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["AnalyzeImageTool"] + + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args[1] + assert call_kwargs["validate_url_access"] == mock_validate_func + + def test_create_local_tool_analyze_image_with_validate_url_access_not_callable(self, nexent_agent_instance): + """Test AnalyzeImageTool creation with non-callable validate_url_access (should be None).""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + tool_config = ToolConfig( + class_name="AnalyzeImageTool", + name="analyze_image", + description="desc", + inputs="{}", + output_type="string", + params={"param1": "value1"}, + source="local", + metadata={ + "vlm_model": ["gpt-4-vision"], + "storage_client": "storage", + "validate_url_access": 12345 + } + ) + + original_value = nexent_agent.__dict__.get("AnalyzeImageTool") + nexent_agent.__dict__["AnalyzeImageTool"] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["AnalyzeImageTool"] = original_value + elif "AnalyzeImageTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["AnalyzeImageTool"] + + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args[1] + assert call_kwargs["validate_url_access"] is None + class TestCreateLocalToolClassNotFound: """Tests for create_local_tool when class is not found.""" diff --git a/test/sdk/core/tools/test_analyze_image_tool.py b/test/sdk/core/tools/test_analyze_image_tool.py index c83f99fa0..a8598a8ad 100644 --- a/test/sdk/core/tools/test_analyze_image_tool.py +++ b/test/sdk/core/tools/test_analyze_image_tool.py @@ -239,7 +239,87 @@ def test_load_save_object_manager_created(self, mock_vlm_model, mock_storage_cli ) mock_manager_class.assert_called_once_with( - storage_client=mock_storage_client) + storage_client=mock_storage_client, + validate_url_access=None + ) + + def test_load_save_object_manager_with_validate_url_access_callable( + self, mock_vlm_model, mock_storage_client + ): + """Test that callable validate_url_access is passed to LoadSaveObjectManager.""" + with patch('sdk.nexent.core.tools.analyze_image_tool.LoadSaveObjectManager') as mock_manager_class: + mock_manager_instance = MagicMock() + mock_manager_class.return_value = mock_manager_instance + mock_manager_instance.load_object.return_value = lambda x: x + + validate_callback = MagicMock() + + tool = AnalyzeImageTool( + observer=MagicMock(), + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + validate_url_access=validate_callback, + ) + + mock_manager_class.assert_called_once_with( + storage_client=mock_storage_client, + validate_url_access=validate_callback + ) + + def test_load_save_object_manager_validate_url_access_not_callable( + self, mock_vlm_model, mock_storage_client + ): + """Test that non-callable validate_url_access is converted to None.""" + with patch('sdk.nexent.core.tools.analyze_image_tool.LoadSaveObjectManager') as mock_manager_class: + mock_manager_instance = MagicMock() + mock_manager_class.return_value = mock_manager_instance + mock_manager_instance.load_object.return_value = lambda x: x + + tool = AnalyzeImageTool( + observer=MagicMock(), + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + validate_url_access="not_a_callable", + ) + + mock_manager_class.assert_called_once_with( + storage_client=mock_storage_client, + validate_url_access=None + ) + + def test_load_save_object_manager_validate_url_access_lambda( + self, mock_vlm_model, mock_storage_client + ): + """Test that lambda validate_url_access is passed to LoadSaveObjectManager.""" + with patch('sdk.nexent.core.tools.analyze_image_tool.LoadSaveObjectManager') as mock_manager_class: + mock_manager_instance = MagicMock() + mock_manager_class.return_value = mock_manager_instance + mock_manager_instance.load_object.return_value = lambda x: x + + validate_callback = lambda url: True + + tool = AnalyzeImageTool( + observer=MagicMock(), + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + validate_url_access=validate_callback, + ) + + mock_manager_class.assert_called_once_with( + storage_client=mock_storage_client, + validate_url_access=validate_callback + ) + + def test_init_param_descriptions_has_validate_url_access(self, mock_vlm_model, mock_storage_client): + """Test that init_param_descriptions includes validate_url_access.""" + tool = AnalyzeImageTool( + observer=MagicMock(), + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + assert "validate_url_access" in tool.init_param_descriptions + assert "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" in tool.init_param_descriptions["validate_url_access"]["description"] def test_observer_add_message_called(self, tool, mock_vlm_model, mock_prompt_loader): """Test that observer.add_message is called with running prompt.""" diff --git a/test/sdk/core/tools/test_analyze_text_file_tool.py b/test/sdk/core/tools/test_analyze_text_file_tool.py index 03646387c..2b3461ec5 100644 --- a/test/sdk/core/tools/test_analyze_text_file_tool.py +++ b/test/sdk/core/tools/test_analyze_text_file_tool.py @@ -171,3 +171,112 @@ def test_analyze_file_defaults_to_english(self, tool, llm_model, monkeypatch): assert result == ("analysis", 0) mock_get_template.assert_called_once_with( template_type="analyze_file", language="en") + + +class TestAnalyzeTextFileToolValidateUrlAccess: + """Test cases for validate_url_access parameter in AnalyzeTextFileTool.""" + + def test_load_save_object_manager_created_with_validate_url_access_none( + self, observer_en, llm_model + ): + """Test that LoadSaveObjectManager is called with validate_url_access=None by default.""" + with patch.object(module, 'LoadSaveObjectManager') as mock_manager_class: + mock_manager_instance = MagicMock() + mock_manager_class.return_value = mock_manager_instance + mock_manager_instance.load_object.return_value = lambda x: x + + tool = AnalyzeTextFileTool( + storage_client=MagicMock(), + observer=observer_en, + data_process_service_url="http://data-process", + llm_model=llm_model, + ) + + mock_manager_class.assert_called_once_with( + storage_client=tool.storage_client, + validate_url_access=None + ) + + def test_load_save_object_manager_with_validate_url_access_callable( + self, observer_en, llm_model + ): + """Test that callable validate_url_access is passed to LoadSaveObjectManager.""" + with patch.object(module, 'LoadSaveObjectManager') as mock_manager_class: + mock_manager_instance = MagicMock() + mock_manager_class.return_value = mock_manager_instance + mock_manager_instance.load_object.return_value = lambda x: x + + validate_callback = MagicMock() + + tool = AnalyzeTextFileTool( + storage_client=MagicMock(), + observer=observer_en, + data_process_service_url="http://data-process", + llm_model=llm_model, + validate_url_access=validate_callback, + ) + + mock_manager_class.assert_called_once_with( + storage_client=tool.storage_client, + validate_url_access=validate_callback + ) + + def test_load_save_object_manager_validate_url_access_not_callable( + self, observer_en, llm_model + ): + """Test that non-callable validate_url_access is converted to None.""" + with patch.object(module, 'LoadSaveObjectManager') as mock_manager_class: + mock_manager_instance = MagicMock() + mock_manager_class.return_value = mock_manager_instance + mock_manager_instance.load_object.return_value = lambda x: x + + tool = AnalyzeTextFileTool( + storage_client=MagicMock(), + observer=observer_en, + data_process_service_url="http://data-process", + llm_model=llm_model, + validate_url_access="not_a_callable", + ) + + mock_manager_class.assert_called_once_with( + storage_client=tool.storage_client, + validate_url_access=None + ) + + def test_load_save_object_manager_validate_url_access_lambda( + self, observer_en, llm_model + ): + """Test that lambda validate_url_access is passed to LoadSaveObjectManager.""" + with patch.object(module, 'LoadSaveObjectManager') as mock_manager_class: + mock_manager_instance = MagicMock() + mock_manager_class.return_value = mock_manager_instance + mock_manager_instance.load_object.return_value = lambda x: x + + validate_callback = lambda url: True + + tool = AnalyzeTextFileTool( + storage_client=MagicMock(), + observer=observer_en, + data_process_service_url="http://data-process", + llm_model=llm_model, + validate_url_access=validate_callback, + ) + + mock_manager_class.assert_called_once_with( + storage_client=tool.storage_client, + validate_url_access=validate_callback + ) + + def test_init_param_descriptions_has_validate_url_access( + self, observer_en, llm_model + ): + """Test that init_param_descriptions includes validate_url_access.""" + tool = AnalyzeTextFileTool( + storage_client=MagicMock(), + observer=observer_en, + data_process_service_url="http://data-process", + llm_model=llm_model, + ) + + assert "validate_url_access" in tool.init_param_descriptions + assert "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" in tool.init_param_descriptions["validate_url_access"]["description"] diff --git a/test/sdk/multi_modal/test_load_save_object.py b/test/sdk/multi_modal/test_load_save_object.py index e65da1daf..92425791c 100644 --- a/test/sdk/multi_modal/test_load_save_object.py +++ b/test/sdk/multi_modal/test_load_save_object.py @@ -7,10 +7,10 @@ from sdk.nexent.multi_modal import load_save_object as lso -def make_manager(client: Any = None) -> lso.LoadSaveObjectManager: +def make_manager(client: Any = None, validate_url_access: Any = None) -> lso.LoadSaveObjectManager: if client is None: client = object() - return lso.LoadSaveObjectManager(storage_client=client) + return lso.LoadSaveObjectManager(storage_client=client, validate_url_access=validate_url_access) def test_get_client_returns_configured_storage(): @@ -441,3 +441,318 @@ async def handler(): result = await handler() assert result == "s3://bucket/object" upload_mock.assert_called_once() + + +# ============================================================================ +# Tests for new code coverage (lines 29-40, 135-139, 185-209) +# ============================================================================ + + +def test_init_stores_validate_url_access(): + """Test that __init__ (lines 29-40) stores the validate_url_access callback.""" + def my_validator(urls): + pass + + manager = make_manager(validate_url_access=my_validator) + assert manager._validate_url_access is my_validator + + +def test_init_validate_url_access_defaults_to_none(): + """Test that validate_url_access defaults to None when not provided.""" + manager = make_manager() + assert manager._validate_url_access is None + + +def test_load_object_with_validate_url_access_success(monkeypatch): + """Test load_object (lines 185-209) with successful URL validation.""" + manager = make_manager() + validate_mock = MagicMock() + download_mock = MagicMock(return_value=b"file-bytes") + + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["image"]) + def handler(image): + return image + + result = handler("https://example.com/img.png") + + validate_mock.assert_not_called() + assert result == b"file-bytes" + + +def test_load_object_validates_urls_before_download(monkeypatch): + """Test that URL validation happens before downloading (lines 200-208).""" + def my_validator(urls): + assert "https://example.com/img.png" in urls + raise PermissionError("Access denied") + + manager = make_manager(validate_url_access=my_validator) + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["image"]) + def handler(image): + return image + + with pytest.raises(PermissionError, match="Access denied"): + handler("https://example.com/img.png") + + download_mock.assert_not_called() + + +def test_load_object_validates_urls_with_other_exception(monkeypatch): + """Test that non-PermissionError exceptions from validator raise PermissionError (lines 206-208).""" + def my_validator(urls): + raise ValueError("Some validation error") + + manager = make_manager(validate_url_access=my_validator) + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["image"]) + def handler(image): + return image + + with pytest.raises(PermissionError, match="URL access validation failed"): + handler("https://example.com/img.png") + + +def test_load_object_collects_urls_from_list(monkeypatch): + """Test that URLs are collected from list arguments (lines 195-198).""" + collected_urls = [] + + def my_validator(urls): + collected_urls.extend(urls) + + manager = make_manager(validate_url_access=my_validator) + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["images"]) + def handler(images): + return images + + result = handler(["https://example.com/a.png", "https://example.com/b.png"]) + + assert len(collected_urls) == 2 + assert "https://example.com/a.png" in collected_urls + assert "https://example.com/b.png" in collected_urls + + +def test_load_object_collects_urls_from_tuple(monkeypatch): + """Test that URLs are collected from tuple arguments (lines 195-198).""" + collected_urls = [] + + def my_validator(urls): + collected_urls.extend(urls) + + manager = make_manager(validate_url_access=my_validator) + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["images"]) + def handler(images): + return images + + result = handler(("https://a.com/1.png", "https://b.com/2.png")) + + assert len(collected_urls) == 2 + assert "https://a.com/1.png" in collected_urls + assert "https://b.com/2.png" in collected_urls + + +def test_load_object_collects_urls_from_multiple_params(monkeypatch): + """Test URL collection across multiple parameters (lines 186-198).""" + collected_urls = [] + + def my_validator(urls): + collected_urls.extend(urls) + + manager = make_manager(validate_url_access=my_validator) + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["image", "mask"]) + def handler(image, mask): + return image, mask + + result = handler("https://example.com/img.png", "https://example.com/mask.png") + + assert len(collected_urls) == 2 + assert "https://example.com/img.png" in collected_urls + assert "https://example.com/mask.png" in collected_urls + + +def test_load_object_no_validation_when_callback_none(monkeypatch): + """Test that validation is skipped when validate_url_access is None (line 201).""" + manager = make_manager(validate_url_access=None) + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["image"]) + def handler(image): + return image + + result = handler("https://example.com/img.png") + assert result == b"file-bytes" + + +def test_load_object_no_validation_when_not_callable(monkeypatch): + """Test that validation is skipped when validate_url_access is not callable (line 201).""" + manager = make_manager(validate_url_access="not-a-callable") + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["image"]) + def handler(image): + return image + + result = handler("https://example.com/img.png") + assert result == b"file-bytes" + + +def test_load_object_with_validate_url_access_and_s3_url(monkeypatch): + """Test URL validation with S3 URLs (lines 186-198).""" + collected_urls = [] + + def my_validator(urls): + collected_urls.extend(urls) + + manager = make_manager(validate_url_access=my_validator) + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "s3" if url.startswith("s3://") else None) + + @manager.load_object(input_names=["file"]) + def handler(file): + return file + + result = handler("s3://bucket/path/to/file.bin") + + assert len(collected_urls) == 1 + assert "s3://bucket/path/to/file.bin" in collected_urls + + +def test_load_object_tool_instance_from_bound_args(monkeypatch): + """Test load_object extracts tool instance from bound args (lines 135-139).""" + manager = make_manager() + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + class ToolWithMethod: + @manager.load_object(input_names=["image"]) + def process(self, image): + return image + + tool = ToolWithMethod() + result = tool.process("https://example.com/img.png") + + download_mock.assert_called_once() + assert result == b"file-bytes" + + +def test_load_object_validates_empty_url_list(monkeypatch): + """Test that empty collections don't trigger validation (line 195).""" + validate_called = False + + def my_validator(urls): + nonlocal validate_called + validate_called = True + + manager = make_manager(validate_url_access=my_validator) + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + + @manager.load_object(input_names=["items"]) + def handler(items): + return items + + result = handler([]) + + assert not validate_called + assert result == [] + + +def test_load_object_validation_called_with_duplicates(monkeypatch): + """Test that duplicate URLs are all included in validation (lines 195-198).""" + collected_urls = [] + + def my_validator(urls): + collected_urls.extend(urls) + + manager = make_manager(validate_url_access=my_validator) + download_mock = MagicMock(side_effect=[b"a", b"b"]) + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["images"]) + def handler(images): + return images + + result = handler(["https://example.com/same.png", "https://example.com/same.png"]) + + assert len(collected_urls) == 2 + assert collected_urls.count("https://example.com/same.png") == 2 + + +def test_download_file_unsupported_url_type_raises(monkeypatch): + """Test that unsupported URL type raises ValueError (line 90).""" + class _Response: + def __init__(self): + self.content = b"binary" + + def raise_for_status(self): + return None + + monkeypatch.setattr(lso.requests, "get", lambda url, timeout: _Response()) + manager = make_manager() + + result = manager.download_file_from_url("ftp://example.com/file.png", url_type="ftp") + assert result is None + + +def test_load_object_transformer_returns_none_raises_error(monkeypatch): + """Test that transformer returning None raises ValueError (line 147-148).""" + def transformer(_data: bytes): + return None + + manager = make_manager() + monkeypatch.setattr( + manager, "download_file_from_url", + MagicMock(return_value=None) + ) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["image"], input_data_transformer=[transformer]) + def handler(image): + return image + + with pytest.raises(ValueError, match="Failed to download file from URL"): + handler("https://example.com/test.png") + + +def test_process_value_handles_none_in_list(monkeypatch): + """Test that None values in lists are handled correctly (line 170).""" + manager = make_manager() + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + monkeypatch.setattr(lso, "is_url", lambda url: "https" if url.startswith("https://") else None) + + @manager.load_object(input_names=["items"]) + def handler(items): + return items + + result = handler([None, "https://example.com/img.png"]) + + assert result[0] is None + assert result[1] == b"file-bytes" From 98a95098645696888a8198f9f1a59136d7c6162c Mon Sep 17 00:00:00 2001 From: xuyaqi Date: Thu, 23 Apr 2026 16:09:37 +0800 Subject: [PATCH 004/156] =?UTF-8?q?=E2=9C=A8=20Feat:=20Add=20presigned=20U?= =?UTF-8?q?RL=20support=20for=20external=20MCP=20tool=20file=20access=20an?= =?UTF-8?q?d=20improve=20agent=20execution=20flow=20(#2839)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复调用多模态工具导致502Bad Gateway问题 * Bugfix: Add tooltip to tab labels in ToolManagement and SkillManagement Made-with: Cursor * Feat: Add presigned URL support for external MCP tool file access and improve agent execution flow * 使用已有的types,而非重复定义 * 针对用户上传的文件进行去重处理,限制文件最大个数 * Update frontend/types/chat.ts Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update frontend/types/chat.ts Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * 新增测试用例 * 修复单元测试 --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- backend/agents/create_agent_info.py | 104 +++++++++++++-- backend/apps/file_management_app.py | 6 +- backend/consts/model.py | 8 +- backend/database/attachment_db.py | 43 ++++-- backend/database/client.py | 4 +- .../managed_system_prompt_template_en.yaml | 22 ++- .../managed_system_prompt_template_zh.yaml | 21 ++- .../manager_system_prompt_template_en.yaml | 22 ++- .../manager_system_prompt_template_zh.yaml | 22 ++- backend/prompts/utils/prompt_generate_en.yaml | 17 +-- backend/prompts/utils/prompt_generate_zh.yaml | 17 +-- .../[locale]/chat/internal/chatInterface.tsx | 39 ++++-- frontend/lib/chat/chatAttachmentUtils.ts | 21 ++- frontend/lib/chatMessageExtractor.ts | 1 + frontend/types/chat.ts | 10 ++ test/backend/agents/test_create_agent_info.py | 96 +++++++++++++- test/backend/database/test_attachment_db.py | 125 +++++++++++++++--- 17 files changed, 487 insertions(+), 91 deletions(-) diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index e0fce0f47..0696cab34 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -554,17 +554,101 @@ async def prepare_prompt_templates( return prompt_templates -async def join_minio_file_description_to_query(minio_files, query): +async def join_minio_file_description_to_query( + minio_files, + query, + history=None, + max_files: int = 50, + max_chars: int = 10000, +): + """ + Join MinIO file descriptions to the user query. + + This function formats uploaded file information into a structured description + that includes both S3 URL (for internal tools) and Download URL (for external MCP tools). + It processes files from both the current message and historical messages. + + De-duplication is performed using the file URL as the unique key. A maximum + file count and total character limit are enforced to prevent prompt bloat. + + Args: + minio_files: List of file info dicts from current message upload + query: Original user query + history: Optional list of historical message dicts, each may contain minio_files + max_files: Maximum number of files to include (default 50) + max_chars: Maximum total characters for file descriptions (default 10000) + + Returns: + Modified query with file descriptions appended + """ final_query = query + seen_urls: set[str] = set() + all_files: list[dict] = [] + + # Collect files from current message first (higher priority) if minio_files and isinstance(minio_files, list): - file_descriptions = [] for file in minio_files: - if isinstance(file, dict) and "url" in file and file["url"] and "name" in file and file["name"]: - file_descriptions.append(f"File name: {file['name']}, S3 URL: s3:/{file['url']}") + if isinstance(file, dict) and file.get("url") and file.get("name"): + url = file["url"] + if url not in seen_urls: + seen_urls.add(url) + all_files.append(file) + + # Collect files from historical messages (lower priority, already-deduped) + if history and isinstance(history, list): + for msg in history: + if isinstance(msg, dict) and msg.get("minio_files"): + for file in msg["minio_files"]: + if isinstance(file, dict) and file.get("url") and file.get("name"): + url = file["url"] + if url not in seen_urls: + seen_urls.add(url) + all_files.append(file) + + # Enforce file count limit (keep most recent files by truncating from the end) + if len(all_files) > max_files: + all_files = all_files[:max_files] + logger.debug(f"File list truncated from {len(all_files)} to {max_files} files") + + if all_files: + file_descriptions: list[str] = [] + # Calculate fixed overhead that is added only once + prefix = "User uploaded files. The file information is as follows:\n" + suffix = f"\n\nUser wants to answer questions based on the information in the above files: {query}" + fixed_overhead = len(prefix) + len(suffix) + + for i, file in enumerate(all_files): + s3_url = f"s3:/{file['url']}" + presigned_url = file.get("presigned_url", "") + + # Build description with both URLs + if presigned_url: + desc = ( + f"File name: {file['name']}\n" + f"- S3 URL: {s3_url} [permanent, for internal tools like analyze_text_file]\n" + f"- Download URL: {presigned_url} [temporary (expires in 24h), for external MCP tools]" + ) + else: + desc = f"File name: {file['name']}, S3 URL: {s3_url} [permanent]" + + # Calculate total length if we include this description + # Each description after the first adds 2 chars for \n\n separator + separator_chars = 2 if i > 0 else 0 + total_len = sum(len(d) for d in file_descriptions) + len(desc) + separator_chars + fixed_overhead + + # Check if adding this description would exceed the character limit + if total_len > max_chars: + logger.debug( + f"File descriptions truncated at {len(file_descriptions)} files " + f"to stay within {max_chars} character limit" + ) + break + + file_descriptions.append(desc) + if file_descriptions: - final_query = "User uploaded files. The file information is as follows:\n" - final_query += "\n".join(file_descriptions) + "\n\n" - final_query += f"User wants to answer questions based on the information in the above files: {query}" + final_query = prefix + "\n\n".join(file_descriptions) + suffix + return final_query @@ -619,7 +703,11 @@ async def create_agent_run_info( version_no = 0 logger.info(f"Agent {agent_id} has no published version, using draft version 0") - final_query = await join_minio_file_description_to_query(minio_files=minio_files, query=query) + final_query = await join_minio_file_description_to_query( + minio_files=minio_files, + query=query, + history=history + ) model_list = await create_model_config_list(tenant_id) create_config_kwargs = { "agent_id": agent_id, diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index b8e1ce711..578277b6d 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -178,7 +178,7 @@ async def get_storage_file( "'base64' (return base64-encoded content for images)." ), ), - expires: int = Query(3600, description="URL validity period (seconds)"), + expires: int = Query(86400, description="URL validity period (seconds)"), filename: Optional[str] = Query(None, description="Original filename for download (optional)"), authorization: Optional[str] = Header(None, alias="Authorization") ): @@ -191,7 +191,7 @@ async def get_storage_file( - **object_name**: File object name - **download**: Download mode: ignore (default, return file info), stream (return file stream), redirect (redirect to download URL) - - **expires**: URL validity period in seconds (default 3600) + - **expires**: URL validity period in seconds (default 86400 = 24 hours) - **filename**: Original filename for download (optional, if not provided, will use object_name) Returns file information, download link, or file content @@ -628,7 +628,7 @@ async def get_storage_file_batch_urls( - attachments/{user_id}/*: Only the owner (user_id) can access - **request_data**: JSON request body containing object_names list - - **expires**: URL validity period in seconds (default 3600) + - **expires**: URL validity period in seconds (default 86400 = 24 hours) Returns URL and status information for each file """ diff --git a/backend/consts/model.py b/backend/consts/model.py index 91cf7d1b6..05e6426b2 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -128,10 +128,16 @@ class GlobalConfig(BaseModel): # Request models +class HistoryItem(BaseModel): + role: str + content: str + minio_files: Optional[List[Dict[str, Any]]] = None + + class AgentRequest(BaseModel): query: str conversation_id: Optional[int] = None - history: Optional[List[Dict]] = None + history: Optional[List[HistoryItem]] = None # Complete list of attachment information minio_files: Optional[List[Dict[str, Any]]] = None agent_id: Optional[int] = None diff --git a/backend/database/attachment_db.py b/backend/database/attachment_db.py index 1faabac23..c0efbade2 100644 --- a/backend/database/attachment_db.py +++ b/backend/database/attachment_db.py @@ -28,7 +28,13 @@ def generate_object_name(file_name: str, prefix: str = "attachments") -> str: return f"{prefix}/{timestamp}_{unique_id}{ext}" -def upload_file(file_path: str, object_name: Optional[str] = None, bucket: Optional[str] = None) -> Dict[str, Any]: +def upload_file( + file_path: str, + object_name: Optional[str] = None, + bucket: Optional[str] = None, + generate_presigned_url: bool = True, + presigned_url_expires: int = 86400 +) -> Dict[str, Any]: """ Upload local file to MinIO @@ -36,6 +42,8 @@ def upload_file(file_path: str, object_name: Optional[str] = None, bucket: Optio file_path: Local file path object_name: Object name, if not specified will be auto-generated bucket: Bucket name, if not specified will use default bucket + generate_presigned_url: Whether to generate presigned URL for external access (default True) + presigned_url_expires: Expiration time in seconds for presigned URL (default 86400 = 24 hours) Returns: Dict[str, Any]: Upload result, containing success flag, URL and error message (if any) @@ -55,6 +63,12 @@ def upload_file(file_path: str, object_name: Optional[str] = None, bucket: Optio if success: response["url"] = result + # Generate presigned URL for external access if requested + if generate_presigned_url: + presigned_result = get_file_url(object_name, bucket, presigned_url_expires) + if presigned_result.get("success"): + response["presigned_url"] = presigned_result["url"] + response["presigned_url_expires_in"] = presigned_url_expires else: response["error"] = result @@ -65,7 +79,9 @@ def upload_fileobj( file_obj: BinaryIO, file_name: str, bucket: Optional[str] = None, - prefix: str = "attachments" + prefix: str = "attachments", + generate_presigned_url: bool = True, + presigned_url_expires: int = 86400 ) -> Dict[str, Any]: """ Upload file object to MinIO @@ -75,6 +91,8 @@ def upload_fileobj( file_name: File name bucket: Bucket name, if not specified will use default bucket prefix: Object name prefix, default is "attachments" + generate_presigned_url: Whether to generate presigned URL for external access (default True) + presigned_url_expires: Expiration time in seconds for presigned URL (default 86400 = 24 hours) Returns: Dict[str, Any]: Upload result, containing success flag, URL and error message (if any) @@ -89,19 +107,28 @@ def upload_fileobj( file_obj.seek(0, os.SEEK_END) file_size = file_obj.tell() - # Reset to original position - file_obj.seek(current_pos) + # Seek to beginning for upload + file_obj.seek(0) # Upload file success, result = minio_client.upload_fileobj( file_obj, object_name, bucket) + # Restore original position + file_obj.seek(current_pos) + # Build response response = {"success": success, "object_name": object_name, "file_name": file_name, "file_size": file_size, "content_type": get_content_type(file_name), "upload_time": datetime.now().isoformat()} if success: response["url"] = result + # Generate presigned URL for external access if requested + if generate_presigned_url: + presigned_result = get_file_url(object_name, bucket, presigned_url_expires) + if presigned_result.get("success"): + response["presigned_url"] = presigned_result["url"] + response["presigned_url_expires_in"] = presigned_url_expires else: response["error"] = result @@ -134,14 +161,14 @@ def download_file(object_name: str, file_path: str, bucket: Optional[str] = None return response -def get_file_url(object_name: str, bucket: Optional[str] = None, expires: int = 3600) -> Dict[str, Any]: +def get_file_url(object_name: str, bucket: Optional[str] = None, expires: int = 86400) -> Dict[str, Any]: """ Get presigned URL for file Args: object_name: Object name bucket: Bucket name, if not specified will use default bucket - expires: URL expiration time in seconds + expires: URL expiration time in seconds (default 86400 = 24 hours) Returns: Dict[str, Any]: Result containing success flag, URL and error message (if any) @@ -223,8 +250,8 @@ def list_files(prefix: str = "", bucket: Optional[str] = None) -> List[Dict[str, for file in files: file["content_type"] = get_content_type(file["key"]) - # Get presigned URL (valid for 1 hour) - success, url = minio_client.get_file_url(file["key"], bucket, 3600) + # Get presigned URL (valid for 24 hours) + success, url = minio_client.get_file_url(file["key"], bucket, 86400) if success: file["url"] = url diff --git a/backend/database/client.py b/backend/database/client.py index 9b0b97a52..407f7b032 100644 --- a/backend/database/client.py +++ b/backend/database/client.py @@ -158,14 +158,14 @@ def download_file(self, object_name: str, file_path: str, bucket: Optional[str] self._ensure_initialized() return self._storage_client.download_file(object_name, file_path, bucket) - def get_file_url(self, object_name: str, bucket: Optional[str] = None, expires: int = 3600) -> Tuple[bool, str]: + def get_file_url(self, object_name: str, bucket: Optional[str] = None, expires: int = 86400) -> Tuple[bool, str]: """ Get presigned URL for file Args: object_name: Object name bucket: Bucket name, if not specified use default bucket - expires: URL expiration time in seconds + expires: URL expiration time in seconds (default 86400 = 24 hours) Returns: Tuple[bool, str]: (Success status, Presigned URL or error message) diff --git a/backend/prompts/managed_system_prompt_template_en.yaml b/backend/prompts/managed_system_prompt_template_en.yaml index 167be1f2b..d8103e5ae 100644 --- a/backend/prompts/managed_system_prompt_template_en.yaml +++ b/backend/prompts/managed_system_prompt_template_en.yaml @@ -48,7 +48,7 @@ system_prompt: |- Ethical Guidelines: Refuse hate speech, discriminatory content, and any requests that violate universal values. ### Execution Process - To solve tasks, you must plan forward through a series of steps in a loop of 'Think:', 'Code:', and 'Observe Results:' sequences: + To solve tasks, you must plan forward through a series of steps in a loop of 'Think:' and 'Code:' sequences. **IMPORTANT: You must NOT output 'Observe Results:' before code execution. Observation results can ONLY be generated after code execution.** 1. Think: - Determine which tools need to be used to obtain information or take action @@ -63,9 +63,7 @@ system_prompt: |- - Call tools correctly according to format specifications - To distinguish between code execution and displaying user code, use 'code' for executing code and 'code' for displaying code - Note that executed code is not visible to users. If users need to see the code, use 'code' for displaying code. - - 3. Observe Results: - - View code execution results + - **IMPORTANT**: After code execution, the system will return content with "Observation:" marker (this is the real execution result). Please continue your next thinking based on these real results. **Do NOT fabricate observation results before code execution.** After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop. @@ -96,15 +94,31 @@ system_prompt: |- {%- if tools and tools.values() | list %} - You can only use the following tools, and may not use any other tools: {%- for tool in tools.values() %} + {%- if tool.source == 'mcp' %} + - [MCP] {{ tool.name }}: {{ tool.description }} + Accepts input: {{tool.inputs}} + Returns output type: {{tool.output_type}} + {%- else %} - {{ tool.name }}: {{ tool.description }} Accepts input: {{tool.inputs}} Returns output type: {{tool.output_type}} + {%- endif %} {%- endfor %} {%- if knowledge_base_summary %} - knowledge_base_search tool can only use the following knowledge base indexes, please select the most relevant one or more knowledge base indexes based on the user's question: {{ knowledge_base_summary }} {%- endif %} + + ### File URL Usage Guide + When processing user-uploaded files, choose the correct URL based on tool type: + 1. **Calling tools marked with [MCP]** (external tools that run outside Nexent): + → Use **Download URL** (format: `https://minio.example.com/...?token=xxx`) + Reason: MCP tools run on external services and cannot access internal S3 storage + 2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image): + → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`) + Reason: Internal tools run inside Nexent and can directly access MinIO storage + {%- else %} - No tools are currently available {%- endif %} diff --git a/backend/prompts/managed_system_prompt_template_zh.yaml b/backend/prompts/managed_system_prompt_template_zh.yaml index c42d61c66..53be7d18f 100644 --- a/backend/prompts/managed_system_prompt_template_zh.yaml +++ b/backend/prompts/managed_system_prompt_template_zh.yaml @@ -113,7 +113,7 @@ system_prompt: |- {%- endif %} ### 执行流程 - 要解决任务,你必须通过一系列步骤向前规划,以'思考:'、'代码:'和'观察结果:'序列的循环进行: + 要解决任务,你必须通过一系列步骤向前规划,以'思考:'、'代码:'序列循环进行。**注意:禁止在代码执行前输出'观察结果:',观察结果只能由代码执行后产生。** 1. 思考: - 确定需要使用哪些工具获取信息或行动 @@ -128,9 +128,7 @@ system_prompt: |- - 根据格式规范正确调用工具 - 考虑到代码执行与展示用户代码的区别,使用'代码'表达运行代码,使用'代码'表达展示代码 - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码'表达展示代码。 - - 3. 观察结果: - - 查看代码执行结果 + - **重要**:代码执行后,系统会返回 "Observation:" 标记的内容(这是真实的执行结果)。请基于这些真实结果继续下一步思考,**不要在代码执行前自行编造观察结果**。 在思考结束后,当你认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。 @@ -161,9 +159,15 @@ system_prompt: |- {%- if tools and tools.values() | list %} - 你只能使用以下工具,不得使用任何其他工具: {%- for tool in tools.values() %} + {%- if tool.source == 'mcp' %} + - [MCP] {{ tool.name }}: {{ tool.description }} + 接受输入: {{tool.inputs}} + 返回输出类型: {{tool.output_type}} + {%- else %} - {{ tool.name }}: {{ tool.description }} 接受输入: {{tool.inputs}} 返回输出类型: {{tool.output_type}} + {%- endif %} {%- endfor %} {%- if knowledge_base_summary %} @@ -172,6 +176,15 @@ system_prompt: |- {%- endif %} + ### 文件链接使用指南 + 当处理用户上传的文件时,请根据工具类型选择正确的 URL: + 1. **调用标记为 [MCP] 的工具**(外部工具,运行在 Nexent 之外): + → 使用 **Download URL**(格式:`https://minio.example.com/...?token=xxx`) + 原因:MCP 工具运行在外部服务,无法访问内部 S3 存储 + 2. **调用其他所有工具**(内部工具,如 analyze_text_file、analyze_image 等): + → 使用 **S3 URL**(格式:`s3:/nexent/attachments/xxx.pdf`) + 原因:内部工具运行在 Nexent 内部,可以直接访问 MinIO 存储 + {%- else %} - 当前没有可用的工具 {%- endif %} diff --git a/backend/prompts/manager_system_prompt_template_en.yaml b/backend/prompts/manager_system_prompt_template_en.yaml index 28e6cb2b1..50cfbc411 100644 --- a/backend/prompts/manager_system_prompt_template_en.yaml +++ b/backend/prompts/manager_system_prompt_template_en.yaml @@ -48,7 +48,7 @@ system_prompt: |- Ethical Guidelines: Refuse hate speech, discriminatory content, and any requests that violate universal values. ### Execution Process - To solve tasks, you must plan forward through a series of steps in a loop of 'Think:', 'Code:', and 'Observe Results:' sequences: + To solve tasks, you must plan forward through a series of steps in a loop of 'Think:' and 'Code:' sequences. **IMPORTANT: You must NOT output 'Observe Results:' before code execution. Observation results can ONLY be generated after code execution.** 1. Think: - Analyze current task status and progress @@ -64,10 +64,7 @@ system_prompt: |- - Correctly call tools or agents to solve problems - To distinguish between code execution and displaying user code, use 'code' for executing code and 'code' for displaying code - Note that executed code is not visible to users. If users need to see the code, use 'code' for displaying code. - - 3. Observe Results: - - View code execution results - - Decide on next action based on results + - **IMPORTANT**: After code execution, the system will return content with "Observation:" marker (this is the real execution result). Please continue your next thinking based on these real results. **Do NOT fabricate observation results before code execution.** After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop. @@ -99,15 +96,30 @@ system_prompt: |- {%- if tools and tools.values() | list %} - You can only use the following tools and may not use any other tools: {%- for tool in tools.values() %} + {%- if tool.source == 'mcp' %} + - [MCP] {{ tool.name }}: {{ tool.description }} + Accepts input: {{tool.inputs}} + Returns output type: {{tool.output_type}} + {%- else %} - {{ tool.name }}: {{ tool.description }} Accepts input: {{tool.inputs}} Returns output type: {{tool.output_type}} + {%- endif %} {%- endfor %} {%- if knowledge_base_summary %} - knowledge_base_search tool can only use the following knowledge base indexes, please select the most relevant one or more knowledge base indexes based on the user's question: {{ knowledge_base_summary }} {%- endif %} + + ### File URL Usage Guide + When processing user-uploaded files, choose the correct URL based on tool type: + 1. **Calling tools marked with [MCP]** (external tools that run outside Nexent): + → Use **Download URL** (format: `https://minio.example.com/...?token=xxx`) + Reason: MCP tools run on external services and cannot access internal S3 storage + 2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image): + → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`) + Reason: Internal tools run inside Nexent and can directly access MinIO storage {%- else %} - No tools are currently available {%- endif %} diff --git a/backend/prompts/manager_system_prompt_template_zh.yaml b/backend/prompts/manager_system_prompt_template_zh.yaml index 015b74450..3c7144cad 100644 --- a/backend/prompts/manager_system_prompt_template_zh.yaml +++ b/backend/prompts/manager_system_prompt_template_zh.yaml @@ -111,7 +111,7 @@ system_prompt: |- {%- endif %} ### 执行流程 - 要解决任务,你必须通过一系列步骤向前规划,以'思考:'、'代码:'和'观察结果:'序列的循环进行: + 要解决任务,你必须通过一系列步骤向前规划,以'思考:'和'代码:'序列循环进行。**注意:禁止在代码执行前输出'观察结果:',观察结果只能由代码执行后产生。** 1. 思考: - 分析当前任务状态和进展 @@ -127,10 +127,7 @@ system_prompt: |- - 正确调用工具或助手解决问题 - 考虑到代码执行与展示用户代码的区别,使用'代码'表达运行代码,使用'代码'表达展示代码 - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码'表达展示代码。 - - 3. 观察结果: - - 查看代码执行结果 - - 根据结果决定下一步行动 + - **重要**:代码执行后,系统会返回 "Observation:" 标记的内容(这是真实的执行结果)。请基于这些真实结果继续下一步思考,**不要在代码执行前自行编造观察结果**。 在思考结束后,当你认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。 @@ -162,15 +159,30 @@ system_prompt: |- {%- if tools and tools.values() | list %} - 你只能使用以下工具,不得使用任何其他工具: {%- for tool in tools.values() %} + {%- if tool.source == 'mcp' %} + - [MCP] {{ tool.name }}: {{ tool.description }} + 接受输入: {{tool.inputs}} + 返回输出类型: {{tool.output_type}} + {%- else %} - {{ tool.name }}: {{ tool.description }} 接受输入: {{tool.inputs}} 返回输出类型: {{tool.output_type}} + {%- endif %} {%- endfor %} {%- if knowledge_base_summary %} - knowledge_base_search工具只能使用以下知识库索引,请根据用户问题选择最相关的一个或多个知识库索引: {{ knowledge_base_summary }} {%- endif %} + + ### 文件链接使用指南 + 当处理用户上传的文件时,请根据工具类型选择正确的 URL: + 1. **调用标记为 [MCP] 的工具**(外部工具,运行在 Nexent 之外): + → 使用 **Download URL**(格式:`https://minio.example.com/...?token=xxx`) + 原因:MCP 工具运行在外部服务,无法访问内部 S3 存储 + 2. **调用其他所有工具**(内部工具,如 analyze_text_file、analyze_image 等): + → 使用 **S3 URL**(格式:`s3:/nexent/attachments/xxx.pdf`) + 原因:内部工具运行在 Nexent 内部,可以直接访问 MinIO 存储 {%- else %} - 当前没有可用的工具 {%- endif %} diff --git a/backend/prompts/utils/prompt_generate_en.yaml b/backend/prompts/utils/prompt_generate_en.yaml index 596bb2cb9..c54e2ee88 100644 --- a/backend/prompts/utils/prompt_generate_en.yaml +++ b/backend/prompts/utils/prompt_generate_en.yaml @@ -43,7 +43,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- 3. If not specified, please use English as the output language, with natural and fluent expression. ### Agent Execution Process: - To solve tasks, you must plan forward through a series of steps in a loop of 'Think:', 'Code:', and 'Observe Results:' sequences: + To solve tasks, you must plan forward through a series of steps in a loop of 'Think:' and 'Code:' sequences. **IMPORTANT: You must NOT output 'Observe Results:' before code execution. Observation results can ONLY be generated after code execution.** 1. Think: - Determine which tools/assistants need to be used to obtain information or take action @@ -55,9 +55,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- - Call tools/assistants correctly according to format specifications - To distinguish between code execution and displaying user code, use 'code' for executing code and 'code' for displaying code - Note that executed code is not visible to users. If users need to see the code, use 'code' for displaying code. - - 3. Observe Results: - - View code execution results + - **IMPORTANT**: After code execution, the system will return content with "Observation:" marker (this is the real execution result). Please continue your next thinking based on these real results. **Do NOT fabricate observation results before code execution.** After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop. @@ -82,7 +80,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- knowledge_info = knowledge_base_search(query="Oriental Pearl Tower introduction", index_names=["local_knowledge_base1", "local_knowledge_base2"]) print(knowledge_info) - Observe Results: No results found for query "Oriental Pearl Tower introduction". The search results are insufficient to support an answer. + # System returns Observation: No relevant results found Think: Since no relevant information was found in the local knowledge base, I need to use the web_search tool to query network information. Code: @@ -90,7 +88,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- web_info = web_search(query="Oriental Pearl Tower introduction") print(web_info) - Observe Results: The Oriental Pearl TV Tower is located in Lujiazui, Pudong New Area, Shanghai, China... + # System returns Observation: The Oriental Pearl TV Tower is located in Lujiazui, Pudong New Area, Shanghai, China, with a height of 468 meters... Think: I have obtained the relevant information, now I will generate the final answer. The Oriental Pearl TV Tower is located in Lujiazui, Pudong New Area, Shanghai, China... @@ -105,7 +103,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- itinerary_result = travel_planning_assistant(task="Help me plan tomorrow's trip from Shanghai to Beijing") print(itinerary_result) - Observe Results: Tomorrow's trip planning from Shanghai to Beijing, including transportation, accommodation, attractions, etc. + # System returns Observation: Trip plan completed: High-speed train G2, departs 8:00, arrives Beijing South Station at 11:30; Hotel near Wangfujing; Recommended attractions: Tiananmen, Forbidden City, Great Wall... Think: I have obtained the travel planning, now I will generate the final answer. Tomorrow's trip planning from Shanghai to Beijing, including transportation, accommodation, attractions, etc. @@ -120,7 +118,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- weather_data = weather_api(city="Beijing") print(weather_data) - Observe Results: {"temperature": 25, "humidity": "60%", "condition": "sunny"} + # System returns Observation: {"city": "Beijing", "temperature": 25, "humidity": 60, "condition": "sunny"} Think: Now I have weather data, let the analysis assistant help me analyze this data. Code: @@ -128,7 +126,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- analysis_result = data_analysis_assistant(task="Analyze today's weather data: temperature 25 degrees, humidity 60%, sunny") print(analysis_result) - Observe Results: Today's weather is suitable, temperature is moderate, humidity is normal, suitable for outdoor activities. + # System returns Observation: Based on weather data analysis, today is suitable for outdoor activities. Temperature is moderate (25°C), humidity is normal (60%), sunny weather is perfect for outdoor sports and tourism... Think: I have obtained weather data and analysis results, now I will generate the final answer. Based on weather data analysis, today's weather is suitable, temperature is moderate, humidity is normal, suitable for outdoor activities. @@ -158,7 +156,6 @@ FEW_SHOTS_SYSTEM_PROMPT: |- right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right) - Observe Results: The Python quick sort code. Think: I have obtained the Python quick sort code, now I will generate the final answer. The Python quick sort code is as follows: diff --git a/backend/prompts/utils/prompt_generate_zh.yaml b/backend/prompts/utils/prompt_generate_zh.yaml index e48b97204..8c19d138e 100644 --- a/backend/prompts/utils/prompt_generate_zh.yaml +++ b/backend/prompts/utils/prompt_generate_zh.yaml @@ -42,7 +42,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- 3.若未指定语言,请使用中文输出,语言表达要自然流畅。 ### Agent的执行流程: - 要解决任务,Agent必须通过一系列步骤向前规划,以'思考:'、'代码:'和'观察结果:'序列的循环进行: + 要解决任务,Agent必须通过一系列步骤向前规划,以'思考:'和'代码:'序列循环进行。**注意:禁止在代码执行前输出'观察结果:',观察结果只能由代码执行后产生。** 1. 思考: - 确定需要使用哪些工具/助手获取信息或行动 @@ -54,9 +54,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- - 根据格式规范正确调用工具/助手 - 考虑到代码执行与展示用户代码的区别,使用'代码'表达运行代码,使用'代码'表达展示代码 - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码'表达展示代码。 - - 3. 观察结果: - - 查看代码执行结果 + - **重要**:代码执行后,系统会返回 "Observation:" 标记的内容(这是真实的执行结果)。请基于这些真实结果继续下一步思考,**不要在代码执行前自行编造观察结果**。 在思考结束后,当Agent认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。 @@ -81,7 +79,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- knowledge_info = knowledge_base_search(query="东方明珠 介绍", index_names=["本地知识库1", "本地知识库2"]) print(knowledge_info) - 观察结果:未找到查询"东方明珠 介绍"的结果。检索结果难以支撑回答。 + # 系统返回 Observation: 未找到相关结果 思考:从本地知识库中没有找到相关信息,我需要使用web_search工具查询网络信息。 代码: @@ -89,7 +87,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- web_info = web_search(query="东方明珠 介绍") print(web_info) - 观察结果:东方明珠广播电视塔位于中国上海市浦东新区陆家嘴... + # 系统返回 Observation: 东方明珠广播电视塔位于中国上海市浦东新区陆家嘴,塔高468米,是中国著名的地标建筑之一... 思考:我已经获得了有关信息,现在我将生成最终回答。 东方明珠广播电视塔位于中国上海市浦东新区陆家嘴... @@ -104,7 +102,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- itinerary_result = travel_planning_assistant(task="帮我规划明天从上海出发去北京的行程") print(itinerary_result) - 观察结果:明天从上海出发去北京的行程规划,包括交通、住宿、景点等。 + # 系统返回 Observation: 行程规划已完成,包括:高铁G2,8:00出发,11:30到达北京南站;酒店预订于王府井附近;景点推荐:天安门、故宫、长城... 思考:我已经获得了出行规划,现在我将生成最终回答。 明天从上海出发去北京的行程规划,包括交通、住宿、景点等。 @@ -119,7 +117,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- weather_data = weather_api(city="北京") print(weather_data) - 观察结果:{"temperature": 25, "humidity": 60%, "condition": "晴天"} + # 系统返回 Observation: {"city": "北京", "temperature": 25, "humidity": 60, "condition": "晴天"} 思考:现在我有天气数据了,让分析助手帮我分析这些数据。 代码: @@ -127,7 +125,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- analysis_result = data_analysis_assistant(task="分析今天的天气数据:温度25度,湿度60%,晴天") print(analysis_result) - 观察结果:今天天气适宜,温度适中,湿度正常,适合户外活动。 + # 系统返回 Observation: 根据天气数据分析,今天天气适宜外出活动,温度适中(25℃),湿度正常(60%),晴天适合户外运动和旅游... 思考:我已经获得了天气数据和分析结果,现在我将生成最终回答。 根据天气数据分析,今天天气适宜,温度适中,湿度正常,适合户外活动。 @@ -155,7 +153,6 @@ FEW_SHOTS_SYSTEM_PROMPT: |- right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right) - 观察结果:快速排序的python代码。 思考:我已经获得了快速排序的python代码,现在我将生成最终回答。 快速排序的python代码如下: diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx index c6166e3f4..5cac4e472 100644 --- a/frontend/app/[locale]/chat/internal/chatInterface.tsx +++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx @@ -30,7 +30,7 @@ import { createMessageAttachments, cleanupAttachmentUrls, } from "@/lib/chat/chatAttachmentUtils"; -import { ConversationListItem, ApiConversationDetail } from "@/types/chat"; +import { ConversationListItem, ApiConversationDetail, HistoryItem } from "@/types/chat"; import { ChatMessageType } from "@/types/chat"; import { handleStreamResponse } from "@/app/chat/streaming/chatStreamHandler"; import { @@ -258,6 +258,7 @@ export function ChatInterface() { // Handle file upload let uploadedFileUrls: Record = {}; let objectNames: Record = {}; // Add object name mapping + let presignedUrls: Record = {}; // Store presigned URLs for external MCP tool access if (attachments.length > 0) { // Show loading state @@ -267,13 +268,16 @@ export function ChatInterface() { const uploadResult = await uploadAttachments(attachments, t); uploadedFileUrls = uploadResult.uploadedFileUrls; objectNames = uploadResult.objectNames; // Get object name mapping + presignedUrls = uploadResult.presignedUrls; // Get presigned URLs for external access } // Use preprocessing function to create message attachments const messageAttachments = createMessageAttachments( attachments, uploadedFileUrls, - fileUrls + fileUrls, + objectNames, + presignedUrls ); // Create user message object @@ -434,13 +438,29 @@ export function ChatInterface() { conversation_id: id, history: currentMessages .filter((msg) => msg.id !== userMessage.id) - .map((msg) => ({ - role: msg.role, - content: - msg.role === ROLE_ASSISTANT - ? msg.finalAnswer?.trim() || msg.content || "" - : msg.content || "", - })), + .map((msg) => { + const historyItem: HistoryItem = { + role: msg.role, + content: + msg.role === ROLE_ASSISTANT + ? msg.finalAnswer?.trim() || msg.content || "" + : msg.content || "", + }; + // Include attachment info for historical messages so the agent + // can reference files from previous turns + if (msg.attachments && msg.attachments.length > 0) { + historyItem.minio_files = msg.attachments.map((attachment) => ({ + object_name: attachment.object_name || "", + name: attachment.name, + type: attachment.type, + size: attachment.size, + url: attachment.url || "", + presigned_url: attachment.presigned_url || "", + description: attachment.description || "", + })); + } + return historyItem; + }), minio_files: messageAttachments.length > 0 ? messageAttachments.map((attachment) => { @@ -456,6 +476,7 @@ export function ChatInterface() { type: attachment.type, size: attachment.size, url: uploadedFileUrls[attachment.name] || attachment.url, + presigned_url: presignedUrls[attachment.name] || "", description: description, }; }) diff --git a/frontend/lib/chat/chatAttachmentUtils.ts b/frontend/lib/chat/chatAttachmentUtils.ts index fc442521a..c85615b4e 100644 --- a/frontend/lib/chat/chatAttachmentUtils.ts +++ b/frontend/lib/chat/chatAttachmentUtils.ts @@ -1,7 +1,7 @@ import type { Dispatch, SetStateAction } from "react"; import { conversationService } from "@/services/conversationService"; import { storageService } from "@/services/storageService"; -import { FilePreview } from "@/types/chat"; +import type { FileAttachment, FilePreview } from "@/types/chat"; import log from "@/lib/logger"; /** @@ -40,10 +40,11 @@ export const uploadAttachments = async ( ): Promise<{ uploadedFileUrls: Record; objectNames: Record; + presignedUrls: Record; error?: string; }> => { if (attachments.length === 0) { - return { uploadedFileUrls: {}, objectNames: {} }; + return { uploadedFileUrls: {}, objectNames: {}, presignedUrls: {} }; } try { @@ -53,22 +54,28 @@ export const uploadAttachments = async ( const uploadedFileUrls: Record = {}; const objectNames: Record = {}; + const presignedUrls: Record = {}; if (uploadResult.success_count > 0) { uploadResult.results.forEach((result) => { if (result.success) { uploadedFileUrls[result.file_name] = result.url; objectNames[result.file_name] = result.object_name; + // Store presigned URL for external MCP tool access + if (result.presigned_url) { + presignedUrls[result.file_name] = result.presigned_url; + } } }); } - return { uploadedFileUrls, objectNames }; + return { uploadedFileUrls, objectNames, presignedUrls }; } catch (error) { log.error(t("chatPreprocess.fileUploadFailed"), error); return { uploadedFileUrls: {}, objectNames: {}, + presignedUrls: {}, error: error instanceof Error ? error.message : String(error), }; } @@ -80,8 +87,10 @@ export const uploadAttachments = async ( export const createMessageAttachments = ( attachments: FilePreview[], uploadedFileUrls: Record, - fileUrls: Record -): { type: string; name: string; size: number; url?: string }[] => { + fileUrls: Record, + objectNames?: Record, + presignedUrls?: Record +): FileAttachment[] => { return attachments.map((attachment) => ({ type: attachment.type, name: attachment.file.name, @@ -91,6 +100,8 @@ export const createMessageAttachments = ( (attachment.type === "image" ? attachment.previewUrl : fileUrls[attachment.id]), + object_name: objectNames?.[attachment.file.name], + presigned_url: presignedUrls?.[attachment.file.name], })); }; diff --git a/frontend/lib/chatMessageExtractor.ts b/frontend/lib/chatMessageExtractor.ts index 906ba59d8..ddc12b2bf 100644 --- a/frontend/lib/chatMessageExtractor.ts +++ b/frontend/lib/chatMessageExtractor.ts @@ -268,6 +268,7 @@ export function extractUserMsgFromResponse( size: item.size || 0, object_name: item.object_name, url: item.url, + presigned_url: item.presigned_url, // Preserve presigned_url for MCP tool access description: item.description, }; }); diff --git a/frontend/types/chat.ts b/frontend/types/chat.ts index 423faa325..a1381af1f 100644 --- a/frontend/types/chat.ts +++ b/frontend/types/chat.ts @@ -83,6 +83,7 @@ export interface FileAttachment { size: number url?: string object_name?: string + presigned_url?: string // Temporary URL for external tools (e.g., MCP); expires after a configurable period (24 hours by default) description?: string } @@ -227,9 +228,17 @@ export interface MinioFileItem { size: number object_name?: string url?: string + presigned_url?: string // Temporary URL for external tools (e.g., MCP), default 24h validity description?: string } +// History item for API request payload +export interface HistoryItem { + role: string; + content: string; + minio_files?: MinioFileItem[]; +} + export interface ApiMessage { role: "user" | "assistant" message: ApiMessageItem[] @@ -323,6 +332,7 @@ export interface StorageUploadResult { content_type: string; upload_time: string; url: string; + presigned_url?: string; error?: string; }[]; } \ No newline at end of file diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index ff2655e19..41e81a87f 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -2513,7 +2513,7 @@ async def test_create_agent_run_info_success(self): # Verify that other functions were called correctly mock_join_query.assert_called_once_with( - minio_files=[], query="test query") + minio_files=[], query="test query", history=[]) mock_create_models.assert_called_once_with("tenant_1") mock_create_agent.assert_called_once_with( agent_id="agent_1", @@ -3018,7 +3018,7 @@ async def test_join_minio_file_description_to_query_with_files(self): result = await join_minio_file_description_to_query(minio_files, query) - expected = "User uploaded files. The file information is as follows:\nFile name: 1.pdf, S3 URL: s3://nexent/1.pdf\nFile name: 2.pdf, S3 URL: s3://nexent/2.pdf\n\nUser wants to answer questions based on the information in the above files: test query" + expected = "User uploaded files. The file information is as follows:\nFile name: 1.pdf, S3 URL: s3://nexent/1.pdf [permanent]\n\nFile name: 2.pdf, S3 URL: s3://nexent/2.pdf [permanent]\n\nUser wants to answer questions based on the information in the above files: test query" assert result == expected @pytest.mark.asyncio @@ -3054,6 +3054,98 @@ async def test_join_minio_file_description_to_query_no_descriptions(self): assert result == "test query" + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_deduplication_current(self): + """Test that duplicate files in current message are de-duplicated by URL""" + minio_files = [ + {"url": "/nexent/1.pdf", "name": "1.pdf"}, + {"url": "/nexent/1.pdf", "name": "1.pdf"}, # Duplicate URL + {"url": "/nexent/2.pdf", "name": "2.pdf"}, + ] + query = "test query" + + result = await join_minio_file_description_to_query(minio_files, query) + + # Count occurrences of "File name: 1.pdf" which should appear exactly once + assert result.count("File name: 1.pdf") == 1 + assert result.count("File name: 2.pdf") == 1 + # Total file description blocks should be 2, not 3 + assert result.count("S3 URL:") == 2 + + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_deduplication_history(self): + """Test that files in history are de-duplicated against current message""" + minio_files = [{"url": "/nexent/1.pdf", "name": "1.pdf"}] + history = [ + {"minio_files": [{"url": "/nexent/1.pdf", "name": "1.pdf"}]}, # Same URL as current + {"minio_files": [{"url": "/nexent/2.pdf", "name": "2.pdf"}]}, + ] + query = "test query" + + result = await join_minio_file_description_to_query(minio_files, query, history) + + # Count occurrences of "File name:" which should appear exactly once for each unique file + assert result.count("File name: 1.pdf") == 1 + assert result.count("File name: 2.pdf") == 1 + # Total file description blocks should be 2, not 3 + assert result.count("S3 URL:") == 2 + + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_max_files(self): + """Test that file list is truncated when exceeding max_files limit""" + minio_files = [ + {"url": f"/nexent/file_{i}.pdf", "name": f"file_{i}.pdf"} + for i in range(10) + ] + query = "test query" + + result = await join_minio_file_description_to_query(minio_files, query, max_files=5) + + for i in range(5): + assert f"file_{i}.pdf" in result + for i in range(5, 10): + assert f"file_{i}.pdf" not in result + + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_max_chars(self): + """Test that file descriptions are truncated when exceeding max_chars limit""" + # Each file description is roughly 72 chars + # With prefix (~56) and suffix (~100), fixed overhead is ~156 chars + # Setting max_chars=100 should prevent ANY file from being included + # (since even one file needs ~72 + 156 = 228 chars) + minio_files = [ + {"url": f"/nexent/file_{i}.pdf", "name": f"file_{i}.pdf"} + for i in range(10) + ] + query = "test query" + + # Very small limit - should result in no files being included + result = await join_minio_file_description_to_query(minio_files, query, max_chars=100) + assert result == "test query" + + # Reasonable limit - should include some files + # With 500 chars, we can fit: 500 - 156 = 344 available chars + # Each file is ~72 chars, so we can fit ~4 files + result = await join_minio_file_description_to_query(minio_files, query, max_chars=500) + # Should include at least some files but not all 10 + assert "file_0.pdf" in result + assert result.count("File name:") < 10 + + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_current_files_priority(self): + """Test that current message files appear before history files when deduping""" + minio_files = [{"url": "/nexent/1.pdf", "name": "current_1.pdf"}] + history = [ + {"minio_files": [{"url": "/nexent/2.pdf", "name": "history_2.pdf"}]}, + ] + query = "test query" + + result = await join_minio_file_description_to_query(minio_files, query, history) + + pos_current = result.find("current_1.pdf") + pos_history = result.find("history_2.pdf") + assert pos_current < pos_history, "Current message files should appear before history files" + class TestPreparePromptTemplates: """Tests for the prepare_prompt_templates function""" diff --git a/test/backend/database/test_attachment_db.py b/test/backend/database/test_attachment_db.py index 8b1998cf1..00263a57c 100644 --- a/test/backend/database/test_attachment_db.py +++ b/test/backend/database/test_attachment_db.py @@ -127,15 +127,17 @@ def test_generate_object_name_format(self): class TestUploadFile: """Test cases for upload_file function""" + @patch('backend.database.attachment_db.get_file_url') @patch('backend.database.attachment_db.os.path.exists') @patch('backend.database.attachment_db.os.path.getsize') @patch('backend.database.attachment_db.os.path.basename') - def test_upload_file_success(self, mock_basename, mock_getsize, mock_exists): - """Test successful file upload""" + def test_upload_file_success(self, mock_basename, mock_getsize, mock_exists, mock_get_file_url): + """Test successful file upload with presigned URL""" mock_basename.return_value = 'test.txt' mock_exists.return_value = True mock_getsize.return_value = 1024 minio_client_mock.upload_file.return_value = (True, '/bucket/attachments/test.txt') + mock_get_file_url.return_value = {'success': True, 'url': 'http://minio:9000/presigned-url?signature=xxx'} result = upload_file('/path/to/test.txt', 'attachments/test.txt', 'bucket') @@ -144,6 +146,9 @@ def test_upload_file_success(self, mock_basename, mock_getsize, mock_exists): assert result['file_name'] == 'test.txt' assert result['file_size'] == 1024 assert 'url' in result + assert 'presigned_url' in result + assert result['presigned_url'] == 'http://minio:9000/presigned-url?signature=xxx' + assert result['presigned_url_expires_in'] == 86400 assert 'upload_time' in result minio_client_mock.upload_file.assert_called_once_with( '/path/to/test.txt', 'attachments/test.txt', 'bucket' @@ -153,13 +158,15 @@ def test_upload_file_success(self, mock_basename, mock_getsize, mock_exists): @patch('backend.database.attachment_db.os.path.getsize') @patch('backend.database.attachment_db.os.path.basename') @patch('backend.database.attachment_db.generate_object_name') - def test_upload_file_auto_generate_object_name(self, mock_generate, mock_basename, mock_getsize, mock_exists): + @patch('backend.database.attachment_db.get_file_url') + def test_upload_file_auto_generate_object_name(self, mock_get_file_url, mock_generate, mock_basename, mock_getsize, mock_exists): """Test upload_file auto-generates object name when not provided""" mock_basename.return_value = 'test.txt' mock_exists.return_value = True mock_getsize.return_value = 1024 mock_generate.return_value = 'attachments/20240101120000_abc123.txt' minio_client_mock.upload_file.return_value = (True, '/bucket/attachments/20240101120000_abc123.txt') + mock_get_file_url.return_value = {'success': True, 'url': 'http://minio:9000/presigned-url'} result = upload_file('/path/to/test.txt', None, 'bucket') @@ -183,30 +190,76 @@ def test_upload_file_failure(self, mock_basename, mock_getsize, mock_exists): assert result['success'] is False assert result['error'] == 'Upload failed' assert 'url' not in result + assert 'presigned_url' not in result @patch('backend.database.attachment_db.os.path.exists') @patch('backend.database.attachment_db.os.path.getsize') @patch('backend.database.attachment_db.os.path.basename') - def test_upload_file_nonexistent_file(self, mock_basename, mock_getsize, mock_exists): - """Test upload_file with nonexistent file""" + @patch('backend.database.attachment_db.get_file_url') + def test_upload_file_without_presigned_url(self, mock_get_file_url, mock_basename, mock_getsize, mock_exists): + """Test upload_file when generate_presigned_url is False""" mock_basename.return_value = 'test.txt' - mock_exists.return_value = False - mock_getsize.return_value = 0 + mock_exists.return_value = True + mock_getsize.return_value = 1024 minio_client_mock.upload_file.return_value = (True, '/bucket/attachments/test.txt') - result = upload_file('/path/to/nonexistent.txt', 'attachments/test.txt', 'bucket') + result = upload_file('/path/to/test.txt', 'attachments/test.txt', 'bucket', generate_presigned_url=False) + + assert result['success'] is True + assert 'url' in result + assert 'presigned_url' not in result + mock_get_file_url.assert_not_called() + + @patch('backend.database.attachment_db.os.path.exists') + @patch('backend.database.attachment_db.os.path.getsize') + @patch('backend.database.attachment_db.os.path.basename') + @patch('backend.database.attachment_db.get_file_url') + def test_upload_file_custom_presigned_url_expires(self, mock_get_file_url, mock_basename, mock_getsize, mock_exists): + """Test upload_file with custom presigned URL expiration""" + mock_basename.return_value = 'test.txt' + mock_exists.return_value = True + mock_getsize.return_value = 1024 + minio_client_mock.upload_file.return_value = (True, '/bucket/attachments/test.txt') + mock_get_file_url.return_value = {'success': True, 'url': 'http://minio:9000/presigned-url'} + + result = upload_file('/path/to/test.txt', 'attachments/test.txt', 'bucket', presigned_url_expires=7200) + + assert result['success'] is True + assert result['presigned_url_expires_in'] == 7200 + mock_get_file_url.assert_called_once_with('attachments/test.txt', 'bucket', 7200) + + @patch('backend.database.attachment_db.get_file_url') + @patch('backend.database.attachment_db.os.path.exists') + @patch('backend.database.attachment_db.os.path.getsize') + @patch('backend.database.attachment_db.os.path.basename') + def test_upload_file_nonexistent_file(self, mock_basename, mock_getsize, mock_exists, mock_get_file_url): + """Test upload_file handles nonexistent local file gracefully""" + mock_basename.return_value = 'missing.txt' + mock_exists.return_value = False + mock_getsize.return_value = 1024 + minio_client_mock.upload_file.return_value = (True, '/bucket/attachments/missing.txt') + mock_get_file_url.return_value = {'success': True, 'url': 'http://minio:9000/presigned-url'} + + result = upload_file('/path/to/missing.txt', 'attachments/missing.txt', 'bucket') + assert result['success'] is True assert result['file_size'] == 0 + assert result['file_name'] == 'missing.txt' + assert 'url' in result + assert 'presigned_url' in result + mock_getsize.assert_not_called() class TestUploadFileobj: """Test cases for upload_fileobj function""" @patch('backend.database.attachment_db.generate_object_name') - def test_upload_fileobj_success(self, mock_generate): - """Test successful file object upload""" + @patch('backend.database.attachment_db.get_file_url') + def test_upload_fileobj_success(self, mock_get_file_url, mock_generate): + """Test successful file object upload with presigned URL""" mock_generate.return_value = 'attachments/20240101120000_abc123.txt' minio_client_mock.upload_fileobj.return_value = (True, '/bucket/attachments/20240101120000_abc123.txt') + mock_get_file_url.return_value = {'success': True, 'url': 'http://minio:9000/presigned-url?signature=xxx'} file_obj = BytesIO(b'test data') result = upload_fileobj(file_obj, 'test.txt', 'bucket', 'attachments') @@ -216,9 +269,13 @@ def test_upload_fileobj_success(self, mock_generate): assert result['file_name'] == 'test.txt' assert result['file_size'] == len(b'test data') assert 'url' in result + assert 'presigned_url' in result + assert result['presigned_url'] == 'http://minio:9000/presigned-url?signature=xxx' + assert result['presigned_url_expires_in'] == 86400 assert 'upload_time' in result mock_generate.assert_called_once_with('test.txt', prefix='attachments') minio_client_mock.upload_fileobj.assert_called_once() + mock_get_file_url.assert_called_once() @patch('backend.database.attachment_db.generate_object_name') def test_upload_fileobj_failure(self, mock_generate): @@ -232,22 +289,60 @@ def test_upload_fileobj_failure(self, mock_generate): assert result['success'] is False assert result['error'] == 'Upload failed' assert 'url' not in result + assert 'presigned_url' not in result @patch('backend.database.attachment_db.generate_object_name') - def test_upload_fileobj_preserves_file_position(self, mock_generate): - """Test upload_fileobj preserves original file position""" + @patch('backend.database.attachment_db.get_file_url') + def test_upload_fileobj_preserves_file_position(self, mock_get_file_url, mock_generate): + """Test upload_fileobj reads full content and preserves original file position""" mock_generate.return_value = 'attachments/test.txt' - minio_client_mock.upload_fileobj.return_value = (True, '/bucket/attachments/test.txt') + mock_get_file_url.return_value = {'success': True, 'url': 'http://minio:9000/presigned-url'} - file_obj = BytesIO(b'test data') + file_obj = BytesIO(b'full test data content') original_pos = 4 file_obj.seek(original_pos) + captured_data = {} + def capture_upload(file_obj_arg, object_name, bucket): + captured_data['content'] = file_obj_arg.read() + return (True, '/bucket/attachments/test.txt') + minio_client_mock.upload_fileobj.side_effect = capture_upload + result = upload_fileobj(file_obj, 'test.txt', 'bucket') - # File position should be restored + assert captured_data['content'] == b'full test data content' assert file_obj.tell() == original_pos + @patch('backend.database.attachment_db.generate_object_name') + @patch('backend.database.attachment_db.get_file_url') + def test_upload_fileobj_without_presigned_url(self, mock_get_file_url, mock_generate): + """Test upload_fileobj when generate_presigned_url is False""" + mock_generate.return_value = 'attachments/test.txt' + minio_client_mock.upload_fileobj.return_value = (True, '/bucket/attachments/test.txt') + + file_obj = BytesIO(b'test data') + result = upload_fileobj(file_obj, 'test.txt', 'bucket', generate_presigned_url=False) + + assert result['success'] is True + assert 'url' in result + assert 'presigned_url' not in result + mock_get_file_url.assert_not_called() + + @patch('backend.database.attachment_db.generate_object_name') + @patch('backend.database.attachment_db.get_file_url') + def test_upload_fileobj_custom_presigned_url_expires(self, mock_get_file_url, mock_generate): + """Test upload_fileobj with custom presigned URL expiration""" + mock_generate.return_value = 'attachments/test.txt' + minio_client_mock.upload_fileobj.return_value = (True, '/bucket/attachments/test.txt') + mock_get_file_url.return_value = {'success': True, 'url': 'http://minio:9000/presigned-url'} + + file_obj = BytesIO(b'test data') + result = upload_fileobj(file_obj, 'test.txt', 'bucket', presigned_url_expires=7200) + + assert result['success'] is True + assert result['presigned_url_expires_in'] == 7200 + mock_get_file_url.assert_called_once_with('attachments/test.txt', 'bucket', 7200) + class TestDownloadFile: """Test cases for download_file function""" From 1d0ee63cae2609078b192850e692e5bc76bf7ee7 Mon Sep 17 00:00:00 2001 From: hhhhsc701 <56435672+hhhhsc701@users.noreply.github.com> Date: Fri, 24 Apr 2026 14:41:42 +0800 Subject: [PATCH 005/156] Refactor OAuth implementation and enhance account linking features MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * openspec初始化 * oauth spec开发结果 * oauth 单元测试 * oauth 重定向修复 * oauth 重定向修复 * oauth 重定向修复 * oauth 抽象实现 * gde provider * gde provider * enhance unlink_account logic to check for password authentication before unlinking * refactor OAuthAccountsSection to load enabled providers and improve account unlinking logic * add OAuth linking functionality with state management and error handling * refactor OAuth account deletion logic to use direct deletion and update related tests * update GDE OAuth configuration to use environment variables for URLs and client IDs * add SSL verification configuration for OAuth requests and update context handling * remove hardcoded OAuth credentials from const.py and update .env.example * remove avatar_url references from user info handling and update email fallback logic * refactor user identity handling in OAuth account unlinking logic * update OAuthAccountsSection to simplify display logic for linked accounts * refactor OAuth user binding logic to check for existing accounts before creating new users * 删除冗余文件 * 删除冗余文件 * add user OAuth account table and update trigger for third-party logins * 修复单元测试 * 删除冗余代码 * k8s同步oauth配置 * 软删除时需添加delete_flag="Y"的筛选条件 * 用户删除的时候将oauth表中delete_flag设置为Y * 优化import * 移除无用的rebind_oauth_account函数调用,并在用户已绑定其他账户时抛出OAuthLinkError * clean code * 补充ut * 补充单元测试 --- backend/apps/config_app.py | 7 +- backend/apps/oauth_app.py | 290 ++++++ backend/consts/const.py | 6 + backend/consts/error_code.py | 24 + backend/consts/exceptions.py | 59 +- backend/consts/model.py | 46 + backend/consts/oauth_providers.py | 113 +++ backend/database/db_models.py | 26 + backend/database/oauth_account_db.py | 220 +++++ backend/pyproject.toml | 2 + backend/services/oauth_service.py | 343 +++++++ backend/services/user_service.py | 10 +- backend/utils/auth_utils.py | 86 +- docker/.env.example | 17 + docker/init.sql | 54 ++ .../v2.0.3_0430_add_user_oauth_account_t.sql | 52 ++ .../users/components/UserProfileComp.tsx | 6 + frontend/components/auth/loginModal.tsx | 68 +- .../settings/OAuthAccountsSection.tsx | 143 +++ frontend/hooks/auth/useAuthenticationUI.ts | 19 +- frontend/public/locales/en/common.json | 10 + frontend/public/locales/zh/common.json | 10 + frontend/server.js | 19 +- frontend/services/api.ts | 7 + frontend/services/oauthService.ts | 69 ++ .../charts/nexent-common/files/init.sql | 53 ++ .../nexent-common/templates/configmap.yaml | 18 + .../nexent/charts/nexent-common/values.yaml | 12 + test/backend/app/test_oauth_app.py | 767 ++++++++++++++++ .../backend/database/test_oauth_account_db.py | 360 ++++++++ test/backend/services/test_oauth_service.py | 844 ++++++++++++++++++ 31 files changed, 3724 insertions(+), 36 deletions(-) create mode 100644 backend/apps/oauth_app.py create mode 100644 backend/consts/oauth_providers.py create mode 100644 backend/database/oauth_account_db.py create mode 100644 backend/services/oauth_service.py create mode 100644 docker/sql/v2.0.3_0430_add_user_oauth_account_t.sql create mode 100644 frontend/components/settings/OAuthAccountsSection.tsx create mode 100644 frontend/services/oauthService.ts create mode 100644 test/backend/app/test_oauth_app.py create mode 100644 test/backend/database/test_oauth_account_db.py create mode 100644 test/backend/services/test_oauth_service.py diff --git a/backend/apps/config_app.py b/backend/apps/config_app.py index fc6267555..c89a1a1a2 100644 --- a/backend/apps/config_app.py +++ b/backend/apps/config_app.py @@ -7,11 +7,14 @@ from apps.vectordatabase_app import router as vectordatabase_router from apps.dify_app import router as dify_router from apps.idata_app import router as idata_router -from apps.file_management_app import file_management_config_router as file_manager_router +from apps.file_management_app import ( + file_management_config_router as file_manager_router, +) from apps.image_app import router as proxy_router from apps.knowledge_summary_app import router as summary_router from apps.mock_user_management_app import router as mock_user_management_router from apps.model_managment_app import router as model_manager_router +from apps.oauth_app import router as oauth_router from apps.prompt_app import router as prompt_router from apps.remote_mcp_app import router as remote_mcp_router from apps.skill_app import router as skill_router @@ -53,6 +56,8 @@ logger.info("Normal mode - using real user management router") app.include_router(user_management_router) +app.include_router(oauth_router) + app.include_router(summary_router) app.include_router(prompt_router) app.include_router(skill_router) diff --git a/backend/apps/oauth_app.py b/backend/apps/oauth_app.py new file mode 100644 index 000000000..bda69f935 --- /dev/null +++ b/backend/apps/oauth_app.py @@ -0,0 +1,290 @@ +import logging + +from fastapi import APIRouter, Header, HTTPException +from fastapi.responses import JSONResponse, RedirectResponse +from http import HTTPStatus +from typing import Optional + +from consts.exceptions import OAuthLinkError, OAuthProviderError, UnauthorizedError +from consts.oauth_providers import get_all_provider_definitions +from database.oauth_account_db import get_oauth_account_by_provider +from services.oauth_service import ( + create_or_update_oauth_account, + ensure_user_tenant_exists, + exchange_code_for_provider_token, + get_authorize_url, + get_enabled_providers, + get_provider_user_info, + list_linked_accounts, + unlink_account, parse_state, +) +from utils.auth_utils import ( + calculate_expires_at, + generate_session_jwt, + get_current_user_id, get_supabase_admin_client, +) + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/user/oauth", tags=["oauth"]) + + +@router.get("/providers") +async def get_providers(): + providers = get_enabled_providers() + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": providers}, + ) + + +@router.get("/authorize") +async def authorize(provider: str): + try: + url = get_authorize_url(provider) + return RedirectResponse(url=url, status_code=HTTPStatus.FOUND) + except OAuthProviderError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except Exception as e: + logger.error(f"OAuth authorize failed: {e}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="OAuth authorization failed", + ) + + +@router.get("/link") +async def link(provider: str, authorization: Optional[str] = Header(None)): + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Not logged in") + + try: + user_id, _ = get_current_user_id(authorization) + url = get_authorize_url(provider, link_user_id=user_id) + return RedirectResponse(url=url, status_code=HTTPStatus.FOUND) + except UnauthorizedError: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Not logged in") + except OAuthProviderError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except Exception as e: + logger.error(f"OAuth link failed: {e}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="OAuth link failed", + ) + + +@router.get("/callback") +async def callback( + provider: str, + code: str = "", + state: str = "", + error: Optional[str] = None, + error_description: Optional[str] = None, +): + if error: + return JSONResponse( + status_code=HTTPStatus.BAD_REQUEST, + content={ + "message": "OAuth provider returned an error", + "data": { + "oauth_error": error, + "oauth_error_description": error_description or "Unknown error", + }, + }, + ) + + if not code: + return JSONResponse( + status_code=HTTPStatus.BAD_REQUEST, + content={ + "message": "No authorization code received", + "data": { + "oauth_error": "no_code", + "oauth_error_description": "No authorization code received", + }, + }, + ) + + if provider not in get_all_provider_definitions(): + return JSONResponse( + status_code=HTTPStatus.BAD_REQUEST, + content={ + "message": "Unsupported OAuth provider", + "data": { + "oauth_error": "unsupported_provider", + "oauth_error_description": f"Provider '{provider}' is not supported", + }, + }, + ) + + state_info = parse_state(state) + link_user_id = state_info.get("link_user_id", "") + + try: + token_data = exchange_code_for_provider_token(provider, code) + provider_access_token = token_data["access_token"] + + user_info = get_provider_user_info( + provider, + provider_access_token, + openid=token_data.get("openid", ""), + ) + + provider_user_id = user_info["id"] + email = user_info["email"] + username = user_info["username"] + + if link_user_id: + supabase_user_id = link_user_id + else: + # First check if this OAuth account is already bound to a user + existing_binding = get_oauth_account_by_provider(provider, provider_user_id) + if existing_binding: + supabase_user_id = existing_binding["user_id"] + else: + # No binding found, search/create user by email in Supabase + admin_client = get_supabase_admin_client() + if not admin_client: + raise RuntimeError("Supabase admin client not available") + + supabase_user_id = None + page = 1 + while True: + users_resp = admin_client.auth.admin.list_users( + page=page, per_page=100 + ) + users = users_resp if len(users_resp) > 0 else [] + if not users: + break + for u in users: + if u.email and u.email.lower() == email.lower(): + supabase_user_id = u.id + break + if supabase_user_id: + break + if len(users) < 100: + break + page += 1 + + if not supabase_user_id: + if not email: + email = f"{provider}_{provider_user_id}@oauth.nexent" + create_resp = admin_client.auth.admin.create_user( + { + "email": email, + "email_confirm": True, + "user_metadata": { + "full_name": username, + "provider": provider, + }, + } + ) + supabase_user_id = create_resp.user.id + + ensure_user_tenant_exists(user_id=supabase_user_id, email=email) + + create_or_update_oauth_account( + user_id=supabase_user_id, + provider=provider, + provider_user_id=provider_user_id, + email=email, + username=username, + ) + + expiry_seconds = 3600 + jwt_token = generate_session_jwt(supabase_user_id, expires_in=expiry_seconds) + expires_at = calculate_expires_at(jwt_token) + + return JSONResponse( + status_code=HTTPStatus.OK, + content={ + "message": "OAuth login successful", + "data": { + "user": { + "id": str(supabase_user_id), + "email": email, + }, + "session": { + "access_token": jwt_token, + "refresh_token": "", + "expires_at": expires_at, + "expires_in_seconds": expiry_seconds, + }, + }, + }, + ) + + except Exception as e: + logger.error(f"OAuth callback failed for provider={provider}: {e}") + return JSONResponse( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + content={ + "message": "OAuth login failed", + "data": { + "oauth_error": "callback_failed", + "oauth_error_description": "OAuth login failed", + }, + }, + ) + + +@router.get("/accounts") +async def get_accounts(authorization: Optional[str] = Header(None)): + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Not logged in") + + try: + user_id, _ = get_current_user_id(authorization) + accounts = list_linked_accounts(user_id) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": accounts}, + ) + except UnauthorizedError: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Not logged in") + except Exception as e: + logger.error(f"Failed to get OAuth accounts: {e}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Failed to get OAuth accounts", + ) + + +@router.delete("/accounts/{provider}") +async def delete_account(provider: str, authorization: Optional[str] = Header(None)): + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Not logged in") + + try: + user_id, _ = get_current_user_id(authorization) + + has_password_auth = False + + admin_client = get_supabase_admin_client() + if admin_client: + try: + user_resp = admin_client.auth.admin.get_user_by_id(user_id) + user_metadata = getattr(user_resp.user, "user_metadata", {}) or {} + signup_provider = user_metadata.get("provider", "email") + has_password_auth = signup_provider == "email" + except Exception as e: + logger.warning(f"Failed to check user identities for {user_id}: {e}") + + unlink_account(user_id, provider, has_password_auth=has_password_auth) + return JSONResponse( + status_code=HTTPStatus.OK, + content={ + "message": "success", + "data": {"provider": provider, "unlinked": True}, + }, + ) + except OAuthLinkError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except UnauthorizedError: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Not logged in") + except Exception as e: + logger.error(f"Failed to unlink OAuth account: {e}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Failed to unlink OAuth account", + ) diff --git a/backend/consts/const.py b/backend/consts/const.py index bccb91ccd..796db4987 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -69,6 +69,12 @@ class VectorDatabaseType(str, Enum): SUPABASE_JWT_SECRET = os.getenv('SUPABASE_JWT_SECRET') or os.getenv('JWT_SECRET', '') +# OAuth Configuration +OAUTH_CALLBACK_BASE_URL = os.getenv("OAUTH_CALLBACK_BASE_URL", "") +OAUTH_SSL_VERIFY = os.getenv("OAUTH_SSL_VERIFY", "true").lower() == "true" +OAUTH_CA_BUNDLE = os.getenv("OAUTH_CA_BUNDLE", "") + + # ===== To be migrated to frontend configuration ===== # Email Configuration IMAP_SERVER = os.getenv('IMAP_SERVER') diff --git a/backend/consts/error_code.py b/backend/consts/error_code.py index 072243de4..4b4792e47 100644 --- a/backend/consts/error_code.py +++ b/backend/consts/error_code.py @@ -142,6 +142,20 @@ class ErrorCode(Enum): PROFILE_USER_ALREADY_EXISTS = "110103" # User already exists PROFILE_INVALID_CREDENTIALS = "110104" # Invalid credentials + # ==================== 16 OAuth / 第三方登录 ==================== + # 01 - Provider + OAUTH_PROVIDER_NOT_CONFIGURED = "160101" # OAuth provider not configured + OAUTH_PROVIDER_DISABLED = "160102" # OAuth provider disabled + OAUTH_PROVIDER_UNSUPPORTED = "160103" # OAuth provider not supported + OAUTH_PROVIDER_ERROR = "160104" # OAuth provider returned an error + + # 02 - Account Linking + OAUTH_LINK_FAILED = "160201" # Failed to link OAuth account + OAUTH_UNLINK_FAILED = "160202" # Failed to unlink OAuth account + OAUTH_UNLINK_LAST_METHOD = "160203" # Cannot unlink last auth method + OAUTH_ACCOUNT_NOT_FOUND = "160204" # OAuth account link not found + OAUTH_ACCOUNT_ALREADY_LINKED = "160205" # OAuth account already linked + # ==================== 12 TenantResource / 租户资源 ==================== # 01 - Tenant TENANT_NOT_FOUND = "120101" # Tenant not found @@ -237,4 +251,14 @@ class ErrorCode(Enum): ErrorCode.IDATA_CONNECTION_ERROR: 502, ErrorCode.IDATA_RESPONSE_ERROR: 502, ErrorCode.IDATA_RATE_LIMIT: 429, + # OAuth (module 16) + ErrorCode.OAUTH_PROVIDER_NOT_CONFIGURED: 400, + ErrorCode.OAUTH_PROVIDER_DISABLED: 400, + ErrorCode.OAUTH_PROVIDER_UNSUPPORTED: 400, + ErrorCode.OAUTH_PROVIDER_ERROR: 502, + ErrorCode.OAUTH_LINK_FAILED: 500, + ErrorCode.OAUTH_UNLINK_FAILED: 500, + ErrorCode.OAUTH_UNLINK_LAST_METHOD: 400, + ErrorCode.OAUTH_ACCOUNT_NOT_FOUND: 404, + ErrorCode.OAUTH_ACCOUNT_ALREADY_LINKED: 409, } diff --git a/backend/consts/exceptions.py b/backend/consts/exceptions.py index 074b4a5b0..9481ebab2 100644 --- a/backend/consts/exceptions.py +++ b/backend/consts/exceptions.py @@ -6,13 +6,13 @@ 1. New Framework (with ErrorCode): from consts.error_code import ErrorCode from consts.exceptions import AppException - + raise AppException(ErrorCode.COMMON_VALIDATION_ERROR, "Validation failed") raise AppException(ErrorCode.MCP_CONNECTION_FAILED, "Connection timeout", details={"host": "localhost"}) 2. Legacy Framework (simple exceptions): from consts.exceptions import ValidationError, NotFoundException, MCPConnectionError - + raise ValidationError("Tenant name cannot be empty") raise NotFoundException("Tenant 123 not found") raise MCPConnectionError("MCP connection failed") @@ -26,6 +26,7 @@ # ==================== New Framework: AppException with ErrorCode ==================== + class AppException(Exception): """ Base application exception with ErrorCode. @@ -35,7 +36,9 @@ class AppException(Exception): raise AppException(ErrorCode.MCP_CONNECTION_FAILED, "Timeout", details={"host": "x"}) """ - def __init__(self, error_code: ErrorCode, message: str = None, details: dict = None): + def __init__( + self, error_code: ErrorCode, message: str = None, details: dict = None + ): self.error_code = error_code self.message = message or ErrorMessage.get_message(error_code) self.details = details or {} @@ -43,9 +46,11 @@ def __init__(self, error_code: ErrorCode, message: str = None, details: dict = N def to_dict(self) -> dict: return { - "code": str(self.error_code.value), # Keep as string to preserve leading zeros + "code": str( + self.error_code.value + ), # Keep as string to preserve leading zeros "message": self.message, - "details": self.details if self.details else None + "details": self.details if self.details else None, } @property @@ -70,133 +75,172 @@ def raise_error(error_code: ErrorCode, message: str = None, details: dict = None # These do NOT require ErrorCode - they are simple Exception subclasses. # Exception handler will infer ErrorCode from class name. + class AgentRunException(Exception): """Exception raised when agent run fails.""" + pass class LimitExceededError(Exception): """Raised when an outer platform calling too frequently""" + pass class UnauthorizedError(Exception): """Raised when a user from outer platform is unauthorized.""" + pass class SignatureValidationError(Exception): """Raised when X-Signature header is missing or does not match the expected HMAC value.""" + pass class MemoryPreparationException(Exception): """Raised when memory preprocessing or retrieval fails prior to agent run.""" + pass class MCPConnectionError(Exception): """Raised when MCP connection fails.""" + pass class MCPNameIllegal(Exception): """Raised when MCP name is illegal.""" + pass class NoInviteCodeException(Exception): """Raised when invite code is not found.""" + pass class IncorrectInviteCodeException(Exception): """Raised when invite code is incorrect.""" + pass class OfficeConversionException(Exception): """Raised when Office-to-PDF conversion via data-process service fails.""" + pass class UnsupportedFileTypeException(Exception): """Raised when a file type is not supported for the requested operation.""" + pass class FileTooLargeException(Exception): """Raised when a file exceeds the maximum allowed size for the requested operation.""" + pass class UserRegistrationException(Exception): """Raised when user registration fails.""" + pass class TimeoutException(Exception): """Raised when timeout occurs.""" + pass class ValidationError(Exception): """Raised when validation fails.""" + pass class NotFoundException(Exception): """Raised when not found exception occurs.""" + pass class MEConnectionException(Exception): """Raised when ME connection fails.""" + pass class VoiceServiceException(Exception): """Raised when voice service fails.""" + pass class STTConnectionException(Exception): """Raised when STT service connection fails.""" + pass class TTSConnectionException(Exception): """Raised when TTS service connection fails.""" + pass class VoiceConfigException(Exception): """Raised when voice configuration is invalid.""" + pass class ToolExecutionException(Exception): """Raised when mcp tool execution failed.""" + pass class MCPContainerError(Exception): """Raised when MCP container operation fails.""" + pass class DuplicateError(Exception): """Raised when a duplicate resource already exists.""" + pass class DataMateConnectionError(Exception): """Raised when DataMate connection fails or URL is not configured.""" + pass class SkillException(Exception): """Raised when skill operations fail.""" + + pass + + +class OAuthProviderError(Exception): + """Raised when OAuth provider configuration is invalid or provider returns an error.""" + + pass + + +class OAuthLinkError(Exception): + """Raised when linking or unlinking an OAuth account fails.""" + pass @@ -251,5 +295,10 @@ class UnsupportedOperationError(Exception): DifyServiceException = Exception # Generic fallback ExternalAPIError = Exception # Generic fallback +# OAuth aliases +OAuthProviderNotConfiguredError = OAuthProviderError +OAuthProviderDisabledError = OAuthProviderError +OAuthAccountNotFoundError = NotFoundException + # Signature aliases # SignatureValidationError already defined above diff --git a/backend/consts/model.py b/backend/consts/model.py index 05e6426b2..59653a257 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -52,6 +52,52 @@ class UserDeleteRequest(BaseModel): new_owner_id: Optional[str] = None +class OAuthProviderDefinition(BaseModel): + name: str + display_name: str + icon: str + + authorize_url: str + authorize_method: str = "GET" + authorize_params: Dict[str, str] = {} + authorize_fragment: str = "" + authorize_param_map: Dict[str, str] = { + "client_id": "client_id", + "redirect_uri": "redirect_uri", + "scope": "scope", + "state": "state", + } + encode_redirect_uri: bool = False + + token_url: str + token_method: str = "POST" + token_params_map: Dict[str, str] = { + "client_id": "client_id", + "client_secret": "client_secret", + "code": "code", + "grant_type": "grant_type", + } + token_extra_params: Dict[str, str] = {} + token_error_key: Optional[str] = None + token_error_message_key: Optional[str] = None + token_response_id_key: Optional[str] = None + + userinfo_url: str + userinfo_auth_scheme: str = "Bearer" + userinfo_params: Dict[str, str] = {} + userinfo_field_map: Dict[str, str] = { + "id": "id", + "email": "email", + "username": "login", + } + userinfo_needs_email_fetch: bool = False + userinfo_email_url: Optional[str] = None + + client_id_env: str + client_secret_env: str + enabled_check: Optional[str] = None + + # Response models for model management class ModelResponse(BaseModel): code: int = 200 diff --git a/backend/consts/oauth_providers.py b/backend/consts/oauth_providers.py new file mode 100644 index 000000000..2dd01f0d6 --- /dev/null +++ b/backend/consts/oauth_providers.py @@ -0,0 +1,113 @@ +import os +from typing import Dict + +from consts.model import OAuthProviderDefinition + +GITHUB_PROVIDER = OAuthProviderDefinition( + name="github", + display_name="GitHub", + icon="github", + authorize_url="https://github.com/login/oauth/authorize", + authorize_params={"scope": "read:user user:email"}, + token_url="https://github.com/login/oauth/access_token", + token_error_key="error", + token_error_message_key="error_description", + userinfo_url="https://api.github.com/user", + userinfo_field_map={ + "id": "id", + "email": "email", + "username": "login", + }, + userinfo_needs_email_fetch=True, + userinfo_email_url="https://api.github.com/user/emails", + client_id_env="GITHUB_OAUTH_CLIENT_ID", + client_secret_env="GITHUB_OAUTH_CLIENT_SECRET", +) + +GDE_PROVIDER = OAuthProviderDefinition( + name="gde", + display_name="Gde", + icon="gde", + authorize_url=f"{os.getenv('GDE_URL')}/dspcas/oauth2.0/authorize", + authorize_param_map={"client_id": "client_id", "redirect_uri": "redirect_uri"}, + token_url=f"{os.getenv('GDE_URL')}/dspcas/v2/oauth2.0/accessToken", + token_params_map={ + "client_id": "client_id", + "client_secret": "secret", + "code": "code", + "grant_type": "grant_type", + "redirect_uri": "redirect_uri", + }, + token_error_key="errorCode", + token_error_message_key="errorMessage", + userinfo_url=f"{os.getenv('GDE_URL')}/dspcas/oauth2.0/profile", + userinfo_params={"access_token": "{access_token}"}, + userinfo_field_map={"id": "attributes.userId", "username": "id"}, + client_id_env="GDE_OAUTH_CLIENT_ID", + client_secret_env="GDE_OAUTH_CLIENT_SECRET", +) + +WECHAT_PROVIDER = OAuthProviderDefinition( + name="wechat", + display_name="WeChat", + icon="wechat", + authorize_url="https://open.weixin.qq.com/connect/qrconnect", + authorize_params={"response_type": "code", "scope": "snsapi_login"}, + authorize_fragment="#wechat_redirect", + authorize_param_map={ + "client_id": "appid", + "redirect_uri": "redirect_uri", + "scope": "scope", + "state": "state", + }, + encode_redirect_uri=True, + token_url="https://api.weixin.qq.com/sns/oauth2/access_token", + token_method="GET", + token_params_map={ + "client_id": "appid", + "client_secret": "secret", + "code": "code", + "grant_type": "grant_type", + }, + token_error_key="errcode", + token_error_message_key="errmsg", + token_response_id_key="openid", + userinfo_url="https://api.weixin.qq.com/sns/userinfo", + userinfo_auth_scheme="", + userinfo_params={"openid": "{openid}"}, + userinfo_field_map={ + "id": "openid", + "email": "", + "username": "nickname", + }, + client_id_env="WECHAT_OAUTH_APP_ID", + client_secret_env="WECHAT_OAUTH_APP_SECRET", + enabled_check="ENABLE_WECHAT_OAUTH", +) + +OAUTH_PROVIDER_REGISTRY: Dict[str, OAuthProviderDefinition] = { + "github": GITHUB_PROVIDER, + "wechat": WECHAT_PROVIDER, + "gde": GDE_PROVIDER, +} + + +def get_provider_definition(provider: str) -> OAuthProviderDefinition: + return OAUTH_PROVIDER_REGISTRY[provider] + + +def is_provider_enabled(definition: OAuthProviderDefinition) -> bool: + if definition.enabled_check: + return os.getenv(definition.enabled_check, "false").lower() in ( + "true", + "1", + "yes", + ) + + client_id = os.getenv(definition.client_id_env, "") + client_secret = os.getenv(definition.client_secret_env, "") + return bool(client_id and client_secret) + + +def get_all_provider_definitions() -> Dict[str, OAuthProviderDefinition]: + return dict(OAUTH_PROVIDER_REGISTRY) diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 3741dd559..88c0874ee 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -523,6 +523,32 @@ class UserTokenUsageLog(TableBase): meta_data = Column(JSONB, doc="Additional metadata for this usage log entry, stored as JSON") +class UserOAuthAccount(TableBase): + __tablename__ = "user_oauth_account_t" + __table_args__ = ( + UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"), + {"schema": SCHEMA}, + ) + + oauth_account_id = Column( + Integer, + Sequence("user_oauth_account_t_oauth_account_id_seq", schema=SCHEMA), + primary_key=True, + nullable=False, + doc="OAuth account ID, primary key", + ) + user_id = Column(String(100), nullable=False, doc="Supabase user UUID") + provider = Column( + String(30), nullable=False, doc="OAuth provider name: github, wechat" + ) + provider_user_id = Column( + String(200), nullable=False, doc="User ID from the OAuth provider" + ) + provider_email = Column(String(255), doc="Email address from the OAuth provider") + provider_username = Column(String(200), doc="Display name from the OAuth provider") + tenant_id = Column(String(100), doc="Tenant ID at time of linking") + + class SkillInfo(TableBase): """ Skill information table - stores skill metadata and content. diff --git a/backend/database/oauth_account_db.py b/backend/database/oauth_account_db.py new file mode 100644 index 000000000..3b798f738 --- /dev/null +++ b/backend/database/oauth_account_db.py @@ -0,0 +1,220 @@ +""" +Database operations for OAuth account management +""" + +import logging +from typing import Any, Dict, List, Optional + +from database.client import as_dict, get_db_session +from database.db_models import UserOAuthAccount + +logger = logging.getLogger(__name__) + + +def insert_oauth_account( + user_id: str, + provider: str, + provider_user_id: str, + provider_email: Optional[str] = None, + provider_username: Optional[str] = None, + tenant_id: Optional[str] = None, +) -> Dict[str, Any]: + with get_db_session() as session: + account = UserOAuthAccount( + user_id=user_id, + provider=provider, + provider_user_id=provider_user_id, + provider_email=provider_email, + provider_username=provider_username, + tenant_id=tenant_id, + created_by=user_id, + updated_by=user_id, + ) + session.add(account) + session.flush() + return as_dict(account) + + +def get_oauth_account_by_provider( + provider: str, provider_user_id: str +) -> Optional[Dict[str, Any]]: + with get_db_session() as session: + result = ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.provider == provider, + UserOAuthAccount.provider_user_id == provider_user_id, + UserOAuthAccount.delete_flag == "N", + ) + .first() + ) + return as_dict(result) if result else None + + +def get_soft_deleted_oauth_account( + provider: str, provider_user_id: str +) -> Optional[Dict[str, Any]]: + with get_db_session() as session: + result = ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.provider == provider, + UserOAuthAccount.provider_user_id == provider_user_id, + UserOAuthAccount.delete_flag == "Y", + ) + .first() + ) + return as_dict(result) if result else None + + +def list_oauth_accounts_by_user_id(user_id: str) -> List[Dict[str, Any]]: + with get_db_session() as session: + results = ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.user_id == user_id, + UserOAuthAccount.delete_flag == "N", + ) + .all() + ) + return [as_dict(r) for r in results] + + +def rebind_oauth_account( + provider: str, + provider_user_id: str, + new_user_id: str, + provider_email: Optional[str] = None, + provider_username: Optional[str] = None, + tenant_id: Optional[str] = None, +) -> bool: + with get_db_session() as session: + result = ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.provider == provider, + UserOAuthAccount.provider_user_id == provider_user_id, + UserOAuthAccount.delete_flag == "N", + ) + .first() + ) + if not result: + return False + + result.user_id = new_user_id + result.updated_by = new_user_id + if provider_email is not None: + result.provider_email = provider_email + if provider_username is not None: + result.provider_username = provider_username + if tenant_id is not None: + result.tenant_id = tenant_id + + return True + + +def update_oauth_account_tokens( + provider: str, + provider_user_id: str, + provider_username: Optional[str] = None, +) -> bool: + with get_db_session() as session: + result = ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.provider == provider, + UserOAuthAccount.provider_user_id == provider_user_id, + UserOAuthAccount.delete_flag == "N", + ) + .first() + ) + if not result: + return False + + if provider_username is not None: + result.provider_username = provider_username + + return True + + +def delete_oauth_account(user_id: str, provider: str) -> bool: + with get_db_session() as session: + result = ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.user_id == user_id, + UserOAuthAccount.provider == provider, + UserOAuthAccount.delete_flag == "N", + ) + .first() + ) + if not result: + return False + + result.delete_flag = "Y" + result.updated_by = user_id + return True + + +def reactivate_oauth_account( + provider: str, + provider_user_id: str, + user_id: str, + provider_email: Optional[str] = None, + provider_username: Optional[str] = None, + tenant_id: Optional[str] = None, +) -> bool: + with get_db_session() as session: + result = ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.provider == provider, + UserOAuthAccount.provider_user_id == provider_user_id, + UserOAuthAccount.delete_flag == "Y", + ) + .first() + ) + if not result: + return False + + result.delete_flag = "N" + result.user_id = user_id + result.updated_by = user_id + if provider_email is not None: + result.provider_email = provider_email + if provider_username is not None: + result.provider_username = provider_username + if tenant_id is not None: + result.tenant_id = tenant_id + + return True + + +def count_oauth_accounts_by_user_id(user_id: str) -> int: + with get_db_session() as session: + return ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.user_id == user_id, + UserOAuthAccount.delete_flag == "N", + ) + .count() + ) + + +def soft_delete_all_oauth_accounts_by_user_id(user_id: str, deleted_by: str) -> int: + with get_db_session() as session: + result = ( + session.query(UserOAuthAccount) + .filter( + UserOAuthAccount.user_id == user_id, + UserOAuthAccount.delete_flag == "N", + ) + .all() + ) + count = 0 + for account in result: + account.delete_flag = "Y" + account.updated_by = deleted_by + count += 1 + return count \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 04b94589c..3ac17b7be 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -6,6 +6,8 @@ dependencies = [ "uvicorn>=0.34.0", "fastapi>=0.115.12", "aiohttp>=3.8.0", + "authlib>=1.3.0", + "cryptography>=42.0.0", "psycopg2-binary==2.9.10", "PyJWT>=2.8.0", "sqlalchemy~=2.0.37", diff --git a/backend/services/oauth_service.py b/backend/services/oauth_service.py new file mode 100644 index 000000000..0083ad9ec --- /dev/null +++ b/backend/services/oauth_service.py @@ -0,0 +1,343 @@ +import json +import logging +import os +import secrets +import ssl +import urllib.request +from typing import Any, Dict, List, Optional +from urllib.parse import urlencode, quote + +from consts.const import ( + DEFAULT_TENANT_ID, + OAUTH_CALLBACK_BASE_URL, + OAUTH_SSL_VERIFY, + OAUTH_CA_BUNDLE, +) +from consts.exceptions import OAuthLinkError, OAuthProviderError +from consts.oauth_providers import ( + get_all_provider_definitions, + get_provider_definition, + is_provider_enabled, +) +from database.oauth_account_db import ( + count_oauth_accounts_by_user_id, + delete_oauth_account, + get_oauth_account_by_provider, + get_soft_deleted_oauth_account, + insert_oauth_account, + list_oauth_accounts_by_user_id, + reactivate_oauth_account, + update_oauth_account_tokens, +) +from database.user_tenant_db import get_user_tenant_by_user_id, insert_user_tenant + +logger = logging.getLogger(__name__) + + +def _build_ssl_context() -> ssl.SSLContext: + if OAUTH_CA_BUNDLE and os.path.isfile(OAUTH_CA_BUNDLE): + return ssl.create_default_context(cafile=OAUTH_CA_BUNDLE) + if not OAUTH_SSL_VERIFY: + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx + return ssl.create_default_context() + + +_SSL_CTX = _build_ssl_context() + + +def parse_state(state: str) -> Dict[str, str]: + parts = state.split(":", 2) + if len(parts) >= 2: + return { + "provider": parts[0], + "token": parts[1], + "link_user_id": parts[2] if len(parts) > 2 else "", + } + return {"provider": state, "token": "", "link_user_id": ""} + + +def _resolve_field(data: dict, field_path: str) -> Any: + if "." not in field_path: + return data.get(field_path) + parts = field_path.split(".") + current = data + for part in parts: + if isinstance(current, dict): + current = current.get(part) + else: + return None + return current + + +def get_supported_providers() -> set: + return set(get_all_provider_definitions().keys()) + + +def get_enabled_providers() -> List[Dict[str, str]]: + providers = [] + for name, definition in get_all_provider_definitions().items(): + if is_provider_enabled(definition): + providers.append( + { + "name": definition.name, + "display_name": definition.display_name, + "icon": definition.icon, + "enabled": True, + } + ) + return providers + + +def get_authorize_url(provider: str, link_user_id: str = "") -> str: + try: + definition = get_provider_definition(provider) + except KeyError: + raise OAuthProviderError(f"Unsupported OAuth provider: {provider}") + + if not is_provider_enabled(definition): + raise OAuthProviderError(f"OAuth provider '{provider}' is not configured") + + callback_url = ( + f"{OAUTH_CALLBACK_BASE_URL}/api/user/oauth/callback?provider={provider}" + ) + random_token = secrets.token_urlsafe(32) + if link_user_id: + state = f"{provider}:{random_token}:{link_user_id}" + else: + state = f"{provider}:{random_token}" + + client_id = os.getenv(definition.client_id_env, "") + redirect_uri = ( + quote(callback_url, safe="") if definition.encode_redirect_uri else callback_url + ) + + params = dict(definition.authorize_params) + param_map = definition.authorize_param_map + params[param_map.get("client_id", "client_id")] = client_id + params[param_map.get("redirect_uri", "redirect_uri")] = redirect_uri + params[param_map.get("state", "state")] = state + + url = f"{definition.authorize_url}?{urlencode(params)}" + if definition.authorize_fragment: + url += definition.authorize_fragment + return url + + +def _http_post_json(url: str, data: dict, headers: Optional[dict] = None) -> dict: + req_data = json.dumps(data).encode("utf-8") + req_headers = {"Content-Type": "application/json", "Accept": "application/json"} + if headers: + req_headers.update(headers) + req = urllib.request.Request(url, data=req_data, headers=req_headers, method="POST") + with urllib.request.urlopen(req, timeout=15, context=_SSL_CTX) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def _http_get_json(url: str, headers: Optional[dict] = None) -> dict: + req = urllib.request.Request(url, headers=headers or {}) + with urllib.request.urlopen(req, timeout=15, context=_SSL_CTX) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def exchange_code_for_provider_token(provider: str, code: str) -> Dict[str, Any]: + try: + definition = get_provider_definition(provider) + except KeyError: + raise OAuthProviderError(f"Unsupported provider: {provider}") + + client_id = os.getenv(definition.client_id_env, "") + client_secret = os.getenv(definition.client_secret_env, "") + callback_url = ( + f"{OAUTH_CALLBACK_BASE_URL}/api/user/oauth/callback?provider={provider}" + ) + redirect_uri = ( + quote(callback_url, safe="") if definition.encode_redirect_uri else callback_url + ) + + param_map = definition.token_params_map + + result: Dict[str, Any] = {"access_token": ""} + + if definition.token_method.upper() == "POST": + body = dict(definition.token_extra_params) + body[param_map.get("client_id", "client_id")] = client_id + body[param_map.get("client_secret", "client_secret")] = client_secret + body[param_map.get("code", "code")] = code + body.setdefault(param_map.get("grant_type", "grant_type"), "authorization_code") + if param_map.get("redirect_uri", "") == "redirect_uri": + body["redirect_uri"] = redirect_uri + + resp = _http_post_json(definition.token_url, data=body) + else: + params = dict(definition.token_extra_params) + params[param_map.get("client_id", "client_id")] = client_id + params[param_map.get("client_secret", "client_secret")] = client_secret + params[param_map.get("code", "code")] = code + params[param_map.get("grant_type", "grant_type")] = "authorization_code" + if param_map.get("redirect_uri", "") == "redirect_uri": + params["redirect_uri"] = redirect_uri + + resp = _http_get_json(f"{definition.token_url}?{urlencode(params)}") + + if definition.token_error_key and definition.token_error_key in resp: + err_msg = resp.get( + definition.token_error_message_key, str(resp[definition.token_error_key]) + ) + raise OAuthProviderError(f"{provider} token exchange failed: {err_msg}") + + result["access_token"] = resp["access_token"] + if definition.token_response_id_key: + result["openid"] = resp.get(definition.token_response_id_key, "") + + return result + + +def get_provider_user_info( + provider: str, access_token: str, **kwargs: Any +) -> Dict[str, Any]: + try: + definition = get_provider_definition(provider) + except KeyError: + raise OAuthProviderError(f"Unsupported provider: {provider}") + + headers: Dict[str, str] = {"Accept": "application/json"} + if definition.userinfo_auth_scheme and access_token: + headers["Authorization"] = f"{definition.userinfo_auth_scheme} {access_token}" + + url_params = {} + for key, value in definition.userinfo_params.items(): + resolved = value.format( + openid=kwargs.get("openid", ""), access_token=access_token + ) + url_params[key] = resolved + + query = urlencode(url_params) if url_params else "" + separator = ( + "&" if "?" in definition.userinfo_url and query else ("?" if query else "") + ) + url = f"{definition.userinfo_url}{separator}{query}" + + user_resp = _http_get_json(url, headers=headers) + + field_map = definition.userinfo_field_map + result = {} + for our_key, provider_key in field_map.items(): + if provider_key: + result[our_key] = _resolve_field(user_resp, provider_key) or "" + else: + result[our_key] = "" + result["id"] = str(result.get("id", "")) + + if definition.userinfo_needs_email_fetch and not result.get("email"): + try: + emails_resp = _http_get_json( + definition.userinfo_email_url, + headers={"Authorization": f"Bearer {access_token}"}, + ) + if isinstance(emails_resp, list) and emails_resp: + primary = next( + (e for e in emails_resp if e.get("primary")), + emails_resp[0], + ) + result["email"] = primary.get("email", "") + except Exception: + logger.warning(f"Failed to fetch {provider} user emails") + + if result.get("email", "") == "": + result["email"] = f"{result['username']}@nexent.com" + + return result + + +def create_or_update_oauth_account( + user_id: str, + provider: str, + provider_user_id: str, + email: Optional[str] = None, + username: Optional[str] = None, + tenant_id: Optional[str] = None, +) -> Dict[str, Any]: + existing = get_oauth_account_by_provider(provider, provider_user_id) + + if existing: + if existing.get("user_id") != user_id: + raise OAuthLinkError( + f"This {provider} account is already bound to another user" + ) + else: + update_oauth_account_tokens( + provider=provider, + provider_user_id=provider_user_id, + provider_username=username, + ) + updated = get_oauth_account_by_provider(provider, provider_user_id) + return updated if updated else existing + + soft_deleted = get_soft_deleted_oauth_account(provider, provider_user_id) + if soft_deleted: + reactivate_oauth_account( + provider=provider, + provider_user_id=provider_user_id, + user_id=user_id, + provider_email=email, + provider_username=username, + tenant_id=tenant_id or DEFAULT_TENANT_ID, + ) + reactivated = get_oauth_account_by_provider(provider, provider_user_id) + return reactivated if reactivated else {"provider": provider, "provider_user_id": provider_user_id, "user_id": user_id} + + return insert_oauth_account( + user_id=user_id, + provider=provider, + provider_user_id=provider_user_id, + provider_email=email, + provider_username=username, + tenant_id=tenant_id or DEFAULT_TENANT_ID, + ) + + +def ensure_user_tenant_exists(user_id: str, email: str) -> Dict[str, Any]: + existing = get_user_tenant_by_user_id(user_id) + if existing: + return existing + + insert_user_tenant( + user_id=user_id, + tenant_id=DEFAULT_TENANT_ID, + user_role="USER", + user_email=email, + ) + logger.info(f"Created user_tenant for new OAuth user {user_id}") + result = get_user_tenant_by_user_id(user_id) + return result if result else {"user_id": user_id, "tenant_id": DEFAULT_TENANT_ID} + + +def list_linked_accounts(user_id: str) -> List[Dict[str, Any]]: + accounts = list_oauth_accounts_by_user_id(user_id) + result = [] + for acct in accounts: + result.append( + { + "provider": acct["provider"], + "provider_username": acct.get("provider_username"), + "provider_email": acct.get("provider_email"), + "linked_at": str(acct.get("create_time", "")), + } + ) + return result + + +def unlink_account( + user_id: str, provider: str, has_password_auth: bool = False +) -> bool: + oauth_count = count_oauth_accounts_by_user_id(user_id) + if oauth_count <= 1 and not has_password_auth: + raise OAuthLinkError("Cannot unlink the last authentication method") + + success = delete_oauth_account(user_id, provider) + if not success: + raise OAuthLinkError(f"No linked {provider} account found") + return True diff --git a/backend/services/user_service.py b/backend/services/user_service.py index ceb471844..6f4edcb1a 100644 --- a/backend/services/user_service.py +++ b/backend/services/user_service.py @@ -11,6 +11,7 @@ from database.group_db import remove_user_from_all_groups from database.memory_config_db import soft_delete_all_configs_by_user_id from database.conversation_db import soft_delete_all_conversations_by_user +from database.oauth_account_db import soft_delete_all_oauth_accounts_by_user_id from utils.auth_utils import get_supabase_admin_client from utils.memory_utils import build_memory_config @@ -174,7 +175,14 @@ async def delete_user_and_cleanup(user_id: str, tenant_id: str) -> None: except Exception as e: logger.error(f"Failed clearing memory for user {user_id}: {e}") - # 5) Delete from Supabase + # 5) Soft-delete OAuth account bindings + try: + deleted_oauth = soft_delete_all_oauth_accounts_by_user_id(user_id, user_id) + logger.debug(f"\t{deleted_oauth} OAuth account bindings deleted.") + except Exception as e: + logger.error(f"Failed deleting OAuth accounts for user {user_id}: {e}") + + # 6) Delete from Supabase try: admin_client = get_supabase_admin_client() if admin_client and hasattr(admin_client.auth, "admin"): diff --git a/backend/utils/auth_utils.py b/backend/utils/auth_utils.py index 7b40576e2..543d49693 100644 --- a/backend/utils/auth_utils.py +++ b/backend/utils/auth_utils.py @@ -42,7 +42,9 @@ TIMESTAMP_VALIDITY_WINDOW = 5 * 60 -def calculate_hmac_signature(secret_key: str, access_key: str, timestamp: str, body: str) -> str: +def calculate_hmac_signature( + secret_key: str, access_key: str, timestamp: str, body: str +) -> str: """ Calculate HMAC-SHA256 signature for AK/SK authentication. @@ -84,7 +86,9 @@ def get_aksk_config(tenant_id: str) -> Tuple[str, str]: raise UnauthorizedError("AK/SK authentication is not configured") -def verify_aksk_signature(access_key: str, timestamp: str, signature: str, body: str, tenant_id: str = None) -> bool: +def verify_aksk_signature( + access_key: str, timestamp: str, signature: str, body: str, tenant_id: str = None +) -> bool: """Verify AK/SK signature; returns False instead of raising on mismatch.""" tenant = tenant_id or DEFAULT_TENANT_ID try: @@ -99,13 +103,17 @@ def verify_aksk_signature(access_key: str, timestamp: str, signature: str, body: return hmac.compare_digest(expected_sig, signature) -def validate_aksk_authentication(headers: Dict[str, str], body: str, tenant_id: str = None) -> bool: +def validate_aksk_authentication( + headers: Dict[str, str], body: str, tenant_id: str = None +) -> bool: """ Validate AK/SK authentication. Returns True when valid, otherwise raises domain exceptions. """ - from consts.exceptions import SignatureValidationError # imported lazily for test-time stubbing + from consts.exceptions import ( + SignatureValidationError, + ) # imported lazily for test-time stubbing try: access_key, ts, sig = extract_aksk_headers(headers) @@ -129,6 +137,7 @@ def validate_aksk_authentication(headers: Dict[str, str], body: str, tenant_id: logger.exception("Unexpected error during AK/SK authentication") raise UnauthorizedError("Authentication failed") from exc + # --------------------------------------------------------------------------- # Bearer Token (API Key) authentication # --------------------------------------------------------------------------- @@ -151,7 +160,11 @@ def validate_bearer_token(authorization: Optional[str]) -> Tuple[bool, Optional[ return False, None # Extract token from "Bearer " format - token = authorization.replace("Bearer ", "") if authorization.startswith("Bearer ") else authorization + token = ( + authorization.replace("Bearer ", "") + if authorization.startswith("Bearer ") + else authorization + ) if not token: logger.warning("Empty bearer token") @@ -161,7 +174,9 @@ def validate_bearer_token(authorization: Optional[str]) -> Tuple[bool, Optional[ try: token_info = get_token_by_access_key(token) if token_info and token_info.get("delete_flag") != "Y": - logger.debug(f"Token validated successfully for user {token_info.get('user_id')}") + logger.debug( + f"Token validated successfully for user {token_info.get('user_id')}" + ) return True, token_info else: logger.warning(f"Invalid or inactive token: {token[:20]}...") @@ -202,12 +217,14 @@ def get_user_and_tenant_by_access_key(access_key: str) -> Dict[str, str]: tenant_id = user_tenant_record["tenant_id"] else: tenant_id = DEFAULT_TENANT_ID - logger.warning(f"No tenant relationship found for user {user_id}, using default tenant") + logger.warning( + f"No tenant relationship found for user {user_id}, using default tenant" + ) return { "user_id": user_id, "tenant_id": tenant_id, - "token_id": token_info.get("token_id") + "token_id": token_info.get("token_id"), } @@ -245,8 +262,9 @@ def get_jwt_expiry_seconds(token: str) -> int: # 10 years in seconds return 10 * 365 * 24 * 60 * 60 # Ensure token is pure JWT, remove possible Bearer prefix - jwt_token = token.replace( - "Bearer ", "") if token.startswith("Bearer ") else token + jwt_token = ( + token.replace("Bearer ", "") if token.startswith("Bearer ") else token + ) # If debug expiration time is set, return directly for quick debugging if DEBUG_JWT_EXPIRE_SECONDS > 0: @@ -300,13 +318,18 @@ def _extract_user_id_from_jwt_token(authorization: str) -> Optional[str]: UnauthorizedError: If token is invalid, expired, or signature verification fails """ if not SUPABASE_JWT_SECRET: - logging.error("SUPABASE_JWT_SECRET (or JWT_SECRET) is not configured; cannot verify JWT") + logging.error( + "SUPABASE_JWT_SECRET (or JWT_SECRET) is not configured; cannot verify JWT" + ) raise UnauthorizedError("JWT verification is not configured") try: # Format authorization header - token = authorization.replace("Bearer ", "") if authorization.startswith( - "Bearer ") else authorization + token = ( + authorization.replace("Bearer ", "") + if authorization.startswith("Bearer ") + else authorization + ) # Decode and verify JWT (signature + expiration) # verify_aud=False: allow tokens with aud claim (e.g. test JWT, Supabase) without strict audience check @@ -349,12 +372,13 @@ def get_current_user_id(authorization: Optional[str] = None) -> tuple[str, str]: """ # In speed mode, allow unauthenticated access with default user for demo/dev if IS_SPEED_MODE: - logging.debug( - "Speed mode detected - returning default user ID and tenant ID") + logging.debug("Speed mode detected - returning default user ID and tenant ID") return DEFAULT_USER_ID, DEFAULT_TENANT_ID # In normal mode, missing auth header means unauthorized - return 401, not default user - if authorization is None or (isinstance(authorization, str) and not authorization.strip()): + if authorization is None or ( + isinstance(authorization, str) and not authorization.strip() + ): raise UnauthorizedError("No authorization header provided") try: @@ -363,13 +387,14 @@ def get_current_user_id(authorization: Optional[str] = None) -> tuple[str, str]: raise UnauthorizedError("Invalid or expired authentication token") user_tenant_record = get_user_tenant_by_user_id(user_id) - if user_tenant_record and user_tenant_record.get('tenant_id'): - tenant_id = user_tenant_record['tenant_id'] + if user_tenant_record and user_tenant_record.get("tenant_id"): + tenant_id = user_tenant_record["tenant_id"] logging.debug(f"Found tenant ID for user {user_id}: {tenant_id}") else: tenant_id = DEFAULT_TENANT_ID logging.warning( - f"No tenant relationship found for user {user_id}, using default tenant") + f"No tenant relationship found for user {user_id}, using default tenant" + ) return user_id, tenant_id @@ -393,8 +418,8 @@ def get_user_language(request: Request = None) -> str: # Read language setting from cookie if request: try: - if hasattr(request, 'cookies') and request.cookies: - cookie_locale = request.cookies.get('NEXT_LOCALE') + if hasattr(request, "cookies") and request.cookies: + cookie_locale = request.cookies.get("NEXT_LOCALE") if cookie_locale and cookie_locale in [LANGUAGE["ZH"], LANGUAGE["EN"]]: return cookie_locale except (AttributeError, TypeError) as e: @@ -407,6 +432,7 @@ def get_user_language(request: Request = None) -> str: # Simple JWT helpers for tests and tooling # --------------------------------------------------------------------------- + def generate_test_jwt(user_id: str, expires_in: int = 3600) -> str: """ Generate a simple unsigned JWT for testing purposes (HS256 with dummy secret) @@ -423,7 +449,23 @@ def generate_test_jwt(user_id: str, expires_in: int = 3600) -> str: return jwt.encode(payload, MOCK_JWT_SECRET_KEY, algorithm="HS256") -def get_current_user_info(authorization: Optional[str] = None, request: Request = None) -> tuple[str, str, str]: +def generate_session_jwt(user_id: str, expires_in: int = 3600) -> str: + """Generate a signed JWT compatible with the existing auth verification flow.""" + now = int(time.time()) + payload = { + "sub": user_id, + "role": "authenticated", + "aud": "authenticated", + "iat": now, + "exp": now + expires_in, + "iss": SUPABASE_URL, + } + return jwt.encode(payload, SUPABASE_JWT_SECRET, algorithm="HS256") + + +def get_current_user_info( + authorization: Optional[str] = None, request: Request = None +) -> tuple[str, str, str]: """ Get current user information, including user ID, tenant ID, and language preference diff --git a/docker/.env.example b/docker/.env.example index a8ec6dedb..888609e04 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -163,3 +163,20 @@ LLM_SLOW_TOKEN_RATE_THRESHOLD=10.0 # Market Backend Address MARKET_BACKEND=http://60.204.251.153:8010 + +# ===== OAuth Configuration ===== +# GitHub OAuth - get credentials from https://github.com/settings/developers +GITHUB_OAUTH_CLIENT_ID= +GITHUB_OAUTH_CLIENT_SECRET= +# GDE OAuth +GDE_URL= +GDE_OAUTH_CLIENT_ID= +GDE_OAUTH_CLIENT_SECRET= +# WeChat OAuth (set ENABLE_WECHAT_OAUTH=true to enable) +ENABLE_WECHAT_OAUTH=false +WECHAT_OAUTH_APP_ID= +WECHAT_OAUTH_APP_SECRET= +# Base URL for OAuth callback (e.g., http://localhost:3000 for local dev) +OAUTH_SSL_VERIFY=true +OAUTH_CA_BUNDLE= +OAUTH_CALLBACK_BASE_URL=http://localhost:3000 diff --git a/docker/init.sql b/docker/init.sql index 6ca77f731..f4f10cc31 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -1517,3 +1517,57 @@ COMMENT ON COLUMN nexent.ag_a2a_artifact_t.parts IS 'Artifact parts following A2 COMMENT ON COLUMN nexent.ag_a2a_artifact_t.meta_data IS 'Artifact metadata'; COMMENT ON COLUMN nexent.ag_a2a_artifact_t.extensions IS 'Extension URI list'; COMMENT ON COLUMN nexent.ag_a2a_artifact_t.create_time IS 'Artifact creation timestamp'; + + +-- Create user OAuth account table for third-party login (GitHub, WeChat, etc.) +CREATE TABLE IF NOT EXISTS nexent.user_oauth_account_t ( + oauth_account_id SERIAL PRIMARY KEY, + user_id VARCHAR(100) NOT NULL, + provider VARCHAR(30) NOT NULL, + provider_user_id VARCHAR(200) NOT NULL, + provider_email VARCHAR(255), + provider_username VARCHAR(200), + tenant_id VARCHAR(100), + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag CHAR(1) DEFAULT 'N', + CONSTRAINT uq_oauth_provider_user UNIQUE (provider, provider_user_id) +); + +ALTER TABLE nexent.user_oauth_account_t OWNER TO "root"; + +-- Create a function to update the update_time column +CREATE OR REPLACE FUNCTION update_user_oauth_account_t_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create a trigger to call the function before each update +CREATE TRIGGER update_user_oauth_account_t_update_time_trigger +BEFORE UPDATE ON nexent.user_oauth_account_t +FOR EACH ROW +EXECUTE FUNCTION update_user_oauth_account_t_update_time(); + +-- Add comments +COMMENT ON TABLE nexent.user_oauth_account_t IS 'User OAuth account table - third-party login bindings'; +COMMENT ON COLUMN nexent.user_oauth_account_t.oauth_account_id IS 'OAuth account ID, primary key'; +COMMENT ON COLUMN nexent.user_oauth_account_t.user_id IS 'Nexent user ID (Supabase UUID)'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider IS 'OAuth provider name: github, wechat'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_user_id IS 'User ID from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_email IS 'Email from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_username IS 'Display name from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.tenant_id IS 'Tenant ID at time of linking'; +COMMENT ON COLUMN nexent.user_oauth_account_t.create_time IS 'Creation time'; +COMMENT ON COLUMN nexent.user_oauth_account_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.user_oauth_account_t.created_by IS 'Creator'; +COMMENT ON COLUMN nexent.user_oauth_account_t.updated_by IS 'Updater'; +COMMENT ON COLUMN nexent.user_oauth_account_t.delete_flag IS 'Whether it is deleted. Optional values: Y/N'; + +-- Create index for user_id queries +CREATE INDEX IF NOT EXISTS idx_user_oauth_account_t_user_id +ON nexent.user_oauth_account_t (user_id); diff --git a/docker/sql/v2.0.3_0430_add_user_oauth_account_t.sql b/docker/sql/v2.0.3_0430_add_user_oauth_account_t.sql new file mode 100644 index 000000000..18ca52dc3 --- /dev/null +++ b/docker/sql/v2.0.3_0430_add_user_oauth_account_t.sql @@ -0,0 +1,52 @@ +-- Create user OAuth account table for third-party login (GitHub, WeChat, etc.) +CREATE TABLE IF NOT EXISTS nexent.user_oauth_account_t ( + oauth_account_id SERIAL PRIMARY KEY, + user_id VARCHAR(100) NOT NULL, + provider VARCHAR(30) NOT NULL, + provider_user_id VARCHAR(200) NOT NULL, + provider_email VARCHAR(255), + provider_username VARCHAR(200), + tenant_id VARCHAR(100), + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag CHAR(1) DEFAULT 'N', + CONSTRAINT uq_oauth_provider_user UNIQUE (provider, provider_user_id) +); + +ALTER TABLE nexent.user_oauth_account_t OWNER TO "root"; + +-- Create a function to update the update_time column +CREATE OR REPLACE FUNCTION update_user_oauth_account_t_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create a trigger to call the function before each update +CREATE TRIGGER update_user_oauth_account_t_update_time_trigger +BEFORE UPDATE ON nexent.user_oauth_account_t +FOR EACH ROW +EXECUTE FUNCTION update_user_oauth_account_t_update_time(); + +-- Add comments +COMMENT ON TABLE nexent.user_oauth_account_t IS 'User OAuth account table - third-party login bindings'; +COMMENT ON COLUMN nexent.user_oauth_account_t.oauth_account_id IS 'OAuth account ID, primary key'; +COMMENT ON COLUMN nexent.user_oauth_account_t.user_id IS 'Nexent user ID (Supabase UUID)'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider IS 'OAuth provider name: github, wechat'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_user_id IS 'User ID from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_email IS 'Email from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_username IS 'Display name from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.tenant_id IS 'Tenant ID at time of linking'; +COMMENT ON COLUMN nexent.user_oauth_account_t.create_time IS 'Creation time'; +COMMENT ON COLUMN nexent.user_oauth_account_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.user_oauth_account_t.created_by IS 'Creator'; +COMMENT ON COLUMN nexent.user_oauth_account_t.updated_by IS 'Updater'; +COMMENT ON COLUMN nexent.user_oauth_account_t.delete_flag IS 'Whether it is deleted. Optional values: Y/N'; + +-- Create index for user_id queries +CREATE INDEX IF NOT EXISTS idx_user_oauth_account_t_user_id +ON nexent.user_oauth_account_t (user_id); diff --git a/frontend/app/[locale]/users/components/UserProfileComp.tsx b/frontend/app/[locale]/users/components/UserProfileComp.tsx index 2a66bd89e..c1dbb403b 100644 --- a/frontend/app/[locale]/users/components/UserProfileComp.tsx +++ b/frontend/app/[locale]/users/components/UserProfileComp.tsx @@ -35,6 +35,7 @@ import { useAuthenticationContext } from "@/components/providers/AuthenticationP import { useGroupList } from "@/hooks/group/useGroupList"; import { useMemo } from "react"; import { DeleteAccountModal } from "@/components/auth/DeleteAccountModal"; +import { OAuthAccountsSection } from "@/components/settings/OAuthAccountsSection"; import log from "@/lib/logger"; import { getUserTokens, @@ -587,6 +588,11 @@ export default function UserProfileComp() { loading={isLoading} disabled={isAdminOrSuperAdmin} /> + + {/* OAuth Linked Accounts */} +
+ +
); } diff --git a/frontend/components/auth/loginModal.tsx b/frontend/components/auth/loginModal.tsx index 0c219bb3d..ba7ea9ff2 100644 --- a/frontend/components/auth/loginModal.tsx +++ b/frontend/components/auth/loginModal.tsx @@ -1,18 +1,53 @@ "use client"; -import { useState } from "react"; +import { useState, useEffect } from "react"; import { useTranslation } from "react-i18next"; -import { Modal, Form, Input, Button, Typography, Space } from "antd"; -import { UserRound, LockKeyhole } from "lucide-react"; -import { usePathname, useRouter } from "next/navigation"; +import { Modal, Form, Input, Button, Typography, Space, Divider, Alert } from "antd"; +import { UserRound, LockKeyhole, Github, Link2 } from "lucide-react"; +import { usePathname, useRouter, useSearchParams } from "next/navigation"; import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; import { useDeployment } from "@/components/providers/deploymentProvider"; import { getEffectiveRoutePath } from "@/lib/auth"; +import { oauthService } from "@/services/oauthService"; import log from "@/lib/logger"; const { Text } = Typography; +const providerIconMap: Record = { + github: , +}; + +function OAuthLoginButtons() { + const { t } = useTranslation("common"); + const [providers, setProviders] = useState>([]); + + useEffect(() => { + oauthService.getEnabledProviders().then((p) => setProviders(p)); + }, []); + + if (providers.length === 0) return null; + + return ( +
+ {t("auth.oauthDivider") || "or"} +
+ {providers.map((provider) => ( + + ))} +
+
+ ); +} + /** * LoginModal Component * Handles user authentication through a modal interface @@ -32,14 +67,26 @@ export function LoginModal() { const router = useRouter(); const pathname = usePathname(); + const searchParams = useSearchParams(); const [form] = Form.useForm(); const [isLoading, setIsLoading] = useState(false); const [emailError, setEmailError] = useState(""); const [passwordError, setPasswordError] = useState(false); + const [oauthError, setOauthError] = useState(null); + + useEffect(() => { + const error = searchParams.get("oauth_error"); + const description = searchParams.get("oauth_error_description"); + if (error) { + setOauthError(description || error); + router.replace("/"); + } + }, [searchParams, router]); const resetForm = () => { setEmailError(""); setPasswordError(false); + setOauthError(null); form.resetFields(); }; @@ -188,6 +235,16 @@ export function LoginModal() { className="mt-6" autoComplete="off" > + {oauthError && ( + setOauthError(null)} + className="mb-4" + /> + )} {/* Email input field */} + {/* OAuth login section */} + + {/* Registration link section (hidden when opened from session expired flow) */}
diff --git a/frontend/components/settings/OAuthAccountsSection.tsx b/frontend/components/settings/OAuthAccountsSection.tsx new file mode 100644 index 000000000..9baf08377 --- /dev/null +++ b/frontend/components/settings/OAuthAccountsSection.tsx @@ -0,0 +1,143 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { useTranslation } from "react-i18next"; +import { Button, Card, Modal, message } from "antd"; +import { Github, Unlink, Link2, Plus } from "lucide-react"; + +import { + oauthService, + type OAuthAccount, + type OAuthProvider, +} from "@/services/oauthService"; + +const providerIcons: Record = { + github: , +}; + +interface ProviderRow { + name: string; + display_name: string; + linked: boolean; + account?: OAuthAccount; +} + +export function OAuthAccountsSection() { + const { t } = useTranslation("common"); + const [accounts, setAccounts] = useState([]); + const [enabledProviders, setEnabledProviders] = useState([]); + const [loading, setLoading] = useState(false); + const [unlinkTarget, setUnlinkTarget] = useState(null); + + useEffect(() => { + loadData(); + }, []); + + const loadData = async () => { + setLoading(true); + const [linked, providers] = await Promise.all([ + oauthService.getLinkedAccounts(), + oauthService.getEnabledProviders(), + ]); + setAccounts(linked); + setEnabledProviders(providers); + setLoading(false); + }; + + const handleUnlink = async () => { + if (!unlinkTarget) return; + + try { + const success = await oauthService.unlinkAccount(unlinkTarget.provider); + if (success) { + message.success(t("auth.unlinkSuccess")); + await loadData(); + } else { + message.error(t("auth.unlinkFailed")); + } + } finally { + setUnlinkTarget(null); + } + }; + + const accountMap = new Map(accounts.map((a) => [a.provider, a])); + const rows: ProviderRow[] = enabledProviders.map((p) => { + const account = accountMap.get(p.name); + return { + name: p.name, + display_name: p.display_name, + linked: !!account, + account: account, + }; + }); + + return ( + {t("auth.linkedAccounts")}} + loading={loading} + className="mt-4" + > + {rows.length === 0 ? ( +
+ {t("auth.noLinkedAccounts")} +
+ ) : ( +
+ {rows.map((row) => ( +
+
+
+ {providerIcons[row.name] || } +
+
+
+ {row.display_name} +
+
+ {row.linked + ? row.account!.provider_username || row.account!.provider_email || "-" + : t("auth.noLinkedAccounts")} +
+
+
+ {row.linked ? ( + + ) : ( + + )} +
+ ))} +
+ )} + + setUnlinkTarget(null)} + okText={t("auth.confirm")} + cancelText={t("auth.cancel")} + okButtonProps={{ danger: true }} + > +

{t("auth.unlinkConfirm", { provider: unlinkTarget?.provider || "" })}

+
+
+ ); +} diff --git a/frontend/hooks/auth/useAuthenticationUI.ts b/frontend/hooks/auth/useAuthenticationUI.ts index 8891790e6..cb0cbade0 100644 --- a/frontend/hooks/auth/useAuthenticationUI.ts +++ b/frontend/hooks/auth/useAuthenticationUI.ts @@ -1,7 +1,7 @@ "use client"; import { useState, useCallback, useRef, useEffect } from "react"; -import { useRouter, usePathname } from "next/navigation"; +import { useRouter, usePathname, useSearchParams } from "next/navigation"; import { useTranslation } from "react-i18next"; import { useDeployment } from "@/components/providers/deploymentProvider"; @@ -27,6 +27,7 @@ export function useAuthenticationUI({ }): AuthenticationUIReturn { const router = useRouter(); const pathname = usePathname(); + const searchParams = useSearchParams(); const { t } = useTranslation("common"); const { isSpeedMode } = useDeployment(); @@ -108,7 +109,23 @@ export function useAuthenticationUI({ }; }, [isSpeedMode, setIsSessionExpiredModalOpen]); + // Auto-open login modal when returning from a failed OAuth redirect + useEffect(() => { + if (isSpeedMode) return; + if (isAuthChecking) return; + if (isAuthenticated) { + const oauthError = searchParams.get("oauth_error"); + if (oauthError) { + router.replace("/"); + } + return; + } + const oauthError = searchParams.get("oauth_error"); + if (oauthError && !isLoginModalOpen) { + setIsLoginModalOpen(true); + } + }, [searchParams, isAuthChecking, isAuthenticated, isSpeedMode, isLoginModalOpen, router]); // Route guard for unauthenticated users - check when pathname changes useEffect(() => { diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 70bde3339..39d40f1d0 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -953,6 +953,16 @@ "auth.logoutSuccess": "You have successfully logged out", "auth.logoutFailed": "Logout failed, please try again", "auth.accessDenied": "You do not have permission to access this page", + "auth.oauthDivider": "or continue with", + "auth.oauthLogin": "{{provider}} Login", + "auth.oauthLoginFailed": "Third-party login failed: {{error}}", + "auth.linkedAccounts": "Linked Accounts", + "auth.unlinkAccount": "Unlink", + "auth.unlinkConfirm": "Are you sure you want to unlink this {{provider}} account? You will need to use another login method.", + "auth.unlinkSuccess": "Account unlinked successfully", + "auth.unlinkFailed": "Failed to unlink account", + "auth.noLinkedAccounts": "No third-party accounts linked", + "auth.linkAccount": "Link Account", "auth.revoke": "Delete Account", "auth.confirmRevoke": "Delete Account", "auth.confirmRevokePrompt": "Are you sure you want to delete your account? This action cannot be undone!", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 0de8ccdb8..ef54e835c 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -955,6 +955,16 @@ "auth.logoutSuccess": "您已成功退出登录", "auth.logoutFailed": "退出失败,请重试", "auth.accessDenied": "您没有权限访问此页面", + "auth.oauthDivider": "或使用第三方登录", + "auth.oauthLogin": "{{provider}} 登录", + "auth.oauthLoginFailed": "第三方登录失败:{{error}}", + "auth.linkedAccounts": "已绑定的账号", + "auth.unlinkAccount": "解绑", + "auth.unlinkConfirm": "确定要解绑此 {{provider}} 账号吗?您将需要使用其他登录方式。", + "auth.unlinkSuccess": "账号解绑成功", + "auth.unlinkFailed": "账号解绑失败", + "auth.noLinkedAccounts": "未绑定第三方账号", + "auth.linkAccount": "绑定账号", "auth.revoke": "删除账号", "auth.confirmRevoke": "确认删除账号", "auth.confirmRevokePrompt": "确定要彻底删除当前账号吗?此操作不可恢复!", diff --git a/frontend/server.js b/frontend/server.js index 8f620944c..28ca40b3a 100644 --- a/frontend/server.js +++ b/frontend/server.js @@ -118,6 +118,8 @@ const AUTH_INTERCEPT_ENDPOINTS = new Set([ "/api/user/refresh_token", "/api/user/logout", "/api/user/revoke", + "/api/user/oauth/callback", + "/api/user/oauth/link", ]); function collectRequestBody(req) { @@ -192,11 +194,16 @@ function forwardAuthRequest(req, res, targetUrl) { if (isLogout || isRevoke) { clearAuthCookies(res); } else if (data.data && data.data.session) { - // Extract tokens, set cookies, strip tokens from response const session = data.data.session; setAuthCookies(res, session); - // Remove sensitive tokens from the response body sent to browser + const isOAuthCallback = req.parsedPathname === "/api/user/oauth/callback"; + if (isOAuthCallback) { + res.writeHead(302, { Location: "/chat" }); + res.end(); + return; + } + const sanitized = { ...data }; sanitized.data = { ...data.data }; sanitized.data.session = { @@ -204,6 +211,14 @@ function forwardAuthRequest(req, res, targetUrl) { expires_in_seconds: session.expires_in_seconds, }; finalBody = Buffer.from(JSON.stringify(sanitized)); + } else if (req.parsedPathname === "/api/user/oauth/callback" && data.data && data.data.oauth_error) { + const errorParams = new URLSearchParams({ + oauth_error: data.data.oauth_error, + oauth_error_description: data.data.oauth_error_description || "", + }); + res.writeHead(302, { Location: `/?${errorParams.toString()}` }); + res.end(); + return; } } } catch { diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 5cec1b488..8eba7dbd2 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -20,6 +20,13 @@ export const API_ENDPOINTS = { tokens: `${API_BASE_URL}/user/tokens`, deleteToken: (tokenId: number) => `${API_BASE_URL}/user/tokens/${tokenId}`, }, + oauth: { + providers: `${API_BASE_URL}/user/oauth/providers`, + authorize: `${API_BASE_URL}/user/oauth/authorize`, + link: `${API_BASE_URL}/user/oauth/link`, + accounts: `${API_BASE_URL}/user/oauth/accounts`, + unlink: (provider: string) => `${API_BASE_URL}/user/oauth/accounts/${provider}`, + }, conversation: { list: `${API_BASE_URL}/conversation/list`, create: `${API_BASE_URL}/conversation/create`, diff --git a/frontend/services/oauthService.ts b/frontend/services/oauthService.ts new file mode 100644 index 000000000..ba9b05bed --- /dev/null +++ b/frontend/services/oauthService.ts @@ -0,0 +1,69 @@ +import { API_ENDPOINTS } from "@/services/api"; +import { fetchWithAuth } from "@/lib/auth"; +import log from "@/lib/logger"; + +export interface OAuthProvider { + name: string; + display_name: string; + icon: string; + enabled: boolean; +} + +export interface OAuthAccount { + provider: string; + provider_username: string | null; + provider_email: string | null; + linked_at: string | null; +} + +export const oauthService = { + getEnabledProviders: async (): Promise => { + try { + const response = await fetch(API_ENDPOINTS.oauth.providers); + if (!response.ok) { + log.warn("Failed to fetch OAuth providers"); + return []; + } + const data = await response.json(); + return data.data || []; + } catch (error) { + log.error("Failed to fetch OAuth providers:", error); + return []; + } + }, + + startOAuthLogin: (provider: string): void => { + window.location.href = `${API_ENDPOINTS.oauth.authorize}?provider=${provider}`; + }, + + startOAuthLink: (provider: string): void => { + window.location.href = `${API_ENDPOINTS.oauth.link}?provider=${provider}`; + }, + + getLinkedAccounts: async (): Promise => { + try { + const response = await fetchWithAuth(API_ENDPOINTS.oauth.accounts); + if (!response.ok) { + log.warn("Failed to fetch linked OAuth accounts"); + return []; + } + const data = await response.json(); + return data.data || []; + } catch (error) { + log.error("Failed to fetch linked OAuth accounts:", error); + return []; + } + }, + + unlinkAccount: async (provider: string): Promise => { + try { + const response = await fetchWithAuth(API_ENDPOINTS.oauth.unlink(provider), { + method: "DELETE", + }); + return response.ok; + } catch (error) { + log.error(`Failed to unlink ${provider} account:`, error); + return false; + } + }, +}; diff --git a/k8s/helm/nexent/charts/nexent-common/files/init.sql b/k8s/helm/nexent/charts/nexent-common/files/init.sql index e209caa41..0538e99e1 100644 --- a/k8s/helm/nexent/charts/nexent-common/files/init.sql +++ b/k8s/helm/nexent/charts/nexent-common/files/init.sql @@ -1390,3 +1390,56 @@ COMMENT ON COLUMN "ag_a2a_artifact_t".meta_data IS 'Artifact metadata'; COMMENT ON COLUMN "ag_a2a_artifact_t".extensions IS 'Extension URI list'; COMMENT ON COLUMN "ag_a2a_artifact_t".create_time IS 'Artifact creation timestamp'; + +-- Create user OAuth account table for third-party login (GitHub, WeChat, etc.) +CREATE TABLE IF NOT EXISTS nexent.user_oauth_account_t ( + oauth_account_id SERIAL PRIMARY KEY, + user_id VARCHAR(100) NOT NULL, + provider VARCHAR(30) NOT NULL, + provider_user_id VARCHAR(200) NOT NULL, + provider_email VARCHAR(255), + provider_username VARCHAR(200), + tenant_id VARCHAR(100), + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT NOW(), + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag CHAR(1) DEFAULT 'N', + CONSTRAINT uq_oauth_provider_user UNIQUE (provider, provider_user_id) +); + +ALTER TABLE nexent.user_oauth_account_t OWNER TO "root"; + +-- Create a function to update the update_time column +CREATE OR REPLACE FUNCTION update_user_oauth_account_t_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create a trigger to call the function before each update +CREATE TRIGGER update_user_oauth_account_t_update_time_trigger +BEFORE UPDATE ON nexent.user_oauth_account_t +FOR EACH ROW +EXECUTE FUNCTION update_user_oauth_account_t_update_time(); + +-- Add comments +COMMENT ON TABLE nexent.user_oauth_account_t IS 'User OAuth account table - third-party login bindings'; +COMMENT ON COLUMN nexent.user_oauth_account_t.oauth_account_id IS 'OAuth account ID, primary key'; +COMMENT ON COLUMN nexent.user_oauth_account_t.user_id IS 'Nexent user ID (Supabase UUID)'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider IS 'OAuth provider name: github, wechat'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_user_id IS 'User ID from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_email IS 'Email from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.provider_username IS 'Display name from the OAuth provider'; +COMMENT ON COLUMN nexent.user_oauth_account_t.tenant_id IS 'Tenant ID at time of linking'; +COMMENT ON COLUMN nexent.user_oauth_account_t.create_time IS 'Creation time'; +COMMENT ON COLUMN nexent.user_oauth_account_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.user_oauth_account_t.created_by IS 'Creator'; +COMMENT ON COLUMN nexent.user_oauth_account_t.updated_by IS 'Updater'; +COMMENT ON COLUMN nexent.user_oauth_account_t.delete_flag IS 'Whether it is deleted. Optional values: Y/N'; + +-- Create index for user_id queries +CREATE INDEX IF NOT EXISTS idx_user_oauth_account_t_user_id +ON nexent.user_oauth_account_t (user_id); \ No newline at end of file diff --git a/k8s/helm/nexent/charts/nexent-common/templates/configmap.yaml b/k8s/helm/nexent/charts/nexent-common/templates/configmap.yaml index 474945954..9a1ca1282 100644 --- a/k8s/helm/nexent/charts/nexent-common/templates/configmap.yaml +++ b/k8s/helm/nexent/charts/nexent-common/templates/configmap.yaml @@ -124,3 +124,21 @@ data: # Kubernetes Deployment Mode IS_DEPLOYED_BY_KUBERNETES: {{ .Values.config.isDeployedByKubernetes | quote }} KUBERNETES_NAMESPACE: {{ .Values.global.namespace | quote }} + + + # ===== OAuth Configuration ===== + # GitHub OAuth - get credentials from https://github.com/settings/developers + GITHUB_OAUTH_CLIENT_ID: {{ .Values.config.oauth.githubClientId | quote }} + GITHUB_OAUTH_CLIENT_SECRET: {{ .Values.config.oauth.githubClientSecret | quote }} + # GDE OAuth + GDE_URL: {{ .Values.config.oauth.gdeUrl | quote }} + GDE_OAUTH_CLIENT_ID: {{ .Values.config.oauth.gdeClientId | quote }} + GDE_OAUTH_CLIENT_SECRET: {{ .Values.config.oauth.gdeClientSecret | quote }} + # WeChat OAuth (set ENABLE_WECHAT_OAUTH=true to enable) + ENABLE_WECHAT_OAUTH: {{ .Values.config.oauth.enableWechat | quote }} + WECHAT_OAUTH_APP_ID: {{ .Values.config.oauth.wechatClientId | quote }} + WECHAT_OAUTH_APP_SECRET: {{ .Values.config.oauth.wechatClientSecret | quote }} + # Base URL for OAuth callback (e.g., http://localhost:3000 for local dev) + OAUTH_SSL_VERIFY: {{ .Values.config.oauth.sslVerify | quote }} + OAUTH_CA_BUNDLE: {{ .Values.config.oauth.caBundle | quote }} + OAUTH_CALLBACK_BASE_URL: {{ .Values.config.oauth.callbackBaseUrl | quote }} diff --git a/k8s/helm/nexent/charts/nexent-common/values.yaml b/k8s/helm/nexent/charts/nexent-common/values.yaml index 331e2f896..951468390 100644 --- a/k8s/helm/nexent/charts/nexent-common/values.yaml +++ b/k8s/helm/nexent/charts/nexent-common/values.yaml @@ -106,6 +106,18 @@ config: telemetrySampleRate: "1.0" slowRequestThresholdSeconds: "5.0" slowTokenRateThreshold: "10.0" + oauth: + githubClientId: "" + githubClientSecret: "" + enableWechat: "false" + wechatClientId: "" + wechatClientSecret: "" + gdeUrl: "" + gdeClientId: "" + gdeClientSecret: "" + sslVerify: "true" + caBundle: "" + callbackBaseUrl: "http://localhost:3000" # Secrets used by common templates secrets: diff --git a/test/backend/app/test_oauth_app.py b/test/backend/app/test_oauth_app.py new file mode 100644 index 000000000..758ab75d2 --- /dev/null +++ b/test/backend/app/test_oauth_app.py @@ -0,0 +1,767 @@ +import sys +import os +import unittest +from unittest.mock import patch, MagicMock + +test_dir = os.path.dirname(__file__) +backend_dir = os.path.abspath(os.path.join(test_dir, "../../../backend")) +sys.path.insert(0, backend_dir) + +sys.modules["boto3"] = MagicMock() + +consts_mock = MagicMock() +consts_mock.const = MagicMock() +consts_mock.const.GITHUB_OAUTH_CLIENT_ID = "test_id" +consts_mock.const.GITHUB_OAUTH_CLIENT_SECRET = "test_secret" +consts_mock.const.ENABLE_WECHAT_OAUTH = False +consts_mock.const.OAUTH_CALLBACK_BASE_URL = "http://localhost:3000" +consts_mock.const.SUPABASE_URL = "http://supabase.test" +consts_mock.const.DEFAULT_TENANT_ID = "default" +sys.modules["consts"] = consts_mock +sys.modules["consts.const"] = consts_mock.const + +sys.modules["consts.model"] = MagicMock() + +oauth_providers_mock = MagicMock() +oauth_providers_mock.get_all_provider_definitions.return_value = { + "github": MagicMock(), + "wechat": MagicMock(), +} +sys.modules["consts.oauth_providers"] = oauth_providers_mock + + +class _OAuthProviderError(Exception): + pass + + +class _OAuthLinkError(Exception): + pass + + +class _UnauthorizedError(Exception): + pass + + +exceptions_mock = MagicMock() +exceptions_mock.OAuthProviderError = _OAuthProviderError +exceptions_mock.OAuthLinkError = _OAuthLinkError +exceptions_mock.UnauthorizedError = _UnauthorizedError +sys.modules["consts.exceptions"] = exceptions_mock + +sys.modules["database"] = MagicMock() +database_oauth_mock = MagicMock() +database_oauth_mock.get_oauth_account_by_provider = MagicMock(return_value=None) +database_oauth_mock.get_soft_deleted_oauth_account = MagicMock(return_value=None) +sys.modules["database.oauth_account_db"] = database_oauth_mock +sys.modules["database.user_tenant_db"] = MagicMock() +sys.modules["database.client"] = MagicMock() +sys.modules["database.db_models"] = MagicMock() +sys.modules["backend.database"] = MagicMock() +sys.modules["backend.database.client"] = MagicMock() +sys.modules["backend.database.db_models"] = MagicMock() +sys.modules["utils"] = MagicMock() +sys.modules["utils.token_encryption"] = MagicMock() +sys.modules["utils.config_utils"] = MagicMock() + +auth_utils_mock = MagicMock() +auth_utils_mock.get_current_user_id = MagicMock(return_value=("user-1", "t-1")) +auth_utils_mock.get_jwt_expiry_seconds = MagicMock(return_value=3600) +auth_utils_mock.calculate_expires_at = MagicMock(return_value=1735689600) +auth_utils_mock.get_supabase_admin_client = MagicMock() +auth_utils_mock.generate_session_jwt = MagicMock(return_value="eyJ.mock.jwt.token") +sys.modules["utils.auth_utils"] = auth_utils_mock + +oauth_service_mock = MagicMock() +oauth_service_mock.parse_state = MagicMock( + return_value={"provider": "github", "token": "tok", "link_user_id": ""} +) +sys.modules["services"] = MagicMock() +sys.modules["services.oauth_service"] = oauth_service_mock + +nexent_mock = MagicMock() +sys.modules["nexent"] = nexent_mock +sys.modules["nexent.storage"] = MagicMock() +sys.modules["nexent.storage.storage_client_factory"] = MagicMock() +sys.modules["nexent.storage.minio_config"] = MagicMock() + +storage_client_mock = MagicMock() +minio_mock = MagicMock() +minio_mock._ensure_bucket_exists = MagicMock() +minio_mock.client = MagicMock() +patch( + "nexent.storage.storage_client_factory.create_storage_client_from_config", + return_value=storage_client_mock, +).start() +patch( + "nexent.storage.minio_config.MinIOStorageConfig.validate", lambda self: None +).start() +patch("database.client.MinioClient", return_value=minio_mock).start() +patch("database.client.MinioClient", return_value=minio_mock).start() +patch("database.client.minio_client", minio_mock).start() + +from fastapi.testclient import TestClient +from fastapi import FastAPI +from http import HTTPStatus + +from apps.oauth_app import router + +app = FastAPI() +app.include_router(router) +client = TestClient(app) + + +class TestGetProviders(unittest.TestCase): + def test_returns_provider_list(self): + oauth_service_mock.get_enabled_providers.return_value = [ + { + "name": "github", + "display_name": "GitHub", + "icon": "github", + "enabled": True, + } + ] + + response = client.get("/user/oauth/providers") + + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertEqual(data["message"], "success") + self.assertEqual(len(data["data"]), 1) + self.assertEqual(data["data"][0]["name"], "github") + + def test_returns_empty_list(self): + oauth_service_mock.get_enabled_providers.return_value = [] + + response = client.get("/user/oauth/providers") + + self.assertEqual(response.status_code, HTTPStatus.OK) + self.assertEqual(response.json()["data"], []) + + +class TestAuthorize(unittest.TestCase): + def test_redirects_to_provider(self): + oauth_service_mock.get_authorize_url.return_value = ( + "https://github.com/login/oauth/authorize?client_id=test_id" + ) + + response = client.get( + "/user/oauth/authorize?provider=github", follow_redirects=False + ) + + self.assertEqual(response.status_code, HTTPStatus.FOUND) + self.assertIn("github.com", response.headers["location"]) + + def test_returns_400_for_unsupported_provider(self): + oauth_service_mock.get_authorize_url.side_effect = _OAuthProviderError( + "Unsupported" + ) + + response = client.get("/user/oauth/authorize?provider=google") + + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + + oauth_service_mock.get_authorize_url.side_effect = None + + def test_returns_500_on_unexpected_error(self): + oauth_service_mock.get_authorize_url.side_effect = Exception("Unexpected") + + response = client.get("/user/oauth/authorize?provider=github") + + self.assertEqual(response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR) + + oauth_service_mock.get_authorize_url.side_effect = None + + +class TestLink(unittest.TestCase): + def test_redirects_to_provider_with_link_user_id(self): + oauth_service_mock.reset_mock() + oauth_service_mock.get_authorize_url.return_value = ( + "https://github.com/login/oauth/authorize?client_id=test_id&state=github:token:user-1" + ) + + response = client.get( + "/user/oauth/link?provider=github", + headers={"Authorization": "Bearer valid_token"}, + follow_redirects=False, + ) + + self.assertEqual(response.status_code, HTTPStatus.FOUND) + self.assertIn("github.com", response.headers["location"]) + oauth_service_mock.get_authorize_url.assert_called_once_with("github", link_user_id="user-1") + + def test_returns_401_without_auth(self): + response = client.get("/user/oauth/link?provider=github") + + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) + + @patch("apps.oauth_app.get_current_user_id") + def test_returns_401_for_invalid_token(self, mock_get_user): + mock_get_user.side_effect = _UnauthorizedError("Invalid token") + + response = client.get( + "/user/oauth/link?provider=github", + headers={"Authorization": "Bearer invalid"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) + mock_get_user.side_effect = None + + def test_returns_400_for_unsupported_provider(self): + oauth_service_mock.get_authorize_url.side_effect = _OAuthProviderError( + "Unsupported provider" + ) + + response = client.get( + "/user/oauth/link?provider=google", + headers={"Authorization": "Bearer valid_token"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + oauth_service_mock.get_authorize_url.side_effect = None + + def test_returns_500_on_unexpected_error(self): + oauth_service_mock.get_authorize_url.side_effect = Exception("Unexpected") + + response = client.get( + "/user/oauth/link?provider=github", + headers={"Authorization": "Bearer valid_token"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR) + oauth_service_mock.get_authorize_url.side_effect = None + + +class TestCallback(unittest.TestCase): + def test_returns_error_when_provider_error(self): + response = client.get( + "/user/oauth/callback?provider=github&error=access_denied&error_description=User+cancelled" + ) + + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + data = response.json() + self.assertEqual(data["data"]["oauth_error"], "access_denied") + + def test_returns_error_when_no_code(self): + response = client.get("/user/oauth/callback?provider=github") + + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + data = response.json() + self.assertEqual(data["data"]["oauth_error"], "no_code") + + def test_returns_error_for_unsupported_provider(self): + response = client.get("/user/oauth/callback?provider=google&code=abc123") + + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + data = response.json() + self.assertEqual(data["data"]["oauth_error"], "unsupported_provider") + + def test_success_returns_session_data(self): + oauth_service_mock.reset_mock() + oauth_service_mock.parse_state.return_value = {"provider": "github", "token": "tok", "link_user_id": ""} + database_oauth_mock.get_oauth_account_by_provider.return_value = None + database_oauth_mock.get_soft_deleted_oauth_account.return_value = None + oauth_service_mock.exchange_code_for_provider_token.return_value = { + "access_token": "ghu_provider_token_123", + } + oauth_service_mock.get_provider_user_info.return_value = { + "id": "12345", + "email": "octocat@github.com", + "username": "octocat", + } + + mock_existing_user = MagicMock() + mock_existing_user.id = "user-uuid-123" + mock_existing_user.email = "octocat@github.com" + + mock_users_resp = MagicMock() + mock_users_resp.users = [mock_existing_user] + + mock_admin_client = MagicMock() + mock_admin_client.auth.admin.list_users.return_value = mock_users_resp + + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin_client + auth_utils_mock.generate_session_jwt.return_value = "eyJ.mock.jwt.token" + + response = client.get("/user/oauth/callback?provider=github&code=valid_code") + + if response.status_code != HTTPStatus.OK: + print("Response:", response.json()) + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertIn("session", data["data"]) + self.assertEqual(data["data"]["user"]["email"], "octocat@github.com") + self.assertEqual( + data["data"]["session"]["access_token"], + "eyJ.mock.jwt.token", + ) + self.assertEqual(data["data"]["session"]["expires_in_seconds"], 3600) + + auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + + def test_success_creates_new_user_when_not_found(self): + oauth_service_mock.reset_mock() + oauth_service_mock.parse_state.return_value = {"provider": "github", "token": "tok", "link_user_id": ""} + database_oauth_mock.get_oauth_account_by_provider.return_value = None + database_oauth_mock.get_soft_deleted_oauth_account.return_value = None + oauth_service_mock.exchange_code_for_provider_token.return_value = { + "access_token": "ghu_provider_token_456", + } + oauth_service_mock.get_provider_user_info.return_value = { + "id": "67890", + "email": "newuser@github.com", + "username": "newuser", + } + + mock_empty_resp = MagicMock() + mock_empty_resp.users = [] + + mock_new_user = MagicMock() + mock_new_user.id = "new-uuid-456" + + mock_admin_client = MagicMock() + mock_admin_client.auth.admin.list_users.return_value = mock_empty_resp + mock_admin_client.auth.admin.create_user.return_value = MagicMock( + user=mock_new_user + ) + + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin_client + auth_utils_mock.generate_session_jwt.return_value = "eyJ.new.jwt.token" + + response = client.get("/user/oauth/callback?provider=github&code=new_code") + + if response.status_code != HTTPStatus.OK: + print("Response:", response.json()) + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertEqual(data["data"]["user"]["email"], "newuser@github.com") + mock_admin_client.auth.admin.create_user.assert_called_once() + + auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + + def test_returns_500_on_token_exchange_failure(self): + oauth_service_mock.exchange_code_for_provider_token.side_effect = Exception( + "Token exchange failed" + ) + + response = client.get("/user/oauth/callback?provider=github&code=bad_code") + + self.assertEqual(response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR) + data = response.json() + self.assertEqual(data["data"]["oauth_error"], "callback_failed") + + oauth_service_mock.exchange_code_for_provider_token.side_effect = None + + def test_returns_500_on_exception(self): + oauth_service_mock.exchange_code_for_provider_token.side_effect = Exception( + "Network error" + ) + + response = client.get("/user/oauth/callback?provider=github&code=crash_code") + + self.assertEqual(response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR) + data = response.json() + self.assertEqual(data["data"]["oauth_error"], "callback_failed") + + oauth_service_mock.exchange_code_for_provider_token.side_effect = None + + def test_success_with_link_user_id_binding(self): + """Callback with link_user_id should bind OAuth to that user directly.""" + oauth_service_mock.reset_mock() + database_oauth_mock.reset_mock() + oauth_service_mock.parse_state.return_value = { + "provider": "github", + "token": "tok", + "link_user_id": "existing-user-uuid", + } + oauth_service_mock.exchange_code_for_provider_token.return_value = { + "access_token": "ghu_provider_token", + } + oauth_service_mock.get_provider_user_info.return_value = { + "id": "12345", + "email": "octocat@github.com", + "username": "octocat", + } + oauth_service_mock.ensure_user_tenant_exists.return_value = { + "user_id": "existing-user-uuid", + "tenant_id": "t-1", + } + oauth_service_mock.create_or_update_oauth_account.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "existing-user-uuid", + } + auth_utils_mock.generate_session_jwt.return_value = "eyJ.bind.jwt" + + response = client.get( + "/user/oauth/callback?provider=github&code=bind_code&state=github:tok:existing-user-uuid" + ) + + if response.status_code != HTTPStatus.OK: + print("Response:", response.json()) + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertEqual(data["data"]["user"]["id"], "existing-user-uuid") + self.assertEqual(data["data"]["user"]["email"], "octocat@github.com") + + # Should NOT call database lookup when link_user_id is present + database_oauth_mock.get_oauth_account_by_provider.assert_not_called() + + # Should bind to the specified user + oauth_service_mock.create_or_update_oauth_account.assert_called_once_with( + user_id="existing-user-uuid", + provider="github", + provider_user_id="12345", + email="octocat@github.com", + username="octocat", + ) + + def test_success_with_already_bound_oauth_account(self): + """Callback with existing binding should use that user_id without Supabase lookup.""" + oauth_service_mock.reset_mock() + database_oauth_mock.reset_mock() + auth_utils_mock.reset_mock() + auth_utils_mock.get_current_user_id.return_value = ("user-1", "t-1") + auth_utils_mock.get_jwt_expiry_seconds.return_value = 3600 + auth_utils_mock.calculate_expires_at.return_value = 1735689600 + auth_utils_mock.generate_session_jwt.return_value = "eyJ.bound.jwt" + oauth_service_mock.parse_state.return_value = { + "provider": "github", + "token": "tok", + "link_user_id": "", + } + database_oauth_mock.get_oauth_account_by_provider.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "bound-user-uuid", + } + oauth_service_mock.exchange_code_for_provider_token.return_value = { + "access_token": "ghu_provider_token", + } + oauth_service_mock.get_provider_user_info.return_value = { + "id": "12345", + "email": "octocat@github.com", + "username": "octocat", + } + oauth_service_mock.ensure_user_tenant_exists.return_value = { + "user_id": "bound-user-uuid", + "tenant_id": "t-1", + } + oauth_service_mock.create_or_update_oauth_account.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "bound-user-uuid", + } + + response = client.get( + "/user/oauth/callback?provider=github&code=login_code&state=github:tok" + ) + + if response.status_code != HTTPStatus.OK: + print("Response:", response.json()) + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertEqual(data["data"]["user"]["id"], "bound-user-uuid") + + auth_utils_mock.get_supabase_admin_client.assert_not_called() + oauth_service_mock.create_or_update_oauth_account.assert_called_once() + + +class TestGetAccounts(unittest.TestCase): + def test_returns_accounts_with_auth(self): + oauth_service_mock.list_linked_accounts.return_value = [ + { + "provider": "github", + "provider_username": "octocat", + "linked_at": "2025-01-01", + } + ] + + response = client.get( + "/user/oauth/accounts", + headers={"Authorization": "Bearer valid_token"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertEqual(len(data["data"]), 1) + + def test_returns_401_without_auth(self): + response = client.get("/user/oauth/accounts") + + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) + + @patch("apps.oauth_app.get_current_user_id") + def test_returns_401_for_invalid_token(self, mock_get_user): + mock_get_user.side_effect = _UnauthorizedError("Invalid token") + + response = client.get( + "/user/oauth/accounts", + headers={"Authorization": "Bearer invalid"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) + + mock_get_user.side_effect = None + + +class TestDeleteAccount(unittest.TestCase): + def setUp(self): + mock_identity = MagicMock() + mock_identity.provider = "email" + + mock_user = MagicMock() + mock_user.identities = [mock_identity] + mock_user.app_metadata = MagicMock() + mock_user.app_metadata.get = MagicMock(return_value="email") + + mock_user_resp = MagicMock() + mock_user_resp.user = mock_user + + mock_admin = MagicMock() + mock_admin.auth.admin.get_user_by_id.return_value = mock_user_resp + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin + oauth_service_mock.count_oauth_accounts_by_user_id.return_value = 2 + + def test_unlinks_successfully(self): + oauth_service_mock.unlink_account.reset_mock() + oauth_service_mock.unlink_account.return_value = True + + response = client.delete( + "/user/oauth/accounts/github", + headers={"Authorization": "Bearer valid_token"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertTrue(data["data"]["unlinked"]) + oauth_service_mock.unlink_account.assert_called_once() + + def test_returns_401_without_auth(self): + response = client.delete("/user/oauth/accounts/github") + + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) + + @patch("apps.oauth_app.get_current_user_id") + def test_returns_400_when_last_account(self, mock_get_user): + mock_get_user.return_value = ("user-1", "t-1") + oauth_service_mock.unlink_account.side_effect = _OAuthLinkError( + "Cannot unlink last" + ) + + response = client.delete( + "/user/oauth/accounts/github", + headers={"Authorization": "Bearer valid"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + + oauth_service_mock.unlink_account.side_effect = None + + +class TestCallbackPagination(unittest.TestCase): + def test_finds_user_on_second_page(self): + oauth_service_mock.reset_mock() + database_oauth_mock.reset_mock() + auth_utils_mock.reset_mock() + auth_utils_mock.get_current_user_id.return_value = ("user-1", "t-1") + auth_utils_mock.get_jwt_expiry_seconds.return_value = 3600 + auth_utils_mock.calculate_expires_at.return_value = 1735689600 + auth_utils_mock.generate_session_jwt.return_value = "eyJ.page2.jwt" + oauth_service_mock.parse_state.return_value = {"provider": "github", "token": "tok", "link_user_id": ""} + database_oauth_mock.get_oauth_account_by_provider.return_value = None + database_oauth_mock.get_soft_deleted_oauth_account.return_value = None + oauth_service_mock.exchange_code_for_provider_token.return_value = {"access_token": "ghu_token"} + oauth_service_mock.get_provider_user_info.return_value = { + "id": "12345", + "email": "page2user@github.com", + "username": "page2user", + } + oauth_service_mock.ensure_user_tenant_exists.return_value = {"user_id": "page2-uuid", "tenant_id": "t-1"} + oauth_service_mock.create_or_update_oauth_account.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "page2-uuid", + } + + mock_page1_user = MagicMock() + mock_page1_user.id = "user-page1" + mock_page1_user.email = "other@github.com" + mock_page2_user = MagicMock() + mock_page2_user.id = "page2-uuid" + mock_page2_user.email = "page2user@github.com" + + mock_page1_resp = MagicMock() + mock_page1_resp.users = [mock_page1_user] + mock_page1_resp.__len__ = lambda self: 1 + + mock_page2_resp = MagicMock() + mock_page2_resp.users = [mock_page2_user] + mock_page2_resp.__len__ = lambda self: 1 + + mock_admin_client = MagicMock() + mock_admin_client.auth.admin.list_users.side_effect = [mock_page1_resp, mock_page2_resp] + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin_client + + response = client.get("/user/oauth/callback?provider=github&code=page2_code&state=github:tok") + + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertEqual(data["data"]["user"]["email"], "page2user@github.com") + + auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + + def test_stops_pagination_when_less_than_100_users(self): + oauth_service_mock.reset_mock() + database_oauth_mock.reset_mock() + auth_utils_mock.reset_mock() + auth_utils_mock.get_current_user_id.return_value = ("user-1", "t-1") + auth_utils_mock.get_jwt_expiry_seconds.return_value = 3600 + auth_utils_mock.calculate_expires_at.return_value = 1735689600 + auth_utils_mock.generate_session_jwt.return_value = "eyJ.new.jwt" + oauth_service_mock.parse_state.return_value = {"provider": "github", "token": "tok", "link_user_id": ""} + database_oauth_mock.get_oauth_account_by_provider.return_value = None + database_oauth_mock.get_soft_deleted_oauth_account.return_value = None + oauth_service_mock.exchange_code_for_provider_token.return_value = {"access_token": "ghu_token"} + oauth_service_mock.get_provider_user_info.return_value = { + "id": "67890", + "email": "newuser@github.com", + "username": "newuser", + } + oauth_service_mock.ensure_user_tenant_exists.return_value = {"user_id": "new-uuid", "tenant_id": "t-1"} + oauth_service_mock.create_or_update_oauth_account.return_value = { + "provider": "github", + "provider_user_id": "67890", + "user_id": "new-uuid", + } + + mock_empty_resp = MagicMock() + mock_empty_resp.users = [] + mock_empty_resp.__len__ = lambda self: 0 + + mock_new_user = MagicMock() + mock_new_user.id = "new-uuid" + + mock_admin_client = MagicMock() + mock_admin_client.auth.admin.list_users.return_value = mock_empty_resp + mock_admin_client.auth.admin.create_user.return_value = MagicMock(user=mock_new_user) + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin_client + + response = client.get("/user/oauth/callback?provider=github&code=short_page_code&state=github:tok") + + self.assertEqual(response.status_code, HTTPStatus.OK) + mock_admin_client.auth.admin.list_users.assert_called_once() + + auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + + +class TestCallbackEmailFallback(unittest.TestCase): + def test_creates_user_with_oauth_fallback_email(self): + oauth_service_mock.reset_mock() + database_oauth_mock.reset_mock() + auth_utils_mock.reset_mock() + auth_utils_mock.get_current_user_id.return_value = ("user-1", "t-1") + auth_utils_mock.get_jwt_expiry_seconds.return_value = 3600 + auth_utils_mock.calculate_expires_at.return_value = 1735689600 + auth_utils_mock.generate_session_jwt.return_value = "eyJ.noemail.jwt" + oauth_service_mock.parse_state.return_value = {"provider": "github", "token": "tok", "link_user_id": ""} + database_oauth_mock.get_oauth_account_by_provider.return_value = None + database_oauth_mock.get_soft_deleted_oauth_account.return_value = None + oauth_service_mock.exchange_code_for_provider_token.return_value = {"access_token": "ghu_token"} + oauth_service_mock.get_provider_user_info.return_value = { + "id": "99999", + "email": "", + "username": "noemail_user", + } + oauth_service_mock.ensure_user_tenant_exists.return_value = {"user_id": "noemail-uuid", "tenant_id": "t-1"} + oauth_service_mock.create_or_update_oauth_account.return_value = { + "provider": "github", + "provider_user_id": "99999", + "user_id": "noemail-uuid", + } + + mock_empty_resp = MagicMock() + mock_empty_resp.users = [] + mock_empty_resp.__len__ = lambda self: 0 + + mock_new_user = MagicMock() + mock_new_user.id = "noemail-uuid" + + mock_admin_client = MagicMock() + mock_admin_client.auth.admin.list_users.return_value = mock_empty_resp + mock_admin_client.auth.admin.create_user.return_value = MagicMock(user=mock_new_user) + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin_client + + response = client.get("/user/oauth/callback?provider=github&code=noemail_code&state=github:tok") + + self.assertEqual(response.status_code, HTTPStatus.OK) + data = response.json() + self.assertIn("@oauth.nexent", data["data"]["user"]["email"]) + + auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + + +class TestDeleteAccountMetadata(unittest.TestCase): + def test_handles_get_user_exception_gracefully(self): + oauth_service_mock.reset_mock() + oauth_service_mock.count_oauth_accounts_by_user_id.return_value = 2 + oauth_service_mock.unlink_account.return_value = True + + mock_admin = MagicMock() + mock_admin.auth.admin.get_user_by_id.side_effect = Exception("User lookup failed") + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin + + response = client.delete( + "/user/oauth/accounts/github", + headers={"Authorization": "Bearer valid_token"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.OK) + + auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + + def test_unlinks_with_password_auth_detected(self): + oauth_service_mock.reset_mock() + oauth_service_mock.count_oauth_accounts_by_user_id.return_value = 1 + oauth_service_mock.unlink_account.return_value = True + + mock_identity = MagicMock() + mock_identity.provider = "email" + + mock_user = MagicMock() + mock_user.identities = [mock_identity] + mock_user.app_metadata = MagicMock() + mock_user.app_metadata.get = MagicMock(return_value="email") + + mock_user_resp = MagicMock() + mock_user_resp.user = mock_user + + mock_admin = MagicMock() + mock_admin.auth.admin.get_user_by_id.return_value = mock_user_resp + auth_utils_mock.get_supabase_admin_client.return_value = mock_admin + + response = client.delete( + "/user/oauth/accounts/github", + headers={"Authorization": "Bearer valid_token"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.OK) + + auth_utils_mock.get_supabase_admin_client.return_value = MagicMock() + + +class TestGetAccounts(unittest.TestCase): + def test_returns_500_on_service_error(self): + oauth_service_mock.list_linked_accounts.side_effect = Exception("Database error") + + response = client.get( + "/user/oauth/accounts", + headers={"Authorization": "Bearer valid_token"}, + ) + + self.assertEqual(response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR) + + oauth_service_mock.list_linked_accounts.side_effect = None + + +if __name__ == "__main__": + unittest.main() diff --git a/test/backend/database/test_oauth_account_db.py b/test/backend/database/test_oauth_account_db.py new file mode 100644 index 000000000..0b883be19 --- /dev/null +++ b/test/backend/database/test_oauth_account_db.py @@ -0,0 +1,360 @@ +import sys +import os +import unittest +from unittest.mock import MagicMock + +test_dir = os.path.dirname(__file__) +backend_dir = os.path.abspath(os.path.join(test_dir, "../../../backend")) +sys.path.insert(0, backend_dir) + +consts_mock = MagicMock() +consts_mock.const = MagicMock() +consts_mock.const.MINIO_ENDPOINT = "http://localhost:9000" +consts_mock.const.MINIO_ACCESS_KEY = "test" +consts_mock.const.MINIO_SECRET_KEY = "test" +consts_mock.const.MINIO_REGION = "us-east-1" +consts_mock.const.MINIO_DEFAULT_BUCKET = "test" +consts_mock.const.POSTGRES_HOST = "localhost" +consts_mock.const.POSTGRES_USER = "test" +consts_mock.const.NEXENT_POSTGRES_PASSWORD = "test" +consts_mock.const.POSTGRES_DB = "test" +consts_mock.const.POSTGRES_PORT = 5432 +consts_mock.const.DEFAULT_TENANT_ID = "default-tenant" +sys.modules["consts"] = consts_mock +sys.modules["consts.const"] = consts_mock.const + +sys.modules["consts.exceptions"] = MagicMock() +sys.modules["boto3"] = MagicMock() + +sqlalchemy_mock = MagicMock() +sys.modules["sqlalchemy"] = sqlalchemy_mock +sys.modules["sqlalchemy.exc"] = sqlalchemy_mock.exc +sys.modules["sqlalchemy.orm"] = MagicMock() +sys.modules["sqlalchemy.dialects"] = MagicMock() +sys.modules["sqlalchemy.dialects.postgresql"] = MagicMock() + +mock_get_db_session = MagicMock() +mock_as_dict = MagicMock() + +client_mock = MagicMock() +client_mock.get_db_session = mock_get_db_session +client_mock.as_dict = mock_as_dict +client_mock.MinioClient = MagicMock() +client_mock.PostgresClient = MagicMock() +client_mock.db_client = MagicMock() +client_mock.filter_property = MagicMock() +sys.modules["database.client"] = client_mock + +db_models_mock = MagicMock() +db_models_mock.UserOAuthAccount = MagicMock() +db_models_mock.TableBase = MagicMock() +sys.modules["database.db_models"] = db_models_mock + +from database.oauth_account_db import ( + count_oauth_accounts_by_user_id, + delete_oauth_account, + get_oauth_account_by_provider, + get_soft_deleted_oauth_account, + insert_oauth_account, + list_oauth_accounts_by_user_id, + reactivate_oauth_account, + rebind_oauth_account, + soft_delete_all_oauth_accounts_by_user_id, + update_oauth_account_tokens, +) + + +def _make_mock_session(): + session = MagicMock() + query_mock = MagicMock() + filter_mock = MagicMock() + session.query.return_value = query_mock + query_mock.filter.return_value = filter_mock + + mock_get_db_session.return_value.__enter__ = MagicMock(return_value=session) + mock_get_db_session.return_value.__exit__ = MagicMock(return_value=False) + return session, query_mock, filter_mock + + +class TestInsertOAuthAccount(unittest.TestCase): + def test_insert_and_return_dict(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + session.add = MagicMock() + session.flush = MagicMock() + client_mock.as_dict.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "user-1", + } + + result = insert_oauth_account( + user_id="user-1", + provider="github", + provider_user_id="12345", + provider_email="test@github.com", + ) + + session.add.assert_called_once() + session.flush.assert_called_once() + self.assertEqual(result["provider"], "github") + + +class TestGetOAuthAccountByProvider(unittest.TestCase): + def test_returns_dict_when_found(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + filter_mock.first.return_value = mock_account + client_mock.as_dict.return_value = { + "provider": "github", + "provider_user_id": "12345", + } + + result = get_oauth_account_by_provider("github", "12345") + + self.assertIsNotNone(result) + self.assertEqual(result["provider"], "github") + + def test_returns_none_when_not_found(self): + session, query, filter_mock = _make_mock_session() + filter_mock.first.return_value = None + + result = get_oauth_account_by_provider("github", "nonexistent") + + self.assertIsNone(result) + + +class TestListOAuthAccountsByUserId(unittest.TestCase): + def test_returns_list_of_dicts(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + filter_mock.all.return_value = [mock_account] + client_mock.as_dict.return_value = {"provider": "github", "user_id": "user-1"} + + result = list_oauth_accounts_by_user_id("user-1") + + self.assertEqual(len(result), 1) + + def test_returns_empty_list(self): + session, query, filter_mock = _make_mock_session() + filter_mock.all.return_value = [] + + result = list_oauth_accounts_by_user_id("user-1") + + self.assertEqual(len(result), 0) + + +class TestUpdateOAuthAccountTokens(unittest.TestCase): + def test_updates_and_returns_true(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + filter_mock.first.return_value = mock_account + + result = update_oauth_account_tokens( + provider="github", + provider_user_id="12345", + provider_username="new_name", + ) + + self.assertTrue(result) + self.assertEqual(mock_account.provider_username, "new_name") + + def test_returns_false_when_not_found(self): + session, query, filter_mock = _make_mock_session() + filter_mock.first.return_value = None + + result = update_oauth_account_tokens("github", "nonexistent") + + self.assertFalse(result) + + def test_skips_none_fields(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + filter_mock.first.return_value = mock_account + + update_oauth_account_tokens("github", "12345") + + +class TestDeleteOAuthAccount(unittest.TestCase): + def test_soft_deletes_and_returns_true(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + filter_mock.first.return_value = mock_account + + result = delete_oauth_account("user-1", "github") + + self.assertTrue(result) + self.assertEqual(mock_account.delete_flag, "Y") + + def test_returns_false_when_not_found(self): + session, query, filter_mock = _make_mock_session() + filter_mock.first.return_value = None + + result = delete_oauth_account("user-1", "github") + + self.assertFalse(result) + + +class TestReactivateOAuthAccount(unittest.TestCase): + def test_reactivates_and_returns_true(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + mock_account.delete_flag = "Y" + filter_mock.first.return_value = mock_account + + result = reactivate_oauth_account( + provider="github", + provider_user_id="12345", + user_id="user-2", + provider_email="new@email.com", + provider_username="newname", + ) + + self.assertTrue(result) + self.assertEqual(mock_account.delete_flag, "N") + self.assertEqual(mock_account.user_id, "user-2") + + def test_returns_false_when_not_found(self): + session, query, filter_mock = _make_mock_session() + filter_mock.first.return_value = None + + result = reactivate_oauth_account("github", "12345", "user-1") + + self.assertFalse(result) + + +class TestCountOAuthAccountsByUserId(unittest.TestCase): + def test_returns_correct_count(self): + session, query, filter_mock = _make_mock_session() + filter_mock.count.return_value = 3 + + result = count_oauth_accounts_by_user_id("user-1") + + self.assertEqual(result, 3) + + def test_returns_zero_when_no_accounts(self): + session, query, filter_mock = _make_mock_session() + filter_mock.count.return_value = 0 + + result = count_oauth_accounts_by_user_id("user-1") + + self.assertEqual(result, 0) + + +class TestGetSoftDeletedOAuthAccount(unittest.TestCase): + def test_returns_dict_when_soft_deleted_found(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + mock_account.delete_flag = "Y" + filter_mock.first.return_value = mock_account + client_mock.as_dict.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "user-1", + "delete_flag": "Y", + } + + result = get_soft_deleted_oauth_account("github", "12345") + + self.assertIsNotNone(result) + self.assertEqual(result["delete_flag"], "Y") + self.assertEqual(result["provider"], "github") + + def test_returns_none_when_not_soft_deleted(self): + session, query, filter_mock = _make_mock_session() + filter_mock.first.return_value = None + + result = get_soft_deleted_oauth_account("github", "12345") + + self.assertIsNone(result) + + def test_returns_none_when_not_found(self): + session, query, filter_mock = _make_mock_session() + filter_mock.first.return_value = None + + result = get_soft_deleted_oauth_account("github", "nonexistent") + + self.assertIsNone(result) + + +class TestRebindOAuthAccount(unittest.TestCase): + def test_rebinds_to_new_user(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + mock_account.delete_flag = "N" + filter_mock.first.return_value = mock_account + client_mock.as_dict.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "new-user", + } + + result = rebind_oauth_account( + provider="github", + provider_user_id="12345", + new_user_id="new-user", + provider_email="new@email.com", + provider_username="newname", + ) + + self.assertTrue(result) + self.assertEqual(mock_account.user_id, "new-user") + self.assertEqual(mock_account.provider_email, "new@email.com") + self.assertEqual(mock_account.provider_username, "newname") + self.assertEqual(mock_account.updated_by, "new-user") + + def test_rebinds_keeps_existing_email_when_none_provided(self): + session, query, filter_mock = _make_mock_session() + mock_account = MagicMock() + mock_account.delete_flag = "N" + mock_account.provider_email = "existing@email.com" + mock_account.provider_username = "existingname" + filter_mock.first.return_value = mock_account + client_mock.as_dict.return_value = {"provider": "github", "user_id": "new-user"} + + result = rebind_oauth_account( + provider="github", + provider_user_id="12345", + new_user_id="new-user", + ) + + self.assertTrue(result) + self.assertEqual(mock_account.provider_email, "existing@email.com") + + def test_returns_false_when_not_found(self): + session, query, filter_mock = _make_mock_session() + filter_mock.first.return_value = None + + result = rebind_oauth_account("github", "nonexistent", "new-user") + + self.assertFalse(result) + + +class TestSoftDeleteAllOAuthAccountsByUserId(unittest.TestCase): + def test_soft_deletes_all_accounts(self): + session, query, filter_mock = _make_mock_session() + mock_account1 = MagicMock() + mock_account1.delete_flag = "N" + mock_account2 = MagicMock() + mock_account2.delete_flag = "N" + filter_mock.all.return_value = [mock_account1, mock_account2] + + result = soft_delete_all_oauth_accounts_by_user_id("user-1", deleted_by="admin") + + self.assertEqual(result, 2) + self.assertEqual(mock_account1.delete_flag, "Y") + self.assertEqual(mock_account2.delete_flag, "Y") + self.assertEqual(mock_account1.updated_by, "admin") + self.assertEqual(mock_account2.updated_by, "admin") + + def test_returns_zero_when_no_accounts(self): + session, query, filter_mock = _make_mock_session() + filter_mock.all.return_value = [] + + result = soft_delete_all_oauth_accounts_by_user_id("user-1", "admin") + + self.assertEqual(result, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/backend/services/test_oauth_service.py b/test/backend/services/test_oauth_service.py new file mode 100644 index 000000000..c974c1b5b --- /dev/null +++ b/test/backend/services/test_oauth_service.py @@ -0,0 +1,844 @@ +import sys +import os +import unittest +from unittest.mock import MagicMock, patch + +test_dir = os.path.dirname(__file__) +backend_dir = os.path.abspath(os.path.join(test_dir, "../../../backend")) +sys.path.insert(0, backend_dir) + +consts_mock = MagicMock() +consts_mock.const = MagicMock() +consts_mock.const.DEFAULT_TENANT_ID = "default-tenant-id" +consts_mock.const.OAUTH_CALLBACK_BASE_URL = "http://localhost:3000" +consts_mock.const.OAUTH_SSL_VERIFY = True +consts_mock.const.OAUTH_CA_BUNDLE = "" +sys.modules["consts"] = consts_mock +sys.modules["consts.const"] = consts_mock.const + + +class _OAuthProviderError(Exception): + pass + + +class _OAuthLinkError(Exception): + pass + + +exceptions_mock = MagicMock() +exceptions_mock.OAuthProviderError = _OAuthProviderError +exceptions_mock.OAuthLinkError = _OAuthLinkError +sys.modules["consts.exceptions"] = exceptions_mock + +oauth_account_db_mock = MagicMock() +sys.modules["database.oauth_account_db"] = oauth_account_db_mock + +db_pkg = MagicMock() +db_pkg.oauth_account_db = oauth_account_db_mock +sys.modules["database"] = db_pkg + +user_tenant_db_mock = MagicMock() +sys.modules["database.user_tenant_db"] = user_tenant_db_mock +db_pkg.user_tenant_db = user_tenant_db_mock + +model_mock = MagicMock() + + +class _FakeOAuthProviderDefinition: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __repr__(self): + return f"FakeDef({self.name})" + + +model_mock.OAuthProviderDefinition = _FakeOAuthProviderDefinition +sys.modules["consts.model"] = model_mock + +GITHUB_DEF = _FakeOAuthProviderDefinition( + name="github", + display_name="GitHub", + icon="github", + authorize_url="https://github.com/login/oauth/authorize", + authorize_method="GET", + authorize_params={"scope": "read:user user:email"}, + authorize_fragment="", + authorize_param_map={ + "client_id": "client_id", + "redirect_uri": "redirect_uri", + "scope": "scope", + "state": "state", + }, + encode_redirect_uri=False, + token_url="https://github.com/login/oauth/access_token", + token_method="POST", + token_params_map={ + "client_id": "client_id", + "client_secret": "client_secret", + "code": "code", + "grant_type": "grant_type", + "redirect_uri": "redirect_uri", + }, + token_extra_params={}, + token_error_key="error", + token_error_message_key="error_description", + token_response_id_key=None, + userinfo_url="https://api.github.com/user", + userinfo_auth_scheme="Bearer", + userinfo_params={}, + userinfo_field_map={ + "id": "id", + "email": "email", + "username": "login", + }, + userinfo_needs_email_fetch=True, + userinfo_email_url="https://api.github.com/user/emails", + client_id_env="GITHUB_OAUTH_CLIENT_ID", + client_secret_env="GITHUB_OAUTH_CLIENT_SECRET", + enabled_check=None, +) + +WECHAT_DEF = _FakeOAuthProviderDefinition( + name="wechat", + display_name="WeChat", + icon="wechat", + authorize_url="https://open.weixin.qq.com/connect/qrconnect", + authorize_method="GET", + authorize_params={"response_type": "code", "scope": "snsapi_login"}, + authorize_fragment="#wechat_redirect", + authorize_param_map={ + "client_id": "appid", + "redirect_uri": "redirect_uri", + "scope": "scope", + "state": "state", + }, + encode_redirect_uri=True, + token_url="https://api.weixin.qq.com/sns/oauth2/access_token", + token_method="GET", + token_params_map={ + "client_id": "appid", + "client_secret": "secret", + "code": "code", + "grant_type": "grant_type", + }, + token_extra_params={}, + token_error_key="errcode", + token_error_message_key="errmsg", + token_response_id_key="openid", + userinfo_url="https://api.weixin.qq.com/sns/userinfo", + userinfo_auth_scheme="", + userinfo_params={"openid": "{openid}"}, + userinfo_field_map={ + "id": "openid", + "email": "", + "username": "nickname", + }, + userinfo_needs_email_fetch=False, + userinfo_email_url=None, + client_id_env="WECHAT_OAUTH_APP_ID", + client_secret_env="WECHAT_OAUTH_APP_SECRET", + enabled_check="ENABLE_WECHAT_OAUTH", +) + +GDE_DEF = _FakeOAuthProviderDefinition( + name="gde", + display_name="Gde", + icon="gde", + authorize_url="https://gde.test/dspcas/oauth2.0/authorize", + authorize_method="GET", + authorize_params={}, + authorize_fragment="", + authorize_param_map={"client_id": "client_id", "redirect_uri": "redirect_uri"}, + encode_redirect_uri=False, + token_url="https://gde.test/dspcas/v2/oauth2.0/accessToken", + token_method="POST", + token_params_map={ + "client_id": "client_id", + "client_secret": "secret", + "code": "code", + "grant_type": "grant_type", + "redirect_uri": "redirect_uri", + }, + token_extra_params={}, + token_error_key="errorCode", + token_error_message_key="errorMessage", + token_response_id_key=None, + userinfo_url="https://gde.test/dspcas/oauth2.0/profile", + userinfo_auth_scheme="Bearer", + userinfo_params={"access_token": "{access_token}"}, + userinfo_field_map={"id": "attributes.userId", "email": "", "username": "id"}, + userinfo_needs_email_fetch=False, + userinfo_email_url=None, + client_id_env="GDE_OAUTH_CLIENT_ID", + client_secret_env="GDE_OAUTH_CLIENT_SECRET", + enabled_check=None, +) + +oauth_providers_mock = MagicMock() +oauth_providers_mock.OAUTH_PROVIDER_REGISTRY = { + "github": GITHUB_DEF, + "wechat": WECHAT_DEF, + "gde": GDE_DEF, +} + + +def _get_provider_definition(provider): + if provider in oauth_providers_mock.OAUTH_PROVIDER_REGISTRY: + return oauth_providers_mock.OAUTH_PROVIDER_REGISTRY[provider] + raise KeyError(provider) + + +def _is_provider_enabled(definition): + if definition.enabled_check: + return os.getenv(definition.enabled_check, "false").lower() in ( + "true", + "1", + "yes", + ) + client_id = os.getenv(definition.client_id_env, "") + client_secret = os.getenv(definition.client_secret_env, "") + return bool(client_id and client_secret) + + +def _get_all_provider_definitions(): + return dict(oauth_providers_mock.OAUTH_PROVIDER_REGISTRY) + + +oauth_providers_mock.get_provider_definition = _get_provider_definition +oauth_providers_mock.is_provider_enabled = _is_provider_enabled +oauth_providers_mock.get_all_provider_definitions = _get_all_provider_definitions +oauth_providers_mock.GITHUB_PROVIDER = GITHUB_DEF +oauth_providers_mock.WECHAT_PROVIDER = WECHAT_DEF +sys.modules["consts.oauth_providers"] = oauth_providers_mock + +import services.oauth_service as oauth_service_module +from services.oauth_service import ( + create_or_update_oauth_account, + ensure_user_tenant_exists, + exchange_code_for_provider_token, + get_authorize_url, + get_enabled_providers, + get_provider_user_info, + get_supported_providers, + list_linked_accounts, + parse_state, + unlink_account, + _resolve_field, + _build_ssl_context, +) + + +class TestParseState(unittest.TestCase): + def test_parses_full_state_with_link_user_id(self): + result = parse_state("github:random_token:user-123") + self.assertEqual(result["provider"], "github") + self.assertEqual(result["token"], "random_token") + self.assertEqual(result["link_user_id"], "user-123") + + def test_parses_state_without_link_user_id(self): + result = parse_state("github:random_token") + self.assertEqual(result["provider"], "github") + self.assertEqual(result["token"], "random_token") + self.assertEqual(result["link_user_id"], "") + + def test_parses_minimal_state(self): + result = parse_state("github") + self.assertEqual(result["provider"], "github") + self.assertEqual(result["token"], "") + self.assertEqual(result["link_user_id"], "") + + +class TestResolveField(unittest.TestCase): + def test_resolves_simple_field(self): + data = {"id": "12345", "email": "test@example.com"} + result = _resolve_field(data, "id") + self.assertEqual(result, "12345") + + def test_resolves_nested_field(self): + data = {"attributes": {"userId": "abc"}} + result = _resolve_field(data, "attributes.userId") + self.assertEqual(result, "abc") + + def test_returns_none_for_missing_field(self): + data = {"id": "12345"} + result = _resolve_field(data, "email") + self.assertIsNone(result) + + def test_returns_none_for_missing_nested_field(self): + data = {"attributes": {"name": "test"}} + result = _resolve_field(data, "attributes.userId") + self.assertIsNone(result) + + +class TestBuildSSLContext(unittest.TestCase): + def test_returns_default_context_when_verify_enabled(self): + ctx = _build_ssl_context() + self.assertEqual(ctx.verify_mode, 2) + + def test_returns_no_verify_context_when_disabled(self): + with patch.object(oauth_service_module, "OAUTH_SSL_VERIFY", False): + ctx = _build_ssl_context() + self.assertEqual(ctx.verify_mode, 0) + self.assertEqual(ctx.check_hostname, False) + + +class TestGetSupportedProviders(unittest.TestCase): + def test_supported_providers_set(self): + providers = get_supported_providers() + self.assertEqual(providers, {"github", "wechat", "gde"}) + + +class TestGetEnabledProviders(unittest.TestCase): + def test_returns_github_when_configured(self): + with patch.dict( + os.environ, + {"GITHUB_OAUTH_CLIENT_ID": "id", "GITHUB_OAUTH_CLIENT_SECRET": "secret"}, + clear=False, + ): + providers = get_enabled_providers() + + self.assertEqual(len(providers), 1) + self.assertEqual(providers[0]["name"], "github") + self.assertTrue(providers[0]["enabled"]) + + def test_returns_empty_when_nothing_configured(self): + env = { + k: "" + for k in [ + "GITHUB_OAUTH_CLIENT_ID", + "GITHUB_OAUTH_CLIENT_SECRET", + "WECHAT_OAUTH_APP_ID", + "WECHAT_OAUTH_APP_SECRET", + ] + } + env["ENABLE_WECHAT_OAUTH"] = "false" + with patch.dict(os.environ, env, clear=False): + providers = get_enabled_providers() + + self.assertEqual(len(providers), 0) + + def test_returns_both_when_all_configured(self): + env = { + "GITHUB_OAUTH_CLIENT_ID": "id", + "GITHUB_OAUTH_CLIENT_SECRET": "secret", + "ENABLE_WECHAT_OAUTH": "true", + "WECHAT_OAUTH_APP_ID": "wx_id", + "WECHAT_OAUTH_APP_SECRET": "wx_secret", + } + with patch.dict(os.environ, env, clear=False): + providers = get_enabled_providers() + + self.assertEqual(len(providers), 2) + names = [p["name"] for p in providers] + self.assertIn("github", names) + self.assertIn("wechat", names) + + +class TestGetAuthorizeUrl(unittest.TestCase): + def test_returns_github_authorize_url(self): + with patch.dict( + os.environ, + { + "GITHUB_OAUTH_CLIENT_ID": "gh_test_id", + "GITHUB_OAUTH_CLIENT_SECRET": "gh_test_secret", + }, + clear=False, + ): + url = get_authorize_url("github") + + self.assertIn("github.com/login/oauth/authorize", url) + self.assertIn("client_id=gh_test_id", url) + self.assertIn("redirect_uri=", url) + self.assertIn("state=github", url) + + def test_returns_github_authorize_url_with_link_user_id(self): + with patch.dict( + os.environ, + { + "GITHUB_OAUTH_CLIENT_ID": "gh_test_id", + "GITHUB_OAUTH_CLIENT_SECRET": "gh_test_secret", + }, + clear=False, + ): + url = get_authorize_url("github", link_user_id="user-123") + + self.assertIn("github.com/login/oauth/authorize", url) + self.assertIn("user-123", url) + + def test_returns_wechat_authorize_url(self): + env = { + "WECHAT_OAUTH_APP_ID": "wx_test_id", + "WECHAT_OAUTH_APP_SECRET": "wx_test_secret", + "ENABLE_WECHAT_OAUTH": "true", + } + with patch.dict(os.environ, env, clear=False): + url = get_authorize_url("wechat") + + self.assertIn("open.weixin.qq.com/connect/qrconnect", url) + self.assertIn("appid=wx_test_id", url) + self.assertTrue(url.endswith("#wechat_redirect")) + + def test_unsupported_provider_raises(self): + with self.assertRaises(_OAuthProviderError): + get_authorize_url("google") + + def test_unconfigured_provider_raises(self): + with patch.dict( + os.environ, + {"GITHUB_OAUTH_CLIENT_ID": "", "GITHUB_OAUTH_CLIENT_SECRET": ""}, + clear=False, + ): + with self.assertRaises(_OAuthProviderError): + get_authorize_url("github") + + +class TestExchangeCodeForProviderToken(unittest.TestCase): + def test_raises_for_unsupported_provider(self): + with self.assertRaises(_OAuthProviderError): + exchange_code_for_provider_token("google", "code123") + + +class TestGetProviderUserInfo(unittest.TestCase): + def test_raises_for_unsupported_provider(self): + with self.assertRaises(_OAuthProviderError): + get_provider_user_info("google", "token123") + + +class TestCreateOrUpdateOAuthAccount(unittest.TestCase): + def test_creates_new_account_when_none_exists(self): + oauth_account_db_mock.reset_mock() + oauth_account_db_mock.get_oauth_account_by_provider.return_value = None + oauth_account_db_mock.get_soft_deleted_oauth_account.return_value = None + oauth_account_db_mock.insert_oauth_account.return_value = { + "provider": "github", + "provider_user_id": "12345", + } + + result = create_or_update_oauth_account( + user_id="user-1", + provider="github", + provider_user_id="12345", + email="octo@github.com", + ) + + oauth_account_db_mock.insert_oauth_account.assert_called_once() + self.assertEqual(result["provider"], "github") + + def test_reactivates_soft_deleted_account(self): + oauth_account_db_mock.reset_mock() + oauth_account_db_mock.get_oauth_account_by_provider.side_effect = [ + None, + {"provider": "github", "provider_user_id": "12345", "user_id": "user-1"}, + ] + oauth_account_db_mock.get_soft_deleted_oauth_account.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "user-1", + "delete_flag": "Y", + } + oauth_account_db_mock.reactivate_oauth_account.return_value = True + + result = create_or_update_oauth_account( + user_id="user-1", + provider="github", + provider_user_id="12345", + email="octo@github.com", + username="octocat", + ) + + oauth_account_db_mock.reactivate_oauth_account.assert_called_once_with( + provider="github", + provider_user_id="12345", + user_id="user-1", + provider_email="octo@github.com", + provider_username="octocat", + tenant_id="default-tenant-id", + ) + oauth_account_db_mock.insert_oauth_account.assert_not_called() + self.assertEqual(result["user_id"], "user-1") + + def test_updates_existing_account(self): + oauth_account_db_mock.reset_mock() + oauth_account_db_mock.get_oauth_account_by_provider.side_effect = [ + {"provider": "github", "provider_user_id": "12345", "user_id": "user-1"}, + { + "provider": "github", + "provider_user_id": "12345", + "user_id": "user-1", + "updated": True, + }, + ] + + result = create_or_update_oauth_account( + user_id="user-1", + provider="github", + provider_user_id="12345", + username="new_name", + ) + + oauth_account_db_mock.update_oauth_account_tokens.assert_called_once() + self.assertTrue(result.get("updated")) + + def test_raises_when_already_bound_to_other_user(self): + oauth_account_db_mock.reset_mock() + oauth_account_db_mock.get_oauth_account_by_provider.return_value = { + "provider": "github", + "provider_user_id": "12345", + "user_id": "old-user", + } + + with self.assertRaises(_OAuthLinkError): + create_or_update_oauth_account( + user_id="new-user", + provider="github", + provider_user_id="12345", + email="octo@github.com", + username="octocat", + ) + + oauth_account_db_mock.update_oauth_account_tokens.assert_not_called() + oauth_account_db_mock.insert_oauth_account.assert_not_called() + + +class TestEnsureUserTenantExists(unittest.TestCase): + def test_returns_existing_tenant(self): + user_tenant_db_mock.get_user_tenant_by_user_id.reset_mock() + user_tenant_db_mock.insert_user_tenant.reset_mock() + user_tenant_db_mock.get_user_tenant_by_user_id.side_effect = None + user_tenant_db_mock.get_user_tenant_by_user_id.return_value = { + "user_id": "user-1", + "tenant_id": "t-1", + } + + result = ensure_user_tenant_exists("user-1", "test@example.com") + + self.assertEqual(result["tenant_id"], "t-1") + user_tenant_db_mock.insert_user_tenant.assert_not_called() + + def test_creates_tenant_when_missing(self): + user_tenant_db_mock.get_user_tenant_by_user_id.reset_mock() + user_tenant_db_mock.insert_user_tenant.reset_mock() + user_tenant_db_mock.get_user_tenant_by_user_id.side_effect = [ + None, + {"user_id": "user-1", "tenant_id": "default-tenant-id"}, + ] + + result = ensure_user_tenant_exists("user-1", "test@example.com") + + user_tenant_db_mock.insert_user_tenant.assert_called_once() + self.assertEqual(result["tenant_id"], "default-tenant-id") + + user_tenant_db_mock.get_user_tenant_by_user_id.side_effect = None + user_tenant_db_mock.get_user_tenant_by_user_id.return_value = { + "user_id": "user-1", + "tenant_id": "t-1", + } + + +class TestListLinkedAccounts(unittest.TestCase): + def test_transforms_db_results(self): + oauth_account_db_mock.list_oauth_accounts_by_user_id.return_value = [ + { + "provider": "github", + "provider_username": "octocat", + "provider_email": "octo@github.com", + "create_time": "2025-01-01T00:00:00", + } + ] + + result = list_linked_accounts("user-1") + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["provider"], "github") + self.assertEqual(result[0]["provider_username"], "octocat") + self.assertIn("linked_at", result[0]) + + def test_returns_empty_list(self): + oauth_account_db_mock.list_oauth_accounts_by_user_id.return_value = [] + + result = list_linked_accounts("user-1") + + self.assertEqual(len(result), 0) + + +class TestUnlinkAccount(unittest.TestCase): + def test_success_with_multiple_accounts(self): + oauth_account_db_mock.count_oauth_accounts_by_user_id.return_value = 2 + oauth_account_db_mock.delete_oauth_account.return_value = True + + result = unlink_account("user-1", "github") + + self.assertTrue(result) + + def test_raises_when_last_account_no_password(self): + oauth_account_db_mock.count_oauth_accounts_by_user_id.return_value = 1 + + with self.assertRaises(_OAuthLinkError): + unlink_account("user-1", "github") + + def test_allows_last_unlink_when_has_password(self): + oauth_account_db_mock.count_oauth_accounts_by_user_id.return_value = 1 + oauth_account_db_mock.delete_oauth_account.return_value = True + + result = unlink_account("user-1", "github", has_password_auth=True) + + self.assertTrue(result) + + def test_raises_when_account_not_found(self): + oauth_account_db_mock.count_oauth_accounts_by_user_id.return_value = 2 + oauth_account_db_mock.delete_oauth_account.return_value = False + + with self.assertRaises(_OAuthLinkError): + unlink_account("user-1", "github") + + +class TestHTTPHelpers(unittest.TestCase): + def test_http_post_json_returns_parsed_response(self): + mock_response = MagicMock() + mock_response.read.return_value = b'{"access_token": "test_token"}' + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_response) + mock_cm.__exit__ = MagicMock(return_value=False) + with patch("urllib.request.urlopen", return_value=mock_cm): + import services.oauth_service as svc + result = svc._http_post_json("https://test.com/token", {"code": "abc"}) + self.assertEqual(result["access_token"], "test_token") + + def test_http_get_json_returns_parsed_response(self): + mock_response = MagicMock() + mock_response.read.return_value = b'{"id": "12345", "login": "octocat"}' + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_response) + mock_cm.__exit__ = MagicMock(return_value=False) + with patch("urllib.request.urlopen", return_value=mock_cm): + import services.oauth_service as svc + result = svc._http_get_json("https://test.com/user") + self.assertEqual(result["id"], "12345") + + def test_http_post_json_merges_headers(self): + mock_response = MagicMock() + mock_response.read.return_value = b'{"result": "ok"}' + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_response) + mock_cm.__exit__ = MagicMock(return_value=False) + with patch("urllib.request.urlopen", return_value=mock_cm) as mock_urlopen: + import services.oauth_service as svc + svc._http_post_json("https://test.com/token", {"code": "abc"}, headers={"X-Custom": "value"}) + self.assertTrue(mock_urlopen.called) + + def test_http_get_json_with_headers(self): + mock_response = MagicMock() + mock_response.read.return_value = b'{"result": "ok"}' + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_response) + mock_cm.__exit__ = MagicMock(return_value=False) + with patch("urllib.request.urlopen", return_value=mock_cm): + import services.oauth_service as svc + result = svc._http_get_json("https://test.com/user", headers={"Authorization": "Bearer token"}) + self.assertEqual(result["result"], "ok") + + +class TestGetProviderUserInfoEdgeCases(unittest.TestCase): + def test_returns_email_from_primary_in_emails_list(self): + mock_user_resp = MagicMock() + mock_user_resp.read.return_value = b'{"id": "12345", "login": "octocat"}' + mock_emails_resp = MagicMock() + mock_emails_resp.read.return_value = b'[{"email": "secondary@github.com", "primary": false}, {"email": "primary@github.com", "primary": true}]' + + mock_cm1 = MagicMock() + mock_cm1.__enter__ = MagicMock(return_value=mock_user_resp) + mock_cm1.__exit__ = MagicMock(return_value=False) + mock_cm2 = MagicMock() + mock_cm2.__enter__ = MagicMock(return_value=mock_emails_resp) + mock_cm2.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", side_effect=[mock_cm1, mock_cm2]): + env = { + "GITHUB_OAUTH_CLIENT_ID": "id", + "GITHUB_OAUTH_CLIENT_SECRET": "secret", + } + with patch.dict(os.environ, env, clear=False): + result = get_provider_user_info("github", "test_token") + + self.assertEqual(result["email"], "primary@github.com") + + def test_returns_first_email_when_no_primary(self): + mock_user_resp = MagicMock() + mock_user_resp.read.return_value = b'{"id": "12345", "login": "octocat"}' + mock_emails_resp = MagicMock() + mock_emails_resp.read.return_value = b'[{"email": "first@github.com"}]' + + mock_cm1 = MagicMock() + mock_cm1.__enter__ = MagicMock(return_value=mock_user_resp) + mock_cm1.__exit__ = MagicMock(return_value=False) + mock_cm2 = MagicMock() + mock_cm2.__enter__ = MagicMock(return_value=mock_emails_resp) + mock_cm2.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", side_effect=[mock_cm1, mock_cm2]): + env = { + "GITHUB_OAUTH_CLIENT_ID": "id", + "GITHUB_OAUTH_CLIENT_SECRET": "secret", + } + with patch.dict(os.environ, env, clear=False): + result = get_provider_user_info("github", "test_token") + + self.assertEqual(result["email"], "first@github.com") + + def test_fallback_email_when_no_email_found(self): + mock_user_resp = MagicMock() + mock_user_resp.read.return_value = b'{"id": "12345", "login": "testuser"}' + mock_emails_resp = MagicMock() + mock_emails_resp.read.return_value = b'[]' + + mock_cm1 = MagicMock() + mock_cm1.__enter__ = MagicMock(return_value=mock_user_resp) + mock_cm1.__exit__ = MagicMock(return_value=False) + mock_cm2 = MagicMock() + mock_cm2.__enter__ = MagicMock(return_value=mock_emails_resp) + mock_cm2.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", side_effect=[mock_cm1, mock_cm2]): + env = { + "GITHUB_OAUTH_CLIENT_ID": "id", + "GITHUB_OAUTH_CLIENT_SECRET": "secret", + } + with patch.dict(os.environ, env, clear=False): + result = get_provider_user_info("github", "test_token") + + self.assertEqual(result["email"], "testuser@nexent.com") + + def test_wechat_does_not_fetch_emails(self): + mock_user_resp = MagicMock() + mock_user_resp.read.return_value = b'{"openid": "wx123", "nickname": "wechat_user"}' + + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_user_resp) + mock_cm.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_cm): + env = { + "ENABLE_WECHAT_OAUTH": "true", + "WECHAT_OAUTH_APP_ID": "id", + "WECHAT_OAUTH_APP_SECRET": "secret", + } + with patch.dict(os.environ, env, clear=False): + result = get_provider_user_info("wechat", "test_token", openid="wx123") + + self.assertEqual(result["id"], "wx123") + self.assertEqual(result["username"], "wechat_user") + + def test_resolves_nested_field_path(self): + mock_user_resp = MagicMock() + mock_user_resp.read.return_value = b'{"attributes": {"userId": "nested123"}, "id": "testuser"}' + + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_user_resp) + mock_cm.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_cm): + env = { + "GDE_URL": "https://gde.test", + "GDE_OAUTH_CLIENT_ID": "id", + "GDE_OAUTH_CLIENT_SECRET": "secret", + } + with patch.dict(os.environ, env, clear=False): + result = get_provider_user_info("gde", "test_token") + + self.assertEqual(result["id"], "nested123") + + +class TestExchangeCodeForProviderTokenWithMock(unittest.TestCase): + def test_exchange_with_post_method(self): + mock_token_resp = MagicMock() + mock_token_resp.read.return_value = b'{"access_token": "gh_token_123"}' + + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_token_resp) + mock_cm.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_cm): + env = { + "GITHUB_OAUTH_CLIENT_ID": "test_id", + "GITHUB_OAUTH_CLIENT_SECRET": "test_secret", + } + with patch.dict(os.environ, env, clear=False): + result = exchange_code_for_provider_token("github", "code123") + + self.assertEqual(result["access_token"], "gh_token_123") + + def test_exchange_with_get_method(self): + mock_token_resp = MagicMock() + mock_token_resp.read.return_value = b'{"access_token": "wx_token_456", "openid": "wx_openid"}' + + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_token_resp) + mock_cm.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_cm): + env = { + "ENABLE_WECHAT_OAUTH": "true", + "WECHAT_OAUTH_APP_ID": "wx_id", + "WECHAT_OAUTH_APP_SECRET": "wx_secret", + } + with patch.dict(os.environ, env, clear=False): + result = exchange_code_for_provider_token("wechat", "code456") + + self.assertEqual(result["access_token"], "wx_token_456") + self.assertEqual(result["openid"], "wx_openid") + + def test_raises_on_provider_error_response(self): + mock_token_resp = MagicMock() + mock_token_resp.read.return_value = b'{"errcode": 40001, "errmsg": "invalid code"}' + + mock_cm = MagicMock() + mock_cm.__enter__ = MagicMock(return_value=mock_token_resp) + mock_cm.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_cm): + env = { + "ENABLE_WECHAT_OAUTH": "true", + "WECHAT_OAUTH_APP_ID": "wx_id", + "WECHAT_OAUTH_APP_SECRET": "wx_secret", + } + with patch.dict(os.environ, env, clear=False): + with self.assertRaises(_OAuthProviderError): + exchange_code_for_provider_token("wechat", "bad_code") + + +class TestGetAuthorizeUrlEdgeCases(unittest.TestCase): + def test_includes_authorize_params(self): + env = { + "GITHUB_OAUTH_CLIENT_ID": "gh_test_id", + "GITHUB_OAUTH_CLIENT_SECRET": "gh_test_secret", + } + with patch.dict(os.environ, env, clear=False): + url = get_authorize_url("github") + + self.assertIn("scope=", url) + + def test_wechat_includes_fragment(self): + env = { + "ENABLE_WECHAT_OAUTH": "true", + "WECHAT_OAUTH_APP_ID": "wx_test_id", + "WECHAT_OAUTH_APP_SECRET": "wx_test_secret", + } + with patch.dict(os.environ, env, clear=False): + url = get_authorize_url("wechat") + + self.assertTrue(url.endswith("#wechat_redirect")) + + def test_includes_state_token(self): + env = { + "GITHUB_OAUTH_CLIENT_ID": "gh_test_id", + "GITHUB_OAUTH_CLIENT_SECRET": "gh_test_secret", + } + with patch.dict(os.environ, env, clear=False): + url = get_authorize_url("github") + + self.assertIn("state=github", url) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 4d9327715f4448007f8d328108d388e671ba640d Mon Sep 17 00:00:00 2001 From: Wenbo Zhang <122337639+Stockton11@users.noreply.github.com> Date: Sat, 25 Apr 2026 14:35:00 +0800 Subject: [PATCH 006/156] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20File=20preview:=20?= =?UTF-8?q?Change=20the=20preview=20style=20of=20txt=20and=20merge=20the?= =?UTF-8?q?=20preview=20of=20unsuploaded=20files=20(#2840)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implementing virtual scrolling with react-virtuoso and support automatic line wrapping * merge unuploaded file preview logic * bug fix --- .../[locale]/chat/components/chatInput.tsx | 241 +-------------- frontend/components/ui/PdfViewer.tsx | 2 - frontend/components/ui/filePreviewDrawer.tsx | 284 ++++++++++++------ frontend/public/locales/en/common.json | 1 + frontend/public/locales/zh/common.json | 1 + frontend/types/chat.ts | 19 +- 6 files changed, 229 insertions(+), 319 deletions(-) diff --git a/frontend/app/[locale]/chat/components/chatInput.tsx b/frontend/app/[locale]/chat/components/chatInput.tsx index 9b175c8cd..e83341738 100644 --- a/frontend/app/[locale]/chat/components/chatInput.tsx +++ b/frontend/app/[locale]/chat/components/chatInput.tsx @@ -18,6 +18,7 @@ import { Input } from "@/components/ui/input"; import { Button } from "antd"; import { Tooltip } from "@/components/ui/tooltip"; import { Textarea } from "@/components/ui/textarea"; +import { FilePreviewDrawer } from "@/components/ui/filePreviewDrawer"; import { conversationService } from "@/services/conversationService"; import { useConfig } from "@/hooks/useConfig"; import { extractColorsFromUri } from "@/lib/avatar"; @@ -27,192 +28,6 @@ import { FilePreview } from "@/types/chat"; import { ChatAgentSelector } from "./chatAgentSelector"; -// Image viewer component -function ImageViewer({ - src, - alt, - onClose, -}: { - src: string; - alt: string; - onClose: () => void; -}) { - const { t } = useTranslation("common"); - return ( -
-
e.stopPropagation()} - > - {alt} - -
-
- ); -} - -// File preview component -function FileViewer({ file, onClose }: { file: File; onClose: () => void }) { - const [content, setContent] = useState(null); - const [loading, setLoading] = useState(true); - const [error, setError] = useState(null); - const fileType = file.type; - const extension = getFileExtension(file.name); - const { t } = useTranslation("common"); - - // Read file content - useEffect(() => { - setLoading(true); - setError(null); - - const readTextFile = () => { - const reader = new FileReader(); - - reader.onload = (event) => { - if (event.target?.result) { - setContent(event.target.result as string); - setLoading(false); - } - }; - - reader.onerror = () => { - setError(t("chatInput.cannotReadFileContent")); - setLoading(false); - }; - - reader.readAsText(file); - }; - - const readBinaryFile = () => { - const objectUrl = URL.createObjectURL(file); - setContent(objectUrl); - setLoading(false); - - return () => { - URL.revokeObjectURL(objectUrl); - }; - }; - - // Select the appropriate read method based on the file type - if (isTextFile(fileType, extension)) { - readTextFile(); - } else { - return readBinaryFile(); - } - }, [file, fileType, extension, t]); - - // Determine if it is a text file - const isTextFile = (type: string, ext: string) => { - return chatConfig.textTypes.includes(type) || chatConfig.textExtensions.includes(ext); - }; - - // Render file content - const renderFileContent = () => { - if (loading) { - return ( -
- {t("chatInput.loadingFileContent")} -
- ); - } - - if (error) { - return
{error}
; - } - - if (content === null) { - return ( -
- {t("chatInput.cannotPreviewFileType")} -
- ); - } - - if (fileType.startsWith("image/")) { - return ( -
- {file.name} -
- ); - } - - if (fileType === "application/pdf" || extension === "pdf") { - return ( -