Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .claude/skills/install-miles-diffusion/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Components (every version pinned):
2. **Conda env** — Python `3.11` (configurable via `PY_VER`).
3. **Tooling** — `pip==26.0.1`, `wheel==0.45.1`, `setuptools==82.0.1` (resolver behaviour depends on these).
4. **PyTorch** — `torch==2.9.1` on `cu129` (override via `TORCH_VER` / `CUDA_VER`).
5. **sglang-diffusion** — clones **`Rockdu/sglang` @ `sglang-diffusion-rollout-test`** into `$SGLANG_DIR` (default `../sglang`) and `git checkout --detach $SGLANG_COMMIT` (default `0372158dd66bc7cb0740c733bd60047db790ec7d`). Installed editable as `python[all]`. Pinning to a SHA (not just the branch tip) is required for bit-exact rollout reproducibility. Override `SGLANG_REPO` / `SGLANG_BRANCH` / `SGLANG_COMMIT` only if you know what you're doing.
5. **sglang-diffusion** — clones **`Rockdu/sglang` @ `feat/wan-rollout-optimization`** into `$SGLANG_DIR` (default `../sglang`) and `git checkout --detach $SGLANG_COMMIT` (default `553cacca96fbfa9af55cf0b07ab3b9d2595d35cd`). Installed editable as `python[all]`. Pinning to a SHA (not just the branch tip) is required for bit-exact rollout reproducibility. Override `SGLANG_REPO` / `SGLANG_BRANCH` / `SGLANG_COMMIT` only if you know what you're doing.
6. **miles package** — `pip install -r requirements.txt` (all `==`-pinned: transformers 5.5.4, accelerate 1.12.0, ray 2.53.0, datasets 4.4.2, safetensors 0.7.0, wandb 0.23.1, …) plus `pip install -e . --no-deps`.
7. **flow_grpo OCR deps** — runs `flow_grpo/setup.sh` (every line `--no-deps` and `==`-pinned: paddleocr 2.9.1, paddlepaddle-gpu 2.6.2, peft 0.18.1, diffusers 0.37.0, opencv 4.11.0.86, etc.).
8. **torch_memory_saver** — pinned to `0.0.9`, skipped silently on failure.
Expand All @@ -37,8 +37,8 @@ Before doing anything, surface these to the user and let them override:
- `PY_VER` (default `3.11`)
- `SGLANG_DIR` (default `$(dirname "$PWD")/sglang`)
- `SGLANG_REPO` (default `https://github.com/Rockdu/sglang.git`)
- `SGLANG_BRANCH` (default `sglang-diffusion-rollout-test`)
- `SGLANG_COMMIT` (default `0372158dd66bc7cb0740c733bd60047db790ec7d`)
- `SGLANG_BRANCH` (default `feat/wan-rollout-optimization`)
- `SGLANG_COMMIT` (default `553cacca96fbfa9af55cf0b07ab3b9d2595d35cd`)
- `CUDA_VER` (default `12.9`)
- `TORCH_VER` (default `2.9.1`)

Expand All @@ -54,7 +54,7 @@ It's long — run with `run_in_background: true` and stream with Monitor, or let

- **No conda/mamba** — helper aborts with a message telling the user to install miniforge. Don't auto-install conda.
- **No CUDA toolkit / no GPU** — `nvidia-smi` fails at smoke-test; install still succeeds but warn the user.
- **sglang branch missing / renamed** — if `Rockdu/sglang` no longer has `sglang-diffusion-rollout-test`, the clone/fetch fails. Do not silently fall back to upstream sgl-project/sglang: the required changes (multimodal_gen + weight-sync RPC) only live on the Rockdu fork. Surface the failure to the user and ask which branch to pin to instead.
- **sglang branch missing / renamed** — if `Rockdu/sglang` no longer has `feat/wan-rollout-optimization`, the clone/fetch fails. Do not silently fall back to upstream sgl-project/sglang: the required changes (multimodal_gen + weight-sync RPC) only live on the Rockdu fork. Surface the failure to the user and ask which branch to pin to instead.
- **paddlepaddle-gpu wheel mismatch** — pinned to 2.6.2 in flow_grpo/setup.sh. If the machine's CUDA is too new, you may need to swap the pin. Report the mismatch; don't silently change the pin.
- **System apt missing sudo** — fall back to `apt-get` without sudo (works in containers). If both fail, tell the user which .so is missing.

Expand Down
10 changes: 5 additions & 5 deletions .claude/skills/install-miles-diffusion/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# PY_VER python version (default: 3.11)
# SGLANG_DIR where to clone sglang (default: ../sglang)
# SGLANG_REPO sglang git URL (default: https://github.com/Rockdu/sglang.git)
# SGLANG_BRANCH sglang branch to check out (default: sglang-diffusion-rollout-test)
# SGLANG_BRANCH sglang branch to check out (default: feat/wan-rollout-optimization)
# SGLANG_COMMIT sglang commit SHA to pin (default: pinned working SHA below)
# CUDA_VER torch cuda tag (default: 12.9 -> cu129)
# TORCH_VER torch version (default: 2.9.1)
Expand All @@ -17,7 +17,7 @@
# Override only if you know what you're doing.
#
# sglang source of truth: the sglang-diffusion fork lives at
# Rockdu/sglang @ sglang-diffusion-rollout-test
# Rockdu/sglang @ feat/wan-rollout-optimization
# miles-diffusion depends on that branch (multimodal_gen +
# update_weights_from_tensor for RL weight sync). The branch tip moves; we pin
# to a specific commit SHA via SGLANG_COMMIT for bit-reproducibility.
Expand All @@ -29,8 +29,8 @@ PY_VER="${PY_VER:-3.11}"
CUDA_VER="${CUDA_VER:-12.9}"
TORCH_VER="${TORCH_VER:-2.9.1}"
SGLANG_REPO="${SGLANG_REPO:-https://github.com/Rockdu/sglang.git}"
SGLANG_BRANCH="${SGLANG_BRANCH:-sglang-diffusion-rollout-test}"
SGLANG_COMMIT="${SGLANG_COMMIT:-0372158dd66bc7cb0740c733bd60047db790ec7d}"
SGLANG_BRANCH="${SGLANG_BRANCH:-feat/wan-rollout-optimization}"
SGLANG_COMMIT="${SGLANG_COMMIT:-553cacca96fbfa9af55cf0b07ab3b9d2595d35cd}"

# Tooling pins (pip resolver behaviour depends on these).
PIP_VER="${PIP_VER:-26.0.1}"
Expand Down Expand Up @@ -107,7 +107,7 @@ else
fi

# ---------------------------------------------------------------- sglang-diffusion
# Depends on Rockdu/sglang @ sglang-diffusion-rollout-test (sglang-diffusion
# Depends on Rockdu/sglang @ feat/wan-rollout-optimization (sglang-diffusion
# fork with update_weights_from_tensor for multimodal_gen). Pinned to a
# specific commit so bit-exact rollout behaviour is reproducible.
if [[ ! -d "$SGLANG_DIR" ]]; then
Expand Down
102 changes: 76 additions & 26 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import miles.backends.fsdp_utils.configs.qwen_image # noqa: F401 — register pipeline config
import miles.backends.fsdp_utils.configs.sd3 # noqa: F401 — register pipeline config
import miles.backends.fsdp_utils.configs.wan2_2 # noqa: F401 — register pipeline config
from miles.ray.train_actor import TrainRayActor
from miles.utils import tracking_utils, train_metric_utils
from miles.utils.context_utils import with_defer
Expand All @@ -20,7 +21,6 @@
from miles.utils.sde_log_prob import sde_step_with_logprob
from miles.utils.timer import Timer, inverse_timer, timer
from miles.utils.tracking_utils import init_tracking

from . import checkpoint
from .configs.train_pipeline_config import get_train_pipeline_config
from .diffusion_update_weight_utils import DiffusionUpdateWeightFromTensor, DiffusionUpdateWeightFromTensorLoRA
Expand Down Expand Up @@ -74,34 +74,51 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
vae=None,
tokenizer=None,
)
model = pipeline.transformer
raw_models: dict[str, torch.nn.Module] = {}
for component in args.update_weight_target_modules:
sub_model = getattr(pipeline, component, None)
if sub_model is None:
raise ValueError(
f"--update-weight-target-module: pipeline {self.args.hf_checkpoint} "
f"has no component '{component}'"
)
raw_models[component] = sub_model
self.scheduler = pipeline.scheduler
del pipeline

self.train_pipeline_config = get_train_pipeline_config(args.diffusion_model)

if args.use_lora:
model = apply_lora(model, args, self.train_pipeline_config)
self.models: dict[str, torch.nn.Module] = {}
for component, model in raw_models.items():
if args.use_lora:
model = apply_lora(model, args, self.train_pipeline_config)

model.train()
model.train()

if args.gradient_checkpointing:
model.enable_gradient_checkpointing()
if args.gradient_checkpointing:
model.enable_gradient_checkpointing()

model.to(torch.cuda.current_device())
model.to(torch.cuda.current_device())

self.train_pipeline_config.preprocess_model_before_fsdp(model)
self.train_pipeline_config.preprocess_model_before_fsdp(model)

model = apply_fsdp2(
model,
mesh=self.parallel_state.dp_mesh,
cpu_offload=self.args.fsdp_cpu_offload,
args=self.args,
)
model = apply_fsdp2(
model,
mesh=self.parallel_state.dp_mesh,
cpu_offload=self.args.fsdp_cpu_offload,
args=self.args,
)
self.models[component] = model
# Force a sync to ensure sharding is complete and old memory is freed.
torch.cuda.synchronize()
clear_memory()
self.model = model
# Single component keeps the bare model as self.model so optimizer /
# checkpoint state-dict keys stay identical to pre-dual-DiT runs;
# multi-component wraps in a ModuleDict (keys get a component prefix).
if len(self.models) == 1:
self.model = next(iter(self.models.values()))
else:
self.model = torch.nn.ModuleDict(self.models)

if args.optimizer == "adam":
self.optimizer = torch.optim.AdamW(
Expand Down Expand Up @@ -133,9 +150,9 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
if self.args.debug_train_only:
self.weight_updater = None
elif self.args.use_lora:
self.weight_updater = DiffusionUpdateWeightFromTensorLoRA(self.args, self.model)
self.weight_updater = DiffusionUpdateWeightFromTensorLoRA(self.args, self.models)
else:
self.weight_updater = DiffusionUpdateWeightFromTensor(self.args, self.model)
self.weight_updater = DiffusionUpdateWeightFromTensor(self.args, self.models)

checkpoint.finalize_load(self, checkpoint_payload)

Expand Down Expand Up @@ -307,8 +324,8 @@ def _train_core(self, rollout_id: int, rollout_data) -> None:
raise ValueError(
"--diffusion-kl-beta currently requires --use-lora so the base model can be used as reference."
)
if kl_beta > 0 and not hasattr(self.model, "disable_adapter"):
raise RuntimeError("Diffusion KL requires a PEFT model exposing disable_adapter() after FSDP wrapping.")
if kl_beta > 0 and not all(hasattr(m, "disable_adapter") for m in self.models.values()):
raise RuntimeError("Diffusion KL requires PEFT models exposing disable_adapter() after FSDP wrapping.")

# ------------- training parameters -------------
# See docs/developer_guide/terminology.md for batch-size naming convention.
Expand Down Expand Up @@ -622,6 +639,36 @@ def _forward_tile(
latents_flat = latents_tile.reshape(tile_sample_count * tile_tstep_count, *latents_tile.shape[2:])
timesteps_flat = timesteps_tile.reshape(tile_sample_count * tile_tstep_count)

# Phase routing (multi-expert models, e.g. Wan2.2 high/low-noise): one
# tile must map to exactly one DiT so a single forward uses one model
# and one CFG scale, mirroring sgl-d's per-step model selection.
tile_components = {
train_pipeline_config.component_for_timestep(t, num_train_timesteps) for t in timesteps_flat.tolist()
}
if len(tile_components) > 1:
raise ValueError(
f"Tile mixes denoising phases {sorted(tile_components)}; shrink the tile so it is "
"phase-pure (--micro-batch-size-tstep 1 and --micro-batch-size-sample 1)."
)
tile_component = tile_components.pop()
if tile_component in self.models:
tile_model = self.models[tile_component]
elif len(self.models) == 1 and len(train_pipeline_config.target_components) == 1:
# Single-DiT model trained under a custom module name.
tile_model = self.model
else:
raise KeyError(
f"Step strategy selected timesteps for component '{tile_component}' but trained "
f"components are {list(self.models)}; align --diffusion-step-strategy-path with "
"--update-weight-target-module."
)
guidance_scale = train_pipeline_config.select_guidance_scale(
float(timesteps_flat[0]),
num_train_timesteps,
guidance_scale,
self.args.diffusion_guidance_scale_2,
)

# sgl-d's Qwen DiT divides timestep by num_train_timesteps inside
# forward; diffusers' does not. SD3 already expects raw timesteps.
if train_pipeline_config.needs_timestep_scaling:
Expand Down Expand Up @@ -666,7 +713,7 @@ def _forward_tile(
timesteps_input = timesteps_for_model.to(forward_dtype)

def _forward(cond: dict) -> torch.Tensor:
return self.model(
return tile_model(
hidden_states=latents_input,
timestep=timesteps_input,
return_dict=False,
Expand All @@ -676,13 +723,13 @@ def _forward(cond: dict) -> torch.Tensor:
cfg_batching = bool(self.args.fsdp_cfg_batching)

def _compute_noise_pred(disable_adapter: bool = False) -> torch.Tensor:
adapter_ctx = self.model.disable_adapter() if disable_adapter else nullcontext()
adapter_ctx = tile_model.disable_adapter() if disable_adapter else nullcontext()
with adapter_ctx:
if not use_cfg:
return _forward(pos_cond_tile)
if cfg_batching:
joint_cond = _pack_cond_for_joint_cfg(pos_cond_tile, neg_cond_tile)
joint_out = self.model(
joint_out = tile_model(
hidden_states=torch.cat([latents_input, latents_input], dim=0),
timestep=torch.cat([timesteps_input, timesteps_input], dim=0),
return_dict=False,
Expand Down Expand Up @@ -752,9 +799,10 @@ def _compute_noise_pred(disable_adapter: bool = False) -> torch.Tensor:
log_stats["clipfrac"].append(torch.mean((torch.abs(ratio - 1.0) > clip_range).float()).detach())
log_stats["log_prob_new_idx_0"].append(log_prob_new[0, 0].detach())
log_stats["log_prob_old_idx_0"].append(log_prob_old_tile[0, 0].detach())
log_stats["log_prob_mean_abs_diff"].append(
torch.mean(torch.abs(log_prob_new - log_prob_old_tile)).detach()
)
log_prob_mean_abs_diff = torch.mean(torch.abs(log_prob_new - log_prob_old_tile)).detach()
log_stats["log_prob_mean_abs_diff"].append(log_prob_mean_abs_diff)
if len(self.models) > 1:
log_stats[f"log_prob_mean_abs_diff_{tile_component}"].append(log_prob_mean_abs_diff)
# To log model output diff, please enable --diffusion-debug-mode
rollout_mo_window = grids.get("rollout_model_outputs")
if rollout_mo_window is not None:
Expand All @@ -767,6 +815,8 @@ def _compute_noise_pred(disable_adapter: bool = False) -> torch.Tensor:
log_stats["model_output_max_abs_diff"].append(diff.max().detach())
log_stats["model_output_mean_abs_diff"].append(diff.mean().detach())
log_stats["model_output_rel_max"].append((diff.max() / ref_max).detach())
if len(self.models) > 1:
log_stats[f"model_output_mean_abs_diff_{tile_component}"].append(diff.mean().detach())

return loss

Expand Down
22 changes: 22 additions & 0 deletions miles/backends/fsdp_utils/configs/train_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ class TrainPipelineConfig(abc.ABC):
lora_target_modules: list[str] = ["to_q", "to_k", "to_v", "to_out.0"]
needs_timestep_scaling: bool = True
optimizer_state_allowed_missing: list[str] = []
# Pipeline components (DiT modules) this model family can train. Multi-expert
# models (Wan2.2 high/low-noise) list all of them; --update-weight-target-module
# selects which subset actually gets trained in a run.
target_components: list[str] = ["transformer"]

def component_for_timestep(self, timestep: float, num_train_timesteps: int) -> str:
"""Which pipeline component denoises this (raw, unscaled) timestep.

Single-DiT models always route to the first target component. Multi-expert
models override this to mirror the rollout engine's per-step selection.
"""
return self.target_components[0]

def select_guidance_scale(
self,
timestep: float,
num_train_timesteps: int,
guidance_scale: float,
guidance_scale_2: float | None,
) -> float:
"""CFG scale for this (raw) timestep, mirroring sgl-d's per-step selection."""
return guidance_scale

def prepare_trajectory(
self,
Expand Down
69 changes: 69 additions & 0 deletions miles/backends/fsdp_utils/configs/wan2_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Wan2.2 training pipeline config."""

from __future__ import annotations

import torch
from miles.utils.types import CondKwargs

from .train_pipeline_config import TrainPipelineConfig, register_train_pipeline_config


@register_train_pipeline_config("Wan2.2-T2V-A14B", "Wan-AI/Wan2.2-T2V-A14B")
class Wan2_2TrainPipelineConfig(TrainPipelineConfig):
# High-noise expert ("transformer") handles t >= boundary, low-noise expert
# ("transformer_2") the rest — mirrors sgl-d's _select_and_manage_model.
target_components = ["transformer", "transformer_2"]
boundary_ratio = 0.875
# Wan DiT expects raw scheduler timesteps (0..num_train_timesteps), no /1000 scaling.
needs_timestep_scaling = False

def component_for_timestep(self, timestep: float, num_train_timesteps: int) -> str:
if timestep >= self.boundary_ratio * num_train_timesteps:
return "transformer"
return "transformer_2"

def select_guidance_scale(
self,
timestep: float,
num_train_timesteps: int,
guidance_scale: float,
guidance_scale_2: float | None,
) -> float:
if timestep >= self.boundary_ratio * num_train_timesteps:
return guidance_scale
# sgl-d uses batch.guidance_scale_2 for low-noise steps with NO fallback;
# a silent fallback here would desync train/rollout CFG and corrupt ratios.
assert guidance_scale_2 is not None, (
"Wan2.2 low-noise steps require --diffusion-guidance-scale-2 "
"(rollout already denoises them with guidance_scale_2)."
)
return guidance_scale_2

def prepare_cond_kwargs(self, cond: CondKwargs | None, device: torch.device) -> dict:
if cond is None or not cond.encoder_hidden_states:
return {}
enc = torch.cat(cond.encoder_hidden_states).to(device)
if enc.ndim == 2:
enc = enc.unsqueeze(0)
return {"encoder_hidden_states": enc}

def collate_cond_for_sample_batch(
self,
per_sample_cond_kwargs: list[dict],
device: torch.device,
) -> dict:
encs = [kw["encoder_hidden_states"] for kw in per_sample_cond_kwargs]
return {"encoder_hidden_states": torch.cat(encs, dim=0).to(device)}

def cfg_combine(
self,
noise_pred_pos: torch.Tensor,
noise_pred_neg: torch.Tensor,
guidance_scale: float,
true_cfg_scale: float | None = None,
) -> torch.Tensor:
scale = true_cfg_scale if true_cfg_scale is not None else guidance_scale
return noise_pred_neg + scale * (noise_pred_pos - noise_pred_neg)

def preprocess_model_before_fsdp(self, model: torch.nn.Module) -> None:
return None
Loading
Loading