Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/memos/memories/textual/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,8 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata):
dialog_id: str | None = Field(default=None, description="ID of the dialog.")
original_text: str | None = Field(default=None, description="String of the dialog.")
embedding: list[float] | None = Field(default=None, description="Vector of the dialog.")
explicit_preference: str | None = Field(default=None, description="Explicit preference.")
preference: str | None = Field(default=None, description="Preference.")
created_at: str | None = Field(default=None, description="Timestamp of the dialog.")
implicit_preference: str | None = Field(default=None, description="Implicit preference.")


class TextualMemoryItem(BaseModel):
Expand Down
59 changes: 28 additions & 31 deletions src/memos/memories/textual/prefer_text_memory/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def _update_memory_op_trace(
new_memories: list[TextualMemoryItem],
retrieved_memories: list[MilvusVecDBItem],
collection_name: str,
preference_type: str,
) -> list[str] | str:
# create new vec db items
new_vec_db_items: list[MilvusVecDBItem] = []
Expand All @@ -124,17 +123,19 @@ def _update_memory_op_trace(
{
"id": new_memory.id,
"context_summary": new_memory.memory,
"preference": new_memory.payload[preference_type],
"preference": new_memory.payload["preference"],
}
for new_memory in new_vec_db_items
if new_memory.payload.get("preference", None)
]
retrieved_mem_inputs = [
{
"id": mem.id,
"context_summary": mem.memory,
"preference": mem.payload[preference_type],
"preference": mem.payload["preference"],
}
for mem in retrieved_memories
if mem.payload.get("preference", None)
]

rsp = self._judge_update_or_add_trace_op(
Expand Down Expand Up @@ -168,7 +169,7 @@ def execute_op(
elif op_type == "update":
if op["target_id"] in retrieved_mem_db_item_map:
update_mem_db_item = retrieved_mem_db_item_map[op["target_id"]]
update_mem_db_item.payload[preference_type] = op["new_preference"]
update_mem_db_item.payload["preference"] = op["new_preference"]
update_mem_db_item.payload["updated_at"] = datetime.now().isoformat()
update_mem_db_item.memory = op["new_context_summary"]
update_mem_db_item.original_text = op["new_context_summary"]
Expand Down Expand Up @@ -198,7 +199,6 @@ def _update_memory_fine(
new_memory: TextualMemoryItem,
retrieved_memories: list[MilvusVecDBItem],
collection_name: str,
preference_type: str,
) -> str:
payload = new_memory.to_dict()["metadata"]
fields_to_remove = {"dialog_id", "original_text", "embedding"}
Expand All @@ -211,19 +211,15 @@ def _update_memory_fine(
payload=payload,
)

new_mem_input = {
"memory": new_memory.memory,
"preference": new_memory.metadata.explicit_preference
if preference_type == "explicit_preference"
else new_memory.metadata.implicit_preference,
}
new_mem_input = {"memory": new_memory.memory, "preference": new_memory.metadata.preference}
retrieved_mem_inputs = [
{
"id": mem.id,
"memory": mem.memory,
"preference": mem.payload[preference_type],
"preference": mem.payload["preference"],
}
for mem in retrieved_memories
if mem.payload.get("preference", None)
]
rsp = self._judge_update_or_add_fine(
new_mem=json.dumps(new_mem_input),
Expand All @@ -240,7 +236,7 @@ def _update_memory_fine(
)
if need_update and update_item and rsp:
update_vec_db_item = update_item[0]
update_vec_db_item.payload[preference_type] = rsp["new_preference"]
update_vec_db_item.payload["preference"] = rsp["new_preference"]
update_vec_db_item.payload["updated_at"] = vec_db_item.payload["updated_at"]
update_vec_db_item.memory = rsp["new_memory"]
update_vec_db_item.original_text = vec_db_item.original_text
Expand Down Expand Up @@ -287,23 +283,19 @@ def _update_memory(
new_memory: TextualMemoryItem,
retrieved_memories: list[MilvusVecDBItem],
collection_name: str,
preference_type: str,
update_mode: str = "fast",
) -> list[str] | str | None:
"""Update the memory.
Args:
new_memory: TextualMemoryItem
retrieved_memories: list[MilvusVecDBItem]
collection_name: str
preference_type: str
update_mode: str, "fast" or "fine"
"""
if update_mode == "fast":
return self._update_memory_fast(new_memory, retrieved_memories, collection_name)
elif update_mode == "fine":
return self._update_memory_fine(
new_memory, retrieved_memories, collection_name, preference_type
)
return self._update_memory_fine(new_memory, retrieved_memories, collection_name)
else:
raise ValueError(f"Invalid update mode: {update_mode}")

Expand All @@ -330,7 +322,6 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str |
memory,
search_results,
collection_name,
preference_type,
update_mode=os.getenv("PREFERENCE_ADDER_MODE", "fast"),
)

Expand Down Expand Up @@ -369,18 +360,24 @@ def process_memory_batch(self, memories: list[TextualMemoryItem], *args, **kwarg
explicit_recalls = list({recall.id: recall for recall in explicit_recalls}.values())
implicit_recalls = list({recall.id: recall for recall in implicit_recalls}.values())

explicit_added_ids = self._update_memory_op_trace(
explicit_new_mems,
explicit_recalls,
pref_type_collection_map["explicit_preference"],
"explicit_preference",
)
implicit_added_ids = self._update_memory_op_trace(
implicit_new_mems,
implicit_recalls,
pref_type_collection_map["implicit_preference"],
"implicit_preference",
)
# 使用线程池并行处理显式和隐式偏好
with ContextThreadPoolExecutor(max_workers=2) as executor:
explicit_future = executor.submit(
self._update_memory_op_trace,
explicit_new_mems,
explicit_recalls,
pref_type_collection_map["explicit_preference"],
)
implicit_future = executor.submit(
self._update_memory_op_trace,
implicit_new_mems,
implicit_recalls,
pref_type_collection_map["implicit_preference"],
)

explicit_added_ids = explicit_future.result()
implicit_added_ids = implicit_future.result()

return explicit_added_ids + implicit_added_ids

def process_memory_single(
Expand Down
3 changes: 3 additions & 0 deletions src/memos/memories/textual/prefer_text_memory/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
response = response.strip().replace("```json", "").replace("```", "").strip()
result = json.loads(response)
for d in result:
d["preference"] = d.pop("explicit_preference")
return result
except Exception as e:
logger.error(f"Error extracting explicit preference: {e}, return None")
Expand All @@ -88,6 +90,7 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
response = response.strip().replace("```json", "").replace("```", "").strip()
result = json.loads(response)
result["preference"] = result.pop("implicit_preference")
return result
except Exception as e:
logger.error(f"Error extracting implicit preferences: {e}, return None")
Expand Down
4 changes: 2 additions & 2 deletions src/memos/memories/textual/prefer_text_memory/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def retrieve(
metadata=PreferenceTextualMemoryMetadata(**pref.payload),
)
for pref in explicit_prefs
if pref.payload["explicit_preference"]
if pref.payload.get("preference", None)
]

implicit_prefs_mem = [
Expand All @@ -116,7 +116,7 @@ def retrieve(
metadata=PreferenceTextualMemoryMetadata(**pref.payload),
)
for pref in implicit_prefs
if pref.payload["implicit_preference"]
if pref.payload.get("preference", None)
]

reranker_map = {
Expand Down
6 changes: 2 additions & 4 deletions src/memos/memories/textual/prefer_text_memory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ def deduplicate_preferences(

for i, pref in enumerate(prefs):
# Extract preference text
if hasattr(pref.metadata, "implicit_preference") and pref.metadata.implicit_preference:
text = pref.metadata.implicit_preference
elif hasattr(pref.metadata, "explicit_preference") and pref.metadata.explicit_preference:
text = pref.metadata.explicit_preference
if hasattr(pref.metadata, "preference") and pref.metadata.preference:
text = pref.metadata.preference
else:
text = pref.memory

Expand Down
11 changes: 5 additions & 6 deletions src/memos/templates/instruction_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ def instruct_completion(
implicit_pref = []
for memory in memories:
pref_type = memory.get("metadata", {}).get("preference_type")
pref = memory.get("metadata", {}).get("preference", None)
if not pref:
continue
if pref_type == "explicit_preference":
pref = memory.get("metadata", {}).get("explicit_preference", None)
if pref:
explicit_pref.append(pref)
explicit_pref.append(pref)
elif pref_type == "implicit_preference":
pref = memory.get("metadata", {}).get("implicit_preference", None)
if pref:
implicit_pref.append(pref)
implicit_pref.append(pref)

explicit_pref_str = (
"Explicit Preference:\n"
Expand Down
Loading