From 15c6daca5891a38406ad84bb3661830d5a36a4ea Mon Sep 17 00:00:00 2001 From: "Jonathan B. Coe" Date: Sun, 8 Mar 2026 23:38:41 +0000 Subject: [PATCH] Ensure deterministic data loading This commit resolves a non-determinism issue in `TtyrecDataset` where the order of loaded games could vary between runs. The previous implementation used a single data loading function with a shared lock (`threading.Lock`) for all workers. This created a race condition where multiple threads competed to fetch the next game, resulting in an unpredictable data order. The fix introduces a deterministic assignment of games to each worker: - A new `_make_load_fns` method creates a separate, dedicated data loading function for each batch dimension. - Each function is assigned a unique, non-overlapping sequence of games by striding over the global `gameids` list. This eliminates the need for a shared lock and guarantees a deterministic loading order. The corresponding tests in `test_minibatches` have also been updated to validate this new, deterministic behavior. --- nle/dataset/dataset.py | 41 ++++++++++++++++++++++++++------------- nle/tests/test_dataset.py | 20 ++++++++++++------- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/nle/dataset/dataset.py b/nle/dataset/dataset.py index cb54206db..665967609 100644 --- a/nle/dataset/dataset.py +++ b/nle/dataset/dataset.py @@ -1,6 +1,5 @@ import os import sqlite3 -import threading from collections import defaultdict from functools import partial @@ -74,11 +73,11 @@ def convert_frames( def _ttyrec_generator( - batch_size, seq_length, rows, cols, load_fn, map_fn, ttyrec_version + batch_size, seq_length, rows, cols, load_fns, map_fn, ttyrec_version ): """A generator to fill minibatches with ttyrecs. - :param load_fn: a function to load the next ttyrec into a converter. + :param load_fns: a list of functions to load the next ttyrec into a converter. load_fn(ttyrecs.Converter conv) -> bool is_success :param map_fn: a function that maps a series of iterables through a fn. map_fn(fn, *iterables) -> (can use built-in map) @@ -110,17 +109,18 @@ def _ttyrec_generator( converters = [ converter.Converter(rows, cols, ttyrec_version) for _ in range(batch_size) ] - assert all(load_fn(c) for c in converters), "Not enough ttyrecs to fill a batch!" + assert all( + fn(c) for fn, c in zip(load_fns, converters) + ), "Not enough ttyrecs to fill a batch!" # Convert (at least one minibatch) - _convert_frames = partial(convert_frames, load_fn=load_fn) gameids[0, -1] = 1 # basically creating a "do-while" loop by setting an indicator while np.any( gameids[:, -1] != 0 ): # loop until only padding is found, i.e. end of data list( map_fn( - _convert_frames, + convert_frames, converters, chars, colors, @@ -130,6 +130,7 @@ def _ttyrec_generator( scores, resets, gameids, + load_fns, ) ) @@ -287,9 +288,17 @@ def populate_metadata(self): self._meta[row[0]].append(row) self._meta_cols = [desc[0] for desc in c.description] - def _make_load_fn(self, gameids): + def _make_load_fns(self, gameids, batch_size): + """Create one closure per batch dimension.""" + load_fns = [] + for i in range(batch_size): + # Deterministically select a subset of games for this dimension. + my_gameids = gameids[i::batch_size] + load_fns.append(self._make_one_load_fn(my_gameids)) + return load_fns + + def _make_one_load_fn(self, gameids): """Make a closure to load the next gameid from the db into the converter.""" - lock = threading.Lock() count = [0] def _load_fn(converter): @@ -300,9 +309,8 @@ def _load_fn(converter): files = self.get_paths(gameid) if gameid == 0 or part >= len(files): - with lock: - i = count[0] - count[0] += 1 + i = count[0] + count[0] += 1 if (not self.loop_forever) and i >= len(gameids): return False @@ -323,12 +331,14 @@ def __iter__(self): if self.shuffle: np.random.shuffle(gameids) + load_fns = self._make_load_fns(gameids, self.batch_size) + return _ttyrec_generator( self.batch_size, self.seq_length, self.rows, self.cols, - self._make_load_fn(gameids), + load_fns, self._map, self._ttyrec_version, ) @@ -336,13 +346,16 @@ def __iter__(self): def get_ttyrecs(self, gameids, chunk_size=None): """Fetch data from a single episode, chunked into a sequence of tensors.""" seq_length = chunk_size or self.seq_length + batch_size = len(gameids) + load_fns = self._make_load_fns(gameids, batch_size) + mbs = [] for mb in _ttyrec_generator( - len(gameids), + batch_size, seq_length, self.rows, self.cols, - self._make_load_fn(gameids), + load_fns, self._map, self._ttyrec_version, ): diff --git a/nle/tests/test_dataset.py b/nle/tests/test_dataset.py index fbd183458..fd66015f2 100644 --- a/nle/tests/test_dataset.py +++ b/nle/tests/test_dataset.py @@ -61,7 +61,8 @@ def test_minibatches(self, db_exists, pool): seq_length=50, batch_size=4, threadpool=pool, - gameids=range(1, 8), + # Repeated gameids ensure enough data for batches and resets + gameids=[1, 4, 2, 5, 3, 6, 7, 1], shuffle=False, ) # starting gameids = [TTYREC, TTYREC, TTYREC, TTYREC2] @@ -70,23 +71,28 @@ def test_minibatches(self, db_exists, pool): for name, array in mb.items(): if name in ("gameids",): continue - # Test first three rows are the same, and differ from from fourth - np.testing.assert_array_equal(array[0], array[1]) - np.testing.assert_array_equal(array[0], array[2]) + # Check data for different batch dimensions are different np.testing.assert_raises( - AssertionError, np.testing.assert_array_equal, array[0], array[3] + AssertionError, np.testing.assert_array_equal, array[0], array[1] ) # Check reseting occured reset = np.where(mb["done"][3] == 1)[0][0] assert reset == 31 - # Check the data at location is the same. Note reset occurs for batch 4 + # Check that the data after the reset is from a new game. + # With deterministic loading, the worker for this batch dimension will switch + # to a new game after the reset, so the data should be different. seq = 10 for name, array in mb.items(): if name in ("done", "gameids"): continue - np.testing.assert_array_equal(array[3][:seq], array[3][reset : reset + seq]) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + array[3][:seq], + array[3][reset : reset + seq], + ) # No leading 1s assert (mb["done"][:, 0] == 0).all()