Skip to content
Merged
38 changes: 24 additions & 14 deletions src/memos/multi_mem_cube/composite_cube.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -46,21 +47,30 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
"tool_mem": [],
}

for view in self.cube_views:
def _search_single_cube(view: SingleCubeView) -> dict[str, Any]:
self.logger.info(f"[CompositeCubeView] fan-out search to cube={view.cube_id}")
cube_result = view.search_memories(search_req)
merged_results["text_mem"].extend(cube_result.get("text_mem", []))
merged_results["act_mem"].extend(cube_result.get("act_mem", []))
merged_results["para_mem"].extend(cube_result.get("para_mem", []))
merged_results["pref_mem"].extend(cube_result.get("pref_mem", []))
merged_results["tool_mem"].extend(cube_result.get("tool_mem", []))

note = cube_result.get("pref_note")
if note:
if merged_results["pref_note"]:
merged_results["pref_note"] += " | " + note
else:
merged_results["pref_note"] = note
return view.search_memories(search_req)

# parallel search for each cube
with ThreadPoolExecutor(max_workers=2) as executor:
future_to_view = {
executor.submit(_search_single_cube, view): view for view in self.cube_views
}

for future in as_completed(future_to_view):
cube_result = future.result()
merged_results["text_mem"].extend(cube_result.get("text_mem", []))
merged_results["act_mem"].extend(cube_result.get("act_mem", []))
merged_results["para_mem"].extend(cube_result.get("para_mem", []))
merged_results["pref_mem"].extend(cube_result.get("pref_mem", []))
merged_results["tool_mem"].extend(cube_result.get("tool_mem", []))

note = cube_result.get("pref_note")
if note:
if merged_results["pref_note"]:
merged_results["pref_note"] += " | " + note
else:
merged_results["pref_note"] = note

return merged_results

Expand Down