diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index b734e913d..a01f14566 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -756,8 +756,24 @@ async def load_memory(self): shared_memory = await SharedMemoryManager.get_shared_memory( self.config, mem_instance_type) + + ignore_roles = getattr(_memory, 'ignore_roles', []) + shared_memory.should_early_add_after_task = ( + 'assistant' in ignore_roles and 'tool' in ignore_roles) + shared_memory.early_add_after_task_done = False + self.memory_tools.append(shared_memory) + def _schedule_add_memory_after_task(self, messages, timestamp=None): + + def _add_memory(): + asyncio.run( + self.add_memory( + messages, add_type='add_after_task', timestamp=timestamp)) + + loop = asyncio.get_running_loop() + loop.run_in_executor(None, _add_memory) + async def prepare_rag(self): """Load and initialize the RAG component from the config.""" if hasattr(self.config, 'rag'): @@ -886,7 +902,6 @@ async def step( messages = deepcopy(messages) messages = self._append_task_notifications(messages) if (not self.load_cache) or messages[-1].role != 'assistant': - messages = await self.condense_memory(messages) await self.on_generate_response(messages) tools = await self.tool_manager.get_tools() @@ -1093,27 +1108,40 @@ def _get_run_memory_info(self, memory_config: DictConfig): async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: - tools_num = len(self.memory_tools) if self.memory_tools else 0 - - for idx, (mem_instance_type, - memory_config) in enumerate(self.config.memory.items()): + for tool, (_, memory_config) in zip(self.memory_tools, + self.config.memory.items()): + timestamp = kwargs.get('timestamp', '') if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( memory_config) + should_early = getattr(tool, 'should_early_add_after_task', + False) + early_done = getattr(tool, 'early_add_after_task_done', + False) + + if timestamp == 'early': + if not (should_early and not early_done): + # pass memory tool.run + continue + tool.early_add_after_task_done = True + else: + if early_done: + # pass memory tool.run + continue + else: user_id, agent_id, run_id, memory_type = self._get_step_memory_info( memory_config) - if idx < tools_num: - if any(v is not None + if not any(v is not None for v in [user_id, agent_id, run_id, memory_type]): - await self.memory_tools[idx].add( - messages, - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - memory_type=memory_type, - ) + continue + await tool.add( + messages, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + memory_type=memory_type) def save_history(self, messages: List[Message], **kwargs): """ @@ -1220,6 +1248,10 @@ async def run_loop(self, messages: Union[List[Message], str], role='user', content='\n'.join(notifications))) if self._skill_runtime: self._skill_runtime.maybe_refresh_system_prompt(messages) + messages = await self.condense_memory(messages) + # If assistant and tool content can be ignored, add memory earlier to reduce running time. + self._schedule_add_memory_after_task( + messages, timestamp='early') async for messages in self.step(messages): messages = self._apply_pending_rollback(messages) yield messages @@ -1247,13 +1279,8 @@ async def run_loop(self, messages: Union[List[Message], str], await self.cleanup_tools() yield messages - def _add_memory(): - asyncio.run( - self.add_memory( - messages, add_type='add_after_task', **kwargs)) + self._schedule_add_memory_after_task(messages) - loop = asyncio.get_running_loop() - loop.run_in_executor(None, _add_memory) except Exception as e: import traceback diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index d84826de4..3f76a1bc8 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -588,8 +588,29 @@ def _init_memory_obj(self): f'Failed to import mem0: {e}. Please install mem0ai package via `pip install mem0ai`.' ) raise - + import mem0.vector_stores.milvus capture_event_origin = mem0.memory.main.capture_event + update_origin = mem0.vector_stores.milvus.MilvusDB.update + + @wraps(update_origin) + def update(self, vector_id=None, vector=None, payload=None): + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + if vector is None: + res = self.client.get( + collection_name=self.collection_name, ids=[vector_id]) + if res: + vector = res[0]['vectors'] + + schema = {'id': vector_id, 'vectors': vector, 'metadata': payload} + self.client.upsert( + collection_name=self.collection_name, data=schema) @wraps(capture_event_origin) def patched_capture_event(event_name, @@ -597,6 +618,7 @@ def patched_capture_event(event_name, additional_data=None): pass + mem0.vector_stores.milvus.MilvusDB.update = update mem0.memory.main.capture_event = partial(patched_capture_event, ) # emb config diff --git a/ms_agent/memory/memory_manager.py b/ms_agent/memory/memory_manager.py index 5d0a10210..4d31e2f92 100644 --- a/ms_agent/memory/memory_manager.py +++ b/ms_agent/memory/memory_manager.py @@ -1,7 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from omegaconf import DictConfig from typing import Dict +from omegaconf import DictConfig, OmegaConf + from ms_agent.memory import Memory, memory_mapping from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_USER @@ -17,10 +18,15 @@ class SharedMemoryManager: async def get_shared_memory(cls, config: DictConfig, mem_instance_type: str) -> Memory: """Get or create a shared memory instance based on configuration.""" - user_id: str = getattr(config, 'user_id', DEFAULT_USER) - path: str = getattr(config, 'path', DEFAULT_OUTPUT_DIR) - - key = f'{mem_instance_type}_{user_id}_{path}' + user_id: str = getattr( + getattr(config.memory, mem_instance_type, OmegaConf.create({})), + 'user_id', DEFAULT_USER) + path: str = getattr( + getattr(config.memory, mem_instance_type, OmegaConf.create({})), + 'path', DEFAULT_OUTPUT_DIR) + llm_str: str = getattr(config.llm, 'model', 'default_model') + + key = f'{mem_instance_type}_{user_id}_{llm_str}_{path}' if key not in cls._instances: logger.info(f'Creating new shared memory instance for key: {key}')