From 439ed49e6d6fff5c63ed3576bbb7eaa3c1c915b9 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 17:50:35 +0800 Subject: [PATCH 1/7] hotfix:hotfix --- src/memos/api/product_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 30df150ea..f7f0304c7 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -201,8 +201,8 @@ class APIADDRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) - async_mode: Literal["async", "sync"] = Field( - "async", description="Whether to add memory in async mode" + async_mode: Literal["async", "sync"] | None = Field( + None, description="Whether to add memory in async mode" ) From b0e4afe8c2512fedc235514220960ddc2fc80103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sat, 22 Nov 2025 11:40:51 +0800 Subject: [PATCH 2/7] feat: enhance APIADDRequest with custom_tags, info, and is_feedback fields - Add custom_tags field for user-defined tags (e.g., ['Travel', 'family']) that can be used as search filters - Add info field for additional metadata (agent_id, app_id, source_type, etc.) with all keys usable as search filters - Add is_feedback field to indicate if the request represents user feedback - Reorganize fields with category comments for better readability - Mark async_mode as required with default value 'async' - Mark mem_cube_id, memory_content, doc_path, and source as deprecated - Enhance field descriptions for better API documentation --- src/memos/api/product_models.py | 112 ++++++++++++++++++++++++++++---- 1 file changed, 101 insertions(+), 11 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index cb72011a3..35dc2d4ab 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -199,22 +199,112 @@ class APISearchRequest(BaseRequest): class APIADDRequest(BaseRequest): """Request model for creating memories.""" + # ==== Basic identifiers ==== user_id: str = Field(None, description="User ID") - mem_cube_id: str | None = Field(None, description="Cube ID") + session_id: str | None = Field( + None, + description="Session ID. If not provided, a default session will be used.", + ) + # ==== Single-cube writing (Deprecated) ==== + mem_cube_id: str | None = Field( + None, + description="(Deprecated) Target cube ID for this add request (optional for developer API).", + ) + + # ==== Multi-cube writing ==== writable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can write for multi-cube add" ) - messages: list[MessageDict] | None = Field(None, description="List of messages to store.") - memory_content: str | None = Field(None, description="Memory content to store") - doc_path: str | None = Field(None, description="Path to document to store") - source: str | None = Field(None, description="Source of the memory") - chat_history: list[MessageDict] | None = Field(None, description="Chat history") - session_id: str | None = Field(None, description="Session id") - operation: list[PermissionDict] | None = Field( - None, description="operation ids for multi cubes" + + # ==== Async control ==== + async_mode: Literal["async", "sync"] = Field( + "async", + description=( + "Whether to add memory in async mode. " + "Use 'async' to enqueue background add (non-blocking), " + "or 'sync' to add memories in the current call. " + "Default: 'async'." + ), ) - async_mode: Literal["async", "sync"] | None = Field( - None, description="Whether to add memory in async mode" + + # ==== Business tags & info ==== + custom_tags: list[str] | None = Field( + None, + description=( + "Custom tags for this add request, e.g. ['Travel', 'family']. " + "These tags can be used as filters in search." + ), + ) + + info: dict[str, str] | None = Field( + None, + description=( + "Additional metadata for the add request. " + "All keys can be used as filters in search. " + "Example: " + "{'agent_id': 'xxxxxx', " + "'app_id': 'xxxx', " + "'source_type': 'web', " + "'source_url': 'https://www.baidu.com', " + "'source_content': '西湖是杭州最著名的景点'}." + ), + ) + + # ==== Input content ==== + messages: list[MessageDict] | None = Field( + None, + description=( + "List of messages to store. Supports: " + "- system / user / assistant messages with 'content' and 'chat_time'; " + "- tool messages including: " + " * tool_description (name, description, parameters), " + " * tool_input (call_id, name, argument), " + " * raw tool messages where content is str or list[str], " + " * tool_output with structured output items " + " (input_text / input_image / input_file, etc.). " + "Also supports pure input items when there is no dialog." + ), + ) + + # pure input (no role) + # e.g. {\"type\": \"input_text\", \"text\": \"你好\"} etc。 + # If there is no dialog, higher-level code can wrap raw inputs into MessageDict. + + # ==== Chat history ==== + chat_history: list[MessageDict] | None = Field( + None, + description=( + "Historical chat messages used internally by algorithms. " + "If None, internal stored history will be used; " + "if provided (even an empty list), this value will be used as-is." + ), + ) + + # ==== Feedback flag ==== + is_feedback: bool = Field( + False, + description=("Whether this request represents user feedback. Default: False."), + ) + + # ==== Backward compatibility fields (will delete later) ==== + memory_content: str | None = Field( + None, + description="(Deprecated) Plain memory content to store. Prefer using `messages`.", + ) + doc_path: str | None = Field( + None, + description="(Deprecated / internal) Path to document to store.", + ) + source: str | None = Field( + None, + description=( + "(Deprecated) Simple source tag of the memory. " + "Prefer using `info.source_type` / `info.source_url`." + ), + ) + operation: list[PermissionDict] | None = Field( + None, + description=("(Internal) Operation definitions for multi-cube write permissions."), ) From f5f82580fcca907940c3ed06e0106684eabb8de7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sat, 22 Nov 2025 15:43:37 +0800 Subject: [PATCH 3/7] refactor(api): enhance APISearchRequest model with improved structure and documentation - Reorganize fields with clear section comments (Basic inputs, Cube scoping, Search mode, etc.) - Add comprehensive field descriptions for better API documentation - Add new 'filter' field for structured filter conditions with support for logical operators, comparisons, and string operations - Add 'pref_top_k' and 'include_preference' fields for preference memory handling - Mark 'mem_cube_id' as deprecated, recommend 'readable_cube_ids' for multi-cube search - Make 'user_id' required field (was optional) - Add validation constraints (e.g., top_k >= 1) - Improve backward compatibility notes and internal field descriptions - Add 'threshold' field for internal similarity threshold control --- src/memos/api/product_models.py | 120 ++++++++++++++++++++++++++++---- 1 file changed, 106 insertions(+), 14 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 35dc2d4ab..ad44194fa 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -1,7 +1,7 @@ import os import uuid -from typing import Generic, Literal, TypeVar +from typing import Any, Generic, Literal, TypeVar from pydantic import BaseModel, Field @@ -175,25 +175,117 @@ class SearchRequest(BaseRequest): class APISearchRequest(BaseRequest): """Request model for searching memories.""" - query: str = Field(..., description="Search query") - user_id: str = Field(None, description="User ID") - mem_cube_id: str | None = Field(None, description="Cube ID to search in") + # ==== Basic inputs ==== + query: str = Field( + ..., + description=("User search query"), + ) + user_id: str = Field(..., description="User ID") + + # ==== Cube scoping ==== + mem_cube_id: str | None = Field( + None, + description=( + "(Deprecated) Single cube ID to search in. " + "Prefer `readable_cube_ids` for multi-cube search." + ), + ) readable_cube_ids: list[str] | None = Field( - None, description="List of cube IDs user can read for multi-cube search" + None, + description=( + "List of cube IDs that are readable for this request. " + "Required for algorithm-facing API; optional for developer-facing API." + ), ) + + # ==== Search mode ==== mode: SearchMode = Field( - os.getenv("SEARCH_MODE", SearchMode.FAST), description="search mode: fast, fine, or mixture" + os.getenv("SEARCH_MODE", SearchMode.FAST), + description="Search mode: FAST, FINE, or MIXTURE.", ) - internet_search: bool = Field(False, description="Whether to use internet search") - moscube: bool = Field(False, description="Whether to use MemOSCube") - top_k: int = Field(10, description="Number of results to return") - chat_history: list[MessageDict] | None = Field(None, description="Chat history") - session_id: str | None = Field(None, description="Session ID for soft-filtering memories") + + session_id: str | None = Field( + None, + description=( + "Session ID used as a soft signal to prioritize more relevant memories. " + "Only used for weighting, not as a hard filter." + ), + ) + + # ==== Result control ==== + top_k: int = Field( + 10, + ge=1, + description="Number of textual memories to retrieve (top-K). Default: 10.", + ) + + pref_top_k: int = Field( + 6, + ge=0, + description="Number of preference memories to retrieve (top-K). Default: 6.", + ) + + include_preference: bool = Field( + True, + description=( + "Whether to retrieve preference memories along with general memories. " + "If enabled, the system will automatically recall user preferences " + "relevant to the query. Default: True." + ), + ) + + # ==== Filter conditions ==== + filter: dict[str, Any] | None = Field( + None, + description=( + "Structured filter conditions for searching memories. " + "Supports logical operators: AND, OR, NOT; " + "comparison: E (==), NE (!=), GT (>), LT (<), GTE (>=), LTE (<=); " + "arithmetic: +, -, *, /, %, **; " + "set: IN, CONTAINS, ICONTAINS; " + "string: LIKE. " + "This nested dict will be converted into an internal expression tree." + ), + ) + + # ==== Extended capabilities ==== + internet_search: bool = Field( + False, + description=( + "Whether to enable internet search in addition to memory search. " + "Primarily used by internal algorithms. Default: False." + ), + ) + + # Inner user, not supported in API yet + threshold: float | None = Field( + None, + description=( + "Internal similarity threshold for searching plaintext memories. " + "If None, default thresholds will be applied." + ), + ) + + # ==== Context ==== + chat_history: list[MessageDict] | None = Field( + None, + description=( + "Historical chat messages used internally by algorithms. " + "If None, internal stored history may be used; " + "if provided (even an empty list), this value will be used as-is." + ), + ) + + # ==== Backward compatibility ==== + moscube: bool = Field( + False, + description="(Deprecated / internal) Whether to use legacy MemOSCube path.", + ) + operation: list[PermissionDict] | None = Field( - None, description="operation ids for multi cubes" + None, + description="(Internal) Operation definitions for multi-cube read permissions.", ) - include_preference: bool = Field(True, description="Whether to handle preference memory") - pref_top_k: int = Field(6, description="Number of preference results to return") class APIADDRequest(BaseRequest): From 39a7b34018c675f3cd5835d99a7cb058bcf60c97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sat, 22 Nov 2025 17:00:33 +0800 Subject: [PATCH 4/7] test: add routers api --- tests/api/test_product_router.py | 450 +++++++++++++++++++++++++++++++ tests/api/test_server_router.py | 445 ++++++++++++++++++++++++++++++ 2 files changed, 895 insertions(+) create mode 100644 tests/api/test_product_router.py create mode 100644 tests/api/test_server_router.py diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py new file mode 100644 index 000000000..9ed67d037 --- /dev/null +++ b/tests/api/test_product_router.py @@ -0,0 +1,450 @@ +""" +Unit tests for product_router input/output format validation. + +This module tests that the product_router endpoints correctly validate +input request formats and return properly formatted responses. +""" + +# Mock sklearn before importing any memos modules to avoid import errors +import importlib.util +import sys + +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from fastapi.testclient import TestClient + +# Patch the MOS_PRODUCT_INSTANCE directly after import +import memos.api.routers.product_router as pr_module + + +# Create a proper mock module with __spec__ +sklearn_mock = MagicMock() +sklearn_mock.__spec__ = importlib.util.spec_from_loader("sklearn", None) +sys.modules["sklearn"] = sklearn_mock + +sklearn_fe_mock = MagicMock() +sklearn_fe_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction", None) +sys.modules["sklearn.feature_extraction"] = sklearn_fe_mock + +sklearn_fet_mock = MagicMock() +sklearn_fet_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction.text", None) +sklearn_fet_mock.TfidfVectorizer = MagicMock() +sys.modules["sklearn.feature_extraction.text"] = sklearn_fet_mock + +# Mock sklearn.metrics as well +sklearn_metrics_mock = MagicMock() +sklearn_metrics_mock.__spec__ = importlib.util.spec_from_loader("sklearn.metrics", None) +sklearn_metrics_mock.roc_curve = MagicMock() +sys.modules["sklearn.metrics"] = sklearn_metrics_mock + + +# Create mock instance +_mock_mos_instance = Mock() + +pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance +pr_module.get_mos_product_instance = lambda: _mock_mos_instance + +# Mock MOSProduct class before importing to prevent initialization +with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance): + # Import after patching + from memos.api import product_api + + +@pytest.fixture(scope="module") +def mock_mos_product_instance(): + """Mock get_mos_product_instance for all tests.""" + # Ensure the mock is set + pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance + pr_module.get_mos_product_instance = lambda: _mock_mos_instance + yield product_api.app, _mock_mos_instance + + +@pytest.fixture +def client(mock_mos_product_instance): + """Create test client for product_api.""" + app, _ = mock_mos_product_instance + return TestClient(app) + + +@pytest.fixture +def mock_mos_product(mock_mos_product_instance): + """Get the mocked MOSProduct instance.""" + _, mock_instance = mock_mos_product_instance + # Ensure get_mos_product_instance returns this mock + import memos.api.routers.product_router as pr_module + + pr_module.get_mos_product_instance = lambda: mock_instance + pr_module.MOS_PRODUCT_INSTANCE = mock_instance + return mock_instance + + +@pytest.fixture(autouse=True) +def setup_mock_mos_product(mock_mos_product): + """Set up default return values for MOSProduct methods.""" + # Set up default return values for methods + mock_mos_product.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []} + mock_mos_product.add.return_value = None + mock_mos_product.chat.return_value = ("test response", []) + mock_mos_product.chat_with_references.return_value = iter( + ['data: {"type": "content", "data": "test"}\n\n'] + ) + # Ensure get_all and get_subgraph return proper list format (MemoryResponse expects list) + default_memory_result = [{"cube_id": "test_cube", "memories": []}] + mock_mos_product.get_all.return_value = default_memory_result + mock_mos_product.get_subgraph.return_value = default_memory_result + mock_mos_product.get_suggestion_query.return_value = ["suggestion1", "suggestion2"] + # Ensure get_mos_product_instance returns the mock + import memos.api.routers.product_router as pr_module + + pr_module.get_mos_product_instance = lambda: mock_mos_product + + +class TestProductRouterSearch: + """Test /search endpoint input/output format.""" + + def test_search_valid_input_output(self, mock_mos_product, client): + """Test search endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "query": "test query", + "mem_cube_id": "test_cube", + "top_k": 10, + } + + response = client.post("/product/search", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], dict) + + # Verify MOSProduct.search was called with correct parameters + mock_mos_product.search.assert_called_once() + call_kwargs = mock_mos_product.search.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["query"] == "test query" + + def test_search_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test search endpoint with missing required field.""" + request_data = { + "query": "test query", + } + + response = client.post("/product/search", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_search_response_format(self, mock_mos_product, client): + """Test search endpoint returns SearchResponse format.""" + mock_mos_product.search.return_value = { + "text_mem": [{"cube_id": "test_cube", "memories": []}], + "act_mem": [], + "para_mem": [], + } + + request_data = { + "user_id": "test_user", + "query": "test query", + } + + response = client.post("/product/search", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Search completed successfully" + assert isinstance(data["data"], dict) + assert "text_mem" in data["data"] + + +class TestProductRouterAdd: + """Test /add endpoint input/output format.""" + + def test_add_valid_input_output(self, mock_mos_product, client): + """Test add endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "memory_content": "test memory content", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/add", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert data["data"] is None # SimpleResponse has None data + + # Verify MOSProduct.add was called with correct parameters + mock_mos_product.add.assert_called_once() + call_kwargs = mock_mos_product.add.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["memory_content"] == "test memory content" + + def test_add_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test add endpoint with missing required field.""" + request_data = { + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_add_response_format(self, mock_mos_product, client): + """Test add endpoint returns SimpleResponse format.""" + request_data = { + "user_id": "test_user", + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Memory created successfully" + assert data["data"] is None + + +class TestProductRouterChatComplete: + """Test /chat/complete endpoint input/output format.""" + + def test_chat_complete_valid_input_output(self, mock_mos_product, client): + """Test chat/complete endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "query": "test query", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/chat/complete", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "message" in data + assert "data" in data + assert isinstance(data["data"], dict) + assert "response" in data["data"] + assert "references" in data["data"] + + # Verify MOSProduct.chat was called with correct parameters + mock_mos_product.chat.assert_called_once() + call_kwargs = mock_mos_product.chat.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["query"] == "test query" + + def test_chat_complete_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test chat/complete endpoint with missing required field.""" + request_data = { + "query": "test query", + } + + response = client.post("/product/chat/complete", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_chat_complete_response_format(self, mock_mos_product, client): + """Test chat/complete endpoint returns correct format.""" + mock_mos_product.chat.return_value = ("test response", [{"id": "ref1"}]) + + request_data = { + "user_id": "test_user", + "query": "test query", + } + + response = client.post("/product/chat/complete", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Chat completed successfully" + assert isinstance(data["data"]["response"], str) + assert isinstance(data["data"]["references"], list) + + +class TestProductRouterChat: + """Test /chat endpoint input/output format (SSE stream).""" + + def test_chat_valid_input_output(self, mock_mos_product, client): + """Test chat endpoint with valid input returns SSE stream.""" + request_data = { + "user_id": "test_user", + "query": "test query", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/chat", json=request_data) + + assert response.status_code == 200 + assert "text/event-stream" in response.headers["content-type"] + + # Verify MOSProduct.chat_with_references was called + mock_mos_product.chat_with_references.assert_called_once() + call_kwargs = mock_mos_product.chat_with_references.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["query"] == "test query" + + def test_chat_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test chat endpoint with missing required field.""" + request_data = { + "query": "test query", + } + + response = client.post("/product/chat", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + +class TestProductRouterSuggestions: + """Test /suggestions endpoint input/output format.""" + + def test_suggestions_valid_input_output(self, mock_mos_product, client): + """Test suggestions endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "mem_cube_id": "test_cube", + "language": "zh", + } + + response = client.post("/product/suggestions", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], dict) + assert "query" in data["data"] + + # Verify MOSProduct.get_suggestion_query was called + mock_mos_product.get_suggestion_query.assert_called_once() + call_kwargs = mock_mos_product.get_suggestion_query.call_args[1] + assert call_kwargs["user_id"] == "test_user" + + def test_suggestions_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test suggestions endpoint with missing required field.""" + request_data = { + "mem_cube_id": "test_cube", + } + + response = client.post("/product/suggestions", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_suggestions_response_format(self, mock_mos_product, client): + """Test suggestions endpoint returns SuggestionResponse format.""" + mock_mos_product.get_suggestion_query.return_value = [ + "suggestion1", + "suggestion2", + "suggestion3", + ] + + request_data = { + "user_id": "test_user", + "mem_cube_id": "test_cube", + "language": "en", + } + + response = client.post("/product/suggestions", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Suggestions retrieved successfully" + assert isinstance(data["data"], dict) + assert isinstance(data["data"]["query"], list) + + +class TestProductRouterGetAll: + """Test /get_all endpoint input/output format.""" + + def test_get_all_valid_input_output(self, mock_mos_product, client): + """Test get_all endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], list) + + # Verify MOSProduct.get_all was called + mock_mos_product.get_all.assert_called_once() + call_kwargs = mock_mos_product.get_all.call_args[1] + assert call_kwargs["user_id"] == "test_user" + assert call_kwargs["memory_type"] == "text_mem" + + def test_get_all_with_search_query(self, mock_mos_product, client): + """Test get_all endpoint with search_query uses get_subgraph.""" + # Reset mock call counts + mock_mos_product.get_all.reset_mock() + mock_mos_product.get_subgraph.reset_mock() + + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + "search_query": "test query", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + # Verify get_subgraph was called instead of get_all + mock_mos_product.get_subgraph.assert_called_once() + mock_mos_product.get_all.assert_not_called() + + def test_get_all_invalid_input_missing_user_id(self, mock_mos_product, client): + """Test get_all endpoint with missing required field.""" + request_data = { + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_get_all_response_format(self, mock_mos_product, client): + """Test get_all endpoint returns MemoryResponse format.""" + mock_mos_product.get_all.return_value = [{"cube_id": "test_cube", "memories": []}] + + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Memories retrieved successfully" + assert isinstance(data["data"], list) + assert len(data["data"]) > 0 diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py new file mode 100644 index 000000000..a4bb198e0 --- /dev/null +++ b/tests/api/test_server_router.py @@ -0,0 +1,445 @@ +""" +Unit tests for server_router input/output format validation. + +This module tests that the server_router endpoints correctly validate +input request formats and return properly formatted responses. +""" + +# Mock sklearn before importing any memos modules to avoid import errors +import importlib.util +import sys + +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from fastapi.testclient import TestClient + +from memos.api.product_models import ( + APIADDRequest, + APIChatCompleteRequest, + APISearchRequest, + MemoryResponse, + SearchResponse, + SuggestionResponse, +) + + +# Create a proper mock module with __spec__ +sklearn_mock = MagicMock() +sklearn_mock.__spec__ = importlib.util.spec_from_loader("sklearn", None) +sys.modules["sklearn"] = sklearn_mock + +sklearn_fe_mock = MagicMock() +sklearn_fe_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction", None) +sys.modules["sklearn.feature_extraction"] = sklearn_fe_mock + +sklearn_metrics_mock = MagicMock() +sklearn_metrics_mock.__spec__ = importlib.util.spec_from_loader("sklearn.metrics", None) +sys.modules["sklearn.metrics"] = sklearn_metrics_mock + +sklearn_fet_mock = MagicMock() +sklearn_fet_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction.text", None) +sklearn_fet_mock.TfidfVectorizer = MagicMock() +sys.modules["sklearn.feature_extraction.text"] = sklearn_fet_mock + + +@pytest.fixture(scope="module") +def mock_init_server(): + """Mock init_server before importing server_api.""" + # Create mock components + mock_components = { + "graph_db": Mock(), + "mem_reader": Mock(), + "llm": Mock(), + "embedder": Mock(), + "reranker": Mock(), + "internet_retriever": Mock(), + "memory_manager": Mock(), + "default_cube_config": Mock(), + "mos_server": Mock(), + "mem_scheduler": Mock(), + "naive_mem_cube": Mock(), + "searcher": Mock(), + "api_module": Mock(), + "vector_db": None, + "pref_extractor": None, + "pref_adder": None, + "pref_retriever": None, + "pref_mem": None, + "online_bot": None, + } + + with patch("memos.api.handlers.init_server", return_value=mock_components): + # Import after patching + from memos.api import server_api + + yield server_api.app + + +@pytest.fixture +def client(mock_init_server): + """Create test client for server_api.""" + return TestClient(mock_init_server) + + +@pytest.fixture +def mock_handlers(): + """Mock all handlers used by server_router.""" + with ( + patch("memos.api.routers.server_router.search_handler") as mock_search, + patch("memos.api.routers.server_router.add_handler") as mock_add, + patch("memos.api.routers.server_router.chat_handler") as mock_chat, + patch("memos.api.routers.server_router.handlers.suggestion_handler") as mock_suggestion, + patch("memos.api.routers.server_router.handlers.memory_handler") as mock_memory, + ): + # Set up default return values + mock_search.handle_search_memories.return_value = SearchResponse( + message="Search completed successfully", + data={"text_mem": [], "act_mem": [], "para_mem": []}, + ) + + mock_add.handle_add_memories.return_value = MemoryResponse( + message="Memory added successfully", data=[] + ) + + mock_chat.handle_chat_complete.return_value = { + "message": "Chat completed successfully", + "data": {"response": "test response", "references": []}, + } + + mock_suggestion.handle_get_suggestion_queries.return_value = SuggestionResponse( + message="Suggestions retrieved successfully", data={"query": ["suggestion1"]} + ) + + mock_memory.handle_get_all_memories.return_value = MemoryResponse( + message="Memories retrieved successfully", data=[] + ) + + mock_memory.handle_get_subgraph.return_value = MemoryResponse( + message="Memories retrieved successfully", data=[] + ) + + yield { + "search": mock_search, + "add": mock_add, + "chat": mock_chat, + "suggestion": mock_suggestion, + "memory": mock_memory, + } + + +class TestServerRouterSearch: + """Test /search endpoint input/output format.""" + + def test_search_valid_input_output(self, mock_handlers, client): + """Test search endpoint with valid input returns correct output format.""" + request_data = { + "query": "test query", + "user_id": "test_user", + "mem_cube_id": "test_cube", + "top_k": 10, + } + + response = client.post("/product/search", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], dict) + + # Verify handler was called with correct request type + mock_handlers["search"].handle_search_memories.assert_called_once() + call_args = mock_handlers["search"].handle_search_memories.call_args[0][0] + assert isinstance(call_args, APISearchRequest) + assert call_args.query == "test query" + assert call_args.user_id == "test_user" + + def test_search_invalid_input_missing_query(self, mock_handlers, client): + """Test search endpoint with missing required field.""" + request_data = { + "user_id": "test_user", + } + + response = client.post("/product/search", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_search_response_format(self, mock_handlers, client): + """Test search endpoint returns SearchResponse format.""" + mock_handlers["search"].handle_search_memories.return_value = SearchResponse( + message="Search completed successfully", + data={ + "text_mem": [{"cube_id": "test_cube", "memories": []}], + "act_mem": [], + "para_mem": [], + }, + ) + + request_data = { + "query": "test query", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/search", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Search completed successfully" + assert isinstance(data["data"], dict) + assert "text_mem" in data["data"] + + +class TestServerRouterAdd: + """Test /add endpoint input/output format.""" + + def test_add_valid_input_output(self, mock_handlers, client): + """Test add endpoint with valid input returns correct output format.""" + request_data = { + "mem_cube_id": "test_cube", + "user_id": "test_user", + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], list) + + # Verify handler was called with correct request type + mock_handlers["add"].handle_add_memories.assert_called_once() + call_args = mock_handlers["add"].handle_add_memories.call_args[0][0] + assert isinstance(call_args, APIADDRequest) + assert call_args.mem_cube_id == "test_cube" + assert call_args.user_id == "test_user" + + def test_add_invalid_input_missing_cube_id(self, mock_handlers, client): + """Test add endpoint with missing required field.""" + request_data = { + "user_id": "test_user", + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_add_response_format(self, mock_handlers, client): + """Test add endpoint returns MemoryResponse format.""" + mock_handlers["add"].handle_add_memories.return_value = MemoryResponse( + message="Memory added successfully", + data=[{"cube_id": "test_cube", "memories": []}], + ) + + request_data = { + "mem_cube_id": "test_cube", + "memory_content": "test memory content", + } + + response = client.post("/product/add", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Memory added successfully" + assert isinstance(data["data"], list) + + +class TestServerRouterChatComplete: + """Test /chat/complete endpoint input/output format.""" + + def test_chat_complete_valid_input_output(self, mock_handlers, client): + """Test chat/complete endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "query": "test query", + "mem_cube_id": "test_cube", + } + + response = client.post("/product/chat/complete", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "message" in data + assert "data" in data + assert isinstance(data["data"], dict) + assert "response" in data["data"] + assert "references" in data["data"] + + # Verify handler was called with correct request type + mock_handlers["chat"].handle_chat_complete.assert_called_once() + call_args = mock_handlers["chat"].handle_chat_complete.call_args[0][0] + assert isinstance(call_args, APIChatCompleteRequest) + assert call_args.user_id == "test_user" + assert call_args.query == "test query" + + def test_chat_complete_invalid_input_missing_user_id(self, mock_handlers, client): + """Test chat/complete endpoint with missing required field.""" + request_data = { + "query": "test query", + } + + response = client.post("/product/chat/complete", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_chat_complete_response_format(self, mock_handlers, client): + """Test chat/complete endpoint returns correct format.""" + mock_handlers["chat"].handle_chat_complete.return_value = { + "message": "Chat completed successfully", + "data": {"response": "test response", "references": [{"id": "ref1"}]}, + } + + request_data = { + "user_id": "test_user", + "query": "test query", + } + + response = client.post("/product/chat/complete", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Chat completed successfully" + assert isinstance(data["data"]["response"], str) + assert isinstance(data["data"]["references"], list) + + +class TestServerRouterSuggestions: + """Test /suggestions endpoint input/output format.""" + + def test_suggestions_valid_input_output(self, mock_handlers, client): + """Test suggestions endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "mem_cube_id": "test_cube", + "language": "zh", + } + + response = client.post("/product/suggestions", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + + # Verify handler was called + mock_handlers["suggestion"].handle_get_suggestion_queries.assert_called_once() + + def test_suggestions_invalid_input_missing_user_id(self, mock_handlers, client): + """Test suggestions endpoint with missing required field.""" + request_data = { + "mem_cube_id": "test_cube", + } + + response = client.post("/product/suggestions", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_suggestions_response_format(self, mock_handlers, client): + """Test suggestions endpoint returns SuggestionResponse format.""" + mock_handlers["suggestion"].handle_get_suggestion_queries.return_value = SuggestionResponse( + message="Suggestions retrieved successfully", + data={"query": ["suggestion1", "suggestion2"]}, + ) + + request_data = { + "user_id": "test_user", + "mem_cube_id": "test_cube", + "language": "en", + } + + response = client.post("/product/suggestions", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Suggestions retrieved successfully" + assert isinstance(data["data"], dict) + assert "query" in data["data"] + + +class TestServerRouterGetAll: + """Test /get_all endpoint input/output format.""" + + def test_get_all_valid_input_output(self, mock_handlers, client): + """Test get_all endpoint with valid input returns correct output format.""" + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + data = response.json() + + # Validate response structure + assert "code" in data + assert "message" in data + assert "data" in data + assert data["code"] == 200 + assert isinstance(data["data"], list) + + def test_get_all_with_search_query(self, mock_handlers, client): + """Test get_all endpoint with search_query uses subgraph handler.""" + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + "search_query": "test query", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + # Verify subgraph handler was called + mock_handlers["memory"].handle_get_subgraph.assert_called_once() + + def test_get_all_invalid_input_missing_user_id(self, mock_handlers, client): + """Test get_all endpoint with missing required field.""" + request_data = { + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + # Should return validation error + assert response.status_code == 422 + + def test_get_all_response_format(self, mock_handlers, client): + """Test get_all endpoint returns MemoryResponse format.""" + mock_handlers["memory"].handle_get_all_memories.return_value = MemoryResponse( + message="Memories retrieved successfully", + data=[{"cube_id": "test_cube", "memories": []}], + ) + + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Memories retrieved successfully" + assert isinstance(data["data"], list) From 9e6cc60a9f3b003c2ec17332d9e48d60e0b50971 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sat, 22 Nov 2025 17:39:49 +0800 Subject: [PATCH 5/7] fix: Fixed the compatibility issue in the product router. --- src/memos/api/config.py | 4 +++- src/memos/api/product_models.py | 8 ++++++++ src/memos/api/routers/product_router.py | 3 ++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index a276fa63d..b90df51b2 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -328,7 +328,9 @@ def get_memreader_config() -> dict[str, Any]: "top_p": 0.95, "top_k": 20, "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), - "api_base": os.getenv("MEMRADER_API_BASE"), + # Default to OpenAI base URL when env var is not provided to satisfy pydantic + # validation requirements during tests/import. + "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"), "remove_think_prefix": True, "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}, }, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index ff930666f..191b219e4 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -93,6 +93,9 @@ class ChatRequest(BaseRequest): temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + moscube: bool = Field( + False, description="(Deprecated) Whether to use legacy MemOSCube pipeline" + ) class ChatCompleteRequest(BaseRequest): @@ -116,6 +119,11 @@ class ChatCompleteRequest(BaseRequest): top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + base_prompt: str | None = Field(None, description="(Deprecated) Base prompt alias") + moscube: bool = Field( + False, description="(Deprecated) Whether to use legacy MemOSCube pipeline" + ) + class UserCreate(BaseRequest): user_name: str | None = Field(None, description="Name of the user") diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 2f6c5c317..ccacee816 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -297,7 +297,8 @@ def chat_complete(chat_req: ChatCompleteRequest): history=chat_req.history, internet_search=chat_req.internet_search, moscube=chat_req.moscube, - base_prompt=chat_req.base_prompt, + base_prompt=chat_req.base_prompt or chat_req.system_prompt, + # will deprecate base_prompt in the future top_k=chat_req.top_k, threshold=chat_req.threshold, session_id=chat_req.session_id, From 27ee05f95c58ac38df34739b36bb6ed5724945e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sat, 22 Nov 2025 18:12:46 +0800 Subject: [PATCH 6/7] fix: tests unpass --- tests/api/test_server_router.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index a4bb198e0..5f34a2048 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -68,6 +68,7 @@ def mock_init_server(): "pref_retriever": None, "pref_mem": None, "online_bot": None, + "chat_llms": Mock(), } with patch("memos.api.handlers.init_server", return_value=mock_components): @@ -184,6 +185,7 @@ def test_search_response_format(self, mock_handlers, client): request_data = { "query": "test query", + "user_id": "test_user_id", "mem_cube_id": "test_cube", } @@ -226,18 +228,6 @@ def test_add_valid_input_output(self, mock_handlers, client): assert call_args.mem_cube_id == "test_cube" assert call_args.user_id == "test_user" - def test_add_invalid_input_missing_cube_id(self, mock_handlers, client): - """Test add endpoint with missing required field.""" - request_data = { - "user_id": "test_user", - "memory_content": "test memory content", - } - - response = client.post("/product/add", json=request_data) - - # Should return validation error - assert response.status_code == 422 - def test_add_response_format(self, mock_handlers, client): """Test add endpoint returns MemoryResponse format.""" mock_handlers["add"].handle_add_memories.return_value = MemoryResponse( From 1c03794cac974fe255276467ba89c66889916967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Sat, 22 Nov 2025 23:15:52 +0800 Subject: [PATCH 7/7] fix: test_api bug --- tests/api/test_product_router.py | 34 +++----------------------------- tests/api/test_server_router.py | 27 +++---------------------- 2 files changed, 6 insertions(+), 55 deletions(-) diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py index 9ed67d037..857b290c5 100644 --- a/tests/api/test_product_router.py +++ b/tests/api/test_product_router.py @@ -5,50 +5,22 @@ input request formats and return properly formatted responses. """ -# Mock sklearn before importing any memos modules to avoid import errors -import importlib.util -import sys - -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from fastapi.testclient import TestClient # Patch the MOS_PRODUCT_INSTANCE directly after import +# Patch MOS_PRODUCT_INSTANCE and MOSProduct so we can test the FastAPI router +# without initializing the full MemOS product stack. import memos.api.routers.product_router as pr_module -# Create a proper mock module with __spec__ -sklearn_mock = MagicMock() -sklearn_mock.__spec__ = importlib.util.spec_from_loader("sklearn", None) -sys.modules["sklearn"] = sklearn_mock - -sklearn_fe_mock = MagicMock() -sklearn_fe_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction", None) -sys.modules["sklearn.feature_extraction"] = sklearn_fe_mock - -sklearn_fet_mock = MagicMock() -sklearn_fet_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction.text", None) -sklearn_fet_mock.TfidfVectorizer = MagicMock() -sys.modules["sklearn.feature_extraction.text"] = sklearn_fet_mock - -# Mock sklearn.metrics as well -sklearn_metrics_mock = MagicMock() -sklearn_metrics_mock.__spec__ = importlib.util.spec_from_loader("sklearn.metrics", None) -sklearn_metrics_mock.roc_curve = MagicMock() -sys.modules["sklearn.metrics"] = sklearn_metrics_mock - - -# Create mock instance _mock_mos_instance = Mock() - pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance pr_module.get_mos_product_instance = lambda: _mock_mos_instance - -# Mock MOSProduct class before importing to prevent initialization with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance): - # Import after patching from memos.api import product_api diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 5f34a2048..853a271f6 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -5,11 +5,7 @@ input request formats and return properly formatted responses. """ -# Mock sklearn before importing any memos modules to avoid import errors -import importlib.util -import sys - -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest @@ -25,25 +21,8 @@ ) -# Create a proper mock module with __spec__ -sklearn_mock = MagicMock() -sklearn_mock.__spec__ = importlib.util.spec_from_loader("sklearn", None) -sys.modules["sklearn"] = sklearn_mock - -sklearn_fe_mock = MagicMock() -sklearn_fe_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction", None) -sys.modules["sklearn.feature_extraction"] = sklearn_fe_mock - -sklearn_metrics_mock = MagicMock() -sklearn_metrics_mock.__spec__ = importlib.util.spec_from_loader("sklearn.metrics", None) -sys.modules["sklearn.metrics"] = sklearn_metrics_mock - -sklearn_fet_mock = MagicMock() -sklearn_fet_mock.__spec__ = importlib.util.spec_from_loader("sklearn.feature_extraction.text", None) -sklearn_fet_mock.TfidfVectorizer = MagicMock() -sys.modules["sklearn.feature_extraction.text"] = sklearn_fet_mock - - +# Patch init_server so we can import server_api without starting the full MemOS stack, +# and keep sklearn and other core dependencies untouched for other tests. @pytest.fixture(scope="module") def mock_init_server(): """Mock init_server before importing server_api."""