diff --git a/src/memos/mem_scheduler/base_mixins/__init__.py b/src/memos/mem_scheduler/base_mixins/__init__.py new file mode 100644 index 000000000..7e01cffc0 --- /dev/null +++ b/src/memos/mem_scheduler/base_mixins/__init__.py @@ -0,0 +1,10 @@ +from .memory_ops import BaseSchedulerMemoryMixin +from .queue_ops import BaseSchedulerQueueMixin +from .web_log_ops import BaseSchedulerWebLogMixin + + +__all__ = [ + "BaseSchedulerMemoryMixin", + "BaseSchedulerQueueMixin", + "BaseSchedulerWebLogMixin", +] diff --git a/src/memos/mem_scheduler/base_mixins/memory_ops.py b/src/memos/mem_scheduler/base_mixins/memory_ops.py new file mode 100644 index 000000000..87f284898 --- /dev/null +++ b/src/memos/mem_scheduler/base_mixins/memory_ops.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.memories.activation.kv import KVCacheMemory +from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory +from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE + + +if TYPE_CHECKING: + from memos.types.general_types import MemCubeID, UserID + + +logger = get_logger(__name__) + + +class BaseSchedulerMemoryMixin: + def transform_working_memories_to_monitors( + self, query_keywords, memories: list[TextualMemoryItem] + ) -> list[MemoryMonitorItem]: + result = [] + mem_length = len(memories) + for idx, mem in enumerate(memories): + text_mem = mem.memory + mem_key = transform_name_to_key(name=text_mem) + + keywords_score = 0 + if query_keywords and text_mem: + for keyword, count in query_keywords.items(): + keyword_count = text_mem.count(keyword) + if keyword_count > 0: + keywords_score += keyword_count * count + logger.debug( + "Matched keyword '%s' %s times, added %s to keywords_score", + keyword, + keyword_count, + keywords_score, + ) + + sorting_score = mem_length - idx + + mem_monitor = MemoryMonitorItem( + memory_text=text_mem, + tree_memory_item=mem, + tree_memory_item_mapping_key=mem_key, + sorting_score=sorting_score, + keywords_score=keywords_score, + recording_count=1, + ) + result.append(mem_monitor) + + logger.info("Transformed %s memories to monitors", len(result)) + return result + + def replace_working_memory( + self, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube, + original_memory: list[TextualMemoryItem], + new_memory: list[TextualMemoryItem], + ) -> None | list[TextualMemoryItem]: + text_mem_base = mem_cube.text_mem + if isinstance(text_mem_base, TreeTextMemory): + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.sync_with_orm() + + query_history = query_db_manager.obj.get_queries_with_timesort() + + original_count = len(original_memory) + filtered_original_memory = [] + for origin_mem in original_memory: + if "mode:fast" not in origin_mem.metadata.tags: + filtered_original_memory.append(origin_mem) + else: + logger.debug( + "Filtered out memory - ID: %s, Tags: %s", + getattr(origin_mem, "id", "unknown"), + origin_mem.metadata.tags, + ) + filtered_count = original_count - len(filtered_original_memory) + remaining_count = len(filtered_original_memory) + + logger.info( + "Filtering complete. Removed %s memories with tag 'mode:fast'. Remaining memories: %s", + filtered_count, + remaining_count, + ) + original_memory = filtered_original_memory + + memories_with_new_order, rerank_success_flag = ( + self.retriever.process_and_rerank_memories( + queries=query_history, + original_memory=original_memory, + new_memory=new_memory, + top_k=self.top_k, + ) + ) + + logger.info("Filtering memories based on query history: %s queries", len(query_history)) + filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, + memories=memories_with_new_order, + ) + + if filter_success_flag: + logger.info( + "Memory filtering completed successfully. Filtered from %s to %s memories", + len(memories_with_new_order), + len(filtered_memories), + ) + memories_with_new_order = filtered_memories + else: + logger.warning( + "Memory filtering failed - keeping all memories as fallback. Original count: %s", + len(memories_with_new_order), + ) + + query_keywords = query_db_manager.obj.get_keywords_collections() + logger.info( + "Processing %s memories with %s query keywords", + len(memories_with_new_order), + len(query_keywords), + ) + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=memories_with_new_order, + ) + + if not rerank_success_flag: + for one in new_working_memory_monitors: + one.sorting_score = 0 + + logger.info("update %s working_memory_monitors", len(new_working_memory_monitors)) + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][ + mem_cube_id + ].obj.get_sorted_mem_monitors(reverse=True) + new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors] + + text_mem_base.replace_working_memory(memories=new_working_memories) + + logger.info( + "The working memory has been replaced with %s new memories.", + len(memories_with_new_order), + ) + self.log_working_memory_replacement( + original_memory=original_memory, + new_memory=new_working_memories, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + log_func_callback=self._submit_web_logs, + ) + elif isinstance(text_mem_base, NaiveTextMemory): + logger.info( + "NaiveTextMemory: Updating working memory monitors with %s candidates.", + len(new_memory), + ) + + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.sync_with_orm() + query_keywords = query_db_manager.obj.get_keywords_collections() + + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=new_memory, + ) + + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + memories_with_new_order = new_memory + else: + logger.error("memory_base is not supported") + memories_with_new_order = new_memory + + return memories_with_new_order + + def update_activation_memory( + self, + new_memories: list[str | TextualMemoryItem], + label: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube, + ) -> None: + if len(new_memories) == 0: + logger.error("update_activation_memory: new_memory is empty.") + return + if isinstance(new_memories[0], TextualMemoryItem): + new_text_memories = [mem.memory for mem in new_memories] + elif isinstance(new_memories[0], str): + new_text_memories = new_memories + else: + logger.error("Not Implemented.") + return + + try: + if isinstance(mem_cube.act_mem, VLLMKVCacheMemory): + act_mem: VLLMKVCacheMemory = mem_cube.act_mem + elif isinstance(mem_cube.act_mem, KVCacheMemory): + act_mem = mem_cube.act_mem + else: + logger.error("Not Implemented.") + return + + new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format( + memory_text="".join( + [ + f"{i + 1}. {sentence.strip()}\n" + for i, sentence in enumerate(new_text_memories) + if sentence.strip() + ] + ) + ) + + original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all() + original_text_memories = [] + if len(original_cache_items) > 0: + pre_cache_item: VLLMKVCacheItem = original_cache_items[-1] + original_text_memories = pre_cache_item.records.text_memories + original_composed_text_memory = pre_cache_item.records.composed_text_memory + if original_composed_text_memory == new_text_memory: + logger.warning( + "Skipping memory update - new composition matches existing cache: %s", + new_text_memory[:50] + "..." + if len(new_text_memory) > 50 + else new_text_memory, + ) + return + act_mem.delete_all() + + cache_item = act_mem.extract(new_text_memory) + cache_item.records.text_memories = new_text_memories + cache_item.records.timestamp = get_utc_now() + + act_mem.add([cache_item]) + act_mem.dump(self.act_mem_dump_path) + + self.log_activation_memory_update( + original_text_memories=original_text_memories, + new_text_memories=new_text_memories, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + log_func_callback=self._submit_web_logs, + ) + + except Exception as e: + logger.error("MOS-based activation memory update failed: %s", e, exc_info=True) + + def update_activation_memory_periodically( + self, + interval_seconds: int, + label: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube, + ): + try: + if ( + self.monitor.last_activation_mem_update_time == datetime.min + or self.monitor.timed_trigger( + last_time=self.monitor.last_activation_mem_update_time, + interval_seconds=interval_seconds, + ) + ): + logger.info( + "Updating activation memory for user %s and mem_cube %s", + user_id, + mem_cube_id, + ) + + if ( + user_id not in self.monitor.working_memory_monitors + or mem_cube_id not in self.monitor.working_memory_monitors[user_id] + or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories) + == 0 + ): + logger.warning( + "No memories found in working_memory_monitors, activation memory update is skipped" + ) + return + + self.monitor.update_activation_memory_monitors( + user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube + ) + + activation_db_manager = self.monitor.activation_memory_monitors[user_id][ + mem_cube_id + ] + activation_db_manager.sync_with_orm() + new_activation_memories = [ + m.memory_text for m in activation_db_manager.obj.memories + ] + + logger.info( + "Collected %s new memory entries for processing", + len(new_activation_memories), + ) + for i, memory in enumerate(new_activation_memories[:5], 1): + logger.info( + "Part of New Activation Memories | %s/%s: %s", + i, + len(new_activation_memories), + memory[:20], + ) + + self.update_activation_memory( + new_memories=new_activation_memories, + label=label, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + self.monitor.last_activation_mem_update_time = get_utc_now() + + logger.debug( + "Activation memory update completed at %s", + self.monitor.last_activation_mem_update_time, + ) + + else: + logger.info( + "Skipping update - %s second interval not yet reached. Last update time is %s and now is %s", + interval_seconds, + self.monitor.last_activation_mem_update_time, + get_utc_now(), + ) + except Exception as e: + logger.error("Error in update_activation_memory_periodically: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/base_mixins/queue_ops.py b/src/memos/mem_scheduler/base_mixins/queue_ops.py new file mode 100644 index 000000000..e5709ff36 --- /dev/null +++ b/src/memos/mem_scheduler/base_mixins/queue_ops.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +import multiprocessing +import time + +from contextlib import suppress +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +from memos.context.context import ( + ContextThread, + RequestContext, + get_current_context, + get_current_trace_id, + set_request_context, +) +from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import STARTUP_BY_PROCESS +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import TaskPriorityLevel +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube +from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from collections.abc import Callable + + +class BaseSchedulerQueueMixin: + def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + if isinstance(messages, ScheduleMessageItem): + messages = [messages] + + if not messages: + return + + current_trace_id = get_current_trace_id() + + immediate_msgs: list[ScheduleMessageItem] = [] + queued_msgs: list[ScheduleMessageItem] = [] + + for msg in messages: + if current_trace_id: + msg.trace_id = current_trace_id + + with suppress(Exception): + self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label) + + if getattr(msg, "timestamp", None) is None: + msg.timestamp = get_utc_now() + + if self.status_tracker: + try: + self.status_tracker.task_submitted( + task_id=msg.item_id, + user_id=msg.user_id, + task_type=msg.label, + mem_cube_id=msg.mem_cube_id, + business_task_id=msg.task_id, + ) + except Exception: + logger.warning("status_tracker.task_submitted failed", exc_info=True) + + if self.disabled_handlers and msg.label in self.disabled_handlers: + logger.info("Skipping disabled handler: %s - %s", msg.label, msg.content) + continue + + task_priority = self.orchestrator.get_task_priority(task_label=msg.label) + if task_priority == TaskPriorityLevel.LEVEL_1: + immediate_msgs.append(msg) + else: + queued_msgs.append(msg) + + if immediate_msgs: + for m in immediate_msgs: + emit_monitor_event( + "enqueue", + m, + { + "enqueue_ts": to_iso(getattr(m, "timestamp", None)), + "event_duration_ms": 0, + "total_duration_ms": 0, + }, + ) + + for m in immediate_msgs: + try: + now = time.time() + enqueue_ts_obj = getattr(m, "timestamp", None) + enqueue_epoch = None + if isinstance(enqueue_ts_obj, int | float): + enqueue_epoch = float(enqueue_ts_obj) + elif hasattr(enqueue_ts_obj, "timestamp"): + dt = enqueue_ts_obj + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + enqueue_epoch = dt.timestamp() + + queue_wait_ms = None + if enqueue_epoch is not None: + queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 + + object.__setattr__(m, "_dequeue_ts", now) + emit_monitor_event( + "dequeue", + m, + { + "enqueue_ts": to_iso(enqueue_ts_obj), + "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(), + "queue_wait_ms": queue_wait_ms, + "event_duration_ms": queue_wait_ms, + "total_duration_ms": queue_wait_ms, + }, + ) + self.metrics.task_dequeued(user_id=m.user_id, task_type=m.label) + except Exception: + logger.debug("Failed to emit dequeue for immediate task", exc_info=True) + + user_cube_groups = group_messages_by_user_and_mem_cube(immediate_msgs) + for user_id, cube_groups in user_cube_groups.items(): + for mem_cube_id, user_cube_msgs in cube_groups.items(): + label_groups: dict[str, list[ScheduleMessageItem]] = {} + for m in user_cube_msgs: + label_groups.setdefault(m.label, []).append(m) + + for label, msgs_by_label in label_groups.items(): + handler = self.dispatcher.handlers.get( + label, self.dispatcher._default_message_handler + ) + self.dispatcher.execute_task( + user_id=user_id, + mem_cube_id=mem_cube_id, + task_label=label, + msgs=msgs_by_label, + handler_call_back=handler, + ) + + if queued_msgs: + self.memos_message_queue.submit_messages(messages=queued_msgs) + + def _message_consumer(self) -> None: + while self._running: + try: + if self.enable_parallel_dispatch and self.dispatcher: + running_tasks = self.dispatcher.get_running_task_count() + if running_tasks >= self.dispatcher.max_workers: + time.sleep(self._consume_interval) + continue + + messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch) + + if messages: + now = time.time() + for msg in messages: + prev_context = get_current_context() + try: + msg_context = RequestContext( + trace_id=msg.trace_id, + user_name=msg.user_name, + ) + set_request_context(msg_context) + + enqueue_ts_obj = getattr(msg, "timestamp", None) + enqueue_epoch = None + if isinstance(enqueue_ts_obj, int | float): + enqueue_epoch = float(enqueue_ts_obj) + elif hasattr(enqueue_ts_obj, "timestamp"): + dt = enqueue_ts_obj + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + enqueue_epoch = dt.timestamp() + + queue_wait_ms = None + if enqueue_epoch is not None: + queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 + + object.__setattr__(msg, "_dequeue_ts", now) + emit_monitor_event( + "dequeue", + msg, + { + "enqueue_ts": to_iso(enqueue_ts_obj), + "dequeue_ts": datetime.fromtimestamp( + now, tz=timezone.utc + ).isoformat(), + "queue_wait_ms": queue_wait_ms, + "event_duration_ms": queue_wait_ms, + "total_duration_ms": queue_wait_ms, + }, + ) + self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) + finally: + set_request_context(prev_context) + try: + with suppress(Exception): + if messages: + self.dispatcher.on_messages_enqueued(messages) + + self.dispatcher.dispatch(messages) + except Exception as e: + logger.error("Error dispatching messages: %s", e) + + time.sleep(self._consume_interval) + + except Exception as e: + if "No messages available in Redis queue" not in str(e): + logger.error("Unexpected error in message consumer: %s", e, exc_info=True) + time.sleep(self._consume_interval) + + def _monitor_loop(self): + while self._running: + try: + q_sizes = self.memos_message_queue.qsize() + + if not isinstance(q_sizes, dict): + continue + + for stream_key, queue_length in q_sizes.items(): + if stream_key == "total_size": + continue + + parts = stream_key.split(":") + if len(parts) >= 3: + user_id = parts[-3] + self.metrics.update_queue_length(queue_length, user_id) + else: + if ":" not in stream_key: + self.metrics.update_queue_length(queue_length, stream_key) + + except Exception as e: + logger.error("Error in metrics monitor loop: %s", e, exc_info=True) + + time.sleep(15) + + def start(self) -> None: + if self.enable_parallel_dispatch: + logger.info( + "Initializing dispatcher thread pool with %s workers", + self.thread_pool_max_workers, + ) + + self.start_consumer() + self.start_background_monitor() + + def start_background_monitor(self): + if self._monitor_thread and self._monitor_thread.is_alive(): + return + self._monitor_thread = ContextThread( + target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor" + ) + self._monitor_thread.start() + logger.info("Scheduler metrics monitor thread started.") + + def start_consumer(self) -> None: + if self._running: + logger.warning("Memory Scheduler consumer is already running") + return + + self._running = True + + if self.scheduler_startup_mode == STARTUP_BY_PROCESS: + self._consumer_process = multiprocessing.Process( + target=self._message_consumer, + daemon=True, + name="MessageConsumerProcess", + ) + self._consumer_process.start() + logger.info("Message consumer process started") + else: + self._consumer_thread = ContextThread( + target=self._message_consumer, + daemon=True, + name="MessageConsumerThread", + ) + self._consumer_thread.start() + logger.info("Message consumer thread started") + + def stop_consumer(self) -> None: + if not self._running: + logger.warning("Memory Scheduler consumer is not running") + return + + self._running = False + + if self.scheduler_startup_mode == STARTUP_BY_PROCESS and self._consumer_process: + if self._consumer_process.is_alive(): + self._consumer_process.join(timeout=5.0) + if self._consumer_process.is_alive(): + logger.warning("Consumer process did not stop gracefully, terminating...") + self._consumer_process.terminate() + self._consumer_process.join(timeout=2.0) + if self._consumer_process.is_alive(): + logger.error("Consumer process could not be terminated") + else: + logger.info("Consumer process terminated") + else: + logger.info("Consumer process stopped") + self._consumer_process = None + elif self._consumer_thread and self._consumer_thread.is_alive(): + self._consumer_thread.join(timeout=5.0) + if self._consumer_thread.is_alive(): + logger.warning("Consumer thread did not stop gracefully") + else: + logger.info("Consumer thread stopped") + self._consumer_thread = None + + logger.info("Memory Scheduler consumer stopped") + + def stop(self) -> None: + if not self._running: + logger.warning("Memory Scheduler is not running") + return + + self.stop_consumer() + + if self._monitor_thread: + self._monitor_thread.join(timeout=2.0) + + if self.dispatcher: + logger.info("Shutting down dispatcher...") + self.dispatcher.shutdown() + + if self.dispatcher_monitor: + logger.info("Shutting down monitor...") + self.dispatcher_monitor.stop() + + @property + def handlers(self) -> dict[str, Callable]: + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, returning empty handlers dict") + return {} + + return self.dispatcher.handlers + + def register_handlers( + self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]] + ) -> None: + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, cannot register handlers") + return + + self.dispatcher.register_handlers(handlers) + + def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, cannot unregister handlers") + return dict.fromkeys(labels, False) + + return self.dispatcher.unregister_handlers(labels) + + def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]: + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, returning empty tasks dict") + return {} + + running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func) + + result = {} + for task_id, task_item in running_tasks.items(): + result[task_id] = { + "item_id": task_item.item_id, + "user_id": task_item.user_id, + "mem_cube_id": task_item.mem_cube_id, + "task_info": task_item.task_info, + "task_name": task_item.task_name, + "start_time": task_item.start_time, + "end_time": task_item.end_time, + "status": task_item.status, + "result": task_item.result, + "error_message": task_item.error_message, + "messages": task_item.messages, + } + + return result + + def get_tasks_status(self): + return self.task_schedule_monitor.get_tasks_status() + + def print_tasks_status(self, tasks_status: dict | None = None) -> None: + self.task_schedule_monitor.print_tasks_status(tasks_status=tasks_status) + + def _gather_queue_stats(self) -> dict: + memos_message_queue = self.memos_message_queue.memos_message_queue + stats: dict[str, int | float | str] = {} + stats["use_redis_queue"] = bool(self.use_redis_queue) + if not self.use_redis_queue: + try: + stats["qsize"] = int(memos_message_queue.qsize()) + except Exception: + stats["qsize"] = -1 + try: + stats["unfinished_tasks"] = int( + getattr(memos_message_queue, "unfinished_tasks", 0) or 0 + ) + except Exception: + stats["unfinished_tasks"] = -1 + stats["maxsize"] = int(self.max_internal_message_queue_size) + try: + maxsize = int(self.max_internal_message_queue_size) or 1 + qsize = int(stats.get("qsize", 0)) + stats["utilization"] = min(1.0, max(0.0, qsize / maxsize)) + except Exception: + stats["utilization"] = 0.0 + try: + d_stats = self.dispatcher.stats() + stats.update( + { + "running": int(d_stats.get("running", 0)), + "inflight": int(d_stats.get("inflight", 0)), + "handlers": int(d_stats.get("handlers", 0)), + } + ) + except Exception: + stats.update({"running": 0, "inflight": 0, "handlers": 0}) + return stats diff --git a/src/memos/mem_scheduler/base_mixins/web_log_ops.py b/src/memos/mem_scheduler/base_mixins/web_log_ops.py new file mode 100644 index 000000000..64b5348d3 --- /dev/null +++ b/src/memos/mem_scheduler/base_mixins/web_log_ops.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_ARCHIVE_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, +) + + +logger = get_logger(__name__) + + +class BaseSchedulerWebLogMixin: + def _submit_web_logs( + self, + messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem], + additional_log_info: str | None = None, + ) -> None: + if isinstance(messages, ScheduleLogForWebItem): + messages = [messages] + + for message in messages: + if self.rabbitmq_config is None: + return + try: + logger.info( + "[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish %s", + message.model_dump_json(indent=2), + ) + self.rabbitmq_publish_message(message=message.to_dict()) + logger.info( + "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched item_id=%s task_id=%s label=%s", + message.item_id, + message.task_id, + message.label, + ) + except Exception as e: + logger.error( + "[DIAGNOSTIC] base_scheduler._submit_web_logs failed: %s", + e, + exc_info=True, + ) + + logger.debug( + "%s submitted. %s in queue. additional_log_info: %s", + len(messages), + self._web_log_message_queue.qsize(), + additional_log_info, + ) + + def get_web_log_messages(self) -> list[dict]: + raw_items: list[ScheduleLogForWebItem] = [] + while True: + try: + raw_items.append(self._web_log_message_queue.get_nowait()) + except Exception: + break + + def _map_label(label: str) -> str: + mapping = { + QUERY_TASK_LABEL: "addMessage", + ANSWER_TASK_LABEL: "addMessage", + ADD_TASK_LABEL: "addMemory", + MEM_UPDATE_TASK_LABEL: "updateMemory", + MEM_ORGANIZE_TASK_LABEL: "mergeMemory", + MEM_ARCHIVE_TASK_LABEL: "archiveMemory", + } + return mapping.get(label, label) + + def _normalize_item(item: ScheduleLogForWebItem) -> dict: + data = item.to_dict() + data["label"] = _map_label(data.get("label")) + memcube_content = getattr(item, "memcube_log_content", None) or [] + metadata = getattr(item, "metadata", None) or [] + + memcube_name = getattr(item, "memcube_name", None) + if not memcube_name and hasattr(self, "_map_memcube_name"): + memcube_name = self._map_memcube_name(item.mem_cube_id) + data["memcube_name"] = memcube_name + + memory_len = getattr(item, "memory_len", None) + if memory_len is None: + if data["label"] == "mergeMemory": + memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"]) + elif memcube_content: + memory_len = len(memcube_content) + else: + memory_len = 1 if item.log_content else 0 + + data["memcube_log_content"] = memcube_content + data["memory_len"] = memory_len + + def _with_memory_time(meta: dict) -> dict: + enriched = dict(meta) + if "memory_time" not in enriched: + enriched["memory_time"] = enriched.get("updated_at") or enriched.get( + "update_at" + ) + return enriched + + data["metadata"] = [_with_memory_time(m) for m in metadata] + data["log_title"] = "" + return data + + return [_normalize_item(it) for it in raw_items] diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 5ab524128..2cb104343 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1,29 +1,18 @@ -import multiprocessing +from __future__ import annotations + import os import threading -import time -from collections.abc import Callable -from contextlib import suppress -from datetime import datetime, timezone from pathlib import Path -from typing import TYPE_CHECKING, Union - -from sqlalchemy.engine import Engine +from typing import TYPE_CHECKING from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig -from memos.context.context import ( - ContextThread, - RequestContext, - get_current_context, - get_current_trace_id, - set_request_context, -) -from memos.llms.base import BaseLLM from memos.log import get_logger -from memos.mem_cube.base import BaseMemCube -from memos.mem_cube.general import GeneralMemCube -from memos.mem_feedback.simple_feedback import SimpleMemFeedback +from memos.mem_scheduler.base_mixins import ( + BaseSchedulerMemoryMixin, + BaseSchedulerQueueMixin, + BaseSchedulerWebLogMixin, +) from memos.mem_scheduler.general_modules.init_components_for_scheduler import init_components from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule @@ -42,58 +31,43 @@ DEFAULT_THREAD_POOL_MAX_WORKERS, DEFAULT_TOP_K, DEFAULT_USE_REDIS_QUEUE, - STARTUP_BY_PROCESS, TreeTextMemory_SEARCH_METHOD, ) -from memos.mem_scheduler.schemas.message_schemas import ( - ScheduleLogForWebItem, - ScheduleMessageItem, -) -from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem -from memos.mem_scheduler.schemas.task_schemas import ( - ADD_TASK_LABEL, - ANSWER_TASK_LABEL, - MEM_ARCHIVE_TASK_LABEL, - MEM_ORGANIZE_TASK_LABEL, - MEM_UPDATE_TASK_LABEL, - QUERY_TASK_LABEL, - TaskPriorityLevel, -) from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils import metrics -from memos.mem_scheduler.utils.db_utils import get_utc_now -from memos.mem_scheduler.utils.filter_utils import ( - transform_name_to_key, -) -from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube -from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule -from memos.memories.activation.kv import KVCacheMemory -from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory -from memos.memories.textual.naive import NaiveTextMemory -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher -from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE -from memos.types.general_types import ( - MemCubeID, - UserID, -) if TYPE_CHECKING: import redis + from sqlalchemy.engine import Engine + + from memos.llms.base import BaseLLM + from memos.mem_cube.base import BaseMemCube + from memos.mem_feedback.simple_feedback import SimpleMemFeedback + from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem + from memos.memories.textual.tree import TreeTextMemory + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.http_bge import HTTPBGEReranker + from memos.types.general_types import MemCubeID, UserID logger = get_logger(__name__) -class BaseScheduler(RabbitMQSchedulerModule, RedisSchedulerModule, SchedulerLoggerModule): +class BaseScheduler( + RabbitMQSchedulerModule, + RedisSchedulerModule, + SchedulerLoggerModule, + BaseSchedulerWebLogMixin, + BaseSchedulerMemoryMixin, + BaseSchedulerQueueMixin, +): """Base class for all mem_scheduler.""" def __init__(self, config: BaseSchedulerConfig): @@ -219,7 +193,7 @@ def initialize_modules( process_llm: BaseLLM | None = None, db_engine: Engine | None = None, mem_reader=None, - redis_client: Union["redis.Redis", None] = None, + redis_client: redis.Redis | None = None, ): if process_llm is None: process_llm = chat_llm @@ -391,929 +365,4 @@ def mem_cubes(self, value: dict[str, BaseMemCube]) -> None: f"Failed to initialize current_mem_cube from mem_cubes: {e}", exc_info=True ) - def transform_working_memories_to_monitors( - self, query_keywords, memories: list[TextualMemoryItem] - ) -> list[MemoryMonitorItem]: - """ - Convert a list of TextualMemoryItem objects into MemoryMonitorItem objects - with importance scores based on keyword matching. - - Args: - memories: List of TextualMemoryItem objects to be transformed. - - Returns: - List of MemoryMonitorItem objects with computed importance scores. - """ - - result = [] - mem_length = len(memories) - for idx, mem in enumerate(memories): - text_mem = mem.memory - mem_key = transform_name_to_key(name=text_mem) - - # Calculate importance score based on keyword matches - keywords_score = 0 - if query_keywords and text_mem: - for keyword, count in query_keywords.items(): - keyword_count = text_mem.count(keyword) - if keyword_count > 0: - keywords_score += keyword_count * count - logger.debug( - f"Matched keyword '{keyword}' {keyword_count} times, added {keywords_score} to keywords_score" - ) - - # rank score - sorting_score = mem_length - idx - - mem_monitor = MemoryMonitorItem( - memory_text=text_mem, - tree_memory_item=mem, - tree_memory_item_mapping_key=mem_key, - sorting_score=sorting_score, - keywords_score=keywords_score, - recording_count=1, - ) - result.append(mem_monitor) - - logger.info(f"Transformed {len(result)} memories to monitors") - return result - - def replace_working_memory( - self, - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - original_memory: list[TextualMemoryItem], - new_memory: list[TextualMemoryItem], - ) -> None | list[TextualMemoryItem]: - """Replace working memory with new memories after reranking.""" - text_mem_base = mem_cube.text_mem - if isinstance(text_mem_base, TreeTextMemory): - text_mem_base: TreeTextMemory = text_mem_base - - # process rerank memories with llm - query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] - # Sync with database to get latest query history - query_db_manager.sync_with_orm() - - query_history = query_db_manager.obj.get_queries_with_timesort() - - original_count = len(original_memory) - # Filter out memories tagged with "mode:fast" - filtered_original_memory = [] - for origin_mem in original_memory: - if "mode:fast" not in origin_mem.metadata.tags: - filtered_original_memory.append(origin_mem) - else: - logger.debug( - f"Filtered out memory - ID: {getattr(origin_mem, 'id', 'unknown')}, Tags: {origin_mem.metadata.tags}" - ) - # Calculate statistics - filtered_count = original_count - len(filtered_original_memory) - remaining_count = len(filtered_original_memory) - - logger.info( - f"Filtering complete. Removed {filtered_count} memories with tag 'mode:fast'. Remaining memories: {remaining_count}" - ) - original_memory = filtered_original_memory - - memories_with_new_order, rerank_success_flag = ( - self.retriever.process_and_rerank_memories( - queries=query_history, - original_memory=original_memory, - new_memory=new_memory, - top_k=self.top_k, - ) - ) - - # Filter completely unrelated memories according to query_history - logger.info(f"Filtering memories based on query history: {len(query_history)} queries") - filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories( - query_history=query_history, - memories=memories_with_new_order, - ) - - if filter_success_flag: - logger.info( - f"Memory filtering completed successfully. " - f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories" - ) - memories_with_new_order = filtered_memories - else: - logger.warning( - "Memory filtering failed - keeping all memories as fallback. " - f"Original count: {len(memories_with_new_order)}" - ) - - # Update working memory monitors - query_keywords = query_db_manager.obj.get_keywords_collections() - logger.info( - f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords" - ) - new_working_memory_monitors = self.transform_working_memories_to_monitors( - query_keywords=query_keywords, - memories=memories_with_new_order, - ) - - if not rerank_success_flag: - for one in new_working_memory_monitors: - one.sorting_score = 0 - - logger.info(f"update {len(new_working_memory_monitors)} working_memory_monitors") - self.monitor.update_working_memory_monitors( - new_working_memory_monitors=new_working_memory_monitors, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - - mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][ - mem_cube_id - ].obj.get_sorted_mem_monitors(reverse=True) - new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors] - - text_mem_base.replace_working_memory(memories=new_working_memories) - - logger.info( - f"The working memory has been replaced with {len(memories_with_new_order)} new memories." - ) - self.log_working_memory_replacement( - original_memory=original_memory, - new_memory=new_working_memories, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - log_func_callback=self._submit_web_logs, - ) - elif isinstance(text_mem_base, NaiveTextMemory): - # For NaiveTextMemory, we populate the monitors with the new candidates so activation memory can pick them up - logger.info( - f"NaiveTextMemory: Updating working memory monitors with {len(new_memory)} candidates." - ) - - # Use query keywords if available, otherwise just basic monitoring - query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] - query_db_manager.sync_with_orm() - query_keywords = query_db_manager.obj.get_keywords_collections() - - new_working_memory_monitors = self.transform_working_memories_to_monitors( - query_keywords=query_keywords, - memories=new_memory, - ) - - self.monitor.update_working_memory_monitors( - new_working_memory_monitors=new_working_memory_monitors, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - memories_with_new_order = new_memory - else: - logger.error("memory_base is not supported") - memories_with_new_order = new_memory - - return memories_with_new_order - - def update_activation_memory( - self, - new_memories: list[str | TextualMemoryItem], - label: str, - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - ) -> None: - """ - Update activation memory by extracting KVCacheItems from new_memory (list of str), - add them to a KVCacheMemory instance, and dump to disk. - """ - if len(new_memories) == 0: - logger.error("update_activation_memory: new_memory is empty.") - return - if isinstance(new_memories[0], TextualMemoryItem): - new_text_memories = [mem.memory for mem in new_memories] - elif isinstance(new_memories[0], str): - new_text_memories = new_memories - else: - logger.error("Not Implemented.") - return - - try: - if isinstance(mem_cube.act_mem, VLLMKVCacheMemory): - act_mem: VLLMKVCacheMemory = mem_cube.act_mem - elif isinstance(mem_cube.act_mem, KVCacheMemory): - act_mem: KVCacheMemory = mem_cube.act_mem - else: - logger.error("Not Implemented.") - return - - new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format( - memory_text="".join( - [ - f"{i + 1}. {sentence.strip()}\n" - for i, sentence in enumerate(new_text_memories) - if sentence.strip() # Skip empty strings - ] - ) - ) - - # huggingface or vllm kv cache - original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all() - original_text_memories = [] - if len(original_cache_items) > 0: - pre_cache_item: VLLMKVCacheItem = original_cache_items[-1] - original_text_memories = pre_cache_item.records.text_memories - original_composed_text_memory = pre_cache_item.records.composed_text_memory - if original_composed_text_memory == new_text_memory: - logger.warning( - "Skipping memory update - new composition matches existing cache: %s", - new_text_memory[:50] + "..." - if len(new_text_memory) > 50 - else new_text_memory, - ) - return - act_mem.delete_all() - - cache_item = act_mem.extract(new_text_memory) - cache_item.records.text_memories = new_text_memories - cache_item.records.timestamp = get_utc_now() - - act_mem.add([cache_item]) - act_mem.dump(self.act_mem_dump_path) - - self.log_activation_memory_update( - original_text_memories=original_text_memories, - new_text_memories=new_text_memories, - label=label, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - log_func_callback=self._submit_web_logs, - ) - - except Exception as e: - logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True) - # Re-raise the exception if it's critical for the operation - # For now, we'll continue execution but this should be reviewed - - def update_activation_memory_periodically( - self, - interval_seconds: int, - label: str, - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - ): - try: - if ( - self.monitor.last_activation_mem_update_time == datetime.min - or self.monitor.timed_trigger( - last_time=self.monitor.last_activation_mem_update_time, - interval_seconds=interval_seconds, - ) - ): - logger.info( - f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}" - ) - - if ( - user_id not in self.monitor.working_memory_monitors - or mem_cube_id not in self.monitor.working_memory_monitors[user_id] - or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories) - == 0 - ): - logger.warning( - "No memories found in working_memory_monitors, activation memory update is skipped" - ) - return - - self.monitor.update_activation_memory_monitors( - user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube - ) - - # Sync with database to get latest activation memories - activation_db_manager = self.monitor.activation_memory_monitors[user_id][ - mem_cube_id - ] - activation_db_manager.sync_with_orm() - new_activation_memories = [ - m.memory_text for m in activation_db_manager.obj.memories - ] - - logger.info( - f"Collected {len(new_activation_memories)} new memory entries for processing" - ) - # Print the content of each new activation memory - for i, memory in enumerate(new_activation_memories[:5], 1): - logger.info( - f"Part of New Activation Memorires | {i}/{len(new_activation_memories)}: {memory[:20]}" - ) - - self.update_activation_memory( - new_memories=new_activation_memories, - label=label, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - - self.monitor.last_activation_mem_update_time = get_utc_now() - - logger.debug( - f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}" - ) - - else: - logger.info( - f"Skipping update - {interval_seconds} second interval not yet reached. " - f"Last update time is {self.monitor.last_activation_mem_update_time} and now is " - f"{get_utc_now()}" - ) - except Exception as e: - logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) - - def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): - """Submit messages for processing, with priority-aware dispatch. - - - LEVEL_1 tasks dispatch immediately to the appropriate handler. - - Lower-priority tasks are enqueued via the configured message queue. - """ - if isinstance(messages, ScheduleMessageItem): - messages = [messages] - - if not messages: - return - - current_trace_id = get_current_trace_id() - - immediate_msgs: list[ScheduleMessageItem] = [] - queued_msgs: list[ScheduleMessageItem] = [] - - for msg in messages: - # propagate request trace_id when available so monitor logs align with request logs - if current_trace_id: - msg.trace_id = current_trace_id - - # basic metrics and status tracking - with suppress(Exception): - self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label) - - # ensure timestamp exists for monitoring - if getattr(msg, "timestamp", None) is None: - msg.timestamp = get_utc_now() - - if self.status_tracker: - try: - self.status_tracker.task_submitted( - task_id=msg.item_id, - user_id=msg.user_id, - task_type=msg.label, - mem_cube_id=msg.mem_cube_id, - business_task_id=msg.task_id, - ) - except Exception: - logger.warning("status_tracker.task_submitted failed", exc_info=True) - - # honor disabled handlers - if self.disabled_handlers and msg.label in self.disabled_handlers: - logger.info(f"Skipping disabled handler: {msg.label} - {msg.content}") - continue - - # decide priority path - task_priority = self.orchestrator.get_task_priority(task_label=msg.label) - if task_priority == TaskPriorityLevel.LEVEL_1: - immediate_msgs.append(msg) - else: - queued_msgs.append(msg) - - # Dispatch high-priority tasks immediately - if immediate_msgs: - # emit enqueue events for consistency - for m in immediate_msgs: - emit_monitor_event( - "enqueue", - m, - { - "enqueue_ts": to_iso(getattr(m, "timestamp", None)), - "event_duration_ms": 0, - "total_duration_ms": 0, - }, - ) - - # simulate dequeue for immediately dispatched messages so monitor logs stay complete - for m in immediate_msgs: - try: - now = time.time() - enqueue_ts_obj = getattr(m, "timestamp", None) - enqueue_epoch = None - if isinstance(enqueue_ts_obj, int | float): - enqueue_epoch = float(enqueue_ts_obj) - elif hasattr(enqueue_ts_obj, "timestamp"): - dt = enqueue_ts_obj - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - enqueue_epoch = dt.timestamp() - - queue_wait_ms = None - if enqueue_epoch is not None: - queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 - - object.__setattr__(m, "_dequeue_ts", now) - emit_monitor_event( - "dequeue", - m, - { - "enqueue_ts": to_iso(enqueue_ts_obj), - "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(), - "queue_wait_ms": queue_wait_ms, - "event_duration_ms": queue_wait_ms, - "total_duration_ms": queue_wait_ms, - }, - ) - self.metrics.task_dequeued(user_id=m.user_id, task_type=m.label) - except Exception: - logger.debug("Failed to emit dequeue for immediate task", exc_info=True) - - user_cube_groups = group_messages_by_user_and_mem_cube(immediate_msgs) - for user_id, cube_groups in user_cube_groups.items(): - for mem_cube_id, user_cube_msgs in cube_groups.items(): - label_groups: dict[str, list[ScheduleMessageItem]] = {} - for m in user_cube_msgs: - label_groups.setdefault(m.label, []).append(m) - - for label, msgs_by_label in label_groups.items(): - handler = self.dispatcher.handlers.get( - label, self.dispatcher._default_message_handler - ) - self.dispatcher.execute_task( - user_id=user_id, - mem_cube_id=mem_cube_id, - task_label=label, - msgs=msgs_by_label, - handler_call_back=handler, - ) - - # Enqueue lower-priority tasks - if queued_msgs: - self.memos_message_queue.submit_messages(messages=queued_msgs) - - def _submit_web_logs( - self, - messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem], - additional_log_info: str | None = None, - ) -> None: - """Submit log messages to the web log queue and optionally to RabbitMQ. - - Args: - messages: Single log message or list of log messages - """ - if isinstance(messages, ScheduleLogForWebItem): - messages = [messages] # transform single message to list - - for message in messages: - if self.rabbitmq_config is None: - return - try: - # Always call publish; the publisher now caches when offline and flushes after reconnect - logger.info( - f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}" - ) - self.rabbitmq_publish_message(message=message.to_dict()) - logger.info( - "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched " - "item_id=%s task_id=%s label=%s", - message.item_id, - message.task_id, - message.label, - ) - except Exception as e: - logger.error( - f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True - ) - - logger.debug( - f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}" - ) - - def get_web_log_messages(self) -> list[dict]: - """ - Retrieve structured log messages from the queue and return JSON-serializable dicts. - """ - raw_items: list[ScheduleLogForWebItem] = [] - while True: - try: - raw_items.append(self._web_log_message_queue.get_nowait()) - except Exception: - break - - def _map_label(label: str) -> str: - mapping = { - QUERY_TASK_LABEL: "addMessage", - ANSWER_TASK_LABEL: "addMessage", - ADD_TASK_LABEL: "addMemory", - MEM_UPDATE_TASK_LABEL: "updateMemory", - MEM_ORGANIZE_TASK_LABEL: "mergeMemory", - MEM_ARCHIVE_TASK_LABEL: "archiveMemory", - } - return mapping.get(label, label) - - def _normalize_item(item: ScheduleLogForWebItem) -> dict: - data = item.to_dict() - data["label"] = _map_label(data.get("label")) - memcube_content = getattr(item, "memcube_log_content", None) or [] - metadata = getattr(item, "metadata", None) or [] - - memcube_name = getattr(item, "memcube_name", None) - if not memcube_name and hasattr(self, "_map_memcube_name"): - memcube_name = self._map_memcube_name(item.mem_cube_id) - data["memcube_name"] = memcube_name - - memory_len = getattr(item, "memory_len", None) - if memory_len is None: - if data["label"] == "mergeMemory": - memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"]) - elif memcube_content: - memory_len = len(memcube_content) - else: - memory_len = 1 if item.log_content else 0 - - data["memcube_log_content"] = memcube_content - data["memory_len"] = memory_len - - def _with_memory_time(meta: dict) -> dict: - enriched = dict(meta) - if "memory_time" not in enriched: - enriched["memory_time"] = enriched.get("updated_at") or enriched.get( - "update_at" - ) - return enriched - - data["metadata"] = [_with_memory_time(m) for m in metadata] - data["log_title"] = "" - return data - - return [_normalize_item(it) for it in raw_items] - - def _message_consumer(self) -> None: - """ - Continuously checks the queue for messages and dispatches them. - - Runs in a dedicated thread to process messages at regular intervals. - For Redis queue, this method starts the Redis listener. - """ - - # Original local queue logic - while self._running: # Use a running flag for graceful shutdown - try: - # Check dispatcher thread pool status to avoid overloading - if self.enable_parallel_dispatch and self.dispatcher: - running_tasks = self.dispatcher.get_running_task_count() - if running_tasks >= self.dispatcher.max_workers: - # Thread pool is full, wait and retry - time.sleep(self._consume_interval) - continue - - # Get messages in batches based on consume_batch setting - - messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch) - - if messages: - now = time.time() - for msg in messages: - prev_context = get_current_context() - try: - # Set context for this message - msg_context = RequestContext( - trace_id=msg.trace_id, - user_name=msg.user_name, - ) - set_request_context(msg_context) - - enqueue_ts_obj = getattr(msg, "timestamp", None) - enqueue_epoch = None - if isinstance(enqueue_ts_obj, int | float): - enqueue_epoch = float(enqueue_ts_obj) - elif hasattr(enqueue_ts_obj, "timestamp"): - dt = enqueue_ts_obj - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - enqueue_epoch = dt.timestamp() - - queue_wait_ms = None - if enqueue_epoch is not None: - queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 - - # Avoid pydantic field enforcement by using object.__setattr__ - object.__setattr__(msg, "_dequeue_ts", now) - emit_monitor_event( - "dequeue", - msg, - { - "enqueue_ts": to_iso(enqueue_ts_obj), - "dequeue_ts": datetime.fromtimestamp( - now, tz=timezone.utc - ).isoformat(), - "queue_wait_ms": queue_wait_ms, - "event_duration_ms": queue_wait_ms, - "total_duration_ms": queue_wait_ms, - }, - ) - self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) - finally: - # Restore the prior context of the consumer thread - set_request_context(prev_context) - try: - import contextlib - - with contextlib.suppress(Exception): - if messages: - self.dispatcher.on_messages_enqueued(messages) - - self.dispatcher.dispatch(messages) - except Exception as e: - logger.error(f"Error dispatching messages: {e!s}") - - # Sleep briefly to prevent busy waiting - time.sleep(self._consume_interval) # Adjust interval as needed - - except Exception as e: - # Don't log error for "No messages available in Redis queue" as it's expected - if "No messages available in Redis queue" not in str(e): - logger.error(f"Unexpected error in message consumer: {e!s}", exc_info=True) - time.sleep(self._consume_interval) # Prevent tight error loops - - def _monitor_loop(self): - while self._running: - try: - q_sizes = self.memos_message_queue.qsize() - - if not isinstance(q_sizes, dict): - continue - - for stream_key, queue_length in q_sizes.items(): - # Skip aggregate keys like 'total_size' - if stream_key == "total_size": - continue - - # Key format: ...:{user_id}:{mem_cube_id}:{task_label} - # We want to extract user_id, which is the 3rd component from the end. - parts = stream_key.split(":") - if len(parts) >= 3: - user_id = parts[-3] - self.metrics.update_queue_length(queue_length, user_id) - else: - # Fallback for unexpected key formats (e.g. legacy or testing) - # Try to use the key itself if it looks like a user_id (no colons) - # or just log a warning? - # For now, let's assume if it's not total_size and short, it might be a direct user_id key - # (though that shouldn't happen with current queue implementations) - if ":" not in stream_key: - self.metrics.update_queue_length(queue_length, stream_key) - - except Exception as e: - logger.error(f"Error in metrics monitor loop: {e}", exc_info=True) - - time.sleep(15) # 每 15 秒采样一次 - - def start(self) -> None: - """ - Start the message consumer thread/process and initialize dispatcher resources. - - Initializes and starts: - 1. Message consumer thread or process (based on startup_mode) - 2. Dispatcher thread pool (if parallel dispatch enabled) - """ - # Initialize dispatcher resources - if self.enable_parallel_dispatch: - logger.info( - f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers" - ) - - self.start_consumer() - self.start_background_monitor() - - def start_background_monitor(self): - if self._monitor_thread and self._monitor_thread.is_alive(): - return - self._monitor_thread = ContextThread( - target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor" - ) - self._monitor_thread.start() - logger.info("Scheduler metrics monitor thread started.") - - def start_consumer(self) -> None: - """ - Start only the message consumer thread/process. - - This method can be used to restart the consumer after it has been stopped - with stop_consumer(), without affecting other scheduler components. - """ - if self._running: - logger.warning("Memory Scheduler consumer is already running") - return - - # Start consumer based on startup mode - self._running = True - - if self.scheduler_startup_mode == STARTUP_BY_PROCESS: - # Start consumer process - self._consumer_process = multiprocessing.Process( - target=self._message_consumer, - daemon=True, - name="MessageConsumerProcess", - ) - self._consumer_process.start() - logger.info("Message consumer process started") - else: - # Default to thread mode - self._consumer_thread = ContextThread( - target=self._message_consumer, - daemon=True, - name="MessageConsumerThread", - ) - self._consumer_thread.start() - logger.info("Message consumer thread started") - - def stop_consumer(self) -> None: - """Stop only the message consumer thread/process gracefully. - - This method stops the consumer without affecting other components like - dispatcher or monitors. Useful when you want to pause message processing - while keeping other scheduler components running. - """ - if not self._running: - logger.warning("Memory Scheduler consumer is not running") - return - - # Signal consumer thread/process to stop - self._running = False - - # Wait for consumer thread or process - if self.scheduler_startup_mode == STARTUP_BY_PROCESS and self._consumer_process: - if self._consumer_process.is_alive(): - self._consumer_process.join(timeout=5.0) - if self._consumer_process.is_alive(): - logger.warning("Consumer process did not stop gracefully, terminating...") - self._consumer_process.terminate() - self._consumer_process.join(timeout=2.0) - if self._consumer_process.is_alive(): - logger.error("Consumer process could not be terminated") - else: - logger.info("Consumer process terminated") - else: - logger.info("Consumer process stopped") - self._consumer_process = None - elif self._consumer_thread and self._consumer_thread.is_alive(): - self._consumer_thread.join(timeout=5.0) - if self._consumer_thread.is_alive(): - logger.warning("Consumer thread did not stop gracefully") - else: - logger.info("Consumer thread stopped") - self._consumer_thread = None - - logger.info("Memory Scheduler consumer stopped") - - def stop(self) -> None: - """Stop all scheduler components gracefully. - - 1. Stops message consumer thread/process - 2. Shuts down dispatcher thread pool - 3. Cleans up resources - """ - if not self._running: - logger.warning("Memory Scheduler is not running") - return - - # Stop consumer first - self.stop_consumer() - - if self._monitor_thread: - self._monitor_thread.join(timeout=2.0) - - # Shutdown dispatcher - if self.dispatcher: - logger.info("Shutting down dispatcher...") - self.dispatcher.shutdown() - - # Shutdown dispatcher_monitor - if self.dispatcher_monitor: - logger.info("Shutting down monitor...") - self.dispatcher_monitor.stop() - - @property - def handlers(self) -> dict[str, Callable]: - """ - Access the dispatcher's handlers dictionary. - - Returns: - dict[str, Callable]: Dictionary mapping labels to handler functions - """ - if not self.dispatcher: - logger.warning("Dispatcher is not initialized, returning empty handlers dict") - return {} - - return self.dispatcher.handlers - - def register_handlers( - self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]] - ) -> None: - """ - Bulk register multiple handlers from a dictionary. - - Args: - handlers: Dictionary mapping labels to handler functions - Format: {label: handler_callable} - """ - if not self.dispatcher: - logger.warning("Dispatcher is not initialized, cannot register handlers") - return - - self.dispatcher.register_handlers(handlers) - - def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: - """ - Unregister handlers from the dispatcher by their labels. - - Args: - labels: List of labels to unregister handlers for - - Returns: - dict[str, bool]: Dictionary mapping each label to whether it was successfully unregistered - """ - if not self.dispatcher: - logger.warning("Dispatcher is not initialized, cannot unregister handlers") - return dict.fromkeys(labels, False) - - return self.dispatcher.unregister_handlers(labels) - - def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]: - if not self.dispatcher: - logger.warning("Dispatcher is not initialized, returning empty tasks dict") - return {} - - running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func) - - # Convert RunningTaskItem objects to dictionaries for easier consumption - result = {} - for task_id, task_item in running_tasks.items(): - result[task_id] = { - "item_id": task_item.item_id, - "user_id": task_item.user_id, - "mem_cube_id": task_item.mem_cube_id, - "task_info": task_item.task_info, - "task_name": task_item.task_name, - "start_time": task_item.start_time, - "end_time": task_item.end_time, - "status": task_item.status, - "result": task_item.result, - "error_message": task_item.error_message, - "messages": task_item.messages, - } - - return result - - def get_tasks_status(self): - """Delegate status collection to TaskScheduleMonitor.""" - return self.task_schedule_monitor.get_tasks_status() - - def print_tasks_status(self, tasks_status: dict | None = None) -> None: - """Delegate pretty printing to TaskScheduleMonitor.""" - self.task_schedule_monitor.print_tasks_status(tasks_status=tasks_status) - - def _gather_queue_stats(self) -> dict: - """Collect queue/dispatcher stats for reporting.""" - memos_message_queue = self.memos_message_queue.memos_message_queue - stats: dict[str, int | float | str] = {} - stats["use_redis_queue"] = bool(self.use_redis_queue) - # local queue metrics - if not self.use_redis_queue: - try: - stats["qsize"] = int(memos_message_queue.qsize()) - except Exception: - stats["qsize"] = -1 - # unfinished_tasks if available - try: - stats["unfinished_tasks"] = int( - getattr(memos_message_queue, "unfinished_tasks", 0) or 0 - ) - except Exception: - stats["unfinished_tasks"] = -1 - stats["maxsize"] = int(self.max_internal_message_queue_size) - try: - maxsize = int(self.max_internal_message_queue_size) or 1 - qsize = int(stats.get("qsize", 0)) - stats["utilization"] = min(1.0, max(0.0, qsize / maxsize)) - except Exception: - stats["utilization"] = 0.0 - # dispatcher stats - try: - d_stats = self.dispatcher.stats() - stats.update( - { - "running": int(d_stats.get("running", 0)), - "inflight": int(d_stats.get("inflight", 0)), - "handlers": int(d_stats.get("handlers", 0)), - } - ) - except Exception: - stats.update({"running": 0, "inflight": 0, "handlers": 0}) - return stats + # Methods moved to mixins in mem_scheduler.base_mixins. diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 74e50a514..1fc3317d8 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -1,49 +1,16 @@ -import concurrent.futures -import contextlib -import json -import traceback +from __future__ import annotations -from memos.configs.mem_scheduler import GeneralSchedulerConfig -from memos.context.context import ContextThreadPoolExecutor -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.base_scheduler import BaseScheduler -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem -from memos.mem_scheduler.schemas.task_schemas import ( - ADD_TASK_LABEL, - ANSWER_TASK_LABEL, - DEFAULT_MAX_QUERY_KEY_WORDS, - LONG_TERM_MEMORY_TYPE, - MEM_FEEDBACK_TASK_LABEL, - MEM_ORGANIZE_TASK_LABEL, - MEM_READ_TASK_LABEL, - MEM_UPDATE_TASK_LABEL, - NOT_APPLICABLE_TYPE, - PREF_ADD_TASK_LABEL, - QUERY_TASK_LABEL, - USER_INPUT_TYPE, -) -from memos.mem_scheduler.utils.filter_utils import ( - is_all_chinese, - is_all_english, - transform_name_to_key, -) -from memos.mem_scheduler.utils.misc_utils import ( - group_messages_by_user_and_mem_cube, - is_cloud_env, -) -from memos.memories.textual.item import TextualMemoryItem -from memos.memories.textual.naive import NaiveTextMemory -from memos.memories.textual.preference import PreferenceTextMemory -from memos.memories.textual.tree import TreeTextMemory -from memos.types import ( - MemCubeID, - UserID, -) +from typing import TYPE_CHECKING -logger = get_logger(__name__) +if TYPE_CHECKING: + from memos.configs.mem_scheduler import GeneralSchedulerConfig +from memos.mem_scheduler.base_scheduler import BaseScheduler +from memos.mem_scheduler.handlers import ( + SchedulerHandlerContext, + SchedulerHandlerRegistry, + SchedulerHandlerServices, +) class GeneralScheduler(BaseScheduler): @@ -53,1447 +20,29 @@ def __init__(self, config: GeneralSchedulerConfig): self.query_key_words_limit = self.config.get("query_key_words_limit", 20) - # register handlers - handlers = { - QUERY_TASK_LABEL: self._query_message_consumer, - ANSWER_TASK_LABEL: self._answer_message_consumer, - MEM_UPDATE_TASK_LABEL: self._memory_update_consumer, - ADD_TASK_LABEL: self._add_message_consumer, - MEM_READ_TASK_LABEL: self._mem_read_message_consumer, - MEM_ORGANIZE_TASK_LABEL: self._mem_reorganize_message_consumer, - PREF_ADD_TASK_LABEL: self._pref_add_message_consumer, - MEM_FEEDBACK_TASK_LABEL: self._mem_feedback_message_consumer, - } - self.dispatcher.register_handlers(handlers) - - def long_memory_update_process( - self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] - ): - mem_cube = self.mem_cube - - # update query monitors - for msg in messages: - self.monitor.register_query_monitor_if_not_exists( - user_id=user_id, mem_cube_id=mem_cube_id - ) - - query = msg.content - query_keywords = self.monitor.extract_query_keywords(query=query) - logger.info( - f'Extracted keywords "{query_keywords}" from query "{query}" for user_id={user_id}' - ) - - if len(query_keywords) == 0: - stripped_query = query.strip() - # Determine measurement method based on language - if is_all_english(stripped_query): - words = stripped_query.split() # Word count for English - elif is_all_chinese(stripped_query): - words = stripped_query # Character count for Chinese - else: - logger.debug( - f"Mixed-language memory, using character count: {stripped_query[:50]}..." - ) - words = stripped_query # Default to character count - - query_keywords = list(set(words[: self.query_key_words_limit])) - logger.error( - f"Keyword extraction failed for query '{query}' (user_id={user_id}). Using fallback keywords: {query_keywords[:10]}... (truncated)", - exc_info=True, - ) - - item = QueryMonitorItem( - user_id=user_id, - mem_cube_id=mem_cube_id, - query_text=query, - keywords=query_keywords, - max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, - ) - - query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] - query_db_manager.obj.put(item=item) - # Sync with database after adding new item - query_db_manager.sync_with_orm() - logger.debug( - f"Queries in monitor for user_id={user_id}, mem_cube_id={mem_cube_id}: {query_db_manager.obj.get_queries_with_timesort()}" - ) - - queries = [msg.content for msg in messages] - - # recall - cur_working_memory, new_candidates = self.process_session_turn( - queries=queries, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - top_k=self.top_k, - ) - logger.info( - # Build the candidate preview string outside the f-string to avoid backslashes in expression - f"[long_memory_update_process] Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} " - f"new candidate memories for user_id={user_id}: " - + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in new_candidates])) - ) - - # rerank - new_order_working_memory = self.replace_working_memory( - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - original_memory=cur_working_memory, - new_memory=new_candidates, - ) - logger.debug( - f"[long_memory_update_process] Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}" - ) - - old_memory_texts = "\n- " + "\n- ".join( - [f"{one.id}: {one.memory}" for one in cur_working_memory] - ) - new_memory_texts = "\n- " + "\n- ".join( - [f"{one.id}: {one.memory}" for one in new_order_working_memory] - ) - - logger.info( - f"[long_memory_update_process] For user_id='{user_id}', mem_cube_id='{mem_cube_id}': " - f"Scheduler replaced working memory based on query history {queries}. " - f"Old working memory ({len(cur_working_memory)} items): {old_memory_texts}. " - f"New working memory ({len(new_order_working_memory)} items): {new_memory_texts}." - ) - - # update activation memories - logger.debug( - f"Activation memory update {'enabled' if self.enable_activation_memory else 'disabled'} " - f"(interval: {self.monitor.act_mem_update_interval}s)" - ) - if self.enable_activation_memory: - self.update_activation_memory_periodically( - interval_seconds=self.monitor.act_mem_update_interval, - label=QUERY_TASK_LABEL, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=self.mem_cube, - ) - - def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {ADD_TASK_LABEL} handler.") - # Process the query in a session turn - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=ADD_TASK_LABEL) - try: - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - - # Process each message in the batch - for msg in batch: - prepared_add_items, prepared_update_items_with_original = ( - self.log_add_messages(msg=msg) - ) - logger.info( - f"prepared_add_items: {prepared_add_items};\n prepared_update_items_with_original: {prepared_update_items_with_original}" - ) - # Conditional Logging: Knowledge Base (Cloud Service) vs. Playground/Default - cloud_env = is_cloud_env() - - if cloud_env: - self.send_add_log_messages_to_cloud_env( - msg, prepared_add_items, prepared_update_items_with_original - ) - else: - self.send_add_log_messages_to_local_env( - msg, prepared_add_items, prepared_update_items_with_original - ) - - except Exception as e: - logger.error(f"Error: {e}", exc_info=True) - - def _memory_update_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {MEM_UPDATE_TASK_LABEL} handler.") - - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=MEM_UPDATE_TASK_LABEL) - - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - # Process the whole batch once; no need to iterate per message - self.long_memory_update_process( - user_id=user_id, mem_cube_id=mem_cube_id, messages=batch - ) - - def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - """ - Process and handle query trigger messages from the queue. - - Args: - messages: List of query messages to process - """ - logger.info(f"Messages {messages} assigned to {QUERY_TASK_LABEL} handler.") - - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=QUERY_TASK_LABEL) - - mem_update_messages = [] - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - - for msg in batch: - try: - event = self.create_event_log( - label="addMessage", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=NOT_APPLICABLE_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=[ - { - "content": f"[User] {msg.content}", - "ref_id": msg.item_id, - "role": "user", - } - ], - metadata=[], - memory_len=1, - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - self._submit_web_logs([event]) - except Exception: - logger.exception("Failed to record addMessage log for query") - # Re-submit the message with label changed to mem_update - update_msg = ScheduleMessageItem( - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - label=MEM_UPDATE_TASK_LABEL, - content=msg.content, - session_id=msg.session_id, - user_name=msg.user_name, - info=msg.info, - task_id=msg.task_id, - ) - mem_update_messages.append(update_msg) - - self.submit_messages(messages=mem_update_messages) - - def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - """ - Process and handle answer trigger messages from the queue. - - Args: - messages: List of answer messages to process - """ - logger.info(f"Messages {messages} assigned to {ANSWER_TASK_LABEL} handler.") - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=ANSWER_TASK_LABEL) - - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - try: - for msg in batch: - event = self.create_event_log( - label="addMessage", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=NOT_APPLICABLE_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=[ - { - "content": f"[Assistant] {msg.content}", - "ref_id": msg.item_id, - "role": "assistant", - } - ], - metadata=[], - memory_len=1, - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - self._submit_web_logs([event]) - except Exception: - logger.exception("Failed to record addMessage log for answer") - - def log_add_messages(self, msg: ScheduleMessageItem): - try: - userinput_memory_ids = json.loads(msg.content) - except Exception as e: - logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) - userinput_memory_ids = [] - - # Prepare data for both logging paths, fetching original content for updates - prepared_add_items = [] - prepared_update_items_with_original = [] - missing_ids: list[str] = [] - - for memory_id in userinput_memory_ids: - try: - # This mem_item represents the NEW content that was just added/processed - mem_item: TextualMemoryItem | None = None - mem_item = self.mem_cube.text_mem.get( - memory_id=memory_id, user_name=msg.mem_cube_id - ) - if mem_item is None: - raise ValueError(f"Memory {memory_id} not found after retries") - # Check if a memory with the same key already exists (determining if it's an update) - key = getattr(mem_item.metadata, "key", None) or transform_name_to_key( - name=mem_item.memory - ) - exists = False - original_content = None - original_item_id = None - - # Only check graph_store if a key exists and the text_mem has a graph_store - if key and hasattr(self.mem_cube.text_mem, "graph_store"): - candidates = self.mem_cube.text_mem.graph_store.get_by_metadata( - [ - {"field": "key", "op": "=", "value": key}, - { - "field": "memory_type", - "op": "=", - "value": mem_item.metadata.memory_type, - }, - ] - ) - if candidates: - exists = True - original_item_id = candidates[0] - # Crucial step: Fetch the original content for updates - # This `get` is for the *existing* memory that will be updated - original_mem_item = self.mem_cube.text_mem.get( - memory_id=original_item_id, user_name=msg.mem_cube_id - ) - original_content = original_mem_item.memory - - if exists: - prepared_update_items_with_original.append( - { - "new_item": mem_item, - "original_content": original_content, - "original_item_id": original_item_id, - } - ) - else: - prepared_add_items.append(mem_item) - - except Exception: - missing_ids.append(memory_id) - logger.debug( - f"This MemoryItem {memory_id} has already been deleted or an error occurred during preparation." - ) - - if missing_ids: - content_preview = ( - msg.content[:200] + "..." - if isinstance(msg.content, str) and len(msg.content) > 200 - else msg.content - ) - logger.warning( - "Missing TextualMemoryItem(s) during add log preparation. " - "memory_ids=%s user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s content_preview=%s", - missing_ids, - msg.user_id, - msg.mem_cube_id, - msg.task_id, - msg.item_id, - getattr(msg, "redis_message_id", ""), - msg.label, - getattr(msg, "stream_key", ""), - content_preview, - ) - - if not prepared_add_items and not prepared_update_items_with_original: - logger.warning( - "No add/update items prepared; skipping addMemory/knowledgeBaseUpdate logs. " - "user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s missing_ids=%s", - msg.user_id, - msg.mem_cube_id, - msg.task_id, - msg.item_id, - getattr(msg, "redis_message_id", ""), - msg.label, - getattr(msg, "stream_key", ""), - missing_ids, - ) - return prepared_add_items, prepared_update_items_with_original - - def send_add_log_messages_to_local_env( - self, msg: ScheduleMessageItem, prepared_add_items, prepared_update_items_with_original - ): - # Existing: Playground/Default Logging - # Reconstruct add_content/add_meta/update_content/update_meta from prepared_items - # This ensures existing logging path continues to work with pre-existing data structures - add_content_legacy: list[dict] = [] - add_meta_legacy: list[dict] = [] - update_content_legacy: list[dict] = [] - update_meta_legacy: list[dict] = [] - - for item in prepared_add_items: - key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) - add_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id}) - add_meta_legacy.append( - { - "ref_id": item.id, - "id": item.id, - "key": item.metadata.key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - - for item_data in prepared_update_items_with_original: - item = item_data["new_item"] - key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) - update_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id}) - update_meta_legacy.append( - { - "ref_id": item.id, - "id": item.id, - "key": item.metadata.key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - - events = [] - if add_content_legacy: - event = self.create_event_log( - label="addMemory", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=add_content_legacy, - metadata=add_meta_legacy, - memory_len=len(add_content_legacy), - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - events.append(event) - if update_content_legacy: - event = self.create_event_log( - label="updateMemory", - from_memory_type=LONG_TERM_MEMORY_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=update_content_legacy, - metadata=update_meta_legacy, - memory_len=len(update_content_legacy), - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.task_id = msg.task_id - events.append(event) - logger.info(f"send_add_log_messages_to_local_env: {len(events)}") - if events: - self._submit_web_logs(events, additional_log_info="send_add_log_messages_to_cloud_env") - - def send_add_log_messages_to_cloud_env( - self, msg: ScheduleMessageItem, prepared_add_items, prepared_update_items_with_original - ): - """ - Cloud logging path for add/update events. - """ - kb_log_content: list[dict] = [] - info = msg.info or {} - - # Process added items - for item in prepared_add_items: - metadata = getattr(item, "metadata", None) - file_ids = getattr(metadata, "file_ids", None) if metadata else None - source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": info.get("trigger_source", "Messages"), - "operation": "ADD", - "memory_id": item.id, - "content": item.memory, - "original_content": None, - "source_doc_id": source_doc_id, - } - ) - - # Process updated items - for item_data in prepared_update_items_with_original: - item = item_data["new_item"] - metadata = getattr(item, "metadata", None) - file_ids = getattr(metadata, "file_ids", None) if metadata else None - source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": info.get("trigger_source", "Messages"), - "operation": "UPDATE", - "memory_id": item.id, - "content": item.memory, - "original_content": item_data.get("original_content"), - "source_doc_id": source_doc_id, - } - ) - - if kb_log_content: - logger.info( - f"[DIAGNOSTIC] general_scheduler.send_add_log_messages_to_cloud_env: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: {msg.user_id}, mem_cube_id: {msg.mem_cube_id}, task_id: {msg.task_id}. KB content: {json.dumps(kb_log_content, indent=2)}" - ) - event = self.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self._map_memcube_name(msg.mem_cube_id), - ) - event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes." - event.task_id = msg.task_id - self._submit_web_logs([event]) - - def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - try: - if not messages: - return - message = messages[0] - mem_cube = self.mem_cube - - user_id = message.user_id - mem_cube_id = message.mem_cube_id - content = message.content - - try: - feedback_data = json.loads(content) if isinstance(content, str) else content - if not isinstance(feedback_data, dict): - logger.error( - f"Failed to decode feedback_data or it is not a dict: {feedback_data}" - ) - return - except json.JSONDecodeError: - logger.error(f"Invalid JSON content for feedback message: {content}", exc_info=True) - return - - task_id = feedback_data.get("task_id") or message.task_id - feedback_result = self.feedback_server.process_feedback( - user_id=user_id, - user_name=mem_cube_id, - session_id=feedback_data.get("session_id"), - chat_history=feedback_data.get("history", []), - retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []), - feedback_content=feedback_data.get("feedback_content"), - feedback_time=feedback_data.get("feedback_time"), - task_id=task_id, - info=feedback_data.get("info", None), - ) - - logger.info( - f"Successfully processed feedback for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - - cloud_env = is_cloud_env() - if cloud_env: - record = feedback_result.get("record") if isinstance(feedback_result, dict) else {} - add_records = record.get("add") if isinstance(record, dict) else [] - update_records = record.get("update") if isinstance(record, dict) else [] - - def _extract_fields(mem_item): - mem_id = ( - getattr(mem_item, "id", None) - if not isinstance(mem_item, dict) - else mem_item.get("id") - ) - mem_memory = ( - getattr(mem_item, "memory", None) - if not isinstance(mem_item, dict) - else mem_item.get("memory") or mem_item.get("text") - ) - if mem_memory is None and isinstance(mem_item, dict): - mem_memory = mem_item.get("text") - original_content = ( - getattr(mem_item, "origin_memory", None) - if not isinstance(mem_item, dict) - else mem_item.get("origin_memory") - or mem_item.get("old_memory") - or mem_item.get("original_content") - ) - source_doc_id = None - if isinstance(mem_item, dict): - source_doc_id = mem_item.get("source_doc_id", None) - - return mem_id, mem_memory, original_content, source_doc_id - - kb_log_content: list[dict] = [] - - for mem_item in add_records or []: - mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item) - if mem_id and mem_memory: - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": "Feedback", - "operation": "ADD", - "memory_id": mem_id, - "content": mem_memory, - "original_content": None, - "source_doc_id": source_doc_id, - } - ) - else: - logger.warning( - "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s", - user_id, - mem_cube_id, - task_id, - mem_item, - stack_info=True, - ) - - for mem_item in update_records or []: - mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item) - if mem_id and mem_memory: - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": "Feedback", - "operation": "UPDATE", - "memory_id": mem_id, - "content": mem_memory, - "original_content": original_content, - "source_doc_id": source_doc_id, - } - ) - else: - logger.warning( - "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s", - user_id, - mem_cube_id, - task_id, - mem_item, - stack_info=True, - ) - - logger.info(f"[Feedback Scheduler] kb_log_content: {kb_log_content!s}") - if kb_log_content: - logger.info( - "[DIAGNOSTIC] general_scheduler._mem_feedback_message_consumer: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s", - user_id, - mem_cube_id, - task_id, - len(kb_log_content), - ) - event = self.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.log_content = ( - f"Knowledge Base Memory Update: {len(kb_log_content)} changes." - ) - event.task_id = task_id - self._submit_web_logs([event]) - else: - logger.warning( - "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s", - user_id, - mem_cube_id, - task_id, - stack_info=True, - ) - else: - logger.info( - "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)", - cloud_env, - ) - - except Exception as e: - logger.error(f"Error processing feedbackMemory message: {e}", exc_info=True) - - def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info( - f"[DIAGNOSTIC] general_scheduler._mem_read_message_consumer called. Received messages: {[msg.model_dump_json(indent=2) for msg in messages]}" - ) - logger.info(f"Messages {messages} assigned to {MEM_READ_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - user_id = message.user_id - mem_cube_id = message.mem_cube_id - mem_cube = self.mem_cube - if mem_cube is None: - logger.error( - f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing", - stack_info=True, - ) - return - - content = message.content - user_name = message.user_name - info = message.info or {} - chat_history = message.chat_history - - # Parse the memory IDs from content - mem_ids = json.loads(content) if isinstance(content, str) else content - if not mem_ids: - return - - logger.info( - f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" - ) - - # Get the text memory from the mem_cube - text_mem = mem_cube.text_mem - if not isinstance(text_mem, TreeTextMemory): - logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") - return - - # Use mem_reader to process the memories - self._process_memories_with_reader( - mem_ids=mem_ids, - user_id=user_id, - mem_cube_id=mem_cube_id, - text_mem=text_mem, - user_name=user_name, - custom_tags=info.get("custom_tags", None), - task_id=message.task_id, - info=info, - chat_history=chat_history, - ) - - logger.info( - f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - - except Exception as e: - logger.error(f"Error processing mem_read message: {e}", stack_info=True) - - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(f"Thread task failed: {e}", stack_info=True) - - def _process_memories_with_reader( - self, - mem_ids: list[str], - user_id: str, - mem_cube_id: str, - text_mem: TreeTextMemory, - user_name: str, - custom_tags: list[str] | None = None, - task_id: str | None = None, - info: dict | None = None, - chat_history: list | None = None, - ) -> None: - logger.info( - f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader called. mem_ids: {mem_ids}, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}" - ) - """ - Process memories using mem_reader for enhanced memory processing. - - Args: - mem_ids: List of memory IDs to process - user_id: User ID - mem_cube_id: Memory cube ID - text_mem: Text memory instance - custom_tags: Optional list of custom tags for memory processing - """ - kb_log_content: list[dict] = [] - try: - # Get the mem_reader from the parent MOSCore - if not hasattr(self, "mem_reader") or self.mem_reader is None: - logger.warning( - "mem_reader not available in scheduler, skipping enhanced processing" - ) - return - - # Get the original memory items - memory_items = [] - for mem_id in mem_ids: - try: - memory_item = text_mem.get(mem_id, user_name=user_name) - memory_items.append(memory_item) - except Exception as e: - logger.warning( - f"[_process_memories_with_reader] Failed to get memory {mem_id}: {e}" - ) - continue - - if not memory_items: - logger.warning("No valid memory items found for processing") - return - - # parse working_binding ids from the *original* memory_items (the raw items created in /add) - # these still carry metadata.background with "[working_binding:...]" so we can know - # which WorkingMemory clones should be cleaned up later. - from memos.memories.textual.tree_text_memory.organize.manager import ( - extract_working_binding_ids, - ) - - bindings_to_delete = extract_working_binding_ids(memory_items) - logger.info( - f"Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" - ) - - # Use mem_reader to process the memories - logger.info(f"Processing {len(memory_items)} memories with mem_reader") - - # Extract memories using mem_reader - try: - processed_memories = self.mem_reader.fine_transfer_simple_mem( - memory_items, - type="chat", - custom_tags=custom_tags, - user_name=user_name, - chat_history=chat_history, - ) - except Exception as e: - logger.warning(f"{e}: Fail to transfer mem: {memory_items}") - processed_memories = [] - - if processed_memories and len(processed_memories) > 0: - # Flatten the results (mem_reader returns list of lists) - flattened_memories = [] - for memory_list in processed_memories: - flattened_memories.extend(memory_list) - - logger.info(f"mem_reader processed {len(flattened_memories)} enhanced memories") - - # Add the enhanced memories back to the memory system - if flattened_memories: - enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name) - logger.info( - f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" - ) - - # Mark merged_from memories as archived when provided in memory metadata - if self.mem_reader.graph_db: - for memory in flattened_memories: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - for old_id in old_ids: - try: - self.mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name - ) - logger.info( - f"[Scheduler] Archived merged_from memory: {old_id}" - ) - except Exception as e: - logger.warning( - f"[Scheduler] Failed to archive merged_from memory {old_id}: {e}" - ) - else: - # Check if any memory has merged_from but graph_db is unavailable - has_merged_from = any( - (m.metadata.info or {}).get("merged_from") for m in flattened_memories - ) - if has_merged_from: - logger.warning( - "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." - ) - - # LOGGING BLOCK START - # This block is replicated from _add_message_consumer to ensure consistent logging - cloud_env = is_cloud_env() - if cloud_env: - # New: Knowledge Base Logging (Cloud Service) - kb_log_content = [] - for item in flattened_memories: - metadata = getattr(item, "metadata", None) - file_ids = getattr(metadata, "file_ids", None) if metadata else None - source_doc_id = ( - file_ids[0] if isinstance(file_ids, list) and file_ids else None - ) - kb_log_content.append( - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": info.get("trigger_source", "Messages") - if info - else "Messages", - "operation": "ADD", - "memory_id": item.id, - "content": item.memory, - "original_content": None, - "source_doc_id": source_doc_id, - } - ) - if kb_log_content: - logger.info( - f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}. KB content: {json.dumps(kb_log_content, indent=2)}" - ) - event = self.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.log_content = ( - f"Knowledge Base Memory Update: {len(kb_log_content)} changes." - ) - event.task_id = task_id - self._submit_web_logs([event]) - else: - # Existing: Playground/Default Logging - add_content_legacy: list[dict] = [] - add_meta_legacy: list[dict] = [] - for item_id, item in zip( - enhanced_mem_ids, flattened_memories, strict=False - ): - key = getattr(item.metadata, "key", None) or transform_name_to_key( - name=item.memory - ) - add_content_legacy.append( - {"content": f"{key}: {item.memory}", "ref_id": item_id} - ) - add_meta_legacy.append( - { - "ref_id": item_id, - "id": item_id, - "key": item.metadata.key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - if add_content_legacy: - event = self.create_event_log( - label="addMemory", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=add_content_legacy, - metadata=add_meta_legacy, - memory_len=len(add_content_legacy), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.task_id = task_id - self._submit_web_logs([event]) - # LOGGING BLOCK END - else: - logger.info("No enhanced memories generated by mem_reader") - else: - logger.info("mem_reader returned no processed memories") - - # build full delete list: - # - original raw mem_ids (temporary fast memories) - # - any bound working memories referenced by the enhanced memories - delete_ids = list(mem_ids) - if bindings_to_delete: - delete_ids.extend(list(bindings_to_delete)) - # deduplicate - delete_ids = list(dict.fromkeys(delete_ids)) - if delete_ids: - try: - text_mem.delete(delete_ids, user_name=user_name) - logger.info( - f"Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" - ) - except Exception as e: - logger.warning(f"Failed to delete some mem_ids {delete_ids}: {e}") - else: - logger.info("No mem_ids to delete (nothing to cleanup)") - - text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) - logger.info("Remove and Refresh Memories") - logger.debug(f"Finished add {user_id} memory: {mem_ids}") - - except Exception as exc: - logger.error( - f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True - ) - with contextlib.suppress(Exception): - cloud_env = is_cloud_env() - if cloud_env: - if not kb_log_content: - trigger_source = ( - info.get("trigger_source", "Messages") if info else "Messages" - ) - kb_log_content = [ - { - "log_source": "KNOWLEDGE_BASE_LOG", - "trigger_source": trigger_source, - "operation": "ADD", - "memory_id": mem_id, - "content": None, - "original_content": None, - "source_doc_id": None, - } - for mem_id in mem_ids - ] - event = self.create_event_log( - label="knowledgeBaseUpdate", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=self.mem_cube, - memcube_log_content=kb_log_content, - metadata=None, - memory_len=len(kb_log_content), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.log_content = f"Knowledge Base Memory Update failed: {exc!s}" - event.task_id = task_id - event.status = "failed" - self._submit_web_logs([event]) - - def _mem_reorganize_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - user_id = message.user_id - mem_cube_id = message.mem_cube_id - mem_cube = self.mem_cube - if mem_cube is None: - logger.warning( - f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing" - ) - return - content = message.content - user_name = message.user_name - - # Parse the memory IDs from content - mem_ids = json.loads(content) if isinstance(content, str) else content - if not mem_ids: - return - - logger.info( - f"Processing mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" - ) - - # Get the text memory from the mem_cube - text_mem = mem_cube.text_mem - if not isinstance(text_mem, TreeTextMemory): - logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") - return - - # Use mem_reader to process the memories - self._process_memories_with_reorganize( - mem_ids=mem_ids, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - text_mem=text_mem, - user_name=user_name, - ) - - with contextlib.suppress(Exception): - mem_items: list[TextualMemoryItem] = [] - for mid in mem_ids: - with contextlib.suppress(Exception): - mem_items.append(text_mem.get(mid, user_name=user_name)) - if len(mem_items) > 1: - keys: list[str] = [] - memcube_content: list[dict] = [] - meta: list[dict] = [] - merged_target_ids: set[str] = set() - with contextlib.suppress(Exception): - if hasattr(text_mem, "graph_store"): - for mid in mem_ids: - edges = text_mem.graph_store.get_edges( - mid, type="MERGED_TO", direction="OUT" - ) - for edge in edges: - target = ( - edge.get("to") or edge.get("dst") or edge.get("target") - ) - if target: - merged_target_ids.add(target) - for item in mem_items: - key = getattr( - getattr(item, "metadata", {}), "key", None - ) or transform_name_to_key(getattr(item, "memory", "")) - keys.append(key) - memcube_content.append( - {"content": key or "(no key)", "ref_id": item.id, "type": "merged"} - ) - meta.append( - { - "ref_id": item.id, - "id": item.id, - "key": key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } - ) - combined_key = keys[0] if keys else "" - post_ref_id = None - post_meta = { - "ref_id": None, - "id": None, - "key": None, - "memory": None, - "memory_type": None, - "status": None, - "confidence": None, - "tags": None, - "updated_at": None, - } - if merged_target_ids: - post_ref_id = next(iter(merged_target_ids)) - with contextlib.suppress(Exception): - merged_item = text_mem.get(post_ref_id, user_name=user_name) - combined_key = ( - getattr(getattr(merged_item, "metadata", {}), "key", None) - or combined_key - ) - post_meta = { - "ref_id": post_ref_id, - "id": post_ref_id, - "key": getattr( - getattr(merged_item, "metadata", {}), "key", None - ), - "memory": getattr(merged_item, "memory", None), - "memory_type": getattr( - getattr(merged_item, "metadata", {}), "memory_type", None - ), - "status": getattr( - getattr(merged_item, "metadata", {}), "status", None - ), - "confidence": getattr( - getattr(merged_item, "metadata", {}), "confidence", None - ), - "tags": getattr( - getattr(merged_item, "metadata", {}), "tags", None - ), - "updated_at": getattr( - getattr(merged_item, "metadata", {}), "updated_at", None - ) - or getattr( - getattr(merged_item, "metadata", {}), "update_at", None - ), - } - if not post_ref_id: - import hashlib - - post_ref_id = f"merge-{hashlib.md5(''.join(sorted(mem_ids)).encode()).hexdigest()}" - post_meta["ref_id"] = post_ref_id - post_meta["id"] = post_ref_id - if not post_meta.get("key"): - post_meta["key"] = combined_key - if not keys: - keys = [item.id for item in mem_items] - memcube_content.append( - { - "content": combined_key if combined_key else "(no key)", - "ref_id": post_ref_id, - "type": "postMerge", - } - ) - meta.append(post_meta) - event = self.create_event_log( - label="mergeMemory", - from_memory_type=LONG_TERM_MEMORY_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - memcube_log_content=memcube_content, - metadata=meta, - memory_len=len(keys), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - self._submit_web_logs([event]) - - logger.info( - f"Successfully processed mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - - except Exception as e: - logger.error(f"Error processing mem_reorganize message: {e}", exc_info=True) - - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(f"Thread task failed: {e}", exc_info=True) - - def _process_memories_with_reorganize( - self, - mem_ids: list[str], - user_id: str, - mem_cube_id: str, - mem_cube: GeneralMemCube, - text_mem: TreeTextMemory, - user_name: str, - ) -> None: - """ - Process memories using mem_reorganize for enhanced memory processing. - - Args: - mem_ids: List of memory IDs to process - user_id: User ID - mem_cube_id: Memory cube ID - mem_cube: Memory cube instance - text_mem: Text memory instance - """ - try: - # Get the mem_reader from the parent MOSCore - if not hasattr(self, "mem_reader") or self.mem_reader is None: - logger.warning( - "mem_reader not available in scheduler, skipping enhanced processing" - ) - return - - # Get the original memory items - memory_items = [] - for mem_id in mem_ids: - try: - memory_item = text_mem.get(mem_id, user_name=user_name) - memory_items.append(memory_item) - except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}") - continue - - if not memory_items: - logger.warning("No valid memory items found for processing") - return - - # Use mem_reader to process the memories - logger.info(f"Processing {len(memory_items)} memories with mem_reader") - text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) - logger.info("Remove and Refresh Memories") - logger.debug(f"Finished add {user_id} memory: {mem_ids}") - - except Exception: - logger.error( - f"Error in _process_memories_with_reorganize: {traceback.format_exc()}", - exc_info=True, - ) - - def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {PREF_ADD_TASK_LABEL} handler.") - - def process_message(message: ScheduleMessageItem): - try: - mem_cube = self.mem_cube - if mem_cube is None: - logger.warning( - f"mem_cube is None for user_id={message.user_id}, mem_cube_id={message.mem_cube_id}, skipping processing" - ) - return - - user_id = message.user_id - session_id = message.session_id - mem_cube_id = message.mem_cube_id - content = message.content - messages_list = json.loads(content) - info = message.info or {} - - logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}") - - # Get the preference memory from the mem_cube - pref_mem = mem_cube.pref_mem - if pref_mem is None: - logger.warning( - f"Preference memory not initialized for mem_cube_id={mem_cube_id}, " - f"skipping pref_add processing" - ) - return - if not isinstance(pref_mem, PreferenceTextMemory): - logger.error( - f"Expected PreferenceTextMemory but got {type(pref_mem).__name__} " - f"for mem_cube_id={mem_cube_id}" - ) - return - - # Use pref_mem.get_memory to process the memories - pref_memories = pref_mem.get_memory( - messages_list, - type="chat", - info={ - **info, - "user_id": user_id, - "session_id": session_id, - "mem_cube_id": mem_cube_id, - }, - ) - # Add pref_mem to vector db - pref_ids = pref_mem.add(pref_memories) - - logger.info( - f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}" - ) - - except Exception as e: - logger.error(f"Error processing pref_add message: {e}", exc_info=True) - - with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(f"Thread task failed: {e}", exc_info=True) - - def process_session_turn( - self, - queries: str | list[str], - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - top_k: int = 10, - ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]] | None: - """ - Process a dialog turn: - - If q_list reaches window size, trigger retrieval; - - Immediately switch to the new memory if retrieval is triggered. - """ - - text_mem_base = mem_cube.text_mem - if not isinstance(text_mem_base, TreeTextMemory): - if isinstance(text_mem_base, NaiveTextMemory): - logger.debug( - f"NaiveTextMemory used for mem_cube_id={mem_cube_id}, processing session turn with simple search." - ) - # Treat NaiveTextMemory similar to TreeTextMemory but with simpler logic - # We will perform retrieval to get "working memory" candidates for activation memory - # But we won't have a distinct "current working memory" - cur_working_memory = [] - else: - logger.warning( - f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " - f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " - f"text_mem_base value: {text_mem_base}" - ) - return [], [] - else: - cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( - user_name=mem_cube_id - ) - cur_working_memory = cur_working_memory[:top_k] - - logger.info( - f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - - text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] - intent_result = self.monitor.detect_intent( - q_list=queries, text_working_memory=text_working_memory - ) - - time_trigger_flag = False - if self.monitor.timed_trigger( - last_time=self.monitor.last_query_consume_time, - interval_seconds=self.monitor.query_trigger_interval, - ): - time_trigger_flag = True - - if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag): - logger.info( - f"[process_session_turn] Query schedule not triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. Intent_result: {intent_result}" - ) - return - elif (not intent_result["trigger_retrieval"]) and time_trigger_flag: - logger.info( - f"[process_session_turn] Query schedule forced to trigger due to time ticker for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) - intent_result["trigger_retrieval"] = True - intent_result["missing_evidences"] = queries - else: - logger.info( - f"[process_session_turn] Query schedule triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. " - f"Missing evidences: {intent_result['missing_evidences']}" - ) - - missing_evidences = intent_result["missing_evidences"] - num_evidence = len(missing_evidences) - k_per_evidence = max(1, top_k // max(1, num_evidence)) - new_candidates = [] - for item in missing_evidences: - logger.info( - f"[process_session_turn] Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}" - ) - - search_args = {} - if isinstance(text_mem_base, NaiveTextMemory): - # NaiveTextMemory doesn't support complex search args usually, but let's see - # self.retriever.search calls mem_cube.text_mem.search - # NaiveTextMemory.search takes query and top_k - # SchedulerRetriever.search handles method dispatch - # For NaiveTextMemory, we might need to bypass retriever or extend it - # But let's try calling naive memory directly if retriever fails or doesn't support it - try: - results = text_mem_base.search(query=item, top_k=k_per_evidence) - except Exception as e: - logger.warning(f"NaiveTextMemory search failed: {e}") - results = [] - else: - results: list[TextualMemoryItem] = self.retriever.search( - query=item, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - top_k=k_per_evidence, - method=self.search_method, - search_args=search_args, - ) - - logger.info( - f"[process_session_turn] Search results for missing evidence '{item}': " - + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in results])) - ) - new_candidates.extend(results) - return cur_working_memory, new_candidates + services = SchedulerHandlerServices( + validate_messages=self.validate_schedule_messages, + submit_messages=self.submit_messages, + create_event_log=self.create_event_log, + submit_web_logs=self._submit_web_logs, + map_memcube_name=self._map_memcube_name, + update_activation_memory_periodically=self.update_activation_memory_periodically, + replace_working_memory=self.replace_working_memory, + transform_working_memories_to_monitors=self.transform_working_memories_to_monitors, + log_working_memory_replacement=self.log_working_memory_replacement, + ) + ctx = SchedulerHandlerContext( + get_mem_cube=lambda: self.mem_cube, + get_monitor=lambda: self.monitor, + get_retriever=lambda: self.retriever, + get_mem_reader=lambda: self.mem_reader, + get_feedback_server=lambda: self.feedback_server, + get_search_method=lambda: self.search_method, + get_top_k=lambda: self.top_k, + get_enable_activation_memory=lambda: self.enable_activation_memory, + get_query_key_words_limit=lambda: self.query_key_words_limit, + services=services, + ) + + self._handler_registry = SchedulerHandlerRegistry(ctx) + self.register_handlers(self._handler_registry.build_dispatch_map()) diff --git a/src/memos/mem_scheduler/handlers/__init__.py b/src/memos/mem_scheduler/handlers/__init__.py new file mode 100644 index 000000000..75c56791a --- /dev/null +++ b/src/memos/mem_scheduler/handlers/__init__.py @@ -0,0 +1,9 @@ +from .context import SchedulerHandlerContext, SchedulerHandlerServices +from .registry import SchedulerHandlerRegistry + + +__all__ = [ + "SchedulerHandlerContext", + "SchedulerHandlerRegistry", + "SchedulerHandlerServices", +] diff --git a/src/memos/mem_scheduler/handlers/add_handler.py b/src/memos/mem_scheduler/handlers/add_handler.py new file mode 100644 index 000000000..5d1a8d3e0 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/add_handler.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import json + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + LONG_TERM_MEMORY_TYPE, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_cloud_env + + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.memories.textual.item import TextualMemoryItem + + +logger = get_logger(__name__) + + +class AddMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {ADD_TASK_LABEL} handler.") + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.ctx.services.validate_messages(messages=messages, label=ADD_TASK_LABEL) + try: + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + + for msg in batch: + prepared_add_items, prepared_update_items_with_original = ( + self.log_add_messages(msg=msg) + ) + logger.info( + "prepared_add_items: %s;\n prepared_update_items_with_original: %s", + prepared_add_items, + prepared_update_items_with_original, + ) + cloud_env = is_cloud_env() + + if cloud_env: + self.send_add_log_messages_to_cloud_env( + msg, prepared_add_items, prepared_update_items_with_original + ) + else: + self.send_add_log_messages_to_local_env( + msg, prepared_add_items, prepared_update_items_with_original + ) + + except Exception as e: + logger.error(f"Error: {e}", exc_info=True) + + def log_add_messages(self, msg: ScheduleMessageItem): + try: + userinput_memory_ids = json.loads(msg.content) + except Exception as e: + logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) + userinput_memory_ids = [] + + prepared_add_items = [] + prepared_update_items_with_original = [] + missing_ids: list[str] = [] + + mem_cube = self.ctx.get_mem_cube() + + for memory_id in userinput_memory_ids: + try: + mem_item: TextualMemoryItem | None = None + mem_item = mem_cube.text_mem.get(memory_id=memory_id, user_name=msg.mem_cube_id) + if mem_item is None: + raise ValueError(f"Memory {memory_id} not found after retries") + key = getattr(mem_item.metadata, "key", None) or transform_name_to_key( + name=mem_item.memory + ) + exists = False + original_content = None + original_item_id = None + + if key and hasattr(mem_cube.text_mem, "graph_store"): + candidates = mem_cube.text_mem.graph_store.get_by_metadata( + [ + {"field": "key", "op": "=", "value": key}, + { + "field": "memory_type", + "op": "=", + "value": mem_item.metadata.memory_type, + }, + ] + ) + if candidates: + exists = True + original_item_id = candidates[0] + original_mem_item = mem_cube.text_mem.get( + memory_id=original_item_id, user_name=msg.mem_cube_id + ) + original_content = original_mem_item.memory + + if exists: + prepared_update_items_with_original.append( + { + "new_item": mem_item, + "original_content": original_content, + "original_item_id": original_item_id, + } + ) + else: + prepared_add_items.append(mem_item) + + except Exception: + missing_ids.append(memory_id) + logger.debug( + "This MemoryItem %s has already been deleted or an error occurred during preparation.", + memory_id, + ) + + if missing_ids: + content_preview = ( + msg.content[:200] + "..." + if isinstance(msg.content, str) and len(msg.content) > 200 + else msg.content + ) + logger.warning( + "Missing TextualMemoryItem(s) during add log preparation. " + "memory_ids=%s user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s content_preview=%s", + missing_ids, + msg.user_id, + msg.mem_cube_id, + msg.task_id, + msg.item_id, + getattr(msg, "redis_message_id", ""), + msg.label, + getattr(msg, "stream_key", ""), + content_preview, + ) + + if not prepared_add_items and not prepared_update_items_with_original: + logger.warning( + "No add/update items prepared; skipping addMemory/knowledgeBaseUpdate logs. " + "user_id=%s mem_cube_id=%s task_id=%s item_id=%s redis_msg_id=%s label=%s stream_key=%s missing_ids=%s", + msg.user_id, + msg.mem_cube_id, + msg.task_id, + msg.item_id, + getattr(msg, "redis_message_id", ""), + msg.label, + getattr(msg, "stream_key", ""), + missing_ids, + ) + return prepared_add_items, prepared_update_items_with_original + + def send_add_log_messages_to_local_env( + self, + msg: ScheduleMessageItem, + prepared_add_items, + prepared_update_items_with_original, + ) -> None: + add_content_legacy: list[dict] = [] + add_meta_legacy: list[dict] = [] + update_content_legacy: list[dict] = [] + update_meta_legacy: list[dict] = [] + + for item in prepared_add_items: + key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) + add_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id}) + add_meta_legacy.append( + { + "ref_id": item.id, + "id": item.id, + "key": item.metadata.key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + + for item_data in prepared_update_items_with_original: + item = item_data["new_item"] + key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) + update_content_legacy.append({"content": f"{key}: {item.memory}", "ref_id": item.id}) + update_meta_legacy.append( + { + "ref_id": item.id, + "id": item.id, + "key": item.metadata.key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + + events = [] + if add_content_legacy: + event = self.ctx.services.create_event_log( + label="addMemory", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=add_content_legacy, + metadata=add_meta_legacy, + memory_len=len(add_content_legacy), + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + events.append(event) + if update_content_legacy: + event = self.ctx.services.create_event_log( + label="updateMemory", + from_memory_type=LONG_TERM_MEMORY_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=update_content_legacy, + metadata=update_meta_legacy, + memory_len=len(update_content_legacy), + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + events.append(event) + logger.info("send_add_log_messages_to_local_env: %s", len(events)) + if events: + self.ctx.services.submit_web_logs( + events, additional_log_info="send_add_log_messages_to_cloud_env" + ) + + def send_add_log_messages_to_cloud_env( + self, + msg: ScheduleMessageItem, + prepared_add_items, + prepared_update_items_with_original, + ) -> None: + kb_log_content: list[dict] = [] + info = msg.info or {} + + for item in prepared_add_items: + metadata = getattr(item, "metadata", None) + file_ids = getattr(metadata, "file_ids", None) if metadata else None + source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": info.get("trigger_source", "Messages"), + "operation": "ADD", + "memory_id": item.id, + "content": item.memory, + "original_content": None, + "source_doc_id": source_doc_id, + } + ) + + for item_data in prepared_update_items_with_original: + item = item_data["new_item"] + metadata = getattr(item, "metadata", None) + file_ids = getattr(metadata, "file_ids", None) if metadata else None + source_doc_id = file_ids[0] if isinstance(file_ids, list) and file_ids else None + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": info.get("trigger_source", "Messages"), + "operation": "UPDATE", + "memory_id": item.id, + "content": item.memory, + "original_content": item_data.get("original_content"), + "source_doc_id": source_doc_id, + } + ) + + if kb_log_content: + logger.info( + "[DIAGNOSTIC] add_handler.send_add_log_messages_to_cloud_env: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s", + msg.user_id, + msg.mem_cube_id, + msg.task_id, + json.dumps(kb_log_content, indent=2), + ) + event = self.ctx.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes." + event.task_id = msg.task_id + self.ctx.services.submit_web_logs([event]) diff --git a/src/memos/mem_scheduler/handlers/answer_handler.py b/src/memos/mem_scheduler/handlers/answer_handler.py new file mode 100644 index 000000000..9ec4086a4 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/answer_handler.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + NOT_APPLICABLE_TYPE, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +class AnswerMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {ANSWER_TASK_LABEL} handler.") + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.ctx.services.validate_messages(messages=messages, label=ANSWER_TASK_LABEL) + + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + try: + for msg in batch: + event = self.ctx.services.create_event_log( + label="addMessage", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=NOT_APPLICABLE_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=[ + { + "content": f"[Assistant] {msg.content}", + "ref_id": msg.item_id, + "role": "assistant", + } + ], + metadata=[], + memory_len=1, + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + self.ctx.services.submit_web_logs([event]) + except Exception: + logger.exception("Failed to record addMessage log for answer") diff --git a/src/memos/mem_scheduler/handlers/base.py b/src/memos/mem_scheduler/handlers/base.py new file mode 100644 index 000000000..e04add7d7 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/base.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from memos.mem_scheduler.handlers.context import SchedulerHandlerContext + + +class BaseSchedulerHandler: + def __init__(self, ctx: SchedulerHandlerContext) -> None: + self.ctx = ctx diff --git a/src/memos/mem_scheduler/handlers/context.py b/src/memos/mem_scheduler/handlers/context.py new file mode 100644 index 000000000..d5c1ea9af --- /dev/null +++ b/src/memos/mem_scheduler/handlers/context.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem + from memos.memories.textual.item import TextualMemoryItem + + +@dataclass(frozen=True) +class SchedulerHandlerServices: + validate_messages: Callable[[list[ScheduleMessageItem], str], None] + submit_messages: Callable[[list[ScheduleMessageItem]], None] + create_event_log: Callable[..., Any] + submit_web_logs: Callable[..., None] + map_memcube_name: Callable[[str], str] + update_activation_memory_periodically: Callable[..., None] + replace_working_memory: Callable[ + [str, str, Any, list[TextualMemoryItem], list[TextualMemoryItem]], + list[TextualMemoryItem] | None, + ] + transform_working_memories_to_monitors: Callable[..., list[MemoryMonitorItem]] + log_working_memory_replacement: Callable[..., None] + + +@dataclass(frozen=True) +class SchedulerHandlerContext: + get_mem_cube: Callable[[], Any] + get_monitor: Callable[[], Any] + get_retriever: Callable[[], Any] + get_mem_reader: Callable[[], Any] + get_feedback_server: Callable[[], Any] + get_search_method: Callable[[], str] + get_top_k: Callable[[], int] + get_enable_activation_memory: Callable[[], bool] + get_query_key_words_limit: Callable[[], int] + services: SchedulerHandlerServices diff --git a/src/memos/mem_scheduler/handlers/feedback_handler.py b/src/memos/mem_scheduler/handlers/feedback_handler.py new file mode 100644 index 000000000..cf52470dd --- /dev/null +++ b/src/memos/mem_scheduler/handlers/feedback_handler.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import json + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.task_schemas import LONG_TERM_MEMORY_TYPE, USER_INPUT_TYPE +from memos.mem_scheduler.utils.misc_utils import is_cloud_env + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +class FeedbackMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + try: + if not messages: + return + message = messages[0] + mem_cube = self.ctx.get_mem_cube() + + user_id = message.user_id + mem_cube_id = message.mem_cube_id + content = message.content + + try: + feedback_data = json.loads(content) if isinstance(content, str) else content + if not isinstance(feedback_data, dict): + logger.error( + "Failed to decode feedback_data or it is not a dict: %s", feedback_data + ) + return + except json.JSONDecodeError: + logger.error( + "Invalid JSON content for feedback message: %s", content, exc_info=True + ) + return + + task_id = feedback_data.get("task_id") or message.task_id + feedback_result = self.ctx.get_feedback_server().process_feedback( + user_id=user_id, + user_name=mem_cube_id, + session_id=feedback_data.get("session_id"), + chat_history=feedback_data.get("history", []), + retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []), + feedback_content=feedback_data.get("feedback_content"), + feedback_time=feedback_data.get("feedback_time"), + task_id=task_id, + info=feedback_data.get("info", None), + ) + + logger.info( + "Successfully processed feedback for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + cloud_env = is_cloud_env() + if cloud_env: + record = feedback_result.get("record") if isinstance(feedback_result, dict) else {} + add_records = record.get("add") if isinstance(record, dict) else [] + update_records = record.get("update") if isinstance(record, dict) else [] + + def _extract_fields(mem_item): + mem_id = ( + getattr(mem_item, "id", None) + if not isinstance(mem_item, dict) + else mem_item.get("id") + ) + mem_memory = ( + getattr(mem_item, "memory", None) + if not isinstance(mem_item, dict) + else mem_item.get("memory") or mem_item.get("text") + ) + if mem_memory is None and isinstance(mem_item, dict): + mem_memory = mem_item.get("text") + original_content = ( + getattr(mem_item, "origin_memory", None) + if not isinstance(mem_item, dict) + else mem_item.get("origin_memory") + or mem_item.get("old_memory") + or mem_item.get("original_content") + ) + source_doc_id = None + if isinstance(mem_item, dict): + source_doc_id = mem_item.get("source_doc_id", None) + + return mem_id, mem_memory, original_content, source_doc_id + + kb_log_content: list[dict] = [] + + for mem_item in add_records or []: + mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item) + if mem_id and mem_memory: + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": "Feedback", + "operation": "ADD", + "memory_id": mem_id, + "content": mem_memory, + "original_content": None, + "source_doc_id": source_doc_id, + } + ) + else: + logger.warning( + "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s", + user_id, + mem_cube_id, + task_id, + mem_item, + stack_info=True, + ) + + for mem_item in update_records or []: + mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item) + if mem_id and mem_memory: + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": "Feedback", + "operation": "UPDATE", + "memory_id": mem_id, + "content": mem_memory, + "original_content": original_content, + "source_doc_id": source_doc_id, + } + ) + else: + logger.warning( + "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s", + user_id, + mem_cube_id, + task_id, + mem_item, + stack_info=True, + ) + + logger.info("[Feedback Scheduler] kb_log_content: %s", kb_log_content) + if kb_log_content: + logger.info( + "[DIAGNOSTIC] feedback_handler: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s", + user_id, + mem_cube_id, + task_id, + len(kb_log_content), + ) + event = self.ctx.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + event.log_content = ( + f"Knowledge Base Memory Update: {len(kb_log_content)} changes." + ) + event.task_id = task_id + self.ctx.services.submit_web_logs([event]) + else: + logger.warning( + "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s", + user_id, + mem_cube_id, + task_id, + stack_info=True, + ) + else: + logger.info( + "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)", + cloud_env, + ) + + except Exception as e: + logger.error("Error processing feedbackMemory message: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/handlers/mem_read_handler.py b/src/memos/mem_scheduler/handlers/mem_read_handler.py new file mode 100644 index 000000000..76789f113 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/mem_read_handler.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import concurrent.futures +import contextlib +import json +import traceback + +from typing import TYPE_CHECKING + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.task_schemas import ( + LONG_TERM_MEMORY_TYPE, + MEM_READ_TASK_LABEL, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.mem_scheduler.utils.misc_utils import is_cloud_env +from memos.memories.textual.tree import TreeTextMemory + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +class MemReadMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info( + "[DIAGNOSTIC] mem_read_handler called. Received messages: %s", + [msg.model_dump_json(indent=2) for msg in messages], + ) + logger.info(f"Messages {messages} assigned to {MEM_READ_TASK_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = self.ctx.get_mem_cube() + if mem_cube is None: + logger.error( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", + user_id, + mem_cube_id, + stack_info=True, + ) + return + + content = message.content + user_name = message.user_name + info = message.info or {} + chat_history = message.chat_history + + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + "Processing mem_read for user_id=%s, mem_cube_id=%s, mem_ids=%s", + user_id, + mem_cube_id, + mem_ids, + ) + + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__) + return + + self._process_memories_with_reader( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + text_mem=text_mem, + user_name=user_name, + custom_tags=info.get("custom_tags", None), + task_id=message.task_id, + info=info, + chat_history=chat_history, + ) + + logger.info( + "Successfully processed mem_read for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + except Exception as e: + logger.error("Error processing mem_read message: %s", e, stack_info=True) + + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, stack_info=True) + + def _process_memories_with_reader( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + text_mem: TreeTextMemory, + user_name: str, + custom_tags: list[str] | None = None, + task_id: str | None = None, + info: dict | None = None, + chat_history: list | None = None, + ) -> None: + logger.info( + "[DIAGNOSTIC] mem_read_handler._process_memories_with_reader called. mem_ids: %s, user_id: %s, mem_cube_id: %s, task_id: %s", + mem_ids, + user_id, + mem_cube_id, + task_id, + ) + kb_log_content: list[dict] = [] + try: + mem_reader = self.ctx.get_mem_reader() + if mem_reader is None: + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) + return + + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id, user_name=user_name) + memory_items.append(memory_item) + except Exception as e: + logger.warning( + "[_process_memories_with_reader] Failed to get memory %s: %s", mem_id, e + ) + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + from memos.memories.textual.tree_text_memory.organize.manager import ( + extract_working_binding_ids, + ) + + bindings_to_delete = extract_working_binding_ids(memory_items) + logger.info( + "Extracted %s working_binding ids to cleanup: %s", + len(bindings_to_delete), + list(bindings_to_delete), + ) + + logger.info("Processing %s memories with mem_reader", len(memory_items)) + + try: + processed_memories = mem_reader.fine_transfer_simple_mem( + memory_items, + type="chat", + custom_tags=custom_tags, + user_name=user_name, + chat_history=chat_history, + ) + except Exception as e: + logger.warning("%s: Fail to transfer mem: %s", e, memory_items) + processed_memories = [] + + if processed_memories and len(processed_memories) > 0: + flattened_memories = [] + for memory_list in processed_memories: + flattened_memories.extend(memory_list) + + logger.info("mem_reader processed %s enhanced memories", len(flattened_memories)) + + if flattened_memories: + enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name) + logger.info( + "Added %s enhanced memories: %s", + len(enhanced_mem_ids), + enhanced_mem_ids, + ) + + if mem_reader.graph_db: + for memory in flattened_memories: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + for old_id in old_ids: + try: + mem_reader.graph_db.update_node( + str(old_id), {"status": "archived"}, user_name=user_name + ) + logger.info( + "[Scheduler] Archived merged_from memory: %s", + old_id, + ) + except Exception as e: + logger.warning( + "[Scheduler] Failed to archive merged_from memory %s: %s", + old_id, + e, + ) + else: + has_merged_from = any( + (m.metadata.info or {}).get("merged_from") for m in flattened_memories + ) + if has_merged_from: + logger.warning( + "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + ) + + cloud_env = is_cloud_env() + if cloud_env: + kb_log_content = [] + for item in flattened_memories: + metadata = getattr(item, "metadata", None) + file_ids = getattr(metadata, "file_ids", None) if metadata else None + source_doc_id = ( + file_ids[0] if isinstance(file_ids, list) and file_ids else None + ) + kb_log_content.append( + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": info.get("trigger_source", "Messages") + if info + else "Messages", + "operation": "ADD", + "memory_id": item.id, + "content": item.memory, + "original_content": None, + "source_doc_id": source_doc_id, + } + ) + if kb_log_content: + logger.info( + "[DIAGNOSTIC] mem_read_handler: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s", + user_id, + mem_cube_id, + task_id, + json.dumps(kb_log_content, indent=2), + ) + event = self.ctx.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + event.log_content = ( + f"Knowledge Base Memory Update: {len(kb_log_content)} changes." + ) + event.task_id = task_id + self.ctx.services.submit_web_logs([event]) + else: + add_content_legacy: list[dict] = [] + add_meta_legacy: list[dict] = [] + for item_id, item in zip( + enhanced_mem_ids, flattened_memories, strict=False + ): + key = getattr(item.metadata, "key", None) or transform_name_to_key( + name=item.memory + ) + add_content_legacy.append( + {"content": f"{key}: {item.memory}", "ref_id": item_id} + ) + add_meta_legacy.append( + { + "ref_id": item_id, + "id": item_id, + "key": item.metadata.key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + if add_content_legacy: + event = self.ctx.services.create_event_log( + label="addMemory", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=add_content_legacy, + metadata=add_meta_legacy, + memory_len=len(add_content_legacy), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + event.task_id = task_id + self.ctx.services.submit_web_logs([event]) + else: + logger.info("No enhanced memories generated by mem_reader") + else: + logger.info("mem_reader returned no processed memories") + + delete_ids = list(mem_ids) + if bindings_to_delete: + delete_ids.extend(list(bindings_to_delete)) + delete_ids = list(dict.fromkeys(delete_ids)) + if delete_ids: + try: + text_mem.delete(delete_ids, user_name=user_name) + logger.info( + "Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name + ) + except Exception as e: + logger.warning("Failed to delete some mem_ids %s: %s", delete_ids, e) + else: + logger.info("No mem_ids to delete (nothing to cleanup)") + + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) + logger.info("Remove and Refresh Memories") + logger.debug("Finished add %s memory: %s", user_id, mem_ids) + + except Exception as exc: + logger.error( + "Error in _process_memories_with_reader: %s", + traceback.format_exc(), + exc_info=True, + ) + with contextlib.suppress(Exception): + cloud_env = is_cloud_env() + if cloud_env: + if not kb_log_content: + trigger_source = ( + info.get("trigger_source", "Messages") if info else "Messages" + ) + kb_log_content = [ + { + "log_source": "KNOWLEDGE_BASE_LOG", + "trigger_source": trigger_source, + "operation": "ADD", + "memory_id": mem_id, + "content": None, + "original_content": None, + "source_doc_id": None, + } + for mem_id in mem_ids + ] + event = self.ctx.services.create_event_log( + label="knowledgeBaseUpdate", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=kb_log_content, + metadata=None, + memory_len=len(kb_log_content), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + event.log_content = f"Knowledge Base Memory Update failed: {exc!s}" + event.task_id = task_id + event.status = "failed" + self.ctx.services.submit_web_logs([event]) diff --git a/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py b/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py new file mode 100644 index 000000000..d437ebbd6 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/mem_reorganize_handler.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import concurrent.futures +import contextlib +import json +import traceback + +from typing import TYPE_CHECKING + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.task_schemas import LONG_TERM_MEMORY_TYPE, MEM_ORGANIZE_TASK_LABEL +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.memories.textual.tree import TreeTextMemory + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.memories.textual.item import TextualMemoryItem + + +class MemReorganizeMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_TASK_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + mem_cube_id = message.mem_cube_id + mem_cube = self.ctx.get_mem_cube() + if mem_cube is None: + logger.warning( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", + user_id, + mem_cube_id, + ) + return + content = message.content + user_name = message.user_name + + mem_ids = json.loads(content) if isinstance(content, str) else content + if not mem_ids: + return + + logger.info( + "Processing mem_reorganize for user_id=%s, mem_cube_id=%s, mem_ids=%s", + user_id, + mem_cube_id, + mem_ids, + ) + + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error("Expected TreeTextMemory but got %s", type(text_mem).__name__) + return + + self._process_memories_with_reorganize( + mem_ids=mem_ids, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + text_mem=text_mem, + user_name=user_name, + ) + + with contextlib.suppress(Exception): + mem_items: list[TextualMemoryItem] = [] + for mid in mem_ids: + with contextlib.suppress(Exception): + mem_items.append(text_mem.get(mid, user_name=user_name)) + if len(mem_items) > 1: + keys: list[str] = [] + memcube_content: list[dict] = [] + meta: list[dict] = [] + merged_target_ids: set[str] = set() + with contextlib.suppress(Exception): + if hasattr(text_mem, "graph_store"): + for mid in mem_ids: + edges = text_mem.graph_store.get_edges( + mid, type="MERGED_TO", direction="OUT" + ) + for edge in edges: + target = ( + edge.get("to") or edge.get("dst") or edge.get("target") + ) + if target: + merged_target_ids.add(target) + for item in mem_items: + key = getattr( + getattr(item, "metadata", {}), "key", None + ) or transform_name_to_key(getattr(item, "memory", "")) + keys.append(key) + memcube_content.append( + {"content": key or "(no key)", "ref_id": item.id, "type": "merged"} + ) + meta.append( + { + "ref_id": item.id, + "id": item.id, + "key": key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + ) + combined_key = keys[0] if keys else "" + post_ref_id = None + post_meta = { + "ref_id": None, + "id": None, + "key": None, + "memory": None, + "memory_type": None, + "status": None, + "confidence": None, + "tags": None, + "updated_at": None, + } + if merged_target_ids: + post_ref_id = next(iter(merged_target_ids)) + with contextlib.suppress(Exception): + merged_item = text_mem.get(post_ref_id, user_name=user_name) + combined_key = ( + getattr(getattr(merged_item, "metadata", {}), "key", None) + or combined_key + ) + post_meta = { + "ref_id": post_ref_id, + "id": post_ref_id, + "key": getattr( + getattr(merged_item, "metadata", {}), "key", None + ), + "memory": getattr(merged_item, "memory", None), + "memory_type": getattr( + getattr(merged_item, "metadata", {}), "memory_type", None + ), + "status": getattr( + getattr(merged_item, "metadata", {}), "status", None + ), + "confidence": getattr( + getattr(merged_item, "metadata", {}), "confidence", None + ), + "tags": getattr( + getattr(merged_item, "metadata", {}), "tags", None + ), + "updated_at": getattr( + getattr(merged_item, "metadata", {}), "updated_at", None + ) + or getattr( + getattr(merged_item, "metadata", {}), "update_at", None + ), + } + if not post_ref_id: + import hashlib + + post_ref_id = ( + "merge-" + + hashlib.md5("".join(sorted(mem_ids)).encode()).hexdigest() + ) + post_meta["ref_id"] = post_ref_id + post_meta["id"] = post_ref_id + if not post_meta.get("key"): + post_meta["key"] = combined_key + if not keys: + keys = [item.id for item in mem_items] + memcube_content.append( + { + "content": combined_key if combined_key else "(no key)", + "ref_id": post_ref_id, + "type": "postMerge", + } + ) + meta.append(post_meta) + event = self.ctx.services.create_event_log( + label="mergeMemory", + from_memory_type=LONG_TERM_MEMORY_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + memcube_log_content=memcube_content, + metadata=meta, + memory_len=len(keys), + memcube_name=self.ctx.services.map_memcube_name(mem_cube_id), + ) + self.ctx.services.submit_web_logs([event]) + + logger.info( + "Successfully processed mem_reorganize for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + + except Exception as e: + logger.error("Error processing mem_reorganize message: %s", e, exc_info=True) + + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, exc_info=True) + + def _process_memories_with_reorganize( + self, + mem_ids: list[str], + user_id: str, + mem_cube_id: str, + mem_cube, + text_mem: TreeTextMemory, + user_name: str, + ) -> None: + try: + mem_reader = self.ctx.get_mem_reader() + if mem_reader is None: + logger.warning( + "mem_reader not available in scheduler, skipping enhanced processing" + ) + return + + memory_items = [] + for mem_id in mem_ids: + try: + memory_item = text_mem.get(mem_id, user_name=user_name) + memory_items.append(memory_item) + except Exception as e: + logger.warning( + "Failed to get memory %s: %s|%s", mem_id, e, traceback.format_exc() + ) + continue + + if not memory_items: + logger.warning("No valid memory items found for processing") + return + + logger.info("Processing %s memories with mem_reader", len(memory_items)) + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) + logger.info("Remove and Refresh Memories") + logger.debug("Finished add %s memory: %s", user_id, mem_ids) + + except Exception: + logger.error( + "Error in _process_memories_with_reorganize: %s", + traceback.format_exc(), + exc_info=True, + ) diff --git a/src/memos/mem_scheduler/handlers/memory_update_handler.py b/src/memos/mem_scheduler/handlers/memory_update_handler.py new file mode 100644 index 000000000..0d3d1719e --- /dev/null +++ b/src/memos/mem_scheduler/handlers/memory_update_handler.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem +from memos.mem_scheduler.schemas.task_schemas import ( + DEFAULT_MAX_QUERY_KEY_WORDS, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, +) +from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube +from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.tree import TreeTextMemory + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + from memos.memories.textual.item import TextualMemoryItem + from memos.types import MemCubeID, UserID + + +class MemoryUpdateHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_UPDATE_TASK_LABEL} handler.") + + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.ctx.services.validate_messages(messages=messages, label=MEM_UPDATE_TASK_LABEL) + + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + self.long_memory_update_process( + user_id=user_id, mem_cube_id=mem_cube_id, messages=batch + ) + + def long_memory_update_process( + self, + user_id: str, + mem_cube_id: str, + messages: list[ScheduleMessageItem], + ) -> None: + mem_cube = self.ctx.get_mem_cube() + monitor = self.ctx.get_monitor() + + query_key_words_limit = self.ctx.get_query_key_words_limit() + + for msg in messages: + monitor.register_query_monitor_if_not_exists(user_id=user_id, mem_cube_id=mem_cube_id) + + query = msg.content + query_keywords = monitor.extract_query_keywords(query=query) + logger.info( + 'Extracted keywords "%s" from query "%s" for user_id=%s', + query_keywords, + query, + user_id, + ) + + if len(query_keywords) == 0: + stripped_query = query.strip() + if is_all_english(stripped_query): + words = stripped_query.split() + elif is_all_chinese(stripped_query): + words = stripped_query + else: + logger.debug( + "Mixed-language memory, using character count: %s...", + stripped_query[:50], + ) + words = stripped_query + + query_keywords = list(set(words[:query_key_words_limit])) + logger.error( + "Keyword extraction failed for query '%s' (user_id=%s). Using fallback keywords: %s... (truncated)", + query, + user_id, + query_keywords[:10], + exc_info=True, + ) + + item = QueryMonitorItem( + user_id=user_id, + mem_cube_id=mem_cube_id, + query_text=query, + keywords=query_keywords, + max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, + ) + + query_db_manager = monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.obj.put(item=item) + query_db_manager.sync_with_orm() + logger.debug( + "Queries in monitor for user_id=%s, mem_cube_id=%s: %s", + user_id, + mem_cube_id, + query_db_manager.obj.get_queries_with_timesort(), + ) + + queries = [msg.content for msg in messages] + + cur_working_memory, new_candidates = self.process_session_turn( + queries=queries, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=self.ctx.get_top_k(), + ) + logger.info( + "[long_memory_update_process] Processed %s queries %s and retrieved %s new candidate memories for user_id=%s: " + + ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in new_candidates])), + len(queries), + queries, + len(new_candidates), + user_id, + ) + + new_order_working_memory = self.ctx.services.replace_working_memory( + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + original_memory=cur_working_memory, + new_memory=new_candidates, + ) + logger.debug( + "[long_memory_update_process] Final working memory size: %s memories for user_id=%s", + len(new_order_working_memory), + user_id, + ) + + old_memory_texts = "\n- " + "\n- ".join( + [f"{one.id}: {one.memory}" for one in cur_working_memory] + ) + new_memory_texts = "\n- " + "\n- ".join( + [f"{one.id}: {one.memory}" for one in new_order_working_memory] + ) + + logger.info( + "[long_memory_update_process] For user_id='%s', mem_cube_id='%s': " + "Scheduler replaced working memory based on query history %s. " + "Old working memory (%s items): %s. " + "New working memory (%s items): %s.", + user_id, + mem_cube_id, + queries, + len(cur_working_memory), + old_memory_texts, + len(new_order_working_memory), + new_memory_texts, + ) + + logger.debug( + "Activation memory update %s (interval: %ss)", + "enabled" if self.ctx.get_enable_activation_memory() else "disabled", + monitor.act_mem_update_interval, + ) + if self.ctx.get_enable_activation_memory(): + self.ctx.services.update_activation_memory_periodically( + interval_seconds=monitor.act_mem_update_interval, + label=QUERY_TASK_LABEL, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + def process_session_turn( + self, + queries: str | list[str], + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube, + top_k: int = 10, + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]] | None: + text_mem_base = mem_cube.text_mem + if not isinstance(text_mem_base, TreeTextMemory): + if isinstance(text_mem_base, NaiveTextMemory): + logger.debug( + "NaiveTextMemory used for mem_cube_id=%s, processing session turn with simple search.", + mem_cube_id, + ) + cur_working_memory = [] + else: + logger.warning( + "Not implemented! Expected TreeTextMemory but got %s for mem_cube_id=%s, user_id=%s. text_mem_base value: %s", + type(text_mem_base).__name__, + mem_cube_id, + user_id, + text_mem_base, + ) + return [], [] + else: + cur_working_memory = text_mem_base.get_working_memory(user_name=mem_cube_id) + cur_working_memory = cur_working_memory[:top_k] + + logger.info( + "[process_session_turn] Processing %s queries for user_id=%s, mem_cube_id=%s", + len(queries), + user_id, + mem_cube_id, + ) + + text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] + monitor = self.ctx.get_monitor() + intent_result = monitor.detect_intent( + q_list=queries, text_working_memory=text_working_memory + ) + + time_trigger_flag = False + if monitor.timed_trigger( + last_time=monitor.last_query_consume_time, + interval_seconds=monitor.query_trigger_interval, + ): + time_trigger_flag = True + + if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag): + logger.info( + "[process_session_turn] Query schedule not triggered for user_id=%s, mem_cube_id=%s. Intent_result: %s", + user_id, + mem_cube_id, + intent_result, + ) + return + if (not intent_result["trigger_retrieval"]) and time_trigger_flag: + logger.info( + "[process_session_turn] Query schedule forced to trigger due to time ticker for user_id=%s, mem_cube_id=%s", + user_id, + mem_cube_id, + ) + intent_result["trigger_retrieval"] = True + intent_result["missing_evidences"] = queries + else: + logger.info( + "[process_session_turn] Query schedule triggered for user_id=%s, mem_cube_id=%s. Missing evidences: %s", + user_id, + mem_cube_id, + intent_result["missing_evidences"], + ) + + missing_evidences = intent_result["missing_evidences"] + num_evidence = len(missing_evidences) + k_per_evidence = max(1, top_k // max(1, num_evidence)) + new_candidates: list[TextualMemoryItem] = [] + retriever = self.ctx.get_retriever() + search_method = self.ctx.get_search_method() + + for item in missing_evidences: + logger.info( + "[process_session_turn] Searching for missing evidence: '%s' with top_k=%s for user_id=%s", + item, + k_per_evidence, + user_id, + ) + + search_args = {} + if isinstance(text_mem_base, NaiveTextMemory): + try: + results = text_mem_base.search(query=item, top_k=k_per_evidence) + except Exception as e: + logger.warning("NaiveTextMemory search failed: %s", e) + results = [] + else: + results = retriever.search( + query=item, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=search_method, + search_args=search_args, + ) + + logger.info( + "[process_session_turn] Search results for missing evidence '%s': \n- %s", + item, + "\n- ".join([f"{one.id}: {one.memory}" for one in results]), + ) + new_candidates.extend(results) + return cur_working_memory, new_candidates diff --git a/src/memos/mem_scheduler/handlers/pref_add_handler.py b/src/memos/mem_scheduler/handlers/pref_add_handler.py new file mode 100644 index 000000000..4d17b0847 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/pref_add_handler.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import concurrent.futures +import json + +from typing import TYPE_CHECKING + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.task_schemas import PREF_ADD_TASK_LABEL +from memos.memories.textual.preference import PreferenceTextMemory + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +class PrefAddMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {PREF_ADD_TASK_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + mem_cube = self.ctx.get_mem_cube() + if mem_cube is None: + logger.warning( + "mem_cube is None for user_id=%s, mem_cube_id=%s, skipping processing", + message.user_id, + message.mem_cube_id, + ) + return + + user_id = message.user_id + session_id = message.session_id + mem_cube_id = message.mem_cube_id + content = message.content + messages_list = json.loads(content) + info = message.info or {} + + logger.info( + "Processing pref_add for user_id=%s, mem_cube_id=%s", user_id, mem_cube_id + ) + + pref_mem = mem_cube.pref_mem + if pref_mem is None: + logger.warning( + "Preference memory not initialized for mem_cube_id=%s, skipping pref_add processing", + mem_cube_id, + ) + return + if not isinstance(pref_mem, PreferenceTextMemory): + logger.error( + "Expected PreferenceTextMemory but got %s for mem_cube_id=%s", + type(pref_mem).__name__, + mem_cube_id, + ) + return + + pref_memories = pref_mem.get_memory( + messages_list, + type="chat", + info={ + **info, + "user_id": user_id, + "session_id": session_id, + "mem_cube_id": mem_cube_id, + }, + ) + pref_ids = pref_mem.add(pref_memories) + + logger.info( + "Successfully processed and add preferences for user_id=%s, mem_cube_id=%s, pref_ids=%s", + user_id, + mem_cube_id, + pref_ids, + ) + + except Exception as e: + logger.error("Error processing pref_add message: %s", e, exc_info=True) + + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error("Thread task failed: %s", e, exc_info=True) diff --git a/src/memos/mem_scheduler/handlers/query_handler.py b/src/memos/mem_scheduler/handlers/query_handler.py new file mode 100644 index 000000000..4d3a09368 --- /dev/null +++ b/src/memos/mem_scheduler/handlers/query_handler.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.handlers.base import BaseSchedulerHandler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + MEM_UPDATE_TASK_LABEL, + NOT_APPLICABLE_TYPE, + QUERY_TASK_LABEL, + USER_INPUT_TYPE, +) +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube + + +logger = get_logger(__name__) + + +class QueryMessageHandler(BaseSchedulerHandler): + def handle(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {QUERY_TASK_LABEL} handler.") + + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.ctx.services.validate_messages(messages=messages, label=QUERY_TASK_LABEL) + + mem_update_messages: list[ScheduleMessageItem] = [] + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + + for msg in batch: + try: + event = self.ctx.services.create_event_log( + label="addMessage", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=NOT_APPLICABLE_TYPE, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=self.ctx.get_mem_cube(), + memcube_log_content=[ + { + "content": f"[User] {msg.content}", + "ref_id": msg.item_id, + "role": "user", + } + ], + metadata=[], + memory_len=1, + memcube_name=self.ctx.services.map_memcube_name(msg.mem_cube_id), + ) + event.task_id = msg.task_id + self.ctx.services.submit_web_logs([event]) + except Exception: + logger.exception("Failed to record addMessage log for query") + + update_msg = ScheduleMessageItem( + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + label=MEM_UPDATE_TASK_LABEL, + content=msg.content, + session_id=msg.session_id, + user_name=msg.user_name, + info=msg.info, + task_id=msg.task_id, + ) + mem_update_messages.append(update_msg) + + self.ctx.services.submit_messages(messages=mem_update_messages) diff --git a/src/memos/mem_scheduler/handlers/registry.py b/src/memos/mem_scheduler/handlers/registry.py new file mode 100644 index 000000000..2a62aa57f --- /dev/null +++ b/src/memos/mem_scheduler/handlers/registry.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from memos.mem_scheduler.handlers.context import SchedulerHandlerContext + +from memos.mem_scheduler.handlers.add_handler import AddMessageHandler +from memos.mem_scheduler.handlers.answer_handler import AnswerMessageHandler +from memos.mem_scheduler.handlers.feedback_handler import FeedbackMessageHandler +from memos.mem_scheduler.handlers.mem_read_handler import MemReadMessageHandler +from memos.mem_scheduler.handlers.mem_reorganize_handler import MemReorganizeMessageHandler +from memos.mem_scheduler.handlers.memory_update_handler import MemoryUpdateHandler +from memos.mem_scheduler.handlers.pref_add_handler import PrefAddMessageHandler +from memos.mem_scheduler.handlers.query_handler import QueryMessageHandler +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_FEEDBACK_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_READ_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + PREF_ADD_TASK_LABEL, + QUERY_TASK_LABEL, +) + + +class SchedulerHandlerRegistry: + def __init__(self, ctx: SchedulerHandlerContext) -> None: + self.query = QueryMessageHandler(ctx) + self.answer = AnswerMessageHandler(ctx) + self.add = AddMessageHandler(ctx) + self.memory_update = MemoryUpdateHandler(ctx) + self.mem_feedback = FeedbackMessageHandler(ctx) + self.mem_read = MemReadMessageHandler(ctx) + self.mem_reorganize = MemReorganizeMessageHandler(ctx) + self.pref_add = PrefAddMessageHandler(ctx) + + def build_dispatch_map(self) -> dict[str, Callable]: + return { + QUERY_TASK_LABEL: self.query.handle, + ANSWER_TASK_LABEL: self.answer.handle, + MEM_UPDATE_TASK_LABEL: self.memory_update.handle, + ADD_TASK_LABEL: self.add.handle, + MEM_READ_TASK_LABEL: self.mem_read.handle, + MEM_ORGANIZE_TASK_LABEL: self.mem_reorganize.handle, + PREF_ADD_TASK_LABEL: self.pref_add.handle, + MEM_FEEDBACK_TASK_LABEL: self.mem_feedback.handle, + } diff --git a/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py new file mode 100644 index 000000000..98125c13b --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/enhancement_pipeline.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import time + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + DEFAULT_SCHEDULER_RETRIEVER_RETRIES, +) +from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer +from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata +from memos.types.general_types import FINE_STRATEGY, FineStrategy + + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from collections.abc import Callable + + +class EnhancementPipeline: + def __init__(self, process_llm, config, build_prompt: Callable[..., str]): + self.process_llm = process_llm + self.config = config + self.build_prompt = build_prompt + self.batch_size: int | None = getattr( + config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE + ) + self.retries: int = getattr( + config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES + ) + + def evaluate_memory_answer_ability( + self, query: str, memory_texts: list[str], top_k: int | None = None + ) -> bool: + limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts + prompt = self.build_prompt( + template_name="memory_answer_ability_evaluation", + query=query, + memory_list="\n".join([f"- {memory}" for memory in limited_memories]) + if limited_memories + else "No memories available", + ) + + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + try: + result = extract_json_obj(response) + + if "result" in result: + logger.info( + "Answerability: result=%s; reason=%s; evaluated=%s", + result["result"], + result.get("reason", "n/a"), + len(limited_memories), + ) + return result["result"] + logger.warning("Answerability: invalid LLM JSON structure; payload=%s", result) + return False + + except Exception as e: + logger.error("Answerability: parse failed; err=%s; raw=%s...", e, str(response)[:200]) + return False + + def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str: + if len(query_history) == 1: + query_history = query_history[0] + else: + query_history = ( + [f"[{i}] {query}" for i, query in enumerate(query_history)] + if len(query_history) > 1 + else query_history[0] + ) + if FINE_STRATEGY == FineStrategy.REWRITE: + text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)]) + prompt_name = "memory_rewrite_enhancement" + else: + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) + prompt_name = "memory_recreate_enhancement" + return self.build_prompt( + prompt_name, + query_history=query_history, + memories=text_memories, + ) + + def _process_enhancement_batch( + self, + batch_index: int, + query_history: list[str], + memories: list[TextualMemoryItem], + retries: int, + ) -> tuple[list[TextualMemoryItem], bool]: + attempt = 0 + text_memories = [one.memory for one in memories] + + prompt = self._build_enhancement_prompt( + query_history=query_history, batch_texts=text_memories + ) + + llm_response = None + while attempt <= max(0, retries) + 1: + try: + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + processed_text_memories = extract_list_items_in_answer(llm_response) + if len(processed_text_memories) > 0: + enhanced_memories = [] + user_id = memories[0].metadata.user_id + if FINE_STRATEGY == FineStrategy.RECREATE: + for new_mem in processed_text_memories: + enhanced_memories.append( + TextualMemoryItem( + memory=new_mem, + metadata=TextualMemoryMetadata( + user_id=user_id, memory_type="LongTermMemory" + ), + ) + ) + elif FINE_STRATEGY == FineStrategy.REWRITE: + + def _parse_index_and_text(s: str) -> tuple[int | None, str]: + import re + + s = (s or "").strip() + m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s) + if m: + return int(m.group(1)), m.group(2).strip() + m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s) + if m: + return int(m.group(1)), m.group(2).strip() + return None, s + + idx_to_original = dict(enumerate(memories)) + for j, item in enumerate(processed_text_memories): + idx, new_text = _parse_index_and_text(item) + if idx is not None and idx in idx_to_original: + orig = idx_to_original[idx] + else: + orig = memories[j] if j < len(memories) else None + if not orig: + continue + enhanced_memories.append( + TextualMemoryItem( + id=orig.id, + memory=new_text, + metadata=orig.metadata, + ) + ) + else: + logger.error("Fine search strategy %s not exists", FINE_STRATEGY) + + logger.info( + "[enhance_memories_with_query] done | Strategy=%s | prompt=%s | llm_response=%s", + FINE_STRATEGY, + prompt, + llm_response, + ) + return enhanced_memories, True + raise ValueError( + "Fail to run memory enhancement; retry " + f"{attempt}/{max(1, retries) + 1}; " + f"processed_text_memories: {processed_text_memories}" + ) + except Exception as e: + attempt += 1 + time.sleep(1) + logger.debug( + "[enhance_memories_with_query][batch=%s] retry %s/%s failed: %s", + batch_index, + attempt, + max(1, retries) + 1, + e, + ) + logger.error( + "Fail to run memory enhancement; prompt: %s;\n llm_response: %s", + prompt, + llm_response, + exc_info=True, + ) + return memories, False + + @staticmethod + def _split_batches( + memories: list[TextualMemoryItem], batch_size: int + ) -> list[tuple[int, int, list[TextualMemoryItem]]]: + batches: list[tuple[int, int, list[TextualMemoryItem]]] = [] + start = 0 + n = len(memories) + while start < n: + end = min(start + batch_size, n) + batches.append((start, end, memories[start:end])) + start = end + return batches + + def recall_for_missing_memories(self, query: str, memories: list[str]) -> tuple[str, bool]: + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)]) + + prompt = self.build_prompt( + template_name="enlarge_recall", + query=query, + memories_inline=text_memories, + ) + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + json_result: dict = extract_json_obj(llm_response) + + logger.info( + "[recall_for_missing_memories] done | prompt=%s | llm_response=%s", + prompt, + llm_response, + ) + + hint = json_result.get("hint", "") + if len(hint) == 0: + return hint, False + return hint, json_result.get("trigger_recall", False) + + def enhance_memories_with_query( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> tuple[list[TextualMemoryItem], bool]: + if not memories: + logger.warning("[Enhance] skipped (no memories to process)") + return memories, True + + batch_size = self.batch_size + retries = self.retries + num_of_memories = len(memories) + try: + if batch_size is None or num_of_memories <= batch_size: + enhanced_memories, success_flag = self._process_enhancement_batch( + batch_index=0, + query_history=query_history, + memories=memories, + retries=retries, + ) + + all_success = success_flag + else: + batches = self._split_batches(memories=memories, batch_size=batch_size) + + all_success = True + failed_batches = 0 + from concurrent.futures import as_completed + + from memos.context.context import ContextThreadPoolExecutor + + with ContextThreadPoolExecutor(max_workers=len(batches)) as executor: + future_map = { + executor.submit( + self._process_enhancement_batch, bi, query_history, texts, retries + ): (bi, s, e) + for bi, (s, e, texts) in enumerate(batches) + } + enhanced_memories = [] + for fut in as_completed(future_map): + _bi, _s, _e = future_map[fut] + + batch_memories, ok = fut.result() + enhanced_memories.extend(batch_memories) + if not ok: + all_success = False + failed_batches += 1 + logger.info( + "[Enhance] multi-batch done | batches=%s | enhanced=%s | failed_batches=%s | success=%s", + len(batches), + len(enhanced_memories), + failed_batches, + all_success, + ) + + except Exception as e: + logger.error("[Enhance] fatal error: %s", e, exc_info=True) + all_success = False + enhanced_memories = memories + + if len(enhanced_memories) == 0: + enhanced_memories = [] + logger.error("[Enhance] fatal error: enhanced_memories is empty", exc_info=True) + return enhanced_memories, all_success diff --git a/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py new file mode 100644 index 000000000..315f821a9 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/filter_pipeline.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from memos.mem_scheduler.memory_manage_modules.memory_filter import MemoryFilter + + +if TYPE_CHECKING: + from memos.memories.textual.tree import TextualMemoryItem + + +class FilterPipeline: + def __init__(self, process_llm, config): + self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) + + def filter_unrelated_memories( + self, query_history: list[str], memories: list[TextualMemoryItem] + ) -> tuple[list[TextualMemoryItem], bool]: + return self.memory_filter.filter_unrelated_memories(query_history, memories) + + def filter_redundant_memories( + self, query_history: list[str], memories: list[TextualMemoryItem] + ) -> tuple[list[TextualMemoryItem], bool]: + return self.memory_filter.filter_redundant_memories(query_history, memories) + + def filter_unrelated_and_redundant_memories( + self, query_history: list[str], memories: list[TextualMemoryItem] + ) -> tuple[list[TextualMemoryItem], bool]: + return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories) diff --git a/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py new file mode 100644 index 000000000..0e347df6a --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/rerank_pipeline.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.utils.filter_utils import ( + filter_too_short_memories, + filter_vector_based_similar_memories, + transform_name_to_key, +) +from memos.mem_scheduler.utils.misc_utils import extract_json_obj + + +if TYPE_CHECKING: + from memos.memories.textual.item import TextualMemoryItem + + +logger = get_logger(__name__) + + +class RerankPipeline: + def __init__( + self, + process_llm, + similarity_threshold: float, + min_length_threshold: int, + build_prompt, + ): + self.process_llm = process_llm + self.filter_similarity_threshold = similarity_threshold + self.filter_min_length_threshold = min_length_threshold + self.build_prompt = build_prompt + + def rerank_memories( + self, queries: list[str], original_memories: list[str], top_k: int + ) -> tuple[list[str], bool]: + logger.info("Starting memory reranking for %s memories", len(original_memories)) + + prompt = self.build_prompt( + "memory_reranking", + queries=[f"[0] {queries[0]}"], + current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)], + ) + logger.debug("Generated reranking prompt: %s...", prompt[:200]) + + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug("Received LLM response: %s...", response[:200]) + + try: + response = extract_json_obj(response) + new_order = response["new_order"][:top_k] + text_memories_with_new_order = [original_memories[idx] for idx in new_order] + logger.info( + "Successfully reranked memories. Returning top %s items; Ranking reasoning: %s", + len(text_memories_with_new_order), + response["reasoning"], + ) + success_flag = True + except Exception as e: + logger.error( + "Failed to rerank memories with LLM. Exception: %s. Raw response: %s ", + e, + response, + exc_info=True, + ) + text_memories_with_new_order = original_memories[:top_k] + success_flag = False + return text_memories_with_new_order, success_flag + + def process_and_rerank_memories( + self, + queries: list[str], + original_memory: list[TextualMemoryItem], + new_memory: list[TextualMemoryItem], + top_k: int = 10, + ) -> tuple[list[TextualMemoryItem], bool]: + combined_memory = original_memory + new_memory + + memory_map = { + transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory + } + + combined_text_memory = [m.memory for m in combined_memory] + + filtered_combined_text_memory = filter_vector_based_similar_memories( + text_memories=combined_text_memory, + similarity_threshold=self.filter_similarity_threshold, + ) + + filtered_combined_text_memory = filter_too_short_memories( + text_memories=filtered_combined_text_memory, + min_length_threshold=self.filter_min_length_threshold, + ) + + unique_memory = list(dict.fromkeys(filtered_combined_text_memory)) + + text_memories_with_new_order, success_flag = self.rerank_memories( + queries=queries, + original_memories=unique_memory, + top_k=top_k, + ) + + memories_with_new_order = [] + for text in text_memories_with_new_order: + normalized_text = transform_name_to_key(name=text) + if normalized_text in memory_map: + memories_with_new_order.append(memory_map[normalized_text]) + else: + logger.warning( + "Memory text not found in memory map. text: %s;\nKeys of memory_map: %s", + text, + memory_map.keys(), + ) + + return memories_with_new_order, success_flag diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index f205766f0..3e849f470 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -1,457 +1,99 @@ -import time +from __future__ import annotations -from concurrent.futures import as_completed +from typing import TYPE_CHECKING -from memos.configs.mem_scheduler import BaseSchedulerConfig -from memos.context.context import ContextThreadPoolExecutor -from memos.llms.base import BaseLLM from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.schemas.general_schemas import ( - DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, - DEFAULT_SCHEDULER_RETRIEVER_RETRIES, - TreeTextMemory_FINE_SEARCH_METHOD, - TreeTextMemory_SEARCH_METHOD, -) -from memos.mem_scheduler.utils.filter_utils import ( - filter_too_short_memories, - filter_vector_based_similar_memories, - transform_name_to_key, -) -from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer -from memos.memories.textual.item import TextualMemoryMetadata -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -from memos.types.general_types import ( - FINE_STRATEGY, - FineStrategy, - SearchMode, -) +from memos.mem_scheduler.memory_manage_modules.enhancement_pipeline import EnhancementPipeline +from memos.mem_scheduler.memory_manage_modules.filter_pipeline import FilterPipeline +from memos.mem_scheduler.memory_manage_modules.rerank_pipeline import RerankPipeline +from memos.mem_scheduler.memory_manage_modules.search_pipeline import SearchPipeline -# Extract JSON response -from .memory_filter import MemoryFilter + +if TYPE_CHECKING: + from memos.memories.textual.item import TextualMemoryItem logger = get_logger(__name__) class SchedulerRetriever(BaseSchedulerModule): - def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): + def __init__(self, process_llm, config): super().__init__() - # hyper-parameters self.filter_similarity_threshold = 0.75 self.filter_min_length_threshold = 6 - self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) self.process_llm = process_llm self.config = config - # Configure enhancement batching & retries from config with safe defaults - self.batch_size: int | None = getattr( - config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE + self.search_pipeline = SearchPipeline() + self.enhancement_pipeline = EnhancementPipeline( + process_llm=process_llm, + config=config, + build_prompt=self.build_prompt, ) - self.retries: int = getattr( - config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES + self.rerank_pipeline = RerankPipeline( + process_llm=process_llm, + similarity_threshold=self.filter_similarity_threshold, + min_length_threshold=self.filter_min_length_threshold, + build_prompt=self.build_prompt, ) + self.filter_pipeline = FilterPipeline(process_llm=process_llm, config=config) + self.memory_filter = self.filter_pipeline.memory_filter def evaluate_memory_answer_ability( self, query: str, memory_texts: list[str], top_k: int | None = None ) -> bool: - limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts - # Build prompt using the template - prompt = self.build_prompt( - template_name="memory_answer_ability_evaluation", - query=query, - memory_list="\n".join([f"- {memory}" for memory in limited_memories]) - if limited_memories - else "No memories available", - ) - - # Use the process LLM to generate response - response = self.process_llm.generate([{"role": "user", "content": prompt}]) - - try: - result = extract_json_obj(response) - - # Validate response structure - if "result" in result: - logger.info( - f"Answerability: result={result['result']}; reason={result.get('reason', 'n/a')}; evaluated={len(limited_memories)}" - ) - return result["result"] - else: - logger.warning(f"Answerability: invalid LLM JSON structure; payload={result}") - return False - - except Exception as e: - logger.error(f"Answerability: parse failed; err={e}; raw={str(response)[:200]}...") - # Fallback: return False if we can't determine answer ability - return False - - # ---------------------- Enhancement helpers ---------------------- - def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str: - if len(query_history) == 1: - query_history = query_history[0] - else: - query_history = ( - [f"[{i}] {query}" for i, query in enumerate(query_history)] - if len(query_history) > 1 - else query_history[0] - ) - # Include numbering for rewrite mode to help LLM reference original memory IDs - if FINE_STRATEGY == FineStrategy.REWRITE: - text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)]) - prompt_name = "memory_rewrite_enhancement" - else: - text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) - prompt_name = "memory_recreate_enhancement" - return self.build_prompt( - prompt_name, - query_history=query_history, - memories=text_memories, - ) - - def _process_enhancement_batch( - self, - batch_index: int, - query_history: list[str], - memories: list[TextualMemoryItem], - retries: int, - ) -> tuple[list[TextualMemoryItem], bool]: - attempt = 0 - text_memories = [one.memory for one in memories] - - prompt = self._build_enhancement_prompt( - query_history=query_history, batch_texts=text_memories - ) - - llm_response = None - while attempt <= max(0, retries) + 1: - try: - llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) - processed_text_memories = extract_list_items_in_answer(llm_response) - if len(processed_text_memories) > 0: - # create new - enhanced_memories = [] - user_id = memories[0].metadata.user_id - if FINE_STRATEGY == FineStrategy.RECREATE: - for new_mem in processed_text_memories: - enhanced_memories.append( - TextualMemoryItem( - memory=new_mem, - metadata=TextualMemoryMetadata( - user_id=user_id, memory_type="LongTermMemory" - ), # TODO add memory_type - ) - ) - elif FINE_STRATEGY == FineStrategy.REWRITE: - # Parse index from each processed line and rewrite corresponding original memory - def _parse_index_and_text(s: str) -> tuple[int | None, str]: - import re - - s = (s or "").strip() - # Preferred: [index] text - m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s) - if m: - return int(m.group(1)), m.group(2).strip() - # Fallback: index: text or index - text - m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s) - if m: - return int(m.group(1)), m.group(2).strip() - return None, s - - idx_to_original = dict(enumerate(memories)) - for j, item in enumerate(processed_text_memories): - idx, new_text = _parse_index_and_text(item) - if idx is not None and idx in idx_to_original: - orig = idx_to_original[idx] - else: - # Fallback: align by order if index missing/invalid - orig = memories[j] if j < len(memories) else None - if not orig: - continue - enhanced_memories.append( - TextualMemoryItem( - id=orig.id, - memory=new_text, - metadata=orig.metadata, - ) - ) - else: - logger.error(f"Fine search strategy {FINE_STRATEGY} not exists") - - logger.info( - f"[enhance_memories_with_query] ✅ done | Strategy={FINE_STRATEGY} | prompt={prompt} | llm_response={llm_response}" - ) - return enhanced_memories, True - else: - raise ValueError( - f"Fail to run memory enhancement; retry {attempt}/{max(1, retries) + 1}; processed_text_memories: {processed_text_memories}" - ) - except Exception as e: - attempt += 1 - time.sleep(1) - logger.debug( - f"[enhance_memories_with_query][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}" - ) - logger.error( - f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}", - exc_info=True, - ) - return memories, False - - @staticmethod - def _split_batches( - memories: list[TextualMemoryItem], batch_size: int - ) -> list[tuple[int, int, list[TextualMemoryItem]]]: - batches: list[tuple[int, int, list[TextualMemoryItem]]] = [] - start = 0 - n = len(memories) - while start < n: - end = min(start + batch_size, n) - batches.append((start, end, memories[start:end])) - start = end - return batches - - def recall_for_missing_memories( - self, - query: str, - memories: list[str], - ) -> tuple[str, bool]: - text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)]) - - prompt = self.build_prompt( - template_name="enlarge_recall", + return self.enhancement_pipeline.evaluate_memory_answer_ability( query=query, - memories_inline=text_memories, - ) - llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) - - json_result: dict = extract_json_obj(llm_response) - - logger.info( - f"[recall_for_missing_memories] ✅ done | prompt={prompt} | llm_response={llm_response}" + memory_texts=memory_texts, + top_k=top_k, ) - hint = json_result.get("hint", "") - if len(hint) == 0: - return hint, False - return hint, json_result.get("trigger_recall", False) - def search( self, query: str, user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube, top_k: int, - method: str = TreeTextMemory_SEARCH_METHOD, + method: str, search_args: dict | None = None, ) -> list[TextualMemoryItem]: - """Search in text memory with the given query. - - Args: - query: The search query string - top_k: Number of top results to return - method: Search method to use - - Returns: - Search results or None if not implemented - """ - text_mem_base = mem_cube.text_mem - # Normalize default for mutable argument - search_args = search_args or {} - try: - if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: - assert isinstance(text_mem_base, TreeTextMemory) - session_id = search_args.get("session_id", "default_session") - target_session_id = session_id - search_priority = ( - {"session_id": target_session_id} if "session_id" in search_args else None - ) - search_filter = search_args.get("filter") - search_source = search_args.get("source") - plugin = bool(search_source is not None and search_source == "plugin") - user_name = search_args.get("user_name", mem_cube_id) - internet_search = search_args.get("internet_search", False) - chat_history = search_args.get("chat_history") - search_tool_memory = search_args.get("search_tool_memory", False) - tool_mem_top_k = search_args.get("tool_mem_top_k", 6) - playground_search_goal_parser = search_args.get( - "playground_search_goal_parser", False - ) - - info = search_args.get( - "info", - { - "user_id": user_id, - "session_id": target_session_id, - "chat_history": chat_history, - }, - ) - - results_long_term = mem_cube.text_mem.search( - query=query, - user_name=user_name, - top_k=top_k, - mode=SearchMode.FAST, - manual_close_internet=not internet_search, - memory_type="LongTermMemory", - search_filter=search_filter, - search_priority=search_priority, - info=info, - plugin=plugin, - search_tool_memory=search_tool_memory, - tool_mem_top_k=tool_mem_top_k, - playground_search_goal_parser=playground_search_goal_parser, - ) - - results_user = mem_cube.text_mem.search( - query=query, - user_name=user_name, - top_k=top_k, - mode=SearchMode.FAST, - manual_close_internet=not internet_search, - memory_type="UserMemory", - search_filter=search_filter, - search_priority=search_priority, - info=info, - plugin=plugin, - search_tool_memory=search_tool_memory, - tool_mem_top_k=tool_mem_top_k, - playground_search_goal_parser=playground_search_goal_parser, - ) - results = results_long_term + results_user - else: - raise NotImplementedError(str(type(text_mem_base))) - except Exception as e: - logger.error(f"Fail to search. The exeption is {e}.", exc_info=True) - results = [] - return results + return self.search_pipeline.search( + query=query, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=top_k, + method=method, + search_args=search_args, + ) def enhance_memories_with_query( self, query_history: list[str], memories: list[TextualMemoryItem], - ) -> (list[TextualMemoryItem], bool): - """ - Enhance memories by adding context and making connections to better answer queries. - - Args: - query_history: List of user queries in chronological order - memories: List of memory items to enhance - - Returns: - Tuple of (enhanced_memories, success_flag) - """ - if not memories: - logger.warning("[Enhance] ⚠️ skipped (no memories to process)") - return memories, True - - batch_size = self.batch_size - retries = self.retries - num_of_memories = len(memories) - try: - # no parallel - if batch_size is None or num_of_memories <= batch_size: - # Single batch path with retry - enhanced_memories, success_flag = self._process_enhancement_batch( - batch_index=0, - query_history=query_history, - memories=memories, - retries=retries, - ) - - all_success = success_flag - else: - # parallel running batches - # Split into batches preserving order - batches = self._split_batches(memories=memories, batch_size=batch_size) - - # Process batches concurrently - all_success = True - failed_batches = 0 - with ContextThreadPoolExecutor(max_workers=len(batches)) as executor: - future_map = { - executor.submit( - self._process_enhancement_batch, bi, query_history, texts, retries - ): (bi, s, e) - for bi, (s, e, texts) in enumerate(batches) - } - enhanced_memories = [] - for fut in as_completed(future_map): - bi, s, e = future_map[fut] - - batch_memories, ok = fut.result() - enhanced_memories.extend(batch_memories) - if not ok: - all_success = False - failed_batches += 1 - logger.info( - f"[Enhance] ✅ multi-batch done | batches={len(batches)} | enhanced={len(enhanced_memories)} |" - f" failed_batches={failed_batches} | success={all_success}" - ) - - except Exception as e: - logger.error(f"[Enhance] ❌ fatal error: {e}", exc_info=True) - all_success = False - enhanced_memories = memories + ) -> tuple[list[TextualMemoryItem], bool]: + return self.enhancement_pipeline.enhance_memories_with_query( + query_history=query_history, + memories=memories, + ) - if len(enhanced_memories) == 0: - enhanced_memories = [] - logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True) - return enhanced_memories, all_success + def recall_for_missing_memories(self, query: str, memories: list[str]) -> tuple[str, bool]: + return self.enhancement_pipeline.recall_for_missing_memories( + query=query, + memories=memories, + ) def rerank_memories( self, queries: list[str], original_memories: list[str], top_k: int - ) -> (list[str], bool): - """ - Rerank memories based on relevance to given queries using LLM. - - Args: - queries: List of query strings to determine relevance - original_memories: List of memory strings to be reranked - top_k: Number of top memories to return after reranking - - Returns: - List of reranked memory strings (length <= top_k) - - Note: - If LLM reranking fails, falls back to original order (truncated to top_k) - """ - - logger.info(f"Starting memory reranking for {len(original_memories)} memories") - - # Build LLM prompt for memory reranking - prompt = self.build_prompt( - "memory_reranking", - queries=[f"[0] {queries[0]}"], - current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)], + ) -> tuple[list[str], bool]: + return self.rerank_pipeline.rerank_memories( + queries=queries, + original_memories=original_memories, + top_k=top_k, ) - logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars - - # Get LLM response - response = self.process_llm.generate([{"role": "user", "content": prompt}]) - logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars - - try: - # Parse JSON response - response = extract_json_obj(response) - new_order = response["new_order"][:top_k] - text_memories_with_new_order = [original_memories[idx] for idx in new_order] - logger.info( - f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items;" - f"Ranking reasoning: {response['reasoning']}" - ) - success_flag = True - except Exception as e: - logger.error( - f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ", - exc_info=True, - ) - text_memories_with_new_order = original_memories[:top_k] - success_flag = False - return text_memories_with_new_order, success_flag def process_and_rerank_memories( self, @@ -459,89 +101,40 @@ def process_and_rerank_memories( original_memory: list[TextualMemoryItem], new_memory: list[TextualMemoryItem], top_k: int = 10, - ) -> list[TextualMemoryItem] | None: - """ - Process and rerank memory items by combining original and new memories, - applying filters, and then reranking based on relevance to queries. - - Args: - queries: List of query strings to rerank memories against - original_memory: List of original TextualMemoryItem objects - new_memory: List of new TextualMemoryItem objects to merge - top_k: Maximum number of memories to return after reranking - - Returns: - List of reranked TextualMemoryItem objects, or None if processing fails - """ - # Combine original and new memories into a single list - combined_memory = original_memory + new_memory - - # Create a mapping from normalized text to memory objects - memory_map = { - transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory - } - - # Extract normalized text representations from all memory items - combined_text_memory = [m.memory for m in combined_memory] - - # Apply similarity filter to remove overly similar memories - filtered_combined_text_memory = filter_vector_based_similar_memories( - text_memories=combined_text_memory, - similarity_threshold=self.filter_similarity_threshold, - ) - - # Apply length filter to remove memories that are too short - filtered_combined_text_memory = filter_too_short_memories( - text_memories=filtered_combined_text_memory, - min_length_threshold=self.filter_min_length_threshold, - ) - - # Ensure uniqueness of memory texts using dictionary keys (preserves order) - unique_memory = list(dict.fromkeys(filtered_combined_text_memory)) - - # Rerank the filtered memories based on relevance to the queries - text_memories_with_new_order, success_flag = self.rerank_memories( + ) -> tuple[list[TextualMemoryItem], bool]: + return self.rerank_pipeline.process_and_rerank_memories( queries=queries, - original_memories=unique_memory, + original_memory=original_memory, + new_memory=new_memory, top_k=top_k, ) - # Map reranked text entries back to their original memory objects - memories_with_new_order = [] - for text in text_memories_with_new_order: - normalized_text = transform_name_to_key(name=text) - if normalized_text in memory_map: # Ensure correct key matching - memories_with_new_order.append(memory_map[normalized_text]) - else: - logger.warning( - f"Memory text not found in memory map. text: {text};\n" - f"Keys of memory_map: {memory_map.keys()}" - ) - - return memories_with_new_order, success_flag - def filter_unrelated_memories( self, query_history: list[str], memories: list[TextualMemoryItem], - ) -> (list[TextualMemoryItem], bool): - return self.memory_filter.filter_unrelated_memories(query_history, memories) + ) -> tuple[list[TextualMemoryItem], bool]: + return self.filter_pipeline.filter_unrelated_memories( + query_history=query_history, + memories=memories, + ) def filter_redundant_memories( self, query_history: list[str], memories: list[TextualMemoryItem], - ) -> (list[TextualMemoryItem], bool): - return self.memory_filter.filter_redundant_memories(query_history, memories) + ) -> tuple[list[TextualMemoryItem], bool]: + return self.filter_pipeline.filter_redundant_memories( + query_history=query_history, + memories=memories, + ) def filter_unrelated_and_redundant_memories( self, query_history: list[str], memories: list[TextualMemoryItem], - ) -> (list[TextualMemoryItem], bool): - """ - Filter out both unrelated and redundant memories using LLM analysis. - - This method delegates to the MemoryFilter class. - """ - return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories) + ) -> tuple[list[TextualMemoryItem], bool]: + return self.filter_pipeline.filter_unrelated_and_redundant_memories( + query_history=query_history, + memories=memories, + ) diff --git a/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py new file mode 100644 index 000000000..a346622c5 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import ( + TreeTextMemory_FINE_SEARCH_METHOD, + TreeTextMemory_SEARCH_METHOD, +) +from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.types.general_types import SearchMode + + +logger = get_logger(__name__) + + +class SearchPipeline: + def search( + self, + query: str, + user_id: str, + mem_cube_id: str, + mem_cube, + top_k: int, + method: str = TreeTextMemory_SEARCH_METHOD, + search_args: dict | None = None, + ) -> list[TextualMemoryItem]: + text_mem_base = mem_cube.text_mem + search_args = search_args or {} + try: + if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: + assert isinstance(text_mem_base, TreeTextMemory) + session_id = search_args.get("session_id", "default_session") + target_session_id = session_id + search_priority = ( + {"session_id": target_session_id} if "session_id" in search_args else None + ) + search_filter = search_args.get("filter") + search_source = search_args.get("source") + plugin = bool(search_source is not None and search_source == "plugin") + user_name = search_args.get("user_name", mem_cube_id) + internet_search = search_args.get("internet_search", False) + chat_history = search_args.get("chat_history") + search_tool_memory = search_args.get("search_tool_memory", False) + tool_mem_top_k = search_args.get("tool_mem_top_k", 6) + playground_search_goal_parser = search_args.get( + "playground_search_goal_parser", False + ) + + info = search_args.get( + "info", + { + "user_id": user_id, + "session_id": target_session_id, + "chat_history": chat_history, + }, + ) + + results_long_term = mem_cube.text_mem.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=SearchMode.FAST, + manual_close_internet=not internet_search, + memory_type="LongTermMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + + results_user = mem_cube.text_mem.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=SearchMode.FAST, + manual_close_internet=not internet_search, + memory_type="UserMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + results = results_long_term + results_user + else: + raise NotImplementedError(str(type(text_mem_base))) + except Exception as e: + logger.error("Fail to search. The exception is %s.", e, exc_info=True) + results = [] + return results diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 497d19ac6..e535d6f73 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -19,6 +19,7 @@ from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.search import build_search_context, search_text_memories from memos.types import ( MemCubeID, SearchMode, @@ -104,29 +105,13 @@ def search_memories( mem_cube: NaiveMemCube, mode: SearchMode, ): - """Fine search memories function copied from server_router to avoid circular import""" - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_priority = {"session_id": search_req.session_id} if search_req.session_id else None - search_filter = search_req.filter - - # Create MemCube and perform search - search_results = mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, + """Shared text-memory search via centralized search service.""" + return search_text_memories( + text_mem=mem_cube.text_mem, + search_req=search_req, + user_context=user_context, mode=mode, - manual_close_internet=not search_req.internet_search, - search_filter=search_filter, - search_priority=search_priority, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, ) - return search_results def mix_search_memories( self, @@ -157,19 +142,13 @@ def mix_search_memories( ] # Get mem_cube for fast search - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_priority = {"session_id": search_req.session_id} if search_req.session_id else None - search_filter = search_req.filter + search_ctx = build_search_context(search_req=search_req) + search_priority = search_ctx.search_priority + search_filter = search_ctx.search_filter # Rerank Memories - reranker expects TextualMemoryItem objects - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } + info = search_ctx.info raw_retrieved_memories = self.searcher.retrieve( query=search_req.query, diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index d570dccdd..dc7d86752 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -718,10 +718,10 @@ def _claim_pending_messages( justid=False, ) if len(claimed_result) == 2: - next_id, claimed = claimed_result - deleted_ids = [] + _next_id, claimed = claimed_result + _deleted_ids = [] elif len(claimed_result) == 3: - next_id, claimed, deleted_ids = claimed_result + _next_id, claimed, _deleted_ids = claimed_result else: raise ValueError( f"Unexpected xautoclaim response length: {len(claimed_result)}" @@ -745,10 +745,10 @@ def _claim_pending_messages( justid=False, ) if len(claimed_result) == 2: - next_id, claimed = claimed_result - deleted_ids = [] + _next_id, claimed = claimed_result + _deleted_ids = [] elif len(claimed_result) == 3: - next_id, claimed, deleted_ids = claimed_result + _next_id, claimed, _deleted_ids = claimed_result else: raise ValueError( f"Unexpected xautoclaim response length: {len(claimed_result)}" diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index bd026a51d..5058ca805 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -25,6 +25,7 @@ PREF_ADD_TASK_LABEL, ) from memos.multi_mem_cube.views import MemCubeView +from memos.search import search_text_memories from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( FINE_STRATEGY, @@ -455,31 +456,11 @@ def _fast_search( Returns: List of search results """ - target_session_id = search_req.session_id or "default_session" - search_priority = {"session_id": search_req.session_id} if search_req.session_id else None - search_filter = search_req.filter or None - plugin = bool(search_req.source is not None and search_req.source == "plugin") - - search_results = self.naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, + search_results = search_text_memories( + text_mem=self.naive_mem_cube.text_mem, + search_req=search_req, + user_context=user_context, mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - memory_type=search_req.search_memory_type, - search_filter=search_filter, - search_priority=search_priority, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - plugin=plugin, - search_tool_memory=search_req.search_tool_memory, - tool_mem_top_k=search_req.tool_mem_top_k, - include_skill_memory=search_req.include_skill_memory, - skill_mem_top_k=search_req.skill_mem_top_k, - dedup=search_req.dedup, ) formatted_memories = [ diff --git a/src/memos/search/__init__.py b/src/memos/search/__init__.py new file mode 100644 index 000000000..71388c62b --- /dev/null +++ b/src/memos/search/__init__.py @@ -0,0 +1,4 @@ +from .search_service import SearchContext, build_search_context, search_text_memories + + +__all__ = ["SearchContext", "build_search_context", "search_text_memories"] diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py new file mode 100644 index 000000000..79c9a43e5 --- /dev/null +++ b/src/memos/search/search_service.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from memos.api.product_models import APISearchRequest + from memos.types import SearchMode, UserContext + + +@dataclass(frozen=True) +class SearchContext: + target_session_id: str + search_priority: dict[str, Any] | None + search_filter: dict[str, Any] | None + info: dict[str, Any] + plugin: bool + + +def build_search_context( + search_req: APISearchRequest, +) -> SearchContext: + target_session_id = search_req.session_id or "default_session" + search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + return SearchContext( + target_session_id=target_session_id, + search_priority=search_priority, + search_filter=search_req.filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + plugin=bool(search_req.source is not None and search_req.source == "plugin"), + ) + + +def search_text_memories( + text_mem: Any, + search_req: APISearchRequest, + user_context: UserContext, + mode: SearchMode, +) -> list[Any]: + """ + Shared text-memory search logic for API and scheduler paths. + """ + ctx = build_search_context(search_req=search_req) + return text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=mode, + manual_close_internet=not search_req.internet_search, + memory_type=search_req.search_memory_type, + search_filter=ctx.search_filter, + search_priority=ctx.search_priority, + info=ctx.info, + plugin=ctx.plugin, + search_tool_memory=search_req.search_tool_memory, + tool_mem_top_k=search_req.tool_mem_top_k, + include_skill_memory=search_req.include_skill_memory, + skill_mem_top_k=search_req.skill_mem_top_k, + dedup=search_req.dedup, + )