-
Notifications
You must be signed in to change notification settings - Fork 56
[Feature] Add LoRA Inference Support for WAN Models via Flax NNX #308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Does LoRA support the I2V pipelines as well? |
9290a6e to
e1b7221
Compare
|
Added examples of I2V support |
entrpn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this implementation load multiple loras at once?
src/maxdiffusion/models/lora_nnx.py
Outdated
| return jnp.array(v) | ||
|
|
||
|
|
||
| def parse_lora_dict(state_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know which lora formats are supported by this function? There are a couple lora trainers out there, might want to specify in a comment or readme which ones we're specifically targeting (diffusers, or others).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment that it supports ComfyUI and AI Toolkit lora formats
|
Now supports multiple loras at once. Example added to description |
7f018e4 to
9b5051c
Compare
|
@Perseus14 please squash your commit and make sure linter tests pass. Other than than, looks good. |
Summary
This PR introduces full Low-Rank Adaptation (LoRA) inference support for the WAN family of models in MaxDiffusion.
Unlike previous implementations in this codebase that rely on
flax.linen, this implementation leveragesflax.nnx. This allows for a more Pythonic, object-oriented approach to weight injection, enabling us to modify thetransformer modelin-place.Key Features
1. Transition to
flax.nnxWAN models in MaxDiffusion are implemented using
flax.nnx. To support LoRA, we implemented a native NNX loader rather than wrappinglinenmodules.nnx.iter_graph) to identify target layers (nnx.Linear,nnx.Conv,nnx.Embed,nnx.LayerNorm) and merge LoRA weights directly into the kernel values.2. Robust Weight Merging Strategy
This implementation solves several critical distributed training/inference challenges:
jax.jit): To avoidShardingMismatchandDeviceArrayerrors that occur when mixing sharded TPU weights with CPU-based LoRA weights, all merge computations (kernel + delta) are performed within JIT-compiled functions (_compute_and_add_*_jit). This ensures weight updates occur efficiently on-device across the TPU mesh.jax.dlpackwhere possible to efficiently move PyTorch tensors to JAX arrays without unnecessary memory overhead.3. Advanced LoRA Support
Beyond standard
Linearrank reduction, this PR supports:diffweights before device-side merging.diff,diff_b): Supports checkpoints that include full-parameter fine-tuning offsets (difference injections) and bias tuning, which are common in high-fidelity WAN fine-tunes.text_embedding,time_embedding, andLayerNorm/RMSNormscales and biases.4. Scanned vs. Unscanned Layers
MaxDiffusion supports enabling
jax.scanfor transformer layers via thescan_layers: Trueconfiguration flag. This improves training memory efficiency by stacking weights of repeated layers (e.g., Attention, FFN) along a new leading dimension. Since users may run inference with or without this flag enabled, this LoRA implementation is designed to transparently support both modes.The loader distinguishes between:
merge_lora()function is used, which iterates through each layer and merges weights individually via efficient, on-device JIT calls (_compute_and_add_single_jit).merge_lora_for_scanned()function is used. It detects which parameters are stacked (e.g.,kernel.ndim > 2) and which are not._compute_and_add_scanned_jit. This updates all layers in the stack at once on-device, which is significantly more efficient than merging layer-by-layer.embeddings,proj_out): It merges them individually using the single-layer JIT logic.This dual approach ensures correct weight injection whether or not layers are scanned, while maximizing performance in scanned mode through batching.
Files Added / Modified
src/maxdiffusion/models/lora_nnx.py: [NEW] Core logic. Contains the JIT merge functions,parse_lora_dict, and the graph traversal logic (merge_lora,merge_lora_for_scanned) to inject weights into NNX modules.src/maxdiffusion/loaders/wan_lora_nnx_loader.py: [NEW] Orchestrates the loading process. Handles the download of safetensors, conversion of keys, and delegation to the merge functions.src/maxdiffusion/generate_wan.py: Updated the generation pipeline to identify iflorais enabled and trigger the loading sequence before inference.src/maxdiffusion/lora_conversion_utils.py: Updatedtranslate_wan_nnx_path_to_diffusers_lorato accurately map NNX paths (including embeddings and time projections) to Diffusers-style keys.Testing
Scenario 2: Validation of Multiple LoRA weights
WAN2.1 distill_lora and divine_power_lora
WAN2.2 distill_lora and orbit_shot_lora