From fad399cd73e0d095c71bc27307d4d96e337517b1 Mon Sep 17 00:00:00 2001 From: "liang.feng" Date: Wed, 17 Jun 2026 08:51:33 -0700 Subject: [PATCH] fix(action): handle action_normalization=None in ActionBaseDataset._build_result MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit joint_pos DROID actions are raw/un-normalized: DROIDLeRobotDataset forces action_normalization=None for action_space="joint_pos", and the parameter is typed str | None across all action datasets. But _build_result always called normalize_action(action, None, ...), which only accepts quantile|meanstd|minmax and raises "ValueError: Unknown normalization method: None" — so the shipped action_policy_droid_nano recipe crashed at dataloader start (no 8-D joint_pos stats exist to normalize against anyway). Guard the None case: pass actions through untouched and skip stats loading. The fix lives at the dataset layer (where the str | None contract lives); normalize_action's strict method validation is left intact. Adds base_dataset_test.py covering both the None pass-through and the method-set normalization paths. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../data/vfm/action/datasets/base_dataset.py | 9 ++- .../vfm/action/datasets/base_dataset_test.py | 69 +++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 cosmos_framework/data/vfm/action/datasets/base_dataset_test.py diff --git a/cosmos_framework/data/vfm/action/datasets/base_dataset.py b/cosmos_framework/data/vfm/action/datasets/base_dataset.py index 564d48e..2a8d23b 100644 --- a/cosmos_framework/data/vfm/action/datasets/base_dataset.py +++ b/cosmos_framework/data/vfm/action/datasets/base_dataset.py @@ -102,7 +102,7 @@ def domain_id(self) -> int: return self._domain_id @property - def action_normalization(self) -> str: + def action_normalization(self) -> str | None: return self._action_normalization @property @@ -186,7 +186,12 @@ def _build_result( **extras: Any, ) -> dict[str, Any]: idle_frames = self._compute_idle_frames(action) - normalized_action = normalize_action(action, self.action_normalization, self._load_norm_stats()) + # action_normalization=None means "raw / un-normalized" (e.g. DROID + # joint_pos): pass actions through untouched and skip loading stats. + if self.action_normalization is None: + normalized_action = action + else: + normalized_action = normalize_action(action, self.action_normalization, self._load_norm_stats()) formatted_video = (video * 255.0).clamp(0.0, 255.0).to(torch.uint8).permute(1, 0, 2, 3) return { "ai_caption": ai_caption, diff --git a/cosmos_framework/data/vfm/action/datasets/base_dataset_test.py b/cosmos_framework/data/vfm/action/datasets/base_dataset_test.py new file mode 100644 index 0000000..21718fb --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/base_dataset_test.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Tests for ActionBaseDataset._build_result normalization handling.""" + +from pathlib import Path + +import torch + +from cosmos_framework.data.vfm.action.action_spec import Gripper, Joint, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset + + +class _StubDataset(ActionBaseDataset): + """Concrete subclass exposing _build_result without touching disk.""" + + @property + def action_dim(self) -> int: + return 8 + + def _action_spec(self): + # 8D joint_pos layout: 7 arm joints + gripper (matches DROID joint_pos). + return build_action_spec(Joint(n=7, label="arm"), Gripper()) + + @classmethod + def _stats_path(cls) -> Path: + return Path("/nonexistent/stats.json") + + def __getitem__(self, idx): # pragma: no cover - not exercised + raise NotImplementedError + + +def _make_dataset(action_normalization, norm_stats=None) -> _StubDataset: + # Bypass __init__ (which reads dataset files from disk) and set only the + # attributes _build_result touches. + ds = object.__new__(_StubDataset) + ds._fps = 15.0 + ds._viewpoint = "concat_view" + ds._domain_id = 0 + ds._action_normalization = action_normalization + ds._norm_stats = norm_stats + return ds + + +def _video() -> torch.Tensor: + return torch.zeros(2, 3, 4, 4) # [C, T, H, W] -> permuted inside _build_result + + +def test_build_result_skips_normalization_when_none(): + """action_normalization=None (raw joint_pos) must pass actions through unchanged.""" + action = torch.arange(4 * 8, dtype=torch.float32).reshape(4, 8) + ds = _make_dataset(action_normalization=None) + + result = ds._build_result(mode="policy", video=_video(), action=action, ai_caption="x") + + assert torch.equal(result["action"], action) + + +def test_build_result_applies_normalization_when_method_set(): + """A configured method still normalizes (regression guard for the None fix).""" + action = torch.full((4, 8), 0.5) + stats = {"min": torch.zeros(8), "max": torch.ones(8)} + ds = _make_dataset(action_normalization="minmax", norm_stats=stats) + + result = ds._build_result(mode="policy", video=_video(), action=action, ai_caption="x") + + # minmax with [0,1] range maps 0.5 -> 0.0; must differ from the raw input. + assert torch.allclose(result["action"], torch.zeros(4, 8)) + assert not torch.equal(result["action"], action)