From 54afa6501af00560070419d16151de0b8a451614 Mon Sep 17 00:00:00 2001 From: "zhangmaoquan.1" Date: Thu, 2 Apr 2026 01:14:57 +0000 Subject: [PATCH 1/6] [feat] JoyAI-JoyImage-Edit support --- scripts/convert_joyimage_edit_to_diffusers.py | 306 +++++ setup.py | 1 + src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_joyimage.py | 658 +++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/joyimage/__init__.py | 49 + .../joyimage/pipeline_joyimage_edit.py | 1187 +++++++++++++++++ .../pipelines/joyimage/pipeline_output.py | 16 + 10 files changed, 2226 insertions(+) create mode 100644 scripts/convert_joyimage_edit_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_joyimage.py create mode 100644 src/diffusers/pipelines/joyimage/__init__.py create mode 100644 src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py create mode 100644 src/diffusers/pipelines/joyimage/pipeline_output.py diff --git a/scripts/convert_joyimage_edit_to_diffusers.py b/scripts/convert_joyimage_edit_to_diffusers.py new file mode 100644 index 000000000000..37506ea05d17 --- /dev/null +++ b/scripts/convert_joyimage_edit_to_diffusers.py @@ -0,0 +1,306 @@ +import argparse +import pathlib +from typing import Any, Dict, Tuple +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors.torch import load_file +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from safetensors.torch import load_file +from diffusers import ( + AutoencoderKLWan, + JoyImageEditTransformer3DModel, + JoyImageEditPipeline, +) +# This code is modified from convert_wan_to_diffusers.py to support input ckpt path +def convert_vae(vae_ckpt_path): + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in old_state_dict.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + new_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + new_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + new_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + new_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Convert to down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Convert residual block naming but keep the original structure + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + + new_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Convert to up_blocks + parts = key.split(".") + block_idx = int(parts[2]) + + # Group residual blocks + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + # Keep as is for other blocks + new_state_dict[key] = value + continue + + # Convert residual block naming + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + + new_state_dict[new_key] = value + + # Handle shortcut connections + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + + new_state_dict[new_key] = value + + # Handle upsamplers + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + new_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + # Keep other keys unchanged + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan() + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + +def get_transformer_config() -> Tuple[Dict[str, Any], ...]: + config = { + "diffusers_config": { + "hidden_size": 4096, + "in_channels": 16, + "heads_num": 32, + "mm_double_blocks_depth": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_states_dim": 4096, + "rope_type": "rope", + "dit_modulation_type": "wanx", + "unpatchify_new": True, + "rope_theta": 10000, + }, + } + return config +def convert_transformer(ckpt_path: str): + checkpoint = torch.load(ckpt_path, weights_only=True) + if "model" in checkpoint: + original_state_dict = checkpoint["model"] + else: + original_state_dict = checkpoint + config = get_transformer_config() + with init_empty_weights(): + transformer = JoyImageEditTransformer3DModel(**config['diffusers_config']) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") + parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + parser.add_argument("--flow_shift", type=float, default=7.0) + return parser.parse_args() + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} +if __name__ == "__main__": + args = get_args() + transformer = None + vae = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + # assert args.tokenizer_path is not None + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + vae = vae.to(dtype=dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.save_pipeline: + processor = AutoProcessor.from_pretrained(args.text_encoder_path) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16).to("cuda") + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) + flow_shift = 1.5 + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=flow_shift + ) + transformer = transformer.to("cuda") + vae = vae.to("cuda") + pipe = JoyImageEditPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + processor.save_pretrained(f"{args.output_path}/processor") \ No newline at end of file diff --git a/setup.py b/setup.py index d42da57920a0..e16a2b792e25 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,7 @@ "accelerate>=0.31.0", "compel==0.1.8", "datasets", + "einops", "filelock", "flax>=0.4.1", "ftfy", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7d966452d1a2..0caedd8d7b81 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -237,6 +237,7 @@ "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", "HunyuanImageTransformer2DModel", + "JoyImageEditTransformer3DModel", "HunyuanVideo15Transformer3DModel", "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", @@ -596,6 +597,7 @@ "LTXLatentUpsamplePipeline", "LTXPipeline", "LucyEditPipeline", + "JoyImageEditPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", "LuminaPipeline", @@ -1025,6 +1027,7 @@ HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, HunyuanImageTransformer2DModel, + JoyImageEditTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, @@ -1359,6 +1362,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, LucyEditPipeline, + JoyImageEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, LuminaPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..6364ce0ab78f 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -110,6 +110,7 @@ _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] + _import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel", "JoyImageTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] @@ -225,6 +226,7 @@ HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, + JoyImageEditTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..1fc78c124618 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -34,6 +34,7 @@ from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel + from .transformer_joyimage import JoyImageEditTransformer3DModel, JoyImageTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py new file mode 100644 index 000000000000..c29284b3d8f7 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -0,0 +1,658 @@ +import math +from types import SimpleNamespace + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ..attention import FeedForward +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin + +ATTN_BACKEND = 'sdpa' +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + ATTN_BACKEND = 'flash_attn' +except: + pass + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + if len(x) == dim: + return tuple(x) + raise ValueError(f"Expected length {dim} or int, but got {x}") + +def get_meshgrid_nd(start, *args, dim=2): + if len(args) == 0: + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = _to_tuple(args[1], dim=dim) + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") + return torch.stack(grid, dim=0) + +def reshape_for_broadcast(freqs_cis, x: torch.Tensor, head_first: bool = False): + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + if head_first: + assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]) + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + if head_first: + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x: torch.Tensor): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis, head_first: bool = False): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + return xq_out, xk_out + + +def get_1d_rotary_pos_embed( + dim: int, + pos, + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, +): + if isinstance(pos, int): + pos = torch.arange(pos).float() + + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + freqs = torch.outer(pos.float() * interpolation_factor, freqs) + + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) + return freqs_cos, freqs_sin + + return torch.polar(torch.ones_like(freqs), freqs) + + +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + txt_rope_size=None, + theta_rescale_factor=1.0, + interpolation_factor=1.0, +): + rope_dim_list = list(rope_dim_list) + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) + + if isinstance(theta_rescale_factor, (int, float)): + theta_rescale_factor = [float(theta_rescale_factor)] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [float(theta_rescale_factor[0])] * len(rope_dim_list) + + if isinstance(interpolation_factor, (int, float)): + interpolation_factor = [float(interpolation_factor)] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [float(interpolation_factor[0])] * len(rope_dim_list) + + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs.append(emb) + + if use_real: + vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1)) + else: + vis_emb = torch.cat(embs, dim=1) + + if txt_rope_size is None: + return vis_emb, None + + embs_txt = [] + vis_max_ids = grid.view(-1).max().item() + grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1 + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid_txt, + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs_txt.append(emb) + + if use_real: + txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1)) + else: + txt_emb = torch.cat(embs_txt, dim=1) + + return vis_emb, txt_emb + +def get_cu_seqlens(text_mask, img_len): + """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len + + Args: + text_mask (torch.Tensor): the mask of text + img_len (int): the length of image + + Returns: + torch.Tensor: the calculated cu_seqlens for flash attention + """ + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + + cu_seqlens = torch.zeros([2 * batch_size + 1], + dtype=torch.int32, device="cuda") + + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + + return cu_seqlens + +def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + if modulate_type == "wanx": + return ModulateWan(hidden_size, factor, **factory_kwargs) + if modulate_type == "adaLN": + return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs) + if modulate_type == "jdx": + return ModulateX(hidden_size, factor, **factory_kwargs) + raise ValueError(f"Unknown modulation type: {modulate_type}.") + + +class ModulateDiT(nn.Module): + def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.factor = factor + self.act = act_layer() + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor): + return self.linear(self.act(x)).chunk(self.factor, dim=-1) + + +class ModulateWan(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, requires_grad=True + ) + + def forward(self, x: torch.Tensor): + if len(x.shape) != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] + + +class ModulateX(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): + super().__init__() + self.factor = factor + + def forward(self, x: torch.Tensor): + if len(x.shape) != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] + + +def modulate(x, shift=None, scale=None): + if scale is None and shift is None: + return x + if shift is None: + return x * (1 + scale.unsqueeze(1)) + if scale is None: + return x + shift.unsqueeze(1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False): + if gate is None: + return x + return x * (gate.unsqueeze(1).tanh() if tanh else gate.unsqueeze(1)) + +def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_kwargs=None): + batch_size = q.shape[0] + if ATTN_BACKEND == 'sdpa': + q = rearrange(q, "b l h c -> b h l c") + k = rearrange(k, "b l h c -> b h l c") + v = rearrange(v, "b l h c -> b h l c") + output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_kwargs['attn_mask']) + output = rearrange(output, "b h l c -> b l h c") + elif ATTN_BACKEND == 'flash_attn': + cu_seqlens_q = attn_kwargs['cu_seqlens_q'] + cu_seqlens_kv = attn_kwargs['cu_seqlens_kv'] + max_seqlen_q = attn_kwargs['max_seqlen_q'] + max_seqlen_kv = attn_kwargs['max_seqlen_kv'] + x = flash_attn_varlen_func( + q.view(q.shape[0] * q.shape[1], *q.shape[2:]), + k.view(k.shape[0] * k.shape[1], *k.shape[2:]), + v.view(v.shape[0] * v.shape[1], *v.shape[2:]), + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ) + # x with shape [(bxs), a, d] + output = x.view( + batch_size, max_seqlen_q, x.shape[-2], x.shape[-1] + ) # reshape + return output + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +class MMDoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + dtype=None, + device=None, + dit_modulation_type: str = "wanx", + attn_backend: str = "torch_spda", + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_backend = attn_backend + self.dit_modulation_type = dit_modulation_type + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) + self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + self.txt_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) + self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_freqs_cis=None, txt_freqs_cis=None, attn_kwargs=None): + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(vec) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(vec) + + img_modulated = modulate(self.img_norm1(img), shift=img_mod1_shift, scale=img_mod1_scale) + img_qkv = self.img_attn_qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + if vis_freqs_cis is not None: + img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) + + txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + if txt_freqs_cis is not None: + txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) + + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + v = torch.cat((img_v, txt_v), dim=1) + + attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) + img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] + + img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + img = img + apply_gate( + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), + gate=img_mod2_gate, + ) + + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + txt = txt + apply_gate( + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), + gate=txt_mod2_gate, + ) + + return img, txt + + +class WanTimeTextImageEmbedding(nn.Module): + def __init__(self, dim: int, time_freq_dim: int, time_proj_dim: int, text_embed_dim: int): + super().__init__() + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): + timestep = self.timesteps_proj(timestep) + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + return temb, timestep_proj, encoder_hidden_states + + +class JoyImageTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: tuple[int, int, int] = (1, 2, 2), + in_channels: int = 16, + out_channels: int = 16, + hidden_size: int = 4096, + heads_num: int = 32, + text_states_dim: int = 4096, + mlp_width_ratio: float = 4.0, + mm_double_blocks_depth: int = 40, + rope_dim_list: tuple[int, int, int] = (16, 56, 56), + rope_type: str = "rope", + dit_modulation_type: str = "wanx", + attn_backend: str = "torch_spda", + unpatchify_new: bool = True, + rope_theta: int = 256, + enable_activation_checkpointing: bool = False, + is_repa: bool = False, + repa_layer: int = 13, + ): + super().__init__() + + self.args = SimpleNamespace( + enable_activation_checkpointing=enable_activation_checkpointing, + is_repa=is_repa, + repa_layer=repa_layer, + ) + + self.out_channels = out_channels or in_channels + self.patch_size = tuple(patch_size) + self.hidden_size = hidden_size + self.heads_num = heads_num + self.rope_dim_list = tuple(rope_dim_list) + self.dit_modulation_type = dit_modulation_type + self.mm_double_blocks_depth = mm_double_blocks_depth + self.attn_backend = attn_backend + self.rope_type = rope_type + self.unpatchify_new = unpatchify_new + self.theta = rope_theta + + if hidden_size % heads_num != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") + + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + + self.condition_embedder = WanTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6 if dit_modulation_type != "adaLN" else hidden_size, + text_embed_dim=text_states_dim, + ) + + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + hidden_size=self.hidden_size, + heads_num=self.heads_num, + mlp_width_ratio=mlp_width_ratio, + dit_modulation_type=self.dit_modulation_type, + attn_backend=attn_backend, + ) + for _ in range(mm_double_blocks_depth) + ] + ) + + self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, out_channels * math.prod(self.patch_size)) + + if self.args.is_repa: + self.repa_proj = nn.Linear(hidden_size, text_states_dim) + if self.args.repa_layer > mm_double_blocks_depth: + raise ValueError("repa_layer should be smaller than total depth") + + self.gradient_checkpointing = enable_activation_checkpointing + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None): + target_ndim = 3 + if len(vis_rope_size) != target_ndim: + vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + list(vis_rope_size) + + head_dim = self.hidden_size // self.heads_num + rope_dim_list = list(self.rope_dim_list) + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + if sum(rope_dim_list) != head_dim: + raise ValueError("sum(rope_dim_list) should equal head_dim") + + return get_nd_rotary_pos_embed( + rope_dim_list, + vis_rope_size, + txt_rope_size=txt_rope_size, + theta=self.theta, + use_real=True, + theta_rescale_factor=1, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + return_dict: bool = True, + ): + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states is required.") + + is_multi_item = len(hidden_states.shape) == 6 + num_items = 0 + if is_multi_item: + num_items = hidden_states.shape[1] + if num_items > 1: + if self.patch_size[0] != 1: + raise ValueError("For multi-item input, patch_size[0] must be 1") + hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1) + hidden_states = rearrange(hidden_states, "b n c t h w -> b c (n t) h w") + + _, _, ot, oh, ow = hidden_states.shape + tt, th, tw = ( + ot // self.patch_size[0], + oh // self.patch_size[1], + ow // self.patch_size[2], + ) + + if encoder_hidden_states_mask is None: + encoder_hidden_states_mask = torch.ones( + (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), + dtype=torch.bool, + device=encoder_hidden_states.device, + ) + + img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + txt_seq_len = txt.shape[1] + vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed( + vis_rope_size=(tt, th, tw), + txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, + ) + + txt_seq_len = txt.shape[1] + img_seq_len = img.shape[1] + + cu_seqlens_q = get_cu_seqlens( + encoder_hidden_states_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q + + + attn_kwargs = {"encoder_hidden_states_mask": encoder_hidden_states_mask} + attn_kwargs.update({ + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_kv': cu_seqlens_kv, + 'max_seqlen_q': max_seqlen_q, + 'max_seqlen_kv': max_seqlen_kv, + }) + + max_seqlen_q = img_seq_len + txt_seq_len + seq_lens = encoder_hidden_states_mask.sum(dim=1) + img_seq_len + max_len = encoder_hidden_states_mask.shape[1] + img_seq_len + assert max_seqlen_q == max_len + positions = torch.arange(max_seqlen_q, device=img.device).unsqueeze(0) + seq_lens_expanded = seq_lens.unsqueeze(1) + mask = positions < seq_lens_expanded + mask = mask.unsqueeze(1).unsqueeze(2) + attn_mask = mask & mask.transpose(-1, -2) + + attn_kwargs.update({'attn_mask': attn_mask, 'max_seqlen_q': max_seqlen_q}) + + img_hidden_states = [] + for block in self.double_blocks: + img, txt = block(img, txt, vec, vis_freqs_cis, txt_freqs_cis, attn_kwargs) + img_hidden_states.append(img) + + img_len = img.shape[1] + x = torch.cat((img, txt), 1) + img = x[:, :img_len, ...] + + img = self.proj_out(self.norm_out(img)) + img = self.unpatchify(img, tt, th, tw) + + repa_hidden_state = None + if self.args.is_repa: + repa_hidden_state = self.repa_proj(img_hidden_states[self.args.repa_layer]) + repa_hidden_state = repa_hidden_state.view(img.shape[0], tt, th, tw, -1) + + if is_multi_item: + img = rearrange(img, "b c (n t) h w -> b n c t h w", n=num_items) + if num_items > 1: + img = torch.cat([img[:, 1:], img[:, :1]], dim=1) + if repa_hidden_state is not None: + repa_hidden_state = rearrange(repa_hidden_state, "b (n t) h w c -> b n t h w c", n=num_items) + + if not return_dict: + return (img, txt, repa_hidden_state) + + return Transformer2DModelOutput(sample=img) + + def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int): + c = self.out_channels + pt, ph, pw = self.patch_size + if t * h * w != x.shape[1]: + raise ValueError("Invalid token length for unpatchify.") + + if self.unpatchify_new: + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = torch.einsum("nthwopqc->nctohpwq", x) + else: + x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) + x = torch.einsum("nthwcopq->nctohpwq", x) + + return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + + +class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): + """ + Backward-compatible alias of JoyImageTransformer3DModel. + """ + + pass \ No newline at end of file diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3dafb56fdd65..d9159be6156b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -299,6 +299,7 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] + _import_structure["joyimage"] = ["JoyImageEditPipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -740,6 +741,7 @@ ) from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline + from .joyimage import JoyImageEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .marigold import ( diff --git a/src/diffusers/pipelines/joyimage/__init__.py b/src/diffusers/pipelines/joyimage/__init__.py new file mode 100644 index 000000000000..a6d5f31fe63c --- /dev/null +++ b/src/diffusers/pipelines/joyimage/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput"] + _import_structure["pipeline_joyimage_edit"] = ["JoyImageEditPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_joyimage_edit import JoyImageEditPipeline + from .pipeline_output import JoyImageEditPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) \ No newline at end of file diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py new file mode 100644 index 000000000000..190cb3d90c31 --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py @@ -0,0 +1,1187 @@ +import inspect +import math +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torchvision.transforms.functional as TF +from einops import rearrange +from PIL import Image +from transformers import AutoProcessor, Qwen2Tokenizer, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKLWan, JoyImageEditTransformer3DModel +from ..pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import BaseOutput, deprecate, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import JoyImageEditPipelineOutput + + +EXAMPLE_DOC_STRING = """""" + +# Mapping from precision string to torch dtype. +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +class BucketGroup: + """Manages dynamic batch grouping buckets for image inference.""" + + def __init__( + self, + bucket_configs: list[tuple[int, int, int, int, int]], + prioritize_frame_matching: bool = True, + ): + """ + Initialize bucket group with predefined configurations. + + Args: + bucket_configs: List of (batch_size, num_items, num_frames, height, width) tuples. + prioritize_frame_matching: Unused, kept for API compatibility. + """ + self.bucket_configs = [tuple(b) for b in bucket_configs] + + def find_best_bucket(self, media_shape: tuple[int, int, int, int]) -> tuple[int, int, int, int, int]: + """ + Find the best matching bucket for given media dimensions. + + Selects the bucket whose aspect ratio (height/width) is closest to that of + the input media. Only image inference (num_frames=1) is supported. + + Args: + media_shape: (num_items, num_frames, height, width) of the input media. + + Returns: + Best matching bucket as (batch_size, num_items, num_frames, height, width). + + Raises: + ValueError: If num_frames != 1 or no valid bucket is found. + """ + num_items, num_frames, height, width = media_shape + target_aspect_ratio = height / width + + if num_frames != 1: + raise ValueError( + f"Only image inference (num_frames=1) is supported, got num_frames={num_frames}" + ) + + valid_buckets = [ + b for b in self.bucket_configs + if b[1] == num_items and b[2] == 1 + ] + if not valid_buckets: + raise ValueError(f"No image buckets found for shape {media_shape}") + + return min( + valid_buckets, + key=lambda bucket: abs((bucket[3] / bucket[4]) - target_aspect_ratio), + ) + + +def _get_text_encoder_ckpt( + text_encoder: Qwen3VLForConditionalGeneration, + fallback: str = "Qwen/Qwen3-VL-8B-Instruct", +) -> str: + """ + Retrieve the checkpoint identifier from the text encoder. + + Args: + text_encoder: The text encoder model instance. + fallback: Default checkpoint name if none can be resolved. + + Returns: + A non-empty string identifying the checkpoint. + """ + candidates = [ + getattr(text_encoder, "name_or_path", None), + getattr(getattr(text_encoder, "config", None), "_name_or_path", None), + ] + for c in candidates: + if isinstance(c, str) and len(c) > 0: + return c + return fallback + + +def _generate_hw_buckets( + base_height: int = 256, + base_width: int = 256, + step_width: int = 16, + step_height: int = 16, + max_ratio: float = 4.0, +) -> list[tuple[int, int, int, int, int]]: + """ + Generate (batch_size=1, num_items=1, num_frames=1, height, width) bucket tuples + covering a range of aspect ratios while keeping total pixels close to + base_height * base_width. + + Args: + base_height: Reference height in pixels. + base_width: Reference width in pixels. + step_width: Width increment per step. + step_height: Height decrement per step. + max_ratio: Maximum allowed aspect ratio (long side / short side). + + Returns: + List of bucket tuples (1, 1, 1, height, width). + """ + buckets = [] + target_pixels = base_height * base_width + + height = target_pixels // step_width + width = step_width + + while height >= step_height: + if max(height, width) / min(height, width) <= max_ratio: + buckets.append((1, 1, 1, height, width)) + if height * (width + step_width) <= target_pixels: + width += step_width + else: + height -= step_height + + return buckets + + +def generate_video_image_bucket( + basesize: int = 256, + min_temporal: int = 65, + max_temporal: int = 129, + bs_img: int = 8, + bs_vid: int = 1, + bs_mimg: int = 4, + min_items: int = 1, + max_items: int = 1, +) -> list[list[int]]: + """ + Generate bucket configurations for image inference. + + Each bucket is represented as [batch_size, num_items, num_frames, height, width]. + Spatial dimensions are scaled by (basesize // 256) when basesize > 256. + + Args: + basesize: Base spatial resolution. Must be one of {256, 512, 768, 1024}. + min_temporal: Unused; kept for API compatibility. + max_temporal: Unused; kept for API compatibility. + bs_img: Batch size for single-image buckets. + bs_vid: Unused; kept for API compatibility. + bs_mimg: Batch size for multi-image buckets. + min_items: Minimum number of items in multi-image buckets. + max_items: Maximum number of items in multi-image buckets. + + Returns: + List of bucket configs as [batch_size, num_items, num_frames, height, width]. + + Raises: + AssertionError: If basesize is not in {256, 512, 768, 1024}. + """ + assert basesize in [256, 512, 768, 1024], ( + f"[generate_video_image_bucket] unsupported basesize {basesize}" + ) + bucket_list = [] + base_bucket_list = _generate_hw_buckets() + + # Single-image buckets. + for _bucket in base_bucket_list: + bucket = list(_bucket) + bucket[0] = bs_img + bucket_list.append(bucket) + + # Multi-image buckets. + for num_items in range(min_items, max_items + 1): + for _bucket in base_bucket_list: + bucket = list(_bucket) + bucket[0] = bs_mimg + bucket[1] = num_items + bucket_list.append(bucket) + + # Scale spatial dimensions when basesize exceeds 256. + if basesize > 256: + ratio = basesize // 256 + + def _scale(bucket: list[int], r: int) -> list[int]: + bucket[-2] *= r + bucket[-1] *= r + return bucket + + bucket_list = [_scale(bucket, ratio) for bucket in bucket_list] + + return bucket_list + +def _resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + """Scale to cover target_size, then center-crop.""" + w, h = img.size # PIL uses (width, height). + bh, bw = target_size + scale = max(bh / h, bw / w) + resize_h = math.ceil(h * scale) + resize_w = math.ceil(w * scale) + img = TF.resize(img, (resize_h, resize_w), interpolation=TF.InterpolationMode.BILINEAR, antialias=True) + img = TF.center_crop(img, target_size) + return img + +def _dynamic_resize_from_bucket(image_size: Tuple[int, int], basesize: int = 512) -> Tuple[int, int]: + """ + Resize and center-crop an image to the nearest bucket dimensions. + + The best-matching bucket is selected based on the image's aspect ratio. + The image is first scaled so that neither dimension is smaller than the + target, then center-cropped to the exact target size. + + Args: + image: Input PIL image. Returns None if None is passed. + basesize: Base resolution used to generate candidate buckets. + + Returns: + Resized and cropped PIL image, or None if input is None. + """ + bucket_config = generate_video_image_bucket( + basesize=basesize, + min_temporal=56, + max_temporal=56, + bs_img=4, + bs_vid=4, + bs_mimg=8, + min_items=2, + max_items=2, + ) + bucket_group = BucketGroup(bucket_config) + src_w, src_h = image_size + bucket = bucket_group.find_best_bucket((1, 1, src_h, src_w)) + target_height, target_width = bucket[-2], bucket[-1] + return target_height, target_width + + + +def _build_args( + args: Any, + text_encoder: Qwen3VLForConditionalGeneration, +) -> Any: + """ + Return args unchanged if provided, otherwise construct a default namespace. + + Args: + args: Existing args object, or None. + text_encoder: Text encoder used to resolve the checkpoint path when args is None. + + Returns: + The original args object, or a SimpleNamespace with sensible defaults. + """ + if args is not None: + return args + + text_encoder_ckpt = _get_text_encoder_ckpt(text_encoder) + return SimpleNamespace( + enable_multi_task_training=False, + text_token_max_length=2048, + dit_precision="bf16", + vae_precision="bf16", + text_encoder_arch_config={"params": {"text_encoder_ckpt": text_encoder_ckpt}}, + ) + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Configure the scheduler and return its timestep sequence. + + Exactly one of ``timesteps``, ``sigmas``, or ``num_inference_steps`` should be + provided to control the denoising schedule. + + Args: + scheduler: The diffusion scheduler. + num_inference_steps: Number of denoising steps (used when neither + ``timesteps`` nor ``sigmas`` is given). + device: Target device for the timestep tensor. + timesteps: Custom discrete timesteps. + sigmas: Custom sigma values (alternative to ``timesteps``). + **kwargs: Additional keyword arguments forwarded to ``set_timesteps``. + + Returns: + Tuple of (timesteps tensor, num_inference_steps int). + + Raises: + ValueError: If both ``timesteps`` and ``sigmas`` are provided, or if the + scheduler does not support the requested schedule parameterisation. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + + if timesteps is not None: + if "timesteps" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom timesteps.") + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if "sigmas" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom sigmas.") + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +@dataclass +class _LegacyPipelineOutput(BaseOutput): + """Legacy output dataclass retained for backward compatibility.""" + + videos: Union[torch.Tensor, np.ndarray] + + +class JoyImageEditPipeline(DiffusionPipeline): + """ + Diffusion pipeline for image editing using the JoyImage architecture. + + The pipeline encodes text and image conditioning via a Qwen3-VL text encoder, + denoises latents with a 3-D transformer, and decodes the result with a WAN VAE. + + Model offloading order: text_encoder -> transformer -> vae. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: JoyImageEditTransformer3DModel, + processor: Qwen3VLProcessor, + args: Any = None, + ): + """ + Initialise the pipeline and register all sub-modules. + + Args: + scheduler: Noise scheduler for the denoising process. + vae: Variational autoencoder used for encoding / decoding latents. + text_encoder: Qwen3-VL multimodal language model for prompt encoding. + tokenizer: Tokenizer paired with the text encoder. + transformer: 3-D transformer denoising network. + processor: Qwen3-VL processor for multi-image prompt preparation. + args: Optional configuration namespace. Defaults are inferred when None. + """ + super().__init__() + self.args = _build_args(args=args, text_encoder=text_encoder) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + processor=processor, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = ( + self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + text_encoder_ckpt = dict(self.args.text_encoder_arch_config.get("params", {})).get( + "text_encoder_ckpt", _get_text_encoder_ckpt(self.text_encoder) + ) + self.qwen_processor = ( + processor if processor is not None else AutoProcessor.from_pretrained(text_encoder_ckpt) + ) + + self.text_token_max_length = self.args.text_token_max_length + + # Prompt templates used when encoding text with / without image tokens. + self.prompt_template_encode = { + "image": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ), + "multiple_images": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "{}<|im_start|>assistant\n" + ), + } + # Number of system-prompt tokens to drop from the beginning of hidden states. + self.prompt_template_encode_start_idx = { + "image": 34, + "multiple_images": 34, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _extract_masked_hidden( + self, hidden_states: torch.Tensor, mask: torch.Tensor + ) -> tuple[torch.Tensor, ...]: + """ + Extract valid (non-padded) hidden states for each sequence in the batch. + + Args: + hidden_states: Shape (B, T, D). + mask: Binary attention mask of shape (B, T). + + Returns: + Tuple of tensors, one per batch element, each of shape (valid_T, D). + """ + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + return torch.split(selected, valid_lengths.tolist(), dim=0) + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + template_type: str = "image", + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode text prompts using the Qwen tokenizer (text-only path). + + Args: + prompt: A single prompt string or a list of prompt strings. + template_type: Key into ``prompt_template_encode`` / ``prompt_template_encode_start_idx``. + device: Target device. + dtype: Target floating-point dtype. + + Returns: + Tuple of (prompt_embeds, encoder_attention_mask) where both tensors + have shape (B, max_seq_len, D) and (B, max_seq_len) respectively, + zero-padded to the same length. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + max_length=self.text_token_max_length + drop_idx, + padding=True, + truncation=True, + return_tensors="pt", + ).to(device) + + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + + # Drop system-prompt prefix tokens and re-pack into a padded batch. + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [ + torch.ones(e.size(0), dtype=torch.long, device=e.device) + for e in split_hidden_states + ] + + max_seq_len = min( + self.text_token_max_length, + max(u.size(0) for u in split_hidden_states), + max(u.size(0) for u in attn_mask_list), + ) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds, encoder_attention_mask + + def encode_prompt_multiple_images( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + images: Optional[torch.Tensor] = None, + template_type: Optional[str] = "multiple_images", + max_sequence_length: Optional[int] = None, + drop_vit_feature: Optional[float] = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode prompts that contain inline image tokens via the Qwen processor. + + ``\\n`` placeholders in each prompt string are replaced by the + Qwen vision special tokens before being fed to the multimodal encoder. + + Args: + prompt: Prompt string(s), optionally containing ``\\n`` tokens. + device: Target device. + images: Pixel tensors corresponding to the inline image tokens. + template_type: Must be ``"multiple_images"``. + max_sequence_length: If set, truncate the output to this length + (keeping the last ``max_sequence_length`` tokens). + drop_vit_feature: When True, drop all tokens up to and including the + last vision-end token so that only the text portion is returned. + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask). + """ + assert template_type == "multiple_images" + device = device or self._execution_device + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # If no image tokens are present, discard the image tensors. + if not any("\n" in p for p in prompt): + images = None + + prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt] + prompt = [template.format(p) for p in prompt] + + inputs = self.qwen_processor( + text=prompt, + images=images, + padding=True, + return_tensors="pt", + ).to(device) + + encoder_hidden_states = self.text_encoder(**inputs, output_hidden_states=True) + last_hidden_states = encoder_hidden_states.hidden_states[-1] + + if drop_vit_feature: + # Find the last vision-end token and drop everything before it. + input_ids = inputs["input_ids"] + vlm_image_end_idx = torch.where(input_ids[0] == 151653)[0][-1] + drop_idx = vlm_image_end_idx + 1 + + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:] + + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + + return prompt_embeds, prompt_embeds_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + images: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + template_type: str = "image", + drop_vit_feature: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode a text prompt (and optional inline images) into embeddings. + + When ``images`` is provided the multi-image encoding path is used; + otherwise the text-only Qwen tokenizer path is used. Pre-computed + ``prompt_embeds`` bypass encoding entirely. + + Args: + prompt: Prompt string or list of prompt strings. + images: Optional image tensors for multi-image conditioning. + device: Target device. + num_images_per_prompt: Number of outputs to generate per prompt. + prompt_embeds: Pre-computed prompt embeddings. + prompt_embeds_mask: Attention mask for pre-computed embeddings. + max_sequence_length: Maximum output sequence length. + template_type: Prompt template key (``"image"`` or ``"multiple_images"``). + drop_vit_feature: Drop vision tokens in the multi-image path. + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask). + """ + if images is not None: + return self.encode_prompt_multiple_images( + prompt=prompt, + images=images, + device=device, + max_sequence_length=max_sequence_length, + drop_vit_feature=drop_vit_feature, + ) + + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, template_type, device + ) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def decode_latents(self, latents: torch.Tensor, enable_tiling: bool = True) -> torch.Tensor: + """ + Decode latents to pixel values. + + .. deprecated:: 1.0.0 + Use the VAE directly instead of calling this method. + + Args: + latents: Latent tensor to decode. + enable_tiling: Whether to enable tiled decoding to save memory. + + Returns: + Float tensor of shape (..., H, W, C) with values in [0, 1]. + """ + deprecation_message = "The decode_latents method is deprecated." + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + if image.ndim == 4: + image = image.cpu().permute(0, 2, 3, 1).float() + else: + image = image.cpu().float() + return image + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + """ + Validate pipeline inputs before the forward pass. + + Raises: + ValueError: On any invalid combination of arguments. + """ + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError("`callback_on_step_end_tensor_inputs` has invalid keys.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.") + elif prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + elif prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError("`prompt` has to be of type `str` or `list`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.") + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError("If `prompt_embeds` are provided, `prompt_embeds_mask` is required.") + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` is required." + ) + + def normalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + """ + Normalise latents using per-channel statistics from the VAE config. + + Uses (latent - mean) / std when the VAE exposes ``latents_mean`` and + ``latents_std``; otherwise falls back to scaling by ``scaling_factor``. + + Args: + latent: Raw latent tensor from ``vae.encode``. + + Returns: + Normalised latent tensor. + """ + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latent = (latent - latents_mean) / latents_std + else: + latent = latent * self.vae.config.scaling_factor + return latent + + def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + """ + Invert :meth:`normalize_latents` to recover the original latent scale. + + Args: + latent: Normalised latent tensor. + + Returns: + Latent tensor in the scale expected by ``vae.decode``. + """ + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latent = latent * latents_std + latents_mean + else: + latent = latent / self.vae.config.scaling_factor + return latent + + def prepare_latents( + self, + batch_size: int, + num_items: int, + num_channels_latents: int, + height: int, + width: int, + video_length: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + latents: Optional[torch.Tensor] = None, + reference_images: Optional[List[Image.Image]] = None, + enable_denormalization: bool = True, + ) -> torch.Tensor: + """ + Prepare the initial noisy latent tensor for the denoising loop. + + When ``reference_images`` is provided the first (num_items - 1) slots are + filled with VAE-encoded reference image latents; the last slot is random noise. + When ``latents`` is provided it is moved to ``device`` without modification. + Otherwise pure random noise is returned. + + Args: + batch_size: Number of samples in the batch. + num_items: Number of image slots (reference + target). + num_channels_latents: Latent channel dimension from the transformer config. + height: Spatial height in pixels. + width: Spatial width in pixels. + video_length: Number of frames (1 for image inference). + dtype: Floating-point dtype for the latent tensor. + device: Target device. + generator: RNG generator(s) for reproducible sampling. + latents: Optional pre-allocated latent tensor. + reference_images: Optional list of PIL images to encode as conditioning. + enable_denormalization: Whether to normalise encoded reference latents. + + Returns: + Latent tensor of shape (B, num_items, C, T, H', W'). + + Raises: + ValueError: If ``generator`` is a list whose length differs from ``batch_size``. + """ + shape = ( + batch_size, + num_items, + num_channels_latents, + (video_length - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError("Generator list length must match batch size.") + + if latents is None: + if reference_images is not None: + # Encode reference images and concatenate with a noise slot. + ref_img = [torch.from_numpy(np.array(x.convert("RGB"))) for x in reference_images] + ref_img = torch.stack(ref_img).to(device=device, dtype=dtype) + ref_img = ref_img / 127.5 - 1.0 + ref_img = rearrange(ref_img, "x h w c -> x c 1 h w") + ref_vae = self.vae.encode(ref_img).latent_dist.sample() + if enable_denormalization: + ref_vae = self.normalize_latents(ref_vae) + ref_vae = rearrange(ref_vae, "(b n) c 1 h w -> b n c 1 h w", n=(num_items - 1)) + noise = randn_tensor((shape[0], 1, *shape[2:]), generator=generator, device=device, dtype=dtype) + latents = torch.cat([ref_vae, noise], dim=1) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + # ------------------------------------------------------------------ + # Pipeline properties + # ------------------------------------------------------------------ + + @property + def guidance_scale(self) -> float: + """Classifier-free guidance scale used in the current forward pass.""" + return self._guidance_scale + + @property + def do_classifier_free_guidance(self) -> bool: + """True when guidance_scale > 1, enabling classifier-free guidance.""" + return self._guidance_scale > 1 + + @property + def num_timesteps(self) -> int: + """Total number of denoising timesteps in the current forward pass.""" + return self._num_timesteps + + @property + def interrupt(self) -> bool: + """When True, the denoising loop is interrupted at the next step.""" + return self._interrupt + + # ------------------------------------------------------------------ + # Utility + # ------------------------------------------------------------------ + + def pad_sequence(self, x: torch.Tensor, target_length: int) -> torch.Tensor: + """ + Truncate or zero-pad a sequence tensor along dimension 1. + + If the sequence is longer than ``target_length`` the last + ``target_length`` elements are kept. If it is shorter, zero-padding + is appended on the right. + + Args: + x: Input tensor of shape (B, T, ...) or (B, T). + target_length: Desired sequence length. + + Returns: + Tensor of shape (B, target_length, ...) or (B, target_length). + """ + current_length = x.shape[1] + if current_length >= target_length: + return x[:, -target_length:] + padding_length = target_length - current_length + if x.ndim >= 3: + padding = torch.zeros((x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device) + else: + padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device) + return torch.cat([x, padding], dim=1) + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + enable_tiling: bool = False, + max_sequence_length: int = 4096, + drop_vit_feature: bool = False, + enable_denormalization: bool = True, + **kwargs, + ): + r""" + Generate an edited image conditioned on a reference image and a text prompt. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide generation. + height (`int`): + Height of the generated output in pixels. + width (`int`): + Width of the generated output in pixels. + image (`PipelineImageInput`, *optional*): + Reference image used for conditioning. When provided the pipeline + operates in image-editing mode with ``num_items=2``. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps. More steps generally improve quality at + the cost of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps for the denoising process. When provided, + ``num_inference_steps`` is inferred from the list length. + sigmas (`List[float]`, *optional*): + Custom sigmas for the denoising process. Mutually exclusive with + ``timesteps``. + guidance_scale (`float`, *optional*, defaults to 4.0): + Classifier-free guidance scale. + negative_prompt (`str` or `List[str]`, *optional*): + Negative prompt(s) used to suppress undesired content. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of generated samples per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + RNG generator(s) for deterministic sampling. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. Sampled from a Gaussian distribution + when not provided. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed prompt embeddings. When provided ``prompt`` can be omitted. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for ``prompt_embeds``. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative prompt embeddings. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for ``negative_prompt_embeds``. + output_type (`str`, *optional*, defaults to ``"pil"``): + Output format. Pass ``"latent"`` to return raw latents. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a :class:`JoyImageEditPipelineOutput` or a plain tensor. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + Callback invoked at the end of each denoising step with signature + ``(self, step: int, timestep: int, callback_kwargs: Dict)``. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*, defaults to ``["latents"]``): + Tensor keys included in ``callback_kwargs`` for ``callback_on_step_end``. + enable_tiling (`bool`, *optional*, defaults to `False`): + Enable tiled VAE decoding to reduce peak memory usage. + max_sequence_length (`int`, *optional*, defaults to 4096): + Maximum sequence length for prompt encoding. + drop_vit_feature (`bool`, *optional*, defaults to `False`): + Drop vision tokens in the multi-image encoding path. + enable_denormalization (`bool`, *optional*, defaults to `True`): + Denormalise latents before VAE decoding. + **kwargs: + Additional keyword arguments for forward compatibility. + + Examples: + + Returns: + [`~pipelines.joyimage.JoyImageEditPipelineOutput`] or `torch.Tensor`: + If ``return_dict`` is ``True``, returns a pipeline output object + containing the generated image(s). Otherwise returns the image + tensor directly. + """ + # Resize the input image to the nearest bucket resolution. + # Or resize the specified height and width to the nearest bucket resolution. + image_size = image[0].size if isinstance(image, list) else image.size + if height is not None and width is not None: + # Override the image size if specified. + image_size = (width, height) + + height, width = _dynamic_resize_from_bucket(image_size, basesize=1024) + processed_image = _resize_center_crop(image, (height, width)) + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # num_items: 1 for unconditional generation, 2 for reference-image editing. + num_items = 1 if image is None else 2 + + # Encode the conditioning prompt (and reference image when present). + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + images=processed_image, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + template_type="image", + drop_vit_feature=drop_vit_feature, + ) + + if self.do_classifier_free_guidance: + # Build default negative prompts when none are provided. + if negative_prompt is None and negative_prompt_embeds is None: + if num_items <= 1: + negative_prompt = ["<|im_start|>user\n<|im_end|>\n"] * batch_size + else: + negative_prompt = ["<|im_start|>user\n\n<|im_end|>\n"] * batch_size + + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + images=processed_image, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + template_type="image", + ) + + # Pad both embeddings to the same sequence length and concatenate + # in (unconditional, conditional) order for a single forward pass. + max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) + prompt_embeds = torch.cat( + [ + self.pad_sequence(negative_prompt_embeds, max_seq_len), + self.pad_sequence(prompt_embeds, max_seq_len), + ] + ) + if prompt_embeds_mask is not None: + prompt_embeds_mask = torch.cat( + [ + self.pad_sequence(negative_prompt_embeds_mask, max_seq_len), + self.pad_sequence(prompt_embeds_mask, max_seq_len), + ] + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + ) + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_items, + num_channels_latents, + height, + width, + 1, # video_length = 1 for image inference + prompt_embeds.dtype, + device, + generator, + latents, + reference_images=[processed_image], + enable_denormalization=enable_denormalization, + ) + + target_dtype = PRECISION_TO_TYPE[self.args.dit_precision] + autocast_enabled = target_dtype != torch.float32 + vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision] + vae_autocast_enabled = vae_dtype != torch.float32 + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # Cache reference latents to restore them at each denoising step. + if num_items > 1: + ref_latents = latents[:, :(num_items - 1)].clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Restore reference latents so they are never overwritten by the scheduler. + if num_items > 1: + latents[:, :(num_items - 1)] = ref_latents.clone() + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + t_expand = t.repeat(latent_model_input.shape[0]) + + with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # Rescale to match the conditional prediction norm (guidance rescaling). + cond_norm = torch.norm(noise_pred_text, dim=2, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=2, keepdim=True) + noise_pred = noise_pred * (cond_norm / noise_norm.clamp_min(1e-6)) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + if progress_bar is not None: + progress_bar.update() + + if output_type != "latent": + latents = rearrange(latents, "b n c f h w -> (b n) c f h w") + if enable_denormalization: + latents = self.denormalize_latents(latents) + + with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled): + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False)[0] + image = rearrange(image, "(b n) c f h w -> b n c f h w", b=batch_size) + else: + image = latents + + # Extract the last item (target slot) from the batch, shape: (F, C, H, W). + image = image.float().permute(0, 1, 3, 2, 4, 5)[0, -1] + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return JoyImageEditPipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py new file mode 100644 index 000000000000..a98b9066c69a --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +import torch + +from ...utils import BaseOutput +import PIL.Image +from typing import Union, List, Tuple, Optional +import numpy as np + +@dataclass +class JoyImageEditPipelineOutput(BaseOutput): + """ + Output class for JoyImageEdit generation pipelines. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] From 84597596ce7b96601bb94be381a9ca6872650a44 Mon Sep 17 00:00:00 2001 From: "zhangmaoquan.1" Date: Tue, 14 Apr 2026 14:44:17 +0800 Subject: [PATCH 2/6] [fix] remove rearrange --- .../transformers/transformer_joyimage.py | 28 ++++++++++++------- .../joyimage/pipeline_joyimage_edit.py | 10 +++---- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py index c29284b3d8f7..b7de3dc39518 100644 --- a/src/diffusers/models/transformers/transformer_joyimage.py +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange + from ...configuration_utils import ConfigMixin, register_to_config from ..attention import FeedForward @@ -268,12 +268,12 @@ def apply_gate(x, gate=None, tanh=False): def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_kwargs=None): batch_size = q.shape[0] if ATTN_BACKEND == 'sdpa': - q = rearrange(q, "b l h c -> b h l c") - k = rearrange(k, "b l h c -> b h l c") - v = rearrange(v, "b l h c -> b h l c") + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) output = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_kwargs['attn_mask']) - output = rearrange(output, "b h l c -> b l h c") + output = output.transpose(1, 2) elif ATTN_BACKEND == 'flash_attn': cu_seqlens_q = attn_kwargs['cu_seqlens_q'] cu_seqlens_kv = attn_kwargs['cu_seqlens_kv'] @@ -369,7 +369,8 @@ def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_f img_modulated = modulate(self.img_norm1(img), shift=img_mod1_shift, scale=img_mod1_scale) img_qkv = self.img_attn_qkv(img_modulated) - img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + B, L, _ = img_qkv.shape + img_q, img_k, img_v = img_qkv.reshape(B, L, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4).unbind(0) img_q = self.img_attn_q_norm(img_q).to(img_v) img_k = self.img_attn_k_norm(img_k).to(img_v) if vis_freqs_cis is not None: @@ -377,7 +378,8 @@ def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_f txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) txt_qkv = self.txt_attn_qkv(txt_modulated) - txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + B2, L2, _ = txt_qkv.shape + txt_q, txt_k, txt_v = txt_qkv.reshape(B2, L2, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4).unbind(0) txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) if txt_freqs_cis is not None: @@ -548,7 +550,7 @@ def forward( if self.patch_size[0] != 1: raise ValueError("For multi-item input, patch_size[0] must be 1") hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1) - hidden_states = rearrange(hidden_states, "b n c t h w -> b c (n t) h w") + hidden_states = hidden_states.permute(0, 2, 1, 3, 4, 5).flatten(2, 3) _, _, ot, oh, ow = hidden_states.shape tt, th, tw = ( @@ -623,11 +625,17 @@ def forward( repa_hidden_state = repa_hidden_state.view(img.shape[0], tt, th, tw, -1) if is_multi_item: - img = rearrange(img, "b c (n t) h w -> b n c t h w", n=num_items) + # b c (n t) h w -> b n c t h w + b, c, nt, h, w = img.shape + t = nt // num_items + img = img.reshape(b, c, num_items, t, h, w).permute(0, 2, 1, 3, 4, 5) if num_items > 1: img = torch.cat([img[:, 1:], img[:, :1]], dim=1) if repa_hidden_state is not None: - repa_hidden_state = rearrange(repa_hidden_state, "b (n t) h w c -> b n t h w c", n=num_items) + # b (n t) h w c -> b n t h w c + b2, nt2, h2, w2, c2 = repa_hidden_state.shape + t2 = nt2 // num_items + repa_hidden_state = repa_hidden_state.reshape(b2, num_items, t2, h2, w2, c2) if not return_dict: return (img, txt, repa_hidden_state) diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py index 190cb3d90c31..e8877cfa3e34 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py @@ -7,7 +7,7 @@ import numpy as np import torch import torchvision.transforms.functional as TF -from einops import rearrange + from PIL import Image from transformers import AutoProcessor, Qwen2Tokenizer, Qwen3VLForConditionalGeneration, Qwen3VLProcessor @@ -817,11 +817,11 @@ def prepare_latents( ref_img = [torch.from_numpy(np.array(x.convert("RGB"))) for x in reference_images] ref_img = torch.stack(ref_img).to(device=device, dtype=dtype) ref_img = ref_img / 127.5 - 1.0 - ref_img = rearrange(ref_img, "x h w c -> x c 1 h w") + ref_img = ref_img.permute(0, 3, 1, 2).unsqueeze(2) ref_vae = self.vae.encode(ref_img).latent_dist.sample() if enable_denormalization: ref_vae = self.normalize_latents(ref_vae) - ref_vae = rearrange(ref_vae, "(b n) c 1 h w -> b n c 1 h w", n=(num_items - 1)) + ref_vae = ref_vae.view(shape[0], num_items - 1, *ref_vae.shape[1:]) noise = randn_tensor((shape[0], 1, *shape[2:]), generator=generator, device=device, dtype=dtype) latents = torch.cat([ref_vae, noise], dim=1) else: @@ -1162,7 +1162,7 @@ def __call__( progress_bar.update() if output_type != "latent": - latents = rearrange(latents, "b n c f h w -> (b n) c f h w") + latents = latents.flatten(0, 1) if enable_denormalization: latents = self.denormalize_latents(latents) @@ -1170,7 +1170,7 @@ def __call__( if enable_tiling: self.vae.enable_tiling() image = self.vae.decode(latents, return_dict=False)[0] - image = rearrange(image, "(b n) c f h w -> b n c f h w", b=batch_size) + image = image.unflatten(0, (batch_size, -1)) else: image = latents From e6e6df53174a120548505d31fa7b1342eb9868f9 Mon Sep 17 00:00:00 2001 From: "zhangmaoquan.1" Date: Tue, 14 Apr 2026 16:54:45 +0800 Subject: [PATCH 3/6] [refactor] two pass when do cfg --- .../transformers/transformer_joyimage.py | 100 ++---------------- .../joyimage/pipeline_joyimage_edit.py | 66 +++--------- 2 files changed, 24 insertions(+), 142 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py index b7de3dc39518..815e9a8ffd64 100644 --- a/src/diffusers/models/transformers/transformer_joyimage.py +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -12,13 +12,6 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -ATTN_BACKEND = 'sdpa' -try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func - ATTN_BACKEND = 'flash_attn' -except: - pass - def _to_tuple(x, dim=2): if isinstance(x, int): return (x,) * dim @@ -174,31 +167,6 @@ def get_nd_rotary_pos_embed( return vis_emb, txt_emb -def get_cu_seqlens(text_mask, img_len): - """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len - - Args: - text_mask (torch.Tensor): the mask of text - img_len (int): the length of image - - Returns: - torch.Tensor: the calculated cu_seqlens for flash attention - """ - batch_size = text_mask.shape[0] - text_len = text_mask.sum(dim=1) - max_len = text_mask.shape[1] + img_len - - cu_seqlens = torch.zeros([2 * batch_size + 1], - dtype=torch.int32, device="cuda") - - for i in range(batch_size): - s = text_len[i] + img_len - s1 = i * max_len + s - s2 = (i + 1) * max_len - cu_seqlens[2 * i + 1] = s1 - cu_seqlens[2 * i + 2] = s2 - - return cu_seqlens def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): factory_kwargs = {"dtype": dtype, "device": device} @@ -265,33 +233,12 @@ def apply_gate(x, gate=None, tanh=False): return x return x * (gate.unsqueeze(1).tanh() if tanh else gate.unsqueeze(1)) -def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_kwargs=None): - batch_size = q.shape[0] - if ATTN_BACKEND == 'sdpa': - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attn_kwargs['attn_mask']) - output = output.transpose(1, 2) - elif ATTN_BACKEND == 'flash_attn': - cu_seqlens_q = attn_kwargs['cu_seqlens_q'] - cu_seqlens_kv = attn_kwargs['cu_seqlens_kv'] - max_seqlen_q = attn_kwargs['max_seqlen_q'] - max_seqlen_kv = attn_kwargs['max_seqlen_kv'] - x = flash_attn_varlen_func( - q.view(q.shape[0] * q.shape[1], *q.shape[2:]), - k.view(k.shape[0] * k.shape[1], *k.shape[2:]), - v.view(v.shape[0] * v.shape[1], *v.shape[2:]), - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - ) - # x with shape [(bxs), a, d] - output = x.view( - batch_size, max_seqlen_q, x.shape[-2], x.shape[-1] - ) # reshape +def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + output = torch.nn.functional.scaled_dot_product_attention(q, k, v) + output = output.transpose(1, 2) return output class RMSNorm(nn.Module): @@ -349,7 +296,7 @@ def __init__( self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") - def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_freqs_cis=None, txt_freqs_cis=None, attn_kwargs=None): + def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_freqs_cis=None, txt_freqs_cis=None): ( img_mod1_shift, img_mod1_scale, @@ -389,7 +336,7 @@ def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_f k = torch.cat((img_k, txt_k), dim=1) v = torch.cat((img_v, txt_v), dim=1) - attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) + attn = attention(q, k, v).flatten(2, 3) img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) @@ -579,37 +526,10 @@ def forward( txt_seq_len = txt.shape[1] img_seq_len = img.shape[1] - - cu_seqlens_q = get_cu_seqlens( - encoder_hidden_states_mask, img_seq_len) - cu_seqlens_kv = cu_seqlens_q - max_seqlen_q = img_seq_len + txt_seq_len - max_seqlen_kv = max_seqlen_q - - - attn_kwargs = {"encoder_hidden_states_mask": encoder_hidden_states_mask} - attn_kwargs.update({ - 'cu_seqlens_q': cu_seqlens_q, - 'cu_seqlens_kv': cu_seqlens_kv, - 'max_seqlen_q': max_seqlen_q, - 'max_seqlen_kv': max_seqlen_kv, - }) - - max_seqlen_q = img_seq_len + txt_seq_len - seq_lens = encoder_hidden_states_mask.sum(dim=1) + img_seq_len - max_len = encoder_hidden_states_mask.shape[1] + img_seq_len - assert max_seqlen_q == max_len - positions = torch.arange(max_seqlen_q, device=img.device).unsqueeze(0) - seq_lens_expanded = seq_lens.unsqueeze(1) - mask = positions < seq_lens_expanded - mask = mask.unsqueeze(1).unsqueeze(2) - attn_mask = mask & mask.transpose(-1, -2) - - attn_kwargs.update({'attn_mask': attn_mask, 'max_seqlen_q': max_seqlen_q}) - + img_hidden_states = [] for block in self.double_blocks: - img, txt = block(img, txt, vec, vis_freqs_cis, txt_freqs_cis, attn_kwargs) + img, txt = block(img, txt, vec, vis_freqs_cis, txt_freqs_cis) img_hidden_states.append(img) img_len = img.shape[1] diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py index e8877cfa3e34..c9688e15adbd 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py @@ -855,35 +855,6 @@ def interrupt(self) -> bool: """When True, the denoising loop is interrupted at the next step.""" return self._interrupt - # ------------------------------------------------------------------ - # Utility - # ------------------------------------------------------------------ - - def pad_sequence(self, x: torch.Tensor, target_length: int) -> torch.Tensor: - """ - Truncate or zero-pad a sequence tensor along dimension 1. - - If the sequence is longer than ``target_length`` the last - ``target_length`` elements are kept. If it is shorter, zero-padding - is appended on the right. - - Args: - x: Input tensor of shape (B, T, ...) or (B, T). - target_length: Desired sequence length. - - Returns: - Tensor of shape (B, target_length, ...) or (B, target_length). - """ - current_length = x.shape[1] - if current_length >= target_length: - return x[:, -target_length:] - padding_length = target_length - current_length - if x.ndim >= 3: - padding = torch.zeros((x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device) - else: - padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device) - return torch.cat([x, padding], dim=1) - # ------------------------------------------------------------------ # Forward pass # ------------------------------------------------------------------ @@ -1062,23 +1033,6 @@ def __call__( template_type="image", ) - # Pad both embeddings to the same sequence length and concatenate - # in (unconditional, conditional) order for a single forward pass. - max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) - prompt_embeds = torch.cat( - [ - self.pad_sequence(negative_prompt_embeds, max_seq_len), - self.pad_sequence(prompt_embeds, max_seq_len), - ] - ) - if prompt_embeds_mask is not None: - prompt_embeds_mask = torch.cat( - [ - self.pad_sequence(negative_prompt_embeds_mask, max_seq_len), - self.pad_sequence(prompt_embeds_mask, max_seq_len), - ] - ) - timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, @@ -1124,7 +1078,7 @@ def __call__( if num_items > 1: latents[:, :(num_items - 1)] = ref_latents.clone() - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latents t_expand = t.repeat(latent_model_input.shape[0]) with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): @@ -1137,12 +1091,20 @@ def __call__( )[0] if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + return_dict=False, + )[0] + + comb_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond) # Rescale to match the conditional prediction norm (guidance rescaling). - cond_norm = torch.norm(noise_pred_text, dim=2, keepdim=True) - noise_norm = torch.norm(noise_pred, dim=2, keepdim=True) - noise_pred = noise_pred * (cond_norm / noise_norm.clamp_min(1e-6)) + cond_norm = torch.norm(noise_pred, dim=2, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=2, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm.clamp_min(1e-6)) latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] From f557113ae51572f8ea0acd60535e9a4d20d44097 Mon Sep 17 00:00:00 2001 From: "zhangmaoquan.1" Date: Tue, 14 Apr 2026 17:35:31 +0800 Subject: [PATCH 4/6] [refactor] remove repa, use wantimetextembeding, refactor modulate code --- .../transformers/transformer_joyimage.py | 96 ++----------------- 1 file changed, 7 insertions(+), 89 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py index 815e9a8ffd64..1f5153b2fc3f 100644 --- a/src/diffusers/models/transformers/transformer_joyimage.py +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -11,6 +11,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from .transformer_wan import WanTimeTextImageEmbedding def _to_tuple(x, dim=2): if isinstance(x, int): @@ -167,33 +168,7 @@ def get_nd_rotary_pos_embed( return vis_emb, txt_emb - -def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): - factory_kwargs = {"dtype": dtype, "device": device} - if modulate_type == "wanx": - return ModulateWan(hidden_size, factor, **factory_kwargs) - if modulate_type == "adaLN": - return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs) - if modulate_type == "jdx": - return ModulateX(hidden_size, factor, **factory_kwargs) - raise ValueError(f"Unknown modulation type: {modulate_type}.") - - -class ModulateDiT(nn.Module): - def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): - factory_kwargs = {"dtype": dtype, "device": device} - super().__init__() - self.factor = factor - self.act = act_layer() - self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) - nn.init.zeros_(self.linear.weight) - nn.init.zeros_(self.linear.bias) - - def forward(self, x: torch.Tensor): - return self.linear(self.act(x)).chunk(self.factor, dim=-1) - - -class ModulateWan(nn.Module): +class JoyImageModulate(nn.Module): def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): super().__init__() self.factor = factor @@ -207,17 +182,6 @@ def forward(self, x: torch.Tensor): return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] -class ModulateX(nn.Module): - def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): - super().__init__() - self.factor = factor - - def forward(self, x: torch.Tensor): - if len(x.shape) != 3: - x = x.unsqueeze(1) - return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] - - def modulate(x, shift=None, scale=None): if scale is None and shift is None: return x @@ -267,18 +231,16 @@ def __init__( mlp_width_ratio: float, dtype=None, device=None, - dit_modulation_type: str = "wanx", attn_backend: str = "torch_spda", ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.attn_backend = attn_backend - self.dit_modulation_type = dit_modulation_type self.heads_num = heads_num head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) - self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) + self.img_mod = JoyImageModulate(hidden_size, 6, **factory_kwargs) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) @@ -287,7 +249,7 @@ def __init__( self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") - self.txt_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) + self.txt_mod = JoyImageModulate(hidden_size, 6, **factory_kwargs) self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) @@ -352,28 +314,6 @@ def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_f ) return img, txt - - -class WanTimeTextImageEmbedding(nn.Module): - def __init__(self, dim: int, time_freq_dim: int, time_proj_dim: int, text_embed_dim: int): - super().__init__() - self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) - self.act_fn = nn.SiLU() - self.time_proj = nn.Linear(dim, time_proj_dim) - self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") - - def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): - timestep = self.timesteps_proj(timestep) - time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype - if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: - timestep = timestep.to(time_embedder_dtype) - temb = self.time_embedder(timestep).type_as(encoder_hidden_states) - timestep_proj = self.time_proj(self.act_fn(temb)) - encoder_hidden_states = self.text_embedder(encoder_hidden_states) - return temb, timestep_proj, encoder_hidden_states - - class JoyImageTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @@ -390,20 +330,15 @@ def __init__( mm_double_blocks_depth: int = 40, rope_dim_list: tuple[int, int, int] = (16, 56, 56), rope_type: str = "rope", - dit_modulation_type: str = "wanx", attn_backend: str = "torch_spda", unpatchify_new: bool = True, rope_theta: int = 256, enable_activation_checkpointing: bool = False, - is_repa: bool = False, - repa_layer: int = 13, ): super().__init__() self.args = SimpleNamespace( enable_activation_checkpointing=enable_activation_checkpointing, - is_repa=is_repa, - repa_layer=repa_layer, ) self.out_channels = out_channels or in_channels @@ -411,7 +346,6 @@ def __init__( self.hidden_size = hidden_size self.heads_num = heads_num self.rope_dim_list = tuple(rope_dim_list) - self.dit_modulation_type = dit_modulation_type self.mm_double_blocks_depth = mm_double_blocks_depth self.attn_backend = attn_backend self.rope_type = rope_type @@ -426,7 +360,7 @@ def __init__( self.condition_embedder = WanTimeTextImageEmbedding( dim=hidden_size, time_freq_dim=256, - time_proj_dim=hidden_size * 6 if dit_modulation_type != "adaLN" else hidden_size, + time_proj_dim=hidden_size * 6, text_embed_dim=text_states_dim, ) @@ -436,7 +370,6 @@ def __init__( hidden_size=self.hidden_size, heads_num=self.heads_num, mlp_width_ratio=mlp_width_ratio, - dit_modulation_type=self.dit_modulation_type, attn_backend=attn_backend, ) for _ in range(mm_double_blocks_depth) @@ -446,11 +379,6 @@ def __init__( self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(hidden_size, out_channels * math.prod(self.patch_size)) - if self.args.is_repa: - self.repa_proj = nn.Linear(hidden_size, text_states_dim) - if self.args.repa_layer > mm_double_blocks_depth: - raise ValueError("repa_layer should be smaller than total depth") - self.gradient_checkpointing = enable_activation_checkpointing @property @@ -514,7 +442,7 @@ def forward( ) img = self.img_in(hidden_states).flatten(2).transpose(1, 2) - _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + _, vec, txt, _ = self.condition_embedder(timestep, encoder_hidden_states) if vec.shape[-1] > self.hidden_size: vec = vec.unflatten(1, (6, -1)) @@ -539,11 +467,6 @@ def forward( img = self.proj_out(self.norm_out(img)) img = self.unpatchify(img, tt, th, tw) - repa_hidden_state = None - if self.args.is_repa: - repa_hidden_state = self.repa_proj(img_hidden_states[self.args.repa_layer]) - repa_hidden_state = repa_hidden_state.view(img.shape[0], tt, th, tw, -1) - if is_multi_item: # b c (n t) h w -> b n c t h w b, c, nt, h, w = img.shape @@ -551,14 +474,9 @@ def forward( img = img.reshape(b, c, num_items, t, h, w).permute(0, 2, 1, 3, 4, 5) if num_items > 1: img = torch.cat([img[:, 1:], img[:, :1]], dim=1) - if repa_hidden_state is not None: - # b (n t) h w c -> b n t h w c - b2, nt2, h2, w2, c2 = repa_hidden_state.shape - t2 = nt2 // num_items - repa_hidden_state = repa_hidden_state.reshape(b2, num_items, t2, h2, w2, c2) if not return_dict: - return (img, txt, repa_hidden_state) + return (img, txt) return Transformer2DModelOutput(sample=img) From d397b6812ff5bb11772abbcaa47db9cd5eedc292 Mon Sep 17 00:00:00 2001 From: "zhangmaoquan.1" Date: Tue, 14 Apr 2026 22:19:06 +0800 Subject: [PATCH 5/6] [refactor] Joyimage Attention refactor --- .../transformers/transformer_joyimage.py | 597 +++++++++++------- 1 file changed, 373 insertions(+), 224 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py index 1f5153b2fc3f..6bf3404c196f 100644 --- a/src/diffusers/models/transformers/transformer_joyimage.py +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -1,17 +1,41 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math -from types import SimpleNamespace +from typing import Any, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F - from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from .transformer_wan import WanTimeTextImageEmbedding + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# --------------------------------------------------------------------------- +# Rotary position embedding utilities +# --------------------------------------------------------------------------- + def _to_tuple(x, dim=2): if isinstance(x, int): @@ -20,7 +44,8 @@ def _to_tuple(x, dim=2): return tuple(x) raise ValueError(f"Expected length {dim} or int, but got {x}") -def get_meshgrid_nd(start, *args, dim=2): + +def _get_meshgrid_nd(start, *args, dim=2): if len(args) == 0: num = _to_tuple(start, dim=dim) start = (0,) * dim @@ -44,42 +69,42 @@ def get_meshgrid_nd(start, *args, dim=2): grid = torch.meshgrid(*axis_grid, indexing="ij") return torch.stack(grid, dim=0) -def reshape_for_broadcast(freqs_cis, x: torch.Tensor, head_first: bool = False): - ndim = x.ndim - assert 0 <= 1 < ndim +def _reshape_for_broadcast(freqs_cis, x: torch.Tensor, head_first: bool = False): + ndim = x.ndim if isinstance(freqs_cis, tuple): if head_first: - assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]) shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] else: - assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) if head_first: - assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] else: - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) -def rotate_half(x: torch.Tensor): +def _rotate_half(x: torch.Tensor): x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) return torch.stack([-x_imag, x_real], dim=-1).flatten(3) -def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis, head_first: bool = False): - cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) +def _apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + cos, sin = _reshape_for_broadcast(freqs_cis, xq, head_first) cos, sin = cos.to(xq.device), sin.to(xq.device) - xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) - xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + xq_out = (xq.float() * cos + _rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk.float()) * sin).type_as(xk) return xq_out, xk_out -def get_1d_rotary_pos_embed( +def _get_1d_rotary_pos_embed( dim: int, pos, theta: float = 10000.0, @@ -104,7 +129,7 @@ def get_1d_rotary_pos_embed( return torch.polar(torch.ones_like(freqs), freqs) -def get_nd_rotary_pos_embed( +def _get_nd_rotary_pos_embed( rope_dim_list, start, *args, @@ -115,7 +140,7 @@ def get_nd_rotary_pos_embed( interpolation_factor=1.0, ): rope_dim_list = list(rope_dim_list) - grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) + grid = _get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) if isinstance(theta_rescale_factor, (int, float)): theta_rescale_factor = [float(theta_rescale_factor)] * len(rope_dim_list) @@ -129,7 +154,7 @@ def get_nd_rotary_pos_embed( embs = [] for i in range(len(rope_dim_list)): - emb = get_1d_rotary_pos_embed( + emb = _get_1d_rotary_pos_embed( rope_dim_list[i], grid[i].reshape(-1), theta, @@ -151,7 +176,7 @@ def get_nd_rotary_pos_embed( vis_max_ids = grid.view(-1).max().item() grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1 for i in range(len(rope_dim_list)): - emb = get_1d_rotary_pos_embed( + emb = _get_1d_rotary_pos_embed( rope_dim_list[i], grid_txt, theta, @@ -168,97 +193,210 @@ def get_nd_rotary_pos_embed( return vis_emb, txt_emb + +# --------------------------------------------------------------------------- +# Modulation +# --------------------------------------------------------------------------- + + class JoyImageModulate(nn.Module): + """Wan-style learnable modulation table. + + Produces `factor` modulation vectors by adding the conditioning signal to a + learnable parameter table. + """ + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): super().__init__() self.factor = factor self.modulate_table = nn.Parameter( - torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, requires_grad=True + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, + requires_grad=True, ) - def forward(self, x: torch.Tensor): - if len(x.shape) != 3: + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + if x.ndim != 3: x = x.unsqueeze(1) return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] -def modulate(x, shift=None, scale=None): - if scale is None and shift is None: - return x - if shift is None: - return x * (1 + scale.unsqueeze(1)) - if scale is None: - return x + shift.unsqueeze(1) +# --------------------------------------------------------------------------- +# Attention processor +# --------------------------------------------------------------------------- + + +class JoyImageAttnProcessor: + """Attention processor for JoyImage double-stream joint attention. + + Implements the joint attention computation where text and image streams are + processed together. The block stores fused QKV projections directly + (``img_attn_qkv`` / ``txt_attn_qkv``) so this processor operates on + a ``JoyImageTransformerBlock`` rather than on a generic ``Attention`` module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + pass + + def __call__( + self, + block: "JoyImageTransformerBlock", + hidden_states: torch.Tensor, # image stream (B, S_img, D) + encoder_hidden_states: torch.Tensor = None, # text stream (B, S_txt, D) + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is None: + raise ValueError("JoyImageAttnProcessor requires encoder_hidden_states (text stream)") + + heads = block.num_attention_heads + + # image stream: fused QKV -> split + img_qkv = block.img_attn_qkv(hidden_states) + img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) + + # text stream: fused QKV -> split + txt_qkv = block.txt_attn_qkv(encoder_hidden_states) + txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) + + # reshape to multi-head: (B, S, H, D) + img_query = img_query.unflatten(-1, (heads, -1)) + img_key = img_key.unflatten(-1, (heads, -1)) + img_value = img_value.unflatten(-1, (heads, -1)) + + txt_query = txt_query.unflatten(-1, (heads, -1)) + txt_key = txt_key.unflatten(-1, (heads, -1)) + txt_value = txt_value.unflatten(-1, (heads, -1)) + + # QK norm + img_query = block.img_attn_q_norm(img_query) + img_key = block.img_attn_k_norm(img_key) + txt_query = block.txt_attn_q_norm(txt_query) + txt_key = block.txt_attn_k_norm(txt_key) + + # RoPE (custom implementation) + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_query, img_key = _apply_rotary_emb(img_query, img_key, vis_freqs, head_first=False) + if txt_freqs is not None: + txt_query, txt_key = _apply_rotary_emb(txt_query, txt_key, txt_freqs, head_first=False) + + # concatenate for joint attention: [img, txt] + joint_query = torch.cat([img_query, txt_query], dim=1) + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # split back + img_attn_output = joint_hidden_states[:, : hidden_states.shape[1], :] + txt_attn_output = joint_hidden_states[:, hidden_states.shape[1] :, :] + + # output projections + img_attn_output = block.img_attn_proj(img_attn_output) + txt_attn_output = block.txt_attn_proj(txt_attn_output) + + return img_attn_output, txt_attn_output + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) -def apply_gate(x, gate=None, tanh=False): - if gate is None: - return x - return x * (gate.unsqueeze(1).tanh() if tanh else gate.unsqueeze(1)) +def _apply_gate(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + return x * gate.unsqueeze(1) -def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - output = torch.nn.functional.scaled_dot_product_attention(q, k, v) - output = output.transpose(1, 2) - return output -class RMSNorm(nn.Module): - def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - if elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- + - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) +@maybe_allow_in_graph +class JoyImageTransformerBlock(nn.Module): + """Double-stream transformer block for JoyImage. - def forward(self, x): - output = self._norm(x.float()).type_as(x) - if hasattr(self, "weight"): - output = output * self.weight - return output + Each block processes an image stream and a text stream jointly through + shared attention, following the SD3 / Flux double-stream pattern with + WAN-style modulation. + Attention projections are stored **directly** on the block (fused QKV) + so that weight keys match the checkpoint layout, e.g. + ``double_blocks.0.img_attn_qkv.weight``. + """ -class MMDoubleStreamBlock(nn.Module): def __init__( self, - hidden_size: int, - heads_num: int, - mlp_width_ratio: float, - dtype=None, - device=None, - attn_backend: str = "torch_spda", + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: float = 4.0, + qk_norm: str = "rms_norm", + eps: float = 1e-6, ): - factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.attn_backend = attn_backend - self.heads_num = heads_num - head_dim = hidden_size // heads_num - mlp_hidden_dim = int(hidden_size * mlp_width_ratio) - - self.img_mod = JoyImageModulate(hidden_size, 6, **factory_kwargs) - self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) - self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) - self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") - - self.txt_mod = JoyImageModulate(hidden_size, 6, **factory_kwargs) - self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) - self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) - self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") - - def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_freqs_cis=None, txt_freqs_cis=None): + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + mlp_hidden_dim = int(dim * mlp_width_ratio) + + # image stream + self.img_mod = JoyImageModulate(dim, factor=6) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + # text stream + self.txt_mod = JoyImageModulate(dim, factor=6) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + # ---- joint attention (fused QKV, directly on the block) ---- + # image attention layers + self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) + self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.img_attn_k_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.img_attn_proj = nn.Linear(inner_dim, dim, bias=True) + + # text attention layers + self.txt_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) + self.txt_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.txt_attn_k_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.txt_attn_proj = nn.Linear(inner_dim, dim, bias=True) + + self.processor = JoyImageAttnProcessor() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # modulation ( img_mod1_shift, img_mod1_scale, @@ -266,7 +404,7 @@ def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_f img_mod2_shift, img_mod2_scale, img_mod2_gate, - ) = self.img_mod(vec) + ) = self.img_mod(temb) ( txt_mod1_shift, txt_mod1_scale, @@ -274,130 +412,146 @@ def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_f txt_mod2_shift, txt_mod2_scale, txt_mod2_gate, - ) = self.txt_mod(vec) - - img_modulated = modulate(self.img_norm1(img), shift=img_mod1_shift, scale=img_mod1_scale) - img_qkv = self.img_attn_qkv(img_modulated) - B, L, _ = img_qkv.shape - img_q, img_k, img_v = img_qkv.reshape(B, L, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4).unbind(0) - img_q = self.img_attn_q_norm(img_q).to(img_v) - img_k = self.img_attn_k_norm(img_k).to(img_v) - if vis_freqs_cis is not None: - img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) - - txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) - txt_qkv = self.txt_attn_qkv(txt_modulated) - B2, L2, _ = txt_qkv.shape - txt_q, txt_k, txt_v = txt_qkv.reshape(B2, L2, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4).unbind(0) - txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) - txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) - if txt_freqs_cis is not None: - txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) - - q = torch.cat((img_q, txt_q), dim=1) - k = torch.cat((img_k, txt_k), dim=1) - v = torch.cat((img_v, txt_v), dim=1) - - attn = attention(q, k, v).flatten(2, 3) - img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] - - img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) - img = img + apply_gate( - self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), - gate=img_mod2_gate, + ) = self.txt_mod(temb) + + # --- attention --- + img_modulated = _modulate(self.img_norm1(hidden_states), img_mod1_shift, img_mod1_scale) + txt_modulated = _modulate(self.txt_norm1(encoder_hidden_states), txt_mod1_shift, txt_mod1_scale) + + img_attn, txt_attn = self.processor( + self, + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + image_rotary_emb=image_rotary_emb, ) - txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) - txt = txt + apply_gate( - self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), - gate=txt_mod2_gate, + hidden_states = hidden_states + _apply_gate(img_attn, img_mod1_gate) + encoder_hidden_states = encoder_hidden_states + _apply_gate(txt_attn, txt_mod1_gate) + + # --- FFN --- + hidden_states = hidden_states + _apply_gate( + self.img_mlp(_modulate(self.img_norm2(hidden_states), img_mod2_shift, img_mod2_scale)), + img_mod2_gate, ) + encoder_hidden_states = encoder_hidden_states + _apply_gate( + self.txt_mlp(_modulate(self.txt_norm2(encoder_hidden_states), txt_mod2_shift, txt_mod2_scale)), + txt_mod2_gate, + ) + + return hidden_states, encoder_hidden_states + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + - return img, txt class JoyImageTransformer3DModel(ModelMixin, ConfigMixin): + """JoyImage Transformer model for image generation / editing. + + Dual-stream DiT architecture with WAN-style conditioning embeddings and + custom rotary position embeddings. + """ + _supports_gradient_checkpointing = True + _no_split_modules = ["JoyImageTransformerBlock"] @register_to_config def __init__( self, - patch_size: tuple[int, int, int] = (1, 2, 2), + patch_size: list = [1, 2, 2], in_channels: int = 16, - out_channels: int = 16, - hidden_size: int = 4096, - heads_num: int = 32, - text_states_dim: int = 4096, + out_channels: int | None = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, mlp_width_ratio: float = 4.0, - mm_double_blocks_depth: int = 40, - rope_dim_list: tuple[int, int, int] = (16, 56, 56), + num_layers: int = 20, + rope_dim_list: list[int] = [16, 56, 56], rope_type: str = "rope", - attn_backend: str = "torch_spda", - unpatchify_new: bool = True, - rope_theta: int = 256, - enable_activation_checkpointing: bool = False, + theta: int = 256, + # legacy config.json keys (kept for backward compatibility) + heads_num: int | None = None, + mm_double_blocks_depth: int | None = None, + text_states_dim: int | None = None, + rope_theta: int | None = None, ): super().__init__() - self.args = SimpleNamespace( - enable_activation_checkpointing=enable_activation_checkpointing, - ) + # --- backward-compatible parameter mapping --- + if heads_num is not None: + num_attention_heads = heads_num + if mm_double_blocks_depth is not None: + num_layers = mm_double_blocks_depth + if text_states_dim is not None: + text_dim = text_states_dim + if rope_theta is not None: + theta = rope_theta self.out_channels = out_channels or in_channels - self.patch_size = tuple(patch_size) + self.patch_size = patch_size self.hidden_size = hidden_size - self.heads_num = heads_num - self.rope_dim_list = tuple(rope_dim_list) - self.mm_double_blocks_depth = mm_double_blocks_depth - self.attn_backend = attn_backend + self.num_attention_heads = num_attention_heads + self.rope_dim_list = rope_dim_list self.rope_type = rope_type - self.unpatchify_new = unpatchify_new - self.theta = rope_theta + self.theta = theta + + attention_head_dim = hidden_size // num_attention_heads + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) - if hidden_size % heads_num != 0: - raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") + # image projection + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + # condition embedder (re-uses WAN implementation) + from .transformer_wan import WanTimeTextImageEmbedding self.condition_embedder = WanTimeTextImageEmbedding( dim=hidden_size, time_freq_dim=256, time_proj_dim=hidden_size * 6, - text_embed_dim=text_states_dim, + text_embed_dim=text_dim, ) + # double-stream blocks self.double_blocks = nn.ModuleList( [ - MMDoubleStreamBlock( - hidden_size=self.hidden_size, - heads_num=self.heads_num, + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, mlp_width_ratio=mlp_width_ratio, - attn_backend=attn_backend, ) - for _ in range(mm_double_blocks_depth) + for _ in range(num_layers) ] ) + # output head self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(hidden_size, out_channels * math.prod(self.patch_size)) - - self.gradient_checkpointing = enable_activation_checkpointing + self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size)) - @property - def device(self) -> torch.device: - return next(self.parameters()).device + # ------------------------------------------------------------------ + # RoPE helper + # ------------------------------------------------------------------ - def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None): + def get_rotary_pos_embed( + self, + vis_rope_size: list[int], + txt_rope_size: int | None = None, + ): target_ndim = 3 if len(vis_rope_size) != target_ndim: vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + list(vis_rope_size) - head_dim = self.hidden_size // self.heads_num - rope_dim_list = list(self.rope_dim_list) + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] - if sum(rope_dim_list) != head_dim: - raise ValueError("sum(rope_dim_list) should equal head_dim") + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal head_dim" - return get_nd_rotary_pos_embed( + vis_freqs, txt_freqs = _get_nd_rotary_pos_embed( rope_dim_list, vis_rope_size, txt_rope_size=txt_rope_size, @@ -405,6 +559,24 @@ def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None): use_real=True, theta_rescale_factor=1, ) + return vis_freqs, txt_freqs + + # ------------------------------------------------------------------ + # Unpatchify + # ------------------------------------------------------------------ + + def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: + c = self.out_channels + pt, ph, pw = self.patch_size + assert t * h * w == x.shape[1] + + x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c) + x = torch.einsum("nthwopqc->nctohpwq", x) + return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw) + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ def forward( self, @@ -414,91 +586,68 @@ def forward( encoder_hidden_states_mask: torch.Tensor = None, return_dict: bool = True, ): - if encoder_hidden_states is None: - raise ValueError("encoder_hidden_states is required.") - - is_multi_item = len(hidden_states.shape) == 6 + # handle multi-item input (b, n, c, t, h, w) + is_multi_item = hidden_states.ndim == 6 num_items = 0 if is_multi_item: num_items = hidden_states.shape[1] if num_items > 1: - if self.patch_size[0] != 1: - raise ValueError("For multi-item input, patch_size[0] must be 1") - hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4, 5).flatten(2, 3) - - _, _, ot, oh, ow = hidden_states.shape - tt, th, tw = ( - ot // self.patch_size[0], - oh // self.patch_size[1], - ow // self.patch_size[2], - ) + assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1" + hidden_states = torch.cat( + [hidden_states[:, -1:], hidden_states[:, :-1]], dim=1 + ) + # rearrange: (b, n, c, t, h, w) -> (b, c, n*t, h, w) + b, n, c, t, h, w = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t, h, w) - if encoder_hidden_states_mask is None: - encoder_hidden_states_mask = torch.ones( - (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), - dtype=torch.bool, - device=encoder_hidden_states.device, - ) + batch_size, _, ot, oh, ow = hidden_states.shape + tt = ot // self.patch_size[0] + th = oh // self.patch_size[1] + tw = ow // self.patch_size[2] + # patchify img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + + # condition embeddings _, vec, txt, _ = self.condition_embedder(timestep, encoder_hidden_states) if vec.shape[-1] > self.hidden_size: vec = vec.unflatten(1, (6, -1)) txt_seq_len = txt.shape[1] - vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed( - vis_rope_size=(tt, th, tw), + + # RoPE + vis_freqs, txt_freqs = self.get_rotary_pos_embed( + vis_rope_size=[tt, th, tw], txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, ) - txt_seq_len = txt.shape[1] - img_seq_len = img.shape[1] - - img_hidden_states = [] + # main loop for block in self.double_blocks: - img, txt = block(img, txt, vec, vis_freqs_cis, txt_freqs_cis) - img_hidden_states.append(img) - - img_len = img.shape[1] - x = torch.cat((img, txt), 1) - img = x[:, :img_len, ...] + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=(vis_freqs, txt_freqs), + ) + # final layer img = self.proj_out(self.norm_out(img)) img = self.unpatchify(img, tt, th, tw) + # un-multi-item: (b, c, n*t, h, w) -> (b, n, c, t, h, w) if is_multi_item: - # b c (n t) h w -> b n c t h w - b, c, nt, h, w = img.shape - t = nt // num_items - img = img.reshape(b, c, num_items, t, h, w).permute(0, 2, 1, 3, 4, 5) + c_out = img.shape[1] + img = img.reshape(batch_size, c_out, num_items, -1, oh, ow) + img = img.permute(0, 2, 1, 3, 4, 5) # (b, n, c, t, h, w) if num_items > 1: img = torch.cat([img[:, 1:], img[:, :1]], dim=1) if not return_dict: - return (img, txt) - + return (img,) return Transformer2DModelOutput(sample=img) - def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int): - c = self.out_channels - pt, ph, pw = self.patch_size - if t * h * w != x.shape[1]: - raise ValueError("Invalid token length for unpatchify.") - - if self.unpatchify_new: - x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) - x = torch.einsum("nthwopqc->nctohpwq", x) - else: - x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) - x = torch.einsum("nthwcopq->nctohpwq", x) - - return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) - class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): - """ - Backward-compatible alias of JoyImageTransformer3DModel. - """ + """Alias kept for backward compatibility with pipeline imports.""" pass \ No newline at end of file From 9d78e4e1fd34b2243e164f08e712bde69999b732 Mon Sep 17 00:00:00 2001 From: "zhangmaoquan.1" Date: Tue, 14 Apr 2026 22:41:51 +0800 Subject: [PATCH 6/6] remove vae tiling and autocast --- .../joyimage/pipeline_joyimage_edit.py | 39 +++++++------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py index c9688e15adbd..2a5e042a7999 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py @@ -889,7 +889,6 @@ def __call__( ] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - enable_tiling: bool = False, max_sequence_length: int = 4096, drop_vit_feature: bool = False, enable_denormalization: bool = True, @@ -1057,11 +1056,6 @@ def __call__( enable_denormalization=enable_denormalization, ) - target_dtype = PRECISION_TO_TYPE[self.args.dit_precision] - autocast_enabled = target_dtype != torch.float32 - vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision] - vae_autocast_enabled = vae_dtype != torch.float32 - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -1081,25 +1075,23 @@ def __call__( latent_model_input = latents t_expand = t.repeat(latent_model_input.shape[0]) - with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): - noise_pred = self.transformer( + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( hidden_states=latent_model_input, timestep=t_expand, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, return_dict=False, )[0] - if self.do_classifier_free_guidance: - with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): - noise_pred_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=t_expand, - encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - return_dict=False, - )[0] - comb_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond) # Rescale to match the conditional prediction norm (guidance rescaling). cond_norm = torch.norm(noise_pred, dim=2, keepdim=True) @@ -1128,11 +1120,8 @@ def __call__( if enable_denormalization: latents = self.denormalize_latents(latents) - with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled): - if enable_tiling: - self.vae.enable_tiling() - image = self.vae.decode(latents, return_dict=False)[0] - image = image.unflatten(0, (batch_size, -1)) + image = self.vae.decode(latents, return_dict=False)[0] + image = image.unflatten(0, (batch_size, -1)) else: image = latents