From 74c5a5ef2fb48075bac823e9df382925dd579049 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 11 Sep 2025 11:18:40 +0800 Subject: [PATCH 01/62] add preference text memory --- .../textual/prefer_text_memory/assemble.py | 25 ++++ .../textual/prefer_text_memory/builders.py | 49 ++++++ .../textual/prefer_text_memory/config.py | 139 ++++++++++++++++++ .../textual/prefer_text_memory/factory.py | 73 +++++++++ .../textual/prefer_text_memory/retrievers.py | 25 ++++ .../textual/prefer_text_memory/strategies.py | 18 +++ .../textual/prefer_text_memory/updater.py | 23 +++ src/memos/memories/textual/preference.py | 120 +++++++++++++++ 8 files changed, 472 insertions(+) create mode 100644 src/memos/memories/textual/prefer_text_memory/assemble.py create mode 100644 src/memos/memories/textual/prefer_text_memory/builders.py create mode 100644 src/memos/memories/textual/prefer_text_memory/config.py create mode 100644 src/memos/memories/textual/prefer_text_memory/factory.py create mode 100644 src/memos/memories/textual/prefer_text_memory/retrievers.py create mode 100644 src/memos/memories/textual/prefer_text_memory/strategies.py create mode 100644 src/memos/memories/textual/prefer_text_memory/updater.py create mode 100644 src/memos/memories/textual/preference.py diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py new file mode 100644 index 000000000..845c55e24 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from memos.memories.textual.item import TextualMemoryItem + + +class BaseAssembler(ABC): + """Abstract base class for assemblers.""" + @abstractmethod + def __init__(self): + """Initialize the assembler.""" + + @abstractmethod + def assemble(self, memories: list[TextualMemoryItem]) -> str: + """Assemble memories into a single memory.""" + + +class NaiveAssembler(BaseAssembler): + """Naive assembler.""" + def __init__(self): + """Initialize the naive assembler.""" + super().__init__() + + def assemble(self, memories: list[TextualMemoryItem]) -> str: + """Assemble memories into a single memory.""" + pass + diff --git a/src/memos/memories/textual/prefer_text_memory/builders.py b/src/memos/memories/textual/prefer_text_memory/builders.py new file mode 100644 index 000000000..b2830f4d3 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/builders.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod +from typing import Optional +from memos.llms.base import BaseLLM + +class BaseBuilder(ABC): + """ + Abstract base class for memory builders. + + Each builder implements a specific build strategy for creating + procedural memory content from task trajectories. + """ + + @abstractmethod + def __init__(self, llm_provider: Optional[BaseLLM] = None): + """ + Initialize the memory builder. + + Args: + llm_provider: LLM provider for script generation (required for some strategies) + """ + self.llm_provider = llm_provider + + @abstractmethod + def build(self, task_description: str, trajectory: str) -> str: + """ + Build memory content from task description and trajectory. + + Args: + task_description: Natural language description of the task + trajectory: Detailed step-by-step trajectory of task execution + + Returns: + Memory content string formatted according to the build strategy + + Raises: + RuntimeError: If memory building fails + """ + pass + + +class NaiveBuilder(BaseBuilder): + """Naive memory builder.""" + def __init__(self, llm_provider: Optional[BaseLLM] = None): + """Initialize the naive memory builder.""" + super().__init__(llm_provider) + + def build(self, task_description: str, trajectory: str) -> str: + """Build memory content from task description and trajectory.""" + pass \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/config.py b/src/memos/memories/textual/prefer_text_memory/config.py new file mode 100644 index 000000000..fc1e3bfb4 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/config.py @@ -0,0 +1,139 @@ +from memos.configs.base import BaseConfig +from typing import Any, ClassVar +from pydantic import Field +from pydantic import field_validator, model_validator + + + +class BaseBuilderConfig(BaseConfig): + """Base configuration class for Builder.""" + + +class NaiveBuilderConfig(BaseBuilderConfig): + """Configuration for Naive Builder.""" + + + +class BuilderConfigFactory(BaseConfig): + """Factory class for creating Builder configurations.""" + + backend: str = Field(..., description="Backend for Builder") + config: dict[str, Any] = Field(..., description="Configuration for the Builder backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveBuilderConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "BuilderConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + + +class BaseRetrieverConfig(BaseConfig): + """Base configuration class for Retriever.""" + + +class NaiveRetrieverConfig(BaseRetrieverConfig): + """Configuration for Naive Retriever.""" + + +class RetrieverConfigFactory(BaseConfig): + """Factory class for creating Retriever configurations.""" + + backend: str = Field(..., description="Backend for Retriever") + config: dict[str, Any] = Field(..., description="Configuration for the Retriever backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveRetrieverConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "RetrieverConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + + +class BaseUpdaterConfig(BaseConfig): + """Base configuration class for Updater.""" + + +class NaiveUpdaterConfig(BaseUpdaterConfig): + """Configuration for Naive Updater.""" + + +class UpdaterConfigFactory(BaseConfig): + """Factory class for creating Updater configurations.""" + + backend: str = Field(..., description="Backend for Updater") + config: dict[str, Any] = Field(..., description="Configuration for the Updater backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveUpdaterConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "UpdaterConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + + +class BaseAssemblerConfig(BaseConfig): + """Base configuration class for Assembler.""" + + +class NaiveAssemblerConfig(BaseAssemblerConfig): + """Configuration for Naive Assembler.""" + + +class AssemblerConfigFactory(BaseConfig): + """Factory class for creating Assembler configurations.""" + + backend: str = Field(..., description="Backend for Assembler") + config: dict[str, Any] = Field(..., description="Configuration for the Assembler backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveAssemblerConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "AssemblerConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py new file mode 100644 index 000000000..612dc1194 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -0,0 +1,73 @@ +from typing import Any, ClassVar + +from memos.memories.textual.prefer_text_memory.builders import BaseBuilder, NaiveBuilder +from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever, NaiveRetriever +from memos.memories.textual.prefer_text_memory.updater import BaseUpdater, NaiveUpdater +from memos.memories.textual.prefer_text_memory.assemble import BaseAssembler, NaiveAssembler +from memos.memories.textual.prefer_text_memory.config import BuilderConfigFactory, RetrieverConfigFactory, UpdaterConfigFactory, AssemblerConfigFactory + + +class BuilderFactory(BaseBuilder): + """Factory class for creating Builder instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveBuilder, + } + + @classmethod + def from_config(cls, config_factory: BuilderConfigFactory) -> BaseBuilder: + """Create a Builder instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + builder_class = cls.backend_to_class[backend] + return builder_class(config_factory.config) + +class RetrieverFactory(BaseRetriever): + """Factory class for creating Retriever instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveRetriever, + } + + @classmethod + def from_config(cls, config_factory: RetrieverConfigFactory) -> BaseRetriever: + """Create a Retriever instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + retriever_class = cls.backend_to_class[backend] + return retriever_class(config_factory.config) + +class UpdaterFactory(BaseUpdater): + """Factory class for creating Updater instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveUpdater, + } + + @classmethod + def from_config(cls, config_factory: UpdaterConfigFactory) -> BaseUpdater: + """Create a Updater instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + updater_class = cls.backend_to_class[backend] + return updater_class(config_factory.config) + +class AssemblerFactory(BaseAssembler): + """Factory class for creating Assembler instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveAssembler, + } + + @classmethod + def from_config(cls, config_factory: AssemblerConfigFactory) -> BaseAssembler: + """Create a Assembler instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + assembler_class = cls.backend_to_class[backend] + return assembler_class(config_factory.config) + \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py new file mode 100644 index 000000000..a6c3bba1c --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List + + +class BaseRetriever(ABC): + """Abstract base class for retrievers.""" + + @abstractmethod + def __init__(self): + """Initialize the retriever.""" + + + @abstractmethod + def retrieve(self, task_description: str, k: int, threshold: float) -> List[Dict[str, Any]]: + """Retrieve memories from the retriever.""" + +class NaiveRetriever(BaseRetriever): + """Naive retriever.""" + def __init__(self): + """Initialize the naive retriever.""" + super().__init__() + + def retrieve(self, task_description: str, k: int, threshold: float) -> List[Dict[str, Any]]: + """Retrieve memories from the naive retriever.""" + pass \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/strategies.py b/src/memos/memories/textual/prefer_text_memory/strategies.py new file mode 100644 index 000000000..11c8f44e0 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/strategies.py @@ -0,0 +1,18 @@ +from enum import Enum + +class BuildStrategy(Enum): + """Strategy for building memories.""" + NAIVE = "naive" + +class RetrieveStrategy(Enum): + """Strategy for retrieving memories.""" + NAIVE = "naive" + +class UpdateStrategy(Enum): + """Strategy for updating memories.""" + NAIVE = "naive" + +class AssembleStrategy(Enum): + """Strategy for assembling memories.""" + NAIVE = "naive" + diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py new file mode 100644 index 000000000..71d043e89 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod +from memos.memories.textual.item import TextualMemoryItem + + +class BaseUpdater(ABC): + """Abstract base class for updaters.""" + + @abstractmethod + def __init__(self): + """Initialize the updater.""" + + +class NaiveUpdater(BaseUpdater): + """Naive updater.""" + def __init__(self): + """Initialize the naive updater.""" + super().__init__() + + def update(self, memories: list[TextualMemoryItem]) -> None: + """Update the memory.""" + pass + + diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py new file mode 100644 index 000000000..96c4e830c --- /dev/null +++ b/src/memos/memories/textual/preference.py @@ -0,0 +1,120 @@ +import json +import os +import shutil +import tempfile + +from datetime import datetime +from pathlib import Path +from typing import Any + +from memos.configs.memory import GeneralTextMemoryConfig +from memos.embedders.factory import EmbedderFactory, OllamaEmbedder, ArkEmbedder, SenTranEmbedder, UniversalAPIEmbedder +from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM +from memos.log import get_logger +from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessageList +from memos.llms.base import BaseLLM +from memos.vec_dbs.factory import QdrantVecDB, VecDBFactory + + +class PreferenceTextMemory(BaseTextMemory): + """Preference textual memory implementation for storing and retrieving memories.""" + + def __init__(self, config: GeneralTextMemoryConfig): + """Initialize memory with the given configuration.""" + self.config: GeneralTextMemoryConfig = config + self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( + config.extractor_llm + ) + self.vector_db: QdrantVecDB = VecDBFactory.from_config(config.vector_db) + self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = \ + EmbedderFactory.from_config(config.embedder) + + + def build_memory(): + """Build memory from the original dialogs. (Initialize memory)""" + pass + + def extract(self, messages: MessageList) -> list[TextualMemoryItem]: + """Extract memories based on the messages. + Args: + messages (MessageList): The messages to extract memories from. + Returns: + list[TextualMemoryItem]: List of extracted memory items. + """ + pass + + def get_prompt(self, memories: list[TextualMemoryItem]) -> str: + """Get the prompt for the memory. + Args: + memories (list[TextualMemoryItem]): The memories to get the prompt for. + Returns: + str: The prompt for the memory. + """ + pass + + def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + """Add memories. + + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + """ + pass + + def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: + """Update a memory by memory_id.""" + pass + + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + """Search for memories based on a query. + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + pass + + def get(self, memory_id: str) -> TextualMemoryItem: + """Get a memory by its ID. + Args: + memory_id (str): The ID of the memory to retrieve. + Returns: + TextualMemoryItem: The memory with the given ID. + """ + pass + + def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: + """Get memories by their IDs. + Args: + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs. + """ + pass + + def get_all(self) -> list[TextualMemoryItem]: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + pass + + def delete(self, memory_ids: list[str]) -> None: + """Delete memories. + Args: + memory_ids (list[str]): List of memory IDs to delete. + """ + pass + + def delete_all(self) -> None: + """Delete all memories.""" + pass + + def drop( + self, + ) -> None: + """Drop all databases.""" + pass \ No newline at end of file From b8a739d13aa74927da66ac4e716214b01f794f3a Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 11 Sep 2025 17:01:13 +0800 Subject: [PATCH 02/62] finish milvus support --- src/memos/vec_dbs/milvus.py | 286 ++++++++++++++++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 src/memos/vec_dbs/milvus.py diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py new file mode 100644 index 000000000..a04c5aeff --- /dev/null +++ b/src/memos/vec_dbs/milvus.py @@ -0,0 +1,286 @@ +from typing import Any + +from memos.configs.vec_db import MilvusVecDBConfig +from memos.dependency import require_python_package +from memos.log import get_logger +from memos.vec_dbs.base import BaseVecDB +from memos.vec_dbs.item import VecDBItem + + +logger = get_logger(__name__) + + +class MilvusVecDB(BaseVecDB): + """Milvus vector database implementation.""" + + @require_python_package( + import_name="pymilvus", + install_command="pip install -U pymilvus", + install_link="https://milvus.io/docs/install-pymilvus.md", + ) + def __init__(self, config: MilvusVecDBConfig): + """Initialize the Milvus vector database and the collection.""" + from pymilvus import MilvusClient + + self.config = config + + # Create Milvus client + self.client = MilvusClient(uri=self.config.uri) + self.create_collection() + + def create_collection(self) -> None: + """Create a new collection with specified parameters.""" + for collection_name in self.config.collection_name: + if self.collection_exists(collection_name): + logger.warning( + f"Collection '{collection_name}' already exists. Skipping creation." + ) + continue + + self.client.create_collection( + collection_name=collection_name, + dimension=self.config.vector_dimension, + metric_type=self._get_metric_type(), + id_type="string", # Use string ID type, align with VecDBItem id type + ) + + logger.info( + f"Collection '{collection_name}' created with {self.config.vector_dimension} dimensions." + ) + + def list_collections(self) -> list[str]: + """List all collections.""" + return self.client.list_collections() + + def delete_collection(self, name: str) -> None: + """Delete a collection.""" + self.client.drop_collection(name) + + def collection_exists(self, name: str) -> bool: + """Check if a collection exists.""" + return self.client.has_collection(collection_name=name) + + def search( + self, query_vector: list[float], collection_name: str, top_k: int, filter: dict[str, Any] | None = None + ) -> list[VecDBItem]: + """ + Search for similar items in the database. + + Args: + query_vector: Single vector to search + collection_name: Name of the collection to search + top_k: Number of results to return + filter: Payload filters + + Returns: + List of search results with distance scores and payloads. + """ + # Convert filter to Milvus expression + expr = self._dict_to_expr(filter) if filter else None + + results = self.client.search( + collection_name=collection_name, + data=[query_vector], + limit=top_k, + filter=expr, + ) + + items = [] + for hit in results[0]: + # Extract payload from hit + payload = {} + for key, value in hit.get("entity", {}).items(): + if key not in ["id", "vector"]: + payload[key] = value + + items.append(VecDBItem( + id=str(hit["id"]), + vector=hit.get("entity", {}).get("vector"), + payload=payload, + score=float(hit["distance"]), + )) + + logger.info(f"Milvus search completed with {len(items)} results.") + return items + + def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: + """Convert a dictionary filter to a Milvus expression string.""" + conditions = [] + for field, value in filter_dict.items(): + if isinstance(value, str): + conditions.append(f'{field} == "{value}"') + else: + conditions.append(f"{field} == {value}") + return " and ".join(conditions) + + def _get_metric_type(self) -> str: + """Get the metric type for search.""" + metric_map = { + "cosine": "COSINE", + "euclidean": "L2", + "dot": "IP", + } + return metric_map.get(self.config.distance_metric, "L2") + + def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: + """Get a single item by ID.""" + results = self.client.get( + collection_name=collection_name, + ids=[id], + ) + + if not results: + return None + + entity = results[0] + payload = {k: v for k, v in entity.items() if k not in ["id", "vector"]} + + return VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + + def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: + """Get multiple items by their IDs.""" + results = self.client.get( + collection_name=collection_name, + ids=ids, + ) + + if not results: + return [] + + items = [] + for entity in results: + payload = {k: v for k, v in entity.items() if k not in ["id", "vector"]} + items.append(VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + )) + + return items + + def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100) -> list[VecDBItem]: + """ + Retrieve all items that match the given filter criteria. + + Args: + filter: Payload filters to match against stored items + scroll_limit: Maximum number of items to retrieve per scroll request + + Returns: + List of items including vectors and payload that match the filter + """ + expr = self._dict_to_expr(filter) if filter else None + + results = self.client.query( + collection_name=collection_name, + filter=expr, + limit=scroll_limit, + ) + + items = [] + for entity in results: + payload = {k: v for k, v in entity.items() if k not in ["id", "vector"]} + items.append(VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + )) + + logger.info(f"Milvus retrieve by filter completed with {len(items)} results.") + return items + + def get_all(self, collection_name: str, scroll_limit=100) -> list[VecDBItem]: + """Retrieve all items in the vector database.""" + return self.get_by_filter(collection_name, {}, scroll_limit=scroll_limit) + + def count(self, collection_name: str, filter: dict[str, Any] | None = None) -> int: + """Count items in the database, optionally with filter.""" + if filter: + # If there's a filter, use query method + expr = self._dict_to_expr(filter) + results = self.client.query( + collection_name=collection_name, + filter=expr, + output_fields=["id"], + ) + return len(results) + else: + # For counting all items, use get_collection_stats for accurate count + stats = self.client.get_collection_stats(collection_name) + # Extract row count from stats - stats is a dict, not a list + return int(stats.get("row_count", 0)) + + def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + """ + Add data to the vector database. + + Args: + data: List of VecDBItem objects or dictionaries containing: + - 'id': unique identifier + - 'vector': embedding vector + - 'payload': additional fields for filtering/retrieval + """ + entities = [] + for item in data: + if isinstance(item, dict): + item = item.copy() + item = VecDBItem.from_dict(item) + + # Prepare entity data + entity = { + "id": item.id, + "vector": item.vector, + } + # Add payload fields + if item.payload: + entity.update(item.payload) + + entities.append(entity) + + # Use upsert to be safe (insert or update) + self.client.upsert( + collection_name=collection_name, + data=entities, + ) + + def update(self, collection_name: str, id: str, data: VecDBItem | dict[str, Any]) -> None: + """Update an item in the vector database.""" + if isinstance(data, dict): + data = data.copy() + data = VecDBItem.from_dict(data) + + # Use upsert for updates + self.upsert(collection_name, [data]) + + def ensure_payload_indexes(self, fields: list[str]) -> None: + """ + Create payload indexes for specified fields in the collection. + This is idempotent: it will skip if index already exists. + + Args: + fields (list[str]): List of field names to index (as keyword). + """ + # Note: Milvus doesn't have the same concept of payload indexes as Qdrant + # Field indexes are created automatically for scalar fields + logger.info(f"Milvus automatically indexes scalar fields: {fields}") + + def upsert(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + """ + Add or update data in the vector database. + + If an item with the same ID exists, it will be updated. + Otherwise, it will be added as a new item. + """ + # Reuse add method since it already uses upsert + self.add(collection_name, data) + + def delete(self, collection_name: str, ids: list[str]) -> None: + """Delete items from the vector database.""" + self.client.delete( + collection_name=collection_name, + ids=ids, + ) \ No newline at end of file From 89b8d4bfc82728e6d4e6d2e91d96498c6d595f96 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 12 Sep 2025 18:02:17 +0800 Subject: [PATCH 03/62] add new builder --- src/memos/configs/memory.py | 45 ++ src/memos/configs/vec_db.py | 7 + src/memos/memories/factory.py | 2 + .../textual/prefer_text_memory/assemble.py | 19 +- .../textual/prefer_text_memory/builders.py | 488 +++++++++++++++++- .../textual/prefer_text_memory/clustering.py | 122 +++++ .../textual/prefer_text_memory/config.py | 4 + .../textual/prefer_text_memory/factory.py | 16 +- .../textual/prefer_text_memory/retrievers.py | 21 +- .../textual/prefer_text_memory/updater.py | 28 +- src/memos/memories/textual/preference.py | 92 +++- src/memos/templates/prefer_complete_prompt.py | 200 +++++++ src/memos/vec_dbs/factory.py | 2 + 13 files changed, 977 insertions(+), 69 deletions(-) create mode 100644 src/memos/memories/textual/prefer_text_memory/clustering.py create mode 100644 src/memos/templates/prefer_complete_prompt.py diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 1eea6deaf..a83a77b46 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -10,6 +10,11 @@ from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory from memos.exceptions import ConfigurationError +from memos.memories.textual.prefer_text_memory.config import ( + BuilderConfigFactory, + RetrieverConfigFactory, + UpdaterConfigFactory, + AssemblerConfigFactory) # ─── 1. Global Base Memory Config ───────────────────────────────────────────── @@ -180,6 +185,45 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ) +class PreferenceTextMemoryConfig(BaseTextMemoryConfig): + """Preference memory configuration class.""" + + extractor_llm: LLMConfigFactory = Field( + ..., + default_factory=LLMConfigFactory, + description="LLM configuration for the memory extractor", + ) + vector_db: VectorDBConfigFactory = Field( + ..., + default_factory=VectorDBConfigFactory, + description="Vector database configuration for the memory storage", + ) + embedder: EmbedderConfigFactory = Field( + ..., + default_factory=EmbedderConfigFactory, + description="Embedder configuration for the memory embedding", + ) + builder: BuilderConfigFactory = Field( + ..., + default_factory=BuilderConfigFactory, + description="Builder configuration for the memory building", + ) + retriever: RetrieverConfigFactory = Field( + ..., + default_factory=RetrieverConfigFactory, + description="Retriever configuration for the memory retrieving", + ) + updater: UpdaterConfigFactory = Field( + ..., + default_factory=UpdaterConfigFactory, + description="Updater configuration for the memory updating", + ) + assembler: AssemblerConfigFactory = Field( + ..., + default_factory=AssemblerConfigFactory, + description="Assembler configuration for the memory assembling", + ) + # ─── 3. Global Memory Config Factory ────────────────────────────────────────── @@ -193,6 +237,7 @@ class MemoryConfigFactory(BaseConfig): "naive_text": NaiveTextMemoryConfig, "general_text": GeneralTextMemoryConfig, "tree_text": TreeTextMemoryConfig, + "preference_text": PreferenceTextMemoryConfig, "kv_cache": KVCacheMemoryConfig, "vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache "lora": LoRAMemoryConfig, diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index b43298d9b..2ffc5abf9 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -39,6 +39,12 @@ def set_default_path(self): return self +class MilvusVecDBConfig(BaseVecDBConfig): + """Configuration for Milvus vector database.""" + + uri: str = Field(..., description="URI for Milvus connection") + collection_name: list[str] = Field(..., description="Name(s) of the collection(s)") + class VectorDBConfigFactory(BaseConfig): """Factory class for creating vector database configurations.""" @@ -47,6 +53,7 @@ class VectorDBConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "qdrant": QdrantVecDBConfig, + "milvus": MilvusVecDBConfig, } @field_validator("backend") diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index 9fdc67c53..e53cc9569 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -11,6 +11,7 @@ from memos.memories.textual.general import GeneralTextMemory from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.tree import TreeTextMemory +from memos.memories.textual.preference import PreferenceTextMemory class MemoryFactory(BaseMemory): @@ -20,6 +21,7 @@ class MemoryFactory(BaseMemory): "naive_text": NaiveTextMemory, "general_text": GeneralTextMemory, "tree_text": TreeTextMemory, + "preference_text": PreferenceTextMemory, "kv_cache": KVCacheMemory, "vllm_kv_cache": VLLMKVCacheMemory, "lora": LoRAMemory, diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py index 845c55e24..52332d1fd 100644 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -5,21 +5,24 @@ class BaseAssembler(ABC): """Abstract base class for assemblers.""" @abstractmethod - def __init__(self): + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the assembler.""" @abstractmethod - def assemble(self, memories: list[TextualMemoryItem]) -> str: - """Assemble memories into a single memory.""" + def assemble(self, query: str, memories: list[TextualMemoryItem]) -> str: + """Assemble query and memories into a single memory.""" class NaiveAssembler(BaseAssembler): """Naive assembler.""" - def __init__(self): + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the naive assembler.""" - super().__init__() + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db - def assemble(self, memories: list[TextualMemoryItem]) -> str: - """Assemble memories into a single memory.""" - pass + def assemble(self, query: str, memories: list[TextualMemoryItem]) -> str: + """Assemble query and memories into a single memory.""" + return f"Query: {query}\nMemories: {memories}" diff --git a/src/memos/memories/textual/prefer_text_memory/builders.py b/src/memos/memories/textual/prefer_text_memory/builders.py index b2830f4d3..6896bd298 100644 --- a/src/memos/memories/textual/prefer_text_memory/builders.py +++ b/src/memos/memories/textual/prefer_text_memory/builders.py @@ -1,33 +1,51 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Dict, List, Any, Tuple +import uuid +import json +from datetime import datetime + from memos.llms.base import BaseLLM +from memos.types import ChatHistory, MessageList +from memos.embedders.base import BaseEmbedder +from memos.vec_dbs.base import BaseVecDB +from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata +from memos.templates.prefer_complete_prompt import ( + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT, + NAIVE_USER_PREFERENCE_EXTRACT_PROMPT, + NAIVE_TOPIC_INFO_EXTRACT_PROMPT +) +from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer +from memos.memories.textual.prefer_text_memory.chunk_merging import ChunkMergingManager, NaiveChunkMerger + class BaseBuilder(ABC): """ Abstract base class for memory builders. Each builder implements a specific build strategy for creating - procedural memory content from task trajectories. + preference memory content from chat history. """ @abstractmethod - def __init__(self, llm_provider: Optional[BaseLLM] = None): + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """ Initialize the memory builder. Args: llm_provider: LLM provider for script generation (required for some strategies) + embedder: Embedder for vector operations + vector_db: Vector database for storage """ - self.llm_provider = llm_provider @abstractmethod - def build(self, task_description: str, trajectory: str) -> str: + def build(self, history: ChatHistory) -> str: """ - Build memory content from task description and trajectory. + Build memory content from chat history. Args: - task_description: Natural language description of the task - trajectory: Detailed step-by-step trajectory of task execution + history: The chat history to build memory from. Returns: Memory content string formatted according to the build strategy @@ -35,15 +53,457 @@ def build(self, task_description: str, trajectory: str) -> str: Raises: RuntimeError: If memory building fails """ - pass + class NaiveBuilder(BaseBuilder): """Naive memory builder.""" - def __init__(self, llm_provider: Optional[BaseLLM] = None): + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the naive memory builder.""" - super().__init__(llm_provider) + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db + + + + def build(self, history: ChatHistory) -> str: + """Build memory content from chat history following the preference extraction pipeline.""" + + # Initialize clustering and chunk merging managers + clusterer = HDBSCANClusterer() + chunk_merging_manager = ChunkMergingManager(NaiveChunkMerger()) + + # Step 1: Build QA pairs from chat history + qa_pairs = self._build_qa_pairs(history.chat_history) + + # Step 2: Process each QA pair + basic_infos = [] + explicit_preferences = [] + topic_infos = [] + + for qa_pair in qa_pairs: + # Extract basic info + basic_info = self._extract_basic_info(qa_pair) + basic_infos.append(basic_info) + + # Extract topic information + topic_info = self._extract_topic_info(qa_pair) + if topic_info: + topic_infos.append(topic_info) + + # Extract explicit preference from + explicit_pref = self._extract_explicit_preference(qa_pair) + if explicit_pref: + explicit_preferences.append(explicit_pref) + + # Step 3: Generate embeddings + dialogue_vectors = self._generate_dialogue_vectors(basic_infos) + topic_vectors = self._generate_topic_vectors(topic_infos) - def build(self, task_description: str, trajectory: str) -> str: - """Build memory content from task description and trajectory.""" - pass \ No newline at end of file + whole_infos = self._concat_infos(basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors) + + # Step 4: Perform clustering + implicit_clusters = self._implicit_cluster(clusterer, whole_infos) + topic_clusters = self._topic_cluster(clusterer, whole_infos) + + # Step 5: Extract implicit preferences + implicit_preferences = self._extract_implicit_preferences(implicit_clusters) + + # Step 6: Extract topic preferences + topic_preferences = self._extract_topic_preferences(topic_clusters) + + # Step 7: Handle chunk merging for long chunks (optional) + merged_preferences = self._handle_chunk_merging( + explicit_preferences, implicit_preferences, topic_preferences, chunk_merging_manager + ) + + # Step 8: Extract user preferences + user_preferences = self._extract_user_preferences(topic_preferences) + + # Step 9: Store all preferences in memory + self._store_preferences( + explicit_preferences, + implicit_preferences, + topic_preferences, + user_preferences, + history.user_id + ) + + # Return summary of built memory + return self._generate_memory_summary( + explicit_preferences, + implicit_preferences, + topic_preferences, + basic_infos, + user_preferences + ) + + def _build_qa_pairs(self, chat_history: MessageList) -> List[MessageList]: + """Build QA pairs from chat history.""" + qa_pairs = [] + current_qa_pair = [] + + for message in chat_history: + role = message["role"] + + if role == "user": + # If we have a complete QA pair, save it + if len(current_qa_pair) >= 2: # At least question + answer + qa_pairs.append(current_qa_pair) + + # Start new QA pair + current_qa_pair = [message] + + elif role == "assistant": + # Add answer to current QA pair + current_qa_pair.append(message) + + # Don't forget the last QA pair if it exists and is complete + if len(current_qa_pair) >= 2: + qa_pairs.append(current_qa_pair) + + return qa_pairs + + def _extract_basic_info(self, qa_pair: MessageList) -> Dict[str, Any]: + """Extract basic information from a QA pair (no LLM needed).""" + basic_info = { + "dialog_id": str(uuid.uuid4()), + "dialog_msgs": qa_pair, + "dialog_str": "\n".join([f"{msg['role']}: {msg['content']}" for msg in qa_pair]), + "created_at": datetime.now().isoformat() + } + + return basic_info + + def _extract_topic_info(self, qa_pair: MessageList) -> Optional[Dict[str, Any]]: + """Extract topic information from a QA pair.""" + prompt = NAIVE_TOPIC_INFO_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + return result + except Exception: + return response + + def _extract_explicit_preference(self, qa_pair: MessageList) -> Optional[Dict[str, Any]]: + """Extract explicit preference from a QA pair (LLM-1).""" + + prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + return result + except Exception: + return response + + def _generate_dialogue_vectors(self, basic_infos: List[Dict[str, Any]]) -> List[List[float]]: + """Generate embeddings for dialogue segments.""" + if not self.embedder or not basic_infos: + return [] + + texts = [info.get("dialog_segment_str", "") for info in basic_infos] + return self.embedder.embed(texts) + + def _generate_topic_vectors(self, topic_infos: List[Dict[str, Any]]) -> List[List[float]]: + """Generate embeddings for topic information.""" + if not self.embedder: + return [] + + texts = [f"{info.get('topic_name', '')} {info.get('topic_description', '')}" for info in topic_infos] + return self.embedder.embed(texts) + + def _concat_infos( + self, + basic_infos: List[Dict[str, Any]], + explicit_preferences: List[Dict[str, Any]], + topic_infos: List[Dict[str, Any]], + dialogue_vectors: List[List[float]], + topic_vectors: List[List[float]]) -> List[Dict[str, Any]]: + """Concatenate infos.""" + whole_infos = [{**bsc, **ep, **ti, "dialog_vector": dv, "topic_vector": tv} + for bsc, ep, ti, dv, tv in zip(basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors)] + return whole_infos + + def _implicit_cluster(self, clusterer: HDBSCANClusterer, whole_infos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Implicit cluster.""" + # Extract vectors for clustering + vectors = [info.get("dialog_vector") for info in whole_infos] + + # Perform clustering + cluster_results = clusterer.cluster(vectors) + + # Map cluster results back to original data + for cluster in cluster_results: + original_infos = [] + for item in cluster["items"]: + index = item["index"] + original_info = whole_infos[index] + original_infos.append({ + "dialog_id": original_info.get("dialog_id"), + "dialog_msgs": original_info.get("dialog_msgs"), + "dialog_str": original_info.get("dialog_str"), + "created_at": original_info.get("created_at") + }) + cluster["original_data"] = original_infos + + return cluster_results + + def _topic_cluster(self, clusterer: HDBSCANClusterer, whole_infos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Topic cluster.""" + # Extract vectors for clustering + vectors = [info.get("topic_vector") for info in whole_infos] + + # Perform clustering + cluster_results = clusterer.cluster(vectors) + + # Map cluster results back to original data + for cluster in cluster_results: + original_infos = [] + for item in cluster["items"]: + index = item["index"] + original_info = whole_infos[index] + original_infos.append({ + "dialog_id": original_info.get("dialog_id"), + "dialog_msgs": original_info.get("dialog_msgs"), + "dialog_str": original_info.get("dialog_str"), + "created_at": original_info.get("created_at") + }) + cluster["original_data"] = original_infos + + return cluster_results + + + def _handle_chunk_merging(self, explicit_prefs: List[Dict[str, Any]], + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + chunk_merging_manager) -> Dict[str, Any]: + """Handle chunk merging for long chunks (optional step).""" + merged_results = { + "explicit_preferences": explicit_prefs, + "implicit_preferences": implicit_prefs, + "topic_preferences": topic_prefs, + "merged_chunks": [] + } + + # Check for long chunks that need merging + all_preferences = explicit_prefs + implicit_prefs + topic_prefs + + for pref in all_preferences: + content = pref.get("preference", "") + if len(content) > 4096: # Long chunk threshold + # Split the long chunk + chunks = chunk_merging_manager.process_long_chunks([{ + "content": content, + "id": pref.get("id", ""), + "preference": pref + }]) + + if len(chunks) > 1: + # Extract preferences from split chunks + chunk_preferences = [chunk.get("preference", {}) for chunk in chunks] + + # Merge preferences using LLM + merged_pref = chunk_merging_manager.merge_chunk_items( + chunk_preferences, self.llm_provider + ) + + if merged_pref: + merged_results["merged_chunks"].append({ + "original_preference": pref, + "merged_preference": merged_pref, + "chunk_count": len(chunks) + }) + + return merged_results + + def _extract_implicit_preferences(self, clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract implicit preferences from clusters.""" + + implicit_preferences = [] + + for cluster in clusters: + # Get dialogue segments in this cluster + qa_pairs = "\n".join([info["dialog_str"] for info in cluster["original_data"]]) + + prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + if result.get("implicit_preference"): + result["id"] = str(uuid.uuid4()) + result["created_at"] = datetime.now().isoformat() + implicit_preferences.append(result) + except Exception: + continue + + return implicit_preferences + + def _extract_topic_preferences(self, clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract topic preferences from clusters.""" + + topic_preferences = [] + + for cluster in clusters: + # Get topic infos in this cluster + cluster_topics = [cluster["items"][i] for i in range(len(cluster["items"])) + if i < len(cluster["items"])] + + prompt = f""" + {NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT} + + 主题聚类信息: + - 聚类ID: {cluster.get('cluster_id', '')} + - 聚类大小: {cluster.get('size', 0)} + + 相关主题信息: + {json.dumps(cluster_topics, ensure_ascii=False, indent=2)} + + 请提取主题偏好,返回JSON格式: + {{ + "topic_preference": "主题偏好描述", + "confidence": 0.8, + "cluster_id": "{cluster.get('cluster_id', '')}" + }} + """ + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + if result.get("topic_preference"): + result["id"] = str(uuid.uuid4()) + result["created_at"] = datetime.now().isoformat() + topic_preferences.append(result) + except Exception: + continue + + return topic_preferences + + def _extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract user-level preferences.""" + + prompt = f""" + {NAIVE_USER_PREFERENCE_EXTRACT_PROMPT} + + 主题偏好信息: + {json.dumps(topic_preferences, ensure_ascii=False, indent=2)} + + 相关对话片段: + {json.dumps(topic_preferences, ensure_ascii=False, indent=2)} + + 请提取用户偏好,返回JSON格式: + {{ + "user_preference": "用户偏好描述", + "confidence": 0.8 + }} + """ + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + if result.get("user_preference"): + result["id"] = str(uuid.uuid4()) + result["created_at"] = datetime.now().isoformat() + return [result] + except Exception: + pass + + return [] + + def _store_preferences(self, explicit_prefs: List[Dict[str, Any]], + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + user_prefs: List[Dict[str, Any]], + user_id: str): + """Store all preferences in memory.""" + + # Convert to TextualMemoryItem and store + all_memories = [] + + # Store explicit preferences + for pref in explicit_prefs: + memory_item = TextualMemoryItem( + memory=pref.get("preference", ""), + metadata=TextualMemoryMetadata( + user_id=user_id, + type="explicit_preference", + confidence=pref.get("confidence", 0.5), + source="conversation", + tags=["explicit", "preference"], + updated_at=pref.get("created_at", datetime.now().isoformat()) + ) + ) + all_memories.append(memory_item) + + # Store implicit preferences + for pref in implicit_prefs: + memory_item = TextualMemoryItem( + memory=pref.get("implicit_preference", ""), + metadata=TextualMemoryMetadata( + user_id=user_id, + type="implicit_preference", + confidence=pref.get("confidence", 0.5), + source="conversation", + tags=["implicit", "preference", pref.get("cluster_id", "")], + updated_at=pref.get("created_at", datetime.now().isoformat()) + ) + ) + all_memories.append(memory_item) + + # Store topic preferences + for pref in topic_prefs: + memory_item = TextualMemoryItem( + memory=pref.get("topic_preference", ""), + metadata=TextualMemoryMetadata( + user_id=user_id, + type="topic_preference", + confidence=pref.get("confidence", 0.5), + source="conversation", + tags=["topic", "preference", pref.get("cluster_id", "")], + updated_at=pref.get("created_at", datetime.now().isoformat()) + ) + ) + all_memories.append(memory_item) + + # Store user preferences + for pref in user_prefs: + memory_item = TextualMemoryItem( + memory=pref.get("user_preference", ""), + metadata=TextualMemoryMetadata( + user_id=user_id, + type="user_preference", + confidence=pref.get("confidence", 0.5), + source="conversation", + tags=["user", "preference"], + updated_at=pref.get("created_at", datetime.now().isoformat()) + ) + ) + all_memories.append(memory_item) + + # Store in vector database + if all_memories: + self.vector_db.add(all_memories) + + def _generate_memory_summary(self, explicit_prefs: List[Dict[str, Any]], + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + user_prefs: List[Dict[str, Any]]) -> str: + """Generate a summary of the built memory.""" + summary = { + "memory_build_summary": { + "explicit_preferences_count": len(explicit_prefs), + "implicit_preferences_count": len(implicit_prefs), + "topic_preferences_count": len(topic_prefs), + "user_preferences_count": len(user_prefs), + "total_memories": len(explicit_prefs) + len(implicit_prefs) + len(topic_prefs) + len(user_prefs), + "build_timestamp": datetime.now().isoformat() + } + } + + return json.dumps(summary, ensure_ascii=False, indent=2) diff --git a/src/memos/memories/textual/prefer_text_memory/clustering.py b/src/memos/memories/textual/prefer_text_memory/clustering.py new file mode 100644 index 000000000..f15afc81f --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/clustering.py @@ -0,0 +1,122 @@ +""" +Clustering module for preference memory extraction. + +This module provides abstract clustering functionality that can be used +by different types of clustering (implicit, topic, etc.). +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Tuple +import numpy as np +from datetime import datetime + + +class BaseClusterer(ABC): + """Abstract base class for clustering algorithms.""" + + @abstractmethod + def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: + """ + Perform clustering on the given vectors. + + Args: + vectors: List of vectors to cluster + **kwargs: Additional clustering parameters + + Returns: + List of cluster information dictionaries + """ + pass + + +class HDBSCANClusterer(BaseClusterer): + """HDBSCAN-based clustering implementation.""" + + def __init__(self, min_cluster_size: int = 2, min_samples: int = 1): + """ + Initialize HDBSCAN clusterer. + + Args: + min_cluster_size: Minimum size of a cluster + min_samples: Minimum number of samples in a neighborhood + """ + self.min_cluster_size = min_cluster_size + self.min_samples = min_samples + + def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: + """ + Perform HDBSCAN clustering on the given vectors. + + Args: + vectors: List of vectors to cluster + **kwargs: Additional clustering parameters + + Returns: + List of cluster information dictionaries + """ + if not vectors or len(vectors) < 2: + return [] + + try: + from sklearn.cluster import HDBSCAN + + # Convert to numpy array + vectors_array = np.array(vectors) + + # Perform clustering + clusterer = HDBSCAN( + min_cluster_size=self.min_cluster_size, + min_samples=self.min_samples + ) + cluster_labels = clusterer.fit_predict(vectors_array) + + # Group vectors by cluster + clusters = {} + for i, label in enumerate(cluster_labels): + if label not in clusters: + clusters[label] = [] + clusters[label].append({ + "vector": vectors[i], # origin sample vector + "index": i # origin sample index + }) + + # Convert to cluster information + cluster_info = [] + for cluster_id, items_in_cluster in clusters.items(): + if cluster_id == -1: # Skip noise points + continue + + # Calculate cluster center + center_vector = np.mean([item["vector"] for item in items_in_cluster], axis=0) + + # Find the vector closest to the geometric center + distances = [np.linalg.norm(np.array(item["vector"]) - center_vector) for item in items_in_cluster] + closest_idx = np.argmin(distances) + center_index = items_in_cluster[closest_idx]["index"] + + cluster_info.append({ + "cluster_id": cluster_id, + "center_index": center_index, + "center_vector": center_vector.tolist(), + "size": len(items_in_cluster), + "created_at": datetime.now().isoformat(), + "items": items_in_cluster # list of origin sample index and vector in this cluster + }) + + return cluster_info + + except ImportError: + # Fallback: create single cluster + return [{ + "cluster_id": 0, + "center_index": 0, + "center_vector": vectors[0] if vectors else [], + "size": len(vectors), + "created_at": datetime.now().isoformat(), + "items": [{"index": i, "vector": vectors[i]} + for i in range(len(vectors))] # list of origin sample index and vector in this cluster + }] + + + + diff --git a/src/memos/memories/textual/prefer_text_memory/config.py b/src/memos/memories/textual/prefer_text_memory/config.py index fc1e3bfb4..ede422e15 100644 --- a/src/memos/memories/textual/prefer_text_memory/config.py +++ b/src/memos/memories/textual/prefer_text_memory/config.py @@ -2,6 +2,9 @@ from typing import Any, ClassVar from pydantic import Field from pydantic import field_validator, model_validator +from memos.configs.vec_db import VectorDBConfigFactory +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.llm import LLMConfigFactory @@ -11,6 +14,7 @@ class BaseBuilderConfig(BaseConfig): class NaiveBuilderConfig(BaseBuilderConfig): """Configuration for Naive Builder.""" + # No additional config needed since components are passed from parent diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index 612dc1194..16dc622e9 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -15,13 +15,13 @@ class BuilderFactory(BaseBuilder): } @classmethod - def from_config(cls, config_factory: BuilderConfigFactory) -> BaseBuilder: + def from_config(cls, config_factory: BuilderConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseBuilder: """Create a Builder instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") builder_class = cls.backend_to_class[backend] - return builder_class(config_factory.config) + return builder_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) class RetrieverFactory(BaseRetriever): """Factory class for creating Retriever instances.""" @@ -31,13 +31,13 @@ class RetrieverFactory(BaseRetriever): } @classmethod - def from_config(cls, config_factory: RetrieverConfigFactory) -> BaseRetriever: + def from_config(cls, config_factory: RetrieverConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseRetriever: """Create a Retriever instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") retriever_class = cls.backend_to_class[backend] - return retriever_class(config_factory.config) + return retriever_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) class UpdaterFactory(BaseUpdater): """Factory class for creating Updater instances.""" @@ -47,13 +47,13 @@ class UpdaterFactory(BaseUpdater): } @classmethod - def from_config(cls, config_factory: UpdaterConfigFactory) -> BaseUpdater: + def from_config(cls, config_factory: UpdaterConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseUpdater: """Create a Updater instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") updater_class = cls.backend_to_class[backend] - return updater_class(config_factory.config) + return updater_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) class AssemblerFactory(BaseAssembler): """Factory class for creating Assembler instances.""" @@ -63,11 +63,11 @@ class AssemblerFactory(BaseAssembler): } @classmethod - def from_config(cls, config_factory: AssemblerConfigFactory) -> BaseAssembler: + def from_config(cls, config_factory: AssemblerConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseAssembler: """Create a Assembler instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") assembler_class = cls.backend_to_class[backend] - return assembler_class(config_factory.config) + return assembler_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index a6c3bba1c..ea2f259ca 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -1,25 +1,32 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any +from memos.memories.textual.item import TextualMemoryItem +from memos.vec_dbs.base import BaseVecDB +from memos.embedders.base import BaseEmbedder class BaseRetriever(ABC): """Abstract base class for retrievers.""" @abstractmethod - def __init__(self): + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the retriever.""" @abstractmethod - def retrieve(self, task_description: str, k: int, threshold: float) -> List[Dict[str, Any]]: + def retrieve(self, query: str, top_k: int, info: dict[str, Any]) -> list[TextualMemoryItem]: """Retrieve memories from the retriever.""" class NaiveRetriever(BaseRetriever): """Naive retriever.""" - def __init__(self): + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the naive retriever.""" - super().__init__() + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.vector_db = vector_db + self.embedder = embedder - def retrieve(self, task_description: str, k: int, threshold: float) -> List[Dict[str, Any]]: + def retrieve(self, query: str, top_k: int, info: dict[str, Any]) -> list[TextualMemoryItem]: """Retrieve memories from the naive retriever.""" - pass \ No newline at end of file + query_embedding = self.embedder.embed(query) + return self.vector_db.search(query_embedding, top_k, info) \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index 71d043e89..a16884ec0 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any from memos.memories.textual.item import TextualMemoryItem @@ -6,18 +7,35 @@ class BaseUpdater(ABC): """Abstract base class for updaters.""" @abstractmethod - def __init__(self): + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the updater.""" + @abstractmethod + def update(self, new_memory: TextualMemoryItem | dict[str, Any], *args, **kwargs) -> None: + """Update the memory. + Args: + new_memory (TextualMemoryItem | dict[str, Any]): The new memory to update. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + class NaiveUpdater(BaseUpdater): """Naive updater.""" - def __init__(self): + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the naive updater.""" - super().__init__() + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db - def update(self, memories: list[TextualMemoryItem]) -> None: - """Update the memory.""" + def update(self, new_memory: TextualMemoryItem | dict[str, Any], *args, **kwargs) -> None: + """Update the memory. + Args: + new_memory (TextualMemoryItem | dict[str, Any]): The new memory to update. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ pass diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 96c4e830c..0d318e230 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -7,23 +7,24 @@ from pathlib import Path from typing import Any -from memos.configs.memory import GeneralTextMemoryConfig +from memos.configs.memory import PreferenceTextMemoryConfig from memos.embedders.factory import EmbedderFactory, OllamaEmbedder, ArkEmbedder, SenTranEmbedder, UniversalAPIEmbedder from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.item import TextualMemoryItem -from memos.types import MessageList +from memos.types import ChatHistory, MessageList from memos.llms.base import BaseLLM from memos.vec_dbs.factory import QdrantVecDB, VecDBFactory +from memos.memories.textual.prefer_text_memory.factory import BuilderFactory, RetrieverFactory, UpdaterFactory, AssemblerFactory class PreferenceTextMemory(BaseTextMemory): """Preference textual memory implementation for storing and retrieving memories.""" - def __init__(self, config: GeneralTextMemoryConfig): + def __init__(self, config: PreferenceTextMemoryConfig): """Initialize memory with the given configuration.""" - self.config: GeneralTextMemoryConfig = config + self.config: PreferenceTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm ) @@ -31,26 +32,74 @@ def __init__(self, config: GeneralTextMemoryConfig): self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = \ EmbedderFactory.from_config(config.embedder) + self.builder = BuilderFactory.from_config( + config.builder, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db + ) + self.retriever = RetrieverFactory.from_config( + config.retriever, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db + ) + self.updater = UpdaterFactory.from_config( + config.updater, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db + ) + self.assembler = AssemblerFactory.from_config( + config.assembler, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db + ) - def build_memory(): - """Build memory from the original dialogs. (Initialize memory)""" - pass - - def extract(self, messages: MessageList) -> list[TextualMemoryItem]: - """Extract memories based on the messages. + def build_memory(self, history: ChatHistory) -> None: + """Build memory from the original dialogs. (Initialize memory) + Args: - messages (MessageList): The messages to extract memories from. + history: The chat history to build memory from. + Returns: - list[TextualMemoryItem]: List of extracted memory items. + Memory content string formatted according to the build strategy """ - pass + return self.builder.build(history) + + def update_memory(self, new_memory: TextualMemoryItem | dict[str, Any]) -> None: + """Update a memory by new memory.""" + self.updater.update(new_memory) - def get_prompt(self, memories: list[TextualMemoryItem]) -> str: - """Get the prompt for the memory. + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + """Search for memories based on a query. Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + return self.retriever.retrieve(query, top_k, info) + + + def get_prompt(self, query: str, memories: list[TextualMemoryItem]) -> str: + """Construct the prompt for the query with memories. + Args: + query (str): The query to get the prompt for. memories (list[TextualMemoryItem]): The memories to get the prompt for. Returns: - str: The prompt for the memory. + str: The prompt for the query with memories. + """ + return self.assembler.assemble(query, memories) + + def extract(self, messages: MessageList) -> list[TextualMemoryItem]: + """Extract memories based on the messages. + Args: + messages (MessageList): The messages to extract memories from. + Returns: + list[TextualMemoryItem]: List of extracted memory items. """ pass @@ -66,17 +115,6 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) """Update a memory by memory_id.""" pass - def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: - """Search for memories based on a query. - Args: - query (str): The query to search for. - top_k (int): The number of top results to return. - info (dict): Leave a record of memory consumption. - Returns: - list[TextualMemoryItem]: List of matching memories. - """ - pass - def get(self, memory_id: str) -> TextualMemoryItem: """Get a memory by its ID. Args: diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py new file mode 100644 index 000000000..e004b8da9 --- /dev/null +++ b/src/memos/templates/prefer_complete_prompt.py @@ -0,0 +1,200 @@ + + + +NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT = """ +You are an information extraction assistant. You will be given a QA pair (user question Q and assistant answer A). +Please extract the user's explicit preferences from the Q and implicit preferences from the A, and output JSON strictly according to the requirements. + +# Extraction Rules +Explicit preferences (explicit_preferences): Extract only from the user's Q. Including but not limited to: +- Role descriptions (e.g., "You are a history teacher") +- Style constraints (e.g., "humorous style", "academic style") +- Format requirements (e.g., "table", "Markdown") +- Length limitations (e.g., "within 100 words") +- Language requirements (e.g., "write in English") +- Safety compliance requirements (e.g., "don't involve sensitive content") +- Quality standards (e.g., "be concise and clear") + +# Output Format +{ + "explicit_preferences": { + "role": "", + "style": "", + "format": "", + "length": "", + "language": "", + "safety": "", + "quality": "" + } +} + +# Notes +If there is no information for a certain item, please leave an empty string "". +Only output JSON, no explanations. + +# Conversation Content +{qa_pair} +""" + + +NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ +You are a preference extraction expert. You will be given multiple user QA pairs (questions and answers). +Your task is to extract the user's **implicit preferences** from these QA pairs. + +# Definitions: +1. **Explicit Preferences**: Constraints explicitly stated by the user in their questions, such as: + - Role requirements (e.g., "act as a teacher") + - Style preferences (e.g., "be humorous", "be formal") + - Format requirements (e.g., "use bullet points", "create a table") + - Length constraints (e.g., "keep it short", "be detailed") + - Language requirements (e.g., "write in English") + - Safety guidelines (e.g., "avoid sensitive topics") + +2. **Implicit Preferences**: Patterns that are NOT explicitly stated but consistently appear across multiple QA pairs: + - Recurring themes or topics the user frequently asks about + - Consistent communication style preferences + - Repeated information depth requirements + - Common response format expectations + - Underlying values or priorities + +# Extraction Rules: +- Focus on patterns that appear across MULTIPLE QA pairs, not single occurrences +- Look for consistent behaviors, not one-time requests +- Extract only implicit preferences, do not repeat explicit ones +- Use concise language, avoid redundant words +- Each preference should be distinct and non-overlapping + +# Output Format: +{ + "implicit_preferences": [ + "preference 1", + "preference 2", + "preference 3" + ] +} + +# Notes: +- If no clear implicit preferences are found, return an empty array [] +- Only output JSON, no explanations +- Focus on meaningful patterns, not trivial observations + +# QA Pairs: +{qa_pair} +""" + + +NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT = """ +You are a topic cluster analysis expert. You will be given a topic cluster containing multiple QA conversations with their preferences, topic names, and descriptions. +Your task is to analyze all information in this cluster and extract three key components, then output in strict JSON format. + +# Extraction Targets: + +1. **Cluster Name (cluster_name)**: + - Use concise words (2-6 words) to summarize the core theme of this cluster + - Should be more abstract and comprehensive than individual conversation topics + - Examples: "Machine Learning Basics", "Creative Writing", "Health & Wellness", "Career Development" + +2. **Cluster Description (cluster_description)**: + - Provide a brief summary of the main content and scope of this cluster theme + - Should be more specific than the cluster name + - Describe what topics and areas are covered within this cluster + +3. **Cluster Preferences (cluster_preferences)**: + - Combine and summarize common preferences from all explicit and implicit preferences in the cluster + - Focus on shared patterns across all QA pairs in this topic area + - Identify recurring user preferences specific to this topic domain + +# Output Format: +{ + "cluster_name": "", + "cluster_description": "", + "cluster_preferences": "" +} + +# Notes: +- If any field has no clear information, leave it as empty string "" +- Only output JSON, no explanations +- Focus on meaningful patterns that represent the cluster as a whole +- Cluster name should be broader than individual topic names + +# Cluster Information: +{cluster_info} +""" + + +NAIVE_USER_PREFERENCE_EXTRACT_PROMPT = """ +You are an advanced information integration assistant. You will be given a user's preference list across different topic clusters, where each cluster contains: + +- cluster_name: The name of the topic cluster +- cluster_preferences: Natural language description of preferences in that cluster + +Your task is to extract the user's **highest-level common preferences** by focusing on these three key dimensions: + +# Analysis Dimensions: + +1. **Content Preferences**: What types of information, topics, or knowledge styles the user tends to prefer + - Subject matter interests and expertise areas + - Information depth and complexity preferences + - Knowledge domain preferences + +2. **Interaction Style Preferences**: How the user prefers information to be presented, structured, or delivered + - Communication format preferences (formal vs. casual) + - Information organization preferences (structured vs. narrative) + - Response style preferences (concise vs. detailed) + +3. **Value Orientations**: Core values or principles reflected in the user's information choices, processing, or practices + - Underlying priorities and decision-making patterns + - Quality standards and expectations + - Ethical or professional principles + +# Requirements: +- Synthesize common patterns across ALL topic clusters, not individual cluster details +- Express in natural language, highlighting the user's overall preference characteristics +- Summarize into a coherent paragraph that flows smoothly +- Avoid bullet points and don't repeat specific cluster examples or operational details +- Focus on high-level patterns that transcend individual topics + +# Output Format: +{ + "user_overall_preferences": "Write the synthesized highest-level common preferences here, covering content preferences, interaction style, and value orientations" +} + +# Notes: +- If no clear patterns emerge, describe the user as having diverse or varied preferences +- Focus on meaningful patterns, not trivial observations +- Only output JSON, no explanations + +# Cluster Information: +{cluster_info} +""" + + +NAIVE_TOPIC_INFO_EXTRACT_PROMPT = """ +You are a topic extraction assistant. You will be given a QA pair (user question Q and assistant answer A). +Please extract the main topic name and topic description from this conversation and output in JSON format. + +# Extraction Rules +- Topic Name (topic_name): + - Use concise words to summarize the core topic of the conversation + - Keep it between 2-6 words + - Examples: "Science Fiction Writing", "Technical Documentation", "Healthy Diet", "Career Advice", "Python Programming" + +- Topic Description (topic_description): + - Provide a brief 1-2 sentence summary of what the conversation is about + - Be more specific than the topic name but keep it under 50 words + - Focus on the main content and key points discussed + +# Output Format +{ + "topic_name": "", + "topic_description": "" +} + +# Notes +- If no clear topic can be identified, leave topic_name as empty string "" +- If the conversation is too brief or unclear, leave topic_description as empty string "" +- Only output JSON, no explanations + +# Conversation Content +{qa_pair} +""" diff --git a/src/memos/vec_dbs/factory.py b/src/memos/vec_dbs/factory.py index 8df22d14d..d8545027c 100644 --- a/src/memos/vec_dbs/factory.py +++ b/src/memos/vec_dbs/factory.py @@ -3,6 +3,7 @@ from memos.configs.vec_db import VectorDBConfigFactory from memos.vec_dbs.base import BaseVecDB from memos.vec_dbs.qdrant import QdrantVecDB +from memos.vec_dbs.milvus import MilvusVecDB class VecDBFactory(BaseVecDB): @@ -10,6 +11,7 @@ class VecDBFactory(BaseVecDB): backend_to_class: ClassVar[dict[str, Any]] = { "qdrant": QdrantVecDB, + "milvus": MilvusVecDB, } @classmethod From 41a487ad31d8ff39aed40ed1d4692e520ba9e0fe Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 16 Sep 2025 10:06:39 +0800 Subject: [PATCH 04/62] finish prefer textual memory base level --- src/memos/configs/vec_db.py | 1 + src/memos/memories/textual/item.py | 57 +++ .../textual/prefer_text_memory/__init__.py | 0 .../textual/prefer_text_memory/assemble.py | 32 +- .../textual/prefer_text_memory/builders.py | 449 ++---------------- .../textual/prefer_text_memory/clustering.py | 11 +- .../textual/prefer_text_memory/naive_op.py | 439 +++++++++++++++++ .../textual/prefer_text_memory/retrievers.py | 33 +- .../textual/prefer_text_memory/updater.py | 113 ++++- src/memos/memories/textual/preference.py | 65 ++- src/memos/templates/prefer_complete_prompt.py | 78 ++- src/memos/vec_dbs/milvus.py | 58 ++- 12 files changed, 869 insertions(+), 467 deletions(-) create mode 100644 src/memos/memories/textual/prefer_text_memory/__init__.py create mode 100644 src/memos/memories/textual/prefer_text_memory/naive_op.py diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index 2ffc5abf9..457a09a21 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -44,6 +44,7 @@ class MilvusVecDBConfig(BaseVecDBConfig): uri: str = Field(..., description="URI for Milvus connection") collection_name: list[str] = Field(..., description="Name(s) of the collection(s)") + max_length: int = Field(default=65535, description="Maximum length for string fields (varChar type)") class VectorDBConfigFactory(BaseConfig): """Factory class for creating vector database configurations.""" diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 6b6e70fdf..490fc1a72 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -104,6 +104,62 @@ class SearchedTreeNodeTextualMemoryMetadata(TreeNodeTextualMemoryMetadata): ) +class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): + """Metadata for preference memory item.""" + + preference_type: Literal["explicit_preference", "implicit_preference", "topic_preference", "user_preference"] = Field( + default="explicit_preference", description="Type of preference." + ) + dialog_id: str | None = Field( + default=None, description="ID of the dialog." + ) + dialog_msgs: list[str] | None = Field( + default=None, description="Messages of the dialog." + ) + dialog_str: str | None = Field( + default=None, description="String of the dialog." + ) + dialog_vector: list[float] | None = Field( + default=None, description="Vector of the dialog." + ) + topic_name: str | None = Field( + default=None, description="Name of the topic." + ) + topic_description: str | None = Field( + default=None, description="Description of the topic." + ) + topic_vector: list[float] | None = Field( + default=None, description="Vector of the topic." + ) + explicit_preference: str | None = Field( + default=None, description="Explicit preference." + ) + center_dialog: str | None = Field( + default=None, description="String of the center dialog." + ) + center_vector: list[float] | None = Field( + default=None, description="Vector of the center dialog." + ) + created_at: str | None = Field( + default=None, description="Timestamp of the dialog." + ) + implicit_preference: str | None = Field( + default=None, description="Implicit preference." + ) + topic_cluster_name: str | None = Field( + default=None, description="Name of the topic cluster." + ) + topic_cluster_description: str | None = Field( + default=None, description="Description of the topic cluster." + ) + topic_preferences: str | None = Field( + default=None, description="Preferences of the topic." + ) + user_preferences: str | None = Field( + default=None, description="Preferences of the user." + ) + + class TextualMemoryItem(BaseModel): """Represents a single memory item in the textual memory. @@ -117,6 +173,7 @@ class TextualMemoryItem(BaseModel): TextualMemoryMetadata | TreeNodeTextualMemoryMetadata | SearchedTreeNodeTextualMemoryMetadata + | PreferenceTextualMemoryMetadata ) = Field(default_factory=TextualMemoryMetadata) model_config = ConfigDict(extra="forbid") diff --git a/src/memos/memories/textual/prefer_text_memory/__init__.py b/src/memos/memories/textual/prefer_text_memory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py index 52332d1fd..ac7762046 100644 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from memos.memories.textual.item import TextualMemoryItem - +from memos.memories.textual.prefer_text_memory.naive_op import NaiveOp class BaseAssembler(ABC): """Abstract base class for assemblers.""" @@ -9,8 +9,15 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the assembler.""" @abstractmethod - def assemble(self, query: str, memories: list[TextualMemoryItem]) -> str: - """Assemble query and memories into a single memory.""" + def assemble(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: + """Assemble query and memories into a single memory. + Args: + query: The query to assemble. + memories: The memories to assemble. + assemble_strategy: The strategy to assemble the memories. option: [raw, semi, full] + Returns: + The assembled prompt. + """ class NaiveAssembler(BaseAssembler): @@ -22,7 +29,22 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.embedder = embedder self.vector_db = vector_db - def assemble(self, query: str, memories: list[TextualMemoryItem]) -> str: + def assemble(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: """Assemble query and memories into a single memory.""" - return f"Query: {query}\nMemories: {memories}" + + explicit_prefs = [{"dialog_str": memory.metadata.dialog_str, "explicit_preference": memory.metadata.explicit_preference} for memory in memories if memory.metadata.preference_type == "explicit_preference"] + implicit_prefs = [{"center_dialog_str": memory.metadata.center_dialog, "implicit_preference": memory.metadata.implicit_preference} for memory in memories if memory.metadata.preference_type == "implicit_preference"] + topic_prefs = [{"center_dialog_str": memory.metadata.center_dialog, "topic_preferences": memory.metadata.topic_preferences} for memory in memories if memory.metadata.preference_type == "topic_preference"] + user_prefs = [{"user_preferences": memory.metadata.user_preferences} for memory in memories if memory.metadata.preference_type == "user_preference"] + + naive_op = NaiveOp(self.llm_provider, self.embedder, self.vector_db) + + if assemble_strategy == "raw": + return memories + elif assemble_strategy == "semi": + return f"Query: {query}\n\n In addition to the above Query, you can refer to the following preference below memories. \n\nMemories: {memories}. \n\nWhen encountering conflicts, prioritize following the query." + elif assemble_strategy == "full": + return naive_op.preference_integration(query, explicit_prefs, implicit_prefs, topic_prefs, user_prefs) + else: + raise ValueError(f"Invalid assemble strategy: {assemble_strategy}") diff --git a/src/memos/memories/textual/prefer_text_memory/builders.py b/src/memos/memories/textual/prefer_text_memory/builders.py index 6896bd298..d4dd7bfaf 100644 --- a/src/memos/memories/textual/prefer_text_memory/builders.py +++ b/src/memos/memories/textual/prefer_text_memory/builders.py @@ -5,19 +5,16 @@ from datetime import datetime from memos.llms.base import BaseLLM +from memos.memories.textual.prefer_text_memory import naive_op from memos.types import ChatHistory, MessageList from memos.embedders.base import BaseEmbedder from memos.vec_dbs.base import BaseVecDB from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata -from memos.templates.prefer_complete_prompt import ( - NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, - NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, - NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT, - NAIVE_USER_PREFERENCE_EXTRACT_PROMPT, - NAIVE_TOPIC_INFO_EXTRACT_PROMPT -) +from memos.vec_dbs.item import VecDBItem + from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer -from memos.memories.textual.prefer_text_memory.chunk_merging import ChunkMergingManager, NaiveChunkMerger +from memos.memories.textual.prefer_text_memory.naive_op import NaiveOp + class BaseBuilder(ABC): @@ -70,12 +67,12 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): def build(self, history: ChatHistory) -> str: """Build memory content from chat history following the preference extraction pipeline.""" - # Initialize clustering and chunk merging managers + # Initialize clustering manager clusterer = HDBSCANClusterer() - chunk_merging_manager = ChunkMergingManager(NaiveChunkMerger()) + naive_op = NaiveOp(self.llm_provider, self.embedder, self.vector_db) # Step 1: Build QA pairs from chat history - qa_pairs = self._build_qa_pairs(history.chat_history) + qa_pairs = naive_op.build_qa_pairs(history.chat_history) # Step 2: Process each QA pair basic_infos = [] @@ -84,426 +81,52 @@ def build(self, history: ChatHistory) -> str: for qa_pair in qa_pairs: # Extract basic info - basic_info = self._extract_basic_info(qa_pair) + basic_info = naive_op.extract_basic_info(qa_pair) basic_infos.append(basic_info) # Extract topic information - topic_info = self._extract_topic_info(qa_pair) + topic_info = naive_op.extract_topic_info(qa_pair) if topic_info: topic_infos.append(topic_info) # Extract explicit preference from - explicit_pref = self._extract_explicit_preference(qa_pair) + explicit_pref = naive_op.extract_explicit_preference(qa_pair) if explicit_pref: explicit_preferences.append(explicit_pref) # Step 3: Generate embeddings - dialogue_vectors = self._generate_dialogue_vectors(basic_infos) - topic_vectors = self._generate_topic_vectors(topic_infos) + dialogue_vectors = naive_op.generate_dialogue_vectors(basic_infos) + topic_vectors = naive_op.generate_topic_vectors(topic_infos) - whole_infos = self._concat_infos(basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors) + whole_infos = naive_op.concat_infos(basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors) # Step 4: Perform clustering - implicit_clusters = self._implicit_cluster(clusterer, whole_infos) - topic_clusters = self._topic_cluster(clusterer, whole_infos) + implicit_clusters = naive_op.implicit_cluster(clusterer, whole_infos) + topic_clusters = naive_op.topic_cluster(clusterer, whole_infos) # Step 5: Extract implicit preferences - implicit_preferences = self._extract_implicit_preferences(implicit_clusters) + implicit_clusters = naive_op.extract_implicit_preferences(implicit_clusters) # Step 6: Extract topic preferences - topic_preferences = self._extract_topic_preferences(topic_clusters) - - # Step 7: Handle chunk merging for long chunks (optional) - merged_preferences = self._handle_chunk_merging( - explicit_preferences, implicit_preferences, topic_preferences, chunk_merging_manager + topic_clusters = naive_op.extract_topic_preferences(topic_clusters) + + # Step 7: Extract user preferences + user_preferences = naive_op.extract_user_preferences(topic_clusters) + + # Step 8: Store all preferences in memory + naive_op.store_preferences( + explicit_prefs=whole_infos, + implicit_prefs=implicit_clusters, + topic_prefs=topic_clusters, + user_prefs=user_preferences, + user_id=history.user_id, ) - # Step 8: Extract user preferences - user_preferences = self._extract_user_preferences(topic_preferences) - - # Step 9: Store all preferences in memory - self._store_preferences( - explicit_preferences, - implicit_preferences, - topic_preferences, - user_preferences, - history.user_id + # Step 9: Return summary of built memory + return naive_op.generate_memory_summary( + explicit_prefs=whole_infos, + implicit_prefs=implicit_clusters, + topic_prefs=topic_clusters, + user_prefs=user_preferences, ) - - # Return summary of built memory - return self._generate_memory_summary( - explicit_preferences, - implicit_preferences, - topic_preferences, - basic_infos, - user_preferences - ) - - def _build_qa_pairs(self, chat_history: MessageList) -> List[MessageList]: - """Build QA pairs from chat history.""" - qa_pairs = [] - current_qa_pair = [] - - for message in chat_history: - role = message["role"] - - if role == "user": - # If we have a complete QA pair, save it - if len(current_qa_pair) >= 2: # At least question + answer - qa_pairs.append(current_qa_pair) - - # Start new QA pair - current_qa_pair = [message] - - elif role == "assistant": - # Add answer to current QA pair - current_qa_pair.append(message) - - # Don't forget the last QA pair if it exists and is complete - if len(current_qa_pair) >= 2: - qa_pairs.append(current_qa_pair) - - return qa_pairs - - def _extract_basic_info(self, qa_pair: MessageList) -> Dict[str, Any]: - """Extract basic information from a QA pair (no LLM needed).""" - basic_info = { - "dialog_id": str(uuid.uuid4()), - "dialog_msgs": qa_pair, - "dialog_str": "\n".join([f"{msg['role']}: {msg['content']}" for msg in qa_pair]), - "created_at": datetime.now().isoformat() - } - - return basic_info - - def _extract_topic_info(self, qa_pair: MessageList) -> Optional[Dict[str, Any]]: - """Extract topic information from a QA pair.""" - prompt = NAIVE_TOPIC_INFO_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - return result - except Exception: - return response - - def _extract_explicit_preference(self, qa_pair: MessageList) -> Optional[Dict[str, Any]]: - """Extract explicit preference from a QA pair (LLM-1).""" - - prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - return result - except Exception: - return response - - def _generate_dialogue_vectors(self, basic_infos: List[Dict[str, Any]]) -> List[List[float]]: - """Generate embeddings for dialogue segments.""" - if not self.embedder or not basic_infos: - return [] - - texts = [info.get("dialog_segment_str", "") for info in basic_infos] - return self.embedder.embed(texts) - - def _generate_topic_vectors(self, topic_infos: List[Dict[str, Any]]) -> List[List[float]]: - """Generate embeddings for topic information.""" - if not self.embedder: - return [] - - texts = [f"{info.get('topic_name', '')} {info.get('topic_description', '')}" for info in topic_infos] - return self.embedder.embed(texts) - - def _concat_infos( - self, - basic_infos: List[Dict[str, Any]], - explicit_preferences: List[Dict[str, Any]], - topic_infos: List[Dict[str, Any]], - dialogue_vectors: List[List[float]], - topic_vectors: List[List[float]]) -> List[Dict[str, Any]]: - """Concatenate infos.""" - whole_infos = [{**bsc, **ep, **ti, "dialog_vector": dv, "topic_vector": tv} - for bsc, ep, ti, dv, tv in zip(basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors)] - return whole_infos - - def _implicit_cluster(self, clusterer: HDBSCANClusterer, whole_infos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Implicit cluster.""" - # Extract vectors for clustering - vectors = [info.get("dialog_vector") for info in whole_infos] - - # Perform clustering - cluster_results = clusterer.cluster(vectors) - - # Map cluster results back to original data - for cluster in cluster_results: - original_infos = [] - for item in cluster["items"]: - index = item["index"] - original_info = whole_infos[index] - original_infos.append({ - "dialog_id": original_info.get("dialog_id"), - "dialog_msgs": original_info.get("dialog_msgs"), - "dialog_str": original_info.get("dialog_str"), - "created_at": original_info.get("created_at") - }) - cluster["original_data"] = original_infos - - return cluster_results - - def _topic_cluster(self, clusterer: HDBSCANClusterer, whole_infos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Topic cluster.""" - # Extract vectors for clustering - vectors = [info.get("topic_vector") for info in whole_infos] - - # Perform clustering - cluster_results = clusterer.cluster(vectors) - - # Map cluster results back to original data - for cluster in cluster_results: - original_infos = [] - for item in cluster["items"]: - index = item["index"] - original_info = whole_infos[index] - original_infos.append({ - "dialog_id": original_info.get("dialog_id"), - "dialog_msgs": original_info.get("dialog_msgs"), - "dialog_str": original_info.get("dialog_str"), - "created_at": original_info.get("created_at") - }) - cluster["original_data"] = original_infos - - return cluster_results - - - def _handle_chunk_merging(self, explicit_prefs: List[Dict[str, Any]], - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - chunk_merging_manager) -> Dict[str, Any]: - """Handle chunk merging for long chunks (optional step).""" - merged_results = { - "explicit_preferences": explicit_prefs, - "implicit_preferences": implicit_prefs, - "topic_preferences": topic_prefs, - "merged_chunks": [] - } - - # Check for long chunks that need merging - all_preferences = explicit_prefs + implicit_prefs + topic_prefs - - for pref in all_preferences: - content = pref.get("preference", "") - if len(content) > 4096: # Long chunk threshold - # Split the long chunk - chunks = chunk_merging_manager.process_long_chunks([{ - "content": content, - "id": pref.get("id", ""), - "preference": pref - }]) - - if len(chunks) > 1: - # Extract preferences from split chunks - chunk_preferences = [chunk.get("preference", {}) for chunk in chunks] - - # Merge preferences using LLM - merged_pref = chunk_merging_manager.merge_chunk_items( - chunk_preferences, self.llm_provider - ) - - if merged_pref: - merged_results["merged_chunks"].append({ - "original_preference": pref, - "merged_preference": merged_pref, - "chunk_count": len(chunks) - }) - - return merged_results - - def _extract_implicit_preferences(self, clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Extract implicit preferences from clusters.""" - - implicit_preferences = [] - - for cluster in clusters: - # Get dialogue segments in this cluster - qa_pairs = "\n".join([info["dialog_str"] for info in cluster["original_data"]]) - - prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - - if result.get("implicit_preference"): - result["id"] = str(uuid.uuid4()) - result["created_at"] = datetime.now().isoformat() - implicit_preferences.append(result) - except Exception: - continue - - return implicit_preferences - - def _extract_topic_preferences(self, clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Extract topic preferences from clusters.""" - - topic_preferences = [] - - for cluster in clusters: - # Get topic infos in this cluster - cluster_topics = [cluster["items"][i] for i in range(len(cluster["items"])) - if i < len(cluster["items"])] - - prompt = f""" - {NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT} - - 主题聚类信息: - - 聚类ID: {cluster.get('cluster_id', '')} - - 聚类大小: {cluster.get('size', 0)} - - 相关主题信息: - {json.dumps(cluster_topics, ensure_ascii=False, indent=2)} - - 请提取主题偏好,返回JSON格式: - {{ - "topic_preference": "主题偏好描述", - "confidence": 0.8, - "cluster_id": "{cluster.get('cluster_id', '')}" - }} - """ - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - - if result.get("topic_preference"): - result["id"] = str(uuid.uuid4()) - result["created_at"] = datetime.now().isoformat() - topic_preferences.append(result) - except Exception: - continue - - return topic_preferences - - def _extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Extract user-level preferences.""" - - prompt = f""" - {NAIVE_USER_PREFERENCE_EXTRACT_PROMPT} - - 主题偏好信息: - {json.dumps(topic_preferences, ensure_ascii=False, indent=2)} - - 相关对话片段: - {json.dumps(topic_preferences, ensure_ascii=False, indent=2)} - - 请提取用户偏好,返回JSON格式: - {{ - "user_preference": "用户偏好描述", - "confidence": 0.8 - }} - """ - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - - if result.get("user_preference"): - result["id"] = str(uuid.uuid4()) - result["created_at"] = datetime.now().isoformat() - return [result] - except Exception: - pass - - return [] - - def _store_preferences(self, explicit_prefs: List[Dict[str, Any]], - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - user_prefs: List[Dict[str, Any]], - user_id: str): - """Store all preferences in memory.""" - - # Convert to TextualMemoryItem and store - all_memories = [] - - # Store explicit preferences - for pref in explicit_prefs: - memory_item = TextualMemoryItem( - memory=pref.get("preference", ""), - metadata=TextualMemoryMetadata( - user_id=user_id, - type="explicit_preference", - confidence=pref.get("confidence", 0.5), - source="conversation", - tags=["explicit", "preference"], - updated_at=pref.get("created_at", datetime.now().isoformat()) - ) - ) - all_memories.append(memory_item) - - # Store implicit preferences - for pref in implicit_prefs: - memory_item = TextualMemoryItem( - memory=pref.get("implicit_preference", ""), - metadata=TextualMemoryMetadata( - user_id=user_id, - type="implicit_preference", - confidence=pref.get("confidence", 0.5), - source="conversation", - tags=["implicit", "preference", pref.get("cluster_id", "")], - updated_at=pref.get("created_at", datetime.now().isoformat()) - ) - ) - all_memories.append(memory_item) - - # Store topic preferences - for pref in topic_prefs: - memory_item = TextualMemoryItem( - memory=pref.get("topic_preference", ""), - metadata=TextualMemoryMetadata( - user_id=user_id, - type="topic_preference", - confidence=pref.get("confidence", 0.5), - source="conversation", - tags=["topic", "preference", pref.get("cluster_id", "")], - updated_at=pref.get("created_at", datetime.now().isoformat()) - ) - ) - all_memories.append(memory_item) - - # Store user preferences - for pref in user_prefs: - memory_item = TextualMemoryItem( - memory=pref.get("user_preference", ""), - metadata=TextualMemoryMetadata( - user_id=user_id, - type="user_preference", - confidence=pref.get("confidence", 0.5), - source="conversation", - tags=["user", "preference"], - updated_at=pref.get("created_at", datetime.now().isoformat()) - ) - ) - all_memories.append(memory_item) - - # Store in vector database - if all_memories: - self.vector_db.add(all_memories) - - def _generate_memory_summary(self, explicit_prefs: List[Dict[str, Any]], - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - user_prefs: List[Dict[str, Any]]) -> str: - """Generate a summary of the built memory.""" - summary = { - "memory_build_summary": { - "explicit_preferences_count": len(explicit_prefs), - "implicit_preferences_count": len(implicit_prefs), - "topic_preferences_count": len(topic_prefs), - "user_preferences_count": len(user_prefs), - "total_memories": len(explicit_prefs) + len(implicit_prefs) + len(topic_prefs) + len(user_prefs), - "build_timestamp": datetime.now().isoformat() - } - } - - return json.dumps(summary, ensure_ascii=False, indent=2) + \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/clustering.py b/src/memos/memories/textual/prefer_text_memory/clustering.py index f15afc81f..7870a9e55 100644 --- a/src/memos/memories/textual/prefer_text_memory/clustering.py +++ b/src/memos/memories/textual/prefer_text_memory/clustering.py @@ -9,6 +9,7 @@ from typing import List, Dict, Any, Optional, Tuple import numpy as np from datetime import datetime +import uuid class BaseClusterer(ABC): @@ -82,8 +83,8 @@ def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: # Convert to cluster information cluster_info = [] - for cluster_id, items_in_cluster in clusters.items(): - if cluster_id == -1: # Skip noise points + for cluster_label, items_in_cluster in clusters.items(): + if cluster_label == -1: # Skip noise points continue # Calculate cluster center @@ -95,7 +96,8 @@ def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: center_index = items_in_cluster[closest_idx]["index"] cluster_info.append({ - "cluster_id": cluster_id, + "cluster_id": str(uuid.uuid4()), + "cluster_label": cluster_label, "center_index": center_index, "center_vector": center_vector.tolist(), "size": len(items_in_cluster), @@ -108,7 +110,8 @@ def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: except ImportError: # Fallback: create single cluster return [{ - "cluster_id": 0, + "cluster_id": str(uuid.uuid4()), + "cluster_label": -1, "center_index": 0, "center_vector": vectors[0] if vectors else [], "size": len(vectors), diff --git a/src/memos/memories/textual/prefer_text_memory/naive_op.py b/src/memos/memories/textual/prefer_text_memory/naive_op.py new file mode 100644 index 000000000..08c4da61c --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/naive_op.py @@ -0,0 +1,439 @@ +from memos.types import MessageList +from typing import List, Dict, Any, Optional +import uuid +import json +from datetime import datetime + +from memos.templates.prefer_complete_prompt import ( + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT, + NAIVE_USER_PREFERENCE_EXTRACT_PROMPT, + NAIVE_TOPIC_INFO_EXTRACT_PROMPT, + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, + NAIVE_PREFERENCE_INTEGRATION_PROMPT +) +from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer +from memos.vec_dbs.item import VecDBItem + + +class NaiveOp: + """Naive operation.""" + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the naive operation.""" + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db + + def build_qa_pairs(self, chat_history: MessageList) -> List[MessageList]: + """Build QA pairs from chat history.""" + qa_pairs = [] + current_qa_pair = [] + + for message in chat_history: + role = message["role"] + + if role == "user": + # If we have a complete QA pair, save it + if len(current_qa_pair) >= 2: # At least question + answer + qa_pairs.append(current_qa_pair) + + # Start new QA pair + current_qa_pair = [message] + + elif role == "assistant": + # Add answer to current QA pair + current_qa_pair.append(message) + + # Don't forget the last QA pair if it exists and is complete + if len(current_qa_pair) >= 2: + qa_pairs.append(current_qa_pair) + + return qa_pairs + + def extract_basic_info(self, qa_pair: MessageList) -> Dict[str, Any]: + """Extract basic information from a QA pair (no LLM needed).""" + basic_info = { + "dialog_id": str(uuid.uuid4()), + "dialog_msgs": qa_pair, + "dialog_str": "\n".join([f"{msg['role']}: {msg['content']}" for msg in qa_pair]), + "created_at": datetime.now().isoformat() + } + + return basic_info + + def extract_topic_info(self, qa_pair: MessageList) -> Optional[Dict[str, Any]]: + """Extract topic information from a QA pair.""" + # Convert qa_pair to string format + qa_pair_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in qa_pair]) + prompt = NAIVE_TOPIC_INFO_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + return result + except Exception: + return response + + def extract_explicit_preference(self, qa_pair: MessageList) -> Optional[Dict[str, Any]]: + """Extract explicit preference from a QA pair (LLM-1).""" + # Convert qa_pair to string format + qa_pair_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in qa_pair]) + prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + return result + except Exception: + return response + + def generate_dialogue_vectors(self, basic_infos: List[Dict[str, Any]]) -> List[List[float]]: + """Generate embeddings for dialogue segments.""" + if not self.embedder or not basic_infos: + return [] + + texts = [info.get("dialog_str", "") for info in basic_infos] + embdeddings = self.embedder.embed(texts) + return [{"dialog_vector": embedding} for embedding in embdeddings] + + def generate_topic_vectors(self, topic_infos: List[Dict[str, Any]]) -> List[List[float]]: + """Generate embeddings for topic information.""" + if not self.embedder: + return [] + + texts = [f"{info.get('topic_name', '')} {info.get('topic_description', '')}" for info in topic_infos] + embdeddings = self.embedder.embed(texts) + return [{"topic_vector": embedding} for embedding in embdeddings] + + def concat_infos( + self, + basic_infos: List[Dict[str, Any]] = None, + explicit_preferences: List[Dict[str, Any]] = None, + topic_infos: List[Dict[str, Any]] = None, + dialogue_vectors: List[Dict[str, Any]] = None, + topic_vectors: List[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + """Concatenate infos - only merge if not None.""" + # Get all non-None lists + lists_to_zip = [] + for lst in [basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors]: + if lst is not None: + lists_to_zip.append(lst) + + if not lists_to_zip: + return [] + + # Use the first list to determine length + length = len(lists_to_zip[0]) + + whole_infos = [] + for i in range(length): + merged_dict = {} + + # Only merge if not None + if basic_infos is not None and i < len(basic_infos): + merged_dict.update(basic_infos[i]) + if explicit_preferences is not None and i < len(explicit_preferences): + merged_dict.update(explicit_preferences[i]) + if topic_infos is not None and i < len(topic_infos): + merged_dict.update(topic_infos[i]) + if dialogue_vectors is not None and i < len(dialogue_vectors): + merged_dict.update(dialogue_vectors[i]) + if topic_vectors is not None and i < len(topic_vectors): + merged_dict.update(topic_vectors[i]) + + whole_infos.append(merged_dict) + + return whole_infos + + def implicit_cluster(self, clusterer: HDBSCANClusterer, whole_infos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Implicit cluster.""" + # Extract vectors for clustering + vectors = [info.get("dialog_vector") for info in whole_infos] + if not vectors: + return [] + # Perform clustering + cluster_results = clusterer.cluster(vectors) + + # Map cluster results back to original data + for cluster in cluster_results: + cluster["center_dialog"] = whole_infos[cluster["center_index"]]["dialog_msgs"] + cluster["center_dialog_str"] = whole_infos[cluster["center_index"]]["dialog_str"] + original_infos = [] + for item in cluster["items"]: + index = item["index"] + original_info = whole_infos[index] + original_infos.append({ + "dialog_id": original_info.get("dialog_id"), + "dialog_msgs": original_info.get("dialog_msgs"), + "dialog_str": original_info.get("dialog_str"), + "created_at": original_info.get("created_at") + }) + cluster["original_data"] = original_infos + + return cluster_results + + def topic_cluster(self, clusterer: HDBSCANClusterer, whole_infos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Topic cluster.""" + # Extract vectors for clustering + vectors = [info.get("topic_vector") for info in whole_infos] + if not vectors: + return [] + # Perform clustering + cluster_results = clusterer.cluster(vectors) + + # Map cluster results back to original data + for cluster in cluster_results: + cluster["center_dialog"] = whole_infos[cluster["center_index"]]["dialog_str"] + original_infos = [] + for item in cluster["items"]: + index = item["index"] + original_info = whole_infos[index] + original_infos.append({ + "dialog_id": original_info.get("dialog_id"), + "dialog_msgs": original_info.get("dialog_msgs"), + "dialog_str": original_info.get("dialog_str"), + "created_at": original_info.get("created_at") + }) + cluster["original_data"] = original_infos + + return cluster_results + + + def extract_implicit_preferences(self, clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract implicit preferences from clusters.""" + if not clusters: + return [] + for cluster in clusters: + # Get dialogue segments in this cluster + qa_pairs = "\n".join([info["dialog_str"] for info in cluster["original_data"]]) + + prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + if result.get("implicit_preference"): + cluster["implicit_preference"] = result + except Exception as e: + print(e) + cluster["implicit_preference"] = "" + + return clusters + + def extract_topic_preferences(self, clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract topic preferences from clusters.""" + if not clusters: + return [] + for cluster in clusters: + # Get dialogue segments in this cluster + qa_pairs = "\n".join([info["dialog_str"] for info in cluster["original_data"]]) + + prompt = NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + cluster["topic_cluster_name"] = result.get("topic_cluster_name", "") + cluster["topic_cluster_description"] = result.get("topic_cluster_description", "") + cluster["topic_preferences"] = result.get("topic_preferences", "") + except Exception as e: + print(e) + cluster["topic_cluster_name"] = "" + cluster["topic_cluster_description"] = "" + cluster["topic_preferences"] = "" + + return clusters + + def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract user-level preferences.""" + if not topic_preferences: + return [] + cluster_infos = [{ + "topic_cluster_name": cluster["topic_cluster_name"], + "topic_cluster_description": cluster["topic_cluster_description"], + "topic_preferences": cluster["topic_preferences"]} + for cluster in topic_preferences] + prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace("{cluster_info}", json.dumps(cluster_infos, ensure_ascii=False, indent=2)) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + if result.get("user_preferences"): + return result + except Exception as e: + print(e) + return "" + + def judge_update_or_add(self, old_msg: MessageList, new_msg: MessageList) -> bool: + """Judge if the new message expresses the same core content as the old message.""" + # Convert messages to string format for comparison + old_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in old_msg]) + new_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in new_msg]) + + # Use the template prompt with placeholders + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_str).replace("{new_information}", new_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + response = result.get("is_same", False) + return response if isinstance(response, bool) else response == "true" + except Exception as e: + print(f"Error in judge_update_or_add: {e}") + # Fallback to simple string comparison + return old_str == new_str + + def preference_integration(self, query: str, + explicit_prefs: List[Dict[str, Any]], + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + user_prefs: List[Dict[str, Any]]) -> str: + """Integrate preferences.""" + explicit_prefs_str = json.dumps(explicit_prefs, ensure_ascii=False, indent=2) + implicit_prefs_str = json.dumps(implicit_prefs, ensure_ascii=False, indent=2) + topic_prefs_str = json.dumps(topic_prefs, ensure_ascii=False, indent=2) + user_prefs_str = json.dumps(user_prefs, ensure_ascii=False, indent=2) + + prompt = NAIVE_PREFERENCE_INTEGRATION_PROMPT.format( + query_preference=query, + explicit_preference=explicit_prefs_str, + implicit_preference=implicit_prefs_str, + topic_preference=topic_prefs_str, + user_preference=user_prefs_str + ) + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + return result["final_prompt"] + except Exception as e: + print(f"Error in preference_integration: {e}") + return "" + + + def store_preferences(self, explicit_prefs: List[Dict[str, Any]], + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + user_prefs: List[Dict[str, Any]], + user_id: str): + """Store all preferences in memory.""" + + # Convert to VecDBItem format and store in separate collections + + # Store explicit preferences + if explicit_prefs: + explicit_memories = [] + for pref in explicit_prefs: + # Create VecDBItem directly using existing embedding + vec_db_item = VecDBItem( + id=pref.get("dialog_id", ""), + vector=pref.get("dialog_vector", []), + payload={ + "dialog_id": pref.get("dialog_id", ""), + "dialog_msgs": pref.get("dialog_msgs", []), + "dialog_str": pref.get("dialog_str", ""), + "dialog_vector": pref.get("dialog_vector", []), + "created_at": pref.get("created_at", datetime.now().isoformat()), + "topic_name": pref.get("topic_name", ""), + "topic_description": pref.get("topic_description", ""), + "topic_vector": pref.get("topic_vector", []), + "user_id": user_id, + "preference_type": "explicit_preference" + } + ) + explicit_memories.append(vec_db_item) + + # Store in explicit_preference collection + self.vector_db.add("explicit_preference", explicit_memories) + + # Store implicit preferences + if implicit_prefs: + implicit_memories = [] + for pref in implicit_prefs: + # Create VecDBItem directly using existing embedding + vec_db_item = VecDBItem( + id=pref.get("cluster_id", ""), + vector=pref.get("center_vector", []), + payload={ + "cluster_id": pref.get("cluster_id", ""), + "center_dialog": pref.get("center_dialog", ""), + "center_vector": pref.get("center_vector", []), + "implicit_preference": pref.get("implicit_preference", ""), + "created_at": pref.get("created_at", datetime.now().isoformat()), + "user_id": user_id, + "preference_type": "implicit_preference" + } + ) + implicit_memories.append(vec_db_item) + + # Store in implicit_preference collection + self.vector_db.add("implicit_preference", implicit_memories) + + # Store topic preferences + if topic_prefs: + topic_memories = [] + for pref in topic_prefs: + # Create VecDBItem directly using existing embedding + vec_db_item = VecDBItem( + id=pref.get("cluster_id", ""), + vector=pref.get("center_vector", []), + payload={ + "cluster_id": pref.get("cluster_id", ""), + "center_dialog": pref.get("center_dialog", ""), + "center_vector": pref.get("center_vector", []), + "topic_cluster_name": pref.get("topic_cluster_name", ""), + "topic_cluster_description": pref.get("topic_cluster_description", ""), + "topic_preferences": pref.get("topic_preferences", ""), + "created_at": pref.get("created_at", datetime.now().isoformat()), + "user_id": user_id, + "preference_type": "topic_preference" + } + ) + topic_memories.append(vec_db_item) + + # Store in topic_preference collection + self.vector_db.add("topic_preference", topic_memories) + + # Store user preferences + if user_prefs: + user_memories = [] + for pref in user_prefs: + # Create VecDBItem with zero vector (user preferences don't need vector search) + # Use zero vector to satisfy Milvus collection dimension requirements + # Get embedding dimension from embedder config + embedding_dim = getattr(self.embedder.config, 'embedding_dims', 768) # Default to 768 if not available + zero_vector = [0.0] * embedding_dim + vec_db_item = VecDBItem( + id=user_id, + vector=zero_vector, + payload={ + "user_id": user_id, + "user_preferences": pref.get("user_preferences", ""), + "created_at": datetime.now().isoformat(), + "preference_type": "user_preference" + } + ) + user_memories.append(vec_db_item) + + # Store in user_preference collection + self.vector_db.add("user_preference", user_memories) + + def generate_memory_summary(self, explicit_prefs: List[Dict[str, Any]], + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + user_prefs: List[Dict[str, Any]]) -> str: + """Generate a summary of the built memory.""" + summary = { + "memory_build_summary": { + "explicit_preferences_count": len(explicit_prefs), + "implicit_preferences_count": len(implicit_prefs), + "topic_preferences_count": len(topic_prefs), + "user_preferences_count": len(user_prefs), + "build_timestamp": datetime.now().isoformat() + } + } + + return json.dumps(summary, ensure_ascii=False, indent=2) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index ea2f259ca..471f7345b 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod -from typing import Any -from memos.memories.textual.item import TextualMemoryItem -from memos.vec_dbs.base import BaseVecDB -from memos.embedders.base import BaseEmbedder +from typing import Any, Dict +from memos.memories.textual.item import TextualMemoryItem, PreferenceTextualMemoryMetadata class BaseRetriever(ABC): @@ -14,7 +12,7 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): @abstractmethod - def retrieve(self, query: str, top_k: int, info: dict[str, Any]) -> list[TextualMemoryItem]: + def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[TextualMemoryItem]: """Retrieve memories from the retriever.""" class NaiveRetriever(BaseRetriever): @@ -22,11 +20,28 @@ class NaiveRetriever(BaseRetriever): def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the naive retriever.""" super().__init__(llm_provider, embedder, vector_db) - self.llm_provider = llm_provider self.vector_db = vector_db self.embedder = embedder - def retrieve(self, query: str, top_k: int, info: dict[str, Any]) -> list[TextualMemoryItem]: + def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[TextualMemoryItem]: """Retrieve memories from the naive retriever.""" - query_embedding = self.embedder.embed(query) - return self.vector_db.search(query_embedding, top_k, info) \ No newline at end of file + query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings + query_embedding = query_embeddings[0] # Get the first (and only) embedding + explicit_prefs = self.vector_db.search(query_embedding, "explicit_preference", top_k, info) + implicit_prefs = self.vector_db.search(query_embedding, "implicit_preference", top_k, info) + topic_prefs = self.vector_db.search(query_embedding, "topic_preference", top_k, info) + + # Get user preferences directly (no vector search needed since there's only one per user) + user_id = info.get("user_id") if info else None + if user_id: + user_prefs = self.vector_db.get_by_filter("user_preference", {"user_id": user_id}) + else: + user_prefs = [] + + explicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("dialog_str", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in explicit_prefs] + implicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in implicit_prefs] + topic_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs] + user_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("user_preferences", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in user_prefs] + + return explicit_prefs + implicit_prefs + topic_prefs + user_prefs + diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index a16884ec0..d54810b39 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -1,6 +1,10 @@ from abc import ABC, abstractmethod from typing import Any -from memos.memories.textual.item import TextualMemoryItem +from datetime import datetime +from memos.types import MessageList +from memos.vec_dbs.item import VecDBItem +from memos.memories.textual.prefer_text_memory.naive_op import NaiveOp +from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer class BaseUpdater(ABC): @@ -11,10 +15,10 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the updater.""" @abstractmethod - def update(self, new_memory: TextualMemoryItem | dict[str, Any], *args, **kwargs) -> None: - """Update the memory. + def update(self, new_dialog: MessageList, *args, **kwargs) -> None: + """Update the dialog. Args: - new_memory (TextualMemoryItem | dict[str, Any]): The new memory to update. + new_dialog (MessageList): The new dialog to update. *args: Additional positional arguments. **kwargs: Additional keyword arguments. """ @@ -29,13 +33,106 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.embedder = embedder self.vector_db = vector_db - def update(self, new_memory: TextualMemoryItem | dict[str, Any], *args, **kwargs) -> None: - """Update the memory. + def update(self, new_dialog: MessageList, *args, **kwargs) -> None: + """Update the dialog to the vector db, fast update. Args: - new_memory (TextualMemoryItem | dict[str, Any]): The new memory to update. + new_dialog (MessageList): The new dialog to update. *args: Additional positional arguments. **kwargs: Additional keyword arguments. """ - pass + naive_op = NaiveOp(self.llm_provider, self.embedder, self.vector_db) + basic_info = naive_op.extract_basic_info(new_dialog) + topic_info = naive_op.extract_topic_info(new_dialog) + explicit_pref = naive_op.extract_explicit_preference(new_dialog) + dialogue_vectors = naive_op.generate_dialogue_vectors([basic_info]) + concat_info = naive_op.concat_infos([basic_info], [explicit_pref], [topic_info], dialogue_vectors)[0] + + vec_db_item = VecDBItem( + id=concat_info.get("dialog_id", ""), + vector=concat_info.get("dialog_vector", []), + payload={ + "dialog_id": concat_info.get("dialog_id", ""), + "dialog_msgs": concat_info.get("dialog_msgs", []), + "dialog_str": concat_info.get("dialog_str", ""), + "created_at": concat_info.get("created_at", datetime.now().isoformat()), + "user_id": concat_info.get("user_id", ""), + "type": "explicit_preference" + } + ) + + # retrieve the dialog + dialog_vector = dialogue_vectors[0]["dialog_vector"] # This is already a List[float] + dialog_items = self.vector_db.search(dialog_vector, "explicit_preference", top_k=1) + + # Extract dialog_str from retrieved items + if dialog_items: + # Get the first (most similar) item + retrieved_item = dialog_items[0] + # Extract dialog_str from payload + old_msgs = retrieved_item.payload.get("dialog_msgs", "") + is_same = naive_op.judge_update_or_add(old_msgs, new_dialog) + if is_same: + # Extract ID from the retrieved item + item_id = retrieved_item.id + self.vector_db.update("explicit_preference", item_id, vec_db_item) + return + + self.vector_db.add("explicit_preference", vec_db_item) + + + def slow_update(self): + """Retrieve all dialog info from the expicit preference collection, + and reconstruct the implicit preference collection, topic collection and user preference collection. + """ + clusterer = HDBSCANClusterer() + naive_op = NaiveOp(self.llm_provider, self.embedder, self.vector_db) + all_data = self.vector_db.get_all("explicit_preference") + + user_id = all_data[0].payload.get("user_id", "") + + # Convert VecDBItem list to whole_infos format + whole_infos = [item.payload for item in all_data] + + # Perform clustering + implicit_clusters = naive_op.implicit_cluster(clusterer, whole_infos) + topic_clusters = naive_op.topic_cluster(clusterer, whole_infos) + + # Extract implicit preferences + implicit_clusters = naive_op.extract_implicit_preferences(implicit_clusters) + + # Extract topic preferences + topic_clusters = naive_op.extract_topic_preferences(topic_clusters) + + # Extract user preferences + user_preferences = naive_op.extract_user_preferences(topic_clusters) + + # refresh the implicit preference collection, topic collection and user preference collection + self.vector_db.delete_collection("implicit_preference") + self.vector_db.delete_collection("topic_preference") + self.vector_db.delete_collection("user_preference") + + self.vector_db.create_collection_by_name("implicit_preference") + self.vector_db.create_collection_by_name("topic_preference") + self.vector_db.create_collection_by_name("user_preference") + + + # Store all preferences in memory + naive_op.store_preferences( + explicit_prefs=whole_infos, + implicit_prefs=implicit_clusters, + topic_prefs=topic_clusters, + user_prefs=user_preferences, + user_id=user_id, + ) + + # Return summary of built memory + return naive_op.generate_memory_summary( + explicit_prefs=whole_infos, + implicit_prefs=implicit_clusters, + topic_prefs=topic_clusters, + user_prefs=user_preferences, + ) + + diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 0d318e230..e5260536a 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -15,7 +15,8 @@ from memos.memories.textual.item import TextualMemoryItem from memos.types import ChatHistory, MessageList from memos.llms.base import BaseLLM -from memos.vec_dbs.factory import QdrantVecDB, VecDBFactory +from memos.vec_dbs.factory import QdrantVecDB, VecDBFactory, MilvusVecDB +from memos.vec_dbs.item import VecDBItem from memos.memories.textual.prefer_text_memory.factory import BuilderFactory, RetrieverFactory, UpdaterFactory, AssemblerFactory @@ -28,7 +29,7 @@ def __init__(self, config: PreferenceTextMemoryConfig): self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm ) - self.vector_db: QdrantVecDB = VecDBFactory.from_config(config.vector_db) + self.vector_db: MilvusVecDB | QdrantVecDB = VecDBFactory.from_config(config.vector_db) self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = \ EmbedderFactory.from_config(config.embedder) @@ -57,7 +58,7 @@ def __init__(self, config: PreferenceTextMemoryConfig): vector_db=self.vector_db ) - def build_memory(self, history: ChatHistory) -> None: + def build_preferences(self, history: ChatHistory) -> None: """Build memory from the original dialogs. (Initialize memory) Args: @@ -68,9 +69,21 @@ def build_memory(self, history: ChatHistory) -> None: """ return self.builder.build(history) - def update_memory(self, new_memory: TextualMemoryItem | dict[str, Any]) -> None: - """Update a memory by new memory.""" - self.updater.update(new_memory) + def update_preferences(self, new_dialog: MessageList) -> None: + """Update a memory by new dialog. + Args: + new_dialog (MessageList): The new dialog to update. + """ + self.updater.update(new_dialog) + + def search_preferences(self, query: str, top_k: int, info=None) -> list[TextualMemoryItem]: + """Search for preferences based on a query. + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + """ + return self.retriever.retrieve(query, top_k, info) def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. @@ -94,6 +107,24 @@ def get_prompt(self, query: str, memories: list[TextualMemoryItem]) -> str: """ return self.assembler.assemble(query, memories) + def load(self, dir: str) -> None: + """Load memories from the specified directory. + Args: + dir (str): The directory containing the memory files. + """ + # For preference memory, we don't need to load from files + # as the data is stored in the vector database + pass + + def dump(self, dir: str) -> None: + """Dump memories to the specified directory. + Args: + dir (str): The directory where the memory files will be saved. + """ + # For preference memory, we don't need to dump to files + # as the data is stored in the vector database + pass + def extract(self, messages: MessageList) -> list[TextualMemoryItem]: """Extract memories based on the messages. Args: @@ -109,7 +140,27 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: Args: memories: List of TextualMemoryItem objects or dictionaries to add. """ - pass + if self.config.backend == "naive": + pass + else: + memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] + + # Memory encode + embed_memories = self.embedder.embed([m.memory for m in memory_items]) + + # Create vector db items + vec_db_items = [] + for item, emb in zip(memory_items, embed_memories, strict=True): + vec_db_items.append( + VecDBItem( + id=item.id, + payload=item.model_dump(), + vector=emb, + ) + ) + + # Add to vector db + self.vector_db.add(vec_db_items) def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: """Update a memory by memory_id.""" diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index e004b8da9..e2d960400 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -106,9 +106,9 @@ # Output Format: { - "cluster_name": "", - "cluster_description": "", - "cluster_preferences": "" + "topic_cluster_name": "", + "topic_cluster_description": "", + "topic_preferences": "" } # Notes: @@ -125,8 +125,9 @@ NAIVE_USER_PREFERENCE_EXTRACT_PROMPT = """ You are an advanced information integration assistant. You will be given a user's preference list across different topic clusters, where each cluster contains: -- cluster_name: The name of the topic cluster -- cluster_preferences: Natural language description of preferences in that cluster +- topic_cluster_name: The name of the topic cluster +- topic_cluster_description: The description of the topic cluster +- topic_preferences: Natural language description of preferences in that cluster Your task is to extract the user's **highest-level common preferences** by focusing on these three key dimensions: @@ -156,7 +157,7 @@ # Output Format: { - "user_overall_preferences": "Write the synthesized highest-level common preferences here, covering content preferences, interaction style, and value orientations" + "user_preferences": "Write the synthesized highest-level common preferences here, covering content preferences, interaction style, and value orientations" } # Notes: @@ -198,3 +199,68 @@ # Conversation Content {qa_pair} """ + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT = """ +You are a content comparison expert. Now you are given old and new information, each containing a question, answer topic name and topic description. +Please judge whether these two information express the **same question or core content**, regardless of expression differences, details or example differences. The judgment criteria are as follows: + +- Core content is consistent, that is, the essence of the question, goal or core concept to be solved is the same, it counts as "same". +- Different expressions, different examples, but the core meaning is consistent, also counts as "same". +- If the question goals, concepts involved or solution ideas are different, it counts as "different". + +Please output JSON format: +{ + "is_same": true/false, + "reasoning": "Briefly explain the judgment basis, highlighting whether the core content is consistent" +} + +**Old Information:** +{old_information} + +**New Information:** +{new_information} +""" + + +NAIVE_PREFERENCE_INTEGRATION_PROMPT = """ +You are a preference integration expert. Your task is to integrate preference constraints from different sources and generate a final prompt that can be directly input to an LLM. Please note the following information sources and their priority levels (from high to low): + +Sources: +1. Current query preferences: Constraints explicitly stated in the current user question +2. Related dialogue preferences: Preference references from Q&A pairs related to the current query +3. Related topic preferences: Preference references from topics related to the current query +4. User preferences: Common preference references from the user's historical conversations + +Priority: Current query preferences > Related dialogue preferences > Related topic preferences > User preferences > Implicit preferences + +Requirements: +- If conflicts exist between preferences, strictly follow the priority order, with higher priority preferences overriding lower priority ones. +- Generate a comprehensive prompt that includes all integrated preferences and constraints. +- The final prompt should be ready to be input directly to an LLM for answering the user's query. +- Keep the integrated preferences specific and actionable. +- Ensure the prompt is clear, structured, and contains all necessary context and constraints. + +Please generate the final integrated prompt based on the input, strictly resolve conflicts by priority, and output in JSON format as follows: +{{ + "final_prompt": "Complete prompt ready for LLM input, including query, context, and all integrated preferences", + "conflict_handling": ["Conflict resolution explanation 1", "Conflict resolution explanation 2", "..."], + "preference_summary": "Summary of all integrated preferences and constraints" +}} + +# Current query +{query_preference} + +# Related dialogue preferences +{explicit_preference} + +# Implicit preferences +{implicit_preference} + +# Related topic preferences +{topic_preference} + +# User preferences +{user_preference} + +""" \ No newline at end of file diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index a04c5aeff..7b529cdc1 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -42,12 +42,28 @@ def create_collection(self) -> None: dimension=self.config.vector_dimension, metric_type=self._get_metric_type(), id_type="string", # Use string ID type, align with VecDBItem id type + max_length=self.config.max_length, # Use max_length from config ) logger.info( f"Collection '{collection_name}' created with {self.config.vector_dimension} dimensions." ) + def create_collection_by_name(self, collection_name: str) -> None: + """Create a new collection with specified parameters.""" + if self.collection_exists(collection_name): + logger.warning( + f"Collection '{collection_name}' already exists. Skipping creation." + ) + return + + self.client.create_collection( + collection_name=collection_name, + dimension=self.config.vector_dimension, + metric_type=self._get_metric_type(), + id_type="string", # Use string ID type, align with VecDBItem id type + ) + def list_collections(self) -> list[str]: """List all collections.""" return self.client.list_collections() @@ -164,34 +180,46 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100) -> list[VecDBItem]: """ - Retrieve all items that match the given filter criteria. + Retrieve all items that match the given filter criteria using query_iterator. Args: filter: Payload filters to match against stored items - scroll_limit: Maximum number of items to retrieve per scroll request + scroll_limit: Maximum number of items to retrieve per batch (batch_size) Returns: List of items including vectors and payload that match the filter """ expr = self._dict_to_expr(filter) if filter else None + all_items = [] - results = self.client.query( + # Use query_iterator for efficient pagination + iterator = self.client.query_iterator( collection_name=collection_name, filter=expr, - limit=scroll_limit, + batch_size=scroll_limit, ) + + # Iterate through all batches + while True: + batch_results = iterator.next() + + if not batch_results: + break + + # Convert batch results to VecDBItem objects + for entity in batch_results: + payload = {k: v for k, v in entity.items() if k not in ["id", "vector"]} + all_items.append(VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + )) + + # Close the iterator + iterator.close() - items = [] - for entity in results: - payload = {k: v for k, v in entity.items() if k not in ["id", "vector"]} - items.append(VecDBItem( - id=entity["id"], - vector=entity.get("vector"), - payload=payload, - )) - - logger.info(f"Milvus retrieve by filter completed with {len(items)} results.") - return items + logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") + return all_items def get_all(self, collection_name: str, scroll_limit=100) -> list[VecDBItem]: """Retrieve all items in the vector database.""" From da601d81706c1c9b4c4a74d250b8756d6213097d Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 17 Sep 2025 21:04:56 +0800 Subject: [PATCH 05/62] modify code struct --- src/memos/configs/mem_cube.py | 16 + src/memos/configs/mem_os.py | 4 + src/memos/configs/memory.py | 16 +- src/memos/mem_cube/general.py | 43 ++- src/memos/mem_os/core.py | 44 ++- .../textual/prefer_text_memory/adder.py | 106 ++++++ .../textual/prefer_text_memory/builders.py | 132 ------- .../textual/prefer_text_memory/clustering.py | 171 +++++---- .../textual/prefer_text_memory/config.py | 53 ++- .../textual/prefer_text_memory/extractor.py | 201 +++++++++++ .../textual/prefer_text_memory/factory.py | 39 +- .../textual/prefer_text_memory/spliter.py | 103 ++++++ .../textual/prefer_text_memory/updater.py | 339 ++++++++++++++---- .../textual/prefer_text_memory/utils.py | 14 + src/memos/memories/textual/preference.py | 61 +--- src/memos/templates/prefer_complete_prompt.py | 12 +- src/memos/types.py | 1 + 17 files changed, 995 insertions(+), 360 deletions(-) create mode 100644 src/memos/memories/textual/prefer_text_memory/adder.py delete mode 100644 src/memos/memories/textual/prefer_text_memory/builders.py create mode 100644 src/memos/memories/textual/prefer_text_memory/extractor.py create mode 100644 src/memos/memories/textual/prefer_text_memory/spliter.py create mode 100644 src/memos/memories/textual/prefer_text_memory/utils.py diff --git a/src/memos/configs/mem_cube.py b/src/memos/configs/mem_cube.py index b9868fa99..afca52a26 100644 --- a/src/memos/configs/mem_cube.py +++ b/src/memos/configs/mem_cube.py @@ -54,6 +54,11 @@ class GeneralMemCubeConfig(BaseMemCubeConfig): default_factory=MemoryConfigFactory, description="Configuration for the parametric memory", ) + pref_mem: MemoryConfigFactory = Field( + ..., + default_factory=MemoryConfigFactory, + description="Configuration for the preference memory", + ) @field_validator("text_mem") @classmethod @@ -87,3 +92,14 @@ def validate_para_mem(cls, para_mem: MemoryConfigFactory) -> MemoryConfigFactory f"GeneralMemCubeConfig requires para_mem backend to be one of {allowed_backends}, got '{para_mem.backend}'" ) return para_mem + + @field_validator("pref_mem") + @classmethod + def validate_pref_mem(cls, pref_mem: MemoryConfigFactory) -> MemoryConfigFactory: + """Validate the pref_mem field.""" + allowed_backends = ["naive_pref", "uninitialized"] + if pref_mem.backend not in allowed_backends: + raise ConfigurationError( + f"GeneralMemCubeConfig requires pref_mem backend to be one of {allowed_backends}, got '{pref_mem.backend}'" + ) + return pref_mem \ No newline at end of file diff --git a/src/memos/configs/mem_os.py b/src/memos/configs/mem_os.py index 0645fce44..549e55792 100644 --- a/src/memos/configs/mem_os.py +++ b/src/memos/configs/mem_os.py @@ -58,6 +58,10 @@ class MOSConfig(BaseConfig): default=False, description="Enable parametric memory for the MemChat", ) + enable_preference_memory: bool = Field( + default=False, + description="Enable preference memory for the MemChat", + ) enable_mem_scheduler: bool = Field( default=False, description="Enable memory scheduler for automated memory management", diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index a83a77b46..6ee76ffae 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -11,7 +11,8 @@ from memos.configs.vec_db import VectorDBConfigFactory from memos.exceptions import ConfigurationError from memos.memories.textual.prefer_text_memory.config import ( - BuilderConfigFactory, + AdderConfigFactory, + ExtractorConfigFactory, RetrieverConfigFactory, UpdaterConfigFactory, AssemblerConfigFactory) @@ -203,10 +204,15 @@ class PreferenceTextMemoryConfig(BaseTextMemoryConfig): default_factory=EmbedderConfigFactory, description="Embedder configuration for the memory embedding", ) - builder: BuilderConfigFactory = Field( + extractor: ExtractorConfigFactory = Field( ..., - default_factory=BuilderConfigFactory, - description="Builder configuration for the memory building", + default_factory=ExtractorConfigFactory, + description="Extractor configuration for the memory extracting", + ) + adder: AdderConfigFactory = Field( + ..., + default_factory=AdderConfigFactory, + description="Adder configuration for the memory adding", ) retriever: RetrieverConfigFactory = Field( ..., @@ -237,7 +243,7 @@ class MemoryConfigFactory(BaseConfig): "naive_text": NaiveTextMemoryConfig, "general_text": GeneralTextMemoryConfig, "tree_text": TreeTextMemoryConfig, - "preference_text": PreferenceTextMemoryConfig, + "pref_text": PreferenceTextMemoryConfig, "kv_cache": KVCacheMemoryConfig, "vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache "lora": LoRAMemoryConfig, diff --git a/src/memos/mem_cube/general.py b/src/memos/mem_cube/general.py index 7217c354b..0fa3b26e2 100644 --- a/src/memos/mem_cube/general.py +++ b/src/memos/mem_cube/general.py @@ -38,16 +38,21 @@ def __init__(self, config: GeneralMemCubeConfig): if config.para_mem.backend != "uninitialized" else None ) + self._pref_mem: BaseTextMemory | None = ( + MemoryFactory.from_config(config.pref_mem) + if config.pref_mem.backend != "uninitialized" + else None + ) def load( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None ) -> None: """Load memories. Args: dir (str): The directory containing the memory files. memory_types (list[str], optional): List of memory types to load. If None, loads all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) if loaded_schema != self.config.model_schema: @@ -58,7 +63,7 @@ def load( # If no specific memory types specified, load all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Load specified memory types if "text_mem" in memory_types and self.text_mem: @@ -73,17 +78,21 @@ def load( self.para_mem.load(dir) logger.info(f"Loaded para_mem from {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.load(dir) + logger.info(f"Loaded pref_mem from {dir}") + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None ) -> None: """Dump memories. Args: dir (str): The directory where the memory files will be saved. memory_types (list[str], optional): List of memory types to dump. If None, dumps all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ if os.path.exists(dir) and os.listdir(dir): raise MemCubeError( @@ -95,7 +104,7 @@ def dump( # If no specific memory types specified, dump all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Dump specified memory types if "text_mem" in memory_types and self.text_mem: @@ -110,12 +119,16 @@ def dump( self.para_mem.dump(dir) logger.info(f"Dumped para_mem to {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.dump(dir) + logger.info(f"Dumped pref_mem to {dir}") + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") @staticmethod def init_from_dir( dir: str, - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, default_config: GeneralMemCubeConfig | None = None, ) -> "GeneralMemCube": """Create a MemCube instance from a MemCube directory. @@ -146,7 +159,7 @@ def init_from_dir( def init_from_remote_repo( cube_id: str, base_url: str = "https://huggingface.co/datasets", - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, default_config: GeneralMemCubeConfig | None = None, ) -> "GeneralMemCube": """Create a MemCube instance from a remote repository. @@ -205,3 +218,17 @@ def para_mem(self, value: BaseParaMemory) -> None: if not isinstance(value, BaseParaMemory): raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") self._para_mem = value + + @property + def pref_mem(self) -> "BaseTextMemory | None": + """Get the preference memory.""" + if self._pref_mem is None: + logger.warning("Preference memory is not initialized. Returning None.") + return self._pref_mem + + @pref_mem.setter + def pref_mem(self, value: BaseTextMemory) -> None: + """Set the preference memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._pref_mem = value \ No newline at end of file diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index a201e22cc..9eece178b 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -579,6 +579,7 @@ def search( "text_mem": [], "act_mem": [], "para_mem": [], + "pref_mem": [], } if install_cube_ids is None: install_cube_ids = user_cube_ids @@ -615,6 +616,32 @@ def search( logger.info( f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" ) + + if ( + (mem_cube_id in install_cube_ids) + and (mem_cube.pref_mem is not None) + and self.config.enable_preference_memory + ): + memories = mem_cube.pref_mem.search( + query, + top_k=top_k if top_k else self.config.top_k, + mode=mode, + info={ + "user_id": target_user_id, + "session_id": self.session_id, + "chat_history": chat_history.chat_history, + }, + moscube=moscube, + ) + result["pref_mem"].append({"cube_id": mem_cube_id, "memories": memories}) + logger.info( + f"🧠 [Memory] Searched preferences from {mem_cube_id}:\n{self._str_memories(memories)}\n" + ) + search_time_end = time.time() + logger.info( + f"time search pref: search pref time user_id: {target_user_id} time is: {search_time_end - time_start}" + ) + return result def add( @@ -699,6 +726,21 @@ def add( ) self.mem_scheduler.submit_messages(messages=[message_item]) + if ( + (messages is not None) + and self.config.enable_preference_memory + and self.mem_cubes[mem_cube_id].pref_mem + ): + messages_list = [messages] + pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( + messages_list, + msg_type="chat", + info={"user_id": target_user_id, "session_id": self.session_id}) + preferences = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) + logger.info( + f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {preferences}" + ) + # user profile if ( (memory_content is not None) @@ -965,7 +1007,7 @@ def load( load_dir: str, user_id: str | None = None, mem_cube_id: str | None = None, - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump the MemCube to a dictionary. Args: diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py new file mode 100644 index 000000000..c3d91ca05 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -0,0 +1,106 @@ +from abc import ABC, abstractmethod +from typing import Any +import json +from datetime import datetime +from concurrent.futures import ThreadPoolExecutor, as_completed +from memos.memories.textual.item import TextualMemoryItem +from memos.vec_dbs.item import VecDBItem +from memos.templates.prefer_complete_prompt import NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT +from memos.types import MessageList + + +class BaseAdder(ABC): + """Abstract base class for adders.""" + + @abstractmethod + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the adder.""" + + @abstractmethod + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], *args, **kwargs) -> list[str]: + """Add the instruct preference memories. + Args: + memories (list[TextualMemoryItem | dict[str, Any]]): The memories to add. + **kwargs: Additional keyword arguments. + Returns: + list[str]: List of added memory IDs. + """ + +class NaiveAdder(BaseAdder): + """Naive adder.""" + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the naive adder.""" + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db + + def _judge_update_or_add(self, old_msg: MessageList, new_msg: MessageList) -> bool: + """Judge if the new message expresses the same core content as the old message.""" + # Convert messages to string format for comparison + old_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in old_msg]) + new_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in new_msg]) + + # Use the template prompt with placeholders + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_str).replace("{new_information}", new_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + response = result.get("is_same", False) + return response if isinstance(response, bool) else response == "true" + except Exception as e: + print(f"Error in judge_update_or_add: {e}") + # Fallback to simple string comparison + return old_str == new_str + + def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: + """Process a single memory and return its ID if added successfully.""" + try: + payload = {**memory.metadata.model_dump()} + vec_db_item = VecDBItem( + id=memory.id, + vector=memory.metadata.dialog_vector, + payload=payload + ) + + recall = self.vector_db.search(memory.metadata.dialog_vector, "explicit_preference", top_k=1)[0] + if not recall or (recall.score is not None and recall.score < 0.5): + self.vector_db.update("explicit_preference", memory.id, vec_db_item) + return memory.id + + old_msg_str = recall.payload.get("dialog_str", "") + new_msg_str = memory.metadata.dialog_str + is_same = self._judge_update_or_add(old_msg_str, new_msg_str) + if is_same: + self.vector_db.delete("explicit_preference", [recall.id]) + self.vector_db.update("explicit_preference", memory.id, vec_db_item) + return memory.id + + except Exception as e: + print(f"Error processing memory {memory.id}: {e}") + return None + + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], max_workers: int = 10, *args, **kwargs) -> list[str]: + """Add the instruct preference memories using thread pool for acceleration.""" + if not memories: + return [] + + added_ids = [] + with ThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: + future_to_memory = { + executor.submit(self._process_single_memory, memory): memory + for memory in memories + } + + for future in as_completed(future_to_memory): + try: + memory_id = future.result() + if memory_id: + added_ids.append(memory_id) + except Exception as e: + memory = future_to_memory[future] + print(f"Error processing memory {memory.id}: {e}") + continue + + return added_ids diff --git a/src/memos/memories/textual/prefer_text_memory/builders.py b/src/memos/memories/textual/prefer_text_memory/builders.py deleted file mode 100644 index d4dd7bfaf..000000000 --- a/src/memos/memories/textual/prefer_text_memory/builders.py +++ /dev/null @@ -1,132 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, Dict, List, Any, Tuple -import uuid -import json -from datetime import datetime - -from memos.llms.base import BaseLLM -from memos.memories.textual.prefer_text_memory import naive_op -from memos.types import ChatHistory, MessageList -from memos.embedders.base import BaseEmbedder -from memos.vec_dbs.base import BaseVecDB -from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata -from memos.vec_dbs.item import VecDBItem - -from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer -from memos.memories.textual.prefer_text_memory.naive_op import NaiveOp - - - -class BaseBuilder(ABC): - """ - Abstract base class for memory builders. - - Each builder implements a specific build strategy for creating - preference memory content from chat history. - """ - - @abstractmethod - def __init__(self, llm_provider=None, embedder=None, vector_db=None): - """ - Initialize the memory builder. - - Args: - llm_provider: LLM provider for script generation (required for some strategies) - embedder: Embedder for vector operations - vector_db: Vector database for storage - """ - - @abstractmethod - def build(self, history: ChatHistory) -> str: - """ - Build memory content from chat history. - - Args: - history: The chat history to build memory from. - - Returns: - Memory content string formatted according to the build strategy - - Raises: - RuntimeError: If memory building fails - """ - - - -class NaiveBuilder(BaseBuilder): - """Naive memory builder.""" - def __init__(self, llm_provider=None, embedder=None, vector_db=None): - """Initialize the naive memory builder.""" - super().__init__(llm_provider, embedder, vector_db) - self.llm_provider = llm_provider - self.embedder = embedder - self.vector_db = vector_db - - - - def build(self, history: ChatHistory) -> str: - """Build memory content from chat history following the preference extraction pipeline.""" - - # Initialize clustering manager - clusterer = HDBSCANClusterer() - naive_op = NaiveOp(self.llm_provider, self.embedder, self.vector_db) - - # Step 1: Build QA pairs from chat history - qa_pairs = naive_op.build_qa_pairs(history.chat_history) - - # Step 2: Process each QA pair - basic_infos = [] - explicit_preferences = [] - topic_infos = [] - - for qa_pair in qa_pairs: - # Extract basic info - basic_info = naive_op.extract_basic_info(qa_pair) - basic_infos.append(basic_info) - - # Extract topic information - topic_info = naive_op.extract_topic_info(qa_pair) - if topic_info: - topic_infos.append(topic_info) - - # Extract explicit preference from - explicit_pref = naive_op.extract_explicit_preference(qa_pair) - if explicit_pref: - explicit_preferences.append(explicit_pref) - - # Step 3: Generate embeddings - dialogue_vectors = naive_op.generate_dialogue_vectors(basic_infos) - topic_vectors = naive_op.generate_topic_vectors(topic_infos) - - whole_infos = naive_op.concat_infos(basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors) - - # Step 4: Perform clustering - implicit_clusters = naive_op.implicit_cluster(clusterer, whole_infos) - topic_clusters = naive_op.topic_cluster(clusterer, whole_infos) - - # Step 5: Extract implicit preferences - implicit_clusters = naive_op.extract_implicit_preferences(implicit_clusters) - - # Step 6: Extract topic preferences - topic_clusters = naive_op.extract_topic_preferences(topic_clusters) - - # Step 7: Extract user preferences - user_preferences = naive_op.extract_user_preferences(topic_clusters) - - # Step 8: Store all preferences in memory - naive_op.store_preferences( - explicit_prefs=whole_infos, - implicit_prefs=implicit_clusters, - topic_prefs=topic_clusters, - user_prefs=user_preferences, - user_id=history.user_id, - ) - - # Step 9: Return summary of built memory - return naive_op.generate_memory_summary( - explicit_prefs=whole_infos, - implicit_prefs=implicit_clusters, - topic_prefs=topic_clusters, - user_prefs=user_preferences, - ) - \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/clustering.py b/src/memos/memories/textual/prefer_text_memory/clustering.py index 7870a9e55..0e54cd95b 100644 --- a/src/memos/memories/textual/prefer_text_memory/clustering.py +++ b/src/memos/memories/textual/prefer_text_memory/clustering.py @@ -1,22 +1,27 @@ -""" -Clustering module for preference memory extraction. - -This module provides abstract clustering functionality that can be used -by different types of clustering (implicit, topic, etc.). -""" - from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional, Tuple +from pydantic import BaseModel +from typing import List, Dict, Any import numpy as np from datetime import datetime import uuid +from sklearn.cluster import HDBSCAN +from sklearn.neighbors import NearestNeighbors +class ClusterResult(BaseModel): + cluster_id: str + cluster_label: int + center_index: int + center_vector: List[float] + size: int + created_at: str + items: List[Dict[str, Any]] + class BaseClusterer(ABC): """Abstract base class for clustering algorithms.""" @abstractmethod - def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: + def cluster(self, vectors: List[List[float]], **kwargs) -> List[ClusterResult]: """ Perform clustering on the given vectors. @@ -29,22 +34,29 @@ def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: """ pass - class HDBSCANClusterer(BaseClusterer): """HDBSCAN-based clustering implementation.""" - def __init__(self, min_cluster_size: int = 2, min_samples: int = 1): + def __init__(self, min_cluster_size: int = 2, min_samples: int = 1, metric: str = "cosine"): """ Initialize HDBSCAN clusterer. Args: min_cluster_size: Minimum size of a cluster min_samples: Minimum number of samples in a neighborhood + metric: Metric for clustering """ self.min_cluster_size = min_cluster_size self.min_samples = min_samples + self.metric = metric + + self.clusterer = HDBSCAN( + min_cluster_size=self.min_cluster_size, + min_samples=self.min_samples, + metric=self.metric + ) - def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: + def cluster(self, vectors: List[List[float]], **kwargs) -> List[ClusterResult]: """ Perform HDBSCAN clustering on the given vectors. @@ -56,70 +68,83 @@ def cluster(self, vectors: List[List[float]], **kwargs) -> List[Dict[str, Any]]: List of cluster information dictionaries """ if not vectors or len(vectors) < 2: - return [] + return [] + + # Convert to numpy array + vectors_array = np.array(vectors) + + # Perform clustering + cluster_labels = self.clusterer.fit_predict(vectors_array) - try: - from sklearn.cluster import HDBSCAN + # Group vectors by cluster + clusters = {} + for i, label in enumerate(cluster_labels): + if label not in clusters: + clusters[label] = [] + clusters[label].append({ + "vector": vectors[i], # origin sample vector + "index": i # origin sample index + }) + + # Convert to cluster information + cluster_info = [] + for cluster_label, items_in_cluster in clusters.items(): + if cluster_label == -1: # Skip noise points + continue - # Convert to numpy array - vectors_array = np.array(vectors) + # Calculate cluster geometric center + geometric_center_vector = np.mean([item["vector"] for item in items_in_cluster], axis=0) - # Perform clustering - clusterer = HDBSCAN( - min_cluster_size=self.min_cluster_size, - min_samples=self.min_samples - ) - cluster_labels = clusterer.fit_predict(vectors_array) + # Find the vector closest to the geometric center using sklearn + cluster_vectors = [item["vector"] for item in items_in_cluster] + nn = NearestNeighbors(n_neighbors=1, metric=self.metric) + nn.fit(cluster_vectors) + _, indices = nn.kneighbors([geometric_center_vector.tolist()]) + closest_idx = indices[0][0] + center_index = items_in_cluster[closest_idx]["index"] + center_vector = cluster_vectors[closest_idx]["vector"] - # Group vectors by cluster - clusters = {} - for i, label in enumerate(cluster_labels): - if label not in clusters: - clusters[label] = [] - clusters[label].append({ - "vector": vectors[i], # origin sample vector - "index": i # origin sample index - }) - - # Convert to cluster information - cluster_info = [] - for cluster_label, items_in_cluster in clusters.items(): - if cluster_label == -1: # Skip noise points - continue - - # Calculate cluster center - center_vector = np.mean([item["vector"] for item in items_in_cluster], axis=0) - - # Find the vector closest to the geometric center - distances = [np.linalg.norm(np.array(item["vector"]) - center_vector) for item in items_in_cluster] - closest_idx = np.argmin(distances) - center_index = items_in_cluster[closest_idx]["index"] - - cluster_info.append({ - "cluster_id": str(uuid.uuid4()), - "cluster_label": cluster_label, - "center_index": center_index, - "center_vector": center_vector.tolist(), - "size": len(items_in_cluster), - "created_at": datetime.now().isoformat(), - "items": items_in_cluster # list of origin sample index and vector in this cluster - }) + cluster_info.append(ClusterResult( + cluster_id=str(uuid.uuid4()), + cluster_label=cluster_label, # cluseter label: -1(noise) 0 1 2 3 ... + center_index=center_index, # center sample origin index + center_vector=center_vector, + size=len(items_in_cluster), # size of the cluster + created_at=datetime.now().isoformat(), + items=items_in_cluster # list of origin sample index and vector in this cluster + )) + + return cluster_info + + def search_knn_by_center_embeddings(self, center_emb: List[float], vectors: List[List[float]], top_k: int = 5) -> List[Dict[str, Any]]: + """ + Search for k-nearest neighbors to the center embedding using sklearn. + + Args: + center_emb: Center embedding vector + vectors: List of vectors to search in + top_k: Number of nearest neighbors to return (maybe include the center embedding itself) - return cluster_info + Returns: + List of dictionaries containing index, distance, and vector for each neighbor + """ + if not vectors: + return [] - except ImportError: - # Fallback: create single cluster - return [{ - "cluster_id": str(uuid.uuid4()), - "cluster_label": -1, - "center_index": 0, - "center_vector": vectors[0] if vectors else [], - "size": len(vectors), - "created_at": datetime.now().isoformat(), - "items": [{"index": i, "vector": vectors[i]} - for i in range(len(vectors))] # list of origin sample index and vector in this cluster - }] - - - - + # Initialize NearestNeighbors with the same metric + nn = NearestNeighbors(n_neighbors=min(top_k, len(vectors)), metric=self.metric) + nn.fit(vectors) + + # Search for nearest neighbors + distances, indices = nn.kneighbors([center_emb]) + + # Format results + results = [] + for dist, idx in zip(distances[0], indices[0]): + results.append({ + "index": int(idx), # neighbor original index + "distance": float(dist), # neighbor distance + "vector": vectors[idx] # neighbor vector + }) + + return results diff --git a/src/memos/memories/textual/prefer_text_memory/config.py b/src/memos/memories/textual/prefer_text_memory/config.py index ede422e15..fa6067528 100644 --- a/src/memos/memories/textual/prefer_text_memory/config.py +++ b/src/memos/memories/textual/prefer_text_memory/config.py @@ -8,24 +8,24 @@ -class BaseBuilderConfig(BaseConfig): - """Base configuration class for Builder.""" +class BaseAdderConfig(BaseConfig): + """Base configuration class for Adder.""" -class NaiveBuilderConfig(BaseBuilderConfig): - """Configuration for Naive Builder.""" +class NaiveAdderConfig(BaseAdderConfig): + """Configuration for Naive Adder.""" # No additional config needed since components are passed from parent -class BuilderConfigFactory(BaseConfig): - """Factory class for creating Builder configurations.""" +class AdderConfigFactory(BaseConfig): + """Factory class for creating Adder configurations.""" - backend: str = Field(..., description="Backend for Builder") - config: dict[str, Any] = Field(..., description="Configuration for the Builder backend") + backend: str = Field(..., description="Backend for Adder") + config: dict[str, Any] = Field(..., description="Configuration for the Adder backend") backend_to_class: ClassVar[dict[str, Any]] = { - "naive": NaiveBuilderConfig, + "naive": NaiveAdderConfig, } @field_validator("backend") @@ -37,14 +37,45 @@ def validate_backend(cls, backend: str) -> str: return backend @model_validator(mode="after") - def create_config(self) -> "BuilderConfigFactory": + def create_config(self) -> "AdderConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) return self +class BaseExtractorConfig(BaseConfig): + """Base configuration class for Extractor.""" + + +class NaiveExtractorConfig(BaseExtractorConfig): + """Configuration for Naive Extractor.""" + +class ExtractorConfigFactory(BaseConfig): + """Factory class for creating Extractor configurations.""" + + backend: str = Field(..., description="Backend for Extractor") + config: dict[str, Any] = Field(..., description="Configuration for the Extractor backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveExtractorConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "ExtractorConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + class BaseRetrieverConfig(BaseConfig): - """Base configuration class for Retriever.""" + """Base configuration class for Retrievers.""" class NaiveRetrieverConfig(BaseRetrieverConfig): diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py new file mode 100644 index 000000000..65dbd5037 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -0,0 +1,201 @@ +import json +import uuid +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, Optional, List +from concurrent.futures import ThreadPoolExecutor, as_completed +from memos.types import MessageList +from memos.memories.textual.item import TextualMemoryItem, PreferenceTextualMemoryMetadata +from memos.templates.prefer_complete_prompt import ( + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT, + NAIVE_USER_PREFERENCE_EXTRACT_PROMPT, + NAIVE_TOPIC_INFO_EXTRACT_PROMPT, +) +from memos.memories.textual.prefer_text_memory.spliter import Splitter +from memos.memories.textual.prefer_text_memory.utils import convert_messages_to_string + +class BaseExtractor(ABC): + """Abstract base class for extractors.""" + + @abstractmethod + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the extractor.""" + + +class NaiveExtractor(BaseExtractor): + """Extractor.""" + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the extractor.""" + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db + self.splitter = Splitter() + + def extract_basic_info(self, qa_pair: MessageList) -> Dict[str, Any]: + """Extract basic information from a QA pair (no LLM needed).""" + basic_info = { + "dialog_id": str(uuid.uuid4()), + "dialog_msgs": qa_pair, + "dialog_str": convert_messages_to_string(qa_pair), + "created_at": datetime.now().isoformat() + } + + return basic_info + + def extract_topic_info(self, qa_pair: MessageList | str) -> Dict[str, Any]: + """Extract topic information from a QA pair.""" + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, MessageList) else qa_pair + prompt = NAIVE_TOPIC_INFO_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + return result + except Exception: + return response + + def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, Any]: + """Extract explicit preference from a QA pair.""" + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, MessageList) else qa_pair + prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + return result + except Exception: + return response + + def extract_implicit_preferences(self, qa_pairs: list[MessageList] | list[str]) -> List[Dict[str, Any]]: + """Extract implicit preferences from cluster qa pairs.""" + qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs, MessageList) else "\n\n".join(qa_pairs) + prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + if result.get("implicit_preference"): + return result + except Exception as e: + print(f"Error processing cluster: {qa_pairs}\n{e}") + return "" + + def extract_topic_preferences(self, qa_pairs: list[MessageList] | list[str]) -> List[Dict[str, Any]]: + """Extract topic preferences from cluster qa pairs.""" + qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs, MessageList) else "\n\n".join(qa_pairs) + prompt = NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + if result.get("topic_cluster_name"): + return result + except Exception as e: + print(f"Error processing cluster: {qa_pairs}\n{e}") + return "" + + def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract user-level preferences.""" + if not topic_preferences: + return [] + cluster_infos = [{ + "topic_cluster_name": cluster["topic_cluster_name"], + "topic_cluster_description": cluster["topic_cluster_description"], + "topic_preferences": cluster["topic_preferences"]} + for cluster in topic_preferences] + prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace("{cluster_info}", json.dumps(cluster_infos, ensure_ascii=False, indent=2)) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + + if result.get("user_preferences"): + return result + except Exception as e: + print(f"Error processing user preferences: {topic_preferences}\n{e}") + return "" + + def concat_infos( + self, + basic_infos: List[Dict[str, Any]] = None, + explicit_preferences: List[Dict[str, Any]] = None, + topic_infos: List[Dict[str, Any]] = None, + dialogue_vectors: List[Dict[str, Any]] = None, + topic_vectors: List[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + """Concatenate infos - only merge if not None.""" + # Get all non-None lists + list_to_concat = [] + for lst in [basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors]: + if lst is not None: + list_to_concat.append(lst) + + if not list_to_concat: + return [] + + # Use the first list to determine length + length = len(list_to_concat[0]) + + whole_infos = [] + for i in range(length): + merged_dict = {} + + # Only merge if not None + if basic_infos is not None and i < len(basic_infos): + merged_dict.update(basic_infos[i]) + if explicit_preferences is not None and i < len(explicit_preferences): + merged_dict.update(explicit_preferences[i]) + if topic_infos is not None and i < len(topic_infos): + merged_dict.update(topic_infos[i]) + if dialogue_vectors is not None and i < len(dialogue_vectors): + merged_dict.update(dialogue_vectors[i]) + if topic_vectors is not None and i < len(topic_vectors): + merged_dict.update(topic_vectors[i]) + + whole_infos.append(merged_dict) + + return whole_infos + + def _process_single_chunk(self, chunk: MessageList, msg_type: str, info: dict[str, Any]) -> TextualMemoryItem: + """Process a single chunk and return a TextualMemoryItem.""" + basic_info = self.extract_basic_info(chunk) + topic_info = self.extract_topic_info(chunk) + explicit_pref = self.extract_explicit_preference(chunk) + + vector_info = { + "dialog_vector": self.embedder.embed([basic_info["dialog_str"]])[0], + "topic_vector": self.embedder.embed([topic_info["topic_name"] + topic_info["topic_description"]])[0] + } + extract_info = {**basic_info, **topic_info, **explicit_pref, **vector_info, **info} + + metadata = PreferenceTextualMemoryMetadata(type=msg_type, preference_type="explicit_preference", **extract_info) + memory = TextualMemoryItem(id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata) + return memory + + def extract(self, messages: MessageList, msg_type: str, info: dict[str, Any], max_workers: int = 10) -> list[TextualMemoryItem]: + """Extract preference memories based on the messages using thread pool for acceleration.""" + chunks = self.splitter.split_chunks(messages) + if not chunks: + return [] + + memories = [] + with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: + future_to_chunk = { + executor.submit(self._process_single_chunk, chunk, msg_type, info): chunk + for chunk in chunks + } + + for future in as_completed(future_to_chunk): + try: + memory = future.result() + memories.append(memory) + except Exception as e: + chunk = future_to_chunk[future] + print(f"Error processing chunk: {chunk}\n{e}") + continue + + return memories \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index 16dc622e9..f591d40b1 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -1,27 +1,48 @@ from typing import Any, ClassVar -from memos.memories.textual.prefer_text_memory.builders import BaseBuilder, NaiveBuilder +from memos.memories.textual.prefer_text_memory.adder import BaseAdder, NaiveAdder +from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor, NaiveExtractor from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever, NaiveRetriever from memos.memories.textual.prefer_text_memory.updater import BaseUpdater, NaiveUpdater from memos.memories.textual.prefer_text_memory.assemble import BaseAssembler, NaiveAssembler -from memos.memories.textual.prefer_text_memory.config import BuilderConfigFactory, RetrieverConfigFactory, UpdaterConfigFactory, AssemblerConfigFactory +from memos.memories.textual.prefer_text_memory.config import (AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, + UpdaterConfigFactory, + AssemblerConfigFactory) -class BuilderFactory(BaseBuilder): - """Factory class for creating Builder instances.""" +class AdderFactory(BaseAdder): + """Factory class for creating Adder instances.""" backend_to_class: ClassVar[dict[str, Any]] = { - "naive": NaiveBuilder, + "naive": NaiveAdder, } @classmethod - def from_config(cls, config_factory: BuilderConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseBuilder: - """Create a Builder instance from a configuration factory.""" + def from_config(cls, config_factory: AdderConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseAdder: + """Create a Adder instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") - builder_class = cls.backend_to_class[backend] - return builder_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + adder_class = cls.backend_to_class[backend] + return adder_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + +class ExtractorFactory(BaseExtractor): + """Factory class for creating Extractor instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveExtractor, + } + + @classmethod + def from_config(cls, config_factory: ExtractorConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseExtractor: + """Create a Extractor instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + extractor_class = cls.backend_to_class[backend] + return extractor_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) class RetrieverFactory(BaseRetriever): """Factory class for creating Retriever instances.""" diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py new file mode 100644 index 000000000..bee0891c7 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -0,0 +1,103 @@ +from memos.types import MessageList +from memos.parsers.factory import ParserFactory +from memos.configs.parser import ParserConfigFactory +from memos.chunkers import ChunkerFactory +from memos.configs.chunker import ChunkerConfigFactory + +class Splitter: + """Splitter.""" + def __init__(self, lookback_turns: int = 1, + chunk_size: int = 256, + chunk_overlap: int = 128, + min_sentences_per_chunk: int = 1, + tokenizer: str = "gpt2", + parser_backend: str = "markitdown", + chunker_backend: str = "sentence"): + """Initialize the splitter.""" + self.lookback_turns = lookback_turns + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.min_sentences_per_chunk = min_sentences_per_chunk + self.tokenizer = tokenizer + self.chunker_backend = chunker_backend + self.parser_backend = parser_backend + # Initialize parser + parser_config = ParserConfigFactory.model_validate( + { + "backend": self.parser_backend, + "config": {}, + } + ) + self.parser = ParserFactory.from_config(parser_config) + + # Initialize chunker + chunker_config = ChunkerConfigFactory.model_validate( + { + "backend": self.chunker_backend, + "config": { + "tokenizer_or_token_counter": self.tokenizer, + "chunk_size": self.chunk_size, + "chunk_overlap": self.chunk_overlap, + "min_sentences_per_chunk": self.min_sentences_per_chunk + } + } + ) + self.chunker = ChunkerFactory.from_config(chunker_config) + + def split_chunks(self, data: MessageList | str) -> list[MessageList] | list[str]: + """Split the messages or files into chunks. + + Args: + data: MessageList or string to split + + Returns: + List of MessageList chunks or list of string chunks + """ + if isinstance(data, MessageList): + # Build QA pairs from chat history + pairs = self.build_qa_pairs(data) + chunks = [] + + # Create chunks by looking back fixed number of turns + for i in range(len(pairs)): + # Calculate the start index for lookback + start_idx = max(0, i + 1 - self.lookback_turns) + # Get the chunk of pairs (as many as available, up to lookback_turns) + chunk_pairs = pairs[start_idx:i+1] + + # Flatten chunk_pairs (list[list[dict]]) to MessageList (list[dict]) + chunk_messages = [] + for pair in chunk_pairs: + chunk_messages.extend(pair) + + chunks.append(chunk_messages) + + return chunks + else: + # Parse and chunk the string data using pre-initialized components + text = self.parser.parse(data) + chunks = self.chunker.chunk(text) + + return [chunk.text for chunk in chunks] + + + def build_qa_pairs(self, chat_history: MessageList) -> list[MessageList]: + """Build QA pairs from chat history.""" + qa_pairs = [] + current_qa_pair = [] + + for message in chat_history: + if message["role"] == "user": + current_qa_pair.append(message) + elif message["role"] == "assistant": + if not current_qa_pair: + continue + current_qa_pair.append(message) + qa_pairs.append(current_qa_pair.copy()) + current_qa_pair = [] # reset + + return qa_pairs + + def recursive_split_merge(): + pass + \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index d54810b39..a33f586ac 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -1,17 +1,18 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, List, Dict from datetime import datetime +import json +from concurrent.futures import ThreadPoolExecutor, as_completed from memos.types import MessageList from memos.vec_dbs.item import VecDBItem -from memos.memories.textual.prefer_text_memory.naive_op import NaiveOp -from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer +from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer, ClusterResult class BaseUpdater(ABC): """Abstract base class for updaters.""" @abstractmethod - def __init__(self, llm_provider=None, embedder=None, vector_db=None): + def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=None): """Initialize the updater.""" @abstractmethod @@ -26,87 +27,255 @@ def update(self, new_dialog: MessageList, *args, **kwargs) -> None: class NaiveUpdater(BaseUpdater): """Naive updater.""" - def __init__(self, llm_provider=None, embedder=None, vector_db=None): + def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=None): """Initialize the naive updater.""" - super().__init__(llm_provider, embedder, vector_db) + super().__init__(llm_provider, embedder, vector_db, extractor) self.llm_provider = llm_provider self.embedder = embedder self.vector_db = vector_db + self.extractor = extractor + self.clusterer = HDBSCANClusterer() - def update(self, new_dialog: MessageList, *args, **kwargs) -> None: - """Update the dialog to the vector db, fast update. + def _implicit_cluster(self, informations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Implicit cluster.""" + vectors = [info.get("dialog_vector") for info in informations] + if not vectors: + return [] + return self.clusterer.cluster(vectors) + + def _topic_cluster(self, informations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Topic cluster.""" + vectors = [info.get("topic_vector") for info in informations] + if not vectors: + return [] + return self.clusterer.cluster(vectors) + + def _create_cluster_extract_input(self, cluster_results: List[ClusterResult], informations: List[Dict[str, Any]], input_type: str, k: int = 5) -> Dict[str, List[str]]: + """Create cluster extract input. Args: - new_dialog (MessageList): The new dialog to update. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. + cluster_results: List[ClusterResult] + informations: List[Dict[str, Any]] containing dialog information + input_type: str, "original" or "knn" + k: int, number of nearest neighbors for knn mode + Returns: + Dict[str, List[str]] - cluster_id -> list of dialog strings """ - naive_op = NaiveOp(self.llm_provider, self.embedder, self.vector_db) - basic_info = naive_op.extract_basic_info(new_dialog) - topic_info = naive_op.extract_topic_info(new_dialog) - explicit_pref = naive_op.extract_explicit_preference(new_dialog) - - dialogue_vectors = naive_op.generate_dialogue_vectors([basic_info]) - concat_info = naive_op.concat_infos([basic_info], [explicit_pref], [topic_info], dialogue_vectors)[0] - - vec_db_item = VecDBItem( - id=concat_info.get("dialog_id", ""), - vector=concat_info.get("dialog_vector", []), - payload={ - "dialog_id": concat_info.get("dialog_id", ""), - "dialog_msgs": concat_info.get("dialog_msgs", []), - "dialog_str": concat_info.get("dialog_str", ""), - "created_at": concat_info.get("created_at", datetime.now().isoformat()), - "user_id": concat_info.get("user_id", ""), - "type": "explicit_preference" - } - ) + result = {} + if not cluster_results: + return result + + if input_type == "original": + # Use all original data in each cluster + for cluster in cluster_results: + cluster_dialogs = [] + for item in cluster.items: + # item contains {"vector": ..., "index": ...} + original_index = item["index"] + dialog_str = informations[original_index].get("dialog_str", "") + cluster_dialogs.append(dialog_str) + if cluster_dialogs: + result[cluster.cluster_id] = cluster_dialogs + + elif input_type == "knn": + # Use knn search from cluster center to find k nearest neighbors + for cluster in cluster_results: + # Extract all vectors from items in this cluster + item_vectors = [item["vector"] for item in cluster.items] + + if not item_vectors: + result[cluster.cluster_id] = [] + continue + + # Use clusterer to find k nearest neighbors to cluster center + knn_results = self.clusterer.search_knn_by_center_embeddings( + center_emb=cluster.center_vector, + vectors=item_vectors, + top_k=min(k, len(item_vectors)) + ) + + # Get dialog strings for knn results + cluster_dialogs = [] + for knn_item in knn_results: + # knn_item contains {"index": idx_in_item_vectors, "distance": ..., "vector": ...} + item_idx = knn_item["index"] # index in item_vectors + original_index = cluster.items[item_idx]["index"] # original index in informations + dialog_str = informations[original_index].get("dialog_str", "") + cluster_dialogs.append(dialog_str) + if cluster_dialogs: + result[cluster.cluster_id] = cluster_dialogs + + else: + raise ValueError(f"Invalid input type: {input_type}") + + return result + + def _process_single_implicit_cluster(self, cluster_id: str, cluster_dialogs: List[str]) -> Dict[str, Any]: + """Process a single implicit cluster.""" + try: + result = self.extractor.extract_implicit_preferences(cluster_dialogs) + return {"cluster_id": cluster_id, "implicit_exract_result": result} + except Exception as e: + print(f"Error processing implicit cluster {cluster_id}: {e}") + return {"cluster_id": cluster_id, "implicit_exract_result": None} + + def _extract_implicit_preferences(self, implicit_extract_inputs: Dict[str, List[str]], max_workers: int = 10) -> List[Dict[str, Any]]: + """Extract implicit preferences from implicit extract inputs using thread pool.""" + if not implicit_extract_inputs: + return [] - # retrieve the dialog - dialog_vector = dialogue_vectors[0]["dialog_vector"] # This is already a List[float] - dialog_items = self.vector_db.search(dialog_vector, "explicit_preference", top_k=1) + results = [] + with ThreadPoolExecutor(max_workers=min(max_workers, len(implicit_extract_inputs))) as executor: + future_to_cluster = { + executor.submit(self._process_single_implicit_cluster, cluster_id, cluster_dialogs): cluster_id + for cluster_id, cluster_dialogs in implicit_extract_inputs.items() + } + + for future in as_completed(future_to_cluster): + try: + result = future.result() + if result["implicit_exract_result"] is not None: + results.append(result) + except Exception as e: + cluster_id = future_to_cluster[future] + print(f"Error processing implicit cluster {cluster_id}: {e}") + continue - # Extract dialog_str from retrieved items - if dialog_items: - # Get the first (most similar) item - retrieved_item = dialog_items[0] - # Extract dialog_str from payload - old_msgs = retrieved_item.payload.get("dialog_msgs", "") - is_same = naive_op.judge_update_or_add(old_msgs, new_dialog) - if is_same: - # Extract ID from the retrieved item - item_id = retrieved_item.id - self.vector_db.update("explicit_preference", item_id, vec_db_item) - return - - self.vector_db.add("explicit_preference", vec_db_item) + return results + def _process_single_topic_cluster(self, cluster_id: str, cluster_dialogs: List[str]) -> Dict[str, Any]: + """Process a single topic cluster.""" + try: + result = self.extractor.extract_topic_preferences(cluster_dialogs) + return {"cluster_id": cluster_id, "topic_exract_result": result} + except Exception as e: + print(f"Error processing topic cluster {cluster_id}: {e}") + return {"cluster_id": cluster_id, "topic_exract_result": None} - def slow_update(self): - """Retrieve all dialog info from the expicit preference collection, - and reconstruct the implicit preference collection, topic collection and user preference collection. - """ - clusterer = HDBSCANClusterer() - naive_op = NaiveOp(self.llm_provider, self.embedder, self.vector_db) - all_data = self.vector_db.get_all("explicit_preference") - - user_id = all_data[0].payload.get("user_id", "") - - # Convert VecDBItem list to whole_infos format - whole_infos = [item.payload for item in all_data] + def _extract_topic_preferences(self, topic_extract_inputs: Dict[str, List[str]], max_workers: int = 10) -> List[Dict[str, Any]]: + """Extract topic preferences from topic extract inputs using thread pool.""" + if not topic_extract_inputs: + return [] - # Perform clustering - implicit_clusters = naive_op.implicit_cluster(clusterer, whole_infos) - topic_clusters = naive_op.topic_cluster(clusterer, whole_infos) + results = [] + with ThreadPoolExecutor(max_workers=min(max_workers, len(topic_extract_inputs))) as executor: + future_to_cluster = { + executor.submit(self._process_single_topic_cluster, cluster_id, cluster_dialogs): cluster_id + for cluster_id, cluster_dialogs in topic_extract_inputs.items() + } + + for future in as_completed(future_to_cluster): + try: + result = future.result() + if result["topic_exract_result"] is not None: + results.append(result) + except Exception as e: + cluster_id = future_to_cluster[future] + print(f"Error processing topic cluster {cluster_id}: {e}") + continue - # Extract implicit preferences - implicit_clusters = naive_op.extract_implicit_preferences(implicit_clusters) + return results + + def _store_preferences(self, + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + user_prefs: List[Dict[str, Any]], + user_id: str): + """Store all preferences in memory.""" + + # Store implicit preferences + if implicit_prefs: + implicit_memories = [] + for pref in implicit_prefs: + # Create VecDBItem directly using existing embedding + vec_db_item = VecDBItem( + id=pref.get("cluster_id", ""), + vector=pref.get("center_vector", []), + payload={ + "cluster_id": pref.get("cluster_id", ""), + "center_dialog": pref.get("center_dialog", ""), + "center_vector": pref.get("center_vector", []), + "implicit_preference": pref.get("implicit_preference", ""), + "created_at": pref.get("created_at", datetime.now().isoformat()), + "user_id": user_id, + "preference_type": "implicit_preference" + } + ) + implicit_memories.append(vec_db_item) + + # Store in implicit_preference collection + self.vector_db.add("implicit_preference", implicit_memories) - # Extract topic preferences - topic_clusters = naive_op.extract_topic_preferences(topic_clusters) + # Store topic preferences + if topic_prefs: + topic_memories = [] + for pref in topic_prefs: + # Create VecDBItem directly using existing embedding + vec_db_item = VecDBItem( + id=pref.get("cluster_id", ""), + vector=pref.get("center_vector", []), + payload={ + "cluster_id": pref.get("cluster_id", ""), + "center_dialog": pref.get("center_dialog", ""), + "center_vector": pref.get("center_vector", []), + "topic_cluster_name": pref.get("topic_cluster_name", ""), + "topic_cluster_description": pref.get("topic_cluster_description", ""), + "topic_preferences": pref.get("topic_preferences", ""), + "created_at": pref.get("created_at", datetime.now().isoformat()), + "user_id": user_id, + "preference_type": "topic_preference" + } + ) + topic_memories.append(vec_db_item) + + # Store in topic_preference collection + self.vector_db.add("topic_preference", topic_memories) - # Extract user preferences - user_preferences = naive_op.extract_user_preferences(topic_clusters) + # Store user preferences + if user_prefs: + user_memories = [] + for pref in user_prefs: + # Create VecDBItem with zero vector (user preferences don't need vector search) + # Use zero vector to satisfy Milvus collection dimension requirements + # Get embedding dimension from embedder config + embedding_dim = getattr(self.embedder.config, 'embedding_dims', 768) # Default to 768 if not available + zero_vector = [0.0] * embedding_dim + vec_db_item = VecDBItem( + id=user_id, + vector=zero_vector, + payload={ + "user_id": user_id, + "user_preferences": pref.get("user_preferences", ""), + "created_at": datetime.now().isoformat(), + "preference_type": "user_preference" + } + ) + user_memories.append(vec_db_item) + + # Store in user_preference collection + self.vector_db.add("user_preference", user_memories) + + def _generate_memory_summary(self, explicit_prefs: List[Dict[str, Any]], + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + user_prefs: List[Dict[str, Any]]) -> str: + """Generate a summary of the built memory.""" + summary = { + "memory_build_summary": { + "explicit_preferences_count": len(explicit_prefs), + "implicit_preferences_count": len(implicit_prefs), + "topic_preferences_count": len(topic_prefs), + "user_preferences_count": len(user_prefs), + "build_timestamp": datetime.now().isoformat() + } + } + return json.dumps(summary, ensure_ascii=False, indent=2) + + def slow_update(self): + """Retrieve all dialog info from the expicit preference collection, + and reconstruct the implicit preference collection, topic collection and user preference collection. + """ + # refresh the implicit preference collection, topic collection and user preference collection self.vector_db.delete_collection("implicit_preference") self.vector_db.delete_collection("topic_preference") @@ -117,9 +286,33 @@ def slow_update(self): self.vector_db.create_collection_by_name("user_preference") + all_data = self.vector_db.get_all("explicit_preference") + user_id = all_data[0].payload.get("user_id", "") + informations = [item.payload for item in all_data] + + # Perform clustering + implicit_clusters = self._implicit_cluster(informations) + topic_clusters = self._topic_cluster(informations) + + # create extract inputs for each implicit and topic cluster + implicit_extract_inputs = self._create_cluster_extract_input(implicit_clusters, informations, "original") + topic_extract_inputs = self._create_cluster_extract_input(topic_clusters, informations, "original") + + # Extract preferences + if implicit_extract_inputs: + implicit_cluster_info = self._extract_implicit_preferences(implicit_extract_inputs) + if topic_extract_inputs: + topic_cluster_info = self._extract_topic_preferences(topic_extract_inputs) + + # Extract user preferences + user_preferences = self.extractor.extract_user_preferences(topic_cluster_info) + + + + # Store all preferences in memory - naive_op.store_preferences( - explicit_prefs=whole_infos, + self._store_preferences( + explicit_prefs=informations, implicit_prefs=implicit_clusters, topic_prefs=topic_clusters, user_prefs=user_preferences, @@ -127,8 +320,8 @@ def slow_update(self): ) # Return summary of built memory - return naive_op.generate_memory_summary( - explicit_prefs=whole_infos, + return self._generate_memory_summary( + explicit_prefs=informations, implicit_prefs=implicit_clusters, topic_prefs=topic_clusters, user_prefs=user_preferences, diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py new file mode 100644 index 000000000..380e5a890 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -0,0 +1,14 @@ +from memos.types import MessageList + + + +def convert_messages_to_string(messages: MessageList) -> str: + """Convert a list of messages to a string.""" + message_text = "" + for message in messages: + if message["role"] == "user": + message_text += f"Query: {message['content']}\n" + elif message["role"] == "assistant": + message_text += f"Answer: {message['content']}\n" + message_text = message_text.strip() + return message_text \ No newline at end of file diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index e5260536a..327ade501 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -17,7 +17,7 @@ from memos.llms.base import BaseLLM from memos.vec_dbs.factory import QdrantVecDB, VecDBFactory, MilvusVecDB from memos.vec_dbs.item import VecDBItem -from memos.memories.textual.prefer_text_memory.factory import BuilderFactory, RetrieverFactory, UpdaterFactory, AssemblerFactory +from memos.memories.textual.prefer_text_memory.factory import AdderFactory, ExtractorFactory, RetrieverFactory, UpdaterFactory, AssemblerFactory class PreferenceTextMemory(BaseTextMemory): @@ -33,8 +33,15 @@ def __init__(self, config: PreferenceTextMemoryConfig): self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = \ EmbedderFactory.from_config(config.embedder) - self.builder = BuilderFactory.from_config( - config.builder, + self.extractor = ExtractorFactory.from_config( + config.extractor, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db + ) + + self.adder = AdderFactory.from_config( + config.adder, llm_provider=self.extractor_llm, embedder=self.embedder, vector_db=self.vector_db @@ -49,7 +56,8 @@ def __init__(self, config: PreferenceTextMemoryConfig): config.updater, llm_provider=self.extractor_llm, embedder=self.embedder, - vector_db=self.vector_db + vector_db=self.vector_db, + extractor=self.extractor ) self.assembler = AssemblerFactory.from_config( config.assembler, @@ -58,16 +66,14 @@ def __init__(self, config: PreferenceTextMemoryConfig): vector_db=self.vector_db ) - def build_preferences(self, history: ChatHistory) -> None: - """Build memory from the original dialogs. (Initialize memory) - + def get_memory(self, messages: MessageList, type: str, info: dict[str, Any]) -> list[TextualMemoryItem]: + """Get memory based on the messages. Args: - history: The chat history to build memory from. - - Returns: - Memory content string formatted according to the build strategy + messages (MessageList): The messages to get memory from. + type (str): The type of memory to get. + info (dict[str, Any]): The info to get memory. """ - return self.builder.build(history) + self.extractor.extract(messages, type, info) def update_preferences(self, new_dialog: MessageList) -> None: """Update a memory by new dialog. @@ -75,15 +81,6 @@ def update_preferences(self, new_dialog: MessageList) -> None: new_dialog (MessageList): The new dialog to update. """ self.updater.update(new_dialog) - - def search_preferences(self, query: str, top_k: int, info=None) -> list[TextualMemoryItem]: - """Search for preferences based on a query. - Args: - query (str): The query to search for. - top_k (int): The number of top results to return. - info (dict): Leave a record of memory consumption. - """ - return self.retriever.retrieve(query, top_k, info) def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. @@ -140,27 +137,7 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: Args: memories: List of TextualMemoryItem objects or dictionaries to add. """ - if self.config.backend == "naive": - pass - else: - memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] - - # Memory encode - embed_memories = self.embedder.embed([m.memory for m in memory_items]) - - # Create vector db items - vec_db_items = [] - for item, emb in zip(memory_items, embed_memories, strict=True): - vec_db_items.append( - VecDBItem( - id=item.id, - payload=item.model_dump(), - vector=emb, - ) - ) - - # Add to vector db - self.vector_db.add(vec_db_items) + return self.adder.add(memories) def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: """Update a memory by memory_id.""" diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index e2d960400..34480f502 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -6,7 +6,7 @@ Please extract the user's explicit preferences from the Q and implicit preferences from the A, and output JSON strictly according to the requirements. # Extraction Rules -Explicit preferences (explicit_preferences): Extract only from the user's Q. Including but not limited to: +Explicit preferences, extract only from the user's Q. Including but not limited to: - Role descriptions (e.g., "You are a history teacher") - Style constraints (e.g., "humorous style", "academic style") - Format requirements (e.g., "table", "Markdown") @@ -17,7 +17,7 @@ # Output Format { - "explicit_preferences": { + "explicit_preference": { "role": "", "style": "", "format": "", @@ -66,7 +66,7 @@ # Output Format: { - "implicit_preferences": [ + "implicit_preference": [ "preference 1", "preference 2", "preference 3" @@ -108,7 +108,7 @@ { "topic_cluster_name": "", "topic_cluster_description": "", - "topic_preferences": "" + "topic_preference": "" } # Notes: @@ -157,7 +157,7 @@ # Output Format: { - "user_preferences": "Write the synthesized highest-level common preferences here, covering content preferences, interaction style, and value orientations" + "user_preference": "Write the synthesized highest-level common preferences here, covering content preferences, interaction style, and value orientations" } # Notes: @@ -230,7 +230,7 @@ 1. Current query preferences: Constraints explicitly stated in the current user question 2. Related dialogue preferences: Preference references from Q&A pairs related to the current query 3. Related topic preferences: Preference references from topics related to the current query -4. User preferences: Common preference references from the user's historical conversations +4. User preference: Common preference references from the user's historical conversations Priority: Current query preferences > Related dialogue preferences > Related topic preferences > User preferences > Implicit preferences diff --git a/src/memos/types.py b/src/memos/types.py index 0897cecdd..140911978 100644 --- a/src/memos/types.py +++ b/src/memos/types.py @@ -53,3 +53,4 @@ class MOSSearchResult(TypedDict): text_mem: list[dict[str, str | list[TextualMemoryItem]]] act_mem: list[dict[str, str | list[ActivationMemoryItem]]] para_mem: list[dict[str, str | list[ParametricMemoryItem]]] + pref_mem: list[dict[str, str | list[TextualMemoryItem]]] From 6457f59b38dcde65661e207153ecf0d42afbab65 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 18 Sep 2025 11:57:46 +0800 Subject: [PATCH 06/62] modify pref module --- .../textual/prefer_text_memory/assemble.py | 49 ++++- .../textual/prefer_text_memory/clustering.py | 4 + .../textual/prefer_text_memory/extractor.py | 10 +- .../textual/prefer_text_memory/retrievers.py | 12 +- .../textual/prefer_text_memory/updater.py | 171 +++++++++--------- src/memos/memories/textual/preference.py | 20 +- src/memos/templates/prefer_complete_prompt.py | 2 +- 7 files changed, 155 insertions(+), 113 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py index ac7762046..2201dc270 100644 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod +from typing import Any, Dict, List +import json +from memos.templates.prefer_complete_prompt import NAIVE_PREFERENCE_INTEGRATION_PROMPT from memos.memories.textual.item import TextualMemoryItem -from memos.memories.textual.prefer_text_memory.naive_op import NaiveOp class BaseAssembler(ABC): """Abstract base class for assemblers.""" @@ -29,22 +31,53 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.embedder = embedder self.vector_db = vector_db + def _preference_integration(self, query: str, + explicit_prefs: List[Dict[str, Any]], + implicit_prefs: List[Dict[str, Any]], + topic_prefs: List[Dict[str, Any]], + user_prefs: List[Dict[str, Any]]) -> str: + """Integrate preferences.""" + explicit_prefs_str = json.dumps(explicit_prefs, ensure_ascii=False, indent=2) + implicit_prefs_str = json.dumps(implicit_prefs, ensure_ascii=False, indent=2) + topic_prefs_str = json.dumps(topic_prefs, ensure_ascii=False, indent=2) + user_prefs_str = json.dumps(user_prefs, ensure_ascii=False, indent=2) + + prompt = NAIVE_PREFERENCE_INTEGRATION_PROMPT.format( + query_preference=query, + explicit_preference=explicit_prefs_str, + implicit_preference=implicit_prefs_str, + topic_preference=topic_prefs_str, + user_preference=user_prefs_str + ) + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + result = json.loads(response) + return result["final_prompt"] + except Exception as e: + print(f"Error in preference_integration: {e}") + return "" + def assemble(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: """Assemble query and memories into a single memory.""" - explicit_prefs = [{"dialog_str": memory.metadata.dialog_str, "explicit_preference": memory.metadata.explicit_preference} for memory in memories if memory.metadata.preference_type == "explicit_preference"] - implicit_prefs = [{"center_dialog_str": memory.metadata.center_dialog, "implicit_preference": memory.metadata.implicit_preference} for memory in memories if memory.metadata.preference_type == "implicit_preference"] - topic_prefs = [{"center_dialog_str": memory.metadata.center_dialog, "topic_preferences": memory.metadata.topic_preferences} for memory in memories if memory.metadata.preference_type == "topic_preference"] - user_prefs = [{"user_preferences": memory.metadata.user_preferences} for memory in memories if memory.metadata.preference_type == "user_preference"] - - naive_op = NaiveOp(self.llm_provider, self.embedder, self.vector_db) + explicit_prefs = [{"dialog_str": memory.metadata.dialog_str, + "explicit_preference": memory.metadata.explicit_preference} + for memory in memories if memory.metadata.preference_type == "explicit_preference"] + implicit_prefs = [{"center_dialog_str": memory.metadata.center_dialog, + "implicit_preference": memory.metadata.implicit_preference} + for memory in memories if memory.metadata.preference_type == "implicit_preference"] + topic_prefs = [{"center_dialog_str": memory.metadata.center_dialog, + "topic_preferences": memory.metadata.topic_preferences} + for memory in memories if memory.metadata.preference_type == "topic_preference"] + user_prefs = [{"user_preferences": memory.metadata.user_preferences} + for memory in memories if memory.metadata.preference_type == "user_preference"] if assemble_strategy == "raw": return memories elif assemble_strategy == "semi": return f"Query: {query}\n\n In addition to the above Query, you can refer to the following preference below memories. \n\nMemories: {memories}. \n\nWhen encountering conflicts, prioritize following the query." elif assemble_strategy == "full": - return naive_op.preference_integration(query, explicit_prefs, implicit_prefs, topic_prefs, user_prefs) + return self._preference_integration(query, explicit_prefs, implicit_prefs, topic_prefs, user_prefs) else: raise ValueError(f"Invalid assemble strategy: {assemble_strategy}") diff --git a/src/memos/memories/textual/prefer_text_memory/clustering.py b/src/memos/memories/textual/prefer_text_memory/clustering.py index 0e54cd95b..bd64ce6ec 100644 --- a/src/memos/memories/textual/prefer_text_memory/clustering.py +++ b/src/memos/memories/textual/prefer_text_memory/clustering.py @@ -7,6 +7,8 @@ from sklearn.cluster import HDBSCAN from sklearn.neighbors import NearestNeighbors +from memos.types import MessageList + class ClusterResult(BaseModel): cluster_id: str @@ -16,6 +18,8 @@ class ClusterResult(BaseModel): size: int created_at: str items: List[Dict[str, Any]] + center_dialog_msgs: MessageList = None + center_dialog_str: str = None class BaseClusterer(ABC): """Abstract base class for clustering algorithms.""" diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 65dbd5037..53182e3e7 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -103,18 +103,14 @@ def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> L """Extract user-level preferences.""" if not topic_preferences: return [] - cluster_infos = [{ - "topic_cluster_name": cluster["topic_cluster_name"], - "topic_cluster_description": cluster["topic_cluster_description"], - "topic_preferences": cluster["topic_preferences"]} - for cluster in topic_preferences] - prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace("{cluster_info}", json.dumps(cluster_infos, ensure_ascii=False, indent=2)) + + prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace("{cluster_info}", json.dumps(topic_preferences, ensure_ascii=False, indent=2)) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) result = json.loads(response) - if result.get("user_preferences"): + if result.get("user_preference"): return result except Exception as e: print(f"Error processing user preferences: {topic_preferences}\n{e}") diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 471f7345b..4133dbfe1 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -38,10 +38,14 @@ def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[Te else: user_prefs = [] - explicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("dialog_str", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in explicit_prefs] - implicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in implicit_prefs] - topic_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs] - user_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("user_preferences", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in user_prefs] + explicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in explicit_prefs] + implicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in implicit_prefs] + topic_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs] + user_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("user_preferences", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in user_prefs] return explicit_prefs + implicit_prefs + topic_prefs + user_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index a33f586ac..f4d53e4f6 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -41,14 +41,22 @@ def _implicit_cluster(self, informations: List[Dict[str, Any]]) -> List[Dict[str vectors = [info.get("dialog_vector") for info in informations] if not vectors: return [] - return self.clusterer.cluster(vectors) + res = self.clusterer.cluster(vectors) + for cluster in res: + cluster.center_dialog_msgs = informations[cluster.center_index].get("dialog_msgs", []) + cluster.center_dialog_str = informations[cluster.center_index].get("dialog_str", "") + return res def _topic_cluster(self, informations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Topic cluster.""" vectors = [info.get("topic_vector") for info in informations] if not vectors: return [] - return self.clusterer.cluster(vectors) + res = self.clusterer.cluster(vectors) + for cluster in res: + cluster.center_dialog_msgs = informations[cluster.center_index].get("dialog_msgs", []) + cluster.center_dialog_str = informations[cluster.center_index].get("dialog_str", "") + return res def _create_cluster_extract_input(self, cluster_results: List[ClusterResult], informations: List[Dict[str, Any]], input_type: str, k: int = 5) -> Dict[str, List[str]]: """Create cluster extract input. @@ -175,96 +183,92 @@ def _extract_topic_preferences(self, topic_extract_inputs: Dict[str, List[str]], return results + def _extract_user_preferences(self, topic_cluster_pref_infos: Dict[str, Any]) -> List[Dict[str, Any]]: + """Extract user preferences from topic cluster info.""" + topic_cluster_pref = [info["topic_exract_result"] for info in topic_cluster_pref_infos] + return self.extractor.extract_user_preferences(topic_cluster_pref) + def _store_preferences(self, - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - user_prefs: List[Dict[str, Any]], - user_id: str): - """Store all preferences in memory.""" - - # Store implicit preferences - if implicit_prefs: - implicit_memories = [] - for pref in implicit_prefs: - # Create VecDBItem directly using existing embedding - vec_db_item = VecDBItem( - id=pref.get("cluster_id", ""), - vector=pref.get("center_vector", []), + implicit_clusters, + topic_clusters, + implicit_cluster_prefs, + topic_cluster_prefs, + user_prefs, + user_id): + """Create store data.""" + implicit_memories = [] + topic_memories = [] + + if implicit_clusters: + for cluster in implicit_clusters: + pref = implicit_cluster_prefs[cluster.cluster_id] + mem = VecDBItem( + id=cluster.cluster_id, + vector=cluster.center_vector, payload={ - "cluster_id": pref.get("cluster_id", ""), - "center_dialog": pref.get("center_dialog", ""), - "center_vector": pref.get("center_vector", []), + "cluster_id": cluster.cluster_id, + "center_dialog_msgs": cluster.center_dialog_msgs, + "center_dialog_str": cluster.center_dialog_str, + "center_vector": cluster.center_vector, "implicit_preference": pref.get("implicit_preference", ""), - "created_at": pref.get("created_at", datetime.now().isoformat()), + "created_at": cluster.created_at, "user_id": user_id, + "size": cluster.size, "preference_type": "implicit_preference" } ) - implicit_memories.append(vec_db_item) - - # Store in implicit_preference collection + implicit_memories.append(mem) + self.vector_db.add("implicit_preference", implicit_memories) - - # Store topic preferences - if topic_prefs: - topic_memories = [] - for pref in topic_prefs: - # Create VecDBItem directly using existing embedding - vec_db_item = VecDBItem( - id=pref.get("cluster_id", ""), - vector=pref.get("center_vector", []), + + if topic_clusters: + for cluster in topic_clusters: + pref = topic_cluster_prefs[cluster.cluster_id] + mem = VecDBItem( + id=cluster.cluster_id, + vector=cluster.center_vector, payload={ - "cluster_id": pref.get("cluster_id", ""), - "center_dialog": pref.get("center_dialog", ""), - "center_vector": pref.get("center_vector", []), + "cluster_id": cluster.cluster_id, + "center_dialog_msgs": cluster.center_dialog_msgs, + "center_dialog_str": cluster.center_dialog_str, + "center_vector": cluster.center_vector, "topic_cluster_name": pref.get("topic_cluster_name", ""), "topic_cluster_description": pref.get("topic_cluster_description", ""), - "topic_preferences": pref.get("topic_preferences", ""), - "created_at": pref.get("created_at", datetime.now().isoformat()), + "topic_preference": pref.get("topic_preference", ""), + "created_at": cluster.created_at, "user_id": user_id, + "size": cluster.size, "preference_type": "topic_preference" } ) - topic_memories.append(vec_db_item) - - # Store in topic_preference collection + topic_memories.append(mem) + self.vector_db.add("topic_preference", topic_memories) - - # Store user preferences + if user_prefs: - user_memories = [] - for pref in user_prefs: - # Create VecDBItem with zero vector (user preferences don't need vector search) - # Use zero vector to satisfy Milvus collection dimension requirements - # Get embedding dimension from embedder config - embedding_dim = getattr(self.embedder.config, 'embedding_dims', 768) # Default to 768 if not available - zero_vector = [0.0] * embedding_dim - vec_db_item = VecDBItem( - id=user_id, - vector=zero_vector, - payload={ - "user_id": user_id, - "user_preferences": pref.get("user_preferences", ""), - "created_at": datetime.now().isoformat(), - "preference_type": "user_preference" - } - ) - user_memories.append(vec_db_item) - - # Store in user_preference collection - self.vector_db.add("user_preference", user_memories) - - def _generate_memory_summary(self, explicit_prefs: List[Dict[str, Any]], - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - user_prefs: List[Dict[str, Any]]) -> str: + mem = VecDBItem( + id=user_id, + vector=[0.0] * self.embedder.config.embedding_dims, + payload={ + "user_id": user_id, + "user_preference": user_prefs.get("user_preference", ""), + "created_at": datetime.now().isoformat(), + "preference_type": "user_preference" + } + ) + self.vector_db.add("user_preference", [mem]) + + def _generate_memory_summary(self, explicit_infos: List[Dict[str, Any]], + implicit_infos: List[Dict[str, Any]], + topic_infos: List[Dict[str, Any]], + user_infos: List[Dict[str, Any]]) -> str: """Generate a summary of the built memory.""" summary = { "memory_build_summary": { - "explicit_preferences_count": len(explicit_prefs), - "implicit_preferences_count": len(implicit_prefs), - "topic_preferences_count": len(topic_prefs), - "user_preferences_count": len(user_prefs), + "explicit_preferences_count": len(explicit_infos), + "implicit_preferences_count": len(implicit_infos), + "topic_preferences_count": len(topic_infos), + "user_preferences_count": len(user_infos), "build_timestamp": datetime.now().isoformat() } } @@ -300,31 +304,32 @@ def slow_update(self): # Extract preferences if implicit_extract_inputs: - implicit_cluster_info = self._extract_implicit_preferences(implicit_extract_inputs) + implicit_cluster_prefs = self._extract_implicit_preferences(implicit_extract_inputs) if topic_extract_inputs: - topic_cluster_info = self._extract_topic_preferences(topic_extract_inputs) + topic_cluster_prefs = self._extract_topic_preferences(topic_extract_inputs) # Extract user preferences - user_preferences = self.extractor.extract_user_preferences(topic_cluster_info) + user_prefs = self._extract_user_preferences(topic_cluster_prefs) # Store all preferences in memory self._store_preferences( - explicit_prefs=informations, - implicit_prefs=implicit_clusters, - topic_prefs=topic_clusters, - user_prefs=user_preferences, + implicit_clusters=implicit_clusters, + topic_clusters=topic_clusters, + implicit_cluster_prefs=implicit_cluster_prefs, + topic_cluster_prefs=topic_cluster_prefs, + user_prefs=user_prefs, user_id=user_id, ) # Return summary of built memory return self._generate_memory_summary( - explicit_prefs=informations, - implicit_prefs=implicit_clusters, - topic_prefs=topic_clusters, - user_prefs=user_preferences, + explicit_infos=informations, + implicit_infos=implicit_clusters, + topic_infos=topic_clusters, + user_infos=user_prefs, ) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 327ade501..879b75a3d 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -111,7 +111,7 @@ def load(self, dir: str) -> None: """ # For preference memory, we don't need to load from files # as the data is stored in the vector database - pass + raise NotImplementedError def dump(self, dir: str) -> None: """Dump memories to the specified directory. @@ -120,7 +120,7 @@ def dump(self, dir: str) -> None: """ # For preference memory, we don't need to dump to files # as the data is stored in the vector database - pass + raise NotImplementedError def extract(self, messages: MessageList) -> list[TextualMemoryItem]: """Extract memories based on the messages. @@ -129,7 +129,7 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]: Returns: list[TextualMemoryItem]: List of extracted memory items. """ - pass + raise NotImplementedError def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: """Add memories. @@ -141,7 +141,7 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: """Update a memory by memory_id.""" - pass + raise NotImplementedError def get(self, memory_id: str) -> TextualMemoryItem: """Get a memory by its ID. @@ -150,7 +150,7 @@ def get(self, memory_id: str) -> TextualMemoryItem: Returns: TextualMemoryItem: The memory with the given ID. """ - pass + raise NotImplementedError def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: """Get memories by their IDs. @@ -159,28 +159,28 @@ def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: Returns: list[TextualMemoryItem]: List of memories with the specified IDs. """ - pass + raise NotImplementedError def get_all(self) -> list[TextualMemoryItem]: """Get all memories. Returns: list[TextualMemoryItem]: List of all memories. """ - pass + raise NotImplementedError def delete(self, memory_ids: list[str]) -> None: """Delete memories. Args: memory_ids (list[str]): List of memory IDs to delete. """ - pass + raise NotImplementedError def delete_all(self) -> None: """Delete all memories.""" - pass + raise NotImplementedError def drop( self, ) -> None: """Drop all databases.""" - pass \ No newline at end of file + raise NotImplementedError \ No newline at end of file diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 34480f502..d998ad0b5 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -127,7 +127,7 @@ - topic_cluster_name: The name of the topic cluster - topic_cluster_description: The description of the topic cluster -- topic_preferences: Natural language description of preferences in that cluster +- topic_preference: Natural language description of preferences in that cluster Your task is to extract the user's **highest-level common preferences** by focusing on these three key dimensions: From e1f3ecddab3427686c4708e4f4a279ccdc5854e9 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 18 Sep 2025 15:02:08 +0800 Subject: [PATCH 07/62] implement remain preference function --- src/memos/memories/textual/preference.py | 95 ++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 5 deletions(-) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 879b75a3d..017939827 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -12,7 +12,7 @@ from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.base import BaseTextMemory -from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.item import TextualMemoryItem, PreferenceTextualMemoryMetadata from memos.types import ChatHistory, MessageList from memos.llms.base import BaseLLM from memos.vec_dbs.factory import QdrantVecDB, VecDBFactory, MilvusVecDB @@ -20,6 +20,9 @@ from memos.memories.textual.prefer_text_memory.factory import AdderFactory, ExtractorFactory, RetrieverFactory, UpdaterFactory, AssemblerFactory +logger = get_logger(__name__) + + class PreferenceTextMemory(BaseTextMemory): """Preference textual memory implementation for storing and retrieving memories.""" @@ -111,7 +114,26 @@ def load(self, dir: str) -> None: """ # For preference memory, we don't need to load from files # as the data is stored in the vector database - raise NotImplementedError + try: + memory_file = os.path.join(dir, self.config.memory_filename) + + if not os.path.exists(memory_file): + logger.warning(f"Memory file not found: {memory_file}") + return + + with open(memory_file, encoding="utf-8") as f: + memories = json.load(f) + for collection_name, items in memories.items(): + vec_db_items = [VecDBItem.from_dict(m) for m in items] + self.vector_db.add(collection_name, vec_db_items) + logger.info(f"Loaded {len(items)} memories from {collection_name} in {memory_file}") + + except FileNotFoundError: + logger.error(f"Memory file not found in directory: {dir}") + except json.JSONDecodeError as e: + logger.error(f"Error decoding JSON from memory file: {e}") + except Exception as e: + logger.error(f"An error occurred while loading memories: {e}") def dump(self, dir: str) -> None: """Dump memories to the specified directory. @@ -120,7 +142,23 @@ def dump(self, dir: str) -> None: """ # For preference memory, we don't need to dump to files # as the data is stored in the vector database - raise NotImplementedError + try: + all_collections = self.vector_db.list_collections() + json_memories = {} + for collection_name in all_collections: + items = self.vector_db.get_all(collection_name) + json_memories[collection_name] = [memory.to_dict() for memory in items] + + os.makedirs(dir, exist_ok=True) + memory_file = os.path.join(dir, self.config.memory_filename) + with open(memory_file, "w", encoding="utf-8") as f: + json.dump(json_memories, f, indent=4, ensure_ascii=False) + + logger.info(f"Dumped {len(json_memories)} collections, {sum(len(items) for items in json_memories.values())} memories to {memory_file}") + + except Exception as e: + logger.error(f"An error occurred while dumping memories: {e}") + raise def extract(self, messages: MessageList) -> list[TextualMemoryItem]: """Extract memories based on the messages. @@ -151,6 +189,21 @@ def get(self, memory_id: str) -> TextualMemoryItem: TextualMemoryItem: The memory with the given ID. """ raise NotImplementedError + + def get_with_collection_name(self, collection_name: str, memory_id: str) -> TextualMemoryItem: + """Get a memory by its ID and collection name. + Args: + memory_id (str): The ID of the memory to retrieve. + collection_name (str): The name of the collection to retrieve the memory from. + Returns: + TextualMemoryItem: The memory with the given ID and collection name. + """ + res = self.vector_db.get_by_id(collection_name, memory_id) + if res is None: + raise ValueError(f"Memory with ID {memory_id} not found in collection {collection_name}") + return TextualMemoryItem(id=res.id, + memory=res.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**res.payload)) def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: """Get memories by their IDs. @@ -160,13 +213,35 @@ def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: list[TextualMemoryItem]: List of memories with the specified IDs. """ raise NotImplementedError + + def get_by_ids_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> list[TextualMemoryItem]: + """Get memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to retrieve the memory from. + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs and collection name. + """ + res = self.vector_db.get_by_ids(collection_name, memory_ids) + if res is None: + raise ValueError(f"Memory with IDs {memory_ids} not found in collection {collection_name}") + return [TextualMemoryItem(id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload)) for memo in res] def get_all(self) -> list[TextualMemoryItem]: """Get all memories. Returns: list[TextualMemoryItem]: List of all memories. """ - raise NotImplementedError + all_collections = self.vector_db.list_collections() + all_memories = {} + for collection_name in all_collections: + items = self.vector_db.get_all(collection_name) + all_memories[collection_name] = [TextualMemoryItem(id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload)) for memo in items] + return all_memories def delete(self, memory_ids: list[str]) -> None: """Delete memories. @@ -174,10 +249,20 @@ def delete(self, memory_ids: list[str]) -> None: memory_ids (list[str]): List of memory IDs to delete. """ raise NotImplementedError + + def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: + """Delete memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to delete the memory from. + memory_ids (list[str]): List of memory IDs to delete. + """ + self.vector_db.delete(collection_name, memory_ids) def delete_all(self) -> None: """Delete all memories.""" - raise NotImplementedError + for collection_name in self.vector_db.config.collection_name: + self.vector_db.delete_collection(collection_name) + self.vector_db.create_collection() def drop( self, From d412c7878effa2ad43f92a92e4392765f1ca9582 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 18 Sep 2025 15:14:01 +0800 Subject: [PATCH 08/62] modify preference.py --- src/memos/memories/textual/preference.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 017939827..d385eb69e 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -8,7 +8,11 @@ from typing import Any from memos.configs.memory import PreferenceTextMemoryConfig -from memos.embedders.factory import EmbedderFactory, OllamaEmbedder, ArkEmbedder, SenTranEmbedder, UniversalAPIEmbedder +from memos.embedders.factory import (EmbedderFactory, + OllamaEmbedder, + ArkEmbedder, + SenTranEmbedder, + UniversalAPIEmbedder) from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.base import BaseTextMemory @@ -17,7 +21,11 @@ from memos.llms.base import BaseLLM from memos.vec_dbs.factory import QdrantVecDB, VecDBFactory, MilvusVecDB from memos.vec_dbs.item import VecDBItem -from memos.memories.textual.prefer_text_memory.factory import AdderFactory, ExtractorFactory, RetrieverFactory, UpdaterFactory, AssemblerFactory +from memos.memories.textual.prefer_text_memory.factory import (AdderFactory, + ExtractorFactory, + RetrieverFactory, + UpdaterFactory, + AssemblerFactory) logger = get_logger(__name__) From 0e502f441dc7cc9ac59544e12095b18020183a4e Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 18 Sep 2025 15:35:13 +0800 Subject: [PATCH 09/62] modify bug in milvus --- src/memos/vec_dbs/milvus.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 7b529cdc1..7bb61a91c 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -62,6 +62,7 @@ def create_collection_by_name(self, collection_name: str) -> None: dimension=self.config.vector_dimension, metric_type=self._get_metric_type(), id_type="string", # Use string ID type, align with VecDBItem id type + max_length=self.config.max_length, # Use max_length from config ) def list_collections(self) -> list[str]: From 83a9a3bdd52dec9dcfcca9c471e5bbc3a9381e21 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 18 Sep 2025 20:51:42 +0800 Subject: [PATCH 10/62] finish debug --- .../textual/prefer_text_memory/adder.py | 18 +- .../textual/prefer_text_memory/clustering.py | 2 +- .../textual/prefer_text_memory/extractor.py | 22 +- .../textual/prefer_text_memory/factory.py | 4 +- .../textual/prefer_text_memory/naive_op.py | 439 ------------------ .../textual/prefer_text_memory/spliter.py | 2 +- .../textual/prefer_text_memory/updater.py | 62 +-- src/memos/memories/textual/preference.py | 42 +- src/memos/templates/prefer_complete_prompt.py | 18 +- src/memos/vec_dbs/milvus.py | 1 + 10 files changed, 87 insertions(+), 523 deletions(-) delete mode 100644 src/memos/memories/textual/prefer_text_memory/naive_op.py diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index c3d91ca05..79f9959aa 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -35,36 +35,34 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.embedder = embedder self.vector_db = vector_db - def _judge_update_or_add(self, old_msg: MessageList, new_msg: MessageList) -> bool: + def _judge_update_or_add(self, old_msg: str, new_msg: str) -> bool: """Judge if the new message expresses the same core content as the old message.""" - # Convert messages to string format for comparison - old_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in old_msg]) - new_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in new_msg]) - # Use the template prompt with placeholders - prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_str).replace("{new_information}", new_str) + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_msg).replace("{new_information}", new_msg) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) response = result.get("is_same", False) return response if isinstance(response, bool) else response == "true" except Exception as e: print(f"Error in judge_update_or_add: {e}") # Fallback to simple string comparison - return old_str == new_str + return old_msg == new_msg def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: """Process a single memory and return its ID if added successfully.""" try: - payload = {**memory.metadata.model_dump()} + payload = memory.to_dict()["metadata"] vec_db_item = VecDBItem( id=memory.id, vector=memory.metadata.dialog_vector, payload=payload - ) + ) - recall = self.vector_db.search(memory.metadata.dialog_vector, "explicit_preference", top_k=1)[0] + search_results = self.vector_db.search(memory.metadata.dialog_vector, "explicit_preference", top_k=1) + recall = search_results[0] if search_results else None if not recall or (recall.score is not None and recall.score < 0.5): self.vector_db.update("explicit_preference", memory.id, vec_db_item) return memory.id diff --git a/src/memos/memories/textual/prefer_text_memory/clustering.py b/src/memos/memories/textual/prefer_text_memory/clustering.py index bd64ce6ec..beaaf108e 100644 --- a/src/memos/memories/textual/prefer_text_memory/clustering.py +++ b/src/memos/memories/textual/prefer_text_memory/clustering.py @@ -106,7 +106,7 @@ def cluster(self, vectors: List[List[float]], **kwargs) -> List[ClusterResult]: _, indices = nn.kneighbors([geometric_center_vector.tolist()]) closest_idx = indices[0][0] center_index = items_in_cluster[closest_idx]["index"] - center_vector = cluster_vectors[closest_idx]["vector"] + center_vector = items_in_cluster[closest_idx]["vector"] cluster_info.append(ClusterResult( cluster_id=str(uuid.uuid4()), diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 53182e3e7..202514ede 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -47,11 +47,12 @@ def extract_basic_info(self, qa_pair: MessageList) -> Dict[str, Any]: def extract_topic_info(self, qa_pair: MessageList | str) -> Dict[str, Any]: """Extract topic information from a QA pair.""" - qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, MessageList) else qa_pair + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair prompt = NAIVE_TOPIC_INFO_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) return result except Exception: @@ -59,11 +60,12 @@ def extract_topic_info(self, qa_pair: MessageList | str) -> Dict[str, Any]: def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, Any]: """Extract explicit preference from a QA pair.""" - qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, MessageList) else qa_pair + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) return result except Exception: @@ -71,33 +73,39 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, A def extract_implicit_preferences(self, qa_pairs: list[MessageList] | list[str]) -> List[Dict[str, Any]]: """Extract implicit preferences from cluster qa pairs.""" - qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs, MessageList) else "\n\n".join(qa_pairs) + if not qa_pairs: + return None + qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs[0], dict) else "\n\n".join(qa_pairs) prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) if result.get("implicit_preference"): return result except Exception as e: print(f"Error processing cluster: {qa_pairs}\n{e}") - return "" + return None def extract_topic_preferences(self, qa_pairs: list[MessageList] | list[str]) -> List[Dict[str, Any]]: """Extract topic preferences from cluster qa pairs.""" - qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs, MessageList) else "\n\n".join(qa_pairs) + if not qa_pairs: + return None + qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs[0], dict) else "\n\n".join(qa_pairs) prompt = NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) if result.get("topic_cluster_name"): return result except Exception as e: print(f"Error processing cluster: {qa_pairs}\n{e}") - return "" + return None def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Extract user-level preferences.""" @@ -108,8 +116,8 @@ def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> L try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) - if result.get("user_preference"): return result except Exception as e: diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index f591d40b1..56259bacc 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -68,13 +68,13 @@ class UpdaterFactory(BaseUpdater): } @classmethod - def from_config(cls, config_factory: UpdaterConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseUpdater: + def from_config(cls, config_factory: UpdaterConfigFactory, llm_provider=None, embedder=None, vector_db=None, extractor=None) -> BaseUpdater: """Create a Updater instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") updater_class = cls.backend_to_class[backend] - return updater_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + return updater_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db, extractor=extractor) class AssemblerFactory(BaseAssembler): """Factory class for creating Assembler instances.""" diff --git a/src/memos/memories/textual/prefer_text_memory/naive_op.py b/src/memos/memories/textual/prefer_text_memory/naive_op.py deleted file mode 100644 index 08c4da61c..000000000 --- a/src/memos/memories/textual/prefer_text_memory/naive_op.py +++ /dev/null @@ -1,439 +0,0 @@ -from memos.types import MessageList -from typing import List, Dict, Any, Optional -import uuid -import json -from datetime import datetime - -from memos.templates.prefer_complete_prompt import ( - NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, - NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, - NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT, - NAIVE_USER_PREFERENCE_EXTRACT_PROMPT, - NAIVE_TOPIC_INFO_EXTRACT_PROMPT, - NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, - NAIVE_PREFERENCE_INTEGRATION_PROMPT -) -from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer -from memos.vec_dbs.item import VecDBItem - - -class NaiveOp: - """Naive operation.""" - def __init__(self, llm_provider=None, embedder=None, vector_db=None): - """Initialize the naive operation.""" - self.llm_provider = llm_provider - self.embedder = embedder - self.vector_db = vector_db - - def build_qa_pairs(self, chat_history: MessageList) -> List[MessageList]: - """Build QA pairs from chat history.""" - qa_pairs = [] - current_qa_pair = [] - - for message in chat_history: - role = message["role"] - - if role == "user": - # If we have a complete QA pair, save it - if len(current_qa_pair) >= 2: # At least question + answer - qa_pairs.append(current_qa_pair) - - # Start new QA pair - current_qa_pair = [message] - - elif role == "assistant": - # Add answer to current QA pair - current_qa_pair.append(message) - - # Don't forget the last QA pair if it exists and is complete - if len(current_qa_pair) >= 2: - qa_pairs.append(current_qa_pair) - - return qa_pairs - - def extract_basic_info(self, qa_pair: MessageList) -> Dict[str, Any]: - """Extract basic information from a QA pair (no LLM needed).""" - basic_info = { - "dialog_id": str(uuid.uuid4()), - "dialog_msgs": qa_pair, - "dialog_str": "\n".join([f"{msg['role']}: {msg['content']}" for msg in qa_pair]), - "created_at": datetime.now().isoformat() - } - - return basic_info - - def extract_topic_info(self, qa_pair: MessageList) -> Optional[Dict[str, Any]]: - """Extract topic information from a QA pair.""" - # Convert qa_pair to string format - qa_pair_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in qa_pair]) - prompt = NAIVE_TOPIC_INFO_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - return result - except Exception: - return response - - def extract_explicit_preference(self, qa_pair: MessageList) -> Optional[Dict[str, Any]]: - """Extract explicit preference from a QA pair (LLM-1).""" - # Convert qa_pair to string format - qa_pair_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in qa_pair]) - prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - return result - except Exception: - return response - - def generate_dialogue_vectors(self, basic_infos: List[Dict[str, Any]]) -> List[List[float]]: - """Generate embeddings for dialogue segments.""" - if not self.embedder or not basic_infos: - return [] - - texts = [info.get("dialog_str", "") for info in basic_infos] - embdeddings = self.embedder.embed(texts) - return [{"dialog_vector": embedding} for embedding in embdeddings] - - def generate_topic_vectors(self, topic_infos: List[Dict[str, Any]]) -> List[List[float]]: - """Generate embeddings for topic information.""" - if not self.embedder: - return [] - - texts = [f"{info.get('topic_name', '')} {info.get('topic_description', '')}" for info in topic_infos] - embdeddings = self.embedder.embed(texts) - return [{"topic_vector": embedding} for embedding in embdeddings] - - def concat_infos( - self, - basic_infos: List[Dict[str, Any]] = None, - explicit_preferences: List[Dict[str, Any]] = None, - topic_infos: List[Dict[str, Any]] = None, - dialogue_vectors: List[Dict[str, Any]] = None, - topic_vectors: List[Dict[str, Any]] = None) -> List[Dict[str, Any]]: - """Concatenate infos - only merge if not None.""" - # Get all non-None lists - lists_to_zip = [] - for lst in [basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors]: - if lst is not None: - lists_to_zip.append(lst) - - if not lists_to_zip: - return [] - - # Use the first list to determine length - length = len(lists_to_zip[0]) - - whole_infos = [] - for i in range(length): - merged_dict = {} - - # Only merge if not None - if basic_infos is not None and i < len(basic_infos): - merged_dict.update(basic_infos[i]) - if explicit_preferences is not None and i < len(explicit_preferences): - merged_dict.update(explicit_preferences[i]) - if topic_infos is not None and i < len(topic_infos): - merged_dict.update(topic_infos[i]) - if dialogue_vectors is not None and i < len(dialogue_vectors): - merged_dict.update(dialogue_vectors[i]) - if topic_vectors is not None and i < len(topic_vectors): - merged_dict.update(topic_vectors[i]) - - whole_infos.append(merged_dict) - - return whole_infos - - def implicit_cluster(self, clusterer: HDBSCANClusterer, whole_infos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Implicit cluster.""" - # Extract vectors for clustering - vectors = [info.get("dialog_vector") for info in whole_infos] - if not vectors: - return [] - # Perform clustering - cluster_results = clusterer.cluster(vectors) - - # Map cluster results back to original data - for cluster in cluster_results: - cluster["center_dialog"] = whole_infos[cluster["center_index"]]["dialog_msgs"] - cluster["center_dialog_str"] = whole_infos[cluster["center_index"]]["dialog_str"] - original_infos = [] - for item in cluster["items"]: - index = item["index"] - original_info = whole_infos[index] - original_infos.append({ - "dialog_id": original_info.get("dialog_id"), - "dialog_msgs": original_info.get("dialog_msgs"), - "dialog_str": original_info.get("dialog_str"), - "created_at": original_info.get("created_at") - }) - cluster["original_data"] = original_infos - - return cluster_results - - def topic_cluster(self, clusterer: HDBSCANClusterer, whole_infos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Topic cluster.""" - # Extract vectors for clustering - vectors = [info.get("topic_vector") for info in whole_infos] - if not vectors: - return [] - # Perform clustering - cluster_results = clusterer.cluster(vectors) - - # Map cluster results back to original data - for cluster in cluster_results: - cluster["center_dialog"] = whole_infos[cluster["center_index"]]["dialog_str"] - original_infos = [] - for item in cluster["items"]: - index = item["index"] - original_info = whole_infos[index] - original_infos.append({ - "dialog_id": original_info.get("dialog_id"), - "dialog_msgs": original_info.get("dialog_msgs"), - "dialog_str": original_info.get("dialog_str"), - "created_at": original_info.get("created_at") - }) - cluster["original_data"] = original_infos - - return cluster_results - - - def extract_implicit_preferences(self, clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Extract implicit preferences from clusters.""" - if not clusters: - return [] - for cluster in clusters: - # Get dialogue segments in this cluster - qa_pairs = "\n".join([info["dialog_str"] for info in cluster["original_data"]]) - - prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - - if result.get("implicit_preference"): - cluster["implicit_preference"] = result - except Exception as e: - print(e) - cluster["implicit_preference"] = "" - - return clusters - - def extract_topic_preferences(self, clusters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Extract topic preferences from clusters.""" - if not clusters: - return [] - for cluster in clusters: - # Get dialogue segments in this cluster - qa_pairs = "\n".join([info["dialog_str"] for info in cluster["original_data"]]) - - prompt = NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - - cluster["topic_cluster_name"] = result.get("topic_cluster_name", "") - cluster["topic_cluster_description"] = result.get("topic_cluster_description", "") - cluster["topic_preferences"] = result.get("topic_preferences", "") - except Exception as e: - print(e) - cluster["topic_cluster_name"] = "" - cluster["topic_cluster_description"] = "" - cluster["topic_preferences"] = "" - - return clusters - - def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Extract user-level preferences.""" - if not topic_preferences: - return [] - cluster_infos = [{ - "topic_cluster_name": cluster["topic_cluster_name"], - "topic_cluster_description": cluster["topic_cluster_description"], - "topic_preferences": cluster["topic_preferences"]} - for cluster in topic_preferences] - prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace("{cluster_info}", json.dumps(cluster_infos, ensure_ascii=False, indent=2)) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - - if result.get("user_preferences"): - return result - except Exception as e: - print(e) - return "" - - def judge_update_or_add(self, old_msg: MessageList, new_msg: MessageList) -> bool: - """Judge if the new message expresses the same core content as the old message.""" - # Convert messages to string format for comparison - old_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in old_msg]) - new_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in new_msg]) - - # Use the template prompt with placeholders - prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_str).replace("{new_information}", new_str) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - response = result.get("is_same", False) - return response if isinstance(response, bool) else response == "true" - except Exception as e: - print(f"Error in judge_update_or_add: {e}") - # Fallback to simple string comparison - return old_str == new_str - - def preference_integration(self, query: str, - explicit_prefs: List[Dict[str, Any]], - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - user_prefs: List[Dict[str, Any]]) -> str: - """Integrate preferences.""" - explicit_prefs_str = json.dumps(explicit_prefs, ensure_ascii=False, indent=2) - implicit_prefs_str = json.dumps(implicit_prefs, ensure_ascii=False, indent=2) - topic_prefs_str = json.dumps(topic_prefs, ensure_ascii=False, indent=2) - user_prefs_str = json.dumps(user_prefs, ensure_ascii=False, indent=2) - - prompt = NAIVE_PREFERENCE_INTEGRATION_PROMPT.format( - query_preference=query, - explicit_preference=explicit_prefs_str, - implicit_preference=implicit_prefs_str, - topic_preference=topic_prefs_str, - user_preference=user_prefs_str - ) - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - return result["final_prompt"] - except Exception as e: - print(f"Error in preference_integration: {e}") - return "" - - - def store_preferences(self, explicit_prefs: List[Dict[str, Any]], - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - user_prefs: List[Dict[str, Any]], - user_id: str): - """Store all preferences in memory.""" - - # Convert to VecDBItem format and store in separate collections - - # Store explicit preferences - if explicit_prefs: - explicit_memories = [] - for pref in explicit_prefs: - # Create VecDBItem directly using existing embedding - vec_db_item = VecDBItem( - id=pref.get("dialog_id", ""), - vector=pref.get("dialog_vector", []), - payload={ - "dialog_id": pref.get("dialog_id", ""), - "dialog_msgs": pref.get("dialog_msgs", []), - "dialog_str": pref.get("dialog_str", ""), - "dialog_vector": pref.get("dialog_vector", []), - "created_at": pref.get("created_at", datetime.now().isoformat()), - "topic_name": pref.get("topic_name", ""), - "topic_description": pref.get("topic_description", ""), - "topic_vector": pref.get("topic_vector", []), - "user_id": user_id, - "preference_type": "explicit_preference" - } - ) - explicit_memories.append(vec_db_item) - - # Store in explicit_preference collection - self.vector_db.add("explicit_preference", explicit_memories) - - # Store implicit preferences - if implicit_prefs: - implicit_memories = [] - for pref in implicit_prefs: - # Create VecDBItem directly using existing embedding - vec_db_item = VecDBItem( - id=pref.get("cluster_id", ""), - vector=pref.get("center_vector", []), - payload={ - "cluster_id": pref.get("cluster_id", ""), - "center_dialog": pref.get("center_dialog", ""), - "center_vector": pref.get("center_vector", []), - "implicit_preference": pref.get("implicit_preference", ""), - "created_at": pref.get("created_at", datetime.now().isoformat()), - "user_id": user_id, - "preference_type": "implicit_preference" - } - ) - implicit_memories.append(vec_db_item) - - # Store in implicit_preference collection - self.vector_db.add("implicit_preference", implicit_memories) - - # Store topic preferences - if topic_prefs: - topic_memories = [] - for pref in topic_prefs: - # Create VecDBItem directly using existing embedding - vec_db_item = VecDBItem( - id=pref.get("cluster_id", ""), - vector=pref.get("center_vector", []), - payload={ - "cluster_id": pref.get("cluster_id", ""), - "center_dialog": pref.get("center_dialog", ""), - "center_vector": pref.get("center_vector", []), - "topic_cluster_name": pref.get("topic_cluster_name", ""), - "topic_cluster_description": pref.get("topic_cluster_description", ""), - "topic_preferences": pref.get("topic_preferences", ""), - "created_at": pref.get("created_at", datetime.now().isoformat()), - "user_id": user_id, - "preference_type": "topic_preference" - } - ) - topic_memories.append(vec_db_item) - - # Store in topic_preference collection - self.vector_db.add("topic_preference", topic_memories) - - # Store user preferences - if user_prefs: - user_memories = [] - for pref in user_prefs: - # Create VecDBItem with zero vector (user preferences don't need vector search) - # Use zero vector to satisfy Milvus collection dimension requirements - # Get embedding dimension from embedder config - embedding_dim = getattr(self.embedder.config, 'embedding_dims', 768) # Default to 768 if not available - zero_vector = [0.0] * embedding_dim - vec_db_item = VecDBItem( - id=user_id, - vector=zero_vector, - payload={ - "user_id": user_id, - "user_preferences": pref.get("user_preferences", ""), - "created_at": datetime.now().isoformat(), - "preference_type": "user_preference" - } - ) - user_memories.append(vec_db_item) - - # Store in user_preference collection - self.vector_db.add("user_preference", user_memories) - - def generate_memory_summary(self, explicit_prefs: List[Dict[str, Any]], - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - user_prefs: List[Dict[str, Any]]) -> str: - """Generate a summary of the built memory.""" - summary = { - "memory_build_summary": { - "explicit_preferences_count": len(explicit_prefs), - "implicit_preferences_count": len(implicit_prefs), - "topic_preferences_count": len(topic_prefs), - "user_preferences_count": len(user_prefs), - "build_timestamp": datetime.now().isoformat() - } - } - - return json.dumps(summary, ensure_ascii=False, indent=2) diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py index bee0891c7..b9c186b17 100644 --- a/src/memos/memories/textual/prefer_text_memory/spliter.py +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -53,7 +53,7 @@ def split_chunks(self, data: MessageList | str) -> list[MessageList] | list[str] Returns: List of MessageList chunks or list of string chunks """ - if isinstance(data, MessageList): + if isinstance(data, list): # Build QA pairs from chat history pairs = self.build_qa_pairs(data) chunks = [] diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index f4d53e4f6..be1ffcb92 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, List, Dict from datetime import datetime +import uuid import json from concurrent.futures import ThreadPoolExecutor, as_completed from memos.types import MessageList @@ -15,14 +16,6 @@ class BaseUpdater(ABC): def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=None): """Initialize the updater.""" - @abstractmethod - def update(self, new_dialog: MessageList, *args, **kwargs) -> None: - """Update the dialog. - Args: - new_dialog (MessageList): The new dialog to update. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - """ class NaiveUpdater(BaseUpdater): @@ -126,26 +119,26 @@ def _process_single_implicit_cluster(self, cluster_id: str, cluster_dialogs: Lis print(f"Error processing implicit cluster {cluster_id}: {e}") return {"cluster_id": cluster_id, "implicit_exract_result": None} - def _extract_implicit_preferences(self, implicit_extract_inputs: Dict[str, List[str]], max_workers: int = 10) -> List[Dict[str, Any]]: + def _extract_implicit_preferences(self, implicit_extract_inputs: Dict[str, List[str]], max_workers: int = 10) -> Dict[str, Dict[str, Any]]: """Extract implicit preferences from implicit extract inputs using thread pool.""" if not implicit_extract_inputs: - return [] + return {} - results = [] + results = {} with ThreadPoolExecutor(max_workers=min(max_workers, len(implicit_extract_inputs))) as executor: - future_to_cluster = { - executor.submit(self._process_single_implicit_cluster, cluster_id, cluster_dialogs): cluster_id + futures = [ + executor.submit(self._process_single_implicit_cluster, cluster_id, cluster_dialogs) for cluster_id, cluster_dialogs in implicit_extract_inputs.items() - } + ] - for future in as_completed(future_to_cluster): + for future in as_completed(futures): try: result = future.result() if result["implicit_exract_result"] is not None: - results.append(result) + cluster_id = result["cluster_id"] + results[cluster_id] = result["implicit_exract_result"] except Exception as e: - cluster_id = future_to_cluster[future] - print(f"Error processing implicit cluster {cluster_id}: {e}") + print(f"Error processing implicit cluster: {e}") continue return results @@ -159,33 +152,34 @@ def _process_single_topic_cluster(self, cluster_id: str, cluster_dialogs: List[s print(f"Error processing topic cluster {cluster_id}: {e}") return {"cluster_id": cluster_id, "topic_exract_result": None} - def _extract_topic_preferences(self, topic_extract_inputs: Dict[str, List[str]], max_workers: int = 10) -> List[Dict[str, Any]]: + def _extract_topic_preferences(self, topic_extract_inputs: Dict[str, List[str]], max_workers: int = 10) -> Dict[str, Dict[str, Any]]: """Extract topic preferences from topic extract inputs using thread pool.""" if not topic_extract_inputs: - return [] + return {} - results = [] + results = {} with ThreadPoolExecutor(max_workers=min(max_workers, len(topic_extract_inputs))) as executor: - future_to_cluster = { - executor.submit(self._process_single_topic_cluster, cluster_id, cluster_dialogs): cluster_id + futures = [ + executor.submit(self._process_single_topic_cluster, cluster_id, cluster_dialogs) for cluster_id, cluster_dialogs in topic_extract_inputs.items() - } + ] - for future in as_completed(future_to_cluster): + for future in as_completed(futures): try: result = future.result() if result["topic_exract_result"] is not None: - results.append(result) + cluster_id = result["cluster_id"] + results[cluster_id] = result["topic_exract_result"] except Exception as e: - cluster_id = future_to_cluster[future] - print(f"Error processing topic cluster {cluster_id}: {e}") + print(f"Error processing topic cluster: {e}") continue return results def _extract_user_preferences(self, topic_cluster_pref_infos: Dict[str, Any]) -> List[Dict[str, Any]]: """Extract user preferences from topic cluster info.""" - topic_cluster_pref = [info["topic_exract_result"] for info in topic_cluster_pref_infos] + # topic_prefs is a dict, so we just pass the values as a list + topic_cluster_pref = list(topic_cluster_pref_infos.values()) return self.extractor.extract_user_preferences(topic_cluster_pref) def _store_preferences(self, @@ -201,6 +195,9 @@ def _store_preferences(self, if implicit_clusters: for cluster in implicit_clusters: + if cluster.cluster_id not in implicit_cluster_prefs: + print(f"Warning: No preference found for cluster {cluster.cluster_id}, skipping...") + continue pref = implicit_cluster_prefs[cluster.cluster_id] mem = VecDBItem( id=cluster.cluster_id, @@ -223,6 +220,9 @@ def _store_preferences(self, if topic_clusters: for cluster in topic_clusters: + if cluster.cluster_id not in topic_cluster_prefs: + print(f"Warning: No preference found for topic cluster {cluster.cluster_id}, skipping...") + continue pref = topic_cluster_prefs[cluster.cluster_id] mem = VecDBItem( id=cluster.cluster_id, @@ -247,8 +247,8 @@ def _store_preferences(self, if user_prefs: mem = VecDBItem( - id=user_id, - vector=[0.0] * self.embedder.config.embedding_dims, + id=str(uuid.uuid4()), + vector=[0.0] * self.vector_db.config.vector_dimension, payload={ "user_id": user_id, "user_preference": user_prefs.get("user_preference", ""), diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index d385eb69e..0c95daf90 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -86,12 +86,12 @@ def get_memory(self, messages: MessageList, type: str, info: dict[str, Any]) -> """ self.extractor.extract(messages, type, info) - def update_preferences(self, new_dialog: MessageList) -> None: - """Update a memory by new dialog. - Args: - new_dialog (MessageList): The new dialog to update. + def slow_update(self) -> str: + """Perform a slow update of preferences by reconstructing all preference collections. + Returns: + str: Summary of the memory build process. """ - self.updater.update(new_dialog) + return self.updater.slow_update() def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. @@ -206,12 +206,16 @@ def get_with_collection_name(self, collection_name: str, memory_id: str) -> Text Returns: TextualMemoryItem: The memory with the given ID and collection name. """ - res = self.vector_db.get_by_id(collection_name, memory_id) - if res is None: - raise ValueError(f"Memory with ID {memory_id} not found in collection {collection_name}") - return TextualMemoryItem(id=res.id, - memory=res.payload.get("dialog_str", ""), - metadata=PreferenceTextualMemoryMetadata(**res.payload)) + try: + res = self.vector_db.get_by_id(collection_name, memory_id) + if res is None: + raise ValueError(f"Memory with ID {memory_id} not found in collection {collection_name}") + return TextualMemoryItem(id=res.id, + memory=res.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**res.payload)) + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError(f"Memory with ID {memory_id} not found in collection {collection_name}: {e}") def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: """Get memories by their IDs. @@ -230,12 +234,16 @@ def get_by_ids_with_collection_name(self, collection_name: str, memory_ids: list Returns: list[TextualMemoryItem]: List of memories with the specified IDs and collection name. """ - res = self.vector_db.get_by_ids(collection_name, memory_ids) - if res is None: - raise ValueError(f"Memory with IDs {memory_ids} not found in collection {collection_name}") - return [TextualMemoryItem(id=memo.id, - memory=memo.payload.get("dialog_str", ""), - metadata=PreferenceTextualMemoryMetadata(**memo.payload)) for memo in res] + try: + res = self.vector_db.get_by_ids(collection_name, memory_ids) + if not res: + raise ValueError(f"Memory with IDs {memory_ids} not found in collection {collection_name}") + return [TextualMemoryItem(id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload)) for memo in res] + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError(f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}") def get_all(self) -> list[TextualMemoryItem]: """Get all memories. diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index d998ad0b5..bfdbd9eed 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -17,15 +17,7 @@ # Output Format { - "explicit_preference": { - "role": "", - "style": "", - "format": "", - "length": "", - "language": "", - "safety": "", - "quality": "" - } + "explicit_preference": "Write the explicit preferences here" } # Notes @@ -66,15 +58,11 @@ # Output Format: { - "implicit_preference": [ - "preference 1", - "preference 2", - "preference 3" - ] + "implicit_preference": "Write the implicit preferences here, 1. preference 1, 2. preference 2, 3. preference 3" } # Notes: -- If no clear implicit preferences are found, return an empty array [] +- If no clear implicit preferences are found, return an empty string "" - Only output JSON, no explanations - Focus on meaningful patterns, not trivial observations diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 7bb61a91c..a733dd657 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -198,6 +198,7 @@ def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_lim collection_name=collection_name, filter=expr, batch_size=scroll_limit, + output_fields=["*"], # Include all fields including payload ) # Iterate through all batches From 9f4de02f200c22066ed24eba3a4efb1ce9e9b2fa Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 18 Sep 2025 21:02:21 +0800 Subject: [PATCH 11/62] modify user pref user id code --- src/memos/memories/textual/prefer_text_memory/updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index be1ffcb92..e09609ae7 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -247,7 +247,7 @@ def _store_preferences(self, if user_prefs: mem = VecDBItem( - id=str(uuid.uuid4()), + id=user_id, vector=[0.0] * self.vector_db.config.vector_dimension, payload={ "user_id": user_id, From 9ffd09567cea5f447d3fedb23636eed6afd06919 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 19 Sep 2025 10:17:13 +0800 Subject: [PATCH 12/62] modify bug in milvus --- src/memos/vec_dbs/milvus.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index a733dd657..62047b369 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -150,7 +150,7 @@ def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: return None entity = results[0] - payload = {k: v for k, v in entity.items() if k not in ["id", "vector"]} + payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} return VecDBItem( id=entity["id"], @@ -170,7 +170,7 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: items = [] for entity in results: - payload = {k: v for k, v in entity.items() if k not in ["id", "vector"]} + payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} items.append(VecDBItem( id=entity["id"], vector=entity.get("vector"), @@ -210,7 +210,7 @@ def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_lim # Convert batch results to VecDBItem objects for entity in batch_results: - payload = {k: v for k, v in entity.items() if k not in ["id", "vector"]} + payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} all_items.append(VecDBItem( id=entity["id"], vector=entity.get("vector"), From 52dbbac6f8b4ee9e092e8f474ddea0865e5c1afc Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 19 Sep 2025 17:15:25 +0800 Subject: [PATCH 13/62] finish debug in core --- src/memos/configs/mem_cube.py | 2 +- src/memos/mem_os/core.py | 3 +- src/memos/memories/factory.py | 2 +- src/memos/memories/textual/item.py | 2 +- .../textual/prefer_text_memory/extractor.py | 7 +- .../textual/prefer_text_memory/retrievers.py | 5 + src/memos/memories/textual/preference.py | 7 +- src/memos/vec_dbs/milvus.py | 93 ++++++++++++------- 8 files changed, 77 insertions(+), 44 deletions(-) diff --git a/src/memos/configs/mem_cube.py b/src/memos/configs/mem_cube.py index afca52a26..bc3fdf557 100644 --- a/src/memos/configs/mem_cube.py +++ b/src/memos/configs/mem_cube.py @@ -97,7 +97,7 @@ def validate_para_mem(cls, para_mem: MemoryConfigFactory) -> MemoryConfigFactory @classmethod def validate_pref_mem(cls, pref_mem: MemoryConfigFactory) -> MemoryConfigFactory: """Validate the pref_mem field.""" - allowed_backends = ["naive_pref", "uninitialized"] + allowed_backends = ["pref_text", "uninitialized"] if pref_mem.backend not in allowed_backends: raise ConfigurationError( f"GeneralMemCubeConfig requires pref_mem backend to be one of {allowed_backends}, got '{pref_mem.backend}'" diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 9eece178b..b2db9b4dd 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -631,7 +631,6 @@ def search( "session_id": self.session_id, "chat_history": chat_history.chat_history, }, - moscube=moscube, ) result["pref_mem"].append({"cube_id": mem_cube_id, "memories": memories}) logger.info( @@ -734,7 +733,7 @@ def add( messages_list = [messages] pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( messages_list, - msg_type="chat", + type="chat", info={"user_id": target_user_id, "session_id": self.session_id}) preferences = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) logger.info( diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index e53cc9569..d796b5d04 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -21,7 +21,7 @@ class MemoryFactory(BaseMemory): "naive_text": NaiveTextMemory, "general_text": GeneralTextMemory, "tree_text": TreeTextMemory, - "preference_text": PreferenceTextMemory, + "pref_text": PreferenceTextMemory, "kv_cache": KVCacheMemory, "vllm_kv_cache": VLLMKVCacheMemory, "lora": LoRAMemory, diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 490fc1a72..d5902b554 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -113,7 +113,7 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): dialog_id: str | None = Field( default=None, description="ID of the dialog." ) - dialog_msgs: list[str] | None = Field( + dialog_msgs: list[dict[str, str]] | None = Field( default=None, description="Messages of the dialog." ) dialog_str: str | None = Field( diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 202514ede..4f55213f8 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -180,9 +180,12 @@ def _process_single_chunk(self, chunk: MessageList, msg_type: str, info: dict[st memory = TextualMemoryItem(id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata) return memory - def extract(self, messages: MessageList, msg_type: str, info: dict[str, Any], max_workers: int = 10) -> list[TextualMemoryItem]: + def extract(self, messages: list[MessageList], msg_type: str, info: dict[str, Any], max_workers: int = 10) -> list[TextualMemoryItem]: """Extract preference memories based on the messages using thread pool for acceleration.""" - chunks = self.splitter.split_chunks(messages) + chunks: list[MessageList] = [] + for message in messages: + chunk = self.splitter.split_chunks(message) + chunks.extend(chunk) if not chunks: return [] diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 4133dbfe1..9936ec375 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -25,6 +25,11 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[TextualMemoryItem]: """Retrieve memories from the naive retriever.""" + # TODO: un-support rewrite query and session filter now + if info: + info = info.copy() # Create a copy to avoid modifying the original + info.pop("chat_history", None) + info.pop("session_id", None) query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings query_embedding = query_embeddings[0] # Get the first (and only) embedding explicit_prefs = self.vector_db.search(query_embedding, "explicit_preference", top_k, info) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 0c95daf90..4abfa86bf 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -84,7 +84,7 @@ def get_memory(self, messages: MessageList, type: str, info: dict[str, Any]) -> type (str): The type of memory to get. info (dict[str, Any]): The info to get memory. """ - self.extractor.extract(messages, type, info) + return self.extractor.extract(messages, type, info) def slow_update(self) -> str: """Perform a slow update of preferences by reconstructing all preference collections. @@ -139,7 +139,10 @@ def load(self, dir: str) -> None: except FileNotFoundError: logger.error(f"Memory file not found in directory: {dir}") except json.JSONDecodeError as e: - logger.error(f"Error decoding JSON from memory file: {e}") + if e.pos == 0 and "Expecting value" in str(e): + logger.warning(f"Memory file is empty or contains only whitespace: {memory_file}") + else: + logger.error(f"Error decoding JSON from memory file: {e}") except Exception as e: logger.error(f"An error occurred while loading memories: {e}") diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 62047b369..6b7614dbe 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -1,5 +1,5 @@ from typing import Any - +from pymilvus import MilvusClient, DataType from memos.configs.vec_db import MilvusVecDBConfig from memos.dependency import require_python_package from memos.log import get_logger @@ -20,14 +20,32 @@ class MilvusVecDB(BaseVecDB): ) def __init__(self, config: MilvusVecDBConfig): """Initialize the Milvus vector database and the collection.""" - from pymilvus import MilvusClient self.config = config # Create Milvus client self.client = MilvusClient(uri=self.config.uri) + self.schema = self.create_schema() + self.index_params = self.create_index() self.create_collection() + def create_schema(self): + """Create schema for the milvus collection.""" + schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True) + schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True) + schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension) + schema.add_field(field_name="payload", datatype=DataType.JSON) + + return schema + + def create_index(self): + """Create index for the milvus collection.""" + index_params = self.client.prepare_index_params() + index_params.add_index(field_name="vector", index_type="FLAT", metric_type=self._get_metric_type()) + + return index_params + + def create_collection(self) -> None: """Create a new collection with specified parameters.""" for collection_name in self.config.collection_name: @@ -41,8 +59,8 @@ def create_collection(self) -> None: collection_name=collection_name, dimension=self.config.vector_dimension, metric_type=self._get_metric_type(), - id_type="string", # Use string ID type, align with VecDBItem id type - max_length=self.config.max_length, # Use max_length from config + schema=self.schema, + index_params=self.index_params, ) logger.info( @@ -61,8 +79,8 @@ def create_collection_by_name(self, collection_name: str) -> None: collection_name=collection_name, dimension=self.config.vector_dimension, metric_type=self._get_metric_type(), - id_type="string", # Use string ID type, align with VecDBItem id type - max_length=self.config.max_length, # Use max_length from config + schema=self.schema, + index_params=self.index_params, ) def list_collections(self) -> list[str]: @@ -100,21 +118,18 @@ def search( data=[query_vector], limit=top_k, filter=expr, + output_fields=["*"], # Return all fields ) items = [] for hit in results[0]: - # Extract payload from hit - payload = {} - for key, value in hit.get("entity", {}).items(): - if key not in ["id", "vector"]: - payload[key] = value + entity = hit.get("entity", {}) items.append(VecDBItem( id=str(hit["id"]), - vector=hit.get("entity", {}).get("vector"), - payload=payload, - score=float(hit["distance"]), + vector=entity.get("vector"), + payload=entity.get("payload", {}), + score=1-float(hit["distance"]), )) logger.info(f"Milvus search completed with {len(items)} results.") @@ -124,10 +139,16 @@ def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: """Convert a dictionary filter to a Milvus expression string.""" conditions = [] for field, value in filter_dict.items(): + # For JSON fields, we need to use payload["field"] syntax if isinstance(value, str): - conditions.append(f'{field} == "{value}"') + conditions.append(f"payload['{field}'] == '{value}'") + elif isinstance(value, list) and len(value) == 0: + # Skip empty lists as they cause Milvus query syntax errors + continue + elif isinstance(value, list) and len(value) > 0: + conditions.append(f"payload['{field}'] in {value}") else: - conditions.append(f"{field} == {value}") + conditions.append(f"payload['{field}'] == '{value}'") return " and ".join(conditions) def _get_metric_type(self) -> str: @@ -202,23 +223,27 @@ def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_lim ) # Iterate through all batches - while True: - batch_results = iterator.next() - - if not batch_results: - break + try: + while True: + batch_results = iterator.next() - # Convert batch results to VecDBItem objects - for entity in batch_results: - payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} - all_items.append(VecDBItem( - id=entity["id"], - vector=entity.get("vector"), - payload=payload, - )) - - # Close the iterator - iterator.close() + if not batch_results: + break + + # Convert batch results to VecDBItem objects + for entity in batch_results: + payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} + all_items.append(VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + )) + except Exception as e: + logger.warning(f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far.") + # 返回已经找到的项目,而不是空列表 + finally: + # Close the iterator + iterator.close() logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") return all_items @@ -264,10 +289,8 @@ def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> N entity = { "id": item.id, "vector": item.vector, + "payload": item.payload if item.payload else {} } - # Add payload fields - if item.payload: - entity.update(item.payload) entities.append(entity) From 0d1a85945c4ef064db348a63857a03eccb727513 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sun, 21 Sep 2025 16:26:52 +0800 Subject: [PATCH 14/62] repair bug in milvus get_all --- src/memos/vec_dbs/milvus.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 6b7614dbe..248dfa8ec 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -232,7 +232,8 @@ def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_lim # Convert batch results to VecDBItem objects for entity in batch_results: - payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} + # Extract the actual payload from Milvus entity + payload = entity.get("payload", {}) all_items.append(VecDBItem( id=entity["id"], vector=entity.get("vector"), From d706bb2cef46cdcb91614c741221e738ea29e75b Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Mon, 22 Sep 2025 14:40:44 +0800 Subject: [PATCH 15/62] add pref mem esarch time in core --- src/memos/mem_os/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index b2db9b4dd..e9e579a0f 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -622,6 +622,7 @@ def search( and (mem_cube.pref_mem is not None) and self.config.enable_preference_memory ): + time_start = time.time() memories = mem_cube.pref_mem.search( query, top_k=top_k if top_k else self.config.top_k, From 2fab119845f049b3d9f760a2fe4e82cb8921ae6f Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Mon, 22 Sep 2025 17:33:42 +0800 Subject: [PATCH 16/62] modify search for pref mem in product..py --- src/memos/mem_os/product.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 5899b680c..fe3f912a4 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -1327,6 +1327,22 @@ def search( reformat_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list}) logger.info(f"search memory list is : {reformat_memory_list}") search_result["text_mem"] = reformat_memory_list + + pref_memory_list = search_result["pref_mem"] + reformat_pref_memory_list = [] + for memory in pref_memory_list: + memories_list = [] + for data in memory["memories"]: + memories = data.model_dump() + memories["ref_id"] = f"[{memories['id'].split('-')[0]}]" + memories["metadata"]["embedding"] = [] + memories["metadata"]["sources"] = [] + memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]" + memories["metadata"]["id"] = memories["id"] + memories["metadata"]["memory"] = memories["memory"] + memories_list.append(memories) + reformat_pref_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list}) + search_result["pref_mem"] = reformat_pref_memory_list time_end = time.time() logger.info( f"time search: total time for user_id: {user_id} time is: {time_end - time_start}" From a906230a772f15e04ebdc2d5e536954cd7eec81e Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Mon, 22 Sep 2025 19:03:56 +0800 Subject: [PATCH 17/62] add simple pref memos example --- examples/mem_os/simple_prefs_memos_product.py | 389 ++++++++++++++++++ 1 file changed, 389 insertions(+) create mode 100644 examples/mem_os/simple_prefs_memos_product.py diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py new file mode 100644 index 000000000..fc4e9ba60 --- /dev/null +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -0,0 +1,389 @@ +import uuid +import json +import sys +import os + +# 添加项目根目录到 Python 路径 +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +sys.path.insert(0, project_root) + +# 确保使用本地代码 +sys.path.insert(0, os.path.join(project_root, "src")) + +from memos.configs.mem_os import MOSConfig +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.product import MOSProduct + + +# init MOS +mos_config = { + "user_id": "root", + "chat_model": { + "backend": "openai", + "config": { + "model_name_or_path": "gpt-4o-mini", + "api_key": "sk-xxxxx", + "api_base": "http://xxxx/v1", + "temperature": 0.1, + "remove_think_prefix": True, + "max_tokens": 4096 + } + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": { + "backend": "openai", + "config": { + "model_name_or_path": "gpt-4o-mini", + "api_key": "sk-xxxxx", + "api_base": "http://xxxx/v1", + "temperature": 0.0, + "remove_think_prefix": True, + "max_tokens": 8192 + } + }, + "embedder": { + "backend": "ollama", + "config": { + "model_name_or_path": "nomic-embed-text:latest" + } + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1 + } + } + } + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, + "enable_preference_memory": True +} + +cube_config = { + "model_schema": "memos.configs.mem_cube.GeneralMemCubeConfig", + "user_id": "6dab9067-173a-4034-b5d9-735fc01999e3", + "cube_id": "6dab9067-173a-4034-b5d9-735fc01999e3/mem_cube", + "config_filename": "config.json", + "text_mem": { + "backend": "tree_text", + "config": { + "cube_id": "6dab9067-173a-4034-b5d9-735fc01999e3/mem_cube", + "extractor_llm": { + "backend": "openai", + "config": { + "model_name_or_path": "gpt-4o-mini", + "api_key": "sk-xxxxx", + "api_base": "http://xxxx/v1", + "temperature": 0.8, + "max_tokens": 1024, + "top_p": 0.9, + "top_k": 50, + "remove_think_prefix": True + } + }, + "dispatcher_llm": { + "backend": "openai", + "config": { + "model_name_or_path": "gpt-4o-mini", + "api_key": "sk-xxxxx", + "api_base": "http://xxxx/v1", + "temperature": 0.8, + "max_tokens": 1024, + "top_p": 0.9, + "top_k": 50, + "remove_think_prefix": True + } + }, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": "bolt://localhost:7687", + "user": "neo4j", + "password": "12345678", + "db_name": "neo4j", + "user_name": "memosneo4j", + "embedding_dimension": 768, + "use_multi_db": False, + "auto_create": False + } + }, + "embedder": { + "backend": "ollama", + "config": { + "model_name_or_path": "nomic-embed-text:latest" + } + } + } + }, + "act_mem": { + "backend": "uninitialized", + "config": {} + }, + "para_mem": { + "backend": "uninitialized", + "config": {} + }, + "pref_mem": { + "backend": "pref_text", + "config": { + "cube_id": "6dab9067-173a-4034-b5d9-735fc01999e3/mem_cube", + "extractor_llm": { + "backend": "openai", + "config": { + "model_name_or_path": "gpt-4o-mini", + "api_key": "sk-xxxxx", + "api_base": "http://xxxx/v1", + "temperature": 0.8, + "max_tokens": 1024, + "top_p": 0.9, + "top_k": 50, + "remove_think_prefix": False + } + }, + "vector_db": { + "backend": "milvus", + "config": { + "collection_name": [ + "explicit_preference", + "implicit_preference", + "topic_preference", + "user_preference" + ], + "vector_dimension": 768, + "distance_metric": "cosine", + "uri": "./milvus_demo.db" + } + }, + "embedder": { + "backend": "ollama", + "config": { + "model_name_or_path": "nomic-embed-text:latest", + "embedding_dims": 768, + "api_base": "http://localhost:11434" + } + }, + "extractor": { + "backend": "naive", + "config": {} + }, + "adder": { + "backend": "naive", + "config": {} + }, + "retriever": { + "backend": "naive", + "config": {} + }, + "updater": { + "backend": "naive", + "config": {} + }, + "assembler": { + "backend": "naive", + "config": {} + } + } + } +} +mos_config = MOSConfig(**mos_config) +memory = MOSProduct(default_config=mos_config) + +# register user and mem cube +mem_cube_config = GeneralMemCubeConfig(**cube_config) +mem_cube = GeneralMemCube(mem_cube_config) +user_id = "6dab9067-173a-4034-b5d9-735fc01999e3" +register_result = memory.user_register(user_id=user_id, user_name="root", default_mem_cube=mem_cube) +print(f"register_result: {register_result}") + +# add memory for user +memory.add( + messages=[ + # Sports and hobbies + {"role": "user", "content": "I like playing football."}, + {"role": "assistant", "content": "I like playing football too. It's a great way to stay active and work as a team."}, + + {"role": "user", "content": "Do you enjoy watching basketball games?"}, + {"role": "assistant", "content": "Yes, I find basketball exciting to watch, especially the fast-paced action and strategic plays."}, + + {"role": "user", "content": "I'm learning to play tennis this summer."}, + {"role": "assistant", "content": "That's wonderful! Tennis is a fantastic sport that improves both physical fitness and mental focus."}, + + # Food and cooking + {"role": "user", "content": "What's your favorite type of cuisine?"}, + {"role": "assistant", "content": "I really enjoy Italian cuisine, especially pasta dishes and authentic pizza."}, + + {"role": "user", "content": "I love cooking Asian food at home."}, + {"role": "assistant", "content": "Asian cooking is amazing! The balance of flavors and techniques is so interesting to learn."}, + + {"role": "user", "content": "Have you ever tried making homemade bread?"}, + {"role": "assistant", "content": "Yes, there's something magical about the process of making bread from scratch. The smell is incredible!"}, + + # Travel and places + {"role": "user", "content": "I'm planning a trip to Japan next year."}, + {"role": "assistant", "content": "Japan is an incredible destination! The culture, food, and technology blend is fascinating."}, + + {"role": "user", "content": "What's the most beautiful place you've ever seen?"}, + {"role": "assistant", "content": "I find mountain landscapes particularly breathtaking, especially during sunrise or sunset."}, + + {"role": "user", "content": "I prefer beach vacations over city trips."}, + {"role": "assistant", "content": "Beach vacations are so relaxing! The sound of waves and warm sand is perfect for unwinding."}, + + # Technology and innovation + {"role": "user", "content": "What do you think about artificial intelligence?"}, + {"role": "assistant", "content": "AI is transforming many aspects of our lives. It's exciting to see how it can help solve complex problems."}, + + {"role": "user", "content": "I'm learning to code in Python."}, + {"role": "assistant", "content": "Python is a great language to start with! It's versatile and has a wonderful community."}, + + {"role": "user", "content": "Do you use any productivity apps?"}, + {"role": "assistant", "content": "I find task management apps really helpful for organizing daily activities and staying focused."}, + + # Music and entertainment + {"role": "user", "content": "What kind of music do you enjoy?"}, + {"role": "assistant", "content": "I appreciate many genres, but I particularly enjoy jazz and classical music for their complexity."}, + + {"role": "user", "content": "I play the guitar in my free time."}, + {"role": "assistant", "content": "That's fantastic! Playing an instrument is such a rewarding creative outlet."}, + + {"role": "user", "content": "Have you seen any good movies lately?"}, + {"role": "assistant", "content": "I enjoy films that make you think deeply about human nature and society."}, + + # Books and learning + {"role": "user", "content": "Do you read books regularly?"}, + {"role": "assistant", "content": "Yes, reading is one of my favorite ways to learn and explore new ideas."}, + + {"role": "user", "content": "I'm currently reading a science fiction novel."}, + {"role": "assistant", "content": "Science fiction is great for exploring possibilities and imagining different futures."}, + + {"role": "user", "content": "What's your favorite subject to learn about?"}, + {"role": "assistant", "content": "I'm fascinated by psychology and how the human mind works."}, + + # Nature and environment + {"role": "user", "content": "I love hiking in the mountains."}, + {"role": "assistant", "content": "Hiking is wonderful! It's a great way to connect with nature and get exercise."}, + + {"role": "user", "content": "Are you concerned about climate change?"}, + {"role": "assistant", "content": "Yes, it's one of the most important challenges of our time that requires collective action."}, + + {"role": "user", "content": "I try to recycle and reduce waste at home."}, + {"role": "assistant", "content": "That's commendable! Every small action contributes to environmental protection."}, + + # Work and career + {"role": "user", "content": "What do you do for work?"}, + {"role": "assistant", "content": "I work in technology, helping people solve problems and improve their daily lives."}, + + {"role": "user", "content": "I'm considering a career change."}, + {"role": "assistant", "content": "Career changes can be challenging but also very rewarding when you find the right fit."}, + + {"role": "user", "content": "Do you work from home or in an office?"}, + {"role": "assistant", "content": "I work remotely, which gives me flexibility in how I organize my day."}, + + # Health and wellness + {"role": "user", "content": "I've started doing yoga in the morning."}, + {"role": "assistant", "content": "Yoga is excellent for both physical and mental well-being. Great way to start the day!"}, + + {"role": "user", "content": "How do you manage stress?"}, + {"role": "assistant", "content": "I find meditation and regular exercise help me stay balanced and focused."}, + + {"role": "user", "content": "I'm trying to eat more vegetables."}, + {"role": "assistant", "content": "That's a great health goal! Vegetables provide so many essential nutrients."}, + + # Art and creativity + {"role": "user", "content": "Do you enjoy painting or drawing?"}, + {"role": "assistant", "content": "I appreciate visual arts, though I'm more of an observer than a creator."}, + + {"role": "user", "content": "I'm learning photography."}, + {"role": "assistant", "content": "Photography is a beautiful way to capture moments and express your perspective."}, + + {"role": "user", "content": "What's your favorite type of art?"}, + {"role": "assistant", "content": "I'm drawn to abstract art because it allows for personal interpretation."}, + + # Family and relationships + {"role": "user", "content": "I have two siblings."}, + {"role": "assistant", "content": "Having siblings can be wonderful - they're often our first friends and lifelong companions."}, + + {"role": "user", "content": "Do you have any pets?"}, + {"role": "assistant", "content": "I don't have pets, but I think they bring so much joy and companionship to people's lives."}, + + {"role": "user", "content": "I'm planning to visit my grandparents this weekend."}, + {"role": "assistant", "content": "That's lovely! Grandparents often have such interesting stories and wisdom to share."}, + + # Education and growth + {"role": "user", "content": "I'm taking an online course in data science."}, + {"role": "assistant", "content": "Data science is a fascinating field with many applications across different industries."}, + + {"role": "user", "content": "What's the best way to learn a new language?"}, + {"role": "assistant", "content": "Immersion and consistent practice are key, along with finding methods that work for your learning style."}, + + {"role": "user", "content": "I enjoy attending workshops and conferences."}, + {"role": "assistant", "content": "Workshops are great for hands-on learning and networking with like-minded people."}, + + # Hobbies and interests + {"role": "user", "content": "I collect vintage postcards."}, + {"role": "assistant", "content": "What a unique hobby! Postcards offer glimpses into different times and places."}, + + {"role": "user", "content": "Do you garden?"}, + {"role": "assistant", "content": "I appreciate gardening, though I'm more of a plant observer than a gardener myself."}, + + {"role": "user", "content": "I love solving puzzles and brain teasers."}, + {"role": "assistant", "content": "Puzzles are excellent for keeping the mind sharp and developing problem-solving skills."}, + + # Future and goals + {"role": "user", "content": "What are your goals for next year?"}, + {"role": "assistant", "content": "I aim to continue learning and growing, helping more people with their challenges."}, + + {"role": "user", "content": "I want to start my own business someday."}, + {"role": "assistant", "content": "Entrepreneurship is exciting! It requires courage, planning, and persistence."}, + + {"role": "user", "content": "Do you believe in setting New Year's resolutions?"}, + {"role": "assistant", "content": "I think any time is good for setting positive goals and making positive changes."}, + + # Philosophy and life + {"role": "user", "content": "What's the meaning of life to you?"}, + {"role": "assistant", "content": "I believe it's about learning, growing, and making positive connections with others."}, + + {"role": "user", "content": "I try to practice gratitude daily."}, + {"role": "assistant", "content": "Gratitude is such a powerful practice that can transform your perspective on life."}, + + {"role": "user", "content": "What's your biggest fear?"}, + {"role": "assistant", "content": "I think fear of not being able to help when someone really needs it would be challenging."}, + + # Technology and daily life + {"role": "user", "content": "How has technology changed your daily routine?"}, + {"role": "assistant", "content": "Technology has made many tasks more efficient and connected me with people worldwide."}, + + {"role": "user", "content": "I'm trying to reduce my screen time."}, + {"role": "assistant", "content": "That's a healthy goal! Finding balance with technology is important for well-being."}, + + {"role": "user", "content": "Do you use social media?"}, + {"role": "assistant", "content": "I think social media can be a great tool for connection when used mindfully."}, + + # Final thoughts + {"role": "user", "content": "This has been a great conversation!"}, + {"role": "assistant", "content": "I've really enjoyed our chat too! It's wonderful to explore so many different topics together."}, + ], + user_id=user_id, + mem_cube_id=register_result["default_cube_id"] + +) + +user_cube_id = memory.user_manager.get_user_cubes(user_id)[0].cube_id +print(f"user_cube_id: {user_cube_id}") +memory.mem_cubes[user_cube_id].pref_mem.slow_update() + + +retrieved_memories = memory.search(query="What do you like?", user_id=user_id) +# output text_memories: I like playing football, act_memories, para_memories +# print(f"text_memories: {retrieved_memories['text_mem']}") +# print(f"pref_memories: {retrieved_memories['pref_mem']}") +print(f"len_pref_memories: {len(retrieved_memories['pref_mem'][0]['memories'])}") From a31745779275f62a79eb2c6b16299edb5f9de383 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 23 Sep 2025 14:41:30 +0800 Subject: [PATCH 18/62] modify bug in examples/mem_os/simple_prefs_memos_product.py --- examples/mem_os/simple_prefs_memos_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py index fc4e9ba60..776cc1671 100644 --- a/examples/mem_os/simple_prefs_memos_product.py +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -199,7 +199,7 @@ memory = MOSProduct(default_config=mos_config) # register user and mem cube -mem_cube_config = GeneralMemCubeConfig(**cube_config) +mem_cube_config = GeneralMemCubeConfig.model_validate(cube_config) mem_cube = GeneralMemCube(mem_cube_config) user_id = "6dab9067-173a-4034-b5d9-735fc01999e3" register_result = memory.user_register(user_id=user_id, user_name="root", default_mem_cube=mem_cube) From 2493b8ad9a5d0baddc9dcb4d17798629fcc7b9e2 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 23 Sep 2025 16:35:48 +0800 Subject: [PATCH 19/62] repair bug in user id related part --- examples/mem_os/simple_prefs_memos_product.py | 2 +- .../textual/prefer_text_memory/updater.py | 17 ++++++++--------- src/memos/memories/textual/preference.py | 6 ++++-- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py index 776cc1671..8f2973be3 100644 --- a/examples/mem_os/simple_prefs_memos_product.py +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -379,7 +379,7 @@ user_cube_id = memory.user_manager.get_user_cubes(user_id)[0].cube_id print(f"user_cube_id: {user_cube_id}") -memory.mem_cubes[user_cube_id].pref_mem.slow_update() +memory.mem_cubes[user_cube_id].pref_mem.slow_update(user_id) retrieved_memories = memory.search(query="What do you like?", user_id=user_id) diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index e09609ae7..57167a6db 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -247,7 +247,7 @@ def _store_preferences(self, if user_prefs: mem = VecDBItem( - id=user_id, + id=str(uuid.uuid4()), vector=[0.0] * self.vector_db.config.vector_dimension, payload={ "user_id": user_id, @@ -275,20 +275,19 @@ def _generate_memory_summary(self, explicit_infos: List[Dict[str, Any]], return json.dumps(summary, ensure_ascii=False, indent=2) - def slow_update(self): + def slow_update(self, user_id: str): """Retrieve all dialog info from the expicit preference collection, and reconstruct the implicit preference collection, topic collection and user preference collection. """ # refresh the implicit preference collection, topic collection and user preference collection - self.vector_db.delete_collection("implicit_preference") - self.vector_db.delete_collection("topic_preference") - self.vector_db.delete_collection("user_preference") - - self.vector_db.create_collection_by_name("implicit_preference") - self.vector_db.create_collection_by_name("topic_preference") - self.vector_db.create_collection_by_name("user_preference") + impl_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="implicit_preference", filter={"user_id": user_id})] + topic_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="topic_preference", filter={"user_id": user_id})] + user_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="user_preference", filter={"user_id": user_id})] + self.vector_db.delete("implicit_preference", impl_ids) + self.vector_db.delete("topic_preference", topic_ids) + self.vector_db.delete("user_preference", user_ids) all_data = self.vector_db.get_all("explicit_preference") user_id = all_data[0].payload.get("user_id", "") diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 4abfa86bf..b55d6307f 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -86,12 +86,14 @@ def get_memory(self, messages: MessageList, type: str, info: dict[str, Any]) -> """ return self.extractor.extract(messages, type, info) - def slow_update(self) -> str: + def slow_update(self, user_id: str) -> str: """Perform a slow update of preferences by reconstructing all preference collections. + Args: + user_id (str): The user ID to update preferences for. Returns: str: Summary of the memory build process. """ - return self.updater.slow_update() + return self.updater.slow_update(user_id) def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. From 3b84f8e0c2dfe44cb6e60f1043f1d03924810310 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 23 Sep 2025 20:57:52 +0800 Subject: [PATCH 20/62] modify search --- .../memories/textual/prefer_text_memory/retrievers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 9936ec375..f1d1d6eae 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -44,13 +44,13 @@ def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[Te user_prefs = [] explicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("dialog_str", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in explicit_prefs] + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in explicit_prefs if pref.payload["explicit_preference"]] implicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in implicit_prefs] + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in implicit_prefs if pref.payload["implicit_preference"]] topic_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs] + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs if pref.payload["topic_preferences"]] user_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("user_preferences", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in user_prefs] + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in user_prefs if pref.payload["user_preferences"]] return explicit_prefs + implicit_prefs + topic_prefs + user_prefs From 8945b1072c2e1d7a682fc73956bf807183cf3688 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 24 Sep 2025 15:59:28 +0800 Subject: [PATCH 21/62] repair bug in slow update --- src/memos/memories/textual/item.py | 4 +- .../textual/prefer_text_memory/retrievers.py | 6 +- .../textual/prefer_text_memory/spliter.py | 69 +++++++--- .../textual/prefer_text_memory/updater.py | 6 +- src/memos/templates/prefer_complete_prompt.py | 120 ++++++++++-------- src/memos/vec_dbs/milvus.py | 16 ++- 6 files changed, 138 insertions(+), 83 deletions(-) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index d5902b554..c6cc67fdb 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -152,10 +152,10 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): topic_cluster_description: str | None = Field( default=None, description="Description of the topic cluster." ) - topic_preferences: str | None = Field( + topic_preference: str | None = Field( default=None, description="Preferences of the topic." ) - user_preferences: str | None = Field( + user_preference: str | None = Field( default=None, description="Preferences of the user." ) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index f1d1d6eae..8804c9a21 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -48,9 +48,9 @@ def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[Te implicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in implicit_prefs if pref.payload["implicit_preference"]] topic_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs if pref.payload["topic_preferences"]] - user_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("user_preferences", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in user_prefs if pref.payload["user_preferences"]] + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs if pref.payload["topic_preference"]] + user_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("user_preference", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in user_prefs if pref.payload["user_preference"]] return explicit_prefs + implicit_prefs + topic_prefs + user_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py index b9c186b17..e4f62d8a8 100644 --- a/src/memos/memories/textual/prefer_text_memory/spliter.py +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -1,3 +1,5 @@ +import copy + from memos.types import MessageList from memos.parsers.factory import ParserFactory from memos.configs.parser import ParserConfigFactory @@ -44,6 +46,53 @@ def __init__(self, lookback_turns: int = 1, ) self.chunker = ChunkerFactory.from_config(chunker_config) + def _split_with_lookback(self, data: MessageList) -> list[MessageList]: + """Split the messages or files into chunks by looking back fixed number of turns. + adjacent chunk with high duplicate rate""" + # Build QA pairs from chat history + pairs = self.build_qa_pairs(data) + chunks = [] + + # Create chunks by looking back fixed number of turns + for i in range(len(pairs)): + # Calculate the start index for lookback + start_idx = max(0, i + 1 - self.lookback_turns) + # Get the chunk of pairs (as many as available, up to lookback_turns) + chunk_pairs = pairs[start_idx:i+1] + + # Flatten chunk_pairs (list[list[dict]]) to MessageList (list[dict]) + chunk_messages = [] + for pair in chunk_pairs: + chunk_messages.extend(pair) + + chunks.append(chunk_messages) + return chunks + + def _split_with_overlap(self, data: MessageList) -> list[MessageList]: + """split the messages or files into chunks with overlap. + adjacent chunk with low duplicate rate""" + chunks = [] + chunk = [] + for item in data: + # Convert dictionary to string + if "chat_time" in item: + mem = item["role"] + ": " + f"[{item['chat_time']}]: " + item["content"] + chunk.append(mem) + else: + mem = item["role"] + ":" + item["content"] + chunk.append(mem) + # 3 turns (Q + A = 6) each chunk + if len(chunk) >= 6: + chunks.append(chunk) + # overlap 1 turns (Q + A = 2) + context = copy.deepcopy(chunk[-2:]) + chunk = context + if chunk: + chunks.append(chunk) + + return chunks + + def split_chunks(self, data: MessageList | str) -> list[MessageList] | list[str]: """Split the messages or files into chunks. @@ -54,25 +103,7 @@ def split_chunks(self, data: MessageList | str) -> list[MessageList] | list[str] List of MessageList chunks or list of string chunks """ if isinstance(data, list): - # Build QA pairs from chat history - pairs = self.build_qa_pairs(data) - chunks = [] - - # Create chunks by looking back fixed number of turns - for i in range(len(pairs)): - # Calculate the start index for lookback - start_idx = max(0, i + 1 - self.lookback_turns) - # Get the chunk of pairs (as many as available, up to lookback_turns) - chunk_pairs = pairs[start_idx:i+1] - - # Flatten chunk_pairs (list[list[dict]]) to MessageList (list[dict]) - chunk_messages = [] - for pair in chunk_pairs: - chunk_messages.extend(pair) - - chunks.append(chunk_messages) - - return chunks + return self._split_with_lookback(data) else: # Parse and chunk the string data using pre-initialized components text = self.parser.parse(data) diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index 57167a6db..052cc0445 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -284,13 +284,13 @@ def slow_update(self, user_id: str): impl_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="implicit_preference", filter={"user_id": user_id})] topic_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="topic_preference", filter={"user_id": user_id})] user_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="user_preference", filter={"user_id": user_id})] - + self.vector_db.delete("implicit_preference", impl_ids) self.vector_db.delete("topic_preference", topic_ids) self.vector_db.delete("user_preference", user_ids) - all_data = self.vector_db.get_all("explicit_preference") - user_id = all_data[0].payload.get("user_id", "") + # get all data from explicit preference collection + all_data = self.vector_db.get_by_filter("explicit_preference", filter={"user_id": user_id}) informations = [item.payload for item in all_data] # Perform clustering diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index bfdbd9eed..82fa16328 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -2,75 +2,91 @@ NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT = """ -You are an information extraction assistant. You will be given a QA pair (user question Q and assistant answer A). -Please extract the user's explicit preferences from the Q and implicit preferences from the A, and output JSON strictly according to the requirements. +You are a preference extraction assistant. +Please extract the user's explicitly mentioned preferences from the following conversation. -# Extraction Rules -Explicit preferences, extract only from the user's Q. Including but not limited to: -- Role descriptions (e.g., "You are a history teacher") -- Style constraints (e.g., "humorous style", "academic style") -- Format requirements (e.g., "table", "Markdown") -- Length limitations (e.g., "within 100 words") -- Language requirements (e.g., "write in English") -- Safety compliance requirements (e.g., "don't involve sensitive content") -- Quality standards (e.g., "be concise and clear") - -# Output Format -{ - "explicit_preference": "Write the explicit preferences here" -} +Notes: +- A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". +- Any clearly expressed inclination, desire, rejection, or priority counts as an explicit preference. -# Notes -If there is no information for a certain item, please leave an empty string "". -Only output JSON, no explanations. +Requirements: +1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. +2. Output should be a concise natural language summary, not a list or categories. +3. If there are no explicit preferences in the conversation, output an empty string "". +4. Output only the preference statements themselves, without any additional explanation. -# Conversation Content +Conversation: {qa_pair} + +Output format: +```json +{ + "explicit_preference": "A short natural language summary of the preferences, or an empty string" +} +``` +Don't output anything except the JSON. """ NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ -You are a preference extraction expert. You will be given multiple user QA pairs (questions and answers). -Your task is to extract the user's **implicit preferences** from these QA pairs. - -# Definitions: -1. **Explicit Preferences**: Constraints explicitly stated by the user in their questions, such as: - - Role requirements (e.g., "act as a teacher") - - Style preferences (e.g., "be humorous", "be formal") - - Format requirements (e.g., "use bullet points", "create a table") - - Length constraints (e.g., "keep it short", "be detailed") - - Language requirements (e.g., "write in English") - - Safety guidelines (e.g., "avoid sensitive topics") - -2. **Implicit Preferences**: Patterns that are NOT explicitly stated but consistently appear across multiple QA pairs: - - Recurring themes or topics the user frequently asks about - - Consistent communication style preferences - - Repeated information depth requirements - - Common response format expectations - - Underlying values or priorities - -# Extraction Rules: -- Focus on patterns that appear across MULTIPLE QA pairs, not single occurrences -- Look for consistent behaviors, not one-time requests -- Extract only implicit preferences, do not repeat explicit ones -- Use concise language, avoid redundant words -- Each preference should be distinct and non-overlapping +You are a preference inference assistant. Please extract **implicit preferences** from the following conversation +(preferences that the user did not explicitly state but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices). -# Output Format: +Notes: +- Implicit preferences refer to user inclinations or choices that are not directly expressed, but can be reasonably inferred from factual cues in the conversation. +- Do not treat explicitly stated preferences as implicit preferences; this prompt is only for inferring preferences that are not directly mentioned. + +Requirements: +1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. +2. Output a concise natural language statement; do not use lists, categories, or include the reasoning process. +3. Only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. +4. If no implicit preference can be reasonably inferred, directly return an empty string "" (do not output anything else). + +Conversation: +{qa_pair} + +Output format: +```json { - "implicit_preference": "Write the implicit preferences here, 1. preference 1, 2. preference 2, 3. preference 3" + "implicit_preference": "A concise natural language statement of the implicit preferences reasonably inferred from the conversation, or an empty string" } +``` +Don't output anything except the JSON. +""" -# Notes: -- If no clear implicit preferences are found, return an empty string "" -- Only output JSON, no explanations -- Focus on meaningful patterns, not trivial observations -# QA Pairs: +NAIVE_EXPLICIT_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ +You are a preference extraction and inference assistant. Please extract the user's preferences from the following conversation, including: + +1. **Explicit preferences**: Preferences that the user directly expresses, such as likes, dislikes, wants, does not want, or prioritized choices. +2. **Implicit preferences**: Preferences that are not explicitly stated but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices. + +Notes: +- For explicit preferences, only extract what the user directly states, do not infer. +- For implicit preferences, only infer when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. +- Do not duplicate: do not treat explicit preferences as implicit preferences. + +Requirements: +1. Output in JSON format with two fields: "explicit_preferences" and "implicit_preferences". +2. Each field should be an array, with each element being a concise natural language preference statement. +3. Output only the preference statements themselves; do not include any extra explanation, reasoning, or confidence information. +4. If a type of preference does not exist, its array should be empty. + +Conversation: {qa_pair} + +Output Format: +```json +{ + "explicit_preferences": ["The user clearly likes coffee", "The user does not want to sit by the window"], + "implicit_preferences": ["The user prefers a quiet environment"] +} +``` +Don't output anything except the JSON. """ + NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT = """ You are a topic cluster analysis expert. You will be given a topic cluster containing multiple QA conversations with their preferences, topic names, and descriptions. Your task is to analyze all information in this cluster and extract three key components, then output in strict JSON format. diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 248dfa8ec..2ffdfdd53 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -111,7 +111,7 @@ def search( List of search results with distance scores and payloads. """ # Convert filter to Milvus expression - expr = self._dict_to_expr(filter) if filter else None + expr = self._dict_to_expr(filter) if filter else "" results = self.client.search( collection_name=collection_name, @@ -137,10 +137,16 @@ def search( def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: """Convert a dictionary filter to a Milvus expression string.""" + if not filter_dict: + return "" + conditions = [] for field, value in filter_dict.items(): + # Skip None values as they cause Milvus query syntax errors + if value is None: + continue # For JSON fields, we need to use payload["field"] syntax - if isinstance(value, str): + elif isinstance(value, str): conditions.append(f"payload['{field}'] == '{value}'") elif isinstance(value, list) and len(value) == 0: # Skip empty lists as they cause Milvus query syntax errors @@ -211,7 +217,7 @@ def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_lim Returns: List of items including vectors and payload that match the filter """ - expr = self._dict_to_expr(filter) if filter else None + expr = self._dict_to_expr(filter) if filter else "" all_items = [] # Use query_iterator for efficient pagination @@ -257,7 +263,7 @@ def count(self, collection_name: str, filter: dict[str, Any] | None = None) -> i """Count items in the database, optionally with filter.""" if filter: # If there's a filter, use query method - expr = self._dict_to_expr(filter) + expr = self._dict_to_expr(filter) if filter else "" results = self.client.query( collection_name=collection_name, filter=expr, @@ -334,6 +340,8 @@ def upsert(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) - def delete(self, collection_name: str, ids: list[str]) -> None: """Delete items from the vector database.""" + if not ids: + return self.client.delete( collection_name=collection_name, ids=ids, From ee63e18bbbe2f897a29cf6dfd80aac913b38592d Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 24 Sep 2025 16:18:48 +0800 Subject: [PATCH 22/62] modify define error in extractor -> extract_implicit_preferences --- src/memos/memories/textual/prefer_text_memory/extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 4f55213f8..aa9036af1 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -71,7 +71,7 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, A except Exception: return response - def extract_implicit_preferences(self, qa_pairs: list[MessageList] | list[str]) -> List[Dict[str, Any]]: + def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> List[Dict[str, Any]]: """Extract implicit preferences from cluster qa pairs.""" if not qa_pairs: return None From 67ed7a8788cb246bf992f09f03a30569f91b7b74 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 24 Sep 2025 16:35:32 +0800 Subject: [PATCH 23/62] reapair define error in extractor and modify split func in spliter --- .../textual/prefer_text_memory/extractor.py | 4 ++-- .../textual/prefer_text_memory/spliter.py | 15 ++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index aa9036af1..95308b86e 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -89,7 +89,7 @@ def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> Lis print(f"Error processing cluster: {qa_pairs}\n{e}") return None - def extract_topic_preferences(self, qa_pairs: list[MessageList] | list[str]) -> List[Dict[str, Any]]: + def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> List[Dict[str, Any]]: """Extract topic preferences from cluster qa pairs.""" if not qa_pairs: return None @@ -184,7 +184,7 @@ def extract(self, messages: list[MessageList], msg_type: str, info: dict[str, An """Extract preference memories based on the messages using thread pool for acceleration.""" chunks: list[MessageList] = [] for message in messages: - chunk = self.splitter.split_chunks(message) + chunk = self.splitter.split_chunks(message, split_type="lookback") chunks.extend(chunk) if not chunks: return [] diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py index e4f62d8a8..94aba0ec0 100644 --- a/src/memos/memories/textual/prefer_text_memory/spliter.py +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -48,7 +48,8 @@ def __init__(self, lookback_turns: int = 1, def _split_with_lookback(self, data: MessageList) -> list[MessageList]: """Split the messages or files into chunks by looking back fixed number of turns. - adjacent chunk with high duplicate rate""" + adjacent chunk with high duplicate rate, + default lookback turns is 1, only current turn in chunk""" # Build QA pairs from chat history pairs = self.build_qa_pairs(data) chunks = [] @@ -81,8 +82,8 @@ def _split_with_overlap(self, data: MessageList) -> list[MessageList]: else: mem = item["role"] + ":" + item["content"] chunk.append(mem) - # 3 turns (Q + A = 6) each chunk - if len(chunk) >= 6: + # 5 turns (Q + A = 10) each chunk + if len(chunk) >= 10: chunks.append(chunk) # overlap 1 turns (Q + A = 2) context = copy.deepcopy(chunk[-2:]) @@ -93,7 +94,7 @@ def _split_with_overlap(self, data: MessageList) -> list[MessageList]: return chunks - def split_chunks(self, data: MessageList | str) -> list[MessageList] | list[str]: + def split_chunks(self, data: MessageList | str, **kwargs) -> list[MessageList] | list[str]: """Split the messages or files into chunks. Args: @@ -103,7 +104,11 @@ def split_chunks(self, data: MessageList | str) -> list[MessageList] | list[str] List of MessageList chunks or list of string chunks """ if isinstance(data, list): - return self._split_with_lookback(data) + if kwargs.get("split_type") == "lookback": + chunks = self._split_with_lookback(data) + elif kwargs.get("split_type") == "overlap": + chunks = self._split_with_overlap(data) + return chunks else: # Parse and chunk the string data using pre-initialized components text = self.parser.parse(data) From 2ab893e3fab59736ae0147fcc3d9128112840c0a Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 26 Sep 2025 17:32:35 +0800 Subject: [PATCH 24/62] modify code --- .../textual/prefer_text_memory/adder.py | 40 +++-- .../textual/prefer_text_memory/assemble.py | 28 ---- .../textual/prefer_text_memory/extractor.py | 138 +++++++++--------- .../textual/prefer_text_memory/spliter.py | 8 +- .../textual/prefer_text_memory/updater.py | 6 +- src/memos/templates/prefer_complete_prompt.py | 62 ++++---- 6 files changed, 130 insertions(+), 152 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 79f9959aa..fbc0d4ddf 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -60,21 +60,35 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: vector=memory.metadata.dialog_vector, payload=payload ) - - search_results = self.vector_db.search(memory.metadata.dialog_vector, "explicit_preference", top_k=1) - recall = search_results[0] if search_results else None - if not recall or (recall.score is not None and recall.score < 0.5): + if memory.metadata.preference_type == "explicit_preference": + search_results = self.vector_db.search(memory.metadata.dialog_vector, "explicit_preference", top_k=1) + recall = search_results[0] if search_results else None + if not recall or (recall.score is not None and recall.score < 0.5): + self.vector_db.update("explicit_preference", memory.id, vec_db_item) + return memory.id + + old_msg_str = recall.payload.get("dialog_str", "") + new_msg_str = memory.metadata.dialog_str + is_same = self._judge_update_or_add(old_msg_str, new_msg_str) + if is_same: + self.vector_db.delete("explicit_preference", [recall.id]) self.vector_db.update("explicit_preference", memory.id, vec_db_item) return memory.id - - old_msg_str = recall.payload.get("dialog_str", "") - new_msg_str = memory.metadata.dialog_str - is_same = self._judge_update_or_add(old_msg_str, new_msg_str) - if is_same: - self.vector_db.delete("explicit_preference", [recall.id]) - self.vector_db.update("explicit_preference", memory.id, vec_db_item) - return memory.id - + + if memory.metadata.preference_type == "implicit_preference": + search_results = self.vector_db.search(memory.metadata.dialog_vector, "implicit_preference", top_k=1) + recall = search_results[0] if search_results else None + if not recall or (recall.score is not None and recall.score < 0.5): + self.vector_db.update("implicit_preference", memory.id, vec_db_item) + return memory.id + + old_msg_str = recall.payload.get("dialog_str", "") + new_msg_str = memory.metadata.dialog_str + is_same = self._judge_update_or_add(old_msg_str, new_msg_str) + if is_same: + self.vector_db.delete("implicit_preference", [recall.id]) + self.vector_db.update("implicit_preference", memory.id, vec_db_item) + return memory.id except Exception as e: print(f"Error processing memory {memory.id}: {e}") return None diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py index 2201dc270..0b009b754 100644 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -31,32 +31,6 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.embedder = embedder self.vector_db = vector_db - def _preference_integration(self, query: str, - explicit_prefs: List[Dict[str, Any]], - implicit_prefs: List[Dict[str, Any]], - topic_prefs: List[Dict[str, Any]], - user_prefs: List[Dict[str, Any]]) -> str: - """Integrate preferences.""" - explicit_prefs_str = json.dumps(explicit_prefs, ensure_ascii=False, indent=2) - implicit_prefs_str = json.dumps(implicit_prefs, ensure_ascii=False, indent=2) - topic_prefs_str = json.dumps(topic_prefs, ensure_ascii=False, indent=2) - user_prefs_str = json.dumps(user_prefs, ensure_ascii=False, indent=2) - - prompt = NAIVE_PREFERENCE_INTEGRATION_PROMPT.format( - query_preference=query, - explicit_preference=explicit_prefs_str, - implicit_preference=implicit_prefs_str, - topic_preference=topic_prefs_str, - user_preference=user_prefs_str - ) - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - result = json.loads(response) - return result["final_prompt"] - except Exception as e: - print(f"Error in preference_integration: {e}") - return "" - def assemble(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: """Assemble query and memories into a single memory.""" @@ -76,8 +50,6 @@ def assemble(self, query: str, memories: list[TextualMemoryItem], assemble_strat return memories elif assemble_strategy == "semi": return f"Query: {query}\n\n In addition to the above Query, you can refer to the following preference below memories. \n\nMemories: {memories}. \n\nWhen encountering conflicts, prioritize following the query." - elif assemble_strategy == "full": - return self._preference_integration(query, explicit_prefs, implicit_prefs, topic_prefs, user_prefs) else: raise ValueError(f"Invalid assemble strategy: {assemble_strategy}") diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 95308b86e..c8259aa16 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -55,10 +55,11 @@ def extract_topic_info(self, qa_pair: MessageList | str) -> Dict[str, Any]: response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) return result - except Exception: - return response + except Exception as e: + print(f"Error extracting topic info: {e}, return None") + return None - def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, Any]: + def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, Any] | None: """Extract explicit preference from a QA pair.""" qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) @@ -68,10 +69,11 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, A response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) return result - except Exception: - return response + except Exception as e: + print(f"Error extracting explicit preference: {e}, return None") + return None - def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> List[Dict[str, Any]]: + def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> Dict[str, Any] | None: """Extract implicit preferences from cluster qa pairs.""" if not qa_pairs: return None @@ -82,14 +84,30 @@ def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> Lis response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) - - if result.get("implicit_preference"): - return result + return result except Exception as e: - print(f"Error processing cluster: {qa_pairs}\n{e}") + print(f"Error extracting implicit preferences: {e}, return None") return None + + # def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> Dict[str, Any] | None: + # """Extract implicit preferences from cluster qa pairs.""" + # if not qa_pairs: + # return None + # qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs[0], dict) else "\n\n".join(qa_pairs) + # prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) + + # try: + # response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + # response = response.strip().replace("```json", "").replace("```", "").strip() + # result = json.loads(response) + + # if result.get("implicit_preference"): + # return result + # except Exception as e: + # print(f"Error processing cluster: {qa_pairs}\n{e}") + # return None - def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> List[Dict[str, Any]]: + def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> Dict[str, Any] | None: """Extract topic preferences from cluster qa pairs.""" if not qa_pairs: return None @@ -104,10 +122,10 @@ def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> List[D if result.get("topic_cluster_name"): return result except Exception as e: - print(f"Error processing cluster: {qa_pairs}\n{e}") + print(f"Error extracting topic preferences: {qa_pairs}\n{e}, return None") return None - def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> Dict[str, Any] | None: """Extract user-level preferences.""" if not topic_preferences: return [] @@ -121,54 +139,16 @@ def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> L if result.get("user_preference"): return result except Exception as e: - print(f"Error processing user preferences: {topic_preferences}\n{e}") + print(f"Error processing user preferences: {topic_preferences}\n{e}, return None") return "" - def concat_infos( - self, - basic_infos: List[Dict[str, Any]] = None, - explicit_preferences: List[Dict[str, Any]] = None, - topic_infos: List[Dict[str, Any]] = None, - dialogue_vectors: List[Dict[str, Any]] = None, - topic_vectors: List[Dict[str, Any]] = None) -> List[Dict[str, Any]]: - """Concatenate infos - only merge if not None.""" - # Get all non-None lists - list_to_concat = [] - for lst in [basic_infos, explicit_preferences, topic_infos, dialogue_vectors, topic_vectors]: - if lst is not None: - list_to_concat.append(lst) - - if not list_to_concat: - return [] - - # Use the first list to determine length - length = len(list_to_concat[0]) - - whole_infos = [] - for i in range(length): - merged_dict = {} - - # Only merge if not None - if basic_infos is not None and i < len(basic_infos): - merged_dict.update(basic_infos[i]) - if explicit_preferences is not None and i < len(explicit_preferences): - merged_dict.update(explicit_preferences[i]) - if topic_infos is not None and i < len(topic_infos): - merged_dict.update(topic_infos[i]) - if dialogue_vectors is not None and i < len(dialogue_vectors): - merged_dict.update(dialogue_vectors[i]) - if topic_vectors is not None and i < len(topic_vectors): - merged_dict.update(topic_vectors[i]) - - whole_infos.append(merged_dict) - - return whole_infos - - def _process_single_chunk(self, chunk: MessageList, msg_type: str, info: dict[str, Any]) -> TextualMemoryItem: + def _process_single_chunk_explicit(self, chunk: MessageList, msg_type: str, info: dict[str, Any]) -> TextualMemoryItem | None: """Process a single chunk and return a TextualMemoryItem.""" basic_info = self.extract_basic_info(chunk) topic_info = self.extract_topic_info(chunk) explicit_pref = self.extract_explicit_preference(chunk) + if not explicit_pref: + return None vector_info = { "dialog_vector": self.embedder.embed([basic_info["dialog_str"]])[0], @@ -180,29 +160,55 @@ def _process_single_chunk(self, chunk: MessageList, msg_type: str, info: dict[st memory = TextualMemoryItem(id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata) return memory + def _process_single_chunk_implicit(self, chunk: MessageList, msg_type: str, info: dict[str, Any]) -> TextualMemoryItem | None: + basic_info = self.extract_basic_info(chunk) + implicit_pref = self.extract_implicit_preferences(chunk) + if not implicit_pref: + return None + + vector_info = { + "dialog_vector": self.embedder.embed([basic_info["dialog_str"]])[0], + } + + extract_info = {**basic_info, **implicit_pref, **vector_info, **info} + + metadata = PreferenceTextualMemoryMetadata(type=msg_type, preference_type="implicit_preference", **extract_info) + memory = TextualMemoryItem(id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata) + return memory + def extract(self, messages: list[MessageList], msg_type: str, info: dict[str, Any], max_workers: int = 10) -> list[TextualMemoryItem]: """Extract preference memories based on the messages using thread pool for acceleration.""" - chunks: list[MessageList] = [] + chunks_for_explicit: list[MessageList] = [] for message in messages: chunk = self.splitter.split_chunks(message, split_type="lookback") - chunks.extend(chunk) - if not chunks: + chunks_for_explicit.extend(chunk) + if not chunks_for_explicit: return [] + chunks_for_implicit: list[MessageList] = [] + for message in messages: + chunk = self.splitter.split_chunks(message, split_type="overlap") + chunks_for_implicit.extend(chunk) + memories = [] - with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: - future_to_chunk = { - executor.submit(self._process_single_chunk, chunk, msg_type, info): chunk - for chunk in chunks + with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks_for_explicit) + len(chunks_for_implicit))) as executor: + futures = { + executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ("explicit", chunk) + for chunk in chunks_for_explicit } + futures.update({ + executor.submit(self._process_single_chunk_implicit, chunk, msg_type, info): ("implicit", chunk) + for chunk in chunks_for_implicit + }) - for future in as_completed(future_to_chunk): + for future in as_completed(futures): try: memory = future.result() - memories.append(memory) + if memory: + memories.append(memory) except Exception as e: - chunk = future_to_chunk[future] - print(f"Error processing chunk: {chunk}\n{e}") + task_type, chunk = futures[future] + print(f"Error processing {task_type} chunk: {chunk}\n{e}") continue return memories \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py index 94aba0ec0..c0466718a 100644 --- a/src/memos/memories/textual/prefer_text_memory/spliter.py +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -75,13 +75,7 @@ def _split_with_overlap(self, data: MessageList) -> list[MessageList]: chunks = [] chunk = [] for item in data: - # Convert dictionary to string - if "chat_time" in item: - mem = item["role"] + ": " + f"[{item['chat_time']}]: " + item["content"] - chunk.append(mem) - else: - mem = item["role"] + ":" + item["content"] - chunk.append(mem) + chunk.append(item) # 5 turns (Q + A = 10) each chunk if len(chunk) >= 10: chunks.append(chunk) diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index 052cc0445..b30a59242 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -176,7 +176,7 @@ def _extract_topic_preferences(self, topic_extract_inputs: Dict[str, List[str]], return results - def _extract_user_preferences(self, topic_cluster_pref_infos: Dict[str, Any]) -> List[Dict[str, Any]]: + def _extract_user_preferences(self, topic_cluster_pref_infos: Dict[str, Any]) -> Dict[str, Any] | None: """Extract user preferences from topic cluster info.""" # topic_prefs is a dict, so we just pass the values as a list topic_cluster_pref = list(topic_cluster_pref_infos.values()) @@ -261,14 +261,14 @@ def _store_preferences(self, def _generate_memory_summary(self, explicit_infos: List[Dict[str, Any]], implicit_infos: List[Dict[str, Any]], topic_infos: List[Dict[str, Any]], - user_infos: List[Dict[str, Any]]) -> str: + user_infos: Dict[str, Any]) -> str: """Generate a summary of the built memory.""" summary = { "memory_build_summary": { "explicit_preferences_count": len(explicit_infos), "implicit_preferences_count": len(implicit_infos), "topic_preferences_count": len(topic_infos), - "user_preferences_count": len(user_infos), + "user_preferences_count": 1 if user_infos else 0, "build_timestamp": datetime.now().isoformat() } } diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 82fa16328..8b439c415 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -38,9 +38,10 @@ Requirements: 1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. -2. Output a concise natural language statement; do not use lists, categories, or include the reasoning process. -3. Only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. -4. If no implicit preference can be reasonably inferred, directly return an empty string "" (do not output anything else). +2. Output a concise natural language statement; do not use lists, categories, or include the reasoning process. +3. Inferred implicit preferences must not conflict with explicit preferences. +4. For implicit_preference: only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. Put all reasoning and explanation in the reasoning field. +5. If no implicit preference can be reasonably inferred, leave the implicit_preference field empty (do not output anything else). Conversation: {qa_pair} @@ -48,7 +49,8 @@ Output format: ```json { - "implicit_preference": "A concise natural language statement of the implicit preferences reasonably inferred from the conversation, or an empty string" + "implicit_preference": "A concise natural language statement of the implicit preferences reasonably inferred from the conversation, or an empty string", + "reasoning": "Briefly explain the reasoning process for the implicit preference" } ``` Don't output anything except the JSON. @@ -228,43 +230,33 @@ NAIVE_PREFERENCE_INTEGRATION_PROMPT = """ -You are a preference integration expert. Your task is to integrate preference constraints from different sources and generate a final prompt that can be directly input to an LLM. Please note the following information sources and their priority levels (from high to low): +You are a memory integration expert. Your task is to integrate various memories from the user and generate a final prompt that can be directly input into a large language model (LLM). The generated prompt should include all relevant memories, but you must filter them based on the current query. -Sources: -1. Current query preferences: Constraints explicitly stated in the current user question -2. Related dialogue preferences: Preference references from Q&A pairs related to the current query -3. Related topic preferences: Preference references from topics related to the current query -4. User preference: Common preference references from the user's historical conversations +1. Requirements and Rules: +- Filter Relevant Memories: Only retain memories that are related to the current query, and discard irrelevant ones. -Priority: Current query preferences > Related dialogue preferences > Related topic preferences > User preferences > Implicit preferences +2. Handle Preference Conflicts: +- Explicit preference memories take precedence over implicit preference memories. +- If conflicts arise, delete lower-priority memories according to the hierarchy to ensure consistency. -Requirements: -- If conflicts exist between preferences, strictly follow the priority order, with higher priority preferences overriding lower priority ones. -- Generate a comprehensive prompt that includes all integrated preferences and constraints. -- The final prompt should be ready to be input directly to an LLM for answering the user's query. -- Keep the integrated preferences specific and actionable. -- Ensure the prompt is clear, structured, and contains all necessary context and constraints. - -Please generate the final integrated prompt based on the input, strictly resolve conflicts by priority, and output in JSON format as follows: -{{ - "final_prompt": "Complete prompt ready for LLM input, including query, context, and all integrated preferences", - "conflict_handling": ["Conflict resolution explanation 1", "Conflict resolution explanation 2", "..."], - "preference_summary": "Summary of all integrated preferences and constraints" -}} - -# Current query -{query_preference} +3. Generate Final Integrated Prompt: +- The prompt should include all filtered, conflict-free memories. +- Provide structured, clear, specific, and actionable context. +- The prompt must be directly usable by the LLM to answer the user's query. -# Related dialogue preferences -{explicit_preference} +4. Priority Rules: +Explicit memories > Implicit memories -# Implicit preferences -{implicit_preference} +Please output in the following JSON formatm, don't output anything else: -# Related topic preferences -{topic_preference} +{ + "final_prompt": "The complete prompt containing filtered, integrated, and conflict-free memories, ready for LLM input", + "explanation": "Briefly explain the reasoning process for the final prompt and process of filtering and integrating memories" +} -# User preferences -{user_preference} +Query: +{query} +Memories: +{memories} """ \ No newline at end of file From 8ce0ca3015a8f05322789d2ba746c98e8f2e9b3a Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 26 Sep 2025 17:45:53 +0800 Subject: [PATCH 25/62] modify adder --- .../textual/prefer_text_memory/adder.py | 50 ++++++++----------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index fbc0d4ddf..c69d1d94f 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -59,36 +59,30 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: id=memory.id, vector=memory.metadata.dialog_vector, payload=payload - ) - if memory.metadata.preference_type == "explicit_preference": - search_results = self.vector_db.search(memory.metadata.dialog_vector, "explicit_preference", top_k=1) - recall = search_results[0] if search_results else None - if not recall or (recall.score is not None and recall.score < 0.5): - self.vector_db.update("explicit_preference", memory.id, vec_db_item) - return memory.id - - old_msg_str = recall.payload.get("dialog_str", "") - new_msg_str = memory.metadata.dialog_str - is_same = self._judge_update_or_add(old_msg_str, new_msg_str) - if is_same: - self.vector_db.delete("explicit_preference", [recall.id]) - self.vector_db.update("explicit_preference", memory.id, vec_db_item) - return memory.id + ) + pref_type_collection_map = { + "explicit_preference": "explicit_preference", + "implicit_preference": "implicit_preference", + "topic_preference": "topic_preference", + "user_preference": "user_preference" + } + preference_type = memory.metadata.preference_type + collection_name = pref_type_collection_map[preference_type] - if memory.metadata.preference_type == "implicit_preference": - search_results = self.vector_db.search(memory.metadata.dialog_vector, "implicit_preference", top_k=1) - recall = search_results[0] if search_results else None - if not recall or (recall.score is not None and recall.score < 0.5): - self.vector_db.update("implicit_preference", memory.id, vec_db_item) - return memory.id - - old_msg_str = recall.payload.get("dialog_str", "") - new_msg_str = memory.metadata.dialog_str - is_same = self._judge_update_or_add(old_msg_str, new_msg_str) - if is_same: - self.vector_db.delete("implicit_preference", [recall.id]) - self.vector_db.update("implicit_preference", memory.id, vec_db_item) + search_results = self.vector_db.search(memory.metadata.dialog_vector, collection_name, top_k=1) + recall = search_results[0] if search_results else None + if not recall or (recall.score is not None and recall.score < 0.5): + self.vector_db.update(collection_name, memory.id, vec_db_item) return memory.id + + old_msg_str = recall.payload.get("dialog_str", "") + new_msg_str = memory.metadata.dialog_str + is_same = self._judge_update_or_add(old_msg_str, new_msg_str) + if is_same: + self.vector_db.delete(collection_name, [recall.id]) + self.vector_db.update(collection_name, memory.id, vec_db_item) + return memory.id + except Exception as e: print(f"Error processing memory {memory.id}: {e}") return None From b89e28a49bd02899353aecb2ffaf9390a029e393 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sun, 28 Sep 2025 14:16:39 +0800 Subject: [PATCH 26/62] optimize the code --- src/memos/configs/vec_db.py | 2 + .../textual/prefer_text_memory/assemble.py | 90 +++++++++++++++---- .../textual/prefer_text_memory/extractor.py | 18 ---- .../textual/prefer_text_memory/retrievers.py | 31 ++++--- .../textual/prefer_text_memory/spliter.py | 4 - src/memos/vec_dbs/milvus.py | 2 +- 6 files changed, 99 insertions(+), 48 deletions(-) diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index 457a09a21..17611f41e 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -45,6 +45,8 @@ class MilvusVecDBConfig(BaseVecDBConfig): uri: str = Field(..., description="URI for Milvus connection") collection_name: list[str] = Field(..., description="Name(s) of the collection(s)") max_length: int = Field(default=65535, description="Maximum length for string fields (varChar type)") + user_name: str = Field(default="", description="User name for Milvus connection") + password: str = Field(default="", description="Password for Milvus connection") class VectorDBConfigFactory(BaseConfig): """Factory class for creating vector database configurations.""" diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py index 0b009b754..e16274737 100644 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -11,7 +11,7 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the assembler.""" @abstractmethod - def assemble(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: + def get_instruction(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: """Assemble query and memories into a single memory. Args: query: The query to assemble. @@ -31,25 +31,85 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.embedder = embedder self.vector_db = vector_db - def assemble(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: + def get_instruction(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: """Assemble query and memories into a single memory.""" - explicit_prefs = [{"dialog_str": memory.metadata.dialog_str, - "explicit_preference": memory.metadata.explicit_preference} - for memory in memories if memory.metadata.preference_type == "explicit_preference"] - implicit_prefs = [{"center_dialog_str": memory.metadata.center_dialog, - "implicit_preference": memory.metadata.implicit_preference} - for memory in memories if memory.metadata.preference_type == "implicit_preference"] - topic_prefs = [{"center_dialog_str": memory.metadata.center_dialog, - "topic_preferences": memory.metadata.topic_preferences} - for memory in memories if memory.metadata.preference_type == "topic_preference"] - user_prefs = [{"user_preferences": memory.metadata.user_preferences} - for memory in memories if memory.metadata.preference_type == "user_preference"] + # Initialize all preference lists + textual_mems = [] + explicit_prefs = [] + implicit_prefs = [] + topic_prefs = [] + user_prefs = [] + + # Single loop to categorize all memories by preference type + for memory in memories: + if memory.metadata.preference_type == "explicit_preference": + explicit_prefs.append({ + "dialog_str": memory.metadata.dialog_str, + "explicit_preference": memory.metadata.explicit_preference + }) + elif memory.metadata.preference_type == "implicit_preference": + implicit_prefs.append({ + "dialog_str": memory.metadata.center_dialog, + "implicit_preference": memory.metadata.implicit_preference + }) + elif memory.metadata.preference_type == "topic_preference": + topic_prefs.append({ + "center_dialog_str": memory.metadata.center_dialog, + "topic_preferences": memory.metadata.topic_preferences + }) + elif memory.metadata.preference_type == "user_preference": + user_prefs.append({ + "user_preferences": memory.metadata.user_preferences + }) + else: + textual_mems.append(memory.memory) + # Build memories string with different titles for different types + memories_parts = [] + if textual_mems: + memories_parts.append("## Textual Memories:") + for i, mem in enumerate(textual_mems, 1): + memories_parts.append(f"{i}. {mem}") + if explicit_prefs: + memories_parts.append("## Explicit Preferences:") + for i, pref in enumerate(explicit_prefs, 1): + memories_parts.append(f"{i}. {pref['dialog_str']}") + + if implicit_prefs: + memories_parts.append("\n## Implicit Preferences:") + for i, pref in enumerate(implicit_prefs, 1): + memories_parts.append(f"{i}. {pref['dialog_str']}") + + if topic_prefs: + memories_parts.append("\n## Topic Preferences:") + for i, pref in enumerate(topic_prefs, 1): + memories_parts.append(f"{i}. {pref['center_dialog_str']}") + + if user_prefs: + memories_parts.append("\n## User Preferences:") + for i, pref in enumerate(user_prefs, 1): + memories_parts.append(f"{i}. {pref['user_preferences']}") + + memories_str = "\n".join(memories_parts) + + system_prompt = ( + "You are a knowledgeable and helpful AI assistant. " + "You have access to conversation memories that help you provide more personalized responses. " + "Use the memories to understand the user's context, preferences, and past interactions. " + "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." + f"\n\n## Memories:\n{memories_str}" + ) + if assemble_strategy == "raw": - return memories + return system_prompt.replace("{memories}", memories_str) elif assemble_strategy == "semi": - return f"Query: {query}\n\n In addition to the above Query, you can refer to the following preference below memories. \n\nMemories: {memories}. \n\nWhen encountering conflicts, prioritize following the query." + return ( + system_prompt + + ("Note: Textual memories are summaries of facts, while preference memories are summaries of user preferences. " + \ + "Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts." + \ + "When encountering preference conflicts, the priority is: explicit preferences > implicit preferences > textual memories.") + ).replace("{memories}", memories_str) else: raise ValueError(f"Invalid assemble strategy: {assemble_strategy}") diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index c8259aa16..16016f2b8 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -88,24 +88,6 @@ def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> Dic except Exception as e: print(f"Error extracting implicit preferences: {e}, return None") return None - - # def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> Dict[str, Any] | None: - # """Extract implicit preferences from cluster qa pairs.""" - # if not qa_pairs: - # return None - # qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs[0], dict) else "\n\n".join(qa_pairs) - # prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) - - # try: - # response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - # response = response.strip().replace("```json", "").replace("```", "").strip() - # result = json.loads(response) - - # if result.get("implicit_preference"): - # return result - # except Exception as e: - # print(f"Error processing cluster: {qa_pairs}\n{e}") - # return None def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> Dict[str, Any] | None: """Extract topic preferences from cluster qa pairs.""" diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 8804c9a21..8245396d3 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict +from concurrent.futures import ThreadPoolExecutor, as_completed from memos.memories.textual.item import TextualMemoryItem, PreferenceTextualMemoryMetadata @@ -32,20 +33,30 @@ def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[Te info.pop("session_id", None) query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings query_embedding = query_embeddings[0] # Get the first (and only) embedding - explicit_prefs = self.vector_db.search(query_embedding, "explicit_preference", top_k, info) - implicit_prefs = self.vector_db.search(query_embedding, "implicit_preference", top_k, info) - topic_prefs = self.vector_db.search(query_embedding, "topic_preference", top_k, info) - # Get user preferences directly (no vector search needed since there's only one per user) - user_id = info.get("user_id") if info else None - if user_id: - user_prefs = self.vector_db.get_by_filter("user_preference", {"user_id": user_id}) - else: - user_prefs = [] + # Use thread pool to parallelize the searches + with ThreadPoolExecutor(max_workers=2) as executor: + # Submit all search tasks + future_explicit = executor.submit(self.vector_db.search, query_embedding, "explicit_preference", top_k, info) + future_implicit = executor.submit(self.vector_db.search, query_embedding, "implicit_preference", top_k, info) + future_topic = executor.submit(self.vector_db.search, query_embedding, "topic_preference", top_k, info) + + # Get user preferences directly (no vector search needed since there's only one per user) + user_id = info.get("user_id") if info else None + if user_id: + future_user = executor.submit(self.vector_db.get_by_filter, "user_preference", {"user_id": user_id}) + else: + future_user = None + + # Wait for all results + explicit_prefs = future_explicit.result() + implicit_prefs = future_implicit.result() + topic_prefs = future_topic.result() + user_prefs = future_user.result() if future_user else [] explicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("dialog_str", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in explicit_prefs if pref.payload["explicit_preference"]] - implicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), + implicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("dialog_str", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in implicit_prefs if pref.payload["implicit_preference"]] topic_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs if pref.payload["topic_preference"]] diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py index c0466718a..8b5982c2d 100644 --- a/src/memos/memories/textual/prefer_text_memory/spliter.py +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -127,7 +127,3 @@ def build_qa_pairs(self, chat_history: MessageList) -> list[MessageList]: current_qa_pair = [] # reset return qa_pairs - - def recursive_split_merge(): - pass - \ No newline at end of file diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 2ffdfdd53..e612b94b1 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -24,7 +24,7 @@ def __init__(self, config: MilvusVecDBConfig): self.config = config # Create Milvus client - self.client = MilvusClient(uri=self.config.uri) + self.client = MilvusClient(uri=self.config.uri, user=self.config.user_name, password=self.config.password) self.schema = self.create_schema() self.index_params = self.create_index() self.create_collection() From d77a0d188aaa1b20031d58a615b53033f03b954f Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sun, 28 Sep 2025 20:22:07 +0800 Subject: [PATCH 27/62] repair bug in adder and extractor --- .../memories/textual/prefer_text_memory/adder.py | 2 +- .../textual/prefer_text_memory/extractor.py | 13 ++++++++----- .../memories/textual/prefer_text_memory/utils.py | 4 ++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index c69d1d94f..919e6dd0e 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -87,7 +87,7 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: print(f"Error processing memory {memory.id}: {e}") return None - def add(self, memories: list[TextualMemoryItem | dict[str, Any]], max_workers: int = 10, *args, **kwargs) -> list[str]: + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], max_workers: int = 8, *args, **kwargs) -> list[str]: """Add the instruct preference memories using thread pool for acceleration.""" if not memories: return [] diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 16016f2b8..7ede2e45b 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -127,16 +127,17 @@ def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> D def _process_single_chunk_explicit(self, chunk: MessageList, msg_type: str, info: dict[str, Any]) -> TextualMemoryItem | None: """Process a single chunk and return a TextualMemoryItem.""" basic_info = self.extract_basic_info(chunk) - topic_info = self.extract_topic_info(chunk) - explicit_pref = self.extract_explicit_preference(chunk) + if not basic_info["dialog_str"]: + return None + + explicit_pref = self.extract_explicit_preference(basic_info["dialog_str"]) if not explicit_pref: return None vector_info = { "dialog_vector": self.embedder.embed([basic_info["dialog_str"]])[0], - "topic_vector": self.embedder.embed([topic_info["topic_name"] + topic_info["topic_description"]])[0] } - extract_info = {**basic_info, **topic_info, **explicit_pref, **vector_info, **info} + extract_info = {**basic_info, **explicit_pref, **vector_info, **info} metadata = PreferenceTextualMemoryMetadata(type=msg_type, preference_type="explicit_preference", **extract_info) memory = TextualMemoryItem(id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata) @@ -144,7 +145,9 @@ def _process_single_chunk_explicit(self, chunk: MessageList, msg_type: str, info def _process_single_chunk_implicit(self, chunk: MessageList, msg_type: str, info: dict[str, Any]) -> TextualMemoryItem | None: basic_info = self.extract_basic_info(chunk) - implicit_pref = self.extract_implicit_preferences(chunk) + if not basic_info["dialog_str"]: + return None + implicit_pref = self.extract_implicit_preferences(basic_info["dialog_str"]) if not implicit_pref: return None diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py index 380e5a890..7c4feced2 100644 --- a/src/memos/memories/textual/prefer_text_memory/utils.py +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -7,8 +7,8 @@ def convert_messages_to_string(messages: MessageList) -> str: message_text = "" for message in messages: if message["role"] == "user": - message_text += f"Query: {message['content']}\n" + message_text += f"Query: {message['content']}\n" if message["content"].strip() else "" elif message["role"] == "assistant": - message_text += f"Answer: {message['content']}\n" + message_text += f"Answer: {message['content']}\n" if message["content"].strip() else "" message_text = message_text.strip() return message_text \ No newline at end of file From 21ede589e3298b65097a9fc51f165205ac04449e Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Mon, 29 Sep 2025 20:09:14 +0800 Subject: [PATCH 28/62] finish make test and make pre-commit --- docs/openapi.json | 8 +- examples/mem_os/simple_prefs_memos_product.py | 594 +++++++++--------- src/memos/configs/mem_cube.py | 2 +- src/memos/configs/memory.py | 10 +- src/memos/configs/vec_db.py | 5 +- src/memos/mem_cube/general.py | 10 +- src/memos/mem_os/core.py | 9 +- src/memos/mem_os/product.py | 4 +- src/memos/memories/factory.py | 2 +- src/memos/memories/textual/item.py | 58 +- .../textual/prefer_text_memory/adder.py | 47 +- .../textual/prefer_text_memory/assemble.py | 74 ++- .../textual/prefer_text_memory/clustering.py | 118 ++-- .../textual/prefer_text_memory/config.py | 18 +- .../textual/prefer_text_memory/extractor.py | 136 ++-- .../textual/prefer_text_memory/factory.py | 80 ++- .../textual/prefer_text_memory/retrievers.py | 91 ++- .../textual/prefer_text_memory/spliter.py | 53 +- .../textual/prefer_text_memory/strategies.py | 18 - .../textual/prefer_text_memory/updater.py | 193 +++--- .../textual/prefer_text_memory/utils.py | 3 +- src/memos/memories/textual/preference.py | 127 ++-- src/memos/templates/prefer_complete_prompt.py | 48 +- src/memos/vec_dbs/factory.py | 2 +- src/memos/vec_dbs/milvus.py | 114 ++-- tests/configs/test_mem_cube.py | 2 +- 26 files changed, 1026 insertions(+), 800 deletions(-) delete mode 100644 src/memos/memories/textual/prefer_text_memory/strategies.py diff --git a/docs/openapi.json b/docs/openapi.json index 5a3471ac0..ee2ff1368 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -884,7 +884,7 @@ "type": "string", "title": "Session Id", "description": "Session ID for the MOS. This is used to distinguish between different dialogue", - "default": "0ce84b9c-0615-4b9d-83dd-fba50537d5d3" + "default": "41bb5e18-252d-4948-918c-07d82aa47086" }, "chat_model": { "$ref": "#/components/schemas/LLMConfigFactory", @@ -939,6 +939,12 @@ "description": "Enable parametric memory for the MemChat", "default": false }, + "enable_preference_memory": { + "type": "boolean", + "title": "Enable Preference Memory", + "description": "Enable preference memory for the MemChat", + "default": false + }, "enable_mem_scheduler": { "type": "boolean", "title": "Enable Mem Scheduler", diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py index 8f2973be3..9d0722223 100644 --- a/examples/mem_os/simple_prefs_memos_product.py +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -1,25 +1,11 @@ -import uuid -import json -import sys -import os - -# 添加项目根目录到 Python 路径 -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) -sys.path.insert(0, project_root) - -# 确保使用本地代码 -sys.path.insert(0, os.path.join(project_root, "src")) - -from memos.configs.mem_os import MOSConfig from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig from memos.mem_cube.general import GeneralMemCube from memos.mem_os.product import MOSProduct -# init MOS -mos_config = { - "user_id": "root", - "chat_model": { +def get_config(user_id: str): + llm_config = { "backend": "openai", "config": { "model_name_or_path": "gpt-4o-mini", @@ -27,182 +13,117 @@ "api_base": "http://xxxx/v1", "temperature": 0.1, "remove_think_prefix": True, - "max_tokens": 4096 - } - }, - "mem_reader": { - "backend": "simple_struct", - "config": { - "llm": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-4o-mini", - "api_key": "sk-xxxxx", - "api_base": "http://xxxx/v1", - "temperature": 0.0, - "remove_think_prefix": True, - "max_tokens": 8192 - } - }, - "embedder": { - "backend": "ollama", - "config": { - "model_name_or_path": "nomic-embed-text:latest" - } - }, - "chunker": { - "backend": "sentence", - "config": { - "tokenizer_or_token_counter": "gpt2", - "chunk_size": 512, - "chunk_overlap": 128, - "min_sentences_per_chunk": 1 - } - } - } - }, - "max_turns_window": 20, - "top_k": 5, - "enable_textual_memory": True, - "enable_activation_memory": False, - "enable_parametric_memory": False, - "enable_preference_memory": True -} + "max_tokens": 4096, + }, + } -cube_config = { - "model_schema": "memos.configs.mem_cube.GeneralMemCubeConfig", - "user_id": "6dab9067-173a-4034-b5d9-735fc01999e3", - "cube_id": "6dab9067-173a-4034-b5d9-735fc01999e3/mem_cube", - "config_filename": "config.json", - "text_mem": { - "backend": "tree_text", - "config": { - "cube_id": "6dab9067-173a-4034-b5d9-735fc01999e3/mem_cube", - "extractor_llm": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-4o-mini", - "api_key": "sk-xxxxx", - "api_base": "http://xxxx/v1", - "temperature": 0.8, - "max_tokens": 1024, - "top_p": 0.9, - "top_k": 50, - "remove_think_prefix": True - } - }, - "dispatcher_llm": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-4o-mini", - "api_key": "sk-xxxxx", - "api_base": "http://xxxx/v1", - "temperature": 0.8, - "max_tokens": 1024, - "top_p": 0.9, - "top_k": 50, - "remove_think_prefix": True - } - }, - "graph_db": { - "backend": "neo4j", - "config": { - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", - "db_name": "neo4j", - "user_name": "memosneo4j", - "embedding_dimension": 768, - "use_multi_db": False, - "auto_create": False - } - }, - "embedder": { - "backend": "ollama", - "config": { - "model_name_or_path": "nomic-embed-text:latest" - } - } - } - }, - "act_mem": { - "backend": "uninitialized", - "config": {} - }, - "para_mem": { - "backend": "uninitialized", - "config": {} - }, - "pref_mem": { - "backend": "pref_text", - "config": { - "cube_id": "6dab9067-173a-4034-b5d9-735fc01999e3/mem_cube", - "extractor_llm": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-4o-mini", - "api_key": "sk-xxxxx", - "api_base": "http://xxxx/v1", - "temperature": 0.8, - "max_tokens": 1024, - "top_p": 0.9, - "top_k": 50, - "remove_think_prefix": False - } - }, - "vector_db": { - "backend": "milvus", - "config": { - "collection_name": [ - "explicit_preference", - "implicit_preference", - "topic_preference", - "user_preference" - ], - "vector_dimension": 768, - "distance_metric": "cosine", - "uri": "./milvus_demo.db" - } - }, - "embedder": { - "backend": "ollama", - "config": { - "model_name_or_path": "nomic-embed-text:latest", - "embedding_dims": 768, - "api_base": "http://localhost:11434" - } - }, - "extractor": { - "backend": "naive", - "config": {} - }, - "adder": { - "backend": "naive", - "config": {} + embedder_config = { + "backend": "ollama", + "config": {"model_name_or_path": "nomic-embed-text:latest"}, + } + + # init MOS + mos_config = { + "user_id": user_id, + "chat_model": llm_config, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": llm_config, + "embedder": embedder_config, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, }, - "retriever": { - "backend": "naive", - "config": {} + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, + "enable_preference_memory": True, + } + + cube_config = { + "model_schema": "memos.configs.mem_cube.GeneralMemCubeConfig", + "user_id": user_id, + "cube_id": f"{user_id}/mem_cube", + "config_filename": "config.json", + "text_mem": { + "backend": "tree_text", + "config": { + "cube_id": f"{user_id}/mem_cube", + "extractor_llm": llm_config, + "dispatcher_llm": llm_config, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": "bolt://localhost:7687", + "user": "neo4j", + "password": "12345678", + "db_name": "neo4j", + "user_name": "memosneo4j", + "embedding_dimension": 768, + "use_multi_db": False, + "auto_create": False, + }, + }, + "embedder": embedder_config, }, - "updater": { - "backend": "naive", - "config": {} + }, + "act_mem": {"backend": "uninitialized", "config": {}}, + "para_mem": {"backend": "uninitialized", "config": {}}, + "pref_mem": { + "backend": "pref_text", + "config": { + "cube_id": f"{user_id}/mem_cube", + "extractor_llm": llm_config, + "vector_db": { + "backend": "milvus", + "config": { + "collection_name": [ + "explicit_preference", + "implicit_preference", + "topic_preference", + "user_preference", + ], + "vector_dimension": 768, + "distance_metric": "cosine", + "uri": "./milvus_demo.db", + }, + }, + "embedder": embedder_config, + "extractor": {"backend": "naive", "config": {}}, + "adder": {"backend": "naive", "config": {}}, + "retriever": {"backend": "naive", "config": {}}, + "updater": {"backend": "naive", "config": {}}, + "assembler": {"backend": "naive", "config": {}}, }, - "assembler": { - "backend": "naive", - "config": {} - } - } + }, } -} -mos_config = MOSConfig(**mos_config) + mos_config = MOSConfig(**mos_config) + cube_config = GeneralMemCubeConfig.model_validate(cube_config) + + return mos_config, cube_config + + +user_id = "superman" +mos_config, mem_cube_config = get_config(user_id) + memory = MOSProduct(default_config=mos_config) # register user and mem cube -mem_cube_config = GeneralMemCubeConfig.model_validate(cube_config) mem_cube = GeneralMemCube(mem_cube_config) -user_id = "6dab9067-173a-4034-b5d9-735fc01999e3" -register_result = memory.user_register(user_id=user_id, user_name="root", default_mem_cube=mem_cube) + +register_result = memory.user_register( + user_id=user_id, user_name=user_id, default_mem_cube=mem_cube +) print(f"register_result: {register_result}") # add memory for user @@ -210,171 +131,269 @@ messages=[ # Sports and hobbies {"role": "user", "content": "I like playing football."}, - {"role": "assistant", "content": "I like playing football too. It's a great way to stay active and work as a team."}, - + { + "role": "assistant", + "content": "I like playing football too. It's a great way to stay active and work as a team.", + }, {"role": "user", "content": "Do you enjoy watching basketball games?"}, - {"role": "assistant", "content": "Yes, I find basketball exciting to watch, especially the fast-paced action and strategic plays."}, - + { + "role": "assistant", + "content": "Yes, I find basketball exciting to watch, especially the fast-paced action and strategic plays.", + }, {"role": "user", "content": "I'm learning to play tennis this summer."}, - {"role": "assistant", "content": "That's wonderful! Tennis is a fantastic sport that improves both physical fitness and mental focus."}, - + { + "role": "assistant", + "content": "That's wonderful! Tennis is a fantastic sport that improves both physical fitness and mental focus.", + }, # Food and cooking {"role": "user", "content": "What's your favorite type of cuisine?"}, - {"role": "assistant", "content": "I really enjoy Italian cuisine, especially pasta dishes and authentic pizza."}, - + { + "role": "assistant", + "content": "I really enjoy Italian cuisine, especially pasta dishes and authentic pizza.", + }, {"role": "user", "content": "I love cooking Asian food at home."}, - {"role": "assistant", "content": "Asian cooking is amazing! The balance of flavors and techniques is so interesting to learn."}, - + { + "role": "assistant", + "content": "Asian cooking is amazing! The balance of flavors and techniques is so interesting to learn.", + }, {"role": "user", "content": "Have you ever tried making homemade bread?"}, - {"role": "assistant", "content": "Yes, there's something magical about the process of making bread from scratch. The smell is incredible!"}, - + { + "role": "assistant", + "content": "Yes, there's something magical about the process of making bread from scratch. The smell is incredible!", + }, # Travel and places {"role": "user", "content": "I'm planning a trip to Japan next year."}, - {"role": "assistant", "content": "Japan is an incredible destination! The culture, food, and technology blend is fascinating."}, - + { + "role": "assistant", + "content": "Japan is an incredible destination! The culture, food, and technology blend is fascinating.", + }, {"role": "user", "content": "What's the most beautiful place you've ever seen?"}, - {"role": "assistant", "content": "I find mountain landscapes particularly breathtaking, especially during sunrise or sunset."}, - + { + "role": "assistant", + "content": "I find mountain landscapes particularly breathtaking, especially during sunrise or sunset.", + }, {"role": "user", "content": "I prefer beach vacations over city trips."}, - {"role": "assistant", "content": "Beach vacations are so relaxing! The sound of waves and warm sand is perfect for unwinding."}, - + { + "role": "assistant", + "content": "Beach vacations are so relaxing! The sound of waves and warm sand is perfect for unwinding.", + }, # Technology and innovation {"role": "user", "content": "What do you think about artificial intelligence?"}, - {"role": "assistant", "content": "AI is transforming many aspects of our lives. It's exciting to see how it can help solve complex problems."}, - + { + "role": "assistant", + "content": "AI is transforming many aspects of our lives. It's exciting to see how it can help solve complex problems.", + }, {"role": "user", "content": "I'm learning to code in Python."}, - {"role": "assistant", "content": "Python is a great language to start with! It's versatile and has a wonderful community."}, - + { + "role": "assistant", + "content": "Python is a great language to start with! It's versatile and has a wonderful community.", + }, {"role": "user", "content": "Do you use any productivity apps?"}, - {"role": "assistant", "content": "I find task management apps really helpful for organizing daily activities and staying focused."}, - + { + "role": "assistant", + "content": "I find task management apps really helpful for organizing daily activities and staying focused.", + }, # Music and entertainment {"role": "user", "content": "What kind of music do you enjoy?"}, - {"role": "assistant", "content": "I appreciate many genres, but I particularly enjoy jazz and classical music for their complexity."}, - + { + "role": "assistant", + "content": "I appreciate many genres, but I particularly enjoy jazz and classical music for their complexity.", + }, {"role": "user", "content": "I play the guitar in my free time."}, - {"role": "assistant", "content": "That's fantastic! Playing an instrument is such a rewarding creative outlet."}, - + { + "role": "assistant", + "content": "That's fantastic! Playing an instrument is such a rewarding creative outlet.", + }, {"role": "user", "content": "Have you seen any good movies lately?"}, - {"role": "assistant", "content": "I enjoy films that make you think deeply about human nature and society."}, - + { + "role": "assistant", + "content": "I enjoy films that make you think deeply about human nature and society.", + }, # Books and learning {"role": "user", "content": "Do you read books regularly?"}, - {"role": "assistant", "content": "Yes, reading is one of my favorite ways to learn and explore new ideas."}, - + { + "role": "assistant", + "content": "Yes, reading is one of my favorite ways to learn and explore new ideas.", + }, {"role": "user", "content": "I'm currently reading a science fiction novel."}, - {"role": "assistant", "content": "Science fiction is great for exploring possibilities and imagining different futures."}, - + { + "role": "assistant", + "content": "Science fiction is great for exploring possibilities and imagining different futures.", + }, {"role": "user", "content": "What's your favorite subject to learn about?"}, - {"role": "assistant", "content": "I'm fascinated by psychology and how the human mind works."}, - + { + "role": "assistant", + "content": "I'm fascinated by psychology and how the human mind works.", + }, # Nature and environment {"role": "user", "content": "I love hiking in the mountains."}, - {"role": "assistant", "content": "Hiking is wonderful! It's a great way to connect with nature and get exercise."}, - + { + "role": "assistant", + "content": "Hiking is wonderful! It's a great way to connect with nature and get exercise.", + }, {"role": "user", "content": "Are you concerned about climate change?"}, - {"role": "assistant", "content": "Yes, it's one of the most important challenges of our time that requires collective action."}, - + { + "role": "assistant", + "content": "Yes, it's one of the most important challenges of our time that requires collective action.", + }, {"role": "user", "content": "I try to recycle and reduce waste at home."}, - {"role": "assistant", "content": "That's commendable! Every small action contributes to environmental protection."}, - + { + "role": "assistant", + "content": "That's commendable! Every small action contributes to environmental protection.", + }, # Work and career {"role": "user", "content": "What do you do for work?"}, - {"role": "assistant", "content": "I work in technology, helping people solve problems and improve their daily lives."}, - + { + "role": "assistant", + "content": "I work in technology, helping people solve problems and improve their daily lives.", + }, {"role": "user", "content": "I'm considering a career change."}, - {"role": "assistant", "content": "Career changes can be challenging but also very rewarding when you find the right fit."}, - + { + "role": "assistant", + "content": "Career changes can be challenging but also very rewarding when you find the right fit.", + }, {"role": "user", "content": "Do you work from home or in an office?"}, - {"role": "assistant", "content": "I work remotely, which gives me flexibility in how I organize my day."}, - + { + "role": "assistant", + "content": "I work remotely, which gives me flexibility in how I organize my day.", + }, # Health and wellness {"role": "user", "content": "I've started doing yoga in the morning."}, - {"role": "assistant", "content": "Yoga is excellent for both physical and mental well-being. Great way to start the day!"}, - + { + "role": "assistant", + "content": "Yoga is excellent for both physical and mental well-being. Great way to start the day!", + }, {"role": "user", "content": "How do you manage stress?"}, - {"role": "assistant", "content": "I find meditation and regular exercise help me stay balanced and focused."}, - + { + "role": "assistant", + "content": "I find meditation and regular exercise help me stay balanced and focused.", + }, {"role": "user", "content": "I'm trying to eat more vegetables."}, - {"role": "assistant", "content": "That's a great health goal! Vegetables provide so many essential nutrients."}, - + { + "role": "assistant", + "content": "That's a great health goal! Vegetables provide so many essential nutrients.", + }, # Art and creativity {"role": "user", "content": "Do you enjoy painting or drawing?"}, - {"role": "assistant", "content": "I appreciate visual arts, though I'm more of an observer than a creator."}, - + { + "role": "assistant", + "content": "I appreciate visual arts, though I'm more of an observer than a creator.", + }, {"role": "user", "content": "I'm learning photography."}, - {"role": "assistant", "content": "Photography is a beautiful way to capture moments and express your perspective."}, - + { + "role": "assistant", + "content": "Photography is a beautiful way to capture moments and express your perspective.", + }, {"role": "user", "content": "What's your favorite type of art?"}, - {"role": "assistant", "content": "I'm drawn to abstract art because it allows for personal interpretation."}, - + { + "role": "assistant", + "content": "I'm drawn to abstract art because it allows for personal interpretation.", + }, # Family and relationships {"role": "user", "content": "I have two siblings."}, - {"role": "assistant", "content": "Having siblings can be wonderful - they're often our first friends and lifelong companions."}, - + { + "role": "assistant", + "content": "Having siblings can be wonderful - they're often our first friends and lifelong companions.", + }, {"role": "user", "content": "Do you have any pets?"}, - {"role": "assistant", "content": "I don't have pets, but I think they bring so much joy and companionship to people's lives."}, - + { + "role": "assistant", + "content": "I don't have pets, but I think they bring so much joy and companionship to people's lives.", + }, {"role": "user", "content": "I'm planning to visit my grandparents this weekend."}, - {"role": "assistant", "content": "That's lovely! Grandparents often have such interesting stories and wisdom to share."}, - + { + "role": "assistant", + "content": "That's lovely! Grandparents often have such interesting stories and wisdom to share.", + }, # Education and growth {"role": "user", "content": "I'm taking an online course in data science."}, - {"role": "assistant", "content": "Data science is a fascinating field with many applications across different industries."}, - + { + "role": "assistant", + "content": "Data science is a fascinating field with many applications across different industries.", + }, {"role": "user", "content": "What's the best way to learn a new language?"}, - {"role": "assistant", "content": "Immersion and consistent practice are key, along with finding methods that work for your learning style."}, - + { + "role": "assistant", + "content": "Immersion and consistent practice are key, along with finding methods that work for your learning style.", + }, {"role": "user", "content": "I enjoy attending workshops and conferences."}, - {"role": "assistant", "content": "Workshops are great for hands-on learning and networking with like-minded people."}, - + { + "role": "assistant", + "content": "Workshops are great for hands-on learning and networking with like-minded people.", + }, # Hobbies and interests {"role": "user", "content": "I collect vintage postcards."}, - {"role": "assistant", "content": "What a unique hobby! Postcards offer glimpses into different times and places."}, - + { + "role": "assistant", + "content": "What a unique hobby! Postcards offer glimpses into different times and places.", + }, {"role": "user", "content": "Do you garden?"}, - {"role": "assistant", "content": "I appreciate gardening, though I'm more of a plant observer than a gardener myself."}, - + { + "role": "assistant", + "content": "I appreciate gardening, though I'm more of a plant observer than a gardener myself.", + }, {"role": "user", "content": "I love solving puzzles and brain teasers."}, - {"role": "assistant", "content": "Puzzles are excellent for keeping the mind sharp and developing problem-solving skills."}, - + { + "role": "assistant", + "content": "Puzzles are excellent for keeping the mind sharp and developing problem-solving skills.", + }, # Future and goals {"role": "user", "content": "What are your goals for next year?"}, - {"role": "assistant", "content": "I aim to continue learning and growing, helping more people with their challenges."}, - + { + "role": "assistant", + "content": "I aim to continue learning and growing, helping more people with their challenges.", + }, {"role": "user", "content": "I want to start my own business someday."}, - {"role": "assistant", "content": "Entrepreneurship is exciting! It requires courage, planning, and persistence."}, - + { + "role": "assistant", + "content": "Entrepreneurship is exciting! It requires courage, planning, and persistence.", + }, {"role": "user", "content": "Do you believe in setting New Year's resolutions?"}, - {"role": "assistant", "content": "I think any time is good for setting positive goals and making positive changes."}, - + { + "role": "assistant", + "content": "I think any time is good for setting positive goals and making positive changes.", + }, # Philosophy and life {"role": "user", "content": "What's the meaning of life to you?"}, - {"role": "assistant", "content": "I believe it's about learning, growing, and making positive connections with others."}, - + { + "role": "assistant", + "content": "I believe it's about learning, growing, and making positive connections with others.", + }, {"role": "user", "content": "I try to practice gratitude daily."}, - {"role": "assistant", "content": "Gratitude is such a powerful practice that can transform your perspective on life."}, - + { + "role": "assistant", + "content": "Gratitude is such a powerful practice that can transform your perspective on life.", + }, {"role": "user", "content": "What's your biggest fear?"}, - {"role": "assistant", "content": "I think fear of not being able to help when someone really needs it would be challenging."}, - + { + "role": "assistant", + "content": "I think fear of not being able to help when someone really needs it would be challenging.", + }, # Technology and daily life {"role": "user", "content": "How has technology changed your daily routine?"}, - {"role": "assistant", "content": "Technology has made many tasks more efficient and connected me with people worldwide."}, - + { + "role": "assistant", + "content": "Technology has made many tasks more efficient and connected me with people worldwide.", + }, {"role": "user", "content": "I'm trying to reduce my screen time."}, - {"role": "assistant", "content": "That's a healthy goal! Finding balance with technology is important for well-being."}, - + { + "role": "assistant", + "content": "That's a healthy goal! Finding balance with technology is important for well-being.", + }, {"role": "user", "content": "Do you use social media?"}, - {"role": "assistant", "content": "I think social media can be a great tool for connection when used mindfully."}, - + { + "role": "assistant", + "content": "I think social media can be a great tool for connection when used mindfully.", + }, # Final thoughts {"role": "user", "content": "This has been a great conversation!"}, - {"role": "assistant", "content": "I've really enjoyed our chat too! It's wonderful to explore so many different topics together."}, + { + "role": "assistant", + "content": "I've really enjoyed our chat too! It's wonderful to explore so many different topics together.", + }, ], user_id=user_id, - mem_cube_id=register_result["default_cube_id"] - + mem_cube_id=register_result["default_cube_id"], ) user_cube_id = memory.user_manager.get_user_cubes(user_id)[0].cube_id @@ -383,7 +402,8 @@ retrieved_memories = memory.search(query="What do you like?", user_id=user_id) -# output text_memories: I like playing football, act_memories, para_memories -# print(f"text_memories: {retrieved_memories['text_mem']}") -# print(f"pref_memories: {retrieved_memories['pref_mem']}") -print(f"len_pref_memories: {len(retrieved_memories['pref_mem'][0]['memories'])}") +print( + f"len_pref_memories: {len(retrieved_memories['pref_mem'][0]['memories'])}" + if retrieved_memories["pref_mem"] + else 0 +) diff --git a/src/memos/configs/mem_cube.py b/src/memos/configs/mem_cube.py index bc3fdf557..4bd709fab 100644 --- a/src/memos/configs/mem_cube.py +++ b/src/memos/configs/mem_cube.py @@ -102,4 +102,4 @@ def validate_pref_mem(cls, pref_mem: MemoryConfigFactory) -> MemoryConfigFactory raise ConfigurationError( f"GeneralMemCubeConfig requires pref_mem backend to be one of {allowed_backends}, got '{pref_mem.backend}'" ) - return pref_mem \ No newline at end of file + return pref_mem diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 6ee76ffae..cf5cdfbda 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -11,11 +11,12 @@ from memos.configs.vec_db import VectorDBConfigFactory from memos.exceptions import ConfigurationError from memos.memories.textual.prefer_text_memory.config import ( - AdderConfigFactory, + AdderConfigFactory, + AssemblerConfigFactory, ExtractorConfigFactory, - RetrieverConfigFactory, - UpdaterConfigFactory, - AssemblerConfigFactory) + RetrieverConfigFactory, + UpdaterConfigFactory, +) # ─── 1. Global Base Memory Config ───────────────────────────────────────────── @@ -230,6 +231,7 @@ class PreferenceTextMemoryConfig(BaseTextMemoryConfig): description="Assembler configuration for the memory assembling", ) + # ─── 3. Global Memory Config Factory ────────────────────────────────────────── diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index 17611f41e..dd1748714 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -44,10 +44,13 @@ class MilvusVecDBConfig(BaseVecDBConfig): uri: str = Field(..., description="URI for Milvus connection") collection_name: list[str] = Field(..., description="Name(s) of the collection(s)") - max_length: int = Field(default=65535, description="Maximum length for string fields (varChar type)") + max_length: int = Field( + default=65535, description="Maximum length for string fields (varChar type)" + ) user_name: str = Field(default="", description="User name for Milvus connection") password: str = Field(default="", description="Password for Milvus connection") + class VectorDBConfigFactory(BaseConfig): """Factory class for creating vector database configurations.""" diff --git a/src/memos/mem_cube/general.py b/src/memos/mem_cube/general.py index 0fa3b26e2..d6a57012e 100644 --- a/src/memos/mem_cube/general.py +++ b/src/memos/mem_cube/general.py @@ -45,7 +45,9 @@ def __init__(self, config: GeneralMemCubeConfig): ) def load( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Load memories. Args: @@ -85,7 +87,9 @@ def load( logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump memories. Args: @@ -231,4 +235,4 @@ def pref_mem(self, value: BaseTextMemory) -> None: """Set the preference memory.""" if not isinstance(value, BaseTextMemory): raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") - self._pref_mem = value \ No newline at end of file + self._pref_mem = value diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index e9e579a0f..d96b555fd 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -616,7 +616,7 @@ def search( logger.info( f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" ) - + if ( (mem_cube_id in install_cube_ids) and (mem_cube.pref_mem is not None) @@ -733,9 +733,10 @@ def add( ): messages_list = [messages] pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( - messages_list, - type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}) + messages_list, + type="chat", + info={"user_id": target_user_id, "session_id": self.session_id}, + ) preferences = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) logger.info( f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {preferences}" diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index fe3f912a4..de6113d61 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -1341,7 +1341,9 @@ def search( memories["metadata"]["id"] = memories["id"] memories["metadata"]["memory"] = memories["memory"] memories_list.append(memories) - reformat_pref_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list}) + reformat_pref_memory_list.append( + {"cube_id": memory["cube_id"], "memories": memories_list} + ) search_result["pref_mem"] = reformat_pref_memory_list time_end = time.time() logger.info( diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index d796b5d04..7a907bf2f 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -10,8 +10,8 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.general import GeneralTextMemory from memos.memories.textual.naive import NaiveTextMemory -from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.preference import PreferenceTextMemory +from memos.memories.textual.tree import TreeTextMemory class MemoryFactory(BaseMemory): diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index c6cc67fdb..f335175a7 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -107,57 +107,31 @@ class SearchedTreeNodeTextualMemoryMetadata(TreeNodeTextualMemoryMetadata): class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): """Metadata for preference memory item.""" - preference_type: Literal["explicit_preference", "implicit_preference", "topic_preference", "user_preference"] = Field( - default="explicit_preference", description="Type of preference." - ) - dialog_id: str | None = Field( - default=None, description="ID of the dialog." - ) + preference_type: Literal[ + "explicit_preference", "implicit_preference", "topic_preference", "user_preference" + ] = Field(default="explicit_preference", description="Type of preference.") + dialog_id: str | None = Field(default=None, description="ID of the dialog.") dialog_msgs: list[dict[str, str]] | None = Field( default=None, description="Messages of the dialog." ) - dialog_str: str | None = Field( - default=None, description="String of the dialog." - ) - dialog_vector: list[float] | None = Field( - default=None, description="Vector of the dialog." - ) - topic_name: str | None = Field( - default=None, description="Name of the topic." - ) - topic_description: str | None = Field( - default=None, description="Description of the topic." - ) - topic_vector: list[float] | None = Field( - default=None, description="Vector of the topic." - ) - explicit_preference: str | None = Field( - default=None, description="Explicit preference." - ) - center_dialog: str | None = Field( - default=None, description="String of the center dialog." - ) + dialog_str: str | None = Field(default=None, description="String of the dialog.") + dialog_vector: list[float] | None = Field(default=None, description="Vector of the dialog.") + topic_name: str | None = Field(default=None, description="Name of the topic.") + topic_description: str | None = Field(default=None, description="Description of the topic.") + topic_vector: list[float] | None = Field(default=None, description="Vector of the topic.") + explicit_preference: str | None = Field(default=None, description="Explicit preference.") + center_dialog: str | None = Field(default=None, description="String of the center dialog.") center_vector: list[float] | None = Field( default=None, description="Vector of the center dialog." ) - created_at: str | None = Field( - default=None, description="Timestamp of the dialog." - ) - implicit_preference: str | None = Field( - default=None, description="Implicit preference." - ) - topic_cluster_name: str | None = Field( - default=None, description="Name of the topic cluster." - ) + created_at: str | None = Field(default=None, description="Timestamp of the dialog.") + implicit_preference: str | None = Field(default=None, description="Implicit preference.") + topic_cluster_name: str | None = Field(default=None, description="Name of the topic cluster.") topic_cluster_description: str | None = Field( default=None, description="Description of the topic cluster." ) - topic_preference: str | None = Field( - default=None, description="Preferences of the topic." - ) - user_preference: str | None = Field( - default=None, description="Preferences of the user." - ) + topic_preference: str | None = Field(default=None, description="Preferences of the topic.") + user_preference: str | None = Field(default=None, description="Preferences of the user.") class TextualMemoryItem(BaseModel): diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 919e6dd0e..c61b322b3 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -1,17 +1,17 @@ -from abc import ABC, abstractmethod -from typing import Any import json -from datetime import datetime + +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any + from memos.memories.textual.item import TextualMemoryItem -from memos.vec_dbs.item import VecDBItem from memos.templates.prefer_complete_prompt import NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT -from memos.types import MessageList +from memos.vec_dbs.item import VecDBItem class BaseAdder(ABC): """Abstract base class for adders.""" - + @abstractmethod def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the adder.""" @@ -26,8 +26,10 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]], *args, **kwarg list[str]: List of added memory IDs. """ + class NaiveAdder(BaseAdder): """Naive adder.""" + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the naive adder.""" super().__init__(llm_provider, embedder, vector_db) @@ -38,8 +40,10 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): def _judge_update_or_add(self, old_msg: str, new_msg: str) -> bool: """Judge if the new message expresses the same core content as the old message.""" # Use the template prompt with placeholders - prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_msg).replace("{new_information}", new_msg) - + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_msg).replace( + "{new_information}", new_msg + ) + try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() @@ -50,31 +54,31 @@ def _judge_update_or_add(self, old_msg: str, new_msg: str) -> bool: print(f"Error in judge_update_or_add: {e}") # Fallback to simple string comparison return old_msg == new_msg - + def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: """Process a single memory and return its ID if added successfully.""" try: payload = memory.to_dict()["metadata"] vec_db_item = VecDBItem( - id=memory.id, - vector=memory.metadata.dialog_vector, - payload=payload + id=memory.id, vector=memory.metadata.dialog_vector, payload=payload ) pref_type_collection_map = { "explicit_preference": "explicit_preference", "implicit_preference": "implicit_preference", "topic_preference": "topic_preference", - "user_preference": "user_preference" + "user_preference": "user_preference", } preference_type = memory.metadata.preference_type collection_name = pref_type_collection_map[preference_type] - search_results = self.vector_db.search(memory.metadata.dialog_vector, collection_name, top_k=1) + search_results = self.vector_db.search( + memory.metadata.dialog_vector, collection_name, top_k=1 + ) recall = search_results[0] if search_results else None if not recall or (recall.score is not None and recall.score < 0.5): self.vector_db.update(collection_name, memory.id, vec_db_item) return memory.id - + old_msg_str = recall.payload.get("dialog_str", "") new_msg_str = memory.metadata.dialog_str is_same = self._judge_update_or_add(old_msg_str, new_msg_str) @@ -87,7 +91,13 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: print(f"Error processing memory {memory.id}: {e}") return None - def add(self, memories: list[TextualMemoryItem | dict[str, Any]], max_workers: int = 8, *args, **kwargs) -> list[str]: + def add( + self, + memories: list[TextualMemoryItem | dict[str, Any]], + max_workers: int = 8, + *args, + **kwargs, + ) -> list[str]: """Add the instruct preference memories using thread pool for acceleration.""" if not memories: return [] @@ -95,10 +105,9 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]], max_workers: i added_ids = [] with ThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: future_to_memory = { - executor.submit(self._process_single_memory, memory): memory - for memory in memories + executor.submit(self._process_single_memory, memory): memory for memory in memories } - + for future in as_completed(future_to_memory): try: memory_id = future.result() diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py index e16274737..9f4885c38 100644 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -1,17 +1,19 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List -import json -from memos.templates.prefer_complete_prompt import NAIVE_PREFERENCE_INTEGRATION_PROMPT + from memos.memories.textual.item import TextualMemoryItem + class BaseAssembler(ABC): """Abstract base class for assemblers.""" + @abstractmethod def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the assembler.""" @abstractmethod - def get_instruction(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: + def get_instruction( + self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str = "semi" + ) -> str: """Assemble query and memories into a single memory. Args: query: The query to assemble. @@ -24,6 +26,7 @@ def get_instruction(self, query: str, memories: list[TextualMemoryItem], assembl class NaiveAssembler(BaseAssembler): """Naive assembler.""" + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the naive assembler.""" super().__init__(llm_provider, embedder, vector_db) @@ -31,7 +34,9 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.embedder = embedder self.vector_db = vector_db - def get_instruction(self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str="semi") -> str: + def get_instruction( + self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str = "semi" + ) -> str: """Assemble query and memories into a single memory.""" # Initialize all preference lists @@ -40,28 +45,32 @@ def get_instruction(self, query: str, memories: list[TextualMemoryItem], assembl implicit_prefs = [] topic_prefs = [] user_prefs = [] - + # Single loop to categorize all memories by preference type for memory in memories: if memory.metadata.preference_type == "explicit_preference": - explicit_prefs.append({ - "dialog_str": memory.metadata.dialog_str, - "explicit_preference": memory.metadata.explicit_preference - }) + explicit_prefs.append( + { + "dialog_str": memory.metadata.dialog_str, + "explicit_preference": memory.metadata.explicit_preference, + } + ) elif memory.metadata.preference_type == "implicit_preference": - implicit_prefs.append({ - "dialog_str": memory.metadata.center_dialog, - "implicit_preference": memory.metadata.implicit_preference - }) + implicit_prefs.append( + { + "dialog_str": memory.metadata.center_dialog, + "implicit_preference": memory.metadata.implicit_preference, + } + ) elif memory.metadata.preference_type == "topic_preference": - topic_prefs.append({ - "center_dialog_str": memory.metadata.center_dialog, - "topic_preferences": memory.metadata.topic_preferences - }) + topic_prefs.append( + { + "center_dialog_str": memory.metadata.center_dialog, + "topic_preferences": memory.metadata.topic_preferences, + } + ) elif memory.metadata.preference_type == "user_preference": - user_prefs.append({ - "user_preferences": memory.metadata.user_preferences - }) + user_prefs.append({"user_preferences": memory.metadata.user_preferences}) else: textual_mems.append(memory.memory) @@ -75,24 +84,24 @@ def get_instruction(self, query: str, memories: list[TextualMemoryItem], assembl memories_parts.append("## Explicit Preferences:") for i, pref in enumerate(explicit_prefs, 1): memories_parts.append(f"{i}. {pref['dialog_str']}") - + if implicit_prefs: memories_parts.append("\n## Implicit Preferences:") for i, pref in enumerate(implicit_prefs, 1): memories_parts.append(f"{i}. {pref['dialog_str']}") - + if topic_prefs: memories_parts.append("\n## Topic Preferences:") for i, pref in enumerate(topic_prefs, 1): memories_parts.append(f"{i}. {pref['center_dialog_str']}") - + if user_prefs: memories_parts.append("\n## User Preferences:") for i, pref in enumerate(user_prefs, 1): memories_parts.append(f"{i}. {pref['user_preferences']}") - + memories_str = "\n".join(memories_parts) - + system_prompt = ( "You are a knowledgeable and helpful AI assistant. " "You have access to conversation memories that help you provide more personalized responses. " @@ -100,16 +109,17 @@ def get_instruction(self, query: str, memories: list[TextualMemoryItem], assembl "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." f"\n\n## Memories:\n{memories_str}" ) - + if assemble_strategy == "raw": return system_prompt.replace("{memories}", memories_str) elif assemble_strategy == "semi": return ( - system_prompt + - ("Note: Textual memories are summaries of facts, while preference memories are summaries of user preferences. " + \ - "Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts." + \ - "When encountering preference conflicts, the priority is: explicit preferences > implicit preferences > textual memories.") + system_prompt + + ( + "Note: Textual memories are summaries of facts, while preference memories are summaries of user preferences. " + + "Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts." + + "When encountering preference conflicts, the priority is: explicit preferences > implicit preferences > textual memories." + ) ).replace("{memories}", memories_str) else: raise ValueError(f"Invalid assemble strategy: {assemble_strategy}") - diff --git a/src/memos/memories/textual/prefer_text_memory/clustering.py b/src/memos/memories/textual/prefer_text_memory/clustering.py index beaaf108e..cb68370b6 100644 --- a/src/memos/memories/textual/prefer_text_memory/clustering.py +++ b/src/memos/memories/textual/prefer_text_memory/clustering.py @@ -1,9 +1,12 @@ +import uuid + from abc import ABC, abstractmethod -from pydantic import BaseModel -from typing import List, Dict, Any -import numpy as np from datetime import datetime -import uuid +from typing import Any + +import numpy as np + +from pydantic import BaseModel from sklearn.cluster import HDBSCAN from sklearn.neighbors import NearestNeighbors @@ -14,37 +17,38 @@ class ClusterResult(BaseModel): cluster_id: str cluster_label: int center_index: int - center_vector: List[float] + center_vector: list[float] size: int created_at: str - items: List[Dict[str, Any]] + items: list[dict[str, Any]] center_dialog_msgs: MessageList = None center_dialog_str: str = None + class BaseClusterer(ABC): """Abstract base class for clustering algorithms.""" - + @abstractmethod - def cluster(self, vectors: List[List[float]], **kwargs) -> List[ClusterResult]: + def cluster(self, vectors: list[list[float]], **kwargs) -> list[ClusterResult]: """ Perform clustering on the given vectors. - + Args: vectors: List of vectors to cluster **kwargs: Additional clustering parameters - + Returns: List of cluster information dictionaries """ - pass + class HDBSCANClusterer(BaseClusterer): """HDBSCAN-based clustering implementation.""" - + def __init__(self, min_cluster_size: int = 2, min_samples: int = 1, metric: str = "cosine"): """ Initialize HDBSCAN clusterer. - + Args: min_cluster_size: Minimum size of a cluster min_samples: Minimum number of samples in a neighborhood @@ -55,50 +59,50 @@ def __init__(self, min_cluster_size: int = 2, min_samples: int = 1, metric: str self.metric = metric self.clusterer = HDBSCAN( - min_cluster_size=self.min_cluster_size, - min_samples=self.min_samples, - metric=self.metric - ) - - def cluster(self, vectors: List[List[float]], **kwargs) -> List[ClusterResult]: + min_cluster_size=self.min_cluster_size, min_samples=self.min_samples, metric=self.metric + ) + + def cluster(self, vectors: list[list[float]], **kwargs) -> list[ClusterResult]: """ Perform HDBSCAN clustering on the given vectors. - + Args: vectors: List of vectors to cluster **kwargs: Additional clustering parameters - + Returns: List of cluster information dictionaries """ if not vectors or len(vectors) < 2: - return [] + return [] # Convert to numpy array vectors_array = np.array(vectors) # Perform clustering cluster_labels = self.clusterer.fit_predict(vectors_array) - + # Group vectors by cluster clusters = {} for i, label in enumerate(cluster_labels): if label not in clusters: clusters[label] = [] - clusters[label].append({ - "vector": vectors[i], # origin sample vector - "index": i # origin sample index - }) - + clusters[label].append( + { + "vector": vectors[i], # origin sample vector + "index": i, # origin sample index + } + ) + # Convert to cluster information cluster_info = [] for cluster_label, items_in_cluster in clusters.items(): if cluster_label == -1: # Skip noise points continue - + # Calculate cluster geometric center geometric_center_vector = np.mean([item["vector"] for item in items_in_cluster], axis=0) - + # Find the vector closest to the geometric center using sklearn cluster_vectors = [item["vector"] for item in items_in_cluster] nn = NearestNeighbors(n_neighbors=1, metric=self.metric) @@ -107,48 +111,54 @@ def cluster(self, vectors: List[List[float]], **kwargs) -> List[ClusterResult]: closest_idx = indices[0][0] center_index = items_in_cluster[closest_idx]["index"] center_vector = items_in_cluster[closest_idx]["vector"] - - cluster_info.append(ClusterResult( - cluster_id=str(uuid.uuid4()), - cluster_label=cluster_label, # cluseter label: -1(noise) 0 1 2 3 ... - center_index=center_index, # center sample origin index - center_vector=center_vector, - size=len(items_in_cluster), # size of the cluster - created_at=datetime.now().isoformat(), - items=items_in_cluster # list of origin sample index and vector in this cluster - )) - + + cluster_info.append( + ClusterResult( + cluster_id=str(uuid.uuid4()), + cluster_label=cluster_label, # cluseter label: -1(noise) 0 1 2 3 ... + center_index=center_index, # center sample origin index + center_vector=center_vector, + size=len(items_in_cluster), # size of the cluster + created_at=datetime.now().isoformat(), + items=items_in_cluster, # list of origin sample index and vector in this cluster + ) + ) + return cluster_info - def search_knn_by_center_embeddings(self, center_emb: List[float], vectors: List[List[float]], top_k: int = 5) -> List[Dict[str, Any]]: + def search_knn_by_center_embeddings( + self, center_emb: list[float], vectors: list[list[float]], top_k: int = 5 + ) -> list[dict[str, Any]]: """ Search for k-nearest neighbors to the center embedding using sklearn. - + Args: center_emb: Center embedding vector vectors: List of vectors to search in top_k: Number of nearest neighbors to return (maybe include the center embedding itself) - + Returns: List of dictionaries containing index, distance, and vector for each neighbor """ if not vectors: return [] - + # Initialize NearestNeighbors with the same metric nn = NearestNeighbors(n_neighbors=min(top_k, len(vectors)), metric=self.metric) nn.fit(vectors) - + # Search for nearest neighbors distances, indices = nn.kneighbors([center_emb]) - + # Format results results = [] - for dist, idx in zip(distances[0], indices[0]): - results.append({ - "index": int(idx), # neighbor original index - "distance": float(dist), # neighbor distance - "vector": vectors[idx] # neighbor vector - }) - + for dist, idx in zip(distances[0], indices[0], strict=False): + results.append( + { + "index": int(idx), # neighbor original index + "distance": float(dist), # neighbor distance + "vector": vectors[idx], # neighbor vector + } + ) + return results diff --git a/src/memos/memories/textual/prefer_text_memory/config.py b/src/memos/memories/textual/prefer_text_memory/config.py index fa6067528..f11e106b7 100644 --- a/src/memos/memories/textual/prefer_text_memory/config.py +++ b/src/memos/memories/textual/prefer_text_memory/config.py @@ -1,11 +1,8 @@ -from memos.configs.base import BaseConfig from typing import Any, ClassVar -from pydantic import Field -from pydantic import field_validator, model_validator -from memos.configs.vec_db import VectorDBConfigFactory -from memos.configs.embedder import EmbedderConfigFactory -from memos.configs.llm import LLMConfigFactory +from pydantic import Field, field_validator, model_validator + +from memos.configs.base import BaseConfig class BaseAdderConfig(BaseConfig): @@ -14,8 +11,8 @@ class BaseAdderConfig(BaseConfig): class NaiveAdderConfig(BaseAdderConfig): """Configuration for Naive Adder.""" - # No additional config needed since components are passed from parent + # No additional config needed since components are passed from parent class AdderConfigFactory(BaseConfig): @@ -50,6 +47,7 @@ class BaseExtractorConfig(BaseConfig): class NaiveExtractorConfig(BaseExtractorConfig): """Configuration for Naive Extractor.""" + class ExtractorConfigFactory(BaseConfig): """Factory class for creating Extractor configurations.""" @@ -74,6 +72,7 @@ def create_config(self) -> "ExtractorConfigFactory": self.config = config_class(**self.config) return self + class BaseRetrieverConfig(BaseConfig): """Base configuration class for Retrievers.""" @@ -124,7 +123,7 @@ class UpdaterConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "naive": NaiveUpdaterConfig, } - + @field_validator("backend") @classmethod def validate_backend(cls, backend: str) -> str: @@ -157,7 +156,7 @@ class AssemblerConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "naive": NaiveAssemblerConfig, } - + @field_validator("backend") @classmethod def validate_backend(cls, backend: str) -> str: @@ -171,4 +170,3 @@ def create_config(self) -> "AssemblerConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) return self - diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 7ede2e45b..81b5cf8d6 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -1,24 +1,27 @@ import json import uuid + from abc import ABC, abstractmethod -from datetime import datetime -from typing import Any, Dict, Optional, List from concurrent.futures import ThreadPoolExecutor, as_completed -from memos.types import MessageList -from memos.memories.textual.item import TextualMemoryItem, PreferenceTextualMemoryMetadata +from datetime import datetime +from typing import Any + +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.prefer_text_memory.spliter import Splitter +from memos.memories.textual.prefer_text_memory.utils import convert_messages_to_string from memos.templates.prefer_complete_prompt import ( NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_TOPIC_INFO_EXTRACT_PROMPT, NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT, NAIVE_USER_PREFERENCE_EXTRACT_PROMPT, - NAIVE_TOPIC_INFO_EXTRACT_PROMPT, ) -from memos.memories.textual.prefer_text_memory.spliter import Splitter -from memos.memories.textual.prefer_text_memory.utils import convert_messages_to_string +from memos.types import MessageList + class BaseExtractor(ABC): """Abstract base class for extractors.""" - + @abstractmethod def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the extractor.""" @@ -26,6 +29,7 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): class NaiveExtractor(BaseExtractor): """Extractor.""" + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the extractor.""" super().__init__(llm_provider, embedder, vector_db) @@ -34,22 +38,22 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.vector_db = vector_db self.splitter = Splitter() - def extract_basic_info(self, qa_pair: MessageList) -> Dict[str, Any]: + def extract_basic_info(self, qa_pair: MessageList) -> dict[str, Any]: """Extract basic information from a QA pair (no LLM needed).""" basic_info = { "dialog_id": str(uuid.uuid4()), "dialog_msgs": qa_pair, "dialog_str": convert_messages_to_string(qa_pair), - "created_at": datetime.now().isoformat() + "created_at": datetime.now().isoformat(), } - + return basic_info - def extract_topic_info(self, qa_pair: MessageList | str) -> Dict[str, Any]: + def extract_topic_info(self, qa_pair: MessageList | str) -> dict[str, Any]: """Extract topic information from a QA pair.""" qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair prompt = NAIVE_TOPIC_INFO_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) - + try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() @@ -57,13 +61,13 @@ def extract_topic_info(self, qa_pair: MessageList | str) -> Dict[str, Any]: return result except Exception as e: print(f"Error extracting topic info: {e}, return None") - return None - - def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, Any] | None: + return None + + def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, Any] | None: """Extract explicit preference from a QA pair.""" qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) - + try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() @@ -73,13 +77,19 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> Dict[str, A print(f"Error extracting explicit preference: {e}, return None") return None - def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> Dict[str, Any] | None: + def extract_implicit_preferences( + self, qa_pairs: MessageList | list[str] + ) -> dict[str, Any] | None: """Extract implicit preferences from cluster qa pairs.""" if not qa_pairs: return None - qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs[0], dict) else "\n\n".join(qa_pairs) + qa_pairs_str = ( + convert_messages_to_string(qa_pairs) + if isinstance(qa_pairs[0], dict) + else "\n\n".join(qa_pairs) + ) prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) - + try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() @@ -88,32 +98,40 @@ def extract_implicit_preferences(self, qa_pairs: MessageList | list[str]) -> Dic except Exception as e: print(f"Error extracting implicit preferences: {e}, return None") return None - - def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> Dict[str, Any] | None: + + def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> dict[str, Any] | None: """Extract topic preferences from cluster qa pairs.""" if not qa_pairs: return None - qa_pairs_str = convert_messages_to_string(qa_pairs) if isinstance(qa_pairs[0], dict) else "\n\n".join(qa_pairs) + qa_pairs_str = ( + convert_messages_to_string(qa_pairs) + if isinstance(qa_pairs[0], dict) + else "\n\n".join(qa_pairs) + ) prompt = NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) - + try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) - + if result.get("topic_cluster_name"): return result except Exception as e: print(f"Error extracting topic preferences: {qa_pairs}\n{e}, return None") return None - - def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> Dict[str, Any] | None: + + def extract_user_preferences( + self, topic_preferences: list[dict[str, Any]] + ) -> dict[str, Any] | None: """Extract user-level preferences.""" if not topic_preferences: return [] - prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace("{cluster_info}", json.dumps(topic_preferences, ensure_ascii=False, indent=2)) - + prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace( + "{cluster_info}", json.dumps(topic_preferences, ensure_ascii=False, indent=2) + ) + try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() @@ -124,7 +142,9 @@ def extract_user_preferences(self, topic_preferences: List[Dict[str, Any]]) -> D print(f"Error processing user preferences: {topic_preferences}\n{e}, return None") return "" - def _process_single_chunk_explicit(self, chunk: MessageList, msg_type: str, info: dict[str, Any]) -> TextualMemoryItem | None: + def _process_single_chunk_explicit( + self, chunk: MessageList, msg_type: str, info: dict[str, Any] + ) -> TextualMemoryItem | None: """Process a single chunk and return a TextualMemoryItem.""" basic_info = self.extract_basic_info(chunk) if not basic_info["dialog_str"]: @@ -139,29 +159,45 @@ def _process_single_chunk_explicit(self, chunk: MessageList, msg_type: str, info } extract_info = {**basic_info, **explicit_pref, **vector_info, **info} - metadata = PreferenceTextualMemoryMetadata(type=msg_type, preference_type="explicit_preference", **extract_info) - memory = TextualMemoryItem(id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata) + metadata = PreferenceTextualMemoryMetadata( + type=msg_type, preference_type="explicit_preference", **extract_info + ) + memory = TextualMemoryItem( + id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata + ) return memory - def _process_single_chunk_implicit(self, chunk: MessageList, msg_type: str, info: dict[str, Any]) -> TextualMemoryItem | None: + def _process_single_chunk_implicit( + self, chunk: MessageList, msg_type: str, info: dict[str, Any] + ) -> TextualMemoryItem | None: basic_info = self.extract_basic_info(chunk) if not basic_info["dialog_str"]: return None implicit_pref = self.extract_implicit_preferences(basic_info["dialog_str"]) if not implicit_pref: return None - + vector_info = { "dialog_vector": self.embedder.embed([basic_info["dialog_str"]])[0], } extract_info = {**basic_info, **implicit_pref, **vector_info, **info} - metadata = PreferenceTextualMemoryMetadata(type=msg_type, preference_type="implicit_preference", **extract_info) - memory = TextualMemoryItem(id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata) + metadata = PreferenceTextualMemoryMetadata( + type=msg_type, preference_type="implicit_preference", **extract_info + ) + memory = TextualMemoryItem( + id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata + ) return memory - def extract(self, messages: list[MessageList], msg_type: str, info: dict[str, Any], max_workers: int = 10) -> list[TextualMemoryItem]: + def extract( + self, + messages: list[MessageList], + msg_type: str, + info: dict[str, Any], + max_workers: int = 10, + ) -> list[TextualMemoryItem]: """Extract preference memories based on the messages using thread pool for acceleration.""" chunks_for_explicit: list[MessageList] = [] for message in messages: @@ -176,16 +212,26 @@ def extract(self, messages: list[MessageList], msg_type: str, info: dict[str, An chunks_for_implicit.extend(chunk) memories = [] - with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks_for_explicit) + len(chunks_for_implicit))) as executor: + with ThreadPoolExecutor( + max_workers=min(max_workers, len(chunks_for_explicit) + len(chunks_for_implicit)) + ) as executor: futures = { - executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ("explicit", chunk) + executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ( + "explicit", + chunk, + ) for chunk in chunks_for_explicit } - futures.update({ - executor.submit(self._process_single_chunk_implicit, chunk, msg_type, info): ("implicit", chunk) - for chunk in chunks_for_implicit - }) - + futures.update( + { + executor.submit(self._process_single_chunk_implicit, chunk, msg_type, info): ( + "implicit", + chunk, + ) + for chunk in chunks_for_implicit + } + ) + for future in as_completed(futures): try: memory = future.result() @@ -196,4 +242,4 @@ def extract(self, messages: list[MessageList], msg_type: str, info: dict[str, An print(f"Error processing {task_type} chunk: {chunk}\n{e}") continue - return memories \ No newline at end of file + return memories diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index 56259bacc..253f21662 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -1,26 +1,30 @@ from typing import Any, ClassVar from memos.memories.textual.prefer_text_memory.adder import BaseAdder, NaiveAdder +from memos.memories.textual.prefer_text_memory.assemble import BaseAssembler, NaiveAssembler +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + AssemblerConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, + UpdaterConfigFactory, +) from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor, NaiveExtractor from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever, NaiveRetriever from memos.memories.textual.prefer_text_memory.updater import BaseUpdater, NaiveUpdater -from memos.memories.textual.prefer_text_memory.assemble import BaseAssembler, NaiveAssembler -from memos.memories.textual.prefer_text_memory.config import (AdderConfigFactory, - ExtractorConfigFactory, - RetrieverConfigFactory, - UpdaterConfigFactory, - AssemblerConfigFactory) class AdderFactory(BaseAdder): """Factory class for creating Adder instances.""" - + backend_to_class: ClassVar[dict[str, Any]] = { "naive": NaiveAdder, } - + @classmethod - def from_config(cls, config_factory: AdderConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseAdder: + def from_config( + cls, config_factory: AdderConfigFactory, llm_provider=None, embedder=None, vector_db=None + ) -> BaseAdder: """Create a Adder instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: @@ -28,15 +32,22 @@ def from_config(cls, config_factory: AdderConfigFactory, llm_provider=None, embe adder_class = cls.backend_to_class[backend] return adder_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + class ExtractorFactory(BaseExtractor): """Factory class for creating Extractor instances.""" - + backend_to_class: ClassVar[dict[str, Any]] = { "naive": NaiveExtractor, } - + @classmethod - def from_config(cls, config_factory: ExtractorConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseExtractor: + def from_config( + cls, + config_factory: ExtractorConfigFactory, + llm_provider=None, + embedder=None, + vector_db=None, + ) -> BaseExtractor: """Create a Extractor instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: @@ -44,15 +55,22 @@ def from_config(cls, config_factory: ExtractorConfigFactory, llm_provider=None, extractor_class = cls.backend_to_class[backend] return extractor_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + class RetrieverFactory(BaseRetriever): """Factory class for creating Retriever instances.""" - + backend_to_class: ClassVar[dict[str, Any]] = { "naive": NaiveRetriever, } - + @classmethod - def from_config(cls, config_factory: RetrieverConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseRetriever: + def from_config( + cls, + config_factory: RetrieverConfigFactory, + llm_provider=None, + embedder=None, + vector_db=None, + ) -> BaseRetriever: """Create a Retriever instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: @@ -60,35 +78,51 @@ def from_config(cls, config_factory: RetrieverConfigFactory, llm_provider=None, retriever_class = cls.backend_to_class[backend] return retriever_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + class UpdaterFactory(BaseUpdater): """Factory class for creating Updater instances.""" - + backend_to_class: ClassVar[dict[str, Any]] = { "naive": NaiveUpdater, } - + @classmethod - def from_config(cls, config_factory: UpdaterConfigFactory, llm_provider=None, embedder=None, vector_db=None, extractor=None) -> BaseUpdater: + def from_config( + cls, + config_factory: UpdaterConfigFactory, + llm_provider=None, + embedder=None, + vector_db=None, + extractor=None, + ) -> BaseUpdater: """Create a Updater instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") updater_class = cls.backend_to_class[backend] - return updater_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db, extractor=extractor) + return updater_class( + llm_provider=llm_provider, embedder=embedder, vector_db=vector_db, extractor=extractor + ) + class AssemblerFactory(BaseAssembler): """Factory class for creating Assembler instances.""" - + backend_to_class: ClassVar[dict[str, Any]] = { "naive": NaiveAssembler, } - + @classmethod - def from_config(cls, config_factory: AssemblerConfigFactory, llm_provider=None, embedder=None, vector_db=None) -> BaseAssembler: + def from_config( + cls, + config_factory: AssemblerConfigFactory, + llm_provider=None, + embedder=None, + vector_db=None, + ) -> BaseAssembler: """Create a Assembler instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") assembler_class = cls.backend_to_class[backend] return assembler_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) - \ No newline at end of file diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 8245396d3..0e5572c94 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -1,30 +1,36 @@ from abc import ABC, abstractmethod -from typing import Any, Dict -from concurrent.futures import ThreadPoolExecutor, as_completed -from memos.memories.textual.item import TextualMemoryItem, PreferenceTextualMemoryMetadata +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem class BaseRetriever(ABC): """Abstract base class for retrievers.""" - + @abstractmethod def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the retriever.""" - @abstractmethod - def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[TextualMemoryItem]: + def retrieve( + self, query: str, top_k: int, info: dict[str, Any] | None = None + ) -> list[TextualMemoryItem]: """Retrieve memories from the retriever.""" + class NaiveRetriever(BaseRetriever): """Naive retriever.""" + def __init__(self, llm_provider=None, embedder=None, vector_db=None): """Initialize the naive retriever.""" super().__init__(llm_provider, embedder, vector_db) self.vector_db = vector_db self.embedder = embedder - def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[TextualMemoryItem]: + def retrieve( + self, query: str, top_k: int, info: dict[str, Any] | None = None + ) -> list[TextualMemoryItem]: """Retrieve memories from the naive retriever.""" # TODO: un-support rewrite query and session filter now if info: @@ -33,35 +39,70 @@ def retrieve(self, query: str, top_k: int, info: dict[str, Any]=None) -> list[Te info.pop("session_id", None) query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings query_embedding = query_embeddings[0] # Get the first (and only) embedding - + # Use thread pool to parallelize the searches with ThreadPoolExecutor(max_workers=2) as executor: # Submit all search tasks - future_explicit = executor.submit(self.vector_db.search, query_embedding, "explicit_preference", top_k, info) - future_implicit = executor.submit(self.vector_db.search, query_embedding, "implicit_preference", top_k, info) - future_topic = executor.submit(self.vector_db.search, query_embedding, "topic_preference", top_k, info) - + future_explicit = executor.submit( + self.vector_db.search, query_embedding, "explicit_preference", top_k, info + ) + future_implicit = executor.submit( + self.vector_db.search, query_embedding, "implicit_preference", top_k, info + ) + future_topic = executor.submit( + self.vector_db.search, query_embedding, "topic_preference", top_k, info + ) + # Get user preferences directly (no vector search needed since there's only one per user) user_id = info.get("user_id") if info else None if user_id: - future_user = executor.submit(self.vector_db.get_by_filter, "user_preference", {"user_id": user_id}) + future_user = executor.submit( + self.vector_db.get_by_filter, "user_preference", {"user_id": user_id} + ) else: future_user = None - + # Wait for all results explicit_prefs = future_explicit.result() implicit_prefs = future_implicit.result() topic_prefs = future_topic.result() user_prefs = future_user.result() if future_user else [] - - explicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("dialog_str", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in explicit_prefs if pref.payload["explicit_preference"]] - implicit_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("dialog_str", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in implicit_prefs if pref.payload["implicit_preference"]] - topic_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("center_dialog", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in topic_prefs if pref.payload["topic_preference"]] - user_prefs = [TextualMemoryItem(id=pref.id, memory=pref.payload.get("user_preference", ""), - metadata=PreferenceTextualMemoryMetadata(**pref.payload)) for pref in user_prefs if pref.payload["user_preference"]] - - return explicit_prefs + implicit_prefs + topic_prefs + user_prefs + explicit_prefs = [ + TextualMemoryItem( + id=pref.id, + memory=pref.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload), + ) + for pref in explicit_prefs + if pref.payload["explicit_preference"] + ] + implicit_prefs = [ + TextualMemoryItem( + id=pref.id, + memory=pref.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload), + ) + for pref in implicit_prefs + if pref.payload["implicit_preference"] + ] + topic_prefs = [ + TextualMemoryItem( + id=pref.id, + memory=pref.payload.get("center_dialog", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload), + ) + for pref in topic_prefs + if pref.payload["topic_preference"] + ] + user_prefs = [ + TextualMemoryItem( + id=pref.id, + memory=pref.payload.get("user_preference", ""), + metadata=PreferenceTextualMemoryMetadata(**pref.payload), + ) + for pref in user_prefs + if pref.payload["user_preference"] + ] + + return explicit_prefs + implicit_prefs + topic_prefs + user_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py index 8b5982c2d..59a6b0052 100644 --- a/src/memos/memories/textual/prefer_text_memory/spliter.py +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -1,20 +1,25 @@ import copy -from memos.types import MessageList -from memos.parsers.factory import ParserFactory -from memos.configs.parser import ParserConfigFactory from memos.chunkers import ChunkerFactory from memos.configs.chunker import ChunkerConfigFactory +from memos.configs.parser import ParserConfigFactory +from memos.parsers.factory import ParserFactory +from memos.types import MessageList + class Splitter: """Splitter.""" - def __init__(self, lookback_turns: int = 1, - chunk_size: int = 256, - chunk_overlap: int = 128, - min_sentences_per_chunk: int = 1, - tokenizer: str = "gpt2", - parser_backend: str = "markitdown", - chunker_backend: str = "sentence"): + + def __init__( + self, + lookback_turns: int = 1, + chunk_size: int = 256, + chunk_overlap: int = 128, + min_sentences_per_chunk: int = 1, + tokenizer: str = "gpt2", + parser_backend: str = "markitdown", + chunker_backend: str = "sentence", + ): """Initialize the splitter.""" self.lookback_turns = lookback_turns self.chunk_size = chunk_size @@ -31,7 +36,7 @@ def __init__(self, lookback_turns: int = 1, } ) self.parser = ParserFactory.from_config(parser_config) - + # Initialize chunker chunker_config = ChunkerConfigFactory.model_validate( { @@ -40,37 +45,37 @@ def __init__(self, lookback_turns: int = 1, "tokenizer_or_token_counter": self.tokenizer, "chunk_size": self.chunk_size, "chunk_overlap": self.chunk_overlap, - "min_sentences_per_chunk": self.min_sentences_per_chunk - } + "min_sentences_per_chunk": self.min_sentences_per_chunk, + }, } ) self.chunker = ChunkerFactory.from_config(chunker_config) def _split_with_lookback(self, data: MessageList) -> list[MessageList]: - """Split the messages or files into chunks by looking back fixed number of turns. + """Split the messages or files into chunks by looking back fixed number of turns. adjacent chunk with high duplicate rate, default lookback turns is 1, only current turn in chunk""" # Build QA pairs from chat history pairs = self.build_qa_pairs(data) chunks = [] - + # Create chunks by looking back fixed number of turns for i in range(len(pairs)): # Calculate the start index for lookback start_idx = max(0, i + 1 - self.lookback_turns) # Get the chunk of pairs (as many as available, up to lookback_turns) - chunk_pairs = pairs[start_idx:i+1] - + chunk_pairs = pairs[start_idx : i + 1] + # Flatten chunk_pairs (list[list[dict]]) to MessageList (list[dict]) chunk_messages = [] for pair in chunk_pairs: chunk_messages.extend(pair) - + chunks.append(chunk_messages) return chunks def _split_with_overlap(self, data: MessageList) -> list[MessageList]: - """split the messages or files into chunks with overlap. + """split the messages or files into chunks with overlap. adjacent chunk with low duplicate rate""" chunks = [] chunk = [] @@ -87,13 +92,12 @@ def _split_with_overlap(self, data: MessageList) -> list[MessageList]: return chunks - def split_chunks(self, data: MessageList | str, **kwargs) -> list[MessageList] | list[str]: """Split the messages or files into chunks. - + Args: data: MessageList or string to split - + Returns: List of MessageList chunks or list of string chunks """ @@ -107,15 +111,14 @@ def split_chunks(self, data: MessageList | str, **kwargs) -> list[MessageList] | # Parse and chunk the string data using pre-initialized components text = self.parser.parse(data) chunks = self.chunker.chunk(text) - - return [chunk.text for chunk in chunks] + return [chunk.text for chunk in chunks] def build_qa_pairs(self, chat_history: MessageList) -> list[MessageList]: """Build QA pairs from chat history.""" qa_pairs = [] current_qa_pair = [] - + for message in chat_history: if message["role"] == "user": current_qa_pair.append(message) diff --git a/src/memos/memories/textual/prefer_text_memory/strategies.py b/src/memos/memories/textual/prefer_text_memory/strategies.py deleted file mode 100644 index 11c8f44e0..000000000 --- a/src/memos/memories/textual/prefer_text_memory/strategies.py +++ /dev/null @@ -1,18 +0,0 @@ -from enum import Enum - -class BuildStrategy(Enum): - """Strategy for building memories.""" - NAIVE = "naive" - -class RetrieveStrategy(Enum): - """Strategy for retrieving memories.""" - NAIVE = "naive" - -class UpdateStrategy(Enum): - """Strategy for updating memories.""" - NAIVE = "naive" - -class AssembleStrategy(Enum): - """Strategy for assembling memories.""" - NAIVE = "naive" - diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index b30a59242..18b4ae480 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -1,25 +1,26 @@ -from abc import ABC, abstractmethod -from typing import Any, List, Dict -from datetime import datetime -import uuid import json +import uuid + +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed -from memos.types import MessageList +from datetime import datetime +from typing import Any + +from memos.memories.textual.prefer_text_memory.clustering import ClusterResult, HDBSCANClusterer from memos.vec_dbs.item import VecDBItem -from memos.memories.textual.prefer_text_memory.clustering import HDBSCANClusterer, ClusterResult class BaseUpdater(ABC): """Abstract base class for updaters.""" - + @abstractmethod def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=None): """Initialize the updater.""" - class NaiveUpdater(BaseUpdater): """Naive updater.""" + def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=None): """Initialize the naive updater.""" super().__init__(llm_provider, embedder, vector_db, extractor) @@ -29,7 +30,7 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=N self.extractor = extractor self.clusterer = HDBSCANClusterer() - def _implicit_cluster(self, informations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _implicit_cluster(self, informations: list[dict[str, Any]]) -> list[dict[str, Any]]: """Implicit cluster.""" vectors = [info.get("dialog_vector") for info in informations] if not vectors: @@ -39,8 +40,8 @@ def _implicit_cluster(self, informations: List[Dict[str, Any]]) -> List[Dict[str cluster.center_dialog_msgs = informations[cluster.center_index].get("dialog_msgs", []) cluster.center_dialog_str = informations[cluster.center_index].get("dialog_str", "") return res - - def _topic_cluster(self, informations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + + def _topic_cluster(self, informations: list[dict[str, Any]]) -> list[dict[str, Any]]: """Topic cluster.""" vectors = [info.get("topic_vector") for info in informations] if not vectors: @@ -51,10 +52,16 @@ def _topic_cluster(self, informations: List[Dict[str, Any]]) -> List[Dict[str, A cluster.center_dialog_str = informations[cluster.center_index].get("dialog_str", "") return res - def _create_cluster_extract_input(self, cluster_results: List[ClusterResult], informations: List[Dict[str, Any]], input_type: str, k: int = 5) -> Dict[str, List[str]]: + def _create_cluster_extract_input( + self, + cluster_results: list[ClusterResult], + informations: list[dict[str, Any]], + input_type: str, + k: int = 5, + ) -> dict[str, list[str]]: """Create cluster extract input. Args: - cluster_results: List[ClusterResult] + cluster_results: List[ClusterResult] informations: List[Dict[str, Any]] containing dialog information input_type: str, "original" or "knn" k: int, number of nearest neighbors for knn mode @@ -76,41 +83,45 @@ def _create_cluster_extract_input(self, cluster_results: List[ClusterResult], in cluster_dialogs.append(dialog_str) if cluster_dialogs: result[cluster.cluster_id] = cluster_dialogs - + elif input_type == "knn": # Use knn search from cluster center to find k nearest neighbors for cluster in cluster_results: # Extract all vectors from items in this cluster item_vectors = [item["vector"] for item in cluster.items] - + if not item_vectors: result[cluster.cluster_id] = [] continue - + # Use clusterer to find k nearest neighbors to cluster center knn_results = self.clusterer.search_knn_by_center_embeddings( center_emb=cluster.center_vector, vectors=item_vectors, - top_k=min(k, len(item_vectors)) + top_k=min(k, len(item_vectors)), ) - + # Get dialog strings for knn results cluster_dialogs = [] for knn_item in knn_results: # knn_item contains {"index": idx_in_item_vectors, "distance": ..., "vector": ...} item_idx = knn_item["index"] # index in item_vectors - original_index = cluster.items[item_idx]["index"] # original index in informations + original_index = cluster.items[item_idx][ + "index" + ] # original index in informations dialog_str = informations[original_index].get("dialog_str", "") cluster_dialogs.append(dialog_str) if cluster_dialogs: result[cluster.cluster_id] = cluster_dialogs - + else: raise ValueError(f"Invalid input type: {input_type}") - + return result - def _process_single_implicit_cluster(self, cluster_id: str, cluster_dialogs: List[str]) -> Dict[str, Any]: + def _process_single_implicit_cluster( + self, cluster_id: str, cluster_dialogs: list[str] + ) -> dict[str, Any]: """Process a single implicit cluster.""" try: result = self.extractor.extract_implicit_preferences(cluster_dialogs) @@ -119,18 +130,22 @@ def _process_single_implicit_cluster(self, cluster_id: str, cluster_dialogs: Lis print(f"Error processing implicit cluster {cluster_id}: {e}") return {"cluster_id": cluster_id, "implicit_exract_result": None} - def _extract_implicit_preferences(self, implicit_extract_inputs: Dict[str, List[str]], max_workers: int = 10) -> Dict[str, Dict[str, Any]]: + def _extract_implicit_preferences( + self, implicit_extract_inputs: dict[str, list[str]], max_workers: int = 10 + ) -> dict[str, dict[str, Any]]: """Extract implicit preferences from implicit extract inputs using thread pool.""" if not implicit_extract_inputs: return {} - + results = {} - with ThreadPoolExecutor(max_workers=min(max_workers, len(implicit_extract_inputs))) as executor: + with ThreadPoolExecutor( + max_workers=min(max_workers, len(implicit_extract_inputs)) + ) as executor: futures = [ executor.submit(self._process_single_implicit_cluster, cluster_id, cluster_dialogs) for cluster_id, cluster_dialogs in implicit_extract_inputs.items() ] - + for future in as_completed(futures): try: result = future.result() @@ -140,10 +155,12 @@ def _extract_implicit_preferences(self, implicit_extract_inputs: Dict[str, List[ except Exception as e: print(f"Error processing implicit cluster: {e}") continue - + return results - - def _process_single_topic_cluster(self, cluster_id: str, cluster_dialogs: List[str]) -> Dict[str, Any]: + + def _process_single_topic_cluster( + self, cluster_id: str, cluster_dialogs: list[str] + ) -> dict[str, Any]: """Process a single topic cluster.""" try: result = self.extractor.extract_topic_preferences(cluster_dialogs) @@ -152,18 +169,22 @@ def _process_single_topic_cluster(self, cluster_id: str, cluster_dialogs: List[s print(f"Error processing topic cluster {cluster_id}: {e}") return {"cluster_id": cluster_id, "topic_exract_result": None} - def _extract_topic_preferences(self, topic_extract_inputs: Dict[str, List[str]], max_workers: int = 10) -> Dict[str, Dict[str, Any]]: + def _extract_topic_preferences( + self, topic_extract_inputs: dict[str, list[str]], max_workers: int = 10 + ) -> dict[str, dict[str, Any]]: """Extract topic preferences from topic extract inputs using thread pool.""" if not topic_extract_inputs: return {} - + results = {} - with ThreadPoolExecutor(max_workers=min(max_workers, len(topic_extract_inputs))) as executor: + with ThreadPoolExecutor( + max_workers=min(max_workers, len(topic_extract_inputs)) + ) as executor: futures = [ executor.submit(self._process_single_topic_cluster, cluster_id, cluster_dialogs) for cluster_id, cluster_dialogs in topic_extract_inputs.items() ] - + for future in as_completed(futures): try: result = future.result() @@ -173,22 +194,26 @@ def _extract_topic_preferences(self, topic_extract_inputs: Dict[str, List[str]], except Exception as e: print(f"Error processing topic cluster: {e}") continue - + return results - def _extract_user_preferences(self, topic_cluster_pref_infos: Dict[str, Any]) -> Dict[str, Any] | None: + def _extract_user_preferences( + self, topic_cluster_pref_infos: dict[str, Any] + ) -> dict[str, Any] | None: """Extract user preferences from topic cluster info.""" # topic_prefs is a dict, so we just pass the values as a list topic_cluster_pref = list(topic_cluster_pref_infos.values()) return self.extractor.extract_user_preferences(topic_cluster_pref) - def _store_preferences(self, - implicit_clusters, - topic_clusters, - implicit_cluster_prefs, - topic_cluster_prefs, - user_prefs, - user_id): + def _store_preferences( + self, + implicit_clusters, + topic_clusters, + implicit_cluster_prefs, + topic_cluster_prefs, + user_prefs, + user_id, + ): """Create store data.""" implicit_memories = [] topic_memories = [] @@ -196,7 +221,9 @@ def _store_preferences(self, if implicit_clusters: for cluster in implicit_clusters: if cluster.cluster_id not in implicit_cluster_prefs: - print(f"Warning: No preference found for cluster {cluster.cluster_id}, skipping...") + print( + f"Warning: No preference found for cluster {cluster.cluster_id}, skipping..." + ) continue pref = implicit_cluster_prefs[cluster.cluster_id] mem = VecDBItem( @@ -211,8 +238,8 @@ def _store_preferences(self, "created_at": cluster.created_at, "user_id": user_id, "size": cluster.size, - "preference_type": "implicit_preference" - } + "preference_type": "implicit_preference", + }, ) implicit_memories.append(mem) @@ -221,7 +248,9 @@ def _store_preferences(self, if topic_clusters: for cluster in topic_clusters: if cluster.cluster_id not in topic_cluster_prefs: - print(f"Warning: No preference found for topic cluster {cluster.cluster_id}, skipping...") + print( + f"Warning: No preference found for topic cluster {cluster.cluster_id}, skipping..." + ) continue pref = topic_cluster_prefs[cluster.cluster_id] mem = VecDBItem( @@ -238,8 +267,8 @@ def _store_preferences(self, "created_at": cluster.created_at, "user_id": user_id, "size": cluster.size, - "preference_type": "topic_preference" - } + "preference_type": "topic_preference", + }, ) topic_memories.append(mem) @@ -253,15 +282,18 @@ def _store_preferences(self, "user_id": user_id, "user_preference": user_prefs.get("user_preference", ""), "created_at": datetime.now().isoformat(), - "preference_type": "user_preference" - } + "preference_type": "user_preference", + }, ) self.vector_db.add("user_preference", [mem]) - def _generate_memory_summary(self, explicit_infos: List[Dict[str, Any]], - implicit_infos: List[Dict[str, Any]], - topic_infos: List[Dict[str, Any]], - user_infos: Dict[str, Any]) -> str: + def _generate_memory_summary( + self, + explicit_infos: list[dict[str, Any]], + implicit_infos: list[dict[str, Any]], + topic_infos: list[dict[str, Any]], + user_infos: dict[str, Any], + ) -> str: """Generate a summary of the built memory.""" summary = { "memory_build_summary": { @@ -269,22 +301,37 @@ def _generate_memory_summary(self, explicit_infos: List[Dict[str, Any]], "implicit_preferences_count": len(implicit_infos), "topic_preferences_count": len(topic_infos), "user_preferences_count": 1 if user_infos else 0, - "build_timestamp": datetime.now().isoformat() + "build_timestamp": datetime.now().isoformat(), } } - + return json.dumps(summary, ensure_ascii=False, indent=2) - + def slow_update(self, user_id: str): - """Retrieve all dialog info from the expicit preference collection, + """Retrieve all dialog info from the expicit preference collection, and reconstruct the implicit preference collection, topic collection and user preference collection. """ # refresh the implicit preference collection, topic collection and user preference collection - impl_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="implicit_preference", filter={"user_id": user_id})] - topic_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="topic_preference", filter={"user_id": user_id})] - user_ids = [item.id for item in self.vector_db.get_by_filter(collection_name="user_preference", filter={"user_id": user_id})] - + impl_ids = [ + item.id + for item in self.vector_db.get_by_filter( + collection_name="implicit_preference", filter={"user_id": user_id} + ) + ] + topic_ids = [ + item.id + for item in self.vector_db.get_by_filter( + collection_name="topic_preference", filter={"user_id": user_id} + ) + ] + user_ids = [ + item.id + for item in self.vector_db.get_by_filter( + collection_name="user_preference", filter={"user_id": user_id} + ) + ] + self.vector_db.delete("implicit_preference", impl_ids) self.vector_db.delete("topic_preference", topic_ids) self.vector_db.delete("user_preference", user_ids) @@ -292,26 +339,27 @@ def slow_update(self, user_id: str): # get all data from explicit preference collection all_data = self.vector_db.get_by_filter("explicit_preference", filter={"user_id": user_id}) informations = [item.payload for item in all_data] - + # Perform clustering implicit_clusters = self._implicit_cluster(informations) topic_clusters = self._topic_cluster(informations) # create extract inputs for each implicit and topic cluster - implicit_extract_inputs = self._create_cluster_extract_input(implicit_clusters, informations, "original") - topic_extract_inputs = self._create_cluster_extract_input(topic_clusters, informations, "original") - + implicit_extract_inputs = self._create_cluster_extract_input( + implicit_clusters, informations, "original" + ) + topic_extract_inputs = self._create_cluster_extract_input( + topic_clusters, informations, "original" + ) + # Extract preferences if implicit_extract_inputs: implicit_cluster_prefs = self._extract_implicit_preferences(implicit_extract_inputs) if topic_extract_inputs: topic_cluster_prefs = self._extract_topic_preferences(topic_extract_inputs) - + # Extract user preferences user_prefs = self._extract_user_preferences(topic_cluster_prefs) - - - # Store all preferences in memory self._store_preferences( @@ -322,7 +370,7 @@ def slow_update(self, user_id: str): user_prefs=user_prefs, user_id=user_id, ) - + # Return summary of built memory return self._generate_memory_summary( explicit_infos=informations, @@ -330,6 +378,3 @@ def slow_update(self, user_id: str): topic_infos=topic_clusters, user_infos=user_prefs, ) - - - diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py index 7c4feced2..e0752d899 100644 --- a/src/memos/memories/textual/prefer_text_memory/utils.py +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -1,7 +1,6 @@ from memos.types import MessageList - def convert_messages_to_string(messages: MessageList) -> str: """Convert a list of messages to a string.""" message_text = "" @@ -11,4 +10,4 @@ def convert_messages_to_string(messages: MessageList) -> str: elif message["role"] == "assistant": message_text += f"Answer: {message['content']}\n" if message["content"].strip() else "" message_text = message_text.strip() - return message_text \ No newline at end of file + return message_text diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index b55d6307f..2a6d09224 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -1,31 +1,30 @@ import json import os -import shutil -import tempfile -from datetime import datetime -from pathlib import Path from typing import Any from memos.configs.memory import PreferenceTextMemoryConfig -from memos.embedders.factory import (EmbedderFactory, - OllamaEmbedder, - ArkEmbedder, - SenTranEmbedder, - UniversalAPIEmbedder) +from memos.embedders.factory import ( + ArkEmbedder, + EmbedderFactory, + OllamaEmbedder, + SenTranEmbedder, + UniversalAPIEmbedder, +) from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.base import BaseTextMemory -from memos.memories.textual.item import TextualMemoryItem, PreferenceTextualMemoryMetadata -from memos.types import ChatHistory, MessageList -from memos.llms.base import BaseLLM -from memos.vec_dbs.factory import QdrantVecDB, VecDBFactory, MilvusVecDB +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.prefer_text_memory.factory import ( + AdderFactory, + AssemblerFactory, + ExtractorFactory, + RetrieverFactory, + UpdaterFactory, +) +from memos.types import MessageList +from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB, VecDBFactory from memos.vec_dbs.item import VecDBItem -from memos.memories.textual.prefer_text_memory.factory import (AdderFactory, - ExtractorFactory, - RetrieverFactory, - UpdaterFactory, - AssemblerFactory) logger = get_logger(__name__) @@ -41,43 +40,46 @@ def __init__(self, config: PreferenceTextMemoryConfig): config.extractor_llm ) self.vector_db: MilvusVecDB | QdrantVecDB = VecDBFactory.from_config(config.vector_db) - self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = \ + self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = ( EmbedderFactory.from_config(config.embedder) + ) self.extractor = ExtractorFactory.from_config( config.extractor, llm_provider=self.extractor_llm, embedder=self.embedder, - vector_db=self.vector_db + vector_db=self.vector_db, ) self.adder = AdderFactory.from_config( - config.adder, + config.adder, llm_provider=self.extractor_llm, embedder=self.embedder, - vector_db=self.vector_db + vector_db=self.vector_db, ) self.retriever = RetrieverFactory.from_config( config.retriever, llm_provider=self.extractor_llm, embedder=self.embedder, - vector_db=self.vector_db + vector_db=self.vector_db, ) self.updater = UpdaterFactory.from_config( config.updater, llm_provider=self.extractor_llm, embedder=self.embedder, vector_db=self.vector_db, - extractor=self.extractor + extractor=self.extractor, ) self.assembler = AssemblerFactory.from_config( config.assembler, llm_provider=self.extractor_llm, embedder=self.embedder, - vector_db=self.vector_db + vector_db=self.vector_db, ) - def get_memory(self, messages: MessageList, type: str, info: dict[str, Any]) -> list[TextualMemoryItem]: + def get_memory( + self, messages: MessageList, type: str, info: dict[str, Any] + ) -> list[TextualMemoryItem]: """Get memory based on the messages. Args: messages (MessageList): The messages to get memory from. @@ -94,7 +96,7 @@ def slow_update(self, user_id: str) -> str: str: Summary of the memory build process. """ return self.updater.slow_update(user_id) - + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: @@ -106,7 +108,6 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem """ return self.retriever.retrieve(query, top_k, info) - def get_prompt(self, query: str, memories: list[TextualMemoryItem]) -> str: """Construct the prompt for the query with memories. Args: @@ -167,7 +168,9 @@ def dump(self, dir: str) -> None: with open(memory_file, "w", encoding="utf-8") as f: json.dump(json_memories, f, indent=4, ensure_ascii=False) - logger.info(f"Dumped {len(json_memories)} collections, {sum(len(items) for items in json_memories.values())} memories to {memory_file}") + logger.info( + f"Dumped {len(json_memories)} collections, {sum(len(items) for items in json_memories.values())} memories to {memory_file}" + ) except Exception as e: logger.error(f"An error occurred while dumping memories: {e}") @@ -189,11 +192,11 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: memories: List of TextualMemoryItem objects or dictionaries to add. """ return self.adder.add(memories) - + def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: """Update a memory by memory_id.""" raise NotImplementedError - + def get(self, memory_id: str) -> TextualMemoryItem: """Get a memory by its ID. Args: @@ -203,7 +206,9 @@ def get(self, memory_id: str) -> TextualMemoryItem: """ raise NotImplementedError - def get_with_collection_name(self, collection_name: str, memory_id: str) -> TextualMemoryItem: + def get_with_collection_name( + self, collection_name: str, memory_id: str + ) -> TextualMemoryItem | None: """Get a memory by its ID and collection name. Args: memory_id (str): The ID of the memory to retrieve. @@ -214,14 +219,18 @@ def get_with_collection_name(self, collection_name: str, memory_id: str) -> Text try: res = self.vector_db.get_by_id(collection_name, memory_id) if res is None: - raise ValueError(f"Memory with ID {memory_id} not found in collection {collection_name}") - return TextualMemoryItem(id=res.id, - memory=res.payload.get("dialog_str", ""), - metadata=PreferenceTextualMemoryMetadata(**res.payload)) + return None + return TextualMemoryItem( + id=res.id, + memory=res.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**res.payload), + ) except Exception as e: # Convert any other exception to ValueError for consistent error handling - raise ValueError(f"Memory with ID {memory_id} not found in collection {collection_name}: {e}") - + raise ValueError( + f"Memory with ID {memory_id} not found in collection {collection_name}: {e}" + ) from e + def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: """Get memories by their IDs. Args: @@ -231,7 +240,9 @@ def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: """ raise NotImplementedError - def get_by_ids_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> list[TextualMemoryItem]: + def get_by_ids_with_collection_name( + self, collection_name: str, memory_ids: list[str] + ) -> list[TextualMemoryItem]: """Get memories by their IDs and collection name. Args: collection_name (str): The name of the collection to retrieve the memory from. @@ -242,14 +253,21 @@ def get_by_ids_with_collection_name(self, collection_name: str, memory_ids: list try: res = self.vector_db.get_by_ids(collection_name, memory_ids) if not res: - raise ValueError(f"Memory with IDs {memory_ids} not found in collection {collection_name}") - return [TextualMemoryItem(id=memo.id, - memory=memo.payload.get("dialog_str", ""), - metadata=PreferenceTextualMemoryMetadata(**memo.payload)) for memo in res] + return [] + return [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in res + ] except Exception as e: # Convert any other exception to ValueError for consistent error handling - raise ValueError(f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}") - + raise ValueError( + f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}" + ) from e + def get_all(self) -> list[TextualMemoryItem]: """Get all memories. Returns: @@ -259,11 +277,16 @@ def get_all(self) -> list[TextualMemoryItem]: all_memories = {} for collection_name in all_collections: items = self.vector_db.get_all(collection_name) - all_memories[collection_name] = [TextualMemoryItem(id=memo.id, - memory=memo.payload.get("dialog_str", ""), - metadata=PreferenceTextualMemoryMetadata(**memo.payload)) for memo in items] + all_memories[collection_name] = [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in items + ] return all_memories - + def delete(self, memory_ids: list[str]) -> None: """Delete memories. Args: @@ -278,15 +301,15 @@ def delete_with_collection_name(self, collection_name: str, memory_ids: list[str memory_ids (list[str]): List of memory IDs to delete. """ self.vector_db.delete(collection_name, memory_ids) - + def delete_all(self) -> None: """Delete all memories.""" for collection_name in self.vector_db.config.collection_name: self.vector_db.delete_collection(collection_name) self.vector_db.create_collection() - + def drop( self, ) -> None: """Drop all databases.""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 8b439c415..3f2e6b2f1 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -1,21 +1,18 @@ - - - NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT = """ -You are a preference extraction assistant. -Please extract the user's explicitly mentioned preferences from the following conversation. +You are a preference extraction assistant. +Please extract the user's explicitly mentioned preferences from the following conversation. -Notes: -- A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". -- Any clearly expressed inclination, desire, rejection, or priority counts as an explicit preference. +Notes: +- A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". +- Any clearly expressed inclination, desire, rejection, or priority counts as an explicit preference. -Requirements: -1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. -2. Output should be a concise natural language summary, not a list or categories. -3. If there are no explicit preferences in the conversation, output an empty string "". -4. Output only the preference statements themselves, without any additional explanation. +Requirements: +1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. +2. Output should be a concise natural language summary, not a list or categories. +3. If there are no explicit preferences in the conversation, output an empty string "". +4. Output only the preference statements themselves, without any additional explanation. -Conversation: +Conversation: {qa_pair} Output format: @@ -29,7 +26,7 @@ NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ -You are a preference inference assistant. Please extract **implicit preferences** from the following conversation +You are a preference inference assistant. Please extract **implicit preferences** from the following conversation (preferences that the user did not explicitly state but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices). Notes: @@ -37,16 +34,16 @@ - Do not treat explicitly stated preferences as implicit preferences; this prompt is only for inferring preferences that are not directly mentioned. Requirements: -1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. +1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. 2. Output a concise natural language statement; do not use lists, categories, or include the reasoning process. 3. Inferred implicit preferences must not conflict with explicit preferences. -4. For implicit_preference: only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. Put all reasoning and explanation in the reasoning field. +4. For implicit_preference: only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. Put all reasoning and explanation in the reasoning field. 5. If no implicit preference can be reasonably inferred, leave the implicit_preference field empty (do not output anything else). Conversation: {qa_pair} -Output format: +Output format: ```json { "implicit_preference": "A concise natural language statement of the implicit preferences reasonably inferred from the conversation, or an empty string", @@ -60,18 +57,18 @@ NAIVE_EXPLICIT_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ You are a preference extraction and inference assistant. Please extract the user's preferences from the following conversation, including: -1. **Explicit preferences**: Preferences that the user directly expresses, such as likes, dislikes, wants, does not want, or prioritized choices. +1. **Explicit preferences**: Preferences that the user directly expresses, such as likes, dislikes, wants, does not want, or prioritized choices. 2. **Implicit preferences**: Preferences that are not explicitly stated but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices. Notes: -- For explicit preferences, only extract what the user directly states, do not infer. -- For implicit preferences, only infer when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. +- For explicit preferences, only extract what the user directly states, do not infer. +- For implicit preferences, only infer when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. - Do not duplicate: do not treat explicit preferences as implicit preferences. Requirements: -1. Output in JSON format with two fields: "explicit_preferences" and "implicit_preferences". -2. Each field should be an array, with each element being a concise natural language preference statement. -3. Output only the preference statements themselves; do not include any extra explanation, reasoning, or confidence information. +1. Output in JSON format with two fields: "explicit_preferences" and "implicit_preferences". +2. Each field should be an array, with each element being a concise natural language preference statement. +3. Output only the preference statements themselves; do not include any extra explanation, reasoning, or confidence information. 4. If a type of preference does not exist, its array should be empty. Conversation: @@ -88,7 +85,6 @@ """ - NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT = """ You are a topic cluster analysis expert. You will be given a topic cluster containing multiple QA conversations with their preferences, topic names, and descriptions. Your task is to analyze all information in this cluster and extract three key components, then output in strict JSON format. @@ -259,4 +255,4 @@ Memories: {memories} -""" \ No newline at end of file +""" diff --git a/src/memos/vec_dbs/factory.py b/src/memos/vec_dbs/factory.py index d8545027c..f2950b4ea 100644 --- a/src/memos/vec_dbs/factory.py +++ b/src/memos/vec_dbs/factory.py @@ -2,8 +2,8 @@ from memos.configs.vec_db import VectorDBConfigFactory from memos.vec_dbs.base import BaseVecDB -from memos.vec_dbs.qdrant import QdrantVecDB from memos.vec_dbs.milvus import MilvusVecDB +from memos.vec_dbs.qdrant import QdrantVecDB class VecDBFactory(BaseVecDB): diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index e612b94b1..a46c4df79 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -1,5 +1,7 @@ from typing import Any -from pymilvus import MilvusClient, DataType + +from pymilvus import DataType, MilvusClient + from memos.configs.vec_db import MilvusVecDBConfig from memos.dependency import require_python_package from memos.log import get_logger @@ -24,7 +26,9 @@ def __init__(self, config: MilvusVecDBConfig): self.config = config # Create Milvus client - self.client = MilvusClient(uri=self.config.uri, user=self.config.user_name, password=self.config.password) + self.client = MilvusClient( + uri=self.config.uri, user=self.config.user_name, password=self.config.password + ) self.schema = self.create_schema() self.index_params = self.create_index() self.create_collection() @@ -32,8 +36,12 @@ def __init__(self, config: MilvusVecDBConfig): def create_schema(self): """Create schema for the milvus collection.""" schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True) - schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True) - schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension) + schema.add_field( + field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True + ) + schema.add_field( + field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension + ) schema.add_field(field_name="payload", datatype=DataType.JSON) return schema @@ -41,18 +49,17 @@ def create_schema(self): def create_index(self): """Create index for the milvus collection.""" index_params = self.client.prepare_index_params() - index_params.add_index(field_name="vector", index_type="FLAT", metric_type=self._get_metric_type()) + index_params.add_index( + field_name="vector", index_type="FLAT", metric_type=self._get_metric_type() + ) return index_params - def create_collection(self) -> None: """Create a new collection with specified parameters.""" for collection_name in self.config.collection_name: if self.collection_exists(collection_name): - logger.warning( - f"Collection '{collection_name}' already exists. Skipping creation." - ) + logger.warning(f"Collection '{collection_name}' already exists. Skipping creation.") continue self.client.create_collection( @@ -70,11 +77,9 @@ def create_collection(self) -> None: def create_collection_by_name(self, collection_name: str) -> None: """Create a new collection with specified parameters.""" if self.collection_exists(collection_name): - logger.warning( - f"Collection '{collection_name}' already exists. Skipping creation." - ) + logger.warning(f"Collection '{collection_name}' already exists. Skipping creation.") return - + self.client.create_collection( collection_name=collection_name, dimension=self.config.vector_dimension, @@ -96,7 +101,11 @@ def collection_exists(self, name: str) -> bool: return self.client.has_collection(collection_name=name) def search( - self, query_vector: list[float], collection_name: str, top_k: int, filter: dict[str, Any] | None = None + self, + query_vector: list[float], + collection_name: str, + top_k: int, + filter: dict[str, Any] | None = None, ) -> list[VecDBItem]: """ Search for similar items in the database. @@ -112,7 +121,7 @@ def search( """ # Convert filter to Milvus expression expr = self._dict_to_expr(filter) if filter else "" - + results = self.client.search( collection_name=collection_name, data=[query_vector], @@ -120,18 +129,20 @@ def search( filter=expr, output_fields=["*"], # Return all fields ) - + items = [] for hit in results[0]: entity = hit.get("entity", {}) - - items.append(VecDBItem( - id=str(hit["id"]), - vector=entity.get("vector"), - payload=entity.get("payload", {}), - score=1-float(hit["distance"]), - )) - + + items.append( + VecDBItem( + id=str(hit["id"]), + vector=entity.get("vector"), + payload=entity.get("payload", {}), + score=1 - float(hit["distance"]), + ) + ) + logger.info(f"Milvus search completed with {len(items)} results.") return items @@ -139,7 +150,7 @@ def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: """Convert a dictionary filter to a Milvus expression string.""" if not filter_dict: return "" - + conditions = [] for field, value in filter_dict.items(): # Skip None values as they cause Milvus query syntax errors @@ -178,7 +189,7 @@ def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: entity = results[0] payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} - + return VecDBItem( id=entity["id"], vector=entity.get("vector"), @@ -198,15 +209,19 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: items = [] for entity in results: payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} - items.append(VecDBItem( - id=entity["id"], - vector=entity.get("vector"), - payload=payload, - )) - + items.append( + VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + ) + return items - def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100) -> list[VecDBItem]: + def get_by_filter( + self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 + ) -> list[VecDBItem]: """ Retrieve all items that match the given filter criteria using query_iterator. @@ -219,7 +234,7 @@ def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_lim """ expr = self._dict_to_expr(filter) if filter else "" all_items = [] - + # Use query_iterator for efficient pagination iterator = self.client.query_iterator( collection_name=collection_name, @@ -227,27 +242,30 @@ def get_by_filter(self, collection_name: str, filter: dict[str, Any], scroll_lim batch_size=scroll_limit, output_fields=["*"], # Include all fields including payload ) - + # Iterate through all batches try: while True: batch_results = iterator.next() - + if not batch_results: break - + # Convert batch results to VecDBItem objects for entity in batch_results: # Extract the actual payload from Milvus entity payload = entity.get("payload", {}) - all_items.append(VecDBItem( - id=entity["id"], - vector=entity.get("vector"), - payload=payload, - )) + all_items.append( + VecDBItem( + id=entity["id"], + vector=entity.get("vector"), + payload=payload, + ) + ) except Exception as e: - logger.warning(f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far.") - # 返回已经找到的项目,而不是空列表 + logger.warning( + f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far." + ) finally: # Close the iterator iterator.close() @@ -291,14 +309,14 @@ def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> N if isinstance(item, dict): item = item.copy() item = VecDBItem.from_dict(item) - + # Prepare entity data entity = { "id": item.id, "vector": item.vector, - "payload": item.payload if item.payload else {} + "payload": item.payload if item.payload else {}, } - + entities.append(entity) # Use upsert to be safe (insert or update) @@ -345,4 +363,4 @@ def delete(self, collection_name: str, ids: list[str]) -> None: self.client.delete( collection_name=collection_name, ids=ids, - ) \ No newline at end of file + ) diff --git a/tests/configs/test_mem_cube.py b/tests/configs/test_mem_cube.py index 6c962dd01..c50195558 100644 --- a/tests/configs/test_mem_cube.py +++ b/tests/configs/test_mem_cube.py @@ -28,7 +28,7 @@ def test_base_mem_cube_config(): def test_general_mem_cube_config(): check_config_base_class( GeneralMemCubeConfig, - factory_fields=["text_mem", "act_mem", "para_mem"], + factory_fields=["text_mem", "act_mem", "para_mem", "pref_mem"], required_fields=[], optional_fields=["config_filename", "user_id", "cube_id"], reserved_fields=["model_schema"], From 54a6b23f470de925c2512b6d61918cf9687c6578 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Mon, 29 Sep 2025 20:58:59 +0800 Subject: [PATCH 29/62] repair bug in preference --- src/memos/mem_os/core.py | 245 +++++++++++++---------- src/memos/memories/textual/preference.py | 3 +- 2 files changed, 140 insertions(+), 108 deletions(-) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index d96b555fd..9ae4176b6 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -6,6 +6,7 @@ from pathlib import Path from threading import Lock from typing import Any, Literal +from concurrent.futures import ThreadPoolExecutor from memos.configs.mem_os import MOSConfig from memos.llms.factory import LLMFactory @@ -590,57 +591,77 @@ def search( tmp_mem_cubes[mem_cube_id] = self.mem_cubes.get(mem_cube_id) for mem_cube_id, mem_cube in tmp_mem_cubes.items(): - if ( - (mem_cube_id in install_cube_ids) - and (mem_cube.text_mem is not None) - and self.config.enable_textual_memory - ): - time_start = time.time() - memories = mem_cube.text_mem.search( - query, - top_k=top_k if top_k else self.config.top_k, - mode=mode, - manual_close_internet=not internet_search, - info={ - "user_id": target_user_id, - "session_id": self.session_id, - "chat_history": chat_history.chat_history, - }, - moscube=moscube, - ) - result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories}) - logger.info( - f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n" - ) - search_time_end = time.time() - logger.info( - f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" - ) - - if ( - (mem_cube_id in install_cube_ids) - and (mem_cube.pref_mem is not None) - and self.config.enable_preference_memory - ): - time_start = time.time() - memories = mem_cube.pref_mem.search( - query, - top_k=top_k if top_k else self.config.top_k, - mode=mode, - info={ - "user_id": target_user_id, - "session_id": self.session_id, - "chat_history": chat_history.chat_history, - }, - ) - result["pref_mem"].append({"cube_id": mem_cube_id, "memories": memories}) - logger.info( - f"🧠 [Memory] Searched preferences from {mem_cube_id}:\n{self._str_memories(memories)}\n" - ) - search_time_end = time.time() - logger.info( - f"time search pref: search pref time user_id: {target_user_id} time is: {search_time_end - time_start}" - ) + # Define internal functions for parallel search execution + def search_textual_memory(): + if ( + (mem_cube_id in install_cube_ids) + and (mem_cube.text_mem is not None) + and self.config.enable_textual_memory + ): + time_start = time.time() + memories = mem_cube.text_mem.search( + query, + top_k=top_k if top_k else self.config.top_k, + mode=mode, + manual_close_internet=not internet_search, + info={ + "user_id": target_user_id, + "session_id": self.session_id, + "chat_history": chat_history.chat_history, + }, + moscube=moscube, + ) + search_time_end = time.time() + logger.info( + f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n" + ) + logger.info( + f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" + ) + return {"cube_id": mem_cube_id, "memories": memories} + return None + + def search_preference_memory(): + if ( + (mem_cube_id in install_cube_ids) + and (mem_cube.pref_mem is not None) + and self.config.enable_preference_memory + ): + time_start = time.time() + memories = mem_cube.pref_mem.search( + query, + top_k=top_k if top_k else self.config.top_k, + mode=mode, + info={ + "user_id": target_user_id, + "session_id": self.session_id, + "chat_history": chat_history.chat_history, + }, + ) + search_time_end = time.time() + logger.info( + f"🧠 [Memory] Searched preferences from {mem_cube_id}:\n{self._str_memories(memories)}\n" + ) + logger.info( + f"time search pref: search pref time user_id: {target_user_id} time is: {search_time_end - time_start}" + ) + return {"cube_id": mem_cube_id, "memories": memories} + return None + + # Execute both search functions in parallel + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(search_textual_memory) + pref_future = executor.submit(search_preference_memory) + + # Wait for both tasks to complete and collect results + text_result = text_future.result() + pref_result = pref_future.result() + + # Add results to the main result dictionary + if text_result is not None: + result["text_mem"].append(text_result) + if pref_result is not None: + result["pref_mem"].append(pref_result) return result @@ -682,65 +703,77 @@ def add( if mem_cube_id not in self.mem_cubes: raise ValueError(f"MemCube '{mem_cube_id}' is not loaded. Please register.") - if ( - (messages is not None) - and self.config.enable_textual_memory - and self.mem_cubes[mem_cube_id].text_mem - ): - if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": - add_memory = [] - metadata = TextualMemoryMetadata( - user_id=target_user_id, session_id=self.session_id, source="conversation" - ) - for message in messages: - add_memory.append( - TextualMemoryItem(memory=message["content"], metadata=metadata) + + # Define internal functions for parallel execution + def process_textual_memory(): + if ( + (messages is not None) + and self.config.enable_textual_memory + and self.mem_cubes[mem_cube_id].text_mem + ): + if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": + add_memory = [] + metadata = TextualMemoryMetadata( + user_id=target_user_id, session_id=self.session_id, source="conversation" ) - self.mem_cubes[mem_cube_id].text_mem.add(add_memory) - else: - messages_list = [messages] - memories = self.mem_reader.get_memory( - messages_list, - type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}, - ) - - mem_ids = [] - for mem in memories: - mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem) - mem_ids.extend(mem_id_list) - logger.info( - f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_id_list}" + for message in messages: + add_memory.append( + TextualMemoryItem(memory=message["content"], metadata=metadata) + ) + self.mem_cubes[mem_cube_id].text_mem.add(add_memory) + else: + messages_list = [messages] + memories = self.mem_reader.get_memory( + messages_list, + type="chat", + info={"user_id": target_user_id, "session_id": self.session_id}, ) - # submit messages for scheduler - if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) + mem_ids = [] + for mem in memories: + mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem) + mem_ids.extend(mem_id_list) + logger.info( + f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_id_list}" + ) + + # submit messages for scheduler + if self.enable_mem_scheduler and self.mem_scheduler is not None: + mem_cube = self.mem_cubes[mem_cube_id] + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + + def process_preference_memory(): + if ( + (messages is not None) + and self.config.enable_preference_memory + and self.mem_cubes[mem_cube_id].pref_mem + ): + messages_list = [messages] + pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( + messages_list, + type="chat", + info={"user_id": target_user_id, "session_id": self.session_id}) + preferences = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) + logger.info( + f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {preferences}" + ) - if ( - (messages is not None) - and self.config.enable_preference_memory - and self.mem_cubes[mem_cube_id].pref_mem - ): - messages_list = [messages] - pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( - messages_list, - type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}, - ) - preferences = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) - logger.info( - f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {preferences}" - ) + # Execute both memory processing functions in parallel + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(process_textual_memory) + pref_future = executor.submit(process_preference_memory) + + # Wait for both tasks to complete + text_future.result() + pref_future.result() # user profile if ( diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 2a6d09224..cf26f885c 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -157,9 +157,8 @@ def dump(self, dir: str) -> None: # For preference memory, we don't need to dump to files # as the data is stored in the vector database try: - all_collections = self.vector_db.list_collections() json_memories = {} - for collection_name in all_collections: + for collection_name in self.vector_db.config.collection_name: items = self.vector_db.get_all(collection_name) json_memories[collection_name] = [memory.to_dict() for memory in items] From ed2806a4422ce52c968d1d935444ddace6101971 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 30 Sep 2025 16:33:37 +0800 Subject: [PATCH 30/62] add memory field for milvusvecdbitem and modify related module --- src/memos/memories/textual/item.py | 9 +- .../textual/prefer_text_memory/adder.py | 15 +- .../textual/prefer_text_memory/assemble.py | 70 ++++------ .../textual/prefer_text_memory/clustering.py | 1 - .../textual/prefer_text_memory/extractor.py | 23 ++-- .../textual/prefer_text_memory/retrievers.py | 9 +- .../textual/prefer_text_memory/updater.py | 129 ++---------------- src/memos/templates/prefer_complete_prompt.py | 6 +- src/memos/vec_dbs/item.py | 6 + src/memos/vec_dbs/milvus.py | 43 +++--- 10 files changed, 104 insertions(+), 207 deletions(-) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index f335175a7..9db920f73 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -111,19 +111,12 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): "explicit_preference", "implicit_preference", "topic_preference", "user_preference" ] = Field(default="explicit_preference", description="Type of preference.") dialog_id: str | None = Field(default=None, description="ID of the dialog.") - dialog_msgs: list[dict[str, str]] | None = Field( - default=None, description="Messages of the dialog." - ) dialog_str: str | None = Field(default=None, description="String of the dialog.") - dialog_vector: list[float] | None = Field(default=None, description="Vector of the dialog.") + embedding: list[float] | None = Field(default=None, description="Vector of the dialog.") topic_name: str | None = Field(default=None, description="Name of the topic.") topic_description: str | None = Field(default=None, description="Description of the topic.") topic_vector: list[float] | None = Field(default=None, description="Vector of the topic.") explicit_preference: str | None = Field(default=None, description="Explicit preference.") - center_dialog: str | None = Field(default=None, description="String of the center dialog.") - center_vector: list[float] | None = Field( - default=None, description="Vector of the center dialog." - ) created_at: str | None = Field(default=None, description="Timestamp of the dialog.") implicit_preference: str | None = Field(default=None, description="Implicit preference.") topic_cluster_name: str | None = Field(default=None, description="Name of the topic cluster.") diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index c61b322b3..e673c8fae 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -6,7 +6,7 @@ from memos.memories.textual.item import TextualMemoryItem from memos.templates.prefer_complete_prompt import NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT -from memos.vec_dbs.item import VecDBItem +from memos.vec_dbs.item import MilvusVecDBItem class BaseAdder(ABC): @@ -59,9 +59,12 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: """Process a single memory and return its ID if added successfully.""" try: payload = memory.to_dict()["metadata"] - vec_db_item = VecDBItem( - id=memory.id, vector=memory.metadata.dialog_vector, payload=payload + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=memory.id, memory=memory.memory, vector=memory.metadata.embedding, payload=payload ) + pref_type_collection_map = { "explicit_preference": "explicit_preference", "implicit_preference": "implicit_preference", @@ -72,15 +75,15 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: collection_name = pref_type_collection_map[preference_type] search_results = self.vector_db.search( - memory.metadata.dialog_vector, collection_name, top_k=1 + memory.metadata.embedding, collection_name, top_k=1 ) recall = search_results[0] if search_results else None if not recall or (recall.score is not None and recall.score < 0.5): self.vector_db.update(collection_name, memory.id, vec_db_item) return memory.id - old_msg_str = recall.payload.get("dialog_str", "") - new_msg_str = memory.metadata.dialog_str + old_msg_str = recall.memory + new_msg_str = memory.memory is_same = self._judge_update_or_add(old_msg_str, new_msg_str) if is_same: self.vector_db.delete(collection_name, [recall.id]) diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py index 9f4885c38..14234decb 100644 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -39,66 +39,52 @@ def get_instruction( ) -> str: """Assemble query and memories into a single memory.""" - # Initialize all preference lists - textual_mems = [] - explicit_prefs = [] - implicit_prefs = [] - topic_prefs = [] - user_prefs = [] - - # Single loop to categorize all memories by preference type + # Initialize all memory lists + mems = { + "textual_mems": [], + "explicit_prefs": [], + "implicit_prefs": [], + "topic_prefs": [], + "user_prefs": [], + } + for memory in memories: if memory.metadata.preference_type == "explicit_preference": - explicit_prefs.append( - { - "dialog_str": memory.metadata.dialog_str, - "explicit_preference": memory.metadata.explicit_preference, - } - ) + mems["explicit_prefs"].append(memory.metadata.explicit_preference) elif memory.metadata.preference_type == "implicit_preference": - implicit_prefs.append( - { - "dialog_str": memory.metadata.center_dialog, - "implicit_preference": memory.metadata.implicit_preference, - } - ) + mems["implicit_prefs"].append(memory.metadata.implicit_preference) elif memory.metadata.preference_type == "topic_preference": - topic_prefs.append( - { - "center_dialog_str": memory.metadata.center_dialog, - "topic_preferences": memory.metadata.topic_preferences, - } - ) + mems["topic_prefs"].append(memory.metadata.topic_preference) elif memory.metadata.preference_type == "user_preference": - user_prefs.append({"user_preferences": memory.metadata.user_preferences}) + mems["user_prefs"].append(memory.metadata.user_preference) else: - textual_mems.append(memory.memory) + mems["textual_mems"].append(memory.memory) # Build memories string with different titles for different types memories_parts = [] - if textual_mems: + if mems["textual_mems"]: memories_parts.append("## Textual Memories:") - for i, mem in enumerate(textual_mems, 1): + for i, mem in enumerate(mems["textual_mems"], 1): memories_parts.append(f"{i}. {mem}") - if explicit_prefs: + if mems["explicit_prefs"]: memories_parts.append("## Explicit Preferences:") - for i, pref in enumerate(explicit_prefs, 1): - memories_parts.append(f"{i}. {pref['dialog_str']}") + for i, pref in enumerate(mems["explicit_prefs"], 1): + memories_parts.append(f"{i}. {pref}") - if implicit_prefs: + if mems["implicit_prefs"]: memories_parts.append("\n## Implicit Preferences:") - for i, pref in enumerate(implicit_prefs, 1): - memories_parts.append(f"{i}. {pref['dialog_str']}") + for i, pref in enumerate(mems["implicit_prefs"], 1): + memories_parts.append(f"{i}. {pref}") - if topic_prefs: + if mems["topic_prefs"]: memories_parts.append("\n## Topic Preferences:") - for i, pref in enumerate(topic_prefs, 1): - memories_parts.append(f"{i}. {pref['center_dialog_str']}") + for i, pref in enumerate(mems["topic_prefs"], 1): + memories_parts.append(f"{i}. {pref}") - if user_prefs: + if mems["user_prefs"]: memories_parts.append("\n## User Preferences:") - for i, pref in enumerate(user_prefs, 1): - memories_parts.append(f"{i}. {pref['user_preferences']}") + for i, pref in enumerate(mems["user_prefs"], 1): + memories_parts.append(f"{i}. {pref}") memories_str = "\n".join(memories_parts) diff --git a/src/memos/memories/textual/prefer_text_memory/clustering.py b/src/memos/memories/textual/prefer_text_memory/clustering.py index cb68370b6..53d858853 100644 --- a/src/memos/memories/textual/prefer_text_memory/clustering.py +++ b/src/memos/memories/textual/prefer_text_memory/clustering.py @@ -21,7 +21,6 @@ class ClusterResult(BaseModel): size: int created_at: str items: list[dict[str, Any]] - center_dialog_msgs: MessageList = None center_dialog_str: str = None diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 81b5cf8d6..e88b71cc0 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -42,7 +42,6 @@ def extract_basic_info(self, qa_pair: MessageList) -> dict[str, Any]: """Extract basic information from a QA pair (no LLM needed).""" basic_info = { "dialog_id": str(uuid.uuid4()), - "dialog_msgs": qa_pair, "dialog_str": convert_messages_to_string(qa_pair), "created_at": datetime.now().isoformat(), } @@ -77,7 +76,7 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A print(f"Error extracting explicit preference: {e}, return None") return None - def extract_implicit_preferences( + def extract_implicit_preference( self, qa_pairs: MessageList | list[str] ) -> dict[str, Any] | None: """Extract implicit preferences from cluster qa pairs.""" @@ -99,7 +98,7 @@ def extract_implicit_preferences( print(f"Error extracting implicit preferences: {e}, return None") return None - def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> dict[str, Any] | None: + def extract_topic_preference(self, qa_pairs: MessageList | list[str]) -> dict[str, Any] | None: """Extract topic preferences from cluster qa pairs.""" if not qa_pairs: return None @@ -121,15 +120,15 @@ def extract_topic_preferences(self, qa_pairs: MessageList | list[str]) -> dict[s print(f"Error extracting topic preferences: {qa_pairs}\n{e}, return None") return None - def extract_user_preferences( - self, topic_preferences: list[dict[str, Any]] + def extract_user_preference( + self, topic_preference: list[dict[str, Any]] ) -> dict[str, Any] | None: """Extract user-level preferences.""" - if not topic_preferences: + if not topic_preference: return [] prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace( - "{cluster_info}", json.dumps(topic_preferences, ensure_ascii=False, indent=2) + "{cluster_info}", json.dumps(topic_preference, ensure_ascii=False, indent=2) ) try: @@ -139,7 +138,7 @@ def extract_user_preferences( if result.get("user_preference"): return result except Exception as e: - print(f"Error processing user preferences: {topic_preferences}\n{e}, return None") + print(f"Error processing user preferences: {topic_preference}\n{e}, return None") return "" def _process_single_chunk_explicit( @@ -155,7 +154,7 @@ def _process_single_chunk_explicit( return None vector_info = { - "dialog_vector": self.embedder.embed([basic_info["dialog_str"]])[0], + "embedding": self.embedder.embed([basic_info["dialog_str"]])[0], } extract_info = {**basic_info, **explicit_pref, **vector_info, **info} @@ -165,6 +164,7 @@ def _process_single_chunk_explicit( memory = TextualMemoryItem( id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata ) + return memory def _process_single_chunk_implicit( @@ -173,12 +173,12 @@ def _process_single_chunk_implicit( basic_info = self.extract_basic_info(chunk) if not basic_info["dialog_str"]: return None - implicit_pref = self.extract_implicit_preferences(basic_info["dialog_str"]) + implicit_pref = self.extract_implicit_preference(basic_info["dialog_str"]) if not implicit_pref: return None vector_info = { - "dialog_vector": self.embedder.embed([basic_info["dialog_str"]])[0], + "embedding": self.embedder.embed([basic_info["dialog_str"]])[0], } extract_info = {**basic_info, **implicit_pref, **vector_info, **info} @@ -189,6 +189,7 @@ def _process_single_chunk_implicit( memory = TextualMemoryItem( id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata ) + return memory def extract( diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 0e5572c94..0a6a10f2f 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -71,16 +71,17 @@ def retrieve( explicit_prefs = [ TextualMemoryItem( id=pref.id, - memory=pref.payload.get("dialog_str", ""), + memory=pref.memory, metadata=PreferenceTextualMemoryMetadata(**pref.payload), ) for pref in explicit_prefs if pref.payload["explicit_preference"] ] + implicit_prefs = [ TextualMemoryItem( id=pref.id, - memory=pref.payload.get("dialog_str", ""), + memory=pref.memory, metadata=PreferenceTextualMemoryMetadata(**pref.payload), ) for pref in implicit_prefs @@ -89,7 +90,7 @@ def retrieve( topic_prefs = [ TextualMemoryItem( id=pref.id, - memory=pref.payload.get("center_dialog", ""), + memory=pref.memory, metadata=PreferenceTextualMemoryMetadata(**pref.payload), ) for pref in topic_prefs @@ -98,7 +99,7 @@ def retrieve( user_prefs = [ TextualMemoryItem( id=pref.id, - memory=pref.payload.get("user_preference", ""), + memory=pref.memory, metadata=PreferenceTextualMemoryMetadata(**pref.payload), ) for pref in user_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py index 18b4ae480..723107f75 100644 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ b/src/memos/memories/textual/prefer_text_memory/updater.py @@ -7,7 +7,7 @@ from typing import Any from memos.memories.textual.prefer_text_memory.clustering import ClusterResult, HDBSCANClusterer -from memos.vec_dbs.item import VecDBItem +from memos.vec_dbs.item import MilvusVecDBItem class BaseUpdater(ABC): @@ -30,17 +30,6 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=N self.extractor = extractor self.clusterer = HDBSCANClusterer() - def _implicit_cluster(self, informations: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Implicit cluster.""" - vectors = [info.get("dialog_vector") for info in informations] - if not vectors: - return [] - res = self.clusterer.cluster(vectors) - for cluster in res: - cluster.center_dialog_msgs = informations[cluster.center_index].get("dialog_msgs", []) - cluster.center_dialog_str = informations[cluster.center_index].get("dialog_str", "") - return res - def _topic_cluster(self, informations: list[dict[str, Any]]) -> list[dict[str, Any]]: """Topic cluster.""" vectors = [info.get("topic_vector") for info in informations] @@ -48,7 +37,6 @@ def _topic_cluster(self, informations: list[dict[str, Any]]) -> list[dict[str, A return [] res = self.clusterer.cluster(vectors) for cluster in res: - cluster.center_dialog_msgs = informations[cluster.center_index].get("dialog_msgs", []) cluster.center_dialog_str = informations[cluster.center_index].get("dialog_str", "") return res @@ -119,57 +107,18 @@ def _create_cluster_extract_input( return result - def _process_single_implicit_cluster( - self, cluster_id: str, cluster_dialogs: list[str] - ) -> dict[str, Any]: - """Process a single implicit cluster.""" - try: - result = self.extractor.extract_implicit_preferences(cluster_dialogs) - return {"cluster_id": cluster_id, "implicit_exract_result": result} - except Exception as e: - print(f"Error processing implicit cluster {cluster_id}: {e}") - return {"cluster_id": cluster_id, "implicit_exract_result": None} - - def _extract_implicit_preferences( - self, implicit_extract_inputs: dict[str, list[str]], max_workers: int = 10 - ) -> dict[str, dict[str, Any]]: - """Extract implicit preferences from implicit extract inputs using thread pool.""" - if not implicit_extract_inputs: - return {} - - results = {} - with ThreadPoolExecutor( - max_workers=min(max_workers, len(implicit_extract_inputs)) - ) as executor: - futures = [ - executor.submit(self._process_single_implicit_cluster, cluster_id, cluster_dialogs) - for cluster_id, cluster_dialogs in implicit_extract_inputs.items() - ] - - for future in as_completed(futures): - try: - result = future.result() - if result["implicit_exract_result"] is not None: - cluster_id = result["cluster_id"] - results[cluster_id] = result["implicit_exract_result"] - except Exception as e: - print(f"Error processing implicit cluster: {e}") - continue - - return results - def _process_single_topic_cluster( self, cluster_id: str, cluster_dialogs: list[str] ) -> dict[str, Any]: """Process a single topic cluster.""" try: - result = self.extractor.extract_topic_preferences(cluster_dialogs) + result = self.extractor.extract_topic_preference(cluster_dialogs) return {"cluster_id": cluster_id, "topic_exract_result": result} except Exception as e: print(f"Error processing topic cluster {cluster_id}: {e}") return {"cluster_id": cluster_id, "topic_exract_result": None} - def _extract_topic_preferences( + def _extract_topic_preference( self, topic_extract_inputs: dict[str, list[str]], max_workers: int = 10 ) -> dict[str, dict[str, Any]]: """Extract topic preferences from topic extract inputs using thread pool.""" @@ -197,54 +146,24 @@ def _extract_topic_preferences( return results - def _extract_user_preferences( + def _extract_user_preference( self, topic_cluster_pref_infos: dict[str, Any] ) -> dict[str, Any] | None: """Extract user preferences from topic cluster info.""" # topic_prefs is a dict, so we just pass the values as a list topic_cluster_pref = list(topic_cluster_pref_infos.values()) - return self.extractor.extract_user_preferences(topic_cluster_pref) + return self.extractor.extract_user_preference(topic_cluster_pref) def _store_preferences( self, - implicit_clusters, topic_clusters, - implicit_cluster_prefs, topic_cluster_prefs, user_prefs, user_id, ): """Create store data.""" - implicit_memories = [] topic_memories = [] - if implicit_clusters: - for cluster in implicit_clusters: - if cluster.cluster_id not in implicit_cluster_prefs: - print( - f"Warning: No preference found for cluster {cluster.cluster_id}, skipping..." - ) - continue - pref = implicit_cluster_prefs[cluster.cluster_id] - mem = VecDBItem( - id=cluster.cluster_id, - vector=cluster.center_vector, - payload={ - "cluster_id": cluster.cluster_id, - "center_dialog_msgs": cluster.center_dialog_msgs, - "center_dialog_str": cluster.center_dialog_str, - "center_vector": cluster.center_vector, - "implicit_preference": pref.get("implicit_preference", ""), - "created_at": cluster.created_at, - "user_id": user_id, - "size": cluster.size, - "preference_type": "implicit_preference", - }, - ) - implicit_memories.append(mem) - - self.vector_db.add("implicit_preference", implicit_memories) - if topic_clusters: for cluster in topic_clusters: if cluster.cluster_id not in topic_cluster_prefs: @@ -253,14 +172,12 @@ def _store_preferences( ) continue pref = topic_cluster_prefs[cluster.cluster_id] - mem = VecDBItem( + mem = MilvusVecDBItem( id=cluster.cluster_id, + memory=cluster.center_dialog_str, vector=cluster.center_vector, payload={ "cluster_id": cluster.cluster_id, - "center_dialog_msgs": cluster.center_dialog_msgs, - "center_dialog_str": cluster.center_dialog_str, - "center_vector": cluster.center_vector, "topic_cluster_name": pref.get("topic_cluster_name", ""), "topic_cluster_description": pref.get("topic_cluster_description", ""), "topic_preference": pref.get("topic_preference", ""), @@ -275,7 +192,7 @@ def _store_preferences( self.vector_db.add("topic_preference", topic_memories) if user_prefs: - mem = VecDBItem( + mem = MilvusVecDBItem( id=str(uuid.uuid4()), vector=[0.0] * self.vector_db.config.vector_dimension, payload={ @@ -290,17 +207,15 @@ def _store_preferences( def _generate_memory_summary( self, explicit_infos: list[dict[str, Any]], - implicit_infos: list[dict[str, Any]], topic_infos: list[dict[str, Any]], user_infos: dict[str, Any], ) -> str: """Generate a summary of the built memory.""" summary = { "memory_build_summary": { - "explicit_preferences_count": len(explicit_infos), - "implicit_preferences_count": len(implicit_infos), - "topic_preferences_count": len(topic_infos), - "user_preferences_count": 1 if user_infos else 0, + "explicit_preference_count": len(explicit_infos), + "topic_preference_count": len(topic_infos), + "user_preference_count": 1 if user_infos else 0, "build_timestamp": datetime.now().isoformat(), } } @@ -312,13 +227,7 @@ def slow_update(self, user_id: str): and reconstruct the implicit preference collection, topic collection and user preference collection. """ - # refresh the implicit preference collection, topic collection and user preference collection - impl_ids = [ - item.id - for item in self.vector_db.get_by_filter( - collection_name="implicit_preference", filter={"user_id": user_id} - ) - ] + # refresh the topic collection and user preference collection topic_ids = [ item.id for item in self.vector_db.get_by_filter( @@ -332,7 +241,6 @@ def slow_update(self, user_id: str): ) ] - self.vector_db.delete("implicit_preference", impl_ids) self.vector_db.delete("topic_preference", topic_ids) self.vector_db.delete("user_preference", user_ids) @@ -341,31 +249,23 @@ def slow_update(self, user_id: str): informations = [item.payload for item in all_data] # Perform clustering - implicit_clusters = self._implicit_cluster(informations) topic_clusters = self._topic_cluster(informations) # create extract inputs for each implicit and topic cluster - implicit_extract_inputs = self._create_cluster_extract_input( - implicit_clusters, informations, "original" - ) topic_extract_inputs = self._create_cluster_extract_input( topic_clusters, informations, "original" ) # Extract preferences - if implicit_extract_inputs: - implicit_cluster_prefs = self._extract_implicit_preferences(implicit_extract_inputs) if topic_extract_inputs: - topic_cluster_prefs = self._extract_topic_preferences(topic_extract_inputs) + topic_cluster_prefs = self._extract_topic_preference(topic_extract_inputs) # Extract user preferences - user_prefs = self._extract_user_preferences(topic_cluster_prefs) + user_prefs = self._extract_user_preference(topic_cluster_prefs) # Store all preferences in memory self._store_preferences( - implicit_clusters=implicit_clusters, topic_clusters=topic_clusters, - implicit_cluster_prefs=implicit_cluster_prefs, topic_cluster_prefs=topic_cluster_prefs, user_prefs=user_prefs, user_id=user_id, @@ -374,7 +274,6 @@ def slow_update(self, user_id: str): # Return summary of built memory return self._generate_memory_summary( explicit_infos=informations, - implicit_infos=implicit_clusters, topic_infos=topic_clusters, user_infos=user_prefs, ) diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 3f2e6b2f1..6fbe17176 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -66,7 +66,7 @@ - Do not duplicate: do not treat explicit preferences as implicit preferences. Requirements: -1. Output in JSON format with two fields: "explicit_preferences" and "implicit_preferences". +1. Output in JSON format with two fields: "explicit_preference" and "implicit_preference". 2. Each field should be an array, with each element being a concise natural language preference statement. 3. Output only the preference statements themselves; do not include any extra explanation, reasoning, or confidence information. 4. If a type of preference does not exist, its array should be empty. @@ -77,8 +77,8 @@ Output Format: ```json { - "explicit_preferences": ["The user clearly likes coffee", "The user does not want to sit by the window"], - "implicit_preferences": ["The user prefers a quiet environment"] + "explicit_preference": ["The user clearly likes coffee", "The user does not want to sit by the window"], + "implicit_preference": ["The user prefers a quiet environment"] } ``` Don't output anything except the JSON. diff --git a/src/memos/vec_dbs/item.py b/src/memos/vec_dbs/item.py index 6f74879ac..c753ab177 100644 --- a/src/memos/vec_dbs/item.py +++ b/src/memos/vec_dbs/item.py @@ -41,3 +41,9 @@ def from_dict(cls, data: dict[str, Any]) -> "VecDBItem": def to_dict(self) -> dict[str, Any]: """Convert to dictionary format.""" return self.model_dump(exclude_none=True) + + +class MilvusVecDBItem(VecDBItem): + """Represents a single item in the Milvus vector database.""" + + memory: str | None = Field(default=None, description="Memory string") \ No newline at end of file diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index a46c4df79..8be6db6f1 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -6,7 +6,7 @@ from memos.dependency import require_python_package from memos.log import get_logger from memos.vec_dbs.base import BaseVecDB -from memos.vec_dbs.item import VecDBItem +from memos.vec_dbs.item import MilvusVecDBItem logger = get_logger(__name__) @@ -39,6 +39,9 @@ def create_schema(self): schema.add_field( field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True ) + schema.add_field( + field_name="memory", datatype=DataType.VARCHAR, max_length=65535 + ) schema.add_field( field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension ) @@ -106,7 +109,7 @@ def search( collection_name: str, top_k: int, filter: dict[str, Any] | None = None, - ) -> list[VecDBItem]: + ) -> list[MilvusVecDBItem]: """ Search for similar items in the database. @@ -135,8 +138,9 @@ def search( entity = hit.get("entity", {}) items.append( - VecDBItem( + MilvusVecDBItem( id=str(hit["id"]), + memory=entity.get("memory"), vector=entity.get("vector"), payload=entity.get("payload", {}), score=1 - float(hit["distance"]), @@ -177,7 +181,7 @@ def _get_metric_type(self) -> str: } return metric_map.get(self.config.distance_metric, "L2") - def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: + def get_by_id(self, collection_name: str, id: str) -> MilvusVecDBItem | None: """Get a single item by ID.""" results = self.client.get( collection_name=collection_name, @@ -190,13 +194,14 @@ def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: entity = results[0] payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} - return VecDBItem( + return MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) - def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: + def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBItem]: """Get multiple items by their IDs.""" results = self.client.get( collection_name=collection_name, @@ -210,8 +215,9 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: for entity in results: payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} items.append( - VecDBItem( + MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) @@ -221,7 +227,7 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: def get_by_filter( self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 - ) -> list[VecDBItem]: + ) -> list[MilvusVecDBItem]: """ Retrieve all items that match the given filter criteria using query_iterator. @@ -251,13 +257,14 @@ def get_by_filter( if not batch_results: break - # Convert batch results to VecDBItem objects + # Convert batch results to MilvusVecDBItem objects for entity in batch_results: # Extract the actual payload from Milvus entity payload = entity.get("payload", {}) all_items.append( - VecDBItem( + MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) @@ -273,7 +280,7 @@ def get_by_filter( logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") return all_items - def get_all(self, collection_name: str, scroll_limit=100) -> list[VecDBItem]: + def get_all(self, collection_name: str, scroll_limit=100) -> list[MilvusVecDBItem]: """Retrieve all items in the vector database.""" return self.get_by_filter(collection_name, {}, scroll_limit=scroll_limit) @@ -294,13 +301,14 @@ def count(self, collection_name: str, filter: dict[str, Any] | None = None) -> i # Extract row count from stats - stats is a dict, not a list return int(stats.get("row_count", 0)) - def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + def add(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]]) -> None: """ Add data to the vector database. Args: - data: List of VecDBItem objects or dictionaries containing: + data: List of MilvusVecDBItem objects or dictionaries containing: - 'id': unique identifier + - 'memory': memory string - 'vector': embedding vector - 'payload': additional fields for filtering/retrieval """ @@ -308,11 +316,12 @@ def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> N for item in data: if isinstance(item, dict): item = item.copy() - item = VecDBItem.from_dict(item) + item = MilvusVecDBItem.from_dict(item) # Prepare entity data entity = { "id": item.id, + "memory": item.memory, "vector": item.vector, "payload": item.payload if item.payload else {}, } @@ -325,11 +334,11 @@ def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> N data=entities, ) - def update(self, collection_name: str, id: str, data: VecDBItem | dict[str, Any]) -> None: + def update(self, collection_name: str, id: str, data: MilvusVecDBItem | dict[str, Any]) -> None: """Update an item in the vector database.""" if isinstance(data, dict): data = data.copy() - data = VecDBItem.from_dict(data) + data = MilvusVecDBItem.from_dict(data) # Use upsert for updates self.upsert(collection_name, [data]) @@ -346,7 +355,7 @@ def ensure_payload_indexes(self, fields: list[str]) -> None: # Field indexes are created automatically for scalar fields logger.info(f"Milvus automatically indexes scalar fields: {fields}") - def upsert(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + def upsert(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]]) -> None: """ Add or update data in the vector database. From 4bd97875a4ad9e43886124507772ec435a081843 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sun, 5 Oct 2025 14:00:41 +0800 Subject: [PATCH 31/62] pref code clean --- examples/mem_os/simple_prefs_memos_product.py | 1 - src/memos/configs/memory.py | 6 - src/memos/memories/textual/item.py | 9 - .../textual/prefer_text_memory/adder.py | 2 - .../textual/prefer_text_memory/assemble.py | 16 - .../textual/prefer_text_memory/clustering.py | 163 ---------- .../textual/prefer_text_memory/config.py | 33 --- .../textual/prefer_text_memory/extractor.py | 60 ---- .../textual/prefer_text_memory/factory.py | 28 -- .../textual/prefer_text_memory/retrievers.py | 34 +-- .../textual/prefer_text_memory/updater.py | 279 ------------------ src/memos/memories/textual/preference.py | 17 -- src/memos/templates/prefer_complete_prompt.py | 153 +--------- 13 files changed, 2 insertions(+), 799 deletions(-) delete mode 100644 src/memos/memories/textual/prefer_text_memory/clustering.py delete mode 100644 src/memos/memories/textual/prefer_text_memory/updater.py diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py index 9d0722223..32d6d3d35 100644 --- a/examples/mem_os/simple_prefs_memos_product.py +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -102,7 +102,6 @@ def get_config(user_id: str): "extractor": {"backend": "naive", "config": {}}, "adder": {"backend": "naive", "config": {}}, "retriever": {"backend": "naive", "config": {}}, - "updater": {"backend": "naive", "config": {}}, "assembler": {"backend": "naive", "config": {}}, }, }, diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index cf5cdfbda..8b5355a0f 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -15,7 +15,6 @@ AssemblerConfigFactory, ExtractorConfigFactory, RetrieverConfigFactory, - UpdaterConfigFactory, ) @@ -220,11 +219,6 @@ class PreferenceTextMemoryConfig(BaseTextMemoryConfig): default_factory=RetrieverConfigFactory, description="Retriever configuration for the memory retrieving", ) - updater: UpdaterConfigFactory = Field( - ..., - default_factory=UpdaterConfigFactory, - description="Updater configuration for the memory updating", - ) assembler: AssemblerConfigFactory = Field( ..., default_factory=AssemblerConfigFactory, diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 9db920f73..7cda1055e 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -113,18 +113,9 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): dialog_id: str | None = Field(default=None, description="ID of the dialog.") dialog_str: str | None = Field(default=None, description="String of the dialog.") embedding: list[float] | None = Field(default=None, description="Vector of the dialog.") - topic_name: str | None = Field(default=None, description="Name of the topic.") - topic_description: str | None = Field(default=None, description="Description of the topic.") - topic_vector: list[float] | None = Field(default=None, description="Vector of the topic.") explicit_preference: str | None = Field(default=None, description="Explicit preference.") created_at: str | None = Field(default=None, description="Timestamp of the dialog.") implicit_preference: str | None = Field(default=None, description="Implicit preference.") - topic_cluster_name: str | None = Field(default=None, description="Name of the topic cluster.") - topic_cluster_description: str | None = Field( - default=None, description="Description of the topic cluster." - ) - topic_preference: str | None = Field(default=None, description="Preferences of the topic.") - user_preference: str | None = Field(default=None, description="Preferences of the user.") class TextualMemoryItem(BaseModel): diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index e673c8fae..66c1c9550 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -68,8 +68,6 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: pref_type_collection_map = { "explicit_preference": "explicit_preference", "implicit_preference": "implicit_preference", - "topic_preference": "topic_preference", - "user_preference": "user_preference", } preference_type = memory.metadata.preference_type collection_name = pref_type_collection_map[preference_type] diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py index 14234decb..309beb4da 100644 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ b/src/memos/memories/textual/prefer_text_memory/assemble.py @@ -44,8 +44,6 @@ def get_instruction( "textual_mems": [], "explicit_prefs": [], "implicit_prefs": [], - "topic_prefs": [], - "user_prefs": [], } for memory in memories: @@ -53,10 +51,6 @@ def get_instruction( mems["explicit_prefs"].append(memory.metadata.explicit_preference) elif memory.metadata.preference_type == "implicit_preference": mems["implicit_prefs"].append(memory.metadata.implicit_preference) - elif memory.metadata.preference_type == "topic_preference": - mems["topic_prefs"].append(memory.metadata.topic_preference) - elif memory.metadata.preference_type == "user_preference": - mems["user_prefs"].append(memory.metadata.user_preference) else: mems["textual_mems"].append(memory.memory) @@ -76,16 +70,6 @@ def get_instruction( for i, pref in enumerate(mems["implicit_prefs"], 1): memories_parts.append(f"{i}. {pref}") - if mems["topic_prefs"]: - memories_parts.append("\n## Topic Preferences:") - for i, pref in enumerate(mems["topic_prefs"], 1): - memories_parts.append(f"{i}. {pref}") - - if mems["user_prefs"]: - memories_parts.append("\n## User Preferences:") - for i, pref in enumerate(mems["user_prefs"], 1): - memories_parts.append(f"{i}. {pref}") - memories_str = "\n".join(memories_parts) system_prompt = ( diff --git a/src/memos/memories/textual/prefer_text_memory/clustering.py b/src/memos/memories/textual/prefer_text_memory/clustering.py deleted file mode 100644 index 53d858853..000000000 --- a/src/memos/memories/textual/prefer_text_memory/clustering.py +++ /dev/null @@ -1,163 +0,0 @@ -import uuid - -from abc import ABC, abstractmethod -from datetime import datetime -from typing import Any - -import numpy as np - -from pydantic import BaseModel -from sklearn.cluster import HDBSCAN -from sklearn.neighbors import NearestNeighbors - -from memos.types import MessageList - - -class ClusterResult(BaseModel): - cluster_id: str - cluster_label: int - center_index: int - center_vector: list[float] - size: int - created_at: str - items: list[dict[str, Any]] - center_dialog_str: str = None - - -class BaseClusterer(ABC): - """Abstract base class for clustering algorithms.""" - - @abstractmethod - def cluster(self, vectors: list[list[float]], **kwargs) -> list[ClusterResult]: - """ - Perform clustering on the given vectors. - - Args: - vectors: List of vectors to cluster - **kwargs: Additional clustering parameters - - Returns: - List of cluster information dictionaries - """ - - -class HDBSCANClusterer(BaseClusterer): - """HDBSCAN-based clustering implementation.""" - - def __init__(self, min_cluster_size: int = 2, min_samples: int = 1, metric: str = "cosine"): - """ - Initialize HDBSCAN clusterer. - - Args: - min_cluster_size: Minimum size of a cluster - min_samples: Minimum number of samples in a neighborhood - metric: Metric for clustering - """ - self.min_cluster_size = min_cluster_size - self.min_samples = min_samples - self.metric = metric - - self.clusterer = HDBSCAN( - min_cluster_size=self.min_cluster_size, min_samples=self.min_samples, metric=self.metric - ) - - def cluster(self, vectors: list[list[float]], **kwargs) -> list[ClusterResult]: - """ - Perform HDBSCAN clustering on the given vectors. - - Args: - vectors: List of vectors to cluster - **kwargs: Additional clustering parameters - - Returns: - List of cluster information dictionaries - """ - if not vectors or len(vectors) < 2: - return [] - - # Convert to numpy array - vectors_array = np.array(vectors) - - # Perform clustering - cluster_labels = self.clusterer.fit_predict(vectors_array) - - # Group vectors by cluster - clusters = {} - for i, label in enumerate(cluster_labels): - if label not in clusters: - clusters[label] = [] - clusters[label].append( - { - "vector": vectors[i], # origin sample vector - "index": i, # origin sample index - } - ) - - # Convert to cluster information - cluster_info = [] - for cluster_label, items_in_cluster in clusters.items(): - if cluster_label == -1: # Skip noise points - continue - - # Calculate cluster geometric center - geometric_center_vector = np.mean([item["vector"] for item in items_in_cluster], axis=0) - - # Find the vector closest to the geometric center using sklearn - cluster_vectors = [item["vector"] for item in items_in_cluster] - nn = NearestNeighbors(n_neighbors=1, metric=self.metric) - nn.fit(cluster_vectors) - _, indices = nn.kneighbors([geometric_center_vector.tolist()]) - closest_idx = indices[0][0] - center_index = items_in_cluster[closest_idx]["index"] - center_vector = items_in_cluster[closest_idx]["vector"] - - cluster_info.append( - ClusterResult( - cluster_id=str(uuid.uuid4()), - cluster_label=cluster_label, # cluseter label: -1(noise) 0 1 2 3 ... - center_index=center_index, # center sample origin index - center_vector=center_vector, - size=len(items_in_cluster), # size of the cluster - created_at=datetime.now().isoformat(), - items=items_in_cluster, # list of origin sample index and vector in this cluster - ) - ) - - return cluster_info - - def search_knn_by_center_embeddings( - self, center_emb: list[float], vectors: list[list[float]], top_k: int = 5 - ) -> list[dict[str, Any]]: - """ - Search for k-nearest neighbors to the center embedding using sklearn. - - Args: - center_emb: Center embedding vector - vectors: List of vectors to search in - top_k: Number of nearest neighbors to return (maybe include the center embedding itself) - - Returns: - List of dictionaries containing index, distance, and vector for each neighbor - """ - if not vectors: - return [] - - # Initialize NearestNeighbors with the same metric - nn = NearestNeighbors(n_neighbors=min(top_k, len(vectors)), metric=self.metric) - nn.fit(vectors) - - # Search for nearest neighbors - distances, indices = nn.kneighbors([center_emb]) - - # Format results - results = [] - for dist, idx in zip(distances[0], indices[0], strict=False): - results.append( - { - "index": int(idx), # neighbor original index - "distance": float(dist), # neighbor distance - "vector": vectors[idx], # neighbor vector - } - ) - - return results diff --git a/src/memos/memories/textual/prefer_text_memory/config.py b/src/memos/memories/textual/prefer_text_memory/config.py index f11e106b7..b6e04cadf 100644 --- a/src/memos/memories/textual/prefer_text_memory/config.py +++ b/src/memos/memories/textual/prefer_text_memory/config.py @@ -106,39 +106,6 @@ def create_config(self) -> "RetrieverConfigFactory": return self -class BaseUpdaterConfig(BaseConfig): - """Base configuration class for Updater.""" - - -class NaiveUpdaterConfig(BaseUpdaterConfig): - """Configuration for Naive Updater.""" - - -class UpdaterConfigFactory(BaseConfig): - """Factory class for creating Updater configurations.""" - - backend: str = Field(..., description="Backend for Updater") - config: dict[str, Any] = Field(..., description="Configuration for the Updater backend") - - backend_to_class: ClassVar[dict[str, Any]] = { - "naive": NaiveUpdaterConfig, - } - - @field_validator("backend") - @classmethod - def validate_backend(cls, backend: str) -> str: - """Validate the backend field.""" - if backend not in cls.backend_to_class: - raise ValueError(f"Invalid backend: {backend}") - return backend - - @model_validator(mode="after") - def create_config(self) -> "UpdaterConfigFactory": - config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) - return self - - class BaseAssemblerConfig(BaseConfig): """Base configuration class for Assembler.""" diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index e88b71cc0..7bb36c426 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -12,9 +12,6 @@ from memos.templates.prefer_complete_prompt import ( NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, - NAIVE_TOPIC_INFO_EXTRACT_PROMPT, - NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT, - NAIVE_USER_PREFERENCE_EXTRACT_PROMPT, ) from memos.types import MessageList @@ -48,20 +45,6 @@ def extract_basic_info(self, qa_pair: MessageList) -> dict[str, Any]: return basic_info - def extract_topic_info(self, qa_pair: MessageList | str) -> dict[str, Any]: - """Extract topic information from a QA pair.""" - qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair - prompt = NAIVE_TOPIC_INFO_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - response = response.strip().replace("```json", "").replace("```", "").strip() - result = json.loads(response) - return result - except Exception as e: - print(f"Error extracting topic info: {e}, return None") - return None - def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, Any] | None: """Extract explicit preference from a QA pair.""" qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair @@ -98,49 +81,6 @@ def extract_implicit_preference( print(f"Error extracting implicit preferences: {e}, return None") return None - def extract_topic_preference(self, qa_pairs: MessageList | list[str]) -> dict[str, Any] | None: - """Extract topic preferences from cluster qa pairs.""" - if not qa_pairs: - return None - qa_pairs_str = ( - convert_messages_to_string(qa_pairs) - if isinstance(qa_pairs[0], dict) - else "\n\n".join(qa_pairs) - ) - prompt = NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - response = response.strip().replace("```json", "").replace("```", "").strip() - result = json.loads(response) - - if result.get("topic_cluster_name"): - return result - except Exception as e: - print(f"Error extracting topic preferences: {qa_pairs}\n{e}, return None") - return None - - def extract_user_preference( - self, topic_preference: list[dict[str, Any]] - ) -> dict[str, Any] | None: - """Extract user-level preferences.""" - if not topic_preference: - return [] - - prompt = NAIVE_USER_PREFERENCE_EXTRACT_PROMPT.replace( - "{cluster_info}", json.dumps(topic_preference, ensure_ascii=False, indent=2) - ) - - try: - response = self.llm_provider.generate([{"role": "user", "content": prompt}]) - response = response.strip().replace("```json", "").replace("```", "").strip() - result = json.loads(response) - if result.get("user_preference"): - return result - except Exception as e: - print(f"Error processing user preferences: {topic_preference}\n{e}, return None") - return "" - def _process_single_chunk_explicit( self, chunk: MessageList, msg_type: str, info: dict[str, Any] ) -> TextualMemoryItem | None: diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index 253f21662..f0879c026 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -7,11 +7,9 @@ AssemblerConfigFactory, ExtractorConfigFactory, RetrieverConfigFactory, - UpdaterConfigFactory, ) from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor, NaiveExtractor from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever, NaiveRetriever -from memos.memories.textual.prefer_text_memory.updater import BaseUpdater, NaiveUpdater class AdderFactory(BaseAdder): @@ -79,32 +77,6 @@ def from_config( return retriever_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) -class UpdaterFactory(BaseUpdater): - """Factory class for creating Updater instances.""" - - backend_to_class: ClassVar[dict[str, Any]] = { - "naive": NaiveUpdater, - } - - @classmethod - def from_config( - cls, - config_factory: UpdaterConfigFactory, - llm_provider=None, - embedder=None, - vector_db=None, - extractor=None, - ) -> BaseUpdater: - """Create a Updater instance from a configuration factory.""" - backend = config_factory.backend - if backend not in cls.backend_to_class: - raise ValueError(f"Invalid backend: {backend}") - updater_class = cls.backend_to_class[backend] - return updater_class( - llm_provider=llm_provider, embedder=embedder, vector_db=vector_db, extractor=extractor - ) - - class AssemblerFactory(BaseAssembler): """Factory class for creating Assembler instances.""" diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 0a6a10f2f..74d8d24b9 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -49,24 +49,10 @@ def retrieve( future_implicit = executor.submit( self.vector_db.search, query_embedding, "implicit_preference", top_k, info ) - future_topic = executor.submit( - self.vector_db.search, query_embedding, "topic_preference", top_k, info - ) - - # Get user preferences directly (no vector search needed since there's only one per user) - user_id = info.get("user_id") if info else None - if user_id: - future_user = executor.submit( - self.vector_db.get_by_filter, "user_preference", {"user_id": user_id} - ) - else: - future_user = None # Wait for all results explicit_prefs = future_explicit.result() implicit_prefs = future_implicit.result() - topic_prefs = future_topic.result() - user_prefs = future_user.result() if future_user else [] explicit_prefs = [ TextualMemoryItem( @@ -87,23 +73,5 @@ def retrieve( for pref in implicit_prefs if pref.payload["implicit_preference"] ] - topic_prefs = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), - ) - for pref in topic_prefs - if pref.payload["topic_preference"] - ] - user_prefs = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), - ) - for pref in user_prefs - if pref.payload["user_preference"] - ] - return explicit_prefs + implicit_prefs + topic_prefs + user_prefs + return explicit_prefs + implicit_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/updater.py b/src/memos/memories/textual/prefer_text_memory/updater.py deleted file mode 100644 index 723107f75..000000000 --- a/src/memos/memories/textual/prefer_text_memory/updater.py +++ /dev/null @@ -1,279 +0,0 @@ -import json -import uuid - -from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import datetime -from typing import Any - -from memos.memories.textual.prefer_text_memory.clustering import ClusterResult, HDBSCANClusterer -from memos.vec_dbs.item import MilvusVecDBItem - - -class BaseUpdater(ABC): - """Abstract base class for updaters.""" - - @abstractmethod - def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=None): - """Initialize the updater.""" - - -class NaiveUpdater(BaseUpdater): - """Naive updater.""" - - def __init__(self, llm_provider=None, embedder=None, vector_db=None, extractor=None): - """Initialize the naive updater.""" - super().__init__(llm_provider, embedder, vector_db, extractor) - self.llm_provider = llm_provider - self.embedder = embedder - self.vector_db = vector_db - self.extractor = extractor - self.clusterer = HDBSCANClusterer() - - def _topic_cluster(self, informations: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Topic cluster.""" - vectors = [info.get("topic_vector") for info in informations] - if not vectors: - return [] - res = self.clusterer.cluster(vectors) - for cluster in res: - cluster.center_dialog_str = informations[cluster.center_index].get("dialog_str", "") - return res - - def _create_cluster_extract_input( - self, - cluster_results: list[ClusterResult], - informations: list[dict[str, Any]], - input_type: str, - k: int = 5, - ) -> dict[str, list[str]]: - """Create cluster extract input. - Args: - cluster_results: List[ClusterResult] - informations: List[Dict[str, Any]] containing dialog information - input_type: str, "original" or "knn" - k: int, number of nearest neighbors for knn mode - Returns: - Dict[str, List[str]] - cluster_id -> list of dialog strings - """ - result = {} - if not cluster_results: - return result - - if input_type == "original": - # Use all original data in each cluster - for cluster in cluster_results: - cluster_dialogs = [] - for item in cluster.items: - # item contains {"vector": ..., "index": ...} - original_index = item["index"] - dialog_str = informations[original_index].get("dialog_str", "") - cluster_dialogs.append(dialog_str) - if cluster_dialogs: - result[cluster.cluster_id] = cluster_dialogs - - elif input_type == "knn": - # Use knn search from cluster center to find k nearest neighbors - for cluster in cluster_results: - # Extract all vectors from items in this cluster - item_vectors = [item["vector"] for item in cluster.items] - - if not item_vectors: - result[cluster.cluster_id] = [] - continue - - # Use clusterer to find k nearest neighbors to cluster center - knn_results = self.clusterer.search_knn_by_center_embeddings( - center_emb=cluster.center_vector, - vectors=item_vectors, - top_k=min(k, len(item_vectors)), - ) - - # Get dialog strings for knn results - cluster_dialogs = [] - for knn_item in knn_results: - # knn_item contains {"index": idx_in_item_vectors, "distance": ..., "vector": ...} - item_idx = knn_item["index"] # index in item_vectors - original_index = cluster.items[item_idx][ - "index" - ] # original index in informations - dialog_str = informations[original_index].get("dialog_str", "") - cluster_dialogs.append(dialog_str) - if cluster_dialogs: - result[cluster.cluster_id] = cluster_dialogs - - else: - raise ValueError(f"Invalid input type: {input_type}") - - return result - - def _process_single_topic_cluster( - self, cluster_id: str, cluster_dialogs: list[str] - ) -> dict[str, Any]: - """Process a single topic cluster.""" - try: - result = self.extractor.extract_topic_preference(cluster_dialogs) - return {"cluster_id": cluster_id, "topic_exract_result": result} - except Exception as e: - print(f"Error processing topic cluster {cluster_id}: {e}") - return {"cluster_id": cluster_id, "topic_exract_result": None} - - def _extract_topic_preference( - self, topic_extract_inputs: dict[str, list[str]], max_workers: int = 10 - ) -> dict[str, dict[str, Any]]: - """Extract topic preferences from topic extract inputs using thread pool.""" - if not topic_extract_inputs: - return {} - - results = {} - with ThreadPoolExecutor( - max_workers=min(max_workers, len(topic_extract_inputs)) - ) as executor: - futures = [ - executor.submit(self._process_single_topic_cluster, cluster_id, cluster_dialogs) - for cluster_id, cluster_dialogs in topic_extract_inputs.items() - ] - - for future in as_completed(futures): - try: - result = future.result() - if result["topic_exract_result"] is not None: - cluster_id = result["cluster_id"] - results[cluster_id] = result["topic_exract_result"] - except Exception as e: - print(f"Error processing topic cluster: {e}") - continue - - return results - - def _extract_user_preference( - self, topic_cluster_pref_infos: dict[str, Any] - ) -> dict[str, Any] | None: - """Extract user preferences from topic cluster info.""" - # topic_prefs is a dict, so we just pass the values as a list - topic_cluster_pref = list(topic_cluster_pref_infos.values()) - return self.extractor.extract_user_preference(topic_cluster_pref) - - def _store_preferences( - self, - topic_clusters, - topic_cluster_prefs, - user_prefs, - user_id, - ): - """Create store data.""" - topic_memories = [] - - if topic_clusters: - for cluster in topic_clusters: - if cluster.cluster_id not in topic_cluster_prefs: - print( - f"Warning: No preference found for topic cluster {cluster.cluster_id}, skipping..." - ) - continue - pref = topic_cluster_prefs[cluster.cluster_id] - mem = MilvusVecDBItem( - id=cluster.cluster_id, - memory=cluster.center_dialog_str, - vector=cluster.center_vector, - payload={ - "cluster_id": cluster.cluster_id, - "topic_cluster_name": pref.get("topic_cluster_name", ""), - "topic_cluster_description": pref.get("topic_cluster_description", ""), - "topic_preference": pref.get("topic_preference", ""), - "created_at": cluster.created_at, - "user_id": user_id, - "size": cluster.size, - "preference_type": "topic_preference", - }, - ) - topic_memories.append(mem) - - self.vector_db.add("topic_preference", topic_memories) - - if user_prefs: - mem = MilvusVecDBItem( - id=str(uuid.uuid4()), - vector=[0.0] * self.vector_db.config.vector_dimension, - payload={ - "user_id": user_id, - "user_preference": user_prefs.get("user_preference", ""), - "created_at": datetime.now().isoformat(), - "preference_type": "user_preference", - }, - ) - self.vector_db.add("user_preference", [mem]) - - def _generate_memory_summary( - self, - explicit_infos: list[dict[str, Any]], - topic_infos: list[dict[str, Any]], - user_infos: dict[str, Any], - ) -> str: - """Generate a summary of the built memory.""" - summary = { - "memory_build_summary": { - "explicit_preference_count": len(explicit_infos), - "topic_preference_count": len(topic_infos), - "user_preference_count": 1 if user_infos else 0, - "build_timestamp": datetime.now().isoformat(), - } - } - - return json.dumps(summary, ensure_ascii=False, indent=2) - - def slow_update(self, user_id: str): - """Retrieve all dialog info from the expicit preference collection, - and reconstruct the implicit preference collection, topic collection and user preference collection. - """ - - # refresh the topic collection and user preference collection - topic_ids = [ - item.id - for item in self.vector_db.get_by_filter( - collection_name="topic_preference", filter={"user_id": user_id} - ) - ] - user_ids = [ - item.id - for item in self.vector_db.get_by_filter( - collection_name="user_preference", filter={"user_id": user_id} - ) - ] - - self.vector_db.delete("topic_preference", topic_ids) - self.vector_db.delete("user_preference", user_ids) - - # get all data from explicit preference collection - all_data = self.vector_db.get_by_filter("explicit_preference", filter={"user_id": user_id}) - informations = [item.payload for item in all_data] - - # Perform clustering - topic_clusters = self._topic_cluster(informations) - - # create extract inputs for each implicit and topic cluster - topic_extract_inputs = self._create_cluster_extract_input( - topic_clusters, informations, "original" - ) - - # Extract preferences - if topic_extract_inputs: - topic_cluster_prefs = self._extract_topic_preference(topic_extract_inputs) - - # Extract user preferences - user_prefs = self._extract_user_preference(topic_cluster_prefs) - - # Store all preferences in memory - self._store_preferences( - topic_clusters=topic_clusters, - topic_cluster_prefs=topic_cluster_prefs, - user_prefs=user_prefs, - user_id=user_id, - ) - - # Return summary of built memory - return self._generate_memory_summary( - explicit_infos=informations, - topic_infos=topic_clusters, - user_infos=user_prefs, - ) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index cf26f885c..30aed0dec 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -20,7 +20,6 @@ AssemblerFactory, ExtractorFactory, RetrieverFactory, - UpdaterFactory, ) from memos.types import MessageList from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB, VecDBFactory @@ -63,13 +62,6 @@ def __init__(self, config: PreferenceTextMemoryConfig): embedder=self.embedder, vector_db=self.vector_db, ) - self.updater = UpdaterFactory.from_config( - config.updater, - llm_provider=self.extractor_llm, - embedder=self.embedder, - vector_db=self.vector_db, - extractor=self.extractor, - ) self.assembler = AssemblerFactory.from_config( config.assembler, llm_provider=self.extractor_llm, @@ -88,15 +80,6 @@ def get_memory( """ return self.extractor.extract(messages, type, info) - def slow_update(self, user_id: str) -> str: - """Perform a slow update of preferences by reconstructing all preference collections. - Args: - user_id (str): The user ID to update preferences for. - Returns: - str: Summary of the memory build process. - """ - return self.updater.slow_update(user_id) - def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 6fbe17176..f1a575bd0 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -85,124 +85,6 @@ """ -NAIVE_TOPIC_PREFERENCE_EXTRACT_PROMPT = """ -You are a topic cluster analysis expert. You will be given a topic cluster containing multiple QA conversations with their preferences, topic names, and descriptions. -Your task is to analyze all information in this cluster and extract three key components, then output in strict JSON format. - -# Extraction Targets: - -1. **Cluster Name (cluster_name)**: - - Use concise words (2-6 words) to summarize the core theme of this cluster - - Should be more abstract and comprehensive than individual conversation topics - - Examples: "Machine Learning Basics", "Creative Writing", "Health & Wellness", "Career Development" - -2. **Cluster Description (cluster_description)**: - - Provide a brief summary of the main content and scope of this cluster theme - - Should be more specific than the cluster name - - Describe what topics and areas are covered within this cluster - -3. **Cluster Preferences (cluster_preferences)**: - - Combine and summarize common preferences from all explicit and implicit preferences in the cluster - - Focus on shared patterns across all QA pairs in this topic area - - Identify recurring user preferences specific to this topic domain - -# Output Format: -{ - "topic_cluster_name": "", - "topic_cluster_description": "", - "topic_preference": "" -} - -# Notes: -- If any field has no clear information, leave it as empty string "" -- Only output JSON, no explanations -- Focus on meaningful patterns that represent the cluster as a whole -- Cluster name should be broader than individual topic names - -# Cluster Information: -{cluster_info} -""" - - -NAIVE_USER_PREFERENCE_EXTRACT_PROMPT = """ -You are an advanced information integration assistant. You will be given a user's preference list across different topic clusters, where each cluster contains: - -- topic_cluster_name: The name of the topic cluster -- topic_cluster_description: The description of the topic cluster -- topic_preference: Natural language description of preferences in that cluster - -Your task is to extract the user's **highest-level common preferences** by focusing on these three key dimensions: - -# Analysis Dimensions: - -1. **Content Preferences**: What types of information, topics, or knowledge styles the user tends to prefer - - Subject matter interests and expertise areas - - Information depth and complexity preferences - - Knowledge domain preferences - -2. **Interaction Style Preferences**: How the user prefers information to be presented, structured, or delivered - - Communication format preferences (formal vs. casual) - - Information organization preferences (structured vs. narrative) - - Response style preferences (concise vs. detailed) - -3. **Value Orientations**: Core values or principles reflected in the user's information choices, processing, or practices - - Underlying priorities and decision-making patterns - - Quality standards and expectations - - Ethical or professional principles - -# Requirements: -- Synthesize common patterns across ALL topic clusters, not individual cluster details -- Express in natural language, highlighting the user's overall preference characteristics -- Summarize into a coherent paragraph that flows smoothly -- Avoid bullet points and don't repeat specific cluster examples or operational details -- Focus on high-level patterns that transcend individual topics - -# Output Format: -{ - "user_preference": "Write the synthesized highest-level common preferences here, covering content preferences, interaction style, and value orientations" -} - -# Notes: -- If no clear patterns emerge, describe the user as having diverse or varied preferences -- Focus on meaningful patterns, not trivial observations -- Only output JSON, no explanations - -# Cluster Information: -{cluster_info} -""" - - -NAIVE_TOPIC_INFO_EXTRACT_PROMPT = """ -You are a topic extraction assistant. You will be given a QA pair (user question Q and assistant answer A). -Please extract the main topic name and topic description from this conversation and output in JSON format. - -# Extraction Rules -- Topic Name (topic_name): - - Use concise words to summarize the core topic of the conversation - - Keep it between 2-6 words - - Examples: "Science Fiction Writing", "Technical Documentation", "Healthy Diet", "Career Advice", "Python Programming" - -- Topic Description (topic_description): - - Provide a brief 1-2 sentence summary of what the conversation is about - - Be more specific than the topic name but keep it under 50 words - - Focus on the main content and key points discussed - -# Output Format -{ - "topic_name": "", - "topic_description": "" -} - -# Notes -- If no clear topic can be identified, leave topic_name as empty string "" -- If the conversation is too brief or unclear, leave topic_description as empty string "" -- Only output JSON, no explanations - -# Conversation Content -{qa_pair} -""" - - NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT = """ You are a content comparison expert. Now you are given old and new information, each containing a question, answer topic name and topic description. Please judge whether these two information express the **same question or core content**, regardless of expression differences, details or example differences. The judgment criteria are as follows: @@ -222,37 +104,4 @@ **New Information:** {new_information} -""" - - -NAIVE_PREFERENCE_INTEGRATION_PROMPT = """ -You are a memory integration expert. Your task is to integrate various memories from the user and generate a final prompt that can be directly input into a large language model (LLM). The generated prompt should include all relevant memories, but you must filter them based on the current query. - -1. Requirements and Rules: -- Filter Relevant Memories: Only retain memories that are related to the current query, and discard irrelevant ones. - -2. Handle Preference Conflicts: -- Explicit preference memories take precedence over implicit preference memories. -- If conflicts arise, delete lower-priority memories according to the hierarchy to ensure consistency. - -3. Generate Final Integrated Prompt: -- The prompt should include all filtered, conflict-free memories. -- Provide structured, clear, specific, and actionable context. -- The prompt must be directly usable by the LLM to answer the user's query. - -4. Priority Rules: -Explicit memories > Implicit memories - -Please output in the following JSON formatm, don't output anything else: - -{ - "final_prompt": "The complete prompt containing filtered, integrated, and conflict-free memories, ready for LLM input", - "explanation": "Briefly explain the reasoning process for the final prompt and process of filtering and integrating memories" -} - -Query: -{query} - -Memories: -{memories} -""" +""" \ No newline at end of file From 4e3bd967ba7994468979204ae235010a557cf28c Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 10 Oct 2025 10:39:10 +0800 Subject: [PATCH 32/62] modify prompt of extractor --- .../textual/prefer_text_memory/retrievers.py | 4 +++ src/memos/templates/prefer_complete_prompt.py | 28 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 74d8d24b9..f0291db12 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -54,6 +54,10 @@ def retrieve( explicit_prefs = future_explicit.result() implicit_prefs = future_implicit.result() + # sort by score + explicit_prefs.sort(key=lambda x: x.score, reverse=True) + implicit_prefs.sort(key=lambda x: x.score, reverse=True) + explicit_prefs = [ TextualMemoryItem( id=pref.id, diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index f1a575bd0..71f636160 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -2,6 +2,34 @@ You are a preference extraction assistant. Please extract the user's explicitly mentioned preferences from the following conversation. +Notes: +- A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". +- This includes, but is not limited to, any clearly expressed inclination, desire, rejection, or priority that counts as an explicit preference. + +Requirements: +1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. +2. Output should be a list of concise natural language summaries and the corresponding context summary, context summary must contain complete information of the conversation fragment that the preference is mentioned. +3. If multiple preferences are mentioned within the same topic, you need to merge the preferences and context summary. + +Conversation: +{qa_pair} + +Find ALL explicit preferences. If no explicit preferences found, return []. Output JSON only: +```json +[ + { + "explicit_preference": "A short natural language summary of the preferences", + "context_summary": "The corresponding context summary, which is a summary of the corresponding conversation, do not lack any scenario information", + "reasoning": "reasoning process to find the explicit preferences" + }, +] +``` +""" + +NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_BAK = """ +You are a preference extraction assistant. +Please extract the user's explicitly mentioned preferences from the following conversation. + Notes: - A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". - Any clearly expressed inclination, desire, rejection, or priority counts as an explicit preference. From e1b505f60c7ac13f5653e9f6f3b5de5888433394 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 10 Oct 2025 10:40:30 +0800 Subject: [PATCH 33/62] modify extractor --- .../textual/prefer_text_memory/extractor.py | 92 ++++++++++++++++--- 1 file changed, 79 insertions(+), 13 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 7bb36c426..3407d5c7b 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -81,6 +81,32 @@ def extract_implicit_preference( print(f"Error extracting implicit preferences: {e}, return None") return None + # def _process_single_chunk_explicit( + # self, chunk: MessageList, msg_type: str, info: dict[str, Any] + # ) -> TextualMemoryItem | None: + # """Process a single chunk and return a TextualMemoryItem.""" + # basic_info = self.extract_basic_info(chunk) + # if not basic_info["dialog_str"]: + # return None + + # explicit_pref = self.extract_explicit_preference(basic_info["dialog_str"]) + # if not explicit_pref: + # return None + + # vector_info = { + # "embedding": self.embedder.embed([basic_info["dialog_str"]])[0], + # } + # extract_info = {**basic_info, **explicit_pref, **vector_info, **info} + + # metadata = PreferenceTextualMemoryMetadata( + # type=msg_type, preference_type="explicit_preference", **extract_info + # ) + # memory = TextualMemoryItem( + # id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata + # ) + + # return memory + def _process_single_chunk_explicit( self, chunk: MessageList, msg_type: str, info: dict[str, Any] ) -> TextualMemoryItem | None: @@ -93,19 +119,56 @@ def _process_single_chunk_explicit( if not explicit_pref: return None - vector_info = { - "embedding": self.embedder.embed([basic_info["dialog_str"]])[0], - } - extract_info = {**basic_info, **explicit_pref, **vector_info, **info} + memories = [] + for pref in explicit_pref: + vector_info = { + "embedding": self.embedder.embed([pref["context_summary"]])[0], + } + extract_info = {**basic_info, **pref, **vector_info, **info} - metadata = PreferenceTextualMemoryMetadata( - type=msg_type, preference_type="explicit_preference", **extract_info - ) - memory = TextualMemoryItem( - id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata - ) + metadata = PreferenceTextualMemoryMetadata( + type=msg_type, preference_type="explicit_preference", **extract_info + ) + memory = TextualMemoryItem( + id=str(uuid.uuid4()), memory=pref["context_summary"], metadata=metadata + ) - return memory + memories.append(memory) + + return memories + + # def _process_single_chunk_explicit( + # self, chunk: MessageList, msg_type: str, info: dict[str, Any] + # ) -> TextualMemoryItem | None: + # """Process a single chunk and return a TextualMemoryItem.""" + # basic_info = self.extract_basic_info(chunk) + # if not basic_info["dialog_str"]: + # return None + + # explicit_pref = self.extract_explicit_preference(basic_info["dialog_str"]) + # if not explicit_pref: + # return None + + # mem_texts = [] + # prefs = [] + # for pref in explicit_pref: + # mem_texts.append(pref["context_summary"]) + # prefs.append(pref["explicit_preference"]) + # mem_str = "\n".join(mem_texts) + # pref_dict = {"explicit_preference": "\n".join(prefs)} + # vector_info = { + # "embedding": self.embedder.embed([mem_str])[0], + # } + # extract_info = {**basic_info, **pref_dict, **vector_info, **info} + + # metadata = PreferenceTextualMemoryMetadata( + # type=msg_type, preference_type="explicit_preference", **extract_info + # ) + # memory = TextualMemoryItem( + # id=basic_info["dialog_id"], memory=mem_str, metadata=metadata + # ) + + # return memory def _process_single_chunk_implicit( self, chunk: MessageList, msg_type: str, info: dict[str, Any] @@ -142,7 +205,7 @@ def extract( """Extract preference memories based on the messages using thread pool for acceleration.""" chunks_for_explicit: list[MessageList] = [] for message in messages: - chunk = self.splitter.split_chunks(message, split_type="lookback") + chunk = self.splitter.split_chunks(message, split_type="overlap") chunks_for_explicit.extend(chunk) if not chunks_for_explicit: return [] @@ -177,7 +240,10 @@ def extract( try: memory = future.result() if memory: - memories.append(memory) + if isinstance(memory, list): + memories.extend(memory) + else: + memories.append(memory) except Exception as e: task_type, chunk = futures[future] print(f"Error processing {task_type} chunk: {chunk}\n{e}") From eb1cdb031742e9060254e3f22ac1629ede36b4a5 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 10 Oct 2025 14:47:29 +0800 Subject: [PATCH 34/62] add reranker to pref mem --- src/memos/configs/memory.py | 4 ++++ .../textual/prefer_text_memory/factory.py | 3 ++- .../textual/prefer_text_memory/retrievers.py | 17 ++++++++++++----- src/memos/memories/textual/preference.py | 3 +++ 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 8b5355a0f..d4b12f36d 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -204,6 +204,10 @@ class PreferenceTextMemoryConfig(BaseTextMemoryConfig): default_factory=EmbedderConfigFactory, description="Embedder configuration for the memory embedding", ) + reranker: RerankerConfigFactory | None = Field( + None, + description="Reranker configuration (optional).", + ) extractor: ExtractorConfigFactory = Field( ..., default_factory=ExtractorConfigFactory, diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index f0879c026..a5ec30c36 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -67,6 +67,7 @@ def from_config( config_factory: RetrieverConfigFactory, llm_provider=None, embedder=None, + reranker=None, vector_db=None, ) -> BaseRetriever: """Create a Retriever instance from a configuration factory.""" @@ -74,7 +75,7 @@ def from_config( if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") retriever_class = cls.backend_to_class[backend] - return retriever_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + return retriever_class(llm_provider=llm_provider, embedder=embedder, reranker=reranker, vector_db=vector_db) class AssemblerFactory(BaseAssembler): diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index f0291db12..b4841ea90 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -9,7 +9,7 @@ class BaseRetriever(ABC): """Abstract base class for retrievers.""" @abstractmethod - def __init__(self, llm_provider=None, embedder=None, vector_db=None): + def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=None): """Initialize the retriever.""" @abstractmethod @@ -22,9 +22,10 @@ def retrieve( class NaiveRetriever(BaseRetriever): """Naive retriever.""" - def __init__(self, llm_provider=None, embedder=None, vector_db=None): + def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=None): """Initialize the naive retriever.""" - super().__init__(llm_provider, embedder, vector_db) + super().__init__(llm_provider, embedder, reranker, vector_db) + self.reranker = reranker self.vector_db = vector_db self.embedder = embedder @@ -44,10 +45,10 @@ def retrieve( with ThreadPoolExecutor(max_workers=2) as executor: # Submit all search tasks future_explicit = executor.submit( - self.vector_db.search, query_embedding, "explicit_preference", top_k, info + self.vector_db.search, query_embedding, "explicit_preference", max(top_k, 20), info ) future_implicit = executor.submit( - self.vector_db.search, query_embedding, "implicit_preference", top_k, info + self.vector_db.search, query_embedding, "implicit_preference", max(top_k, 20), info ) # Wait for all results @@ -78,4 +79,10 @@ def retrieve( if pref.payload["implicit_preference"] ] + if self.reranker: + explicit_prefs = self.reranker.rerank(query, explicit_prefs, top_k) + implicit_prefs = self.reranker.rerank(query, implicit_prefs, top_k) + explicit_prefs = [item for item, _ in explicit_prefs] + implicit_prefs = [item for item, _ in implicit_prefs] + return explicit_prefs + implicit_prefs diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 30aed0dec..d1fd3e4a9 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -21,6 +21,7 @@ ExtractorFactory, RetrieverFactory, ) +from memos.reranker.factory import RerankerFactory from memos.types import MessageList from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB, VecDBFactory from memos.vec_dbs.item import VecDBItem @@ -42,6 +43,7 @@ def __init__(self, config: PreferenceTextMemoryConfig): self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = ( EmbedderFactory.from_config(config.embedder) ) + self.reranker = RerankerFactory.from_config(config.reranker) self.extractor = ExtractorFactory.from_config( config.extractor, @@ -60,6 +62,7 @@ def __init__(self, config: PreferenceTextMemoryConfig): config.retriever, llm_provider=self.extractor_llm, embedder=self.embedder, + reranker=self.reranker, vector_db=self.vector_db, ) self.assembler = AssemblerFactory.from_config( From f8e26ad481f5e9c23e778792b146400265be8e98 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 10 Oct 2025 16:22:02 +0800 Subject: [PATCH 35/62] remove assember in pref mem --- examples/mem_os/simple_prefs_memos_product.py | 1 - src/memos/configs/memory.py | 6 -- .../textual/prefer_text_memory/assemble.py | 95 ------------------- .../textual/prefer_text_memory/config.py | 33 ------- .../textual/prefer_text_memory/extractor.py | 33 ------- .../textual/prefer_text_memory/factory.py | 25 ----- src/memos/memories/textual/preference.py | 17 ---- 7 files changed, 210 deletions(-) delete mode 100644 src/memos/memories/textual/prefer_text_memory/assemble.py diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py index 32d6d3d35..beee837a8 100644 --- a/examples/mem_os/simple_prefs_memos_product.py +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -102,7 +102,6 @@ def get_config(user_id: str): "extractor": {"backend": "naive", "config": {}}, "adder": {"backend": "naive", "config": {}}, "retriever": {"backend": "naive", "config": {}}, - "assembler": {"backend": "naive", "config": {}}, }, }, } diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index d4b12f36d..f697a84f0 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -12,7 +12,6 @@ from memos.exceptions import ConfigurationError from memos.memories.textual.prefer_text_memory.config import ( AdderConfigFactory, - AssemblerConfigFactory, ExtractorConfigFactory, RetrieverConfigFactory, ) @@ -223,11 +222,6 @@ class PreferenceTextMemoryConfig(BaseTextMemoryConfig): default_factory=RetrieverConfigFactory, description="Retriever configuration for the memory retrieving", ) - assembler: AssemblerConfigFactory = Field( - ..., - default_factory=AssemblerConfigFactory, - description="Assembler configuration for the memory assembling", - ) # ─── 3. Global Memory Config Factory ────────────────────────────────────────── diff --git a/src/memos/memories/textual/prefer_text_memory/assemble.py b/src/memos/memories/textual/prefer_text_memory/assemble.py deleted file mode 100644 index 309beb4da..000000000 --- a/src/memos/memories/textual/prefer_text_memory/assemble.py +++ /dev/null @@ -1,95 +0,0 @@ -from abc import ABC, abstractmethod - -from memos.memories.textual.item import TextualMemoryItem - - -class BaseAssembler(ABC): - """Abstract base class for assemblers.""" - - @abstractmethod - def __init__(self, llm_provider=None, embedder=None, vector_db=None): - """Initialize the assembler.""" - - @abstractmethod - def get_instruction( - self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str = "semi" - ) -> str: - """Assemble query and memories into a single memory. - Args: - query: The query to assemble. - memories: The memories to assemble. - assemble_strategy: The strategy to assemble the memories. option: [raw, semi, full] - Returns: - The assembled prompt. - """ - - -class NaiveAssembler(BaseAssembler): - """Naive assembler.""" - - def __init__(self, llm_provider=None, embedder=None, vector_db=None): - """Initialize the naive assembler.""" - super().__init__(llm_provider, embedder, vector_db) - self.llm_provider = llm_provider - self.embedder = embedder - self.vector_db = vector_db - - def get_instruction( - self, query: str, memories: list[TextualMemoryItem], assemble_strategy: str = "semi" - ) -> str: - """Assemble query and memories into a single memory.""" - - # Initialize all memory lists - mems = { - "textual_mems": [], - "explicit_prefs": [], - "implicit_prefs": [], - } - - for memory in memories: - if memory.metadata.preference_type == "explicit_preference": - mems["explicit_prefs"].append(memory.metadata.explicit_preference) - elif memory.metadata.preference_type == "implicit_preference": - mems["implicit_prefs"].append(memory.metadata.implicit_preference) - else: - mems["textual_mems"].append(memory.memory) - - # Build memories string with different titles for different types - memories_parts = [] - if mems["textual_mems"]: - memories_parts.append("## Textual Memories:") - for i, mem in enumerate(mems["textual_mems"], 1): - memories_parts.append(f"{i}. {mem}") - if mems["explicit_prefs"]: - memories_parts.append("## Explicit Preferences:") - for i, pref in enumerate(mems["explicit_prefs"], 1): - memories_parts.append(f"{i}. {pref}") - - if mems["implicit_prefs"]: - memories_parts.append("\n## Implicit Preferences:") - for i, pref in enumerate(mems["implicit_prefs"], 1): - memories_parts.append(f"{i}. {pref}") - - memories_str = "\n".join(memories_parts) - - system_prompt = ( - "You are a knowledgeable and helpful AI assistant. " - "You have access to conversation memories that help you provide more personalized responses. " - "Use the memories to understand the user's context, preferences, and past interactions. " - "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." - f"\n\n## Memories:\n{memories_str}" - ) - - if assemble_strategy == "raw": - return system_prompt.replace("{memories}", memories_str) - elif assemble_strategy == "semi": - return ( - system_prompt - + ( - "Note: Textual memories are summaries of facts, while preference memories are summaries of user preferences. " - + "Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts." - + "When encountering preference conflicts, the priority is: explicit preferences > implicit preferences > textual memories." - ) - ).replace("{memories}", memories_str) - else: - raise ValueError(f"Invalid assemble strategy: {assemble_strategy}") diff --git a/src/memos/memories/textual/prefer_text_memory/config.py b/src/memos/memories/textual/prefer_text_memory/config.py index b6e04cadf..7e8354747 100644 --- a/src/memos/memories/textual/prefer_text_memory/config.py +++ b/src/memos/memories/textual/prefer_text_memory/config.py @@ -104,36 +104,3 @@ def create_config(self) -> "RetrieverConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) return self - - -class BaseAssemblerConfig(BaseConfig): - """Base configuration class for Assembler.""" - - -class NaiveAssemblerConfig(BaseAssemblerConfig): - """Configuration for Naive Assembler.""" - - -class AssemblerConfigFactory(BaseConfig): - """Factory class for creating Assembler configurations.""" - - backend: str = Field(..., description="Backend for Assembler") - config: dict[str, Any] = Field(..., description="Configuration for the Assembler backend") - - backend_to_class: ClassVar[dict[str, Any]] = { - "naive": NaiveAssemblerConfig, - } - - @field_validator("backend") - @classmethod - def validate_backend(cls, backend: str) -> str: - """Validate the backend field.""" - if backend not in cls.backend_to_class: - raise ValueError(f"Invalid backend: {backend}") - return backend - - @model_validator(mode="after") - def create_config(self) -> "AssemblerConfigFactory": - config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) - return self diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 3407d5c7b..4eb3647fc 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -137,39 +137,6 @@ def _process_single_chunk_explicit( return memories - # def _process_single_chunk_explicit( - # self, chunk: MessageList, msg_type: str, info: dict[str, Any] - # ) -> TextualMemoryItem | None: - # """Process a single chunk and return a TextualMemoryItem.""" - # basic_info = self.extract_basic_info(chunk) - # if not basic_info["dialog_str"]: - # return None - - # explicit_pref = self.extract_explicit_preference(basic_info["dialog_str"]) - # if not explicit_pref: - # return None - - # mem_texts = [] - # prefs = [] - # for pref in explicit_pref: - # mem_texts.append(pref["context_summary"]) - # prefs.append(pref["explicit_preference"]) - # mem_str = "\n".join(mem_texts) - # pref_dict = {"explicit_preference": "\n".join(prefs)} - # vector_info = { - # "embedding": self.embedder.embed([mem_str])[0], - # } - # extract_info = {**basic_info, **pref_dict, **vector_info, **info} - - # metadata = PreferenceTextualMemoryMetadata( - # type=msg_type, preference_type="explicit_preference", **extract_info - # ) - # memory = TextualMemoryItem( - # id=basic_info["dialog_id"], memory=mem_str, metadata=metadata - # ) - - # return memory - def _process_single_chunk_implicit( self, chunk: MessageList, msg_type: str, info: dict[str, Any] ) -> TextualMemoryItem | None: diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index a5ec30c36..b0cba396f 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -1,10 +1,8 @@ from typing import Any, ClassVar from memos.memories.textual.prefer_text_memory.adder import BaseAdder, NaiveAdder -from memos.memories.textual.prefer_text_memory.assemble import BaseAssembler, NaiveAssembler from memos.memories.textual.prefer_text_memory.config import ( AdderConfigFactory, - AssemblerConfigFactory, ExtractorConfigFactory, RetrieverConfigFactory, ) @@ -76,26 +74,3 @@ def from_config( raise ValueError(f"Invalid backend: {backend}") retriever_class = cls.backend_to_class[backend] return retriever_class(llm_provider=llm_provider, embedder=embedder, reranker=reranker, vector_db=vector_db) - - -class AssemblerFactory(BaseAssembler): - """Factory class for creating Assembler instances.""" - - backend_to_class: ClassVar[dict[str, Any]] = { - "naive": NaiveAssembler, - } - - @classmethod - def from_config( - cls, - config_factory: AssemblerConfigFactory, - llm_provider=None, - embedder=None, - vector_db=None, - ) -> BaseAssembler: - """Create a Assembler instance from a configuration factory.""" - backend = config_factory.backend - if backend not in cls.backend_to_class: - raise ValueError(f"Invalid backend: {backend}") - assembler_class = cls.backend_to_class[backend] - return assembler_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index d1fd3e4a9..0fc22917c 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -17,7 +17,6 @@ from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.prefer_text_memory.factory import ( AdderFactory, - AssemblerFactory, ExtractorFactory, RetrieverFactory, ) @@ -65,12 +64,6 @@ def __init__(self, config: PreferenceTextMemoryConfig): reranker=self.reranker, vector_db=self.vector_db, ) - self.assembler = AssemblerFactory.from_config( - config.assembler, - llm_provider=self.extractor_llm, - embedder=self.embedder, - vector_db=self.vector_db, - ) def get_memory( self, messages: MessageList, type: str, info: dict[str, Any] @@ -94,16 +87,6 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem """ return self.retriever.retrieve(query, top_k, info) - def get_prompt(self, query: str, memories: list[TextualMemoryItem]) -> str: - """Construct the prompt for the query with memories. - Args: - query (str): The query to get the prompt for. - memories (list[TextualMemoryItem]): The memories to get the prompt for. - Returns: - str: The prompt for the query with memories. - """ - return self.assembler.assemble(query, memories) - def load(self, dir: str) -> None: """Load memories from the specified directory. Args: From 666a6fc47da76d767b845893c881756b8cad7a8e Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 10 Oct 2025 20:09:29 +0800 Subject: [PATCH 36/62] modify code --- .../textual/prefer_text_memory/extractor.py | 26 ------ src/memos/templates/instruction_completion.py | 82 +++++++++++++++++++ 2 files changed, 82 insertions(+), 26 deletions(-) create mode 100644 src/memos/templates/instruction_completion.py diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 4eb3647fc..86d3a8f89 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -81,32 +81,6 @@ def extract_implicit_preference( print(f"Error extracting implicit preferences: {e}, return None") return None - # def _process_single_chunk_explicit( - # self, chunk: MessageList, msg_type: str, info: dict[str, Any] - # ) -> TextualMemoryItem | None: - # """Process a single chunk and return a TextualMemoryItem.""" - # basic_info = self.extract_basic_info(chunk) - # if not basic_info["dialog_str"]: - # return None - - # explicit_pref = self.extract_explicit_preference(basic_info["dialog_str"]) - # if not explicit_pref: - # return None - - # vector_info = { - # "embedding": self.embedder.embed([basic_info["dialog_str"]])[0], - # } - # extract_info = {**basic_info, **explicit_pref, **vector_info, **info} - - # metadata = PreferenceTextualMemoryMetadata( - # type=msg_type, preference_type="explicit_preference", **extract_info - # ) - # memory = TextualMemoryItem( - # id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata - # ) - - # return memory - def _process_single_chunk_explicit( self, chunk: MessageList, msg_type: str, info: dict[str, Any] ) -> TextualMemoryItem | None: diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py new file mode 100644 index 000000000..50e477775 --- /dev/null +++ b/src/memos/templates/instruction_completion.py @@ -0,0 +1,82 @@ +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessageList + + +def build_system_prompt( + memories: list[TextualMemoryItem] | None = None, + instruction_strategy: str = "process_conflict"): + # Initialize all memory lists + mems = { + "textual_mems": [], + "explicit_prefs": [], + "implicit_prefs": [], + } + + for memory in memories: + if memory.metadata.preference_type == "explicit_preference": + mems["explicit_prefs"].append(memory.metadata.explicit_preference) + elif memory.metadata.preference_type == "implicit_preference": + mems["implicit_prefs"].append(memory.metadata.implicit_preference) + else: + mems["textual_mems"].append(memory.memory) + + # Build memories string with different titles for different types + memories_parts = [] + if mems["textual_mems"]: + memories_parts.append("## Textual Memories:") + for i, mem in enumerate(mems["textual_mems"], 1): + memories_parts.append(f"{i}. {mem}") + if mems["explicit_prefs"]: + memories_parts.append("## Explicit Preferences:") + for i, pref in enumerate(mems["explicit_prefs"], 1): + memories_parts.append(f"{i}. {pref}") + + if mems["implicit_prefs"]: + memories_parts.append("\n## Implicit Preferences:") + for i, pref in enumerate(mems["implicit_prefs"], 1): + memories_parts.append(f"{i}. {pref}") + + memories_str = "\n".join(memories_parts) + + system_prompt = ( + "You are a knowledgeable and helpful AI assistant. " + "You have access to conversation memories that help you provide more personalized responses. " + "Use the memories to understand the user's context, preferences, and past interactions. " + "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." + "\n\n# Memories:\n{memories}" + ) + + if instruction_strategy == "raw": + system_prompt = system_prompt.replace("{memories}", memories_str) + elif instruction_strategy == "process_conflict": + system_prompt = ( + system_prompt + + ( + "\nNote: Textual memories are summaries of facts, while preference memories are summaries of user preferences. " + + "Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts." + + "When encountering preference conflicts, the priority is: explicit preferences > textual memories > implicit preferences." + ) + ).replace("{memories}", memories_str) + else: + raise ValueError(f"Invalid instruction strategy: {instruction_strategy}") + + return system_prompt + + +def get_instruction( + query: str, + memories: list[TextualMemoryItem] | None = None, + history: MessageList | None = None, + instruction_strategy: str = "process_conflict" +) -> str: + """Create instruction following the memories, preference and tool information.""" + + system_prompt = build_system_prompt(memories, instruction_strategy) + + messages = [ + {"role": "system", "content": system_prompt}, + *history, + {"role": "user", "content": query}, + ] + + return messages From 3fc4371d0359acaf9f78e4ec34d12b0914079cb3 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Mon, 13 Oct 2025 18:58:41 +0800 Subject: [PATCH 37/62] add op trace based update method in add --- .../textual/prefer_text_memory/adder.py | 145 ++++++++-- src/memos/templates/prefer_complete_prompt.py | 272 +++++++++++++++++- src/memos/vec_dbs/milvus.py | 2 + 3 files changed, 398 insertions(+), 21 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 66c1c9550..cb0707998 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -5,7 +5,7 @@ from typing import Any from memos.memories.textual.item import TextualMemoryItem -from memos.templates.prefer_complete_prompt import NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT +from memos.templates.prefer_complete_prompt import NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE from memos.vec_dbs.item import MilvusVecDBItem @@ -37,7 +37,7 @@ def __init__(self, llm_provider=None, embedder=None, vector_db=None): self.embedder = embedder self.vector_db = vector_db - def _judge_update_or_add(self, old_msg: str, new_msg: str) -> bool: + def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool: """Judge if the new message expresses the same core content as the old message.""" # Use the template prompt with placeholders prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_msg).replace( @@ -55,16 +55,127 @@ def _judge_update_or_add(self, old_msg: str, new_msg: str) -> bool: # Fallback to simple string comparison return old_msg == new_msg - def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: - """Process a single memory and return its ID if added successfully.""" + def _judge_update_or_add_trace_op(self, new_mem: str, retrieved_mems: str) -> dict[str, Any] | None: + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE.replace("{new_memory}", new_mem).replace("{retrieved_memories}", retrieved_mems) try: - payload = memory.to_dict()["metadata"] + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + return result + except Exception as e: + print(f"Error in judge_update_or_add_trace_op: {e}") + return None + + def _update_memory_op_trace(self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str + ) -> list[str] | str: + + if not retrieved_memories: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, memory=new_memory.memory, vector=new_memory.metadata.embedding, payload=payload + ) + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + + new_mem_input = { + "context_summary": new_memory.memory, + "preference": new_memory.metadata.explicit_preference if preference_type == "explicit_preference" else new_memory.metadata.implicit_preference, + } + retrieved_mem_inputs = [{"id": mem.id, "context_summary": mem.memory, "preference": mem.payload[preference_type]} for mem in retrieved_memories] + + rsp = self._judge_update_or_add_trace_op(new_mem=json.dumps(new_mem_input), retrieved_mems=json.dumps(retrieved_mem_inputs)) + if not rsp: + payload = new_memory.to_dict()["metadata"] fields_to_remove = {"dialog_id", "dialog_str", "embedding"} payload = {k: v for k, v in payload.items() if k not in fields_to_remove} vec_db_item = MilvusVecDBItem( - id=memory.id, memory=memory.memory, vector=memory.metadata.embedding, payload=payload + id=new_memory.id, memory=new_memory.memory, vector=new_memory.metadata.embedding, payload=payload ) + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + op_trace = rsp["trace"] + added_ids = [] + for op in op_trace: + if op["type"].lower() == "add": + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, memory=new_memory.memory, vector=new_memory.metadata.embedding, payload=payload + ) + self.vector_db.add(collection_name, [vec_db_item]) + added_ids.append(new_memory.id) + elif op["type"].lower() == "update": + payload = {"preference_type": preference_type, preference_type: op["new_preference"]} + vec_db_item = MilvusVecDBItem( + id=op["target_id"], memory=op["new_context_summary"], vector=self.embedder.embed([op["new_context_summary"]])[0], payload=payload + ) + self.vector_db.update(collection_name, op["target_id"], vec_db_item) + added_ids.append(op["target_id"]) + elif op["type"].lower() == "delete": + self.vector_db.delete(collection_name, [op["target_id"]]) + else: + continue + + return added_ids + + def _update_memory_fast(self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + ) -> str: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, memory=new_memory.memory, vector=new_memory.metadata.embedding, payload=payload + ) + recall = retrieved_memories[0] if retrieved_memories else None + if not recall or (recall.score is not None and recall.score < 0.5): + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + + old_msg_str = recall.memory + new_msg_str = new_memory.memory + is_same = self._judge_update_or_add_fast(old_msg=old_msg_str, new_msg=new_msg_str) + if is_same: + self.vector_db.delete(collection_name, [recall.id]) + self.vector_db.update(collection_name, new_memory.id, vec_db_item) + return new_memory.id + + def _update_memory(self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str, + update_mode: str = "op_trace", + ) -> list[str] | str | None: + """Update the memory. + Args: + new_memory: TextualMemoryItem + retrieved_memories: list[MilvusVecDBItem] + collection_name: str + preference_type: str + update_mode: str, "op_trace" or "fast" + """ + if update_mode == "op_trace": + return self._update_memory_op_trace(new_memory, retrieved_memories, collection_name, preference_type) + elif update_mode == "fast": + return self._update_memory_fast(new_memory, retrieved_memories, collection_name) + else: + raise ValueError(f"Invalid update mode: {update_mode}") + + + def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | None: + """Process a single memory and return its ID if added successfully.""" + try: pref_type_collection_map = { "explicit_preference": "explicit_preference", "implicit_preference": "implicit_preference", @@ -73,20 +184,11 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> str | None: collection_name = pref_type_collection_map[preference_type] search_results = self.vector_db.search( - memory.metadata.embedding, collection_name, top_k=1 + memory.metadata.embedding, collection_name, top_k=5 ) - recall = search_results[0] if search_results else None - if not recall or (recall.score is not None and recall.score < 0.5): - self.vector_db.update(collection_name, memory.id, vec_db_item) - return memory.id - - old_msg_str = recall.memory - new_msg_str = memory.memory - is_same = self._judge_update_or_add(old_msg_str, new_msg_str) - if is_same: - self.vector_db.delete(collection_name, [recall.id]) - self.vector_db.update(collection_name, memory.id, vec_db_item) - return memory.id + search_results.sort(key=lambda x: x.score, reverse=True) + + return self._update_memory(memory, search_results, collection_name, preference_type, update_mode="fast") except Exception as e: print(f"Error processing memory {memory.id}: {e}") @@ -113,7 +215,10 @@ def add( try: memory_id = future.result() if memory_id: - added_ids.append(memory_id) + if isinstance(memory_id, list): + added_ids.extend(memory_id) + else: + added_ids.append(memory_id) except Exception as e: memory = future_to_memory[future] print(f"Error processing memory {memory.id}: {e}") diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 71f636160..01a041c07 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -4,7 +4,9 @@ Notes: - A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". -- This includes, but is not limited to, any clearly expressed inclination, desire, rejection, or priority that counts as an explicit preference. +- This includes, but is not limited to, any user's explicitly expressed inclination, desire, rejection, or priority that counts as an explicit preference. +- Focus on extracting the user's preferences in query. Do not extract preferences from the assistant's responses unless the user explicitly agrees with or endorses the assistant's suggestions. +- When the user modifies or updates their preferences for the same topic or event, extract the complete evolution process of their preference changes, including both the original and updated preferences. Requirements: 1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. @@ -132,4 +134,272 @@ **New Information:** {new_information} +""" + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE_BAK = """ +You are a **User Preference Memory Management Agent**. +Your goal is to maintain a user's long-term **preference memory base** by analyzing new preference information and determining how it should update existing memories. + +You must produce a complete **operation trace**, showing which memory entries (identified by unique IDs) should be **added**, **updated**, or **deleted**, and then output the **final memory state** after all operations. + +## Input Format + +New preference memory (new_memory): +{new_memory} + +Retrieved preference memories (retrieved_memories): +{retrieved_memories} + +## Task Instructions + +1. Analyze each retrieved memory and determine its relationship to the new memory: + - **Unrelated** → perform `"ADD"` (insert as a new independent memory); + - **Related** → perform `"UPDATE"` (refine, supplement, or merge with the old memory); + - **Conflicting or outdated** → perform `"DELETE"` (remove obsolete or contradictory memory). + +2. If multiple retrieved memories describe the same preference theme, merge them into one updated memory entry. + +3. Output a structured list of **operation traces**, each explicitly stating: + - which memory (by ID) is affected, + - what operation is performed, + - the before/after content, + - and the reasoning behind it. + +4. Output the **final memory state (after_update_state)**, representing the complete preference memory base after applying all operations. + +## Output Format (JSON) + +{ + "trace": [ + { + "op_id": "op_1", + "type": "ADD" | "UPDATE" | "DELETE", + "target_id": "(the old memory ID; null if ADD)", + "old_content": "(old memory content; null if ADD)", + "new_content": "(the updated or newly created memory, if applicable)", + "reason": "(brief natural-language explanation for the decision)" + } + ], + "after_update_state": [ + {"id": "id1", "content": "…"}, + {"id": "id2", "content": "…"} + ] +} + +## Example + +**Input:** +new_memory: +"User now prefers lattes but occasionally drinks Americanos; he also enjoys studying in quiet coffee shops." + +retrieved_memories: +[ + {"id": "id1", "content": "User likes coffee."}, + {"id": "id2", "content": "User prefers Americanos."}, + {"id": "id3", "content": "User likes working from home."}, + {"id": "id4", "content": "User has no particular interest in tea."} +] + +**Output:** +{ + "trace": [ + { + "op_id": "op_1", + "type": "UPDATE", + "target_id": "id1", + "old_content": "User likes coffee.", + "new_content": "User likes coffee, especially lattes, but sometimes drinks Americanos.", + "reason": "The new memory refines and extends the user's coffee preference details." + }, + { + "op_id": "op_2", + "type": "DELETE", + "target_id": "id2", + "old_content": "User prefers Americanos.", + "new_content": null, + "reason": "This old memory has been integrated into a broader updated coffee preference (id1)." + }, + { + "op_id": "op_3", + "type": "UPDATE", + "target_id": "id3", + "old_content": "User likes working from home.", + "new_content": "User now prefers studying in quiet coffee shops instead of working from home.", + "reason": "The new memory shows a shift in environment preference; the old one is outdated." + } + ], + "after_update_state": [ + {"id": "id1", "content": "User likes coffee, especially lattes, but sometimes drinks Americanos."}, + {"id": "id3", "content": "User now prefers studying in quiet coffee shops instead of working from home."}, + {"id": "id4", "content": "User has no particular interest in tea."} + ] +} + +## Output Requirements + +- The output **must** be valid JSON. +- Each operation must include a `reason`. +- Multiple retrieved memories may be merged into one unified updated memory. +- `after_update_state` must reflect the final, post-update state of the preference memory base. +- Do **not** include any explanatory text outside the JSON. +""" + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE = """ +# User Preference Memory Management Agent + +You are a **User Preference Memory Management Agent**. +Your goal is to maintain a user's long-term **preference memory base** by analyzing new preference information and determining how it should update existing memories. + +Each memory entry contains three fields: +- **id**: a unique identifier for the memory. +- **context_summary**: a factual summary of the dialogue or situation from which the preference was extracted. +- **preference**: the extracted statement describing the user's preference or tendency. + +When updating a preference, you should also integrate and update the corresponding `context_summary` to ensure both fields stay semantically consistent. + +You must produce a complete **operation trace**, showing which memory entries (identified by unique IDs) should be **added**, **updated**, or **deleted**, and then output the **final memory state** after all operations. + +## Input Format + +New preference memory (new_memory): +{new_memory} + +Retrieved preference memories (retrieved_memories): +{retrieved_memories} + +## Task Instructions + +1. Analyze each retrieved memory and determine its relationship to the new memory: + - **Unrelated** → perform `"ADD"` (insert as a new independent memory); + - **Related** → perform `"UPDATE"` (refine, supplement, or merge both the `preference` and the `context_summary`); + - **Conflicting or outdated** → perform `"DELETE"` (remove obsolete or contradictory memory). + +2. If multiple retrieved memories describe the same preference theme, merge them into one updated memory entry, combining both their `preference` information and their `context_summary` in a coherent and concise way. + +3. Output a structured list of **operation traces**, each explicitly stating: + - which memory (by ID) is affected, + - what operation is performed, + - the before/after `preference` and `context_summary`, + - and the reasoning behind it. + +4. Output the **final memory state (after_update_state)**, representing the complete preference memory base after applying all operations. + +## Output Format (JSON) + +{ + "trace": [ + { + "op_id": "op_1", + "type": "ADD" | "UPDATE" | "DELETE", + "target_id": "(the old memory ID; null if ADD)", + "old_preference": "(the old preference text; null if ADD)", + "old_context_summary": "(the old context summary; null if ADD)", + "new_preference": "(the updated or newly created preference, if applicable)", + "new_context_summary": "(the updated or newly created context summary, if applicable)", + "reason": "(brief natural-language explanation for the decision)" + } + ], + "after_update_state": [ + { + "id": "id1", + "context_summary": "updated factual summary of the context", + "preference": "updated or final preference text" + } + ] +} + +## Example + +**Input:** +new_memory: +{ + "context_summary": "During a recent chat about study habits, the user mentioned that he often studies in quiet coffee shops and has started preferring lattes over Americanos, which he only drinks occasionally.", + "preference": "User now prefers lattes but occasionally drinks Americanos; he also enjoys studying in quiet coffee shops." +} + +retrieved_memories: +[ + { + "id": "id1", + "context_summary": "The user previously said he likes coffee in general.", + "preference": "User likes coffee." + }, + { + "id": "id2", + "context_summary": "The user once mentioned preferring Americanos during work breaks.", + "preference": "User prefers Americanos." + }, + { + "id": "id3", + "context_summary": "The user said he often works from home", + "preference": "User likes working from home." + }, + { + "id": "id4", + "context_summary": "The user noted he doesn't drink tea very often.", + "preference": "User has no particular interest in tea." + } +] + +**Output:** +{ + "trace": [ + { + "op_id": "op_1", + "type": "UPDATE", + "target_id": "id1", + "old_preference": "User likes coffee.", + "old_context_summary": "The user previously said he likes coffee in general.", + "new_preference": "User likes coffee, especially lattes, but occasionally drinks Americanos.", + "new_context_summary": "The user discussed his coffee habits, stating he now prefers lattes but only occasionally drinks Americanos", + "reason": "The new memory refines and expands the coffee preference and context while preserving frequency semantics ('occasionally')." + }, + { + "op_id": "op_2", + "type": "DELETE", + "target_id": "id2", + "old_preference": "User prefers Americanos.", + "old_context_summary": "The user once mentioned preferring Americanos during work breaks.", + "new_preference": null, + "new_context_summary": null, + "reason": "This old memory is now merged into the updated coffee preference (id1)." + }, + { + "op_id": "op_3", + "type": "UPDATE", + "target_id": "id3", + "old_preference": "User likes working from home.", + "old_context_summary": "The user said he often works from home.", + "new_preference": "User now prefers studying in quiet coffee shops instead of working from home.", + "new_context_summary": "The user mentioned shifting from working at home to studying in quiet cafes, reflecting a new preferred environment.", + "reason": "The preference has changed for the working environment." + } + ], + "after_update_state": [ + { + "id": "id1", + "context_summary": "The user discussed his coffee habits, saying he now prefers lattes but only occasionally drinks Americanos.", + "preference": "User likes coffee, especially lattes, but occasionally drinks Americanos." + }, + { + "id": "id3", + "context_summary": "The user mentioned shifting from working at home to studying in quiet cafes, reflecting a new preferred environment.", + "preference": "User now prefers studying in quiet coffee shops instead of working from home." + }, + { + "id": "id4", + "context_summary": "The user noted he doesn't drink tea very often.", + "preference": "User has no particular interest in tea." + } + ] +} + +## Output Requirements + +- The output **must** be valid JSON. +- Each operation must include both `preference` and `context_summary` updates where applicable. +- Each operation must include a clear `reason`. +- Multiple retrieved memories may be merged into one unified updated memory. +- `after_update_state` must reflect the final, post-update state of the preference memory base. +- Do **not** include any explanatory text outside the JSON. """ \ No newline at end of file diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 8be6db6f1..88ec7df42 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -336,6 +336,8 @@ def add(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]] def update(self, collection_name: str, id: str, data: MilvusVecDBItem | dict[str, Any]) -> None: """Update an item in the vector database.""" + if id != data.id: + raise ValueError(f"The id of the data to update must be the same as the id of the item to update, ID mismatch: expected {id}, got {data.id}") if isinstance(data, dict): data = data.copy() data = MilvusVecDBItem.from_dict(data) From 90e1252b04d93e0681364d67d31193798c275e5e Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 14 Oct 2025 10:12:05 +0800 Subject: [PATCH 38/62] modify slow update in adder --- .../textual/prefer_text_memory/adder.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index cb0707998..42f04c2e3 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -99,30 +99,37 @@ def _update_memory_op_trace(self, ) self.vector_db.add(collection_name, [vec_db_item]) return new_memory.id - - op_trace = rsp["trace"] - added_ids = [] - for op in op_trace: - if op["type"].lower() == "add": + + def execute_op(op): + op_type = op["type"].lower() + if op_type == "add": payload = new_memory.to_dict()["metadata"] - fields_to_remove = {"dialog_id", "dialog_str", "embedding"} - payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + payload = {k: v for k, v in payload.items() if k not in {"dialog_id", "dialog_str", "embedding"}} vec_db_item = MilvusVecDBItem( - id=new_memory.id, memory=new_memory.memory, vector=new_memory.metadata.embedding, payload=payload + id=new_memory.id, memory=new_memory.memory, + vector=new_memory.metadata.embedding, payload=payload ) self.vector_db.add(collection_name, [vec_db_item]) - added_ids.append(new_memory.id) - elif op["type"].lower() == "update": + return new_memory.id + elif op_type == "update": payload = {"preference_type": preference_type, preference_type: op["new_preference"]} vec_db_item = MilvusVecDBItem( - id=op["target_id"], memory=op["new_context_summary"], vector=self.embedder.embed([op["new_context_summary"]])[0], payload=payload + id=op["target_id"], memory=op["new_context_summary"], + vector=self.embedder.embed([op["new_context_summary"]])[0], payload=payload ) self.vector_db.update(collection_name, op["target_id"], vec_db_item) - added_ids.append(op["target_id"]) - elif op["type"].lower() == "delete": + return op["target_id"] + elif op_type == "delete": self.vector_db.delete(collection_name, [op["target_id"]]) - else: - continue + return None + + with ThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: + future_to_op = {executor.submit(execute_op, op): op for op in rsp["trace"]} + added_ids = [] + for future in as_completed(future_to_op): + result = future.result() + if result is not None: + added_ids.append(result) return added_ids From 2e4714d25ae5086d304532ac732f8813f1af5d62 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 14 Oct 2025 11:22:05 +0800 Subject: [PATCH 39/62] modify implicit part code in extractor and add duplicate in utils --- .../textual/prefer_text_memory/extractor.py | 21 +-- .../textual/prefer_text_memory/retrievers.py | 4 + .../textual/prefer_text_memory/utils.py | 48 +++++ src/memos/templates/prefer_complete_prompt.py | 165 +----------------- 4 files changed, 61 insertions(+), 177 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 86d3a8f89..24ec6dbf3 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -122,7 +122,7 @@ def _process_single_chunk_implicit( return None vector_info = { - "embedding": self.embedder.embed([basic_info["dialog_str"]])[0], + "embedding": self.embedder.embed([implicit_pref["context_summary"]])[0], } extract_info = {**basic_info, **implicit_pref, **vector_info, **info} @@ -131,7 +131,7 @@ def _process_single_chunk_implicit( type=msg_type, preference_type="implicit_preference", **extract_info ) memory = TextualMemoryItem( - id=extract_info["dialog_id"], memory=extract_info["dialog_str"], metadata=metadata + id=extract_info["dialog_id"], memory=implicit_pref["context_summary"], metadata=metadata ) return memory @@ -144,28 +144,23 @@ def extract( max_workers: int = 10, ) -> list[TextualMemoryItem]: """Extract preference memories based on the messages using thread pool for acceleration.""" - chunks_for_explicit: list[MessageList] = [] + chunks: list[MessageList] = [] for message in messages: chunk = self.splitter.split_chunks(message, split_type="overlap") - chunks_for_explicit.extend(chunk) - if not chunks_for_explicit: + chunks.extend(chunk) + if not chunks: return [] - chunks_for_implicit: list[MessageList] = [] - for message in messages: - chunk = self.splitter.split_chunks(message, split_type="overlap") - chunks_for_implicit.extend(chunk) - memories = [] with ThreadPoolExecutor( - max_workers=min(max_workers, len(chunks_for_explicit) + len(chunks_for_implicit)) + max_workers=min(max_workers, len(chunks)) ) as executor: futures = { executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ( "explicit", chunk, ) - for chunk in chunks_for_explicit + for chunk in chunks } futures.update( { @@ -173,7 +168,7 @@ def extract( "implicit", chunk, ) - for chunk in chunks_for_implicit + for chunk in chunks } ) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index b4841ea90..0670a28f1 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -3,6 +3,7 @@ from typing import Any from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.prefer_text_memory.utils import deduplicate_preferences class BaseRetriever(ABC): @@ -85,4 +86,7 @@ def retrieve( explicit_prefs = [item for item, _ in explicit_prefs] implicit_prefs = [item for item, _ in implicit_prefs] + # deduplicate implicit preferences, due to more duplicates + implicit_prefs = deduplicate_preferences(implicit_prefs, similarity_threshold=0.8, num_perm=256) + return explicit_prefs + implicit_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py index e0752d899..e3106edce 100644 --- a/src/memos/memories/textual/prefer_text_memory/utils.py +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -1,3 +1,7 @@ +import re + +from datasketch import MinHashLSH, MinHash +from memos.memories.textual.item import TextualMemoryItem from memos.types import MessageList @@ -11,3 +15,47 @@ def convert_messages_to_string(messages: MessageList) -> str: message_text += f"Answer: {message['content']}\n" if message["content"].strip() else "" message_text = message_text.strip() return message_text + + +def deduplicate_preferences(prefs: list[TextualMemoryItem], similarity_threshold: float = 0.8, num_perm: int = 256) -> list[TextualMemoryItem]: + """ + Deduplicate preference texts using MinHash algorithm. + + Args: + prefs: List of preference memory items to deduplicate + similarity_threshold: Jaccard similarity threshold (0.0-1.0), default 0.8 + + Returns: + Deduplicated list of preference items + """ + if not prefs: + return prefs + + # Use MinHashLSH for efficient similarity search + lsh = MinHashLSH(threshold=similarity_threshold, num_perm=num_perm) + unique_prefs = [] + + for i, pref in enumerate(prefs): + # Extract preference text + if hasattr(pref.metadata, 'implicit_preference') and pref.metadata.implicit_preference: + text = pref.metadata.implicit_preference + elif hasattr(pref.metadata, 'explicit_preference') and pref.metadata.explicit_preference: + text = pref.metadata.explicit_preference + else: + text = pref.memory + + # Create MinHash from text tokens + minhash = MinHash(num_perm=num_perm) + # Simple tokenization: split by whitespace and clean + tokens = re.findall(r'\w+', text.lower()) + for token in tokens: + minhash.update(token.encode('utf8')) + + # Check for duplicates using LSH + similar_items = lsh.query(minhash) + + if not similar_items: # No similar items found + lsh.insert(i, minhash) + unique_prefs.append(pref) + + return unique_prefs diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 01a041c07..c4597073a 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -28,32 +28,6 @@ ``` """ -NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_BAK = """ -You are a preference extraction assistant. -Please extract the user's explicitly mentioned preferences from the following conversation. - -Notes: -- A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". -- Any clearly expressed inclination, desire, rejection, or priority counts as an explicit preference. - -Requirements: -1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. -2. Output should be a concise natural language summary, not a list or categories. -3. If there are no explicit preferences in the conversation, output an empty string "". -4. Output only the preference statements themselves, without any additional explanation. - -Conversation: -{qa_pair} - -Output format: -```json -{ - "explicit_preference": "A short natural language summary of the preferences, or an empty string" -} -``` -Don't output anything except the JSON. -""" - NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ You are a preference inference assistant. Please extract **implicit preferences** from the following conversation @@ -77,6 +51,7 @@ ```json { "implicit_preference": "A concise natural language statement of the implicit preferences reasonably inferred from the conversation, or an empty string", + "context_summary": "The corresponding context summary, which is a summary of the corresponding conversation, do not lack any scenario information", "reasoning": "Briefly explain the reasoning process for the implicit preference" } ``` @@ -84,37 +59,6 @@ """ -NAIVE_EXPLICIT_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ -You are a preference extraction and inference assistant. Please extract the user's preferences from the following conversation, including: - -1. **Explicit preferences**: Preferences that the user directly expresses, such as likes, dislikes, wants, does not want, or prioritized choices. -2. **Implicit preferences**: Preferences that are not explicitly stated but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices. - -Notes: -- For explicit preferences, only extract what the user directly states, do not infer. -- For implicit preferences, only infer when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. -- Do not duplicate: do not treat explicit preferences as implicit preferences. - -Requirements: -1. Output in JSON format with two fields: "explicit_preference" and "implicit_preference". -2. Each field should be an array, with each element being a concise natural language preference statement. -3. Output only the preference statements themselves; do not include any extra explanation, reasoning, or confidence information. -4. If a type of preference does not exist, its array should be empty. - -Conversation: -{qa_pair} - -Output Format: -```json -{ - "explicit_preference": ["The user clearly likes coffee", "The user does not want to sit by the window"], - "implicit_preference": ["The user prefers a quiet environment"] -} -``` -Don't output anything except the JSON. -""" - - NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT = """ You are a content comparison expert. Now you are given old and new information, each containing a question, answer topic name and topic description. Please judge whether these two information express the **same question or core content**, regardless of expression differences, details or example differences. The judgment criteria are as follows: @@ -136,113 +80,6 @@ {new_information} """ -NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE_BAK = """ -You are a **User Preference Memory Management Agent**. -Your goal is to maintain a user's long-term **preference memory base** by analyzing new preference information and determining how it should update existing memories. - -You must produce a complete **operation trace**, showing which memory entries (identified by unique IDs) should be **added**, **updated**, or **deleted**, and then output the **final memory state** after all operations. - -## Input Format - -New preference memory (new_memory): -{new_memory} - -Retrieved preference memories (retrieved_memories): -{retrieved_memories} - -## Task Instructions - -1. Analyze each retrieved memory and determine its relationship to the new memory: - - **Unrelated** → perform `"ADD"` (insert as a new independent memory); - - **Related** → perform `"UPDATE"` (refine, supplement, or merge with the old memory); - - **Conflicting or outdated** → perform `"DELETE"` (remove obsolete or contradictory memory). - -2. If multiple retrieved memories describe the same preference theme, merge them into one updated memory entry. - -3. Output a structured list of **operation traces**, each explicitly stating: - - which memory (by ID) is affected, - - what operation is performed, - - the before/after content, - - and the reasoning behind it. - -4. Output the **final memory state (after_update_state)**, representing the complete preference memory base after applying all operations. - -## Output Format (JSON) - -{ - "trace": [ - { - "op_id": "op_1", - "type": "ADD" | "UPDATE" | "DELETE", - "target_id": "(the old memory ID; null if ADD)", - "old_content": "(old memory content; null if ADD)", - "new_content": "(the updated or newly created memory, if applicable)", - "reason": "(brief natural-language explanation for the decision)" - } - ], - "after_update_state": [ - {"id": "id1", "content": "…"}, - {"id": "id2", "content": "…"} - ] -} - -## Example - -**Input:** -new_memory: -"User now prefers lattes but occasionally drinks Americanos; he also enjoys studying in quiet coffee shops." - -retrieved_memories: -[ - {"id": "id1", "content": "User likes coffee."}, - {"id": "id2", "content": "User prefers Americanos."}, - {"id": "id3", "content": "User likes working from home."}, - {"id": "id4", "content": "User has no particular interest in tea."} -] - -**Output:** -{ - "trace": [ - { - "op_id": "op_1", - "type": "UPDATE", - "target_id": "id1", - "old_content": "User likes coffee.", - "new_content": "User likes coffee, especially lattes, but sometimes drinks Americanos.", - "reason": "The new memory refines and extends the user's coffee preference details." - }, - { - "op_id": "op_2", - "type": "DELETE", - "target_id": "id2", - "old_content": "User prefers Americanos.", - "new_content": null, - "reason": "This old memory has been integrated into a broader updated coffee preference (id1)." - }, - { - "op_id": "op_3", - "type": "UPDATE", - "target_id": "id3", - "old_content": "User likes working from home.", - "new_content": "User now prefers studying in quiet coffee shops instead of working from home.", - "reason": "The new memory shows a shift in environment preference; the old one is outdated." - } - ], - "after_update_state": [ - {"id": "id1", "content": "User likes coffee, especially lattes, but sometimes drinks Americanos."}, - {"id": "id3", "content": "User now prefers studying in quiet coffee shops instead of working from home."}, - {"id": "id4", "content": "User has no particular interest in tea."} - ] -} - -## Output Requirements - -- The output **must** be valid JSON. -- Each operation must include a `reason`. -- Multiple retrieved memories may be merged into one unified updated memory. -- `after_update_state` must reflect the final, post-update state of the preference memory base. -- Do **not** include any explanatory text outside the JSON. -""" NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE = """ # User Preference Memory Management Agent From 8cbd710d1abb2f1f17664880c4fb2bf9329bbda2 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 14 Oct 2025 14:51:15 +0800 Subject: [PATCH 40/62] modify depulicate threshold --- src/memos/memories/textual/prefer_text_memory/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py index e3106edce..33ca21481 100644 --- a/src/memos/memories/textual/prefer_text_memory/utils.py +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -17,7 +17,7 @@ def convert_messages_to_string(messages: MessageList) -> str: return message_text -def deduplicate_preferences(prefs: list[TextualMemoryItem], similarity_threshold: float = 0.8, num_perm: int = 256) -> list[TextualMemoryItem]: +def deduplicate_preferences(prefs: list[TextualMemoryItem], similarity_threshold: float = 0.6, num_perm: int = 256) -> list[TextualMemoryItem]: """ Deduplicate preference texts using MinHash algorithm. From 0d205c5d89b6b4e167255c7ea7c2f55fe07fd33a Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 15 Oct 2025 12:45:05 +0800 Subject: [PATCH 41/62] modify api config --- src/memos/api/config.py | 50 +++++++++++++++++++ .../textual/prefer_text_memory/extractor.py | 12 ++--- .../textual/prefer_text_memory/retrievers.py | 2 +- 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index c9ff70d4e..70fcf96b0 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -90,6 +90,44 @@ def get_activation_vllm_config() -> dict[str, Any]: }, } + @staticmethod + def get_preference_memory_config() -> dict[str, Any]: + """Get preference memory configuration.""" + return { + "backend": "pref_text", + "config": { + "extractor_llm": {"backend": "openai", "config": APIConfig.get_openai_config()}, + "vector_db": { + "backend": "milvus", + "config": { + "collection_name": [ + "explicit_preference", + "implicit_preference", + ], + "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + "distance_metric": "cosine", + "uri": os.getenv("MILVUS_URI", "http://localhost:19530"), + "user_name": os.getenv("MILVUS_USER_NAME", "root"), + "password": os.getenv("MILVUS_PASSWORD", "12345678") + } + }, + "embedder": APIConfig.get_embedder_config(), + "reranker": APIConfig.get_reranker_config(), + "extractor": { + "backend": "naive", + "config": {} + }, + "adder": { + "backend": "naive", + "config": {} + }, + "retriever": { + "backend": "naive", + "config": {} + }, + } + } + @staticmethod def get_reranker_config() -> dict[str, Any]: """Get embedder configuration.""" @@ -370,6 +408,8 @@ def get_product_default_config() -> dict[str, Any]: "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": int(os.getenv("MOS_TOP_K", "50")), "max_turns_window": int(os.getenv("MOS_MAX_TURNS_WINDOW", "20")), } @@ -399,6 +439,8 @@ def get_start_default_config() -> dict[str, Any]: "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": int(os.getenv("MOS_TOP_K", "5")), "chat_model": { "backend": os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai"), @@ -466,6 +508,8 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": 30, "max_turns_window": 20, } @@ -531,6 +575,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" else APIConfig.get_activation_vllm_config(), "para_mem": {}, + "pref_mem": {} + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "false" + else APIConfig.get_preference_memory_config(), } ) else: @@ -593,6 +640,9 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" else APIConfig.get_activation_vllm_config(), "para_mem": {}, + "pref_mem": {} + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "false" + else APIConfig.get_preference_memory_config(), } ) else: diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 24ec6dbf3..bc081d795 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -60,17 +60,13 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A return None def extract_implicit_preference( - self, qa_pairs: MessageList | list[str] + self, qa_pair: MessageList | str ) -> dict[str, Any] | None: """Extract implicit preferences from cluster qa pairs.""" - if not qa_pairs: + if not qa_pair: return None - qa_pairs_str = ( - convert_messages_to_string(qa_pairs) - if isinstance(qa_pairs[0], dict) - else "\n\n".join(qa_pairs) - ) - prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pairs}", qa_pairs_str) + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair + prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 0670a28f1..b91cfbb46 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -87,6 +87,6 @@ def retrieve( implicit_prefs = [item for item, _ in implicit_prefs] # deduplicate implicit preferences, due to more duplicates - implicit_prefs = deduplicate_preferences(implicit_prefs, similarity_threshold=0.8, num_perm=256) + # implicit_prefs = deduplicate_preferences(implicit_prefs, similarity_threshold=0.6, num_perm=256) return explicit_prefs + implicit_prefs From d6827ea36ca6ba1a9ef7067daccc20464481fcfb Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 15 Oct 2025 14:51:12 +0800 Subject: [PATCH 42/62] reapir bug in adder about search relate --- src/memos/memories/textual/prefer_text_memory/adder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 42f04c2e3..20087bfcc 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -191,7 +191,7 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | collection_name = pref_type_collection_map[preference_type] search_results = self.vector_db.search( - memory.metadata.embedding, collection_name, top_k=5 + memory.metadata.embedding, collection_name, top_k=5, filter={"user_id": memory.metadata.user_id} ) search_results.sort(key=lambda x: x.score, reverse=True) From 0ee048c85378ed71617e71acd31ff12634213922 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 15 Oct 2025 18:59:48 +0800 Subject: [PATCH 43/62] repair bug in core , dupicate search --- src/memos/mem_os/core.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index e8dede58b..17b5b6115 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -680,33 +680,6 @@ def search_preference_memory(): if pref_result is not None: result["pref_mem"].append(pref_result) - if ( - (mem_cube_id in install_cube_ids) - and (mem_cube.text_mem is not None) - and self.config.enable_textual_memory - ): - time_start = time.time() - memories = mem_cube.text_mem.search( - query, - top_k=top_k if top_k else self.config.top_k, - mode=mode, - manual_close_internet=not internet_search, - info={ - "user_id": target_user_id, - "session_id": target_session_id, - "chat_history": chat_history.chat_history, - }, - moscube=moscube, - search_filter=search_filter, - ) - result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories}) - logger.info( - f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n" - ) - search_time_end = time.time() - logger.info( - f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" - ) return result def add( From 2e2ad34b4c7f5343f3f6ecc20dd4e655afb5511e Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 21 Oct 2025 11:13:26 +0800 Subject: [PATCH 44/62] add pref to new naive cube and server api --- src/memos/api/config.py | 26 +- src/memos/api/routers/server_router.py | 223 ++++++++++++++---- src/memos/graph_dbs/nebular.py | 2 +- src/memos/mem_cube/base.py | 1 + src/memos/mem_cube/navie.py | 48 +++- src/memos/mem_os/core.py | 1 - src/memos/memories/factory.py | 2 + .../memories/textual/simple_preference.py | 159 +++++++++++++ 8 files changed, 401 insertions(+), 61 deletions(-) create mode 100644 src/memos/memories/textual/simple_preference.py diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 16475d929..12ff91cfa 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -117,17 +117,7 @@ def get_preference_memory_config() -> dict[str, Any]: "extractor_llm": {"backend": "openai", "config": APIConfig.get_openai_config()}, "vector_db": { "backend": "milvus", - "config": { - "collection_name": [ - "explicit_preference", - "implicit_preference", - ], - "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), - "distance_metric": "cosine", - "uri": os.getenv("MILVUS_URI", "http://localhost:19530"), - "user_name": os.getenv("MILVUS_USER_NAME", "root"), - "password": os.getenv("MILVUS_PASSWORD", "12345678") - } + "config": APIConfig.get_milvus_config(), }, "embedder": APIConfig.get_embedder_config(), "reranker": APIConfig.get_reranker_config(), @@ -313,6 +303,20 @@ def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } + @staticmethod + def get_milvus_config(): + return { + "collection_name": [ + "explicit_preference", + "implicit_preference", + ], + "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + "distance_metric": "cosine", + "uri": os.getenv("MILVUS_URI", "http://localhost:19530"), + "user_name": os.getenv("MILVUS_USER_NAME", "root"), + "password": os.getenv("MILVUS_PASSWORD", "12345678") + } + @staticmethod def get_mysql_config() -> dict[str, Any]: """Get MySQL configuration.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index a332de583..125b9ea3f 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,4 +1,5 @@ import os +from concurrent.futures import ThreadPoolExecutor import traceback from typing import Any @@ -15,6 +16,7 @@ ) from memos.configs.embedder import EmbedderConfigFactory from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.vec_db import VectorDBConfigFactory from memos.configs.internet_retriever import InternetRetrieverConfigFactory from memos.configs.llm import LLMConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory @@ -30,6 +32,9 @@ from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) +from memos.memories.textual.prefer_text_memory.factory import AdderFactory, ExtractorFactory, RetrieverFactory +from memos.memories.textual.prefer_text_memory.config import AdderConfigFactory, ExtractorConfigFactory, RetrieverConfigFactory +from memos.vec_dbs.factory import VecDBFactory from memos.reranker.factory import RerankerFactory from memos.types import MOSSearchResult, UserContext @@ -56,6 +61,16 @@ def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: ) +def _build_vec_db_config() -> dict[str, Any]: + """Build vector database configuration.""" + return VectorDBConfigFactory.model_validate( + { + "backend": "milvus", + "config": APIConfig.get_milvus_config(), + } + ) + + def _build_llm_config() -> dict[str, Any]: """Build LLM configuration.""" return LLMConfigFactory.model_validate( @@ -88,6 +103,33 @@ def _build_internet_retriever_config() -> dict[str, Any]: return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) +def _build_extractor_config() -> dict[str, Any]: + """Build extractor configuration.""" + return ExtractorConfigFactory.model_validate( + { + "backend": "naive", + "config": {} + } + ) + +def _build_adder_config() -> dict[str, Any]: + """Build adder configuration.""" + return AdderConfigFactory.model_validate( + { + "backend": "naive", + "config": {} + } + ) + +def _build_retriever_config() -> dict[str, Any]: + """Build retriever configuration.""" + return RetrieverConfigFactory.model_validate( + { + "backend": "naive", + "config": {} + } + ) + def _get_default_memory_size(cube_config) -> dict[str, int]: """Get default memory size configuration.""" return getattr(cube_config.text_mem.config, "memory_size", None) or { @@ -110,9 +152,14 @@ def init_server(): mem_reader_config = _build_mem_reader_config() reranker_config = _build_reranker_config() internet_retriever_config = _build_internet_retriever_config() + vector_db_config = _build_vec_db_config() + extractor_config = _build_extractor_config() + adder_config = _build_adder_config() + retriever_config = _build_retriever_config() # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) + vector_db = VecDBFactory.from_config(vector_db_config) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) @@ -120,6 +167,25 @@ def init_server(): internet_retriever = InternetRetrieverFactory.from_config( internet_retriever_config, embedder=embedder ) + extractor = ExtractorFactory.from_config( + config_factory=extractor_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + adder = AdderFactory.from_config( + config_factory=adder_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + retriever = RetrieverFactory.from_config( + config_factory=retriever_config, + llm_provider=llm, + embedder=embedder, + reranker=reranker, + vector_db=vector_db, + ) # Initialize memory manager memory_manager = MemoryManager( @@ -144,6 +210,10 @@ def init_server(): memory_manager, default_cube_config, mos_server, + vector_db, + extractor, + adder, + retriever, ) @@ -158,6 +228,10 @@ def init_server(): memory_manager, default_cube_config, mos_server, + vector_db, + extractor, + adder, + retriever, ) = init_server() @@ -172,6 +246,10 @@ def _create_naive_mem_cube() -> NaiveMemCube: internet_retriever=internet_retriever, memory_manager=memory_manager, default_cube_config=default_cube_config, + vector_db=vector_db, + extractor=extractor, + adder=adder, + retriever=retriever, ) return naive_mem_cube @@ -206,6 +284,7 @@ def search_memories(search_req: APISearchRequest): "text_mem": [], "act_mem": [], "para_mem": [], + "pref_mem": [], } target_session_id = search_req.session_id if not target_session_id: @@ -214,26 +293,53 @@ def search_memories(search_req: APISearchRequest): # Create MemCube and perform search naive_mem_cube = _create_naive_mem_cube() - search_results = naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=search_req.mode, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - ) - formatted_memories = [_format_memory_item(data) for data in search_results] + + def _search_text(): + results = naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=search_req.mode, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + return [_format_memory_item(data) for data in results] + + def _search_pref(): + results = naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.top_k, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + return [_format_memory_item(data) for data in results] + + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(_search_text) + pref_future = executor.submit(_search_pref) + text_formatted_memories = text_future.result() + pref_formatted_memories = pref_future.result() memories_result["text_mem"].append( { "cube_id": search_req.mem_cube_id, - "memories": formatted_memories, + "memories": text_formatted_memories, + } + ) + + memories_result["pref_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": pref_formatted_memories, } ) @@ -256,38 +362,67 @@ def add_memories(add_req: APIADDRequest): target_session_id = add_req.session_id if not target_session_id: target_session_id = "default_session" - memories = mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - ) + def _process_text_mem() -> list[dict[str, str]]: + memories_local = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + flattened_local = [mm for m in memories_local for mm in m] + logger.info(f"Memory extraction completed for user {add_req.user_id}") + mem_ids_local: list[str] = naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] + + def _process_pref_mem() -> list[dict[str, str]]: + pref_memories_local = naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) + logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(_process_text_mem) + pref_future = executor.submit(_process_pref_mem) + text_response_data = text_future.result() + pref_response_data = pref_future.result() - # Flatten memory list - flattened_memories = [mm for m in memories for mm in m] - logger.info(f"Memory extraction completed for user {add_req.user_id}") - mem_id_list: list[str] = naive_mem_cube.text_mem.add( - flattened_memories, - user_name=user_context.mem_cube_id, - ) - logger.info( - f"Added {len(mem_id_list)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_id_list}" - ) - response_data = [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) - ] return MemoryResponse( message="Memory added successfully", - data=response_data, + data=text_response_data + pref_response_data, ) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index a6f6b82a4..d127c4534 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -436,7 +436,7 @@ def remove_oldest_memory( WHERE n.memory_type = '{memory_type}' {optional_condition} ORDER BY n.updated_at DESC - OFFSET {keep_latest} + OFFSET {int(keep_latest)} DETACH DELETE n """ self.execute_query(query) diff --git a/src/memos/mem_cube/base.py b/src/memos/mem_cube/base.py index 7d7c5e779..349d511fb 100644 --- a/src/memos/mem_cube/base.py +++ b/src/memos/mem_cube/base.py @@ -19,6 +19,7 @@ def __init__(self, config: BaseMemCubeConfig): self.text_mem: BaseTextMemory self.act_mem: BaseActMemory self.para_mem: BaseParaMemory + self.pref_mem: BaseTextMemory @abstractmethod def load(self, dir: str) -> None: diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py index 7ce3ca642..d0d5262d3 100644 --- a/src/memos/mem_cube/navie.py +++ b/src/memos/mem_cube/navie.py @@ -7,6 +7,7 @@ from memos.embedders.base import BaseEmbedder from memos.exceptions import ConfigurationError, MemCubeError from memos.graph_dbs.base import BaseGraphDB +from memos.vec_dbs.base import BaseVecDB from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.base import BaseMemCube @@ -17,6 +18,10 @@ from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.reranker.base import BaseReranker +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory +from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor +from memos.memories.textual.prefer_text_memory.adder import BaseAdder +from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever logger = get_logger(__name__) @@ -34,7 +39,11 @@ def __init__( reranker: BaseReranker, memory_manager: MemoryManager, default_cube_config: GeneralMemCubeConfig, + vector_db: BaseVecDB, internet_retriever: None = None, + extractor: BaseExtractor=None, + adder: BaseAdder=None, + retriever: BaseRetriever=None, ): """Initialize the MemCube with a configuration.""" self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( @@ -49,6 +58,15 @@ def __init__( ) self._act_mem: BaseActMemory | None = None self._para_mem: BaseParaMemory | None = None + self._pref_mem: BaseTextMemory | None = SimplePreferenceTextMemory( + extractor_llm=llm, + vector_db=vector_db, + embedder=embedder, + reranker=reranker, + extractor=extractor, + adder=adder, + retriever=retriever, + ) def load( self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None @@ -69,7 +87,7 @@ def load( # If no specific memory types specified, load all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Load specified memory types if "text_mem" in memory_types and self.text_mem: @@ -84,17 +102,21 @@ def load( self.para_mem.load(dir) logger.info(f"Loaded para_mem from {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.load(dir) + logger.info(f"Loaded pref_mem from {dir}") + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None ) -> None: """Dump memories. Args: dir (str): The directory where the memory files will be saved. memory_types (list[str], optional): List of memory types to dump. If None, dumps all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ if os.path.exists(dir) and os.listdir(dir): raise MemCubeError( @@ -106,7 +128,7 @@ def dump( # If no specific memory types specified, dump all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Dump specified memory types if "text_mem" in memory_types and self.text_mem: @@ -121,6 +143,10 @@ def dump( self.para_mem.dump(dir) logger.info(f"Dumped para_mem to {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.dump(dir) + logger.info(f"Dumped pref_mem to {dir}") + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") @property @@ -164,3 +190,17 @@ def para_mem(self, value: BaseParaMemory) -> None: if not isinstance(value, BaseParaMemory): raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") self._para_mem = value + + @property + def pref_mem(self) -> "BaseTextMemory | None": + """Get the preference memory.""" + if self._pref_mem is None: + logger.warning("Preference memory is not initialized. Returning None.") + return self._pref_mem + + @pref_mem.setter + def pref_mem(self, value: BaseTextMemory) -> None: + """Set the preference memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._pref_mem = value \ No newline at end of file diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index e9f603c5e..c5ee7b176 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -643,7 +643,6 @@ def search_preference_memory(): memories = mem_cube.pref_mem.search( query, top_k=top_k if top_k else self.config.top_k, - mode=mode, info={ "user_id": target_user_id, "session_id": self.session_id, diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index 6f0871f6b..c67e6ee2b 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -13,6 +13,7 @@ from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.preference import PreferenceTextMemory from memos.memories.textual.tree import TreeTextMemory +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory class MemoryFactory(BaseMemory): @@ -24,6 +25,7 @@ class MemoryFactory(BaseMemory): "tree_text": TreeTextMemory, "simple_tree_text": SimpleTreeTextMemory, "pref_text": PreferenceTextMemory, + "simple_pref_text": SimplePreferenceTextMemory, "kv_cache": KVCacheMemory, "vllm_kv_cache": VLLMKVCacheMemory, "lora": LoRAMemory, diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py new file mode 100644 index 000000000..99bb31637 --- /dev/null +++ b/src/memos/memories/textual/simple_preference.py @@ -0,0 +1,159 @@ +import json +import os + +from typing import Any + +from memos.embedders.factory import ( + ArkEmbedder, + OllamaEmbedder, + SenTranEmbedder, + UniversalAPIEmbedder, +) +from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM +from memos.log import get_logger +from memos.memories.textual.preference import PreferenceTextMemory +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.types import MessageList +from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB + + +logger = get_logger(__name__) + + +class SimplePreferenceTextMemory(PreferenceTextMemory): + """Preference textual memory implementation for storing and retrieving memories.""" + + def __init__( + self, + extractor_llm: OpenAILLM | OllamaLLM | AzureLLM, + vector_db: MilvusVecDB | QdrantVecDB, + embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder, + reranker, + extractor, + adder, + retriever, + ): + """Initialize memory with the given configuration.""" + self.extractor_llm = extractor_llm + self.vector_db = vector_db + self.embedder = embedder + self.reranker = reranker + self.extractor = extractor + self.adder = adder + self.retriever = retriever + + def get_memory( + self, messages: list[MessageList], type: str, info: dict[str, Any] + ) -> list[TextualMemoryItem]: + """Get memory based on the messages. + Args: + messages (MessageList): The messages to get memory from. + type (str): The type of memory to get. + info (dict[str, Any]): The info to get memory. + """ + return self.extractor.extract(messages, type, info) + + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + """Search for memories based on a query. + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + return self.retriever.retrieve(query, top_k, info) + + def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + """Add memories. + + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + """ + return self.adder.add(memories) + + def get_with_collection_name( + self, collection_name: str, memory_id: str + ) -> TextualMemoryItem | None: + """Get a memory by its ID and collection name. + Args: + memory_id (str): The ID of the memory to retrieve. + collection_name (str): The name of the collection to retrieve the memory from. + Returns: + TextualMemoryItem: The memory with the given ID and collection name. + """ + try: + res = self.vector_db.get_by_id(collection_name, memory_id) + if res is None: + return None + return TextualMemoryItem( + id=res.id, + memory=res.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**res.payload), + ) + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with ID {memory_id} not found in collection {collection_name}: {e}" + ) from e + + def get_by_ids_with_collection_name( + self, collection_name: str, memory_ids: list[str] + ) -> list[TextualMemoryItem]: + """Get memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to retrieve the memory from. + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs and collection name. + """ + try: + res = self.vector_db.get_by_ids(collection_name, memory_ids) + if not res: + return [] + return [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in res + ] + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}" + ) from e + + def get_all(self) -> list[TextualMemoryItem]: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + all_collections = self.vector_db.list_collections() + all_memories = {} + for collection_name in all_collections: + items = self.vector_db.get_all(collection_name) + all_memories[collection_name] = [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in items + ] + return all_memories + + def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: + """Delete memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to delete the memory from. + memory_ids (list[str]): List of memory IDs to delete. + """ + self.vector_db.delete(collection_name, memory_ids) + + def delete_all(self) -> None: + """Delete all memories.""" + for collection_name in self.vector_db.config.collection_name: + self.vector_db.delete_collection(collection_name) + self.vector_db.create_collection() From 0f5547e7e3c80d46fb19b78423337409f61448b2 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 21 Oct 2025 17:38:37 +0800 Subject: [PATCH 45/62] add async pref add by mem_schedular --- src/memos/mem_os/core.py | 34 ++++++++++--- src/memos/mem_scheduler/general_scheduler.py | 50 ++++++++++++++++++- .../mem_scheduler/schemas/general_schemas.py | 1 + .../mem_scheduler/schemas/message_schemas.py | 4 ++ src/memos/memories/textual/preference.py | 4 +- 5 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index c5ee7b176..4f56a2d78 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -19,6 +19,7 @@ ADD_LABEL, ANSWER_LABEL, QUERY_LABEL, + PREF_ADD_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_user.user_manager import UserManager, UserRole @@ -777,14 +778,31 @@ def process_preference_memory(): and self.mem_cubes[mem_cube_id].pref_mem ): messages_list = [messages] - pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( - messages_list, - type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}) - preferences = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) - logger.info( - f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {preferences}" - ) + mem_cube = self.mem_cubes[mem_cube_id] + sync_mode = "async" + if sync_mode == "sync": + pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( + messages_list, + type="chat", + info={"user_id": target_user_id, "session_id": self.session_id}) + pref_ids = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) + logger.info( + f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {pref_ids}" + ) + elif sync_mode == "async": + assert self.mem_scheduler is not None, ( + "Mem-Scheduler must be working when use asynchronous memory adding." + ) + message_item = ScheduleMessageItem( + user_id=target_user_id, + session_id=target_session_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) # Execute both memory processing functions in parallel with ThreadPoolExecutor(max_workers=2) as executor: diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 25c7b78fd..608db2ecc 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -1,5 +1,7 @@ import json +import concurrent.futures + from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube @@ -7,6 +9,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, + PREF_ADD_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, QUERY_LABEL, WORKING_MEMORY_TYPE, @@ -16,7 +19,9 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree import TreeTextMemory +from memos.memories.textual.preference import PreferenceTextMemory logger = get_logger(__name__) @@ -34,6 +39,7 @@ def __init__(self, config: GeneralSchedulerConfig): QUERY_LABEL: self._query_message_consumer, ANSWER_LABEL: self._answer_message_consumer, ADD_LABEL: self._add_message_consumer, + PREF_ADD_LABEL: self._pref_add_message_consumer, } self.dispatcher.register_handlers(handlers) @@ -222,6 +228,48 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: except Exception as e: logger.error(f"Error: {e}", exc_info=True) + def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {PREF_ADD_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + session_id = message.session_id + mem_cube_id = message.mem_cube_id + mem_cube = message.mem_cube + content = message.content + messages_list = json.loads(content) + + logger.info( + f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}" + ) + + # Get the preference memory from the mem_cube + pref_mem = mem_cube.pref_mem + if not isinstance(pref_mem, PreferenceTextMemory): + logger.error(f"Expected PreferenceTextMemory but got {type(pref_mem).__name__}") + return + + # Use pref_mem.get_memory to process the memories + pref_memories = pref_mem.get_memory(messages_list, type="chat", info={"user_id": user_id, "session_id": session_id}) + # Add pref_mem to vector db + pref_ids = pref_mem.add(pref_memories) + + logger.info( + f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}" + ) + + except Exception as e: + logger.error(f"Error processing pref_add message: {e}", exc_info=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Thread task failed: {e}", exc_info=True) + def process_session_turn( self, queries: str | list[str], diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index d0d83091b..10fd02010 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -8,6 +8,7 @@ QUERY_LABEL = "query" ANSWER_LABEL = "answer" ADD_LABEL = "add" +PREF_ADD_LABEL = "pref_add" TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 9b5bd5d81..906bb28cf 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -34,6 +34,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) user_id: str = Field(..., description="user id") + session_id: str | None = Field(default=None, description="session id") mem_cube_id: str = Field(..., description="memcube id") label: str = Field(..., description="Label of the schedule message") mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") @@ -54,6 +55,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "example": { "item_id": "123e4567-e89b-12d3-a456-426614174000", # Sample UUID "user_id": "user123", # Example user identifier + "session_id": "session123", # Example session identifier "mem_cube_id": "cube456", # Sample memory cube ID "label": "sample_label", # Demonstration label value "mem_cube": "obj of GeneralMemCube", # Added mem_cube example @@ -75,6 +77,7 @@ def to_dict(self) -> dict: return { "item_id": self.item_id, "user_id": self.user_id, + "session_id": self.session_id, "cube_id": self.mem_cube_id, "label": self.label, "cube": "Not Applicable", # Custom cube serialization @@ -88,6 +91,7 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": return cls( item_id=data.get("item_id", str(uuid4())), user_id=data["user_id"], + session_id=data["session_id"], cube_id=data["cube_id"], label=data["label"], cube="Not Applicable", # Custom cube deserialization diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 0fc22917c..5f85aa907 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -66,11 +66,11 @@ def __init__(self, config: PreferenceTextMemoryConfig): ) def get_memory( - self, messages: MessageList, type: str, info: dict[str, Any] + self, messages: list[MessageList], type: str, info: dict[str, Any] ) -> list[TextualMemoryItem]: """Get memory based on the messages. Args: - messages (MessageList): The messages to get memory from. + messages (list[MessageList]): The messages to get memory from. type (str): The type of memory to get. info (dict[str, Any]): The info to get memory. """ From 72083dab256bceb17467e8964b711d699cf7b2c7 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 21 Oct 2025 17:54:18 +0800 Subject: [PATCH 46/62] modify --- src/memos/mem_os/core.py | 2 +- src/memos/memories/textual/prefer_text_memory/retrievers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 4f56a2d78..5f7ab7801 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -779,7 +779,7 @@ def process_preference_memory(): ): messages_list = [messages] mem_cube = self.mem_cubes[mem_cube_id] - sync_mode = "async" + sync_mode = "sync" if sync_mode == "sync": pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( messages_list, diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index b91cfbb46..bb35f6e10 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -46,10 +46,10 @@ def retrieve( with ThreadPoolExecutor(max_workers=2) as executor: # Submit all search tasks future_explicit = executor.submit( - self.vector_db.search, query_embedding, "explicit_preference", max(top_k, 20), info + self.vector_db.search, query_embedding, "explicit_preference", top_k * 2, info ) future_implicit = executor.submit( - self.vector_db.search, query_embedding, "implicit_preference", max(top_k, 20), info + self.vector_db.search, query_embedding, "implicit_preference", top_k * 2, info ) # Wait for all results From 6b8380146a3bf36aaadbe8f98a8dc0e99a81ab6c Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 22 Oct 2025 16:03:21 +0800 Subject: [PATCH 47/62] replace print to logger --- .../memories/textual/prefer_text_memory/adder.py | 12 ++++++++---- .../memories/textual/prefer_text_memory/extractor.py | 10 +++++++--- .../textual/prefer_text_memory/retrievers.py | 1 - 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 20087bfcc..db6647937 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -4,11 +4,15 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any +from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem from memos.templates.prefer_complete_prompt import NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE from memos.vec_dbs.item import MilvusVecDBItem +logger = get_logger(__name__) + + class BaseAdder(ABC): """Abstract base class for adders.""" @@ -51,7 +55,7 @@ def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool: response = result.get("is_same", False) return response if isinstance(response, bool) else response == "true" except Exception as e: - print(f"Error in judge_update_or_add: {e}") + logger.error(f"Error in judge_update_or_add: {e}") # Fallback to simple string comparison return old_msg == new_msg @@ -63,7 +67,7 @@ def _judge_update_or_add_trace_op(self, new_mem: str, retrieved_mems: str) -> di result = json.loads(response) return result except Exception as e: - print(f"Error in judge_update_or_add_trace_op: {e}") + logger.error(f"Error in judge_update_or_add_trace_op: {e}") return None def _update_memory_op_trace(self, @@ -198,7 +202,7 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | return self._update_memory(memory, search_results, collection_name, preference_type, update_mode="fast") except Exception as e: - print(f"Error processing memory {memory.id}: {e}") + logger.error(f"Error processing memory {memory.id}: {e}") return None def add( @@ -228,7 +232,7 @@ def add( added_ids.append(memory_id) except Exception as e: memory = future_to_memory[future] - print(f"Error processing memory {memory.id}: {e}") + logger.error(f"Error processing memory {memory.id}: {e}") continue return added_ids diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index bc081d795..bf7b5fb25 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -6,6 +6,7 @@ from datetime import datetime from typing import Any +from memos.log import get_logger from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.prefer_text_memory.spliter import Splitter from memos.memories.textual.prefer_text_memory.utils import convert_messages_to_string @@ -16,6 +17,9 @@ from memos.types import MessageList +logger = get_logger(__name__) + + class BaseExtractor(ABC): """Abstract base class for extractors.""" @@ -56,7 +60,7 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A result = json.loads(response) return result except Exception as e: - print(f"Error extracting explicit preference: {e}, return None") + logger.error(f"Error extracting explicit preference: {e}, return None") return None def extract_implicit_preference( @@ -74,7 +78,7 @@ def extract_implicit_preference( result = json.loads(response) return result except Exception as e: - print(f"Error extracting implicit preferences: {e}, return None") + logger.error(f"Error extracting implicit preferences: {e}, return None") return None def _process_single_chunk_explicit( @@ -178,7 +182,7 @@ def extract( memories.append(memory) except Exception as e: task_type, chunk = futures[future] - print(f"Error processing {task_type} chunk: {chunk}\n{e}") + logger.error(f"Error processing {task_type} chunk: {chunk}\n{e}") continue return memories diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index bb35f6e10..91103619b 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -3,7 +3,6 @@ from typing import Any from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem -from memos.memories.textual.prefer_text_memory.utils import deduplicate_preferences class BaseRetriever(ABC): From 87e189b3e1f5ca19ee75bd29e37482630e839dd5 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 22 Oct 2025 17:05:45 +0800 Subject: [PATCH 48/62] repair bug from make pre-commit --- docker/requirements.txt | 2 +- evaluation/scripts/utils/client.py | 6 +- poetry.lock | 2 +- src/memos/api/config.py | 23 +--- src/memos/api/routers/server_router.py | 44 +++--- src/memos/mem_cube/navie.py | 22 +-- src/memos/mem_os/core.py | 35 ++--- src/memos/mem_scheduler/general_scheduler.py | 15 +- src/memos/memories/factory.py | 4 +- src/memos/memories/textual/item.py | 6 +- .../textual/prefer_text_memory/adder.py | 128 ++++++++++++------ .../textual/prefer_text_memory/extractor.py | 8 +- .../textual/prefer_text_memory/factory.py | 4 +- .../textual/prefer_text_memory/retrievers.py | 3 - .../textual/prefer_text_memory/utils.py | 31 +++-- .../memories/textual/simple_preference.py | 7 +- src/memos/templates/instruction_completion.py | 8 +- src/memos/templates/prefer_complete_prompt.py | 14 +- src/memos/vec_dbs/item.py | 2 +- src/memos/vec_dbs/milvus.py | 8 +- 20 files changed, 201 insertions(+), 171 deletions(-) diff --git a/docker/requirements.txt b/docker/requirements.txt index d20c0b36e..4846f1832 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -157,4 +157,4 @@ volcengine-python-sdk==4.0.6 watchfiles==1.1.0 websockets==15.0.1 xlrd==2.0.2 -xlsxwriter==3.2.5 \ No newline at end of file +xlsxwriter==3.2.5 diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 33aea7497..da915f0fb 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -11,14 +11,14 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from memobase import ChatBlob +from utils.mem0_local import Mem0Client +from utils.memos_api import MemOSAPI +from utils.memos_filters import filter_memory_data from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig from memos.mem_cube.general import GeneralMemCube from memos.mem_os.product import MOSProduct -from utils.mem0_local import Mem0Client -from utils.memos_api import MemOSAPI -from utils.memos_filters import filter_memory_data load_dotenv() diff --git a/poetry.lock b/poetry.lock index d34f964b6..110b22175 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6310,4 +6310,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "d85cb8a08870d67df6e462610231f1e735ba5293bd3fe5b0c4a212b3ccff7b72" \ No newline at end of file +content-hash = "d85cb8a08870d67df6e462610231f1e735ba5293bd3fe5b0c4a212b3ccff7b72" diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 12ff91cfa..d26672883 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -121,20 +121,11 @@ def get_preference_memory_config() -> dict[str, Any]: }, "embedder": APIConfig.get_embedder_config(), "reranker": APIConfig.get_reranker_config(), - "extractor": { - "backend": "naive", - "config": {} - }, - "adder": { - "backend": "naive", - "config": {} - }, - "retriever": { - "backend": "naive", - "config": {} - }, + "extractor": {"backend": "naive", "config": {}}, + "adder": {"backend": "naive", "config": {}}, + "retriever": {"backend": "naive", "config": {}}, + }, } - } @staticmethod def get_reranker_config() -> dict[str, Any]: @@ -314,7 +305,7 @@ def get_milvus_config(): "distance_metric": "cosine", "uri": os.getenv("MILVUS_URI", "http://localhost:19530"), "user_name": os.getenv("MILVUS_USER_NAME", "root"), - "password": os.getenv("MILVUS_PASSWORD", "12345678") + "password": os.getenv("MILVUS_PASSWORD", "12345678"), } @staticmethod @@ -591,7 +582,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" else APIConfig.get_activation_vllm_config(), "para_mem": {}, - "pref_mem": {} + "pref_mem": {} if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "false" else APIConfig.get_preference_memory_config(), } @@ -656,7 +647,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" else APIConfig.get_activation_vllm_config(), "para_mem": {}, - "pref_mem": {} + "pref_mem": {} if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "false" else APIConfig.get_preference_memory_config(), } diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 125b9ea3f..d198e8eb0 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,7 +1,7 @@ import os -from concurrent.futures import ThreadPoolExecutor import traceback +from concurrent.futures import ThreadPoolExecutor from typing import Any from fastapi import APIRouter, HTTPException @@ -16,11 +16,11 @@ ) from memos.configs.embedder import EmbedderConfigFactory from memos.configs.graph_db import GraphDBConfigFactory -from memos.configs.vec_db import VectorDBConfigFactory from memos.configs.internet_retriever import InternetRetrieverConfigFactory from memos.configs.llm import LLMConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory from memos.configs.reranker import RerankerConfigFactory +from memos.configs.vec_db import VectorDBConfigFactory from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory @@ -28,15 +28,23 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_os.product_server import MOSServer from memos.mem_reader.factory import MemReaderFactory +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) +from memos.memories.textual.prefer_text_memory.factory import ( + AdderFactory, + ExtractorFactory, + RetrieverFactory, +) from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) -from memos.memories.textual.prefer_text_memory.factory import AdderFactory, ExtractorFactory, RetrieverFactory -from memos.memories.textual.prefer_text_memory.config import AdderConfigFactory, ExtractorConfigFactory, RetrieverConfigFactory -from memos.vec_dbs.factory import VecDBFactory from memos.reranker.factory import RerankerFactory from memos.types import MOSSearchResult, UserContext +from memos.vec_dbs.factory import VecDBFactory logger = get_logger(__name__) @@ -105,30 +113,18 @@ def _build_internet_retriever_config() -> dict[str, Any]: def _build_extractor_config() -> dict[str, Any]: """Build extractor configuration.""" - return ExtractorConfigFactory.model_validate( - { - "backend": "naive", - "config": {} - } - ) + return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) + def _build_adder_config() -> dict[str, Any]: """Build adder configuration.""" - return AdderConfigFactory.model_validate( - { - "backend": "naive", - "config": {} - } - ) + return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) + def _build_retriever_config() -> dict[str, Any]: """Build retriever configuration.""" - return RetrieverConfigFactory.model_validate( - { - "backend": "naive", - "config": {} - } - ) + return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) + def _get_default_memory_size(cube_config) -> dict[str, int]: """Get default memory size configuration.""" @@ -362,6 +358,7 @@ def add_memories(add_req: APIADDRequest): target_session_id = add_req.session_id if not target_session_id: target_session_id = "default_session" + def _process_text_mem() -> list[dict[str, str]]: memories_local = mem_reader.get_memory( [add_req.messages], @@ -419,7 +416,6 @@ def _process_pref_mem() -> list[dict[str, str]]: text_response_data = text_future.result() pref_response_data = pref_future.result() - return MemoryResponse( message="Memory added successfully", data=text_response_data + pref_response_data, diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py index d0d5262d3..8074c1c57 100644 --- a/src/memos/mem_cube/navie.py +++ b/src/memos/mem_cube/navie.py @@ -7,7 +7,6 @@ from memos.embedders.base import BaseEmbedder from memos.exceptions import ConfigurationError, MemCubeError from memos.graph_dbs.base import BaseGraphDB -from memos.vec_dbs.base import BaseVecDB from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.base import BaseMemCube @@ -15,13 +14,14 @@ from memos.memories.activation.base import BaseActMemory from memos.memories.parametric.base import BaseParaMemory from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.prefer_text_memory.adder import BaseAdder +from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor +from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.reranker.base import BaseReranker -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory -from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor -from memos.memories.textual.prefer_text_memory.adder import BaseAdder -from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever +from memos.vec_dbs.base import BaseVecDB logger = get_logger(__name__) @@ -41,9 +41,9 @@ def __init__( default_cube_config: GeneralMemCubeConfig, vector_db: BaseVecDB, internet_retriever: None = None, - extractor: BaseExtractor=None, - adder: BaseAdder=None, - retriever: BaseRetriever=None, + extractor: BaseExtractor | None = None, + adder: BaseAdder | None = None, + retriever: BaseRetriever | None = None, ): """Initialize the MemCube with a configuration.""" self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( @@ -109,7 +109,9 @@ def load( logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump memories. Args: @@ -203,4 +205,4 @@ def pref_mem(self, value: BaseTextMemory) -> None: """Set the preference memory.""" if not isinstance(value, BaseTextMemory): raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") - self._pref_mem = value \ No newline at end of file + self._pref_mem = value diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 5f7ab7801..3804e82da 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -2,11 +2,11 @@ import os import time +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path from threading import Lock from typing import Any, Literal -from concurrent.futures import ThreadPoolExecutor from memos.configs.mem_os import MOSConfig from memos.llms.factory import LLMFactory @@ -18,8 +18,8 @@ from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, - QUERY_LABEL, PREF_ADD_LABEL, + QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_user.user_manager import UserManager, UserRole @@ -604,14 +604,14 @@ def search( for mem_cube_id, mem_cube in tmp_mem_cubes.items(): # Define internal functions for parallel search execution - def search_textual_memory(): + def search_textual_memory(cube_id, cube): if ( - (mem_cube_id in install_cube_ids) - and (mem_cube.text_mem is not None) + (cube_id in install_cube_ids) + and (cube.text_mem is not None) and self.config.enable_textual_memory ): time_start = time.time() - memories = mem_cube.text_mem.search( + memories = cube.text_mem.search( query, top_k=top_k if top_k else self.config.top_k, mode=mode, @@ -626,22 +626,22 @@ def search_textual_memory(): ) search_time_end = time.time() logger.info( - f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n" + f"🧠 [Memory] Searched memories from {cube_id}:\n{self._str_memories(memories)}\n" ) logger.info( f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" ) - return {"cube_id": mem_cube_id, "memories": memories} + return {"cube_id": cube_id, "memories": memories} return None - def search_preference_memory(): + def search_preference_memory(cube_id, cube): if ( - (mem_cube_id in install_cube_ids) - and (mem_cube.pref_mem is not None) + (cube_id in install_cube_ids) + and (cube.pref_mem is not None) and self.config.enable_preference_memory ): time_start = time.time() - memories = mem_cube.pref_mem.search( + memories = cube.pref_mem.search( query, top_k=top_k if top_k else self.config.top_k, info={ @@ -652,18 +652,18 @@ def search_preference_memory(): ) search_time_end = time.time() logger.info( - f"🧠 [Memory] Searched preferences from {mem_cube_id}:\n{self._str_memories(memories)}\n" + f"🧠 [Memory] Searched preferences from {cube_id}:\n{self._str_memories(memories)}\n" ) logger.info( f"time search pref: search pref time user_id: {target_user_id} time is: {search_time_end - time_start}" ) - return {"cube_id": mem_cube_id, "memories": memories} + return {"cube_id": cube_id, "memories": memories} return None # Execute both search functions in parallel with ThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(search_textual_memory) - pref_future = executor.submit(search_preference_memory) + text_future = executor.submit(search_textual_memory, mem_cube_id, mem_cube) + pref_future = executor.submit(search_preference_memory, mem_cube_id, mem_cube) # Wait for both tasks to complete and collect results text_result = text_future.result() @@ -784,7 +784,8 @@ def process_preference_memory(): pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( messages_list, type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}) + info={"user_id": target_user_id, "session_id": self.session_id}, + ) pref_ids = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) logger.info( f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {pref_ids}" diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 608db2ecc..992eb4fa4 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -1,6 +1,5 @@ -import json - import concurrent.futures +import json from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger @@ -9,8 +8,8 @@ from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, - PREF_ADD_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, + PREF_ADD_LABEL, QUERY_LABEL, WORKING_MEMORY_TYPE, MemCubeID, @@ -20,8 +19,8 @@ from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english from memos.memories.textual.item import TextualMemoryItem -from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.preference import PreferenceTextMemory +from memos.memories.textual.tree import TreeTextMemory logger = get_logger(__name__) @@ -240,9 +239,7 @@ def process_message(message: ScheduleMessageItem): content = message.content messages_list = json.loads(content) - logger.info( - f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) + logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}") # Get the preference memory from the mem_cube pref_mem = mem_cube.pref_mem @@ -251,7 +248,9 @@ def process_message(message: ScheduleMessageItem): return # Use pref_mem.get_memory to process the memories - pref_memories = pref_mem.get_memory(messages_list, type="chat", info={"user_id": user_id, "session_id": session_id}) + pref_memories = pref_mem.get_memory( + messages_list, type="chat", info={"user_id": user_id, "session_id": session_id} + ) # Add pref_mem to vector db pref_ids = pref_mem.add(pref_memories) diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index c67e6ee2b..5ba1c6726 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -10,10 +10,10 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.general import GeneralTextMemory from memos.memories.textual.naive import NaiveTextMemory -from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.preference import PreferenceTextMemory -from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.simple_preference import SimplePreferenceTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree import TreeTextMemory class MemoryFactory(BaseMemory): diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index eb4f0f7ac..6d975cfd7 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -170,9 +170,9 @@ class SearchedTreeNodeTextualMemoryMetadata(TreeNodeTextualMemoryMetadata): class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): """Metadata for preference memory item.""" - preference_type: Literal[ - "explicit_preference", "implicit_preference" - ] = Field(default="explicit_preference", description="Type of preference.") + preference_type: Literal["explicit_preference", "implicit_preference"] = Field( + default="explicit_preference", description="Type of preference." + ) dialog_id: str | None = Field(default=None, description="ID of the dialog.") dialog_str: str | None = Field(default=None, description="String of the dialog.") embedding: list[float] | None = Field(default=None, description="Vector of the dialog.") diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index db6647937..390f048ef 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -6,7 +6,10 @@ from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem -from memos.templates.prefer_complete_prompt import NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE +from memos.templates.prefer_complete_prompt import ( + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE, +) from memos.vec_dbs.item import MilvusVecDBItem @@ -59,8 +62,12 @@ def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool: # Fallback to simple string comparison return old_msg == new_msg - def _judge_update_or_add_trace_op(self, new_mem: str, retrieved_mems: str) -> dict[str, Any] | None: - prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE.replace("{new_memory}", new_mem).replace("{retrieved_memories}", retrieved_mems) + def _judge_update_or_add_trace_op( + self, new_mem: str, retrieved_mems: str + ) -> dict[str, Any] | None: + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE.replace("{new_memory}", new_mem).replace( + "{retrieved_memories}", retrieved_mems + ) try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) response = response.strip().replace("```json", "").replace("```", "").strip() @@ -70,63 +77,91 @@ def _judge_update_or_add_trace_op(self, new_mem: str, retrieved_mems: str) -> di logger.error(f"Error in judge_update_or_add_trace_op: {e}") return None - def _update_memory_op_trace(self, - new_memory: TextualMemoryItem, - retrieved_memories: list[MilvusVecDBItem], - collection_name: str, - preference_type: str - ) -> list[str] | str: - + def _update_memory_op_trace( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str, + ) -> list[str] | str: if not retrieved_memories: payload = new_memory.to_dict()["metadata"] fields_to_remove = {"dialog_id", "dialog_str", "embedding"} payload = {k: v for k, v in payload.items() if k not in fields_to_remove} vec_db_item = MilvusVecDBItem( - id=new_memory.id, memory=new_memory.memory, vector=new_memory.metadata.embedding, payload=payload + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, ) self.vector_db.add(collection_name, [vec_db_item]) return new_memory.id new_mem_input = { "context_summary": new_memory.memory, - "preference": new_memory.metadata.explicit_preference if preference_type == "explicit_preference" else new_memory.metadata.implicit_preference, + "preference": new_memory.metadata.explicit_preference + if preference_type == "explicit_preference" + else new_memory.metadata.implicit_preference, } - retrieved_mem_inputs = [{"id": mem.id, "context_summary": mem.memory, "preference": mem.payload[preference_type]} for mem in retrieved_memories] + retrieved_mem_inputs = [ + { + "id": mem.id, + "context_summary": mem.memory, + "preference": mem.payload[preference_type], + } + for mem in retrieved_memories + ] - rsp = self._judge_update_or_add_trace_op(new_mem=json.dumps(new_mem_input), retrieved_mems=json.dumps(retrieved_mem_inputs)) + rsp = self._judge_update_or_add_trace_op( + new_mem=json.dumps(new_mem_input), retrieved_mems=json.dumps(retrieved_mem_inputs) + ) if not rsp: payload = new_memory.to_dict()["metadata"] fields_to_remove = {"dialog_id", "dialog_str", "embedding"} payload = {k: v for k, v in payload.items() if k not in fields_to_remove} vec_db_item = MilvusVecDBItem( - id=new_memory.id, memory=new_memory.memory, vector=new_memory.metadata.embedding, payload=payload + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, ) self.vector_db.add(collection_name, [vec_db_item]) return new_memory.id - + def execute_op(op): op_type = op["type"].lower() if op_type == "add": payload = new_memory.to_dict()["metadata"] - payload = {k: v for k, v in payload.items() if k not in {"dialog_id", "dialog_str", "embedding"}} + payload = { + k: v + for k, v in payload.items() + if k not in {"dialog_id", "dialog_str", "embedding"} + } vec_db_item = MilvusVecDBItem( - id=new_memory.id, memory=new_memory.memory, - vector=new_memory.metadata.embedding, payload=payload + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, ) self.vector_db.add(collection_name, [vec_db_item]) return new_memory.id elif op_type == "update": - payload = {"preference_type": preference_type, preference_type: op["new_preference"]} + payload = { + "preference_type": preference_type, + preference_type: op["new_preference"], + } vec_db_item = MilvusVecDBItem( - id=op["target_id"], memory=op["new_context_summary"], - vector=self.embedder.embed([op["new_context_summary"]])[0], payload=payload + id=op["target_id"], + memory=op["new_context_summary"], + vector=self.embedder.embed([op["new_context_summary"]])[0], + payload=payload, ) self.vector_db.update(collection_name, op["target_id"], vec_db_item) return op["target_id"] elif op_type == "delete": self.vector_db.delete(collection_name, [op["target_id"]]) return None - + with ThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: future_to_op = {executor.submit(execute_op, op): op for op in rsp["trace"]} added_ids = [] @@ -137,16 +172,20 @@ def execute_op(op): return added_ids - def _update_memory_fast(self, - new_memory: TextualMemoryItem, - retrieved_memories: list[MilvusVecDBItem], - collection_name: str, - ) -> str: + def _update_memory_fast( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + ) -> str: payload = new_memory.to_dict()["metadata"] fields_to_remove = {"dialog_id", "dialog_str", "embedding"} payload = {k: v for k, v in payload.items() if k not in fields_to_remove} vec_db_item = MilvusVecDBItem( - id=new_memory.id, memory=new_memory.memory, vector=new_memory.metadata.embedding, payload=payload + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, ) recall = retrieved_memories[0] if retrieved_memories else None if not recall or (recall.score is not None and recall.score < 0.5): @@ -161,13 +200,14 @@ def _update_memory_fast(self, self.vector_db.update(collection_name, new_memory.id, vec_db_item) return new_memory.id - def _update_memory(self, - new_memory: TextualMemoryItem, - retrieved_memories: list[MilvusVecDBItem], - collection_name: str, - preference_type: str, - update_mode: str = "op_trace", - ) -> list[str] | str | None: + def _update_memory( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str, + update_mode: str = "op_trace", + ) -> list[str] | str | None: """Update the memory. Args: new_memory: TextualMemoryItem @@ -177,13 +217,14 @@ def _update_memory(self, update_mode: str, "op_trace" or "fast" """ if update_mode == "op_trace": - return self._update_memory_op_trace(new_memory, retrieved_memories, collection_name, preference_type) + return self._update_memory_op_trace( + new_memory, retrieved_memories, collection_name, preference_type + ) elif update_mode == "fast": return self._update_memory_fast(new_memory, retrieved_memories, collection_name) else: raise ValueError(f"Invalid update mode: {update_mode}") - def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | None: """Process a single memory and return its ID if added successfully.""" try: @@ -195,11 +236,16 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | collection_name = pref_type_collection_map[preference_type] search_results = self.vector_db.search( - memory.metadata.embedding, collection_name, top_k=5, filter={"user_id": memory.metadata.user_id} + memory.metadata.embedding, + collection_name, + top_k=5, + filter={"user_id": memory.metadata.user_id}, ) search_results.sort(key=lambda x: x.score, reverse=True) - - return self._update_memory(memory, search_results, collection_name, preference_type, update_mode="fast") + + return self._update_memory( + memory, search_results, collection_name, preference_type, update_mode="fast" + ) except Exception as e: logger.error(f"Error processing memory {memory.id}: {e}") diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index bf7b5fb25..460b31f4f 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -63,9 +63,7 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A logger.error(f"Error extracting explicit preference: {e}, return None") return None - def extract_implicit_preference( - self, qa_pair: MessageList | str - ) -> dict[str, Any] | None: + def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, Any] | None: """Extract implicit preferences from cluster qa pairs.""" if not qa_pair: return None @@ -152,9 +150,7 @@ def extract( return [] memories = [] - with ThreadPoolExecutor( - max_workers=min(max_workers, len(chunks)) - ) as executor: + with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: futures = { executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ( "explicit", diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index b0cba396f..22182261a 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -73,4 +73,6 @@ def from_config( if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") retriever_class = cls.backend_to_class[backend] - return retriever_class(llm_provider=llm_provider, embedder=embedder, reranker=reranker, vector_db=vector_db) + return retriever_class( + llm_provider=llm_provider, embedder=embedder, reranker=reranker, vector_db=vector_db + ) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 91103619b..7f70bac3b 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -85,7 +85,4 @@ def retrieve( explicit_prefs = [item for item, _ in explicit_prefs] implicit_prefs = [item for item, _ in implicit_prefs] - # deduplicate implicit preferences, due to more duplicates - # implicit_prefs = deduplicate_preferences(implicit_prefs, similarity_threshold=0.6, num_perm=256) - return explicit_prefs + implicit_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py index 33ca21481..b8acb6c5b 100644 --- a/src/memos/memories/textual/prefer_text_memory/utils.py +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -1,6 +1,7 @@ import re -from datasketch import MinHashLSH, MinHash +from datasketch import MinHash, MinHashLSH + from memos.memories.textual.item import TextualMemoryItem from memos.types import MessageList @@ -17,45 +18,47 @@ def convert_messages_to_string(messages: MessageList) -> str: return message_text -def deduplicate_preferences(prefs: list[TextualMemoryItem], similarity_threshold: float = 0.6, num_perm: int = 256) -> list[TextualMemoryItem]: +def deduplicate_preferences( + prefs: list[TextualMemoryItem], similarity_threshold: float = 0.6, num_perm: int = 256 +) -> list[TextualMemoryItem]: """ Deduplicate preference texts using MinHash algorithm. - + Args: prefs: List of preference memory items to deduplicate similarity_threshold: Jaccard similarity threshold (0.0-1.0), default 0.8 - + Returns: Deduplicated list of preference items """ if not prefs: return prefs - + # Use MinHashLSH for efficient similarity search lsh = MinHashLSH(threshold=similarity_threshold, num_perm=num_perm) unique_prefs = [] - + for i, pref in enumerate(prefs): # Extract preference text - if hasattr(pref.metadata, 'implicit_preference') and pref.metadata.implicit_preference: + if hasattr(pref.metadata, "implicit_preference") and pref.metadata.implicit_preference: text = pref.metadata.implicit_preference - elif hasattr(pref.metadata, 'explicit_preference') and pref.metadata.explicit_preference: + elif hasattr(pref.metadata, "explicit_preference") and pref.metadata.explicit_preference: text = pref.metadata.explicit_preference else: text = pref.memory - + # Create MinHash from text tokens minhash = MinHash(num_perm=num_perm) # Simple tokenization: split by whitespace and clean - tokens = re.findall(r'\w+', text.lower()) + tokens = re.findall(r"\w+", text.lower()) for token in tokens: - minhash.update(token.encode('utf8')) - + minhash.update(token.encode("utf8")) + # Check for duplicates using LSH similar_items = lsh.query(minhash) - + if not similar_items: # No similar items found lsh.insert(i, minhash) unique_prefs.append(pref) - + return unique_prefs diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index 99bb31637..29f30d384 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -1,6 +1,3 @@ -import json -import os - from typing import Any from memos.embedders.factory import ( @@ -11,8 +8,8 @@ ) from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger -from memos.memories.textual.preference import PreferenceTextMemory from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.preference import PreferenceTextMemory from memos.types import MessageList from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB @@ -32,7 +29,7 @@ def __init__( extractor, adder, retriever, - ): + ): """Initialize memory with the given configuration.""" self.extractor_llm = extractor_llm self.vector_db = vector_db diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py index 50e477775..483a30766 100644 --- a/src/memos/templates/instruction_completion.py +++ b/src/memos/templates/instruction_completion.py @@ -3,8 +3,8 @@ def build_system_prompt( - memories: list[TextualMemoryItem] | None = None, - instruction_strategy: str = "process_conflict"): + memories: list[TextualMemoryItem] | None = None, instruction_strategy: str = "process_conflict" +): # Initialize all memory lists mems = { "textual_mems": [], @@ -67,12 +67,12 @@ def get_instruction( query: str, memories: list[TextualMemoryItem] | None = None, history: MessageList | None = None, - instruction_strategy: str = "process_conflict" + instruction_strategy: str = "process_conflict", ) -> str: """Create instruction following the memories, preference and tool information.""" system_prompt = build_system_prompt(memories, instruction_strategy) - + messages = [ {"role": "system", "content": system_prompt}, *history, diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index c4597073a..9712fa5ec 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -84,7 +84,7 @@ NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE = """ # User Preference Memory Management Agent -You are a **User Preference Memory Management Agent**. +You are a **User Preference Memory Management Agent**. Your goal is to maintain a user's long-term **preference memory base** by analyzing new preference information and determining how it should update existing memories. Each memory entry contains three fields: @@ -233,10 +233,10 @@ ## Output Requirements -- The output **must** be valid JSON. -- Each operation must include both `preference` and `context_summary` updates where applicable. -- Each operation must include a clear `reason`. -- Multiple retrieved memories may be merged into one unified updated memory. -- `after_update_state` must reflect the final, post-update state of the preference memory base. +- The output **must** be valid JSON. +- Each operation must include both `preference` and `context_summary` updates where applicable. +- Each operation must include a clear `reason`. +- Multiple retrieved memories may be merged into one unified updated memory. +- `after_update_state` must reflect the final, post-update state of the preference memory base. - Do **not** include any explanatory text outside the JSON. -""" \ No newline at end of file +""" diff --git a/src/memos/vec_dbs/item.py b/src/memos/vec_dbs/item.py index c753ab177..081400f15 100644 --- a/src/memos/vec_dbs/item.py +++ b/src/memos/vec_dbs/item.py @@ -46,4 +46,4 @@ def to_dict(self) -> dict[str, Any]: class MilvusVecDBItem(VecDBItem): """Represents a single item in the Milvus vector database.""" - memory: str | None = Field(default=None, description="Memory string") \ No newline at end of file + memory: str | None = Field(default=None, description="Memory string") diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 88ec7df42..50d53a54a 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -39,9 +39,7 @@ def create_schema(self): schema.add_field( field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True ) - schema.add_field( - field_name="memory", datatype=DataType.VARCHAR, max_length=65535 - ) + schema.add_field(field_name="memory", datatype=DataType.VARCHAR, max_length=65535) schema.add_field( field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension ) @@ -337,7 +335,9 @@ def add(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]] def update(self, collection_name: str, id: str, data: MilvusVecDBItem | dict[str, Any]) -> None: """Update an item in the vector database.""" if id != data.id: - raise ValueError(f"The id of the data to update must be the same as the id of the item to update, ID mismatch: expected {id}, got {data.id}") + raise ValueError( + f"The id of the data to update must be the same as the id of the item to update, ID mismatch: expected {id}, got {data.id}" + ) if isinstance(data, dict): data = data.copy() data = MilvusVecDBItem.from_dict(data) From f9094e0fe8c406328906b7474e60d6454a0eea9a Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 23 Oct 2025 11:32:57 +0800 Subject: [PATCH 49/62] inst cplt --- src/memos/api/routers/server_router.py | 5 + src/memos/mem_os/core.py | 1 + src/memos/templates/instruction_completion.py | 104 +++++------------- src/memos/templates/prefer_complete_prompt.py | 8 ++ 4 files changed, 44 insertions(+), 74 deletions(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index d198e8eb0..dc731977f 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -43,6 +43,7 @@ InternetRetrieverFactory, ) from memos.reranker.factory import RerankerFactory +from memos.templates.instruction_completion import instruct_completion from memos.types import MOSSearchResult, UserContext from memos.vec_dbs.factory import VecDBFactory @@ -281,6 +282,7 @@ def search_memories(search_req: APISearchRequest): "act_mem": [], "para_mem": [], "pref_mem": [], + "instruct_completion": "", } target_session_id = search_req.session_id if not target_session_id: @@ -339,6 +341,9 @@ def _search_pref(): } ) + pref_instruction = instruct_completion(pref_formatted_memories) + memories_result["instruct_completion"] = pref_instruction + return SearchResponse( message="Search completed successfully", data=memories_result, diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 3804e82da..89f5448b6 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -779,6 +779,7 @@ def process_preference_memory(): ): messages_list = [messages] mem_cube = self.mem_cubes[mem_cube_id] + # TODO: To be modified, currently only support sync mode, wait for plaintext mode sync_mode = "sync" if sync_mode == "sync": pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py index 483a30766..ff4d83cd8 100644 --- a/src/memos/templates/instruction_completion.py +++ b/src/memos/templates/instruction_completion.py @@ -1,82 +1,38 @@ -from memos.memories.textual.item import TextualMemoryItem -from memos.types import MessageList +from typing import Any +from memos.templates.prefer_complete_prompt import PREF_INSTRUCTIONS -def build_system_prompt( - memories: list[TextualMemoryItem] | None = None, instruction_strategy: str = "process_conflict" -): - # Initialize all memory lists - mems = { - "textual_mems": [], - "explicit_prefs": [], - "implicit_prefs": [], - } +def instruct_completion( + memories: list[dict[str, Any]] | None = None, +) -> str: + """Create instruction following the preferences.""" + explicit_pref = [] + implicit_pref = [] for memory in memories: - if memory.metadata.preference_type == "explicit_preference": - mems["explicit_prefs"].append(memory.metadata.explicit_preference) - elif memory.metadata.preference_type == "implicit_preference": - mems["implicit_prefs"].append(memory.metadata.implicit_preference) - else: - mems["textual_mems"].append(memory.memory) - - # Build memories string with different titles for different types - memories_parts = [] - if mems["textual_mems"]: - memories_parts.append("## Textual Memories:") - for i, mem in enumerate(mems["textual_mems"], 1): - memories_parts.append(f"{i}. {mem}") - if mems["explicit_prefs"]: - memories_parts.append("## Explicit Preferences:") - for i, pref in enumerate(mems["explicit_prefs"], 1): - memories_parts.append(f"{i}. {pref}") - - if mems["implicit_prefs"]: - memories_parts.append("\n## Implicit Preferences:") - for i, pref in enumerate(mems["implicit_prefs"], 1): - memories_parts.append(f"{i}. {pref}") - - memories_str = "\n".join(memories_parts) - - system_prompt = ( - "You are a knowledgeable and helpful AI assistant. " - "You have access to conversation memories that help you provide more personalized responses. " - "Use the memories to understand the user's context, preferences, and past interactions. " - "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." - "\n\n# Memories:\n{memories}" + pref_type = memory.get("metadata", {}).get("preference_type") + if pref_type == "explicit_preference": + pref = memory.get("metadata", {}).get("explicit_preference", None) + if pref: + explicit_pref.append(pref) + elif pref_type == "implicit_preference": + pref = memory.get("metadata", {}).get("implicit_preference", None) + if pref: + implicit_pref.append(pref) + + explicit_pref_str = ( + "Explicit Preferences:\n" + "\n".join(explicit_pref) if explicit_pref else "" + ) + implicit_pref_str = ( + "Implicit Preferences:\n" + "\n".join(implicit_pref) if implicit_pref else "" ) - if instruction_strategy == "raw": - system_prompt = system_prompt.replace("{memories}", memories_str) - elif instruction_strategy == "process_conflict": - system_prompt = ( - system_prompt - + ( - "\nNote: Textual memories are summaries of facts, while preference memories are summaries of user preferences. " - + "Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts." - + "When encountering preference conflicts, the priority is: explicit preferences > textual memories > implicit preferences." - ) - ).replace("{memories}", memories_str) - else: - raise ValueError(f"Invalid instruction strategy: {instruction_strategy}") - - return system_prompt - - -def get_instruction( - query: str, - memories: list[TextualMemoryItem] | None = None, - history: MessageList | None = None, - instruction_strategy: str = "process_conflict", -) -> str: - """Create instruction following the memories, preference and tool information.""" - - system_prompt = build_system_prompt(memories, instruction_strategy) + if not explicit_pref_str and not implicit_pref_str: + return "" - messages = [ - {"role": "system", "content": system_prompt}, - *history, - {"role": "user", "content": query}, - ] + if not explicit_pref_str: + return implicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("explicit preferences > ", "") + if not implicit_pref_str: + return explicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("implicit preferences > ", "") - return messages + return explicit_pref_str + "\n" + implicit_pref_str + "\n" + PREF_INSTRUCTIONS diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 9712fa5ec..d40b7b778 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -240,3 +240,11 @@ - `after_update_state` must reflect the final, post-update state of the preference memory base. - Do **not** include any explanatory text outside the JSON. """ + + +PREF_INSTRUCTIONS = """ +# Note: +Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. +Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. +When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" From 4f764fc604d922296cf8f30310e4ede084f54a3a Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 23 Oct 2025 15:28:35 +0800 Subject: [PATCH 50/62] align to liji cloud server --- src/memos/api/product_models.py | 1 + src/memos/api/routers/server_router.py | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 86751b008..940d7c743 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -179,6 +179,7 @@ class APISearchRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) + handle_pref_mem: bool = Field(False, description="Whether to handle preference memory") class APIADDRequest(BaseRequest): diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index dc731977f..e7f50dcd3 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -281,8 +281,8 @@ def search_memories(search_req: APISearchRequest): "text_mem": [], "act_mem": [], "para_mem": [], - "pref_mem": [], - "instruct_completion": "", + "pref_mem": str, + "prefs": [], } target_session_id = search_req.session_id if not target_session_id: @@ -333,16 +333,17 @@ def _search_pref(): "memories": text_formatted_memories, } ) + if os.getenv("RETURN_ORIGINAL_PREF_MEM", "false").lower() == "true": + memories_result["prefs"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": pref_formatted_memories, + } + ) - memories_result["pref_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": pref_formatted_memories, - } - ) - - pref_instruction = instruct_completion(pref_formatted_memories) - memories_result["instruct_completion"] = pref_instruction + pref_instruction: str = instruct_completion(pref_formatted_memories) + if search_req.handle_pref_mem: + memories_result["pref_mem"] = pref_instruction return SearchResponse( message="Search completed successfully", From e02722afbb24baec065c16420a2f8c01c47d694c Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 23 Oct 2025 16:48:21 +0800 Subject: [PATCH 51/62] repair pkg problem --- poetry.lock | 70 ++++++++++++++++--- pyproject.toml | 8 +++ .../textual/prefer_text_memory/utils.py | 10 ++- src/memos/vec_dbs/milvus.py | 5 +- 4 files changed, 80 insertions(+), 13 deletions(-) diff --git a/poetry.lock b/poetry.lock index 110b22175..44265bca8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -690,6 +690,30 @@ toml = ["tomli (>=2.0.0) ; python_version < \"3.11\""] trio = ["trio (>=0.10.0)"] yaml = ["pyyaml (>=6.0.1)"] +[[package]] +name = "datasketch" +version = "1.6.5" +description = "Probabilistic data structures for processing and searching very large datasets" +optional = true +python-versions = "*" +groups = ["main"] +markers = "extra == \"pref-mem\" or extra == \"all\"" +files = [ + {file = "datasketch-1.6.5-py3-none-any.whl", hash = "sha256:59311b2925b2f37536e9f7c2f46bbc25e8e54379c8635a3fa7ca55d2abb66d1b"}, + {file = "datasketch-1.6.5.tar.gz", hash = "sha256:ba2848cb74f23d6d3dd444cf24edcbc47b1c34a171b1803231793ed4d74d4fcf"}, +] + +[package.dependencies] +numpy = ">=1.11" +scipy = ">=1.0.0" + +[package.extras] +benchmark = ["SetSimilaritySearch (>=0.1.7)", "matplotlib (>=3.1.2)", "nltk (>=3.4.5)", "pandas (>=0.25.3)", "pyfarmhash (>=0.2.2)", "pyhash (>=0.9.3)", "scikit-learn (>=0.21.3)", "scipy (>=1.3.3)"] +cassandra = ["cassandra-driver (>=3.20)"] +experimental-aio = ["aiounittest ; python_version >= \"3.6\"", "motor ; python_version >= \"3.6\""] +redis = ["redis (>=2.10.0)"] +test = ["cassandra-driver (>=3.20)", "coverage", "mock (>=2.0.0)", "mockredispy", "nose (>=1.3.7)", "nose-exclude (>=0.5.0)", "pymongo (>=3.9.0)", "pytest", "redis (>=2.10.0)"] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1222,7 +1246,7 @@ files = [ {file = "grpcio-1.73.1-cp39-cp39-win_amd64.whl", hash = "sha256:42f0660bce31b745eb9d23f094a332d31f210dcadd0fc8e5be7e4c62a87ce86b"}, {file = "grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87"}, ] -markers = {main = "extra == \"all\""} +markers = {main = "extra == \"pref-mem\" or extra == \"all\""} [package.extras] protobuf = ["grpcio-tools (>=1.73.1)"] @@ -3241,7 +3265,7 @@ files = [ {file = "pandas-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:b4b0de34dc8499c2db34000ef8baad684cfa4cbd836ecee05f323ebfba348c7d"}, {file = "pandas-2.3.1.tar.gz", hash = "sha256:0a95b9ac964fe83ce317827f80304d37388ea77616b1425f0ae41c9d2d0d7bb2"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [package.dependencies] numpy = [ @@ -3560,7 +3584,7 @@ files = [ {file = "protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e"}, {file = "protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [[package]] name = "pycparser" @@ -3773,6 +3797,33 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pymilvus" +version = "2.6.2" +description = "Python Sdk for Milvus" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"pref-mem\" or extra == \"all\"" +files = [ + {file = "pymilvus-2.6.2-py3-none-any.whl", hash = "sha256:933e447e09424d490dcf595053b01a7277dadea7ae3235cd704363bd6792509d"}, + {file = "pymilvus-2.6.2.tar.gz", hash = "sha256:b4802cc954de8f2d47bf8d6230e92196514dcb8a3726ba6098dc27909d4bc8e3"}, +] + +[package.dependencies] +grpcio = ">=1.66.2,<1.68.0 || >1.68.0,<1.68.1 || >1.68.1,<1.69.0 || >1.69.0,<1.70.0 || >1.70.0,<1.70.1 || >1.70.1,<1.71.0 || >1.71.0,<1.72.1 || >1.72.1,<1.73.0 || >1.73.0" +pandas = ">=1.2.4" +protobuf = ">=5.27.2" +python-dotenv = ">=1.0.1,<2.0.0" +setuptools = ">69" +ujson = ">=2.0.0" + +[package.extras] +bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "requests", "urllib3"] +dev = ["azure-storage-blob", "black", "grpcio (==1.66.2)", "grpcio-testing (==1.66.2)", "grpcio-tools (==1.66.2)", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "pytest (>=5.3.4)", "pytest-asyncio", "pytest-cov (>=5.0.0)", "pytest-timeout (>=1.3.4)", "requests", "ruff (>=0.12.9,<1)", "scipy", "urllib3"] +milvus-lite = ["milvus-lite (>=2.4.0) ; sys_platform != \"win32\""] +model = ["pymilvus.model (>=0.3.0)"] + [[package]] name = "pymysql" version = "1.1.2" @@ -3946,7 +3997,7 @@ files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\" or extra == \"mem-reader\""} +markers = {main = "extra == \"tree-mem\" or extra == \"all\" or extra == \"mem-reader\" or extra == \"pref-mem\""} [[package]] name = "pywin32" @@ -4955,7 +5006,7 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "extra == \"all\" and platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and (extra == \"all\" or extra == \"pref-mem\") or extra == \"pref-mem\" or extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -5578,7 +5629,7 @@ files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [[package]] name = "ujson" @@ -6301,13 +6352,14 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["chonkie", "markitdown", "neo4j", "pika", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["chonkie", "datasketch", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] +pref-mem = ["datasketch", "pymilvus"] tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "d85cb8a08870d67df6e462610231f1e735ba5293bd3fe5b0c4a212b3ccff7b72" +content-hash = "3f0d0c9a996f87d945ef8bf83eed3e20f8c420b6b39e12012d0147eda2bf4d38" diff --git a/pyproject.toml b/pyproject.toml index a03b9174b..3745582f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,12 @@ mem-reader = [ "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", # Markdown parser for various file formats ] +# PreferenceTextMemory +pref-mem = [ + "pymilvus (>=2.6.1,<3.0.0)", # Milvus Vector DB + "datasketch (>=1.6.5,<2.0.0)", # MinHash library +] + # All optional dependencies # Allow users to install with `pip install MemoryOS[all]` all = [ @@ -99,6 +105,8 @@ all = [ "pymysql (>=1.1.0,<2.0.0)", "chonkie (>=1.0.7,<2.0.0)", "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", + "pymilvus (>=2.6.1,<3.0.0)", + "datasketch (>=1.6.5,<2.0.0)", # NOT exist in the above optional groups # Because they are either huge-size dependencies or infrequently used dependencies. diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py index b8acb6c5b..85adc9304 100644 --- a/src/memos/memories/textual/prefer_text_memory/utils.py +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -1,7 +1,6 @@ import re -from datasketch import MinHash, MinHashLSH - +from memos.dependency import require_python_package from memos.memories.textual.item import TextualMemoryItem from memos.types import MessageList @@ -18,6 +17,11 @@ def convert_messages_to_string(messages: MessageList) -> str: return message_text +@require_python_package( + import_name="datasketch", + install_command="pip install datasketch", + install_link="https://github.com/ekzhu/datasketch", +) def deduplicate_preferences( prefs: list[TextualMemoryItem], similarity_threshold: float = 0.6, num_perm: int = 256 ) -> list[TextualMemoryItem]: @@ -31,6 +35,8 @@ def deduplicate_preferences( Returns: Deduplicated list of preference items """ + from datasketch import MinHash, MinHashLSH + if not prefs: return prefs diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 50d53a54a..fb19fd6ff 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -1,7 +1,5 @@ from typing import Any -from pymilvus import DataType, MilvusClient - from memos.configs.vec_db import MilvusVecDBConfig from memos.dependency import require_python_package from memos.log import get_logger @@ -22,6 +20,7 @@ class MilvusVecDB(BaseVecDB): ) def __init__(self, config: MilvusVecDBConfig): """Initialize the Milvus vector database and the collection.""" + from pymilvus import MilvusClient self.config = config @@ -35,6 +34,8 @@ def __init__(self, config: MilvusVecDBConfig): def create_schema(self): """Create schema for the milvus collection.""" + from pymilvus import DataType + schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True) schema.add_field( field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True From abb93117032f08ef4c86c20c646ecf7975da758c Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 24 Oct 2025 10:28:09 +0800 Subject: [PATCH 52/62] modify example of pref --- examples/mem_os/simple_prefs_memos_product.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py index beee837a8..40ec920f5 100644 --- a/examples/mem_os/simple_prefs_memos_product.py +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -54,7 +54,6 @@ def get_config(user_id: str): "model_schema": "memos.configs.mem_cube.GeneralMemCubeConfig", "user_id": user_id, "cube_id": f"{user_id}/mem_cube", - "config_filename": "config.json", "text_mem": { "backend": "tree_text", "config": { @@ -90,8 +89,6 @@ def get_config(user_id: str): "collection_name": [ "explicit_preference", "implicit_preference", - "topic_preference", - "user_preference", ], "vector_dimension": 768, "distance_metric": "cosine", @@ -394,11 +391,6 @@ def get_config(user_id: str): mem_cube_id=register_result["default_cube_id"], ) -user_cube_id = memory.user_manager.get_user_cubes(user_id)[0].cube_id -print(f"user_cube_id: {user_cube_id}") -memory.mem_cubes[user_cube_id].pref_mem.slow_update(user_id) - - retrieved_memories = memory.search(query="What do you like?", user_id=user_id) print( f"len_pref_memories: {len(retrieved_memories['pref_mem'][0]['memories'])}" From d3ad3652919ad98edaec19f759230463e5e9f8e3 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 24 Oct 2025 11:10:58 +0800 Subject: [PATCH 53/62] pre_commit --- evaluation/.env-example | 2 +- evaluation/scripts/PrefEval/pref_eval.py | 44 ++-- evaluation/scripts/PrefEval/pref_memos.py | 6 +- .../scripts/PrefEval/prefeval_preprocess.py | 5 +- evaluation/scripts/locomo/locomo_ingestion.py | 10 +- evaluation/scripts/locomo/locomo_search.py | 8 +- evaluation/scripts/locomo/prompts.py | 6 +- evaluation/scripts/longmemeval/lme_eval.py | 1 + .../scripts/longmemeval/lme_ingestion.py | 6 +- evaluation/scripts/longmemeval/lme_search.py | 4 +- evaluation/scripts/personamem/pm_ingestion.py | 103 +++++---- evaluation/scripts/personamem/pm_metric.py | 205 ++++++++++-------- evaluation/scripts/personamem/pm_responses.py | 61 ++++-- evaluation/scripts/personamem/pm_search.py | 88 ++++---- evaluation/scripts/run_prefeval_eval.sh | 8 +- evaluation/scripts/utils/client.py | 6 +- 16 files changed, 318 insertions(+), 245 deletions(-) diff --git a/evaluation/.env-example b/evaluation/.env-example index fc57344da..1ffaedcfd 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -38,4 +38,4 @@ MEMSCHEDULER_GRAPHDBAUTH_URI=bolt://localhost:7687 MEMSCHEDULER_GRAPHDBAUTH_USER=neo4j MEMSCHEDULER_GRAPHDBAUTH_PASSWORD=*** MEMSCHEDULER_GRAPHDBAUTH_DB_NAME=neo4j -MEMSCHEDULER_GRAPHDBAUTH_AUTO_CREATE=true \ No newline at end of file +MEMSCHEDULER_GRAPHDBAUTH_AUTO_CREATE=true diff --git a/evaluation/scripts/PrefEval/pref_eval.py b/evaluation/scripts/PrefEval/pref_eval.py index cd9c5dde2..3eb1416a5 100644 --- a/evaluation/scripts/PrefEval/pref_eval.py +++ b/evaluation/scripts/PrefEval/pref_eval.py @@ -1,14 +1,18 @@ +import argparse import asyncio import json +import os import re -import argparse -from typing import Dict, Any + from collections import Counter -from tqdm.asyncio import tqdm -import os +from typing import Any + import pandas as pd + from dotenv import load_dotenv from openai import OpenAI +from tqdm.asyncio import tqdm + load_dotenv() @@ -34,7 +38,7 @@ async def call_gpt4o_mini_async(client: OpenAI, prompt: str) -> str: ) return response.choices[0].message.content except Exception as e: - print(f"API call failed: {str(e)}") + print(f"API call failed: {e!s}") return "" @@ -45,7 +49,7 @@ def parse_xml_response(response: str, tag: str) -> str: async def evaluate_violate_preference_async( client: OpenAI, preference: str, question: str, response: str -) -> Dict[str, str]: +) -> dict[str, str]: prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's response violates the user's stated preference. Evaluate the response based on these criteria: @@ -77,8 +81,8 @@ async def evaluate_violate_preference_async( async def evaluate_acknowledge_preference_async( client: OpenAI, question: str, response: str -) -> Dict[str, str]: - prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query. +) -> dict[str, str]: + prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query. Evaluate the response based on these stringent criteria: @@ -86,7 +90,7 @@ async def evaluate_acknowledge_preference_async( Answer "Yes" if: -1. The response explicitly mentions or refers to a user preference in answering the user's question. Examples include: "Based on your previous preference for xxxx"; "Based on your previous preference, I would suggest you xxxx"; "Since you mentioned you prefer/dislike xxxx"; "Since you are a xxxx"; "I will recommend the following given your preference for xxx", etc. You should extract the "xxxx" in your answer. If it only says "Based on our previous conversation, I recommend..." and does not explicitly mention any preference, you should answer 'No'. +1. The response explicitly mentions or refers to a user preference in answering the user's question. Examples include: "Based on your previous preference for xxxx"; "Based on your previous preference, I would suggest you xxxx"; "Since you mentioned you prefer/dislike xxxx"; "Since you are a xxxx"; "I will recommend the following given your preference for xxx", etc. You should extract the "xxxx" in your answer. If it only says "Based on our previous conversation, I recommend..." and does not explicitly mention any preference, you should answer 'No'. 2. The response assumes the user preference in answering the user's question implicitly. For example, when the user asks 'Can you recommend me cars to drive?', if the response is 'Based on your preference, I will recommend non-EV cars, ...', then this indicates the assistant assumes that the user's preference is a dislike of EV cars, and you should answer 'Yes'. Answer "No" if the response does not mention or refer to any user preference explicitly or implicitly. If you cannot extract the sentence stating what the preference is, answer 'No'. @@ -107,7 +111,7 @@ async def evaluate_acknowledge_preference_async( async def evaluate_hallucinate_preference_async( client: OpenAI, preference: str, restatement: str -) -> Dict[str, str]: +) -> dict[str, str]: if not restatement.strip(): return {"explanation": "No restatement provided by assistant", "answer": "No"} prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's restatement of the user's stated preference is the same preference. Evaluate the response based on these stringent criteria to answer if the assistant has hallucinated the preference or not: @@ -124,10 +128,10 @@ async def evaluate_hallucinate_preference_async( 2. The assistant's restatement is a minor paraphrase that fully preserves the meaning and intent of the original preference. 3. The restatement is just empty, no hallucination. -Here is the information: -Original user preference: "{preference}" -Assistant's restatement: "{restatement}" -Examine the original preference and the assistant's restatement meticulously and answer. Please answer in this exact XML format without any other additional text: +Here is the information: +Original user preference: "{preference}" +Assistant's restatement: "{restatement}" +Examine the original preference and the assistant's restatement meticulously and answer. Please answer in this exact XML format without any other additional text: [1 short sentence explanation] [Yes/No]""" @@ -140,7 +144,7 @@ async def evaluate_hallucinate_preference_async( async def evaluate_helpful_response_async( client: OpenAI, question: str, response: str -) -> Dict[str, str]: +) -> dict[str, str]: prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant provides any substantive response to the user's query. Evaluate the response based on these stringent criteria: @@ -178,7 +182,7 @@ async def evaluate_helpful_response_async( } -def classify_error_type(evaluation_results: Dict[str, Any]) -> str: +def classify_error_type(evaluation_results: dict[str, Any]) -> str: violate = evaluation_results["violate_preference"]["answer"] acknowledge = evaluation_results["acknowledge_preference"]["answer"] hallucinate = evaluation_results["hallucinate_preference"]["answer"] @@ -196,7 +200,7 @@ def classify_error_type(evaluation_results: Dict[str, Any]) -> str: return "Personalized Response" -async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> Dict[str, Any]: +async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> dict[str, Any]: async with semaphore: data = json.loads(line.strip()) preference = data["preference"] @@ -227,7 +231,7 @@ async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) return result -def log_summary(error_counter: Counter, total_samples: int) -> Dict[str, Dict[str, float]]: +def log_summary(error_counter: Counter, total_samples: int) -> dict[str, dict[str, float]]: summary_data = {} print("\n--- Error Type Summary ---") @@ -251,7 +255,7 @@ def log_summary(error_counter: Counter, total_samples: int) -> Dict[str, Dict[st def generate_excel_summary( - summary_results: Dict[str, Dict[str, float]], + summary_results: dict[str, dict[str, float]], avg_search_time: float, avg_context_tokens: float, avg_add_time: float, @@ -320,7 +324,7 @@ async def main(concurrency_limit: int): client = OpenAI(api_key=API_KEY, base_url=API_URL) try: - with open(INPUT_FILE, "r", encoding="utf-8") as f: + with open(INPUT_FILE, encoding="utf-8") as f: lines = f.readlines() except FileNotFoundError: print(f"Error: Input file not found at '{INPUT_FILE}'") diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index d1a901dd2..9d8457976 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from irrelevant_conv import irre_10, irre_300 ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -159,7 +161,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/prefeval_preprocess.py b/evaluation/scripts/PrefEval/prefeval_preprocess.py index 004d5e505..9ace9dec9 100644 --- a/evaluation/scripts/PrefEval/prefeval_preprocess.py +++ b/evaluation/scripts/PrefEval/prefeval_preprocess.py @@ -1,7 +1,8 @@ -from datasets import load_dataset import json import os +from datasets import load_dataset + def convert_dataset_to_jsonl(dataset_name, output_dir="./scripts/PrefEval"): if not os.path.exists(output_dir): @@ -64,7 +65,7 @@ def process_jsonl_file(input_filepath, output_filepath): line_count = 0 print(f"Start processing file: {input_filepath}") with ( - open(input_filepath, "r", encoding="utf-8") as infile, + open(input_filepath, encoding="utf-8") as infile, open(output_filepath, "w", encoding="utf-8") as outfile, ): for line in infile: diff --git a/evaluation/scripts/locomo/locomo_ingestion.py b/evaluation/scripts/locomo/locomo_ingestion.py index 2a177a52a..375e45220 100644 --- a/evaluation/scripts/locomo/locomo_ingestion.py +++ b/evaluation/scripts/locomo/locomo_ingestion.py @@ -1,12 +1,16 @@ -import os -import sys import argparse import concurrent.futures +import os +import sys import time + from datetime import datetime, timezone + import pandas as pd + from dotenv import load_dotenv + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -88,8 +92,8 @@ def process_user(conv_idx, frame, locomo_df, version): client = None if frame == "mem0" or frame == "mem0_graph": - from utils.client import Mem0Client from prompts import custom_instructions + from utils.client import Mem0Client client = Mem0Client(enable_graph="graph" in frame) client.client.update_project(custom_instructions=custom_instructions) diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index d976b8f67..f97f2592a 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -1,14 +1,18 @@ -import os -import sys import argparse import json +import os +import sys + from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from time import time + import pandas as pd + from dotenv import load_dotenv from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) diff --git a/evaluation/scripts/locomo/prompts.py b/evaluation/scripts/locomo/prompts.py index 2827716a0..6bae05f4d 100644 --- a/evaluation/scripts/locomo/prompts.py +++ b/evaluation/scripts/locomo/prompts.py @@ -49,12 +49,12 @@ 5. Always convert relative time references to specific dates, months, or years. 6. Be as specific as possible when talking about people, places, and events 7. Timestamps in memories represent the actual time the event occurred, not the time the event was mentioned in a message. - + Clarification: When interpreting memories, use the timestamp to determine when the described event happened, not when someone talked about the event. - + Example: - + Memory: (2023-03-15T16:33:00Z) I went to the vet yesterday. Question: What day did I go to the vet? Correct Answer: March 15, 2023 diff --git a/evaluation/scripts/longmemeval/lme_eval.py b/evaluation/scripts/longmemeval/lme_eval.py index 45c038a2b..73117b925 100644 --- a/evaluation/scripts/longmemeval/lme_eval.py +++ b/evaluation/scripts/longmemeval/lme_eval.py @@ -26,6 +26,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from utils.prompts import LME_JUDGE_MODEL_TEMPLATE + encoding = tiktoken.get_encoding("cl100k_base") logging.basicConfig(level=logging.CRITICAL) transformers.logging.set_verbosity_error() diff --git a/evaluation/scripts/longmemeval/lme_ingestion.py b/evaluation/scripts/longmemeval/lme_ingestion.py index 6e9bd5ab4..f77e74e08 100644 --- a/evaluation/scripts/longmemeval/lme_ingestion.py +++ b/evaluation/scripts/longmemeval/lme_ingestion.py @@ -1,11 +1,15 @@ import argparse import os import sys + from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone + import pandas as pd + from tqdm import tqdm + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -130,7 +134,7 @@ def main(frame, version, num_workers=2): success_records = [] record_file = f"results/lme/{frame}-{version}/success_records.txt" if os.path.exists(record_file): - with open(record_file, "r") as f: + with open(record_file) as f: for i in f.readlines(): success_records.append(i.strip()) diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index a24c0eaf5..b0b11f911 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -3,6 +3,7 @@ import os import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed @@ -10,13 +11,12 @@ from time import time import pandas as pd + from tqdm import tqdm from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, - MEMOBASE_CONTEXT_TEMPLATE, MEMOS_CONTEXT_TEMPLATE, - ZEP_CONTEXT_TEMPLATE, ) diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index 5cd9d38a6..973f4fec0 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -1,43 +1,55 @@ import argparse -import os -import sys import csv import json +import os +import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from tqdm import tqdm -from utils.client import mem0_client,zep_client,memos_api_client +from utils.client import mem0_client, memos_api_client, zep_client from zep_cloud.types import Message def ingest_session(session, user_id, session_id, frame, client): messages = [] if frame == "zep": - pass for idx, msg in enumerate(session): print( - f"[{frame}] 💬 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}...") - client.memory.add(messages=[Message(role=msg["role"], role_type=msg["role"], content=msg["content"], )], ) + f"[{frame}] 💬 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}..." + ) + client.memory.add( + messages=[ + Message( + role=msg["role"], + role_type=msg["role"], + content=msg["content"], + ) + ], + ) elif frame == "mem0-local" or frame == "mem0-api": for idx, msg in enumerate(session): messages.append({"role": msg["role"], "content": msg["content"]}) print( - f"[{frame}] 📝 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}...") + f"[{frame}] 📝 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}..." + ) if frame == "mem0-local": client.add(messages=messages, user_id=user_id) elif frame == "mem0-api": - client.add(messages=messages, - user_id=user_id, - session_id=session_id, - version="v2", ) + client.add( + messages=messages, + user_id=user_id, + session_id=session_id, + version="v2", + ) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") elif frame == "memos-local" or frame == "memos-api": - if os.getenv("PRE_SPLIT_CHUNK")=="true": + if os.getenv("PRE_SPLIT_CHUNK") == "true": for i in range(0, len(session), 10): - messages = session[i: i + 10] + messages = session[i : i + 10] client.add(messages=messages, user_id=user_id, conv_id=session_id) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages") else: @@ -51,7 +63,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -63,14 +75,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -82,7 +94,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -102,13 +114,13 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as f: + with open(csv_path, newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 def ingest_conv(row_data, context, version, conv_idx, frame): end_index_in_shared_context = row_data["end_index_in_shared_context"] - context = context[:int(end_index_in_shared_context)] + context = context[: int(end_index_in_shared_context)] user_id = f"pm_exper_user_{conv_idx}_{version}" print(f"👤 User ID: {user_id}") print("\n" + "=" * 80) @@ -132,22 +144,12 @@ def ingest_conv(row_data, context, version, conv_idx, frame): print("🔌 Using Mem0 API client for ingestion...") client.delete_all(user_id=user_id) print(f"🗑️ Deleted existing memories for user {user_id}...") - elif frame == "memos-local": - client = memos_client( - mode="local", - db_name=f"pm_{frame}-{version}", - user_id=user_id, - top_k=20, - mem_cube_path=f"results/pm/{frame}-{version}/storages/{user_id}", - mem_cube_config_path="configs/mu_mem_cube_config.json", - mem_os_config_path="configs/mos_memos_config.json", - addorsearch="add", - ) - print("🔌 Using Memos Local client for ingestion...") elif frame == "memos-api": client = memos_api_client() - ingest_session(session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client) + ingest_session( + session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client + ) print(f"✅ Ingestion of conversation {conv_idx} completed") print("=" * 80) @@ -170,16 +172,25 @@ def main(frame, version, num_workers=2): with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_idx = { - executor.submit(ingest_conv, row_data=row_data, context=context, version=version, conv_idx=idx, - frame=frame, ): idx - for idx, (row_data, context) in enumerate(all_data)} - - for future in tqdm(as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations"): + executor.submit( + ingest_conv, + row_data=row_data, + context=context, + version=version, + conv_idx=idx, + frame=frame, + ): idx + for idx, (row_data, context) in enumerate(all_data) + } + + for future in tqdm( + as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" + ): idx = future_to_idx[future] try: future.result() except Exception as exc: - print(f'\n❌ Conversation {idx} generated an exception: {exc}') + print(f"\n❌ Conversation {idx} generated an exception: {exc}") end_time = datetime.now() elapsed_time = end_time - start_time @@ -195,10 +206,18 @@ def main(frame, version, num_workers=2): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Ingestion Script") - parser.add_argument("--lib", type=str, choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], - default='memos-api') - parser.add_argument("--version", type=str, default="0925-1", help="Version of the evaluation framework.") - parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.") + parser.add_argument( + "--lib", + type=str, + choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="0925-1", help="Version of the evaluation framework." + ) + parser.add_argument( + "--workers", type=int, default=3, help="Number of parallel workers for processing users." + ) args = parser.parse_args() main(frame=args.lib, version=args.version, num_workers=args.workers) diff --git a/evaluation/scripts/personamem/pm_metric.py b/evaluation/scripts/personamem/pm_metric.py index 0f6a1e138..70ee1376c 100644 --- a/evaluation/scripts/personamem/pm_metric.py +++ b/evaluation/scripts/personamem/pm_metric.py @@ -8,40 +8,48 @@ def save_to_excel(results, output_path): """Save results to Excel file""" combined_data = [] - + # Add overall statistics row - overall_row = {"category": "overall", "accuracy": results["metrics"]["accuracy"], - "accuracy_std": results["metrics"]["accuracy_std"], - "total_questions": results["metrics"]["total_questions"], - "total_runs": results["metrics"]["total_runs"]} + overall_row = { + "category": "overall", + "accuracy": results["metrics"]["accuracy"], + "accuracy_std": results["metrics"]["accuracy_std"], + "total_questions": results["metrics"]["total_questions"], + "total_runs": results["metrics"]["total_runs"], + } # Add response duration metrics for metric, value in results["metrics"]["response_duration"].items(): overall_row[f"response_{metric}"] = value - + # Add search duration metrics (if exists) if "search_duration" in results["metrics"] and results["metrics"]["search_duration"]: for metric, value in results["metrics"]["search_duration"].items(): overall_row[f"search_{metric}"] = value - + combined_data.append(overall_row) - + # Add category statistics rows for category, scores in results["category_scores"].items(): - category_row = {"category": category, "accuracy": scores["accuracy"], "accuracy_std": scores["accuracy_std"], - "total_questions": scores["total_questions"], "total_runs": scores["total_runs"]} + category_row = { + "category": category, + "accuracy": scores["accuracy"], + "accuracy_std": scores["accuracy_std"], + "total_questions": scores["total_questions"], + "total_runs": scores["total_runs"], + } # Add response duration metrics for metric, value in scores["response_duration"].items(): category_row[f"response_{metric}"] = value - + # Add search duration metrics (if exists) - if "search_duration" in scores and scores["search_duration"]: + if scores.get("search_duration"): for metric, value in scores["search_duration"].items(): category_row[f"search_{metric}"] = value - + combined_data.append(category_row) - + # Save to Excel df = pd.DataFrame(combined_data) df.to_excel(output_path, sheet_name="PersonaMem_Metrics", index=False) @@ -50,62 +58,62 @@ def save_to_excel(results, output_path): def calculate_scores(data, grade_path, output_path): """Calculate PersonaMem evaluation metrics""" - + # Initialize statistics variables category_scores = {} user_metrics = {} - + # Overall metrics - collect accuracy for each run all_response_durations = [] all_search_durations = [] total_questions = 0 - + # For calculating accuracy across multiple runs num_runs = None # Will be determined from first user's data run_accuracies = [] # List to store accuracy for each run across all users - + # Category-wise statistics category_response_durations = {} category_search_durations = {} category_run_accuracies = {} # Store accuracy for each run by category - + print(f"📋 Processing response data for {len(data)} users...") - + # First pass: determine number of runs and initialize run accuracy arrays - for user_id, user_data in data.items(): + for _, user_data in data.items(): # Skip incomplete data (users with only topic field) if len(user_data) <= 2 and "topic" in user_data: continue - + results = user_data.get("results", []) if not results: continue - + if num_runs is None: num_runs = len(results) run_accuracies = [[] for _ in range(num_runs)] # Initialize for each run print(f"📊 Detected {num_runs} runs per user") break - + if num_runs is None: print("❌ Error: Could not determine number of runs from data") return - + # Iterate through all user data for user_id, user_data in data.items(): # Skip incomplete data (users with only topic field) if len(user_data) <= 2 and "topic" in user_data: print(f"⚠️ Skipping incomplete data for user {user_id}") continue - + # Get category and results category = user_data.get("category", "unknown") results = user_data.get("results", []) - + if not results: print(f"⚠️ No results found for user {user_id}") continue - + # Initialize category if not exists if category not in category_scores: category_scores[category] = { @@ -115,39 +123,39 @@ def calculate_scores(data, grade_path, output_path): "accuracy": 0.0, "accuracy_std": 0.0, "response_duration": {}, - "search_duration": {} + "search_duration": {}, } category_response_durations[category] = [] category_search_durations[category] = [] category_run_accuracies[category] = [[] for _ in range(num_runs)] - + # Process each run for this user user_response_durations = [] for run_idx, result in enumerate(results): is_correct = result.get("is_correct", False) - + # Collect accuracy for each run (1 if correct, 0 if not) if run_idx < num_runs: run_accuracies[run_idx].append(1.0 if is_correct else 0.0) category_run_accuracies[category][run_idx].append(1.0 if is_correct else 0.0) - + # Collect response duration response_duration = result.get("response_duration_ms", 0) if response_duration > 0: user_response_durations.append(response_duration) all_response_durations.append(response_duration) category_response_durations[category].append(response_duration) - + # Get search duration (usually same for all runs) search_duration = user_data.get("search_duration_ms", 0) if search_duration > 0: all_search_durations.append(search_duration) category_search_durations[category].append(search_duration) - + # Calculate user-level accuracy (average across runs) user_correct_count = sum(1 for result in results if result.get("is_correct", False)) user_accuracy = user_correct_count / len(results) if results else 0.0 - + # Store user-level metrics user_metrics[user_id] = { "user_id": user_id, @@ -156,22 +164,26 @@ def calculate_scores(data, grade_path, output_path): "accuracy": user_accuracy, "total_runs": len(results), "correct_runs": user_correct_count, - "avg_response_duration_ms": np.mean(user_response_durations) if user_response_durations else 0.0, + "avg_response_duration_ms": np.mean(user_response_durations) + if user_response_durations + else 0.0, "search_duration_ms": search_duration, "golden_answer": user_data.get("golden_answer", ""), - "topic": user_data.get("topic", "") + "topic": user_data.get("topic", ""), } - + # Count statistics total_questions += 1 category_scores[category]["total_questions"] += 1 category_scores[category]["total_runs"] += len(results) - + # Calculate overall accuracy and std across runs overall_run_accuracies = [np.mean(run_acc) for run_acc in run_accuracies if run_acc] overall_accuracy = np.mean(overall_run_accuracies) if overall_run_accuracies else 0.0 - overall_accuracy_std = np.std(overall_run_accuracies) if len(overall_run_accuracies) > 1 else 0.0 - + overall_accuracy_std = ( + np.std(overall_run_accuracies) if len(overall_run_accuracies) > 1 else 0.0 + ) + # Calculate response duration statistics response_duration_stats = {} if all_response_durations: @@ -182,9 +194,9 @@ def calculate_scores(data, grade_path, output_path): "p95": np.percentile(all_response_durations, 95), "std": np.std(all_response_durations), "min": np.min(all_response_durations), - "max": np.max(all_response_durations) + "max": np.max(all_response_durations), } - + # Calculate search duration statistics search_duration_stats = {} if all_search_durations: @@ -195,16 +207,22 @@ def calculate_scores(data, grade_path, output_path): "p95": np.percentile(all_search_durations, 95), "std": np.std(all_search_durations), "min": np.min(all_search_durations), - "max": np.max(all_search_durations) + "max": np.max(all_search_durations), } - + # Calculate category-wise metrics for category in category_scores: # Calculate accuracy mean and std across runs for this category - cat_run_accuracies = [np.mean(run_acc) for run_acc in category_run_accuracies[category] if run_acc] - category_scores[category]["accuracy"] = np.mean(cat_run_accuracies) if cat_run_accuracies else 0.0 - category_scores[category]["accuracy_std"] = np.std(cat_run_accuracies) if len(cat_run_accuracies) > 1 else 0.0 - + cat_run_accuracies = [ + np.mean(run_acc) for run_acc in category_run_accuracies[category] if run_acc + ] + category_scores[category]["accuracy"] = ( + np.mean(cat_run_accuracies) if cat_run_accuracies else 0.0 + ) + category_scores[category]["accuracy_std"] = ( + np.std(cat_run_accuracies) if len(cat_run_accuracies) > 1 else 0.0 + ) + # Response duration statistics for this category if category_response_durations[category]: durations = category_response_durations[category] @@ -215,14 +233,19 @@ def calculate_scores(data, grade_path, output_path): "p95": np.percentile(durations, 95), "std": np.std(durations), "min": np.min(durations), - "max": np.max(durations) + "max": np.max(durations), } else: category_scores[category]["response_duration"] = { - "mean": 0.0, "median": 0.0, "p50": 0.0, "p95": 0.0, - "std": 0.0, "min": 0.0, "max": 0.0 + "mean": 0.0, + "median": 0.0, + "p50": 0.0, + "p95": 0.0, + "std": 0.0, + "min": 0.0, + "max": 0.0, } - + # Search duration statistics for this category if category_search_durations[category]: durations = category_search_durations[category] @@ -233,14 +256,19 @@ def calculate_scores(data, grade_path, output_path): "p95": np.percentile(durations, 95), "std": np.std(durations), "min": np.min(durations), - "max": np.max(durations) + "max": np.max(durations), } else: category_scores[category]["search_duration"] = { - "mean": 0.0, "median": 0.0, "p50": 0.0, "p95": 0.0, - "std": 0.0, "min": 0.0, "max": 0.0 + "mean": 0.0, + "median": 0.0, + "p50": 0.0, + "p95": 0.0, + "std": 0.0, + "min": 0.0, + "max": 0.0, } - + # Build final results results = { "metrics": { @@ -249,22 +277,22 @@ def calculate_scores(data, grade_path, output_path): "total_questions": total_questions, "total_runs": total_questions * num_runs if num_runs else 0, "response_duration": response_duration_stats, - "search_duration": search_duration_stats + "search_duration": search_duration_stats, }, "category_scores": category_scores, - "user_scores": user_metrics + "user_scores": user_metrics, } - + # Save results to JSON file with open(grade_path, "w") as outfile: json.dump(results, outfile, indent=4, ensure_ascii=False) - + # Save to Excel save_to_excel(results, output_path) - + # Print summary print_summary(results) - + return results @@ -273,19 +301,19 @@ def print_summary(results): print("\n" + "=" * 80) print("📊 PERSONAMEM EVALUATION SUMMARY".center(80)) print("=" * 80) - + # Overall accuracy accuracy = results["metrics"]["accuracy"] accuracy_std = results["metrics"]["accuracy_std"] total_questions = results["metrics"]["total_questions"] total_runs = results["metrics"]["total_runs"] - + print(f"🎯 Overall Accuracy: {accuracy:.4f} ± {accuracy_std:.4f}") print(f"📋 Total Questions: {total_questions}") print(f"🔄 Total Runs: {total_runs}") - + print("-" * 80) - + # Response duration statistics if results["metrics"]["response_duration"]: rd = results["metrics"]["response_duration"] @@ -294,7 +322,7 @@ def print_summary(results): print(f" P50: \033[96m{rd['p50']:.2f}") print(f" P95: \033[91m{rd['p95']:.2f}") print(f" Std Dev: {rd['std']:.2f}") - + # Search duration statistics if results["metrics"]["search_duration"]: sd = results["metrics"]["search_duration"] @@ -303,9 +331,9 @@ def print_summary(results): print(f" P50: \033[96m{sd['p50']:.2f}") print(f" P95: \033[91m{sd['p95']:.2f}") print(f" Std Dev: {sd['std']:.2f}") - + print("-" * 80) - + # Category-wise accuracy print("📂 Category-wise Accuracy:") for category, scores in results["category_scores"].items(): @@ -313,50 +341,47 @@ def print_summary(results): acc_std = scores["accuracy_std"] total_cat = scores["total_questions"] total_runs_cat = scores["total_runs"] - print(f" {category:<35}: {acc:.4f} ± {acc_std:.4f} ({total_cat} questions, {total_runs_cat} runs)") - + print( + f" {category:<35}: {acc:.4f} ± {acc_std:.4f} ({total_cat} questions, {total_runs_cat} runs)" + ) + print("=" * 80 + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem evaluation metrics calculation script") parser.add_argument( - "--lib", - type=str, + "--lib", + type=str, choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], required=True, help="Memory library to evaluate", - default='memos-api' - ) - parser.add_argument( - "--version", - type=str, - default="0925", - help="Evaluation framework version" + default="memos-api", ) - + parser.add_argument("--version", type=str, default="0925", help="Evaluation framework version") + args = parser.parse_args() lib, version = args.lib, args.version - + # Define file paths responses_path = f"results/pm/{lib}-{version}/{lib}_pm_responses.json" grade_path = f"results/pm/{lib}-{version}/{lib}_pm_grades.json" output_path = f"results/pm/{lib}-{version}/{lib}_pm_results.xlsx" - + print(f"📂 Loading response data from: {responses_path}") - + try: - with open(responses_path, 'r', encoding='utf-8') as file: + with open(responses_path, encoding="utf-8") as file: data = json.load(file) - + # Calculate metrics results = calculate_scores(data, grade_path, output_path) - + print(f"📁 Results saved to: {grade_path}") print(f"📊 Excel report saved to: {output_path}") - + except FileNotFoundError: print(f"❌ Error: File not found {responses_path}") print("Please make sure to run pm_responses.py first to generate response data") except Exception as e: - print(f"❌ Error occurred during processing: {e}") \ No newline at end of file + print(f"❌ Error occurred during processing: {e}") diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index c48933c11..3c19d58fc 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -10,20 +10,21 @@ from openai import OpenAI from tqdm import tqdm -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils.prompts import PM_ANSWER_PROMPT +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re +from utils.prompts import PM_ANSWER_PROMPT + def extract_choice_answer(predicted_answer, correct_answer): def _extract_only_options(text): text = text.lower() - in_parens = re.findall(r'\(([a-d])\)', text) + in_parens = re.findall(r"\(([a-d])\)", text) if in_parens: return set(in_parens) else: - return set(re.findall(r'\b([a-d])\b', text)) + return set(re.findall(r"\b([a-d])\b", text)) correct = correct_answer.lower().strip("() ") @@ -33,7 +34,7 @@ def _extract_only_options(text): if "" in predicted_answer: predicted_answer = predicted_answer.split("")[-1].strip() if predicted_answer.endswith(""): - predicted_answer = predicted_answer[:-len("")].strip() + predicted_answer = predicted_answer[: -len("")].strip() pred_options = _extract_only_options(predicted_answer) @@ -79,12 +80,14 @@ def process_qa(user_id, search_result, num_runs, llm_client): is_correct, answer = extract_choice_answer(answer, search_result.get("golden_answer", "")) response_duration_ms = (time() - start) * 1000 - run_results.append({ - "run_id": idx + 1, - "answer": answer, - "is_correct": is_correct, - "response_duration_ms": response_duration_ms, - }) + run_results.append( + { + "run_id": idx + 1, + "answer": answer, + "is_correct": is_correct, + "response_duration_ms": response_duration_ms, + } + ) response_duration_ms = sum(result["response_duration_ms"] for result in run_results) / num_runs @@ -95,8 +98,11 @@ def process_qa(user_id, search_result, num_runs, llm_client): print(f"💡 Golden Answer: {search_result.get('golden_answer', 'N/A')}") for idx, result in enumerate(run_results, start=1): print(f"\n🔄 Run {idx}/{num_runs}:") - print(f"💬 Run Answer: {result['answer'][:150]}..." if len( - result['answer']) > 150 else f"💬 Run Answer: {result['answer']}") + print( + f"💬 Run Answer: {result['answer'][:150]}..." + if len(result["answer"]) > 150 + else f"💬 Run Answer: {result['answer']}" + ) print(f"✅ Run Is Correct: {result['is_correct']}") print(f"⏱️ Run Duration: {result['response_duration_ms']:.2f} ms") print("-" * 80) @@ -122,7 +128,9 @@ def main(frame, version, num_runs=3, num_workers=4): load_dotenv() - oai_client = OpenAI(api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL")) + oai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL") + ) print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}") search_path = f"results/pm/{frame}-{version}/{frame}_pm_search_results.json" @@ -146,9 +154,9 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id[future] = user_id for future in tqdm( - as_completed(future_to_user_id), - total=len(future_to_user_id), - desc="📝 Generating responses", + as_completed(future_to_user_id), + total=len(future_to_user_id), + desc="📝 Generating responses", ): user_id = future_to_user_id[future] try: @@ -177,10 +185,21 @@ def main(frame, version, num_runs=3, num_workers=4): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Response Generation Script") - parser.add_argument("--lib", type=str, choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], default='memos-api') - parser.add_argument("--version", type=str, default="0925", help="Version of the evaluation framework.") - parser.add_argument("--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation.") - parser.add_argument("--workers", type=int, default=3, help="Number of worker threads to use for processing.") + parser.add_argument( + "--lib", + type=str, + choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="0925", help="Version of the evaluation framework." + ) + parser.add_argument( + "--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation." + ) + parser.add_argument( + "--workers", type=int, default=3, help="Number of worker threads to use for processing." + ) args = parser.parse_args() main(frame=args.lib, version=args.version, num_runs=args.num_runs, num_workers=args.workers) diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index 50f46f692..b9c5a43d5 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -3,16 +3,17 @@ import os import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import csv + from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from time import time -import csv - from tqdm import tqdm -from utils.client import mem0_client,zep_client,memos_api_client +from utils.client import mem0_client, memos_api_client, zep_client from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, @@ -109,21 +110,13 @@ def mem0_search(client, user_id, query, top_k=20, enable_graph=False, frame="mem return context, duration_ms -def memos_search(client, user_id, query, top_k, frame="memos-local"): +def memos_search(client, user_id, query, top_k, frame="memos-api"): start = time() - if frame == "memos-local": - results = client.search( - query=query, - user_id=user_id, - ) - - results = filter_memory_data(results)["text_mem"][0]["memories"] - search_memories = "\n".join([f" - {item['memory']}" for item in results]) - - elif frame == "memos-api": + if frame == "memos-api": results = client.search(query=query, user_id=user_id, top_k=top_k) - search_memories = "\n".join(f"- {entry.get('memory_value', '')}" - for entry in results.get("memory_detail_list", [])) + search_memories = "\n".join( + f"- {entry.get('memory_value', '')}" for entry in results.get("memory_detail_list", []) + ) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) duration_ms = (time() - start) * 1000 @@ -136,7 +129,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -148,14 +141,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -167,7 +160,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -190,7 +183,7 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as f: + with open(csv_path, newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 @@ -231,19 +224,6 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): client = mem0_client(mode="api") print("🔌 Using Mem0 API client for search...") context, duration_ms = mem0_search(client, user_id, question, top_k=top_k, frame=frame) - elif frame == "memos-local": - client = memos_client( - mode="local", - db_name=f"pm_{frame}-{version}", - user_id=user_id, - top_k=top_k, - mem_cube_path=f"results/pm/{frame}-{version}/storages/{user_id}", - mem_cube_config_path="configs/mu_mem_cube_config.json", - mem_os_config_path="configs/mos_memos_config.json", - addorsearch="search", - ) - print("🔌 Using Memos Local client for search...") - context, duration_ms = memos_search(client, user_id, question, frame=frame) elif frame == "memos-api": client = memos_api_client() print("🔌 Using Memos API client for search...") @@ -266,7 +246,7 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): os.makedirs(f"results/pm/{frame}-{version}/tmp", exist_ok=True) with open( - f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{conv_idx}.json", "w" + f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{conv_idx}.json", "w" ) as f: json.dump(search_results, f, indent=4) print(f"💾 \033[92mSearch results for conversation {conv_idx} saved...") @@ -299,9 +279,7 @@ def main(frame, version, top_k=20, num_workers=2): print(f"📚 Loaded PersonaMem dataset from {question_csv_path} and {context_jsonl_path}") print(f"📊 Total conversations: {total_rows}") - print( - f"⚙️ Search parameters: top_k={top_k}, workers={num_workers}" - ) + print(f"⚙️ Search parameters: top_k={top_k}, workers={num_workers}") print("-" * 80) all_search_results = defaultdict(list) @@ -320,7 +298,9 @@ def main(frame, version, top_k=20, num_workers=2): for idx, (row_data, _) in enumerate(all_data) } - for future in tqdm(as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations"): + for future in tqdm( + as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations" + ): idx = future_to_idx[future] try: search_results = future.result() @@ -328,7 +308,7 @@ def main(frame, version, top_k=20, num_workers=2): all_search_results[user_id].extend(results) print(f"✅ Conversation {idx} processed successfully.") except Exception as exc: - print(f'\n❌ Conversation {idx} generated an exception: {exc}') + print(f"\n❌ Conversation {idx} generated an exception: {exc}") end_time = datetime.now() elapsed_time = end_time - start_time @@ -337,12 +317,8 @@ def main(frame, version, top_k=20, num_workers=2): print("\n" + "=" * 80) print("✅ \033[1;32mSEARCH COMPLETE".center(80)) print("=" * 80) - print( - f"⏱️ Total time taken to search {total_rows} users: \033[92m{elapsed_time_str}" - ) - print( - f"🔄 Framework: {frame} | Version: {version} | Workers: {num_workers}" - ) + print(f"⏱️ Total time taken to search {total_rows} users: \033[92m{elapsed_time_str}") + print(f"🔄 Framework: {frame} | Version: {version} | Workers: {num_workers}") with open(f"results/pm/{frame}-{version}/{frame}_pm_search_results.json", "w") as f: json.dump(dict(all_search_results), f, indent=4) @@ -354,11 +330,21 @@ def main(frame, version, top_k=20, num_workers=2): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PersonaMem Search Script") - parser.add_argument("--lib", type=str, choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], - default='memos-api') - parser.add_argument("--version", type=str, default="0925", help="Version of the evaluation framework.") - parser.add_argument("--top_k", type=int, default=20, help="Number of top results to retrieve from the search.") - parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.") + parser.add_argument( + "--lib", + type=str, + choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"], + default="memos-api", + ) + parser.add_argument( + "--version", type=str, default="0925", help="Version of the evaluation framework." + ) + parser.add_argument( + "--top_k", type=int, default=20, help="Number of top results to retrieve from the search." + ) + parser.add_argument( + "--workers", type=int, default=3, help="Number of parallel workers for processing users." + ) args = parser.parse_args() diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh index 8e718192a..b3b6eefbe 100644 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -11,13 +11,13 @@ WORKERS=10 # Parameters for pref_memos.py TOP_K=10 ADD_TURN=0 # Options: 0, 10, or 300 -LIB="memos-api" -VERSION="1021-5" +LIB="memos-api" +VERSION="1021-5" # --- File Paths --- # You may need to adjust these paths based on your project structure. # Assumes Step 1 (preprocess) outputs this file: -PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" +PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" # Intermediate file (output of 'add' mode, input for 'process' mode) IDS_FILE="results/prefeval/pref_memos_add.jsonl" @@ -86,4 +86,4 @@ if [ $? -ne 0 ]; then fi echo "" -echo "--- PrefEval Pipeline completed successfully! ---" \ No newline at end of file +echo "--- PrefEval Pipeline completed successfully! ---" diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 87b863e86..5ab58f5b1 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -2,10 +2,14 @@ import os import sys import time + from datetime import datetime -from dotenv import load_dotenv + import requests +from dotenv import load_dotenv + + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) load_dotenv() From c78913eaa86b357bfd2b23c77b98864f6ffa4b79 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 24 Oct 2025 11:57:04 +0800 Subject: [PATCH 54/62] fix api bug --- src/memos/api/routers/server_router.py | 3 +-- src/memos/types.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index e7f50dcd3..5e8c0bff9 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -281,8 +281,6 @@ def search_memories(search_req: APISearchRequest): "text_mem": [], "act_mem": [], "para_mem": [], - "pref_mem": str, - "prefs": [], } target_session_id = search_req.session_id if not target_session_id: @@ -334,6 +332,7 @@ def _search_pref(): } ) if os.getenv("RETURN_ORIGINAL_PREF_MEM", "false").lower() == "true": + memories_result["prefs"] = [] memories_result["prefs"].append( { "cube_id": search_req.mem_cube_id, diff --git a/src/memos/types.py b/src/memos/types.py index ddcc81c54..635fabccc 100644 --- a/src/memos/types.py +++ b/src/memos/types.py @@ -56,7 +56,6 @@ class MOSSearchResult(TypedDict): text_mem: list[dict[str, str | list[TextualMemoryItem]]] act_mem: list[dict[str, str | list[ActivationMemoryItem]]] para_mem: list[dict[str, str | list[ParametricMemoryItem]]] - pref_mem: list[dict[str, str | list[TextualMemoryItem]]] # ─── API Types ──────────────────────────────────────────────────────────────────── From 06b41ca355399cd29ae99577f01b4a936d74b549 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 24 Oct 2025 16:40:40 +0800 Subject: [PATCH 55/62] merge inst_cplt to dev --- evaluation/.env-example | 7 +- evaluation/scripts/PrefEval/pref_memos.py | 171 +++++++++++------- evaluation/scripts/locomo/locomo_responses.py | 7 +- evaluation/scripts/locomo/locomo_search.py | 11 +- evaluation/scripts/locomo/prompts.py | 15 +- .../scripts/longmemeval/lme_responses.py | 8 +- evaluation/scripts/longmemeval/lme_search.py | 3 +- evaluation/scripts/personamem/pm_responses.py | 8 +- evaluation/scripts/personamem/pm_search.py | 8 +- evaluation/scripts/run_prefeval_eval.sh | 72 +++++--- evaluation/scripts/utils/prompts.py | 28 ++- 11 files changed, 235 insertions(+), 103 deletions(-) diff --git a/evaluation/.env-example b/evaluation/.env-example index 1ffaedcfd..92b9e31a9 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -14,7 +14,12 @@ CHAT_MODEL_API_KEY="sk-***REDACTED***" MEMOS_KEY="Token mpg-xxxxx" MEMOS_URL="https://apigw-pre.memtensor.cn/api/openmem/v1" -PRE_SPLIT_CHUNK=false # pre split chunk in client end + +# pref +PRE_SPLIT_CHUNK=false # pre split chunk in client end, for personamem and prefeval +INSTRUCT_COMPLETE=true # use instruct complete format or not +ABLATION_PREF=false # remove pref mem, only text mem + MEMOBASE_API_KEY="xxxxx" MEMOBASE_PROJECT_URL="http://xxx.xxx.xxx.xxx:8019" diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 9d8457976..560df7558 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -4,14 +4,12 @@ import os import sys import time - import tiktoken - from dotenv import load_dotenv -from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm +from irrelevant_conv import irre_10, irre_300 ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -20,6 +18,11 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) + +from utils.client import MemosApiClient +from utils.prompts import PREF_INSTRUCTIONS +from utils.pref_mem_utils import create_mem_string + load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") BASE_URL = os.getenv("OPENAI_BASE_URL") @@ -66,11 +69,11 @@ def add_memory_for_line( return None -def process_line_with_id( - line_data: tuple, mem_client, openai_client: OpenAI, top_k_value: int, lib: str, version: str +def search_memory_for_line( + line_data: tuple, mem_client, top_k_value: int ) -> dict: """ - Processes a single line of data using a pre-existing user_id, searching memory and generating a response. + Processes a single line of data, searching memory based on the question. """ i, line = line_data try: @@ -81,24 +84,74 @@ def process_line_with_id( metrics_dict = original_data.get("metrics", {}) if not user_id: - original_data["response"] = ( + original_data["error"] = ( "Error: user_id not found in this line. Please run 'add' mode first." ) return original_data if not question: - original_data["response"] = "Question not found in this line." + original_data["error"] = "Question not found in this line." return original_data start_time_search = time.monotonic() relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) search_memories_duration = time.monotonic() - start_time_search - memories_str = "\n".join( - f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] - ) - memory_tokens_used = len(tokenizer.encode(memories_str)) + context = create_mem_string(relevant_memories) + + memory_tokens_used = len(tokenizer.encode(context)) + + metrics_dict.update({ + "search_memories_duration_seconds": search_memories_duration, + "memory_tokens_used": memory_tokens_used, + "retrieved_memories_text": context + }) + original_data["metrics"] = metrics_dict + + return original_data + + except Exception as e: + user_id_from_data = json.loads(line).get("user_id", "N/A") + print(f"Error searching memory for line {i + 1} (user_id: {user_id_from_data}): {e}") + return None + + +def generate_response_for_line( + line_data: tuple, openai_client: OpenAI +) -> dict: + """ + Generates a response for a single line of data using pre-fetched memories. + """ + i, line = line_data + try: + original_data = json.loads(line) + + question = original_data.get("question") + metrics_dict = original_data.get("metrics", {}) + memories_str = metrics_dict.get("retrieved_memories_text") + + # If an error occurred in 'add' or 'search' mode, just pass the line through + if original_data.get("error"): + return original_data + + if not question: + original_data["error"] = "Question not found in this line." + return original_data + + # Check for None, as an empty string (no memories found) is a valid result + if memories_str is None: + original_data["error"] = ( + "Error: retrieved_memories_text not found in metrics. " + "Please run 'search' mode first." + ) + return original_data + + if os.getenv("ABLATION_PREF") == "true": + memories_str = memories_str.split("Explicit Preference:", "")[0] system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" + if os.getenv("INSTRUCT_COMPLETE") == "true": + system_prompt = system_prompt + "\n" + PREF_INSTRUCTIONS + messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, @@ -108,51 +161,42 @@ def process_line_with_id( assistant_response = response.choices[0].message.content original_data["response"] = assistant_response - metrics_dict.update( - { - "search_memories_duration_seconds": search_memories_duration, - "memory_tokens_used": memory_tokens_used, - "retrieved_memories_text": memories_str, - } - ) - original_data["metrics"] = metrics_dict - return original_data except Exception as e: user_id_from_data = json.loads(line).get("user_id", "N/A") - print(f"Error processing line {i + 1} (user_id: {user_id_from_data}): {e}") + print(f"Error generating response for line {i + 1} (user_id: {user_id_from_data}): {e}") return None def main(): parser = argparse.ArgumentParser( - description="Process conversations with MemOS. Run 'add' mode first, then 'process' mode." + description="Process conversations with MemOS. Run 'add', then 'search', then 'response'." ) parser.add_argument( "mode", - choices=["add", "process"], - help="The mode to run the script in ('add' or 'process').", + choices=["add", "search", "response"], + help="The mode to run the script in ('add', 'search', or 'response').", ) parser.add_argument("--input", required=True, help="Path to the input JSONL file.") parser.add_argument("--output", required=True, help="Path to the output JSONL file.") - parser.add_argument("--top-k", type=int, default=10, help="Number of memories to retrieve.") + parser.add_argument("--top-k", type=int, default=10, help="Number of memories to retrieve (used in 'search' mode).") parser.add_argument( "--add-turn", type=int, choices=[0, 10, 300], default=0, - help="Number of irrelevant turns to add (0, 10, or 300).", + help="Number of irrelevant turns to add (used in 'add' mode).", ) parser.add_argument( "--lib", type=str, choices=["memos-api", "memos-local"], default="memos-api", - help="Which MemOS library to use.", + help="Which MemOS library to use (used in 'add' mode).", ) parser.add_argument( - "--version", type=str, default="0929-1", help="Version identifier for user_id generation." + "--version", type=str, default="0929-1", help="Version identifier for user_id generation (used in 'add' mode)." ) parser.add_argument( "--max-workers", type=int, default=20, help="Maximum number of concurrent workers." @@ -161,15 +205,13 @@ def main(): args = parser.parse_args() try: - with open(args.input, encoding="utf-8") as infile: + with open(args.input, "r", encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") return - from utils.client import memosApiClient - - mem_client = memosApiClient() + mem_client = MemosApiClient() if args.mode == "add": print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...") @@ -202,38 +244,45 @@ def main(): outfile.write(json.dumps(result, ensure_ascii=False) + "\n") print(f"\n'add' mode complete! Data with user_id written to '{args.output}'.") - elif args.mode == "process": - print(f"Running in 'process' mode. Processing questions from '{args.input}'...") + elif args.mode == "search": + print(f"Running in 'search' mode. Searching memories based on '{args.input}'...") print(f"Retrieving top {args.top_k} memories for each query.") print(f"Using {args.max_workers} workers.") - openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) - with ( - open(args.output, "w", encoding="utf-8") as outfile, - concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, - ): - futures = [ - executor.submit( - process_line_with_id, - (i, line), - mem_client, - openai_client, - args.top_k, - args.lib, - args.version, + with open(args.output, "w", encoding="utf-8") as outfile, \ + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: + + futures = [executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) for i, line in enumerate(lines)] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Searching memories...", ) - for i, line in enumerate(lines) - ] + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'.") + + elif args.mode == "response": + print(f"Running in 'response' mode. Generating responses based on '{args.input}'...") + print(f"Using {args.max_workers} workers.") + openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) + with open(args.output, "w", encoding="utf-8") as outfile, \ + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: - pbar = tqdm( - concurrent.futures.as_completed(futures), - total=len(lines), - desc="Processing questions...", - ) - for future in pbar: - result = future.result() - if result: - outfile.write(json.dumps(result, ensure_ascii=False) + "\n") - print(f"\n'process' mode complete! Final results written to '{args.output}'.") + futures = [executor.submit(generate_response_for_line, (i, line), openai_client) for i, line in enumerate(lines)] + + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Generating responses...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"\n'response' mode complete! Final results written to '{args.output}'.") if __name__ == "__main__": diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index 4e3b966a3..0f35cb55d 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -12,6 +12,8 @@ from prompts import ANSWER_PROMPT_MEM0, ANSWER_PROMPT_MEMOS, ANSWER_PROMPT_ZEP from tqdm import tqdm +from utils.pref_mem_utils import remove_pref_mem_from_mem_string + async def locomo_response(frame, llm_client, context: str, question: str) -> str: if frame == "zep": @@ -47,7 +49,10 @@ async def process_qa(frame, qa, search_result, oai_client): gold_answer = qa.get("answer") qa_category = qa.get("category") - answer = await locomo_response(frame, oai_client, search_result.get("context"), query) + context = search_result.get("context") + if os.getenv("ABLATION_PREF") == "true" and frame == "memos-api": + context = remove_pref_mem_from_mem_string(context) + answer = await locomo_response(frame, oai_client, context, query) response_duration_ms = (time() - start) * 1000 diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index f97f2592a..cded65978 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -21,6 +21,8 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) +from utils.pref_mem_utils import create_mem_string + def mem0_search(client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b): from prompts import TEMPLATE_MEM0 @@ -104,12 +106,9 @@ def memos_api_search( start = time() search_a_results = client.search(query=query, user_id=speaker_a_user_id, top_k=top_k) search_b_results = client.search(query=query, user_id=speaker_b_user_id, top_k=top_k) - speaker_a_context = "\n".join( - [i["memory"] for i in search_a_results["text_mem"][0]["memories"]] - ) - speaker_b_context = "\n".join( - [i["memory"] for i in search_b_results["text_mem"][0]["memories"]] - ) + + speaker_a_context = create_mem_string(search_a_results) + speaker_b_context = create_mem_string(search_b_results) context = TEMPLATE_MEMOS.format( speaker_1=speaker_a, diff --git a/evaluation/scripts/locomo/prompts.py b/evaluation/scripts/locomo/prompts.py index 6bae05f4d..d489dfb3f 100644 --- a/evaluation/scripts/locomo/prompts.py +++ b/evaluation/scripts/locomo/prompts.py @@ -1,3 +1,11 @@ +PREF_INSTRUCTIONS = """ + # Note: + Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. + Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. + When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" + + ANSWER_PROMPT_MEM0 = """ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. @@ -103,7 +111,7 @@ 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. 7. Ensure your final answer is specific and avoids vague time references. - + {pref_instructions} {context} Question: {question} @@ -111,6 +119,11 @@ Answer: """ +if os.getenv("INSTRUCT_COMPLETE") == "true": + ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", PREF_INSTRUCTIONS) +else: + ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", "") + custom_instructions = """ Generate personal memories that follow these guidelines: diff --git a/evaluation/scripts/longmemeval/lme_responses.py b/evaluation/scripts/longmemeval/lme_responses.py index 3df3e2da4..8f0f25996 100644 --- a/evaluation/scripts/longmemeval/lme_responses.py +++ b/evaluation/scripts/longmemeval/lme_responses.py @@ -13,6 +13,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from utils.prompts import LME_ANSWER_PROMPT +from utils.pref_mem_utils import remove_pref_mem_from_mem_string def lme_response(llm_client, context, question, question_date): @@ -21,7 +22,6 @@ def lme_response(llm_client, context, question, question_date): question_date=question_date, context=context, ) - response = llm_client.chat.completions.create( model=os.getenv("CHAT_MODEL"), messages=[ @@ -34,12 +34,14 @@ def lme_response(llm_client, context, question, question_date): return result -def process_qa(user_id, search_result, llm_client): +def process_qa(user_id, search_result, llm_client, frame): start = time() search_result = search_result[0] question = search_result.get("question") question_date = search_result.get("date") context = search_result.get("search_context", "") + if os.getenv("ABLATION_PREF") == "true" and frame == "memos-api": + context = remove_pref_mem_from_mem_string(context) anwer = lme_response(llm_client, context, question, question_date) response_duration_ms = (time() - start) * 1000 @@ -95,7 +97,7 @@ def main(frame, version, num_workers=4): future_to_user_id = {} for user_id, search_results in lme_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, oai_client) + future = executor.submit(process_qa, user_id, search_results, oai_client, frame) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index b0b11f911..9e3e78749 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -18,6 +18,7 @@ MEM0_GRAPH_CONTEXT_TEMPLATE, MEMOS_CONTEXT_TEMPLATE, ) +from utils.pref_mem_utils import create_mem_string def mem0_search(client, query, user_id, top_k): @@ -44,7 +45,7 @@ def mem0_search(client, query, user_id, top_k): def memos_search(client, query, user_id, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) - context = "\n".join([i["memory"] for i in results["text_mem"][0]["memories"]]) + context = create_mem_string(results) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context) duration_ms = (time() - start) * 1000 return context, duration_ms diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index 3c19d58fc..84a4774f3 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -15,6 +15,7 @@ import re from utils.prompts import PM_ANSWER_PROMPT +from utils.pref_mem_utils import remove_pref_mem_from_mem_string def extract_choice_answer(predicted_answer, correct_answer): @@ -66,12 +67,15 @@ def pm_response(llm_client, context, question, options): return result -def process_qa(user_id, search_result, num_runs, llm_client): +def process_qa(user_id, search_result, num_runs, llm_client, frame): search_result = search_result[0] question = search_result.get("question") context = search_result.get("search_context", "") options = search_result.get("all_options", []) + if os.getenv("ABLATION_PREF") == "true" and frame == "memos-api": + context = remove_pref_mem_from_mem_string(context) + run_results = [] for idx in range(num_runs): @@ -150,7 +154,7 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id = {} for user_id, search_results in pm_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, num_runs, oai_client) + future = executor.submit(process_qa, user_id, search_results, num_runs, oai_client, frame) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index b9c5a43d5..e83ea6d94 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -20,6 +20,7 @@ MEMOS_CONTEXT_TEMPLATE, ZEP_CONTEXT_TEMPLATE, ) +from utils.pref_mem_utils import create_mem_string def zep_search(client, user_id, query, top_k=20): @@ -114,9 +115,7 @@ def memos_search(client, user_id, query, top_k, frame="memos-api"): start = time() if frame == "memos-api": results = client.search(query=query, user_id=user_id, top_k=top_k) - search_memories = "\n".join( - f"- {entry.get('memory_value', '')}" for entry in results.get("memory_detail_list", []) - ) + search_memories = create_mem_string(results) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) duration_ms = (time() - start) * 1000 @@ -257,7 +256,7 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): def load_existing_results(frame, version, group_idx): result_path = ( - f"results/locomo/{frame}-{version}/tmp/{frame}_locomo_search_results_{group_idx}.json" + f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{group_idx}.json" ) if os.path.exists(result_path): try: @@ -320,6 +319,7 @@ def main(frame, version, top_k=20, num_workers=2): print(f"⏱️ Total time taken to search {total_rows} users: \033[92m{elapsed_time_str}") print(f"🔄 Framework: {frame} | Version: {version} | Workers: {num_workers}") + os.makedirs(f"results/pm/{frame}-{version}/", exist_ok=True) with open(f"results/pm/{frame}-{version}/{frame}_pm_search_results.json", "w") as f: json.dump(dict(all_search_results), f, indent=4) print( diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh index b3b6eefbe..5e1160d7d 100644 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -6,24 +6,30 @@ # Number of workers for parallel processing. # This variable controls both pref_memos.py (--max-workers) # and pref_eval.py (--concurrency-limit). -WORKERS=10 +WORKERS=16 # Parameters for pref_memos.py TOP_K=10 -ADD_TURN=0 # Options: 0, 10, or 300 -LIB="memos-api" -VERSION="1021-5" +ADD_TURN=10 # Options: 0, 10, or 300 +LIB="memos-api" +VERSION="102201-${ADD_TURN}turns" # --- File Paths --- # You may need to adjust these paths based on your project structure. -# Assumes Step 1 (preprocess) outputs this file: -PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" +# Step 1 (preprocess) outputs this file: +PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" -# Intermediate file (output of 'add' mode, input for 'process' mode) -IDS_FILE="results/prefeval/pref_memos_add.jsonl" +# Create a directory name based on LIB and VERSION +OUTPUT_DIR="results/prefeval/${LIB}_${VERSION}" -# Final response file (output of 'process' mode, input for Step 3) -RESPONSE_FILE="results/prefeval/pref_memos_process.jsonl" +# Intermediate file (output of 'add' mode, input for 'search' mode) +IDS_FILE="${OUTPUT_DIR}/pref_memos_add.jsonl" + +# Intermediate file (output of 'search' mode, input for 'response' mode) +SEARCH_FILE="${OUTPUT_DIR}/pref_memos_search.jsonl" + +# Final response file (output of 'response' mode, input for Step 3) +RESPONSE_FILE="${OUTPUT_DIR}/pref_memos_response.jsonl" # Set the Hugging Face mirror endpoint @@ -31,6 +37,7 @@ export HF_ENDPOINT="https://hf-mirror.com" echo "--- Starting PrefEval Pipeline ---" echo "Configuration: WORKERS=$WORKERS, TOP_K=$TOP_K, ADD_TURN=$ADD_TURN, LIB=$LIB, VERSION=$VERSION, HF_ENDPOINT=$HF_ENDPOINT" +echo "Results will be saved to: $OUTPUT_DIR" echo "" # --- Step 1: Preprocess the data --- @@ -42,7 +49,16 @@ if [ $? -ne 0 ]; then exit 1 fi -# --- Step 2: Generate responses using MemOS (split into 'add' and 'process') --- +# --- Create output directory --- +echo "" +echo "Creating output directory: $OUTPUT_DIR" +mkdir -p $OUTPUT_DIR +if [ $? -ne 0 ]; then + echo "Error: Could not create output directory '$OUTPUT_DIR'." + exit 1 +fi + +# --- Step 2: Generate responses using MemOS (split into 'add', 'search', and 'response') --- echo "" echo "Running pref_memos.py in 'add' mode..." # Step 2a: Ingest conversations into memory and generate user_ids @@ -60,26 +76,39 @@ if [ $? -ne 0 ]; then fi echo "" -echo "Running pref_memos.py in 'process' mode..." -# Step 2b: Search memories using user_ids and generate responses -python scripts/PrefEval/pref_memos.py process \ +echo "Running pref_memos.py in 'search' mode..." +# Step 2b: Search memories using user_ids +python scripts/PrefEval/pref_memos.py search \ --input $IDS_FILE \ - --output $RESPONSE_FILE \ + --output $SEARCH_FILE \ --top-k $TOP_K \ - --max-workers $WORKERS \ - --lib $LIB \ - --version $VERSION + --max-workers $WORKERS if [ $? -ne 0 ]; then - echo "Error: pref_memos.py 'process' mode failed." + echo "Error: pref_memos.py 'search' mode failed." + exit 1 +fi + +echo "" +echo "Running pref_memos.py in 'response' mode..." +# Step 2c: Generate responses based on searched memories +python scripts/PrefEval/pref_memos.py response \ + --input $SEARCH_FILE \ + --output $RESPONSE_FILE \ + --max-workers $WORKERS + +if [ $? -ne 0 ]; then + echo "Error: pref_memos.py 'response' mode failed." exit 1 fi # --- Step 3: Evaluate the generated responses --- echo "" echo "Running pref_eval.py..." -# Pass the WORKERS variable to the script's --concurrency-limit argument -python scripts/PrefEval/pref_eval.py --concurrency-limit $WORKERS +python scripts/PrefEval/pref_eval.py \ + --input $RESPONSE_FILE \ + --concurrency-limit $WORKERS + if [ $? -ne 0 ]; then echo "Error: Evaluation script failed." exit 1 @@ -87,3 +116,4 @@ fi echo "" echo "--- PrefEval Pipeline completed successfully! ---" +echo "Final results are in $RESPONSE_FILE" \ No newline at end of file diff --git a/evaluation/scripts/utils/prompts.py b/evaluation/scripts/utils/prompts.py index bd418af54..00dd853d5 100644 --- a/evaluation/scripts/utils/prompts.py +++ b/evaluation/scripts/utils/prompts.py @@ -1,3 +1,13 @@ +import os + +PREF_INSTRUCTIONS = """ + # Note: + Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. + Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. + When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" + + LME_ANSWER_PROMPT = """ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. @@ -17,7 +27,7 @@ 5. Formulate a precise, concise answer based solely on the evidence in the memories. 6. Double-check that your answer directly addresses the question asked. 7. Ensure your final answer is specific and avoids vague time references. - + {pref_instructions} {context} Current Date: {question_date} @@ -27,6 +37,13 @@ Answer: """ + +if os.getenv("INSTRUCT_COMPLETE") == "true": + LME_ANSWER_PROMPT = LME_ANSWER_PROMPT.replace("{pref_instructions}", PREF_INSTRUCTIONS) +else: + LME_ANSWER_PROMPT = LME_ANSWER_PROMPT.replace("{pref_instructions}", "") + + PM_ANSWER_PROMPT = """ You are a helpful assistant tasked with selecting the best answer to a user question, based solely on summarized conversation memories. @@ -46,7 +63,7 @@ - Your final answer **must use parentheses**, like (a) or (b). - Do NOT list multiple choices. Choose only one. - Do NOT include extra text after . Just output the answer. - + {pref_instructions} # QUESTION: {question} @@ -58,6 +75,13 @@ """ +if os.getenv("INSTRUCT_COMPLETE") == "true": + PM_ANSWER_PROMPT = PM_ANSWER_PROMPT.replace("{pref_instructions}", PREF_INSTRUCTIONS) +else: + PM_ANSWER_PROMPT = PM_ANSWER_PROMPT.replace("{pref_instructions}", "") + + + ZEP_CONTEXT_TEMPLATE = """ FACTS and ENTITIES represent relevant context to the current conversation. From 090f0f6a808ebd0301f722751a4017a129e368fb Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 24 Oct 2025 16:45:33 +0800 Subject: [PATCH 56/62] fix pre commit --- evaluation/scripts/PrefEval/pref_memos.py | 108 ++++++++++-------- evaluation/scripts/locomo/locomo_responses.py | 1 - evaluation/scripts/locomo/locomo_search.py | 3 +- evaluation/scripts/locomo/prompts.py | 9 +- .../scripts/longmemeval/lme_responses.py | 2 +- evaluation/scripts/longmemeval/lme_search.py | 2 +- evaluation/scripts/personamem/pm_responses.py | 6 +- evaluation/scripts/personamem/pm_search.py | 6 +- evaluation/scripts/run_prefeval_eval.sh | 6 +- evaluation/scripts/utils/prompts.py | 12 +- 10 files changed, 87 insertions(+), 68 deletions(-) diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 560df7558..901565e12 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -4,12 +4,17 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm +from utils.client import MemosApiClient +from utils.pref_mem_utils import create_mem_string +from utils.prompts import PREF_INSTRUCTIONS -from irrelevant_conv import irre_10, irre_300 ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -19,9 +24,6 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) -from utils.client import MemosApiClient -from utils.prompts import PREF_INSTRUCTIONS -from utils.pref_mem_utils import create_mem_string load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") @@ -69,9 +71,7 @@ def add_memory_for_line( return None -def search_memory_for_line( - line_data: tuple, mem_client, top_k_value: int -) -> dict: +def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict: """ Processes a single line of data, searching memory based on the question. """ @@ -100,11 +100,13 @@ def search_memory_for_line( memory_tokens_used = len(tokenizer.encode(context)) - metrics_dict.update({ - "search_memories_duration_seconds": search_memories_duration, - "memory_tokens_used": memory_tokens_used, - "retrieved_memories_text": context - }) + metrics_dict.update( + { + "search_memories_duration_seconds": search_memories_duration, + "memory_tokens_used": memory_tokens_used, + "retrieved_memories_text": context, + } + ) original_data["metrics"] = metrics_dict return original_data @@ -115,9 +117,7 @@ def search_memory_for_line( return None -def generate_response_for_line( - line_data: tuple, openai_client: OpenAI -) -> dict: +def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: """ Generates a response for a single line of data using pre-fetched memories. """ @@ -180,7 +180,12 @@ def main(): ) parser.add_argument("--input", required=True, help="Path to the input JSONL file.") parser.add_argument("--output", required=True, help="Path to the output JSONL file.") - parser.add_argument("--top-k", type=int, default=10, help="Number of memories to retrieve (used in 'search' mode).") + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Number of memories to retrieve (used in 'search' mode).", + ) parser.add_argument( "--add-turn", type=int, @@ -196,7 +201,10 @@ def main(): help="Which MemOS library to use (used in 'add' mode).", ) parser.add_argument( - "--version", type=str, default="0929-1", help="Version identifier for user_id generation (used in 'add' mode)." + "--version", + type=str, + default="0929-1", + help="Version identifier for user_id generation (used in 'add' mode).", ) parser.add_argument( "--max-workers", type=int, default=20, help="Maximum number of concurrent workers." @@ -205,7 +213,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") @@ -248,40 +256,50 @@ def main(): print(f"Running in 'search' mode. Searching memories based on '{args.input}'...") print(f"Retrieving top {args.top_k} memories for each query.") print(f"Using {args.max_workers} workers.") - with open(args.output, "w", encoding="utf-8") as outfile, \ - concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: - - futures = [executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) for i, line in enumerate(lines)] + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(search_memory_for_line, (i, line), mem_client, args.top_k) + for i, line in enumerate(lines) + ] - pbar = tqdm( - concurrent.futures.as_completed(futures), - total=len(lines), - desc="Searching memories...", - ) - for future in pbar: - result = future.result() - if result: - outfile.write(json.dumps(result, ensure_ascii=False) + "\n") - print(f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'.") + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Searching memories...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + print( + f"\n'search' mode complete! Results with retrieved memories written to '{args.output}'." + ) elif args.mode == "response": print(f"Running in 'response' mode. Generating responses based on '{args.input}'...") print(f"Using {args.max_workers} workers.") openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL) - with open(args.output, "w", encoding="utf-8") as outfile, \ - concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: - - futures = [executor.submit(generate_response_for_line, (i, line), openai_client) for i, line in enumerate(lines)] + with ( + open(args.output, "w", encoding="utf-8") as outfile, + concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, + ): + futures = [ + executor.submit(generate_response_for_line, (i, line), openai_client) + for i, line in enumerate(lines) + ] - pbar = tqdm( - concurrent.futures.as_completed(futures), - total=len(lines), - desc="Generating responses...", - ) - for future in pbar: - result = future.result() - if result: - outfile.write(json.dumps(result, ensure_ascii=False) + "\n") + pbar = tqdm( + concurrent.futures.as_completed(futures), + total=len(lines), + desc="Generating responses...", + ) + for future in pbar: + result = future.result() + if result: + outfile.write(json.dumps(result, ensure_ascii=False) + "\n") print(f"\n'response' mode complete! Final results written to '{args.output}'.") diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index 0f35cb55d..f3ab39d5e 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -11,7 +11,6 @@ from openai import AsyncOpenAI from prompts import ANSWER_PROMPT_MEM0, ANSWER_PROMPT_MEMOS, ANSWER_PROMPT_ZEP from tqdm import tqdm - from utils.pref_mem_utils import remove_pref_mem_from_mem_string diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index cded65978..833daa4b8 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -11,6 +11,7 @@ from dotenv import load_dotenv from tqdm import tqdm +from utils.pref_mem_utils import create_mem_string ROOT_DIR = os.path.dirname( @@ -21,8 +22,6 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) -from utils.pref_mem_utils import create_mem_string - def mem0_search(client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b): from prompts import TEMPLATE_MEM0 diff --git a/evaluation/scripts/locomo/prompts.py b/evaluation/scripts/locomo/prompts.py index d489dfb3f..caf462f6a 100644 --- a/evaluation/scripts/locomo/prompts.py +++ b/evaluation/scripts/locomo/prompts.py @@ -1,5 +1,8 @@ +import os + + PREF_INSTRUCTIONS = """ - # Note: + # Note: Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. @@ -120,9 +123,9 @@ """ if os.getenv("INSTRUCT_COMPLETE") == "true": - ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", PREF_INSTRUCTIONS) + ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", PREF_INSTRUCTIONS) else: - ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", "") + ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", "") custom_instructions = """ Generate personal memories that follow these guidelines: diff --git a/evaluation/scripts/longmemeval/lme_responses.py b/evaluation/scripts/longmemeval/lme_responses.py index 8f0f25996..bed1c5a15 100644 --- a/evaluation/scripts/longmemeval/lme_responses.py +++ b/evaluation/scripts/longmemeval/lme_responses.py @@ -12,8 +12,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils.prompts import LME_ANSWER_PROMPT from utils.pref_mem_utils import remove_pref_mem_from_mem_string +from utils.prompts import LME_ANSWER_PROMPT def lme_response(llm_client, context, question, question_date): diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index 9e3e78749..266b0a2a0 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -13,12 +13,12 @@ import pandas as pd from tqdm import tqdm +from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, MEMOS_CONTEXT_TEMPLATE, ) -from utils.pref_mem_utils import create_mem_string def mem0_search(client, query, user_id, top_k): diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index 84a4774f3..bb2ef18f9 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -14,8 +14,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re -from utils.prompts import PM_ANSWER_PROMPT from utils.pref_mem_utils import remove_pref_mem_from_mem_string +from utils.prompts import PM_ANSWER_PROMPT def extract_choice_answer(predicted_answer, correct_answer): @@ -154,7 +154,9 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id = {} for user_id, search_results in pm_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, num_runs, oai_client, frame) + future = executor.submit( + process_qa, user_id, search_results, num_runs, oai_client, frame + ) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index e83ea6d94..19a437b5d 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -14,13 +14,13 @@ from tqdm import tqdm from utils.client import mem0_client, memos_api_client, zep_client +from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, MEMOS_CONTEXT_TEMPLATE, ZEP_CONTEXT_TEMPLATE, ) -from utils.pref_mem_utils import create_mem_string def zep_search(client, user_id, query, top_k=20): @@ -255,9 +255,7 @@ def process_user(row_data, conv_idx, frame, version, top_k=20): def load_existing_results(frame, version, group_idx): - result_path = ( - f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{group_idx}.json" - ) + result_path = f"results/pm/{frame}-{version}/tmp/{frame}_pm_search_results_{group_idx}.json" if os.path.exists(result_path): try: with open(result_path) as f: diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh index 5e1160d7d..cc7e3604f 100644 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -11,13 +11,13 @@ WORKERS=16 # Parameters for pref_memos.py TOP_K=10 ADD_TURN=10 # Options: 0, 10, or 300 -LIB="memos-api" +LIB="memos-api" VERSION="102201-${ADD_TURN}turns" # --- File Paths --- # You may need to adjust these paths based on your project structure. # Step 1 (preprocess) outputs this file: -PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" +PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" # Create a directory name based on LIB and VERSION OUTPUT_DIR="results/prefeval/${LIB}_${VERSION}" @@ -116,4 +116,4 @@ fi echo "" echo "--- PrefEval Pipeline completed successfully! ---" -echo "Final results are in $RESPONSE_FILE" \ No newline at end of file +echo "Final results are in $RESPONSE_FILE" diff --git a/evaluation/scripts/utils/prompts.py b/evaluation/scripts/utils/prompts.py index 00dd853d5..5cdc6b15c 100644 --- a/evaluation/scripts/utils/prompts.py +++ b/evaluation/scripts/utils/prompts.py @@ -1,7 +1,8 @@ import os + PREF_INSTRUCTIONS = """ - # Note: + # Note: Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. @@ -39,9 +40,9 @@ if os.getenv("INSTRUCT_COMPLETE") == "true": - LME_ANSWER_PROMPT = LME_ANSWER_PROMPT.replace("{pref_instructions}", PREF_INSTRUCTIONS) + LME_ANSWER_PROMPT = LME_ANSWER_PROMPT.replace("{pref_instructions}", PREF_INSTRUCTIONS) else: - LME_ANSWER_PROMPT = LME_ANSWER_PROMPT.replace("{pref_instructions}", "") + LME_ANSWER_PROMPT = LME_ANSWER_PROMPT.replace("{pref_instructions}", "") PM_ANSWER_PROMPT = """ @@ -76,10 +77,9 @@ if os.getenv("INSTRUCT_COMPLETE") == "true": - PM_ANSWER_PROMPT = PM_ANSWER_PROMPT.replace("{pref_instructions}", PREF_INSTRUCTIONS) + PM_ANSWER_PROMPT = PM_ANSWER_PROMPT.replace("{pref_instructions}", PREF_INSTRUCTIONS) else: - PM_ANSWER_PROMPT = PM_ANSWER_PROMPT.replace("{pref_instructions}", "") - + PM_ANSWER_PROMPT = PM_ANSWER_PROMPT.replace("{pref_instructions}", "") ZEP_CONTEXT_TEMPLATE = """ From 403ab7f1adb4ddf9db854e5a172a6c11365c2100 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 24 Oct 2025 16:46:29 +0800 Subject: [PATCH 57/62] fix pre commit --- evaluation/scripts/utils/pref_mem_utils.py | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 evaluation/scripts/utils/pref_mem_utils.py diff --git a/evaluation/scripts/utils/pref_mem_utils.py b/evaluation/scripts/utils/pref_mem_utils.py new file mode 100644 index 000000000..724e292a3 --- /dev/null +++ b/evaluation/scripts/utils/pref_mem_utils.py @@ -0,0 +1,43 @@ + + +def create_mem_string(relevant_memories) -> str: + text_memories = [] + explicit = [] + implicit = [] + for item in relevant_memories["text_mem"]: + for mem in item["memories"]: + text_memories.append(mem['memory']) + text_context = "" + if text_memories: + text_memories_text = '\n'.join(text_memories) + text_context += f"Plaintext Memory:\n{text_memories_text}\n" + for item in relevant_memories["pref_mem"]: + for mem in item["memories"]: + if mem["metadata"]["preference_type"] == "explicit_preference": + explicit.append(mem["metadata"]["explicit_preference"]) + elif mem["metadata"]["preference_type"] == "implicit_preference": + implicit.append(mem["metadata"]["implicit_preference"]) + pref_context = "" + if explicit: + explicit_text = '\n'.join(explicit) + pref_context += f"Explicit Preference:\n{explicit_text}\n" + if implicit: + implicit_text = '\n'.join(implicit) + pref_context += f"Implicit Preference:\n{implicit_text}\n" + context = "" + if text_memories and explicit and implicit: + context = f"{text_context}\n{pref_context}" + return context + + +def remove_pref_mem_from_mem_string(mem_string: str) -> str: + tmp_list = mem_string.split("Plaintext Memory:") + if len(tmp_list) > 1: + return tmp_list[1].split("Explicit Preference:")[0] + else: + return mem_string + + +if __name__ == "__main__": + mem_string = "Plaintext Memory:\nadsadsasasadsadsadsadsa\nExplicit Preference:\nThis is an explicit preference\nImplicit Preference:\nThis is an implicit preference\n" + print(remove_pref_mem_from_mem_string(mem_string)) \ No newline at end of file From 0c8905db7fba3be06fe5361b46ea79b04be92a44 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 24 Oct 2025 21:07:22 +0800 Subject: [PATCH 58/62] fix pre commit error --- evaluation/scripts/PrefEval/pref_eval.py | 44 ++++++++++--------- evaluation/scripts/PrefEval/pref_mem0.py | 6 ++- evaluation/scripts/PrefEval/pref_memobase.py | 10 +++-- evaluation/scripts/PrefEval/pref_memos.py | 20 +++++---- evaluation/scripts/PrefEval/pref_memu.py | 10 +++-- .../scripts/PrefEval/pref_supermemory.py | 8 ++-- evaluation/scripts/PrefEval/pref_zep.py | 10 +++-- evaluation/scripts/locomo/locomo_responses.py | 2 +- .../scripts/longmemeval/lme_responses.py | 2 +- evaluation/scripts/personamem/pm_ingestion.py | 15 ++++--- evaluation/scripts/personamem/pm_metric.py | 4 +- evaluation/scripts/personamem/pm_responses.py | 11 +++-- evaluation/scripts/personamem/pm_search.py | 17 ++++--- evaluation/scripts/run_pm_eval.sh | 2 +- evaluation/scripts/run_prefeval_eval.sh | 8 ++-- evaluation/scripts/utils/client.py | 6 ++- evaluation/scripts/utils/mirix_utils.py | 11 +++-- evaluation/scripts/utils/pref_mem_utils.py | 13 +++--- 18 files changed, 120 insertions(+), 79 deletions(-) diff --git a/evaluation/scripts/PrefEval/pref_eval.py b/evaluation/scripts/PrefEval/pref_eval.py index 10cf41bf3..f1966b847 100644 --- a/evaluation/scripts/PrefEval/pref_eval.py +++ b/evaluation/scripts/PrefEval/pref_eval.py @@ -1,14 +1,18 @@ +import argparse import asyncio import json +import os import re -import argparse -from typing import Dict, Any + from collections import Counter -from tqdm.asyncio import tqdm -import os +from typing import Any + import pandas as pd + from dotenv import load_dotenv from openai import OpenAI +from tqdm.asyncio import tqdm + load_dotenv() @@ -30,7 +34,7 @@ async def call_gpt4o_mini_async(client: OpenAI, prompt: str) -> str: ) return response.choices[0].message.content except Exception as e: - print(f"API call failed: {str(e)}") + print(f"API call failed: {e!s}") return "" @@ -41,7 +45,7 @@ def parse_xml_response(response: str, tag: str) -> str: async def evaluate_violate_preference_async( client: OpenAI, preference: str, question: str, response: str -) -> Dict[str, str]: +) -> dict[str, str]: prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's response violates the user's stated preference. Evaluate the response based on these criteria: @@ -73,8 +77,8 @@ async def evaluate_violate_preference_async( async def evaluate_acknowledge_preference_async( client: OpenAI, question: str, response: str -) -> Dict[str, str]: - prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query. +) -> dict[str, str]: + prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query. Evaluate the response based on these stringent criteria: @@ -82,7 +86,7 @@ async def evaluate_acknowledge_preference_async( Answer "Yes" if: -1. The response explicitly mentions or refers to a user preference in answering the user's question. Examples include: "Based on your previous preference for xxxx"; "Based on your previous preference, I would suggest you xxxx"; "Since you mentioned you prefer/dislike xxxx"; "Since you are a xxxx"; "I will recommend the following given your preference for xxx", etc. You should extract the "xxxx" in your answer. If it only says "Based on our previous conversation, I recommend..." and does not explicitly mention any preference, you should answer 'No'. +1. The response explicitly mentions or refers to a user preference in answering the user's question. Examples include: "Based on your previous preference for xxxx"; "Based on your previous preference, I would suggest you xxxx"; "Since you mentioned you prefer/dislike xxxx"; "Since you are a xxxx"; "I will recommend the following given your preference for xxx", etc. You should extract the "xxxx" in your answer. If it only says "Based on our previous conversation, I recommend..." and does not explicitly mention any preference, you should answer 'No'. 2. The response assumes the user preference in answering the user's question implicitly. For example, when the user asks 'Can you recommend me cars to drive?', if the response is 'Based on your preference, I will recommend non-EV cars, ...', then this indicates the assistant assumes that the user's preference is a dislike of EV cars, and you should answer 'Yes'. Answer "No" if the response does not mention or refer to any user preference explicitly or implicitly. If you cannot extract the sentence stating what the preference is, answer 'No'. @@ -103,7 +107,7 @@ async def evaluate_acknowledge_preference_async( async def evaluate_hallucinate_preference_async( client: OpenAI, preference: str, restatement: str -) -> Dict[str, str]: +) -> dict[str, str]: if not restatement.strip(): return {"explanation": "No restatement provided by assistant", "answer": "No"} prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's restatement of the user's stated preference is the same preference. Evaluate the response based on these stringent criteria to answer if the assistant has hallucinated the preference or not: @@ -120,10 +124,10 @@ async def evaluate_hallucinate_preference_async( 2. The assistant's restatement is a minor paraphrase that fully preserves the meaning and intent of the original preference. 3. The restatement is just empty, no hallucination. -Here is the information: -Original user preference: "{preference}" -Assistant's restatement: "{restatement}" -Examine the original preference and the assistant's restatement meticulously and answer. Please answer in this exact XML format without any other additional text: +Here is the information: +Original user preference: "{preference}" +Assistant's restatement: "{restatement}" +Examine the original preference and the assistant's restatement meticulously and answer. Please answer in this exact XML format without any other additional text: [1 short sentence explanation] [Yes/No]""" @@ -136,7 +140,7 @@ async def evaluate_hallucinate_preference_async( async def evaluate_helpful_response_async( client: OpenAI, question: str, response: str -) -> Dict[str, str]: +) -> dict[str, str]: prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant provides any substantive response to the user's query. Evaluate the response based on these stringent criteria: @@ -174,7 +178,7 @@ async def evaluate_helpful_response_async( } -def classify_error_type(evaluation_results: Dict[str, Any]) -> str: +def classify_error_type(evaluation_results: dict[str, Any]) -> str: violate = evaluation_results["violate_preference"]["answer"] acknowledge = evaluation_results["acknowledge_preference"]["answer"] hallucinate = evaluation_results["hallucinate_preference"]["answer"] @@ -192,7 +196,7 @@ def classify_error_type(evaluation_results: Dict[str, Any]) -> str: return "Personalized Response" -async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> Dict[str, Any]: +async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> dict[str, Any]: async with semaphore: data = json.loads(line.strip()) preference = data["preference"] @@ -223,7 +227,7 @@ async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) return result -def log_summary(error_counter: Counter, total_samples: int) -> Dict[str, Dict[str, float]]: +def log_summary(error_counter: Counter, total_samples: int) -> dict[str, dict[str, float]]: summary_data = {} print("\n--- Error Type Summary ---") @@ -247,7 +251,7 @@ def log_summary(error_counter: Counter, total_samples: int) -> Dict[str, Dict[st def generate_excel_summary( - summary_results: Dict[str, Dict[str, float]], + summary_results: dict[str, dict[str, float]], avg_search_time: float, avg_context_tokens: float, avg_add_time: float, @@ -317,7 +321,7 @@ async def main(concurrency_limit: int, input_file: str, output_file: str, output client = OpenAI(api_key=API_KEY, base_url=API_URL) try: - with open(input_file, "r", encoding="utf-8") as f: + with open(input_file, encoding="utf-8") as f: lines = f.readlines() except FileNotFoundError: print(f"Error: Input file not found at '{input_file}'") diff --git a/evaluation/scripts/PrefEval/pref_mem0.py b/evaluation/scripts/PrefEval/pref_mem0.py index 416d8045f..4bbdb0fd8 100644 --- a/evaluation/scripts/PrefEval/pref_mem0.py +++ b/evaluation/scripts/PrefEval/pref_mem0.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from irrelevant_conv import irre_10, irre_300 ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -199,7 +201,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_memobase.py b/evaluation/scripts/PrefEval/pref_memobase.py index 34d3ea86f..4f6174d3d 100644 --- a/evaluation/scripts/PrefEval/pref_memobase.py +++ b/evaluation/scripts/PrefEval/pref_memobase.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -import time -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -49,7 +51,7 @@ def add_memory_for_line( if conversation: messages = [] - for chunk_start in range(0, len(conversation)): + for chunk_start in range(len(conversation)): chunk = conversation[chunk_start : chunk_start + 1] timestamp_add = str(int(time.time() * 100)) time.sleep(0.001) # Ensure unique timestamp @@ -210,7 +212,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 52d30f3c6..b4c5aec0a 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -4,15 +4,21 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm - -from irrelevant_conv import irre_10, irre_300 -from utils.pref_mem_utils import create_mem_string, remove_pref_mem_from_mem_string, add_pref_instruction +from utils.pref_mem_utils import ( + add_pref_instruction, + create_mem_string, + remove_pref_mem_from_mem_string, +) from utils.prompts import PREFEVAL_ANSWER_PROMPT + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -138,13 +144,11 @@ def generate_response_for_line(line_data: tuple, openai_client: OpenAI, lib: str "Please run 'search' mode first." ) return original_data - + memories_str = remove_pref_mem_from_mem_string(memories_str, frame=lib) template = add_pref_instruction(PREFEVAL_ANSWER_PROMPT, frame=lib) - system_prompt = template.format( - context=memories_str - ) + system_prompt = template.format(context=memories_str) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, @@ -206,7 +210,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_memu.py b/evaluation/scripts/PrefEval/pref_memu.py index 719f2b488..2b9f769a4 100644 --- a/evaluation/scripts/PrefEval/pref_memu.py +++ b/evaluation/scripts/PrefEval/pref_memu.py @@ -4,12 +4,16 @@ import os import sys import time + +from datetime import datetime + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -205,7 +209,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_supermemory.py b/evaluation/scripts/PrefEval/pref_supermemory.py index 85e84b6c9..88a64038b 100644 --- a/evaluation/scripts/PrefEval/pref_supermemory.py +++ b/evaluation/scripts/PrefEval/pref_supermemory.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -201,7 +203,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_zep.py b/evaluation/scripts/PrefEval/pref_zep.py index 699660787..91aef1492 100644 --- a/evaluation/scripts/PrefEval/pref_zep.py +++ b/evaluation/scripts/PrefEval/pref_zep.py @@ -4,12 +4,16 @@ import os import sys import time + +from datetime import datetime + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -211,7 +215,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index c5e342049..4403a9206 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -11,7 +11,7 @@ from openai import AsyncOpenAI from prompts import ANSWER_PROMPT_MEM0, ANSWER_PROMPT_MEMOS, ANSWER_PROMPT_ZEP from tqdm import tqdm -from utils.pref_mem_utils import remove_pref_mem_from_mem_string, add_pref_instruction +from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string async def locomo_response(frame, llm_client, context: str, question: str) -> str: diff --git a/evaluation/scripts/longmemeval/lme_responses.py b/evaluation/scripts/longmemeval/lme_responses.py index ad34ca091..22f17c304 100644 --- a/evaluation/scripts/longmemeval/lme_responses.py +++ b/evaluation/scripts/longmemeval/lme_responses.py @@ -12,7 +12,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils.pref_mem_utils import remove_pref_mem_from_mem_string, add_pref_instruction +from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string from utils.prompts import LME_ANSWER_PROMPT diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index 8de23937c..2f7a83b81 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -3,10 +3,13 @@ import json import os import sys +import time + from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime + from tqdm import tqdm -import time + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -67,7 +70,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -79,14 +82,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -98,7 +101,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -118,7 +121,7 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as f: + with open(csv_path, newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 diff --git a/evaluation/scripts/personamem/pm_metric.py b/evaluation/scripts/personamem/pm_metric.py index 653c5fc10..ae5f79d79 100644 --- a/evaluation/scripts/personamem/pm_metric.py +++ b/evaluation/scripts/personamem/pm_metric.py @@ -44,7 +44,7 @@ def save_to_excel(results, output_path): category_row[f"response_{metric}"] = value # Add search duration metrics (if exists) - if "search_duration" in scores and scores["search_duration"]: + if scores.get("search_duration"): for metric, value in scores["search_duration"].items(): category_row[f"search_{metric}"] = value @@ -371,7 +371,7 @@ def print_summary(results): print(f"📂 Loading response data from: {responses_path}") try: - with open(responses_path, "r", encoding="utf-8") as file: + with open(responses_path, encoding="utf-8") as file: data = json.load(file) # Calculate metrics diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index 9f5dd419c..5b54f9bb8 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -10,12 +10,13 @@ from openai import OpenAI from tqdm import tqdm -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils.prompts import PM_ANSWER_PROMPT -from utils.pref_mem_utils import remove_pref_mem_from_mem_string, add_pref_instruction +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re +from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string +from utils.prompts import PM_ANSWER_PROMPT + def extract_choice_answer(predicted_answer, correct_answer): def _extract_only_options(text): @@ -153,7 +154,9 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id = {} for user_id, search_results in pm_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, num_runs, oai_client, frame) + future = executor.submit( + process_qa, user_id, search_results, num_runs, oai_client, frame + ) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index 8a6f2487d..243c64589 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -1,23 +1,26 @@ import argparse +import csv import json import os import sys + from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from time import time + from tqdm import tqdm -import csv + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, MEMOS_CONTEXT_TEMPLATE, ZEP_CONTEXT_TEMPLATE, ) -from utils.pref_mem_utils import create_mem_string def zep_search(client, user_id, query, top_k=20): @@ -108,7 +111,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -120,14 +123,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -139,7 +142,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -162,7 +165,7 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode='r', newline='', encoding='utf-8') as f: + with open(csv_path, newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 diff --git a/evaluation/scripts/run_pm_eval.sh b/evaluation/scripts/run_pm_eval.sh index f83893fed..acda03a1b 100755 --- a/evaluation/scripts/run_pm_eval.sh +++ b/evaluation/scripts/run_pm_eval.sh @@ -62,4 +62,4 @@ else fi fi -echo "All scripts completed successfully!" \ No newline at end of file +echo "All scripts completed successfully!" diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh index 001f8299d..a79cefcc2 100644 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -11,13 +11,13 @@ WORKERS=10 # Parameters for pref_memos.py TOP_K=6 ADD_TURN=0 # Options: 0, 10, or 300 -LIB="memos-api" +LIB="memos-api" VERSION="1022-0" # --- File Paths --- # You may need to adjust these paths based on your project structure. # Step 1 (preprocess) outputs this file: -PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" +PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" # Create a directory name based on the *specific* LIB (e.g., "memos") OUTPUT_DIR="results/prefeval/${LIB}_${VERSION}" @@ -54,7 +54,7 @@ export HF_ENDPOINT="https://hf-mirror.com" echo "--- Starting PrefEval Pipeline ---" echo "Configuration: WORKERS=$WORKERS, TOP_K=$TOP_K, ADD_TURN=$ADD_TURN, LIB=$LIB, VERSION=$VERSION, HF_ENDPOINT=$HF_ENDPOINT" echo "Results will be saved to: $OUTPUT_DIR" -echo "Using script: $LIB_SCRIPT (mapped from LIB=$LIB)" +echo "Using script: $LIB_SCRIPT (mapped from LIB=$LIB)" echo "" # --- Step 1: Preprocess the data --- @@ -134,7 +134,7 @@ echo "Running pref_eval.py..." python scripts/PrefEval/pref_eval.py \ --input $RESPONSE_FILE \ --concurrency-limit $WORKERS - + if [ $? -ne 0 ]; then echo "Error: Evaluation script failed." exit 1 diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 2efb0493d..ffc9dda12 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -3,11 +3,15 @@ import sys import time import uuid + from contextlib import suppress from datetime import datetime -from dotenv import load_dotenv + import requests +from dotenv import load_dotenv + + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) load_dotenv() diff --git a/evaluation/scripts/utils/mirix_utils.py b/evaluation/scripts/utils/mirix_utils.py index e1b5f3de6..63cd490df 100644 --- a/evaluation/scripts/utils/mirix_utils.py +++ b/evaluation/scripts/utils/mirix_utils.py @@ -1,18 +1,21 @@ import os + import yaml + from tqdm import tqdm def get_mirix_client(config_path, load_from=None): - if os.path.exists(os.path.expanduser(f"~/.mirix")): - os.system(f"rm -rf ~/.mirix/*") + if os.path.exists(os.path.expanduser("~/.mirix")): + os.system("rm -rf ~/.mirix/*") - with open(config_path, "r") as f: + with open(config_path) as f: agent_config = yaml.safe_load(f) os.environ["OPENAI_API_KEY"] = agent_config["api_key"] import mirix - from mirix import Mirix, EmbeddingConfig, LLMConfig + + from mirix import EmbeddingConfig, LLMConfig, Mirix embedding_default_config = EmbeddingConfig( embedding_model=agent_config["embedding_model_name"], diff --git a/evaluation/scripts/utils/pref_mem_utils.py b/evaluation/scripts/utils/pref_mem_utils.py index 5907afe9b..0fb4ec0ec 100644 --- a/evaluation/scripts/utils/pref_mem_utils.py +++ b/evaluation/scripts/utils/pref_mem_utils.py @@ -1,16 +1,18 @@ import os + from prompts import PREF_INSTRUCTIONS + def create_mem_string(relevant_memories) -> str: text_memories = [] explicit = [] implicit = [] for item in relevant_memories["text_mem"]: for mem in item["memories"]: - text_memories.append(mem['memory']) + text_memories.append(mem["memory"]) text_context = "" if text_memories: - text_memories_text = '\n'.join(text_memories) + text_memories_text = "\n".join(text_memories) text_context += f"Plaintext Memory:\n{text_memories_text}\n" for item in relevant_memories["pref_mem"]: for mem in item["memories"]: @@ -20,10 +22,10 @@ def create_mem_string(relevant_memories) -> str: implicit.append(mem["metadata"]["implicit_preference"]) pref_context = "" if explicit: - explicit_text = '\n'.join(explicit) + explicit_text = "\n".join(explicit) pref_context += f"Explicit Preference:\n{explicit_text}\n" if implicit: - implicit_text = '\n'.join(implicit) + implicit_text = "\n".join(implicit) pref_context += f"Implicit Preference:\n{implicit_text}\n" context = "" if text_memories and explicit and implicit: @@ -38,7 +40,8 @@ def remove_pref_mem_from_mem_string(mem_string: str, frame: str) -> str: return tmp_list[1].split("Explicit Preference:")[0] return mem_string + def add_pref_instruction(template: str, frame: str): if os.getenv("INSTRUCT_COMPLETE", "false").lower() == "true" and frame == "memos-api": return template.replace("{pref_instructions}", PREF_INSTRUCTIONS) - return template.replace("{pref_instructions}", "") \ No newline at end of file + return template.replace("{pref_instructions}", "") From 6a28893526d6ae3325f67bcbfc585fe68a5c957c Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 25 Oct 2025 14:30:59 +0800 Subject: [PATCH 59/62] modify code fllow reviewer --- evaluation/.env-example | 3 + evaluation/scripts/PrefEval/pref_memos.py | 16 ++-- evaluation/scripts/locomo/locomo_responses.py | 13 +++- evaluation/scripts/locomo/locomo_search.py | 4 +- evaluation/scripts/run_pm_eval.sh | 2 +- evaluation/scripts/run_prefeval_eval.sh | 0 evaluation/scripts/utils/pref_mem_utils.py | 29 +++---- src/memos/api/routers/server_router.py | 75 +++++++++++-------- src/memos/mem_cube/navie.py | 12 +-- src/memos/templates/instruction_completion.py | 5 +- 10 files changed, 88 insertions(+), 71 deletions(-) mode change 100644 => 100755 evaluation/scripts/run_prefeval_eval.sh diff --git a/evaluation/.env-example b/evaluation/.env-example index 956eead91..bda935442 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -24,6 +24,9 @@ MEMOBASE_PROJECT_URL="http://***.***.***.***:8019" # pref PRE_SPLIT_CHUNK=false # pre split chunk in client end, for personamem and prefeval +# 1. text_mem + pref_mem + instruction_completion: set INSTRUCT_COMPLETE=true, ABLATION_PREF=false +# 2. text_mem + pref_mem: set INSTRUCT_COMPLETE=false, ABLATION_PREF=false +# 3. text_mem: set INSTRUCT_COMPLETE=false, ABLATION_PREF=true INSTRUCT_COMPLETE=true # use instruct complete format or not ABLATION_PREF=false # remove pref mem, only text mem diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index b4c5aec0a..8f8e0948f 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -11,13 +11,6 @@ from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from utils.pref_mem_utils import ( - add_pref_instruction, - create_mem_string, - remove_pref_mem_from_mem_string, -) -from utils.prompts import PREFEVAL_ANSWER_PROMPT - ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -26,6 +19,15 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) + +from utils.pref_mem_utils import ( + add_pref_instruction, + create_mem_string, + remove_pref_mem_from_mem_string, +) +from utils.prompts import PREFEVAL_ANSWER_PROMPT + + load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") BASE_URL = os.getenv("OPENAI_BASE_URL") diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index 4403a9206..026e90640 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -2,7 +2,7 @@ import asyncio import json import os - +import sys from time import time import pandas as pd @@ -11,7 +11,14 @@ from openai import AsyncOpenAI from prompts import ANSWER_PROMPT_MEM0, ANSWER_PROMPT_MEMOS, ANSWER_PROMPT_ZEP from tqdm import tqdm -from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string + +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) async def locomo_response(frame, llm_client, context: str, question: str) -> str: @@ -26,6 +33,7 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str question=question, ) else: + from utils.pref_mem_utils import add_pref_instruction template = add_pref_instruction(ANSWER_PROMPT_MEMOS, frame=frame) prompt = template.format( context=context, @@ -44,6 +52,7 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str async def process_qa(frame, qa, search_result, oai_client): + from utils.pref_mem_utils import remove_pref_mem_from_mem_string start = time() query = qa.get("question") gold_answer = qa.get("answer") diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index 51756145f..3d5a088a4 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -11,8 +11,6 @@ from dotenv import load_dotenv from tqdm import tqdm -from utils.pref_mem_utils import create_mem_string - ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -23,6 +21,7 @@ sys.path.insert(0, EVAL_SCRIPTS_DIR) + def mem0_search(client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b): from prompts import TEMPLATE_MEM0 @@ -101,6 +100,7 @@ def memos_api_search( client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b ): from prompts import TEMPLATE_MEMOS + from utils.pref_mem_utils import create_mem_string start = time() search_a_results = client.search(query=query, user_id=speaker_a_user_id, top_k=top_k) diff --git a/evaluation/scripts/run_pm_eval.sh b/evaluation/scripts/run_pm_eval.sh index acda03a1b..a46440bfc 100755 --- a/evaluation/scripts/run_pm_eval.sh +++ b/evaluation/scripts/run_pm_eval.sh @@ -1,7 +1,7 @@ #!/bin/bash # Common parameters for all scripts -LIB="memu" +LIB="memos-api" VERSION="072202" WORKERS=10 TOPK=20 diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh old mode 100644 new mode 100755 diff --git a/evaluation/scripts/utils/pref_mem_utils.py b/evaluation/scripts/utils/pref_mem_utils.py index 0fb4ec0ec..33055a87d 100644 --- a/evaluation/scripts/utils/pref_mem_utils.py +++ b/evaluation/scripts/utils/pref_mem_utils.py @@ -1,8 +1,8 @@ import os - +import sys +sys.path.append(os.path.dirname(os.path.abspath(__file__))) from prompts import PREF_INSTRUCTIONS - def create_mem_string(relevant_memories) -> str: text_memories = [] explicit = [] @@ -10,27 +10,20 @@ def create_mem_string(relevant_memories) -> str: for item in relevant_memories["text_mem"]: for mem in item["memories"]: text_memories.append(mem["memory"]) - text_context = "" - if text_memories: - text_memories_text = "\n".join(text_memories) - text_context += f"Plaintext Memory:\n{text_memories_text}\n" - for item in relevant_memories["pref_mem"]: + text_memories_text = "\n".join(f"{i+1}. {mem}" for i, mem in enumerate(text_memories)).strip() + text_context = f"Plaintext Memory:\n{text_memories_text}\n" if text_memories_text else "" + + for item in relevant_memories.get("prefs", []): for mem in item["memories"]: if mem["metadata"]["preference_type"] == "explicit_preference": explicit.append(mem["metadata"]["explicit_preference"]) elif mem["metadata"]["preference_type"] == "implicit_preference": implicit.append(mem["metadata"]["implicit_preference"]) - pref_context = "" - if explicit: - explicit_text = "\n".join(explicit) - pref_context += f"Explicit Preference:\n{explicit_text}\n" - if implicit: - implicit_text = "\n".join(implicit) - pref_context += f"Implicit Preference:\n{implicit_text}\n" - context = "" - if text_memories and explicit and implicit: - context = f"{text_context}\n{pref_context}" - return context + explicit_text = "\n".join(f"{i+1}. {pref}" for i, pref in enumerate(explicit)).strip() + explicit_context = f"Explicit Preference:\n{explicit_text}\n" if explicit_text else "" + implicit_text = "\n".join(f"{i+1}. {pref}" for i, pref in enumerate(implicit)).strip() + implicit_context = f"Implicit Preference:\n{implicit_text}\n" if implicit_text else "" + return text_context + explicit_context + implicit_context def remove_pref_mem_from_mem_string(mem_string: str, frame: str) -> str: diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 5e8c0bff9..fe2d8e796 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -112,17 +112,17 @@ def _build_internet_retriever_config() -> dict[str, Any]: return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) -def _build_extractor_config() -> dict[str, Any]: +def _build_pref_extractor_config() -> dict[str, Any]: """Build extractor configuration.""" return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) -def _build_adder_config() -> dict[str, Any]: +def _build_pref_adder_config() -> dict[str, Any]: """Build adder configuration.""" return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) -def _build_retriever_config() -> dict[str, Any]: +def _build_pref_retriever_config() -> dict[str, Any]: """Build retriever configuration.""" return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) @@ -150,9 +150,9 @@ def init_server(): reranker_config = _build_reranker_config() internet_retriever_config = _build_internet_retriever_config() vector_db_config = _build_vec_db_config() - extractor_config = _build_extractor_config() - adder_config = _build_adder_config() - retriever_config = _build_retriever_config() + pref_extractor_config = _build_pref_extractor_config() + pref_adder_config = _build_pref_adder_config() + pref_retriever_config = _build_pref_retriever_config() # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) @@ -164,20 +164,20 @@ def init_server(): internet_retriever = InternetRetrieverFactory.from_config( internet_retriever_config, embedder=embedder ) - extractor = ExtractorFactory.from_config( - config_factory=extractor_config, + pref_extractor = ExtractorFactory.from_config( + config_factory=pref_extractor_config, llm_provider=llm, embedder=embedder, vector_db=vector_db, ) - adder = AdderFactory.from_config( - config_factory=adder_config, + pref_adder = AdderFactory.from_config( + config_factory=pref_adder_config, llm_provider=llm, embedder=embedder, vector_db=vector_db, ) - retriever = RetrieverFactory.from_config( - config_factory=retriever_config, + pref_retriever = RetrieverFactory.from_config( + config_factory=pref_retriever_config, llm_provider=llm, embedder=embedder, reranker=reranker, @@ -208,9 +208,9 @@ def init_server(): default_cube_config, mos_server, vector_db, - extractor, - adder, - retriever, + pref_extractor, + pref_adder, + pref_retriever, ) @@ -226,9 +226,9 @@ def init_server(): default_cube_config, mos_server, vector_db, - extractor, - adder, - retriever, + pref_extractor, + pref_adder, + pref_retriever, ) = init_server() @@ -244,9 +244,9 @@ def _create_naive_mem_cube() -> NaiveMemCube: memory_manager=memory_manager, default_cube_config=default_cube_config, vector_db=vector_db, - extractor=extractor, - adder=adder, - retriever=retriever, + pref_extractor=pref_extractor, + pref_adder=pref_adder, + pref_retriever=pref_retriever, ) return naive_mem_cube @@ -266,6 +266,23 @@ def _format_memory_item(memory_data: Any) -> dict[str, Any]: return memory +def _post_process_pref_mem(memories_result: list[dict[str, Any]], pref_formatted_mem: list[dict[str, Any]], mem_cube_id: str, handle_pref_mem: bool): + if os.getenv("RETURN_ORIGINAL_PREF_MEM", "false").lower() == "true" and pref_formatted_mem: + memories_result["prefs"] = [] + memories_result["prefs"].append( + { + "cube_id": mem_cube_id, + "memories": pref_formatted_mem, + } + ) + + if handle_pref_mem: + pref_instruction: str = instruct_completion(pref_formatted_mem) + memories_result["pref_mem"] = pref_instruction + + return memories_result + + @router.post("/search", summary="Search memories", response_model=SearchResponse) def search_memories(search_req: APISearchRequest): @@ -308,6 +325,8 @@ def _search_text(): return [_format_memory_item(data) for data in results] def _search_pref(): + if not os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true": + return [] results = naive_mem_cube.pref_mem.search( query=search_req.query, top_k=search_req.top_k, @@ -331,18 +350,8 @@ def _search_pref(): "memories": text_formatted_memories, } ) - if os.getenv("RETURN_ORIGINAL_PREF_MEM", "false").lower() == "true": - memories_result["prefs"] = [] - memories_result["prefs"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": pref_formatted_memories, - } - ) - pref_instruction: str = instruct_completion(pref_formatted_memories) - if search_req.handle_pref_mem: - memories_result["pref_mem"] = pref_instruction + memories_result = _post_process_pref_mem(memories_result, pref_formatted_memories, search_req.mem_cube_id, search_req.handle_pref_mem) return SearchResponse( message="Search completed successfully", @@ -393,6 +402,8 @@ def _process_text_mem() -> list[dict[str, str]]: ] def _process_pref_mem() -> list[dict[str, str]]: + if not os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true": + return [] pref_memories_local = naive_mem_cube.pref_mem.get_memory( [add_req.messages], type="chat", diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py index 8074c1c57..ba9f136b7 100644 --- a/src/memos/mem_cube/navie.py +++ b/src/memos/mem_cube/navie.py @@ -41,9 +41,9 @@ def __init__( default_cube_config: GeneralMemCubeConfig, vector_db: BaseVecDB, internet_retriever: None = None, - extractor: BaseExtractor | None = None, - adder: BaseAdder | None = None, - retriever: BaseRetriever | None = None, + pref_extractor: BaseExtractor | None = None, + pref_adder: BaseAdder | None = None, + pref_retriever: BaseRetriever | None = None, ): """Initialize the MemCube with a configuration.""" self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( @@ -63,9 +63,9 @@ def __init__( vector_db=vector_db, embedder=embedder, reranker=reranker, - extractor=extractor, - adder=adder, - retriever=retriever, + extractor=pref_extractor, + adder=pref_adder, + retriever=pref_retriever, ) def load( diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py index ff4d83cd8..f7fba5414 100644 --- a/src/memos/templates/instruction_completion.py +++ b/src/memos/templates/instruction_completion.py @@ -21,15 +21,14 @@ def instruct_completion( implicit_pref.append(pref) explicit_pref_str = ( - "Explicit Preferences:\n" + "\n".join(explicit_pref) if explicit_pref else "" + "Explicit Preference:\n" + "\n".join(f"{i+1}. {pref}" for i, pref in enumerate(explicit_pref)) if explicit_pref else "" ) implicit_pref_str = ( - "Implicit Preferences:\n" + "\n".join(implicit_pref) if implicit_pref else "" + "Implicit Preference:\n" + "\n".join(f"{i+1}. {pref}" for i, pref in enumerate(implicit_pref)) if implicit_pref else "" ) if not explicit_pref_str and not implicit_pref_str: return "" - if not explicit_pref_str: return implicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("explicit preferences > ", "") if not implicit_pref_str: From d7ec7642b9ff09e044a118bf4bf408b07a2ce6ac Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 25 Oct 2025 14:56:18 +0800 Subject: [PATCH 60/62] fix bug in make pre_commit --- evaluation/scripts/PrefEval/pref_memos.py | 13 ++++++------- evaluation/scripts/locomo/locomo_responses.py | 4 ++++ evaluation/scripts/locomo/locomo_search.py | 2 +- evaluation/scripts/personamem/pm_ingestion.py | 2 +- evaluation/scripts/personamem/pm_metric.py | 2 +- evaluation/scripts/utils/pref_mem_utils.py | 9 ++++++--- src/memos/api/routers/server_router.py | 19 +++++++++++++------ src/memos/templates/instruction_completion.py | 10 ++++++++-- 8 files changed, 40 insertions(+), 21 deletions(-) diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 8f8e0948f..753a77d99 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -12,6 +12,7 @@ from openai import OpenAI from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -20,13 +21,6 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) -from utils.pref_mem_utils import ( - add_pref_instruction, - create_mem_string, - remove_pref_mem_from_mem_string, -) -from utils.prompts import PREFEVAL_ANSWER_PROMPT - load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") @@ -78,6 +72,8 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di """ Processes a single line of data, searching memory based on the question. """ + from utils.pref_mem_utils import create_mem_string + i, line = line_data try: original_data = json.loads(line) @@ -123,6 +119,9 @@ def generate_response_for_line(line_data: tuple, openai_client: OpenAI, lib: str """ Generates a response for a single line of data using pre-fetched memories. """ + from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string + from utils.prompts import PREFEVAL_ANSWER_PROMPT + i, line = line_data try: original_data = json.loads(line) diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index 026e90640..2ae4dcb6e 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -3,6 +3,7 @@ import json import os import sys + from time import time import pandas as pd @@ -12,6 +13,7 @@ from prompts import ANSWER_PROMPT_MEM0, ANSWER_PROMPT_MEMOS, ANSWER_PROMPT_ZEP from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -34,6 +36,7 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str ) else: from utils.pref_mem_utils import add_pref_instruction + template = add_pref_instruction(ANSWER_PROMPT_MEMOS, frame=frame) prompt = template.format( context=context, @@ -53,6 +56,7 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str async def process_qa(frame, qa, search_result, oai_client): from utils.pref_mem_utils import remove_pref_mem_from_mem_string + start = time() query = qa.get("question") gold_answer = qa.get("answer") diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index 3d5a088a4..19efb5b92 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -12,6 +12,7 @@ from dotenv import load_dotenv from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -21,7 +22,6 @@ sys.path.insert(0, EVAL_SCRIPTS_DIR) - def mem0_search(client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b): from prompts import TEMPLATE_MEM0 diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index 2f7a83b81..5204b5c2a 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -37,7 +37,7 @@ def ingest_session(session, user_id, session_id, frame, client): client.add(messages=session, user_id=user_id, conv_id=session_id) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(session)} messages") elif frame == "memobase": - for idx, msg in enumerate(session): + for _idx, msg in enumerate(session): if msg["role"] != "system": messages.append( { diff --git a/evaluation/scripts/personamem/pm_metric.py b/evaluation/scripts/personamem/pm_metric.py index ae5f79d79..e88c538d4 100644 --- a/evaluation/scripts/personamem/pm_metric.py +++ b/evaluation/scripts/personamem/pm_metric.py @@ -80,7 +80,7 @@ def calculate_scores(data, grade_path, output_path): print(f"📋 Processing response data for {len(data)} users...") # First pass: determine number of runs and initialize run accuracy arrays - for user_id, user_data in data.items(): + for _user_id, user_data in data.items(): # Skip incomplete data (users with only topic field) if len(user_data) <= 2 and "topic" in user_data: continue diff --git a/evaluation/scripts/utils/pref_mem_utils.py b/evaluation/scripts/utils/pref_mem_utils.py index 33055a87d..22a5bb86c 100644 --- a/evaluation/scripts/utils/pref_mem_utils.py +++ b/evaluation/scripts/utils/pref_mem_utils.py @@ -1,8 +1,11 @@ import os import sys + + sys.path.append(os.path.dirname(os.path.abspath(__file__))) from prompts import PREF_INSTRUCTIONS + def create_mem_string(relevant_memories) -> str: text_memories = [] explicit = [] @@ -10,7 +13,7 @@ def create_mem_string(relevant_memories) -> str: for item in relevant_memories["text_mem"]: for mem in item["memories"]: text_memories.append(mem["memory"]) - text_memories_text = "\n".join(f"{i+1}. {mem}" for i, mem in enumerate(text_memories)).strip() + text_memories_text = "\n".join(f"{i + 1}. {mem}" for i, mem in enumerate(text_memories)).strip() text_context = f"Plaintext Memory:\n{text_memories_text}\n" if text_memories_text else "" for item in relevant_memories.get("prefs", []): @@ -19,9 +22,9 @@ def create_mem_string(relevant_memories) -> str: explicit.append(mem["metadata"]["explicit_preference"]) elif mem["metadata"]["preference_type"] == "implicit_preference": implicit.append(mem["metadata"]["implicit_preference"]) - explicit_text = "\n".join(f"{i+1}. {pref}" for i, pref in enumerate(explicit)).strip() + explicit_text = "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(explicit)).strip() explicit_context = f"Explicit Preference:\n{explicit_text}\n" if explicit_text else "" - implicit_text = "\n".join(f"{i+1}. {pref}" for i, pref in enumerate(implicit)).strip() + implicit_text = "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(implicit)).strip() implicit_context = f"Implicit Preference:\n{implicit_text}\n" if implicit_text else "" return text_context + explicit_context + implicit_context diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index fe2d8e796..e5718046e 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -266,7 +266,13 @@ def _format_memory_item(memory_data: Any) -> dict[str, Any]: return memory -def _post_process_pref_mem(memories_result: list[dict[str, Any]], pref_formatted_mem: list[dict[str, Any]], mem_cube_id: str, handle_pref_mem: bool): + +def _post_process_pref_mem( + memories_result: list[dict[str, Any]], + pref_formatted_mem: list[dict[str, Any]], + mem_cube_id: str, + handle_pref_mem: bool, +): if os.getenv("RETURN_ORIGINAL_PREF_MEM", "false").lower() == "true" and pref_formatted_mem: memories_result["prefs"] = [] memories_result["prefs"].append( @@ -279,9 +285,8 @@ def _post_process_pref_mem(memories_result: list[dict[str, Any]], pref_formatted if handle_pref_mem: pref_instruction: str = instruct_completion(pref_formatted_mem) memories_result["pref_mem"] = pref_instruction - - return memories_result + return memories_result @router.post("/search", summary="Search memories", response_model=SearchResponse) @@ -325,7 +330,7 @@ def _search_text(): return [_format_memory_item(data) for data in results] def _search_pref(): - if not os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true": + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] results = naive_mem_cube.pref_mem.search( query=search_req.query, @@ -351,7 +356,9 @@ def _search_pref(): } ) - memories_result = _post_process_pref_mem(memories_result, pref_formatted_memories, search_req.mem_cube_id, search_req.handle_pref_mem) + memories_result = _post_process_pref_mem( + memories_result, pref_formatted_memories, search_req.mem_cube_id, search_req.handle_pref_mem + ) return SearchResponse( message="Search completed successfully", @@ -402,7 +409,7 @@ def _process_text_mem() -> list[dict[str, str]]: ] def _process_pref_mem() -> list[dict[str, str]]: - if not os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true": + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] pref_memories_local = naive_mem_cube.pref_mem.get_memory( [add_req.messages], diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py index f7fba5414..7ad0fe190 100644 --- a/src/memos/templates/instruction_completion.py +++ b/src/memos/templates/instruction_completion.py @@ -21,10 +21,16 @@ def instruct_completion( implicit_pref.append(pref) explicit_pref_str = ( - "Explicit Preference:\n" + "\n".join(f"{i+1}. {pref}" for i, pref in enumerate(explicit_pref)) if explicit_pref else "" + "Explicit Preference:\n" + + "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(explicit_pref)) + if explicit_pref + else "" ) implicit_pref_str = ( - "Implicit Preference:\n" + "\n".join(f"{i+1}. {pref}" for i, pref in enumerate(implicit_pref)) if implicit_pref else "" + "Implicit Preference:\n" + + "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(implicit_pref)) + if implicit_pref + else "" ) if not explicit_pref_str and not implicit_pref_str: From aac6a39de96d23c820cd0360dd72063b75bda91b Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 25 Oct 2025 17:19:44 +0800 Subject: [PATCH 61/62] repair bug in server router --- src/memos/api/routers/server_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 21b7fb0c3..ae51d5aea 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -356,7 +356,7 @@ def _search_pref(): top_k=search_req.top_k, info={ "user_id": search_req.user_id, - "session_id": target_session_id, + "session_id": search_req.session_id, "chat_history": search_req.chat_history, }, ) From 82a8b3f3058ebb7eb1a380a19c81fea84fa69f56 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 25 Oct 2025 17:24:27 +0800 Subject: [PATCH 62/62] fix pre commit bug --- src/memos/api/routers/server_router.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index ae51d5aea..d2392f927 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -338,11 +338,17 @@ def search_memories(search_req: APISearchRequest): def _search_text(): if search_mode == SearchMode.FAST: - formatted_memories = fast_search_memories(search_req=search_req, user_context=user_context) + formatted_memories = fast_search_memories( + search_req=search_req, user_context=user_context + ) elif search_mode == SearchMode.FINE: - formatted_memories = fine_search_memories(search_req=search_req, user_context=user_context) + formatted_memories = fine_search_memories( + search_req=search_req, user_context=user_context + ) elif search_mode == SearchMode.MIXTURE: - formatted_memories = mix_search_memories(search_req=search_req, user_context=user_context) + formatted_memories = mix_search_memories( + search_req=search_req, user_context=user_context + ) else: logger.error(f"Unsupported search mode: {search_mode}") raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}")