From 21387f73ac746fdc59d0e0d78091dbd19cebd7b6 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 22 Oct 2025 19:58:33 +0800 Subject: [PATCH 1/9] feat:change reranking source filed --- src/memos/reranker/concat.py | 62 ++++++++++++++++++++++++------ src/memos/reranker/cosine_local.py | 4 +- src/memos/reranker/http_bge.py | 8 ++-- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/src/memos/reranker/concat.py b/src/memos/reranker/concat.py index 5ad339529..19be9d01d 100644 --- a/src/memos/reranker/concat.py +++ b/src/memos/reranker/concat.py @@ -2,12 +2,50 @@ from typing import Any +from memos.memories.textual.item import SourceMessage + _TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") +def get_encoded_tokens(content: str) -> int: + """ + Get encoded tokens. + Args: + content: str + Returns: + int: Encoded tokens. + """ + return len(content) + + +def truncate_data(data: list[str | dict[str, Any] | Any], max_tokens: int) -> list[str]: + """ + Truncate data to max tokens. + Args: + data: List of strings or dictionaries. + max_tokens: Maximum number of tokens. + Returns: + str: Truncated string. + """ + total_tokens = 0 + truncated_string = "" + for item in data: + if isinstance(item, SourceMessage): + content = getattr(item, "content", "") + chat_time = getattr(item, "chat_time", "") + if not content: + continue + truncated_string += f"[{chat_time}]: {content}\n" + if get_encoded_tokens(truncated_string) > max_tokens: + break + return truncated_string + + def process_source( - items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, recent_num: int = 3 + items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, + recent_num: int = 10, + max_tokens: int = 2048, ) -> str: """ Args: @@ -23,19 +61,16 @@ def process_source( memory = None for item in items: memory, source = item - for content in source: - if isinstance(content, str): - if "assistant:" in content: - continue - concat_data.append(content) + concat_data.extend(source[-recent_num:]) + truncated_string = truncate_data(concat_data, max_tokens) if memory is not None: - concat_data = [memory, *concat_data] - return "\n".join(concat_data) + truncated_string = f"{memory}\n{truncated_string}" + return truncated_string def concat_original_source( graph_results: list, - merge_field: list[str] | None = None, + rerank_source: str | None = None, ) -> list[str]: """ Merge memory items with original dialogue. @@ -45,14 +80,19 @@ def concat_original_source( Returns: list[str]: List of memory and concat orginal memory. """ - if merge_field is None: + merge_field = [] + if rerank_source is None: merge_field = ["sources"] + else: + merge_field = rerank_source.split(",") documents = [] for item in graph_results: memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m sources = [] for field in merge_field: - source = getattr(item.metadata, field, "") + source = getattr(item.metadata, field, None) + if source is None: + continue sources.append((memory, source)) concat_string = process_source(sources) documents.append(concat_string) diff --git a/src/memos/reranker/cosine_local.py b/src/memos/reranker/cosine_local.py index 000b64cf4..c9a9172b7 100644 --- a/src/memos/reranker/cosine_local.py +++ b/src/memos/reranker/cosine_local.py @@ -5,6 +5,7 @@ from .base import BaseReranker +from memos.log import get_logger if TYPE_CHECKING: from memos.memories.textual.item import TextualMemoryItem @@ -16,6 +17,7 @@ except Exception: _HAS_NUMPY = False +logger = get_logger(__name__) def _cosine_one_to_many(q: list[float], m: list[list[float]]) -> list[float]: """ @@ -92,5 +94,5 @@ def get_weight(it: TextualMemoryItem) -> float: chosen = {it.id for it, _ in top_items} remain = [(it, -1.0) for it in graph_results if it.id not in chosen] top_items.extend(remain[: top_k - len(top_items)]) - + logger.info(f"CosineLocalReranker rerank result: {top_items[:1]}") return top_items diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index f0f5d17a0..9cce12786 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -80,7 +80,7 @@ def __init__( model: str = "bge-reranker-v2-m3", timeout: int = 10, headers_extra: dict | None = None, - rerank_source: list[str] | None = None, + rerank_source: str | None = None, boost_weights: dict[str, float] | None = None, boost_default: float = 0.0, warn_unknown_filter_keys: bool = True, @@ -107,7 +107,7 @@ def __init__( self.model = model self.timeout = timeout self.headers_extra = headers_extra or {} - self.concat_source = rerank_source + self.rerank_source = rerank_source self.boost_weights = ( DEFAULT_BOOST_WEIGHTS.copy() @@ -152,8 +152,8 @@ def rerank( # Build a mapping from "payload docs index" -> "original graph_results index" # Only include items that have a non-empty string memory. This ensures that # any index returned by the server can be mapped back correctly. - if self.concat_source: - documents = concat_original_source(graph_results, self.concat_source) + if self.rerank_source: + documents = concat_original_source(graph_results, self.rerank_source) else: documents = [ (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m) From c53dcacef9edde1fbe8759d5b5a51399a174e5cb Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 22 Oct 2025 19:59:18 +0800 Subject: [PATCH 2/9] fix: code ci --- src/memos/reranker/cosine_local.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/memos/reranker/cosine_local.py b/src/memos/reranker/cosine_local.py index c9a9172b7..fc1dada2b 100644 --- a/src/memos/reranker/cosine_local.py +++ b/src/memos/reranker/cosine_local.py @@ -3,9 +3,10 @@ from typing import TYPE_CHECKING +from memos.log import get_logger + from .base import BaseReranker -from memos.log import get_logger if TYPE_CHECKING: from memos.memories.textual.item import TextualMemoryItem @@ -19,6 +20,7 @@ logger = get_logger(__name__) + def _cosine_one_to_many(q: list[float], m: list[list[float]]) -> list[float]: """ Compute cosine similarities between a single vector q and a matrix m (rows are candidates). From 39b0f91583e886346ec6a0898b68f0084c80948a Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 25 Oct 2025 14:21:12 +0800 Subject: [PATCH 3/9] feat: add reranker strategy --- src/memos/api/config.py | 5 +- src/memos/reranker/base.py | 8 +- src/memos/reranker/factory.py | 11 + src/memos/reranker/http_bge_strategy.py | 316 ++++++++++++++++++ src/memos/reranker/strategies/__init__.py | 4 + src/memos/reranker/strategies/base.py | 58 ++++ .../reranker/strategies/concat_background.py | 93 ++++++ .../reranker/strategies/dialogue_common.py | 104 ++++++ src/memos/reranker/strategies/factory.py | 26 ++ src/memos/reranker/strategies/single_turn.py | 116 +++++++ .../reranker/strategies/singleturn_outmem.py | 96 ++++++ 11 files changed, 831 insertions(+), 6 deletions(-) create mode 100644 src/memos/reranker/http_bge_strategy.py create mode 100644 src/memos/reranker/strategies/__init__.py create mode 100644 src/memos/reranker/strategies/base.py create mode 100644 src/memos/reranker/strategies/concat_background.py create mode 100644 src/memos/reranker/strategies/dialogue_common.py create mode 100644 src/memos/reranker/strategies/factory.py create mode 100644 src/memos/reranker/strategies/single_turn.py create mode 100644 src/memos/reranker/strategies/singleturn_outmem.py diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d552369c5..4cc4c8431 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -113,15 +113,16 @@ def get_reranker_config() -> dict[str, Any]: """Get embedder configuration.""" embedder_backend = os.getenv("MOS_RERANKER_BACKEND", "http_bge") - if embedder_backend == "http_bge": + if embedder_backend in ["http_bge", "http_bge_strategy"]: return { - "backend": "http_bge", + "backend": embedder_backend, "config": { "url": os.getenv("MOS_RERANKER_URL"), "model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"), "timeout": 10, "headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"), "rerank_source": os.getenv("MOS_RERANK_SOURCE"), + "reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"), }, } else: diff --git a/src/memos/reranker/base.py b/src/memos/reranker/base.py index 77a24c164..89b474e0d 100644 --- a/src/memos/reranker/base.py +++ b/src/memos/reranker/base.py @@ -2,8 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from memos.memories.textual.item import TextualMemoryItem @@ -16,9 +15,10 @@ class BaseReranker(ABC): def rerank( self, query: str, - graph_results: list, + graph_results: list[TextualMemoryItem], top_k: int, + search_filter: dict | None = None, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: """Return top_k (item, score) sorted by score desc.""" - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/src/memos/reranker/factory.py b/src/memos/reranker/factory.py index 134e29eb9..91a995ab5 100644 --- a/src/memos/reranker/factory.py +++ b/src/memos/reranker/factory.py @@ -9,6 +9,7 @@ from .cosine_local import CosineLocalReranker from .http_bge import HTTPBGEReranker from .noop import NoopReranker +from .http_bge_strategy import HTTPBGERerankerStrategy if TYPE_CHECKING: @@ -45,4 +46,14 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None: if backend in {"noop", "none", "disabled"}: return NoopReranker() + if backend in {"http_bge_strategy", "bge_strategy"}: + return HTTPBGERerankerStrategy( + reranker_url=c.get("url") or c.get("endpoint") or c.get("reranker_url"), + model=c.get("model", "bge-reranker-v2-m3"), + timeout=int(c.get("timeout", 10)), + headers_extra=c.get("headers_extra"), + rerank_source=c.get("rerank_source"), + reranker_strategy=c.get("reranker_strategy"), + ) + raise ValueError(f"Unknown reranker backend: {cfg.backend}") diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py new file mode 100644 index 000000000..1b46b13c4 --- /dev/null +++ b/src/memos/reranker/http_bge_strategy.py @@ -0,0 +1,316 @@ +# memos/reranker/http_bge.py +from __future__ import annotations + +import re + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +import requests + +from memos.log import get_logger + +from .base import BaseReranker +from memos.reranker.strategies import RerankerStrategyFactory + + + +logger = get_logger(__name__) + + +if TYPE_CHECKING: + from memos.memories.textual.item import TextualMemoryItem + +# Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...") +# before sending text to the reranker. This keeps inputs clean and +# avoids misleading the model with bracketed prefixes. +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") +DEFAULT_BOOST_WEIGHTS = {"user_id": 0.5, "tags": 0.2, "session_id": 0.3} + + +def _value_matches(item_value: Any, wanted: Any) -> bool: + """ + Generic matching: + - if item_value is list/tuple/set: check membership (any match if wanted is iterable) + - else: equality (any match if wanted is iterable) + """ + + def _iterable(x): + # exclude strings from "iterable" + return isinstance(x, Iterable) and not isinstance(x, str | bytes) + + if _iterable(item_value): + if _iterable(wanted): + return any(w in item_value for w in wanted) + return wanted in item_value + else: + if _iterable(wanted): + return any(item_value == w for w in wanted) + return item_value == wanted + + +class HTTPBGERerankerStrategy(BaseReranker): + """ + HTTP-based BGE reranker. + + This class sends (query, documents[]) to a remote HTTP endpoint that + performs cross-encoder-style re-ranking (e.g., BGE reranker) and returns + relevance scores. It then maps those scores back onto the original + TextualMemoryItem list and returns (item, score) pairs sorted by score. + + Notes + ----- + - The endpoint is expected to accept JSON: + { + "model": "", + "query": "", + "documents": ["doc1", "doc2", ...] + } + - Two response shapes are supported: + 1) {"results": [{"index": , "relevance_score": }, ...]} + where "index" refers to the *position in the documents array*. + 2) {"data": [{"score": }, ...]} (aligned by list order) + - If the service fails or responds unexpectedly, this falls back to + returning the original items with 0.0 scores (best-effort). + """ + + def __init__( + self, + reranker_url: str, + token: str = "", + model: str = "bge-reranker-v2-m3", + timeout: int = 10, + headers_extra: dict | None = None, + rerank_source: str | None = None, + boost_weights: dict[str, float] | None = None, + boost_default: float = 0.0, + warn_unknown_filter_keys: bool = True, + reranker_strategy: str = "singleturn_outputmem", + **kwargs, + ): + """ + Parameters + ---------- + reranker_url : str + HTTP endpoint for the reranker service. + token : str, optional + Bearer token for auth. If non-empty, added to the Authorization header. + model : str, optional + Model identifier understood by the server. + timeout : int, optional + Request timeout (seconds). + headers_extra : dict | None, optional + Additional headers to merge into the request headers. + """ + if not reranker_url: + raise ValueError("reranker_url must not be empty") + self.reranker_url = reranker_url + self.token = token or "" + self.model = model + self.timeout = timeout + self.headers_extra = headers_extra or {} + + self.boost_weights = ( + DEFAULT_BOOST_WEIGHTS.copy() + if boost_weights is None + else {k: float(v) for k, v in boost_weights.items()} + ) + self.boost_default = float(boost_default) + self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys) + self._warned_missing_keys: set[str] = set() + self.reranker_strategy = RerankerStrategyFactory.from_config(reranker_strategy) + + def rerank( + self, + query: str, + graph_results: list[TextualMemoryItem], + top_k: int, + search_filter: dict | None = None, + **kwargs, + ) -> list[tuple[TextualMemoryItem, float]]: + """ + Rank candidate memories by relevance to the query. + + Parameters + ---------- + query : str + The search query. + graph_results : list[TextualMemoryItem] + Candidate items to re-rank. Each item is expected to have a + `.memory` str field; non-strings are ignored. + top_k : int + Return at most this many items. + search_filter : dict | None + Currently unused. Present to keep signature compatible. + + Returns + ------- + list[tuple[TextualMemoryItem, float]] + Re-ranked items with scores, sorted descending by score. + """ + if not graph_results: + return [] + + tracker, original_items, documents = self.reranker_strategy.prepare_documents(query, graph_results, top_k) + + logger.info( + f"[HTTPBGEWithSourceReranker] strategy: {self.reranker_strategy}, " + f"query: {query}, documents count: {len(documents)}" + ) + logger.debug(f"[HTTPBGEWithSourceReranker] sample documents: {documents[:2]}...") + + if not documents: + return [] + + headers = {"Content-Type": "application/json", **self.headers_extra} + payload = {"model": self.model, "query": query, "documents": documents} + + try: + # Make the HTTP request to the reranker service + resp = requests.post( + self.reranker_url, headers=headers, json=payload, timeout=30 + ) + resp.raise_for_status() + data = resp.json() + + scored_items: list[tuple[TextualMemoryItem, float]] = [] + + if "results" in data: + # Format: + # dict("results": [{"index": int, "relevance_score": float}, + # ...]) + rows = data.get("results", []) + + ranked_indices = [] + scores = [] + for r in rows: + idx = r.get("index") + # The returned index refers to 'documents' (i.e., our 'pairs' order), + # so we must map it back to the original graph_results index. + if isinstance(idx, int) and 0 <= idx < len(graph_results): + raw_score = float(r.get("relevance_score", r.get("score", 0.0))) + ranked_indices.append(idx) + scores.append(raw_score) + reconstructed_items = self.reranker_strategy.reconstruct_items( + ranked_indices=ranked_indices, + scores=scores, + tracker=tracker, + original_items=original_items, + top_k=top_k, + graph_results=graph_results, + documents=documents + ) + return reconstructed_items + + elif "data" in data: + # Format: {"data": [{"score": float}, ...]} aligned by list order + rows = data.get("data", []) + # Build a list of scores aligned with our 'documents' (pairs) + score_list = [float(r.get("score", 0.0)) for r in rows] + + if len(score_list) < len(graph_results): + score_list += [0.0] * (len(graph_results) - len(score_list)) + elif len(score_list) > len(graph_results): + score_list = score_list[: len(graph_results)] + + scored_items = [] + for item, raw_score in zip(graph_results, score_list, strict=False): + score = self._apply_boost_generic(item, raw_score, search_filter) + scored_items.append((item, score)) + + scored_items.sort(key=lambda x: x[1], reverse=True) + return scored_items[: min(top_k, len(scored_items))] + + else: + # Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs + # Note: we use 'pairs' to keep alignment with valid (string) docs. + return [(item, 0.0) for item in graph_results[:top_k]] + + except Exception as e: + # Network error, timeout, JSON decode error, etc. + # Degrade gracefully by returning first top_k valid docs with 0.0 score. + logger.error(f"[HTTPBGEReranker] request failed: {e}") + return [(item, 0.0) for item in graph_results[:top_k]] + + def _get_attr_or_key(self, obj: Any, key: str) -> Any: + """ + Resolve `key` on `obj` with one-level fallback into `obj.metadata`. + + Priority: + 1) obj. + 2) obj[key] + 3) obj.metadata. + 4) obj.metadata[key] + """ + if obj is None: + return None + + # support input like "metadata.user_id" + if "." in key: + head, tail = key.split(".", 1) + base = self._get_attr_or_key(obj, head) + return self._get_attr_or_key(base, tail) + + def _resolve(o: Any, k: str): + if o is None: + return None + v = getattr(o, k, None) + if v is not None: + return v + if hasattr(o, "get"): + try: + return o.get(k) + except Exception: + return None + return None + + # 1) find in obj + v = _resolve(obj, key) + if v is not None: + return v + + # 2) find in obj.metadata + meta = _resolve(obj, "metadata") + if meta is not None: + return _resolve(meta, key) + + return None + + def _apply_boost_generic( + self, + item: TextualMemoryItem, + base_score: float, + search_filter: dict | None, + ) -> float: + """ + Multiply base_score by (1 + weight) for each matching key in search_filter. + - key resolution: self._get_attr_or_key(item, key) + - weight = boost_weights.get(key, self.boost_default) + - unknown key -> one-time warning + """ + if not search_filter: + return base_score + + score = float(base_score) + + for key, wanted in search_filter.items(): + # _get_attr_or_key automatically find key in item and + # item.metadata ("metadata.user_id" supported) + resolved = self._get_attr_or_key(item, key) + + if resolved is None: + if self.warn_unknown_filter_keys and key not in self._warned_missing_keys: + logger.warning( + "[HTTPBGEReranker] search_filter key '%s' not found on TextualMemoryItem or metadata", + key, + ) + self._warned_missing_keys.add(key) + continue + + if _value_matches(resolved, wanted): + w = float(self.boost_weights.get(key, self.boost_default)) + if w != 0.0: + score *= 1.0 + w + score = min(max(0.0, score), 1.0) + + return score diff --git a/src/memos/reranker/strategies/__init__.py b/src/memos/reranker/strategies/__init__.py new file mode 100644 index 000000000..36186c0ac --- /dev/null +++ b/src/memos/reranker/strategies/__init__.py @@ -0,0 +1,4 @@ +from .factory import RerankerStrategyFactory + + +__all__ = ["RerankerStrategyFactory"] \ No newline at end of file diff --git a/src/memos/reranker/strategies/base.py b/src/memos/reranker/strategies/base.py new file mode 100644 index 000000000..3a35d2baa --- /dev/null +++ b/src/memos/reranker/strategies/base.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any +from memos.memories.textual.item import TextualMemoryItem +from .dialogue_common import DialogueRankingTracker + +class BaseRerankerStrategy(ABC): + """Abstract interface for memory rerankers with concatenation strategy.""" + + @abstractmethod + def prepare_documents( + self, + query: str, + graph_results: list[TextualMemoryItem], + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents for ranking based on the strategy. + + Args: + query: The search query + graph_results: List of TextualMemoryItem objects to process + top_k: Maximum number of items to return + **kwargs: Additional strategy-specific parameters + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + raise NotImplementedError + + @abstractmethod + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[TextualMemoryItem, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked results. + + Args: + ranked_indices: List of indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + **kwargs: Additional strategy-specific parameters + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + raise NotImplementedError \ No newline at end of file diff --git a/src/memos/reranker/strategies/concat_background.py b/src/memos/reranker/strategies/concat_background.py new file mode 100644 index 000000000..6583ba96c --- /dev/null +++ b/src/memos/reranker/strategies/concat_background.py @@ -0,0 +1,93 @@ +# memos/reranker/strategies/single_turn.py +from __future__ import annotations +import re +from typing import Any +from collections import defaultdict +from copy import deepcopy +from .base import BaseRerankerStrategy +from .dialogue_common import DialogueRankingTracker, strip_memory_tags, extract_content + +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") + +class ConcatBackgroundStrategy(BaseRerankerStrategy): + """ + Concat background strategy. + + This strategy processes dialogue pairs by concatenating background and + user and assistant messages into single strings for ranking. Each dialogue pair becomes a + separate document for ranking. + """ + + def prepare_documents( + self, + query: str, + graph_results: list, + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents based on single turn concatenation strategy. + + Args: + query: The search query + graph_results: List of graph results + top_k: Maximum number of items to return + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + + original_items = {} + tracker = DialogueRankingTracker() + documents = [] + for item in graph_results: + memory = getattr(item, "memory", None) + if isinstance(memory, str): + memory = _TAG1.sub("", memory) + + background = "" + if hasattr(item, "metadata") and hasattr(item.metadata, "background"): + background = getattr(item.metadata, "background", "") + if not isinstance(background, str): + background = "" + + documents.append(f"{memory}\n{background}") + return tracker, original_items, documents + + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[Any, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked dialogue pairs. + + Args: + ranked_indices: List of dialogue pair indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + graph_results = kwargs.get("graph_results", None) + documents = kwargs.get("documents", None) + reconstructed_items = [] + for idx in ranked_indices: + item = graph_results[idx] + item.memory = f"{item.memory}\n{documents[idx]}" + reconstructed_items.append((item, scores[idx])) + + reconstructed_items.sort(key=lambda x: x[1], reverse=True) + return reconstructed_items[:top_k] + + diff --git a/src/memos/reranker/strategies/dialogue_common.py b/src/memos/reranker/strategies/dialogue_common.py new file mode 100644 index 000000000..bf9c7d360 --- /dev/null +++ b/src/memos/reranker/strategies/dialogue_common.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import re +from typing import Any, Literal +from pydantic import BaseModel +from memos.memories.textual.item import SourceMessage + +# Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...") +# before sending text to the reranker. This keeps inputs clean and +# avoids misleading the model with bracketed prefixes. +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") + + +def strip_memory_tags(item: TextualMemoryItem) -> str: + """Strip leading tags from memory text.""" + memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m + return memory + +def extract_content(msg: dict[str, Any] | str) -> str: + """Extract content from message, handling both string and dict formats.""" + if isinstance(msg, dict): + return msg.get('content', str(msg)) + if isinstance(msg, SourceMessage): + return msg.content + return str(msg) + + +class DialoguePair(BaseModel): + """Represents a single dialogue pair extracted from sources.""" + + pair_id: str # Unique identifier for this dialogue pair + memory_id: str # ID of the source TextualMemoryItem + memory: str + pair_index: int # Index of this pair within the source memory's dialogue + user_msg: str | dict[str, Any] | SourceMessage # User message content + assistant_msg: str | dict[str, Any] | SourceMessage # Assistant message content + combined_text: str # The concatenated text used for ranking + chat_time: str | None = None + + @property + def user_content(self) -> str: + """Get user message content as string.""" + return extract_content(self.user_msg) + + @property + def assistant_content(self) -> str: + """Get assistant message content as string.""" + return extract_content(self.assistant_msg) + + +class DialogueRankingTracker: + """Tracks dialogue pairs and their rankings for memory reconstruction.""" + + def __init__(self): + self.dialogue_pairs: list[DialoguePair] = [] + + def add_dialogue_pair( + self, + memory_id: str, + pair_index: int, + user_msg: str | dict[str, Any], + assistant_msg: str | dict[str, Any], + memory: str, + chat_time: str | None = None, + concat_format: Literal["user_assistant", "user_only"] = "user_assistant" + ) -> str: + """Add a dialogue pair and return its unique ID.""" + user_content = extract_content(user_msg) + assistant_content = extract_content(assistant_msg) + if concat_format == "user_assistant": + combined_text = f"[{chat_time}]: \nuser: {user_content}\nassistant: {assistant_content}" + elif concat_format == "user_only": + combined_text = f"[{chat_time}]: \nuser: {user_content}" + else: + raise ValueError(f"Invalid concat format: {concat_format}") + + pair_id = f"{memory_id}_{pair_index}" + + dialogue_pair = DialoguePair( + pair_id=pair_id, + memory_id=memory_id, + pair_index=pair_index, + user_msg=user_msg, + assistant_msg=assistant_msg, + combined_text=combined_text, + memory=memory, + chat_time=chat_time + ) + + self.dialogue_pairs.append(dialogue_pair) + return pair_id + + def get_documents_for_ranking(self, concat_memory: bool = True) -> list[str]: + """Get the combined text documents for ranking.""" + if concat_memory: + return [(pair.memory + "\n\n" + pair.combined_text) for pair in self.dialogue_pairs] + else: + return [pair.combined_text for pair in self.dialogue_pairs] + + def get_dialogue_pair_by_index(self, index: int) -> DialoguePair | None: + """Get dialogue pair by its index in the ranking results.""" + if 0 <= index < len(self.dialogue_pairs): + return self.dialogue_pairs[index] + return None diff --git a/src/memos/reranker/strategies/factory.py b/src/memos/reranker/strategies/factory.py new file mode 100644 index 000000000..4f67b4b38 --- /dev/null +++ b/src/memos/reranker/strategies/factory.py @@ -0,0 +1,26 @@ +# memos/reranker/factory.py +from __future__ import annotations + +from typing import TYPE_CHECKING, Any,ClassVar +from .base import BaseRerankerStrategy +from .single_turn import SingleTurnStrategy +from .concat_background import ConcatBackgroundStrategy +from .singleturn_outmem import SingleTurnOutMemStrategy + +class RerankerStrategyFactory(): + """Factory class for creating reranker strategy instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "single_turn": SingleTurnStrategy, + "concat_background": ConcatBackgroundStrategy, + "singleturn_outputmem": SingleTurnOutMemStrategy, + } + + @classmethod + def from_config( + cls, config_factory: str = "single_turn" + ) -> BaseRerankerStrategy: + if config_factory not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + strategy_class = cls.backend_to_class[config_factory] + return strategy_class() diff --git a/src/memos/reranker/strategies/single_turn.py b/src/memos/reranker/strategies/single_turn.py new file mode 100644 index 000000000..c4f77f57c --- /dev/null +++ b/src/memos/reranker/strategies/single_turn.py @@ -0,0 +1,116 @@ +# memos/reranker/strategies/single_turn.py +from __future__ import annotations + +from typing import Any +from collections import defaultdict +from copy import deepcopy +from .base import BaseRerankerStrategy +from .dialogue_common import DialogueRankingTracker, strip_memory_tags, extract_content + + + +class SingleTurnStrategy(BaseRerankerStrategy): + """ + Single turn dialogue strategy. + + This strategy processes dialogue pairs by concatenating user and assistant + messages into single strings for ranking. Each dialogue pair becomes a + separate document for ranking. + example: + >>> documents = ["chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] + >>> output memory item: ["Memory:xxx \n\n chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] + """ + + def prepare_documents( + self, + query: str, + graph_results: list, + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents based on single turn concatenation strategy. + + Args: + query: The search query + graph_results: List of graph results + top_k: Maximum number of items to return + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + + original_items = {} + tracker = DialogueRankingTracker() + for item in graph_results: + memory = strip_memory_tags(item) + sources = getattr(item.metadata, "sources", []) + original_items[item.id] = item + + # Group messages into pairs and concatenate + for i in range(0, len(sources), 2): + user_msg = sources[i] if i < len(sources) else {} + assistant_msg = sources[i + 1] if i + 1 < len(sources) else {} + + user_content = extract_content(user_msg) + assistant_content = extract_content(assistant_msg) + chat_time = getattr(user_msg, "chat_time", "") + + if user_content or assistant_content: # Only add non-empty pairs + pair_index = i // 2 + tracker.add_dialogue_pair( + item.id, + pair_index, + user_msg, + assistant_msg, + memory or "", + chat_time + ) + + documents = tracker.get_documents_for_ranking() + return tracker, original_items, documents + + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[Any, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked dialogue pairs. + + Args: + ranked_indices: List of dialogue pair indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + documents = kwargs.get("documents", []) + # Group ranked pairs by memory_id + memory_groups = defaultdict(list) + memory_scores = defaultdict(list) + + reconstructed_items = [] + for idx, score in zip(ranked_indices, scores): + dialogue_pair = tracker.get_dialogue_pair_by_index(idx) + if dialogue_pair and dialogue_pair.memory_id in original_items: + original_item = original_items[dialogue_pair.memory_id] + reconstructed_item = deepcopy(original_item) + reconstructed_item.memory = dialogue_pair.memory + "\n\n" + dialogue_pair.combined_text + reconstructed_items.append((reconstructed_item, score)) + + # Sort by aggregated score and return top_k + reconstructed_items.sort(key=lambda x: x[1], reverse=True) + return reconstructed_items[:top_k] + + diff --git a/src/memos/reranker/strategies/singleturn_outmem.py b/src/memos/reranker/strategies/singleturn_outmem.py new file mode 100644 index 000000000..442e1e2d2 --- /dev/null +++ b/src/memos/reranker/strategies/singleturn_outmem.py @@ -0,0 +1,96 @@ +# memos/reranker/strategies/single_turn.py +from __future__ import annotations + +from typing import Any +from collections import defaultdict +from copy import deepcopy +from .base import BaseRerankerStrategy +from .dialogue_common import DialogueRankingTracker, strip_memory_tags, extract_content +from .single_turn import SingleTurnStrategy + + +class SingleTurnOutMemStrategy(SingleTurnStrategy): + """ + Single turn dialogue strategy. + + This strategy processes dialogue pairs by concatenating user and assistant + messages into single strings for ranking. Each dialogue pair becomes a + separate document for ranking. + example: + >>> documents = ["chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] + >>> output memory item: ["Memory:xxx \n\n chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] + """ + + def prepare_documents( + self, + query: str, + graph_results: list, + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents based on single turn concatenation strategy. + + Args: + query: The search query + graph_results: List of graph results + top_k: Maximum number of items to return + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + return super().prepare_documents(query, graph_results, top_k, **kwargs) + + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[Any, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked dialogue pairs. + + Args: + ranked_indices: List of dialogue pair indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + # Group ranked pairs by memory_id + memory_groups = defaultdict(list) + memory_scores = defaultdict(list) + + for idx, score in zip(ranked_indices, scores): + dialogue_pair = tracker.get_dialogue_pair_by_index(idx) + if dialogue_pair: + memory_groups[dialogue_pair.memory_id].append(dialogue_pair) + memory_scores[dialogue_pair.memory_id].append(score) + + reconstructed_items = [] + + for memory_id, pairs in memory_groups.items(): + if memory_id not in original_items: + continue + original_item = original_items[memory_id] + + # Calculate aggregated score (e.g., max, mean, or weighted average) + pair_scores = memory_scores[memory_id] + aggregated_score = max(pair_scores) if pair_scores else 0.0 + + reconstructed_items.append((original_item, aggregated_score)) + + # Sort by aggregated score and return top_k + reconstructed_items.sort(key=lambda x: x[1], reverse=True) + return reconstructed_items[:top_k] + + From 2f2d11ad6ee030626b33d1d942a6d05c35a01805 Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 25 Oct 2025 14:31:34 +0800 Subject: [PATCH 4/9] fix: code suffix --- src/memos/reranker/concat.py | 6 +----- src/memos/reranker/strategies/dialogue_common.py | 1 + src/memos/reranker/strategies/factory.py | 6 ++++-- src/memos/reranker/strategies/single_turn.py | 7 +------ src/memos/reranker/strategies/singleturn_outmem.py | 8 ++++++-- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/memos/reranker/concat.py b/src/memos/reranker/concat.py index 19be9d01d..502af18b6 100644 --- a/src/memos/reranker/concat.py +++ b/src/memos/reranker/concat.py @@ -28,7 +28,6 @@ def truncate_data(data: list[str | dict[str, Any] | Any], max_tokens: int) -> li Returns: str: Truncated string. """ - total_tokens = 0 truncated_string = "" for item in data: if isinstance(item, SourceMessage): @@ -81,10 +80,7 @@ def concat_original_source( list[str]: List of memory and concat orginal memory. """ merge_field = [] - if rerank_source is None: - merge_field = ["sources"] - else: - merge_field = rerank_source.split(",") + merge_field = ["sources"] if rerank_source is None else rerank_source.split(",") documents = [] for item in graph_results: memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m diff --git a/src/memos/reranker/strategies/dialogue_common.py b/src/memos/reranker/strategies/dialogue_common.py index bf9c7d360..7385bcae4 100644 --- a/src/memos/reranker/strategies/dialogue_common.py +++ b/src/memos/reranker/strategies/dialogue_common.py @@ -4,6 +4,7 @@ from typing import Any, Literal from pydantic import BaseModel from memos.memories.textual.item import SourceMessage +from memos.memories.textual.item import TextualMemoryItem # Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...") # before sending text to the reranker. This keeps inputs clean and diff --git a/src/memos/reranker/strategies/factory.py b/src/memos/reranker/strategies/factory.py index 4f67b4b38..0a6ad52a2 100644 --- a/src/memos/reranker/strategies/factory.py +++ b/src/memos/reranker/strategies/factory.py @@ -2,11 +2,13 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any,ClassVar -from .base import BaseRerankerStrategy from .single_turn import SingleTurnStrategy from .concat_background import ConcatBackgroundStrategy from .singleturn_outmem import SingleTurnOutMemStrategy +if TYPE_CHECKING: + from .base import BaseRerankerStrategy + class RerankerStrategyFactory(): """Factory class for creating reranker strategy instances.""" @@ -21,6 +23,6 @@ def from_config( cls, config_factory: str = "single_turn" ) -> BaseRerankerStrategy: if config_factory not in cls.backend_to_class: - raise ValueError(f"Invalid backend: {backend}") + raise ValueError(f"Invalid backend: {config_factory}") strategy_class = cls.backend_to_class[config_factory] return strategy_class() diff --git a/src/memos/reranker/strategies/single_turn.py b/src/memos/reranker/strategies/single_turn.py index c4f77f57c..44391284a 100644 --- a/src/memos/reranker/strategies/single_turn.py +++ b/src/memos/reranker/strategies/single_turn.py @@ -95,12 +95,7 @@ def reconstruct_items( Returns: List of (reconstructed_memory_item, aggregated_score) tuples """ - documents = kwargs.get("documents", []) - # Group ranked pairs by memory_id - memory_groups = defaultdict(list) - memory_scores = defaultdict(list) - - reconstructed_items = [] + reconstructed_items = [] for idx, score in zip(ranked_indices, scores): dialogue_pair = tracker.get_dialogue_pair_by_index(idx) if dialogue_pair and dialogue_pair.memory_id in original_items: diff --git a/src/memos/reranker/strategies/singleturn_outmem.py b/src/memos/reranker/strategies/singleturn_outmem.py index 442e1e2d2..4be653872 100644 --- a/src/memos/reranker/strategies/singleturn_outmem.py +++ b/src/memos/reranker/strategies/singleturn_outmem.py @@ -1,13 +1,17 @@ # memos/reranker/strategies/single_turn.py from __future__ import annotations -from typing import Any +from typing import Any, TYPE_CHECKING + from collections import defaultdict from copy import deepcopy from .base import BaseRerankerStrategy from .dialogue_common import DialogueRankingTracker, strip_memory_tags, extract_content from .single_turn import SingleTurnStrategy +if TYPE_CHECKING: + from .dialogue_common import DialogueRankingTracker + class SingleTurnOutMemStrategy(SingleTurnStrategy): """ @@ -78,7 +82,7 @@ def reconstruct_items( reconstructed_items = [] - for memory_id, pairs in memory_groups.items(): + for memory_id, _pairs in memory_groups.items(): if memory_id not in original_items: continue original_item = original_items[memory_id] From f37ceed9dd423a0fae61f8aa9d32dcfa51b13ad6 Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 25 Oct 2025 14:37:56 +0800 Subject: [PATCH 5/9] fix: code suffix --- src/memos/reranker/base.py | 5 ++- src/memos/reranker/factory.py | 2 +- src/memos/reranker/http_bge_strategy.py | 21 +++++----- src/memos/reranker/strategies/__init__.py | 2 +- src/memos/reranker/strategies/base.py | 19 +++++---- .../reranker/strategies/concat_background.py | 37 ++++++++--------- .../reranker/strategies/dialogue_common.py | 22 +++++----- src/memos/reranker/strategies/factory.py | 13 +++--- src/memos/reranker/strategies/single_turn.py | 40 ++++++++----------- .../reranker/strategies/singleturn_outmem.py | 35 ++++++++-------- 10 files changed, 98 insertions(+), 98 deletions(-) diff --git a/src/memos/reranker/base.py b/src/memos/reranker/base.py index 89b474e0d..1c2f86ac5 100644 --- a/src/memos/reranker/base.py +++ b/src/memos/reranker/base.py @@ -2,7 +2,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING + if TYPE_CHECKING: from memos.memories.textual.item import TextualMemoryItem @@ -21,4 +22,4 @@ def rerank( **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: """Return top_k (item, score) sorted by score desc.""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/memos/reranker/factory.py b/src/memos/reranker/factory.py index 91a995ab5..57460a4af 100644 --- a/src/memos/reranker/factory.py +++ b/src/memos/reranker/factory.py @@ -8,8 +8,8 @@ from .cosine_local import CosineLocalReranker from .http_bge import HTTPBGEReranker -from .noop import NoopReranker from .http_bge_strategy import HTTPBGERerankerStrategy +from .noop import NoopReranker if TYPE_CHECKING: diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py index 1b46b13c4..510eaca4b 100644 --- a/src/memos/reranker/http_bge_strategy.py +++ b/src/memos/reranker/http_bge_strategy.py @@ -9,10 +9,9 @@ import requests from memos.log import get_logger - -from .base import BaseReranker from memos.reranker.strategies import RerankerStrategyFactory +from .base import BaseReranker logger = get_logger(__name__) @@ -151,7 +150,9 @@ def rerank( if not graph_results: return [] - tracker, original_items, documents = self.reranker_strategy.prepare_documents(query, graph_results, top_k) + tracker, original_items, documents = self.reranker_strategy.prepare_documents( + query, graph_results, top_k + ) logger.info( f"[HTTPBGEWithSourceReranker] strategy: {self.reranker_strategy}, " @@ -167,9 +168,7 @@ def rerank( try: # Make the HTTP request to the reranker service - resp = requests.post( - self.reranker_url, headers=headers, json=payload, timeout=30 - ) + resp = requests.post(self.reranker_url, headers=headers, json=payload, timeout=self.timeout) resp.raise_for_status() data = resp.json() @@ -192,13 +191,13 @@ def rerank( ranked_indices.append(idx) scores.append(raw_score) reconstructed_items = self.reranker_strategy.reconstruct_items( - ranked_indices=ranked_indices, - scores=scores, - tracker=tracker, - original_items=original_items, + ranked_indices=ranked_indices, + scores=scores, + tracker=tracker, + original_items=original_items, top_k=top_k, graph_results=graph_results, - documents=documents + documents=documents, ) return reconstructed_items diff --git a/src/memos/reranker/strategies/__init__.py b/src/memos/reranker/strategies/__init__.py index 36186c0ac..cee60f2be 100644 --- a/src/memos/reranker/strategies/__init__.py +++ b/src/memos/reranker/strategies/__init__.py @@ -1,4 +1,4 @@ from .factory import RerankerStrategyFactory -__all__ = ["RerankerStrategyFactory"] \ No newline at end of file +__all__ = ["RerankerStrategyFactory"] diff --git a/src/memos/reranker/strategies/base.py b/src/memos/reranker/strategies/base.py index 3a35d2baa..43166dd92 100644 --- a/src/memos/reranker/strategies/base.py +++ b/src/memos/reranker/strategies/base.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import Any + from memos.memories.textual.item import TextualMemoryItem + from .dialogue_common import DialogueRankingTracker + class BaseRerankerStrategy(ABC): """Abstract interface for memory rerankers with concatenation strategy.""" @@ -16,21 +19,21 @@ def prepare_documents( ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: """ Prepare documents for ranking based on the strategy. - + Args: query: The search query graph_results: List of TextualMemoryItem objects to process top_k: Maximum number of items to return **kwargs: Additional strategy-specific parameters - + Returns: - tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: - Tracker: DialogueRankingTracker instance - original_items: Dict mapping memory_id to original TextualMemoryItem - documents: List of text documents ready for ranking """ raise NotImplementedError - + @abstractmethod def reconstruct_items( self, @@ -43,7 +46,7 @@ def reconstruct_items( ) -> list[tuple[TextualMemoryItem, float]]: """ Reconstruct TextualMemoryItem objects from ranked results. - + Args: ranked_indices: List of indices sorted by relevance scores: Corresponding relevance scores @@ -51,8 +54,8 @@ def reconstruct_items( original_items: Dict mapping memory_id to original TextualMemoryItem top_k: Maximum number of items to return **kwargs: Additional strategy-specific parameters - + Returns: List of (reconstructed_memory_item, aggregated_score) tuples """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/memos/reranker/strategies/concat_background.py b/src/memos/reranker/strategies/concat_background.py index 6583ba96c..a52313548 100644 --- a/src/memos/reranker/strategies/concat_background.py +++ b/src/memos/reranker/strategies/concat_background.py @@ -1,22 +1,25 @@ # memos/reranker/strategies/single_turn.py from __future__ import annotations + import re + from typing import Any -from collections import defaultdict -from copy import deepcopy + from .base import BaseRerankerStrategy -from .dialogue_common import DialogueRankingTracker, strip_memory_tags, extract_content +from .dialogue_common import DialogueRankingTracker + _TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") + class ConcatBackgroundStrategy(BaseRerankerStrategy): """ Concat background strategy. - + This strategy processes dialogue pairs by concatenating background and user and assistant messages into single strings for ranking. Each dialogue pair becomes a separate document for ranking. - """ + """ def prepare_documents( self, @@ -27,19 +30,19 @@ def prepare_documents( ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: """ Prepare documents based on single turn concatenation strategy. - + Args: query: The search query graph_results: List of graph results top_k: Maximum number of items to return - + Returns: - tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: - Tracker: DialogueRankingTracker instance - original_items: Dict mapping memory_id to original TextualMemoryItem - documents: List of text documents ready for ranking """ - + original_items = {} tracker = DialogueRankingTracker() documents = [] @@ -47,13 +50,13 @@ def prepare_documents( memory = getattr(item, "memory", None) if isinstance(memory, str): memory = _TAG1.sub("", memory) - + background = "" if hasattr(item, "metadata") and hasattr(item.metadata, "background"): background = getattr(item.metadata, "background", "") if not isinstance(background, str): background = "" - + documents.append(f"{memory}\n{background}") return tracker, original_items, documents @@ -68,26 +71,24 @@ def reconstruct_items( ) -> list[tuple[Any, float]]: """ Reconstruct TextualMemoryItem objects from ranked dialogue pairs. - + Args: ranked_indices: List of dialogue pair indices sorted by relevance scores: Corresponding relevance scores tracker: DialogueRankingTracker instance original_items: Dict mapping memory_id to original TextualMemoryItem top_k: Maximum number of items to return - + Returns: List of (reconstructed_memory_item, aggregated_score) tuples """ - graph_results = kwargs.get("graph_results", None) - documents = kwargs.get("documents", None) + graph_results = kwargs.get("graph_results") + documents = kwargs.get("documents") reconstructed_items = [] for idx in ranked_indices: item = graph_results[idx] item.memory = f"{item.memory}\n{documents[idx]}" reconstructed_items.append((item, scores[idx])) - + reconstructed_items.sort(key=lambda x: x[1], reverse=True) return reconstructed_items[:top_k] - - diff --git a/src/memos/reranker/strategies/dialogue_common.py b/src/memos/reranker/strategies/dialogue_common.py index 7385bcae4..ce0138284 100644 --- a/src/memos/reranker/strategies/dialogue_common.py +++ b/src/memos/reranker/strategies/dialogue_common.py @@ -1,10 +1,13 @@ from __future__ import annotations import re + from typing import Any, Literal + from pydantic import BaseModel -from memos.memories.textual.item import SourceMessage -from memos.memories.textual.item import TextualMemoryItem + +from memos.memories.textual.item import SourceMessage, TextualMemoryItem + # Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...") # before sending text to the reranker. This keeps inputs clean and @@ -17,10 +20,11 @@ def strip_memory_tags(item: TextualMemoryItem) -> str: memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m return memory + def extract_content(msg: dict[str, Any] | str) -> str: """Extract content from message, handling both string and dict formats.""" if isinstance(msg, dict): - return msg.get('content', str(msg)) + return msg.get("content", str(msg)) if isinstance(msg, SourceMessage): return msg.content return str(msg) @@ -33,7 +37,7 @@ class DialoguePair(BaseModel): memory_id: str # ID of the source TextualMemoryItem memory: str pair_index: int # Index of this pair within the source memory's dialogue - user_msg: str | dict[str, Any] | SourceMessage # User message content + user_msg: str | dict[str, Any] | SourceMessage # User message content assistant_msg: str | dict[str, Any] | SourceMessage # Assistant message content combined_text: str # The concatenated text used for ranking chat_time: str | None = None @@ -56,14 +60,14 @@ def __init__(self): self.dialogue_pairs: list[DialoguePair] = [] def add_dialogue_pair( - self, - memory_id: str, + self, + memory_id: str, pair_index: int, - user_msg: str | dict[str, Any], + user_msg: str | dict[str, Any], assistant_msg: str | dict[str, Any], memory: str, chat_time: str | None = None, - concat_format: Literal["user_assistant", "user_only"] = "user_assistant" + concat_format: Literal["user_assistant", "user_only"] = "user_assistant", ) -> str: """Add a dialogue pair and return its unique ID.""" user_content = extract_content(user_msg) @@ -85,7 +89,7 @@ def add_dialogue_pair( assistant_msg=assistant_msg, combined_text=combined_text, memory=memory, - chat_time=chat_time + chat_time=chat_time, ) self.dialogue_pairs.append(dialogue_pair) diff --git a/src/memos/reranker/strategies/factory.py b/src/memos/reranker/strategies/factory.py index 0a6ad52a2..7bb223ce7 100644 --- a/src/memos/reranker/strategies/factory.py +++ b/src/memos/reranker/strategies/factory.py @@ -1,15 +1,18 @@ # memos/reranker/factory.py from __future__ import annotations -from typing import TYPE_CHECKING, Any,ClassVar -from .single_turn import SingleTurnStrategy +from typing import TYPE_CHECKING, Any, ClassVar + from .concat_background import ConcatBackgroundStrategy +from .single_turn import SingleTurnStrategy from .singleturn_outmem import SingleTurnOutMemStrategy + if TYPE_CHECKING: from .base import BaseRerankerStrategy -class RerankerStrategyFactory(): + +class RerankerStrategyFactory: """Factory class for creating reranker strategy instances.""" backend_to_class: ClassVar[dict[str, Any]] = { @@ -19,9 +22,7 @@ class RerankerStrategyFactory(): } @classmethod - def from_config( - cls, config_factory: str = "single_turn" - ) -> BaseRerankerStrategy: + def from_config(cls, config_factory: str = "single_turn") -> BaseRerankerStrategy: if config_factory not in cls.backend_to_class: raise ValueError(f"Invalid backend: {config_factory}") strategy_class = cls.backend_to_class[config_factory] diff --git a/src/memos/reranker/strategies/single_turn.py b/src/memos/reranker/strategies/single_turn.py index 44391284a..19fdd4a5a 100644 --- a/src/memos/reranker/strategies/single_turn.py +++ b/src/memos/reranker/strategies/single_turn.py @@ -1,25 +1,24 @@ # memos/reranker/strategies/single_turn.py from __future__ import annotations -from typing import Any -from collections import defaultdict from copy import deepcopy -from .base import BaseRerankerStrategy -from .dialogue_common import DialogueRankingTracker, strip_memory_tags, extract_content +from typing import Any +from .base import BaseRerankerStrategy +from .dialogue_common import DialogueRankingTracker, extract_content, strip_memory_tags class SingleTurnStrategy(BaseRerankerStrategy): """ Single turn dialogue strategy. - + This strategy processes dialogue pairs by concatenating user and assistant messages into single strings for ranking. Each dialogue pair becomes a separate document for ranking. example: >>> documents = ["chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] >>> output memory item: ["Memory:xxx \n\n chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] - """ + """ def prepare_documents( self, @@ -30,19 +29,19 @@ def prepare_documents( ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: """ Prepare documents based on single turn concatenation strategy. - + Args: query: The search query graph_results: List of graph results top_k: Maximum number of items to return - + Returns: - tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: - Tracker: DialogueRankingTracker instance - original_items: Dict mapping memory_id to original TextualMemoryItem - documents: List of text documents ready for ranking """ - + original_items = {} tracker = DialogueRankingTracker() for item in graph_results: @@ -58,16 +57,11 @@ def prepare_documents( user_content = extract_content(user_msg) assistant_content = extract_content(assistant_msg) chat_time = getattr(user_msg, "chat_time", "") - + if user_content or assistant_content: # Only add non-empty pairs pair_index = i // 2 tracker.add_dialogue_pair( - item.id, - pair_index, - user_msg, - assistant_msg, - memory or "", - chat_time + item.id, pair_index, user_msg, assistant_msg, memory or "", chat_time ) documents = tracker.get_documents_for_ranking() @@ -84,28 +78,28 @@ def reconstruct_items( ) -> list[tuple[Any, float]]: """ Reconstruct TextualMemoryItem objects from ranked dialogue pairs. - + Args: ranked_indices: List of dialogue pair indices sorted by relevance scores: Corresponding relevance scores tracker: DialogueRankingTracker instance original_items: Dict mapping memory_id to original TextualMemoryItem top_k: Maximum number of items to return - + Returns: List of (reconstructed_memory_item, aggregated_score) tuples """ reconstructed_items = [] - for idx, score in zip(ranked_indices, scores): + for idx, score in zip(ranked_indices, scores, strict=False): dialogue_pair = tracker.get_dialogue_pair_by_index(idx) if dialogue_pair and dialogue_pair.memory_id in original_items: original_item = original_items[dialogue_pair.memory_id] reconstructed_item = deepcopy(original_item) - reconstructed_item.memory = dialogue_pair.memory + "\n\n" + dialogue_pair.combined_text + reconstructed_item.memory = ( + dialogue_pair.memory + "\n\n" + dialogue_pair.combined_text + ) reconstructed_items.append((reconstructed_item, score)) # Sort by aggregated score and return top_k reconstructed_items.sort(key=lambda x: x[1], reverse=True) return reconstructed_items[:top_k] - - diff --git a/src/memos/reranker/strategies/singleturn_outmem.py b/src/memos/reranker/strategies/singleturn_outmem.py index 4be653872..43f57630b 100644 --- a/src/memos/reranker/strategies/singleturn_outmem.py +++ b/src/memos/reranker/strategies/singleturn_outmem.py @@ -1,14 +1,13 @@ # memos/reranker/strategies/single_turn.py from __future__ import annotations -from typing import Any, TYPE_CHECKING - from collections import defaultdict -from copy import deepcopy -from .base import BaseRerankerStrategy -from .dialogue_common import DialogueRankingTracker, strip_memory_tags, extract_content +from typing import TYPE_CHECKING, Any + +from .dialogue_common import DialogueRankingTracker from .single_turn import SingleTurnStrategy + if TYPE_CHECKING: from .dialogue_common import DialogueRankingTracker @@ -16,14 +15,14 @@ class SingleTurnOutMemStrategy(SingleTurnStrategy): """ Single turn dialogue strategy. - + This strategy processes dialogue pairs by concatenating user and assistant messages into single strings for ranking. Each dialogue pair becomes a separate document for ranking. example: >>> documents = ["chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] >>> output memory item: ["Memory:xxx \n\n chat_time: 2025-01-01 12:00:00\nuser: hello\nassistant: hi there"] - """ + """ def prepare_documents( self, @@ -34,14 +33,14 @@ def prepare_documents( ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: """ Prepare documents based on single turn concatenation strategy. - + Args: query: The search query graph_results: List of graph results top_k: Maximum number of items to return - + Returns: - tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: - Tracker: DialogueRankingTracker instance - original_items: Dict mapping memory_id to original TextualMemoryItem - documents: List of text documents ready for ranking @@ -59,14 +58,14 @@ def reconstruct_items( ) -> list[tuple[Any, float]]: """ Reconstruct TextualMemoryItem objects from ranked dialogue pairs. - + Args: ranked_indices: List of dialogue pair indices sorted by relevance scores: Corresponding relevance scores tracker: DialogueRankingTracker instance original_items: Dict mapping memory_id to original TextualMemoryItem top_k: Maximum number of items to return - + Returns: List of (reconstructed_memory_item, aggregated_score) tuples """ @@ -74,27 +73,25 @@ def reconstruct_items( memory_groups = defaultdict(list) memory_scores = defaultdict(list) - for idx, score in zip(ranked_indices, scores): + for idx, score in zip(ranked_indices, scores, strict=False): dialogue_pair = tracker.get_dialogue_pair_by_index(idx) if dialogue_pair: memory_groups[dialogue_pair.memory_id].append(dialogue_pair) memory_scores[dialogue_pair.memory_id].append(score) - + reconstructed_items = [] - + for memory_id, _pairs in memory_groups.items(): if memory_id not in original_items: continue original_item = original_items[memory_id] - + # Calculate aggregated score (e.g., max, mean, or weighted average) pair_scores = memory_scores[memory_id] aggregated_score = max(pair_scores) if pair_scores else 0.0 - + reconstructed_items.append((original_item, aggregated_score)) # Sort by aggregated score and return top_k reconstructed_items.sort(key=lambda x: x[1], reverse=True) return reconstructed_items[:top_k] - - From e33ba0bdfb4e75fb27c3b0333996aa2ae685347a Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 25 Oct 2025 14:49:10 +0800 Subject: [PATCH 6/9] fix:change strategy name --- src/memos/reranker/http_bge_strategy.py | 2 +- src/memos/reranker/strategies/factory.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py index 510eaca4b..f02c7b665 100644 --- a/src/memos/reranker/http_bge_strategy.py +++ b/src/memos/reranker/http_bge_strategy.py @@ -84,7 +84,7 @@ def __init__( boost_weights: dict[str, float] | None = None, boost_default: float = 0.0, warn_unknown_filter_keys: bool = True, - reranker_strategy: str = "singleturn_outputmem", + reranker_strategy: str = "single_turn", **kwargs, ): """ diff --git a/src/memos/reranker/strategies/factory.py b/src/memos/reranker/strategies/factory.py index 7bb223ce7..d93cbd65a 100644 --- a/src/memos/reranker/strategies/factory.py +++ b/src/memos/reranker/strategies/factory.py @@ -18,7 +18,7 @@ class RerankerStrategyFactory: backend_to_class: ClassVar[dict[str, Any]] = { "single_turn": SingleTurnStrategy, "concat_background": ConcatBackgroundStrategy, - "singleturn_outputmem": SingleTurnOutMemStrategy, + "singleturn_outmem": SingleTurnOutMemStrategy, } @classmethod From 0ccbceedddcd22af1532fd1cf8d9b2c3a8fa5e72 Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 25 Oct 2025 14:56:54 +0800 Subject: [PATCH 7/9] fix: code format --- src/memos/reranker/http_bge_strategy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py index f02c7b665..acc3f30af 100644 --- a/src/memos/reranker/http_bge_strategy.py +++ b/src/memos/reranker/http_bge_strategy.py @@ -168,7 +168,9 @@ def rerank( try: # Make the HTTP request to the reranker service - resp = requests.post(self.reranker_url, headers=headers, json=payload, timeout=self.timeout) + resp = requests.post( + self.reranker_url, headers=headers, json=payload, timeout=self.timeout + ) resp.raise_for_status() data = resp.json() From c45a0826e2e533df2158a44e7bacc12dc908f35f Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 27 Oct 2025 10:25:47 +0800 Subject: [PATCH 8/9] feat: update memory strategies --- src/memos/graph_dbs/neo4j.py | 2 +- src/memos/reranker/http_bge_strategy.py | 2 +- src/memos/reranker/strategies/dialogue_common.py | 4 ++-- src/memos/reranker/strategies/single_turn.py | 4 ++-- src/memos/reranker/strategies/singleturn_outmem.py | 1 + 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 55db60ed2..e7807b6f5 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -157,7 +157,7 @@ def remove_oldest_memory( """ if not self.config.use_multi_db and (self.config.user_name or user_name): query += f"\nAND n.user_name = '{user_name}'" - + keep_latest = int(keep_latest) query += f""" WITH n ORDER BY n.updated_at DESC SKIP {keep_latest} diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py index acc3f30af..8cbf633a6 100644 --- a/src/memos/reranker/http_bge_strategy.py +++ b/src/memos/reranker/http_bge_strategy.py @@ -158,7 +158,7 @@ def rerank( f"[HTTPBGEWithSourceReranker] strategy: {self.reranker_strategy}, " f"query: {query}, documents count: {len(documents)}" ) - logger.debug(f"[HTTPBGEWithSourceReranker] sample documents: {documents[:2]}...") + logger.info(f"[HTTPBGEWithSourceReranker] sample documents: {documents[:3]}...") if not documents: return [] diff --git a/src/memos/reranker/strategies/dialogue_common.py b/src/memos/reranker/strategies/dialogue_common.py index ce0138284..6bac54449 100644 --- a/src/memos/reranker/strategies/dialogue_common.py +++ b/src/memos/reranker/strategies/dialogue_common.py @@ -80,7 +80,7 @@ def add_dialogue_pair( raise ValueError(f"Invalid concat format: {concat_format}") pair_id = f"{memory_id}_{pair_index}" - + dialogue_pair = DialoguePair( pair_id=pair_id, memory_id=memory_id, @@ -91,7 +91,7 @@ def add_dialogue_pair( memory=memory, chat_time=chat_time, ) - + self.dialogue_pairs.append(dialogue_pair) return pair_id diff --git a/src/memos/reranker/strategies/single_turn.py b/src/memos/reranker/strategies/single_turn.py index 19fdd4a5a..bc4fe53ce 100644 --- a/src/memos/reranker/strategies/single_turn.py +++ b/src/memos/reranker/strategies/single_turn.py @@ -92,11 +92,11 @@ def reconstruct_items( reconstructed_items = [] for idx, score in zip(ranked_indices, scores, strict=False): dialogue_pair = tracker.get_dialogue_pair_by_index(idx) - if dialogue_pair and dialogue_pair.memory_id in original_items: + if dialogue_pair and (dialogue_pair.memory_id in original_items): original_item = original_items[dialogue_pair.memory_id] reconstructed_item = deepcopy(original_item) reconstructed_item.memory = ( - dialogue_pair.memory + "\n\n" + dialogue_pair.combined_text + dialogue_pair.memory + "\n\nsources-dialogue-pairs" + dialogue_pair.combined_text ) reconstructed_items.append((reconstructed_item, score)) diff --git a/src/memos/reranker/strategies/singleturn_outmem.py b/src/memos/reranker/strategies/singleturn_outmem.py index 43f57630b..220045fa2 100644 --- a/src/memos/reranker/strategies/singleturn_outmem.py +++ b/src/memos/reranker/strategies/singleturn_outmem.py @@ -88,6 +88,7 @@ def reconstruct_items( # Calculate aggregated score (e.g., max, mean, or weighted average) pair_scores = memory_scores[memory_id] + aggregated_score = max(pair_scores) if pair_scores else 0.0 reconstructed_items.append((original_item, aggregated_score)) From b60710e63e46970f340c64ecf5f4c89b2affa1f2 Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 27 Oct 2025 10:28:35 +0800 Subject: [PATCH 9/9] fix: code ci --- src/memos/reranker/strategies/dialogue_common.py | 4 ++-- src/memos/reranker/strategies/single_turn.py | 4 +++- src/memos/reranker/strategies/singleturn_outmem.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/memos/reranker/strategies/dialogue_common.py b/src/memos/reranker/strategies/dialogue_common.py index 6bac54449..ce0138284 100644 --- a/src/memos/reranker/strategies/dialogue_common.py +++ b/src/memos/reranker/strategies/dialogue_common.py @@ -80,7 +80,7 @@ def add_dialogue_pair( raise ValueError(f"Invalid concat format: {concat_format}") pair_id = f"{memory_id}_{pair_index}" - + dialogue_pair = DialoguePair( pair_id=pair_id, memory_id=memory_id, @@ -91,7 +91,7 @@ def add_dialogue_pair( memory=memory, chat_time=chat_time, ) - + self.dialogue_pairs.append(dialogue_pair) return pair_id diff --git a/src/memos/reranker/strategies/single_turn.py b/src/memos/reranker/strategies/single_turn.py index bc4fe53ce..d86744811 100644 --- a/src/memos/reranker/strategies/single_turn.py +++ b/src/memos/reranker/strategies/single_turn.py @@ -96,7 +96,9 @@ def reconstruct_items( original_item = original_items[dialogue_pair.memory_id] reconstructed_item = deepcopy(original_item) reconstructed_item.memory = ( - dialogue_pair.memory + "\n\nsources-dialogue-pairs" + dialogue_pair.combined_text + dialogue_pair.memory + + "\n\nsources-dialogue-pairs" + + dialogue_pair.combined_text ) reconstructed_items.append((reconstructed_item, score)) diff --git a/src/memos/reranker/strategies/singleturn_outmem.py b/src/memos/reranker/strategies/singleturn_outmem.py index 220045fa2..de59fec97 100644 --- a/src/memos/reranker/strategies/singleturn_outmem.py +++ b/src/memos/reranker/strategies/singleturn_outmem.py @@ -88,7 +88,7 @@ def reconstruct_items( # Calculate aggregated score (e.g., max, mean, or weighted average) pair_scores = memory_scores[memory_id] - + aggregated_score = max(pair_scores) if pair_scores else 0.0 reconstructed_items.append((original_item, aggregated_score))