Skip to content
2 changes: 1 addition & 1 deletion evaluation/scripts/PrefEval/pref_memos.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def search_memory_for_line(line_data, mem_client, top_k_value):
f"- {entry.get('memory', '')}"
for entry in relevant_memories["text_mem"][0]["memories"]
)
+ f"\n{relevant_memories['pref_string']}"
+ f"\n{relevant_memories.get('pref_string', '')}"
)

memory_tokens_used = len(tokenizer.encode(memories_str))
Expand Down
4 changes: 2 additions & 2 deletions evaluation/scripts/locomo/locomo_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def memos_api_search(

speaker_a_context = (
"\n".join([i["memory"] for i in search_a_results["text_mem"][0]["memories"]])
+ f"\n{search_a_results['pref_string']}"
+ f"\n{search_a_results.get('pref_string', '')}"
)
speaker_b_context = (
"\n".join([i["memory"] for i in search_b_results["text_mem"][0]["memories"]])
+ f"\n{search_b_results['pref_string']}"
+ f"\n{search_b_results.get('pref_string', '')}"
)

context = TEMPLATE_MEMOS.format(
Expand Down
2 changes: 1 addition & 1 deletion evaluation/scripts/longmemeval/lme_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def memos_search(client, query, user_id, top_k):
results = client.search(query=query, user_id=user_id, top_k=top_k)
context = (
"\n".join([i["memory"] for i in results["text_mem"][0]["memories"]])
+ f"\n{results['pref_string']}"
+ f"\n{results.get('pref_string', '')}"
)
context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context)
duration_ms = (time() - start) * 1000
Expand Down
4 changes: 2 additions & 2 deletions evaluation/scripts/personamem/pm_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def memos_search(client, user_id, query, top_k):
start = time()
results = client.search(query=query, user_id=user_id, top_k=top_k)
search_memories = (
"\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"])
+ f"\n{results['pref_string']}"
"\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"])
+ f"\n{results.get('pref_string', '')}"
)
context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories)

Expand Down
10 changes: 6 additions & 4 deletions evaluation/scripts/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def search(self, query, user_id, top_k):
"conversation_id": "",
"top_k": top_k,
"mode": os.getenv("SEARCH_MODE", "fast"),
"handle_pref_mem": False,
"include_preference": True,
"pref_top_k": 6,
},
ensure_ascii=False,
)
Expand Down Expand Up @@ -344,9 +345,10 @@ def wait_for_completion(self, task_id):
query = "杭州西湖有什么"
top_k = 5

# MEMOBASE
client = MemobaseClient()
# MEMOS-API
client = MemosApiClient()
for m in messages:
m["created_at"] = iso_date
client.add(messages, user_id)
client.add(messages, user_id, user_id)
memories = client.search(query, user_id, top_k)
print(memories)
2 changes: 1 addition & 1 deletion src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def get_preference_memory_config() -> dict[str, Any]:
return {
"backend": "pref_text",
"config": {
"extractor_llm": {"backend": "openai", "config": APIConfig.get_openai_config()},
"extractor_llm": APIConfig.get_memreader_config(),
"vector_db": {
"backend": "milvus",
"config": APIConfig.get_milvus_config(),
Expand Down
3 changes: 2 additions & 1 deletion src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ class APISearchRequest(BaseRequest):
operation: list[PermissionDict] | None = Field(
None, description="operation ids for multi cubes"
)
handle_pref_mem: bool = Field(False, description="Whether to handle preference memory")
include_preference: bool = Field(True, description="Whether to handle preference memory")
pref_top_k: int = Field(6, description="Number of preference results to return")


class APIADDRequest(BaseRequest):
Expand Down
16 changes: 10 additions & 6 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,18 @@ def _post_process_pref_mem(
memories_result: list[dict[str, Any]],
pref_formatted_mem: list[dict[str, Any]],
mem_cube_id: str,
handle_pref_mem: bool,
include_preference: bool,
):
if handle_pref_mem:
if include_preference:
memories_result["pref_mem"].append(
{
"cube_id": mem_cube_id,
"memories": pref_formatted_mem,
}
)
pref_instruction: str = instruct_completion(pref_formatted_mem)
pref_instruction, pref_note = instruct_completion(pref_formatted_mem)
memories_result["pref_string"] = pref_instruction
memories_result["pref_note"] = pref_note

return memories_result

Expand All @@ -354,7 +355,7 @@ def search_memories(search_req: APISearchRequest):
"act_mem": [],
"para_mem": [],
"pref_mem": [],
"pref_string": "",
"pref_note": "",
}

search_mode = search_req.mode
Expand Down Expand Up @@ -382,7 +383,7 @@ def _search_pref():
return []
results = naive_mem_cube.pref_mem.search(
query=search_req.query,
top_k=search_req.top_k,
top_k=search_req.pref_top_k,
info={
"user_id": search_req.user_id,
"session_id": search_req.session_id,
Expand All @@ -405,7 +406,10 @@ def _search_pref():
)

memories_result = _post_process_pref_mem(
memories_result, pref_formatted_memories, search_req.mem_cube_id, search_req.handle_pref_mem
memories_result,
pref_formatted_memories,
search_req.mem_cube_id,
search_req.include_preference,
)

return SearchResponse(
Expand Down
16 changes: 11 additions & 5 deletions src/memos/templates/instruction_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def instruct_completion(
memories: list[dict[str, Any]] | None = None,
) -> str:
) -> [str, str]:
"""Create instruction following the preferences."""
explicit_pref = []
implicit_pref = []
Expand Down Expand Up @@ -49,10 +49,16 @@ def instruct_completion(
lang = detect_lang(explicit_pref_str + implicit_pref_str)

if not explicit_pref_str and not implicit_pref_str:
return ""
return "", ""
if not explicit_pref_str:
return implicit_pref_str + "\n" + _prompt_map[lang].replace(_remove_exp_map[lang], "")
pref_note = _prompt_map[lang].replace(_remove_exp_map[lang], "")
pref_string = implicit_pref_str + "\n" + pref_note
return pref_string, pref_note
if not implicit_pref_str:
return explicit_pref_str + "\n" + _prompt_map[lang].replace(_remove_imp_map[lang], "")
pref_note = _prompt_map[lang].replace(_remove_imp_map[lang], "")
pref_string = explicit_pref_str + "\n" + pref_note
return pref_string, pref_note

return explicit_pref_str + "\n" + implicit_pref_str + "\n" + _prompt_map[lang]
pref_note = _prompt_map[lang]
pref_string = explicit_pref_str + "\n" + implicit_pref_str + "\n" + pref_note
return pref_string, pref_note