Skip to content

Batch collate error, "found object" from raw_language and a question about pad_mask #25

@badinkajink

Description

@badinkajink

Edit: This error only occurs when you load the dataset in isolation, maybe because of EagerTensors. No issues when running the train script.

Hi, I ran into a collate error when trying to read a batch from an RLDS dataset passed to the TorchRLDSDataset wrapper in rlds_utils.py. Error from \torch\utils\data\_utils\collate.py:169:

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

This comes from the raw_language key:

type of raw_language: <class 'numpy.ndarray'>
dtype of raw_language: object

I resolved this with the following in TorchRLDSDataset:

    def __iter__(self):
        for sample in self._rlds_dataset.as_numpy_iterator():
            rl = sample['obs']['raw_language']
            sample['obs']['raw_language'] = rl.tolist()
            yield sample

I have been following along with the droid_dataloader.py example for a small custom dataset I have in RLDS format. I've also reproduced this error on the droid_100 dataset. My goal is to train a diffusion policy on my single dataset, for now. I had to change the BASE_DATASET_KWARGS slightly to work for the latest octo--not sure how or if that could've caused this error.

Also, I have an unrelated, potentially simple question:
I'm assuming pad_mask is equivalent to timestep_pad_mask, which as far as I can tell is generated by octo.HistoryWrapper. The droid_100 dataset is also missing this key, but the robomimic_transform seems to require it. Do I need this? Every field is populated with real data for each timestep in my dataset, but my trajectories are all variable step length. Would I have to manually go back into the RLDS builder and pad?

Edit: For anyone who runs into this, pad mask is generated in octo/data/traj_transforms.py:56 in the chunk_act_obs function. Variable step length in trajectories does not matter.

Running:
octo=latest commit (not at the pinned commit)
torch=2.0.1
platform=Windows

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions