Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,16 @@ def get_reranker_config() -> dict[str, Any]:
"""Get embedder configuration."""
embedder_backend = os.getenv("MOS_RERANKER_BACKEND", "http_bge")

if embedder_backend == "http_bge":
if embedder_backend in ["http_bge", "http_bge_strategy"]:
return {
"backend": "http_bge",
"backend": embedder_backend,
"config": {
"url": os.getenv("MOS_RERANKER_URL"),
"model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"),
"timeout": 10,
"headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"),
"rerank_source": os.getenv("MOS_RERANK_SOURCE"),
"reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"),
},
}
else:
Expand Down
2 changes: 1 addition & 1 deletion src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def remove_oldest_memory(
"""
if not self.config.use_multi_db and (self.config.user_name or user_name):
query += f"\nAND n.user_name = '{user_name}'"

keep_latest = int(keep_latest)
query += f"""
WITH n ORDER BY n.updated_at DESC
SKIP {keep_latest}
Expand Down
3 changes: 2 additions & 1 deletion src/memos/reranker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class BaseReranker(ABC):
def rerank(
self,
query: str,
graph_results: list,
graph_results: list[TextualMemoryItem],
top_k: int,
search_filter: dict | None = None,
**kwargs,
) -> list[tuple[TextualMemoryItem, float]]:
"""Return top_k (item, score) sorted by score desc."""
Expand Down
60 changes: 48 additions & 12 deletions src/memos/reranker/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,49 @@

from typing import Any

from memos.memories.textual.item import SourceMessage


_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")


def get_encoded_tokens(content: str) -> int:
"""
Get encoded tokens.
Args:
content: str
Returns:
int: Encoded tokens.
"""
return len(content)


def truncate_data(data: list[str | dict[str, Any] | Any], max_tokens: int) -> list[str]:
"""
Truncate data to max tokens.
Args:
data: List of strings or dictionaries.
max_tokens: Maximum number of tokens.
Returns:
str: Truncated string.
"""
truncated_string = ""
for item in data:
if isinstance(item, SourceMessage):
content = getattr(item, "content", "")
chat_time = getattr(item, "chat_time", "")
if not content:
continue
truncated_string += f"[{chat_time}]: {content}\n"
if get_encoded_tokens(truncated_string) > max_tokens:
break
return truncated_string


def process_source(
items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, recent_num: int = 3
items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None,
recent_num: int = 10,
max_tokens: int = 2048,
) -> str:
"""
Args:
Expand All @@ -23,19 +60,16 @@ def process_source(
memory = None
for item in items:
memory, source = item
for content in source:
if isinstance(content, str):
if "assistant:" in content:
continue
concat_data.append(content)
concat_data.extend(source[-recent_num:])
truncated_string = truncate_data(concat_data, max_tokens)
if memory is not None:
concat_data = [memory, *concat_data]
return "\n".join(concat_data)
truncated_string = f"{memory}\n{truncated_string}"
return truncated_string


def concat_original_source(
graph_results: list,
merge_field: list[str] | None = None,
rerank_source: str | None = None,
) -> list[str]:
"""
Merge memory items with original dialogue.
Expand All @@ -45,14 +79,16 @@ def concat_original_source(
Returns:
list[str]: List of memory and concat orginal memory.
"""
if merge_field is None:
merge_field = ["sources"]
merge_field = []
merge_field = ["sources"] if rerank_source is None else rerank_source.split(",")
documents = []
for item in graph_results:
memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
sources = []
for field in merge_field:
source = getattr(item.metadata, field, "")
source = getattr(item.metadata, field, None)
if source is None:
continue
sources.append((memory, source))
concat_string = process_source(sources)
documents.append(concat_string)
Expand Down
6 changes: 5 additions & 1 deletion src/memos/reranker/cosine_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from typing import TYPE_CHECKING

from memos.log import get_logger

from .base import BaseReranker


Expand All @@ -16,6 +18,8 @@
except Exception:
_HAS_NUMPY = False

logger = get_logger(__name__)


def _cosine_one_to_many(q: list[float], m: list[list[float]]) -> list[float]:
"""
Expand Down Expand Up @@ -92,5 +96,5 @@ def get_weight(it: TextualMemoryItem) -> float:
chosen = {it.id for it, _ in top_items}
remain = [(it, -1.0) for it in graph_results if it.id not in chosen]
top_items.extend(remain[: top_k - len(top_items)])

logger.info(f"CosineLocalReranker rerank result: {top_items[:1]}")
return top_items
11 changes: 11 additions & 0 deletions src/memos/reranker/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .cosine_local import CosineLocalReranker
from .http_bge import HTTPBGEReranker
from .http_bge_strategy import HTTPBGERerankerStrategy
from .noop import NoopReranker


Expand Down Expand Up @@ -45,4 +46,14 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
if backend in {"noop", "none", "disabled"}:
return NoopReranker()

if backend in {"http_bge_strategy", "bge_strategy"}:
return HTTPBGERerankerStrategy(
reranker_url=c.get("url") or c.get("endpoint") or c.get("reranker_url"),
model=c.get("model", "bge-reranker-v2-m3"),
timeout=int(c.get("timeout", 10)),
headers_extra=c.get("headers_extra"),
rerank_source=c.get("rerank_source"),
reranker_strategy=c.get("reranker_strategy"),
)

raise ValueError(f"Unknown reranker backend: {cfg.backend}")
8 changes: 4 additions & 4 deletions src/memos/reranker/http_bge.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
model: str = "bge-reranker-v2-m3",
timeout: int = 10,
headers_extra: dict | None = None,
rerank_source: list[str] | None = None,
rerank_source: str | None = None,
boost_weights: dict[str, float] | None = None,
boost_default: float = 0.0,
warn_unknown_filter_keys: bool = True,
Expand All @@ -107,7 +107,7 @@ def __init__(
self.model = model
self.timeout = timeout
self.headers_extra = headers_extra or {}
self.concat_source = rerank_source
self.rerank_source = rerank_source

self.boost_weights = (
DEFAULT_BOOST_WEIGHTS.copy()
Expand Down Expand Up @@ -152,8 +152,8 @@ def rerank(
# Build a mapping from "payload docs index" -> "original graph_results index"
# Only include items that have a non-empty string memory. This ensures that
# any index returned by the server can be mapped back correctly.
if self.concat_source:
documents = concat_original_source(graph_results, self.concat_source)
if self.rerank_source:
documents = concat_original_source(graph_results, self.rerank_source)
else:
documents = [
(_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
Expand Down
Loading