Skip to content

Support Wan2.2-T2V-A14B (dual-expert MoE) GRPO training#8

Open
zhihengy wants to merge 13 commits into
mainfrom
feat/wan
Open

Support Wan2.2-T2V-A14B (dual-expert MoE) GRPO training#8
zhihengy wants to merge 13 commits into
mainfrom
feat/wan

Conversation

@zhihengy

Copy link
Copy Markdown
Collaborator

Support Wan2.2-T2V-A14B (dual-expert MoE) GRPO training

Adds end-to-end support for RL (GRPO) fine-tuning of Wan2.2-T2V-A14B in
miles-diffusion. Wan2.2's A14B is a two-expert MoE (high-noise transformer
for t≥boundary, low-noise transformer_2 below), which the existing
single-DiT pipeline could not train. Rollout runs on the sgl-d fork.

What's included (feat/wan, 13 commits)

Model adaptation

  • configs/wan2_2.py + train_pipeline_config.py hooks (target_components,
    component_for_timestep, select_guidance_scale) — single-DiT models keep
    identical behavior.
  • actor.py: multi-model load (self.models dict / ModuleDict), per-tile
    phase routing + per-tile CFG scale, per-component log-prob/output-diff metrics.
  • diffusion_update_weight_utils.py: per-component weight push to sgl-d.

Rollout / schedule control

  • --diffusion-flow-shift S: client sends the per-request sigma schedule,
    composing out sgl-d's hardcoded shift (shift_12(shift_{S/12}(x))==shift_S(x));
    step strategies derive phase boundaries from the same effective shift.
  • --diffusion-sde-candidate-steps: configurable SDE window candidate set
    (reproducible, logged to wandb config).
  • step strategies: wan_high_window, wan_dual_window (one window per phase),
    wan_ff_global_window (reproduces Flow-Factory's per-epoch draw for A/B).
  • Wan rollout frame/guidance params; rollout response parsing fix.

Reward

  • PickScore: mean-pool over all generated video frames (single-frame reduces
    to prior behavior).

Recipe / infra

  • scripts/run-diffusion-grpo-wan22-pickscore-4gpu.sh, sglang install pin,
    install skill tweaks.

Validation

  • Dual-DiT smoke verified end-to-end (per-phase model_output diff ~2e-2 bf16
    floor; per-component weight-sync checksums consistent; post-sync log_prob
    alignment holds).
  • --diffusion-flow-shift verified against --save-debug-rollout-data
    trajectory timesteps (matches shift-3 grid bit-for-bit).
  • Clean window {1,2,3} confirmed: first train-step clipfrac≈0, log_prob diff
    ~8e-6 (well under clip_range 1e-4).
  • transformer_2 training confirmed by checkpoint weight diff (lora_B grows).

Known limitations / notes for review

  • Dual-expert + KL (disable_adapter) path not implemented.
  • --micro-batch-size-tstep>1 mixed-phase tiles are asserted out (1×1 is clean).
  • The Wan SDE schedule grid (sigma_min ≈ 1/num_train, diffusers-standard) places
    step 3 at t≈857.7 (low-noise) under shift 3; frameworks using a different
    sigma_min convention will route step 3 differently — align grids for any
    cross-framework A/B.

zhihengy and others added 13 commits June 10, 2026 23:32
Repoint CUDA devices, HF/flashinfer caches, Python, datasets, and
the Wan2.2 checkpoint to /cluster-storage paths. Raise rollout batch
size to 48, num-rollout to 10000, microgroup to 8, sglang server
concurrency to 8, and pickscore GPUs per worker to 1. Parameterize
diffusion-clip-range and enable --diffusion-debug-mode.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- --update-weight-target-module accepts a comma-separated list; validated
  and parsed into args.update_weight_target_modules.
- TrainPipelineConfig grows target_components, component_for_timestep()
  and select_guidance_scale() with single-DiT defaults.
- Wan2_2TrainPipelineConfig routes by boundary_ratio=0.875, mirroring
  sgl-d's _select_and_manage_model: high-noise steps use 'transformer' +
  guidance_scale, low-noise 'transformer_2' + guidance_scale_2 (no silent
  fallback when guidance_scale_2 is unset).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
… sync

Actor loads every component named by --update-weight-target-module instead
of hardcoding pipeline.transformer (Wan2.2's transformer_2 was previously
discarded). Single component keeps the bare model as self.model so optimizer
and checkpoint state-dict keys are unchanged; multiple components wrap in an
nn.ModuleDict.

_forward_tile asserts the tile is phase-pure, dispatches to the component's
model via component_for_timestep, and selects the per-phase CFG scale
(low-noise tiles must use --diffusion-guidance-scale-2 to match rollout).
Per-component log_prob/model_output diff metrics are logged when training
more than one component.

DiffusionUpdateWeight holds {component: model}; buckets, LoRA merge indices
and checksum verification are grouped per component and pushed to sgl-d with
target_modules=[component] (the engine API already supports per-module
payload dicts).

Validated end-to-end on Wan2.2 (3-rollout smoke, wandb ti4yxlxr): both
phases reproduce rollout model outputs at ~2e-02 bf16 noise floor, and
post-sync log_prob alignment holds for both experts. Note: low-noise tiles
show log_prob_mean_abs_diff ~5e-04 vs ~1e-05 for high — inherent SDE
variance shrinkage at low sigma amplifying forward noise, not misalignment;
consider relaxing --diffusion-clip-range for dual runs.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
sgl-d 929dc3b37 changed compute_weights_checksum to hash name + dtype +
shape + bytes; the miles-side mirror still hashed name + bytes only, so
MILES_VERIFY_WEIGHT_SYNC could never match. Restores algorithm parity and
documents the remaining structural limit: fp32 master vs bf16 engine params
hash different bytes by construction, so rely on log_prob alignment for
end-to-end sync correctness.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Samples one SDE window per Wan2.2 phase (high + low noise) and merges them
into a single non-contiguous index list — sgl-d gates SDE per step by list
membership, so non-contiguity is supported. --diffusion-sde-window-size
applies per phase; --diffusion-sde-window-range keeps its wan_high_window
meaning and restricts the high-noise phase only.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Dual-expert run: UPDATE_WEIGHT_TARGET_MODULES="transformer,transformer_2"
DIFFUSION_STEP_STRATEGY_PATH=miles.rollout.step_strategy_hub.wan_dual_window.
Default stays single 'transformer' (behavior unchanged).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The PickScore reward read only frame 0 of generated_output; for multi-frame
(video) rollouts this ignored all but the first frame. Score every frame and
mean-pool per sample, mirroring Flow-Factory's PickScore video handling.
Single-frame rollouts reduce to the previous behavior.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…tegies

Add control over the Wan2.2 rollout/training schedule for matching external
recipes and dual-expert training:

- --diffusion-flow-shift S: client computes the per-request sigma schedule and
  sends it with each rollout (composing out sgl-d's hardcoded shift via the
  multiplicative shift law shift_12(shift_{S/12}(x)) == shift_S(x)); the step
  strategies derive phase boundaries from the same effective shift, so routing
  and rollout stay consistent. None = server default (12.0).
- --diffusion-sde-candidate-steps "i,j,k": configurable SDE window candidate
  set for list-drawing step strategies (replaces a hardcoded default), so it
  goes into wandb config and is reproducible.
- wan_ff_global_window: per-epoch global SDE window that reproduces
  Flow-Factory's FlowMatchEulerDiscreteSDEScheduler draw (seed = epoch +
  rollout_seed) for A/B comparison.
- recipe: pass DIFFUSION_FLOW_SHIFT through.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant