From e01c1da932bf214b3b149d251734b9cf2156f1ac Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 27 Jan 2026 19:07:25 +0530 Subject: [PATCH] Enable JIT Compilation of WAN VAE Encoder/Decoder Forward Passes --- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 12 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 10 +- src/maxdiffusion/configs/ltx_video.yml | 2 + src/maxdiffusion/models/vae_flax.py | 28 +++ .../models/wan/autoencoder_kl_wan.py | 223 +++++++++++------- .../tests/ltx_transformer_step_test.py | 2 +- src/maxdiffusion/tests/wan_vae_test.py | 14 +- 7 files changed, 185 insertions(+), 106 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 07a84419..dce63bf1 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -27,7 +27,7 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 -pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-480P-Diffusers' +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-720P-Diffusers' model_name: wan2.1 model_type: 'I2V' @@ -280,16 +280,16 @@ prompt: "An astronaut hatching from an egg, on the surface of the moon, the dark prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -height: 480 -width: 832 +height: 720 +width: 1280 num_frames: 81 guidance_scale: 5.0 -flow_shift: 3.0 +flow_shift: 5.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 30 -fps: 24 +num_inference_steps: 50 +fps: 16 save_final_checkpoint: False # SDXL Lightning parameters diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index bcc69e66..1baf8dfd 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -281,10 +281,10 @@ prompt: "An astronaut hatching from an egg, on the surface of the moon, the dark prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -height: 480 -width: 832 +height: 720 +width: 1280 num_frames: 81 -flow_shift: 3.0 +flow_shift: 5.0 # Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py # guidance scale factor for low noise transformer @@ -300,8 +300,8 @@ boundary_ratio: 0.875 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 30 -fps: 24 +num_inference_steps: 50 +fps: 16 save_final_checkpoint: False # SDXL Lightning parameters diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 0ab88f70..4328da18 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -1,6 +1,8 @@ #hardware hardware: 'tpu' skip_jax_distributed_system: False +attention: 'flash' +attention_sharding_uniform: True jax_cache_dir: '' weights_dtype: 'bfloat16' diff --git a/src/maxdiffusion/models/vae_flax.py b/src/maxdiffusion/models/vae_flax.py index 013087ab..042ec275 100644 --- a/src/maxdiffusion/models/vae_flax.py +++ b/src/maxdiffusion/models/vae_flax.py @@ -21,6 +21,7 @@ import flax import flax.linen as nn import jax +from jax import tree_util import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict @@ -930,3 +931,30 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r return (sample,) return FlaxDecoderOutput(sample=sample) + + +class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution): + pass + + +def _wan_diag_gauss_dist_flatten(dist): + return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,) + + +def _wan_diag_gauss_dist_unflatten(aux, children): + mean, logvar, std, var = children + deterministic = aux[0] + obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution) + obj.mean = mean + obj.logvar = logvar + obj.std = std + obj.var = var + obj.deterministic = deterministic + return obj + + +tree_util.register_pytree_node( + WanDiagonalGaussianDistribution, + _wan_diag_gauss_dist_flatten, + _wan_diag_gauss_dist_unflatten, +) diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index b737fb6d..0328f6ac 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -19,11 +19,17 @@ import flax import jax import jax.numpy as jnp +from jax import tree_util from flax import nnx from ...configuration_utils import ConfigMixin from ..modeling_flax_utils import FlaxModelMixin, get_activation from ... import common_types -from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) +from ..vae_flax import ( + FlaxAutoencoderKLOutput, + FlaxDiagonalGaussianDistribution, + FlaxDecoderOutput, + WanDiagonalGaussianDistribution, +) BlockSizes = common_types.BlockSizes @@ -34,6 +40,12 @@ pass +def _update_cache(cache, idx, value): + if cache is None: + return None + return cache[:idx] + (value,) + cache[idx + 1 :] + + # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: """Canonicalizes a value to a tuple of integers.""" @@ -45,6 +57,15 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") +class RepSentinel: + + def __eq__(self, other): + return isinstance(other, RepSentinel) + + +tree_util.register_pytree_node(RepSentinel, lambda x: ((), None), lambda _, __: RepSentinel()) + + class WanCausalConv3d(nnx.Module): def __init__( @@ -76,6 +97,7 @@ def __init__( # Store the amount of padding needed *before* the depth dimension for caching logic self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + self.mesh = mesh # Set sharding dynamically based on out_channels. num_context_axis_devices = mesh.shape["context"] kernel_sharding = (None, None, None, None, None) @@ -121,6 +143,7 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) else: x_padded = x + out = self.conv(x_padded) return out @@ -312,30 +335,30 @@ def __init__( else: self.resample = Identity() - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): # Input x: (N, D, H, W, C), assume C = self.dim b, t, h, w, c = x.shape assert c == self.dim if self.mode == "upsample3d": if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx if feat_cache[idx] is None: - feat_cache[idx] = "Rep" - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, RepSentinel()) + feat_idx += 1 else: cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) - if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and not isinstance(feat_cache[idx], RepSentinel): # cache last frame of last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) - if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and isinstance(feat_cache[idx], RepSentinel): cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1) - if feat_cache[idx] == "Rep": + if isinstance(feat_cache[idx], RepSentinel): x = self.time_conv(x) else: x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 x = x.reshape(b, t, h, w, 2, c) x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1) x = x.reshape(b, t * 2, h, w, c) @@ -347,17 +370,17 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: if self.mode == "downsample3d": if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx if feat_cache[idx] is None: - feat_cache[idx] = jnp.copy(x) - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, jnp.copy(x)) + feat_idx += 1 else: cache_x = jnp.copy(x[:, -1:, :, :, :]) x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 - return x + return x, feat_cache, feat_idx class WanResidualBlock(nnx.Module): @@ -416,7 +439,7 @@ def __init__( else Identity() ) - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): # Apply shortcut connection h = self.conv_shortcut(x) @@ -424,32 +447,31 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.nonlinearity(x) if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv1(x, feat_cache[idx], idx) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv1(x) x = self.norm2(x) x = self.nonlinearity(x) - idx = feat_idx[0] if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv2(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv2(x) x = x + h - return x + return x, feat_cache, feat_idx class WanAttentionBlock(nnx.Module): @@ -486,7 +508,7 @@ def __init__( precision=precision, ) - def __call__(self, x: jax.Array): + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): identity = x batch_size, time, height, width, channels = x.shape @@ -509,7 +531,7 @@ def __call__(self, x: jax.Array): # Reshape back x = x.reshape(batch_size, time, height, width, channels) - return x + identity + return x + identity, feat_cache, feat_idx class WanMidBlock(nnx.Module): @@ -561,13 +583,13 @@ def __init__( self.attentions = nnx.data(attentions) self.resnets = nnx.data(resnets) - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - x = self.resnets[0](x, feat_cache, feat_idx) + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): + x, feat_cache, feat_idx = self.resnets[0](x, feat_cache, feat_idx) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - x = attn(x) - x = resnet(x, feat_cache, feat_idx) - return x + x, feat_cache, feat_idx = attn(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = resnet(x, feat_cache, feat_idx) + return x, feat_cache, feat_idx class WanUpBlock(nnx.Module): @@ -622,19 +644,13 @@ def __init__( ) ] - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): for resnet in self.resnets: - if feat_cache is not None: - x = resnet(x, feat_cache, feat_idx) - else: - x = resnet(x) + x, feat_cache, feat_idx = resnet(x, feat_cache, feat_idx) if self.upsamplers is not None: - if feat_cache is not None: - x = self.upsamplers[0](x, feat_cache, feat_idx) - else: - x = self.upsamplers[0](x) - return x + x, feat_cache, feat_idx = self.upsamplers[0](x, feat_cache, feat_idx) + return x, feat_cache, feat_idx class WanEncoder3d(nnx.Module): @@ -743,40 +759,38 @@ def __init__( precision=precision, ) - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + @nnx.jit(static_argnames="feat_idx") + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of the last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_in(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv_in(x) for layer in self.down_blocks: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) - x = self.mid_block(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = self.mid_block(x, feat_cache, feat_idx) x = self.norm_out(x) x = self.nonlinearity(x) if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_out(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv_out(x) - return x + return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32) class WanDecoder3d(nnx.Module): @@ -894,50 +908,47 @@ def __init__( precision=precision, ) - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + @nnx.jit(static_argnames="feat_idx") + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of the last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_in(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv_in(x) ## middle - x = self.mid_block(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = self.mid_block(x, feat_cache, feat_idx) ## upsamples for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = up_block(x, feat_cache, feat_idx) ## head x = self.norm_out(x) x = self.nonlinearity(x) if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of the last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_out(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv_out(x) - return x + return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32) class AutoencoderKLWanCache: def __init__(self, module): self.module = module - self.clear_cache() - - def clear_cache(self): - """Resets cache dictionaries and indices""" def _count_conv3d(module): count = 0 @@ -948,12 +959,38 @@ def _count_conv3d(module): return count self._conv_num = _count_conv3d(self.module.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode self._enc_conv_num = _count_conv3d(self.module.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num + self.init_cache() + + def init_cache(self): + """Resets cache dictionaries and indices""" + self._feat_map = (None,) * self._conv_num + # cache encode + self._enc_feat_map = (None,) * self._enc_conv_num + + +def _wan_cache_flatten(cache): + return (cache._feat_map, cache._enc_feat_map), (cache._conv_num, cache._enc_conv_num) + + +def _wan_cache_unflatten(aux, children): + conv_num, enc_conv_num = aux + feat_map, enc_feat_map = children + # Create a dummy object or one without module reference for JIT internal use + # We can't easily reconstruct 'module' but we don't need it for init_cache anymore + # if we store counts in aux. + # However, __init__ expects module. + # We will bypass __init__ for unflattening. + obj = AutoencoderKLWanCache.__new__(AutoencoderKLWanCache) + obj._conv_num = conv_num + obj._enc_conv_num = enc_conv_num + obj._feat_map = feat_map + obj._enc_feat_map = enc_feat_map + obj.module = None # module is not needed inside the trace for the cache logic now + return obj + + +tree_util.register_pytree_node(AutoencoderKLWanCache, _wan_cache_flatten, _wan_cache_unflatten) class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -1067,7 +1104,7 @@ def __init__( ) def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): - feat_cache.clear_cache() + feat_cache.init_cache() if x.shape[-1] != 3: # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) @@ -1075,21 +1112,27 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): t = x.shape[1] iter_ = 1 + (t - 1) // 4 + enc_feat_map = feat_cache._enc_feat_map + for i in range(iter_): - feat_cache._enc_conv_idx = [0] + enc_conv_idx = 0 if i == 0: - out = self.encoder(x[:, :1, :, :, :], feat_cache=feat_cache._enc_feat_map, feat_idx=feat_cache._enc_conv_idx) + out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx) else: - out_ = self.encoder( + out_, enc_feat_map, enc_conv_idx = self.encoder( x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], - feat_cache=feat_cache._enc_feat_map, - feat_idx=feat_cache._enc_conv_idx, + feat_cache=enc_feat_map, + feat_idx=enc_conv_idx, ) out = jnp.concatenate([out, out_], axis=1) + + # Update back to the wrapper object if needed, but for result we use local vars + feat_cache._enc_feat_map = enc_feat_map + enc = self.quant_conv(out) mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] enc = jnp.concatenate([mu, logvar], axis=-1) - feat_cache.clear_cache() + feat_cache.init_cache() return enc def encode( @@ -1097,7 +1140,7 @@ def encode( ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """Encode video into latent distribution.""" h = self._encode(x, feat_cache) - posterior = FlaxDiagonalGaussianDistribution(h) + posterior = WanDiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) @@ -1105,15 +1148,18 @@ def encode( def _decode( self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True ) -> Union[FlaxDecoderOutput, jax.Array]: - feat_cache.clear_cache() + feat_cache.init_cache() iter_ = z.shape[1] x = self.post_quant_conv(z) + + dec_feat_map = feat_cache._feat_map + for i in range(iter_): - feat_cache._conv_idx = [0] + conv_idx = 0 if i == 0: - out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) + out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx) else: - out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) + out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx) # This is to bypass an issue where frame[1] should be frame[2] and vise versa. # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. @@ -1131,8 +1177,11 @@ def _decode( fm3 = jnp.expand_dims(fm3, axis=axis) fm4 = jnp.expand_dims(fm4, axis=axis) out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) + + feat_cache._feat_map = dec_feat_map + out = jnp.clip(out, min=-1.0, max=1.0) - feat_cache.clear_cache() + feat_cache.init_cache() if not return_dict: return (out,) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 083ed265..c868bd95 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -104,7 +104,7 @@ def test_one_step_transformer(self): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + config_path = os.path.join(base_dir, "../models/ltx_video/ltxv-13B.json") with open(config_path, "r") as f: model_config = json.load(f) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 0f9158cb..73db7173 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -266,7 +266,7 @@ def test_wan_resample(self): # channels is always last here input_shape = (batch, t, h, w, dim) dummy_input = jnp.ones(input_shape) - output = wan_resample(dummy_input) + output, _, _ = wan_resample(dummy_input) assert output.shape == (batch, t, h // 2, w // 2, dim) def test_3d_conv(self): @@ -347,7 +347,7 @@ def test_wan_residual(self): with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) + dummy_output, _, _ = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape # --- Test Case 1: different in/out dim --- in_dim = 96 @@ -356,7 +356,7 @@ def test_wan_residual(self): wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) + dummy_output, _, _ = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape def test_wan_attention(self): @@ -371,7 +371,7 @@ def test_wan_attention(self): with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) dummy_input = jnp.ones(input_shape) - output = wan_attention(dummy_input) + output, _, _ = wan_attention(dummy_input) assert output.shape == input_shape def test_wan_midblock(self): @@ -396,7 +396,7 @@ def test_wan_midblock(self): with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) - output = wan_midblock(dummy_input) + output, _, _ = wan_midblock(dummy_input) assert output.shape == input_shape def test_wan_decode(self): @@ -522,11 +522,11 @@ def vae_encode(video, wan_vae, vae_cache, key): params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) wan_vae = nnx.merge(graphdef, params) - p_vae_encode = jax.jit(functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)) + p_vae_encode = functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key) original_video_shape = original_video.shape latent = p_vae_encode(original_video) - jitted_decode = jax.jit(functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)) + jitted_decode = functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False) video = jitted_decode(latent)[0] video = jnp.transpose(video, (0, 4, 1, 2, 3)) assert video.shape == original_video_shape