From a0823154f9df95ad14f9dd61f86ee83d3a26fc89 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Thu, 18 Jun 2026 12:10:13 +0000 Subject: [PATCH 1/7] feat: add image edit plus --- ...convert_joyimage_edit_plus_to_diffusers.py | 290 ++++++++ src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_joyimage.py | 2 + .../transformer_joyimage_edit_plus.py | 365 +++++++++ src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/joyimage/__init__.py | 7 +- .../joyimage/pipeline_joyimage_edit_plus.py | 697 ++++++++++++++++++ .../pipelines/joyimage/pipeline_output.py | 8 + 10 files changed, 1377 insertions(+), 5 deletions(-) create mode 100644 scripts/convert_joyimage_edit_plus_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_joyimage_edit_plus.py create mode 100644 src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py diff --git a/scripts/convert_joyimage_edit_plus_to_diffusers.py b/scripts/convert_joyimage_edit_plus_to_diffusers.py new file mode 100644 index 000000000000..f01adb03c747 --- /dev/null +++ b/scripts/convert_joyimage_edit_plus_to_diffusers.py @@ -0,0 +1,290 @@ +import argparse +from typing import Any, Dict, Tuple + +import torch +from accelerate import init_empty_weights +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration + +from diffusers import ( + AutoencoderKLWan, + JoyImageEditPlusPipeline, +) +from diffusers.models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) + + +# VAE conversion reused from convert_joyimage_edit_to_diffusers.py (identical VAE) +def convert_vae(vae_ckpt_path): + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + middle_key_mapping = { + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + attention_mapping = { + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + head_mapping = { + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + for key, value in old_state_dict.items(): + if key in middle_key_mapping: + new_state_dict[middle_key_mapping[key]] = value + elif key in attention_mapping: + new_state_dict[attention_mapping[key]] = value + elif key in head_mapping: + new_state_dict[head_mapping[key]] = value + elif key in quant_mapping: + new_state_dict[quant_mapping[key]] = value + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + elif key.startswith("encoder.downsamples."): + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + new_state_dict[new_key] = value + elif key.startswith("decoder.upsamples."): + parts = key.split(".") + block_idx = int(parts[2]) + + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + new_state_dict[key] = value + continue + + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + new_state_dict[new_key] = value + + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + new_state_dict[new_key] = value + + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan() + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + + +def get_transformer_config() -> Dict[str, Any]: + return { + "hidden_size": 4096, + "in_channels": 16, + "num_attention_heads": 32, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_dim": 4096, + "rope_type": "rope", + "theta": 10000, + } + + +def convert_transformer(ckpt_path: str): + checkpoint = torch.load(ckpt_path, weights_only=True) + if "model" in checkpoint: + original_state_dict = checkpoint["model"] + else: + original_state_dict = checkpoint + + attn_suffixes = ( + "img_attn_qkv.", + "img_attn_q_norm.", + "img_attn_k_norm.", + "img_attn_proj.", + "txt_attn_qkv.", + "txt_attn_q_norm.", + "txt_attn_k_norm.", + "txt_attn_proj.", + ) + remapped = {} + for key, value in original_state_dict.items(): + new_key = key + if key.startswith("double_blocks."): + for suffix in attn_suffixes: + if "." + suffix in key and ".attn." + suffix not in key: + new_key = key.replace("." + suffix, ".attn." + suffix) + break + remapped[new_key] = value + + config = get_transformer_config() + with init_empty_weights(): + transformer = JoyImageEditPlusTransformer3DModel(**config) + transformer.load_state_dict(remapped, strict=True, assign=True) + return transformer + + +def get_args(): + parser = argparse.ArgumentParser(description="Convert JoyImage Edit Plus checkpoints to diffusers format") + parser.add_argument("--transformer_ckpt_path", type=str, default=None) + parser.add_argument("--vae_ckpt_path", type=str, default=None) + parser.add_argument("--text_encoder_path", type=str, default=None) + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--dtype", default="bf16", help="Torch dtype (fp32, fp16, bf16)") + parser.add_argument("--flow_shift", type=float, default=1.5) + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +if __name__ == "__main__": + args = get_args() + transformer = None + vae = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + vae = vae.to(dtype=dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.save_pipeline: + processor = AutoProcessor.from_pretrained(args.text_encoder_path) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( + args.text_encoder_path, torch_dtype=torch.bfloat16 + ).to("cuda") + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.flow_shift) + transformer = transformer.to("cuda") + vae = vae.to("cuda") + pipe = JoyImageEditPlusPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + processor.save_pretrained(f"{args.output_path}/processor") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9ec449df0508..b3c62bb70cc1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -275,6 +275,7 @@ "I2VGenXLUNet", "Ideogram4Transformer2DModel", "JoyImageEditTransformer3DModel", + "JoyImageEditPlusTransformer3DModel", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "Krea2Transformer2DModel", @@ -624,6 +625,8 @@ "ImageTextPipelineOutput", "JoyImageEditPipeline", "JoyImageEditPipelineOutput", + "JoyImageEditPlusPipeline", + "JoyImageEditPlusPipelineOutput", "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", "Kandinsky5I2IPipeline", @@ -1137,6 +1140,7 @@ I2VGenXLUNet, Ideogram4Transformer2DModel, JoyImageEditTransformer3DModel, + JoyImageEditPlusTransformer3DModel, Kandinsky3UNet, Kandinsky5Transformer3DModel, Krea2Transformer2DModel, @@ -1461,6 +1465,8 @@ ImageTextPipelineOutput, JoyImageEditPipeline, JoyImageEditPipelineOutput, + JoyImageEditPlusPipeline, + JoyImageEditPlusPipelineOutput, Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, Kandinsky5I2IPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3e56e49ce04e..30eec69dd02a 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -121,6 +121,7 @@ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_ideogram4"] = ["Ideogram4Transformer2DModel"] _import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel"] + _import_structure["transformers.transformer_joyimage_edit_plus"] = ["JoyImageEditPlusTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_krea2"] = ["Krea2Transformer2DModel"] _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] @@ -255,6 +256,7 @@ HunyuanVideoTransformer3DModel, Ideogram4Transformer2DModel, JoyImageEditTransformer3DModel, + JoyImageEditPlusTransformer3DModel, Kandinsky5Transformer3DModel, Krea2Transformer2DModel, LatteTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 4ba9703b5fc0..21f5cb853643 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -42,6 +42,7 @@ from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_ideogram4 import Ideogram4Transformer2DModel from .transformer_joyimage import JoyImageEditTransformer3DModel + from .transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_krea2 import Krea2Transformer2DModel from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py index b17ddb05f799..d30b0501e02f 100644 --- a/src/diffusers/models/transformers/transformer_joyimage.py +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -283,6 +283,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # modulation ( @@ -312,6 +313,7 @@ def forward( hidden_states=img_modulated, encoder_hidden_states=txt_modulated, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, ) hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py new file mode 100644 index 000000000000..abc8c2b4340a --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -0,0 +1,365 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import AttentionMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm +from .transformer_joyimage import ( + JoyImageAttention, + JoyImageModulate, + JoyImageTimeTextImageEmbedding, + JoyImageTransformerBlock, +) + + +logger = logging.get_logger(__name__) + + +def _apply_rotary_emb_batched( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + """RoPE that handles both batched [B, S, D] and unbatched [S, D] freqs.""" + cos, sin = freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device) + + if cos.ndim == 2: + # unbatched: [S, D] -> [1, S, 1, D] + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + elif cos.ndim == 3: + # batched: [B, S, D] -> [B, S, 1, D] + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) + + def _rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk) + return xq_out, xk_out + + +class JoyImageEditPlusAttnProcessor: + """Attention processor that supports batched RoPE embeddings for edit-plus multi-image input.""" + + _attention_backend = None + _parallel_config = None + + def __init__(self): + pass + + def __call__( + self, + attn: "JoyImageAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is None: + raise ValueError("JoyImageEditPlusAttnProcessor requires encoder_hidden_states") + + heads = attn.heads + + img_qkv = attn.img_attn_qkv(hidden_states) + img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) + + txt_qkv = attn.txt_attn_qkv(encoder_hidden_states) + txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) + + img_query = img_query.unflatten(-1, (heads, -1)) + img_key = img_key.unflatten(-1, (heads, -1)) + img_value = img_value.unflatten(-1, (heads, -1)) + + txt_query = txt_query.unflatten(-1, (heads, -1)) + txt_key = txt_key.unflatten(-1, (heads, -1)) + txt_value = txt_value.unflatten(-1, (heads, -1)) + + img_query = attn.img_attn_q_norm(img_query) + img_key = attn.img_attn_k_norm(img_key) + txt_query = attn.txt_attn_q_norm(txt_query) + txt_key = attn.txt_attn_k_norm(txt_key) + + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_query, img_key = _apply_rotary_emb_batched(img_query, img_key, vis_freqs) + if txt_freqs is not None: + txt_query, txt_key = _apply_rotary_emb_batched(txt_query, txt_key, txt_freqs) + + joint_query = torch.cat([img_query, txt_query], dim=1) + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + img_attn_output = joint_hidden_states[:, : hidden_states.shape[1], :] + txt_attn_output = joint_hidden_states[:, hidden_states.shape[1] :, :] + + img_attn_output = attn.img_attn_proj(img_attn_output) + txt_attn_output = attn.txt_attn_proj(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): + """JoyImage Edit Plus Transformer for multi-image editing. + + Uses a patchify+padding approach where each reference image and the target noise are independently + patchified and concatenated into a flat patch sequence. Supports variable-resolution reference images. + + Input format: [B, max_patches, C, pt, ph, pw] (6D padded patches) + """ + + _skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"] + _no_split_modules = ["JoyImageTransformerBlock"] + _supports_gradient_checkpointing = True + _keep_in_fp32_modules = [ + "time_embedder", + "norm1", + "norm2", + "norm_out", + ] + _repeated_blocks = ["JoyImageTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: int | None = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, + mlp_width_ratio: float = 4.0, + num_layers: int = 20, + rope_dim_list: list[int] = [16, 56, 56], + rope_type: str = "rope", + theta: int = 256, + ): + super().__init__() + + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.rope_dim_list = rope_dim_list + self.rope_type = rope_type + self.theta = theta + + attention_head_dim = hidden_size // num_attention_heads + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) + + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + self.condition_embedder = JoyImageTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_dim, + ) + + self.double_blocks = nn.ModuleList( + [ + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = FP32LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size)) + + self.gradient_checkpointing = False + + # Set batched-RoPE-aware attention processor on all blocks + for block in self.double_blocks: + block.attn.set_processor(JoyImageEditPlusAttnProcessor()) + + def _get_rotary_pos_embed_for_range( + self, + start: Tuple[int, int, int], + stop: Tuple[int, int, int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate 3D RoPE for a spatial range [start, stop).""" + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // 3] * 3 + + grids = [] + for i in range(3): + grids.append(torch.arange(start[i], stop[i], dtype=torch.float32)) + + mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0) + + cos_parts, sin_parts = [], [] + for i, dim in enumerate(rope_dim_list): + pos = mesh[i].reshape(-1) + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + angles = torch.outer(pos, freqs) + cos_parts.append(angles.cos().repeat_interleave(2, dim=1)) + sin_parts.append(angles.sin().repeat_interleave(2, dim=1)) + + return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor | None = None, + shape_list: List[List[Tuple[int, int, int]]] | None = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Tuple]: + """ + Args: + hidden_states: [B, max_patches, C, pt, ph, pw] - patchified latent input. + timestep: [B] - diffusion timestep. + encoder_hidden_states: [B, L, D] - text encoder outputs. + encoder_hidden_states_mask: [B, L] - attention mask for text tokens. + shape_list: Per-sample list of (t, h, w) tuples for each component (target + references). + return_dict: Whether to return a dict or tuple. + """ + batch_size, max_num_patches, channels, pt, ph, pw = hidden_states.shape + device = hidden_states.device + + # Unwrap list inputs (SglangXvideo passes these as lists from CFG branches) + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if isinstance(encoder_hidden_states_mask, list): + encoder_hidden_states_mask = encoder_hidden_states_mask[0] + + # Resolve shape_list from forward context if not explicitly provided + if shape_list is None: + try: + from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context + + forward_batch = get_forward_context().forward_batch + if forward_batch is not None and forward_batch.vae_image_sizes is not None: + shape_list = [list(forward_batch.vae_image_sizes)] * batch_size + except (ImportError, AttributeError): + pass + if shape_list is None: + raise ValueError( + "shape_list must be provided either as an argument or via forward_batch.vae_image_sizes" + ) + + # 1. Condition embeddings + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + # 2. Patchify via Conv3d: flatten (B, N) -> apply conv -> reshape back + x = hidden_states.reshape(batch_size * max_num_patches, channels, pt, ph, pw) + x = self.img_in(x) # (B*N, D, 1, 1, 1) + img = x.reshape(batch_size, max_num_patches, -1) + + # 3. Build per-component RoPE with temporal offsets + sample_cos_list, sample_sin_list = [], [] + + for i in range(batch_size): + s_cos_parts, s_sin_parts = [], [] + current_t_offset = 0 + + for thw in shape_list[i]: + t, h, w = thw + start = (current_t_offset, 0, 0) + stop = (current_t_offset + t, h, w) + cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(start, stop) + s_cos_parts.append(cos_emb) + s_sin_parts.append(sin_emb) + current_t_offset += t + + s_cos = torch.cat(s_cos_parts, dim=0).to(device) + s_sin = torch.cat(s_sin_parts, dim=0).to(device) + + actual_len = s_cos.shape[0] + pad_len = max_num_patches - actual_len + if pad_len > 0: + s_cos = F.pad(s_cos, (0, 0, 0, pad_len), value=1.0) + s_sin = F.pad(s_sin, (0, 0, 0, pad_len), value=0.0) + + sample_cos_list.append(s_cos) + sample_sin_list.append(s_sin) + + vis_freqs = (torch.stack(sample_cos_list), torch.stack(sample_sin_list)) + + # 4. Build attention mask: [B, 1, 1, img_seq + txt_seq] + # img patches: only actual (non-padding) patches are valid; txt uses encoder_hidden_states_mask + attention_mask = None + if encoder_hidden_states_mask is not None: + img_mask = torch.zeros(batch_size, max_num_patches, device=device, dtype=encoder_hidden_states_mask.dtype) + for i in range(batch_size): + actual_len = sum(t * h * w for t, h, w in shape_list[i]) + img_mask[i, :actual_len] = 1.0 + full_mask = torch.cat([img_mask, encoder_hidden_states_mask], dim=1) + attention_mask = full_mask.unsqueeze(1).unsqueeze(1).bool() + + # 5. Run double blocks + for block in self.double_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img, txt = self._gradient_checkpointing_func(block, img, txt, vec, (vis_freqs, None), attention_mask) + else: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=(vis_freqs, None), + attention_mask=attention_mask, + ) + + # 6. Output projection + reshape to 6D patches + img = self.proj_out(self.norm_out(img)) + img = img.reshape( + batch_size, max_num_patches, pt, ph, pw, self.out_channels + ).permute(0, 1, 5, 2, 3, 4) # -> [B, N, C, pt, ph, pw] + + if not return_dict: + return (img,) + return Transformer2DModelOutput(sample=img) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 234085456708..0e25c647299b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -345,7 +345,7 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] - _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput"] + _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput", "JoyImageEditPlusPipeline", "JoyImageEditPlusPipelineOutput"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -758,7 +758,7 @@ from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .ideogram4 import Ideogram4Pipeline, Ideogram4PromptEnhancerHead - from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput + from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput, JoyImageEditPlusPipeline, JoyImageEditPlusPipelineOutput from .kandinsky import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, diff --git a/src/diffusers/pipelines/joyimage/__init__.py b/src/diffusers/pipelines/joyimage/__init__.py index 85b9246b22a6..a5faea9d9763 100644 --- a/src/diffusers/pipelines/joyimage/__init__.py +++ b/src/diffusers/pipelines/joyimage/__init__.py @@ -22,8 +22,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_joyimage_edit"] = ["JoyImageEditPipeline"] - - _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput"] + _import_structure["pipeline_joyimage_edit_plus"] = ["JoyImageEditPlusPipeline"] + _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput", "JoyImageEditPlusPipelineOutput"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,7 +34,8 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_joyimage_edit import JoyImageEditPipeline - from .pipeline_output import JoyImageEditPipelineOutput + from .pipeline_joyimage_edit_plus import JoyImageEditPlusPipeline + from .pipeline_output import JoyImageEditPipelineOutput, JoyImageEditPlusPipelineOutput else: import sys diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py new file mode 100644 index 000000000000..c938e8e8ab32 --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -0,0 +1,697 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from PIL import Image +from transformers import ( + Qwen2Tokenizer, + Qwen3VLForConditionalGeneration, + Qwen3VLProcessor, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLWan +from ...models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import JoyImageEditImageProcessor, find_best_bucket +from .pipeline_output import JoyImageEditPlusPipelineOutput + + +EXAMPLE_DOC_STRING = """ +Examples: + ```python + >>> import torch + >>> from diffusers import JoyImageEditPlusPipeline + >>> from diffusers.utils import load_image + + >>> model_id = "jdopensource/JoyAI-Image-Edit-Plus-Diffusers" + >>> pipe = JoyImageEditPlusPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> images = [ + ... load_image("dog.png"), + ... load_image("person.png"), + ... ] + >>> output = pipe( + ... images=images, + ... prompt="Let the person lovingly play with the dog.", + ... height=1024, + ... width=1024, + ... num_inference_steps=30, + ... guidance_scale=4.0, + ... generator=torch.manual_seed(42), + ... ) + >>> output.images[0].save("output.png") + ``` +""" + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + + if timesteps is not None: + if "timesteps" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom timesteps.") + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if "sigmas" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom sigmas.") + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +class JoyImageEditPlusPipeline(DiffusionPipeline): + """Diffusion pipeline for multi-image editing using JoyImage Edit Plus. + + Supports multiple reference images with different resolutions. Each reference image is independently + VAE-encoded and patchified, then concatenated with the target noise patches for joint denoising. + + Model offloading order: text_encoder -> transformer -> vae. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: JoyImageEditPlusTransformer3DModel, + processor: Qwen3VLProcessor, + text_token_max_length: int = 2048, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + processor=processor, + ) + + self.text_token_max_length = text_token_max_length + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.vae_image_processor = JoyImageEditImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + ) + + self.prompt_template_encode = { + "multiple_images": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "{}<|im_start|>assistant\n" + ), + } + self.prompt_template_encode_start_idx = { + "multiple_images": 34, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_last_decoder_hidden_states(self, forward_fn, **kwargs): + """ + Run ``forward_fn(**kwargs)`` while capturing the **pre-norm** output of the last decoder layer via a forward + hook. + + This model was trained on transformers 4.57, where ``Qwen3VLForConditionalGeneration``'s + ``@check_model_inputs`` decorator monkey-patched each decoder layer to collect ``hidden_states``. Because + ``Qwen3VLCausalLMOutputWithPast`` has no ``last_hidden_state`` field, ``tie_last_hidden_states`` had no effect + and ``hidden_states[-1]`` was the **pre-norm** output of the last decoder layer. + + Starting from https://github.com/huggingface/transformers/pull/42609 the CausalLM forward explicitly returns + ``hidden_states=outputs.hidden_states`` from the inner model. Combined with the subsequent + ``@check_model_inputs`` → ``@capture_outputs`` migration (transformers 5.x), ``hidden_states`` is now captured + at the ``Qwen3VLTextModel`` level where ``tie_last_hidden_states=True`` replaces ``hidden_states[-1]`` with the + **post-norm** ``last_hidden_state``. The CausalLM simply passes this through, so ``hidden_states[-1]`` becomes + post-norm – a ~10x scale difference (std ~2 vs ~21) that breaks inference. + + This helper bypasses both mechanisms by hooking the last decoder layer directly, returning the raw pre-norm + output regardless of the transformers version. + """ + captured = {} + + def _hook(_module, _input, output): + captured["hidden_states"] = output[0] if isinstance(output, tuple) else output + + handle = self.text_encoder.model.language_model.layers[-1].register_forward_hook(_hook) + try: + forward_fn(**kwargs) + finally: + handle.remove() + return captured["hidden_states"] + + def encode_prompt_multiple_images( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + images: Optional[List[Image.Image]] = None, + max_sequence_length: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode prompts with inline tokens via the Qwen3-VL processor.""" + device = device or self._execution_device + template = self.prompt_template_encode["multiple_images"] + drop_idx = self.prompt_template_encode_start_idx["multiple_images"] + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt] + prompt = [template.format(p) for p in prompt] + + inputs = self.processor( + text=prompt, + images=images, + padding=True, + return_tensors="pt", + ).to(device) + + last_hidden_states = self._get_last_decoder_hidden_states(self.text_encoder, **inputs) + + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:] + + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + + return prompt_embeds, prompt_embeds_mask + + def _pad_sequence(self, x: torch.Tensor, target_length: int) -> torch.Tensor: + current_length = x.shape[1] + if current_length >= target_length: + return x[:, -target_length:] + padding_length = target_length - current_length + if x.ndim >= 3: + padding = torch.zeros( + (x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device + ) + else: + padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device) + return torch.cat([x, padding], dim=1) + + def normalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latent = (latent - latents_mean) / latents_std + else: + latent = latent * self.vae.config.scaling_factor + return latent + + def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latent = latent * latents_std + latents_mean + else: + latent = latent / self.vae.config.scaling_factor + return latent + + def _resize_center_crop(self, img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + w, h = img.size + bh, bw = target_size + scale = max(bh / h, bw / w) + resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale) + img = img.resize((resize_w, resize_h), Image.LANCZOS) + left = (resize_w - bw) // 2 + top = (resize_h - bh) // 2 + img = img.crop((left, top, left + bw, top + bh)) + return img + + def _get_bucket_size(self, img: Image.Image) -> Tuple[int, int]: + return find_best_bucket(img.size[1], img.size[0], self.vae_image_processor.config.basesize) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + reference_images: Optional[List[List[Image.Image]]] = None, + enable_denormalization: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int, int]]]]: + """Prepare 6D padded latent tensor with target noise + reference image latents. + + Returns: + padded_latents: [B, max_patches, C, pt, ph, pw] + target_mask: [B, max_patches] (True for target patches) + shape_list: per-sample list of (t, h, w) tuples for each component + """ + pt, ph, pw = self.transformer.config.patch_size + + all_patches = [] + all_target_masks = [] + all_shape_lists = [] + max_patches = 0 + + for i in range(batch_size): + sample_gen = generator[i] if isinstance(generator, list) else generator + + # Target noise + t_target = 1 + h_target = int(height) // self.vae_scale_factor_spatial + w_target = int(width) // self.vae_scale_factor_spatial + noise_shape = (num_channels_latents, t_target, h_target, w_target) + noise_block = randn_tensor(noise_shape, generator=sample_gen, device=device, dtype=dtype) + + sample_items = [noise_block] + + # Reference images + if reference_images is not None and reference_images[i]: + for ref_img_pil in reference_images[i]: + ref_h, ref_w = self._get_bucket_size(ref_img_pil) + ref_img_pil = self._resize_center_crop(ref_img_pil, (ref_h, ref_w)) + + ref_tensor = torch.from_numpy(np.array(ref_img_pil.convert("RGB"))).to(device=device, dtype=dtype) + ref_tensor = (ref_tensor / 127.5 - 1.0).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.float32): + ref_latent = self.vae.encode(ref_tensor.float()).latent_dist.mode() + ref_latent = ref_latent.to(dtype) + ref_latent = self.normalize_latents(ref_latent) + ref_latent = ref_latent.squeeze(0) # [C, 1, H', W'] + sample_items.append(ref_latent) + + # Patchify each item and build shape_list + sample_patches = [] + sample_masks = [] + sample_shapes = [] + + for j, item in enumerate(sample_items): + c, t, h, w = item.shape + l_t, l_h, l_w = t // pt, h // ph, w // pw + sample_shapes.append((l_t, l_h, l_w)) + + patches = rearrange(item, "c (t pt) (h ph) (w pw) -> (t h w) c pt ph pw", pt=pt, ph=ph, pw=pw) + sample_patches.append(patches) + sample_masks.append(torch.full((patches.shape[0],), j == 0, device=device, dtype=torch.bool)) + + combined_patches = torch.cat(sample_patches, dim=0) + combined_masks = torch.cat(sample_masks, dim=0) + + all_patches.append(combined_patches) + all_target_masks.append(combined_masks) + all_shape_lists.append(sample_shapes) + max_patches = max(max_patches, combined_patches.shape[0]) + + # Pad to uniform size + padded_latents = torch.zeros( + (batch_size, max_patches, num_channels_latents, pt, ph, pw), device=device, dtype=dtype + ) + target_mask = torch.zeros((batch_size, max_patches), device=device, dtype=torch.bool) + + for i in range(batch_size): + n = all_patches[i].shape[0] + padded_latents[i, :n] = all_patches[i] + target_mask[i, :n] = all_target_masks[i] + + return padded_latents, target_mask, all_shape_lists + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def guidance_scale(self) -> float: + return self._guidance_scale + + @property + def do_classifier_free_guidance(self) -> bool: + return self._guidance_scale > 1 + + @property + def num_timesteps(self) -> int: + return self._num_timesteps + + @property + def interrupt(self) -> bool: + return self._interrupt + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + images: List[Image.Image] | List[List[Image.Image]] | None = None, + prompt: str | List[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 4096, + enable_denormalization: bool = True, + ): + r""" + Generate an edited image from multiple reference images and a text prompt. + + Args: + images (`List[Image.Image]` or `List[List[Image.Image]]`): + Reference images for editing. Each image can have a different resolution. + If a flat list is provided, it's treated as one sample with multiple references. + prompt (`str` or `List[str]`): + Text prompt describing the desired edit. + height (`int`, *optional*): + Output height in pixels. If None, determined from the last reference image's bucket. + width (`int`, *optional*): + Output width in pixels. If None, determined from the last reference image's bucket. + num_inference_steps (`int`, defaults to 30): + Number of denoising steps. + guidance_scale (`float`, defaults to 4.0): + Classifier-free guidance scale. + negative_prompt (`str` or `List[str]`, *optional*): + Negative prompt for CFG. + generator (`torch.Generator`, *optional*): + RNG generator for reproducibility. + enable_denormalization (`bool`, defaults to True): + Whether to denormalize latents before VAE decoding. + + Examples: + + Returns: + [`JoyImageEditPlusPipelineOutput`] or `tuple`. + """ + # Normalize images input to List[List[Image]] + if images is not None: + if isinstance(images[0], Image.Image): + images = [images] # single sample + + self._guidance_scale = guidance_scale + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Determine output resolution from last reference image if not specified + if height is None or width is None: + if images is not None and len(images[0]) > 0: + last_img = images[0][-1] + height, width = self._get_bucket_size(last_img) + else: + height = height or 1024 + width = width or 1024 + + device = self._execution_device + + # Pre-process images: bucket-resize each reference image (matching original pipeline) + if images is not None: + processed_images = [] + for sample_imgs in images: + processed_sample = [] + for img in sample_imgs: + ref_h, ref_w = self._get_bucket_size(img) + resize_img = self._resize_center_crop(img, (ref_h, ref_w)) + processed_sample.append(resize_img) + processed_images.append(processed_sample) + images = processed_images + + # Construct prompts with tokens + prompt = [prompt] if isinstance(prompt, str) else prompt + if images is not None: + formatted_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + p = prompt[i] if i < len(prompt) else prompt[0] + formatted_prompts.append(f"<|im_start|>user\n{image_tags}{p}<|im_end|>\n") + else: + formatted_prompts = [f"<|im_start|>user\n{p}<|im_end|>\n" for p in prompt] + + # Flatten all images for the processor + flattened_images = None + if images is not None: + flattened_images = [img for sublist in images for img in sublist] + + # Encode prompt + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self.encode_prompt_multiple_images( + prompt=formatted_prompts, + images=flattened_images, + device=device, + max_sequence_length=max_sequence_length, + ) + + torch.save(prompt_embeds, "prompt_embeds.pt") + # Encode negative prompt for CFG + if self.do_classifier_free_guidance: + print(f"negative_prompt: {negative_prompt}") + if negative_prompt is None and negative_prompt_embeds is None: + neg_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if images is not None and i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + neg_prompts.append(f"<|im_start|>user\n{image_tags} <|im_end|>\n") + negative_prompt = neg_prompts + elif negative_prompt is not None and negative_prompt_embeds is None: + neg_list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + neg_prompts = [] + for i in range(batch_size): + num_refs = len(images[i]) if images is not None and i < len(images) else 0 + image_tags = "".join(["\n" for _ in range(num_refs)]) + n = neg_list[i] if i < len(neg_list) else neg_list[0] + neg_prompts.append(f"<|im_start|>user\n{image_tags}{n}<|im_end|>\n") + negative_prompt = neg_prompts + + if negative_prompt_embeds is None: + neg_prompt_list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt_multiple_images( + prompt=neg_prompt_list, + images=flattened_images, + device=device, + max_sequence_length=max_sequence_length, + ) + + # Pad and concatenate [negative, positive] + max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) + prompt_embeds = torch.cat([ + self._pad_sequence(negative_prompt_embeds, max_seq_len), + self._pad_sequence(prompt_embeds, max_seq_len), + ]) + if prompt_embeds_mask is not None and negative_prompt_embeds_mask is not None: + prompt_embeds_mask = torch.cat([ + self._pad_sequence(negative_prompt_embeds_mask, max_seq_len), + self._pad_sequence(prompt_embeds_mask, max_seq_len), + ]) + torch.save(prompt_embeds, 'prompt_embeds_2.pt') + + # Prepare timesteps — compute sigmas with single shift to match original scheduler + if timesteps is None and sigmas is None: + shift = getattr(self.scheduler.config, "shift", 1.0) + raw_sigmas = torch.linspace(1, 0, num_inference_steps + 1) + shifted_sigmas = shift * raw_sigmas / (1 + (shift - 1) * raw_sigmas) + sigmas = shifted_sigmas[:-1].tolist() + original_shift = self.scheduler.shift + self.scheduler.set_shift(1.0) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self.scheduler.set_shift(original_shift) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # Prepare latents (patchified) + num_channels_latents = self.transformer.config.in_channels + padded_latents, target_mask, shape_list = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + reference_images=images, + enable_denormalization=enable_denormalization, + ) + torch.save(padded_latents, "padded_latents.pt") + torch.save(target_mask, "target_mask.pt") + # exit(0) + + # Zero out padding text tokens to prevent them from corrupting attention + # (original uses explicit attention masking; here we neutralize padding values) + if prompt_embeds_mask is not None: + prompt_embeds = prompt_embeds * prompt_embeds_mask.unsqueeze(-1) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + clean_reference_backup = padded_latents.clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Restore reference patches + padded_latents[~target_mask] = clean_reference_backup[~target_mask] + + model_input = padded_latents + + # CFG expansion + if self.do_classifier_free_guidance: + model_input_cfg = torch.cat([model_input] * 2) + t_expand = t.repeat(model_input_cfg.shape[0]) + cfg_shape_list = shape_list * 2 + else: + model_input_cfg = model_input + t_expand = t.repeat(batch_size) + cfg_shape_list = shape_list + + # Transformer forward + noise_pred = self.transformer( + hidden_states=model_input_cfg, + timestep=t_expand, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + shape_list=cfg_shape_list, + return_dict=False, + )[0] + + # CFG combination with norm rescaling + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + comb_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + cond_norm = torch.norm(noise_pred_text, dim=2, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=2, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # Scheduler step + padded_latents = self.scheduler.step(noise_pred, t, padded_latents, return_dict=False)[0].to( + dtype=prompt_embeds.dtype + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + padded_latents = callback_outputs.pop("latents", padded_latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if progress_bar is not None: + progress_bar.update() + + # Post-processing: decode target latents + if output_type != "latent": + padded_latents[~target_mask] = clean_reference_backup[~target_mask] + pt, ph, pw = self.transformer.config.patch_size + + image_list = [] + for b_idx in range(batch_size): + l_t, l_h, l_w = shape_list[b_idx][0] + target_len = l_t * l_h * l_w + + target_patches = padded_latents[b_idx, :target_len] + video_latent = rearrange( + target_patches, + "(t h w) c pt ph pw -> 1 c (t pt) (h ph) (w pw)", + t=l_t, h=l_h, w=l_w, + ) + + video_latent = self.denormalize_latents(video_latent) + + with torch.autocast(device_type="cuda", dtype=torch.float32): + sample_image = self.vae.decode(video_latent.float(), return_dict=False)[0] + sample_image = (sample_image / 2 + 0.5).clamp(0, 1).squeeze(0).cpu().float() + image_list.append(sample_image) + + # Convert to output format + output_images = [] + for img_tensor in image_list: + # img_tensor: [C, T, H, W] -> [C, H, W] (T=1) + img_tensor = img_tensor[:, 0] + img_np = (img_tensor.permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) + if output_type == "pil": + output_images.append(Image.fromarray(img_np)) + elif output_type == "np": + output_images.append(img_np) + else: + output_images.append(img_tensor) + + image = output_images + else: + image = padded_latents + + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return JoyImageEditPlusPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 175dce3540d7..40d9d3aa100f 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -14,3 +14,11 @@ class JoyImageEditPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + +@dataclass +class JoyImageEditPlusPipelineOutput(BaseOutput): + """ + Output class for JoyImage Edit Plus multi-image editing pipelines. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file From d6375a8b618bbc96078356bd1248efff202b7977 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Mon, 22 Jun 2026 05:31:11 +0000 Subject: [PATCH 2/7] refactor: remove debug code --- .../pipelines/joyimage/pipeline_joyimage_edit_plus.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index c938e8e8ab32..980939f427d6 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -506,10 +506,8 @@ def __call__( max_sequence_length=max_sequence_length, ) - torch.save(prompt_embeds, "prompt_embeds.pt") # Encode negative prompt for CFG if self.do_classifier_free_guidance: - print(f"negative_prompt: {negative_prompt}") if negative_prompt is None and negative_prompt_embeds is None: neg_prompts = [] for i in range(batch_size): @@ -547,7 +545,6 @@ def __call__( self._pad_sequence(negative_prompt_embeds_mask, max_seq_len), self._pad_sequence(prompt_embeds_mask, max_seq_len), ]) - torch.save(prompt_embeds, 'prompt_embeds_2.pt') # Prepare timesteps — compute sigmas with single shift to match original scheduler if timesteps is None and sigmas is None: @@ -579,9 +576,6 @@ def __call__( reference_images=images, enable_denormalization=enable_denormalization, ) - torch.save(padded_latents, "padded_latents.pt") - torch.save(target_mask, "target_mask.pt") - # exit(0) # Zero out padding text tokens to prevent them from corrupting attention # (original uses explicit attention masking; here we neutralize padding values) From 885186a66852d10a3f733aad8189d278348ea13f Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:15:16 +0000 Subject: [PATCH 3/7] fix: address review issues for JoyImage Edit Plus - Remove einops dependency: replace rearrange with reshape/permute - Remove sglang-specific code from transformer forward - Remove unused import inspect from transformer - Fix hardcoded device_type="cuda" to use device.type - Simplify scheduler sigma math: delegate to retrieve_timesteps - Remove unused enable_denormalization parameter - Fix callback latents variable binding - Fix output_type="pt" to return stacked tensor - Set return_dict default to True in transformer forward - Add dummy objects for JoyImageEditPlus classes - Add transformer and pipeline test files --- .../transformer_joyimage_edit_plus.py | 19 +- .../joyimage/pipeline_joyimage_edit_plus.py | 50 ++-- src/diffusers/utils/dummy_pt_objects.py | 15 ++ .../dummy_torch_and_transformers_objects.py | 30 +++ ...t_models_transformer_joyimage_edit_plus.py | 114 +++++++++ .../joyimage/test_joyimage_edit_plus.py | 225 ++++++++++++++++++ 6 files changed, 403 insertions(+), 50 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_joyimage_edit_plus.py create mode 100644 tests/pipelines/joyimage/test_joyimage_edit_plus.py diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index abc8c2b4340a..572c983ec453 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import math from typing import List, Tuple, Union @@ -255,7 +254,7 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None = None, shape_list: List[List[Tuple[int, int, int]]] | None = None, - return_dict: bool = False, + return_dict: bool = True, ) -> Union[torch.Tensor, Tuple]: """ Args: @@ -269,22 +268,6 @@ def forward( batch_size, max_num_patches, channels, pt, ph, pw = hidden_states.shape device = hidden_states.device - # Unwrap list inputs (SglangXvideo passes these as lists from CFG branches) - if not isinstance(encoder_hidden_states, torch.Tensor): - encoder_hidden_states = encoder_hidden_states[0] - if isinstance(encoder_hidden_states_mask, list): - encoder_hidden_states_mask = encoder_hidden_states_mask[0] - - # Resolve shape_list from forward context if not explicitly provided - if shape_list is None: - try: - from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context - - forward_batch = get_forward_context().forward_batch - if forward_batch is not None and forward_batch.vae_image_sizes is not None: - shape_list = [list(forward_batch.vae_image_sizes)] * batch_size - except (ImportError, AttributeError): - pass if shape_list is None: raise ValueError( "shape_list must be provided either as an argument or via forward_batch.vae_image_sizes" diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index 980939f427d6..144650d46b05 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -18,7 +18,6 @@ import numpy as np import torch -from einops import rearrange from PIL import Image from transformers import ( Qwen2Tokenizer, @@ -282,7 +281,6 @@ def prepare_latents( device: torch.device, generator: Optional[Union[torch.Generator, List[torch.Generator]]], reference_images: Optional[List[List[Image.Image]]] = None, - enable_denormalization: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int, int]]]]: """Prepare 6D padded latent tensor with target noise + reference image latents. @@ -319,7 +317,7 @@ def prepare_latents( ref_tensor = torch.from_numpy(np.array(ref_img_pil.convert("RGB"))).to(device=device, dtype=dtype) ref_tensor = (ref_tensor / 127.5 - 1.0).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) - with torch.autocast(device_type="cuda", dtype=torch.float32): + with torch.autocast(device_type=device.type, dtype=torch.float32): ref_latent = self.vae.encode(ref_tensor.float()).latent_dist.mode() ref_latent = ref_latent.to(dtype) ref_latent = self.normalize_latents(ref_latent) @@ -336,7 +334,8 @@ def prepare_latents( l_t, l_h, l_w = t // pt, h // ph, w // pw sample_shapes.append((l_t, l_h, l_w)) - patches = rearrange(item, "c (t pt) (h ph) (w pw) -> (t h w) c pt ph pw", pt=pt, ph=ph, pw=pw) + patches = item.reshape(c, l_t, pt, l_h, ph, l_w, pw) + patches = patches.permute(1, 3, 5, 0, 2, 4, 6).reshape(-1, c, pt, ph, pw) sample_patches.append(patches) sample_masks.append(torch.full((patches.shape[0],), j == 0, device=device, dtype=torch.bool)) @@ -411,7 +410,6 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 4096, - enable_denormalization: bool = True, ): r""" Generate an edited image from multiple reference images and a text prompt. @@ -434,8 +432,6 @@ def __call__( Negative prompt for CFG. generator (`torch.Generator`, *optional*): RNG generator for reproducibility. - enable_denormalization (`bool`, defaults to True): - Whether to denormalize latents before VAE decoding. Examples: @@ -546,22 +542,10 @@ def __call__( self._pad_sequence(prompt_embeds_mask, max_seq_len), ]) - # Prepare timesteps — compute sigmas with single shift to match original scheduler - if timesteps is None and sigmas is None: - shift = getattr(self.scheduler.config, "shift", 1.0) - raw_sigmas = torch.linspace(1, 0, num_inference_steps + 1) - shifted_sigmas = shift * raw_sigmas / (1 + (shift - 1) * raw_sigmas) - sigmas = shifted_sigmas[:-1].tolist() - original_shift = self.scheduler.shift - self.scheduler.set_shift(1.0) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - self.scheduler.set_shift(original_shift) - else: - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) + # Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # Prepare latents (patchified) num_channels_latents = self.transformer.config.in_channels @@ -574,7 +558,6 @@ def __call__( device=device, generator=generator, reference_images=images, - enable_denormalization=enable_denormalization, ) # Zero out padding text tokens to prevent them from corrupting attention @@ -631,6 +614,7 @@ def __call__( ) if callback_on_step_end is not None: + latents = padded_latents callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] @@ -653,15 +637,13 @@ def __call__( target_len = l_t * l_h * l_w target_patches = padded_latents[b_idx, :target_len] - video_latent = rearrange( - target_patches, - "(t h w) c pt ph pw -> 1 c (t pt) (h ph) (w pw)", - t=l_t, h=l_h, w=l_w, - ) + c_lat = target_patches.shape[1] + video_latent = target_patches.reshape(l_t, l_h, l_w, c_lat, pt, ph, pw) + video_latent = video_latent.permute(3, 0, 4, 1, 5, 2, 6).reshape(1, c_lat, l_t * pt, l_h * ph, l_w * pw) video_latent = self.denormalize_latents(video_latent) - with torch.autocast(device_type="cuda", dtype=torch.float32): + with torch.autocast(device_type=device.type, dtype=torch.float32): sample_image = self.vae.decode(video_latent.float(), return_dict=False)[0] sample_image = (sample_image / 2 + 0.5).clamp(0, 1).squeeze(0).cpu().float() image_list.append(sample_image) @@ -671,15 +653,19 @@ def __call__( for img_tensor in image_list: # img_tensor: [C, T, H, W] -> [C, H, W] (T=1) img_tensor = img_tensor[:, 0] - img_np = (img_tensor.permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) if output_type == "pil": + img_np = (img_tensor.permute(1, 2, 0).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) output_images.append(Image.fromarray(img_np)) elif output_type == "np": + img_np = (img_tensor.permute(1, 2, 0).cpu().float().numpy() * 255).clip(0, 255).astype(np.uint8) output_images.append(img_np) else: output_images.append(img_tensor) - image = output_images + if output_type == "pt": + image = torch.stack(output_images) + else: + image = output_images else: image = padded_latents diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8eb942e68075..06c5b1d425fe 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1500,6 +1500,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class JoyImageEditPlusTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 4d7710adcdd1..8955e52aae6f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2222,6 +2222,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class JoyImageEditPlusPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class JoyImageEditPlusPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Kandinsky3Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py b/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py new file mode 100644 index 000000000000..451dbfbbf0ca --- /dev/null +++ b/tests/models/transformers/test_models_transformer_joyimage_edit_plus.py @@ -0,0 +1,114 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers import JoyImageEditPlusTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class JoyImageEditPlusTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return JoyImageEditPlusTransformer3DModel + + @property + def output_shape(self) -> tuple[int, ...]: + return (2, 16, 1, 2, 2) + + @property + def input_shape(self) -> tuple[int, ...]: + return (2, 16, 1, 2, 2) + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def uses_custom_attn_processor(self) -> bool: + return True + + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { + "patch_size": [1, 2, 2], + "in_channels": 16, + "hidden_size": 32, + "num_attention_heads": 2, + "text_dim": 16, + "num_layers": 2, + "rope_dim_list": [4, 6, 6], + "theta": 256, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + max_patches = 2 + hidden_states = randn_tensor( + (batch_size, max_patches, 16, 1, 2, 2), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor((batch_size, 12, 16), generator=self.generator, device=torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + shape_list = [[(1, 1, 1), (1, 1, 1)]] + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "shape_list": shape_list, + } + + +class TestJoyImageEditPlusTransformer(JoyImageEditPlusTransformerTesterConfig, ModelTesterMixin): + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + pytest.skip("Tolerance requirements too high for meaningful test") + + +class TestJoyImageEditPlusTransformerMemory(JoyImageEditPlusTransformerTesterConfig, MemoryTesterMixin): + pass + + +class TestJoyImageEditPlusTransformerTraining(JoyImageEditPlusTransformerTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = {"JoyImageEditPlusTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestJoyImageEditPlusTransformerAttention(JoyImageEditPlusTransformerTesterConfig, AttentionTesterMixin): + pass + + +class TestJoyImageEditPlusTransformerCompile(JoyImageEditPlusTransformerTesterConfig, TorchCompileTesterMixin): + pass diff --git a/tests/pipelines/joyimage/test_joyimage_edit_plus.py b/tests/pipelines/joyimage/test_joyimage_edit_plus.py new file mode 100644 index 000000000000..e41265d30128 --- /dev/null +++ b/tests/pipelines/joyimage/test_joyimage_edit_plus.py @@ -0,0 +1,225 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from PIL import Image +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + JoyImageEditPlusPipeline, + JoyImageEditPlusTransformer3DModel, +) +from diffusers.hooks import apply_group_offloading + +from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class JoyImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = JoyImageEditPlusPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = frozenset(["prompt", "images"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def setUp(self): + super().setUp() + self._bucket_patcher = patch( + "diffusers.pipelines.joyimage.image_processor.find_best_bucket", + return_value=(32, 32), + ) + self._bucket_patcher.start() + + def tearDown(self): + self._bucket_patcher.stop() + super().tearDown() + + def get_dummy_components(self): + tiny_ckpt_id = "huangfeice/tiny-random-Qwen3VLForConditionalGeneration" + + torch.manual_seed(0) + transformer = JoyImageEditPlusTransformer3DModel( + patch_size=[1, 2, 2], + in_channels=16, + hidden_size=32, + num_attention_heads=2, + text_dim=16, + num_layers=1, + rope_dim_list=[4, 6, 6], + theta=256, + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + processor = Qwen3VLProcessor.from_pretrained(tiny_ckpt_id) + processor.image_processor.min_pixels = 4 * 28 * 28 + processor.image_processor.max_pixels = 4 * 28 * 28 + + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(tiny_ckpt_id) + text_encoder.resize_token_embeddings(len(processor.tokenizer)) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": processor.tokenizer, + "processor": processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "combine the two images", + "images": [Image.new("RGB", (32, 32)), Image.new("RGB", (32, 32))], + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + @unittest.skip("num_images_per_prompt not applicable: each prompt is bound to reference images") + def test_num_images_per_prompt(self): + pass + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=False) + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): + super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) + + @require_torch_accelerator + def test_group_offloading_inference(self): + if not self.test_group_offloading: + return + + def create_pipe(): + torch.manual_seed(0) + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + return pipe(**inputs)[0] + + pipe = create_pipe().to(torch_device) + output_without_group_offloading = run_forward(pipe) + + pipe = create_pipe() + for component_name in ["transformer", "text_encoder"]: + component = getattr(pipe, component_name, None) + if component is None: + continue + if hasattr(component, "enable_group_offload"): + component.enable_group_offload( + torch.device(torch_device), offload_type="block_level", num_blocks_per_group=1 + ) + else: + apply_group_offloading( + component, + onload_device=torch.device(torch_device), + offload_type="block_level", + num_blocks_per_group=1, + ) + pipe.vae.to(torch_device) + output_with_block_level = run_forward(pipe) + + pipe = create_pipe() + pipe.transformer.enable_group_offload(torch.device(torch_device), offload_type="leaf_level") + pipe.text_encoder.to(torch_device) + pipe.vae.to(torch_device) + output_with_leaf_level = run_forward(pipe) + + if torch.is_tensor(output_without_group_offloading): + output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy() + output_with_block_level = output_with_block_level.detach().cpu().numpy() + output_with_leaf_level = output_with_leaf_level.detach().cpu().numpy() + + self.assertTrue(np.allclose(output_without_group_offloading, output_with_block_level, atol=1e-4)) + self.assertTrue(np.allclose(output_without_group_offloading, output_with_leaf_level, atol=1e-4)) + + @unittest.skip("Qwen3VLForConditionalGeneration does not support leaf-level group offloading") + def test_pipeline_level_group_offloading_inference(self): + pass + + @unittest.skip("Qwen3VLForConditionalGeneration does not support sequential CPU offloading") + def test_sequential_cpu_offload_forward_pass(self): + pass + + @unittest.skip("Qwen3VLForConditionalGeneration does not support sequential CPU offloading") + def test_sequential_offload_forward_pass_twice(self): + pass From aa2f5638b31463b1074b9c3e60cbf78332a989db Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:30:41 +0000 Subject: [PATCH 4/7] fix: add missing newline at end of pipeline_output.py --- src/diffusers/pipelines/joyimage/pipeline_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 40d9d3aa100f..23cb24431462 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -21,4 +21,4 @@ class JoyImageEditPlusPipelineOutput(BaseOutput): Output class for JoyImage Edit Plus multi-image editing pipelines. """ - images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file + images: Union[List[PIL.Image.Image], np.ndarray] From 8a911e5614155a0e1c112ac217b5c4e19763b500 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Tue, 23 Jun 2026 02:39:50 +0000 Subject: [PATCH 5/7] fix: add missing newline at end of pipeline_output.py --- src/diffusers/pipelines/joyimage/pipeline_output.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py index 23cb24431462..30be7c248e33 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_output.py +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -22,3 +22,4 @@ class JoyImageEditPlusPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + From 1344fd0275326ff909aa7aa260d84c042c0c40db Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Wed, 24 Jun 2026 10:16:46 +0000 Subject: [PATCH 6/7] doc: add joyimage-edit-plus doc --- docs/source/en/_toctree.yml | 4 ++ .../models/transformer_joyimage_edit_plus.md | 29 +++++++++ .../en/api/pipelines/joyimage_edit_plus.md | 61 +++++++++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 docs/source/en/api/models/transformer_joyimage_edit_plus.md create mode 100644 docs/source/en/api/pipelines/joyimage_edit_plus.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 23e2c867b580..f3239722c64f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -355,6 +355,8 @@ title: Ideogram4Transformer2DModel - local: api/models/transformer_joyimage title: JoyImageEditTransformer3DModel + - local: api/models/transformer_joyimage_edit_plus + title: JoyImageEditPlusTransformer3DModel - local: api/models/krea2_transformer2d title: Krea2Transformer2DModel - local: api/models/latte_transformer3d @@ -555,6 +557,8 @@ title: InstructPix2Pix - local: api/pipelines/joyimage_edit title: JoyImage Edit + - local: api/pipelines/joyimage_edit_plus + title: JoyImage Edit Plus - local: api/pipelines/kandinsky title: Kandinsky 2.1 - local: api/pipelines/kandinsky_v22 diff --git a/docs/source/en/api/models/transformer_joyimage_edit_plus.md b/docs/source/en/api/models/transformer_joyimage_edit_plus.md new file mode 100644 index 000000000000..776c53eaf20c --- /dev/null +++ b/docs/source/en/api/models/transformer_joyimage_edit_plus.md @@ -0,0 +1,29 @@ + + +# JoyImageEditPlusTransformer3DModel + +The model can be loaded with the following code snippet. + +```python +from diffusers import JoyImageEditPlusTransformer3DModel + +transformer = JoyImageEditPlusTransformer3DModel.from_pretrained("jdopensource/JoyAI-Image-Edit-Plus-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## JoyImageEditPlusTransformer3DModel + +[[autodoc]] JoyImageEditPlusTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/joyimage_edit_plus.md b/docs/source/en/api/pipelines/joyimage_edit_plus.md new file mode 100644 index 000000000000..2ce8e2f29647 --- /dev/null +++ b/docs/source/en/api/pipelines/joyimage_edit_plus.md @@ -0,0 +1,61 @@ + + +# JoyAI-Image-Edit-Plus + +[JoyAI-Image](https://github.com/jd-opensource/JoyAI-Image) is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT). + +JoyAI-Image-Edit-Plus is a multi-image instruction-guided editing model that accepts **multiple reference images** and a text instruction to generate a new image that combines elements from the references according to the instruction. It supports 1–5 reference images per sample. + +| Model | Description | Download | +|:-----:|:-----------:|:--------:| +| JoyAI-Image-Edit-Plus | Multi-image instruction-guided editing with element composition from multiple references | [Hugging Face](https://huggingface.co/jdopensource/JoyAI-Image-Edit-Plus-Diffusers) | + +```python +import torch +from PIL import Image +from diffusers import JoyImageEditPlusPipeline + +pipeline = JoyImageEditPlusPipeline.from_pretrained( + "jdopensource/JoyAI-Image-Edit-Plus-Diffusers", torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +images = [ + Image.open("reference_0.png").convert("RGB"), + Image.open("reference_1.png").convert("RGB"), +] + +target_h, target_w = pipeline._get_bucket_size(images[-1]) + +output = pipeline( + images=images, + prompt="Combine the person from the second image with the scene from the first image.", + negative_prompt="low quality, blurry, deformed", + height=target_h, + width=target_w, + num_inference_steps=30, + guidance_scale=4.0, + generator=torch.Generator("cuda").manual_seed(42), +).images[0] +output.save("joyimage_edit_plus_output.png") +``` + +## JoyImageEditPlusPipeline + +[[autodoc]] JoyImageEditPlusPipeline + - all + - __call__ + +## JoyImageEditPlusPipelineOutput + +[[autodoc]] pipelines.joyimage.pipeline_output.JoyImageEditPlusPipelineOutput From 53d0b331fa6cc7a482b3e6bde50c9f20922db6b4 Mon Sep 17 00:00:00 2001 From: "tangyanfei.8" Date: Wed, 24 Jun 2026 11:27:30 +0000 Subject: [PATCH 7/7] refactor: update code format --- .../transformer_joyimage_edit_plus.py | 48 +++-- .../joyimage/pipeline_joyimage_edit_plus.py | 174 +++++++++++++----- 2 files changed, 159 insertions(+), 63 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py index 572c983ec453..125dc30cc726 100644 --- a/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py +++ b/src/diffusers/models/transformers/transformer_joyimage_edit_plus.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from typing import List, Tuple, Union import torch import torch.nn as nn @@ -40,8 +39,8 @@ def _apply_rotary_emb_batched( xq: torch.Tensor, xk: torch.Tensor, - freqs_cis: Tuple[torch.Tensor, torch.Tensor], -) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis: tuple[torch.Tensor, torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: """RoPE that handles both batched [B, S, D] and unbatched [S, D] freqs.""" cos, sin = freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device) @@ -77,10 +76,10 @@ def __call__( attn: "JoyImageAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, - image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: if encoder_hidden_states is None: raise ValueError("JoyImageEditPlusAttnProcessor requires encoder_hidden_states") @@ -140,12 +139,37 @@ def __call__( class JoyImageEditPlusTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): - """JoyImage Edit Plus Transformer for multi-image editing. + r""" + JoyImage Edit Plus Transformer for multi-image editing. Uses a patchify+padding approach where each reference image and the target noise are independently patchified and concatenated into a flat patch sequence. Supports variable-resolution reference images. - Input format: [B, max_patches, C, pt, ph, pw] (6D padded patches) + Input format: `[B, max_patches, C, pt, ph, pw]` (6D padded patches). + + Args: + patch_size (`list`, defaults to `[1, 2, 2]`): + Patch size for patchifying the latent input along `(t, h, w)` dimensions. + in_channels (`int`, defaults to `16`): + The number of channels in the input latent. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + hidden_size (`int`, defaults to `3072`): + The dimensionality of the hidden representations. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads. + text_dim (`int`, defaults to `4096`): + The dimensionality of the text encoder output. + mlp_width_ratio (`float`, defaults to `4.0`): + The ratio of MLP hidden dimension to `hidden_size`. + num_layers (`int`, defaults to `20`): + The number of double-stream transformer blocks. + rope_dim_list (`list[int]`, defaults to `[16, 56, 56]`): + The dimensions for 3D rotary positional embeddings along `(t, h, w)`. + rope_type (`str`, defaults to `"rope"`): + The type of rotary positional embedding. + theta (`int`, defaults to `256`): + The base frequency for rotary embeddings. """ _skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"] @@ -222,9 +246,9 @@ def __init__( def _get_rotary_pos_embed_for_range( self, - start: Tuple[int, int, int], - stop: Tuple[int, int, int], - ) -> Tuple[torch.Tensor, torch.Tensor]: + start: tuple[int, int, int], + stop: tuple[int, int, int], + ) -> tuple[torch.Tensor, torch.Tensor]: """Generate 3D RoPE for a spatial range [start, stop).""" head_dim = self.hidden_size // self.num_attention_heads rope_dim_list = self.rope_dim_list @@ -253,9 +277,9 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None = None, - shape_list: List[List[Tuple[int, int, int]]] | None = None, + shape_list: list[list[tuple[int, int, int]]] | None = None, return_dict: bool = True, - ) -> Union[torch.Tensor, Tuple]: + ) -> torch.Tensor | tuple: """ Args: hidden_states: [B, max_patches, C, pt, ph, pw] - patchified latent input. diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py index 144650d46b05..d314219b45d4 100644 --- a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit_plus.py @@ -14,7 +14,7 @@ import inspect import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable import numpy as np import torch @@ -30,13 +30,16 @@ from ...models import AutoencoderKLWan from ...models.transformers.transformer_joyimage_edit_plus import JoyImageEditPlusTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import replace_example_docstring +from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .image_processor import JoyImageEditImageProcessor, find_best_bucket from .pipeline_output import JoyImageEditPlusPipelineOutput +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + EXAMPLE_DOC_STRING = """ Examples: ```python @@ -68,12 +71,35 @@ def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") @@ -97,12 +123,27 @@ def retrieve_timesteps( class JoyImageEditPlusPipeline(DiffusionPipeline): - """Diffusion pipeline for multi-image editing using JoyImage Edit Plus. + r""" + Diffusion pipeline for multi-image instruction-guided editing using JoyImage Edit Plus. Supports multiple reference images with different resolutions. Each reference image is independently VAE-encoded and patchified, then concatenated with the target noise patches for joint denoising. - Model offloading order: text_encoder -> transformer -> vae. + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3VLForConditionalGeneration`]): + Multimodal text encoder for prompt encoding with inline image understanding. + tokenizer ([`Qwen2Tokenizer`]): + Tokenizer for text processing. + transformer ([`JoyImageEditPlusTransformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + processor ([`Qwen3VLProcessor`]): + Processor for multimodal inputs (text + images). + text_token_max_length (`int`, defaults to `2048`): + Maximum token length for text encoding. """ model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -186,11 +227,11 @@ def _hook(_module, _input, output): def encode_prompt_multiple_images( self, - prompt: Union[str, List[str]], - device: Optional[torch.device] = None, - images: Optional[List[Image.Image]] = None, - max_sequence_length: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + prompt: str | list[str], + device: torch.device | None = None, + images: list[Image.Image] | None = None, + max_sequence_length: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """Encode prompts with inline tokens via the Qwen3-VL processor.""" device = device or self._execution_device template = self.prompt_template_encode["multiple_images"] @@ -257,7 +298,7 @@ def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: latent = latent / self.vae.config.scaling_factor return latent - def _resize_center_crop(self, img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + def _resize_center_crop(self, img: Image.Image, target_size: tuple[int, int]) -> Image.Image: w, h = img.size bh, bw = target_size scale = max(bh / h, bw / w) @@ -268,7 +309,7 @@ def _resize_center_crop(self, img: Image.Image, target_size: Tuple[int, int]) -> img = img.crop((left, top, left + bw, top + bh)) return img - def _get_bucket_size(self, img: Image.Image) -> Tuple[int, int]: + def _get_bucket_size(self, img: Image.Image) -> tuple[int, int]: return find_best_bucket(img.size[1], img.size[0], self.vae_image_processor.config.basesize) def prepare_latents( @@ -279,9 +320,9 @@ def prepare_latents( width: int, dtype: torch.dtype, device: torch.device, - generator: Optional[Union[torch.Generator, List[torch.Generator]]], - reference_images: Optional[List[List[Image.Image]]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int, int]]]]: + generator: torch.Generator | list[torch.Generator] | None, + reference_images: list[list[Image.Image]] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, list[list[tuple[int, int, int]]]]: """Prepare 6D padded latent tensor with target noise + reference image latents. Returns: @@ -388,55 +429,86 @@ def interrupt(self) -> bool: @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - images: List[Image.Image] | List[List[Image.Image]] | None = None, - prompt: str | List[str] = None, + images: list[Image.Image] | list[list[Image.Image]] | None = None, + prompt: str | list[str] = None, height: int | None = None, width: int | None = None, num_inference_steps: int = 30, - timesteps: List[int] = None, - sigmas: List[float] = None, + timesteps: list[int] = None, + sigmas: list[float] = None, guidance_scale: float = 4.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_mask: Optional[torch.Tensor] = None, - output_type: Optional[str] = "pil", + negative_prompt: str | list[str] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", return_dict: bool = True, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 4096, ): r""" - Generate an edited image from multiple reference images and a text prompt. + Function invoked when calling the pipeline for generation. Args: - images (`List[Image.Image]` or `List[List[Image.Image]]`): - Reference images for editing. Each image can have a different resolution. - If a flat list is provided, it's treated as one sample with multiple references. - prompt (`str` or `List[str]`): - Text prompt describing the desired edit. + images (`list[Image.Image]` or `list[list[Image.Image]]`, *optional*): + Reference images for editing. Each image can have a different resolution. If a flat list is provided, + it is treated as one sample with multiple references. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. height (`int`, *optional*): - Output height in pixels. If None, determined from the last reference image's bucket. + The height in pixels of the generated image. If `None`, determined from the last reference image. width (`int`, *optional*): - Output width in pixels. If None, determined from the last reference image's bucket. - num_inference_steps (`int`, defaults to 30): - Number of denoising steps. - guidance_scale (`float`, defaults to 4.0): - Classifier-free guidance scale. - negative_prompt (`str` or `List[str]`, *optional*): - Negative prompt for CFG. - generator (`torch.Generator`, *optional*): - RNG generator for reproducibility. + The width in pixels of the generated image. If `None`, determined from the last reference image. + num_inference_steps (`int`, *optional*, defaults to `30`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spacing is used. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Classifier-free guidance scale. Higher values encourage the model to generate images more aligned + with the `prompt` at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, a blank prompt is used + for classifier-free guidance. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"` (`PIL.Image.Image`), `"np"` + (`np.ndarray`), `"pt"` (`torch.Tensor`), or `"latent"` for raw latent output. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`JoyImageEditPlusPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step with arguments: the pipeline, step index, + timestep, and a dict of callback tensor inputs. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to `4096`): + Maximum sequence length for the text encoder. Examples: Returns: - [`JoyImageEditPlusPipelineOutput`] or `tuple`. + [`JoyImageEditPlusPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`JoyImageEditPlusPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list of generated images. """ # Normalize images input to List[List[Image]] if images is not None: