Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ gcs_metrics: False
save_config_to_gcs: False
log_period: 100

pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-480P-Diffusers'
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-720P-Diffusers'
model_name: wan2.1
model_type: 'I2V'

Expand Down Expand Up @@ -280,16 +280,16 @@ prompt: "An astronaut hatching from an egg, on the surface of the moon, the dark
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
do_classifier_free_guidance: True
height: 480
width: 832
height: 720
width: 1280
num_frames: 81
guidance_scale: 5.0
flow_shift: 3.0
flow_shift: 5.0

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
fps: 24
num_inference_steps: 50
fps: 16
save_final_checkpoint: False

# SDXL Lightning parameters
Expand Down
10 changes: 5 additions & 5 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,10 @@ prompt: "An astronaut hatching from an egg, on the surface of the moon, the dark
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
do_classifier_free_guidance: True
height: 480
width: 832
height: 720
width: 1280
num_frames: 81
flow_shift: 3.0
flow_shift: 5.0

# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
# guidance scale factor for low noise transformer
Expand All @@ -300,8 +300,8 @@ boundary_ratio: 0.875

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
fps: 24
num_inference_steps: 50
fps: 16
save_final_checkpoint: False

# SDXL Lightning parameters
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#hardware
hardware: 'tpu'
skip_jax_distributed_system: False
attention: 'flash'
attention_sharding_uniform: True

jax_cache_dir: ''
weights_dtype: 'bfloat16'
Expand Down
28 changes: 28 additions & 0 deletions src/maxdiffusion/models/vae_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import flax
import flax.linen as nn
import jax
from jax import tree_util
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict

Expand Down Expand Up @@ -930,3 +931,30 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r
return (sample,)

return FlaxDecoderOutput(sample=sample)


class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):
pass


def _wan_diag_gauss_dist_flatten(dist):
return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,)


def _wan_diag_gauss_dist_unflatten(aux, children):
mean, logvar, std, var = children
deterministic = aux[0]
obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution)
obj.mean = mean
obj.logvar = logvar
obj.std = std
obj.var = var
obj.deterministic = deterministic
return obj


tree_util.register_pytree_node(
WanDiagonalGaussianDistribution,
_wan_diag_gauss_dist_flatten,
_wan_diag_gauss_dist_unflatten,
)
Loading
Loading