diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py
index 17bfd3993..7d8cf2897 100644
--- a/examples/mem_scheduler/memos_w_scheduler.py
+++ b/examples/mem_scheduler/memos_w_scheduler.py
@@ -1,29 +1,28 @@
+import re
import shutil
import sys
+from datetime import datetime
from pathlib import Path
from queue import Queue
+
from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.mem_os import MOSConfig
-from datetime import datetime
-import re
-
from memos.configs.mem_scheduler import AuthConfig
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_os.main import MOS
+from memos.mem_scheduler.general_scheduler import GeneralScheduler
from memos.mem_scheduler.schemas.general_schemas import (
- QUERY_LABEL,
- ANSWER_LABEL,
ADD_LABEL,
+ ANSWER_LABEL,
+ MEM_ARCHIVE_LABEL,
MEM_ORGANIZE_LABEL,
MEM_UPDATE_LABEL,
- MEM_ARCHIVE_LABEL,
- NOT_APPLICABLE_TYPE,
+ QUERY_LABEL,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem
from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
-from memos.mem_scheduler.general_scheduler import GeneralScheduler
FILE_PATH = Path(__file__).absolute()
diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py
index 8540a67ec..2f40f1c91 100644
--- a/src/memos/api/handlers/chat_handler.py
+++ b/src/memos/api/handlers/chat_handler.py
@@ -7,6 +7,7 @@
import asyncio
import json
+import re
import traceback
from collections.abc import Generator
@@ -32,7 +33,6 @@
from memos.mem_scheduler.schemas.general_schemas import (
ANSWER_LABEL,
QUERY_LABEL,
- SearchMode,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.templates.mos_prompts import (
@@ -53,6 +53,7 @@ class ChatHandler(BaseHandler):
def __init__(
self,
dependencies: HandlerDependencies,
+ chat_llms: dict[str, Any],
search_handler=None,
add_handler=None,
online_bot=None,
@@ -62,6 +63,7 @@ def __init__(
Args:
dependencies: HandlerDependencies instance
+ chat_llms: Dictionary mapping model names to LLM instances
search_handler: Optional SearchHandler instance (created if not provided)
add_handler: Optional AddHandler instance (created if not provided)
online_bot: Optional DingDing bot function for notifications
@@ -80,6 +82,7 @@ def __init__(
add_handler = AddHandler(dependencies)
+ self.chat_llms = chat_llms
self.search_handler = search_handler
self.add_handler = add_handler
self.online_bot = online_bot
@@ -105,21 +108,19 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
HTTPException: If chat fails
"""
try:
- import time
-
- time_start = time.time()
-
# Step 1: Search for relevant memories
search_req = APISearchRequest(
+ query=chat_req.query,
user_id=chat_req.user_id,
mem_cube_id=chat_req.mem_cube_id,
- query=chat_req.query,
- top_k=chat_req.top_k or 10,
- session_id=chat_req.session_id,
- mode=SearchMode.FAST,
+ mode=chat_req.mode,
internet_search=chat_req.internet_search,
- moscube=chat_req.moscube,
+ top_k=chat_req.top_k,
chat_history=chat_req.history,
+ session_id=chat_req.session_id,
+ include_preference=chat_req.include_preference,
+ pref_top_k=chat_req.pref_top_k,
+ filter=chat_req.filter,
)
search_response = self.search_handler.handle_search_memories(search_req)
@@ -137,7 +138,9 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
)
# Step 2: Build system prompt
- system_prompt = self._build_system_prompt(filtered_memories, chat_req.base_prompt)
+ system_prompt = self._build_system_prompt(
+ filtered_memories, search_response.data["pref_string"], chat_req.system_prompt
+ )
# Prepare message history
history_info = chat_req.history[-20:] if chat_req.history else []
@@ -150,28 +153,33 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
self.logger.info("Starting to generate complete response...")
# Step 3: Generate complete response from LLM
- response = self.llm.generate(current_messages)
-
- time_end = time.time()
+ if chat_req.model_name_or_path and chat_req.model_name_or_path not in self.chat_llms:
+ return {
+ "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}"
+ }
+ model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys()))
+ response = self.chat_llms[model].generate(current_messages, model_name_or_path=model)
+
+ # Step 4: start add after chat asynchronously
+ if chat_req.add_message_on_answer:
+ self._start_add_to_memory(
+ user_id=chat_req.user_id,
+ cube_id=chat_req.mem_cube_id,
+ session_id=chat_req.session_id or "default_session",
+ query=chat_req.query,
+ full_response=response,
+ async_mode="async",
+ )
- # Step 4: Start post-chat processing asynchronously
- self._start_post_chat_processing(
- user_id=chat_req.user_id,
- cube_id=chat_req.mem_cube_id,
- session_id=chat_req.session_id or "default_session",
- query=chat_req.query,
- full_response=response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=0.0,
- current_messages=current_messages,
+ match = re.search(r"([\s\S]*?)", response)
+ reasoning_text = match.group(1) if match else None
+ final_text = (
+ re.sub(r"[\s\S]*?", "", response, count=1) if match else response
)
- # Return the complete response
return {
"message": "Chat completed successfully",
- "data": {"response": response, "references": filtered_memories},
+ "data": {"response": final_text, "reasoning": reasoning_text},
}
except ValueError as err:
@@ -186,6 +194,150 @@ def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse:
This implementation directly uses search_handler and add_handler.
+ Args:
+ chat_req: Chat stream request
+
+ Returns:
+ StreamingResponse with SSE formatted chat stream
+
+ Raises:
+ HTTPException: If stream initialization fails
+ """
+ try:
+
+ def generate_chat_response() -> Generator[str, None, None]:
+ """Generate chat response as SSE stream."""
+ try:
+ search_req = APISearchRequest(
+ query=chat_req.query,
+ user_id=chat_req.user_id,
+ mem_cube_id=chat_req.mem_cube_id,
+ mode=chat_req.mode,
+ internet_search=chat_req.internet_search,
+ top_k=chat_req.top_k,
+ chat_history=chat_req.history,
+ session_id=chat_req.session_id,
+ include_preference=chat_req.include_preference,
+ pref_top_k=chat_req.pref_top_k,
+ filter=chat_req.filter,
+ )
+
+ search_response = self.search_handler.handle_search_memories(search_req)
+
+ self._send_message_to_scheduler(
+ user_id=chat_req.user_id,
+ mem_cube_id=chat_req.mem_cube_id,
+ query=chat_req.query,
+ label=QUERY_LABEL,
+ )
+ # Extract memories from search results
+ memories_list = []
+ if search_response.data and search_response.data.get("text_mem"):
+ text_mem_results = search_response.data["text_mem"]
+ if text_mem_results and text_mem_results[0].get("memories"):
+ memories_list = text_mem_results[0]["memories"]
+
+ # Filter memories by threshold
+ filtered_memories = self._filter_memories_by_threshold(memories_list)
+
+ # Step 2: Build system prompt with memories
+ system_prompt = self._build_system_prompt(
+ filtered_memories,
+ search_response.data["pref_string"],
+ chat_req.system_prompt,
+ )
+
+ # Prepare messages
+ history_info = chat_req.history[-20:] if chat_req.history else []
+ current_messages = [
+ {"role": "system", "content": system_prompt},
+ *history_info,
+ {"role": "user", "content": chat_req.query},
+ ]
+
+ self.logger.info(
+ f"user_id: {chat_req.user_id}, cube_id: {chat_req.mem_cube_id}, "
+ f"current_system_prompt: {system_prompt}"
+ )
+
+ # Step 3: Generate streaming response from LLM
+ if (
+ chat_req.model_name_or_path
+ and chat_req.model_name_or_path not in self.chat_llms
+ ):
+ return {
+ "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}"
+ }
+ model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys()))
+ response_stream = self.chat_llms[model].generate_stream(
+ current_messages, model_name_or_path=model
+ )
+
+ # Stream the response
+ buffer = ""
+ full_response = ""
+ in_think = False
+
+ for chunk in response_stream:
+ if chunk == "":
+ in_think = True
+ continue
+ if chunk == "":
+ in_think = False
+ continue
+
+ if in_think:
+ chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n"
+ yield chunk_data
+ continue
+
+ buffer += chunk
+ full_response += chunk
+
+ chunk_data = f"data: {json.dumps({'type': 'text', 'data': chunk}, ensure_ascii=False)}\n\n"
+ yield chunk_data
+
+ current_messages.append({"role": "assistant", "content": full_response})
+ if chat_req.add_message_on_answer:
+ self._start_add_to_memory(
+ user_id=chat_req.user_id,
+ cube_id=chat_req.mem_cube_id,
+ session_id=chat_req.session_id or "default_session",
+ query=chat_req.query,
+ full_response=full_response,
+ async_mode="async",
+ )
+
+ except Exception as e:
+ self.logger.error(f"Error in chat stream: {e}", exc_info=True)
+ error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
+ yield error_data
+
+ return StreamingResponse(
+ generate_chat_response(),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Content-Type": "text/event-stream",
+ "Access-Control-Allow-Origin": "*",
+ "Access-Control-Allow-Headers": "*",
+ "Access-Control-Allow-Methods": "*",
+ },
+ )
+
+ except ValueError as err:
+ raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
+ except Exception as err:
+ self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
+ raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
+
+ def handle_chat_stream_playground(self, chat_req: ChatRequest) -> StreamingResponse:
+ """
+ Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers.
+
+ This implementation directly uses search_handler and add_handler.
+
Args:
chat_req: Chat stream request
@@ -208,15 +360,17 @@ def generate_chat_response() -> Generator[str, None, None]:
yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n"
search_req = APISearchRequest(
+ query=chat_req.query,
user_id=chat_req.user_id,
mem_cube_id=chat_req.mem_cube_id,
- query=chat_req.query,
- top_k=20,
- session_id=chat_req.session_id,
- mode=SearchMode.FAST,
- internet_search=chat_req.internet_search, # TODO this param is not worked at fine mode
- moscube=chat_req.moscube,
+ mode=chat_req.mode,
+ internet_search=chat_req.internet_search,
+ top_k=chat_req.top_k,
chat_history=chat_req.history,
+ session_id=chat_req.session_id,
+ include_preference=chat_req.include_preference,
+ pref_top_k=chat_req.pref_top_k,
+ filter=chat_req.filter,
)
search_response = self.search_handler.handle_search_memories(search_req)
@@ -240,10 +394,23 @@ def generate_chat_response() -> Generator[str, None, None]:
# Prepare reference data
reference = prepare_reference_data(filtered_memories)
+ # get internet reference
+ internet_reference = self._get_internet_reference(
+ search_response.data.get("text_mem")[0]["memories"]
+ )
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
+ # Prepare preference markdown string
+ if chat_req.include_preference:
+ pref_md_string = self._build_pref_md_string_for_playground(
+ search_response.data["pref_mem"][0].get("memories", [])
+ )
+ yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n"
+
# Step 2: Build system prompt with memories
- system_prompt = self._build_enhance_system_prompt(filtered_memories)
+ system_prompt = self._build_enhance_system_prompt(
+ filtered_memories, search_response.data["pref_string"]
+ )
# Prepare messages
history_info = chat_req.history[-20:] if chat_req.history else []
@@ -261,14 +428,34 @@ def generate_chat_response() -> Generator[str, None, None]:
yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n"
# Step 3: Generate streaming response from LLM
- response_stream = self.llm.generate_stream(current_messages)
+ if (
+ chat_req.model_name_or_path
+ and chat_req.model_name_or_path not in self.chat_llms
+ ):
+ return {
+ "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}"
+ }
+ model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys()))
+ response_stream = self.chat_llms[model].generate_stream(
+ current_messages, model_name_or_path=model
+ )
# Stream the response
buffer = ""
full_response = ""
+ in_think = False
for chunk in response_stream:
- if chunk in ["", ""]:
+ if chunk == "":
+ in_think = True
+ continue
+ if chunk == "":
+ in_think = False
+ continue
+
+ if in_think:
+ chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n"
+ yield chunk_data
continue
buffer += chunk
@@ -291,6 +478,9 @@ def generate_chat_response() -> Generator[str, None, None]:
chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
yield chunk_data
+ # Yield internet reference after text response
+ yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n"
+
# Calculate timing
time_end = time.time()
speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1)
@@ -306,7 +496,6 @@ def generate_chat_response() -> Generator[str, None, None]:
yield f"data: {json.dumps({'type': 'end'})}\n\n"
- # Step 4: Add conversation to memory asynchronously
self._start_post_chat_processing(
user_id=chat_req.user_id,
cube_id=chat_req.mem_cube_id,
@@ -320,6 +509,15 @@ def generate_chat_response() -> Generator[str, None, None]:
current_messages=current_messages,
)
+ self._start_add_to_memory(
+ user_id=chat_req.user_id,
+ cube_id=chat_req.mem_cube_id,
+ session_id=chat_req.session_id or "default_session",
+ query=chat_req.query,
+ full_response=full_response,
+ async_mode="sync",
+ )
+
except Exception as e:
self.logger.error(f"Error in chat stream: {e}", exc_info=True)
error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
@@ -344,9 +542,62 @@ def generate_chat_response() -> Generator[str, None, None]:
self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
+ def _get_internet_reference(
+ self, search_response: list[dict[str, any]]
+ ) -> list[dict[str, any]]:
+ """Get internet reference from search response."""
+ unique_set = set()
+ result = []
+
+ for item in search_response:
+ meta = item.get("metadata", {})
+ if meta.get("source") == "web" and meta.get("internet_info"):
+ info = meta.get("internet_info")
+ key = json.dumps(info, sort_keys=True)
+ if key not in unique_set:
+ unique_set.add(key)
+ result.append(info)
+ return result
+
+ def _build_pref_md_string_for_playground(self, pref_mem_list: list[any]) -> str:
+ """Build preference markdown string for playground."""
+ explicit = []
+ implicit = []
+ for pref_mem in pref_mem_list:
+ if pref_mem["metadata"]["preference_type"] == "explicit":
+ explicit.append(
+ {
+ "content": pref_mem["preference"],
+ "reasoning": pref_mem["metadata"]["reasoning"],
+ }
+ )
+ elif pref_mem["metadata"]["preference_type"] == "implicit":
+ implicit.append(
+ {
+ "content": pref_mem["preference"],
+ "reasoning": pref_mem["metadata"]["reasoning"],
+ }
+ )
+
+ explicit_md = "\n\n".join(
+ [
+ f"显性偏好 {i + 1}:\n- 抽取内容: {pref['content']}\n- 抽取理由: {pref['reasoning']}"
+ for i, pref in enumerate(explicit)
+ ]
+ )
+ implicit_md = "\n\n".join(
+ [
+ f"隐性偏好 {i + 1}:\n- 抽取内容: {pref['content']}\n- 抽取理由: {pref['reasoning']}"
+ for i, pref in enumerate(implicit)
+ ]
+ )
+
+ return f"{explicit_md}\n\n{implicit_md}"
+
def _build_system_prompt(
self,
memories: list | None = None,
+ pref_string: str | None = None,
base_prompt: str | None = None,
**kwargs,
) -> str:
@@ -366,6 +617,8 @@ def _build_system_prompt(
text_memory = memory.get("memory", "")
memory_list.append(f"{i}. {text_memory}")
memory_context = "\n".join(memory_list)
+ if pref_string:
+ memory_context += f"\n\n{pref_string}"
if "{memories}" in base_prompt:
return base_prompt.format(memories=memory_context)
@@ -378,6 +631,7 @@ def _build_system_prompt(
def _build_enhance_system_prompt(
self,
memories_list: list,
+ pref_string: str = "",
tone: str = "friendly",
verbosity: str = "mid",
) -> str:
@@ -386,6 +640,7 @@ def _build_enhance_system_prompt(
Args:
memories_list: List of memory items
+ pref_string: Preference string
tone: Tone of the prompt
verbosity: Verbosity level
@@ -407,6 +662,7 @@ def _build_enhance_system_prompt(
+ mem_block_p
+ "\n## OuterMemory (ordered)\n"
+ mem_block_o
+ + f"\n\n{pref_string}"
)
def _format_mem_block(
@@ -608,6 +864,36 @@ def _send_message_to_scheduler(
except Exception as e:
self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True)
+ async def _add_conversation_to_memory(
+ self,
+ user_id: str,
+ cube_id: str,
+ session_id: str,
+ query: str,
+ clean_response: str,
+ async_mode: Literal["async", "sync"] = "sync",
+ ) -> None:
+ add_req = APIADDRequest(
+ user_id=user_id,
+ mem_cube_id=cube_id,
+ session_id=session_id,
+ messages=[
+ {
+ "role": "user",
+ "content": query,
+ "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
+ },
+ {
+ "role": "assistant",
+ "content": clean_response,
+ "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
+ },
+ ],
+ async_mode=async_mode,
+ )
+
+ self.add_handler.handle_add_memories(add_req)
+
async def _post_chat_processing(
self,
user_id: str,
@@ -701,28 +987,6 @@ async def _post_chat_processing(
user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL
)
- # Add conversation to memory using add handler
- add_req = APIADDRequest(
- user_id=user_id,
- mem_cube_id=cube_id,
- session_id=session_id,
- messages=[
- {
- "role": "user",
- "content": query,
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- {
- "role": "assistant",
- "content": clean_response, # Store clean text without reference markers
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- ],
- async_mode="sync", # set suync for playground
- )
-
- self.add_handler.handle_add_memories(add_req)
-
self.logger.info(f"Post-chat processing completed for user {user_id}")
except Exception as e:
@@ -822,3 +1086,65 @@ def run_async_in_thread():
daemon=True,
)
thread.start()
+
+ def _start_add_to_memory(
+ self,
+ user_id: str,
+ cube_id: str,
+ session_id: str,
+ query: str,
+ full_response: str,
+ async_mode: Literal["async", "sync"] = "sync",
+ ) -> None:
+ def run_async_in_thread():
+ try:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ clean_response, _ = self._extract_references_from_response(full_response)
+ loop.run_until_complete(
+ self._add_conversation_to_memory(
+ user_id=user_id,
+ cube_id=cube_id,
+ session_id=session_id,
+ query=query,
+ clean_response=clean_response,
+ async_mode=async_mode,
+ )
+ )
+ finally:
+ loop.close()
+ except Exception as e:
+ self.logger.error(
+ f"Error in thread-based add to memory for user {user_id}: {e}",
+ exc_info=True,
+ )
+
+ try:
+ asyncio.get_running_loop()
+ clean_response, _ = self._extract_references_from_response(full_response)
+ task = asyncio.create_task(
+ self._add_conversation_to_memory(
+ user_id=user_id,
+ cube_id=cube_id,
+ session_id=session_id,
+ query=query,
+ clean_response=clean_response,
+ async_mode=async_mode,
+ )
+ )
+ task.add_done_callback(
+ lambda t: self.logger.error(
+ f"Error in background add to memory for user {user_id}: {t.exception()}",
+ exc_info=True,
+ )
+ if t.exception()
+ else None
+ )
+ except RuntimeError:
+ thread = ContextThread(
+ target=run_async_in_thread,
+ name=f"AddToMemory-{user_id}",
+ daemon=True,
+ )
+ thread.start()
diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py
index 89e61e79d..3ef1d529d 100644
--- a/src/memos/api/handlers/component_init.py
+++ b/src/memos/api/handlers/component_init.py
@@ -11,6 +11,7 @@
from memos.api.config import APIConfig
from memos.api.handlers.config_builders import (
+ build_chat_llm_config,
build_embedder_config,
build_graph_db_config,
build_internet_retriever_config,
@@ -77,6 +78,38 @@ def _get_default_memory_size(cube_config: Any) -> dict[str, int]:
}
+def _init_chat_llms(chat_llm_configs: list[dict]) -> dict[str, Any]:
+ """
+ Initialize chat language models from configuration.
+
+ Args:
+ chat_llm_configs: List of chat LLM configuration dictionaries
+
+ Returns:
+ Dictionary mapping model names to initialized LLM instances
+ """
+
+ def _list_models(client):
+ try:
+ models = (
+ [model.id for model in client.models.list().data]
+ if client.models.list().data
+ else client.models.list().models
+ )
+ except Exception as e:
+ logger.error(f"Error listing models: {e}")
+ models = []
+ return models
+
+ model_name_instrance_maping = {}
+ for cfg in chat_llm_configs:
+ llm = LLMFactory.from_config(cfg["config_class"])
+ if cfg["support_models"]:
+ for model_name in cfg["support_models"]:
+ model_name_instrance_maping[model_name] = llm
+ return model_name_instrance_maping
+
+
def init_server() -> dict[str, Any]:
"""
Initialize all server components and configurations.
@@ -104,6 +137,7 @@ def init_server() -> dict[str, Any]:
# Build component configurations
graph_db_config = build_graph_db_config()
llm_config = build_llm_config()
+ chat_llm_config = build_chat_llm_config()
embedder_config = build_embedder_config()
mem_reader_config = build_mem_reader_config()
reranker_config = build_reranker_config()
@@ -123,6 +157,7 @@ def init_server() -> dict[str, Any]:
else None
)
llm = LLMFactory.from_config(llm_config)
+ chat_llms = _init_chat_llms(chat_llm_config)
embedder = EmbedderFactory.from_config(embedder_config)
mem_reader = MemReaderFactory.from_config(mem_reader_config)
reranker = RerankerFactory.from_config(reranker_config)
@@ -130,6 +165,8 @@ def init_server() -> dict[str, Any]:
internet_retriever_config, embedder=embedder
)
+ # Initialize chat llms
+
logger.debug("Core components instantiated")
# Initialize memory manager
@@ -234,7 +271,6 @@ def init_server() -> dict[str, Any]:
tree_mem: TreeTextMemory = naive_mem_cube.text_mem
searcher: Searcher = tree_mem.get_searcher(
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
- moscube=False,
)
logger.debug("Searcher created")
@@ -276,6 +312,7 @@ def init_server() -> dict[str, Any]:
"graph_db": graph_db,
"mem_reader": mem_reader,
"llm": llm,
+ "chat_llms": chat_llms,
"embedder": embedder,
"reranker": reranker,
"internet_retriever": internet_retriever,
diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py
index 9f510add0..4a83700d0 100644
--- a/src/memos/api/handlers/config_builders.py
+++ b/src/memos/api/handlers/config_builders.py
@@ -6,6 +6,7 @@
a configuration dictionary using the appropriate ConfigFactory.
"""
+import json
import os
from typing import Any
@@ -81,6 +82,32 @@ def build_llm_config() -> dict[str, Any]:
)
+def build_chat_llm_config() -> list[dict[str, Any]]:
+ """
+ Build chat LLM configuration.
+
+ Returns:
+ Validated chat LLM configuration dictionary
+ """
+ configs = json.loads(os.getenv("CHAT_MODEL_LIST"))
+ return [
+ {
+ "config_class": LLMConfigFactory.model_validate(
+ {
+ "backend": cfg.get("backend", "openai"),
+ "config": (
+ {k: v for k, v in cfg.items() if k not in ["backend", "support_models"]}
+ )
+ if cfg
+ else APIConfig.get_openai_config(),
+ }
+ ),
+ "support_models": cfg.get("support_models", None),
+ }
+ for cfg in configs
+ ]
+
+
def build_embedder_config() -> dict[str, Any]:
"""
Build embedder configuration.
diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py
index 85f339f3f..c47a3cf83 100644
--- a/src/memos/api/handlers/memory_handler.py
+++ b/src/memos/api/handlers/memory_handler.py
@@ -6,7 +6,14 @@
from typing import Any, Literal
-from memos.api.product_models import MemoryResponse
+from memos.api.handlers.formatters_handler import format_memory_item
+from memos.api.product_models import (
+ DeleteMemoryRequest,
+ DeleteMemoryResponse,
+ GetMemoryRequest,
+ GetMemoryResponse,
+ MemoryResponse,
+)
from memos.log import get_logger
from memos.mem_os.utils.format_utils import (
convert_graph_to_tree_forworkmem,
@@ -149,3 +156,37 @@ def handle_get_subgraph(
except Exception as e:
logger.error(f"Failed to get subgraph: {e}", exc_info=True)
raise
+
+
+def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse:
+ # TODO: Implement get memory with filter
+ memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"]
+ filter_params: dict[str, Any] = {}
+ if get_mem_req.user_id is not None:
+ filter_params["user_id"] = get_mem_req.user_id
+ if get_mem_req.mem_cube_id is not None:
+ filter_params["mem_cube_id"] = get_mem_req.mem_cube_id
+ preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params)
+ return GetMemoryResponse(
+ message="Memories retrieved successfully",
+ data={
+ "text_mem": memories,
+ "pref_mem": [format_memory_item(mem) for mem in preferences],
+ },
+ )
+
+
+def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: Any):
+ try:
+ naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids)
+ naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids)
+ except Exception as e:
+ logger.error(f"Failed to delete memories: {e}", exc_info=True)
+ return DeleteMemoryResponse(
+ message="Failed to delete memories",
+ data="failure",
+ )
+ return DeleteMemoryResponse(
+ message="Memories deleted successfully",
+ data={"status": "success"},
+ )
diff --git a/src/memos/api/handlers/scheduler_handler.py b/src/memos/api/handlers/scheduler_handler.py
index 8d3c6dc70..32b312f8a 100644
--- a/src/memos/api/handlers/scheduler_handler.py
+++ b/src/memos/api/handlers/scheduler_handler.py
@@ -22,7 +22,7 @@
def handle_scheduler_status(
- user_name: str | None = None,
+ mem_cube_id: str | None = None,
mem_scheduler: Any | None = None,
instance_id: str = "",
) -> dict[str, Any]:
@@ -43,9 +43,9 @@ def handle_scheduler_status(
HTTPException: If status retrieval fails
"""
try:
- if user_name:
+ if mem_cube_id:
running = mem_scheduler.dispatcher.get_running_tasks(
- lambda task: getattr(task, "mem_cube_id", None) == user_name
+ lambda task: getattr(task, "mem_cube_id", None) == mem_cube_id
)
tasks_iter = to_iter(running)
running_count = len(tasks_iter)
@@ -53,7 +53,7 @@ def handle_scheduler_status(
"message": "ok",
"data": {
"scope": "user",
- "user_name": user_name,
+ "mem_cube_id": mem_cube_id,
"running_tasks": running_count,
"timestamp": time.time(),
"instance_id": instance_id,
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index cb72011a3..3c5fb3bc4 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -1,7 +1,6 @@
-import os
import uuid
-from typing import Generic, Literal, TypeVar
+from typing import Any, Generic, Literal, TypeVar
from pydantic import BaseModel, Field
@@ -37,7 +36,7 @@ class UserRegisterRequest(BaseRequest):
interests: str | None = Field(None, description="User interests")
-class GetMemoryRequest(BaseRequest):
+class GetMemoryPlaygroundRequest(BaseRequest):
"""Request model for getting memories."""
user_id: str = Field(..., description="User ID")
@@ -80,9 +79,20 @@ class ChatRequest(BaseRequest):
None, description="List of cube IDs user can write for multi-cube chat"
)
history: list[MessageDict] | None = Field(None, description="Chat history")
+ mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
internet_search: bool = Field(True, description="Whether to use internet search")
- moscube: bool = Field(False, description="Whether to use MemOSCube")
+ system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
+ top_k: int = Field(10, description="Number of results to return")
+ threshold: float = Field(0.5, description="Threshold for filtering references")
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
+ include_preference: bool = Field(True, description="Whether to handle preference memory")
+ pref_top_k: int = Field(6, description="Number of preference results to return")
+ filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
+ model_name_or_path: str | None = Field(None, description="Model name to use for chat")
+ max_tokens: int | None = Field(None, description="Max tokens to generate")
+ temperature: float | None = Field(None, description="Temperature for sampling")
+ top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
+ add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
class ChatCompleteRequest(BaseRequest):
@@ -93,11 +103,18 @@ class ChatCompleteRequest(BaseRequest):
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
history: list[MessageDict] | None = Field(None, description="Chat history")
internet_search: bool = Field(False, description="Whether to use internet search")
- moscube: bool = Field(False, description="Whether to use MemOSCube")
- base_prompt: str | None = Field(None, description="Base prompt to use for chat")
+ system_prompt: str | None = Field(None, description="Base prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
threshold: float = Field(0.5, description="Threshold for filtering references")
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
+ include_preference: bool = Field(True, description="Whether to handle preference memory")
+ pref_top_k: int = Field(6, description="Number of preference results to return")
+ filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
+ model_name_or_path: str | None = Field(None, description="Model name to use for chat")
+ max_tokens: int | None = Field(None, description="Max tokens to generate")
+ temperature: float | None = Field(None, description="Temperature for sampling")
+ top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
+ add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
class UserCreate(BaseRequest):
@@ -129,6 +146,10 @@ class SuggestionResponse(BaseResponse[list]):
data: dict[str, list[str]] | None = Field(None, description="Response data")
+class AddStatusResponse(BaseResponse[dict]):
+ """Response model for add status operations."""
+
+
class ConfigResponse(BaseResponse[None]):
"""Response model for configuration endpoint."""
@@ -141,6 +162,14 @@ class ChatResponse(BaseResponse[str]):
"""Response model for chat operations."""
+class GetMemoryResponse(BaseResponse[dict]):
+ """Response model for getting memories."""
+
+
+class DeleteMemoryResponse(BaseResponse[dict]):
+ """Response model for deleting memories."""
+
+
class UserResponse(BaseResponse[dict]):
"""Response model for user operations."""
@@ -181,11 +210,8 @@ class APISearchRequest(BaseRequest):
readable_cube_ids: list[str] | None = Field(
None, description="List of cube IDs user can read for multi-cube search"
)
- mode: SearchMode = Field(
- os.getenv("SEARCH_MODE", SearchMode.FAST), description="search mode: fast, fine, or mixture"
- )
+ mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
internet_search: bool = Field(False, description="Whether to use internet search")
- moscube: bool = Field(False, description="Whether to use MemOSCube")
top_k: int = Field(10, description="Number of results to return")
chat_history: list[MessageDict] | None = Field(None, description="Chat history")
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
@@ -194,6 +220,7 @@ class APISearchRequest(BaseRequest):
)
include_preference: bool = Field(True, description="Whether to handle preference memory")
pref_top_k: int = Field(6, description="Number of preference results to return")
+ filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
class APIADDRequest(BaseRequest):
@@ -213,8 +240,13 @@ class APIADDRequest(BaseRequest):
operation: list[PermissionDict] | None = Field(
None, description="operation ids for multi cubes"
)
- async_mode: Literal["async", "sync"] | None = Field(
- None, description="Whether to add memory in async mode"
+ async_mode: Literal["async", "sync"] = Field(
+ "async", description="Whether to add memory in async mode"
+ )
+ custom_tags: list[str] | None = Field(None, description="Custom tags for the memory")
+ info: dict[str, str] | None = Field(None, description="Additional information for the memory")
+ is_feedback: bool = Field(
+ False, description="Whether the user feedback in knowladge base service"
)
@@ -232,13 +264,43 @@ class APIChatCompleteRequest(BaseRequest):
)
history: list[MessageDict] | None = Field(None, description="Chat history")
internet_search: bool = Field(False, description="Whether to use internet search")
- moscube: bool = Field(True, description="Whether to use MemOSCube")
- base_prompt: str | None = Field(None, description="Base prompt to use for chat")
+ system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
+ mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
top_k: int = Field(10, description="Number of results to return")
threshold: float = Field(0.5, description="Threshold for filtering references")
session_id: str | None = Field(
"default_session", description="Session ID for soft-filtering memories"
)
+ include_preference: bool = Field(True, description="Whether to handle preference memory")
+ pref_top_k: int = Field(6, description="Number of preference results to return")
+ filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
+ model_name_or_path: str | None = Field(None, description="Model name to use for chat")
+ max_tokens: int | None = Field(None, description="Max tokens to generate")
+ temperature: float | None = Field(None, description="Temperature for sampling")
+ top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
+ add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
+
+
+class AddStatusRequest(BaseRequest):
+ """Request model for checking add status."""
+
+ mem_cube_id: str = Field(..., description="Cube ID")
+ user_id: str | None = Field(None, description="User ID")
+ session_id: str | None = Field(None, description="Session ID")
+
+
+class GetMemoryRequest(BaseRequest):
+ """Request model for getting memories."""
+
+ mem_cube_id: str = Field(..., description="Cube ID")
+ user_id: str | None = Field(None, description="User ID")
+ include_preference: bool = Field(True, description="Whether to handle preference memory")
+
+
+class DeleteMemoryRequest(BaseRequest):
+ """Request model for deleting memories."""
+
+ memory_ids: list[str] = Field(..., description="Memory IDs")
class SuggestionRequest(BaseRequest):
diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py
index 75b614cf4..2f6c5c317 100644
--- a/src/memos/api/routers/product_router.py
+++ b/src/memos/api/routers/product_router.py
@@ -10,7 +10,7 @@
BaseResponse,
ChatCompleteRequest,
ChatRequest,
- GetMemoryRequest,
+ GetMemoryPlaygroundRequest,
MemoryCreateRequest,
MemoryResponse,
SearchRequest,
@@ -159,7 +159,7 @@ def get_suggestion_queries_post(suggestion_req: SuggestionRequest):
@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse)
-def get_all_memories(memory_req: GetMemoryRequest):
+def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
"""Get all memories for a specific user."""
try:
mos_product = get_mos_product_instance()
diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py
index b3b517305..0067d6e2f 100644
--- a/src/memos/api/routers/server_router.py
+++ b/src/memos/api/routers/server_router.py
@@ -23,11 +23,17 @@
from memos.api.handlers.chat_handler import ChatHandler
from memos.api.handlers.search_handler import SearchHandler
from memos.api.product_models import (
+ AddStatusRequest,
+ AddStatusResponse,
APIADDRequest,
APIChatCompleteRequest,
APISearchRequest,
ChatRequest,
+ DeleteMemoryRequest,
+ DeleteMemoryResponse,
+ GetMemoryPlaygroundRequest,
GetMemoryRequest,
+ GetMemoryResponse,
MemoryResponse,
SearchResponse,
SuggestionRequest,
@@ -54,7 +60,11 @@
search_handler = SearchHandler(dependencies)
add_handler = AddHandler(dependencies)
chat_handler = ChatHandler(
- dependencies, search_handler, add_handler, online_bot=components.get("online_bot")
+ dependencies,
+ components["chat_llms"],
+ search_handler,
+ add_handler,
+ online_bot=components.get("online_bot"),
)
# Extract commonly used components for function-based handlers
@@ -99,11 +109,15 @@ def add_memories(add_req: APIADDRequest):
# =============================================================================
-@router.get("/scheduler/status", summary="Get scheduler running status")
-def scheduler_status(user_name: str | None = None):
+@router.get(
+ "/scheduler/status", summary="Get scheduler running status", response_model=AddStatusResponse
+)
+def scheduler_status(add_status_req: AddStatusRequest):
"""Get scheduler running status."""
return handlers.scheduler_handler.handle_scheduler_status(
- user_name=user_name,
+ mem_cube_id=add_status_req.mem_cube_id,
+ user_id=add_status_req.user_id,
+ session_id=add_status_req.session_id,
mem_scheduler=mem_scheduler,
instance_id=INSTANCE_ID,
)
@@ -155,8 +169,8 @@ def chat_complete(chat_req: APIChatCompleteRequest):
return chat_handler.handle_chat_complete(chat_req)
-@router.post("/chat", summary="Chat with MemOS")
-def chat(chat_req: ChatRequest):
+@router.post("/chat/stream", summary="Chat with MemOS")
+def chat_stream(chat_req: ChatRequest):
"""
Chat with MemOS for a specific user. Returns SSE stream.
@@ -166,6 +180,17 @@ def chat(chat_req: ChatRequest):
return chat_handler.handle_chat_stream(chat_req)
+@router.post("/chat/stream/playground", summary="Chat with MemOS playground")
+def chat_stream_playground(chat_req: ChatRequest):
+ """
+ Chat with MemOS for a specific user. Returns SSE stream.
+
+ This endpoint uses the class-based ChatHandler which internally
+ composes SearchHandler and AddHandler for a clean architecture.
+ """
+ return chat_handler.handle_chat_stream_playground(chat_req)
+
+
# =============================================================================
# Suggestion API Endpoints
# =============================================================================
@@ -188,12 +213,12 @@ def get_suggestion_queries(suggestion_req: SuggestionRequest):
# =============================================================================
-# Memory Retrieval API Endpoints
+# Memory Retrieval Delete API Endpoints
# =============================================================================
@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse)
-def get_all_memories(memory_req: GetMemoryRequest):
+def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
"""
Get all memories or subgraph for a specific user.
@@ -219,3 +244,20 @@ def get_all_memories(memory_req: GetMemoryRequest):
memory_type=memory_req.memory_type or "text_mem",
naive_mem_cube=naive_mem_cube,
)
+
+
+@router.post("/get_memory", summary="Get memories for user", response_model=GetMemoryResponse)
+def get_memories(memory_req: GetMemoryRequest):
+ return handlers.memory_handler.handle_get_memories(
+ get_mem_req=memory_req,
+ naive_mem_cube=naive_mem_cube,
+ )
+
+
+@router.post(
+ "/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse
+)
+def delete_memories(memory_req: DeleteMemoryRequest):
+ return handlers.memory_handler.handle_delete_memories(
+ delete_mem_req=memory_req, naive_mem_cube=naive_mem_cube
+ )
diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py
index d69a0a0fc..70217b896 100644
--- a/src/memos/configs/llm.py
+++ b/src/memos/configs/llm.py
@@ -9,14 +9,17 @@ class BaseLLMConfig(BaseConfig):
"""Base configuration class for LLMs."""
model_name_or_path: str = Field(..., description="Model name or path")
- temperature: float = Field(default=0.8, description="Temperature for sampling")
- max_tokens: int = Field(default=1024, description="Maximum number of tokens to generate")
- top_p: float = Field(default=0.9, description="Top-p sampling parameter")
+ temperature: float = Field(default=0.7, description="Temperature for sampling")
+ max_tokens: int = Field(default=8192, description="Maximum number of tokens to generate")
+ top_p: float = Field(default=0.95, description="Top-p sampling parameter")
top_k: int = Field(default=50, description="Top-k sampling parameter")
remove_think_prefix: bool = Field(
default=False,
description="Remove content within think tags from the generated text",
)
+ default_headers: dict[str, Any] | None = Field(
+ default=None, description="Default headers for LLM requests"
+ )
class OpenAILLMConfig(BaseLLMConfig):
@@ -27,6 +30,18 @@ class OpenAILLMConfig(BaseLLMConfig):
extra_body: Any = Field(default=None, description="extra body")
+class OpenAIResponsesLLMConfig(BaseLLMConfig):
+ api_key: str = Field(..., description="API key for OpenAI")
+ api_base: str = Field(
+ default="https://api.openai.com/v1", description="Base URL for OpenAI responses API"
+ )
+ extra_body: Any = Field(default=None, description="extra body")
+ enable_thinking: bool = Field(
+ default=False,
+ description="Enable reasoning outputs from vLLM",
+ )
+
+
class QwenLLMConfig(BaseLLMConfig):
api_key: str = Field(..., description="API key for DashScope (Qwen)")
api_base: str = Field(
@@ -34,7 +49,6 @@ class QwenLLMConfig(BaseLLMConfig):
description="Base URL for Qwen OpenAI-compatible API",
)
extra_body: Any = Field(default=None, description="extra body")
- model_name_or_path: str = Field(..., description="Model name for Qwen, e.g., 'qwen-plus'")
class DeepSeekLLMConfig(BaseLLMConfig):
@@ -44,9 +58,6 @@ class DeepSeekLLMConfig(BaseLLMConfig):
description="Base URL for DeepSeek OpenAI-compatible API",
)
extra_body: Any = Field(default=None, description="Extra options for API")
- model_name_or_path: str = Field(
- ..., description="Model name: 'deepseek-chat' or 'deepseek-reasoner'"
- )
class AzureLLMConfig(BaseLLMConfig):
@@ -61,11 +72,27 @@ class AzureLLMConfig(BaseLLMConfig):
api_key: str = Field(..., description="API key for Azure OpenAI")
+class AzureResponsesLLMConfig(BaseLLMConfig):
+ base_url: str = Field(
+ default="https://api.openai.azure.com/",
+ description="Base URL for Azure OpenAI API",
+ )
+ api_version: str = Field(
+ default="2024-03-01-preview",
+ description="API version for Azure OpenAI",
+ )
+ api_key: str = Field(..., description="API key for Azure OpenAI")
+
+
class OllamaLLMConfig(BaseLLMConfig):
api_base: str = Field(
default="http://localhost:11434",
description="Base URL for Ollama API",
)
+ enable_thinking: bool = Field(
+ default=False,
+ description="Enable reasoning outputs from Ollama",
+ )
class HFLLMConfig(BaseLLMConfig):
@@ -85,6 +112,10 @@ class VLLMLLMConfig(BaseLLMConfig):
default="http://localhost:8088/v1",
description="Base URL for vLLM API",
)
+ enable_thinking: bool = Field(
+ default=False,
+ description="Enable reasoning outputs from vLLM",
+ )
class LLMConfigFactory(BaseConfig):
@@ -102,6 +133,7 @@ class LLMConfigFactory(BaseConfig):
"huggingface_singleton": HFLLMConfig, # Add singleton support
"qwen": QwenLLMConfig,
"deepseek": DeepSeekLLMConfig,
+ "openai_new": OpenAIResponsesLLMConfig,
}
@field_validator("backend")
diff --git a/src/memos/llms/deepseek.py b/src/memos/llms/deepseek.py
index f5ee4842b..a90f8eb31 100644
--- a/src/memos/llms/deepseek.py
+++ b/src/memos/llms/deepseek.py
@@ -1,10 +1,6 @@
-from collections.abc import Generator
-
from memos.configs.llm import DeepSeekLLMConfig
from memos.llms.openai import OpenAILLM
-from memos.llms.utils import remove_thinking_tags
from memos.log import get_logger
-from memos.types import MessageList
logger = get_logger(__name__)
@@ -15,40 +11,3 @@ class DeepSeekLLM(OpenAILLM):
def __init__(self, config: DeepSeekLLMConfig):
super().__init__(config)
-
- def generate(self, messages: MessageList) -> str:
- """Generate a response from DeepSeek."""
- response = self.client.chat.completions.create(
- model=self.config.model_name_or_path,
- messages=messages,
- temperature=self.config.temperature,
- max_tokens=self.config.max_tokens,
- top_p=self.config.top_p,
- extra_body=self.config.extra_body,
- )
- logger.info(f"Response from DeepSeek: {response.model_dump_json()}")
- response_content = response.choices[0].message.content
- if self.config.remove_think_prefix:
- return remove_thinking_tags(response_content)
- else:
- return response_content
-
- def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
- """Stream response from DeepSeek."""
- response = self.client.chat.completions.create(
- model=self.config.model_name_or_path,
- messages=messages,
- stream=True,
- temperature=self.config.temperature,
- max_tokens=self.config.max_tokens,
- top_p=self.config.top_p,
- extra_body=self.config.extra_body,
- )
- # Streaming chunks of text
- for chunk in response:
- delta = chunk.choices[0].delta
- if hasattr(delta, "reasoning_content") and delta.reasoning_content:
- yield delta.reasoning_content
-
- if hasattr(delta, "content") and delta.content:
- yield delta.content
diff --git a/src/memos/llms/factory.py b/src/memos/llms/factory.py
index 8589d7750..8f4da662f 100644
--- a/src/memos/llms/factory.py
+++ b/src/memos/llms/factory.py
@@ -7,6 +7,7 @@
from memos.llms.hf_singleton import HFSingletonLLM
from memos.llms.ollama import OllamaLLM
from memos.llms.openai import AzureLLM, OpenAILLM
+from memos.llms.openai_new import OpenAIResponsesLLM
from memos.llms.qwen import QwenLLM
from memos.llms.vllm import VLLMLLM
from memos.memos_tools.singleton import singleton_factory
@@ -24,6 +25,7 @@ class LLMFactory(BaseLLM):
"vllm": VLLMLLM,
"qwen": QwenLLM,
"deepseek": DeepSeekLLM,
+ "openai_new": OpenAIResponsesLLM,
}
@classmethod
diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py
index be0d1d95f..d46db7c9e 100644
--- a/src/memos/llms/hf.py
+++ b/src/memos/llms/hf.py
@@ -54,7 +54,9 @@ def __init__(self, config: HFLLMConfig):
processors.append(TopPLogitsWarper(self.config.top_p))
self.logits_processors = LogitsProcessorList(processors)
- def generate(self, messages: MessageList, past_key_values: DynamicCache | None = None):
+ def generate(
+ self, messages: MessageList, past_key_values: DynamicCache | None = None, **kwargs
+ ):
"""
Generate a response from the model. If past_key_values is provided, use cache-augmented generation.
Args:
@@ -68,12 +70,12 @@ def generate(self, messages: MessageList, past_key_values: DynamicCache | None =
)
logger.info(f"HFLLM prompt: {prompt}")
if past_key_values is None:
- return self._generate_full(prompt)
+ return self._generate_full(prompt, **kwargs)
else:
- return self._generate_with_cache(prompt, past_key_values)
+ return self._generate_with_cache(prompt, past_key_values, **kwargs)
def generate_stream(
- self, messages: MessageList, past_key_values: DynamicCache | None = None
+ self, messages: MessageList, past_key_values: DynamicCache | None = None, **kwargs
) -> Generator[str, None, None]:
"""
Generate a streaming response from the model.
@@ -92,7 +94,7 @@ def generate_stream(
else:
yield from self._generate_with_cache_stream(prompt, past_key_values)
- def _generate_full(self, prompt: str) -> str:
+ def _generate_full(self, prompt: str, **kwargs) -> str:
"""
Generate output from scratch using the full prompt.
Args:
@@ -102,13 +104,13 @@ def _generate_full(self, prompt: str) -> str:
"""
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
gen_kwargs = {
- "max_new_tokens": getattr(self.config, "max_tokens", 128),
+ "max_new_tokens": kwargs.get("max_tokens", self.config.max_tokens),
"do_sample": getattr(self.config, "do_sample", True),
}
if self.config.do_sample:
- gen_kwargs["temperature"] = self.config.temperature
- gen_kwargs["top_k"] = self.config.top_k
- gen_kwargs["top_p"] = self.config.top_p
+ gen_kwargs["temperature"] = kwargs.get("temperature", self.config.temperature)
+ gen_kwargs["top_k"] = kwargs.get("top_k", self.config.top_k)
+ gen_kwargs["top_p"] = kwargs.get("top_p", self.config.top_p)
gen_ids = self.model.generate(
**inputs,
**gen_kwargs,
@@ -125,7 +127,7 @@ def _generate_full(self, prompt: str) -> str:
else response
)
- def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]:
+ def _generate_full_stream(self, prompt: str, **kwargs) -> Generator[str, None, None]:
"""
Generate output from scratch using the full prompt with streaming.
Args:
@@ -138,7 +140,7 @@ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]:
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
# Get generation parameters
- max_new_tokens = getattr(self.config, "max_tokens", 128)
+ max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens)
remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
# Manual streaming generation
@@ -192,7 +194,7 @@ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]:
else:
yield new_token_text
- def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
+ def _generate_with_cache(self, query: str, kv: DynamicCache, **kwargs) -> str:
"""
Generate output incrementally using an existing KV cache.
Args:
@@ -209,7 +211,7 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
logits, kv = self._prefill(query_ids, kv)
next_token = self._select_next_token(logits)
generated = [next_token]
- for _ in range(getattr(self.config, "max_tokens", 128) - 1):
+ for _ in range(kwargs.get("max_tokens", self.config.max_tokens) - 1):
if self._should_stop(next_token):
break
logits, kv = self._prefill(next_token, kv)
@@ -228,7 +230,7 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
)
def _generate_with_cache_stream(
- self, query: str, kv: DynamicCache
+ self, query: str, kv: DynamicCache, **kwargs
) -> Generator[str, None, None]:
"""
Generate output incrementally using an existing KV cache with streaming.
@@ -242,7 +244,7 @@ def _generate_with_cache_stream(
query, return_tensors="pt", add_special_tokens=False
).input_ids.to(self.model.device)
- max_new_tokens = getattr(self.config, "max_tokens", 128)
+ max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens)
remove_think_prefix = getattr(self.config, "remove_think_prefix", False)
# Initial forward pass
diff --git a/src/memos/llms/ollama.py b/src/memos/llms/ollama.py
index 050b7a253..bd92f9625 100644
--- a/src/memos/llms/ollama.py
+++ b/src/memos/llms/ollama.py
@@ -1,7 +1,7 @@
from collections.abc import Generator
from typing import Any
-from ollama import Client
+from ollama import Client, Message
from memos.configs.llm import OllamaLLMConfig
from memos.llms.base import BaseLLM
@@ -54,7 +54,7 @@ def _ensure_model_exists(self):
except Exception as e:
logger.warning(f"Could not verify model existence: {e}")
- def generate(self, messages: MessageList) -> Any:
+ def generate(self, messages: MessageList, **kwargs) -> Any:
"""
Generate a response from Ollama LLM.
@@ -68,19 +68,68 @@ def generate(self, messages: MessageList) -> Any:
model=self.config.model_name_or_path,
messages=messages,
options={
- "temperature": self.config.temperature,
- "num_predict": self.config.max_tokens,
- "top_p": self.config.top_p,
- "top_k": self.config.top_k,
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "top_k": kwargs.get("top_k", self.config.top_k),
},
+ think=self.config.enable_thinking,
+ tools=kwargs.get("tools"),
)
logger.info(f"Raw response from Ollama: {response.model_dump_json()}")
-
- str_response = response["message"]["content"] or ""
+ tool_calls = getattr(response.message, "tool_calls", None)
+ if isinstance(tool_calls, list) and len(tool_calls) > 0:
+ return self.tool_call_parser(tool_calls)
+
+ str_thinking = (
+ f"{response.message.thinking}"
+ if hasattr(response.message, "thinking")
+ else ""
+ )
+ str_response = response.message.content
if self.config.remove_think_prefix:
return remove_thinking_tags(str_response)
else:
- return str_response
+ return str_thinking + str_response
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
- raise NotImplementedError
+ if kwargs.get("tools"):
+ logger.info("stream api not support tools")
+ return
+
+ response = self.client.chat(
+ model=kwargs.get("model_name_or_path", self.config.model_name_or_path),
+ messages=messages,
+ options={
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "top_k": kwargs.get("top_k", self.config.top_k),
+ },
+ think=self.config.enable_thinking,
+ stream=True,
+ )
+ # Streaming chunks of text
+ reasoning_started = False
+ for chunk in response:
+ if hasattr(chunk.message, "thinking") and chunk.message.thinking:
+ if not reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+ reasoning_started = True
+ yield chunk.message.thinking
+
+ if hasattr(chunk.message, "content") and chunk.message.content:
+ if reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+ reasoning_started = False
+ yield chunk.message.content
+
+ def tool_call_parser(self, tool_calls: list[Message.ToolCall]) -> list[dict]:
+ """Parse tool calls from OpenAI response."""
+ return [
+ {
+ "function_name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ }
+ for tool_call in tool_calls
+ ]
diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py
index da55ae593..9b348adcf 100644
--- a/src/memos/llms/openai.py
+++ b/src/memos/llms/openai.py
@@ -1,12 +1,12 @@
-import hashlib
import json
-import time
from collections.abc import Generator
-from typing import ClassVar
import openai
+from openai._types import NOT_GIVEN
+from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
+
from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig
from memos.llms.base import BaseLLM
from memos.llms.utils import remove_thinking_tags
@@ -19,84 +19,57 @@
class OpenAILLM(BaseLLM):
- """OpenAI LLM class with singleton pattern."""
-
- _instances: ClassVar[dict] = {} # Class variable to store instances
-
- def __new__(cls, config: OpenAILLMConfig) -> "OpenAILLM":
- config_hash = cls._get_config_hash(config)
-
- if config_hash not in cls._instances:
- logger.info(f"Creating new OpenAI LLM instance for config hash: {config_hash}")
- instance = super().__new__(cls)
- cls._instances[config_hash] = instance
- else:
- logger.info(f"Reusing existing OpenAI LLM instance for config hash: {config_hash}")
-
- return cls._instances[config_hash]
+ """OpenAI LLM class via openai.chat.completions.create."""
def __init__(self, config: OpenAILLMConfig):
- # Avoid duplicate initialization
- if hasattr(self, "_initialized"):
- return
-
self.config = config
- self.client = openai.Client(api_key=config.api_key, base_url=config.api_base)
- self._initialized = True
+ self.client = openai.Client(
+ api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers
+ )
logger.info("OpenAI LLM instance initialized")
- @classmethod
- def _get_config_hash(cls, config: OpenAILLMConfig) -> str:
- """Generate hash value of configuration"""
- config_dict = config.model_dump()
- config_str = json.dumps(config_dict, sort_keys=True)
- return hashlib.md5(config_str.encode()).hexdigest()
-
- @classmethod
- def clear_cache(cls):
- """Clear all cached instances"""
- cls._instances.clear()
- logger.info("OpenAI LLM instance cache cleared")
-
- @timed(log=True, log_prefix="model_timed_openai")
+ @timed(log=True, log_prefix="OpenAI LLM")
def generate(self, messages: MessageList, **kwargs) -> str:
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
- temperature = kwargs.get("temperature", self.config.temperature)
- max_tokens = kwargs.get("max_tokens", self.config.max_tokens)
- top_p = kwargs.get("top_p", self.config.top_p)
- start_time = time.time()
- logger.info(f"openai model request start, model_name: {self.config.model_name_or_path}")
-
response = self.client.chat.completions.create(
- model=self.config.model_name_or_path,
+ model=kwargs.get("model_name_or_path", self.config.model_name_or_path),
messages=messages,
- extra_body=self.config.extra_body,
- temperature=temperature,
- max_tokens=max_tokens,
- top_p=top_p,
- )
-
- end_time = time.time()
- logger.info(
- f"openai model request end, time_cost: {end_time - start_time:.0f} ms, response from OpenAI: {response.model_dump_json()}"
+ temperature=kwargs.get("temperature", self.config.temperature),
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
+ top_p=kwargs.get("top_p", self.config.top_p),
+ extra_body=kwargs.get("extra_body", self.config.extra_body),
+ tools=kwargs.get("tools", NOT_GIVEN),
)
+ logger.info(f"Response from OpenAI: {response.model_dump_json()}")
+ tool_calls = getattr(response.choices[0].message, "tool_calls", None)
+ if isinstance(tool_calls, list) and len(tool_calls) > 0:
+ return self.tool_call_parser(tool_calls)
response_content = response.choices[0].message.content
+ reasoning_content = getattr(response.choices[0].message, "reasoning_content", None)
+ if isinstance(reasoning_content, str) and reasoning_content:
+ reasoning_content = f"{reasoning_content}"
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
- else:
- return response_content
+ if reasoning_content:
+ return reasoning_content + response_content
+ return response_content
@timed(log=True, log_prefix="OpenAI LLM")
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
"""Stream response from OpenAI LLM with optional reasoning support."""
+ if kwargs.get("tools"):
+ logger.info("stream api not support tools")
+ return
+
response = self.client.chat.completions.create(
model=self.config.model_name_or_path,
messages=messages,
stream=True,
- temperature=self.config.temperature,
- max_tokens=self.config.max_tokens,
- top_p=self.config.top_p,
- extra_body=self.config.extra_body,
+ temperature=kwargs.get("temperature", self.config.temperature),
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
+ top_p=kwargs.get("top_p", self.config.top_p),
+ extra_body=kwargs.get("extra_body", self.config.extra_body),
+ tools=kwargs.get("tools", NOT_GIVEN),
)
reasoning_started = False
@@ -104,7 +77,7 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non
for chunk in response:
delta = chunk.choices[0].delta
- # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen)
+ # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek)
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
if not reasoning_started and not self.config.remove_think_prefix:
yield ""
@@ -120,63 +93,44 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non
if reasoning_started and not self.config.remove_think_prefix:
yield ""
+ def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]:
+ """Parse tool calls from OpenAI response."""
+ return [
+ {
+ "tool_call_id": tool_call.id,
+ "function_name": tool_call.function.name,
+ "arguments": json.loads(tool_call.function.arguments),
+ }
+ for tool_call in tool_calls
+ ]
+
class AzureLLM(BaseLLM):
"""Azure OpenAI LLM class with singleton pattern."""
- _instances: ClassVar[dict] = {} # Class variable to store instances
-
- def __new__(cls, config: AzureLLMConfig):
- # Generate hash value of config as cache key
- config_hash = cls._get_config_hash(config)
-
- if config_hash not in cls._instances:
- logger.info(f"Creating new Azure LLM instance for config hash: {config_hash}")
- instance = super().__new__(cls)
- cls._instances[config_hash] = instance
- else:
- logger.info(f"Reusing existing Azure LLM instance for config hash: {config_hash}")
-
- return cls._instances[config_hash]
-
def __init__(self, config: AzureLLMConfig):
- # Avoid duplicate initialization
- if hasattr(self, "_initialized"):
- return
-
self.config = config
self.client = openai.AzureOpenAI(
azure_endpoint=config.base_url,
api_version=config.api_version,
api_key=config.api_key,
)
- self._initialized = True
logger.info("Azure LLM instance initialized")
- @classmethod
- def _get_config_hash(cls, config: AzureLLMConfig) -> str:
- """Generate hash value of configuration"""
- # Convert config to dict and sort to ensure consistency
- config_dict = config.model_dump()
- config_str = json.dumps(config_dict, sort_keys=True)
- return hashlib.md5(config_str.encode()).hexdigest()
-
- @classmethod
- def clear_cache(cls):
- """Clear all cached instances"""
- cls._instances.clear()
- logger.info("Azure LLM instance cache cleared")
-
- def generate(self, messages: MessageList) -> str:
+ def generate(self, messages: MessageList, **kwargs) -> str:
"""Generate a response from Azure OpenAI LLM."""
response = self.client.chat.completions.create(
model=self.config.model_name_or_path,
messages=messages,
- temperature=self.config.temperature,
- max_tokens=self.config.max_tokens,
- top_p=self.config.top_p,
+ temperature=kwargs.get("temperature", self.config.temperature),
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
+ top_p=kwargs.get("top_p", self.config.top_p),
+ tools=kwargs.get("tools", NOT_GIVEN),
+ extra_body=kwargs.get("extra_body", self.config.extra_body),
)
logger.info(f"Response from Azure OpenAI: {response.model_dump_json()}")
+ if response.choices[0].message.tool_calls:
+ return self.tool_call_parser(response.choices[0].message.tool_calls)
response_content = response.choices[0].message.content
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
@@ -184,4 +138,49 @@ def generate(self, messages: MessageList) -> str:
return response_content
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
- raise NotImplementedError
+ """Stream response from Azure OpenAI LLM with optional reasoning support."""
+ if kwargs.get("tools"):
+ logger.info("stream api not support tools")
+ return
+
+ response = self.client.chat.completions.create(
+ model=self.config.model_name_or_path,
+ messages=messages,
+ stream=True,
+ temperature=kwargs.get("temperature", self.config.temperature),
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
+ top_p=kwargs.get("top_p", self.config.top_p),
+ extra_body=kwargs.get("extra_body", self.config.extra_body),
+ )
+
+ reasoning_started = False
+
+ for chunk in response:
+ delta = chunk.choices[0].delta
+
+ # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek)
+ if hasattr(delta, "reasoning_content") and delta.reasoning_content:
+ if not reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+ reasoning_started = True
+ yield delta.reasoning_content
+ elif hasattr(delta, "content") and delta.content:
+ if reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+ reasoning_started = False
+ yield delta.content
+
+ # Ensure we close the block if not already done
+ if reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+
+ def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]:
+ """Parse tool calls from OpenAI response."""
+ return [
+ {
+ "tool_call_id": tool_call.id,
+ "function_name": tool_call.function.name,
+ "arguments": json.loads(tool_call.function.arguments),
+ }
+ for tool_call in tool_calls
+ ]
diff --git a/src/memos/llms/openai_new.py b/src/memos/llms/openai_new.py
new file mode 100644
index 000000000..766a17fda
--- /dev/null
+++ b/src/memos/llms/openai_new.py
@@ -0,0 +1,198 @@
+import json
+
+from collections.abc import Generator
+
+import openai
+
+from openai._types import NOT_GIVEN
+from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
+from openai.types.responses.response_reasoning_item import ResponseReasoningItem
+
+from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig
+from memos.llms.base import BaseLLM
+from memos.llms.utils import remove_thinking_tags
+from memos.log import get_logger
+from memos.types import MessageList
+from memos.utils import timed
+
+
+logger = get_logger(__name__)
+
+
+class OpenAIResponsesLLM(BaseLLM):
+ def __init__(self, config: OpenAILLMConfig):
+ self.config = config
+ self.client = openai.Client(
+ api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers
+ )
+
+ @timed(log=True, log_prefix="OpenAI Responses LLM")
+ def generate(self, messages: MessageList, **kwargs) -> str:
+ response = self.client.responses.create(
+ model=kwargs.get("model_name_or_path", self.config.model_name_or_path),
+ input=messages,
+ temperature=kwargs.get("temperature", self.config.temperature),
+ top_p=kwargs.get("top_p", self.config.top_p),
+ max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens),
+ reasoning={"effort": "low", "summary": "auto"}
+ if self.config.enable_thinking
+ else NOT_GIVEN,
+ tools=kwargs.get("tools", NOT_GIVEN),
+ extra_body=kwargs.get("extra_body", self.config.extra_body),
+ )
+ tool_call_outputs = [
+ item for item in response.output if isinstance(item, ResponseFunctionToolCall)
+ ]
+ if tool_call_outputs:
+ return self.tool_call_parser(tool_call_outputs)
+
+ output_text = getattr(response, "output_text", "")
+ output_reasoning = [
+ item for item in response.output if isinstance(item, ResponseReasoningItem)
+ ]
+ summary = output_reasoning[0].summary
+
+ if self.config.remove_think_prefix:
+ return remove_thinking_tags(output_text)
+ if summary:
+ return f"{summary[0].text}" + output_text
+ return output_text
+
+ @timed(log=True, log_prefix="OpenAI Responses LLM")
+ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
+ if kwargs.get("tools"):
+ logger.info("stream api not support tools")
+ return
+
+ stream = self.client.responses.create(
+ model=kwargs.get("model_name_or_path", self.config.model_name_or_path),
+ input=messages,
+ temperature=kwargs.get("temperature", self.config.temperature),
+ top_p=kwargs.get("top_p", self.config.top_p),
+ max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens),
+ reasoning={"effort": "low", "summary": "auto"}
+ if self.config.enable_thinking
+ else NOT_GIVEN,
+ extra_body=kwargs.get("extra_body", self.config.extra_body),
+ stream=True,
+ )
+
+ reasoning_started = False
+
+ for event in stream:
+ event_type = getattr(event, "type", "")
+ if event_type in (
+ "response.reasoning.delta",
+ "response.reasoning_summary_text.delta",
+ ) and hasattr(event, "delta"):
+ if not self.config.remove_think_prefix:
+ if not reasoning_started:
+ yield ""
+ reasoning_started = True
+ yield event.delta
+ elif event_type == "response.output_text.delta" and hasattr(event, "delta"):
+ if reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+ reasoning_started = False
+ yield event.delta
+
+ if reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+
+ def tool_call_parser(self, tool_calls: list[ResponseFunctionToolCall]) -> list[dict]:
+ """Parse tool calls from OpenAI response."""
+ return [
+ {
+ "tool_call_id": tool_call.call_id,
+ "function_name": tool_call.name,
+ "arguments": json.loads(tool_call.arguments),
+ }
+ for tool_call in tool_calls
+ ]
+
+
+class AzureResponsesLLM(BaseLLM):
+ def __init__(self, config: AzureLLMConfig):
+ self.config = config
+ self.client = openai.AzureOpenAI(
+ azure_endpoint=config.base_url,
+ api_version=config.api_version,
+ api_key=config.api_key,
+ )
+
+ def generate(self, messages: MessageList, **kwargs) -> str:
+ response = self.client.responses.create(
+ model=self.config.model_name_or_path,
+ input=messages,
+ temperature=kwargs.get("temperature", self.config.temperature),
+ top_p=kwargs.get("top_p", self.config.top_p),
+ max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens),
+ tools=kwargs.get("tools", NOT_GIVEN),
+ extra_body=kwargs.get("extra_body", self.config.extra_body),
+ reasoning={"effort": "low", "summary": "auto"}
+ if self.config.enable_thinking
+ else NOT_GIVEN,
+ )
+
+ output_text = getattr(response, "output_text", "")
+ output_reasoning = [
+ item for item in response.output if isinstance(item, ResponseReasoningItem)
+ ]
+ summary = output_reasoning[0].summary
+
+ if self.config.remove_think_prefix:
+ return remove_thinking_tags(output_text)
+ if summary:
+ return f"{summary[0].text}" + output_text
+ return output_text
+
+ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
+ if kwargs.get("tools"):
+ logger.info("stream api not support tools")
+ return
+
+ stream = self.client.responses.create(
+ model=self.config.model_name_or_path,
+ input=messages,
+ temperature=kwargs.get("temperature", self.config.temperature),
+ top_p=kwargs.get("top_p", self.config.top_p),
+ max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens),
+ extra_body=kwargs.get("extra_body", self.config.extra_body),
+ stream=True,
+ reasoning={"effort": "low", "summary": "auto"}
+ if self.config.enable_thinking
+ else NOT_GIVEN,
+ )
+
+ reasoning_started = False
+
+ for event in stream:
+ event_type = getattr(event, "type", "")
+ if event_type in (
+ "response.reasoning.delta",
+ "response.reasoning_summary_text.delta",
+ ) and hasattr(event, "delta"):
+ if not self.config.remove_think_prefix:
+ if not reasoning_started:
+ yield ""
+ reasoning_started = True
+ yield event.delta
+ elif event_type == "response.output_text.delta" and hasattr(event, "delta"):
+ if reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+ reasoning_started = False
+ yield event.delta
+
+ if reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+
+ def tool_call_parser(self, tool_calls: list[ResponseFunctionToolCall]) -> list[dict]:
+ """Parse tool calls from OpenAI response."""
+ return [
+ {
+ "tool_call_id": tool_call.call_id,
+ "function_name": tool_call.name,
+ "arguments": json.loads(tool_call.arguments),
+ }
+ for tool_call in tool_calls
+ ]
diff --git a/src/memos/llms/qwen.py b/src/memos/llms/qwen.py
index a47fcdf36..d54e23c7f 100644
--- a/src/memos/llms/qwen.py
+++ b/src/memos/llms/qwen.py
@@ -1,10 +1,6 @@
-from collections.abc import Generator
-
from memos.configs.llm import QwenLLMConfig
from memos.llms.openai import OpenAILLM
-from memos.llms.utils import remove_thinking_tags
from memos.log import get_logger
-from memos.types import MessageList
logger = get_logger(__name__)
@@ -15,49 +11,3 @@ class QwenLLM(OpenAILLM):
def __init__(self, config: QwenLLMConfig):
super().__init__(config)
-
- def generate(self, messages: MessageList) -> str:
- """Generate a response from Qwen LLM."""
- response = self.client.chat.completions.create(
- model=self.config.model_name_or_path,
- messages=messages,
- extra_body=self.config.extra_body,
- temperature=self.config.temperature,
- max_tokens=self.config.max_tokens,
- top_p=self.config.top_p,
- )
- logger.info(f"Response from Qwen: {response.model_dump_json()}")
- response_content = response.choices[0].message.content
- if self.config.remove_think_prefix:
- return remove_thinking_tags(response_content)
- else:
- return response_content
-
- def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
- """Stream response from Qwen LLM."""
- response = self.client.chat.completions.create(
- model=self.config.model_name_or_path,
- messages=messages,
- stream=True,
- temperature=self.config.temperature,
- max_tokens=self.config.max_tokens,
- top_p=self.config.top_p,
- extra_body=self.config.extra_body,
- )
-
- reasoning_started = False
- for chunk in response:
- delta = chunk.choices[0].delta
-
- # Some models may have separate `reasoning_content` vs `content`
- # For Qwen (DashScope), likely only `content` is used
- if hasattr(delta, "reasoning_content") and delta.reasoning_content:
- if not reasoning_started and not self.config.remove_think_prefix:
- yield ""
- reasoning_started = True
- yield delta.reasoning_content
- elif hasattr(delta, "content") and delta.content:
- if reasoning_started and not self.config.remove_think_prefix:
- yield ""
- reasoning_started = False
- yield delta.content
diff --git a/src/memos/llms/vllm.py b/src/memos/llms/vllm.py
index c3750bb4b..1cf8d4f39 100644
--- a/src/memos/llms/vllm.py
+++ b/src/memos/llms/vllm.py
@@ -1,5 +1,11 @@
+import json
+
from typing import Any, cast
+import openai
+
+from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
+
from memos.configs.llm import VLLMLLMConfig
from memos.llms.base import BaseLLM
from memos.llms.utils import remove_thinking_tags
@@ -27,10 +33,10 @@ def __init__(self, config: VLLMLLMConfig):
if not api_key:
api_key = "dummy"
- import openai
-
self.client = openai.Client(
- api_key=api_key, base_url=getattr(self.config, "api_base", "http://localhost:8088/v1")
+ api_key=api_key,
+ base_url=getattr(self.config, "api_base", "http://localhost:8088/v1"),
+ default_headers=self.config.default_headers,
)
def build_vllm_kv_cache(self, messages: Any) -> str:
@@ -85,36 +91,54 @@ def build_vllm_kv_cache(self, messages: Any) -> str:
return prompt
- def generate(self, messages: list[MessageDict]) -> str:
+ def generate(self, messages: list[MessageDict], **kwargs) -> str:
"""
Generate a response from the model.
"""
if self.client:
- return self._generate_with_api_client(messages)
+ return self._generate_with_api_client(messages, **kwargs)
else:
raise RuntimeError("API client is not available")
- def _generate_with_api_client(self, messages: list[MessageDict]) -> str:
+ def _generate_with_api_client(self, messages: list[MessageDict], **kwargs) -> str:
"""
- Generate response using vLLM API client.
+ Generate response using vLLM API client. detail view https://docs.vllm.ai/en/latest/features/reasoning_outputs/
"""
if self.client:
completion_kwargs = {
- "model": self.config.model_name_or_path,
+ "model": kwargs.get("model_name_or_path", self.config.model_name_or_path),
"messages": messages,
- "temperature": float(getattr(self.config, "temperature", 0.8)),
- "max_tokens": int(getattr(self.config, "max_tokens", 1024)),
- "top_p": float(getattr(self.config, "top_p", 0.9)),
- "extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "extra_body": {
+ "chat_template_kwargs": {
+ "enable_thinking": kwargs.get(
+ "enable_thinking", self.config.enable_thinking
+ )
+ }
+ },
}
+ if kwargs.get("tools"):
+ completion_kwargs["tools"] = kwargs.get("tools")
+ completion_kwargs["tool_choice"] = kwargs.get("tool_choice", "auto")
response = self.client.chat.completions.create(**completion_kwargs)
+
+ if response.choices[0].message.tool_calls:
+ return self.tool_call_parser(response.choices[0].message.tool_calls)
+
+ reasoning_content = (
+ f"{response.choices[0].message.reasoning}"
+ if hasattr(response.choices[0].message, "reasoning")
+ else ""
+ )
response_text = response.choices[0].message.content or ""
logger.info(f"VLLM API response: {response_text}")
return (
remove_thinking_tags(response_text)
if getattr(self.config, "remove_think_prefix", False)
- else response_text
+ else reasoning_content + response_text
)
else:
raise RuntimeError("API client is not available")
@@ -130,26 +154,59 @@ def _messages_to_prompt(self, messages: list[MessageDict]) -> str:
prompt_parts.append(f"{role.capitalize()}: {content}")
return "\n".join(prompt_parts)
- def generate_stream(self, messages: list[MessageDict]):
+ def generate_stream(self, messages: list[MessageDict], **kwargs):
"""
Generate a response from the model using streaming.
Yields content chunks as they are received.
"""
+ if kwargs.get("tools"):
+ logger.info("stream api not support tools")
+ return
+
if self.client:
completion_kwargs = {
"model": self.config.model_name_or_path,
"messages": messages,
- "temperature": float(getattr(self.config, "temperature", 0.8)),
- "max_tokens": int(getattr(self.config, "max_tokens", 1024)),
- "top_p": float(getattr(self.config, "top_p", 0.9)),
- "stream": True, # Enable streaming
- "extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
+ "temperature": kwargs.get("temperature", self.config.temperature),
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
+ "top_p": kwargs.get("top_p", self.config.top_p),
+ "stream": True,
+ "extra_body": {
+ "chat_template_kwargs": {
+ "enable_thinking": kwargs.get(
+ "enable_thinking", self.config.enable_thinking
+ )
+ }
+ },
}
stream = self.client.chat.completions.create(**completion_kwargs)
+
+ reasoning_started = False
for chunk in stream:
- content = chunk.choices[0].delta.content
- if content:
- yield content
+ delta = chunk.choices[0].delta
+ if hasattr(delta, "reasoning") and delta.reasoning:
+ if not reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+ reasoning_started = True
+ yield delta.reasoning
+
+ if hasattr(delta, "content") and delta.content:
+ if reasoning_started and not self.config.remove_think_prefix:
+ yield ""
+ reasoning_started = False
+ yield delta.content
+
else:
raise RuntimeError("API client is not available")
+
+ def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]:
+ """Parse tool calls from OpenAI response."""
+ return [
+ {
+ "tool_call_id": tool_call.id,
+ "function_name": tool_call.function.name,
+ "arguments": json.loads(tool_call.function.arguments),
+ }
+ for tool_call in tool_calls
+ ]
diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py
index 63b87157c..a53e19191 100644
--- a/src/memos/mem_scheduler/base_scheduler.py
+++ b/src/memos/mem_scheduler/base_scheduler.py
@@ -152,7 +152,6 @@ def init_mem_cube(
if searcher is None:
self.searcher: Searcher = self.text_mem.get_searcher(
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
- moscube=False,
)
else:
self.searcher = searcher
@@ -577,12 +576,12 @@ def get_web_log_messages(self) -> list[dict]:
def _map_label(label: str) -> str:
from memos.mem_scheduler.schemas.general_schemas import (
- QUERY_LABEL,
- ANSWER_LABEL,
ADD_LABEL,
- MEM_UPDATE_LABEL,
- MEM_ORGANIZE_LABEL,
+ ANSWER_LABEL,
MEM_ARCHIVE_LABEL,
+ MEM_ORGANIZE_LABEL,
+ MEM_UPDATE_LABEL,
+ QUERY_LABEL,
)
mapping = {
diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py
index 3859c9e6f..7da531a7f 100644
--- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py
+++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py
@@ -1,3 +1,5 @@
+import hashlib
+
from collections.abc import Callable
from memos.log import get_logger
@@ -6,13 +8,13 @@
from memos.mem_scheduler.schemas.general_schemas import (
ACTIVATION_MEMORY_TYPE,
ADD_LABEL,
+ MEM_ARCHIVE_LABEL,
+ MEM_UPDATE_LABEL,
NOT_INITIALIZED,
PARAMETER_MEMORY_TYPE,
TEXT_MEMORY_TYPE,
USER_INPUT_TYPE,
WORKING_MEMORY_TYPE,
- MEM_UPDATE_LABEL,
- MEM_ARCHIVE_LABEL,
)
from memos.mem_scheduler.schemas.message_schemas import (
ScheduleLogForWebItem,
@@ -23,7 +25,6 @@
)
from memos.mem_scheduler.utils.misc_utils import log_exceptions
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
-import hashlib
logger = get_logger(__name__)
diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py
index eeca890a9..e0d18dc72 100644
--- a/src/memos/mem_scheduler/general_scheduler.py
+++ b/src/memos/mem_scheduler/general_scheduler.py
@@ -12,14 +12,14 @@
ADD_LABEL,
ANSWER_LABEL,
DEFAULT_MAX_QUERY_KEY_WORDS,
+ LONG_TERM_MEMORY_TYPE,
MEM_ORGANIZE_LABEL,
MEM_READ_LABEL,
+ NOT_APPLICABLE_TYPE,
PREF_ADD_LABEL,
QUERY_LABEL,
- WORKING_MEMORY_TYPE,
USER_INPUT_TYPE,
- NOT_APPLICABLE_TYPE,
- LONG_TERM_MEMORY_TYPE,
+ WORKING_MEMORY_TYPE,
MemCubeID,
UserID,
)
diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py
index 21b2d63f0..f6e9b86fe 100644
--- a/src/memos/mem_scheduler/optimized_scheduler.py
+++ b/src/memos/mem_scheduler/optimized_scheduler.py
@@ -69,7 +69,6 @@ def submit_memory_history_async_task(
"session_id": session_id,
"top_k": search_req.top_k,
"internet_search": search_req.internet_search,
- "moscube": search_req.moscube,
"chat_history": search_req.chat_history,
},
"user_context": {"mem_cube_id": user_context.mem_cube_id},
@@ -112,7 +111,6 @@ def search_memories(
top_k=search_req.top_k,
mode=mode,
manual_close_internet=not search_req.internet_search,
- moscube=search_req.moscube,
search_filter=search_filter,
info={
"user_id": search_req.user_id,
@@ -154,7 +152,6 @@ def mix_search_memories(
top_k=search_req.top_k,
mode=SearchMode.FAST,
manual_close_internet=not search_req.internet_search,
- moscube=search_req.moscube,
search_filter=search_filter,
info=info,
)
diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py
index 5f85aa907..6e196e23a 100644
--- a/src/memos/memories/textual/preference.py
+++ b/src/memos/memories/textual/preference.py
@@ -190,7 +190,7 @@ def get_with_collection_name(
return None
return TextualMemoryItem(
id=res.id,
- memory=res.payload.get("dialog_str", ""),
+ memory=res.memory,
metadata=PreferenceTextualMemoryMetadata(**res.payload),
)
except Exception as e:
@@ -225,7 +225,7 @@ def get_by_ids_with_collection_name(
return [
TextualMemoryItem(
id=memo.id,
- memory=memo.payload.get("dialog_str", ""),
+ memory=memo.memory,
metadata=PreferenceTextualMemoryMetadata(**memo.payload),
)
for memo in res
@@ -248,19 +248,43 @@ def get_all(self) -> list[TextualMemoryItem]:
all_memories[collection_name] = [
TextualMemoryItem(
id=memo.id,
- memory=memo.payload.get("dialog_str", ""),
+ memory=memo.memory,
metadata=PreferenceTextualMemoryMetadata(**memo.payload),
)
for memo in items
]
return all_memories
+ def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[TextualMemoryItem]:
+ """Get memories by filter.
+ Args:
+ filter (dict[str, Any]): Filter criteria.
+ Returns:
+ list[TextualMemoryItem]: List of memories that match the filter.
+ """
+ collection_list = self.vector_db.config.collection_name
+ all_db_items = []
+ for collection_name in collection_list:
+ db_items = self.vector_db.get_by_filter(collection_name=collection_name, filter=filter)
+ all_db_items.extend(db_items)
+ memories = [
+ TextualMemoryItem(
+ id=memo.id,
+ memory=memo.memory,
+ metadata=PreferenceTextualMemoryMetadata(**memo.payload),
+ )
+ for memo in all_db_items
+ ]
+ return memories
+
def delete(self, memory_ids: list[str]) -> None:
"""Delete memories.
Args:
memory_ids (list[str]): List of memory IDs to delete.
"""
- raise NotImplementedError
+ collection_list = self.vector_db.config.collection_name
+ for collection_name in collection_list:
+ self.vector_db.delete(collection_name, memory_ids)
def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None:
"""Delete memories by their IDs and collection name.
diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py
index 1b2355bc8..27c33029c 100644
--- a/src/memos/memories/textual/tree.py
+++ b/src/memos/memories/textual/tree.py
@@ -129,7 +129,6 @@ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int
def get_searcher(
self,
manual_close_internet: bool = False,
- moscube: bool = False,
):
if (self.internet_retriever is not None) and manual_close_internet:
logger.warning(
@@ -141,7 +140,6 @@ def get_searcher(
self.embedder,
self.reranker,
internet_retriever=None,
- moscube=moscube,
)
else:
searcher = Searcher(
@@ -150,7 +148,6 @@ def get_searcher(
self.embedder,
self.reranker,
internet_retriever=self.internet_retriever,
- moscube=moscube,
)
return searcher
@@ -162,7 +159,6 @@ def search(
mode: str = "fast",
memory_type: str = "All",
manual_close_internet: bool = True,
- moscube: bool = False,
search_filter: dict | None = None,
user_name: str | None = None,
) -> list[TextualMemoryItem]:
@@ -179,7 +175,6 @@ def search(
memory_type (str): Type restriction for search.
['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory']
manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config.
- moscube (bool): whether you use moscube to answer questions
search_filter (dict, optional): Optional metadata filters for search results.
- Keys correspond to memory metadata fields (e.g., "user_id", "session_id").
- Values are exact-match conditions.
@@ -196,7 +191,6 @@ def search(
self.reranker,
bm25_retriever=self.bm25_retriever,
internet_retriever=None,
- moscube=moscube,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
)
@@ -208,7 +202,6 @@ def search(
self.reranker,
bm25_retriever=self.bm25_retriever,
internet_retriever=self.internet_retriever,
- moscube=moscube,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
)
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py
index 31b914776..042ed837e 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py
@@ -200,9 +200,11 @@ def _process_result(
"""Process one Bocha search result into TextualMemoryItem."""
title = result.get("name", "")
content = result.get("summary", "") or result.get("snippet", "")
- summary = result.get("snippet", "")
+ summary = result.get("summary", "") or result.get("snippet", "")
url = result.get("url", "")
publish_time = result.get("datePublished", "")
+ site_name = result.get("siteName", "")
+ site_icon = result.get("siteIcon")
if publish_time:
try:
@@ -229,5 +231,12 @@ def _process_result(
read_item_i.metadata.memory_type = "OuterMemory"
read_item_i.metadata.sources = [SourceMessage(type="web", url=url)] if url else []
read_item_i.metadata.visibility = "public"
+ read_item_i.metadata.internet_info = {
+ "title": title,
+ "url": url,
+ "site_name": site_name,
+ "site_icon": site_icon,
+ "summary": summary,
+ }
memory_items.append(read_item_i)
return memory_items
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
index 933ef5af1..26ae1a723 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
@@ -41,7 +41,6 @@ def __init__(
reranker: BaseReranker,
bm25_retriever: EnhancedBM25 | None = None,
internet_retriever: None = None,
- moscube: bool = False,
search_strategy: dict | None = None,
manual_close_internet: bool = True,
):
@@ -56,7 +55,6 @@ def __init__(
# Create internet retriever from config if provided
self.internet_retriever = internet_retriever
- self.moscube = moscube
self.vec_cot = search_strategy.get("cot", False) if search_strategy else False
self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False
self.manual_close_internet = manual_close_internet
@@ -297,17 +295,6 @@ def _retrieve_paths(
user_name,
)
)
- if self.moscube:
- tasks.append(
- executor.submit(
- self._retrieve_from_memcubes,
- query,
- parsed_goal,
- query_embedding,
- top_k,
- "memos_cube01",
- )
- )
results = []
for t in tasks:
diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py
index f34cad1ef..2055615d2 100644
--- a/src/memos/multi_mem_cube/single_cube.py
+++ b/src/memos/multi_mem_cube/single_cube.py
@@ -232,7 +232,6 @@ def _fast_search(
top_k=search_req.top_k,
mode=SearchMode.FAST,
manual_close_internet=not search_req.internet_search,
- moscube=search_req.moscube,
search_filter=search_filter,
info={
"user_id": search_req.user_id,
@@ -287,7 +286,6 @@ def _fine_search(
top_k=search_req.top_k,
mode=SearchMode.FINE,
manual_close_internet=not search_req.internet_search,
- moscube=search_req.moscube,
search_filter=search_filter,
info=info,
)
diff --git a/src/memos/types/__init__.py b/src/memos/types/__init__.py
new file mode 100644
index 000000000..dd1b98305
--- /dev/null
+++ b/src/memos/types/__init__.py
@@ -0,0 +1,3 @@
+# ruff: noqa: F403, F401
+
+from .types import *
diff --git a/src/memos/types/openai_chat_completion_types/__init__.py b/src/memos/types/openai_chat_completion_types/__init__.py
new file mode 100644
index 000000000..4a08a9f24
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/__init__.py
@@ -0,0 +1,15 @@
+# ruff: noqa: F403, F401
+
+from .chat_completion_assistant_message_param import *
+from .chat_completion_content_part_image_param import *
+from .chat_completion_content_part_input_audio_param import *
+from .chat_completion_content_part_param import *
+from .chat_completion_content_part_refusal_param import *
+from .chat_completion_content_part_text_param import *
+from .chat_completion_message_custom_tool_call_param import *
+from .chat_completion_message_function_tool_call_param import *
+from .chat_completion_message_param import *
+from .chat_completion_message_tool_call_union_param import *
+from .chat_completion_system_message_param import *
+from .chat_completion_tool_message_param import *
+from .chat_completion_user_message_param import *
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py
new file mode 100644
index 000000000..a742de3a9
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py
@@ -0,0 +1,55 @@
+# ruff: noqa: TC001, TC003
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Literal, TypeAlias
+
+from typing_extensions import Required, TypedDict
+
+from .chat_completion_content_part_refusal_param import ChatCompletionContentPartRefusalParam
+from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam
+from .chat_completion_message_tool_call_union_param import ChatCompletionMessageToolCallUnionParam
+
+
+__all__ = ["Audio", "ChatCompletionAssistantMessageParam", "ContentArrayOfContentPart"]
+
+
+class Audio(TypedDict, total=False):
+ id: Required[str]
+ """Unique identifier for a previous audio response from the model."""
+
+
+ContentArrayOfContentPart: TypeAlias = (
+ ChatCompletionContentPartTextParam | ChatCompletionContentPartRefusalParam
+)
+
+
+class ChatCompletionAssistantMessageParam(TypedDict, total=False):
+ role: Required[Literal["assistant"]]
+ """The role of the messages author, in this case `assistant`."""
+
+ audio: Audio | None
+ """
+ Data about a previous audio response from the model.
+ [Learn more](https://platform.openai.com/docs/guides/audio).
+ """
+
+ content: str | Iterable[ContentArrayOfContentPart] | None
+ """The contents of the assistant message.
+
+ Required unless `tool_calls` or `function_call` is specified.
+ """
+
+ refusal: str | None
+ """The refusal message by the assistant."""
+
+ tool_calls: Iterable[ChatCompletionMessageToolCallUnionParam]
+ """The tool calls generated by the model, such as function calls."""
+
+ chat_time: str | None
+ """Optional timestamp for the message, format is not
+ restricted, it can be any vague or precise time string."""
+
+ message_id: str | None
+ """Optional unique identifier for the message"""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py
new file mode 100644
index 000000000..6718bd91e
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py
@@ -0,0 +1,27 @@
+from __future__ import annotations
+
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+
+__all__ = ["ChatCompletionContentPartImageParam", "ImageURL"]
+
+
+class ImageURL(TypedDict, total=False):
+ url: Required[str]
+ """Either a URL of the image or the base64 encoded image data."""
+
+ detail: Literal["auto", "low", "high"]
+ """Specifies the detail level of the image.
+
+ Learn more in the
+ [Vision guide](https://platform.openai.com/docs/guides/vision#low-or-high-fidelity-image-understanding).
+ """
+
+
+class ChatCompletionContentPartImageParam(TypedDict, total=False):
+ image_url: Required[ImageURL]
+
+ type: Required[Literal["image_url"]]
+ """The type of the content part."""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py
new file mode 100644
index 000000000..e7cfa4504
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+
+__all__ = ["ChatCompletionContentPartInputAudioParam", "InputAudio"]
+
+
+class InputAudio(TypedDict, total=False):
+ data: Required[str]
+ """Base64 encoded audio data."""
+
+ format: Required[Literal["wav", "mp3"]]
+ """The format of the encoded audio data. Currently supports "wav" and "mp3"."""
+
+
+class ChatCompletionContentPartInputAudioParam(TypedDict, total=False):
+ input_audio: Required[InputAudio]
+
+ type: Required[Literal["input_audio"]]
+ """The type of the content part. Always `input_audio`."""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py
new file mode 100644
index 000000000..a5e740791
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py
@@ -0,0 +1,41 @@
+from __future__ import annotations
+
+from typing import Literal, TypeAlias
+
+from typing_extensions import Required, TypedDict
+
+from .chat_completion_content_part_image_param import ChatCompletionContentPartImageParam
+from .chat_completion_content_part_input_audio_param import ChatCompletionContentPartInputAudioParam
+from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam
+
+
+__all__ = ["ChatCompletionContentPartParam", "File", "FileFile"]
+
+
+class FileFile(TypedDict, total=False):
+ file_data: str
+ """
+ The base64 encoded file data, used when passing the file to the model as a
+ string.
+ """
+
+ file_id: str
+ """The ID of an uploaded file to use as input."""
+
+ filename: str
+ """The name of the file, used when passing the file to the model as a string."""
+
+
+class File(TypedDict, total=False):
+ file: Required[FileFile]
+
+ type: Required[Literal["file"]]
+ """The type of the content part. Always `file`."""
+
+
+ChatCompletionContentPartParam: TypeAlias = (
+ ChatCompletionContentPartTextParam
+ | ChatCompletionContentPartImageParam
+ | ChatCompletionContentPartInputAudioParam
+ | File
+)
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py
new file mode 100644
index 000000000..fc87e9e1a
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py
@@ -0,0 +1,16 @@
+from __future__ import annotations
+
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+
+__all__ = ["ChatCompletionContentPartRefusalParam"]
+
+
+class ChatCompletionContentPartRefusalParam(TypedDict, total=False):
+ refusal: Required[str]
+ """The refusal message generated by the model."""
+
+ type: Required[Literal["refusal"]]
+ """The type of the content part."""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py
new file mode 100644
index 000000000..f43de0eff
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py
@@ -0,0 +1,16 @@
+from __future__ import annotations
+
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+
+__all__ = ["ChatCompletionContentPartTextParam"]
+
+
+class ChatCompletionContentPartTextParam(TypedDict, total=False):
+ text: Required[str]
+ """The text content."""
+
+ type: Required[Literal["text"]]
+ """The type of the content part."""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py
new file mode 100644
index 000000000..bc7a22edb
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py
@@ -0,0 +1,27 @@
+from __future__ import annotations
+
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+
+__all__ = ["ChatCompletionMessageCustomToolCallParam", "Custom"]
+
+
+class Custom(TypedDict, total=False):
+ input: Required[str]
+ """The input for the custom tool call generated by the model."""
+
+ name: Required[str]
+ """The name of the custom tool to call."""
+
+
+class ChatCompletionMessageCustomToolCallParam(TypedDict, total=False):
+ id: Required[str]
+ """The ID of the tool call."""
+
+ custom: Required[Custom]
+ """The custom tool that the model called."""
+
+ type: Required[Literal["custom"]]
+ """The type of the tool. Always `custom`."""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py
new file mode 100644
index 000000000..56341d94a
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+
+__all__ = ["ChatCompletionMessageFunctionToolCallParam", "Function"]
+
+
+class Function(TypedDict, total=False):
+ arguments: Required[str]
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: Required[str]
+ """The name of the function to call."""
+
+
+class ChatCompletionMessageFunctionToolCallParam(TypedDict, total=False):
+ id: Required[str]
+ """The ID of the tool call."""
+
+ function: Required[Function]
+ """The function that the model called."""
+
+ type: Required[Literal["function"]]
+ """The type of the tool. Currently, only `function` is supported."""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py
new file mode 100644
index 000000000..06a624297
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py
@@ -0,0 +1,18 @@
+from __future__ import annotations
+
+from typing import TypeAlias
+
+from .chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
+from .chat_completion_system_message_param import ChatCompletionSystemMessageParam
+from .chat_completion_tool_message_param import ChatCompletionToolMessageParam
+from .chat_completion_user_message_param import ChatCompletionUserMessageParam
+
+
+__all__ = ["ChatCompletionMessageParam"]
+
+ChatCompletionMessageParam: TypeAlias = (
+ ChatCompletionSystemMessageParam
+ | ChatCompletionUserMessageParam
+ | ChatCompletionAssistantMessageParam
+ | ChatCompletionToolMessageParam
+)
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py
new file mode 100644
index 000000000..28bb880cf
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py
@@ -0,0 +1,15 @@
+from __future__ import annotations
+
+from typing import TypeAlias
+
+from .chat_completion_message_custom_tool_call_param import ChatCompletionMessageCustomToolCallParam
+from .chat_completion_message_function_tool_call_param import (
+ ChatCompletionMessageFunctionToolCallParam,
+)
+
+
+__all__ = ["ChatCompletionMessageToolCallUnionParam"]
+
+ChatCompletionMessageToolCallUnionParam: TypeAlias = (
+ ChatCompletionMessageFunctionToolCallParam | ChatCompletionMessageCustomToolCallParam
+)
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py
new file mode 100644
index 000000000..7faa90e2e
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py
@@ -0,0 +1,35 @@
+# ruff: noqa: TC001, TC003
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam
+
+
+__all__ = ["ChatCompletionSystemMessageParam"]
+
+
+class ChatCompletionSystemMessageParam(TypedDict, total=False):
+ content: Required[str | Iterable[ChatCompletionContentPartTextParam]]
+ """The contents of the system message."""
+
+ role: Required[Literal["system"]]
+ """The role of the messages author, in this case `system`."""
+
+ name: str
+ """An optional name for the participant.
+
+ Provides the model information to differentiate between participants of the same
+ role.
+ """
+
+ chat_time: str | None
+ """Optional timestamp for the message, format is not
+ restricted, it can be any vague or precise time string."""
+
+ message_id: str | None
+ """Optional unique identifier for the message"""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py
new file mode 100644
index 000000000..c03220915
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py
@@ -0,0 +1,31 @@
+# ruff: noqa: TC001, TC003
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+from .chat_completion_content_part_param import ChatCompletionContentPartParam
+
+
+__all__ = ["ChatCompletionToolMessageParam"]
+
+
+class ChatCompletionToolMessageParam(TypedDict, total=False):
+ content: Required[str | Iterable[ChatCompletionContentPartParam]]
+ """The contents of the tool message."""
+
+ role: Required[Literal["tool"]]
+ """The role of the messages author, in this case `tool`."""
+
+ tool_call_id: Required[str]
+ """Tool call that this message is responding to."""
+
+ chat_time: str | None
+ """Optional timestamp for the message, format is not
+ restricted, it can be any vague or precise time string."""
+
+ message_id: str | None
+ """Optional unique identifier for the message"""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py
new file mode 100644
index 000000000..2c2a1f23f
--- /dev/null
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py
@@ -0,0 +1,35 @@
+# ruff: noqa: TC001, TC003
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Literal
+
+from typing_extensions import Required, TypedDict
+
+from .chat_completion_content_part_param import ChatCompletionContentPartParam
+
+
+__all__ = ["ChatCompletionUserMessageParam"]
+
+
+class ChatCompletionUserMessageParam(TypedDict, total=False):
+ content: Required[str | Iterable[ChatCompletionContentPartParam]]
+ """The contents of the user message."""
+
+ role: Required[Literal["user"]]
+ """The role of the messages author, in this case `user`."""
+
+ name: str
+ """An optional name for the participant.
+
+ Provides the model information to differentiate between participants of the same
+ role.
+ """
+
+ chat_time: str | None
+ """Optional timestamp for the message, format is not
+ restricted, it can be any vague or precise time string."""
+
+ message_id: str | None
+ """Optional unique identifier for the message"""
diff --git a/src/memos/types.py b/src/memos/types/types.py
similarity index 82%
rename from src/memos/types.py
rename to src/memos/types/types.py
index 635fabccc..b8efc6208 100644
--- a/src/memos/types.py
+++ b/src/memos/types/types.py
@@ -14,6 +14,23 @@
from memos.memories.parametric.item import ParametricMemoryItem
from memos.memories.textual.item import TextualMemoryItem
+from .openai_chat_completion_types import (
+ ChatCompletionContentPartTextParam,
+ ChatCompletionMessageParam,
+ File,
+)
+
+
+__all__ = [
+ "ChatHistory",
+ "MOSSearchResult",
+ "MessageDict",
+ "MessageList",
+ "MessageRole",
+ "Permission",
+ "PermissionDict",
+ "UserContext",
+]
# ─── Message Types ──────────────────────────────────────────────────────────────
@@ -32,8 +49,16 @@ class MessageDict(TypedDict, total=False):
message_id: str | None # Optional unique identifier for the message
+RawMessageDict: TypeAlias = ChatCompletionContentPartTextParam | File
+
+
# Message collections
-MessageList: TypeAlias = list[MessageDict]
+MessageList: TypeAlias = list[ChatCompletionMessageParam]
+RawMessageList: TypeAlias = list[RawMessageDict]
+
+
+# Messages Type
+MessagesType: TypeAlias = str | MessageList | RawMessageList
# Chat history structure
diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py
index a977a4004..6562c9a95 100644
--- a/tests/configs/test_llm.py
+++ b/tests/configs/test_llm.py
@@ -19,7 +19,14 @@ def test_base_llm_config():
required_fields=[
"model_name_or_path",
],
- optional_fields=["temperature", "max_tokens", "top_p", "top_k", "remove_think_prefix"],
+ optional_fields=[
+ "temperature",
+ "max_tokens",
+ "top_p",
+ "top_k",
+ "remove_think_prefix",
+ "default_headers",
+ ],
)
check_config_instantiation_valid(
@@ -48,6 +55,7 @@ def test_openai_llm_config():
"api_base",
"remove_think_prefix",
"extra_body",
+ "default_headers",
],
)
@@ -79,6 +87,8 @@ def test_ollama_llm_config():
"top_k",
"remove_think_prefix",
"api_base",
+ "default_headers",
+ "enable_thinking",
],
)
@@ -111,6 +121,7 @@ def test_hf_llm_config():
"do_sample",
"remove_think_prefix",
"add_generation_prompt",
+ "default_headers",
],
)
diff --git a/tests/llms/test_deepseek.py b/tests/llms/test_deepseek.py
index 75c1ead5f..11be66887 100644
--- a/tests/llms/test_deepseek.py
+++ b/tests/llms/test_deepseek.py
@@ -12,12 +12,14 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self):
"""Test DeepSeekLLM generate method with and without tag removal."""
# Simulated full content including tag
- full_content = "Thinking in progress...Hello from DeepSeek!"
+ full_content = "Hello from DeepSeek!"
+ reasoning_content = "Thinking in progress..."
# Mock response object
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"mock": "true"}'
mock_response.choices[0].message.content = full_content
+ mock_response.choices[0].message.reasoning_content = reasoning_content
# Config with think prefix preserved
config_with_think = DeepSeekLLMConfig.model_validate(
@@ -35,7 +37,7 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self):
llm_with_think.client.chat.completions.create = MagicMock(return_value=mock_response)
output_with_think = llm_with_think.generate([{"role": "user", "content": "Hello"}])
- self.assertEqual(output_with_think, full_content)
+ self.assertEqual(output_with_think, f"{reasoning_content}{full_content}")
# Config with think tag removed
config_without_think = config_with_think.model_copy(update={"remove_think_prefix": True})
@@ -43,7 +45,7 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self):
llm_without_think.client.chat.completions.create = MagicMock(return_value=mock_response)
output_without_think = llm_without_think.generate([{"role": "user", "content": "Hello"}])
- self.assertEqual(output_without_think, "Hello from DeepSeek!")
+ self.assertEqual(output_without_think, full_content)
def test_deepseek_llm_generate_stream(self):
"""Test DeepSeekLLM generate_stream with reasoning_content and content chunks."""
@@ -84,5 +86,5 @@ def make_chunk(delta_dict):
self.assertIn("Analyzing...", full_output)
self.assertIn("Hello, DeepSeek!", full_output)
- self.assertTrue(full_output.startswith("Analyzing..."))
+ self.assertTrue(full_output.startswith(""))
self.assertTrue(full_output.endswith("DeepSeek!"))
diff --git a/tests/llms/test_ollama.py b/tests/llms/test_ollama.py
index 47002a21f..9ed252f37 100644
--- a/tests/llms/test_ollama.py
+++ b/tests/llms/test_ollama.py
@@ -1,5 +1,6 @@
import unittest
+from types import SimpleNamespace
from unittest.mock import MagicMock
from memos.configs.llm import LLMConfigFactory, OllamaLLMConfig
@@ -12,15 +13,15 @@ def test_llm_factory_with_mocked_ollama_backend(self):
"""Test LLMFactory with mocked Ollama backend."""
mock_chat = MagicMock()
mock_response = MagicMock()
- mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","images":null,"tool_calls":null}}'
- mock_response.__getitem__.side_effect = lambda key: {
- "message": {
- "role": "assistant",
- "content": "Hello! How are you? I'm here to help and smile!",
- "images": None,
- "tool_calls": None,
- }
- }[key]
+ mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!", "thinking":"Analyzing your request...","images":null,"tool_calls":null}}'
+
+ mock_response.message = SimpleNamespace(
+ role="assistant",
+ content="Hello! How are you? I'm here to help and smile!",
+ thinking="Analyzing your request...",
+ images=None,
+ tool_calls=None,
+ )
mock_chat.return_value = mock_response
config = LLMConfigFactory.model_validate(
@@ -32,6 +33,7 @@ def test_llm_factory_with_mocked_ollama_backend(self):
"max_tokens": 1024,
"top_p": 0.9,
"top_k": 50,
+ "enable_thinking": True,
},
}
)
@@ -42,21 +44,23 @@ def test_llm_factory_with_mocked_ollama_backend(self):
]
response = llm.generate(messages)
- self.assertEqual(response, "Hello! How are you? I'm here to help and smile!")
+ self.assertEqual(
+ response,
+ "Analyzing your request...Hello! How are you? I'm here to help and smile!",
+ )
def test_ollama_llm_with_mocked_backend(self):
"""Test OllamaLLM with mocked backend."""
mock_chat = MagicMock()
mock_response = MagicMock()
- mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","images":null,"tool_calls":null}}'
- mock_response.__getitem__.side_effect = lambda key: {
- "message": {
- "role": "assistant",
- "content": "Hello! How are you? I'm here to help and smile!",
- "images": None,
- "tool_calls": None,
- }
- }[key]
+ mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","thinking":"Analyzing your request...","images":null,"tool_calls":null}}'
+ mock_response.message = SimpleNamespace(
+ role="assistant",
+ content="Hello! How are you? I'm here to help and smile!",
+ thinking="Analyzing your request...",
+ images=None,
+ tool_calls=None,
+ )
mock_chat.return_value = mock_response
config = OllamaLLMConfig(
@@ -73,4 +77,7 @@ def test_ollama_llm_with_mocked_backend(self):
]
response = ollama.generate(messages)
- self.assertEqual(response, "Hello! How are you? I'm here to help and smile!")
+ self.assertEqual(
+ response,
+ "Analyzing your request...Hello! How are you? I'm here to help and smile!",
+ )
diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py
index dff57c058..ba5b52df4 100644
--- a/tests/llms/test_openai.py
+++ b/tests/llms/test_openai.py
@@ -14,6 +14,7 @@ def test_llm_factory_with_mocked_openai_backend(self):
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"id":"chatcmpl-BWoqIrvOeWdnFVZQUFzCcdVEpJ166","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Hello! I\'m an AI language model created by OpenAI. I\'m here to help answer questions, provide information, and assist with a wide range of topics. How can I assist you today?","role":"assistant"}}],"created":1747161634,"model":"gpt-4o-2024-08-06","object":"chat.completion"}'
mock_response.choices[0].message.content = "Hello! I'm an AI language model created by OpenAI. I'm here to help answer questions, provide information, and assist with a wide range of topics. How can I assist you today?" # fmt: skip
+ mock_response.choices[0].message.reasoning_content = None
mock_chat_completions_create.return_value = mock_response
config = LLMConfigFactory.model_validate(
diff --git a/tests/llms/test_qwen.py b/tests/llms/test_qwen.py
index 90f31e47f..71a4c75dd 100644
--- a/tests/llms/test_qwen.py
+++ b/tests/llms/test_qwen.py
@@ -12,12 +12,14 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self):
"""Test QwenLLM non-streaming response generation with and without prefix removal."""
# Simulated full response content with tag
- full_content = "Analyzing your request...Hello, world!"
+ full_content = "Hello from DeepSeek!"
+ reasoning_content = "Thinking in progress..."
# Prepare the mock response object with expected structure
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"mocked": "true"}'
mock_response.choices[0].message.content = full_content
+ mock_response.choices[0].message.reasoning_content = reasoning_content
# Create config with remove_think_prefix = False
config_with_think = QwenLLMConfig.model_validate(
@@ -37,7 +39,7 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self):
llm_with_think.client.chat.completions.create = MagicMock(return_value=mock_response)
response_with_think = llm_with_think.generate([{"role": "user", "content": "Hi"}])
- self.assertEqual(response_with_think, full_content)
+ self.assertEqual(response_with_think, f"{reasoning_content}{full_content}")
# Create config with remove_think_prefix = True
config_without_think = config_with_think.model_copy(update={"remove_think_prefix": True})
@@ -47,7 +49,7 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self):
llm_without_think.client.chat.completions.create = MagicMock(return_value=mock_response)
response_without_think = llm_without_think.generate([{"role": "user", "content": "Hi"}])
- self.assertEqual(response_without_think, "Hello, world!")
+ self.assertEqual(response_without_think, full_content)
self.assertNotIn("", response_without_think)
def test_qwen_llm_generate_stream(self):