Conversation
Repoint CUDA devices, HF/flashinfer caches, Python, datasets, and the Wan2.2 checkpoint to /cluster-storage paths. Raise rollout batch size to 48, num-rollout to 10000, microgroup to 8, sglang server concurrency to 8, and pickscore GPUs per worker to 1. Parameterize diffusion-clip-range and enable --diffusion-debug-mode. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- --update-weight-target-module accepts a comma-separated list; validated and parsed into args.update_weight_target_modules. - TrainPipelineConfig grows target_components, component_for_timestep() and select_guidance_scale() with single-DiT defaults. - Wan2_2TrainPipelineConfig routes by boundary_ratio=0.875, mirroring sgl-d's _select_and_manage_model: high-noise steps use 'transformer' + guidance_scale, low-noise 'transformer_2' + guidance_scale_2 (no silent fallback when guidance_scale_2 is unset). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
… sync
Actor loads every component named by --update-weight-target-module instead
of hardcoding pipeline.transformer (Wan2.2's transformer_2 was previously
discarded). Single component keeps the bare model as self.model so optimizer
and checkpoint state-dict keys are unchanged; multiple components wrap in an
nn.ModuleDict.
_forward_tile asserts the tile is phase-pure, dispatches to the component's
model via component_for_timestep, and selects the per-phase CFG scale
(low-noise tiles must use --diffusion-guidance-scale-2 to match rollout).
Per-component log_prob/model_output diff metrics are logged when training
more than one component.
DiffusionUpdateWeight holds {component: model}; buckets, LoRA merge indices
and checksum verification are grouped per component and pushed to sgl-d with
target_modules=[component] (the engine API already supports per-module
payload dicts).
Validated end-to-end on Wan2.2 (3-rollout smoke, wandb ti4yxlxr): both
phases reproduce rollout model outputs at ~2e-02 bf16 noise floor, and
post-sync log_prob alignment holds for both experts. Note: low-noise tiles
show log_prob_mean_abs_diff ~5e-04 vs ~1e-05 for high — inherent SDE
variance shrinkage at low sigma amplifying forward noise, not misalignment;
consider relaxing --diffusion-clip-range for dual runs.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
sgl-d 929dc3b37 changed compute_weights_checksum to hash name + dtype + shape + bytes; the miles-side mirror still hashed name + bytes only, so MILES_VERIFY_WEIGHT_SYNC could never match. Restores algorithm parity and documents the remaining structural limit: fp32 master vs bf16 engine params hash different bytes by construction, so rely on log_prob alignment for end-to-end sync correctness. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Samples one SDE window per Wan2.2 phase (high + low noise) and merges them into a single non-contiguous index list — sgl-d gates SDE per step by list membership, so non-contiguity is supported. --diffusion-sde-window-size applies per phase; --diffusion-sde-window-range keeps its wan_high_window meaning and restricts the high-noise phase only. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Dual-expert run: UPDATE_WEIGHT_TARGET_MODULES="transformer,transformer_2" DIFFUSION_STEP_STRATEGY_PATH=miles.rollout.step_strategy_hub.wan_dual_window. Default stays single 'transformer' (behavior unchanged). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The PickScore reward read only frame 0 of generated_output; for multi-frame (video) rollouts this ignored all but the first frame. Score every frame and mean-pool per sample, mirroring Flow-Factory's PickScore video handling. Single-frame rollouts reduce to the previous behavior. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…tegies
Add control over the Wan2.2 rollout/training schedule for matching external
recipes and dual-expert training:
- --diffusion-flow-shift S: client computes the per-request sigma schedule and
sends it with each rollout (composing out sgl-d's hardcoded shift via the
multiplicative shift law shift_12(shift_{S/12}(x)) == shift_S(x)); the step
strategies derive phase boundaries from the same effective shift, so routing
and rollout stay consistent. None = server default (12.0).
- --diffusion-sde-candidate-steps "i,j,k": configurable SDE window candidate
set for list-drawing step strategies (replaces a hardcoded default), so it
goes into wandb config and is reproducible.
- wan_ff_global_window: per-epoch global SDE window that reproduces
Flow-Factory's FlowMatchEulerDiscreteSDEScheduler draw (seed = epoch +
rollout_seed) for A/B comparison.
- recipe: pass DIFFUSION_FLOW_SHIFT through.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Support Wan2.2-T2V-A14B (dual-expert MoE) GRPO training
Adds end-to-end support for RL (GRPO) fine-tuning of Wan2.2-T2V-A14B in
miles-diffusion. Wan2.2's A14B is a two-expert MoE (high-noise
transformerfor t≥boundary, low-noise
transformer_2below), which the existingsingle-DiT pipeline could not train. Rollout runs on the sgl-d fork.
What's included (feat/wan, 13 commits)
Model adaptation
configs/wan2_2.py+train_pipeline_config.pyhooks (target_components,component_for_timestep,select_guidance_scale) — single-DiT models keepidentical behavior.
actor.py: multi-model load (self.modelsdict /ModuleDict), per-tilephase routing + per-tile CFG scale, per-component log-prob/output-diff metrics.
diffusion_update_weight_utils.py: per-component weight push to sgl-d.Rollout / schedule control
--diffusion-flow-shift S: client sends the per-request sigma schedule,composing out sgl-d's hardcoded shift (shift_12(shift_{S/12}(x))==shift_S(x));
step strategies derive phase boundaries from the same effective shift.
--diffusion-sde-candidate-steps: configurable SDE window candidate set(reproducible, logged to wandb config).
wan_high_window,wan_dual_window(one window per phase),wan_ff_global_window(reproduces Flow-Factory's per-epoch draw for A/B).Reward
to prior behavior).
Recipe / infra
scripts/run-diffusion-grpo-wan22-pickscore-4gpu.sh, sglang install pin,install skill tweaks.
Validation
floor; per-component weight-sync checksums consistent; post-sync log_prob
alignment holds).
--diffusion-flow-shiftverified against--save-debug-rollout-datatrajectory timesteps (matches shift-3 grid bit-for-bit).
~8e-6 (well under clip_range 1e-4).
Known limitations / notes for review
disable_adapter) path not implemented.--micro-batch-size-tstep>1mixed-phase tiles are asserted out (1×1 is clean).step 3 at t≈857.7 (low-noise) under shift 3; frameworks using a different
sigma_min convention will route step 3 differently — align grids for any
cross-framework A/B.