Skip to content

Conversation

@prishajain1
Copy link
Collaborator

@prishajain1 prishajain1 commented Jan 26, 2026

Refactor WAN VAE to enable JIT compilation of individual encoder and decoder forward passes by converting mutable cache management to immutable functional patterns. Each forward pass through the encoder/decoder network can now be JIT compiled while the overall encode/decode pipeline (chunking, looping) remains outside JIT.

Files modified:

  • autoencoder_kl_wan.py :
    • Refactored cache management from mutable lists to immutable tuples for JAX compatibility
    • Added PyTree registration for AutoencoderKLWanCache and RepSentinel to enable passing through JIT boundaries
    • Updated internal module signatures to return tuples (output, feat_cache, feat_idx) for explicit data flow
    • Added @nnx.jit decorators to WanEncoder3d and WanDecoder3d forward passes
    • Optimized cache initialization by moving structure calculations to init()
  • vae_flax.py :
    • Added WanDiagonalGaussianDistribution subclass with PyTree registration for JIT-safe distribution handling
    • Inherits all functionality from FlaxDiagonalGaussianDistribution with added JAX transformation support
  • base_wan_i2v_14b.yml : configs modified
  • base_wan_i2v_27b.yml : configs modified

Tested with:

  • WAN 2.1 T2V inference and training
  • WAN 2.2 T2V inference
  • WAN 2.1 I2V inference
  • WAN 2.2 I2V inference

@prishajain1 prishajain1 requested a review from entrpn as a code owner January 26, 2026 14:55
@github-actions
Copy link

@prishajain1 prishajain1 force-pushed the prisha/i2v_opt branch 3 times, most recently from a064fbe to 4cbc19c Compare January 27, 2026 12:48
@prishajain1 prishajain1 self-assigned this Jan 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants