From dd7828ad090ecc29bd3141417b5df86d7f307a24 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 19 Aug 2025 15:45:31 -0700 Subject: [PATCH] always add indices to the sample --- viscy/data/triplet.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 72828af42..e4167d117 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -309,17 +309,16 @@ def __getitems__(self, indices: list[int]) -> list[TripletSample]: if self.return_negative: for sample, negative_patch in zip(samples, negative_patches): sample["negative"] = negative_patch - else: - for sample, (_, anchor_row) in zip(samples, anchor_rows.iterrows()): - # For new predictions, ensure all INDEX_COLUMNS are included - index_dict = {} - for col in INDEX_COLUMNS: - if col in anchor_row.index: - index_dict[col] = anchor_row[col] - elif col not in ["y", "x", "z"]: - # Skip y and x for legacy data - they weren't part of INDEX_COLUMNS - raise KeyError(f"Required column '{col}' not found in data") - sample["index"] = index_dict + for sample, (_, anchor_row) in zip(samples, anchor_rows.iterrows()): + # For new predictions, ensure all INDEX_COLUMNS are included + index_dict = {} + for col in INDEX_COLUMNS: + if col in anchor_row.index: + index_dict[col] = anchor_row[col] + elif col not in ["y", "x", "z"]: + # Skip y and x for legacy data - they weren't part of INDEX_COLUMNS + raise KeyError(f"Required column '{col}' not found in data") + sample["index"] = index_dict return samples