Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
Mirrors the vision SFT stack (PackingDataLoader + RankPartitionedDataLoader),
but feeds the DROID action dataset (``joint_pos`` 8D + ``use_state``, raw/
un-normalized — same as the internal ``droid_lerobot_8b_policy`` run) through
``ActionTransformPipeline``, and trains the generation + action heads from the
public ``nvidia/Cosmos3-Nano`` base.
un-normalized) through ``ActionTransformPipeline``, and trains the generation +
action heads from the public ``nvidia/Cosmos3-Nano`` base.
Usage (1 node, 8 GPU)::
Expand Down Expand Up @@ -41,13 +40,10 @@
{"override /model": "mot_fsdp"},
{"override /data_train": None},
{"override /data_val": None},
# Match internal droid_lerobot_8b_policy: apex FusedAdam with fp32
# master_weights + eps 1e-8. adamw + fused + eps 1e-6 (bf16, no fp32
# master) under-steps the small 5x-lr action heads and leaves the action
# loss on a noisy high plateau; an exact-match forward/optimizer test
# confirmed the convergence gap was the optimizer, not the model.
# FusedAdam with fp32 master_weights + eps 1e-8 (bf16 params + eps 1e-6
# diverged on the action loss).
{"override /optimizer": "fusedadamw"},
{"override /scheduler": "lambdalinear"}, # matches internal droid_lerobot_8b (was lambdacosine)
{"override /scheduler": "lambdalinear"}, # linear LR decay
{"override /checkpoint": "s3"},
{
"override /callbacks": [
Expand Down Expand Up @@ -76,7 +72,7 @@
betas=[0.9, 0.99],
eps=1.0e-08,
fused=True, # popped by build_optimizer for FusedAdam (fused by construction)
# Generation + action heads (mirrors internal droid_lerobot_8b_policy).
# Train the generation + action heads.
keys_to_select=[
"moe_gen",
"time_embedder",
Expand All @@ -86,7 +82,7 @@
"llm2action",
"action_modality_embed",
],
lr=2.0e-04, # matches internal droid_lerobot_8b_policy submit (--lr 2e-4)
lr=2.0e-04, # for the 8192 global batch
lr_multipliers={
"action2llm": 5.0,
"llm2action": 5.0,
Expand All @@ -96,7 +92,7 @@
weight_decay=0.05,
),
scheduler=dict(
lr_scheduler_type="LambdaLinear", # matches internal droid_lerobot_8b (was LambdaCosine)
lr_scheduler_type="LambdaLinear",
cycle_lengths=[100], # smoke: 100 iters (real run sets via TOML)
f_max=[0.4],
f_min=[0.0],
Expand Down Expand Up @@ -125,7 +121,7 @@
device_monitor=dict(
every_n=200, log_memory_detail=True, save_s3=False, step_size=1, upload_every_n_mul=5
),
grad_clip=dict(clip_norm=1.0, force_finite=True), # matches internal make_8b
grad_clip=dict(clip_norm=1.0, force_finite=True),
heart_beat=dict(every_n=200, save_s3=False, step_size=1, update_interval_in_minute=20),
iter_speed=dict(every_n=1, hit_thres=50, save_s3=False, save_s3_every_log_n=500),
low_precision=dict(update_iter=1),
Expand All @@ -140,10 +136,9 @@
dcp_async_mode_enabled=False,
enable_gcs_patch_in_boto3=True,
keys_not_to_resume=[],
# Skip net_ema. (→ EMA warm-start copies net→net_ema, see dcp.py) AND the
# action heads, so they init fresh from the base — matches internal
# make_8b _DEFAULT_KEYS_TO_SKIP (Cosmos3-Nano's action heads are not
# DROID-policy-trained).
# Skip net_ema. (EMA warm-starts from net, see dcp.py) and the action
# heads, so they init fresh from the base (the base has no DROID-trained
# action heads).
keys_to_skip_loading=[
"net_ema.",
"action2llm",
Expand Down Expand Up @@ -171,7 +166,7 @@
dataloader_train=L(PackingDataLoader)(
audio_sample_rate=48000,
dataset_name="action_droid",
max_samples_per_batch=128, # count-based batch (matches internal res480 8B)
max_samples_per_batch=128, # per rank -> 8192 global batch at 64 ranks (16 nodes, shard 8 x replicate 8)
max_sequence_length=None, # None disables token packing (TOML can't express null)
patch_spatial=2,
sound_latent_fps=0,
Expand All @@ -185,6 +180,13 @@
pin_memory=True,
prefetch_factor=4,
sampler=None,
# Shuffling is handled by the dataset (iterable_shuffle=True below):
# ActionIterableShuffleDataset streams rank x worker-sharded, episode-order-
# shuffled, sequential-within-episode. The map-style dataset has no internal
# shuffle, so a SequentialSampler would feed every rank the SAME consecutive
# overlapping windows -> global batch ~1 episode -> unstable grad-norm; a plain
# RandomSampler decorrelates but does random-access I/O -> slow + OOM. The
# iterable gives decorrelation with sequential reads.
datasets=dict(
droid=dict(
ratio=1,
Expand All @@ -193,15 +195,21 @@
fps=15.0,
chunk_length=32,
action_space="joint_pos",
# Policy-only task mode. "joint" would randomly pick
# forward_dynamics/inverse_dynamics/policy per sample (multi-task),
# which dilutes each per-task loss by ~1/3.
mode="policy",
use_state=True,
iterable_shuffle=True, # rank x worker episode-shuffle stream
episode_shuffle_seed=42,
use_image_augmentation=True, # SR boost (random crop+rescale + color jitter)
# Keep-ranges window filter (drops idle/non-task frames). Off by default;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Keep-ranges window filter (drops idle/non-task frames). Off by default;
# keep_ranges_1_0_1.json window filter (drops idle/non-task frames). Off by default;

# the launcher sets use_filter_dict=True + filter_dict_path for internal parity.
# set use_filter_dict=True + filter_dict_path to enable.
use_filter_dict=False,
filter_dict_path=None,
Comment on lines 208 to 209

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to change use_filter_dict to True.
Nvm. I see you override it in the launch command.

action_normalization=None,
viewpoint="concat_view", # wrist 480p (top) + L/R shoulder 320x180 (bottom)
resolution="480", # 640x360 data @ 480p (matches internal res480 run)
resolution="480", # 640x360 data @ 480p
max_action_dim="${model.config.max_action_dim}",
cfg_dropout_rate=0.1,
tokenizer_config="${model.config.vlm_config.tokenizer}",
Expand All @@ -217,12 +225,24 @@
)


# chunk_length=32 → 33 observation frames; pin the VAE encode duration to match
# (internal used [17] for chunk_length=16). Set post-construction so it lands on
# the deep-copied NANO_MODEL_CONFIG.tokenizer.
# chunk_length=32 -> 33 observation frames; pin the VAE encode duration to match.
# Set post-construction so it lands on the deep-copied NANO_MODEL_CONFIG.tokenizer.
action_policy_droid_nano["model"]["config"]["tokenizer"]["encode_exact_durations"] = [33]


# Uncap the packed-sequence length. The NANO default (45056) caps the packed sequence,
# truncating long DROID windows to ~1/4 of their natural length; -1 (uncapped) processes
# the full vision sequence per step. Does not change the per-token loss; widens the
# effective vision context per step.
action_policy_droid_nano["model"]["config"]["max_num_tokens_after_packing"] = -1


# Weight the vision flow-matching loss 10x in the total loss (the NANO default is 1.0).
# loss_scale multiplies only the vision term, balancing it against the action loss
# (action_loss_weight=10) so both heads train at comparable gradient magnitude.
action_policy_droid_nano["model"]["config"]["rectified_flow_training_config"]["loss_scale"] = 10.0


for _item in [action_policy_droid_nano]:
_name = [k for k, v in globals().items() if v is _item][0]
cs.store(group="experiment", package="_global_", name=_name, node=_item)
67 changes: 61 additions & 6 deletions cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Any

from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset, get_worker_info

from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset
from cosmos_framework.data.vfm.action.transforms import ActionTransformPipeline
Expand All @@ -37,13 +37,63 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> dict[str, Any]:
return self._transform(self._dataset[idx], self._resolution)

def get_shuffle_blocks(self):
"""Delegate to the inner DROIDLeRobotDataset (per-episode/segment flat-index blocks)."""
return self._dataset.get_shuffle_blocks()



class ActionIterableShuffleDataset(IterableDataset):
"""Streaming view of a map-style ``ActionSFTDataset``.

Each ``(rank, worker)`` is assigned a DISJOINT subset of episodes (sharded over
``shard_world_size * num_workers``), shuffles its episode ORDER, and streams the
windows WITHIN each episode sequentially -> within-rank batch diversity (the N
workers of a rank stream N different episodes) AND cross-rank diversity, while
keeping reads sequential (I/O locality + COW; no RandomSampler random-access OOM).
Re-shuffles each epoch and streams indefinitely (the trainer stops at ``max_iter``).

``shard_world_size`` / ``shard_rank`` are set by ``RankPartitionedDataLoader``.
"""

def __init__(self, dataset: "ActionSFTDataset", seed: int = 42):
super().__init__()
self._dataset = dataset
self._seed = int(seed)
self.shard_world_size = 1
self.shard_rank = 0

def __len__(self) -> int: # informational only; iteration is infinite
return len(self._dataset)

def __iter__(self):
import torch

blocks = self._dataset.get_shuffle_blocks()
wi = get_worker_info()
wid = wi.id if wi is not None else 0
nw = wi.num_workers if wi is not None else 1
global_shard = int(self.shard_rank) * nw + wid
total_shards = max(1, int(self.shard_world_size) * nw)
epoch = 0
while True:
g = torch.Generator()
g.manual_seed(self._seed + epoch) # same permutation across all (rank,worker) -> disjoint shard
order = torch.randperm(len(blocks), generator=g).tolist()
for b in order[global_shard::total_shards]:
Comment thread
lfengad marked this conversation as resolved.
start, length = blocks[b]
for idx in range(start, start + length):
yield self._dataset[idx]
epoch += 1


def get_action_droid_sft_dataset(
*,
root: str,
fps: float = 15.0,
chunk_length: int = 32,
action_space: str = "joint_pos",
mode: str = "policy",
use_state: bool = True,
action_normalization: str | None = None,
viewpoint: str = "concat_view",
Expand All @@ -58,16 +108,18 @@ def get_action_droid_sft_dataset(
append_duration_fps_timestamps: bool = True,
append_resolution_info: bool = True,
append_idle_frames: bool = False,
) -> ActionSFTDataset:
"""Build the DROID action SFT dataset (joint_pos 8D by default), matching the
internal ``droid_lerobot_8b_policy`` data: ``action_space='joint_pos'`` +
``use_state`` (8D, raw/un-normalized), concat_view, chunk_length 32."""
iterable_shuffle: bool = False,
episode_shuffle_seed: int = 42,
) -> Dataset:
"""Build the DROID action SFT dataset: ``action_space='joint_pos'`` (8D) +
``use_state`` (raw/un-normalized), concat_view, chunk_length 32."""
dataset = DROIDLeRobotDataset(
root=root,
fps=fps,
chunk_length=chunk_length,
viewpoint=viewpoint,
action_space=action_space,
mode=mode,
use_state=use_state,
action_normalization=action_normalization,
use_image_augmentation=use_image_augmentation,
Expand All @@ -83,4 +135,7 @@ def get_action_droid_sft_dataset(
append_resolution_info=append_resolution_info,
append_idle_frames=append_idle_frames,
)
return ActionSFTDataset(dataset, transform, resolution)
sft = ActionSFTDataset(dataset, transform, resolution)
if iterable_shuffle:
return ActionIterableShuffleDataset(sft, seed=episode_shuffle_seed)
return sft
6 changes: 5 additions & 1 deletion cosmos_framework/data/vfm/action/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ def _build_result(
**extras: Any,
) -> dict[str, Any]:
idle_frames = self._compute_idle_frames(action)
normalized_action = normalize_action(action, self.action_normalization, self._load_norm_stats())
# action_normalization=None -> use raw actions (no normalization), e.g. joint_pos.
if self.action_normalization is None:
normalized_action = action
else:
normalized_action = normalize_action(action, self.action_normalization, self._load_norm_stats())
formatted_video = (video * 255.0).clamp(0.0, 255.0).to(torch.uint8).permute(1, 0, 2, 3)
return {
"ai_caption": ai_caption,
Expand Down
16 changes: 16 additions & 0 deletions cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,19 @@ def __len__(self) -> int:
if self._use_filter_dict:
return int(self._seg_cum[-1]) if self._seg_cum.size else 0
return int(self._valid_cum[-1]) if self._valid_cum.size else 0

def get_shuffle_blocks(self) -> list[tuple[int, int]]:
"""Per-episode (or per kept-segment, when ``use_filter_dict``) flat-index blocks
``(start, length)``. ``ActionIterableShuffleDataset`` shuffles the ORDER of these
blocks and shards them disjointly across ranks, while keeping windows *within* a
block sequential -> decorrelates batches across ranks without random-access I/O
(preserves locality + copy-on-write memory sharing across workers)."""
cum = self._seg_cum if self._use_filter_dict else self._valid_cum
blocks: list[tuple[int, int]] = []
prev = 0
for c in np.asarray(cum).tolist():
c = int(c)
if c > prev:
blocks.append((prev, c - prev))
prev = c
return blocks
Loading