-
Notifications
You must be signed in to change notification settings - Fork 508
Feat/memory update #912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feat/memory update #912
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,6 +120,7 @@ venv.bak/ | |
|
|
||
| .vscode | ||
| .idea | ||
| .cursor | ||
|
|
||
| # custom | ||
| *.pkl | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,8 @@ | |
| from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping | ||
| from ms_agent.memory.memory_manager import SharedMemoryManager | ||
| from ms_agent.rag.base import RAG | ||
| from ms_agent.session import ContextAssembler, SessionLog | ||
| from ms_agent.session.strategies import SummaryCompactor, ToolOutputPruner | ||
| from ms_agent.rag.utils import rag_mapping | ||
| from ms_agent.tools import ToolManager | ||
| from ms_agent.utils import async_retry, read_history, save_history | ||
|
|
@@ -107,9 +109,11 @@ def __init__( | |
| self.tool_manager: Optional[ToolManager] = None | ||
| self.memory_tools: List[Memory] = [] | ||
| self.rag: Optional[RAG] = None | ||
| self.knowledge_search: Optional[SirschmunkSearch] = None | ||
| self.knowledge_search: Optional[SirchmunkSearch] = None | ||
| self.llm: Optional[LLM] = None | ||
| self.runtime: Optional[Runtime] = None | ||
| self.session_log: Optional[SessionLog] = None | ||
| self.context_assembler: Optional[ContextAssembler] = None | ||
| self.max_chat_round: int = 0 | ||
| self.load_cache = kwargs.get('load_cache', False) | ||
| self.config.load_cache = self.load_cache | ||
|
|
@@ -733,6 +737,11 @@ async def do_skill(self, | |
| async def load_memory(self): | ||
| """Initialize and append memory tool instances based on the configuration provided in the global config. | ||
|
|
||
| For ``unified_memory``, this also: | ||
| - Passes the agent's LLM instance to the orchestrator | ||
| - Registers the ``memory`` / ``memory_read`` tools into ToolManager | ||
| - Injects memory-usage guidance into the system prompt | ||
|
|
||
| Raises: | ||
| AssertionError: If a specified memory type in the config does not exist in memory_mapping. | ||
| """ | ||
|
|
@@ -747,6 +756,40 @@ async def load_memory(self): | |
| self.config, mem_instance_type) | ||
| self.memory_tools.append(shared_memory) | ||
|
|
||
| # Wire unified_memory into the tool system | ||
| if mem_instance_type == 'unified_memory': | ||
| await self._register_memory_tool(shared_memory) | ||
|
|
||
| async def _register_memory_tool(self, orchestrator): | ||
| """Register the memory tool into ToolManager and inject prompt guidance.""" | ||
| from ms_agent.memory.unified.memory_tool import MemoryTool, MEMORY_USAGE_PROMPT | ||
|
|
||
| if not hasattr(orchestrator, 'get_tool_schemas'): | ||
| return | ||
|
|
||
| # Pass LLM and session_log to orchestrator for consolidation / extraction | ||
| if self.llm is not None: | ||
| orchestrator.set_llm(self.llm) | ||
| orchestrator.init_update_queue() | ||
| if self.session_log is not None and hasattr(orchestrator, '_session_log'): | ||
| orchestrator._session_log = self.session_log | ||
|
|
||
| # Register memory tool into the agent's tool system | ||
| if self.tool_manager is not None: | ||
| mem_tool = MemoryTool(self.config, orchestrator) | ||
| self.tool_manager.register_tool(mem_tool) | ||
| await self.tool_manager.reindex_tool() | ||
| logger.info('[unified_memory] Memory tool registered') | ||
|
|
||
| # Inject usage guidance into system prompt | ||
| if hasattr(self.config, 'prompt') and hasattr(self.config.prompt, 'system'): | ||
| current_prompt = self.config.prompt.system or '' | ||
| if 'Long-term Memory' not in current_prompt: | ||
| OmegaConf.update( | ||
| self.config, 'prompt.system', | ||
| current_prompt + '\n\n' + MEMORY_USAGE_PROMPT, | ||
| merge=True) | ||
|
|
||
| async def prepare_rag(self): | ||
| """Load and initialize the RAG component from the config.""" | ||
| if hasattr(self.config, 'rag'): | ||
|
|
@@ -770,19 +813,135 @@ async def prepare_knowledge_search(self): | |
| self.config) | ||
|
|
||
| async def condense_memory(self, messages: List[Message]) -> List[Message]: | ||
| """Inject long-term memory context into messages. | ||
|
|
||
| .. deprecated:: | ||
| Historically this also ran context compressors. Compression is | ||
| now handled by :class:`ContextAssembler` before this method is | ||
| called. This method only performs memory *injection* (adding | ||
| ``<long-term-memory>`` blocks, etc.). | ||
| """ | ||
| Update memory using the current conversation history. | ||
| for memory_tool in self.memory_tools: | ||
| messages = await memory_tool.run(messages) | ||
| return messages | ||
|
|
||
| Args: | ||
| messages (List[Message]): Current message history. | ||
| async def inject_memory(self, messages: List[Message]) -> List[Message]: | ||
| """Inject long-term memory context into the message list. | ||
|
|
||
| Returns: | ||
| List[Message]: Possibly updated message history after memory refinement. | ||
| Unlike ``condense_memory`` this only runs ``unified_memory`` style | ||
| tools that *inject* context (MEMORY.md snapshot, facts, etc.) — it | ||
| never trims or compresses messages. | ||
| """ | ||
| for memory_tool in self.memory_tools: | ||
| messages = await memory_tool.run(messages) | ||
| return messages | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inject_memory 与 condense_memory 实现相同 |
||
|
|
||
| def _init_session_log(self) -> None: | ||
| """Create SessionLog and ContextAssembler if session logging is enabled.""" | ||
| session_cfg = getattr(self.config, 'session_log', None) | ||
| enabled = getattr(session_cfg, 'enabled', True) if session_cfg else True | ||
| if not enabled: | ||
| return | ||
|
|
||
| session_dir = getattr( | ||
| session_cfg, 'dir', None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yaml的配置示例文档缺乏。 |
||
| ) if session_cfg else None | ||
| if session_dir is None: | ||
| session_dir = os.path.join( | ||
| getattr(self.config, 'output_dir', 'output'), | ||
| 'sessions', | ||
| ) | ||
|
|
||
| session_key = getattr(session_cfg, 'session_key', None) if session_cfg else None | ||
| self.session_log = SessionLog(session_dir, session_key=session_key) | ||
|
|
||
| compaction_cfg = getattr(self.config, 'compaction', None) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yaml的配置示例文档缺乏。 |
||
| compaction_enabled = ( | ||
| getattr(compaction_cfg, 'enabled', True) if compaction_cfg else True | ||
| ) | ||
|
|
||
| if not compaction_enabled: | ||
| self.context_assembler = ContextAssembler( | ||
| session_log=self.session_log, strategies=[], config={}, | ||
| ) | ||
| return | ||
|
|
||
| strategies = self._build_compaction_strategies(compaction_cfg) | ||
| assembler_config = self._build_assembler_config(compaction_cfg, session_cfg) | ||
| flush_callback = self._make_memory_flush_callback() | ||
|
|
||
| self.context_assembler = ContextAssembler( | ||
| session_log=self.session_log, | ||
| strategies=strategies, | ||
| config=assembler_config, | ||
| memory_flush_callback=flush_callback, | ||
| ) | ||
|
|
||
| def _build_compaction_strategies(self, compaction_cfg): | ||
| """Build the strategy list from YAML ``compaction.strategies``.""" | ||
| if compaction_cfg and hasattr(compaction_cfg, 'strategies'): | ||
| strategies = [] | ||
| for s_cfg in compaction_cfg.strategies: | ||
| name = getattr(s_cfg, 'name', '') | ||
| if not getattr(s_cfg, 'enabled', True): | ||
| continue | ||
| if name == 'tool_output_pruner': | ||
| strategies.append(ToolOutputPruner()) | ||
| elif name == 'summary_compactor': | ||
| strategies.append(SummaryCompactor(llm=self.llm)) | ||
| else: | ||
| logger.warning(f"Unknown compaction strategy: {name}") | ||
| return strategies | ||
|
|
||
| return [ToolOutputPruner(), SummaryCompactor(llm=self.llm)] | ||
|
|
||
| def _build_assembler_config(self, compaction_cfg, session_cfg): | ||
| """Merge compaction params from ``compaction`` and ``session_log``.""" | ||
| config: Dict[str, Any] = {} | ||
|
|
||
| if session_cfg: | ||
| for key in ('context_limit', 'reserved_buffer', 'prune_protect'): | ||
| val = getattr(session_cfg, key, None) | ||
| if val is not None: | ||
| config[key] = val | ||
|
|
||
| if compaction_cfg: | ||
| for key in ('context_limit', 'reserved_buffer'): | ||
| val = getattr(compaction_cfg, key, None) | ||
| if val is not None: | ||
| config[key] = val | ||
| if hasattr(compaction_cfg, 'strategies'): | ||
| for s_cfg in compaction_cfg.strategies: | ||
| if getattr(s_cfg, 'name', '') == 'tool_output_pruner': | ||
| pp = getattr(s_cfg, 'prune_protect', None) | ||
| if pp is not None: | ||
| config['prune_protect'] = pp | ||
|
|
||
| config.setdefault('context_limit', 128000) | ||
| config.setdefault('reserved_buffer', 20000) | ||
| config.setdefault('prune_protect', 40000) | ||
| return config | ||
|
|
||
| def _make_memory_flush_callback(self): | ||
| """Create a callback that flushes memory before context compaction.""" | ||
| def _flush(discarded_messages): | ||
| for memory_tool in self.memory_tools: | ||
| orchestrator = memory_tool | ||
| if hasattr(orchestrator, 'flush'): | ||
| import asyncio | ||
| from ms_agent.llm.utils import Message as _Msg | ||
| msgs = [_Msg( | ||
| role=m.get('role', 'user'), | ||
| content=m.get('content', ''), | ||
| tool_calls=m.get('tool_calls'), | ||
| ) for m in discarded_messages] | ||
| try: | ||
| loop = asyncio.get_running_loop() | ||
| loop.create_task(orchestrator.flush(msgs)) | ||
| except RuntimeError: | ||
| asyncio.run(orchestrator.flush(msgs)) | ||
| return _flush | ||
|
|
||
| def log_output(self, content: Union[str, list]): | ||
| """ | ||
| Log formatted output with a tag prefix. | ||
|
|
@@ -1089,6 +1248,31 @@ def save_history(self, messages: List[Message], **kwargs): | |
| save_history( | ||
| self.output_dir, task=self.tag, config=config, messages=messages) | ||
|
|
||
| @staticmethod | ||
| def _msg_to_dict(msg: Message) -> Dict[str, Any]: | ||
| """Convert a Message to a plain dict for SessionLog. | ||
|
|
||
| Preserves ``prompt_tokens`` and ``completion_tokens`` individually | ||
| so that :class:`ContextAssembler` strategies can leverage API-reported | ||
| usage data for accurate overflow detection. | ||
| """ | ||
| d: Dict[str, Any] = {'role': msg.role, 'content': msg.content or ''} | ||
| if msg.tool_calls: | ||
| d['tool_calls'] = msg.tool_calls | ||
| if hasattr(msg, 'tool_call_id') and msg.tool_call_id: | ||
| d['tool_call_id'] = msg.tool_call_id | ||
| if hasattr(msg, 'name') and msg.name: | ||
| d['name'] = msg.name | ||
| prompt_tokens = int(getattr(msg, 'prompt_tokens', 0) or 0) | ||
| completion_tokens = int(getattr(msg, 'completion_tokens', 0) or 0) | ||
| if prompt_tokens: | ||
| d['prompt_tokens'] = prompt_tokens | ||
| if completion_tokens: | ||
| d['completion_tokens'] = completion_tokens | ||
| if prompt_tokens or completion_tokens: | ||
| d['tokens'] = prompt_tokens + completion_tokens | ||
| return d | ||
|
|
||
| async def run_loop(self, messages: Union[List[Message], str], | ||
| **kwargs) -> AsyncGenerator[Any, Any]: | ||
| """ | ||
|
|
@@ -1112,13 +1296,28 @@ async def run_loop(self, messages: Union[List[Message], str], | |
| await self.load_memory() | ||
| await self.prepare_rag() | ||
| await self.prepare_knowledge_search() | ||
| self._init_session_log() | ||
| self.runtime.tag = self.tag | ||
|
|
||
| if messages is None: | ||
| messages = self.query | ||
|
|
||
| # Load history and restore state | ||
| self.config, self.runtime, messages = self.read_history(messages) | ||
| if self.session_log is not None: | ||
| restored = self.session_log.get_all_messages() | ||
| if restored and self.load_cache: | ||
| from ms_agent.llm.utils import Message as _Msg | ||
| messages = [_Msg( | ||
| role=m.get('role', 'user'), | ||
| content=m.get('content', ''), | ||
| tool_calls=m.get('tool_calls'), | ||
| ) for m in restored] | ||
|
Comment on lines
+1310
to
+1314
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When restoring messages from the session log, the messages = [_Msg(
role=m.get('role', 'user'),
content=m.get('content', ''),
tool_calls=m.get('tool_calls'),
tool_call_id=m.get('tool_call_id'),
name=m.get('name'),
) for m in restored] |
||
| else: | ||
| self.config, self.runtime, messages = self.read_history( | ||
| messages) | ||
| else: | ||
| self.config, self.runtime, messages = self.read_history( | ||
| messages) | ||
|
|
||
| if self.runtime.round == 0: | ||
| # New task: create standardized messages first | ||
|
|
@@ -1137,14 +1336,30 @@ async def run_loop(self, messages: Union[List[Message], str], | |
| await self.do_rag(messages) | ||
| await self.on_task_begin(messages) | ||
|
|
||
| # Seed SessionLog with initial messages | ||
| if self.session_log is not None: | ||
| for msg in messages: | ||
| self.session_log.append(self._msg_to_dict(msg)) | ||
|
|
||
| for message in messages: | ||
| if message.role != 'system': | ||
| self.log_output('[' + message.role + ']:') | ||
| self.log_output(message.content) | ||
| while not self.runtime.should_stop: | ||
| # Rebuild context view from SessionLog (non-destructive compression) | ||
| if self.context_assembler is not None and self.runtime.round > 0: | ||
| messages = self.context_assembler.assemble() | ||
|
|
||
| pre_step_len = len(messages) | ||
| async for messages in self.step(messages): | ||
| yield messages | ||
| self.runtime.round += 1 | ||
|
|
||
| # Append new messages to SessionLog | ||
| if self.session_log is not None: | ||
| for msg in messages[pre_step_len:]: | ||
| self.session_log.append(self._msg_to_dict(msg)) | ||
|
|
||
| # save memory and history | ||
| await self.add_memory( | ||
| messages, add_type='add_after_step', **kwargs) | ||
|
|
@@ -1153,13 +1368,17 @@ async def run_loop(self, messages: Union[List[Message], str], | |
| # +1 means the next round the assistant may give a conclusion | ||
| if self.runtime.round >= self.max_chat_round + 1: | ||
| if not self.runtime.should_stop: | ||
| messages.append( | ||
| Message( | ||
| role='assistant', | ||
| content= | ||
| f'Task {messages[1].content} was cutted off, because ' | ||
| f'max round({self.max_chat_round}) exceeded.', | ||
| )) | ||
| cutoff_msg = Message( | ||
| role='assistant', | ||
| content= | ||
| f'Task {messages[1].content} was cutted off, because ' | ||
| f'max round({self.max_chat_round}) exceeded.', | ||
| ) | ||
| messages.append(cutoff_msg) | ||
| if self.session_log is not None: | ||
| self.session_log.append( | ||
| self._msg_to_dict(cutoff_msg)) | ||
| self.save_history(messages) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. save_history和session_log是否有功能重叠,能否直接合并? |
||
| self.runtime.should_stop = True | ||
| yield messages | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| """Unified Memory — a protocol-driven, backend-pluggable memory system. | ||
|
|
||
| Register ``unified_memory`` in ``memory_mapping`` to use this system. | ||
|
|
||
| Architecture:: | ||
|
|
||
| Orchestrator --delegates-to--> MemoryBackend (Protocol) | ||
| | | ||
| +----------------+----------------+ | ||
| v v v | ||
| FileBasedBackend ReMeBackend MempalaceBackend ... | ||
| (built-in) (adapter) (adapter) | ||
|
|
||
| Switch backends via YAML config:: | ||
|
|
||
| storage: | ||
| backend: "file" # or "reme", "mempalace", "mem0", "byterover", "supermemory" | ||
| """ | ||
| from .config import MemoryConfig | ||
| from .orchestrator import MemoryOrchestrator | ||
| from .protocols import ( | ||
| BaseMemoryBackend, | ||
| MemoryBackend, | ||
| MemoryEntry, | ||
| MemoryEvent, | ||
| MemoryEventBus, | ||
| MemoryNamespace, | ||
| ) | ||
| from .registry import backend_registry | ||
|
|
||
| # Import backends so they self-register | ||
| from .backends import file_based as _fb # noqa: F401 | ||
|
|
||
| __all__ = [ | ||
| "MemoryConfig", | ||
| "MemoryOrchestrator", | ||
| # Layer 2 — primary contract | ||
| "MemoryBackend", | ||
| "BaseMemoryBackend", | ||
| "backend_registry", | ||
| # Layer 1 — data structures | ||
| "MemoryEntry", | ||
| "MemoryEvent", | ||
| "MemoryEventBus", | ||
| "MemoryNamespace", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llm_agent.py
Lines 1296-1299
await self.load_memory()
...
self._init_session_log()
load_memory先于init_session_log被调用,因此load_memory时self.session_log里的值是否未被初始化