diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index 17bfd3993..7d8cf2897 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -1,29 +1,28 @@ +import re import shutil import sys +from datetime import datetime from pathlib import Path from queue import Queue + from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig -from datetime import datetime -import re - from memos.configs.mem_scheduler import AuthConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_os.main import MOS +from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( - QUERY_LABEL, - ANSWER_LABEL, ADD_LABEL, + ANSWER_LABEL, + MEM_ARCHIVE_LABEL, MEM_ORGANIZE_LABEL, MEM_UPDATE_LABEL, - MEM_ARCHIVE_LABEL, - NOT_APPLICABLE_TYPE, + QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem from memos.mem_scheduler.utils.filter_utils import transform_name_to_key -from memos.mem_scheduler.general_scheduler import GeneralScheduler FILE_PATH = Path(__file__).absolute() diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 8540a67ec..2f40f1c91 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -7,6 +7,7 @@ import asyncio import json +import re import traceback from collections.abc import Generator @@ -32,7 +33,6 @@ from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, QUERY_LABEL, - SearchMode, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.templates.mos_prompts import ( @@ -53,6 +53,7 @@ class ChatHandler(BaseHandler): def __init__( self, dependencies: HandlerDependencies, + chat_llms: dict[str, Any], search_handler=None, add_handler=None, online_bot=None, @@ -62,6 +63,7 @@ def __init__( Args: dependencies: HandlerDependencies instance + chat_llms: Dictionary mapping model names to LLM instances search_handler: Optional SearchHandler instance (created if not provided) add_handler: Optional AddHandler instance (created if not provided) online_bot: Optional DingDing bot function for notifications @@ -80,6 +82,7 @@ def __init__( add_handler = AddHandler(dependencies) + self.chat_llms = chat_llms self.search_handler = search_handler self.add_handler = add_handler self.online_bot = online_bot @@ -105,21 +108,19 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An HTTPException: If chat fails """ try: - import time - - time_start = time.time() - # Step 1: Search for relevant memories search_req = APISearchRequest( + query=chat_req.query, user_id=chat_req.user_id, mem_cube_id=chat_req.mem_cube_id, - query=chat_req.query, - top_k=chat_req.top_k or 10, - session_id=chat_req.session_id, - mode=SearchMode.FAST, + mode=chat_req.mode, internet_search=chat_req.internet_search, - moscube=chat_req.moscube, + top_k=chat_req.top_k, chat_history=chat_req.history, + session_id=chat_req.session_id, + include_preference=chat_req.include_preference, + pref_top_k=chat_req.pref_top_k, + filter=chat_req.filter, ) search_response = self.search_handler.handle_search_memories(search_req) @@ -137,7 +138,9 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An ) # Step 2: Build system prompt - system_prompt = self._build_system_prompt(filtered_memories, chat_req.base_prompt) + system_prompt = self._build_system_prompt( + filtered_memories, search_response.data["pref_string"], chat_req.system_prompt + ) # Prepare message history history_info = chat_req.history[-20:] if chat_req.history else [] @@ -150,28 +153,33 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An self.logger.info("Starting to generate complete response...") # Step 3: Generate complete response from LLM - response = self.llm.generate(current_messages) - - time_end = time.time() + if chat_req.model_name_or_path and chat_req.model_name_or_path not in self.chat_llms: + return { + "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}" + } + model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) + response = self.chat_llms[model].generate(current_messages, model_name_or_path=model) + + # Step 4: start add after chat asynchronously + if chat_req.add_message_on_answer: + self._start_add_to_memory( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=response, + async_mode="async", + ) - # Step 4: Start post-chat processing asynchronously - self._start_post_chat_processing( - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - session_id=chat_req.session_id or "default_session", - query=chat_req.query, - full_response=response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=0.0, - current_messages=current_messages, + match = re.search(r"([\s\S]*?)", response) + reasoning_text = match.group(1) if match else None + final_text = ( + re.sub(r"[\s\S]*?", "", response, count=1) if match else response ) - # Return the complete response return { "message": "Chat completed successfully", - "data": {"response": response, "references": filtered_memories}, + "data": {"response": final_text, "reasoning": reasoning_text}, } except ValueError as err: @@ -186,6 +194,150 @@ def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: This implementation directly uses search_handler and add_handler. + Args: + chat_req: Chat stream request + + Returns: + StreamingResponse with SSE formatted chat stream + + Raises: + HTTPException: If stream initialization fails + """ + try: + + def generate_chat_response() -> Generator[str, None, None]: + """Generate chat response as SSE stream.""" + try: + search_req = APISearchRequest( + query=chat_req.query, + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + mode=chat_req.mode, + internet_search=chat_req.internet_search, + top_k=chat_req.top_k, + chat_history=chat_req.history, + session_id=chat_req.session_id, + include_preference=chat_req.include_preference, + pref_top_k=chat_req.pref_top_k, + filter=chat_req.filter, + ) + + search_response = self.search_handler.handle_search_memories(search_req) + + self._send_message_to_scheduler( + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + query=chat_req.query, + label=QUERY_LABEL, + ) + # Extract memories from search results + memories_list = [] + if search_response.data and search_response.data.get("text_mem"): + text_mem_results = search_response.data["text_mem"] + if text_mem_results and text_mem_results[0].get("memories"): + memories_list = text_mem_results[0]["memories"] + + # Filter memories by threshold + filtered_memories = self._filter_memories_by_threshold(memories_list) + + # Step 2: Build system prompt with memories + system_prompt = self._build_system_prompt( + filtered_memories, + search_response.data["pref_string"], + chat_req.system_prompt, + ) + + # Prepare messages + history_info = chat_req.history[-20:] if chat_req.history else [] + current_messages = [ + {"role": "system", "content": system_prompt}, + *history_info, + {"role": "user", "content": chat_req.query}, + ] + + self.logger.info( + f"user_id: {chat_req.user_id}, cube_id: {chat_req.mem_cube_id}, " + f"current_system_prompt: {system_prompt}" + ) + + # Step 3: Generate streaming response from LLM + if ( + chat_req.model_name_or_path + and chat_req.model_name_or_path not in self.chat_llms + ): + return { + "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}" + } + model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) + response_stream = self.chat_llms[model].generate_stream( + current_messages, model_name_or_path=model + ) + + # Stream the response + buffer = "" + full_response = "" + in_think = False + + for chunk in response_stream: + if chunk == "": + in_think = True + continue + if chunk == "": + in_think = False + continue + + if in_think: + chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + continue + + buffer += chunk + full_response += chunk + + chunk_data = f"data: {json.dumps({'type': 'text', 'data': chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + + current_messages.append({"role": "assistant", "content": full_response}) + if chat_req.add_message_on_answer: + self._start_add_to_memory( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=full_response, + async_mode="async", + ) + + except Exception as e: + self.logger.error(f"Error in chat stream: {e}", exc_info=True) + error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" + yield error_data + + return StreamingResponse( + generate_chat_response(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*", + }, + ) + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + + def handle_chat_stream_playground(self, chat_req: ChatRequest) -> StreamingResponse: + """ + Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. + + This implementation directly uses search_handler and add_handler. + Args: chat_req: Chat stream request @@ -208,15 +360,17 @@ def generate_chat_response() -> Generator[str, None, None]: yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n" search_req = APISearchRequest( + query=chat_req.query, user_id=chat_req.user_id, mem_cube_id=chat_req.mem_cube_id, - query=chat_req.query, - top_k=20, - session_id=chat_req.session_id, - mode=SearchMode.FAST, - internet_search=chat_req.internet_search, # TODO this param is not worked at fine mode - moscube=chat_req.moscube, + mode=chat_req.mode, + internet_search=chat_req.internet_search, + top_k=chat_req.top_k, chat_history=chat_req.history, + session_id=chat_req.session_id, + include_preference=chat_req.include_preference, + pref_top_k=chat_req.pref_top_k, + filter=chat_req.filter, ) search_response = self.search_handler.handle_search_memories(search_req) @@ -240,10 +394,23 @@ def generate_chat_response() -> Generator[str, None, None]: # Prepare reference data reference = prepare_reference_data(filtered_memories) + # get internet reference + internet_reference = self._get_internet_reference( + search_response.data.get("text_mem")[0]["memories"] + ) yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" + # Prepare preference markdown string + if chat_req.include_preference: + pref_md_string = self._build_pref_md_string_for_playground( + search_response.data["pref_mem"][0].get("memories", []) + ) + yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n" + # Step 2: Build system prompt with memories - system_prompt = self._build_enhance_system_prompt(filtered_memories) + system_prompt = self._build_enhance_system_prompt( + filtered_memories, search_response.data["pref_string"] + ) # Prepare messages history_info = chat_req.history[-20:] if chat_req.history else [] @@ -261,14 +428,34 @@ def generate_chat_response() -> Generator[str, None, None]: yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n" # Step 3: Generate streaming response from LLM - response_stream = self.llm.generate_stream(current_messages) + if ( + chat_req.model_name_or_path + and chat_req.model_name_or_path not in self.chat_llms + ): + return { + "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}" + } + model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) + response_stream = self.chat_llms[model].generate_stream( + current_messages, model_name_or_path=model + ) # Stream the response buffer = "" full_response = "" + in_think = False for chunk in response_stream: - if chunk in ["", ""]: + if chunk == "": + in_think = True + continue + if chunk == "": + in_think = False + continue + + if in_think: + chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n" + yield chunk_data continue buffer += chunk @@ -291,6 +478,9 @@ def generate_chat_response() -> Generator[str, None, None]: chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" yield chunk_data + # Yield internet reference after text response + yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n" + # Calculate timing time_end = time.time() speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1) @@ -306,7 +496,6 @@ def generate_chat_response() -> Generator[str, None, None]: yield f"data: {json.dumps({'type': 'end'})}\n\n" - # Step 4: Add conversation to memory asynchronously self._start_post_chat_processing( user_id=chat_req.user_id, cube_id=chat_req.mem_cube_id, @@ -320,6 +509,15 @@ def generate_chat_response() -> Generator[str, None, None]: current_messages=current_messages, ) + self._start_add_to_memory( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=full_response, + async_mode="sync", + ) + except Exception as e: self.logger.error(f"Error in chat stream: {e}", exc_info=True) error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" @@ -344,9 +542,62 @@ def generate_chat_response() -> Generator[str, None, None]: self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + def _get_internet_reference( + self, search_response: list[dict[str, any]] + ) -> list[dict[str, any]]: + """Get internet reference from search response.""" + unique_set = set() + result = [] + + for item in search_response: + meta = item.get("metadata", {}) + if meta.get("source") == "web" and meta.get("internet_info"): + info = meta.get("internet_info") + key = json.dumps(info, sort_keys=True) + if key not in unique_set: + unique_set.add(key) + result.append(info) + return result + + def _build_pref_md_string_for_playground(self, pref_mem_list: list[any]) -> str: + """Build preference markdown string for playground.""" + explicit = [] + implicit = [] + for pref_mem in pref_mem_list: + if pref_mem["metadata"]["preference_type"] == "explicit": + explicit.append( + { + "content": pref_mem["preference"], + "reasoning": pref_mem["metadata"]["reasoning"], + } + ) + elif pref_mem["metadata"]["preference_type"] == "implicit": + implicit.append( + { + "content": pref_mem["preference"], + "reasoning": pref_mem["metadata"]["reasoning"], + } + ) + + explicit_md = "\n\n".join( + [ + f"显性偏好 {i + 1}:\n- 抽取内容: {pref['content']}\n- 抽取理由: {pref['reasoning']}" + for i, pref in enumerate(explicit) + ] + ) + implicit_md = "\n\n".join( + [ + f"隐性偏好 {i + 1}:\n- 抽取内容: {pref['content']}\n- 抽取理由: {pref['reasoning']}" + for i, pref in enumerate(implicit) + ] + ) + + return f"{explicit_md}\n\n{implicit_md}" + def _build_system_prompt( self, memories: list | None = None, + pref_string: str | None = None, base_prompt: str | None = None, **kwargs, ) -> str: @@ -366,6 +617,8 @@ def _build_system_prompt( text_memory = memory.get("memory", "") memory_list.append(f"{i}. {text_memory}") memory_context = "\n".join(memory_list) + if pref_string: + memory_context += f"\n\n{pref_string}" if "{memories}" in base_prompt: return base_prompt.format(memories=memory_context) @@ -378,6 +631,7 @@ def _build_system_prompt( def _build_enhance_system_prompt( self, memories_list: list, + pref_string: str = "", tone: str = "friendly", verbosity: str = "mid", ) -> str: @@ -386,6 +640,7 @@ def _build_enhance_system_prompt( Args: memories_list: List of memory items + pref_string: Preference string tone: Tone of the prompt verbosity: Verbosity level @@ -407,6 +662,7 @@ def _build_enhance_system_prompt( + mem_block_p + "\n## OuterMemory (ordered)\n" + mem_block_o + + f"\n\n{pref_string}" ) def _format_mem_block( @@ -608,6 +864,36 @@ def _send_message_to_scheduler( except Exception as e: self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True) + async def _add_conversation_to_memory( + self, + user_id: str, + cube_id: str, + session_id: str, + query: str, + clean_response: str, + async_mode: Literal["async", "sync"] = "sync", + ) -> None: + add_req = APIADDRequest( + user_id=user_id, + mem_cube_id=cube_id, + session_id=session_id, + messages=[ + { + "role": "user", + "content": query, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + }, + { + "role": "assistant", + "content": clean_response, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + }, + ], + async_mode=async_mode, + ) + + self.add_handler.handle_add_memories(add_req) + async def _post_chat_processing( self, user_id: str, @@ -701,28 +987,6 @@ async def _post_chat_processing( user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL ) - # Add conversation to memory using add handler - add_req = APIADDRequest( - user_id=user_id, - mem_cube_id=cube_id, - session_id=session_id, - messages=[ - { - "role": "user", - "content": query, - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - { - "role": "assistant", - "content": clean_response, # Store clean text without reference markers - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - ], - async_mode="sync", # set suync for playground - ) - - self.add_handler.handle_add_memories(add_req) - self.logger.info(f"Post-chat processing completed for user {user_id}") except Exception as e: @@ -822,3 +1086,65 @@ def run_async_in_thread(): daemon=True, ) thread.start() + + def _start_add_to_memory( + self, + user_id: str, + cube_id: str, + session_id: str, + query: str, + full_response: str, + async_mode: Literal["async", "sync"] = "sync", + ) -> None: + def run_async_in_thread(): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + clean_response, _ = self._extract_references_from_response(full_response) + loop.run_until_complete( + self._add_conversation_to_memory( + user_id=user_id, + cube_id=cube_id, + session_id=session_id, + query=query, + clean_response=clean_response, + async_mode=async_mode, + ) + ) + finally: + loop.close() + except Exception as e: + self.logger.error( + f"Error in thread-based add to memory for user {user_id}: {e}", + exc_info=True, + ) + + try: + asyncio.get_running_loop() + clean_response, _ = self._extract_references_from_response(full_response) + task = asyncio.create_task( + self._add_conversation_to_memory( + user_id=user_id, + cube_id=cube_id, + session_id=session_id, + query=query, + clean_response=clean_response, + async_mode=async_mode, + ) + ) + task.add_done_callback( + lambda t: self.logger.error( + f"Error in background add to memory for user {user_id}: {t.exception()}", + exc_info=True, + ) + if t.exception() + else None + ) + except RuntimeError: + thread = ContextThread( + target=run_async_in_thread, + name=f"AddToMemory-{user_id}", + daemon=True, + ) + thread.start() diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 89e61e79d..3ef1d529d 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -11,6 +11,7 @@ from memos.api.config import APIConfig from memos.api.handlers.config_builders import ( + build_chat_llm_config, build_embedder_config, build_graph_db_config, build_internet_retriever_config, @@ -77,6 +78,38 @@ def _get_default_memory_size(cube_config: Any) -> dict[str, int]: } +def _init_chat_llms(chat_llm_configs: list[dict]) -> dict[str, Any]: + """ + Initialize chat language models from configuration. + + Args: + chat_llm_configs: List of chat LLM configuration dictionaries + + Returns: + Dictionary mapping model names to initialized LLM instances + """ + + def _list_models(client): + try: + models = ( + [model.id for model in client.models.list().data] + if client.models.list().data + else client.models.list().models + ) + except Exception as e: + logger.error(f"Error listing models: {e}") + models = [] + return models + + model_name_instrance_maping = {} + for cfg in chat_llm_configs: + llm = LLMFactory.from_config(cfg["config_class"]) + if cfg["support_models"]: + for model_name in cfg["support_models"]: + model_name_instrance_maping[model_name] = llm + return model_name_instrance_maping + + def init_server() -> dict[str, Any]: """ Initialize all server components and configurations. @@ -104,6 +137,7 @@ def init_server() -> dict[str, Any]: # Build component configurations graph_db_config = build_graph_db_config() llm_config = build_llm_config() + chat_llm_config = build_chat_llm_config() embedder_config = build_embedder_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() @@ -123,6 +157,7 @@ def init_server() -> dict[str, Any]: else None ) llm = LLMFactory.from_config(llm_config) + chat_llms = _init_chat_llms(chat_llm_config) embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) reranker = RerankerFactory.from_config(reranker_config) @@ -130,6 +165,8 @@ def init_server() -> dict[str, Any]: internet_retriever_config, embedder=embedder ) + # Initialize chat llms + logger.debug("Core components instantiated") # Initialize memory manager @@ -234,7 +271,6 @@ def init_server() -> dict[str, Any]: tree_mem: TreeTextMemory = naive_mem_cube.text_mem searcher: Searcher = tree_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, ) logger.debug("Searcher created") @@ -276,6 +312,7 @@ def init_server() -> dict[str, Any]: "graph_db": graph_db, "mem_reader": mem_reader, "llm": llm, + "chat_llms": chat_llms, "embedder": embedder, "reranker": reranker, "internet_retriever": internet_retriever, diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index 9f510add0..4a83700d0 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -6,6 +6,7 @@ a configuration dictionary using the appropriate ConfigFactory. """ +import json import os from typing import Any @@ -81,6 +82,32 @@ def build_llm_config() -> dict[str, Any]: ) +def build_chat_llm_config() -> list[dict[str, Any]]: + """ + Build chat LLM configuration. + + Returns: + Validated chat LLM configuration dictionary + """ + configs = json.loads(os.getenv("CHAT_MODEL_LIST")) + return [ + { + "config_class": LLMConfigFactory.model_validate( + { + "backend": cfg.get("backend", "openai"), + "config": ( + {k: v for k, v in cfg.items() if k not in ["backend", "support_models"]} + ) + if cfg + else APIConfig.get_openai_config(), + } + ), + "support_models": cfg.get("support_models", None), + } + for cfg in configs + ] + + def build_embedder_config() -> dict[str, Any]: """ Build embedder configuration. diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 85f339f3f..c47a3cf83 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -6,7 +6,14 @@ from typing import Any, Literal -from memos.api.product_models import MemoryResponse +from memos.api.handlers.formatters_handler import format_memory_item +from memos.api.product_models import ( + DeleteMemoryRequest, + DeleteMemoryResponse, + GetMemoryRequest, + GetMemoryResponse, + MemoryResponse, +) from memos.log import get_logger from memos.mem_os.utils.format_utils import ( convert_graph_to_tree_forworkmem, @@ -149,3 +156,37 @@ def handle_get_subgraph( except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) raise + + +def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse: + # TODO: Implement get memory with filter + memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"] + filter_params: dict[str, Any] = {} + if get_mem_req.user_id is not None: + filter_params["user_id"] = get_mem_req.user_id + if get_mem_req.mem_cube_id is not None: + filter_params["mem_cube_id"] = get_mem_req.mem_cube_id + preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) + return GetMemoryResponse( + message="Memories retrieved successfully", + data={ + "text_mem": memories, + "pref_mem": [format_memory_item(mem) for mem in preferences], + }, + ) + + +def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: Any): + try: + naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids) + naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) + except Exception as e: + logger.error(f"Failed to delete memories: {e}", exc_info=True) + return DeleteMemoryResponse( + message="Failed to delete memories", + data="failure", + ) + return DeleteMemoryResponse( + message="Memories deleted successfully", + data={"status": "success"}, + ) diff --git a/src/memos/api/handlers/scheduler_handler.py b/src/memos/api/handlers/scheduler_handler.py index 8d3c6dc70..32b312f8a 100644 --- a/src/memos/api/handlers/scheduler_handler.py +++ b/src/memos/api/handlers/scheduler_handler.py @@ -22,7 +22,7 @@ def handle_scheduler_status( - user_name: str | None = None, + mem_cube_id: str | None = None, mem_scheduler: Any | None = None, instance_id: str = "", ) -> dict[str, Any]: @@ -43,9 +43,9 @@ def handle_scheduler_status( HTTPException: If status retrieval fails """ try: - if user_name: + if mem_cube_id: running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: getattr(task, "mem_cube_id", None) == user_name + lambda task: getattr(task, "mem_cube_id", None) == mem_cube_id ) tasks_iter = to_iter(running) running_count = len(tasks_iter) @@ -53,7 +53,7 @@ def handle_scheduler_status( "message": "ok", "data": { "scope": "user", - "user_name": user_name, + "mem_cube_id": mem_cube_id, "running_tasks": running_count, "timestamp": time.time(), "instance_id": instance_id, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index cb72011a3..3c5fb3bc4 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -1,7 +1,6 @@ -import os import uuid -from typing import Generic, Literal, TypeVar +from typing import Any, Generic, Literal, TypeVar from pydantic import BaseModel, Field @@ -37,7 +36,7 @@ class UserRegisterRequest(BaseRequest): interests: str | None = Field(None, description="User interests") -class GetMemoryRequest(BaseRequest): +class GetMemoryPlaygroundRequest(BaseRequest): """Request model for getting memories.""" user_id: str = Field(..., description="User ID") @@ -80,9 +79,20 @@ class ChatRequest(BaseRequest): None, description="List of cube IDs user can write for multi-cube chat" ) history: list[MessageDict] | None = Field(None, description="Chat history") + mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") internet_search: bool = Field(True, description="Whether to use internet search") - moscube: bool = Field(False, description="Whether to use MemOSCube") + system_prompt: str | None = Field(None, description="Base system prompt to use for chat") + top_k: int = Field(10, description="Number of results to return") + threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") + include_preference: bool = Field(True, description="Whether to handle preference memory") + pref_top_k: int = Field(6, description="Number of preference results to return") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") + model_name_or_path: str | None = Field(None, description="Model name to use for chat") + max_tokens: int | None = Field(None, description="Max tokens to generate") + temperature: float | None = Field(None, description="Temperature for sampling") + top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") + add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") class ChatCompleteRequest(BaseRequest): @@ -93,11 +103,18 @@ class ChatCompleteRequest(BaseRequest): mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") - moscube: bool = Field(False, description="Whether to use MemOSCube") - base_prompt: str | None = Field(None, description="Base prompt to use for chat") + system_prompt: str | None = Field(None, description="Base prompt to use for chat") top_k: int = Field(10, description="Number of results to return") threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") + include_preference: bool = Field(True, description="Whether to handle preference memory") + pref_top_k: int = Field(6, description="Number of preference results to return") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") + model_name_or_path: str | None = Field(None, description="Model name to use for chat") + max_tokens: int | None = Field(None, description="Max tokens to generate") + temperature: float | None = Field(None, description="Temperature for sampling") + top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") + add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") class UserCreate(BaseRequest): @@ -129,6 +146,10 @@ class SuggestionResponse(BaseResponse[list]): data: dict[str, list[str]] | None = Field(None, description="Response data") +class AddStatusResponse(BaseResponse[dict]): + """Response model for add status operations.""" + + class ConfigResponse(BaseResponse[None]): """Response model for configuration endpoint.""" @@ -141,6 +162,14 @@ class ChatResponse(BaseResponse[str]): """Response model for chat operations.""" +class GetMemoryResponse(BaseResponse[dict]): + """Response model for getting memories.""" + + +class DeleteMemoryResponse(BaseResponse[dict]): + """Response model for deleting memories.""" + + class UserResponse(BaseResponse[dict]): """Response model for user operations.""" @@ -181,11 +210,8 @@ class APISearchRequest(BaseRequest): readable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can read for multi-cube search" ) - mode: SearchMode = Field( - os.getenv("SEARCH_MODE", SearchMode.FAST), description="search mode: fast, fine, or mixture" - ) + mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") internet_search: bool = Field(False, description="Whether to use internet search") - moscube: bool = Field(False, description="Whether to use MemOSCube") top_k: int = Field(10, description="Number of results to return") chat_history: list[MessageDict] | None = Field(None, description="Chat history") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") @@ -194,6 +220,7 @@ class APISearchRequest(BaseRequest): ) include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") class APIADDRequest(BaseRequest): @@ -213,8 +240,13 @@ class APIADDRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) - async_mode: Literal["async", "sync"] | None = Field( - None, description="Whether to add memory in async mode" + async_mode: Literal["async", "sync"] = Field( + "async", description="Whether to add memory in async mode" + ) + custom_tags: list[str] | None = Field(None, description="Custom tags for the memory") + info: dict[str, str] | None = Field(None, description="Additional information for the memory") + is_feedback: bool = Field( + False, description="Whether the user feedback in knowladge base service" ) @@ -232,13 +264,43 @@ class APIChatCompleteRequest(BaseRequest): ) history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") - moscube: bool = Field(True, description="Whether to use MemOSCube") - base_prompt: str | None = Field(None, description="Base prompt to use for chat") + system_prompt: str | None = Field(None, description="Base system prompt to use for chat") + mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") top_k: int = Field(10, description="Number of results to return") threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field( "default_session", description="Session ID for soft-filtering memories" ) + include_preference: bool = Field(True, description="Whether to handle preference memory") + pref_top_k: int = Field(6, description="Number of preference results to return") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") + model_name_or_path: str | None = Field(None, description="Model name to use for chat") + max_tokens: int | None = Field(None, description="Max tokens to generate") + temperature: float | None = Field(None, description="Temperature for sampling") + top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") + add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + + +class AddStatusRequest(BaseRequest): + """Request model for checking add status.""" + + mem_cube_id: str = Field(..., description="Cube ID") + user_id: str | None = Field(None, description="User ID") + session_id: str | None = Field(None, description="Session ID") + + +class GetMemoryRequest(BaseRequest): + """Request model for getting memories.""" + + mem_cube_id: str = Field(..., description="Cube ID") + user_id: str | None = Field(None, description="User ID") + include_preference: bool = Field(True, description="Whether to handle preference memory") + + +class DeleteMemoryRequest(BaseRequest): + """Request model for deleting memories.""" + + memory_ids: list[str] = Field(..., description="Memory IDs") class SuggestionRequest(BaseRequest): diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 75b614cf4..2f6c5c317 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -10,7 +10,7 @@ BaseResponse, ChatCompleteRequest, ChatRequest, - GetMemoryRequest, + GetMemoryPlaygroundRequest, MemoryCreateRequest, MemoryResponse, SearchRequest, @@ -159,7 +159,7 @@ def get_suggestion_queries_post(suggestion_req: SuggestionRequest): @router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) -def get_all_memories(memory_req: GetMemoryRequest): +def get_all_memories(memory_req: GetMemoryPlaygroundRequest): """Get all memories for a specific user.""" try: mos_product = get_mos_product_instance() diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index b3b517305..0067d6e2f 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -23,11 +23,17 @@ from memos.api.handlers.chat_handler import ChatHandler from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import ( + AddStatusRequest, + AddStatusResponse, APIADDRequest, APIChatCompleteRequest, APISearchRequest, ChatRequest, + DeleteMemoryRequest, + DeleteMemoryResponse, + GetMemoryPlaygroundRequest, GetMemoryRequest, + GetMemoryResponse, MemoryResponse, SearchResponse, SuggestionRequest, @@ -54,7 +60,11 @@ search_handler = SearchHandler(dependencies) add_handler = AddHandler(dependencies) chat_handler = ChatHandler( - dependencies, search_handler, add_handler, online_bot=components.get("online_bot") + dependencies, + components["chat_llms"], + search_handler, + add_handler, + online_bot=components.get("online_bot"), ) # Extract commonly used components for function-based handlers @@ -99,11 +109,15 @@ def add_memories(add_req: APIADDRequest): # ============================================================================= -@router.get("/scheduler/status", summary="Get scheduler running status") -def scheduler_status(user_name: str | None = None): +@router.get( + "/scheduler/status", summary="Get scheduler running status", response_model=AddStatusResponse +) +def scheduler_status(add_status_req: AddStatusRequest): """Get scheduler running status.""" return handlers.scheduler_handler.handle_scheduler_status( - user_name=user_name, + mem_cube_id=add_status_req.mem_cube_id, + user_id=add_status_req.user_id, + session_id=add_status_req.session_id, mem_scheduler=mem_scheduler, instance_id=INSTANCE_ID, ) @@ -155,8 +169,8 @@ def chat_complete(chat_req: APIChatCompleteRequest): return chat_handler.handle_chat_complete(chat_req) -@router.post("/chat", summary="Chat with MemOS") -def chat(chat_req: ChatRequest): +@router.post("/chat/stream", summary="Chat with MemOS") +def chat_stream(chat_req: ChatRequest): """ Chat with MemOS for a specific user. Returns SSE stream. @@ -166,6 +180,17 @@ def chat(chat_req: ChatRequest): return chat_handler.handle_chat_stream(chat_req) +@router.post("/chat/stream/playground", summary="Chat with MemOS playground") +def chat_stream_playground(chat_req: ChatRequest): + """ + Chat with MemOS for a specific user. Returns SSE stream. + + This endpoint uses the class-based ChatHandler which internally + composes SearchHandler and AddHandler for a clean architecture. + """ + return chat_handler.handle_chat_stream_playground(chat_req) + + # ============================================================================= # Suggestion API Endpoints # ============================================================================= @@ -188,12 +213,12 @@ def get_suggestion_queries(suggestion_req: SuggestionRequest): # ============================================================================= -# Memory Retrieval API Endpoints +# Memory Retrieval Delete API Endpoints # ============================================================================= @router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) -def get_all_memories(memory_req: GetMemoryRequest): +def get_all_memories(memory_req: GetMemoryPlaygroundRequest): """ Get all memories or subgraph for a specific user. @@ -219,3 +244,20 @@ def get_all_memories(memory_req: GetMemoryRequest): memory_type=memory_req.memory_type or "text_mem", naive_mem_cube=naive_mem_cube, ) + + +@router.post("/get_memory", summary="Get memories for user", response_model=GetMemoryResponse) +def get_memories(memory_req: GetMemoryRequest): + return handlers.memory_handler.handle_get_memories( + get_mem_req=memory_req, + naive_mem_cube=naive_mem_cube, + ) + + +@router.post( + "/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse +) +def delete_memories(memory_req: DeleteMemoryRequest): + return handlers.memory_handler.handle_delete_memories( + delete_mem_req=memory_req, naive_mem_cube=naive_mem_cube + ) diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index d69a0a0fc..70217b896 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -9,14 +9,17 @@ class BaseLLMConfig(BaseConfig): """Base configuration class for LLMs.""" model_name_or_path: str = Field(..., description="Model name or path") - temperature: float = Field(default=0.8, description="Temperature for sampling") - max_tokens: int = Field(default=1024, description="Maximum number of tokens to generate") - top_p: float = Field(default=0.9, description="Top-p sampling parameter") + temperature: float = Field(default=0.7, description="Temperature for sampling") + max_tokens: int = Field(default=8192, description="Maximum number of tokens to generate") + top_p: float = Field(default=0.95, description="Top-p sampling parameter") top_k: int = Field(default=50, description="Top-k sampling parameter") remove_think_prefix: bool = Field( default=False, description="Remove content within think tags from the generated text", ) + default_headers: dict[str, Any] | None = Field( + default=None, description="Default headers for LLM requests" + ) class OpenAILLMConfig(BaseLLMConfig): @@ -27,6 +30,18 @@ class OpenAILLMConfig(BaseLLMConfig): extra_body: Any = Field(default=None, description="extra body") +class OpenAIResponsesLLMConfig(BaseLLMConfig): + api_key: str = Field(..., description="API key for OpenAI") + api_base: str = Field( + default="https://api.openai.com/v1", description="Base URL for OpenAI responses API" + ) + extra_body: Any = Field(default=None, description="extra body") + enable_thinking: bool = Field( + default=False, + description="Enable reasoning outputs from vLLM", + ) + + class QwenLLMConfig(BaseLLMConfig): api_key: str = Field(..., description="API key for DashScope (Qwen)") api_base: str = Field( @@ -34,7 +49,6 @@ class QwenLLMConfig(BaseLLMConfig): description="Base URL for Qwen OpenAI-compatible API", ) extra_body: Any = Field(default=None, description="extra body") - model_name_or_path: str = Field(..., description="Model name for Qwen, e.g., 'qwen-plus'") class DeepSeekLLMConfig(BaseLLMConfig): @@ -44,9 +58,6 @@ class DeepSeekLLMConfig(BaseLLMConfig): description="Base URL for DeepSeek OpenAI-compatible API", ) extra_body: Any = Field(default=None, description="Extra options for API") - model_name_or_path: str = Field( - ..., description="Model name: 'deepseek-chat' or 'deepseek-reasoner'" - ) class AzureLLMConfig(BaseLLMConfig): @@ -61,11 +72,27 @@ class AzureLLMConfig(BaseLLMConfig): api_key: str = Field(..., description="API key for Azure OpenAI") +class AzureResponsesLLMConfig(BaseLLMConfig): + base_url: str = Field( + default="https://api.openai.azure.com/", + description="Base URL for Azure OpenAI API", + ) + api_version: str = Field( + default="2024-03-01-preview", + description="API version for Azure OpenAI", + ) + api_key: str = Field(..., description="API key for Azure OpenAI") + + class OllamaLLMConfig(BaseLLMConfig): api_base: str = Field( default="http://localhost:11434", description="Base URL for Ollama API", ) + enable_thinking: bool = Field( + default=False, + description="Enable reasoning outputs from Ollama", + ) class HFLLMConfig(BaseLLMConfig): @@ -85,6 +112,10 @@ class VLLMLLMConfig(BaseLLMConfig): default="http://localhost:8088/v1", description="Base URL for vLLM API", ) + enable_thinking: bool = Field( + default=False, + description="Enable reasoning outputs from vLLM", + ) class LLMConfigFactory(BaseConfig): @@ -102,6 +133,7 @@ class LLMConfigFactory(BaseConfig): "huggingface_singleton": HFLLMConfig, # Add singleton support "qwen": QwenLLMConfig, "deepseek": DeepSeekLLMConfig, + "openai_new": OpenAIResponsesLLMConfig, } @field_validator("backend") diff --git a/src/memos/llms/deepseek.py b/src/memos/llms/deepseek.py index f5ee4842b..a90f8eb31 100644 --- a/src/memos/llms/deepseek.py +++ b/src/memos/llms/deepseek.py @@ -1,10 +1,6 @@ -from collections.abc import Generator - from memos.configs.llm import DeepSeekLLMConfig from memos.llms.openai import OpenAILLM -from memos.llms.utils import remove_thinking_tags from memos.log import get_logger -from memos.types import MessageList logger = get_logger(__name__) @@ -15,40 +11,3 @@ class DeepSeekLLM(OpenAILLM): def __init__(self, config: DeepSeekLLMConfig): super().__init__(config) - - def generate(self, messages: MessageList) -> str: - """Generate a response from DeepSeek.""" - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, - messages=messages, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - extra_body=self.config.extra_body, - ) - logger.info(f"Response from DeepSeek: {response.model_dump_json()}") - response_content = response.choices[0].message.content - if self.config.remove_think_prefix: - return remove_thinking_tags(response_content) - else: - return response_content - - def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: - """Stream response from DeepSeek.""" - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, - messages=messages, - stream=True, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - extra_body=self.config.extra_body, - ) - # Streaming chunks of text - for chunk in response: - delta = chunk.choices[0].delta - if hasattr(delta, "reasoning_content") and delta.reasoning_content: - yield delta.reasoning_content - - if hasattr(delta, "content") and delta.content: - yield delta.content diff --git a/src/memos/llms/factory.py b/src/memos/llms/factory.py index 8589d7750..8f4da662f 100644 --- a/src/memos/llms/factory.py +++ b/src/memos/llms/factory.py @@ -7,6 +7,7 @@ from memos.llms.hf_singleton import HFSingletonLLM from memos.llms.ollama import OllamaLLM from memos.llms.openai import AzureLLM, OpenAILLM +from memos.llms.openai_new import OpenAIResponsesLLM from memos.llms.qwen import QwenLLM from memos.llms.vllm import VLLMLLM from memos.memos_tools.singleton import singleton_factory @@ -24,6 +25,7 @@ class LLMFactory(BaseLLM): "vllm": VLLMLLM, "qwen": QwenLLM, "deepseek": DeepSeekLLM, + "openai_new": OpenAIResponsesLLM, } @classmethod diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index be0d1d95f..d46db7c9e 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -54,7 +54,9 @@ def __init__(self, config: HFLLMConfig): processors.append(TopPLogitsWarper(self.config.top_p)) self.logits_processors = LogitsProcessorList(processors) - def generate(self, messages: MessageList, past_key_values: DynamicCache | None = None): + def generate( + self, messages: MessageList, past_key_values: DynamicCache | None = None, **kwargs + ): """ Generate a response from the model. If past_key_values is provided, use cache-augmented generation. Args: @@ -68,12 +70,12 @@ def generate(self, messages: MessageList, past_key_values: DynamicCache | None = ) logger.info(f"HFLLM prompt: {prompt}") if past_key_values is None: - return self._generate_full(prompt) + return self._generate_full(prompt, **kwargs) else: - return self._generate_with_cache(prompt, past_key_values) + return self._generate_with_cache(prompt, past_key_values, **kwargs) def generate_stream( - self, messages: MessageList, past_key_values: DynamicCache | None = None + self, messages: MessageList, past_key_values: DynamicCache | None = None, **kwargs ) -> Generator[str, None, None]: """ Generate a streaming response from the model. @@ -92,7 +94,7 @@ def generate_stream( else: yield from self._generate_with_cache_stream(prompt, past_key_values) - def _generate_full(self, prompt: str) -> str: + def _generate_full(self, prompt: str, **kwargs) -> str: """ Generate output from scratch using the full prompt. Args: @@ -102,13 +104,13 @@ def _generate_full(self, prompt: str) -> str: """ inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device) gen_kwargs = { - "max_new_tokens": getattr(self.config, "max_tokens", 128), + "max_new_tokens": kwargs.get("max_tokens", self.config.max_tokens), "do_sample": getattr(self.config, "do_sample", True), } if self.config.do_sample: - gen_kwargs["temperature"] = self.config.temperature - gen_kwargs["top_k"] = self.config.top_k - gen_kwargs["top_p"] = self.config.top_p + gen_kwargs["temperature"] = kwargs.get("temperature", self.config.temperature) + gen_kwargs["top_k"] = kwargs.get("top_k", self.config.top_k) + gen_kwargs["top_p"] = kwargs.get("top_p", self.config.top_p) gen_ids = self.model.generate( **inputs, **gen_kwargs, @@ -125,7 +127,7 @@ def _generate_full(self, prompt: str) -> str: else response ) - def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]: + def _generate_full_stream(self, prompt: str, **kwargs) -> Generator[str, None, None]: """ Generate output from scratch using the full prompt with streaming. Args: @@ -138,7 +140,7 @@ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]: inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device) # Get generation parameters - max_new_tokens = getattr(self.config, "max_tokens", 128) + max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens) remove_think_prefix = getattr(self.config, "remove_think_prefix", False) # Manual streaming generation @@ -192,7 +194,7 @@ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]: else: yield new_token_text - def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: + def _generate_with_cache(self, query: str, kv: DynamicCache, **kwargs) -> str: """ Generate output incrementally using an existing KV cache. Args: @@ -209,7 +211,7 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: logits, kv = self._prefill(query_ids, kv) next_token = self._select_next_token(logits) generated = [next_token] - for _ in range(getattr(self.config, "max_tokens", 128) - 1): + for _ in range(kwargs.get("max_tokens", self.config.max_tokens) - 1): if self._should_stop(next_token): break logits, kv = self._prefill(next_token, kv) @@ -228,7 +230,7 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: ) def _generate_with_cache_stream( - self, query: str, kv: DynamicCache + self, query: str, kv: DynamicCache, **kwargs ) -> Generator[str, None, None]: """ Generate output incrementally using an existing KV cache with streaming. @@ -242,7 +244,7 @@ def _generate_with_cache_stream( query, return_tensors="pt", add_special_tokens=False ).input_ids.to(self.model.device) - max_new_tokens = getattr(self.config, "max_tokens", 128) + max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens) remove_think_prefix = getattr(self.config, "remove_think_prefix", False) # Initial forward pass diff --git a/src/memos/llms/ollama.py b/src/memos/llms/ollama.py index 050b7a253..bd92f9625 100644 --- a/src/memos/llms/ollama.py +++ b/src/memos/llms/ollama.py @@ -1,7 +1,7 @@ from collections.abc import Generator from typing import Any -from ollama import Client +from ollama import Client, Message from memos.configs.llm import OllamaLLMConfig from memos.llms.base import BaseLLM @@ -54,7 +54,7 @@ def _ensure_model_exists(self): except Exception as e: logger.warning(f"Could not verify model existence: {e}") - def generate(self, messages: MessageList) -> Any: + def generate(self, messages: MessageList, **kwargs) -> Any: """ Generate a response from Ollama LLM. @@ -68,19 +68,68 @@ def generate(self, messages: MessageList) -> Any: model=self.config.model_name_or_path, messages=messages, options={ - "temperature": self.config.temperature, - "num_predict": self.config.max_tokens, - "top_p": self.config.top_p, - "top_k": self.config.top_k, + "temperature": kwargs.get("temperature", self.config.temperature), + "num_predict": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "top_k": kwargs.get("top_k", self.config.top_k), }, + think=self.config.enable_thinking, + tools=kwargs.get("tools"), ) logger.info(f"Raw response from Ollama: {response.model_dump_json()}") - - str_response = response["message"]["content"] or "" + tool_calls = getattr(response.message, "tool_calls", None) + if isinstance(tool_calls, list) and len(tool_calls) > 0: + return self.tool_call_parser(tool_calls) + + str_thinking = ( + f"{response.message.thinking}" + if hasattr(response.message, "thinking") + else "" + ) + str_response = response.message.content if self.config.remove_think_prefix: return remove_thinking_tags(str_response) else: - return str_response + return str_thinking + str_response def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: - raise NotImplementedError + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + + response = self.client.chat( + model=kwargs.get("model_name_or_path", self.config.model_name_or_path), + messages=messages, + options={ + "temperature": kwargs.get("temperature", self.config.temperature), + "num_predict": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "top_k": kwargs.get("top_k", self.config.top_k), + }, + think=self.config.enable_thinking, + stream=True, + ) + # Streaming chunks of text + reasoning_started = False + for chunk in response: + if hasattr(chunk.message, "thinking") and chunk.message.thinking: + if not reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = True + yield chunk.message.thinking + + if hasattr(chunk.message, "content") and chunk.message.content: + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield chunk.message.content + + def tool_call_parser(self, tool_calls: list[Message.ToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "function_name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + for tool_call in tool_calls + ] diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index da55ae593..9b348adcf 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -1,12 +1,12 @@ -import hashlib import json -import time from collections.abc import Generator -from typing import ClassVar import openai +from openai._types import NOT_GIVEN +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall + from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig from memos.llms.base import BaseLLM from memos.llms.utils import remove_thinking_tags @@ -19,84 +19,57 @@ class OpenAILLM(BaseLLM): - """OpenAI LLM class with singleton pattern.""" - - _instances: ClassVar[dict] = {} # Class variable to store instances - - def __new__(cls, config: OpenAILLMConfig) -> "OpenAILLM": - config_hash = cls._get_config_hash(config) - - if config_hash not in cls._instances: - logger.info(f"Creating new OpenAI LLM instance for config hash: {config_hash}") - instance = super().__new__(cls) - cls._instances[config_hash] = instance - else: - logger.info(f"Reusing existing OpenAI LLM instance for config hash: {config_hash}") - - return cls._instances[config_hash] + """OpenAI LLM class via openai.chat.completions.create.""" def __init__(self, config: OpenAILLMConfig): - # Avoid duplicate initialization - if hasattr(self, "_initialized"): - return - self.config = config - self.client = openai.Client(api_key=config.api_key, base_url=config.api_base) - self._initialized = True + self.client = openai.Client( + api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers + ) logger.info("OpenAI LLM instance initialized") - @classmethod - def _get_config_hash(cls, config: OpenAILLMConfig) -> str: - """Generate hash value of configuration""" - config_dict = config.model_dump() - config_str = json.dumps(config_dict, sort_keys=True) - return hashlib.md5(config_str.encode()).hexdigest() - - @classmethod - def clear_cache(cls): - """Clear all cached instances""" - cls._instances.clear() - logger.info("OpenAI LLM instance cache cleared") - - @timed(log=True, log_prefix="model_timed_openai") + @timed(log=True, log_prefix="OpenAI LLM") def generate(self, messages: MessageList, **kwargs) -> str: """Generate a response from OpenAI LLM, optionally overriding generation params.""" - temperature = kwargs.get("temperature", self.config.temperature) - max_tokens = kwargs.get("max_tokens", self.config.max_tokens) - top_p = kwargs.get("top_p", self.config.top_p) - start_time = time.time() - logger.info(f"openai model request start, model_name: {self.config.model_name_or_path}") - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, + model=kwargs.get("model_name_or_path", self.config.model_name_or_path), messages=messages, - extra_body=self.config.extra_body, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - ) - - end_time = time.time() - logger.info( - f"openai model request end, time_cost: {end_time - start_time:.0f} ms, response from OpenAI: {response.model_dump_json()}" + temperature=kwargs.get("temperature", self.config.temperature), + max_tokens=kwargs.get("max_tokens", self.config.max_tokens), + top_p=kwargs.get("top_p", self.config.top_p), + extra_body=kwargs.get("extra_body", self.config.extra_body), + tools=kwargs.get("tools", NOT_GIVEN), ) + logger.info(f"Response from OpenAI: {response.model_dump_json()}") + tool_calls = getattr(response.choices[0].message, "tool_calls", None) + if isinstance(tool_calls, list) and len(tool_calls) > 0: + return self.tool_call_parser(tool_calls) response_content = response.choices[0].message.content + reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) + if isinstance(reasoning_content, str) and reasoning_content: + reasoning_content = f"{reasoning_content}" if self.config.remove_think_prefix: return remove_thinking_tags(response_content) - else: - return response_content + if reasoning_content: + return reasoning_content + response_content + return response_content @timed(log=True, log_prefix="OpenAI LLM") def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from OpenAI LLM with optional reasoning support.""" + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, stream=True, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - extra_body=self.config.extra_body, + temperature=kwargs.get("temperature", self.config.temperature), + max_tokens=kwargs.get("max_tokens", self.config.max_tokens), + top_p=kwargs.get("top_p", self.config.top_p), + extra_body=kwargs.get("extra_body", self.config.extra_body), + tools=kwargs.get("tools", NOT_GIVEN), ) reasoning_started = False @@ -104,7 +77,7 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non for chunk in response: delta = chunk.choices[0].delta - # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen) + # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek) if hasattr(delta, "reasoning_content") and delta.reasoning_content: if not reasoning_started and not self.config.remove_think_prefix: yield "" @@ -120,63 +93,44 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non if reasoning_started and not self.config.remove_think_prefix: yield "" + def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.id, + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + for tool_call in tool_calls + ] + class AzureLLM(BaseLLM): """Azure OpenAI LLM class with singleton pattern.""" - _instances: ClassVar[dict] = {} # Class variable to store instances - - def __new__(cls, config: AzureLLMConfig): - # Generate hash value of config as cache key - config_hash = cls._get_config_hash(config) - - if config_hash not in cls._instances: - logger.info(f"Creating new Azure LLM instance for config hash: {config_hash}") - instance = super().__new__(cls) - cls._instances[config_hash] = instance - else: - logger.info(f"Reusing existing Azure LLM instance for config hash: {config_hash}") - - return cls._instances[config_hash] - def __init__(self, config: AzureLLMConfig): - # Avoid duplicate initialization - if hasattr(self, "_initialized"): - return - self.config = config self.client = openai.AzureOpenAI( azure_endpoint=config.base_url, api_version=config.api_version, api_key=config.api_key, ) - self._initialized = True logger.info("Azure LLM instance initialized") - @classmethod - def _get_config_hash(cls, config: AzureLLMConfig) -> str: - """Generate hash value of configuration""" - # Convert config to dict and sort to ensure consistency - config_dict = config.model_dump() - config_str = json.dumps(config_dict, sort_keys=True) - return hashlib.md5(config_str.encode()).hexdigest() - - @classmethod - def clear_cache(cls): - """Clear all cached instances""" - cls._instances.clear() - logger.info("Azure LLM instance cache cleared") - - def generate(self, messages: MessageList) -> str: + def generate(self, messages: MessageList, **kwargs) -> str: """Generate a response from Azure OpenAI LLM.""" response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, + temperature=kwargs.get("temperature", self.config.temperature), + max_tokens=kwargs.get("max_tokens", self.config.max_tokens), + top_p=kwargs.get("top_p", self.config.top_p), + tools=kwargs.get("tools", NOT_GIVEN), + extra_body=kwargs.get("extra_body", self.config.extra_body), ) logger.info(f"Response from Azure OpenAI: {response.model_dump_json()}") + if response.choices[0].message.tool_calls: + return self.tool_call_parser(response.choices[0].message.tool_calls) response_content = response.choices[0].message.content if self.config.remove_think_prefix: return remove_thinking_tags(response_content) @@ -184,4 +138,49 @@ def generate(self, messages: MessageList) -> str: return response_content def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: - raise NotImplementedError + """Stream response from Azure OpenAI LLM with optional reasoning support.""" + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + + response = self.client.chat.completions.create( + model=self.config.model_name_or_path, + messages=messages, + stream=True, + temperature=kwargs.get("temperature", self.config.temperature), + max_tokens=kwargs.get("max_tokens", self.config.max_tokens), + top_p=kwargs.get("top_p", self.config.top_p), + extra_body=kwargs.get("extra_body", self.config.extra_body), + ) + + reasoning_started = False + + for chunk in response: + delta = chunk.choices[0].delta + + # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek) + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + if not reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = True + yield delta.reasoning_content + elif hasattr(delta, "content") and delta.content: + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield delta.content + + # Ensure we close the block if not already done + if reasoning_started and not self.config.remove_think_prefix: + yield "" + + def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.id, + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + for tool_call in tool_calls + ] diff --git a/src/memos/llms/openai_new.py b/src/memos/llms/openai_new.py new file mode 100644 index 000000000..766a17fda --- /dev/null +++ b/src/memos/llms/openai_new.py @@ -0,0 +1,198 @@ +import json + +from collections.abc import Generator + +import openai + +from openai._types import NOT_GIVEN +from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.responses.response_reasoning_item import ResponseReasoningItem + +from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig +from memos.llms.base import BaseLLM +from memos.llms.utils import remove_thinking_tags +from memos.log import get_logger +from memos.types import MessageList +from memos.utils import timed + + +logger = get_logger(__name__) + + +class OpenAIResponsesLLM(BaseLLM): + def __init__(self, config: OpenAILLMConfig): + self.config = config + self.client = openai.Client( + api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers + ) + + @timed(log=True, log_prefix="OpenAI Responses LLM") + def generate(self, messages: MessageList, **kwargs) -> str: + response = self.client.responses.create( + model=kwargs.get("model_name_or_path", self.config.model_name_or_path), + input=messages, + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens), + reasoning={"effort": "low", "summary": "auto"} + if self.config.enable_thinking + else NOT_GIVEN, + tools=kwargs.get("tools", NOT_GIVEN), + extra_body=kwargs.get("extra_body", self.config.extra_body), + ) + tool_call_outputs = [ + item for item in response.output if isinstance(item, ResponseFunctionToolCall) + ] + if tool_call_outputs: + return self.tool_call_parser(tool_call_outputs) + + output_text = getattr(response, "output_text", "") + output_reasoning = [ + item for item in response.output if isinstance(item, ResponseReasoningItem) + ] + summary = output_reasoning[0].summary + + if self.config.remove_think_prefix: + return remove_thinking_tags(output_text) + if summary: + return f"{summary[0].text}" + output_text + return output_text + + @timed(log=True, log_prefix="OpenAI Responses LLM") + def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + + stream = self.client.responses.create( + model=kwargs.get("model_name_or_path", self.config.model_name_or_path), + input=messages, + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens), + reasoning={"effort": "low", "summary": "auto"} + if self.config.enable_thinking + else NOT_GIVEN, + extra_body=kwargs.get("extra_body", self.config.extra_body), + stream=True, + ) + + reasoning_started = False + + for event in stream: + event_type = getattr(event, "type", "") + if event_type in ( + "response.reasoning.delta", + "response.reasoning_summary_text.delta", + ) and hasattr(event, "delta"): + if not self.config.remove_think_prefix: + if not reasoning_started: + yield "" + reasoning_started = True + yield event.delta + elif event_type == "response.output_text.delta" and hasattr(event, "delta"): + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield event.delta + + if reasoning_started and not self.config.remove_think_prefix: + yield "" + + def tool_call_parser(self, tool_calls: list[ResponseFunctionToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.call_id, + "function_name": tool_call.name, + "arguments": json.loads(tool_call.arguments), + } + for tool_call in tool_calls + ] + + +class AzureResponsesLLM(BaseLLM): + def __init__(self, config: AzureLLMConfig): + self.config = config + self.client = openai.AzureOpenAI( + azure_endpoint=config.base_url, + api_version=config.api_version, + api_key=config.api_key, + ) + + def generate(self, messages: MessageList, **kwargs) -> str: + response = self.client.responses.create( + model=self.config.model_name_or_path, + input=messages, + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens), + tools=kwargs.get("tools", NOT_GIVEN), + extra_body=kwargs.get("extra_body", self.config.extra_body), + reasoning={"effort": "low", "summary": "auto"} + if self.config.enable_thinking + else NOT_GIVEN, + ) + + output_text = getattr(response, "output_text", "") + output_reasoning = [ + item for item in response.output if isinstance(item, ResponseReasoningItem) + ] + summary = output_reasoning[0].summary + + if self.config.remove_think_prefix: + return remove_thinking_tags(output_text) + if summary: + return f"{summary[0].text}" + output_text + return output_text + + def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + + stream = self.client.responses.create( + model=self.config.model_name_or_path, + input=messages, + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens), + extra_body=kwargs.get("extra_body", self.config.extra_body), + stream=True, + reasoning={"effort": "low", "summary": "auto"} + if self.config.enable_thinking + else NOT_GIVEN, + ) + + reasoning_started = False + + for event in stream: + event_type = getattr(event, "type", "") + if event_type in ( + "response.reasoning.delta", + "response.reasoning_summary_text.delta", + ) and hasattr(event, "delta"): + if not self.config.remove_think_prefix: + if not reasoning_started: + yield "" + reasoning_started = True + yield event.delta + elif event_type == "response.output_text.delta" and hasattr(event, "delta"): + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield event.delta + + if reasoning_started and not self.config.remove_think_prefix: + yield "" + + def tool_call_parser(self, tool_calls: list[ResponseFunctionToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.call_id, + "function_name": tool_call.name, + "arguments": json.loads(tool_call.arguments), + } + for tool_call in tool_calls + ] diff --git a/src/memos/llms/qwen.py b/src/memos/llms/qwen.py index a47fcdf36..d54e23c7f 100644 --- a/src/memos/llms/qwen.py +++ b/src/memos/llms/qwen.py @@ -1,10 +1,6 @@ -from collections.abc import Generator - from memos.configs.llm import QwenLLMConfig from memos.llms.openai import OpenAILLM -from memos.llms.utils import remove_thinking_tags from memos.log import get_logger -from memos.types import MessageList logger = get_logger(__name__) @@ -15,49 +11,3 @@ class QwenLLM(OpenAILLM): def __init__(self, config: QwenLLMConfig): super().__init__(config) - - def generate(self, messages: MessageList) -> str: - """Generate a response from Qwen LLM.""" - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, - messages=messages, - extra_body=self.config.extra_body, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - ) - logger.info(f"Response from Qwen: {response.model_dump_json()}") - response_content = response.choices[0].message.content - if self.config.remove_think_prefix: - return remove_thinking_tags(response_content) - else: - return response_content - - def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: - """Stream response from Qwen LLM.""" - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, - messages=messages, - stream=True, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - extra_body=self.config.extra_body, - ) - - reasoning_started = False - for chunk in response: - delta = chunk.choices[0].delta - - # Some models may have separate `reasoning_content` vs `content` - # For Qwen (DashScope), likely only `content` is used - if hasattr(delta, "reasoning_content") and delta.reasoning_content: - if not reasoning_started and not self.config.remove_think_prefix: - yield "" - reasoning_started = True - yield delta.reasoning_content - elif hasattr(delta, "content") and delta.content: - if reasoning_started and not self.config.remove_think_prefix: - yield "" - reasoning_started = False - yield delta.content diff --git a/src/memos/llms/vllm.py b/src/memos/llms/vllm.py index c3750bb4b..1cf8d4f39 100644 --- a/src/memos/llms/vllm.py +++ b/src/memos/llms/vllm.py @@ -1,5 +1,11 @@ +import json + from typing import Any, cast +import openai + +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall + from memos.configs.llm import VLLMLLMConfig from memos.llms.base import BaseLLM from memos.llms.utils import remove_thinking_tags @@ -27,10 +33,10 @@ def __init__(self, config: VLLMLLMConfig): if not api_key: api_key = "dummy" - import openai - self.client = openai.Client( - api_key=api_key, base_url=getattr(self.config, "api_base", "http://localhost:8088/v1") + api_key=api_key, + base_url=getattr(self.config, "api_base", "http://localhost:8088/v1"), + default_headers=self.config.default_headers, ) def build_vllm_kv_cache(self, messages: Any) -> str: @@ -85,36 +91,54 @@ def build_vllm_kv_cache(self, messages: Any) -> str: return prompt - def generate(self, messages: list[MessageDict]) -> str: + def generate(self, messages: list[MessageDict], **kwargs) -> str: """ Generate a response from the model. """ if self.client: - return self._generate_with_api_client(messages) + return self._generate_with_api_client(messages, **kwargs) else: raise RuntimeError("API client is not available") - def _generate_with_api_client(self, messages: list[MessageDict]) -> str: + def _generate_with_api_client(self, messages: list[MessageDict], **kwargs) -> str: """ - Generate response using vLLM API client. + Generate response using vLLM API client. detail view https://docs.vllm.ai/en/latest/features/reasoning_outputs/ """ if self.client: completion_kwargs = { - "model": self.config.model_name_or_path, + "model": kwargs.get("model_name_or_path", self.config.model_name_or_path), "messages": messages, - "temperature": float(getattr(self.config, "temperature", 0.8)), - "max_tokens": int(getattr(self.config, "max_tokens", 1024)), - "top_p": float(getattr(self.config, "top_p", 0.9)), - "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}, + "temperature": kwargs.get("temperature", self.config.temperature), + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "extra_body": { + "chat_template_kwargs": { + "enable_thinking": kwargs.get( + "enable_thinking", self.config.enable_thinking + ) + } + }, } + if kwargs.get("tools"): + completion_kwargs["tools"] = kwargs.get("tools") + completion_kwargs["tool_choice"] = kwargs.get("tool_choice", "auto") response = self.client.chat.completions.create(**completion_kwargs) + + if response.choices[0].message.tool_calls: + return self.tool_call_parser(response.choices[0].message.tool_calls) + + reasoning_content = ( + f"{response.choices[0].message.reasoning}" + if hasattr(response.choices[0].message, "reasoning") + else "" + ) response_text = response.choices[0].message.content or "" logger.info(f"VLLM API response: {response_text}") return ( remove_thinking_tags(response_text) if getattr(self.config, "remove_think_prefix", False) - else response_text + else reasoning_content + response_text ) else: raise RuntimeError("API client is not available") @@ -130,26 +154,59 @@ def _messages_to_prompt(self, messages: list[MessageDict]) -> str: prompt_parts.append(f"{role.capitalize()}: {content}") return "\n".join(prompt_parts) - def generate_stream(self, messages: list[MessageDict]): + def generate_stream(self, messages: list[MessageDict], **kwargs): """ Generate a response from the model using streaming. Yields content chunks as they are received. """ + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + if self.client: completion_kwargs = { "model": self.config.model_name_or_path, "messages": messages, - "temperature": float(getattr(self.config, "temperature", 0.8)), - "max_tokens": int(getattr(self.config, "max_tokens", 1024)), - "top_p": float(getattr(self.config, "top_p", 0.9)), - "stream": True, # Enable streaming - "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}, + "temperature": kwargs.get("temperature", self.config.temperature), + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "stream": True, + "extra_body": { + "chat_template_kwargs": { + "enable_thinking": kwargs.get( + "enable_thinking", self.config.enable_thinking + ) + } + }, } stream = self.client.chat.completions.create(**completion_kwargs) + + reasoning_started = False for chunk in stream: - content = chunk.choices[0].delta.content - if content: - yield content + delta = chunk.choices[0].delta + if hasattr(delta, "reasoning") and delta.reasoning: + if not reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = True + yield delta.reasoning + + if hasattr(delta, "content") and delta.content: + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield delta.content + else: raise RuntimeError("API client is not available") + + def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.id, + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + for tool_call in tool_calls + ] diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 63b87157c..a53e19191 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -152,7 +152,6 @@ def init_mem_cube( if searcher is None: self.searcher: Searcher = self.text_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, ) else: self.searcher = searcher @@ -577,12 +576,12 @@ def get_web_log_messages(self) -> list[dict]: def _map_label(label: str) -> str: from memos.mem_scheduler.schemas.general_schemas import ( - QUERY_LABEL, - ANSWER_LABEL, ADD_LABEL, - MEM_UPDATE_LABEL, - MEM_ORGANIZE_LABEL, + ANSWER_LABEL, MEM_ARCHIVE_LABEL, + MEM_ORGANIZE_LABEL, + MEM_UPDATE_LABEL, + QUERY_LABEL, ) mapping = { diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 3859c9e6f..7da531a7f 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -1,3 +1,5 @@ +import hashlib + from collections.abc import Callable from memos.log import get_logger @@ -6,13 +8,13 @@ from memos.mem_scheduler.schemas.general_schemas import ( ACTIVATION_MEMORY_TYPE, ADD_LABEL, + MEM_ARCHIVE_LABEL, + MEM_UPDATE_LABEL, NOT_INITIALIZED, PARAMETER_MEMORY_TYPE, TEXT_MEMORY_TYPE, USER_INPUT_TYPE, WORKING_MEMORY_TYPE, - MEM_UPDATE_LABEL, - MEM_ARCHIVE_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, @@ -23,7 +25,6 @@ ) from memos.mem_scheduler.utils.misc_utils import log_exceptions from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -import hashlib logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index eeca890a9..e0d18dc72 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -12,14 +12,14 @@ ADD_LABEL, ANSWER_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, + LONG_TERM_MEMORY_TYPE, MEM_ORGANIZE_LABEL, MEM_READ_LABEL, + NOT_APPLICABLE_TYPE, PREF_ADD_LABEL, QUERY_LABEL, - WORKING_MEMORY_TYPE, USER_INPUT_TYPE, - NOT_APPLICABLE_TYPE, - LONG_TERM_MEMORY_TYPE, + WORKING_MEMORY_TYPE, MemCubeID, UserID, ) diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 21b2d63f0..f6e9b86fe 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -69,7 +69,6 @@ def submit_memory_history_async_task( "session_id": session_id, "top_k": search_req.top_k, "internet_search": search_req.internet_search, - "moscube": search_req.moscube, "chat_history": search_req.chat_history, }, "user_context": {"mem_cube_id": user_context.mem_cube_id}, @@ -112,7 +111,6 @@ def search_memories( top_k=search_req.top_k, mode=mode, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info={ "user_id": search_req.user_id, @@ -154,7 +152,6 @@ def mix_search_memories( top_k=search_req.top_k, mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info=info, ) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 5f85aa907..6e196e23a 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -190,7 +190,7 @@ def get_with_collection_name( return None return TextualMemoryItem( id=res.id, - memory=res.payload.get("dialog_str", ""), + memory=res.memory, metadata=PreferenceTextualMemoryMetadata(**res.payload), ) except Exception as e: @@ -225,7 +225,7 @@ def get_by_ids_with_collection_name( return [ TextualMemoryItem( id=memo.id, - memory=memo.payload.get("dialog_str", ""), + memory=memo.memory, metadata=PreferenceTextualMemoryMetadata(**memo.payload), ) for memo in res @@ -248,19 +248,43 @@ def get_all(self) -> list[TextualMemoryItem]: all_memories[collection_name] = [ TextualMemoryItem( id=memo.id, - memory=memo.payload.get("dialog_str", ""), + memory=memo.memory, metadata=PreferenceTextualMemoryMetadata(**memo.payload), ) for memo in items ] return all_memories + def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[TextualMemoryItem]: + """Get memories by filter. + Args: + filter (dict[str, Any]): Filter criteria. + Returns: + list[TextualMemoryItem]: List of memories that match the filter. + """ + collection_list = self.vector_db.config.collection_name + all_db_items = [] + for collection_name in collection_list: + db_items = self.vector_db.get_by_filter(collection_name=collection_name, filter=filter) + all_db_items.extend(db_items) + memories = [ + TextualMemoryItem( + id=memo.id, + memory=memo.memory, + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in all_db_items + ] + return memories + def delete(self, memory_ids: list[str]) -> None: """Delete memories. Args: memory_ids (list[str]): List of memory IDs to delete. """ - raise NotImplementedError + collection_list = self.vector_db.config.collection_name + for collection_name in collection_list: + self.vector_db.delete(collection_name, memory_ids) def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: """Delete memories by their IDs and collection name. diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 1b2355bc8..27c33029c 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -129,7 +129,6 @@ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int def get_searcher( self, manual_close_internet: bool = False, - moscube: bool = False, ): if (self.internet_retriever is not None) and manual_close_internet: logger.warning( @@ -141,7 +140,6 @@ def get_searcher( self.embedder, self.reranker, internet_retriever=None, - moscube=moscube, ) else: searcher = Searcher( @@ -150,7 +148,6 @@ def get_searcher( self.embedder, self.reranker, internet_retriever=self.internet_retriever, - moscube=moscube, ) return searcher @@ -162,7 +159,6 @@ def search( mode: str = "fast", memory_type: str = "All", manual_close_internet: bool = True, - moscube: bool = False, search_filter: dict | None = None, user_name: str | None = None, ) -> list[TextualMemoryItem]: @@ -179,7 +175,6 @@ def search( memory_type (str): Type restriction for search. ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory'] manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config. - moscube (bool): whether you use moscube to answer questions search_filter (dict, optional): Optional metadata filters for search results. - Keys correspond to memory metadata fields (e.g., "user_id", "session_id"). - Values are exact-match conditions. @@ -196,7 +191,6 @@ def search( self.reranker, bm25_retriever=self.bm25_retriever, internet_retriever=None, - moscube=moscube, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, ) @@ -208,7 +202,6 @@ def search( self.reranker, bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, - moscube=moscube, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py index 31b914776..042ed837e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py @@ -200,9 +200,11 @@ def _process_result( """Process one Bocha search result into TextualMemoryItem.""" title = result.get("name", "") content = result.get("summary", "") or result.get("snippet", "") - summary = result.get("snippet", "") + summary = result.get("summary", "") or result.get("snippet", "") url = result.get("url", "") publish_time = result.get("datePublished", "") + site_name = result.get("siteName", "") + site_icon = result.get("siteIcon") if publish_time: try: @@ -229,5 +231,12 @@ def _process_result( read_item_i.metadata.memory_type = "OuterMemory" read_item_i.metadata.sources = [SourceMessage(type="web", url=url)] if url else [] read_item_i.metadata.visibility = "public" + read_item_i.metadata.internet_info = { + "title": title, + "url": url, + "site_name": site_name, + "site_icon": site_icon, + "summary": summary, + } memory_items.append(read_item_i) return memory_items diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 933ef5af1..26ae1a723 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -41,7 +41,6 @@ def __init__( reranker: BaseReranker, bm25_retriever: EnhancedBM25 | None = None, internet_retriever: None = None, - moscube: bool = False, search_strategy: dict | None = None, manual_close_internet: bool = True, ): @@ -56,7 +55,6 @@ def __init__( # Create internet retriever from config if provided self.internet_retriever = internet_retriever - self.moscube = moscube self.vec_cot = search_strategy.get("cot", False) if search_strategy else False self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False self.manual_close_internet = manual_close_internet @@ -297,17 +295,6 @@ def _retrieve_paths( user_name, ) ) - if self.moscube: - tasks.append( - executor.submit( - self._retrieve_from_memcubes, - query, - parsed_goal, - query_embedding, - top_k, - "memos_cube01", - ) - ) results = [] for t in tasks: diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index f34cad1ef..2055615d2 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -232,7 +232,6 @@ def _fast_search( top_k=search_req.top_k, mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info={ "user_id": search_req.user_id, @@ -287,7 +286,6 @@ def _fine_search( top_k=search_req.top_k, mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info=info, ) diff --git a/src/memos/types/__init__.py b/src/memos/types/__init__.py new file mode 100644 index 000000000..dd1b98305 --- /dev/null +++ b/src/memos/types/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa: F403, F401 + +from .types import * diff --git a/src/memos/types/openai_chat_completion_types/__init__.py b/src/memos/types/openai_chat_completion_types/__init__.py new file mode 100644 index 000000000..4a08a9f24 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/__init__.py @@ -0,0 +1,15 @@ +# ruff: noqa: F403, F401 + +from .chat_completion_assistant_message_param import * +from .chat_completion_content_part_image_param import * +from .chat_completion_content_part_input_audio_param import * +from .chat_completion_content_part_param import * +from .chat_completion_content_part_refusal_param import * +from .chat_completion_content_part_text_param import * +from .chat_completion_message_custom_tool_call_param import * +from .chat_completion_message_function_tool_call_param import * +from .chat_completion_message_param import * +from .chat_completion_message_tool_call_union_param import * +from .chat_completion_system_message_param import * +from .chat_completion_tool_message_param import * +from .chat_completion_user_message_param import * diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py new file mode 100644 index 000000000..a742de3a9 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py @@ -0,0 +1,55 @@ +# ruff: noqa: TC001, TC003 + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Literal, TypeAlias + +from typing_extensions import Required, TypedDict + +from .chat_completion_content_part_refusal_param import ChatCompletionContentPartRefusalParam +from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam +from .chat_completion_message_tool_call_union_param import ChatCompletionMessageToolCallUnionParam + + +__all__ = ["Audio", "ChatCompletionAssistantMessageParam", "ContentArrayOfContentPart"] + + +class Audio(TypedDict, total=False): + id: Required[str] + """Unique identifier for a previous audio response from the model.""" + + +ContentArrayOfContentPart: TypeAlias = ( + ChatCompletionContentPartTextParam | ChatCompletionContentPartRefusalParam +) + + +class ChatCompletionAssistantMessageParam(TypedDict, total=False): + role: Required[Literal["assistant"]] + """The role of the messages author, in this case `assistant`.""" + + audio: Audio | None + """ + Data about a previous audio response from the model. + [Learn more](https://platform.openai.com/docs/guides/audio). + """ + + content: str | Iterable[ContentArrayOfContentPart] | None + """The contents of the assistant message. + + Required unless `tool_calls` or `function_call` is specified. + """ + + refusal: str | None + """The refusal message by the assistant.""" + + tool_calls: Iterable[ChatCompletionMessageToolCallUnionParam] + """The tool calls generated by the model, such as function calls.""" + + chat_time: str | None + """Optional timestamp for the message, format is not + restricted, it can be any vague or precise time string.""" + + message_id: str | None + """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py new file mode 100644 index 000000000..6718bd91e --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Literal + +from typing_extensions import Required, TypedDict + + +__all__ = ["ChatCompletionContentPartImageParam", "ImageURL"] + + +class ImageURL(TypedDict, total=False): + url: Required[str] + """Either a URL of the image or the base64 encoded image data.""" + + detail: Literal["auto", "low", "high"] + """Specifies the detail level of the image. + + Learn more in the + [Vision guide](https://platform.openai.com/docs/guides/vision#low-or-high-fidelity-image-understanding). + """ + + +class ChatCompletionContentPartImageParam(TypedDict, total=False): + image_url: Required[ImageURL] + + type: Required[Literal["image_url"]] + """The type of the content part.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py new file mode 100644 index 000000000..e7cfa4504 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Literal + +from typing_extensions import Required, TypedDict + + +__all__ = ["ChatCompletionContentPartInputAudioParam", "InputAudio"] + + +class InputAudio(TypedDict, total=False): + data: Required[str] + """Base64 encoded audio data.""" + + format: Required[Literal["wav", "mp3"]] + """The format of the encoded audio data. Currently supports "wav" and "mp3".""" + + +class ChatCompletionContentPartInputAudioParam(TypedDict, total=False): + input_audio: Required[InputAudio] + + type: Required[Literal["input_audio"]] + """The type of the content part. Always `input_audio`.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py new file mode 100644 index 000000000..a5e740791 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Literal, TypeAlias + +from typing_extensions import Required, TypedDict + +from .chat_completion_content_part_image_param import ChatCompletionContentPartImageParam +from .chat_completion_content_part_input_audio_param import ChatCompletionContentPartInputAudioParam +from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam + + +__all__ = ["ChatCompletionContentPartParam", "File", "FileFile"] + + +class FileFile(TypedDict, total=False): + file_data: str + """ + The base64 encoded file data, used when passing the file to the model as a + string. + """ + + file_id: str + """The ID of an uploaded file to use as input.""" + + filename: str + """The name of the file, used when passing the file to the model as a string.""" + + +class File(TypedDict, total=False): + file: Required[FileFile] + + type: Required[Literal["file"]] + """The type of the content part. Always `file`.""" + + +ChatCompletionContentPartParam: TypeAlias = ( + ChatCompletionContentPartTextParam + | ChatCompletionContentPartImageParam + | ChatCompletionContentPartInputAudioParam + | File +) diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py new file mode 100644 index 000000000..fc87e9e1a --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import Literal + +from typing_extensions import Required, TypedDict + + +__all__ = ["ChatCompletionContentPartRefusalParam"] + + +class ChatCompletionContentPartRefusalParam(TypedDict, total=False): + refusal: Required[str] + """The refusal message generated by the model.""" + + type: Required[Literal["refusal"]] + """The type of the content part.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py new file mode 100644 index 000000000..f43de0eff --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import Literal + +from typing_extensions import Required, TypedDict + + +__all__ = ["ChatCompletionContentPartTextParam"] + + +class ChatCompletionContentPartTextParam(TypedDict, total=False): + text: Required[str] + """The text content.""" + + type: Required[Literal["text"]] + """The type of the content part.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py new file mode 100644 index 000000000..bc7a22edb --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Literal + +from typing_extensions import Required, TypedDict + + +__all__ = ["ChatCompletionMessageCustomToolCallParam", "Custom"] + + +class Custom(TypedDict, total=False): + input: Required[str] + """The input for the custom tool call generated by the model.""" + + name: Required[str] + """The name of the custom tool to call.""" + + +class ChatCompletionMessageCustomToolCallParam(TypedDict, total=False): + id: Required[str] + """The ID of the tool call.""" + + custom: Required[Custom] + """The custom tool that the model called.""" + + type: Required[Literal["custom"]] + """The type of the tool. Always `custom`.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py new file mode 100644 index 000000000..56341d94a --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import Literal + +from typing_extensions import Required, TypedDict + + +__all__ = ["ChatCompletionMessageFunctionToolCallParam", "Function"] + + +class Function(TypedDict, total=False): + arguments: Required[str] + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: Required[str] + """The name of the function to call.""" + + +class ChatCompletionMessageFunctionToolCallParam(TypedDict, total=False): + id: Required[str] + """The ID of the tool call.""" + + function: Required[Function] + """The function that the model called.""" + + type: Required[Literal["function"]] + """The type of the tool. Currently, only `function` is supported.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py new file mode 100644 index 000000000..06a624297 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import TypeAlias + +from .chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam +from .chat_completion_system_message_param import ChatCompletionSystemMessageParam +from .chat_completion_tool_message_param import ChatCompletionToolMessageParam +from .chat_completion_user_message_param import ChatCompletionUserMessageParam + + +__all__ = ["ChatCompletionMessageParam"] + +ChatCompletionMessageParam: TypeAlias = ( + ChatCompletionSystemMessageParam + | ChatCompletionUserMessageParam + | ChatCompletionAssistantMessageParam + | ChatCompletionToolMessageParam +) diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py new file mode 100644 index 000000000..28bb880cf --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TypeAlias + +from .chat_completion_message_custom_tool_call_param import ChatCompletionMessageCustomToolCallParam +from .chat_completion_message_function_tool_call_param import ( + ChatCompletionMessageFunctionToolCallParam, +) + + +__all__ = ["ChatCompletionMessageToolCallUnionParam"] + +ChatCompletionMessageToolCallUnionParam: TypeAlias = ( + ChatCompletionMessageFunctionToolCallParam | ChatCompletionMessageCustomToolCallParam +) diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py new file mode 100644 index 000000000..7faa90e2e --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py @@ -0,0 +1,35 @@ +# ruff: noqa: TC001, TC003 + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Literal + +from typing_extensions import Required, TypedDict + +from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam + + +__all__ = ["ChatCompletionSystemMessageParam"] + + +class ChatCompletionSystemMessageParam(TypedDict, total=False): + content: Required[str | Iterable[ChatCompletionContentPartTextParam]] + """The contents of the system message.""" + + role: Required[Literal["system"]] + """The role of the messages author, in this case `system`.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the same + role. + """ + + chat_time: str | None + """Optional timestamp for the message, format is not + restricted, it can be any vague or precise time string.""" + + message_id: str | None + """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py new file mode 100644 index 000000000..c03220915 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py @@ -0,0 +1,31 @@ +# ruff: noqa: TC001, TC003 + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Literal + +from typing_extensions import Required, TypedDict + +from .chat_completion_content_part_param import ChatCompletionContentPartParam + + +__all__ = ["ChatCompletionToolMessageParam"] + + +class ChatCompletionToolMessageParam(TypedDict, total=False): + content: Required[str | Iterable[ChatCompletionContentPartParam]] + """The contents of the tool message.""" + + role: Required[Literal["tool"]] + """The role of the messages author, in this case `tool`.""" + + tool_call_id: Required[str] + """Tool call that this message is responding to.""" + + chat_time: str | None + """Optional timestamp for the message, format is not + restricted, it can be any vague or precise time string.""" + + message_id: str | None + """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py new file mode 100644 index 000000000..2c2a1f23f --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py @@ -0,0 +1,35 @@ +# ruff: noqa: TC001, TC003 + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Literal + +from typing_extensions import Required, TypedDict + +from .chat_completion_content_part_param import ChatCompletionContentPartParam + + +__all__ = ["ChatCompletionUserMessageParam"] + + +class ChatCompletionUserMessageParam(TypedDict, total=False): + content: Required[str | Iterable[ChatCompletionContentPartParam]] + """The contents of the user message.""" + + role: Required[Literal["user"]] + """The role of the messages author, in this case `user`.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the same + role. + """ + + chat_time: str | None + """Optional timestamp for the message, format is not + restricted, it can be any vague or precise time string.""" + + message_id: str | None + """Optional unique identifier for the message""" diff --git a/src/memos/types.py b/src/memos/types/types.py similarity index 82% rename from src/memos/types.py rename to src/memos/types/types.py index 635fabccc..b8efc6208 100644 --- a/src/memos/types.py +++ b/src/memos/types/types.py @@ -14,6 +14,23 @@ from memos.memories.parametric.item import ParametricMemoryItem from memos.memories.textual.item import TextualMemoryItem +from .openai_chat_completion_types import ( + ChatCompletionContentPartTextParam, + ChatCompletionMessageParam, + File, +) + + +__all__ = [ + "ChatHistory", + "MOSSearchResult", + "MessageDict", + "MessageList", + "MessageRole", + "Permission", + "PermissionDict", + "UserContext", +] # ─── Message Types ────────────────────────────────────────────────────────────── @@ -32,8 +49,16 @@ class MessageDict(TypedDict, total=False): message_id: str | None # Optional unique identifier for the message +RawMessageDict: TypeAlias = ChatCompletionContentPartTextParam | File + + # Message collections -MessageList: TypeAlias = list[MessageDict] +MessageList: TypeAlias = list[ChatCompletionMessageParam] +RawMessageList: TypeAlias = list[RawMessageDict] + + +# Messages Type +MessagesType: TypeAlias = str | MessageList | RawMessageList # Chat history structure diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py index a977a4004..6562c9a95 100644 --- a/tests/configs/test_llm.py +++ b/tests/configs/test_llm.py @@ -19,7 +19,14 @@ def test_base_llm_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["temperature", "max_tokens", "top_p", "top_k", "remove_think_prefix"], + optional_fields=[ + "temperature", + "max_tokens", + "top_p", + "top_k", + "remove_think_prefix", + "default_headers", + ], ) check_config_instantiation_valid( @@ -48,6 +55,7 @@ def test_openai_llm_config(): "api_base", "remove_think_prefix", "extra_body", + "default_headers", ], ) @@ -79,6 +87,8 @@ def test_ollama_llm_config(): "top_k", "remove_think_prefix", "api_base", + "default_headers", + "enable_thinking", ], ) @@ -111,6 +121,7 @@ def test_hf_llm_config(): "do_sample", "remove_think_prefix", "add_generation_prompt", + "default_headers", ], ) diff --git a/tests/llms/test_deepseek.py b/tests/llms/test_deepseek.py index 75c1ead5f..11be66887 100644 --- a/tests/llms/test_deepseek.py +++ b/tests/llms/test_deepseek.py @@ -12,12 +12,14 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self): """Test DeepSeekLLM generate method with and without tag removal.""" # Simulated full content including tag - full_content = "Thinking in progress...Hello from DeepSeek!" + full_content = "Hello from DeepSeek!" + reasoning_content = "Thinking in progress..." # Mock response object mock_response = MagicMock() mock_response.model_dump_json.return_value = '{"mock": "true"}' mock_response.choices[0].message.content = full_content + mock_response.choices[0].message.reasoning_content = reasoning_content # Config with think prefix preserved config_with_think = DeepSeekLLMConfig.model_validate( @@ -35,7 +37,7 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self): llm_with_think.client.chat.completions.create = MagicMock(return_value=mock_response) output_with_think = llm_with_think.generate([{"role": "user", "content": "Hello"}]) - self.assertEqual(output_with_think, full_content) + self.assertEqual(output_with_think, f"{reasoning_content}{full_content}") # Config with think tag removed config_without_think = config_with_think.model_copy(update={"remove_think_prefix": True}) @@ -43,7 +45,7 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self): llm_without_think.client.chat.completions.create = MagicMock(return_value=mock_response) output_without_think = llm_without_think.generate([{"role": "user", "content": "Hello"}]) - self.assertEqual(output_without_think, "Hello from DeepSeek!") + self.assertEqual(output_without_think, full_content) def test_deepseek_llm_generate_stream(self): """Test DeepSeekLLM generate_stream with reasoning_content and content chunks.""" @@ -84,5 +86,5 @@ def make_chunk(delta_dict): self.assertIn("Analyzing...", full_output) self.assertIn("Hello, DeepSeek!", full_output) - self.assertTrue(full_output.startswith("Analyzing...")) + self.assertTrue(full_output.startswith("")) self.assertTrue(full_output.endswith("DeepSeek!")) diff --git a/tests/llms/test_ollama.py b/tests/llms/test_ollama.py index 47002a21f..9ed252f37 100644 --- a/tests/llms/test_ollama.py +++ b/tests/llms/test_ollama.py @@ -1,5 +1,6 @@ import unittest +from types import SimpleNamespace from unittest.mock import MagicMock from memos.configs.llm import LLMConfigFactory, OllamaLLMConfig @@ -12,15 +13,15 @@ def test_llm_factory_with_mocked_ollama_backend(self): """Test LLMFactory with mocked Ollama backend.""" mock_chat = MagicMock() mock_response = MagicMock() - mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","images":null,"tool_calls":null}}' - mock_response.__getitem__.side_effect = lambda key: { - "message": { - "role": "assistant", - "content": "Hello! How are you? I'm here to help and smile!", - "images": None, - "tool_calls": None, - } - }[key] + mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!", "thinking":"Analyzing your request...","images":null,"tool_calls":null}}' + + mock_response.message = SimpleNamespace( + role="assistant", + content="Hello! How are you? I'm here to help and smile!", + thinking="Analyzing your request...", + images=None, + tool_calls=None, + ) mock_chat.return_value = mock_response config = LLMConfigFactory.model_validate( @@ -32,6 +33,7 @@ def test_llm_factory_with_mocked_ollama_backend(self): "max_tokens": 1024, "top_p": 0.9, "top_k": 50, + "enable_thinking": True, }, } ) @@ -42,21 +44,23 @@ def test_llm_factory_with_mocked_ollama_backend(self): ] response = llm.generate(messages) - self.assertEqual(response, "Hello! How are you? I'm here to help and smile!") + self.assertEqual( + response, + "Analyzing your request...Hello! How are you? I'm here to help and smile!", + ) def test_ollama_llm_with_mocked_backend(self): """Test OllamaLLM with mocked backend.""" mock_chat = MagicMock() mock_response = MagicMock() - mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","images":null,"tool_calls":null}}' - mock_response.__getitem__.side_effect = lambda key: { - "message": { - "role": "assistant", - "content": "Hello! How are you? I'm here to help and smile!", - "images": None, - "tool_calls": None, - } - }[key] + mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","thinking":"Analyzing your request...","images":null,"tool_calls":null}}' + mock_response.message = SimpleNamespace( + role="assistant", + content="Hello! How are you? I'm here to help and smile!", + thinking="Analyzing your request...", + images=None, + tool_calls=None, + ) mock_chat.return_value = mock_response config = OllamaLLMConfig( @@ -73,4 +77,7 @@ def test_ollama_llm_with_mocked_backend(self): ] response = ollama.generate(messages) - self.assertEqual(response, "Hello! How are you? I'm here to help and smile!") + self.assertEqual( + response, + "Analyzing your request...Hello! How are you? I'm here to help and smile!", + ) diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index dff57c058..ba5b52df4 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -14,6 +14,7 @@ def test_llm_factory_with_mocked_openai_backend(self): mock_response = MagicMock() mock_response.model_dump_json.return_value = '{"id":"chatcmpl-BWoqIrvOeWdnFVZQUFzCcdVEpJ166","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Hello! I\'m an AI language model created by OpenAI. I\'m here to help answer questions, provide information, and assist with a wide range of topics. How can I assist you today?","role":"assistant"}}],"created":1747161634,"model":"gpt-4o-2024-08-06","object":"chat.completion"}' mock_response.choices[0].message.content = "Hello! I'm an AI language model created by OpenAI. I'm here to help answer questions, provide information, and assist with a wide range of topics. How can I assist you today?" # fmt: skip + mock_response.choices[0].message.reasoning_content = None mock_chat_completions_create.return_value = mock_response config = LLMConfigFactory.model_validate( diff --git a/tests/llms/test_qwen.py b/tests/llms/test_qwen.py index 90f31e47f..71a4c75dd 100644 --- a/tests/llms/test_qwen.py +++ b/tests/llms/test_qwen.py @@ -12,12 +12,14 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self): """Test QwenLLM non-streaming response generation with and without prefix removal.""" # Simulated full response content with tag - full_content = "Analyzing your request...Hello, world!" + full_content = "Hello from DeepSeek!" + reasoning_content = "Thinking in progress..." # Prepare the mock response object with expected structure mock_response = MagicMock() mock_response.model_dump_json.return_value = '{"mocked": "true"}' mock_response.choices[0].message.content = full_content + mock_response.choices[0].message.reasoning_content = reasoning_content # Create config with remove_think_prefix = False config_with_think = QwenLLMConfig.model_validate( @@ -37,7 +39,7 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self): llm_with_think.client.chat.completions.create = MagicMock(return_value=mock_response) response_with_think = llm_with_think.generate([{"role": "user", "content": "Hi"}]) - self.assertEqual(response_with_think, full_content) + self.assertEqual(response_with_think, f"{reasoning_content}{full_content}") # Create config with remove_think_prefix = True config_without_think = config_with_think.model_copy(update={"remove_think_prefix": True}) @@ -47,7 +49,7 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self): llm_without_think.client.chat.completions.create = MagicMock(return_value=mock_response) response_without_think = llm_without_think.generate([{"role": "user", "content": "Hi"}]) - self.assertEqual(response_without_think, "Hello, world!") + self.assertEqual(response_without_think, full_content) self.assertNotIn("", response_without_think) def test_qwen_llm_generate_stream(self):