Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,31 @@ def get_embedder_config() -> dict[str, Any]:
},
}

@staticmethod
def get_reader_config() -> dict[str, Any]:
"""Get reader configuration."""
return {
"backend": os.getenv("MEM_READER_BACKEND", "simple_struct"),
"config": {
"chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"),
"chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)),
"chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)),
"chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)),
},
}

@staticmethod
def get_internet_config() -> dict[str, Any]:
"""Get embedder configuration."""
reader_config = APIConfig.get_reader_config()
return {
"backend": "bocha",
"config": {
"api_key": os.getenv("BOCHA_API_KEY"),
"max_results": 15,
"num_per_request": 10,
"reader": {
"backend": "simple_struct",
"backend": reader_config["backend"],
"config": {
"llm": {
"backend": "openai",
Expand All @@ -215,6 +229,7 @@ def get_internet_config() -> dict[str, Any]:
"min_sentences_per_chunk": 1,
},
},
"chat_chunker": reader_config,
},
},
},
Expand Down Expand Up @@ -416,6 +431,8 @@ def get_product_default_config() -> dict[str, Any]:
openai_config = APIConfig.get_openai_config()
qwen_config = APIConfig.qwen_config()
vllm_config = APIConfig.vllm_config()
reader_config = APIConfig.get_reader_config()

backend_model = {
"openai": openai_config,
"huggingface": qwen_config,
Expand All @@ -427,7 +444,7 @@ def get_product_default_config() -> dict[str, Any]:
"user_id": os.getenv("MOS_USER_ID", "root"),
"chat_model": {"backend": backend, "config": backend_model[backend]},
"mem_reader": {
"backend": "simple_struct",
"backend": reader_config["backend"],
"config": {
"llm": APIConfig.get_memreader_config(),
"embedder": APIConfig.get_embedder_config(),
Expand All @@ -440,6 +457,7 @@ def get_product_default_config() -> dict[str, Any]:
"min_sentences_per_chunk": 1,
},
},
"chat_chunker": reader_config,
},
},
"enable_textual_memory": True,
Expand Down Expand Up @@ -510,6 +528,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
qwen_config = APIConfig.qwen_config()
vllm_config = APIConfig.vllm_config()
mysql_config = APIConfig.get_mysql_config()
reader_config = APIConfig.get_reader_config()
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
backend_model = {
"openai": openai_config,
Expand All @@ -524,7 +543,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
"config": backend_model[backend],
},
"mem_reader": {
"backend": "simple_struct",
"backend": reader_config["backend"],
"config": {
"llm": APIConfig.get_memreader_config(),
"embedder": APIConfig.get_embedder_config(),
Expand All @@ -537,6 +556,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
"min_sentences_per_chunk": 1,
},
},
"chat_chunker": reader_config,
},
},
"enable_textual_memory": True,
Expand Down Expand Up @@ -605,6 +625,10 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
},
"search_strategy": {
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
},
},
},
"act_mem": {}
Expand Down Expand Up @@ -672,6 +696,10 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
},
"search_strategy": {
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
},
"mode": os.getenv("ASYNC_MODE", "sync"),
},
},
Expand Down
9 changes: 9 additions & 0 deletions src/memos/configs/mem_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,19 @@ def parse_datetime(cls, value):
description="whether remove example in memory extraction prompt to save token",
)

chat_chunker: dict[str, Any] = Field(
..., description="Configuration for the MemReader chat chunk strategy"
)


class SimpleStructMemReaderConfig(BaseMemReaderConfig):
"""SimpleStruct MemReader configuration class."""


class StrategyStructMemReaderConfig(BaseMemReaderConfig):
"""StrategyStruct MemReader configuration class."""


class MemReaderConfigFactory(BaseConfig):
"""Factory class for creating MemReader configurations."""

Expand All @@ -49,6 +57,7 @@ class MemReaderConfigFactory(BaseConfig):

backend_to_class: ClassVar[dict[str, Any]] = {
"simple_struct": SimpleStructMemReaderConfig,
"strategy_struct": StrategyStructMemReaderConfig,
}

@field_validator("backend")
Expand Down
7 changes: 7 additions & 0 deletions src/memos/configs/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
),
)

search_strategy: dict[str, bool] | None = Field(
default=None,
description=(
'Set search strategy for this memory configuration.{"bm25": true, "cot": false}'
),
)

mode: str | None = Field(
default="sync",
description=("whether use asynchronous mode in memory add"),
Expand Down
21 changes: 21 additions & 0 deletions src/memos/llms/openai.py
Comment thread
CaralHsi marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ def generate(self, messages: MessageList) -> str:
else:
return response_content

def customized_generate(self, messages: MessageList, **kwargs) -> str:
"""Generate a response from OpenAI LLM."""
temperature = kwargs.get("temperature", self.config.temperature)
max_tokens = kwargs.get("max_tokens", self.config.max_tokens)
top_p = kwargs.get("top_p", self.config.top_p)

response = self.client.chat.completions.create(
model=self.config.model_name_or_path,
messages=messages,
extra_body=self.config.extra_body,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
logger.info(f"Response from OpenAI: {response.model_dump_json()}")
response_content = response.choices[0].message.content
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
else:
return response_content

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe merge two functions like this:

def generate(self, messages: MessageList, **kwargs) -> str:
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
temperature = kwargs.get("temperature", self.config.temperature)
max_tokens = kwargs.get("max_tokens", self.config.max_tokens)
top_p = kwargs.get("top_p", self.config.top_p)

response = self.client.chat.completions.create(
model=self.config.model_name_or_path,
messages=messages,
extra_body=self.config.extra_body,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

logger.info(f"Response from OpenAI: {response.model_dump_json()}")

response_content = response.choices[0].message.content
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
return response_content

@timed(log=True, log_prefix="OpenAI LLM")
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
"""Stream response from OpenAI LLM with optional reasoning support."""
Expand Down
2 changes: 2 additions & 0 deletions src/memos/mem_reader/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from memos.configs.mem_reader import MemReaderConfigFactory
from memos.mem_reader.base import BaseMemReader
from memos.mem_reader.simple_struct import SimpleStructMemReader
from memos.mem_reader.strategy_struct import StrategyStructMemReader
from memos.memos_tools.singleton import singleton_factory


Expand All @@ -11,6 +12,7 @@ class MemReaderFactory(BaseMemReader):

backend_to_class: ClassVar[dict[str, Any]] = {
"simple_struct": SimpleStructMemReader,
"strategy_struct": StrategyStructMemReader,
}

@classmethod
Expand Down
135 changes: 135 additions & 0 deletions src/memos/mem_reader/strategy_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
import re

from abc import ABC

from memos import log
from memos.configs.mem_reader import StrategyStructMemReaderConfig
from memos.configs.parser import ParserConfigFactory
from memos.mem_reader.simple_struct import (
SimpleStructMemReader,
)
from memos.parsers.factory import ParserFactory
from memos.templates.mem_reader_prompts import (
SIMPLE_STRUCT_DOC_READER_PROMPT,
SIMPLE_STRUCT_DOC_READER_PROMPT_ZH,
SIMPLE_STRUCT_MEM_READER_EXAMPLE,
SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH,
)
from memos.templates.mem_reader_strategy_prompts import (
STRATEGY_STRUCT_MEM_READER_PROMPT,
STRATEGY_STRUCT_MEM_READER_PROMPT_ZH,
)


logger = log.get_logger(__name__)
PROMPT_DICT = {
"chat": {
"en": STRATEGY_STRUCT_MEM_READER_PROMPT,
"zh": STRATEGY_STRUCT_MEM_READER_PROMPT_ZH,
"en_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE,
"zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH,
},
"doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH},
}

try:
import tiktoken

try:
_ENC = tiktoken.encoding_for_model("gpt-4o-mini")
except Exception:
_ENC = tiktoken.get_encoding("cl100k_base")

def _count_tokens_text(s: str) -> int:
return len(_ENC.encode(s or ""))
except Exception:
# Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars
def _count_tokens_text(s: str) -> int:
if not s:
return 0
zh_chars = re.findall(r"[\u4e00-\u9fff]", s)
zh = len(zh_chars)
rest = len(s) - zh
return zh + max(1, rest // 4)


class StrategyStructMemReader(SimpleStructMemReader, ABC):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed there’s quite a bit of overlap between StrategyStructMemReader and SimpleStructMemReader — particularly in the shared logic for parsing and memory generation.
Would it make sense to consolidate them or refactor into a shared base class for clarity and easier maintenance?

"""Naive implementation of MemReader."""

def __init__(self, config: StrategyStructMemReaderConfig):
super().__init__(config)
self.chat_chunker = config.chat_chunker["config"]

def get_scene_data_info(self, scene_data: list, type: str) -> list[str]:
"""
Get raw information from scene_data.
If scene_data contains dictionaries, convert them to strings.
If scene_data contains file paths, parse them using the parser.

Args:
scene_data: List of dialogue information or document paths
type: Type of scene data: ['doc', 'chat']
Returns:
List of strings containing the processed scene data
"""
results = []

if type == "chat":
if self.chat_chunker["chunk_type"] == "content_length":
content_len_thredshold = self.chat_chunker["chunk_length"]
for items in scene_data:
if not items:
continue

results.append([])
current_length = 0

for _i, item in enumerate(items):
content_length = (
len(item.get("content", ""))
if isinstance(item, dict)
else len(str(item))
)
if not results[-1]:
results[-1].append(item)
current_length = content_length
continue

if current_length + content_length <= content_len_thredshold:
results[-1].append(item)
current_length += content_length
else:
overlap_item = results[-1][-1]
overlap_length = (
len(overlap_item.get("content", ""))
if isinstance(overlap_item, dict)
else len(str(overlap_item))
)

results.append([overlap_item, item])
current_length = overlap_length + content_length
elif type == "doc":
parser_config = ParserConfigFactory.model_validate(
{
"backend": "markitdown",
"config": {},
}
)
parser = ParserFactory.from_config(parser_config)
for item in scene_data:
try:
if os.path.exists(item):
try:
parsed_text = parser.parse(item)
results.append({"file": item, "text": parsed_text})
except Exception as e:
logger.error(f"[SceneParser] Error parsing {item}: {e}")
continue
else:
parsed_text = item
results.append({"file": "pure_text", "text": parsed_text})
except Exception as e:
print(f"Error parsing file {item}: {e!s}")

return results
18 changes: 18 additions & 0 deletions src/memos/memories/textual/simple_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.memories.textual.tree import TreeTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.reranker.base import BaseReranker
from memos.types import MessageList
Expand Down Expand Up @@ -62,6 +63,19 @@ def __init__(
self.graph_store: Neo4jGraphDB = graph_db
logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}")

time_start_bm = time.time()
self.search_strategy = config.search_strategy
self.bm25_retriever = (
EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None
)
logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}")

self.vec_cot = (
self.search_strategy["cot"]
if self.search_strategy and "cot" in self.search_strategy
else False
)

time_start_rr = time.time()
self.reranker = reranker
logger.info(f"time init: reranker time is: {time.time() - time_start_rr}")
Expand Down Expand Up @@ -144,17 +158,21 @@ def search(
self.graph_store,
self.embedder,
self.reranker,
bm25_retriever=self.bm25_retriever,
internet_retriever=None,
moscube=moscube,
vec_cot=self.vec_cot,
)
else:
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
self.embedder,
self.reranker,
bm25_retriever=self.bm25_retriever,
internet_retriever=self.internet_retriever,
moscube=moscube,
vec_cot=self.vec_cot,
)
return searcher.search(
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
Expand Down
Loading
Loading