Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions nle/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import sqlite3
import threading
from collections import defaultdict
from functools import partial

Expand Down Expand Up @@ -74,11 +73,11 @@ def convert_frames(


def _ttyrec_generator(
batch_size, seq_length, rows, cols, load_fn, map_fn, ttyrec_version
batch_size, seq_length, rows, cols, load_fns, map_fn, ttyrec_version
):
"""A generator to fill minibatches with ttyrecs.

:param load_fn: a function to load the next ttyrec into a converter.
:param load_fns: a list of functions to load the next ttyrec into a converter.
load_fn(ttyrecs.Converter conv) -> bool is_success
:param map_fn: a function that maps a series of iterables through a fn.
map_fn(fn, *iterables) -> <generator> (can use built-in map)
Expand Down Expand Up @@ -110,17 +109,18 @@ def _ttyrec_generator(
converters = [
converter.Converter(rows, cols, ttyrec_version) for _ in range(batch_size)
]
assert all(load_fn(c) for c in converters), "Not enough ttyrecs to fill a batch!"
assert all(
fn(c) for fn, c in zip(load_fns, converters)
), "Not enough ttyrecs to fill a batch!"

# Convert (at least one minibatch)
_convert_frames = partial(convert_frames, load_fn=load_fn)
gameids[0, -1] = 1 # basically creating a "do-while" loop by setting an indicator
while np.any(
gameids[:, -1] != 0
): # loop until only padding is found, i.e. end of data
list(
map_fn(
_convert_frames,
convert_frames,
converters,
chars,
colors,
Expand All @@ -130,6 +130,7 @@ def _ttyrec_generator(
scores,
resets,
gameids,
load_fns,
)
)

Expand Down Expand Up @@ -287,9 +288,17 @@ def populate_metadata(self):
self._meta[row[0]].append(row)
self._meta_cols = [desc[0] for desc in c.description]

def _make_load_fn(self, gameids):
def _make_load_fns(self, gameids, batch_size):
"""Create one closure per batch dimension."""
load_fns = []
for i in range(batch_size):
# Deterministically select a subset of games for this dimension.
my_gameids = gameids[i::batch_size]
load_fns.append(self._make_one_load_fn(my_gameids))
return load_fns

def _make_one_load_fn(self, gameids):
"""Make a closure to load the next gameid from the db into the converter."""
lock = threading.Lock()
count = [0]

def _load_fn(converter):
Expand All @@ -300,9 +309,8 @@ def _load_fn(converter):

files = self.get_paths(gameid)
if gameid == 0 or part >= len(files):
with lock:
i = count[0]
count[0] += 1
i = count[0]
count[0] += 1

if (not self.loop_forever) and i >= len(gameids):
return False
Expand All @@ -323,26 +331,31 @@ def __iter__(self):
if self.shuffle:
np.random.shuffle(gameids)

load_fns = self._make_load_fns(gameids, self.batch_size)

return _ttyrec_generator(
self.batch_size,
self.seq_length,
self.rows,
self.cols,
self._make_load_fn(gameids),
load_fns,
self._map,
self._ttyrec_version,
)

def get_ttyrecs(self, gameids, chunk_size=None):
"""Fetch data from a single episode, chunked into a sequence of tensors."""
seq_length = chunk_size or self.seq_length
batch_size = len(gameids)
load_fns = self._make_load_fns(gameids, batch_size)

mbs = []
for mb in _ttyrec_generator(
len(gameids),
batch_size,
seq_length,
self.rows,
self.cols,
self._make_load_fn(gameids),
load_fns,
self._map,
self._ttyrec_version,
):
Expand Down
20 changes: 13 additions & 7 deletions nle/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def test_minibatches(self, db_exists, pool):
seq_length=50,
batch_size=4,
threadpool=pool,
gameids=range(1, 8),
# Repeated gameids ensure enough data for batches and resets
gameids=[1, 4, 2, 5, 3, 6, 7, 1],
shuffle=False,
)
# starting gameids = [TTYREC, TTYREC, TTYREC, TTYREC2]
Expand All @@ -70,23 +71,28 @@ def test_minibatches(self, db_exists, pool):
for name, array in mb.items():
if name in ("gameids",):
continue
# Test first three rows are the same, and differ from from fourth
np.testing.assert_array_equal(array[0], array[1])
np.testing.assert_array_equal(array[0], array[2])
# Check data for different batch dimensions are different
np.testing.assert_raises(
AssertionError, np.testing.assert_array_equal, array[0], array[3]
AssertionError, np.testing.assert_array_equal, array[0], array[1]
)

# Check reseting occured
reset = np.where(mb["done"][3] == 1)[0][0]
assert reset == 31

# Check the data at location is the same. Note reset occurs for batch 4
# Check that the data after the reset is from a new game.
# With deterministic loading, the worker for this batch dimension will switch
# to a new game after the reset, so the data should be different.
seq = 10
for name, array in mb.items():
if name in ("done", "gameids"):
continue
np.testing.assert_array_equal(array[3][:seq], array[3][reset : reset + seq])
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
array[3][:seq],
array[3][reset : reset + seq],
)

# No leading 1s
assert (mb["done"][:, 0] == 0).all()
Expand Down
1 change: 1 addition & 0 deletions nle/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def conn(mockdata):
yield conn


@pytest.mark.usefixtures("mockdata")
class TestDB:
def test_conn(self, conn):
assert conn
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ dev = ["nle[dev]"]
[project.optional-dependencies]
agent = ["torch>=1.3.1"]
dev = [
"pre-commit>=2.0.1",
"cmake_format>=0.6.10",
"memory-profiler>=0.60.0",
"pytest>=6.2.5",
"pre-commit>=2.0.1",
"pytest-benchmark>=3.4.1",
"sphinx>=2.4.4",
"sphinx-rtd-theme>=0.4.3",
"setuptools>=69.5.1",
"pytest-xdist>=3.8.0",
"pytest>=6.2.5",
"ruff==0.4.3",
"setuptools>=69.5.1",
"sphinx-rtd-theme>=0.4.3",
"sphinx>=2.4.4",
]
all = ["nle[agent,dev]"]

Expand Down Expand Up @@ -115,3 +116,4 @@ before-all = "rm -rf {project}/build {project}/*.so {project}/CMakeCache.txt &&

[tool.pytest.ini_options]
testpaths = ["nle/tests"]
addopts = "-n auto"
25 changes: 25 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading