Skip to content

train_pytorch.py: wandb image logging fails with "Un-supported shape [224, 9, 224]" when images are NHWC #877

@leo038

Description

@leo038
## 环境
- 脚本: `scripts/train_pytorch.py`
- 运行方式: `torchrun --standalone --nnodes=1 --nproc_per_node=4 scripts/train_pytorch.py <config> --exp_name <name>`
- 触发位置: 首次 batch 时向 wandb 记录 sample images 的代码块

## 现象
```text
ValueError: Un-supported shape for image conversion [224, 9, 224]

发生在 wandb.Image(img_concatenated)

原因

  • 统一 dataloader 在 PyTorch 下返回的 observation 图像是 NHWC [B, H, W, C](与 JAX/LeRobot 一致)。
  • 当前实现假定图像为 NCHW,对 img[i](实际为 [H, W, C])做了 permute(1, 2, 0),得到 [224, 3, 224],再沿 dim=1 拼接 3 个视角得到 [224, 9, 224],wandb 无法识别。

建议修复

  1. 根据格式判断:若 frame.shape[0] == 3 视为 NCHW,做 permute(1, 2, 0) 转为 NHWC;否则按已是 NHWC 处理。
  2. 拼接时在 NHWC 的 width 维度上 cat,得到 [H, W_total, 3] 再传给 wandb.Image
  3. 若像素值在 [-1, 1],先线性缩放到 [0, 1] 再传入,避免 wandb 的数值范围警告。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions