diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6703c9299e80..acf6e5ede3bb 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -491,6 +491,8 @@ title: AnimateDiff - local: api/pipelines/aura_flow title: AuraFlow + - local: api/pipelines/boogu + title: Boogu-Image - local: api/pipelines/bria_3_2 title: Bria 3.2 - local: api/pipelines/bria_fibo diff --git a/docs/source/en/api/pipelines/boogu.md b/docs/source/en/api/pipelines/boogu.md new file mode 100644 index 000000000000..ca214f5d9c88 --- /dev/null +++ b/docs/source/en/api/pipelines/boogu.md @@ -0,0 +1,153 @@ + + +# Boogu-Image + +## Overview + +Boogu-Image is an instruction-driven image generation and editing model. Rather than a +plain text prompt, it is conditioned on a natural-language *instruction* that is encoded +by a Qwen3-VL multimodal LLM, which can also attend to optional reference images. A +single/double-stream transformer denoiser then predicts the latent updates, and a +flow-matching scheduler with training-aligned time shifting controls the denoising +trajectory. The VAE maps between image and latent space. + +The model is released in several variants: + +- **Base** (`Boogu/Boogu-Image-0.1-Base`) — text-to-image, full sampling schedule. +- **Turbo** (`Boogu/Boogu-Image-0.1-Turbo`) — DMD student model for few-step + text-to-image generation. +- **Edit** (`Boogu/Boogu-Image-0.1-Edit`) — instruction-based image editing conditioned + on one or more reference images. + +FP8-quantized checkpoints are also available for each variant (the `-fp8` suffix). + +There are two pipeline classes: + +- [`BooguImagePipeline`] — text-to-image and instruction editing. +- [`BooguImageTurboPipeline`] — a subclass adding the DMD few-step inference path. It + defaults the guidance scales to the DMD-required values (`text_guidance_scale=1.0`, + `image_guidance_scale=1.0`, `empty_instruction_guidance_scale=0.0`). + +## Usage examples + +### Text-to-image + +```python +import torch +from diffusers.pipelines.boogu import BooguImagePipeline + +pipe = BooguImagePipeline.from_pretrained("Boogu/Boogu-Image-0.1-Base", torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +image = pipe( + instruction="A serene Chinese ink-wash landscape of the Guilin mountains bathed in golden light, layered peaks, mirror-like river, glowing golden contours.", + height=1024, + width=1024, + num_inference_steps=50, + text_guidance_scale=4.0, +).images[0] + +image.save("base.png") +``` + +### Few-step generation (Turbo) + +```python +import torch +from diffusers.pipelines.boogu import BooguImageTurboPipeline + +pipe = BooguImageTurboPipeline.from_pretrained("Boogu/Boogu-Image-0.1-Turbo", torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +image = pipe( + instruction="A serene Chinese ink-wash landscape of the Guilin mountains bathed in golden light.", + height=1024, + width=1024, + num_inference_steps=4, +).images[0] + +image.save("turbo.png") +``` + +### Instruction-based editing + +Pass one or more reference images through `input_images`: + +```python +import torch +from PIL import Image +from diffusers.pipelines.boogu import BooguImagePipeline + +pipe = BooguImagePipeline.from_pretrained("Boogu/Boogu-Image-0.1-Edit", torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +image = pipe( + instruction="Turn the image into a colored-pencil illustration.", + input_images=[Image.open("base.png").convert("RGB")], + height=1024, + width=1024, + num_inference_steps=50, + text_guidance_scale=4.0, + image_guidance_scale=1.0, +).images[0] + +image.save("edit.png") +``` + +### FP8 checkpoints + +FP8 weights are stored in a non-safetensors format, so load the transformer separately +with `use_safetensors=False` and pass it to the pipeline: + +```python +import torch +from diffusers import BooguImageTransformer2DModel +from diffusers.pipelines.boogu import BooguImagePipeline + +transformer = BooguImageTransformer2DModel.from_pretrained( + "Boogu/Boogu-Image-0.1-Base-fp8", + subfolder="transformer", + torch_dtype=torch.bfloat16, + use_safetensors=False, +) +pipe = BooguImagePipeline.from_pretrained( + "Boogu/Boogu-Image-0.1-Base-fp8", torch_dtype=torch.bfloat16, transformer=transformer +) +pipe = pipe.to("cuda") +``` + +Runnable scripts for every variant are available in +[`examples/boogu`](https://github.com/huggingface/diffusers/tree/main/examples/boogu). + +> [!TIP] +> The transformer uses fused `triton` (RMSNorm) and `flash_attn` (SwiGLU, variable-length +> attention) kernels when they are installed, and falls back to pure PyTorch otherwise. + +## BooguImagePipeline + +[[autodoc]] pipelines.boogu.pipeline_boogu.BooguImagePipeline + - all + - __call__ + +## BooguImageTurboPipeline + +[[autodoc]] pipelines.boogu.pipeline_boogu_turbo.BooguImageTurboPipeline + - all + - __call__ + +## FMPipelineOutput + +[[autodoc]] pipelines.boogu.pipeline_boogu.FMPipelineOutput diff --git a/examples/boogu/README.md b/examples/boogu/README.md new file mode 100644 index 000000000000..9f2bfeb7daa4 --- /dev/null +++ b/examples/boogu/README.md @@ -0,0 +1,81 @@ +# Boogu-Image + +[Boogu-Image](https://huggingface.co/Boogu) is an instruction-driven image generation and editing model. It pairs a Qwen3-VL multimodal LLM (instruction encoder) with a single/double-stream transformer denoiser and a flow-matching scheduler with training-aligned time shifting. + +This directory contains minimal inference scripts for the released checkpoints. + +## Environment installation +[Boogu-Image-quick-start](https://github.com/boogu-project/Boogu-Image/blob/main/quick_start.sh) + +## Pipelines + +| Pipeline | Class | Use case | +|---|---|---| +| Base | `BooguImagePipeline` | Text-to-image (50 steps) | +| Turbo | `BooguImageTurboPipeline` | Few-step DMD text-to-image (4 steps) | +| Edit | `BooguImagePipeline` | Instruction-based image editing (pass `input_images`) | + +## Scripts + +| Script | Checkpoint | +|---|---| +| `inference_base.py` | `Boogu/Boogu-Image-0.1-Base` | +| `inference_turbo.py` | `Boogu/Boogu-Image-0.1-Turbo` | +| `inference_edit.py` | `Boogu/Boogu-Image-0.1-Edit` | +| `inference_base_fp8.py` | `Boogu/Boogu-Image-0.1-Base-fp8` | +| `inference_turbo_fp8.py` | `Boogu/Boogu-Image-0.1-Turbo-fp8` | +| `inference_edit_fp8.py` | `Boogu/Boogu-Image-0.1-Edit-fp8` | + +## Usage + +Text-to-image: + +```bash +python inference_base.py +``` + +Few-step (Turbo): + +```bash +python inference_turbo.py +``` + +Image editing (reads `base.png` as the reference image, so run `inference_base.py` first): + +```bash +python inference_edit.py +``` + +## FP8 checkpoints + +FP8 weights are stored in a non-safetensors format, so the transformer is loaded +separately with `use_safetensors=False` and passed to the pipeline: + +```python +import torch +from diffusers import BooguImageTransformer2DModel +from diffusers.pipelines.boogu import BooguImagePipeline + +transformer = BooguImageTransformer2DModel.from_pretrained( + "Boogu/Boogu-Image-0.1-Base-fp8", + subfolder="transformer", + torch_dtype=torch.bfloat16, + use_safetensors=False, +) +pipe = BooguImagePipeline.from_pretrained( + "Boogu/Boogu-Image-0.1-Base-fp8", torch_dtype=torch.bfloat16, transformer=transformer +) +pipe = pipe.to("cuda") +``` + +The FP8 scripts also disable the DeepGEMM kernel for the FP8 VLM (forcing a Triton +finegrained-fp8 fallback) for broader hardware compatibility — see +`_disable_deepgemm_for_fp8_vlm()` in each FP8 script. + +## Optional performance dependencies + +The transformer can use fused kernels when available; without them it falls back to +pure PyTorch and prints a one-time warning: + +- `triton` — fused RMSNorm +- `flash_attn` — fused SwiGLU and variable-length flash attention diff --git a/examples/boogu/inference_base.py b/examples/boogu/inference_base.py new file mode 100644 index 000000000000..dfd7631ce4a6 --- /dev/null +++ b/examples/boogu/inference_base.py @@ -0,0 +1,20 @@ +import torch + +from diffusers.pipelines.boogu import BooguImagePipeline + + +MODEL_PATH = "Boogu/Boogu-Image-0.1-Base" + +pipe = BooguImagePipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +images = pipe( + instruction="一幅国风琉金风格的山水画作,展现了桂林山水在金光普照下的壮丽景象。远山层叠,江水如镜,山峰边缘勾勒着发光的金色线条。画面采用石青石绿岩彩与鎏金质感相结合,局部有厚涂油画笔触,空中飘浮着金色粒子,营造出梦幻朦胧而又磅礴大气的意境。", + height=1024, + width=1024, + num_inference_steps=50, + text_guidance_scale=4.0, +).images + +images[0].save("base.png") +print("Inference OK, saved base.png") diff --git a/examples/boogu/inference_base_fp8.py b/examples/boogu/inference_base_fp8.py new file mode 100644 index 000000000000..faa47f5d879c --- /dev/null +++ b/examples/boogu/inference_base_fp8.py @@ -0,0 +1,52 @@ +import os + +import torch + +from diffusers import BooguImageTransformer2DModel +from diffusers.pipelines.boogu import BooguImagePipeline + + +def _disable_deepgemm_for_fp8_vlm() -> None: + # For transformers >= 5.11.0 + os.environ["TRANSFORMERS_DISABLE_DEEPGEMM_LINEAR"] = "1" + + try: + import transformers.integrations.finegrained_fp8 as fg_fp8 + except Exception: + return + + def _raise_import_error(*args, **kwargs): + raise ImportError("DeepGEMM disabled; forcing Triton finegrained-fp8 fallback.") + + if hasattr(fg_fp8, "deepgemm_fp8_fp4_linear"): + # For 5.10.1 <= transformers < 5.11.0 + fg_fp8.deepgemm_fp8_fp4_linear = _raise_import_error + elif hasattr(fg_fp8, "_load_deepgemm_kernel"): + # For 5.5.0 <= transoformers < 5.10.1 + fg_fp8._load_deepgemm_kernel = _raise_import_error + + +_disable_deepgemm_for_fp8_vlm() + +MODEL_PATH = "Boogu/Boogu-Image-0.1-Base-fp8" + +transformer = BooguImageTransformer2DModel.from_pretrained( + MODEL_PATH, + subfolder="transformer", + torch_dtype=torch.bfloat16, + use_safetensors=False, +) +pipe = BooguImagePipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, transformer=transformer) +pipe = pipe.to("cuda") + +images = pipe( + instruction="一幅国风琉金风格的山水画作,展现了桂林山水在金光普照下的壮丽景象。远山层叠,江水如镜,山峰边缘勾勒着发光的金色线条。画面采用石青石绿岩彩与鎏金质感相结合,局部有厚涂油画笔触,空中飘浮着金色粒子,营造出梦幻朦胧而又磅礴大气的意境。", + height=1024, + width=1024, + num_inference_steps=50, + text_guidance_scale=4.0, +).images + +assert len(images) == 1 +images[0].save("base_fp8.png") +print("Inference OK, saved base_fp8.png") diff --git a/examples/boogu/inference_edit.py b/examples/boogu/inference_edit.py new file mode 100644 index 000000000000..ad6b7fcf3c08 --- /dev/null +++ b/examples/boogu/inference_edit.py @@ -0,0 +1,38 @@ +import os + +import torch +from PIL import Image + +from diffusers.pipelines.boogu import BooguImagePipeline + + +MODEL_PATH = "Boogu/Boogu-Image-0.1-Edit" + +# Negative prompt steering quality away from common artifacts. With text_guidance_scale > 1 +# the model guides away from this prompt, so it noticeably improves style adherence. +NEGATIVE_INSTRUCTION = ( + "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, " + "mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, " + "broken legs censor, censored, censor_bar" +) + +if not os.path.exists("base.png"): + raise FileNotFoundError("base.png not found — run inference_base.py first to generate the reference image.") + +pipe = BooguImagePipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +images = pipe( + instruction="把图片风格调整为彩铅插画。", + negative_instruction=NEGATIVE_INSTRUCTION, + input_images=[Image.open("base.png").convert("RGB")], + height=1024, + width=1024, + num_inference_steps=50, + text_guidance_scale=4.0, + image_guidance_scale=1.0, +).images + +assert len(images) == 1 +images[0].save("edit.png") +print("Inference OK, saved edit.png") diff --git a/examples/boogu/inference_edit_fp8.py b/examples/boogu/inference_edit_fp8.py new file mode 100644 index 000000000000..c1d3d02731cb --- /dev/null +++ b/examples/boogu/inference_edit_fp8.py @@ -0,0 +1,67 @@ +import os + +import torch +from PIL import Image + +from diffusers import BooguImageTransformer2DModel +from diffusers.pipelines.boogu import BooguImagePipeline + + +def _disable_deepgemm_for_fp8_vlm() -> None: + # For transformers >= 5.11.0 + os.environ["TRANSFORMERS_DISABLE_DEEPGEMM_LINEAR"] = "1" + + try: + import transformers.integrations.finegrained_fp8 as fg_fp8 + except Exception: + return + + def _raise_import_error(*args, **kwargs): + raise ImportError("DeepGEMM disabled; forcing Triton finegrained-fp8 fallback.") + + if hasattr(fg_fp8, "deepgemm_fp8_fp4_linear"): + # For 5.10.1 <= transformers < 5.11.0 + fg_fp8.deepgemm_fp8_fp4_linear = _raise_import_error + elif hasattr(fg_fp8, "_load_deepgemm_kernel"): + # For 5.5.0 <= transoformers < 5.10.1 + fg_fp8._load_deepgemm_kernel = _raise_import_error + + +_disable_deepgemm_for_fp8_vlm() + +MODEL_PATH = "Boogu/Boogu-Image-0.1-Edit-fp8" + +# Negative prompt steering quality away from common artifacts. With text_guidance_scale > 1 +# the model guides away from this prompt, so it noticeably improves style adherence. +NEGATIVE_INSTRUCTION = ( + "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, " + "mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, " + "broken legs censor, censored, censor_bar" +) + +if not os.path.exists("base.png"): + raise FileNotFoundError("base.png not found — run inference_base.py first to generate the reference image.") + +transformer = BooguImageTransformer2DModel.from_pretrained( + MODEL_PATH, + subfolder="transformer", + torch_dtype=torch.bfloat16, + use_safetensors=False, +) +pipe = BooguImagePipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, transformer=transformer) +pipe = pipe.to("cuda") + +images = pipe( + instruction="把图片风格调整为彩铅插画。", + negative_instruction=NEGATIVE_INSTRUCTION, + input_images=[Image.open("base.png").convert("RGB")], + height=1024, + width=1024, + num_inference_steps=50, + text_guidance_scale=4.0, + image_guidance_scale=1.0, +).images + +assert len(images) == 1 +images[0].save("edit_fp8.png") +print("Inference OK, saved edit_fp8.png") diff --git a/examples/boogu/inference_turbo.py b/examples/boogu/inference_turbo.py new file mode 100644 index 000000000000..99311356ee4c --- /dev/null +++ b/examples/boogu/inference_turbo.py @@ -0,0 +1,20 @@ +import torch + +from diffusers.pipelines.boogu import BooguImageTurboPipeline + + +MODEL_PATH = "Boogu/Boogu-Image-0.1-Turbo" + +pipe = BooguImageTurboPipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +images = pipe( + instruction="一幅国风琉金风格的山水画作,展现了桂林山水在金光普照下的壮丽景象。远山层叠,江水如镜,山峰边缘勾勒着发光的金色线条。画面采用石青石绿岩彩与鎏金质感相结合,局部有厚涂油画笔触,空中飘浮着金色粒子,营造出梦幻朦胧而又磅礴大气的意境。", + height=1024, + width=1024, + num_inference_steps=4, +).images + +assert len(images) == 1 +images[0].save("turbo.png") +print("Inference OK, saved turbo.png") diff --git a/examples/boogu/inference_turbo_fp8.py b/examples/boogu/inference_turbo_fp8.py new file mode 100644 index 000000000000..90f8385d33ae --- /dev/null +++ b/examples/boogu/inference_turbo_fp8.py @@ -0,0 +1,51 @@ +import os + +import torch + +from diffusers import BooguImageTransformer2DModel +from diffusers.pipelines.boogu import BooguImageTurboPipeline + + +def _disable_deepgemm_for_fp8_vlm() -> None: + # For transformers >= 5.11.0 + os.environ["TRANSFORMERS_DISABLE_DEEPGEMM_LINEAR"] = "1" + + try: + import transformers.integrations.finegrained_fp8 as fg_fp8 + except Exception: + return + + def _raise_import_error(*args, **kwargs): + raise ImportError("DeepGEMM disabled; forcing Triton finegrained-fp8 fallback.") + + if hasattr(fg_fp8, "deepgemm_fp8_fp4_linear"): + # For 5.10.1 <= transformers < 5.11.0 + fg_fp8.deepgemm_fp8_fp4_linear = _raise_import_error + elif hasattr(fg_fp8, "_load_deepgemm_kernel"): + # For 5.5.0 <= transoformers < 5.10.1 + fg_fp8._load_deepgemm_kernel = _raise_import_error + + +_disable_deepgemm_for_fp8_vlm() + +MODEL_PATH = "Boogu/Boogu-Image-0.1-Turbo-fp8" + +transformer = BooguImageTransformer2DModel.from_pretrained( + MODEL_PATH, + subfolder="transformer", + torch_dtype=torch.bfloat16, + use_safetensors=False, +) +pipe = BooguImageTurboPipeline.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, transformer=transformer) +pipe = pipe.to("cuda") + +images = pipe( + instruction="一幅国风琉金风格的山水画作,展现了桂林山水在金光普照下的壮丽景象。远山层叠,江水如镜,山峰边缘勾勒着发光的金色线条。画面采用石青石绿岩彩与鎏金质感相结合,局部有厚涂油画笔触,空中飘浮着金色粒子,营造出梦幻朦胧而又磅礴大气的意境。", + height=1024, + width=1024, + num_inference_steps=4, +).images + +assert len(images) == 1 +images[0].save("turbo_fp8.png") +print("Inference OK, saved turbo_fp8.png") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6353347503e1..3c5051277c21 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -263,6 +263,7 @@ "FluxMultiControlNetModel", "FluxTransformer2DModel", "GlmImageTransformer2DModel", + "BooguImageTransformer2DModel", "HeliosTransformer3DModel", "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", @@ -597,6 +598,8 @@ "FluxPipeline", "FluxPriorReduxPipeline", "GlmImagePipeline", + "BooguImagePipeline", + "BooguImageTurboPipeline", "HeliosPipeline", "HeliosPyramidPipeline", "HiDreamImagePipeline", @@ -1095,6 +1098,7 @@ AutoencoderTiny, AutoencoderVidTok, AutoModel, + BooguImageTransformer2DModel, BriaFiboTransformer2DModel, BriaTransformer2DModel, CacheMixin, @@ -1382,6 +1386,8 @@ AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, + BooguImagePipeline, + BooguImageTurboPipeline, BriaFiboEditPipeline, BriaFiboPipeline, BriaPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7a1d0801f2c5..da7b01128564 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -142,6 +142,7 @@ _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"] + _import_structure["transformers.transformer_boogu"] = ["BooguImageTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -227,6 +228,7 @@ AnyFlowFARTransformer3DModel, AnyFlowTransformer3DModel, AuraFlowTransformer2DModel, + BooguImageTransformer2DModel, BriaFiboTransformer2DModel, BriaTransformer2DModel, ChromaTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 1edceee3ca74..b5d0aa99d3e7 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -21,6 +21,7 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_anyflow import AnyFlowTransformer3DModel from .transformer_anyflow_far import AnyFlowFARTransformer3DModel + from .transformer_boogu import BooguImageTransformer2DModel from .transformer_bria import BriaTransformer2DModel from .transformer_bria_fibo import BriaFiboTransformer2DModel from .transformer_chroma import ChromaTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_boogu.py b/src/diffusers/models/transformers/transformer_boogu.py new file mode 100644 index 000000000000..f0bc09e3581e --- /dev/null +++ b/src/diffusers/models/transformers/transformer_boogu.py @@ -0,0 +1,1885 @@ +""" +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import itertools +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import RMSNorm + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.teacache_util import TeaCacheParams + + +logger = logging.get_logger(__name__) + + +# ----------------------------- RoPE ----------------------------- +class BooguImageRotaryPosEmbed: + """Namespace for Boogu's rotary-position-embedding frequency table. + + Only the static `get_freqs_cis` is used (by the pipeline and the transformer's + internal double-stream RoPE); it does not hold any state. + """ + + @staticmethod + def get_freqs_cis( + axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int + ) -> List[torch.Tensor]: + freqs_cis = [] + freqs_dtype = torch.float32 + for d, e in zip(axes_dim, axes_lens): + emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + +class BooguImageDoubleStreamRotaryPosEmbed(nn.Module): + def __init__( + self, + theta: int, + axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int] = (300, 512, 512), + patch_size: int = 2, + ): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + @staticmethod + def get_freqs_cis( + axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int + ) -> List[torch.Tensor]: + freqs_cis = [] + freqs_dtype = torch.float32 + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + result = [] + for i in range(len(self.axes_dim)): + freqs = freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1).to(device) + + def forward( + self, + freqs_cis, + attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ): + batch_size = len(attention_mask) + p = self.patch_size + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + + seq_lengths = [ + cap_len + sum(ref_img_len) + img_len + for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len) + ] + + max_seq_len = max(seq_lengths) + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # Create position IDs + position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + # add text position ids + position_ids[i, :cap_seq_len] = ( + torch.arange(cap_seq_len, dtype=torch.int32, device=device).unsqueeze(1).expand(-1, 3) + ) + + pe_shift = cap_seq_len + pe_shift_len = cap_seq_len + + if ref_img_sizes[i] is not None: + for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): + H, W = ref_img_size + ref_H_tokens, ref_W_tokens = H // p, W // p + if ref_H_tokens * ref_W_tokens != ref_img_len: + raise ValueError( + f"Reference image token count mismatch: {ref_H_tokens * ref_W_tokens} != {ref_img_len}." + ) + # add image position ids + + row_ids = ( + torch.arange(ref_H_tokens, dtype=torch.int32, device=device) + .unsqueeze(1) + .expand(ref_H_tokens, ref_W_tokens) + .flatten() + ) + col_ids = ( + torch.arange(ref_W_tokens, dtype=torch.int32, device=device) + .unsqueeze(0) + .expand(ref_H_tokens, ref_W_tokens) + .flatten() + ) + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 0] = pe_shift + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 1] = row_ids + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 2] = col_ids + + pe_shift += max(ref_H_tokens, ref_W_tokens) + pe_shift_len += ref_img_len + + H, W = img_sizes[i] + H_tokens, W_tokens = H // p, W // p + if H_tokens * W_tokens != l_effective_img_len[i]: + raise ValueError(f"Image token count mismatch: {H_tokens * W_tokens} != {l_effective_img_len[i]}.") + + row_ids = ( + torch.arange(H_tokens, dtype=torch.int32, device=device) + .unsqueeze(1) + .expand(H_tokens, W_tokens) + .flatten() + ) + col_ids = ( + torch.arange(W_tokens, dtype=torch.int32, device=device) + .unsqueeze(0) + .expand(H_tokens, W_tokens) + .flatten() + ) + + if pe_shift_len + l_effective_img_len[i] != seq_len: + raise ValueError( + f"RoPE position length mismatch: {pe_shift_len + l_effective_img_len[i]} != {seq_len}." + ) + position_ids[i, pe_shift_len:seq_len, 0] = pe_shift + position_ids[i, pe_shift_len:seq_len, 1] = row_ids + position_ids[i, pe_shift_len:seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, + encoder_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + ref_img_freqs_cis = torch.zeros( + batch_size, + max_ref_img_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + img_freqs_cis = torch.zeros( + batch_size, + max_img_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + + # Calculate combined image sequence lengths (ref_img + img) for each sample + combined_img_seq_lengths = [ + sum(ref_img_len) + img_len for ref_img_len, img_len in zip(l_effective_ref_img_len, l_effective_img_len) + ] + max_combined_img_len = max(combined_img_seq_lengths) + + # Create combined image rotary embeddings + combined_img_freqs_cis = torch.zeros( + batch_size, + max_combined_img_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + + for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate( + zip( + l_effective_cap_len, + l_effective_ref_img_len, + l_effective_img_len, + seq_lengths, + ) + ): + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + ref_img_freqs_cis[i, : sum(ref_img_len)] = freqs_cis[i, cap_seq_len : cap_seq_len + sum(ref_img_len)] + img_freqs_cis[i, :img_len] = freqs_cis[ + i, + cap_seq_len + sum(ref_img_len) : cap_seq_len + sum(ref_img_len) + img_len, + ] + + # Combined image rotary embeddings: ref_img + img (same order as img_patch_embed_and_refine) + combined_img_freqs_cis[i, : sum(ref_img_len)] = freqs_cis[i, cap_seq_len : cap_seq_len + sum(ref_img_len)] + combined_img_freqs_cis[i, sum(ref_img_len) : sum(ref_img_len) + img_len] = freqs_cis[ + i, + cap_seq_len + sum(ref_img_len) : cap_seq_len + sum(ref_img_len) + img_len, + ] + + return ( + cap_freqs_cis, + ref_img_freqs_cis, + img_freqs_cis, + freqs_cis, + l_effective_cap_len, + seq_lengths, + combined_img_freqs_cis, + combined_img_seq_lengths, + ) + + +# --------------- Norm / FeedForward / Embedding ---------------- +def _torch_swiglu(x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y + + +swiglu = _torch_swiglu +torch_swiglu = _torch_swiglu + + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, + embedding_dim: int, + norm_eps: float, + norm_elementwise_affine: bool, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + + self.norm = RMSNorm(embedding_dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + return x, gate_msa, scale_mlp, gate_mlp + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + self.swiglu = swiglu + + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + + def forward(self, x): + h1, h2 = self.linear_1(x), self.linear_3(x) + swiglu_fn = torch_swiglu if torch.compiler.is_compiling() else self.swiglu + return self.linear_2(swiglu_fn(h1, h2)) + + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + instruction_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + timestep_scale: float = 1.0, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + scale=timestep_scale, + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(instruction_feat_dim, eps=norm_eps), + nn.Linear(instruction_feat_dim, hidden_size, bias=True), + ) + + self._initialize_weights() + + def _initialize_weights(self): + nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) + nn.init.zeros_(self.caption_embedder[1].bias) + + def forward( + self, + timestep: torch.Tensor, + instruction_hidden_states: torch.Tensor, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(instruction_hidden_states) + return time_embed, caption_embed + + +# ----------------------- Attention processors ------------------ +def apply_rotary_emb(x, freqs_cis, use_real=True, **kwargs): + # use_real=True path delegates to the shared diffusers implementation. + # use_real=False (Lumina-style) uses explicit dim to handle 0-element tensors. + if use_real: + from diffusers.models.embeddings import apply_rotary_emb as _apply + + return _apply(x, freqs_cis, use_real=True, **kwargs) + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + return torch.view_as_real(x_rotated * freqs_cis).flatten(3).type_as(x) + + +def _prepare_attn_mask(attention_mask: Optional[torch.Tensor], batch_size: int) -> Optional[torch.Tensor]: + """Reshape a bool padding mask ``[B, L]`` to the ``[B, 1, 1, L]`` form `dispatch_attention_fn` expects. + + The mask is always materialized (not dropped to ``None`` when no token is masked): + the native backend rounds bf16 differently on its masked vs no-mask paths, and the + Boogu checkpoints were trained with the mask applied. + """ + if attention_mask is None: + return None + return attention_mask.bool().view(batch_size, 1, 1, -1) + + +class BooguImageDoubleStreamSelfAttnProcessor(nn.Module): + """ + Double-stream self-attention processor. + + Instruction and image features are projected separately, concatenated + (instruction first, then image) into a joint sequence, attended jointly via + [`dispatch_attention_fn`], then split back so each stream gets its own output + projection. The QKV / output projections live on this processor module, so the + checkpoint keys are ``...processor.img_to_q`` / ``...processor.instruct_to_q`` / + ``...processor.img_out`` / ``...processor.instruct_out``. + + Args: + head_dim: Dimension of each attention head + num_attention_heads: Number of attention heads for queries + num_kv_heads: Number of key-value heads + qkv_bias: Whether to use bias in QKV linear layers + """ + + _attention_backend = None + _parallel_config = None + + def __init__( + self, + head_dim: int, + num_attention_heads: int, + num_kv_heads: int, + qkv_bias: bool = False, + ) -> None: + """Initialize the double-stream attention processor.""" + super().__init__() + + self.head_dim = head_dim + self.num_attention_heads = num_attention_heads + self.num_kv_heads = num_kv_heads + + query_dim = head_dim * num_attention_heads + kv_dim = head_dim * num_kv_heads + + # Separate Q/K/V projections for instruction and image streams. + # Query uses num_attention_heads, Key/Value use num_kv_heads. + self.img_to_q = nn.Linear(query_dim, query_dim, bias=qkv_bias) + self.img_to_k = nn.Linear(query_dim, kv_dim, bias=qkv_bias) + self.img_to_v = nn.Linear(query_dim, kv_dim, bias=qkv_bias) + + self.instruct_to_q = nn.Linear(query_dim, query_dim, bias=qkv_bias) + self.instruct_to_k = nn.Linear(query_dim, kv_dim, bias=qkv_bias) + self.instruct_to_v = nn.Linear(query_dim, kv_dim, bias=qkv_bias) + + # Separate output projections for instruction and image streams. + self.instruct_out = nn.Linear(query_dim, query_dim, bias=qkv_bias) + self.img_out = nn.Linear(query_dim, query_dim, bias=qkv_bias) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """Xavier-uniform init for the projection weights, zeros for any biases.""" + for proj in ( + self.img_to_q, + self.img_to_k, + self.img_to_v, + self.instruct_to_q, + self.instruct_to_k, + self.instruct_to_v, + self.instruct_out, + self.img_out, + ): + nn.init.xavier_uniform_(proj.weight) + if proj.bias is not None: + nn.init.zeros_(proj.bias) + + def _concat_instruction_image_features( + self, + img_hidden_states_list: List[torch.Tensor], + instruct_hidden_states_list: List[torch.Tensor], + encoder_seq_lengths: List[int], + seq_lengths: List[int], + ) -> List[torch.Tensor]: + """ + Concatenate instruction (text & image) and reference image features (instruction first, then image). + + Args: + img_hidden_states_list: List of image tensors [img_query, img_key, img_value] + instruct_hidden_states_list: List of instruction tensors [instruct_query, instruct_key, instruct_value] + encoder_seq_lengths: Instruction sequence lengths for each sample [B] + seq_lengths: Total sequence lengths for each sample [B] + + Returns: + List of concatenated tensors [query, key, value] + """ + if len(img_hidden_states_list) != len(instruct_hidden_states_list): + raise ValueError( + f"Length mismatch: img_list={len(img_hidden_states_list)}, " + f"instruct_list={len(instruct_hidden_states_list)}" + ) + + batch_size = img_hidden_states_list[0].shape[0] + max_seq_len = max(seq_lengths) + + concatenated_list = [] + + for img_tensor, instruct_tensor in zip(img_hidden_states_list, instruct_hidden_states_list): + # Ensure tensors are on the same device + device = img_tensor.device + if instruct_tensor.device != device: + instruct_tensor = instruct_tensor.to(device) + + # Create output tensor with proper shape [B, max_seq_len, feature_dim] + feature_dim = img_tensor.shape[-1] + concatenated = img_tensor.new_zeros(batch_size, max_seq_len, feature_dim) + + # Concatenate instruction first, then image for each sample + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + # Place instruction tokens first + concatenated[i, :encoder_seq_len] = instruct_tensor[i, :encoder_seq_len] + # Place image tokens after instruction + concatenated[i, encoder_seq_len:seq_len] = img_tensor[i, : seq_len - encoder_seq_len] + + concatenated_list.append(concatenated) + + return concatenated_list + + def _split_instruction_image_features( + self, + hidden_states_list: List[torch.Tensor], + encoder_seq_lengths: List[int], + seq_lengths: List[int], + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """ + Split concatenated features back to instruction and image features. + Inverse operation of _concat_instruction_image_features. + + Args: + hidden_states_list: List of concatenated tensors (usually just one element) + encoder_seq_lengths: Instruction sequence lengths for each sample [B] + seq_lengths: Total sequence lengths for each sample [B] + + Returns: + List of tuples, each containing (instruct_hidden_states, img_hidden_states) + """ + result_list = [] + + for hidden_states in hidden_states_list: + batch_size = hidden_states.shape[0] + feature_dim = hidden_states.shape[-1] + + # Get maximum lengths for instruction and image + max_instruct_len = max(encoder_seq_lengths) + max_img_len = max( + seq_len - encoder_seq_len for seq_len, encoder_seq_len in zip(seq_lengths, encoder_seq_lengths) + ) + + # Create output tensors [B, max_len, feature_dim] + instruct_hidden_states = hidden_states.new_zeros(batch_size, max_instruct_len, feature_dim) + img_hidden_states = hidden_states.new_zeros(batch_size, max_img_len, feature_dim) + + # Split each sample back to instruction and image + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + img_len = seq_len - encoder_seq_len + + # Extract instruction portion + instruct_hidden_states[i, :encoder_seq_len] = hidden_states[i, :encoder_seq_len] + # Extract image portion + img_hidden_states[i, :img_len] = hidden_states[i, encoder_seq_len:seq_len] + + result_list.append((instruct_hidden_states, img_hidden_states)) + + return result_list + + def __call__( + self, + attn: Attention, + img_hidden_states: torch.Tensor, + instruct_hidden_states: torch.Tensor, + joint_attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + encoder_seq_lengths: List[int] = None, # [B] - Instruction sequence lengths for each sample + seq_lengths: List[int] = None, # [B] - Total sequence lengths for each sample + ) -> torch.Tensor: + """ + Process double-stream self-attention. + + Args: + attn: Attention module + img_hidden_states: Image hidden states tensor [B, L_img, D] + instruct_hidden_states: Instruction hidden states tensor [B, L_instruct, D] + joint_attention_mask: Combined padding mask [B, L_total] + rotary_emb: Rotary embeddings for the joint sequence + encoder_seq_lengths: Instruction sequence lengths for each sample [B] + seq_lengths: Total sequence lengths for each sample [B] + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size = img_hidden_states.shape[0] + + # Generate Q, K, V for image and instruction streams (NO head reshaping yet) + img_query = self.img_to_q(img_hidden_states) # [B, L_img, query_dim] + img_key = self.img_to_k(img_hidden_states) # [B, L_img, kv_dim] + img_value = self.img_to_v(img_hidden_states) # [B, L_img, kv_dim] + + instruct_query = self.instruct_to_q(instruct_hidden_states) # [B, L_instruct, query_dim] + instruct_key = self.instruct_to_k(instruct_hidden_states) # [B, L_instruct, kv_dim] + instruct_value = self.instruct_to_v(instruct_hidden_states) # [B, L_instruct, kv_dim] + + # Concatenate QKV across streams (instruction first, then image) + img_list = [img_query, img_key, img_value] # [B, L_img, feature_dim] each + instruct_list = [instruct_query, instruct_key, instruct_value] # [B, L_instruct, feature_dim] each + query, key, value = self._concat_instruction_image_features( + img_list, instruct_list, encoder_seq_lengths, seq_lengths + ) # [B, max_seq_len, feature_dim] each + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + dtype = query.dtype + + # Reshape to [B, L, H, head_dim] (the layout dispatch_attention_fn expects) + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if rotary_emb is not None: + query = apply_rotary_emb(query, rotary_emb, use_real=False) + key = apply_rotary_emb(key, rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=_prepare_attn_mask(joint_attention_mask, batch_size), + scale=attn.scale, + enable_gqa=kv_heads < attn.heads, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3).type_as(query) + + # Split back to instruction / image, apply separate output projections, then merge. + split_results = self._split_instruction_image_features([hidden_states], encoder_seq_lengths, seq_lengths) + instruct_hidden_states, img_hidden_states = split_results[0] + + instruct_projected = self.instruct_out(instruct_hidden_states) # [B, max_instruct_len, feature_dim] + img_projected = self.img_out(img_hidden_states) # [B, max_img_len, feature_dim] + + merged_list = self._concat_instruction_image_features( + [img_projected], [instruct_projected], encoder_seq_lengths, seq_lengths + ) + hidden_states = merged_list[0] # [B, max_seq_len, feature_dim] + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class BooguImageAttnProcessor: + """ + Single-stream self-attention processor. + + Projects Q/K/V from the (shared) `Attention` module, applies QK-norm and RoPE, + and attends via [`dispatch_attention_fn`]. Used for the refiner / single-stream + blocks and the image self-attention of the double-stream block. + """ + + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Process single-stream self-attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor (same as hidden_states for self-attention) + attention_mask: Optional bool padding mask [B, L] + image_rotary_emb: Optional rotary embeddings + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size = hidden_states.shape[0] + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + dtype = query.dtype + + # Reshape to [B, L, H, head_dim] (the layout dispatch_attention_fn expects) + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=_prepare_attn_mask(attention_mask, batch_size), + scale=attn.scale, + enable_gqa=kv_heads < attn.heads, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3).type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class BooguImageTransformerBlock(nn.Module): + """ + Basic Boogu-Image transformer block: attention + MLP + RMSNorm. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + """Initialize the transformer block.""" + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + # Initialize attention layer + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=BooguImageAttnProcessor(), + ) + + # Initialize feed-forward network + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + # Initialize normalization layers + if modulation: + self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """Initialize linear weights and modulation parameters.""" + nn.init.xavier_uniform_(self.attn.to_q.weight) + nn.init.xavier_uniform_(self.attn.to_k.weight) + nn.init.xavier_uniform_(self.attn.to_v.weight) + nn.init.xavier_uniform_(self.attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) + + if self.modulation: + nn.init.zeros_(self.norm1.linear.weight) + nn.init.zeros_(self.norm1.linear.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the transformer block. + + Args: + hidden_states: Input hidden states tensor + attention_mask: Attention mask tensor + image_rotary_emb: Rotary embeddings for image tokens + temb: Optional timestep embedding tensor + + Returns: + torch.Tensor: Output hidden states after transformer block processing + """ + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class BooguImageDoubleStreamTransformerBlock(nn.Module): + """ + Boogu-Image double-stream block. + Here "double-stream" is the same idea as a "dual-stream" layer: + instruction tokens and image tokens are processed in parallel streams. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + """Initialize the double stream transformer block.""" + super().__init__() + self.head_dim = dim // num_attention_heads + self.num_attention_heads = num_attention_heads + self.modulation = modulation + self.hidden_size = dim + + double_stream_processor = BooguImageDoubleStreamSelfAttnProcessor( + head_dim=self.head_dim, + num_attention_heads=num_attention_heads, + num_kv_heads=num_kv_heads, + qkv_bias=False, + ) + + # Image stream components. + self.img_instruct_attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=double_stream_processor, + ) + + self.img_self_attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=BooguImageAttnProcessor(), + ) + + self.img_feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + if modulation: + # Image modulation terms: cross-attn, MLP, self-attn. + self.img_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + self.img_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + self.img_norm3 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + else: + self.img_norm1 = RMSNorm(dim, eps=norm_eps) + self.img_norm2 = RMSNorm(dim, eps=norm_eps) + self.img_norm3 = RMSNorm(dim, eps=norm_eps) + + self.img_ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.img_attn_norm = RMSNorm(dim, eps=norm_eps) + self.img_self_attn_norm = RMSNorm(dim, eps=norm_eps) + self.img_ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + # Instruction stream components. + self.instruct_feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + if modulation: + # Instruction modulation terms: cross-attn, MLP. + self.instruct_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + self.instruct_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + else: + self.instruct_norm1 = RMSNorm(dim, eps=norm_eps) + self.instruct_norm2 = RMSNorm(dim, eps=norm_eps) + + self.instruct_ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.instruct_attn_norm = RMSNorm(dim, eps=norm_eps) + self.instruct_ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + # double_stream_processor owns its own q/k/v projections. + for param in self.img_instruct_attn.to_q.parameters(): + param.requires_grad = False + for param in self.img_instruct_attn.to_k.parameters(): + param.requires_grad = False + for param in self.img_instruct_attn.to_v.parameters(): + param.requires_grad = False + + del self.img_instruct_attn.to_k + del self.img_instruct_attn.to_v + del self.img_instruct_attn.to_q + + def initialize_weights(self) -> None: + """Initialize linear weights and modulation parameters.""" + nn.init.xavier_uniform_(self.img_instruct_attn.to_out[0].weight) + + # Keep Xavier init consistent across Boogu-Image blocks. + nn.init.xavier_uniform_(self.img_self_attn.to_q.weight) + nn.init.xavier_uniform_(self.img_self_attn.to_k.weight) + nn.init.xavier_uniform_(self.img_self_attn.to_v.weight) + nn.init.xavier_uniform_(self.img_self_attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.img_feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.img_feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.img_feed_forward.linear_3.weight) + + nn.init.xavier_uniform_(self.instruct_feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.instruct_feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.instruct_feed_forward.linear_3.weight) + + # Initialize modulation parameters + if self.modulation: + nn.init.zeros_(self.img_norm1.linear.weight) + nn.init.zeros_(self.img_norm1.linear.bias) + nn.init.zeros_(self.img_norm2.linear.weight) + nn.init.zeros_(self.img_norm2.linear.bias) + nn.init.zeros_(self.img_norm3.linear.weight) + nn.init.zeros_(self.img_norm3.linear.bias) + + nn.init.zeros_(self.instruct_norm1.linear.weight) + nn.init.zeros_(self.instruct_norm1.linear.bias) + nn.init.zeros_(self.instruct_norm2.linear.weight) + nn.init.zeros_(self.instruct_norm2.linear.bias) + + def forward( + self, + img_hidden_states: torch.Tensor, # [B, L_img, D] - Image tokens (ref_img + noise_img) + instruct_hidden_states: torch.Tensor, # [B, L_instruct, D] - Instruction tokens + img_attention_mask: torch.Tensor, # [B, L_img] - Attention mask for [ref_img + noise_img] + joint_attention_mask: torch.Tensor, # [B, L_total] - Combined attention mask for [instruct + img] + image_rotary_emb: torch.Tensor, # [B, L_img, head_dim] - Rotary embeddings for [ref_img + noise_img] + rotary_emb: torch.Tensor, # [B, L_total, head_dim] - Rotary embeddings for [instruct + img] + temb: Optional[torch.Tensor] = None, # [B, 1024] - Timestep embeddings + encoder_seq_lengths: List[int] = None, # [B] - Instruction sequence lengths for each sample + seq_lengths: List[int] = None, # [B] - Total sequence lengths for each sample + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Run one dual-stream (double-stream) block step. + Returns updated `(img_hidden_states, instruct_hidden_states)`. + """ + if self.modulation and temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + # Extract dimensions + batch_size = img_hidden_states.shape[0] + L_instruct = instruct_hidden_states.shape[1] # Instruction sequence length + L_img = img_hidden_states.shape[1] # Image sequence length (ref_img + noise_img) + + if self.modulation: + # Step 1: modulation for both streams. + img_norm1_out, img_gate_msa, img_scale_mlp, img_gate_mlp = self.img_norm1(img_hidden_states, temb) + img_norm2_out, img_shift_mlp, _, _ = self.img_norm2(img_hidden_states, temb) + img_norm3_out, img_gate_self, _, _ = self.img_norm3(img_hidden_states, temb) + + ( + instruct_norm1_out, + instruct_gate_msa, + instruct_scale_mlp, + instruct_gate_mlp, + ) = self.instruct_norm1(instruct_hidden_states, temb) + instruct_norm2_out, instruct_shift_mlp, _, _ = self.instruct_norm2(instruct_hidden_states, temb) + + # Step 2: joint attention on [instruct + img]. + # Call processor directly because Attention.forward does not expose these dual-stream args. + joint_attn_out = self.img_instruct_attn.processor( + attn=self.img_instruct_attn, + img_hidden_states=img_norm1_out, + instruct_hidden_states=instruct_norm1_out, + joint_attention_mask=joint_attention_mask, + rotary_emb=rotary_emb, + encoder_seq_lengths=encoder_seq_lengths, + seq_lengths=seq_lengths, + ) + + # Split back into instruction/image segments. + instruct_attn_out = instruct_hidden_states.new_zeros(batch_size, L_instruct, self.hidden_size) + img_attn_out = img_hidden_states.new_zeros(batch_size, L_img, self.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + instruct_attn_out[i, :encoder_seq_len] = joint_attn_out[i, :encoder_seq_len] + img_attn_out[i, : seq_len - encoder_seq_len] = joint_attn_out[i, encoder_seq_len:seq_len] + + # Step 3: image self-attention. + img_self_attn_out = self.img_self_attn( + hidden_states=img_norm3_out, + encoder_hidden_states=img_norm3_out, + attention_mask=img_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + # Step 4: residual updates. + img_hidden_states = img_hidden_states + img_gate_msa.unsqueeze(1).tanh() * self.img_attn_norm(img_attn_out) + img_hidden_states = img_hidden_states + img_gate_self.unsqueeze(1).tanh() * self.img_self_attn_norm( + img_self_attn_out + ) + + img_mlp_input = (1 + img_scale_mlp.unsqueeze(1)) * img_norm2_out + img_shift_mlp.unsqueeze(1) + img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_mlp_input)) + img_hidden_states = img_hidden_states + img_gate_mlp.unsqueeze(1).tanh() * self.img_ffn_norm2(img_mlp_out) + + instruct_hidden_states = instruct_hidden_states + instruct_gate_msa.unsqueeze( + 1 + ).tanh() * self.instruct_attn_norm(instruct_attn_out) + + instruct_mlp_input = ( + 1 + instruct_scale_mlp.unsqueeze(1) + ) * instruct_norm2_out + instruct_shift_mlp.unsqueeze(1) + instruct_mlp_out = self.instruct_feed_forward(self.instruct_ffn_norm1(instruct_mlp_input)) + instruct_hidden_states = instruct_hidden_states + instruct_gate_mlp.unsqueeze( + 1 + ).tanh() * self.instruct_ffn_norm2(instruct_mlp_out) + + else: + # Non-modulated branch used by context-style blocks. + img_norm1_out = self.img_norm1(img_hidden_states) + img_norm3_out = self.img_norm3(img_hidden_states) + instruct_norm1_out = self.instruct_norm1(instruct_hidden_states) + + # Same processor path as above. + joint_attn_out = self.img_instruct_attn.processor( + attn=self.img_instruct_attn, + img_hidden_states=img_norm1_out, + instruct_hidden_states=instruct_norm1_out, + joint_attention_mask=joint_attention_mask, + rotary_emb=rotary_emb, + encoder_seq_lengths=encoder_seq_lengths, + seq_lengths=seq_lengths, + ) + + instruct_attn_out = instruct_hidden_states.new_zeros(batch_size, L_instruct, self.hidden_size) + img_attn_out = img_hidden_states.new_zeros(batch_size, L_img, self.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + instruct_attn_out[i, :encoder_seq_len] = joint_attn_out[i, :encoder_seq_len] + img_attn_out[i, : seq_len - encoder_seq_len] = joint_attn_out[i, encoder_seq_len:seq_len] + + img_self_attn_out = self.img_self_attn( + hidden_states=img_norm3_out, + encoder_hidden_states=img_norm3_out, + attention_mask=img_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + img_hidden_states = img_hidden_states + self.img_attn_norm(img_attn_out) + img_hidden_states = img_hidden_states + self.img_self_attn_norm(img_self_attn_out) + img_norm2_out = self.img_norm2(img_hidden_states) + img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_norm2_out)) + img_hidden_states = img_hidden_states + self.img_ffn_norm2(img_mlp_out) + + instruct_hidden_states = instruct_hidden_states + self.instruct_attn_norm(instruct_attn_out) + instruct_norm2_out = self.instruct_norm2(instruct_hidden_states) + instruct_mlp_out = self.instruct_feed_forward(self.instruct_ffn_norm1(instruct_norm2_out)) + instruct_hidden_states = instruct_hidden_states + self.instruct_ffn_norm2(instruct_mlp_out) + + return img_hidden_states, instruct_hidden_states + + +class BooguImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + Boogu-Image transformer with mixed stream topology. + Early layers use double-stream (aka dual-stream) processing, then switch + to single-stream joint processing. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = [ + "BooguImageTransformerBlock", + "BooguImageDoubleStreamTransformerBlock", + ] + _repeated_blocks = [ + "BooguImageTransformerBlock", + "BooguImageDoubleStreamTransformerBlock", + ] + _skip_layerwise_casting_patterns = ["x_embedder", "norm", "embedding"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_double_stream_layers: int = 2, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (40, 40, 40), + axes_lens: Tuple[int, int, int] = (2048, 1664, 1664), + instruction_feature_configs: Dict[str, Any] = { + "instruction_feat_dim": 1024, + "reduce_type": "mean", + "num_instruction_feat_layers": 1, + }, + timestep_scale: float = 1.0, + ) -> None: + """Initialize the Boogu-Image mixed single-double stream transformer model.""" + super().__init__() + + # Validate configuration + if (hidden_size // num_attention_heads) != sum(axes_dim_rope): + raise ValueError( + f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " + f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" + ) + + if num_double_stream_layers > num_layers: + raise ValueError( + f"num_double_stream_layers ({num_double_stream_layers}) cannot be greater than " + f"num_layers ({num_layers})" + ) + + self.out_channels = out_channels or in_channels + self.num_double_stream_layers = num_double_stream_layers + self.num_single_stream_layers = num_layers - num_double_stream_layers + self.instruction_feature_configs = instruction_feature_configs + self.preprocessed_instruction_feat_dim = self.cal_preprocessed_instruction_feat_dim( + instruction_feature_configs + ) + + # Initialize embeddings + self.rope_embedder = BooguImageDoubleStreamRotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.ref_image_patch_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + instruction_feat_dim=self.preprocessed_instruction_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale, + ) + + # Refiner layers. + self.noise_refiner = nn.ModuleList( + [ + BooguImageTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.ref_image_refiner = nn.ModuleList( + [ + BooguImageTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + BooguImageTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False, + ) + for _ in range(num_refiner_layers) + ] + ) + + # Mixed architecture: dual-stream first, then single-stream. + # Here "double-stream" and "dual-stream" mean the same thing. + self.double_stream_layers = nn.ModuleList( + [ + BooguImageDoubleStreamTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_double_stream_layers) + ] + ) + + # Single-stream layers process the fused sequence; they reuse BooguImageTransformerBlock. + self.single_stream_layers = nn.ModuleList( + [ + BooguImageTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(self.num_single_stream_layers) + ] + ) + + # Output norm and projection. + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + + # Distinguish multiple reference images. + self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images + + self.gradient_checkpointing = False + + self.initialize_weights() + + # TeaCache settings + self.enable_teacache = False + self.teacache_rel_l1_thresh = 0.05 + self.teacache_params = TeaCacheParams() + + # Polynomial (highest-degree first) rescaling the relative L1 distance used by TeaCache. + self.teacache_rescale_coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487] + + def initialize_weights(self) -> None: + """ + Initialize the weights of the model. + + Uses Xavier uniform initialization for linear layers. + """ + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) + nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) + + nn.init.zeros_(self.norm_out.linear_1.weight) + nn.init.zeros_(self.norm_out.linear_1.bias) + nn.init.zeros_(self.norm_out.linear_2.weight) + nn.init.zeros_(self.norm_out.linear_2.bias) + + nn.init.normal_(self.image_index_embedding, std=0.02) + + def img_patch_embed_and_refine( + self, + hidden_states, + ref_image_hidden_states, + padded_img_mask, + padded_ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ): + """Embed image patches and run the refiner blocks.""" + batch_size = len(hidden_states) + max_combined_img_len = max( + [img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)] + ) + + hidden_states = self.x_embedder(hidden_states) + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) + + for i in range(batch_size): + shift = 0 + for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): + ref_image_hidden_states[i, shift : shift + ref_img_len, :] = ( + ref_image_hidden_states[i, shift : shift + ref_img_len, :] + self.image_index_embedding[j] + ) + shift += ref_img_len + + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + + flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) + num_ref_images = len(flat_l_effective_ref_img_len) + max_ref_img_len = max(flat_l_effective_ref_img_len) + + batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool) + batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros( + num_ref_images, max_ref_img_len, self.config.hidden_size + ) + batch_ref_img_rotary_emb = hidden_states.new_zeros( + num_ref_images, + max_ref_img_len, + ref_img_rotary_emb.shape[-1], + dtype=ref_img_rotary_emb.dtype, + ) + batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) + + # Flatten reference images into a temporary batch. + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + batch_ref_img_mask[idx, :ref_img_len] = True + batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[ + i, shift : shift + ref_img_len + ] + batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift : shift + ref_img_len] + batch_temb[idx] = temb[i] + shift += ref_img_len + idx += 1 + + # Refine each reference-image sample. + for layer in self.ref_image_refiner: + batch_ref_image_hidden_states = layer( + batch_ref_image_hidden_states, + batch_ref_img_mask, + batch_ref_img_rotary_emb, + batch_temb, + ) + + # Restore reference-image sequence layout. + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + ref_image_hidden_states[i, shift : shift + ref_img_len] = batch_ref_image_hidden_states[ + idx, :ref_img_len + ] + shift += ref_img_len + idx += 1 + + combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size) + for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)): + combined_img_hidden_states[i, : sum(ref_img_len)] = ref_image_hidden_states[i, : sum(ref_img_len)] + combined_img_hidden_states[i, sum(ref_img_len) : sum(ref_img_len) + img_len] = hidden_states[i, :img_len] + + return combined_img_hidden_states + + def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): + """Flatten patch tokens and pad to batched sequences.""" + batch_size = len(hidden_states) + p = self.config.patch_size + device = hidden_states[0].device + + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] + l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] + + if ref_image_hidden_states is not None: + ref_img_sizes = [ + [(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None + for imgs in ref_image_hidden_states + ] + l_effective_ref_img_len = [ + [(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] + if _ref_img_sizes is not None + else [0] + for _ref_img_sizes in ref_img_sizes + ] + else: + ref_img_sizes = [None for _ in range(batch_size)] + l_effective_ref_img_len = [[0] for _ in range(batch_size)] + + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # Reference-image patch embeddings. + flat_ref_img_hidden_states = [] + for i in range(batch_size): + if ref_img_sizes[i] is not None: + imgs = [] + for ref_img in ref_image_hidden_states[i]: + C, H, W = ref_img.size() + # "c (h p1) (w p2) -> (h w) (p1 p2 c)" + ref_img = ref_img.reshape(C, H // p, p, W // p, p) + ref_img = ref_img.permute(1, 3, 2, 4, 0) + ref_img = ref_img.reshape((H // p) * (W // p), p * p * C) + imgs.append(ref_img) + + img = torch.cat(imgs, dim=0) + flat_ref_img_hidden_states.append(img) + else: + flat_ref_img_hidden_states.append(None) + + # Noise-image patch embeddings. + flat_hidden_states = [] + for i in range(batch_size): + img = hidden_states[i] + C, H, W = img.size() + + # "c (h p1) (w p2) -> (h w) (p1 p2 c)" + img = img.reshape(C, H // p, p, W // p, p) + img = img.permute(1, 3, 2, 4, 0) + img = img.reshape((H // p) * (W // p), p * p * C) + flat_hidden_states.append(img) + + padded_ref_img_hidden_states = torch.zeros( + batch_size, + max_ref_img_len, + flat_hidden_states[0].shape[-1], + device=device, + dtype=flat_hidden_states[0].dtype, + ) + padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + if ref_img_sizes[i] is not None: + padded_ref_img_hidden_states[i, : sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i] + padded_ref_img_mask[i, : sum(l_effective_ref_img_len[i])] = True + + padded_hidden_states = torch.zeros( + batch_size, + max_img_len, + flat_hidden_states[0].shape[-1], + device=device, + dtype=flat_hidden_states[0].dtype, + ) + padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + padded_hidden_states[i, : l_effective_img_len[i]] = flat_hidden_states[i] + padded_img_mask[i, : l_effective_img_len[i]] = True + + return ( + padded_hidden_states, + padded_ref_img_hidden_states, + padded_img_mask, + padded_ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) + + def cal_preprocessed_instruction_feat_dim(self, instruction_feature_configs: Dict[str, Any]): + num_instruction_feat_layers = max(instruction_feature_configs.get("num_instruction_feat_layers", 1), 1) + instruction_feat_dim = instruction_feature_configs.get("instruction_feat_dim", 4096) + reduce_type = instruction_feature_configs.get("reduce_type", "concat") + if "cat" in reduce_type.lower(): + return num_instruction_feat_layers * instruction_feat_dim + elif "mean" in reduce_type.lower(): + return instruction_feat_dim + else: + raise ValueError(f"Invalid reduce_type: {reduce_type}") + + def preprocess_instruction_hidden_states( + self, raw_instruction_hidden_states, instruction_feature_configs: Dict[str, Any] + ): + num_instruction_feat_layers = max(instruction_feature_configs.get("num_instruction_feat_layers", 1), 1) + reduce_type = instruction_feature_configs.get("reduce_type", "concat") + + instruction_hidden_states = None + if isinstance(raw_instruction_hidden_states, torch.Tensor): + instruction_hidden_states = raw_instruction_hidden_states + elif isinstance(raw_instruction_hidden_states, (list, tuple)): + if len(raw_instruction_hidden_states) != num_instruction_feat_layers: + raise ValueError( + f"Expected {num_instruction_feat_layers} instruction-feature layers, " + f"got {len(raw_instruction_hidden_states)}." + ) + if "cat" in reduce_type.lower(): + instruction_hidden_states = torch.cat(raw_instruction_hidden_states, dim=-1) + elif "mean" in reduce_type.lower(): + instruction_hidden_states = torch.mean(torch.stack(raw_instruction_hidden_states), dim=0) + else: + raise ValueError(f"Invalid reduce_type: {reduce_type}") + else: + raise ValueError( + f"Invalid type of raw_instruction_hidden_states, expected torch.Tensor or list, but got {type(raw_instruction_hidden_states)}" + ) + + if self.preprocessed_instruction_feat_dim != instruction_hidden_states.shape[-1]: + raise ValueError( + f"Instruction feature dim mismatch: expected {self.preprocessed_instruction_feat_dim}, " + f"got {instruction_hidden_states.shape[-1]}." + ) + + return instruction_hidden_states + + def forward( + self, + hidden_states: Union[torch.Tensor, List[torch.Tensor]], + timestep: torch.Tensor, + instruction_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + instruction_attention_mask: torch.Tensor, + ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + Forward pass: + context/refiner -> dual-stream (double-stream) -> fusion -> single-stream -> projection. + """ + instruction_hidden_states = self.preprocess_instruction_hidden_states( + instruction_hidden_states, self.instruction_feature_configs + ) + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # === 1. Initial processing (same as original Boogu-Image) === + batch_size = len(hidden_states) + is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) + + if is_hidden_states_tensor: + if hidden_states.ndim != 4: + raise ValueError(f"Expected hidden_states with 4 dims [B, C, H, W], got ndim={hidden_states.ndim}.") + hidden_states = list(hidden_states) + + device = hidden_states[0].device + + # Timestep and instruction embedding. + temb, instruction_hidden_states = self.time_caption_embed( + timestep, instruction_hidden_states, hidden_states[0].dtype + ) + + # Flatten and pad token sequences. + ( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + # Build rotary embeddings and sequence lengths. + ( + context_rotary_emb, + ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + combined_img_rotary_emb, + combined_img_seq_lengths, + ) = self.rope_embedder( + freqs_cis, + instruction_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + # Context refinement. + for layer in self.context_refiner: + instruction_hidden_states = layer( + instruction_hidden_states, + instruction_attention_mask, + context_rotary_emb, + ) + + # Image patch embedding and refinement. + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ) + + # Dual-stream (double-stream) stage. + instruct_hidden_states = instruction_hidden_states + img_hidden_states = combined_img_hidden_states + + # Joint mask for [instruct + image]. + max_seq_len = max(seq_lengths) + joint_attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + for i, seq_len in enumerate(seq_lengths): + joint_attention_mask[i, :seq_len] = True + + # Run dual-stream blocks. + if self.num_double_stream_layers > 0: + # Image-only mask for [ref + noise]. + max_img_len = max(combined_img_seq_lengths) + img_attention_mask = hidden_states.new_zeros(batch_size, max_img_len, dtype=torch.bool) + for i, img_seq_len in enumerate(combined_img_seq_lengths): + img_attention_mask[i, :img_seq_len] = True + + for layer in self.double_stream_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img_hidden_states, instruct_hidden_states = self._gradient_checkpointing_func( + layer, + img_hidden_states, + instruct_hidden_states, + img_attention_mask, + joint_attention_mask, + combined_img_rotary_emb, + rotary_emb, + temb, + encoder_seq_lengths, + seq_lengths, + ) + else: + img_hidden_states, instruct_hidden_states = layer( + img_hidden_states, + instruct_hidden_states, + img_attention_mask, + joint_attention_mask, + combined_img_rotary_emb, + rotary_emb, + temb, + encoder_seq_lengths, + seq_lengths, + ) + + # Fuse streams to joint sequence. + joint_hidden_states = hidden_states.new_zeros(batch_size, max(seq_lengths), self.config.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + joint_hidden_states[i, :encoder_seq_len] = instruct_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len:seq_len] = img_hidden_states[i, : seq_len - encoder_seq_len] + + # Single-stream stage. + hidden_states = joint_hidden_states + + # TeaCache optimization. + if self.enable_teacache and len(self.single_stream_layers) > 0: + teacache_hidden_states = hidden_states.clone() + teacache_temb = temb.clone() + modulated_inp, _, _, _ = self.single_stream_layers[0].norm1(teacache_hidden_states, teacache_temb) + if self.teacache_params.is_first_or_last_step: + should_calc = True + self.teacache_params.accumulated_rel_l1_distance = 0 + else: + rel_l1 = ( + ( + (modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() + / self.teacache_params.previous_modulated_inp.abs().mean() + ) + .cpu() + .item() + ) + rescaled = 0.0 + for coefficient in self.teacache_rescale_coefficients: + rescaled = rescaled * rel_l1 + coefficient + self.teacache_params.accumulated_rel_l1_distance += rescaled + if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.teacache_params.accumulated_rel_l1_distance = 0 + self.teacache_params.previous_modulated_inp = modulated_inp + else: + should_calc = True + + if self.enable_teacache and not should_calc: + hidden_states += self.teacache_params.previous_residual + else: + if self.enable_teacache: + ori_hidden_states = hidden_states.clone() + + for layer in self.single_stream_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, joint_attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer(hidden_states, joint_attention_mask, rotary_emb, temb) + + if self.enable_teacache: + self.teacache_params.previous_residual = hidden_states - ori_hidden_states + + # Output projection. + hidden_states = self.norm_out(hidden_states, temb) + + # Reshape back to image format. + p = self.config.patch_size + output = [] + for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)): + height, width = img_size + img_tokens = hidden_states[i][seq_len - img_len : seq_len] + # "(h w) (p1 p2 c) -> c (h p1) (w p2)" + h, w = height // p, width // p + c = img_tokens.shape[-1] // (p * p) + img_output = img_tokens.reshape(h, w, p, p, c) + img_output = img_output.permute(4, 0, 2, 1, 3) + img_output = img_output.reshape(c, h * p, w * p) + output.append(img_output) + + if is_hidden_states_tensor: + output = torch.stack(output, dim=0) + + # Reset LoRA scaling. + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 850a991941ff..e3625258bfa4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -464,6 +464,7 @@ ] _import_structure["chronoedit"] = ["ChronoEditPipeline"] _import_structure["glm_image"] = ["GlmImagePipeline"] + _import_structure["boogu"] = ["BooguImagePipeline", "BooguImageTurboPipeline"] try: if not is_onnx_available(): @@ -623,6 +624,7 @@ AudioLDM2UNet2DConditionModel, ) from .aura_flow import AuraFlowPipeline + from .boogu import BooguImagePipeline, BooguImageTurboPipeline from .bria import BriaPipeline from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline diff --git a/src/diffusers/pipelines/boogu/__init__.py b/src/diffusers/pipelines/boogu/__init__.py new file mode 100644 index 000000000000..8bdb02c3154c --- /dev/null +++ b/src/diffusers/pipelines/boogu/__init__.py @@ -0,0 +1,3 @@ +from .image_processor import BooguImageProcessor +from .pipeline_boogu import BooguImagePipeline +from .pipeline_boogu_turbo import BooguImageTurboPipeline diff --git a/src/diffusers/pipelines/boogu/image_processor.py b/src/diffusers/pipelines/boogu/image_processor.py new file mode 100644 index 000000000000..13dda1a39a22 --- /dev/null +++ b/src/diffusers/pipelines/boogu/image_processor.py @@ -0,0 +1,189 @@ +# Copyright (C) 2026 Boogu Team. +# This repository is a fork by Boogu Team; modifications have been made. +# +# Original work: Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from ...configuration_utils import register_to_config +from ...image_processor import ( + PipelineImageInput, + VaeImageProcessor, +) + + +class BooguImageProcessor(VaeImageProcessor): + """ + Boogu-Image image processor, with resize/crop behavior adapted from PixArt's + image processor implementation. + + This class keeps a Diffusers-compatible preprocessing contract while adding + Boogu-Image-specific pixel and side-length constraints. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `16`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + max_pixels (`int`, *optional*): + Maximum number of pixels; the image is downscaled to fit when set. + max_side_length (`int`, *optional*): + Maximum side length; the image is downscaled to fit when set. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + resample: str = "lanczos", + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + + self.max_pixels = max_pixels + self.max_side_length = max_side_length + + def get_new_height_width( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + ) -> Tuple[int, int]: + r""" + Returns target `(height, width)` after optional downscaling and + rounding to `vae_scale_factor` multiples. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + if max_side_length is None: + max_side_length = self.max_side_length + + if max_pixels is None: + max_pixels = self.max_pixels + + # Clamp ratio to <=1 to avoid upscaling input images in preprocessing. + ratio = 1.0 + if max_side_length is not None: + longest_side = height if height > width else width + ratio = min(ratio, max_side_length / longest_side) + if max_pixels is not None: + ratio = min(ratio, (max_pixels / (height * width)) ** 0.5) + + new_height, new_width = ( + int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, + int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, + ) + return new_height, new_width + + def preprocess( + self, + image: PipelineImageInput, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + resize_mode: str = "default", # "default", "fill", "crop" + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. + + Identical to [`VaeImageProcessor.preprocess`], except the target size is derived from Boogu's + `max_pixels` / `max_side_length` downscaling (via [`get_new_height_width`]) instead of a fixed + default, before delegating the format handling, resize, and normalization to the parent. + + Args: + image (`PipelineImageInput`): + The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; also a list thereof. + height (`int`, *optional*): + Target height. If `None`, derived from the image and the pixel / side-length constraints. + width (`int`, *optional*): + Target width. If `None`, derived from the image and the pixel / side-length constraints. + max_pixels (`int`, *optional*): + Maximum number of pixels; the image is downscaled to fit. Defaults to `self.max_pixels`. + max_side_length (`int`, *optional*): + Maximum side length; the image is downscaled to fit. Defaults to `self.max_side_length`. + resize_mode (`str`, *optional*, defaults to `default`): + One of `default`, `fill`, or `crop`; see [`VaeImageProcessor.preprocess`]. + crops_coords (`Tuple[int, int, int, int]`, *optional*): + The crop coordinates. If `None`, the image is not cropped. + + Returns: + `torch.Tensor`: + The preprocessed image tensor with shape `[B, C, H, W]`. + """ + if self.config.do_resize: + representative = image[0] if isinstance(image, list) else image + height, width = self.get_new_height_width(representative, height, width, max_pixels, max_side_length) + return super().preprocess( + image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) diff --git a/src/diffusers/pipelines/boogu/pipeline_boogu.py b/src/diffusers/pipelines/boogu/pipeline_boogu.py new file mode 100644 index 000000000000..381926f06a14 --- /dev/null +++ b/src/diffusers/pipelines/boogu/pipeline_boogu.py @@ -0,0 +1,1766 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers.transformer_boogu import BooguImageRotaryPosEmbed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + BaseOutput, + is_torch_xla_available, + logging, +) +from diffusers.utils.teacache_util import TeaCacheParams +from diffusers.utils.torch_utils import randn_tensor + +from ...models.transformers import BooguImageTransformer2DModel +from .image_processor import BooguImageProcessor + + +if is_torch_xla_available(): + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FMPipelineOutput(BaseOutput): + """ + Output class for BooguImagePipeline. + + Args: + images (Union[List[PIL.Image.Image], np.ndarray]): + List of denoised PIL images of length `batch_size` or numpy array of shape + `(batch_size, height, width, num_channels)`. Contains the generated images. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +def set_flow_match_timesteps( + scheduler: FlowMatchEulerDiscreteScheduler, + num_inference_steps: int, + device: str | torch.device | None = None, + seq_len: int | None = None, +) -> tuple[torch.Tensor, int]: + """Set Boogu's training-aligned timesteps on the official flow-match scheduler. + + Boogu trains with a static ``v1`` time shift and a sigma schedule that runs + ``0 -> 1``, feeding that sigma to the transformer as the timestep directly + (unlike the built-in scheduler, whose timesteps run ``1000 -> 0``). The shift + amount ``mu`` is a fixed function of ``seq_len`` (resolution-independent), and + the shift itself reuses the parent's exponential formula. This overwrites the + scheduler's ``timesteps`` / ``sigmas`` to that convention; ``step`` is the + official one and works unchanged on the resulting schedule. + """ + if seq_len is None: + seq_len = scheduler.config.seq_len + + # Static v1 shift: mu is a linear function of seq_len between (base_image_seq_len, + # base_shift) and (max_image_seq_len, max_shift). + slope = (scheduler.config.max_shift - scheduler.config.base_shift) / ( + scheduler.config.max_image_seq_len - scheduler.config.base_image_seq_len + ) + mu = scheduler.config.base_shift + slope * (seq_len - scheduler.config.base_image_seq_len) + + t = np.linspace(0.0, 1.0, num_inference_steps + 1, dtype=np.float32)[:-1] + # Boogu v1 == 1 - exponential_shift(mu, 1, 1 - t); reuse the parent's formula. + t = (1.0 - scheduler._time_shift_exponential(mu, 1.0, 1.0 - torch.from_numpy(t))).numpy() + + timesteps = torch.from_numpy(t).to(dtype=torch.float32, device=device) + scheduler.timesteps = timesteps # 0-1 sigma, fed to the transformer as the timestep + scheduler.sigmas = torch.cat([timesteps, torch.ones(1, device=timesteps.device)]) + scheduler.num_inference_steps = num_inference_steps + scheduler._step_index = None + scheduler._begin_index = None + + return scheduler.timesteps, num_inference_steps + + +# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps; +# the default branch routes the official flow-match scheduler through Boogu's 0->1 time-shift adapter. +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + # Boogu uses the official flow-match scheduler with a training-aligned + # 0->1 sigma schedule; the adapter overwrites timesteps/sigmas to it. + timesteps, num_inference_steps = set_flow_match_timesteps(scheduler, num_inference_steps, device=device) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class MomentumRollingSum: + def __init__(self, momentum_weight: float = 0.1, current_weight: float = 0.9): + self.momentum_weight = momentum_weight + self.current_weight = current_weight + self.rolling_sum = 0 + + def update(self, current_step: torch.Tensor): + self.rolling_sum = self.current_weight * current_step + self.momentum_weight * self.rolling_sum + return self.rolling_sum + + +class BooguImagePipeline(DiffusionPipeline): + """ + Base pipeline for Boogu text-to-image and image-editing inference. + + The pipeline coordinates the main components used by Boogu inference: + the MLLM encodes text instructions and optional reference-image context, + the Boogu single/double-stream transformer predicts latent updates during + the denoising process, the VAE maps between image space and latent space, + and the scheduler defines the diffusion timesteps. + + It also owns the runtime orchestration around classifier + guidance variants, boosted orthogonal guidance, device + placement, and optional CPU/group offload strategies. + + Args: + transformer (BooguImageTransformer2DModel): Boogu transformer + denoiser used for T2I and TI2I latent prediction. + vae (AutoencoderKL): Autoencoder used to encode input/reference images + into latents and decode generated latents back to images. + scheduler (FlowMatchEulerDiscreteScheduler): Scheduler that provides + diffusion timesteps and controls the denoising trajectory. + mllm (Qwen3VLForConditionalGeneration): Multimodal language model used + as the instruction encoder. + processor (Qwen3VLProcessor): Processor paired with the MLLM for + tokenization, chat templating, and image preprocessing. + """ + + model_cpu_offload_seq = "mllm->transformer->vae" + + def __init__( + self, + transformer: BooguImageTransformer2DModel, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + mllm: Qwen3VLForConditionalGeneration, + processor: Qwen3VLProcessor, + ) -> None: + """ + Initialize the Boogu-Image pipeline. + + Args: + transformer: Boogu transformer denoiser for latent prediction. + vae: Autoencoder used for latent/image encoding and decoding. + scheduler: Diffusion scheduler that controls denoising steps. + mllm: Multimodal language model used to encode instructions. + processor: Processor paired with the MLLM for text/image inputs. + """ + # Defer setting pipeline attributes until after super().__init__, + # to avoid accessing self.config before it's created by Diffusers base class. + if hasattr(mllm, "lm_head"): + # Use the inner model of the instruction encoder as the encoder backbone. + mllm = mllm.model + + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor, + ) + + # Now it is safe to set additional attributes + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = BooguImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True) + self.default_sample_size = 128 + + # System prompts matching dataset logic (specific to this pipeline) + + self.SYSTEM_PROMPT_4_TI2I_UNIFIED = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate." + self.SYSTEM_PROMPT_4_T2I_UNIFIED = "You are a helpful assistant that generates high-quality images based on user instructions. The instructions are as follows." + + self.SYSTEM_PROMPT_4_T2I = self.SYSTEM_PROMPT_4_T2I_UNIFIED + self.SYSTEM_PROMPT_DROP = ( + self.SYSTEM_PROMPT_4_TI2I_UNIFIED + ) # This is for empty negative instruction for image guidance in double guidance. + self.SYSTEM_PROMPT_4_TI2I = self.SYSTEM_PROMPT_4_TI2I_UNIFIED + self.SYSTEM_PROMPT_4_I2I = self.SYSTEM_PROMPT_4_TI2I_UNIFIED + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: Union[torch.device, str], + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Prepare the initial latents for the diffusion process. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of channels in the latent space. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type of the latents. + device: The device to place the latents on. + generator: The random number generator to use. + latents: Optional pre-computed latents to use instead of random initialization. + + Returns: + torch.FloatTensor: The prepared latents tensor. + """ + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor: + """ + Encode an image into the VAE latent space. + + Args: + img: The input image tensor to encode. + + Returns: + torch.FloatTensor: The encoded latent representation. + """ + z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample() + if self.vae.config.shift_factor is not None: + z0 = z0 - self.vae.config.shift_factor + if self.vae.config.scaling_factor is not None: + z0 = z0 * self.vae.config.scaling_factor + z0 = z0.to(dtype=self.vae.dtype) + return z0 + + def preprocess_vlm_input_pil_images( + self, + input_pil_images: List[PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + resize_mode: str = "default", + crops_coords: List[Tuple[int, int, int, int]] = None, + ) -> List[PIL.Image.Image]: + """ + Resize input PIL images for VLM encoding. For each image, the target height/width is computed + from the pixel budget (max_pixels / max_side_length) and the image is resized to fit. + max_pixels is an int or None; per-image selection is handled by caller before passing here. + """ + + if input_pil_images is None or len(input_pil_images) <= 0: + return input_pil_images + + assert isinstance(input_pil_images, list), "`input_pil_images` should be a list." + assert all(isinstance(x, PIL.Image.Image) for x in input_pil_images), ( + "`input_pil_images` should be a list of PIL.Image.Image." + ) + + processed_input_pil_images = [] + for image in input_pil_images: + if crops_coords is not None: + image = [i.crop(crops_coords) for i in image] + height, width = self.image_processor.get_new_height_width( + image, height, width, max_pixels, max_side_length + ) + processed_input_pil_images.append( + self.image_processor.resize(image, height, width, resize_mode=resize_mode) + ) + return processed_input_pil_images + + def prepare_image( + self, + images: Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]], + batch_size: int, + num_images_per_instruction: int, + max_input_image_pixels: Union[int, list, tuple], + max_side_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> List[Optional[torch.FloatTensor]]: + """ + Prepare input images for processing by encoding them into the VAE latent space. + + Args: + images: Single image or list of images to process. + batch_size: The number of images to generate per prompt. + num_images_per_instruction: The number of images to generate for each prompt. + device: The device to place the encoded latents on. + dtype: The data type of the encoded latents. + + Returns: + List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image. + """ + + success, max_images_per_sample, wrapped_input_images = self._check_and_wrap_input_images(images) + + if wrapped_input_images is not None: + assert len(wrapped_input_images) == batch_size, ( + "`wrapped_input_images` should be List[List[PIL.Image.Image]] and the `len(wrapped_input_images)` should be equal to `batch_size`." + ) + else: + wrapped_input_images = [None] * batch_size + + latents = [] + + for i, img in enumerate(wrapped_input_images): + if img is not None and len(img) > 0: + ref_latents = [] + for j, img_j in enumerate(img): + max_pixels = self._get_max_image_pixels( + num_images=len(img), + max_input_image_pixels=max_input_image_pixels, + ) + img_j = self.image_processor.preprocess( + img_j, max_pixels=max_pixels, max_side_length=max_side_length + ) + ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0)) + else: + ref_latents = None + + for _ in range(num_images_per_instruction): + latents.append(ref_latents) + + return latents + + def _check_and_wrap_input_images( + self, + input_images: Any, + treat_empty_list_as_none: bool = False, + ) -> Tuple[bool, int, Optional[Union[List[List[PIL.Image.Image]], List[List[str]]]]]: + """ + Normalize input_images into a two-level batch structure with per-sample lists: + - List[List[PIL.Image.Image]] or + - List[List[str]] (each str is an image path) + - Allowed per-sample "empty" markers: [] or None + + Rules: + - If input_images is None or []: + return (True, 0, None) + - If already in batch form such as [[image], [image,image], [], None] or [[str], [], [str,str], None], + return as is (optionally convert [] -> None if treat_empty_list_as_none=True). + - If List[PIL.Image.Image] / List[str] / List[None|PIL|str], wrap each non-None element as a single-image sample: + e.g. [img1, img2, None] -> [[img1], [img2], None] + - If single PIL.Image.Image / single str, wrap as [[item]] + - Otherwise attempt to iterate and collect valid items (PIL first, else paths) into a single batch sample. + + Returns: + (success, max_images_per_sample, wrapped_input_images) + - success: whether input_images is successfully wrapped + - max_images_per_sample: max number of images in any sample of the batch + - wrapped_input_images: List[List[PIL.Image.Image]] or List[List[str]] or None + """ + + # Case 0: input is None or empty + if input_images is None: + return True, 0, None + try: + # Safely check for emptiness without assuming it is a sequence + if hasattr(input_images, "__len__") and len(input_images) == 0: + return True, 0, None + except TypeError: + # If __len__ raises, ignore here; further logic will handle it + pass + + def is_pil_image(x: Any) -> bool: + return isinstance(x, Image.Image) + + def is_path(x: Any) -> bool: + return isinstance(x, str) + + def is_list_of_pil_images(x: Any) -> bool: + return isinstance(x, list) and all(is_pil_image(i) for i in x) + + def is_list_of_paths(x: Any) -> bool: + return isinstance(x, list) and all(is_path(i) for i in x) + + def is_list_of_list_of_pil_images(x: Any) -> bool: + return isinstance(x, list) and len(x) > 0 and all(is_list_of_pil_images(i) for i in x) + + def is_list_of_list_of_paths(x: Any) -> bool: + return isinstance(x, list) and len(x) > 0 and all(is_list_of_paths(i) for i in x) + + def is_batch_two_level_with_none(x: Any) -> bool: + """ + Accept batch-shaped inputs where each sample is: + - None (represents no image) + - [] (empty sample, can be converted to None if treat_empty_list_as_none=True) + - List[PIL.Image.Image] or List[str] + """ + if not isinstance(x, list) or len(x) == 0: + return False + for sample in x: + if sample is None: + continue + if isinstance(sample, list): + if len(sample) == 0: + continue + # Allow mixed PIL/str but all elements must be either PIL or str + all_pil = all(is_pil_image(i) for i in sample) + all_str = all(is_path(i) for i in sample) + if not (all_pil or all_str): + return False + else: + # Non-list, non-None found => not batch two-level + return False + return True + + # Case 1: already in normalized batch form (with None/[] allowed) + if is_batch_two_level_with_none(input_images): + wrapped = list(input_images) # shallow copy + # Optionally convert empty lists to None per sample + if treat_empty_list_as_none: + for idx, sample in enumerate(wrapped): + if isinstance(sample, list) and len(sample) == 0: + wrapped[idx] = None + max_len = 0 + for sample in wrapped: + if isinstance(sample, list): + max_len = max(max_len, len(sample)) + return True, max_len, wrapped + + # Case 2: List[PIL.Image.Image] -> single batch + if is_list_of_pil_images(input_images): + wrapped = [input_images] + max_len = len(input_images) + return True, max_len, wrapped + + # Case 2b: List[str] (paths) -> single batch + if is_list_of_paths(input_images): + wrapped = [input_images] + max_len = len(input_images) + return True, max_len, wrapped + + # Case 2c: Flat batch where elements can be PIL/str/None + if isinstance(input_images, list) and all( + (is_pil_image(x) or is_path(x) or x is None or (isinstance(x, list))) for x in input_images + ): + wrapped: List[Optional[List[Any]]] = [] + max_len = 0 + for item in input_images: + if item is None: + wrapped.append(None) + elif is_pil_image(item) or is_path(item): + wrapped.append([item]) + max_len = max(max_len, 1) + elif isinstance(item, list): + # Clean sublist: keep only PIL or str + pil_sub = [i for i in item if is_pil_image(i)] + str_sub = [i for i in item if is_path(i)] + if len(pil_sub) > 0 and len(str_sub) == 0: + wrapped.append(pil_sub) + max_len = max(max_len, len(pil_sub)) + elif len(str_sub) > 0 and len(pil_sub) == 0: + wrapped.append(str_sub) + max_len = max(max_len, len(str_sub)) + else: + # Empty or mixed invalid -> treat as empty + wrapped.append(None if treat_empty_list_as_none else []) + else: + # Unknown element -> treat as empty + wrapped.append(None if treat_empty_list_as_none else []) + # If all are None and we prefer None, keep as batch-level structure per spec + return True, max_len, wrapped + + # Case 3: single PIL.Image.Image -> [[image]] + if is_pil_image(input_images): + wrapped = [[input_images]] + return True, 1, wrapped + + # Case 3b: single path str -> [[path]] + if is_path(input_images): + wrapped = [[input_images]] + return True, 1, wrapped + + # Case 4: other types -> try to interpret as iterable and collect images/paths as a single sample + try: + as_list = list(input_images) + except TypeError: + # Cannot iterate; normalization fails + return False, 0, None + + pil_items = [x for x in as_list if is_pil_image(x)] + path_items = [x for x in as_list if is_path(x)] + + if pil_items: + # Treat all collected PIL images as one sample in a single batch + wrapped = [pil_items] + max_len = len(pil_items) + return True, max_len, wrapped + + if path_items: + # Treat all collected paths as one sample in a single batch + wrapped = [path_items] + max_len = len(path_items) + return True, max_len, wrapped + + # No valid entries found + return False, 0, None + + def _get_instruction_feature_embeds( + self, + instruction: Union[str, List[str]], + input_pil_images: Optional[List[List[PIL.Image.Image]]], + device: Optional[torch.device] = None, + max_sequence_length: int = 256, + truncate_instruction_sequence: bool = False, + max_vlm_input_pil_pixels: Optional[Union[int, List[int]]] = None, + max_vlm_input_pil_side_length: Optional[int] = None, + system_prompt_follows_task_type: bool = False, + task_type: str = "ti2i", + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get interleaved instruction embeddings from VLM (self.mllm), aligned with training: + - Build VLM inputs via processor.apply_chat_template (images + text) + - Optionally remove vision-token features by truncation + - Return last layer or last-N layers and the corresponding attention mask + + Args: + instruction: The instruction or list of instructions to encode. + input_pil_images: A list of PIL images to be included in the prompt (TI2I/I2I). + device: The device to place the embeddings on. If None, uses the pipeline's device. + max_sequence_length: Maximum sequence length for tokenization. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The instruction embeddings tensor (or list of last-N layers) + - The attention mask tensor + + Raises: + Warning: If the input text is truncated due to sequence length limitations. + """ + device = device or self._execution_device + instruction = [instruction] if isinstance(instruction, str) else instruction + batch_size = len(instruction) + + # Build prompts with images+text. + # input_pil_images: Optional[List[List[PIL.Image.Image]]], outer length == batch_size, + # inner list contains K_i images for sample i. + prompts: List[list] = [] + processed_samples: List[Optional[List[PIL.Image.Image]]] = [] + + if input_pil_images is None or len(input_pil_images) == 0: + # No images for any sample -> pass None per sample + processed_samples = [None for _ in range(batch_size)] # type: List[Optional[List[PIL.Image.Image]]] + else: + # Validate shape: outer length must match batch_size + assert isinstance(input_pil_images, list) and len(input_pil_images) == batch_size, ( + "When provided, `input_pil_images` must be a List[List[PIL.Image.Image]] with len == batch size." + ) + for imgs in input_pil_images: + if imgs and len(imgs) > 0: + # Determine per-sample max_pixels as in dataset logic: + # - If max_vlm_input_pil_pixels is a list/tuple, require len >= K_i and take index K_i-1 + # - If it's an int, use it for all images in this sample + # - If None, do not constrain by pixels + max_pixels_i: Optional[int] = None + if isinstance(max_vlm_input_pil_pixels, (list, tuple)): + assert len(max_vlm_input_pil_pixels) >= len(imgs), ( + "`max_vlm_input_pil_pixels` length must be >= number of images in each sample" + ) + max_pixels_i = int(max_vlm_input_pil_pixels[len(imgs) - 1]) + elif isinstance(max_vlm_input_pil_pixels, int): + max_pixels_i = max_vlm_input_pil_pixels + else: + max_pixels_i = None + proc = self.preprocess_vlm_input_pil_images( + imgs, # List[PIL.Image.Image] for this sample + max_pixels=max_pixels_i, + max_side_length=max_vlm_input_pil_side_length, + ) + processed_samples.append(proc) + else: + # Empty inner list -> treat as no images for this sample + processed_samples.append(None) + + # Build the batched prompts; for each sample i, pass instruction[i] and its image list (or None) + for i in range(batch_size): + sample_imgs: Optional[List[PIL.Image.Image]] = None + if processed_samples and i < len(processed_samples): + sample_imgs = processed_samples[i] + # _apply_chat_template expects (instruction: str, input_pil_images: Optional[List[PIL.Image.Image]]) + prompts.append( + self._apply_chat_template( + instruction[i], + sample_imgs, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + ) + + # Processor produces dict with 'input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw' + vlm_inputs = self.processor.apply_chat_template( + prompts, + padding="longest", + max_length=max_sequence_length, + truncation=truncate_instruction_sequence, + padding_side="right", + return_tensors="pt", + tokenize=True, + return_dict=True, + ) + for k in vlm_inputs.keys(): + if isinstance(vlm_inputs[k], torch.Tensor): + vlm_inputs[k] = vlm_inputs[k].to(device) + + instruction_mask = vlm_inputs["attention_mask"] + + num_instruction_feature_layers = self.transformer.instruction_feature_configs.get( + "num_instruction_feature_layers", 1 + ) + final_instruction_mask = instruction_mask + + with torch.no_grad(): + if num_instruction_feature_layers > 1: + text_encoder_outputs = self.mllm(**vlm_inputs, output_hidden_states=True, return_dict=True) + all_hidden_states = ( + text_encoder_outputs.hidden_states + ) # Tuple of [B, extended_seq_len, text_hidden_dim] + instruction_feats = list(all_hidden_states)[ + -num_instruction_feature_layers: + ] # Convert to list for model processing + else: + instruction_feats = self.mllm(**vlm_inputs).last_hidden_state + + if self.mllm is not None: + dtype = self.mllm.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if isinstance(instruction_feats, (list, tuple)): + final_instruction_feats = [feat.to(dtype=dtype, device=device) for feat in instruction_feats] + else: + final_instruction_feats = instruction_feats.to(dtype=dtype, device=device) + # Keep the attention mask on the same execution device as the features + # before passing both into the diffusion transformer. + final_instruction_mask = final_instruction_mask.to(device=device) + + return final_instruction_feats, final_instruction_mask + + def _apply_chat_template( + self, + instruction: str, + input_pil_images: Optional[List[PIL.Image.Image]] = None, + system_prompt_follows_task_type: bool = False, + task_type: str = "ti2i", + ): + """ + Build chat template content with interleaved text and images. + If `system_prompt_follows_task_type` is True, the system prompt will be selected based on the task type. + If `system_prompt_follows_task_type` is False, the system prompt will be selected based on the input images. + Returns the prompt structure (list of messages with typed contents). + """ + user_text_content = [{"type": "text", "text": instruction}] + + if system_prompt_follows_task_type: + if task_type.lower() == "t2i": + system_prompt = self.SYSTEM_PROMPT_4_T2I + else: + system_prompt = self.SYSTEM_PROMPT_4_TI2I + else: + # Pick system prompt adaptively based on the input images and instruction. + if input_pil_images is None or len(input_pil_images) == 0: + if instruction is None or len(instruction.strip()) == 0: + system_prompt = self.SYSTEM_PROMPT_DROP + else: + system_prompt = self.SYSTEM_PROMPT_4_T2I + else: + if instruction is None or len(instruction.strip()) == 0: + system_prompt = self.SYSTEM_PROMPT_4_I2I + else: + system_prompt = self.SYSTEM_PROMPT_4_TI2I + + system_role = { + "role": "system", + "content": [{"type": "text", "text": system_prompt}], + } + if input_pil_images is None or len(input_pil_images) == 0: + prompt = [system_role, {"role": "user", "content": user_text_content}] + else: + images_content = [{"type": "image", "image": pil_img} for pil_img in input_pil_images] + prompt = [ + system_role, + {"role": "user", "content": images_content + user_text_content}, + ] + return prompt + + def _reshape_embeds_and_mask(self, embeds, mask, num_images_per_instruction): + """ + To duplicate text embeddings and attention mask for each generation per instruction, using mps friendly method + """ + if isinstance(embeds, (list, tuple)): + batch_size, seq_len, _ = embeds[0].shape + reshaped_embeds = [] + for embed in embeds: + embed = embed.repeat(1, num_images_per_instruction, 1) + reshaped_embeds.append(embed.view(batch_size * num_images_per_instruction, seq_len, -1)) + else: + batch_size, seq_len, _ = embeds.shape + embeds = embeds.repeat(1, num_images_per_instruction, 1) + reshaped_embeds = embeds.view(batch_size * num_images_per_instruction, seq_len, -1) + + mask = mask.repeat(num_images_per_instruction, 1) + reshaped_mask = mask.view(batch_size * num_images_per_instruction, -1) + + return batch_size, seq_len, reshaped_embeds, reshaped_mask + + def _get_max_image_pixels( + self, + num_images: int, + max_input_image_pixels: Union[int, list, tuple] = 1024 * 1024, + ): + + if (num_images <= 0) or (not max_input_image_pixels): + return 1024 * 1024 + + if isinstance(max_input_image_pixels, (list, tuple)): + assert len(max_input_image_pixels) >= num_images, ( + f"`len(max_input_image_pixels)` should be >= number of input images per sample, i.e., {num_images}" + ) + max_pixels = max_input_image_pixels[num_images - 1] + else: + max_pixels = max_input_image_pixels + + return max_pixels + + def encode_instruction( + self, + instruction: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_instruction: Optional[Union[str, List[str]]] = None, + input_images: Optional[Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, + use_input_images_4_neg_instruct: bool = False, + use_input_images_4_empty_instruct: bool = False, + max_vlm_input_pil_pixels: Optional[Union[int, List[int]]] = 384 * 384, + max_vlm_input_pil_side_length: Optional[int] = 384 * 2, + num_images_per_instruction: int = 1, + device: Optional[torch.device] = None, + instruction_embeds: Optional[torch.Tensor] = None, + negative_instruction_embeds: Optional[torch.Tensor] = None, + instruction_attention_mask: Optional[torch.Tensor] = None, + negative_instruction_attention_mask: Optional[torch.Tensor] = None, + # For double guidance + empty_instruction: Optional[Union[str, List[str]]] = " ", + empty_instruction_embeds: Optional[torch.Tensor] = None, + empty_instruction_attention_mask: Optional[torch.Tensor] = None, + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide: bool = False, + use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide: bool = False, + max_sequence_length: int = 256, + truncate_instruction_sequence: bool = False, + system_prompt_follows_task_type: bool = False, + task_type: str = "ti2i", + ) -> Tuple[torch.Tensor, ...]: + r""" + Encodes the instruction into text encoder hidden states. + + Args: + instruction (`str` or `List[str]`, *optional*): + instruction to be encoded + negative_instruction (`str` or `List[str]`, *optional*): + The instruction not to guide the image generation. If not defined, one has to pass `negative_instruction_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_instruction (`int`, *optional*, defaults to 1): + number of images that should be generated per instruction + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + instruction_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* instruction weighting. If not + provided, text embeddings will be generated from `instruction` input argument. + negative_instruction_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the instruction. + """ + device = device or self._execution_device + + instruction = [instruction] if isinstance(instruction, str) else instruction + # Chat template with images is handled inside _get_instruction_feature_embeds + batch_size = len(instruction) + + if instruction_embeds is None: + instruction_embeds, instruction_attention_mask = self._get_instruction_feature_embeds( + instruction=instruction, + input_pil_images=input_images, + device=device, + max_sequence_length=max_sequence_length, + truncate_instruction_sequence=truncate_instruction_sequence, + max_vlm_input_pil_pixels=max_vlm_input_pil_pixels, + max_vlm_input_pil_side_length=max_vlm_input_pil_side_length, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + + batch_size, seq_len, _ = instruction_embeds.shape + + batch_size, seq_len, instruction_embeds, instruction_attention_mask = self._reshape_embeds_and_mask( + instruction_embeds, + instruction_attention_mask, + num_images_per_instruction, + ) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_instruction_embeds is None: + negative_instruction = negative_instruction if negative_instruction is not None else "" + + # Normalize str to list + negative_instruction = ( + batch_size * [negative_instruction] if isinstance(negative_instruction, str) else negative_instruction + ) + + if instruction is not None and type(instruction) is not type(negative_instruction): + raise TypeError( + f"`negative_instruction` should be the same type to `instruction`, but got {type(negative_instruction)} !=" + f" {type(instruction)}." + ) + # elif isinstance(negative_instruction, str): # not needed since negative_instruction is already a list + + elif batch_size != len(negative_instruction): + raise ValueError( + f"`negative_instruction`: {negative_instruction} has batch size {len(negative_instruction)}, but `instruction`:" + f" {instruction} has batch size {batch_size}. Please make sure that passed `negative_instruction` matches" + " the batch size of `instruction`." + ) + negative_instruction_embeds, negative_instruction_attention_mask = self._get_instruction_feature_embeds( + instruction=negative_instruction, + input_pil_images=input_images if use_input_images_4_neg_instruct else None, + device=device, + max_sequence_length=max_sequence_length, + truncate_instruction_sequence=truncate_instruction_sequence, + max_vlm_input_pil_pixels=max_vlm_input_pil_pixels if use_input_images_4_neg_instruct else None, + max_vlm_input_pil_side_length=max_vlm_input_pil_side_length + if use_input_images_4_neg_instruct + else None, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + + ( + batch_size, + seq_len, + negative_instruction_embeds, + negative_instruction_attention_mask, + ) = self._reshape_embeds_and_mask( + negative_instruction_embeds, + negative_instruction_attention_mask, + num_images_per_instruction, + ) + + if ( + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide + or use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide + ): + if do_classifier_free_guidance and (empty_instruction_embeds is None): + empty_instruction = empty_instruction if empty_instruction is not None else [" "] * batch_size + + empty_instruction = ( + batch_size * [empty_instruction] if isinstance(empty_instruction, str) else empty_instruction + ) + + if instruction is not None and type(instruction) is not type(empty_instruction): + raise TypeError( + f"`empty_instruction` should be the same type as `instruction`, but got {type(empty_instruction)} !=" + f" {type(instruction)}." + ) + + elif batch_size != len(empty_instruction): + raise ValueError( + f"`empty_instruction`: {empty_instruction} has batch size {len(empty_instruction)}, but `instruction`:" + f" {instruction} has batch size {batch_size}. Please make sure that passed `empty_instruction` matches" + " the batch size of `instruction`." + ) + + empty_instruction_embeds, empty_instruction_attention_mask = self._get_instruction_feature_embeds( + instruction=empty_instruction, + input_pil_images=input_images if use_input_images_4_empty_instruct else None, + device=device, + max_sequence_length=max_sequence_length, + truncate_instruction_sequence=truncate_instruction_sequence, + max_vlm_input_pil_pixels=max_vlm_input_pil_pixels if use_input_images_4_empty_instruct else None, + max_vlm_input_pil_side_length=max_vlm_input_pil_side_length + if use_input_images_4_empty_instruct + else None, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + ( + batch_size, + seq_len, + empty_instruction_embeds, + empty_instruction_attention_mask, + ) = self._reshape_embeds_and_mask( + empty_instruction_embeds, + empty_instruction_attention_mask, + num_images_per_instruction, + ) + + return ( + instruction_embeds, + instruction_attention_mask, + negative_instruction_embeds, + negative_instruction_attention_mask, + empty_instruction_embeds, + empty_instruction_attention_mask, + ) + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def text_guidance_scale(self): + return self._text_guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def empty_instruction_guidance_scale(self): + return self._empty_instruction_guidance_scale + + @property + def cfg_range(self): + return self._cfg_range + + @torch.no_grad() + def __call__( + self, + instruction: Optional[Union[str, List[str]]] = None, + negative_instruction: Optional[Union[str, List[str]]] = None, + instruction_embeds: Optional[torch.FloatTensor] = None, + negative_instruction_embeds: Optional[torch.FloatTensor] = None, + instruction_attention_mask: Optional[torch.LongTensor] = None, + negative_instruction_attention_mask: Optional[torch.LongTensor] = None, + # For double guidance + empty_instruction: Optional[Union[str, List[str]]] = " ", + empty_instruction_embeds: Optional[torch.Tensor] = None, + empty_instruction_attention_mask: Optional[torch.Tensor] = None, + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide: bool = False, + use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide: bool = False, + max_sequence_length: int = 1280, + truncate_instruction_sequence: bool = False, + input_images: Optional[Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, + use_input_images_4_neg_instruct: bool = False, + use_input_images_4_empty_instruct: bool = False, + max_vlm_input_pil_pixels: Optional[Union[int, List[int]]] = 384 * 384, + max_vlm_input_pil_side_length: Optional[int] = 384 * 2, + num_images_per_instruction: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + max_input_image_pixels: Union[int, list, tuple] = 2048 * 2048, + max_input_image_side_length: int = 2048 * 2, + align_res: bool = True, + num_inference_steps: int = 50, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + empty_instruction_guidance_scale: float = 0.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + system_prompt_follows_task_type: bool = False, + ### Momentum Config + use_boosted_orthogonal_guidance: bool = False, + text_momentum_rolling_sum_momentum_weight: float = 0.1, + text_momentum_rolling_sum_current_weight: float = 0.9, + image_momentum_rolling_sum_momentum_weight: float = 0.1, + image_momentum_rolling_sum_current_weight: float = 0.9, + empty_momentum_rolling_sum_momentum_weight: float = 0.1, + empty_momentum_rolling_sum_current_weight: float = 0.9, + bog_mu: float = 0.1, + bog_range=[0.0, 1.0], + bog_interval: int = 3, + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + step_func=None, + ): + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._text_guidance_scale = text_guidance_scale + self._image_guidance_scale = image_guidance_scale + self._empty_instruction_guidance_scale = empty_instruction_guidance_scale + + self._cfg_range = cfg_range + self._attention_kwargs = attention_kwargs + + # 1. Define call parameters + if instruction is not None and isinstance(instruction, str): + batch_size = 1 + instruction = [instruction] + elif instruction is not None and isinstance(instruction, (list, tuple)): + batch_size = len(instruction) + else: + batch_size = instruction_embeds.shape[0] + + # Resolve the device the pipeline's modules live on. With offloading enabled the base + # class returns the right execution device; otherwise it reflects the last `.to(...)`. + device = self._execution_device + + max_images_per_sample = 0 + if input_images: + success, max_images_per_sample, input_images = self._check_and_wrap_input_images(input_images) + + task_type = self._get_task_type_by_input_images(input_images) + + # 2. Encode input instruction + ( + instruction_embeds, + instruction_attention_mask, + negative_instruction_embeds, + negative_instruction_attention_mask, + empty_instruction_embeds, + empty_instruction_attention_mask, + ) = self.encode_instruction( + instruction, + self.text_guidance_scale > 1.0, + negative_instruction=negative_instruction, + input_images=input_images, + use_input_images_4_neg_instruct=use_input_images_4_neg_instruct, + use_input_images_4_empty_instruct=use_input_images_4_empty_instruct, + max_vlm_input_pil_pixels=max_vlm_input_pil_pixels, + max_vlm_input_pil_side_length=max_vlm_input_pil_side_length, + num_images_per_instruction=num_images_per_instruction, + device=device, + instruction_embeds=instruction_embeds, + negative_instruction_embeds=negative_instruction_embeds, + instruction_attention_mask=instruction_attention_mask, + negative_instruction_attention_mask=negative_instruction_attention_mask, + # For double guidance + empty_instruction=empty_instruction, + empty_instruction_embeds=empty_instruction_embeds, + empty_instruction_attention_mask=empty_instruction_attention_mask, + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide=use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide, + use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide=use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide, + max_sequence_length=max_sequence_length, + truncate_instruction_sequence=truncate_instruction_sequence, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + + # Put ref_latents here before encoding instruction. + dtype = self.vae.dtype + + # 3. Prepare control image + ref_latents = self.prepare_image( + images=input_images, + batch_size=batch_size, + num_images_per_instruction=num_images_per_instruction, + max_input_image_pixels=max_input_image_pixels, + max_side_length=max_input_image_side_length, + device=device, + dtype=dtype, + ) + + input_images, width, height, ori_width, ori_height = self._resolve_output_and_original_size( + input_images=input_images, + ref_latents=ref_latents, + align_res=align_res, + width=width, + height=height, + max_input_image_pixels=max_input_image_pixels, + max_images_per_sample=max_images_per_sample, + img_scale_num=self.vae_scale_factor * 2, + ) + + if len(input_images) == 0: + self._image_guidance_scale = 1 + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_instruction, + latent_channels, + height, + width, + instruction_embeds.dtype, + device, + generator, + latents, + ) + + freqs_cis = BooguImageRotaryPosEmbed.get_freqs_cis( + self.transformer.config.axes_dim_rope, + self.transformer.config.axes_lens, + theta=10000, + ) + + image = self.processing( + latents=latents, + ref_latents=ref_latents, + instruction_embeds=instruction_embeds, + freqs_cis=freqs_cis, + negative_instruction_embeds=negative_instruction_embeds, + instruction_attention_mask=instruction_attention_mask, + negative_instruction_attention_mask=negative_instruction_attention_mask, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + device=device, + dtype=dtype, + step_func=step_func, + # For double guidance + empty_instruction_embeds=empty_instruction_embeds, + empty_instruction_attention_mask=empty_instruction_attention_mask, + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide=use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide, + use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide=use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide, + use_boosted_orthogonal_guidance=use_boosted_orthogonal_guidance, + tg_momentum_state=MomentumRollingSum( + momentum_weight=text_momentum_rolling_sum_momentum_weight, + current_weight=text_momentum_rolling_sum_current_weight, + ) + if use_boosted_orthogonal_guidance + else None, + ig_momentum_state=MomentumRollingSum( + momentum_weight=image_momentum_rolling_sum_momentum_weight, + current_weight=image_momentum_rolling_sum_current_weight, + ) + if use_boosted_orthogonal_guidance + else None, + eg_momentum_state=MomentumRollingSum( + momentum_weight=empty_momentum_rolling_sum_momentum_weight, + current_weight=empty_momentum_rolling_sum_current_weight, + ) + if use_boosted_orthogonal_guidance + else None, + bog_mu=bog_mu, + bog_range=bog_range, + bog_interval=bog_interval, + ) + + image = F.interpolate(image, size=(ori_height, ori_width), mode="bilinear") + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + else: + return FMPipelineOutput(images=image) + + def _resolve_output_and_original_size( + self, + input_images, + ref_latents: List[Union[List[torch.FloatTensor], None]], + align_res: bool, + width: int, + height: int, + max_input_image_pixels: Union[int, list, tuple], + max_images_per_sample: int, + img_scale_num: int = 16, + ) -> Tuple[List, int, int, int, int]: + if input_images is None: + input_images = [] + + if len(input_images) == 1 and align_res: + width, height = ( + ref_latents[0][0].shape[-1] * self.vae_scale_factor, + ref_latents[0][0].shape[-2] * self.vae_scale_factor, + ) + ori_width, ori_height = width, height + else: + ori_width, ori_height = width, height + + cur_pixels = height * width + + if isinstance(max_input_image_pixels, (list, tuple)): + if (input_images is not None) and (len(input_images) > 0) and max_images_per_sample > 0: + assert len(max_input_image_pixels) >= max_images_per_sample, ( + f"When `max_input_image_pixels` is a list or tuple, the length of it (here is {len(max_input_image_pixels)}) should be >= max number of input images in all the samples (here is {max_images_per_sample})." + ) + max_pixels = max_input_image_pixels[max_images_per_sample - 1] + else: + max_pixels = max_input_image_pixels[0] + else: + max_pixels = max_input_image_pixels + + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) + + height, width = ( + int(height * ratio) // img_scale_num * img_scale_num, + int(width * ratio) // img_scale_num * img_scale_num, + ) + + return input_images, width, height, ori_width, ori_height + + def _get_task_type_by_ref_latents(self, ref_latents: List[Union[List[torch.FloatTensor], None]]): + if not ref_latents: + return "t2i" + + if isinstance(ref_latents, (list, tuple)): + for x in ref_latents: + if x: + return "ti2i" + return "t2i" + + def _get_task_type_by_input_images(self, input_images: Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]]): + if not input_images: + return "t2i" + + if isinstance(input_images, (list, tuple)): + for x in input_images: + if x: + return "ti2i" + return "t2i" + + def _project_matrix( + self, + m0: torch.Tensor, # [B, C, H, W] # The delta: model_pred - model_pred_uncond + m1: torch.Tensor, # [B, C, H, W] # The conditional pred + dim: int = -2, + ): + """ + Project m0 onto m1 by treating each [H, W] slice as a matrix. + Args: + m0: Input tensor to be decomposed, shape [B, C, H, W]. + m1: Reference tensor that provides projection directions, shape [B, C, H, W]. + dim: Vector dimension to project along within each [H, W] matrix. + dim = -2 projects column vectors (along H), dim = -1 projects row vectors (along W). + Returns: + A tuple (m0_parallel, m0_orthogonal), both with shape [B, C, H, W]. + """ + dtype = m0.dtype + m0, m1 = m0.double(), m1.double() + b, c, h, w = m0.shape + # Only support projecting column vectors (dim=-2) or row vectors (dim=-1). + assert dim in (-1, -2), "dim must be -1 (rows) or -2 (columns)" + # Treat as a batch of matrices: [B*C, H, W] + m0_mat = m0.reshape(b * c, h, w) + m1_mat = m1.reshape(b * c, h, w) + # Normalize along the vector dimension selected by dim. + m1_unit = torch.nn.functional.normalize(m1_mat, dim=dim) + # Project each row/column vector of m0 onto the corresponding vector of m1. + m0_parallel = (m0_mat * m1_unit).sum(dim=dim, keepdim=True) * m1_unit + m0_orthogonal = m0_mat - m0_parallel + return m0_parallel.reshape(b, c, h, w).to(dtype), m0_orthogonal.reshape(b, c, h, w).to(dtype) + + def _newtonschulz5_batched(self, G: torch.Tensor, steps: int = 5, eps: float = 1e-7): + """ + Batched Newton-Schulz iteration. + + Accepts: + - (H, W) -> returns (H, W) + - (N, H, W) -> returns (N, H, W) + - (B, C, H, W) -> returns (B, C, H, W) + """ + a, b, c = (3.4445, -4.7750, 2.0315) + + orig_ndim = G.ndim + if orig_ndim == 2: + G3 = G.unsqueeze(0) # (1, H, W) + out_shape = None + elif orig_ndim == 3: + G3 = G # (N, H, W) + out_shape = None + elif orig_ndim == 4: + B, C, H, W = G.shape + G3 = G.reshape(B * C, H, W) # (N, H, W) + out_shape = (B, C, H, W) + else: + raise ValueError(f"Expected 2D/3D/4D tensor, got ndim={G.ndim}") + + # Match the original behavior: decide whether to transpose based on H/W + H, W = G3.shape[-2], G3.shape[-1] + + # Compute in bfloat16 (keeps the original logic) + X = G3.to(torch.bfloat16) + + # Normalize each matrix by its Frobenius norm: X /= (||X||_F + eps) + # Frobenius norm = sqrt(sum_ij X^2) + nrm = torch.linalg.norm(X, ord="fro", dim=(-2, -1)) # (N,) + X = X / (nrm.unsqueeze(-1).unsqueeze(-1) + eps) + + transposed = False + if H > W: + # Transpose the last two dims so we iterate on the "shorter" dimension first + X = X.transpose(-2, -1) # (N, W, H) + transposed = True + + # Newton–Schulz iterations (batched GEMMs) + for _ in range(steps): + A = X @ X.transpose(-2, -1) # (N, m, m) + Bm = b * A + c * (A @ A) # (N, m, m) + X = a * X + (Bm @ X) # (N, m, n) + + # Transpose back if we transposed at the beginning + if transposed: + X = X.transpose(-2, -1) + + # Restore original shape + if orig_ndim == 2: + return X.squeeze(0) + if out_shape is not None: + return X.reshape(out_shape) + return X + + def bog_norm(self, G: torch.Tensor) -> torch.Tensor: + """ + G: [..., H, W] + return: normalized tensor with same shape + """ + if G.dim() < 2: + raise ValueError("G must have at least 2 dims, got shape {}".format(tuple(G.shape))) + return self._newtonschulz5_batched(G) + + def calculate_boosted_orthogonal_guidance( + self, + model_pred: torch.Tensor, # [B, C, H, W] + model_pred_uncond: torch.Tensor, # [B, C, H, W] + momentum_state: MomentumRollingSum = None, + mu: float = 0.1, + ) -> torch.Tensor: + delta = model_pred - model_pred_uncond + + if momentum_state is not None: + delta = momentum_state.update(delta) + + ## Norm: Newton-Schulz Estimation. + + delta = self.bog_norm(delta) + + r = delta.shape[-2] * 1.0 + c = delta.shape[-1] * 1.0 + r_wei = r / (r + c + 1.0) + c_wei = c / (r + c + 1.0) + + delta_parallel_col, delta_orthogonal_col = self._project_matrix(delta, model_pred, dim=-2) + delta_parallel_row, delta_orthogonal_row = self._project_matrix(delta, model_pred, dim=-1) + + delta_bog = r_wei * (delta_orthogonal_row + mu * delta_parallel_row) + c_wei * ( + delta_orthogonal_col + mu * delta_parallel_col + ) + + return delta_bog + + def processing( + self, + latents, + ref_latents, + instruction_embeds, + freqs_cis, + negative_instruction_embeds, + instruction_attention_mask, + negative_instruction_attention_mask, + num_inference_steps, + timesteps, + device, + dtype, + step_func=None, + # For double guidance + empty_instruction_embeds=None, + empty_instruction_attention_mask=None, + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide=False, + use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide=False, + use_boosted_orthogonal_guidance: bool = False, + # Boosted Orthogonal Guidance Momentum State + tg_momentum_state: MomentumRollingSum = None, + ig_momentum_state: MomentumRollingSum = None, + eg_momentum_state: MomentumRollingSum = None, + bog_mu: float = 0.1, + bog_range=[0.0, 1.0], + bog_interval: int = 3, + ): + task_type = self._get_task_type_by_ref_latents(ref_latents) + + logger.info("[Pipeline Processing]: The current task_type: %s.", task_type) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # NOTE: Declare optional per-condition caches upfront for static analyzers. + # They are populated below depending on which acceleration path is enabled. + teacache_params_drop_ref = None + teacache_params_ref_empty_instruct = None + use_ref_empty_instruct_pred = ( + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide + or use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide + ) + + enable_teacache = self.transformer.enable_teacache or getattr( + self.transformer, "enable_teacache_for_all_layers", False + ) + self.transformer.enable_teacache = enable_teacache + if enable_teacache: + # Use different TeaCacheParams for different conditions + teacache_params = TeaCacheParams() + teacache_params_uncond = TeaCacheParams() + teacache_params_ref = TeaCacheParams() + if use_ref_empty_instruct_pred: + # For double-guidance variants that use an "empty" instruction embedding when predicting ref-image condition. + # Keep TeaCache state isolated per condition; do NOT reuse uncond/ref/cond params here. + teacache_params_ref_empty_instruct = TeaCacheParams() + # For TI2I image-only guidance branch (drop reference image, keep text condition). + # Keep TeaCache state isolated per condition; do NOT reuse uncond/ref/cond params here. + teacache_params_drop_ref = TeaCacheParams() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if enable_teacache: + teacache_params.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params + + model_pred = self.predict( + t=t, + latents=latents, + instruction_embeds=instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=instruction_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + text_guidance_scale = ( + self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + ) + image_guidance_scale = ( + self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + ) + empty_instruction_guidance_scale = ( + self.empty_instruction_guidance_scale + if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] + else 0.0 + ) + + if (task_type == "ti2i") and (text_guidance_scale > 1.0) and (image_guidance_scale > 1.0): # Checked + if enable_teacache: + teacache_params_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params_ref + + model_pred_drop_text = self.predict( + t=t, + latents=latents, + instruction_embeds=negative_instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=negative_instruction_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + if enable_teacache: + teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params_uncond + + model_pred_drop_all = self.predict( + t=t, + latents=latents, + instruction_embeds=negative_instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=negative_instruction_attention_mask, + ref_image_hidden_states=None, + ) + + if ( + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide + or use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide + ): + # Predict ref-image condition using an "empty" instruction embedding. + # IMPORTANT: This is a distinct condition from `model_pred_drop_text` (neg-text + ref), + # so we must keep TeaCache state isolated to avoid cache pollution. + if enable_teacache: + assert teacache_params_ref_empty_instruct is not None + teacache_params_ref_empty_instruct.is_first_or_last_step = ( + i == 0 or i == len(timesteps) - 1 + ) + self.transformer.teacache_params = teacache_params_ref_empty_instruct + + model_pred_drop_text_empty_instruct = self.predict( + t=t, + latents=latents, + instruction_embeds=empty_instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=empty_instruction_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + model_pred_drop_text_pos = model_pred_drop_text + model_pred_drop_text_neg = model_pred_drop_text + + if use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide: + model_pred_drop_text_pos = model_pred_drop_text_empty_instruct + if use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide: + model_pred_drop_text_neg = model_pred_drop_text_empty_instruct + + if ( + use_boosted_orthogonal_guidance + and (bog_range[0] <= t <= bog_range[1]) + and (i % bog_interval == 0) + ): + delta_text = self.calculate_boosted_orthogonal_guidance( + model_pred=model_pred, + model_pred_uncond=model_pred_drop_text, + momentum_state=tg_momentum_state, + mu=bog_mu, + ) + delta_image = self.calculate_boosted_orthogonal_guidance( + model_pred=model_pred_drop_text, + model_pred_uncond=model_pred_drop_all, + momentum_state=ig_momentum_state, + mu=bog_mu, + ) + else: + delta_text = model_pred - model_pred_drop_text + delta_image = model_pred_drop_text - model_pred_drop_all + + if (empty_instruction_guidance_scale != 0.0) and ( + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide + != use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide + ): + if ( + use_boosted_orthogonal_guidance + and (bog_range[0] <= t <= bog_range[1]) + and (i % bog_interval == 0) + ): + delta_empty_instruct = self.calculate_boosted_orthogonal_guidance( + model_pred=model_pred_drop_text_pos, + model_pred_uncond=model_pred_drop_text_neg, + momentum_state=eg_momentum_state, + mu=bog_mu, + ) + else: + delta_empty_instruct = model_pred_drop_text_pos - model_pred_drop_text_neg + + model_pred = ( + model_pred + + (text_guidance_scale - 1) * delta_text + + (image_guidance_scale - 1) * delta_image + + empty_instruction_guidance_scale * delta_empty_instruct + ) + + else: + model_pred = ( + model_pred + + (text_guidance_scale - 1) * delta_text + + (image_guidance_scale - 1) * delta_image + ) + + elif (task_type == "ti2i") and (text_guidance_scale > 1.0): # checked + # TI2I text-only guidance (keep reference-image condition, guide only by text): + + if enable_teacache: + # Keep TeaCache state isolated per condition (ref-only here). + teacache_params_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params_ref + + model_pred_drop_text = self.predict( + t=t, + latents=latents, + instruction_embeds=negative_instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=negative_instruction_attention_mask, + ref_image_hidden_states=ref_latents, + ) + if ( + use_boosted_orthogonal_guidance + and (bog_range[0] <= t <= bog_range[1]) + and (i % bog_interval == 0) + ): + delta_text = self.calculate_boosted_orthogonal_guidance( + model_pred=model_pred, + model_pred_uncond=model_pred_drop_text, + momentum_state=tg_momentum_state, + mu=bog_mu, + ) + else: + delta_text = model_pred - model_pred_drop_text + + # Equivalent: model_pred = model_pred_drop_text + text_guidance_scale * (model_pred - model_pred_drop_text) + model_pred = model_pred + (text_guidance_scale - 1) * delta_text + + elif (task_type == "ti2i") and (image_guidance_scale > 1.0): # Checked + # TI2I image-only guidance (keep text condition, guide only by reference image): + # + # IMPORTANT: + # - TeaCache caches previous residuals per condition; we must not reuse the drop_all/drop_text TeaCache state here. + + if enable_teacache: + assert teacache_params_drop_ref is not None + teacache_params_drop_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params_drop_ref + + model_pred_drop_image = self.predict( + t=t, + latents=latents, + instruction_embeds=instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=instruction_attention_mask, + ref_image_hidden_states=None, + ) + if ( + use_boosted_orthogonal_guidance + and (bog_range[0] <= t <= bog_range[1]) + and (i % bog_interval == 0) + ): + delta_image = self.calculate_boosted_orthogonal_guidance( + model_pred=model_pred, + model_pred_uncond=model_pred_drop_image, + momentum_state=ig_momentum_state, + mu=bog_mu, + ) + else: + delta_image = model_pred - model_pred_drop_image + + # Equivalent: model_pred = model_pred_drop_image + image_guidance_scale * (model_pred - model_pred_drop_image) + model_pred = model_pred + (image_guidance_scale - 1) * delta_image + + elif text_guidance_scale > 1.0: # Checked + if enable_teacache: + teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 + self.transformer.teacache_params = teacache_params_uncond + + model_pred_drop_all = self.predict( + t=t, + latents=latents, + instruction_embeds=negative_instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=negative_instruction_attention_mask, + ref_image_hidden_states=None, + ) + + if ( + use_boosted_orthogonal_guidance + and (bog_range[0] <= t <= bog_range[1]) + and (i % bog_interval == 0) + ): + delta_text = self.calculate_boosted_orthogonal_guidance( + model_pred=model_pred, + model_pred_uncond=model_pred_drop_all, + momentum_state=tg_momentum_state, + mu=bog_mu, + ) + else: + delta_text = model_pred - model_pred_drop_all + + # Equivalent: model_pred = model_pred_drop_all + text_guidance_scale * (model_pred - model_pred_drop_all) + model_pred = model_pred + (text_guidance_scale - 1) * delta_text + + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + + latents = latents.to(dtype=dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if step_func is not None: + step_func(i, self._num_timesteps) + + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return image + + def predict( + self, + t, + latents, + instruction_embeds, + freqs_cis, + instruction_attention_mask, + ref_image_hidden_states, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + batch_size, num_channels_latents, height, width = latents.shape + + optional_kwargs = {} + if "ref_image_hidden_states" in set(inspect.signature(self.transformer.forward).parameters.keys()): + optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states + + model_pred = self.transformer( + latents, + timestep, + instruction_embeds, + freqs_cis, + instruction_attention_mask, + **optional_kwargs, + ) + return model_pred diff --git a/src/diffusers/pipelines/boogu/pipeline_boogu_turbo.py b/src/diffusers/pipelines/boogu/pipeline_boogu_turbo.py new file mode 100644 index 000000000000..c7304be2a0ac --- /dev/null +++ b/src/diffusers/pipelines/boogu/pipeline_boogu_turbo.py @@ -0,0 +1,1228 @@ +""" +Boogu-Image-Turbo (DMD few-step) pipeline. + +This module implements the DMD student few-step inference path as a standalone +`DiffusionPipeline` subclass. Per `.ai/pipelines.md` gotcha #4, each pipeline +variant lives in its own file with its own class (duplicated `__call__`, no +subclassing of another pipeline class); shared private utilities are reused via +`# Copied from` annotations so `make fix-copies` keeps them in sync with +`BooguImagePipeline`. + +The DMD path is pure text-to-image: it does not use the scheduler, reference +images, SDEdit, or classifier-free guidance. It builds its own sigma schedule, +runs `predict` -> renoise per step, then decodes the latents. + +# Copyright (C) 2026 Boogu Team. +# Licensed under the Apache License, Version 2.0 (the "License"). +""" + +from __future__ import annotations + +import inspect +from typing import Any, List, Optional, Tuple, Union + +import PIL.Image +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers.transformer_boogu import BooguImageRotaryPosEmbed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor + +from ...models.transformers import BooguImageTransformer2DModel +from .image_processor import BooguImageProcessor +from .pipeline_boogu import FMPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class BooguImageTurboPipeline(DiffusionPipeline): + """Standalone DMD student few-step text-to-image pipeline. + + Shares components and private utilities with `BooguImagePipeline` (kept in + sync via `# Copied from`), but runs a pure-T2I DMD denoising loop instead of + the scheduler-driven, guidance-capable loop. The DMD path requires pure T2I + inputs and no classifier-free guidance (`text_guidance_scale == + image_guidance_scale == 1.0`, `empty_instruction_guidance_scale == 0.0`). + """ + + model_cpu_offload_seq = "mllm->transformer->vae" + + def __init__( + self, + transformer: BooguImageTransformer2DModel, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + mllm: Qwen3VLForConditionalGeneration, + processor: Qwen3VLProcessor, + ) -> None: + """ + Initialize the Boogu-Image-Turbo pipeline. + + Args: + transformer: Boogu transformer denoiser for latent prediction. + vae: Autoencoder used for latent/image encoding and decoding. + scheduler: Diffusion scheduler (unused by the DMD path, registered for parity). + mllm: Multimodal language model used to encode instructions. + processor: Processor paired with the MLLM for text/image inputs. + """ + # Defer setting pipeline attributes until after super().__init__, + # to avoid accessing self.config before it's created by Diffusers base class. + if hasattr(mllm, "lm_head"): + # Use the inner model of the instruction encoder as the encoder backbone. + mllm = mllm.model + + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor, + ) + + # Now it is safe to set additional attributes + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = BooguImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True) + self.default_sample_size = 128 + + # System prompts matching dataset logic (specific to this pipeline) + + self.SYSTEM_PROMPT_4_TI2I_UNIFIED = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate." + self.SYSTEM_PROMPT_4_T2I_UNIFIED = "You are a helpful assistant that generates high-quality images based on user instructions. The instructions are as follows." + + self.SYSTEM_PROMPT_4_T2I = self.SYSTEM_PROMPT_4_T2I_UNIFIED + self.SYSTEM_PROMPT_DROP = ( + self.SYSTEM_PROMPT_4_TI2I_UNIFIED + ) # This is for empty negative instruction for image guidance in double guidance. + self.SYSTEM_PROMPT_4_TI2I = self.SYSTEM_PROMPT_4_TI2I_UNIFIED + self.SYSTEM_PROMPT_4_I2I = self.SYSTEM_PROMPT_4_TI2I_UNIFIED + + # ------------------------------------------------------------------ # + # DMD helpers (turbo-specific) # + # ------------------------------------------------------------------ # + def _build_dmd_student_sigmas( + self, + num_inference_steps: int, + device: torch.device, + dtype: torch.dtype, + conditioning_sigma: float, + timesteps: Optional[List[float]] = None, + ) -> torch.Tensor: + if timesteps is not None: + sigmas = torch.as_tensor(timesteps, device=device, dtype=dtype) + if sigmas.ndim != 1 or sigmas.numel() == 0: + raise ValueError("DMD inference timesteps must be a non-empty 1D sequence.") + if sigmas.max().item() > 1.0: + sigmas = sigmas / 1000.0 + return sigmas + + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be >= 1 for DMD student inference.") + + return torch.linspace( + conditioning_sigma, + 1.0, + num_inference_steps + 1, + device=device, + dtype=dtype, + )[:-1] + + def _predict_dmd_student_step( + self, + latents: torch.FloatTensor, + sigma: float, + instruction_embeds: torch.FloatTensor, + freqs_cis: torch.FloatTensor, + instruction_attention_mask: torch.Tensor, + ) -> torch.FloatTensor: + model_pred = self.predict( + t=torch.tensor(sigma, device=latents.device, dtype=latents.dtype), + latents=latents, + instruction_embeds=instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=instruction_attention_mask, + ref_image_hidden_states=None, + ) + + sigma_expanded = torch.full( + (latents.shape[0], 1, 1, 1), + sigma, + device=latents.device, + dtype=latents.dtype, + ) + return latents + (1 - sigma_expanded) * model_pred + + def _renoise_dmd_latents( + self, + latents: torch.FloatTensor, + sigma: float, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ) -> torch.FloatTensor: + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + sigma_expanded = torch.full( + (latents.shape[0], 1, 1, 1), + sigma, + device=latents.device, + dtype=latents.dtype, + ) + return (1 - sigma_expanded) * noise + sigma_expanded * latents + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: Union[torch.device, str], + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Prepare the initial latents for the diffusion process. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of channels in the latent space. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type of the latents. + device: The device to place the latents on. + generator: The random number generator to use. + latents: Optional pre-computed latents to use instead of random initialization. + + Returns: + torch.FloatTensor: The prepared latents tensor. + """ + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.encode_vae + def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor: + """ + Encode an image into the VAE latent space. + + Args: + img: The input image tensor to encode. + + Returns: + torch.FloatTensor: The encoded latent representation. + """ + z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample() + if self.vae.config.shift_factor is not None: + z0 = z0 - self.vae.config.shift_factor + if self.vae.config.scaling_factor is not None: + z0 = z0 * self.vae.config.scaling_factor + z0 = z0.to(dtype=self.vae.dtype) + return z0 + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.preprocess_vlm_input_pil_images + def preprocess_vlm_input_pil_images( + self, + input_pil_images: List[PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + resize_mode: str = "default", + crops_coords: List[Tuple[int, int, int, int]] = None, + ) -> List[PIL.Image.Image]: + """ + Resize input PIL images for VLM encoding. For each image, the target height/width is computed + from the pixel budget (max_pixels / max_side_length) and the image is resized to fit. + max_pixels is an int or None; per-image selection is handled by caller before passing here. + """ + + if input_pil_images is None or len(input_pil_images) <= 0: + return input_pil_images + + assert isinstance(input_pil_images, list), "`input_pil_images` should be a list." + assert all(isinstance(x, PIL.Image.Image) for x in input_pil_images), ( + "`input_pil_images` should be a list of PIL.Image.Image." + ) + + processed_input_pil_images = [] + for image in input_pil_images: + if crops_coords is not None: + image = [i.crop(crops_coords) for i in image] + height, width = self.image_processor.get_new_height_width( + image, height, width, max_pixels, max_side_length + ) + processed_input_pil_images.append( + self.image_processor.resize(image, height, width, resize_mode=resize_mode) + ) + return processed_input_pil_images + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.prepare_image + def prepare_image( + self, + images: Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]], + batch_size: int, + num_images_per_instruction: int, + max_input_image_pixels: Union[int, list, tuple], + max_side_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> List[Optional[torch.FloatTensor]]: + """ + Prepare input images for processing by encoding them into the VAE latent space. + + Args: + images: Single image or list of images to process. + batch_size: The number of images to generate per prompt. + num_images_per_instruction: The number of images to generate for each prompt. + device: The device to place the encoded latents on. + dtype: The data type of the encoded latents. + + Returns: + List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image. + """ + + success, max_images_per_sample, wrapped_input_images = self._check_and_wrap_input_images(images) + + if wrapped_input_images is not None: + assert len(wrapped_input_images) == batch_size, ( + "`wrapped_input_images` should be List[List[PIL.Image.Image]] and the `len(wrapped_input_images)` should be equal to `batch_size`." + ) + else: + wrapped_input_images = [None] * batch_size + + latents = [] + + for i, img in enumerate(wrapped_input_images): + if img is not None and len(img) > 0: + ref_latents = [] + for j, img_j in enumerate(img): + max_pixels = self._get_max_image_pixels( + num_images=len(img), + max_input_image_pixels=max_input_image_pixels, + ) + img_j = self.image_processor.preprocess( + img_j, max_pixels=max_pixels, max_side_length=max_side_length + ) + ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0)) + else: + ref_latents = None + + for _ in range(num_images_per_instruction): + latents.append(ref_latents) + + return latents + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline._check_and_wrap_input_images + def _check_and_wrap_input_images( + self, + input_images: Any, + treat_empty_list_as_none: bool = False, + ) -> Tuple[bool, int, Optional[Union[List[List[PIL.Image.Image]], List[List[str]]]]]: + """ + Normalize input_images into a two-level batch structure with per-sample lists: + - List[List[PIL.Image.Image]] or + - List[List[str]] (each str is an image path) + - Allowed per-sample "empty" markers: [] or None + + Rules: + - If input_images is None or []: + return (True, 0, None) + - If already in batch form such as [[image], [image,image], [], None] or [[str], [], [str,str], None], + return as is (optionally convert [] -> None if treat_empty_list_as_none=True). + - If List[PIL.Image.Image] / List[str] / List[None|PIL|str], wrap each non-None element as a single-image sample: + e.g. [img1, img2, None] -> [[img1], [img2], None] + - If single PIL.Image.Image / single str, wrap as [[item]] + - Otherwise attempt to iterate and collect valid items (PIL first, else paths) into a single batch sample. + + Returns: + (success, max_images_per_sample, wrapped_input_images) + - success: whether input_images is successfully wrapped + - max_images_per_sample: max number of images in any sample of the batch + - wrapped_input_images: List[List[PIL.Image.Image]] or List[List[str]] or None + """ + + # Case 0: input is None or empty + if input_images is None: + return True, 0, None + try: + # Safely check for emptiness without assuming it is a sequence + if hasattr(input_images, "__len__") and len(input_images) == 0: + return True, 0, None + except TypeError: + # If __len__ raises, ignore here; further logic will handle it + pass + + def is_pil_image(x: Any) -> bool: + return isinstance(x, Image.Image) + + def is_path(x: Any) -> bool: + return isinstance(x, str) + + def is_list_of_pil_images(x: Any) -> bool: + return isinstance(x, list) and all(is_pil_image(i) for i in x) + + def is_list_of_paths(x: Any) -> bool: + return isinstance(x, list) and all(is_path(i) for i in x) + + def is_list_of_list_of_pil_images(x: Any) -> bool: + return isinstance(x, list) and len(x) > 0 and all(is_list_of_pil_images(i) for i in x) + + def is_list_of_list_of_paths(x: Any) -> bool: + return isinstance(x, list) and len(x) > 0 and all(is_list_of_paths(i) for i in x) + + def is_batch_two_level_with_none(x: Any) -> bool: + """ + Accept batch-shaped inputs where each sample is: + - None (represents no image) + - [] (empty sample, can be converted to None if treat_empty_list_as_none=True) + - List[PIL.Image.Image] or List[str] + """ + if not isinstance(x, list) or len(x) == 0: + return False + for sample in x: + if sample is None: + continue + if isinstance(sample, list): + if len(sample) == 0: + continue + # Allow mixed PIL/str but all elements must be either PIL or str + all_pil = all(is_pil_image(i) for i in sample) + all_str = all(is_path(i) for i in sample) + if not (all_pil or all_str): + return False + else: + # Non-list, non-None found => not batch two-level + return False + return True + + # Case 1: already in normalized batch form (with None/[] allowed) + if is_batch_two_level_with_none(input_images): + wrapped = list(input_images) # shallow copy + # Optionally convert empty lists to None per sample + if treat_empty_list_as_none: + for idx, sample in enumerate(wrapped): + if isinstance(sample, list) and len(sample) == 0: + wrapped[idx] = None + max_len = 0 + for sample in wrapped: + if isinstance(sample, list): + max_len = max(max_len, len(sample)) + return True, max_len, wrapped + + # Case 2: List[PIL.Image.Image] -> single batch + if is_list_of_pil_images(input_images): + wrapped = [input_images] + max_len = len(input_images) + return True, max_len, wrapped + + # Case 2b: List[str] (paths) -> single batch + if is_list_of_paths(input_images): + wrapped = [input_images] + max_len = len(input_images) + return True, max_len, wrapped + + # Case 2c: Flat batch where elements can be PIL/str/None + if isinstance(input_images, list) and all( + (is_pil_image(x) or is_path(x) or x is None or (isinstance(x, list))) for x in input_images + ): + wrapped: List[Optional[List[Any]]] = [] + max_len = 0 + for item in input_images: + if item is None: + wrapped.append(None) + elif is_pil_image(item) or is_path(item): + wrapped.append([item]) + max_len = max(max_len, 1) + elif isinstance(item, list): + # Clean sublist: keep only PIL or str + pil_sub = [i for i in item if is_pil_image(i)] + str_sub = [i for i in item if is_path(i)] + if len(pil_sub) > 0 and len(str_sub) == 0: + wrapped.append(pil_sub) + max_len = max(max_len, len(pil_sub)) + elif len(str_sub) > 0 and len(pil_sub) == 0: + wrapped.append(str_sub) + max_len = max(max_len, len(str_sub)) + else: + # Empty or mixed invalid -> treat as empty + wrapped.append(None if treat_empty_list_as_none else []) + else: + # Unknown element -> treat as empty + wrapped.append(None if treat_empty_list_as_none else []) + # If all are None and we prefer None, keep as batch-level structure per spec + return True, max_len, wrapped + + # Case 3: single PIL.Image.Image -> [[image]] + if is_pil_image(input_images): + wrapped = [[input_images]] + return True, 1, wrapped + + # Case 3b: single path str -> [[path]] + if is_path(input_images): + wrapped = [[input_images]] + return True, 1, wrapped + + # Case 4: other types -> try to interpret as iterable and collect images/paths as a single sample + try: + as_list = list(input_images) + except TypeError: + # Cannot iterate; normalization fails + return False, 0, None + + pil_items = [x for x in as_list if is_pil_image(x)] + path_items = [x for x in as_list if is_path(x)] + + if pil_items: + # Treat all collected PIL images as one sample in a single batch + wrapped = [pil_items] + max_len = len(pil_items) + return True, max_len, wrapped + + if path_items: + # Treat all collected paths as one sample in a single batch + wrapped = [path_items] + max_len = len(path_items) + return True, max_len, wrapped + + # No valid entries found + return False, 0, None + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline._get_instruction_feature_embeds + def _get_instruction_feature_embeds( + self, + instruction: Union[str, List[str]], + input_pil_images: Optional[List[List[PIL.Image.Image]]], + device: Optional[torch.device] = None, + max_sequence_length: int = 256, + truncate_instruction_sequence: bool = False, + max_vlm_input_pil_pixels: Optional[Union[int, List[int]]] = None, + max_vlm_input_pil_side_length: Optional[int] = None, + system_prompt_follows_task_type: bool = False, + task_type: str = "ti2i", + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get interleaved instruction embeddings from VLM (self.mllm), aligned with training: + - Build VLM inputs via processor.apply_chat_template (images + text) + - Optionally remove vision-token features by truncation + - Return last layer or last-N layers and the corresponding attention mask + + Args: + instruction: The instruction or list of instructions to encode. + input_pil_images: A list of PIL images to be included in the prompt (TI2I/I2I). + device: The device to place the embeddings on. If None, uses the pipeline's device. + max_sequence_length: Maximum sequence length for tokenization. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The instruction embeddings tensor (or list of last-N layers) + - The attention mask tensor + + Raises: + Warning: If the input text is truncated due to sequence length limitations. + """ + device = device or self._execution_device + instruction = [instruction] if isinstance(instruction, str) else instruction + batch_size = len(instruction) + + # Build prompts with images+text. + # input_pil_images: Optional[List[List[PIL.Image.Image]]], outer length == batch_size, + # inner list contains K_i images for sample i. + prompts: List[list] = [] + processed_samples: List[Optional[List[PIL.Image.Image]]] = [] + + if input_pil_images is None or len(input_pil_images) == 0: + # No images for any sample -> pass None per sample + processed_samples = [None for _ in range(batch_size)] # type: List[Optional[List[PIL.Image.Image]]] + else: + # Validate shape: outer length must match batch_size + assert isinstance(input_pil_images, list) and len(input_pil_images) == batch_size, ( + "When provided, `input_pil_images` must be a List[List[PIL.Image.Image]] with len == batch size." + ) + for imgs in input_pil_images: + if imgs and len(imgs) > 0: + # Determine per-sample max_pixels as in dataset logic: + # - If max_vlm_input_pil_pixels is a list/tuple, require len >= K_i and take index K_i-1 + # - If it's an int, use it for all images in this sample + # - If None, do not constrain by pixels + max_pixels_i: Optional[int] = None + if isinstance(max_vlm_input_pil_pixels, (list, tuple)): + assert len(max_vlm_input_pil_pixels) >= len(imgs), ( + "`max_vlm_input_pil_pixels` length must be >= number of images in each sample" + ) + max_pixels_i = int(max_vlm_input_pil_pixels[len(imgs) - 1]) + elif isinstance(max_vlm_input_pil_pixels, int): + max_pixels_i = max_vlm_input_pil_pixels + else: + max_pixels_i = None + proc = self.preprocess_vlm_input_pil_images( + imgs, # List[PIL.Image.Image] for this sample + max_pixels=max_pixels_i, + max_side_length=max_vlm_input_pil_side_length, + ) + processed_samples.append(proc) + else: + # Empty inner list -> treat as no images for this sample + processed_samples.append(None) + + # Build the batched prompts; for each sample i, pass instruction[i] and its image list (or None) + for i in range(batch_size): + sample_imgs: Optional[List[PIL.Image.Image]] = None + if processed_samples and i < len(processed_samples): + sample_imgs = processed_samples[i] + # _apply_chat_template expects (instruction: str, input_pil_images: Optional[List[PIL.Image.Image]]) + prompts.append( + self._apply_chat_template( + instruction[i], + sample_imgs, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + ) + + # Processor produces dict with 'input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw' + vlm_inputs = self.processor.apply_chat_template( + prompts, + padding="longest", + max_length=max_sequence_length, + truncation=truncate_instruction_sequence, + padding_side="right", + return_tensors="pt", + tokenize=True, + return_dict=True, + ) + for k in vlm_inputs.keys(): + if isinstance(vlm_inputs[k], torch.Tensor): + vlm_inputs[k] = vlm_inputs[k].to(device) + + instruction_mask = vlm_inputs["attention_mask"] + + num_instruction_feature_layers = self.transformer.instruction_feature_configs.get( + "num_instruction_feature_layers", 1 + ) + final_instruction_mask = instruction_mask + + with torch.no_grad(): + if num_instruction_feature_layers > 1: + text_encoder_outputs = self.mllm(**vlm_inputs, output_hidden_states=True, return_dict=True) + all_hidden_states = ( + text_encoder_outputs.hidden_states + ) # Tuple of [B, extended_seq_len, text_hidden_dim] + instruction_feats = list(all_hidden_states)[ + -num_instruction_feature_layers: + ] # Convert to list for model processing + else: + instruction_feats = self.mllm(**vlm_inputs).last_hidden_state + + if self.mllm is not None: + dtype = self.mllm.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if isinstance(instruction_feats, (list, tuple)): + final_instruction_feats = [feat.to(dtype=dtype, device=device) for feat in instruction_feats] + else: + final_instruction_feats = instruction_feats.to(dtype=dtype, device=device) + # Keep the attention mask on the same execution device as the features + # before passing both into the diffusion transformer. + final_instruction_mask = final_instruction_mask.to(device=device) + + return final_instruction_feats, final_instruction_mask + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline._apply_chat_template + def _apply_chat_template( + self, + instruction: str, + input_pil_images: Optional[List[PIL.Image.Image]] = None, + system_prompt_follows_task_type: bool = False, + task_type: str = "ti2i", + ): + """ + Build chat template content with interleaved text and images. + If `system_prompt_follows_task_type` is True, the system prompt will be selected based on the task type. + If `system_prompt_follows_task_type` is False, the system prompt will be selected based on the input images. + Returns the prompt structure (list of messages with typed contents). + """ + user_text_content = [{"type": "text", "text": instruction}] + + if system_prompt_follows_task_type: + if task_type.lower() == "t2i": + system_prompt = self.SYSTEM_PROMPT_4_T2I + else: + system_prompt = self.SYSTEM_PROMPT_4_TI2I + else: + # Pick system prompt adaptively based on the input images and instruction. + if input_pil_images is None or len(input_pil_images) == 0: + if instruction is None or len(instruction.strip()) == 0: + system_prompt = self.SYSTEM_PROMPT_DROP + else: + system_prompt = self.SYSTEM_PROMPT_4_T2I + else: + if instruction is None or len(instruction.strip()) == 0: + system_prompt = self.SYSTEM_PROMPT_4_I2I + else: + system_prompt = self.SYSTEM_PROMPT_4_TI2I + + system_role = { + "role": "system", + "content": [{"type": "text", "text": system_prompt}], + } + if input_pil_images is None or len(input_pil_images) == 0: + prompt = [system_role, {"role": "user", "content": user_text_content}] + else: + images_content = [{"type": "image", "image": pil_img} for pil_img in input_pil_images] + prompt = [ + system_role, + {"role": "user", "content": images_content + user_text_content}, + ] + return prompt + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline._reshape_embeds_and_mask + def _reshape_embeds_and_mask(self, embeds, mask, num_images_per_instruction): + """ + To duplicate text embeddings and attention mask for each generation per instruction, using mps friendly method + """ + if isinstance(embeds, (list, tuple)): + batch_size, seq_len, _ = embeds[0].shape + reshaped_embeds = [] + for embed in embeds: + embed = embed.repeat(1, num_images_per_instruction, 1) + reshaped_embeds.append(embed.view(batch_size * num_images_per_instruction, seq_len, -1)) + else: + batch_size, seq_len, _ = embeds.shape + embeds = embeds.repeat(1, num_images_per_instruction, 1) + reshaped_embeds = embeds.view(batch_size * num_images_per_instruction, seq_len, -1) + + mask = mask.repeat(num_images_per_instruction, 1) + reshaped_mask = mask.view(batch_size * num_images_per_instruction, -1) + + return batch_size, seq_len, reshaped_embeds, reshaped_mask + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline._get_max_image_pixels + def _get_max_image_pixels( + self, + num_images: int, + max_input_image_pixels: Union[int, list, tuple] = 1024 * 1024, + ): + + if (num_images <= 0) or (not max_input_image_pixels): + return 1024 * 1024 + + if isinstance(max_input_image_pixels, (list, tuple)): + assert len(max_input_image_pixels) >= num_images, ( + f"`len(max_input_image_pixels)` should be >= number of input images per sample, i.e., {num_images}" + ) + max_pixels = max_input_image_pixels[num_images - 1] + else: + max_pixels = max_input_image_pixels + + return max_pixels + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.encode_instruction + def encode_instruction( + self, + instruction: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_instruction: Optional[Union[str, List[str]]] = None, + input_images: Optional[Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, + use_input_images_4_neg_instruct: bool = False, + use_input_images_4_empty_instruct: bool = False, + max_vlm_input_pil_pixels: Optional[Union[int, List[int]]] = 384 * 384, + max_vlm_input_pil_side_length: Optional[int] = 384 * 2, + num_images_per_instruction: int = 1, + device: Optional[torch.device] = None, + instruction_embeds: Optional[torch.Tensor] = None, + negative_instruction_embeds: Optional[torch.Tensor] = None, + instruction_attention_mask: Optional[torch.Tensor] = None, + negative_instruction_attention_mask: Optional[torch.Tensor] = None, + # For double guidance + empty_instruction: Optional[Union[str, List[str]]] = " ", + empty_instruction_embeds: Optional[torch.Tensor] = None, + empty_instruction_attention_mask: Optional[torch.Tensor] = None, + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide: bool = False, + use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide: bool = False, + max_sequence_length: int = 256, + truncate_instruction_sequence: bool = False, + system_prompt_follows_task_type: bool = False, + task_type: str = "ti2i", + ) -> Tuple[torch.Tensor, ...]: + r""" + Encodes the instruction into text encoder hidden states. + + Args: + instruction (`str` or `List[str]`, *optional*): + instruction to be encoded + negative_instruction (`str` or `List[str]`, *optional*): + The instruction not to guide the image generation. If not defined, one has to pass `negative_instruction_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_instruction (`int`, *optional*, defaults to 1): + number of images that should be generated per instruction + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + instruction_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* instruction weighting. If not + provided, text embeddings will be generated from `instruction` input argument. + negative_instruction_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the instruction. + """ + device = device or self._execution_device + + instruction = [instruction] if isinstance(instruction, str) else instruction + # Chat template with images is handled inside _get_instruction_feature_embeds + batch_size = len(instruction) + + if instruction_embeds is None: + instruction_embeds, instruction_attention_mask = self._get_instruction_feature_embeds( + instruction=instruction, + input_pil_images=input_images, + device=device, + max_sequence_length=max_sequence_length, + truncate_instruction_sequence=truncate_instruction_sequence, + max_vlm_input_pil_pixels=max_vlm_input_pil_pixels, + max_vlm_input_pil_side_length=max_vlm_input_pil_side_length, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + + batch_size, seq_len, _ = instruction_embeds.shape + + batch_size, seq_len, instruction_embeds, instruction_attention_mask = self._reshape_embeds_and_mask( + instruction_embeds, + instruction_attention_mask, + num_images_per_instruction, + ) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_instruction_embeds is None: + negative_instruction = negative_instruction if negative_instruction is not None else "" + + # Normalize str to list + negative_instruction = ( + batch_size * [negative_instruction] if isinstance(negative_instruction, str) else negative_instruction + ) + + if instruction is not None and type(instruction) is not type(negative_instruction): + raise TypeError( + f"`negative_instruction` should be the same type to `instruction`, but got {type(negative_instruction)} !=" + f" {type(instruction)}." + ) + # elif isinstance(negative_instruction, str): # not needed since negative_instruction is already a list + + elif batch_size != len(negative_instruction): + raise ValueError( + f"`negative_instruction`: {negative_instruction} has batch size {len(negative_instruction)}, but `instruction`:" + f" {instruction} has batch size {batch_size}. Please make sure that passed `negative_instruction` matches" + " the batch size of `instruction`." + ) + negative_instruction_embeds, negative_instruction_attention_mask = self._get_instruction_feature_embeds( + instruction=negative_instruction, + input_pil_images=input_images if use_input_images_4_neg_instruct else None, + device=device, + max_sequence_length=max_sequence_length, + truncate_instruction_sequence=truncate_instruction_sequence, + max_vlm_input_pil_pixels=max_vlm_input_pil_pixels if use_input_images_4_neg_instruct else None, + max_vlm_input_pil_side_length=max_vlm_input_pil_side_length + if use_input_images_4_neg_instruct + else None, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + + ( + batch_size, + seq_len, + negative_instruction_embeds, + negative_instruction_attention_mask, + ) = self._reshape_embeds_and_mask( + negative_instruction_embeds, + negative_instruction_attention_mask, + num_images_per_instruction, + ) + + if ( + use_empty_neg_instruct_4_ref_img_pred_at_image_guide_in_double_guide + or use_empty_neg_instruct_4_ref_img_pred_at_text_guide_in_double_guide + ): + if do_classifier_free_guidance and (empty_instruction_embeds is None): + empty_instruction = empty_instruction if empty_instruction is not None else [" "] * batch_size + + empty_instruction = ( + batch_size * [empty_instruction] if isinstance(empty_instruction, str) else empty_instruction + ) + + if instruction is not None and type(instruction) is not type(empty_instruction): + raise TypeError( + f"`empty_instruction` should be the same type as `instruction`, but got {type(empty_instruction)} !=" + f" {type(instruction)}." + ) + + elif batch_size != len(empty_instruction): + raise ValueError( + f"`empty_instruction`: {empty_instruction} has batch size {len(empty_instruction)}, but `instruction`:" + f" {instruction} has batch size {batch_size}. Please make sure that passed `empty_instruction` matches" + " the batch size of `instruction`." + ) + + empty_instruction_embeds, empty_instruction_attention_mask = self._get_instruction_feature_embeds( + instruction=empty_instruction, + input_pil_images=input_images if use_input_images_4_empty_instruct else None, + device=device, + max_sequence_length=max_sequence_length, + truncate_instruction_sequence=truncate_instruction_sequence, + max_vlm_input_pil_pixels=max_vlm_input_pil_pixels if use_input_images_4_empty_instruct else None, + max_vlm_input_pil_side_length=max_vlm_input_pil_side_length + if use_input_images_4_empty_instruct + else None, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + ( + batch_size, + seq_len, + empty_instruction_embeds, + empty_instruction_attention_mask, + ) = self._reshape_embeds_and_mask( + empty_instruction_embeds, + empty_instruction_attention_mask, + num_images_per_instruction, + ) + + return ( + instruction_embeds, + instruction_attention_mask, + negative_instruction_embeds, + negative_instruction_attention_mask, + empty_instruction_embeds, + empty_instruction_attention_mask, + ) + + @property + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @property + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.text_guidance_scale + def text_guidance_scale(self): + return self._text_guidance_scale + + @property + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.image_guidance_scale + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.empty_instruction_guidance_scale + def empty_instruction_guidance_scale(self): + return self._empty_instruction_guidance_scale + + @property + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.cfg_range + def cfg_range(self): + return self._cfg_range + + @torch.no_grad() + def __call__( + self, + instruction: Optional[Union[str, List[str]]] = None, + instruction_embeds: Optional[torch.FloatTensor] = None, + instruction_attention_mask: Optional[torch.LongTensor] = None, + max_sequence_length: int = 1280, + truncate_instruction_sequence: bool = False, + num_images_per_instruction: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + align_res: bool = True, + num_inference_steps: int = 50, + system_prompt_follows_task_type: bool = False, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + step_func=None, + # DMD student inference controls + use_dmd_student_inference: bool = True, + dmd_conditioning_sigma: float = 0.001, + ): + """Run DMD student few-step text-to-image inference. + + This is a pure-T2I path: no reference images, no classifier-free + guidance, no scheduler. It mirrors `BooguImagePipeline.__call__`'s setup + for T2I and then runs the DMD predict/renoise loop directly. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # DMD requires no CFG: pin guidance scales to the no-guidance configuration. + self._text_guidance_scale = 1.0 + self._image_guidance_scale = 1.0 + self._empty_instruction_guidance_scale = 0.0 + + # 1. Define call parameters + if instruction is not None and isinstance(instruction, str): + batch_size = 1 + instruction = [instruction] + elif instruction is not None and isinstance(instruction, (list, tuple)): + batch_size = len(instruction) + else: + batch_size = instruction_embeds.shape[0] + + # Resolve the device the pipeline's modules live on. With offloading enabled the base + # class returns the right execution device; otherwise it reflects the last `.to(...)`. + device = self._execution_device + + # Pure T2I: no input images. + task_type = self._get_task_type_by_input_images(None) + + # 2. Encode input instruction (T2I, no negative/empty paths since tg == 1.0). + ( + instruction_embeds, + instruction_attention_mask, + negative_instruction_embeds, + negative_instruction_attention_mask, + empty_instruction_embeds, + empty_instruction_attention_mask, + ) = self.encode_instruction( + instruction, + self.text_guidance_scale > 1.0, + negative_instruction=None, + input_images=None, + num_images_per_instruction=num_images_per_instruction, + device=device, + instruction_embeds=instruction_embeds, + instruction_attention_mask=instruction_attention_mask, + max_sequence_length=max_sequence_length, + truncate_instruction_sequence=truncate_instruction_sequence, + system_prompt_follows_task_type=system_prompt_follows_task_type, + task_type=task_type, + ) + + # Put ref_latents here before encoding instruction. + dtype = self.vae.dtype + + # 3. Prepare control image (T2I -> empty ref latents). + ref_latents = self.prepare_image( + images=None, + batch_size=batch_size, + num_images_per_instruction=num_images_per_instruction, + max_input_image_pixels=2048 * 2048, + max_side_length=2048 * 2, + device=device, + dtype=dtype, + ) + + input_images, width, height, ori_width, ori_height = self._resolve_output_and_original_size( + input_images=None, + ref_latents=ref_latents, + align_res=align_res, + width=width, + height=height, + max_input_image_pixels=2048 * 2048, + max_images_per_sample=0, + img_scale_num=self.vae_scale_factor * 2, + ) + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_instruction, + latent_channels, + height, + width, + instruction_embeds.dtype, + device, + generator, + latents, + ) + + freqs_cis = BooguImageRotaryPosEmbed.get_freqs_cis( + self.transformer.config.axes_dim_rope, + self.transformer.config.axes_lens, + theta=10000, + ) + + # 5. DMD student few-step T2I denoising (no scheduler, no guidance). + if not use_dmd_student_inference: + raise ValueError( + "BooguImageTurboPipeline only supports DMD student inference; pass use_dmd_student_inference=True " + "or use BooguImagePipeline for the scheduler-driven path." + ) + + logger.info("[Turbo Pipeline Processing]: DMD student few-step T2I inference.") + + dmd_sigmas = self._build_dmd_student_sigmas( + num_inference_steps=num_inference_steps, + device=device, + dtype=latents.dtype, + conditioning_sigma=float(dmd_conditioning_sigma), + timesteps=timesteps, + ) + num_inference_steps = int(dmd_sigmas.numel()) + self._num_timesteps = num_inference_steps + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, sigma in enumerate(dmd_sigmas.tolist()): + latents = self._predict_dmd_student_step( + latents=latents, + sigma=sigma, + instruction_embeds=instruction_embeds, + freqs_cis=freqs_cis, + instruction_attention_mask=instruction_attention_mask, + ).to(dtype=dtype) + + if i < num_inference_steps - 1: + latents = self._renoise_dmd_latents( + latents, + sigma=dmd_sigmas[i + 1].item(), + generator=generator, + ).to(dtype=dtype) + + progress_bar.update() + if step_func is not None: + step_func(i, self._num_timesteps) + + # 6. Decode latents (same logic as the parent `processing` tail). + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + image = F.interpolate(image, size=(ori_height, ori_width), mode="bilinear") + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + else: + return FMPipelineOutput(images=image) + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline._resolve_output_and_original_size + def _resolve_output_and_original_size( + self, + input_images, + ref_latents: List[Union[List[torch.FloatTensor], None]], + align_res: bool, + width: int, + height: int, + max_input_image_pixels: Union[int, list, tuple], + max_images_per_sample: int, + img_scale_num: int = 16, + ) -> Tuple[List, int, int, int, int]: + if input_images is None: + input_images = [] + + if len(input_images) == 1 and align_res: + width, height = ( + ref_latents[0][0].shape[-1] * self.vae_scale_factor, + ref_latents[0][0].shape[-2] * self.vae_scale_factor, + ) + ori_width, ori_height = width, height + else: + ori_width, ori_height = width, height + + cur_pixels = height * width + + if isinstance(max_input_image_pixels, (list, tuple)): + if (input_images is not None) and (len(input_images) > 0) and max_images_per_sample > 0: + assert len(max_input_image_pixels) >= max_images_per_sample, ( + f"When `max_input_image_pixels` is a list or tuple, the length of it (here is {len(max_input_image_pixels)}) should be >= max number of input images in all the samples (here is {max_images_per_sample})." + ) + max_pixels = max_input_image_pixels[max_images_per_sample - 1] + else: + max_pixels = max_input_image_pixels[0] + else: + max_pixels = max_input_image_pixels + + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) + + height, width = ( + int(height * ratio) // img_scale_num * img_scale_num, + int(width * ratio) // img_scale_num * img_scale_num, + ) + + return input_images, width, height, ori_width, ori_height + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline._get_task_type_by_ref_latents + def _get_task_type_by_ref_latents(self, ref_latents: List[Union[List[torch.FloatTensor], None]]): + if not ref_latents: + return "t2i" + + if isinstance(ref_latents, (list, tuple)): + for x in ref_latents: + if x: + return "ti2i" + return "t2i" + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline._get_task_type_by_input_images + def _get_task_type_by_input_images(self, input_images: Union[List[List[PIL.Image.Image]], List[PIL.Image.Image]]): + if not input_images: + return "t2i" + + if isinstance(input_images, (list, tuple)): + for x in input_images: + if x: + return "ti2i" + return "t2i" + + # Copied from diffusers.pipelines.boogu.pipeline_boogu.BooguImagePipeline.predict + def predict( + self, + t, + latents, + instruction_embeds, + freqs_cis, + instruction_attention_mask, + ref_image_hidden_states, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + batch_size, num_channels_latents, height, width = latents.shape + + optional_kwargs = {} + if "ref_image_hidden_states" in set(inspect.signature(self.transformer.forward).parameters.keys()): + optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states + + model_pred = self.transformer( + latents, + timestep, + instruction_embeds, + freqs_cis, + instruction_attention_mask, + **optional_kwargs, + ) + return model_pred diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 447586c6f436..a1fb70e91d0e 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -61,6 +61,7 @@ _import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"] _import_structure["scheduling_flow_map_euler_discrete"] = ["FlowMapEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] + _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] _import_structure["scheduling_helios"] = ["HeliosScheduler"] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8439a2b93371..e0cf790d0e58 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -900,6 +900,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BooguImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class BriaFiboTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0747e76cf715..99ed4116f943 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1082,6 +1082,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class BooguImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class BooguImageTurboPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class BriaFiboEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/teacache_util.py b/src/diffusers/utils/teacache_util.py new file mode 100644 index 000000000000..a47076a97e9d --- /dev/null +++ b/src/diffusers/utils/teacache_util.py @@ -0,0 +1,41 @@ +""" +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class TeaCacheParams: + """ + TeaCache parameters for `BooguImageTransformer2DModel` + See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding + + Args: + previous_residual (Optional[torch.Tensor]): + The tensor difference between the output and the input of the transformer layers from the previous timestep. + previous_modulated_inp (Optional[torch.Tensor]): + The modulated input from the previous timestep used to indicate the change of the transformer layer's output. + accumulated_rel_l1_distance (float): + The accumulated relative L1 distance. + is_first_or_last_step (bool): + Whether the current timestep is the first or last step. + """ + + previous_residual: Optional[torch.Tensor] = None + previous_modulated_inp: Optional[torch.Tensor] = None + accumulated_rel_l1_distance: float = 0 + is_first_or_last_step: bool = False diff --git a/tests/models/transformers/test_models_transformer_boogu.py b/tests/models/transformers/test_models_transformer_boogu.py new file mode 100644 index 000000000000..ee6a7a4f6f67 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_boogu.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from diffusers import BooguImageTransformer2DModel +from diffusers.models.transformers.transformer_boogu import BooguImageRotaryPosEmbed +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +# Tiny config: hidden_size // num_attention_heads must equal sum(axes_dim_rope). +# Here 12 // 2 == 6 == 2 + 2 + 2. +_AXES_DIM_ROPE = (2, 2, 2) +_AXES_LENS = (16, 16, 16) +_INSTRUCTION_FEAT_DIM = 8 +_THETA = 10000 + + +class BooguImageTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return BooguImageTransformer2DModel + + @property + def pretrained_model_name_or_path(self): + return None # No tiny Hub checkpoint yet; hub-dependent tests are skipped. + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { + "patch_size": 2, + "in_channels": 4, + "hidden_size": 12, + "num_layers": 2, + "num_double_stream_layers": 1, + "num_refiner_layers": 1, + "num_attention_heads": 2, + "num_kv_heads": 1, + "multiple_of": 4, + "norm_eps": 1e-5, + "axes_dim_rope": _AXES_DIM_ROPE, + "axes_lens": _AXES_LENS, + "instruction_feature_configs": { + "instruction_feat_dim": _INSTRUCTION_FEAT_DIM, + "reduce_type": "mean", + "num_instruction_feat_layers": 1, + }, + "timestep_scale": 1.0, + } + + def get_dummy_inputs(self, height: int = 8, width: int = 8) -> dict: + batch_size = 1 + in_channels = 4 + instruction_len = 5 + gen = self.generator + + hidden_states = randn_tensor( + (batch_size, in_channels, height, width), generator=gen, device=torch.device(torch_device) + ) + timestep = torch.tensor([1.0], device=torch_device) + instruction_hidden_states = randn_tensor( + (batch_size, instruction_len, _INSTRUCTION_FEAT_DIM), generator=gen, device=torch.device(torch_device) + ) + instruction_attention_mask = torch.ones(batch_size, instruction_len, dtype=torch.long, device=torch_device) + freqs_cis = BooguImageRotaryPosEmbed.get_freqs_cis(_AXES_DIM_ROPE, _AXES_LENS, theta=_THETA) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "instruction_hidden_states": instruction_hidden_states, + "freqs_cis": freqs_cis, + "instruction_attention_mask": instruction_attention_mask, + } + + @property + def input_shape(self) -> tuple: + return (4, 8, 8) + + @property + def output_shape(self) -> tuple: + return (4, 8, 8) + + +class TestBooguImageTransformerModel(BooguImageTransformerTesterConfig, ModelTesterMixin): + pass + + +class TestBooguImageTransformerMemory(BooguImageTransformerTesterConfig, MemoryTesterMixin): + pass + + +class TestBooguImageTransformerTorchCompile(BooguImageTransformerTesterConfig, TorchCompileTesterMixin): + @property + def different_shapes_for_compilation(self): + return [(8, 8), (8, 16), (16, 16)] + + def get_dummy_inputs(self, height: int = 8, width: int = 8) -> dict: + return BooguImageTransformerTesterConfig.get_dummy_inputs(self, height=height, width=width) + + +class TestBooguImageTransformerTraining(BooguImageTransformerTesterConfig, TrainingTesterMixin): + pass diff --git a/tests/pipelines/boogu/__init__.py b/tests/pipelines/boogu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/boogu/test_boogu.py b/tests/pipelines/boogu/test_boogu.py new file mode 100644 index 000000000000..5b995708ab0b --- /dev/null +++ b/tests/pipelines/boogu/test_boogu.py @@ -0,0 +1,170 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers import ( + AutoencoderKL, + BooguImagePipeline, + BooguImageTransformer2DModel, + FlowMatchEulerDiscreteScheduler, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +# Tiny processor lives on the Hub (bundles tokenizer + image processor + chat template). +_TINY_QWEN_REPO = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" +# MLLM hidden size; the transformer's instruction_feat_dim must match it. +_MLLM_HIDDEN = 16 + + +class BooguImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BooguImagePipeline + # Boogu is instruction-driven, not prompt-driven. + params = frozenset(["instruction", "height", "width", "num_inference_steps"]) + batch_params = frozenset(["instruction"]) + required_optional_params = frozenset(["num_inference_steps", "generator", "output_type", "return_dict"]) + + # Boogu uses the base-class device placement (`.to(...)` / `_execution_device`), but the + # generic offload / casting / xformers paths do not apply to its instruction-encoder design. + test_xformers_attention = False + test_attention_slicing = False + test_layerwise_casting = False + test_group_offloading = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = BooguImageTransformer2DModel( + patch_size=2, + in_channels=4, + hidden_size=12, + num_layers=2, + num_double_stream_layers=1, + num_refiner_layers=1, + num_attention_heads=2, + num_kv_heads=1, + multiple_of=4, + norm_eps=1e-5, + axes_dim_rope=(2, 2, 2), + axes_lens=(64, 64, 64), + instruction_feature_configs={ + "instruction_feat_dim": _MLLM_HIDDEN, + "reduce_type": "mean", + "num_instruction_feat_layers": 1, + }, + timestep_scale=1.0, + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(32,), + latent_channels=4, + norm_num_groups=8, + sample_size=32, + ) + + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + # Boogu's released configs carry `seq_len`, used for the static v1 time shift. + scheduler.register_to_config(seq_len=4096) + + torch.manual_seed(0) + mllm_config = Qwen3VLConfig( + text_config={ + "hidden_size": _MLLM_HIDDEN, + "intermediate_size": _MLLM_HIDDEN, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": {"mrope_section": [1, 1, 2], "rope_type": "default", "type": "default"}, + "rope_theta": 1000000.0, + "vocab_size": 151936, + "head_dim": 8, + }, + vision_config={ + "depth": 2, + "hidden_size": _MLLM_HIDDEN, + "intermediate_size": _MLLM_HIDDEN, + "num_heads": 2, + "out_hidden_size": _MLLM_HIDDEN, + }, + ) + mllm = Qwen3VLForConditionalGeneration(mllm_config).eval() + processor = Qwen3VLProcessor.from_pretrained(_TINY_QWEN_REPO) + + return { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "mllm": mllm, + "processor": processor, + } + + def get_dummy_inputs(self, device, seed=0): + generator = torch.Generator("cpu").manual_seed(seed) + return { + "instruction": "a cat", + "generator": generator, + "num_inference_steps": 2, + "height": 16, + "width": 16, + # Pure T2I, no classifier-free guidance, run on CPU. + "text_guidance_scale": 1.0, + "image_guidance_scale": 1.0, + "empty_instruction_guidance_scale": 0.0, + "output_type": "np", + } + + def test_boogu_t2i_default(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + images = pipe(**inputs).images + images = np.asarray(images) + + self.assertEqual(images.shape, (1, 16, 16, 3)) + + @unittest.skip( + "Qwen3VLProcessor bundles an image processor that is not DDUF-serializable " + "(same limitation as other Qwen3VL-based pipelines)." + ) + def test_save_load_dduf(self): + pass + + @unittest.skip( + "save/load round-trips the Qwen3VLProcessor, whose image-processor chat-template " + "reload is not supported offline (same limitation as other Qwen3VL-based pipelines)." + ) + def test_save_load_local(self): + pass + + @unittest.skip("device_map sharding requires a hardware accelerator.") + def test_pipeline_with_accelerator_device_map(self): + pass