diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py new file mode 100644 index 000000000..6a9405456 --- /dev/null +++ b/examples/mem_agent/deepsearch_example.py @@ -0,0 +1,191 @@ +""" +DeepSearch Agent Usage Examples - Simplified Version + +This example demonstrates simplified initialization of DeepSearchMemAgent without +external config builders, using APIConfig methods directly. +""" + +import os + +from typing import Any + +from memos.api.config import APIConfig +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory +from memos.configs.mem_agent import MemAgentConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.reranker import RerankerConfigFactory +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.log import get_logger +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent +from memos.mem_agent.factory import MemAgentFactory +from memos.mem_cube.navie import NaiveMemCube +from memos.mem_reader.factory import MemReaderFactory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory + + +logger = get_logger(__name__) + + +def build_minimal_components(): + """ + Build minimal components for DeepSearchMemAgent with simplified configuration. + + This function creates all necessary components using APIConfig methods, + similar to config_builders.py but inline for easier customization. + """ + logger.info("Initializing simplified MemOS components...") + + # Build component configurations using APIConfig methods (like config_builders.py) + + # Graph DB configuration - using APIConfig.get_nebular_config() + graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower() + graph_db_backend_map = { + "polardb": APIConfig.get_polardb_config(), + } + graph_db_config = GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } + ) + + # LLM configuration - using APIConfig.get_openai_config() + llm_config = LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } + ) + + # Embedder configuration - using APIConfig.get_embedder_config() + embedder_config = EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + # Memory reader configuration - using APIConfig.get_product_default_config() + mem_reader_config = MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] + ) + + # Reranker configuration - using APIConfig.get_reranker_config() + reranker_config = RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + # Internet retriever configuration - using APIConfig.get_internet_config() + internet_retriever_config = InternetRetrieverConfigFactory.model_validate( + APIConfig.get_internet_config() + ) + + logger.debug("Component configurations built successfully") + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder + ) + + logger.debug("Core components instantiated") + + # Get default cube configuration like component_init.py + default_cube_config = APIConfig.get_default_cube_config() + + # Get default memory size from cube config (like component_init.py) + def get_memory_size_from_config(cube_config): + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + memory_size = get_memory_size_from_config(default_cube_config) + is_reorganize = getattr(default_cube_config.text_mem.config, "reorganize", False) + + # Initialize memory manager with config from APIConfig + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=memory_size, + is_reorganize=is_reorganize, + ) + text_memory_config = default_cube_config.text_mem.config + text_mem = SimpleTreeTextMemory( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + memory_manager=memory_manager, + config=text_memory_config, + internet_retriever=internet_retriever, + ) + + naive_mem_cube = NaiveMemCube( + text_mem=text_mem, + pref_mem=None, # Simplified: no preference memory + act_mem=None, + para_mem=None, + ) + + return { + "llm": llm, + "naive_mem_cube": naive_mem_cube, + "embedder": embedder, + "graph_db": graph_db, + "mem_reader": mem_reader, + } + + +def factory_initialization() -> tuple[DeepSearchMemAgent, dict[str, Any]]: + # Build necessary components with simplified setup + components = build_minimal_components() + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + + # Create configuration Factory with simplified config + agent_config_factory = MemAgentConfigFactory( + backend="deep_search", + config={ + "agent_name": "SimplifiedDeepSearchAgent", + "description": "Simplified intelligent agent for deep search", + "max_iterations": 3, # Maximum number of iterations + "timeout": 60, # Timeout in seconds + }, + ) + + # Create Agent using Factory + # Pass text_mem as memory_retriever, it provides search method + deep_search_agent = MemAgentFactory.from_config( + config_factory=agent_config_factory, llm=llm, memory_retriever=naive_mem_cube.text_mem + ) + + logger.info("✓ DeepSearchMemAgent created successfully") + logger.info(f" - Agent name: {deep_search_agent.config.agent_name}") + logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") + logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") + + return deep_search_agent, components + + +def main(): + agent_factory, components_factory = factory_initialization() + results = agent_factory.run( + "Caroline met up with friends, family, and mentors in early July 2023.", + user_id="locomo_exp_user_0_speaker_b_ct-1118", + ) + print(results) + + +if __name__ == "__main__": + main() diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py new file mode 100644 index 000000000..7cb623899 --- /dev/null +++ b/src/memos/configs/mem_agent.py @@ -0,0 +1,54 @@ +from typing import Any, ClassVar + +from pydantic import Field, field_validator, model_validator + +from memos.configs.base import BaseConfig + + +class BaseAgentConfig(BaseConfig): + """Base configuration class for agents.""" + + agent_name: str = Field(..., description="Name of the agent") + description: str | None = Field(default=None, description="Description of the agent") + + +class SimpleAgentConfig(BaseAgentConfig): + """Simple agent configuration class.""" + + max_iterations: int = Field( + default=10, description="Maximum number of iterations for the agent" + ) + timeout: int = Field(default=30, description="Timeout in seconds for agent execution") + + +class DeepSearchAgentConfig(BaseAgentConfig): + """Deep search agent configuration class.""" + + max_iterations: int = Field(default=3, description="Maximum number of iterations for the agent") + timeout: int = Field(default=30, description="Timeout in seconds for agent execution") + + +class MemAgentConfigFactory(BaseConfig): + """Factory class for creating agent configurations.""" + + backend: str = Field(..., description="Backend for agent") + config: dict[str, Any] = Field(..., description="Configuration for the agent backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "simple": SimpleAgentConfig, + "deep_search": DeepSearchAgentConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "MemAgentConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py new file mode 100644 index 000000000..daa5f075b --- /dev/null +++ b/src/memos/mem_agent/base.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +from memos.configs.mem_agent import BaseAgentConfig + + +class BaseMemAgent(ABC): + """ + Base class for all agents. + """ + + def __init__(self, config: BaseAgentConfig): + """Initialize the BaseMemAgent with the given configuration.""" + self.config = config + + @abstractmethod + def run(self, input: str) -> str: + """ + Run the agent. + """ diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py new file mode 100644 index 000000000..5a070c6ad --- /dev/null +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -0,0 +1,375 @@ +""" +Deep Search Agent implementation for MemOS. + +This module implements a sophisticated deep search agent that performs iterative +query refinement and memory retrieval to provide comprehensive answers. +""" + +import json +import re + +from typing import TYPE_CHECKING, Any + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_agent.base import BaseMemAgent +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree import TreeTextMemory +from memos.templates.mem_agent_prompts import ( + FINAL_GENERATION_PROMPT, + QUERY_REWRITE_PROMPT, + REFLECTION_PROMPT, +) + + +if TYPE_CHECKING: + from memos.types import MessageList + + +class JSONResponseParser: + """Elegant JSON response parser for LLM outputs""" + + @staticmethod + def parse(response: str) -> dict[str, Any]: + """Parse JSON response from LLM output with fallback strategies""" + # Clean response text by removing code block markers + cleaned = re.sub(r"^```(?:json)?\s*\n?|```\s*$", "", response.strip(), flags=re.IGNORECASE) + + # Try parsing with multiple strategies + for text in [cleaned, re.search(r"\{.*\}", cleaned, re.DOTALL)]: + if not text: + continue + try: + return json.loads(text if isinstance(text, str) else text.group()) + except json.JSONDecodeError: + continue + + raise ValueError(f"Cannot parse JSON response: {response[:100]}...") + + +logger = get_logger(__name__) + + +class QueryRewriter(BaseMemAgent): + """Specialized agent for rewriting queries based on conversation history""" + + def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): + self.llm = llm + self.name = name + + def run(self, query: str, history: list[str] | None = None) -> str: + """Rewrite query to be standalone and more searchable""" + history = history or [] + history_context = self._format_history(history) + + prompt = QUERY_REWRITE_PROMPT.format(history=history_context, query=query) + messages = [{"role": "user", "content": prompt}] + try: + response = self.llm.generate(messages) + logger.info(f"[{self.name}] Rewritten query: {response.strip()}") + return response.strip() + except Exception as e: + logger.error(f"[{self.name}] Query rewrite failed: {e}") + return query + + def _format_history(self, history: list[str]) -> str: + """Format conversation history for prompt context""" + if not history: + return "No previous conversation" + return "\n".join(f"- {msg}" for msg in history[-5:]) + + +class ReflectionAgent: + """Specialized agent for analyzing information sufficiency""" + + def __init__(self, llm: BaseLLM, name: str = "Reflector"): + self.llm = llm + self.name = name + + def run(self, query: str, context: list[str]) -> dict[str, Any]: + """Analyze whether retrieved context is sufficient to answer the query""" + context_summary = self._format_context(context) + prompt = REFLECTION_PROMPT.format(query=query, context=context_summary) + + try: + response = self.llm.generate([{"role": "user", "content": prompt}]) + logger.info(f"[{self.name}] Reflection response: {response}") + + result = JSONResponseParser.parse(response.strip()) + logger.info(f"[{self.name}] Reflection result: {result}") + return result + + except Exception as e: + logger.error(f"[{self.name}] Reflection analysis failed: {e}") + return self._fallback_response() + + def _format_context(self, context: list[str]) -> str: + """Format context strings for analysis with length limits""" + return "\n".join( + f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10] + ) + + def _fallback_response(self) -> dict[str, Any]: + """Return safe fallback when reflection fails""" + return { + "status": "sufficient", + "reasoning": "Unable to analyze, proceeding with available information", + "missing_entities": [], + } + + +class DeepSearchMemAgent(BaseMemAgent): + """ + Main orchestrator agent implementing the deep search pipeline. + + This agent coordinates multiple sub-agents to perform iterative query refinement, + memory retrieval, and information synthesis as shown in the architecture diagram. + """ + + def __init__( + self, + llm: BaseLLM, + memory_retriever: TreeTextMemory | None = None, + config: DeepSearchAgentConfig | None = None, + ): + """ + Initialize DeepSearchMemAgent. + + Args: + llm: Language model for query rewriting and response generation + memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) + config: Configuration for deep search behavior + """ + self.config = config or DeepSearchAgentConfig() + self.max_iterations = self.config.max_iterations + self.timeout = self.config.timeout + self.llm: BaseLLM = llm + self.query_rewriter: QueryRewriter = QueryRewriter(llm, "QueryRewriter") + self.reflector: ReflectionAgent = ReflectionAgent(llm, "Reflector") + self.memory_retriever = memory_retriever + + def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: + """ + Main execution method implementing the deep search pipeline. + + Args: + query: User query string + **kwargs: Additional arguments (history, user_id, etc.) + Returns: + Comprehensive response string + """ + if not self.llm: + raise RuntimeError("LLM not initialized.") + + history = kwargs.get("history", []) + user_id = kwargs.get("user_id") + generated_answer = kwargs.get("generated_answer") + + # Step 1: Query Rewriting + current_query = self.query_rewriter.run(query, history) + + accumulated_context = [] + accumulated_memories = [] + search_keywords = [] # Can be extended with keyword extraction + + # Step 2: Iterative Search and Reflection Loop + for iteration in range(self.max_iterations): + logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") + + search_results = self._perform_memory_search( + current_query, keywords=search_keywords, user_id=user_id, history=history + ) + + if search_results: + context_batch = [self._extract_context_from_memory(mem) for mem in search_results] + accumulated_context.extend(context_batch) + accumulated_memories.extend(search_results) + + reflection_result = self.reflector.run(current_query, context_batch) + status = reflection_result.get("status", "sufficient") + reasoning = reflection_result.get("reasoning", "") + + logger.info(f"Reflection status: {status} - {reasoning}") + + if status == "sufficient": + logger.info("Sufficient information collected") + break + elif status == "needs_raw": + logger.info("Need original sources, retrieving raw content") + break + elif status == "missing_info": + missing_entities = reflection_result.get("missing_entities", []) + logger.info(f"Missing information: {missing_entities}") + current_query = reflection_result.get("new_search_query") + if not current_query: + refined_query = self._refine_query_for_missing_info( + current_query, missing_entities + ) + current_query = refined_query + logger.info(f"Refined query: {current_query}") + else: + logger.warning(f"No search results for iteration {iteration + 1}") + if iteration == 0: + current_query = query + else: + break + + if not generated_answer: + return self._remove_duplicate_memories(accumulated_memories) + else: + return self._generate_final_answer( + query, accumulated_memories, accumulated_context, "", history + ) + + def _remove_duplicate_memories( + self, memories: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """ + Remove duplicate memories based on memory content. + + Args: + memories: List of TextualMemoryItem objects to deduplicate + + Returns: + List of unique TextualMemoryItem objects (first occurrence kept) + """ + seen = set() + return [ + memory + for memory in memories + if (content := getattr(memory, "memory", "").strip()) + and content not in seen + and not seen.add(content) + ] + + def _generate_final_answer( + self, + original_query: str, + search_results: list[TextualMemoryItem], + context: list[str], + missing_info: str = "", + history: list[str] | None = None, + sources: list[str] | None = None, + ) -> str: + """ + Generate the final answer. + """ + context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) + prompt = FINAL_GENERATION_PROMPT.format( + query=original_query, + sources=sources, + context=context_str if context_str else "No specific context retrieved", + missing_info=missing_info if missing_info else "None identified", + ) + messages: MessageList = [{"role": "user", "content": prompt}] + response = self.llm.generate(messages) + return response.strip() + + def _perform_memory_search( + self, + query: str, + keywords: list[str] | None = None, + user_id: str | None = None, + history: list[str] | None = None, + top_k: int = 10, + ) -> list[TextualMemoryItem]: + """ + Perform memory search using the configured retriever. + + Args: + query: Search query + keywords: Additional keywords for search + user_id: User identifier + top_k: Number of results to retrieve + + Returns: + List of retrieved memory items + """ + if not self.memory_retriever: + logger.warning("Memory retriever not configured, returning empty results") + return [] + + try: + # Use the memory retriever interface + # This is a placeholder - actual implementation depends on the retriever interface + search_query = query + if keywords and len(keywords) > 1: + search_query = f"{query} {' '.join(keywords[:3])}" # Combine with top keywords + + # Assuming the retriever has a search method similar to TreeTextMemory + results = self.memory_retriever.search( + query=search_query, + top_k=top_k, + mode="fast", + user_name=user_id, + info={"history": history}, + ) + + return results if isinstance(results, list) else [] + + except Exception as e: + logger.error(f"Error performing memory search: {e}") + return [] + + def _extract_context_from_memory(self, memory_item: TextualMemoryItem) -> str: + """Extract readable context from a memory item.""" + if hasattr(memory_item, "memory"): + return str(memory_item.memory) + elif hasattr(memory_item, "content"): + return str(memory_item.content) + else: + return str(memory_item) + + def _refine_query_for_missing_info(self, query: str, missing_entities: list[str]) -> str: + """Refine the query to search for missing information.""" + if not missing_entities: + return query + + # Simple refinement strategy - append missing entities + entities_str = " ".join(missing_entities[:3]) # Limit to top 3 entities + refined_query = f"{query} {entities_str}" + + return refined_query + + def _generate_final_answer( + self, + original_query: str, + search_results: list[TextualMemoryItem], + context: list[str], + missing_info: str = "", + ) -> str: + """ + Generate the final comprehensive answer. + + Args: + original_query: Original user query + search_results: All retrieved memory items + context: Extracted context strings + missing_info: Information about missing data + + Returns: + Final answer string + """ + # Prepare context for the prompt + context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) # Limit context + sources = ( + f"Retrieved {len(search_results)} memory items" + if search_results + else "No specific sources" + ) + + prompt = FINAL_GENERATION_PROMPT.format( + query=original_query, + sources=sources, + context=context_str if context_str else "No specific context retrieved", + missing_info=missing_info if missing_info else "None identified", + ) + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + return response.strip() + except Exception as e: + logger.error(f"Error generating final answer: {e}") + return f"I apologize, but I encountered an error while processing your query: {original_query}. Please try again." diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py new file mode 100644 index 000000000..09537bd8a --- /dev/null +++ b/src/memos/mem_agent/factory.py @@ -0,0 +1,36 @@ +from typing import Any, ClassVar + +from memos.configs.mem_agent import MemAgentConfigFactory +from memos.mem_agent.base import BaseMemAgent +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent + + +class MemAgentFactory: + """Factory class for creating MemAgent instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "deep_search": DeepSearchMemAgent, + } + + @classmethod + def from_config( + cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Any | None = None + ) -> BaseMemAgent: + """ + Create a MemAgent instance from configuration. + + Args: + config_factory: Configuration factory for the agent + llm: Language model instance + memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) + + Returns: + Initialized MemAgent instance + """ + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + mem_agent_class = cls.backend_to_class[backend] + return mem_agent_class( + llm=llm, memory_retriever=memory_retriever, config=config_factory.config + ) diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py new file mode 100644 index 000000000..477cd2409 --- /dev/null +++ b/src/memos/templates/mem_agent_prompts.py @@ -0,0 +1,77 @@ +QUERY_REWRITE_PROMPT = """ +You are a query rewriting specialist. Your task is to rewrite user queries to be more standalone and searchable. + +Given the conversation history and current user query, rewrite the query to: +1. Be self-contained and independent of conversation context +2. Include relevant context from history when necessary +3. Maintain the original intent and scope +4. Use clear, specific terminology + +Conversation History: +{history} + +Current Query: {query} + +Rewritten Query:""" + +REFLECTION_PROMPT = """ +You are an information sufficiency analyst. Evaluate whether the retrieved context is sufficient to answer the user's query. + +Query: {query} +Retrieved Context: +{context} + +Analyze the context and determine the next step. Return your response in JSON format with the following structure: +{{ + "status": "sufficient|missing_info|needs_raw", + "reasoning": "Brief explanation of your decision", + "missing_entities": ["entity1", "entity2"], + "new_search_query": "new search query", +}} + +Status definitions: +- "sufficient": Context fully answers the query +- "missing_info": Key information is missing (e.g., specific dates, locations, details) +- "needs_raw": Content is relevant but too summarized/vague, need original sources +- "new_search_query": New search query to retrieve more information + +Response:""" + +KEYWORD_EXTRACTION_PROMPT = """ +Analyze the user query and extract key search terms and identify optimal data sources. + +Query: {query} + +Extract: +1. Key search terms and concepts +2. Important entities (people, places, dates, etc.) +3. Suggested data sources or memory types to search + +Return response in JSON format: +{{ + "keywords": ["keyword1", "keyword2"], + "entities": ["entity1", "entity2"], + "search_strategy": "Brief strategy description" +}} + +Response:""" + + +FINAL_GENERATION_PROMPT = """ +You are a comprehensive information synthesizer. Generate a complete answer based on the retrieved information. + +User Query: {query} +Search Sources: {sources} +Retrieved Information: +{context} + +Missing Information (if any): {missing_info} + +Instructions: +1. Synthesize all relevant information to answer the query comprehensively +2. If information is missing, acknowledge gaps and suggest next steps +3. Maintain accuracy and cite sources when possible +4. Provide a well-structured, coherent response +5. Use natural, conversational tone + +Response:"""