diff --git a/nle/dataset/dataset.py b/nle/dataset/dataset.py index cb54206db..665967609 100644 --- a/nle/dataset/dataset.py +++ b/nle/dataset/dataset.py @@ -1,6 +1,5 @@ import os import sqlite3 -import threading from collections import defaultdict from functools import partial @@ -74,11 +73,11 @@ def convert_frames( def _ttyrec_generator( - batch_size, seq_length, rows, cols, load_fn, map_fn, ttyrec_version + batch_size, seq_length, rows, cols, load_fns, map_fn, ttyrec_version ): """A generator to fill minibatches with ttyrecs. - :param load_fn: a function to load the next ttyrec into a converter. + :param load_fns: a list of functions to load the next ttyrec into a converter. load_fn(ttyrecs.Converter conv) -> bool is_success :param map_fn: a function that maps a series of iterables through a fn. map_fn(fn, *iterables) -> (can use built-in map) @@ -110,17 +109,18 @@ def _ttyrec_generator( converters = [ converter.Converter(rows, cols, ttyrec_version) for _ in range(batch_size) ] - assert all(load_fn(c) for c in converters), "Not enough ttyrecs to fill a batch!" + assert all( + fn(c) for fn, c in zip(load_fns, converters) + ), "Not enough ttyrecs to fill a batch!" # Convert (at least one minibatch) - _convert_frames = partial(convert_frames, load_fn=load_fn) gameids[0, -1] = 1 # basically creating a "do-while" loop by setting an indicator while np.any( gameids[:, -1] != 0 ): # loop until only padding is found, i.e. end of data list( map_fn( - _convert_frames, + convert_frames, converters, chars, colors, @@ -130,6 +130,7 @@ def _ttyrec_generator( scores, resets, gameids, + load_fns, ) ) @@ -287,9 +288,17 @@ def populate_metadata(self): self._meta[row[0]].append(row) self._meta_cols = [desc[0] for desc in c.description] - def _make_load_fn(self, gameids): + def _make_load_fns(self, gameids, batch_size): + """Create one closure per batch dimension.""" + load_fns = [] + for i in range(batch_size): + # Deterministically select a subset of games for this dimension. + my_gameids = gameids[i::batch_size] + load_fns.append(self._make_one_load_fn(my_gameids)) + return load_fns + + def _make_one_load_fn(self, gameids): """Make a closure to load the next gameid from the db into the converter.""" - lock = threading.Lock() count = [0] def _load_fn(converter): @@ -300,9 +309,8 @@ def _load_fn(converter): files = self.get_paths(gameid) if gameid == 0 or part >= len(files): - with lock: - i = count[0] - count[0] += 1 + i = count[0] + count[0] += 1 if (not self.loop_forever) and i >= len(gameids): return False @@ -323,12 +331,14 @@ def __iter__(self): if self.shuffle: np.random.shuffle(gameids) + load_fns = self._make_load_fns(gameids, self.batch_size) + return _ttyrec_generator( self.batch_size, self.seq_length, self.rows, self.cols, - self._make_load_fn(gameids), + load_fns, self._map, self._ttyrec_version, ) @@ -336,13 +346,16 @@ def __iter__(self): def get_ttyrecs(self, gameids, chunk_size=None): """Fetch data from a single episode, chunked into a sequence of tensors.""" seq_length = chunk_size or self.seq_length + batch_size = len(gameids) + load_fns = self._make_load_fns(gameids, batch_size) + mbs = [] for mb in _ttyrec_generator( - len(gameids), + batch_size, seq_length, self.rows, self.cols, - self._make_load_fn(gameids), + load_fns, self._map, self._ttyrec_version, ): diff --git a/nle/tests/test_dataset.py b/nle/tests/test_dataset.py index fbd183458..fd66015f2 100644 --- a/nle/tests/test_dataset.py +++ b/nle/tests/test_dataset.py @@ -61,7 +61,8 @@ def test_minibatches(self, db_exists, pool): seq_length=50, batch_size=4, threadpool=pool, - gameids=range(1, 8), + # Repeated gameids ensure enough data for batches and resets + gameids=[1, 4, 2, 5, 3, 6, 7, 1], shuffle=False, ) # starting gameids = [TTYREC, TTYREC, TTYREC, TTYREC2] @@ -70,23 +71,28 @@ def test_minibatches(self, db_exists, pool): for name, array in mb.items(): if name in ("gameids",): continue - # Test first three rows are the same, and differ from from fourth - np.testing.assert_array_equal(array[0], array[1]) - np.testing.assert_array_equal(array[0], array[2]) + # Check data for different batch dimensions are different np.testing.assert_raises( - AssertionError, np.testing.assert_array_equal, array[0], array[3] + AssertionError, np.testing.assert_array_equal, array[0], array[1] ) # Check reseting occured reset = np.where(mb["done"][3] == 1)[0][0] assert reset == 31 - # Check the data at location is the same. Note reset occurs for batch 4 + # Check that the data after the reset is from a new game. + # With deterministic loading, the worker for this batch dimension will switch + # to a new game after the reset, so the data should be different. seq = 10 for name, array in mb.items(): if name in ("done", "gameids"): continue - np.testing.assert_array_equal(array[3][:seq], array[3][reset : reset + seq]) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + array[3][:seq], + array[3][reset : reset + seq], + ) # No leading 1s assert (mb["done"][:, 0] == 0).all() diff --git a/nle/tests/test_db.py b/nle/tests/test_db.py index 770cada33..36c769863 100644 --- a/nle/tests/test_db.py +++ b/nle/tests/test_db.py @@ -133,6 +133,7 @@ def conn(mockdata): yield conn +@pytest.mark.usefixtures("mockdata") class TestDB: def test_conn(self, conn): assert conn diff --git a/pyproject.toml b/pyproject.toml index d1ff3031e..091654575 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,15 +58,16 @@ dev = ["nle[dev]"] [project.optional-dependencies] agent = ["torch>=1.3.1"] dev = [ - "pre-commit>=2.0.1", "cmake_format>=0.6.10", "memory-profiler>=0.60.0", - "pytest>=6.2.5", + "pre-commit>=2.0.1", "pytest-benchmark>=3.4.1", - "sphinx>=2.4.4", - "sphinx-rtd-theme>=0.4.3", - "setuptools>=69.5.1", + "pytest-xdist>=3.8.0", + "pytest>=6.2.5", "ruff==0.4.3", + "setuptools>=69.5.1", + "sphinx-rtd-theme>=0.4.3", + "sphinx>=2.4.4", ] all = ["nle[agent,dev]"] @@ -115,3 +116,4 @@ before-all = "rm -rf {project}/build {project}/*.so {project}/CMakeCache.txt && [tool.pytest.ini_options] testpaths = ["nle/tests"] +addopts = "-n auto" diff --git a/uv.lock b/uv.lock index 170285647..571e70a81 100644 --- a/uv.lock +++ b/uv.lock @@ -220,6 +220,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] +[[package]] +name = "execnet" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd", size = 166622, upload-time = "2025-11-12T09:56:37.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, +] + [[package]] name = "farama-notifications" version = "0.0.4" @@ -462,6 +471,7 @@ all = [ { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-benchmark" }, + { name = "pytest-xdist" }, { name = "ruff" }, { name = "setuptools" }, { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -475,6 +485,7 @@ dev = [ { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-benchmark" }, + { name = "pytest-xdist" }, { name = "ruff" }, { name = "setuptools" }, { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -498,6 +509,7 @@ requires-dist = [ { name = "pybind11", specifier = ">=2.2" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=6.2.5" }, { name = "pytest-benchmark", marker = "extra == 'dev'", specifier = ">=3.4.1" }, + { name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.8.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = "==0.4.3" }, { name = "setuptools", marker = "extra == 'dev'", specifier = ">=69.5.1" }, { name = "sphinx", marker = "extra == 'dev'", specifier = ">=2.4.4" }, @@ -929,6 +941,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/29/e756e715a48959f1c0045342088d7ca9762a2f509b945f362a316e9412b7/pytest_benchmark-5.2.3-py3-none-any.whl", hash = "sha256:bc839726ad20e99aaa0d11a127445457b4219bdb9e80a1afc4b51da7f96b0803", size = 45255, upload-time = "2025-11-09T18:48:39.765Z" }, ] +[[package]] +name = "pytest-xdist" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3"