Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 70 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@
limitations under the License.
-->

[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2026/1/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported
- **`2026/01/29`**: Wan LoRA for inference is now supported
- **`2026/01/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
- **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported.
- **`2025/10/14`**: NVIDIA DGX Spark Flux support.
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.
- **`2025/7/29`**: LTX-Video text2vid generation is now supported.
- **`2025/08/14`**: LTX-Video img2vid generation is now supported.
- **`2025/07/29`**: LTX-Video text2vid generation is now supported.
- **`2025/04/17`**: Flux Finetuning.
- **`2025/02/12`**: Flux LoRA for inference.
- **`2025/02/08`**: Flux schnell & dev inference.
- **`2024/12/12`**: Load multiple LoRAs for inference.
- **`2024/10/22`**: LoRA support for Hyper SDXL.
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.
- **`2024/08/01`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
- **`2024/07/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.

# Overview

Expand Down Expand Up @@ -68,14 +69,15 @@ MaxDiffusion supports
- [SD 1.4](#stable-diffusion-14-training)
- [Dreambooth](#dreambooth)
- [Inference](#inference)
- [Wan2.1](#wan21)
- [Wan2.2](#wan22)
- [Wan](#wan-models)
- [LTX-Video](#ltx-video)
- [Flux](#flux)
- [Fused Attention for GPU](#fused-attention-for-gpu)
- [SDXL](#stable-diffusion-xl)
- [SD 2 base](#stable-diffusion-2-base)
- [SD 2.1](#stable-diffusion-21)
- [Wan LoRA](#wan-lora)
- [Flux LoRA](#flux-lora)
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
- [Load Multiple LoRA](#load-multiple-lora)
- [SDXL Lightning](#sdxl-lightning)
Expand Down Expand Up @@ -482,41 +484,48 @@ To generate images, run the following command:

Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.

## Wan2.1
## Wan Models

Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).

### Text2Vid
Supports both Text2Vid and Img2Vid pipelines.

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
```

### Img2Vid

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_14b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
```

## Wan2.2

Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).

### Text2Vid
The following command will run Wan2.1 T2V:

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention="flash" \
num_inference_steps=50 \
num_frames=81 \
width=1280 \
height=720 \
jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ \
per_device_batch_size=.125 \
ici_data_parallelism=2 \
ici_context_parallelism=2 \
flow_shift=5.0 \
enable_profiler=True \
run_name=wan-inference-testing-720p \
output_dir=gs:/jfacevedo-maxdiffusion \
fps=16 \
flash_min_seq_length=0 \
flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' \
seed=118445
```

### Img2Vid
To run other Wan model inference pipelines, change the config file in the command above:

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_27b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
```
* For Wan2.1 I2V, use `base_wan_i2v_14b.yml`.
* For Wan2.2 T2V, use `base_wan_27b.yml`.
* For Wan2.2 I2V, use `base_wan_i2v_27b.yml`.

## Flux

Expand Down Expand Up @@ -568,6 +577,33 @@ To generate images, run the following command:
```bash
NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
```
## Wan LoRA

Disclaimer: not all LoRA formats have been tested. Currently supports ComfyUI and AI Toolkit formats. If there is a specific LoRA that doesn't load, please let us know.

First create a copy of the relevant config file eg: `src/maxdiffusion/configs/base_wan_{*}.yml`. Update the prompt and LoRA details in the config. Make sure to set `enable_lora: True`. Then run the following command:

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_i2v_14b.yml \ # --> Change to your copy
jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ \
per_device_batch_size=.125 \
ici_data_parallelism=2 \
ici_context_parallelism=2 \
run_name=wan-lora-inference-testing-720p \
output_dir=gs:/jfacevedo-maxdiffusion \
seed=118445 \
enable_lora=True \
```

Loading multiple LoRAs is supported as well.

## Flux LoRA

Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ profiler_steps: 10
enable_jax_named_scopes: False

# Generation parameters
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose."
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
do_classifier_free_guidance: True
height: 720
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ profiler_steps: 10
enable_jax_named_scopes: False

# Generation parameters
prompt: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
prompt_2: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
do_classifier_free_guidance: True
height: 720
Expand Down
Loading