From 989853cfde4e0f991d3b5c97509508c051b5edc4 Mon Sep 17 00:00:00 2001 From: ParisT97 Date: Tue, 12 May 2026 18:02:40 -0400 Subject: [PATCH] fix: add user_ids kwarg to LoCoMo/MemBench/MemSim/PersonaMem load_documents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The base class Dataset.load_documents() declares ``user_ids: set[str] | None = None`` as part of its abstract contract. The runner at ``src/memory_bench/runner.py:128`` always passes ``user_ids=query_user_ids`` when ``dataset.isolation_unit is not None AND query_limit is not None``, but 4 of 7 concrete dataset overrides forgot to include the parameter in their signature: - LoComoDataset (isolation_unit='conversation') — actively fires ``TypeError: LoComoDataset.load_documents() got an unexpected keyword argument 'user_ids'`` on ``omb run --dataset locomo --memory bm25 --split locomo10 --query-limit N``. - MemBenchDataset, MemSimDataset, PersonaMemDataset (isolation_unit=None) — latent signature-contract violations that would surface if upstream ever sets their isolation_unit. This change: - LoComoDataset: add ``user_ids`` to the signature AND apply the filter ``if user_ids is not None and sample_id not in user_ids: continue`` — matches the reference pattern in BEAMDataset / LongMemEvalDataset / LifeBenchDataset (which all filter by their own per-doc id and were already correct). - MemBench / MemSim / PersonaMem: signature-only fixes; the runner never passes ``user_ids`` to them today (their isolation_unit is ``None``), so adding filter logic without a trigger path would be premature. The signature parity matches the base-class abstract method. Live-validated downstream by re-running ``omb run --dataset locomo --memory bm25 --split locomo10 --query-limit 3`` — the TypeError is gone; the run progresses to answer generation. --- src/memory_bench/dataset/locomo.py | 3 +++ src/memory_bench/dataset/membench.py | 1 + src/memory_bench/dataset/memsim.py | 1 + src/memory_bench/dataset/personamem.py | 1 + 4 files changed, 6 insertions(+) diff --git a/src/memory_bench/dataset/locomo.py b/src/memory_bench/dataset/locomo.py index 535985d..a2403a3 100644 --- a/src/memory_bench/dataset/locomo.py +++ b/src/memory_bench/dataset/locomo.py @@ -284,6 +284,7 @@ def load_documents( category: str | None = None, limit: int | None = None, ids: set[str] | None = None, + user_ids: set[str] | None = None, ) -> list[Document]: data = self._load_raw() documents: list[Document] = [] @@ -295,6 +296,8 @@ def load_documents( sample_id = item["sample_id"] if conv_filter is not None and sample_id != conv_filter: continue + if user_ids is not None and sample_id not in user_ids: + continue conv = item["conversation"] speaker_a = conv.get("speaker_a", "A") speaker_b = conv.get("speaker_b", "B") diff --git a/src/memory_bench/dataset/membench.py b/src/memory_bench/dataset/membench.py index cc1c769..6021a92 100644 --- a/src/memory_bench/dataset/membench.py +++ b/src/memory_bench/dataset/membench.py @@ -145,6 +145,7 @@ def load_documents( category: str | None = None, limit: int | None = None, ids: set[str] | None = None, + user_ids: set[str] | None = None, ) -> list[Document]: trajectories = self._load_trajectories(split) documents: list[Document] = [] diff --git a/src/memory_bench/dataset/memsim.py b/src/memory_bench/dataset/memsim.py index 441d200..094ae35 100644 --- a/src/memory_bench/dataset/memsim.py +++ b/src/memory_bench/dataset/memsim.py @@ -139,6 +139,7 @@ def load_documents( category: str | None = None, limit: int | None = None, ids: set[str] | None = None, + user_ids: set[str] | None = None, ) -> list[Document]: trajectories = self._load_trajectories(split) documents: list[Document] = [] diff --git a/src/memory_bench/dataset/personamem.py b/src/memory_bench/dataset/personamem.py index c81a41e..3ca2a5c 100644 --- a/src/memory_bench/dataset/personamem.py +++ b/src/memory_bench/dataset/personamem.py @@ -284,6 +284,7 @@ def load_documents( category: str | None = None, limit: int | None = None, ids: set[str] | None = None, + user_ids: set[str] | None = None, ) -> list[Document]: sessions_by_ctx = self._load_sessions(split) documents: list[Document] = []