diff --git a/src/memory_bench/dataset/locomo.py b/src/memory_bench/dataset/locomo.py index 535985d..a2403a3 100644 --- a/src/memory_bench/dataset/locomo.py +++ b/src/memory_bench/dataset/locomo.py @@ -284,6 +284,7 @@ def load_documents( category: str | None = None, limit: int | None = None, ids: set[str] | None = None, + user_ids: set[str] | None = None, ) -> list[Document]: data = self._load_raw() documents: list[Document] = [] @@ -295,6 +296,8 @@ def load_documents( sample_id = item["sample_id"] if conv_filter is not None and sample_id != conv_filter: continue + if user_ids is not None and sample_id not in user_ids: + continue conv = item["conversation"] speaker_a = conv.get("speaker_a", "A") speaker_b = conv.get("speaker_b", "B") diff --git a/src/memory_bench/dataset/membench.py b/src/memory_bench/dataset/membench.py index cc1c769..6021a92 100644 --- a/src/memory_bench/dataset/membench.py +++ b/src/memory_bench/dataset/membench.py @@ -145,6 +145,7 @@ def load_documents( category: str | None = None, limit: int | None = None, ids: set[str] | None = None, + user_ids: set[str] | None = None, ) -> list[Document]: trajectories = self._load_trajectories(split) documents: list[Document] = [] diff --git a/src/memory_bench/dataset/memsim.py b/src/memory_bench/dataset/memsim.py index 441d200..094ae35 100644 --- a/src/memory_bench/dataset/memsim.py +++ b/src/memory_bench/dataset/memsim.py @@ -139,6 +139,7 @@ def load_documents( category: str | None = None, limit: int | None = None, ids: set[str] | None = None, + user_ids: set[str] | None = None, ) -> list[Document]: trajectories = self._load_trajectories(split) documents: list[Document] = [] diff --git a/src/memory_bench/dataset/personamem.py b/src/memory_bench/dataset/personamem.py index c81a41e..3ca2a5c 100644 --- a/src/memory_bench/dataset/personamem.py +++ b/src/memory_bench/dataset/personamem.py @@ -284,6 +284,7 @@ def load_documents( category: str | None = None, limit: int | None = None, ids: set[str] | None = None, + user_ids: set[str] | None = None, ) -> list[Document]: sessions_by_ctx = self._load_sessions(split) documents: list[Document] = []