Skip to content

dataset with TFDS feature Sequence will crash when trying to load #14

@buckleytoby

Description

@buckleytoby

dataset generation code:

from envlogger.backends.tfds_backend_writer import *
from envlogger.step_data import *
import numpy as np
import dm_env
import tensorflow as tf
from os.path import expanduser

NB_FRAME = 3
SHAPE = (3, 3, 3)
HOME = expanduser("~")
DIRECTORY = HOME + "/data/test"

"""
Composite FeatureConnector for a dict where each value is a list.
"""

# a sequence of images, NB_FRAME long
tfds_features = tfds.features.Sequence(tfds.features.Image(shape=SHAPE), length=NB_FRAME)

observation = np.zeros((NB_FRAME, ) + SHAPE, dtype="uint8")

ds_config = tfds.rlds.rlds_base.DatasetConfig(
    name='test',
    observation_info=tfds_features,
    action_info=tf.float64,
    reward_info=tf.float64,
    discount_info=tf.float64  # default python type for 0.
)


writer = TFDSBackendWriter(data_directory=DIRECTORY,
                           split_name='train', # required
                           max_episodes_per_file=500,
                           ds_config=ds_config)
zero_float64 = 0.0  # np.array(0.0, dtype="float64")

# start episode
timestep = dm_env.restart(observation=observation)
data = StepData(timestep, zero_float64)
writer.record_step(data, True)

# transition episode
timestep = dm_env.transition(reward=zero_float64, observation=observation)
data = StepData(timestep, zero_float64)
writer.record_step(data, False)

# end episode
timestep = dm_env.termination(reward=zero_float64, observation=observation)
data = StepData(timestep, zero_float64)
writer.record_step(data, False)

# close
writer.close()

dataset reader code:

from os.path import expanduser

import tensorflow_datasets as tfds
import rlds

""" Parameters """
HOME = expanduser("~")
DIRECTORY = HOME + "/data/test"

# load the dataset
builder = tfds.builder_from_directory(DIRECTORY)
dataset = builder.as_dataset(split='all')

print("Nb episode: ", len(dataset))

# flatten dataset
dataset = dataset.flat_map(lambda episode: episode[rlds.STEPS])

nb_steps = rlds.transformations.episode_length(dataset).numpy()
print("Nb steps: ", nb_steps)

Error generated:

Exception has occurred: TypeError
Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'ragged_flat_values'
  File "/home/omnid/dexnex/ws_dexnex/src/ros2-to-rlds/ros2-to-rlds/test/test_ds_load.py", line 12, in <module>
    dataset = builder.as_dataset(split='all')
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'ragged_flat_values'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions