diff --git a/scripts/convert_joyimage_edit_to_diffusers.py b/scripts/convert_joyimage_edit_to_diffusers.py new file mode 100644 index 000000000000..37506ea05d17 --- /dev/null +++ b/scripts/convert_joyimage_edit_to_diffusers.py @@ -0,0 +1,306 @@ +import argparse +import pathlib +from typing import Any, Dict, Tuple +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors.torch import load_file +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from safetensors.torch import load_file +from diffusers import ( + AutoencoderKLWan, + JoyImageEditTransformer3DModel, + JoyImageEditPipeline, +) +# This code is modified from convert_wan_to_diffusers.py to support input ckpt path +def convert_vae(vae_ckpt_path): + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in old_state_dict.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + new_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + new_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + new_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + new_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Convert to down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Convert residual block naming but keep the original structure + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + + new_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Convert to up_blocks + parts = key.split(".") + block_idx = int(parts[2]) + + # Group residual blocks + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + # Keep as is for other blocks + new_state_dict[key] = value + continue + + # Convert residual block naming + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + + new_state_dict[new_key] = value + + # Handle shortcut connections + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + + new_state_dict[new_key] = value + + # Handle upsamplers + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + new_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + # Keep other keys unchanged + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan() + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + +def get_transformer_config() -> Tuple[Dict[str, Any], ...]: + config = { + "diffusers_config": { + "hidden_size": 4096, + "in_channels": 16, + "heads_num": 32, + "mm_double_blocks_depth": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_states_dim": 4096, + "rope_type": "rope", + "dit_modulation_type": "wanx", + "unpatchify_new": True, + "rope_theta": 10000, + }, + } + return config +def convert_transformer(ckpt_path: str): + checkpoint = torch.load(ckpt_path, weights_only=True) + if "model" in checkpoint: + original_state_dict = checkpoint["model"] + else: + original_state_dict = checkpoint + config = get_transformer_config() + with init_empty_weights(): + transformer = JoyImageEditTransformer3DModel(**config['diffusers_config']) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") + parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + parser.add_argument("--flow_shift", type=float, default=7.0) + return parser.parse_args() + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} +if __name__ == "__main__": + args = get_args() + transformer = None + vae = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + # assert args.tokenizer_path is not None + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + vae = vae.to(dtype=dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.save_pipeline: + processor = AutoProcessor.from_pretrained(args.text_encoder_path) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16).to("cuda") + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) + flow_shift = 1.5 + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=flow_shift + ) + transformer = transformer.to("cuda") + vae = vae.to("cuda") + pipe = JoyImageEditPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + processor.save_pretrained(f"{args.output_path}/processor") \ No newline at end of file diff --git a/setup.py b/setup.py index d42da57920a0..e16a2b792e25 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,7 @@ "accelerate>=0.31.0", "compel==0.1.8", "datasets", + "einops", "filelock", "flax>=0.4.1", "ftfy", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7d966452d1a2..0caedd8d7b81 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -237,6 +237,7 @@ "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", "HunyuanImageTransformer2DModel", + "JoyImageEditTransformer3DModel", "HunyuanVideo15Transformer3DModel", "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", @@ -596,6 +597,7 @@ "LTXLatentUpsamplePipeline", "LTXPipeline", "LucyEditPipeline", + "JoyImageEditPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", "LuminaPipeline", @@ -1025,6 +1027,7 @@ HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, HunyuanImageTransformer2DModel, + JoyImageEditTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, @@ -1359,6 +1362,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, LucyEditPipeline, + JoyImageEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, LuminaPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..6364ce0ab78f 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -110,6 +110,7 @@ _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] + _import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel", "JoyImageTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] @@ -225,6 +226,7 @@ HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, + JoyImageEditTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..1fc78c124618 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -34,6 +34,7 @@ from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel + from .transformer_joyimage import JoyImageEditTransformer3DModel, JoyImageTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py new file mode 100644 index 000000000000..6bf3404c196f --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -0,0 +1,653 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# --------------------------------------------------------------------------- +# Rotary position embedding utilities +# --------------------------------------------------------------------------- + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + if len(x) == dim: + return tuple(x) + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def _get_meshgrid_nd(start, *args, dim=2): + if len(args) == 0: + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = _to_tuple(args[1], dim=dim) + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") + return torch.stack(grid, dim=0) + + +def _reshape_for_broadcast(freqs_cis, x: torch.Tensor, head_first: bool = False): + ndim = x.ndim + if isinstance(freqs_cis, tuple): + if head_first: + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + if head_first: + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def _rotate_half(x: torch.Tensor): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def _apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + cos, sin = _reshape_for_broadcast(freqs_cis, xq, head_first) + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + _rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk.float()) * sin).type_as(xk) + return xq_out, xk_out + + +def _get_1d_rotary_pos_embed( + dim: int, + pos, + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, +): + if isinstance(pos, int): + pos = torch.arange(pos).float() + + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + freqs = torch.outer(pos.float() * interpolation_factor, freqs) + + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) + return freqs_cos, freqs_sin + + return torch.polar(torch.ones_like(freqs), freqs) + + +def _get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + txt_rope_size=None, + theta_rescale_factor=1.0, + interpolation_factor=1.0, +): + rope_dim_list = list(rope_dim_list) + grid = _get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) + + if isinstance(theta_rescale_factor, (int, float)): + theta_rescale_factor = [float(theta_rescale_factor)] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [float(theta_rescale_factor[0])] * len(rope_dim_list) + + if isinstance(interpolation_factor, (int, float)): + interpolation_factor = [float(interpolation_factor)] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [float(interpolation_factor[0])] * len(rope_dim_list) + + embs = [] + for i in range(len(rope_dim_list)): + emb = _get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs.append(emb) + + if use_real: + vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1)) + else: + vis_emb = torch.cat(embs, dim=1) + + if txt_rope_size is None: + return vis_emb, None + + embs_txt = [] + vis_max_ids = grid.view(-1).max().item() + grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1 + for i in range(len(rope_dim_list)): + emb = _get_1d_rotary_pos_embed( + rope_dim_list[i], + grid_txt, + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs_txt.append(emb) + + if use_real: + txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1)) + else: + txt_emb = torch.cat(embs_txt, dim=1) + + return vis_emb, txt_emb + + +# --------------------------------------------------------------------------- +# Modulation +# --------------------------------------------------------------------------- + + +class JoyImageModulate(nn.Module): + """Wan-style learnable modulation table. + + Produces `factor` modulation vectors by adding the conditioning signal to a + learnable parameter table. + """ + + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + if x.ndim != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] + + +# --------------------------------------------------------------------------- +# Attention processor +# --------------------------------------------------------------------------- + + +class JoyImageAttnProcessor: + """Attention processor for JoyImage double-stream joint attention. + + Implements the joint attention computation where text and image streams are + processed together. The block stores fused QKV projections directly + (``img_attn_qkv`` / ``txt_attn_qkv``) so this processor operates on + a ``JoyImageTransformerBlock`` rather than on a generic ``Attention`` module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + pass + + def __call__( + self, + block: "JoyImageTransformerBlock", + hidden_states: torch.Tensor, # image stream (B, S_img, D) + encoder_hidden_states: torch.Tensor = None, # text stream (B, S_txt, D) + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is None: + raise ValueError("JoyImageAttnProcessor requires encoder_hidden_states (text stream)") + + heads = block.num_attention_heads + + # image stream: fused QKV -> split + img_qkv = block.img_attn_qkv(hidden_states) + img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) + + # text stream: fused QKV -> split + txt_qkv = block.txt_attn_qkv(encoder_hidden_states) + txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) + + # reshape to multi-head: (B, S, H, D) + img_query = img_query.unflatten(-1, (heads, -1)) + img_key = img_key.unflatten(-1, (heads, -1)) + img_value = img_value.unflatten(-1, (heads, -1)) + + txt_query = txt_query.unflatten(-1, (heads, -1)) + txt_key = txt_key.unflatten(-1, (heads, -1)) + txt_value = txt_value.unflatten(-1, (heads, -1)) + + # QK norm + img_query = block.img_attn_q_norm(img_query) + img_key = block.img_attn_k_norm(img_key) + txt_query = block.txt_attn_q_norm(txt_query) + txt_key = block.txt_attn_k_norm(txt_key) + + # RoPE (custom implementation) + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_query, img_key = _apply_rotary_emb(img_query, img_key, vis_freqs, head_first=False) + if txt_freqs is not None: + txt_query, txt_key = _apply_rotary_emb(txt_query, txt_key, txt_freqs, head_first=False) + + # concatenate for joint attention: [img, txt] + joint_query = torch.cat([img_query, txt_query], dim=1) + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # split back + img_attn_output = joint_hidden_states[:, : hidden_states.shape[1], :] + txt_attn_output = joint_hidden_states[:, hidden_states.shape[1] :, :] + + # output projections + img_attn_output = block.img_attn_proj(img_attn_output) + txt_attn_output = block.txt_attn_proj(txt_attn_output) + + return img_attn_output, txt_attn_output + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def _apply_gate(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + return x * gate.unsqueeze(1) + + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- + + +@maybe_allow_in_graph +class JoyImageTransformerBlock(nn.Module): + """Double-stream transformer block for JoyImage. + + Each block processes an image stream and a text stream jointly through + shared attention, following the SD3 / Flux double-stream pattern with + WAN-style modulation. + + Attention projections are stored **directly** on the block (fused QKV) + so that weight keys match the checkpoint layout, e.g. + ``double_blocks.0.img_attn_qkv.weight``. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: float = 4.0, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + mlp_hidden_dim = int(dim * mlp_width_ratio) + + # image stream + self.img_mod = JoyImageModulate(dim, factor=6) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + # text stream + self.txt_mod = JoyImageModulate(dim, factor=6) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + # ---- joint attention (fused QKV, directly on the block) ---- + # image attention layers + self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) + self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.img_attn_k_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.img_attn_proj = nn.Linear(inner_dim, dim, bias=True) + + # text attention layers + self.txt_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) + self.txt_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.txt_attn_k_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.txt_attn_proj = nn.Linear(inner_dim, dim, bias=True) + + self.processor = JoyImageAttnProcessor() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # modulation + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(temb) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(temb) + + # --- attention --- + img_modulated = _modulate(self.img_norm1(hidden_states), img_mod1_shift, img_mod1_scale) + txt_modulated = _modulate(self.txt_norm1(encoder_hidden_states), txt_mod1_shift, txt_mod1_scale) + + img_attn, txt_attn = self.processor( + self, + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + _apply_gate(img_attn, img_mod1_gate) + encoder_hidden_states = encoder_hidden_states + _apply_gate(txt_attn, txt_mod1_gate) + + # --- FFN --- + hidden_states = hidden_states + _apply_gate( + self.img_mlp(_modulate(self.img_norm2(hidden_states), img_mod2_shift, img_mod2_scale)), + img_mod2_gate, + ) + encoder_hidden_states = encoder_hidden_states + _apply_gate( + self.txt_mlp(_modulate(self.txt_norm2(encoder_hidden_states), txt_mod2_shift, txt_mod2_scale)), + txt_mod2_gate, + ) + + return hidden_states, encoder_hidden_states + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + + +class JoyImageTransformer3DModel(ModelMixin, ConfigMixin): + """JoyImage Transformer model for image generation / editing. + + Dual-stream DiT architecture with WAN-style conditioning embeddings and + custom rotary position embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["JoyImageTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: int | None = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, + mlp_width_ratio: float = 4.0, + num_layers: int = 20, + rope_dim_list: list[int] = [16, 56, 56], + rope_type: str = "rope", + theta: int = 256, + # legacy config.json keys (kept for backward compatibility) + heads_num: int | None = None, + mm_double_blocks_depth: int | None = None, + text_states_dim: int | None = None, + rope_theta: int | None = None, + ): + super().__init__() + + # --- backward-compatible parameter mapping --- + if heads_num is not None: + num_attention_heads = heads_num + if mm_double_blocks_depth is not None: + num_layers = mm_double_blocks_depth + if text_states_dim is not None: + text_dim = text_states_dim + if rope_theta is not None: + theta = rope_theta + + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.rope_dim_list = rope_dim_list + self.rope_type = rope_type + self.theta = theta + + attention_head_dim = hidden_size // num_attention_heads + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) + + # image projection + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + # condition embedder (re-uses WAN implementation) + from .transformer_wan import WanTimeTextImageEmbedding + + self.condition_embedder = WanTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_dim, + ) + + # double-stream blocks + self.double_blocks = nn.ModuleList( + [ + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + ) + for _ in range(num_layers) + ] + ) + + # output head + self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size)) + + # ------------------------------------------------------------------ + # RoPE helper + # ------------------------------------------------------------------ + + def get_rotary_pos_embed( + self, + vis_rope_size: list[int], + txt_rope_size: int | None = None, + ): + target_ndim = 3 + if len(vis_rope_size) != target_ndim: + vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + list(vis_rope_size) + + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal head_dim" + + vis_freqs, txt_freqs = _get_nd_rotary_pos_embed( + rope_dim_list, + vis_rope_size, + txt_rope_size=txt_rope_size, + theta=self.theta, + use_real=True, + theta_rescale_factor=1, + ) + return vis_freqs, txt_freqs + + # ------------------------------------------------------------------ + # Unpatchify + # ------------------------------------------------------------------ + + def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: + c = self.out_channels + pt, ph, pw = self.patch_size + assert t * h * w == x.shape[1] + + x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c) + x = torch.einsum("nthwopqc->nctohpwq", x) + return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw) + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + return_dict: bool = True, + ): + # handle multi-item input (b, n, c, t, h, w) + is_multi_item = hidden_states.ndim == 6 + num_items = 0 + if is_multi_item: + num_items = hidden_states.shape[1] + if num_items > 1: + assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1" + hidden_states = torch.cat( + [hidden_states[:, -1:], hidden_states[:, :-1]], dim=1 + ) + # rearrange: (b, n, c, t, h, w) -> (b, c, n*t, h, w) + b, n, c, t, h, w = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t, h, w) + + batch_size, _, ot, oh, ow = hidden_states.shape + tt = ot // self.patch_size[0] + th = oh // self.patch_size[1] + tw = ow // self.patch_size[2] + + # patchify + img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + + # condition embeddings + _, vec, txt, _ = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + txt_seq_len = txt.shape[1] + + # RoPE + vis_freqs, txt_freqs = self.get_rotary_pos_embed( + vis_rope_size=[tt, th, tw], + txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, + ) + + # main loop + for block in self.double_blocks: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=(vis_freqs, txt_freqs), + ) + + # final layer + img = self.proj_out(self.norm_out(img)) + img = self.unpatchify(img, tt, th, tw) + + # un-multi-item: (b, c, n*t, h, w) -> (b, n, c, t, h, w) + if is_multi_item: + c_out = img.shape[1] + img = img.reshape(batch_size, c_out, num_items, -1, oh, ow) + img = img.permute(0, 2, 1, 3, 4, 5) # (b, n, c, t, h, w) + if num_items > 1: + img = torch.cat([img[:, 1:], img[:, :1]], dim=1) + + if not return_dict: + return (img,) + return Transformer2DModelOutput(sample=img) + + +class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): + """Alias kept for backward compatibility with pipeline imports.""" + + pass \ No newline at end of file diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3dafb56fdd65..d9159be6156b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -299,6 +299,7 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] + _import_structure["joyimage"] = ["JoyImageEditPipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -740,6 +741,7 @@ ) from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline + from .joyimage import JoyImageEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .marigold import ( diff --git a/src/diffusers/pipelines/joyimage/__init__.py b/src/diffusers/pipelines/joyimage/__init__.py new file mode 100644 index 000000000000..a6d5f31fe63c --- /dev/null +++ b/src/diffusers/pipelines/joyimage/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput"] + _import_structure["pipeline_joyimage_edit"] = ["JoyImageEditPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_joyimage_edit import JoyImageEditPipeline + from .pipeline_output import JoyImageEditPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) \ No newline at end of file diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py new file mode 100644 index 000000000000..2a5e042a7999 --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py @@ -0,0 +1,1138 @@ +import inspect +import math +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torchvision.transforms.functional as TF + +from PIL import Image +from transformers import AutoProcessor, Qwen2Tokenizer, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKLWan, JoyImageEditTransformer3DModel +from ..pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import BaseOutput, deprecate, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import JoyImageEditPipelineOutput + + +EXAMPLE_DOC_STRING = """""" + +# Mapping from precision string to torch dtype. +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +class BucketGroup: + """Manages dynamic batch grouping buckets for image inference.""" + + def __init__( + self, + bucket_configs: list[tuple[int, int, int, int, int]], + prioritize_frame_matching: bool = True, + ): + """ + Initialize bucket group with predefined configurations. + + Args: + bucket_configs: List of (batch_size, num_items, num_frames, height, width) tuples. + prioritize_frame_matching: Unused, kept for API compatibility. + """ + self.bucket_configs = [tuple(b) for b in bucket_configs] + + def find_best_bucket(self, media_shape: tuple[int, int, int, int]) -> tuple[int, int, int, int, int]: + """ + Find the best matching bucket for given media dimensions. + + Selects the bucket whose aspect ratio (height/width) is closest to that of + the input media. Only image inference (num_frames=1) is supported. + + Args: + media_shape: (num_items, num_frames, height, width) of the input media. + + Returns: + Best matching bucket as (batch_size, num_items, num_frames, height, width). + + Raises: + ValueError: If num_frames != 1 or no valid bucket is found. + """ + num_items, num_frames, height, width = media_shape + target_aspect_ratio = height / width + + if num_frames != 1: + raise ValueError( + f"Only image inference (num_frames=1) is supported, got num_frames={num_frames}" + ) + + valid_buckets = [ + b for b in self.bucket_configs + if b[1] == num_items and b[2] == 1 + ] + if not valid_buckets: + raise ValueError(f"No image buckets found for shape {media_shape}") + + return min( + valid_buckets, + key=lambda bucket: abs((bucket[3] / bucket[4]) - target_aspect_ratio), + ) + + +def _get_text_encoder_ckpt( + text_encoder: Qwen3VLForConditionalGeneration, + fallback: str = "Qwen/Qwen3-VL-8B-Instruct", +) -> str: + """ + Retrieve the checkpoint identifier from the text encoder. + + Args: + text_encoder: The text encoder model instance. + fallback: Default checkpoint name if none can be resolved. + + Returns: + A non-empty string identifying the checkpoint. + """ + candidates = [ + getattr(text_encoder, "name_or_path", None), + getattr(getattr(text_encoder, "config", None), "_name_or_path", None), + ] + for c in candidates: + if isinstance(c, str) and len(c) > 0: + return c + return fallback + + +def _generate_hw_buckets( + base_height: int = 256, + base_width: int = 256, + step_width: int = 16, + step_height: int = 16, + max_ratio: float = 4.0, +) -> list[tuple[int, int, int, int, int]]: + """ + Generate (batch_size=1, num_items=1, num_frames=1, height, width) bucket tuples + covering a range of aspect ratios while keeping total pixels close to + base_height * base_width. + + Args: + base_height: Reference height in pixels. + base_width: Reference width in pixels. + step_width: Width increment per step. + step_height: Height decrement per step. + max_ratio: Maximum allowed aspect ratio (long side / short side). + + Returns: + List of bucket tuples (1, 1, 1, height, width). + """ + buckets = [] + target_pixels = base_height * base_width + + height = target_pixels // step_width + width = step_width + + while height >= step_height: + if max(height, width) / min(height, width) <= max_ratio: + buckets.append((1, 1, 1, height, width)) + if height * (width + step_width) <= target_pixels: + width += step_width + else: + height -= step_height + + return buckets + + +def generate_video_image_bucket( + basesize: int = 256, + min_temporal: int = 65, + max_temporal: int = 129, + bs_img: int = 8, + bs_vid: int = 1, + bs_mimg: int = 4, + min_items: int = 1, + max_items: int = 1, +) -> list[list[int]]: + """ + Generate bucket configurations for image inference. + + Each bucket is represented as [batch_size, num_items, num_frames, height, width]. + Spatial dimensions are scaled by (basesize // 256) when basesize > 256. + + Args: + basesize: Base spatial resolution. Must be one of {256, 512, 768, 1024}. + min_temporal: Unused; kept for API compatibility. + max_temporal: Unused; kept for API compatibility. + bs_img: Batch size for single-image buckets. + bs_vid: Unused; kept for API compatibility. + bs_mimg: Batch size for multi-image buckets. + min_items: Minimum number of items in multi-image buckets. + max_items: Maximum number of items in multi-image buckets. + + Returns: + List of bucket configs as [batch_size, num_items, num_frames, height, width]. + + Raises: + AssertionError: If basesize is not in {256, 512, 768, 1024}. + """ + assert basesize in [256, 512, 768, 1024], ( + f"[generate_video_image_bucket] unsupported basesize {basesize}" + ) + bucket_list = [] + base_bucket_list = _generate_hw_buckets() + + # Single-image buckets. + for _bucket in base_bucket_list: + bucket = list(_bucket) + bucket[0] = bs_img + bucket_list.append(bucket) + + # Multi-image buckets. + for num_items in range(min_items, max_items + 1): + for _bucket in base_bucket_list: + bucket = list(_bucket) + bucket[0] = bs_mimg + bucket[1] = num_items + bucket_list.append(bucket) + + # Scale spatial dimensions when basesize exceeds 256. + if basesize > 256: + ratio = basesize // 256 + + def _scale(bucket: list[int], r: int) -> list[int]: + bucket[-2] *= r + bucket[-1] *= r + return bucket + + bucket_list = [_scale(bucket, ratio) for bucket in bucket_list] + + return bucket_list + +def _resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + """Scale to cover target_size, then center-crop.""" + w, h = img.size # PIL uses (width, height). + bh, bw = target_size + scale = max(bh / h, bw / w) + resize_h = math.ceil(h * scale) + resize_w = math.ceil(w * scale) + img = TF.resize(img, (resize_h, resize_w), interpolation=TF.InterpolationMode.BILINEAR, antialias=True) + img = TF.center_crop(img, target_size) + return img + +def _dynamic_resize_from_bucket(image_size: Tuple[int, int], basesize: int = 512) -> Tuple[int, int]: + """ + Resize and center-crop an image to the nearest bucket dimensions. + + The best-matching bucket is selected based on the image's aspect ratio. + The image is first scaled so that neither dimension is smaller than the + target, then center-cropped to the exact target size. + + Args: + image: Input PIL image. Returns None if None is passed. + basesize: Base resolution used to generate candidate buckets. + + Returns: + Resized and cropped PIL image, or None if input is None. + """ + bucket_config = generate_video_image_bucket( + basesize=basesize, + min_temporal=56, + max_temporal=56, + bs_img=4, + bs_vid=4, + bs_mimg=8, + min_items=2, + max_items=2, + ) + bucket_group = BucketGroup(bucket_config) + src_w, src_h = image_size + bucket = bucket_group.find_best_bucket((1, 1, src_h, src_w)) + target_height, target_width = bucket[-2], bucket[-1] + return target_height, target_width + + + +def _build_args( + args: Any, + text_encoder: Qwen3VLForConditionalGeneration, +) -> Any: + """ + Return args unchanged if provided, otherwise construct a default namespace. + + Args: + args: Existing args object, or None. + text_encoder: Text encoder used to resolve the checkpoint path when args is None. + + Returns: + The original args object, or a SimpleNamespace with sensible defaults. + """ + if args is not None: + return args + + text_encoder_ckpt = _get_text_encoder_ckpt(text_encoder) + return SimpleNamespace( + enable_multi_task_training=False, + text_token_max_length=2048, + dit_precision="bf16", + vae_precision="bf16", + text_encoder_arch_config={"params": {"text_encoder_ckpt": text_encoder_ckpt}}, + ) + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Configure the scheduler and return its timestep sequence. + + Exactly one of ``timesteps``, ``sigmas``, or ``num_inference_steps`` should be + provided to control the denoising schedule. + + Args: + scheduler: The diffusion scheduler. + num_inference_steps: Number of denoising steps (used when neither + ``timesteps`` nor ``sigmas`` is given). + device: Target device for the timestep tensor. + timesteps: Custom discrete timesteps. + sigmas: Custom sigma values (alternative to ``timesteps``). + **kwargs: Additional keyword arguments forwarded to ``set_timesteps``. + + Returns: + Tuple of (timesteps tensor, num_inference_steps int). + + Raises: + ValueError: If both ``timesteps`` and ``sigmas`` are provided, or if the + scheduler does not support the requested schedule parameterisation. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + + if timesteps is not None: + if "timesteps" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom timesteps.") + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if "sigmas" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom sigmas.") + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +@dataclass +class _LegacyPipelineOutput(BaseOutput): + """Legacy output dataclass retained for backward compatibility.""" + + videos: Union[torch.Tensor, np.ndarray] + + +class JoyImageEditPipeline(DiffusionPipeline): + """ + Diffusion pipeline for image editing using the JoyImage architecture. + + The pipeline encodes text and image conditioning via a Qwen3-VL text encoder, + denoises latents with a 3-D transformer, and decodes the result with a WAN VAE. + + Model offloading order: text_encoder -> transformer -> vae. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: JoyImageEditTransformer3DModel, + processor: Qwen3VLProcessor, + args: Any = None, + ): + """ + Initialise the pipeline and register all sub-modules. + + Args: + scheduler: Noise scheduler for the denoising process. + vae: Variational autoencoder used for encoding / decoding latents. + text_encoder: Qwen3-VL multimodal language model for prompt encoding. + tokenizer: Tokenizer paired with the text encoder. + transformer: 3-D transformer denoising network. + processor: Qwen3-VL processor for multi-image prompt preparation. + args: Optional configuration namespace. Defaults are inferred when None. + """ + super().__init__() + self.args = _build_args(args=args, text_encoder=text_encoder) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + processor=processor, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = ( + self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + text_encoder_ckpt = dict(self.args.text_encoder_arch_config.get("params", {})).get( + "text_encoder_ckpt", _get_text_encoder_ckpt(self.text_encoder) + ) + self.qwen_processor = ( + processor if processor is not None else AutoProcessor.from_pretrained(text_encoder_ckpt) + ) + + self.text_token_max_length = self.args.text_token_max_length + + # Prompt templates used when encoding text with / without image tokens. + self.prompt_template_encode = { + "image": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ), + "multiple_images": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "{}<|im_start|>assistant\n" + ), + } + # Number of system-prompt tokens to drop from the beginning of hidden states. + self.prompt_template_encode_start_idx = { + "image": 34, + "multiple_images": 34, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _extract_masked_hidden( + self, hidden_states: torch.Tensor, mask: torch.Tensor + ) -> tuple[torch.Tensor, ...]: + """ + Extract valid (non-padded) hidden states for each sequence in the batch. + + Args: + hidden_states: Shape (B, T, D). + mask: Binary attention mask of shape (B, T). + + Returns: + Tuple of tensors, one per batch element, each of shape (valid_T, D). + """ + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + return torch.split(selected, valid_lengths.tolist(), dim=0) + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + template_type: str = "image", + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode text prompts using the Qwen tokenizer (text-only path). + + Args: + prompt: A single prompt string or a list of prompt strings. + template_type: Key into ``prompt_template_encode`` / ``prompt_template_encode_start_idx``. + device: Target device. + dtype: Target floating-point dtype. + + Returns: + Tuple of (prompt_embeds, encoder_attention_mask) where both tensors + have shape (B, max_seq_len, D) and (B, max_seq_len) respectively, + zero-padded to the same length. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + max_length=self.text_token_max_length + drop_idx, + padding=True, + truncation=True, + return_tensors="pt", + ).to(device) + + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + + # Drop system-prompt prefix tokens and re-pack into a padded batch. + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [ + torch.ones(e.size(0), dtype=torch.long, device=e.device) + for e in split_hidden_states + ] + + max_seq_len = min( + self.text_token_max_length, + max(u.size(0) for u in split_hidden_states), + max(u.size(0) for u in attn_mask_list), + ) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds, encoder_attention_mask + + def encode_prompt_multiple_images( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + images: Optional[torch.Tensor] = None, + template_type: Optional[str] = "multiple_images", + max_sequence_length: Optional[int] = None, + drop_vit_feature: Optional[float] = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode prompts that contain inline image tokens via the Qwen processor. + + ``\\n`` placeholders in each prompt string are replaced by the + Qwen vision special tokens before being fed to the multimodal encoder. + + Args: + prompt: Prompt string(s), optionally containing ``\\n`` tokens. + device: Target device. + images: Pixel tensors corresponding to the inline image tokens. + template_type: Must be ``"multiple_images"``. + max_sequence_length: If set, truncate the output to this length + (keeping the last ``max_sequence_length`` tokens). + drop_vit_feature: When True, drop all tokens up to and including the + last vision-end token so that only the text portion is returned. + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask). + """ + assert template_type == "multiple_images" + device = device or self._execution_device + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # If no image tokens are present, discard the image tensors. + if not any("\n" in p for p in prompt): + images = None + + prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt] + prompt = [template.format(p) for p in prompt] + + inputs = self.qwen_processor( + text=prompt, + images=images, + padding=True, + return_tensors="pt", + ).to(device) + + encoder_hidden_states = self.text_encoder(**inputs, output_hidden_states=True) + last_hidden_states = encoder_hidden_states.hidden_states[-1] + + if drop_vit_feature: + # Find the last vision-end token and drop everything before it. + input_ids = inputs["input_ids"] + vlm_image_end_idx = torch.where(input_ids[0] == 151653)[0][-1] + drop_idx = vlm_image_end_idx + 1 + + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:] + + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + + return prompt_embeds, prompt_embeds_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + images: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + template_type: str = "image", + drop_vit_feature: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode a text prompt (and optional inline images) into embeddings. + + When ``images`` is provided the multi-image encoding path is used; + otherwise the text-only Qwen tokenizer path is used. Pre-computed + ``prompt_embeds`` bypass encoding entirely. + + Args: + prompt: Prompt string or list of prompt strings. + images: Optional image tensors for multi-image conditioning. + device: Target device. + num_images_per_prompt: Number of outputs to generate per prompt. + prompt_embeds: Pre-computed prompt embeddings. + prompt_embeds_mask: Attention mask for pre-computed embeddings. + max_sequence_length: Maximum output sequence length. + template_type: Prompt template key (``"image"`` or ``"multiple_images"``). + drop_vit_feature: Drop vision tokens in the multi-image path. + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask). + """ + if images is not None: + return self.encode_prompt_multiple_images( + prompt=prompt, + images=images, + device=device, + max_sequence_length=max_sequence_length, + drop_vit_feature=drop_vit_feature, + ) + + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, template_type, device + ) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def decode_latents(self, latents: torch.Tensor, enable_tiling: bool = True) -> torch.Tensor: + """ + Decode latents to pixel values. + + .. deprecated:: 1.0.0 + Use the VAE directly instead of calling this method. + + Args: + latents: Latent tensor to decode. + enable_tiling: Whether to enable tiled decoding to save memory. + + Returns: + Float tensor of shape (..., H, W, C) with values in [0, 1]. + """ + deprecation_message = "The decode_latents method is deprecated." + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + if image.ndim == 4: + image = image.cpu().permute(0, 2, 3, 1).float() + else: + image = image.cpu().float() + return image + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + """ + Validate pipeline inputs before the forward pass. + + Raises: + ValueError: On any invalid combination of arguments. + """ + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError("`callback_on_step_end_tensor_inputs` has invalid keys.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.") + elif prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + elif prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError("`prompt` has to be of type `str` or `list`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.") + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError("If `prompt_embeds` are provided, `prompt_embeds_mask` is required.") + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` is required." + ) + + def normalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + """ + Normalise latents using per-channel statistics from the VAE config. + + Uses (latent - mean) / std when the VAE exposes ``latents_mean`` and + ``latents_std``; otherwise falls back to scaling by ``scaling_factor``. + + Args: + latent: Raw latent tensor from ``vae.encode``. + + Returns: + Normalised latent tensor. + """ + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latent = (latent - latents_mean) / latents_std + else: + latent = latent * self.vae.config.scaling_factor + return latent + + def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + """ + Invert :meth:`normalize_latents` to recover the original latent scale. + + Args: + latent: Normalised latent tensor. + + Returns: + Latent tensor in the scale expected by ``vae.decode``. + """ + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latent = latent * latents_std + latents_mean + else: + latent = latent / self.vae.config.scaling_factor + return latent + + def prepare_latents( + self, + batch_size: int, + num_items: int, + num_channels_latents: int, + height: int, + width: int, + video_length: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + latents: Optional[torch.Tensor] = None, + reference_images: Optional[List[Image.Image]] = None, + enable_denormalization: bool = True, + ) -> torch.Tensor: + """ + Prepare the initial noisy latent tensor for the denoising loop. + + When ``reference_images`` is provided the first (num_items - 1) slots are + filled with VAE-encoded reference image latents; the last slot is random noise. + When ``latents`` is provided it is moved to ``device`` without modification. + Otherwise pure random noise is returned. + + Args: + batch_size: Number of samples in the batch. + num_items: Number of image slots (reference + target). + num_channels_latents: Latent channel dimension from the transformer config. + height: Spatial height in pixels. + width: Spatial width in pixels. + video_length: Number of frames (1 for image inference). + dtype: Floating-point dtype for the latent tensor. + device: Target device. + generator: RNG generator(s) for reproducible sampling. + latents: Optional pre-allocated latent tensor. + reference_images: Optional list of PIL images to encode as conditioning. + enable_denormalization: Whether to normalise encoded reference latents. + + Returns: + Latent tensor of shape (B, num_items, C, T, H', W'). + + Raises: + ValueError: If ``generator`` is a list whose length differs from ``batch_size``. + """ + shape = ( + batch_size, + num_items, + num_channels_latents, + (video_length - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError("Generator list length must match batch size.") + + if latents is None: + if reference_images is not None: + # Encode reference images and concatenate with a noise slot. + ref_img = [torch.from_numpy(np.array(x.convert("RGB"))) for x in reference_images] + ref_img = torch.stack(ref_img).to(device=device, dtype=dtype) + ref_img = ref_img / 127.5 - 1.0 + ref_img = ref_img.permute(0, 3, 1, 2).unsqueeze(2) + ref_vae = self.vae.encode(ref_img).latent_dist.sample() + if enable_denormalization: + ref_vae = self.normalize_latents(ref_vae) + ref_vae = ref_vae.view(shape[0], num_items - 1, *ref_vae.shape[1:]) + noise = randn_tensor((shape[0], 1, *shape[2:]), generator=generator, device=device, dtype=dtype) + latents = torch.cat([ref_vae, noise], dim=1) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + # ------------------------------------------------------------------ + # Pipeline properties + # ------------------------------------------------------------------ + + @property + def guidance_scale(self) -> float: + """Classifier-free guidance scale used in the current forward pass.""" + return self._guidance_scale + + @property + def do_classifier_free_guidance(self) -> bool: + """True when guidance_scale > 1, enabling classifier-free guidance.""" + return self._guidance_scale > 1 + + @property + def num_timesteps(self) -> int: + """Total number of denoising timesteps in the current forward pass.""" + return self._num_timesteps + + @property + def interrupt(self) -> bool: + """When True, the denoising loop is interrupted at the next step.""" + return self._interrupt + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 4096, + drop_vit_feature: bool = False, + enable_denormalization: bool = True, + **kwargs, + ): + r""" + Generate an edited image conditioned on a reference image and a text prompt. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide generation. + height (`int`): + Height of the generated output in pixels. + width (`int`): + Width of the generated output in pixels. + image (`PipelineImageInput`, *optional*): + Reference image used for conditioning. When provided the pipeline + operates in image-editing mode with ``num_items=2``. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps. More steps generally improve quality at + the cost of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps for the denoising process. When provided, + ``num_inference_steps`` is inferred from the list length. + sigmas (`List[float]`, *optional*): + Custom sigmas for the denoising process. Mutually exclusive with + ``timesteps``. + guidance_scale (`float`, *optional*, defaults to 4.0): + Classifier-free guidance scale. + negative_prompt (`str` or `List[str]`, *optional*): + Negative prompt(s) used to suppress undesired content. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of generated samples per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + RNG generator(s) for deterministic sampling. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. Sampled from a Gaussian distribution + when not provided. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed prompt embeddings. When provided ``prompt`` can be omitted. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for ``prompt_embeds``. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative prompt embeddings. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for ``negative_prompt_embeds``. + output_type (`str`, *optional*, defaults to ``"pil"``): + Output format. Pass ``"latent"`` to return raw latents. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a :class:`JoyImageEditPipelineOutput` or a plain tensor. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + Callback invoked at the end of each denoising step with signature + ``(self, step: int, timestep: int, callback_kwargs: Dict)``. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*, defaults to ``["latents"]``): + Tensor keys included in ``callback_kwargs`` for ``callback_on_step_end``. + enable_tiling (`bool`, *optional*, defaults to `False`): + Enable tiled VAE decoding to reduce peak memory usage. + max_sequence_length (`int`, *optional*, defaults to 4096): + Maximum sequence length for prompt encoding. + drop_vit_feature (`bool`, *optional*, defaults to `False`): + Drop vision tokens in the multi-image encoding path. + enable_denormalization (`bool`, *optional*, defaults to `True`): + Denormalise latents before VAE decoding. + **kwargs: + Additional keyword arguments for forward compatibility. + + Examples: + + Returns: + [`~pipelines.joyimage.JoyImageEditPipelineOutput`] or `torch.Tensor`: + If ``return_dict`` is ``True``, returns a pipeline output object + containing the generated image(s). Otherwise returns the image + tensor directly. + """ + # Resize the input image to the nearest bucket resolution. + # Or resize the specified height and width to the nearest bucket resolution. + image_size = image[0].size if isinstance(image, list) else image.size + if height is not None and width is not None: + # Override the image size if specified. + image_size = (width, height) + + height, width = _dynamic_resize_from_bucket(image_size, basesize=1024) + processed_image = _resize_center_crop(image, (height, width)) + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # num_items: 1 for unconditional generation, 2 for reference-image editing. + num_items = 1 if image is None else 2 + + # Encode the conditioning prompt (and reference image when present). + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + images=processed_image, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + template_type="image", + drop_vit_feature=drop_vit_feature, + ) + + if self.do_classifier_free_guidance: + # Build default negative prompts when none are provided. + if negative_prompt is None and negative_prompt_embeds is None: + if num_items <= 1: + negative_prompt = ["<|im_start|>user\n<|im_end|>\n"] * batch_size + else: + negative_prompt = ["<|im_start|>user\n\n<|im_end|>\n"] * batch_size + + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + images=processed_image, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + template_type="image", + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + ) + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_items, + num_channels_latents, + height, + width, + 1, # video_length = 1 for image inference + prompt_embeds.dtype, + device, + generator, + latents, + reference_images=[processed_image], + enable_denormalization=enable_denormalization, + ) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # Cache reference latents to restore them at each denoising step. + if num_items > 1: + ref_latents = latents[:, :(num_items - 1)].clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Restore reference latents so they are never overwritten by the scheduler. + if num_items > 1: + latents[:, :(num_items - 1)] = ref_latents.clone() + + latent_model_input = latents + t_expand = t.repeat(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + return_dict=False, + )[0] + + comb_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond) + # Rescale to match the conditional prediction norm (guidance rescaling). + cond_norm = torch.norm(noise_pred, dim=2, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=2, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm.clamp_min(1e-6)) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + if progress_bar is not None: + progress_bar.update() + + if output_type != "latent": + latents = latents.flatten(0, 1) + if enable_denormalization: + latents = self.denormalize_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = image.unflatten(0, (batch_size, -1)) + else: + image = latents + + # Extract the last item (target slot) from the batch, shape: (F, C, H, W). + image = image.float().permute(0, 1, 3, 2, 4, 5)[0, -1] + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return JoyImageEditPipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py new file mode 100644 index 000000000000..a98b9066c69a --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +import torch + +from ...utils import BaseOutput +import PIL.Image +from typing import Union, List, Tuple, Optional +import numpy as np + +@dataclass +class JoyImageEditPipelineOutput(BaseOutput): + """ + Output class for JoyImageEdit generation pipelines. + """ + + images: Union[List[PIL.Image.Image], np.ndarray]