diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 46e241d817b5..64a4222845b0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -114,6 +114,8 @@ title: Guiders - local: modular_diffusers/custom_blocks title: Building Custom Blocks + - local: modular_diffusers/mellon + title: Using Custom Blocks with Mellon title: Modular Diffusers - isExpanded: false sections: diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 24776b42309e..c77efa09f594 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -106,8 +106,6 @@ video, audio = pipe( output_type="np", return_dict=False, ) -video = (video * 255).round().astype("uint8") -video = torch.from_numpy(video) encode_video( video[0], @@ -185,8 +183,6 @@ video, audio = pipe( output_type="np", return_dict=False, ) -video = (video * 255).round().astype("uint8") -video = torch.from_numpy(video) encode_video( video[0], diff --git a/docs/source/en/api/pipelines/z_image.md b/docs/source/en/api/pipelines/z_image.md index 5175f6b0fb6f..cf4c1aefb81f 100644 --- a/docs/source/en/api/pipelines/z_image.md +++ b/docs/source/en/api/pipelines/z_image.md @@ -53,6 +53,41 @@ image = pipe( image.save("zimage_img2img.png") ``` +## Inpainting + +Use [`ZImageInpaintPipeline`] to inpaint specific regions of an image based on a text prompt and mask. + +```python +import torch +import numpy as np +from PIL import Image +from diffusers import ZImageInpaintPipeline +from diffusers.utils import load_image + +pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +init_image = load_image(url).resize((1024, 1024)) + +# Create a mask (white = inpaint, black = preserve) +mask = np.zeros((1024, 1024), dtype=np.uint8) +mask[256:768, 256:768] = 255 # Inpaint center region +mask_image = Image.fromarray(mask) + +prompt = "A beautiful lake with mountains in the background" +image = pipe( + prompt, + image=init_image, + mask_image=mask_image, + strength=1.0, + num_inference_steps=9, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), +).images[0] +image.save("zimage_inpaint.png") +``` + ## ZImagePipeline [[autodoc]] ZImagePipeline @@ -64,3 +99,9 @@ image.save("zimage_img2img.png") [[autodoc]] ZImageImg2ImgPipeline - all - __call__ + +## ZImageInpaintPipeline + +[[autodoc]] ZImageInpaintPipeline + - all + - __call__ diff --git a/docs/source/en/modular_diffusers/components_manager.md b/docs/source/en/modular_diffusers/components_manager.md index af53411b9533..426739347f27 100644 --- a/docs/source/en/modular_diffusers/components_manager.md +++ b/docs/source/en/modular_diffusers/components_manager.md @@ -12,179 +12,85 @@ specific language governing permissions and limitations under the License. # ComponentsManager -The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), prevents duplicate model instances, and supports offloading. +The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), and supports offloading. This guide will show you how to use [`ComponentsManager`] to manage components and device memory. -## Add a component +## Connect to a pipeline -The [`ComponentsManager`] should be created alongside a [`ModularPipeline`] in either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`]. +Create a [`ComponentsManager`] and pass it to a [`ModularPipeline`] with either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`]. -> [!TIP] -> The `collection` parameter is optional but makes it easier to organize and manage components. ```py from diffusers import ModularPipeline, ComponentsManager +import torch -comp = ComponentsManager() -pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1") +manager = ComponentsManager() +pipe = ModularPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", components_manager=manager) +pipe.load_components(torch_dtype=torch.bfloat16) ``` ```py -from diffusers import ComponentsManager -from diffusers.modular_pipelines import SequentialPipelineBlocks -from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS - -t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) - -modular_repo_id = "YiYiXu/modular-loader-t2i-0704" -components = ComponentsManager() -t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components) +from diffusers import ModularPipelineBlocks, ComponentsManager +import torch +manager = ComponentsManager() +blocks = ModularPipelineBlocks.from_pretrained("diffusers/Florence2-image-Annotator", trust_remote_code=True) +pipe= blocks.init_pipeline(components_manager=manager) +pipe.load_components(torch_dtype=torch.bfloat16) ``` -Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_components`]. The example below uses [`~ModularPipeline.load_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection - -```py -pipe.load_components() -pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2") -``` - -Use the [`~ModularPipeline.null_component_names`] property to identify any components that need to be loaded, retrieve them with [`~ComponentsManager.get_components_by_names`], and then call [`~ModularPipeline.update_components`] to add the missing components. - -```py -pipe2.null_component_names -['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet'] - -comp_dict = comp.get_components_by_names(names=pipe2.null_component_names) -pipe2.update_components(**comp_dict) -``` - -To add individual components, use the [`~ComponentsManager.add`] method. This registers a component with a unique id. - -```py -from diffusers import AutoModel - -text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder") -component_id = comp.add("text_encoder", text_encoder) -comp -``` - -Use [`~ComponentsManager.remove`] to remove a component using their id. - -```py -comp.remove("text_encoder_139917733042864") -``` - -## Retrieve a component - -The [`ComponentsManager`] provides several methods to retrieve registered components. - -### get_one - -The [`~ComponentsManager.get_one`] method returns a single component and supports pattern matching for the `name` parameter. If multiple components match, [`~ComponentsManager.get_one`] returns an error. - -| Pattern | Example | Description | -|-------------|----------------------------------|-------------------------------------------| -| exact | `comp.get_one(name="unet")` | exact name match | -| wildcard | `comp.get_one(name="unet*")` | names starting with "unet" | -| exclusion | `comp.get_one(name="!unet")` | exclude components named "unet" | -| or | `comp.get_one(name="unet|vae")` | name is "unet" or "vae" | - -[`~ComponentsManager.get_one`] also filters components by the `collection` argument or `load_id` argument. - -```py -comp.get_one(name="unet", collection="sdxl") -``` - -### get_components_by_names - -The [`~ComponentsManager.get_components_by_names`] method accepts a list of names and returns a dictionary mapping names to components. This is especially useful with [`ModularPipeline`] since they provide lists of required component names and the returned dictionary can be passed directly to [`~ModularPipeline.update_components`]. - -```py -component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"]) -{"text_encoder": component1, "unet": component2, "vae": component3} -``` - -## Duplicate detection - -It is recommended to load model components with [`ComponentSpec`] to assign components with a unique id that encodes their loading parameters. This allows [`ComponentsManager`] to automatically detect and prevent duplicate model instances even when different objects represent the same underlying checkpoint. - -```py -from diffusers import ComponentSpec, ComponentsManager -from transformers import CLIPTextModel - -comp = ComponentsManager() - -# Create ComponentSpec for the first text encoder -spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel) -# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from the same repo/subfolder) -spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel) - -# Load and add both components - the manager will detect they're the same model -comp.add("text_encoder", spec.load()) -comp.add("text_encoder_duplicated", spec_duplicated.load()) -``` - -This returns a warning with instructions for removing the duplicate. +Components loaded by the pipeline are automatically registered in the manager. You can inspect them right away. -```py -ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('')`. -'text_encoder_duplicated_139917580682672' -``` - -You could also add a component without using [`ComponentSpec`] and duplicate detection still works in most cases even if you're adding the same component under a different name. - -However, [`ComponentManager`] can't detect duplicates when you load the same component into different objects. In this case, you should load a model with [`ComponentSpec`]. - -```py -text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder") -comp.add("text_encoder", text_encoder_2) -'text_encoder_139917732983664' -``` +## Inspect components -## Collections +Print the [`ComponentsManager`] to see all registered components, including their class, device placement, dtype, memory size, and load ID. -Collections are labels assigned to components for better organization and management. Add a component to a collection with the `collection` argument in [`~ComponentsManager.add`]. - -Only one component per name is allowed in each collection. Adding a second component with the same name automatically removes the first component. +The output below corresponds to the `from_pretrained` example above. ```py -from diffusers import ComponentSpec, ComponentsManager - -comp = ComponentsManager() -# Create ComponentSpec for the first UNet -spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel) -# Create ComponentSpec for a different UNet -spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16") - -# Add both UNets to the same collection - the second one will replace the first -comp.add("unet", spec.load(), collection="sdxl") -comp.add("unet", spec2.load(), collection="sdxl") +Components: +============================================================================================================================= +Models: +----------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID +----------------------------------------------------------------------------------------------------------------------------- +text_encoder_140458257514752 | Qwen3Model | cpu | torch.bfloat16 | 7.49 | Tongyi-MAI/Z-Image-Turbo|text_encoder|null|null +vae_140458257515376 | AutoencoderKL | cpu | torch.bfloat16 | 0.16 | Tongyi-MAI/Z-Image-Turbo|vae|null|null +transformer_140458257515616 | ZImageTransformer2DModel | cpu | torch.bfloat16 | 11.46 | Tongyi-MAI/Z-Image-Turbo|transformer|null|null +----------------------------------------------------------------------------------------------------------------------------- + +Other Components: +----------------------------------------------------------------------------------------------------------------------------- +ID | Class | Collection +----------------------------------------------------------------------------------------------------------------------------- +scheduler_140461023555264 | FlowMatchEulerDiscreteScheduler | N/A +tokenizer_140458256346432 | Qwen2Tokenizer | N/A +----------------------------------------------------------------------------------------------------------------------------- ``` -This makes it convenient to work with node-based systems because you can: - -- Mark all models as loaded from one node with the `collection` label. -- Automatically replace models when new checkpoints are loaded under the same name. -- Batch delete all models in a collection when a node is removed. +The table shows models (with device, dtype, and memory info) separately from other components like schedulers and tokenizers. If any models have LoRA adapters, IP-Adapters, or quantization applied, that information is displayed in an additional section at the bottom. ## Offloading The [`~ComponentsManager.enable_auto_cpu_offload`] method is a global offloading strategy that works across all models regardless of which pipeline is using them. Once enabled, you don't need to worry about device placement if you add or remove components. ```py -comp.enable_auto_cpu_offload(device="cuda") +manager.enable_auto_cpu_offload(device="cuda") ``` All models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low. -You can set your own rules for which models to offload first. +Call [`~ComponentsManager.disable_auto_cpu_offload`] to disable offloading. + +```py +manager.disable_auto_cpu_offload() +``` diff --git a/docs/source/en/modular_diffusers/custom_blocks.md b/docs/source/en/modular_diffusers/custom_blocks.md index 6ef8db613f7f..b412e0e58abc 100644 --- a/docs/source/en/modular_diffusers/custom_blocks.md +++ b/docs/source/en/modular_diffusers/custom_blocks.md @@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License. [ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block. > [!TIP] -> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana. +> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom blocks. ## Project Structure @@ -31,54 +31,58 @@ Your custom block project should use the following structure: - `block.py` contains the custom block implementation - `modular_config.json` contains the metadata needed to load the block -## Example: Florence 2 Inpainting Block +## Quick Start with Template -In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting. +The fastest way to create a custom block is to start from our template. The template provides a pre-configured project structure with `block.py` and `modular_config.json` files, plus commented examples showing how to define components, inputs, outputs, and the `__call__` method—so you can focus on your custom logic instead of boilerplate setup. -The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub. +### Download the template -```py -# Inside block.py -from diffusers.modular_pipelines import ( - ModularPipelineBlocks, - ComponentSpec, +```python +from diffusers import ModularPipelineBlocks + +model_id = "diffusers/custom-block-template" +local_dir = model_id.split("/")[-1] + +blocks = ModularPipelineBlocks.from_pretrained( + model_id, + trust_remote_code=True, + local_dir=local_dir ) -from transformers import AutoProcessor, Florence2ForConditionalGeneration +``` +This saves the template files to `custom-block-template/` locally or you could use `local_dir` to save to a specific location. -class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): +### Edit locally - @property - def expected_components(self): - return [ - ComponentSpec( - name="image_annotator", - type_hint=Florence2ForConditionalGeneration, - pretrained_model_name_or_path="florence-community/Florence-2-base-ft", - ), - ComponentSpec( - name="image_annotator_processor", - type_hint=AutoProcessor, - pretrained_model_name_or_path="florence-community/Florence-2-base-ft", - ), - ] +Open `block.py` and implement your custom block. The template includes commented examples showing how to define each property. See the [Florence-2 example](#example-florence-2-image-annotator) below for a complete implementation. + +### Test your block + +```python +from diffusers import ModularPipelineBlocks + +blocks = ModularPipelineBlocks.from_pretrained(local_dir, trust_remote_code=True) +pipeline = blocks.init_pipeline() +output = pipeline(...) # your inputs here ``` -Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations. +### Upload to the Hub -```py -from typing import List, Union -from PIL import Image, ImageDraw -import torch -import numpy as np - -from diffusers.modular_pipelines import ( - PipelineState, - ModularPipelineBlocks, - InputParam, - ComponentSpec, - OutputParam, -) +```python +pipeline.save_pretrained(local_dir, repo_id="your-username/your-block-name", push_to_hub=True) +``` + +## Example: Florence-2 Image Annotator + +This example creates a custom block with [Florence-2](https://huggingface.co/docs/transformers/model_doc/florence2) to process an input image and generate a mask for inpainting. + +### Define components + +Define the components the block needs, `Florence2ForConditionalGeneration` and its processor. When defining components, specify the `name` (how you'll access it in code), `type_hint` (the model class), and `pretrained_model_name_or_path` (where to load weights from). + +```python +# Inside block.py +from diffusers.modular_pipelines import ModularPipelineBlocks, ComponentSpec from transformers import AutoProcessor, Florence2ForConditionalGeneration @@ -98,122 +102,21 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): pretrained_model_name_or_path="florence-community/Florence-2-base-ft", ), ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "image", - type_hint=Union[Image.Image, List[Image.Image]], - required=True, - description="Image(s) to annotate", - ), - InputParam( - "annotation_task", - type_hint=Union[str, List[str]], - required=True, - default="", - description="""Annotation Task to perform on the image. - Supported Tasks: - - - - - - - - - - - """, - ), - InputParam( - "annotation_prompt", - type_hint=Union[str, List[str]], - required=True, - description="""Annotation Prompt to provide more context to the task. - Can be used to detect or segment out specific elements in the image - """, - ), - InputParam( - "annotation_output_type", - type_hint=str, - required=True, - default="mask_image", - description="""Output type from annotation predictions. Available options are - mask_image: - -black and white mask image for the given image based on the task type - mask_overlay: - - mask overlayed on the original image - bounding_box: - - bounding boxes drawn on the original image - """, - ), - InputParam( - "annotation_overlay", - type_hint=bool, - required=True, - default=False, - description="", - ), - ] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "mask_image", - type_hint=Image, - description="Inpainting Mask for input Image(s)", - ), - OutputParam( - "annotations", - type_hint=dict, - description="Annotations Predictions for input Image(s)", - ), - OutputParam( - "image", - type_hint=Image, - description="Annotated input Image(s)", - ), - ] - ``` -Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask. +### Define inputs and outputs -```py +Inputs include the image, annotation task, and prompt. Outputs include the generated mask and annotations. + +```python from typing import List, Union -from PIL import Image, ImageDraw -import torch -import numpy as np - -from diffusers.modular_pipelines import ( - PipelineState, - ModularPipelineBlocks, - InputParam, - ComponentSpec, - OutputParam, -) -from transformers import AutoProcessor, Florence2ForConditionalGeneration +from PIL import Image +from diffusers.modular_pipelines import InputParam, OutputParam class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): - @property - def expected_components(self): - return [ - ComponentSpec( - name="image_annotator", - type_hint=Florence2ForConditionalGeneration, - pretrained_model_name_or_path="florence-community/Florence-2-base-ft", - ), - ComponentSpec( - name="image_annotator_processor", - type_hint=AutoProcessor, - pretrained_model_name_or_path="florence-community/Florence-2-base-ft", - ), - ] + # ... expected_components from above ... @property def inputs(self) -> List[InputParam]: @@ -226,51 +129,21 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): ), InputParam( "annotation_task", - type_hint=Union[str, List[str]], - required=True, + type_hint=str, default="", - description="""Annotation Task to perform on the image. - Supported Tasks: - - - - - - - - - - - """, + description="Annotation task to perform (e.g., , , )", ), InputParam( "annotation_prompt", - type_hint=Union[str, List[str]], + type_hint=str, required=True, - description="""Annotation Prompt to provide more context to the task. - Can be used to detect or segment out specific elements in the image - """, + description="Prompt to provide context for the annotation task", ), InputParam( "annotation_output_type", type_hint=str, - required=True, default="mask_image", - description="""Output type from annotation predictions. Available options are - mask_image: - -black and white mask image for the given image based on the task type - mask_overlay: - - mask overlayed on the original image - bounding_box: - - bounding boxes drawn on the original image - """, - ), - InputParam( - "annotation_overlay", - type_hint=bool, - required=True, - default=False, - description="", + description="Output type: 'mask_image', 'mask_overlay', or 'bounding_box'", ), ] @@ -279,109 +152,45 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): return [ OutputParam( "mask_image", - type_hint=Image, - description="Inpainting Mask for input Image(s)", + type_hint=Image.Image, + description="Inpainting mask for the input image", ), OutputParam( "annotations", type_hint=dict, - description="Annotations Predictions for input Image(s)", + description="Raw annotation predictions", ), OutputParam( "image", - type_hint=Image, - description="Annotated input Image(s)", + type_hint=Image.Image, + description="Annotated image", ), ] +``` - def get_annotations(self, components, images, prompts, task): - task_prompts = [task + prompt for prompt in prompts] +### Implement the `__call__` method - inputs = components.image_annotator_processor( - text=task_prompts, images=images, return_tensors="pt" - ).to(components.image_annotator.device, components.image_annotator.dtype) +The `__call__` method contains the block's logic. Access inputs via `block_state`, run your computation, and set outputs back to `block_state`. - generated_ids = components.image_annotator.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=1024, - early_stopping=False, - do_sample=False, - num_beams=3, - ) - annotations = components.image_annotator_processor.batch_decode( - generated_ids, skip_special_tokens=False - ) - outputs = [] - for image, annotation in zip(images, annotations): - outputs.append( - components.image_annotator_processor.post_process_generation( - annotation, task=task, image_size=(image.width, image.height) - ) - ) - return outputs - - def prepare_mask(self, images, annotations, overlay=False, fill="white"): - masks = [] - for image, annotation in zip(images, annotations): - mask_image = image.copy() if overlay else Image.new("L", image.size, 0) - draw = ImageDraw.Draw(mask_image) - - for _, _annotation in annotation.items(): - if "polygons" in _annotation: - for polygon in _annotation["polygons"]: - polygon = np.array(polygon).reshape(-1, 2) - if len(polygon) < 3: - continue - polygon = polygon.reshape(-1).tolist() - draw.polygon(polygon, fill=fill) - - elif "bbox" in _annotation: - bbox = _annotation["bbox"] - draw.rectangle(bbox, fill="white") - - masks.append(mask_image) - - return masks - - def prepare_bounding_boxes(self, images, annotations): - outputs = [] - for image, annotation in zip(images, annotations): - image_copy = image.copy() - draw = ImageDraw.Draw(image_copy) - for _, _annotation in annotation.items(): - bbox = _annotation["bbox"] - label = _annotation["label"] - - draw.rectangle(bbox, outline="red", width=3) - draw.text((bbox[0], bbox[1] - 20), label, fill="red") - - outputs.append(image_copy) - - return outputs - - def prepare_inputs(self, images, prompts): - prompts = prompts or "" - - if isinstance(images, Image.Image): - images = [images] - if isinstance(prompts, str): - prompts = [prompts] - - if len(images) != len(prompts): - raise ValueError("Number of images and annotation prompts must match.") - - return images, prompts +```python +import torch +from diffusers.modular_pipelines import PipelineState + + +class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): + + # ... expected_components, inputs, intermediate_outputs from above ... @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + images, annotation_task_prompt = self.prepare_inputs( block_state.image, block_state.annotation_prompt ) task = block_state.annotation_task fill = block_state.fill - + annotations = self.get_annotations( components, images, annotation_task_prompt, task ) @@ -400,67 +209,69 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): self.set_block_state(state, block_state) return components, state - -``` - -Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines. - - - - -```shell -# In the folder with the `block.py` file, run: -diffusers-cli custom_block -``` - -Then upload the block to the Hub: - -```shell -hf upload . . -``` - - - -```py -from block import Florence2ImageAnnotatorBlock -block = Florence2ImageAnnotatorBlock() -block.push_to_hub("") + + # Helper methods for mask/bounding box generation... ``` - - +> [!TIP] +> See the complete implementation at [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator). ## Using Custom Blocks -Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`. +Load a custom block with [`~ModularPipeline.from_pretrained`] and set `trust_remote_code=True`. ```py import torch -from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks -from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS +from diffusers import ModularPipeline from diffusers.utils import load_image -# Fetch the Florence2 image annotator block that will create our mask -image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True) +# Load the Florence-2 annotator pipeline +image_annotator = ModularPipeline.from_pretrained( + "diffusers/Florence2-image-Annotator", + trust_remote_code=True +) -my_blocks = INPAINT_BLOCKS.copy() -# insert the annotation block before the image encoding step -my_blocks.insert("image_annotator", image_annotator_block, 1) +# Check the docstring to see inputs/outputs +print(image_annotator.blocks.doc) +``` -# Create our initial set of inpainting blocks -blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks) +Use the block to generate a mask: -repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0" -pipe = blocks.init_pipeline(repo_id) -pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True) +```python +image_annotator.load_components(torch_dtype=torch.bfloat16) +image_annotator.to("cuda") -image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true") +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg") image = image.resize((1024, 1024)) - prompt = ["A red car"] annotation_task = "" annotation_prompt = ["the car"] +mask_image = image_annotator_node( + prompt=prompt, + image=image, + annotation_task=annotation_task, + annotation_prompt=annotation_prompt, + annotation_output_type="mask_image", +).images +mask_image[0].save("car-mask.png") +``` + +Compose it with other blocks to create a new pipeline: + +```python +# Get the annotator block +annotator_block = image_annotator.blocks + +# Get an inpainting workflow and insert the annotator at the beginning +inpaint_blocks = ModularPipeline.from_pretrained("Qwen/Qwen-Image").blocks.get_workflow("inpainting") +inpaint_blocks.sub_blocks.insert("image_annotator", annotator_block, 0) + +# Initialize the combined pipeline +pipe = inpaint_blocks.init_pipeline() +pipe.load_components(torch_dtype=torch.float16, device="cuda") + +# Now the pipeline automatically generates masks from prompts output = pipe( prompt=prompt, image=image, @@ -475,18 +286,50 @@ output = pipe( output[0].save("florence-inpainting.png") ``` -## Editing Custom Blocks +## Editing custom blocks -By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder. +Edit custom blocks by downloading it locally. This is the same workflow as the [Quick Start with Template](#quick-start-with-template), but starting from an existing block instead of the template. -```py -import torch -from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks -from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS -from diffusers.utils import load_image +Use the `local_dir` argument to download a custom block to a specific folder: + +```python +from diffusers import ModularPipelineBlocks -# Fetch the Florence2 image annotator block that will create our mask -image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder") +# Download to a local folder for editing +annotator_block = ModularPipelineBlocks.from_pretrained( + "diffusers/Florence2-image-Annotator", + trust_remote_code=True, + local_dir="./my-florence-block" +) ``` -Any changes made to the block files in this folder will be reflected when you load the block again. +Any changes made to the block files in this folder will be reflected when you load the block again. When you're ready to share your changes, upload to a new repository: + +```python +pipeline = annotator_block.init_pipeline() +pipeline.save_pretrained("./my-florence-block", repo_id="your-username/my-custom-florence", push_to_hub=True) +``` + +## Next Steps + + + + +This guide covered creating a single custom block. Learn how to compose multiple blocks together: + +- [SequentialPipelineBlocks](./sequential_pipeline_blocks): Chain blocks to execute in sequence +- [ConditionalPipelineBlocks](./auto_pipeline_blocks): Create conditional blocks that select different execution paths +- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks): Define an iterative workflows like the denoising loop + + + + +Make your custom block work with Mellon's visual interface. See the [Mellon Custom Blocks](./mellon) guide. + + + + +Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks. + + + \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/mellon.md b/docs/source/en/modular_diffusers/mellon.md new file mode 100644 index 000000000000..808e62ad7966 --- /dev/null +++ b/docs/source/en/modular_diffusers/mellon.md @@ -0,0 +1,270 @@ + + + +## Using Custom Blocks with Mellon + +[Mellon](https://github.com/cubiq/Mellon) is a visual workflow interface that integrates with Modular Diffusers and is designed for node-based workflows. + +> [!WARNING] +> Mellon is in early development and not ready for production use yet. Consider this a sneak peek of how the integration works! + + +Custom blocks work in Mellon out of the box - just need to add a `mellon_pipeline_config.json` to your repository. This config file tells Mellon how to render your block's parameters as UI components. + +Here's what it looks like in action with the [Gemini Prompt Expander](https://huggingface.co/diffusers/gemini-prompt-expander-mellon) block: + +![Mellon custom block demo](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/modular_demo_dynamic.gif) + +To use a modular diffusers custom block in Mellon: +1. Drag a **Dynamic Block Node** from the ModularDiffusers section +2. Enter the `repo_id` (e.g., `diffusers/gemini-prompt-expander-mellon`) +3. Click **Load Custom Block** +4. The node transforms to show your block's inputs and outputs + +Now let's walk through how to create this config for your own custom block. + +## Steps to create a Mellon config + +1. **Specify Mellon types for your parameters** - Each `InputParam`/`OutputParam` needs a type that tells Mellon what UI component to render (e.g., `"textbox"`, `"dropdown"`, `"image"`). +2. **Generate `mellon_pipeline_config.json`** - Use our utility to generate a config template and push it to your Hub repository. +3. **(Optional) Manually adjust the config** - Fine-tune the generated config for your specific needs. + +## Specify Mellon types for parameters + +Mellon types determine how each parameter renders in the UI. If you don't specify a type for a parameter, it will default to `"custom"`, which renders as a simple connection dot. You can always adjust this later in the generated config. + + +| Type | Input/Output | Description | +|------|--------------|-------------| +| `image` | Both | Image (PIL Image) | +| `video` | Both | Video | +| `text` | Both | Text display | +| `textbox` | Input | Text input | +| `dropdown` | Input | Dropdown selection menu | +| `slider` | Input | Slider for numeric values | +| `number` | Input | Numeric input | +| `checkbox` | Input | Boolean toggle | + +For parameters that need more configuration (like dropdowns with options, or sliders with min/max values), pass a `MellonParam` instance directly instead of a string. You can use one of the class methods below, or create a fully custom one with `MellonParam(name, label, type, ...)`. + +| Method | Description | +|--------|-------------| +| `MellonParam.Input.image(name)` | Image input | +| `MellonParam.Input.textbox(name, default)` | Text input as textarea | +| `MellonParam.Input.dropdown(name, options, default)` | Dropdown selection | +| `MellonParam.Input.slider(name, default, min, max, step)` | Slider for numeric values | +| `MellonParam.Input.number(name, default, min, max, step)` | Numeric input (no slider) | +| `MellonParam.Input.seed(name, default)` | Seed input with randomize button | +| `MellonParam.Input.checkbox(name, default)` | Boolean checkbox | +| `MellonParam.Input.model(name)` | Model input for diffusers components | +| `MellonParam.Output.image(name)` | Image output | +| `MellonParam.Output.video(name)` | Video output | +| `MellonParam.Output.text(name)` | Text output | +| `MellonParam.Output.model(name)` | Model output for diffusers components | + +Choose one of the methods below to specify a Mellon type. + +### Using `metadata` in block definitions + +If you're defining a custom block from scratch, add `metadata={"mellon": ""}` directly to your `InputParam` and `OutputParam` definitions. If you're editing an existing custom block from the Hub, see [Editing custom blocks](./custom_blocks#editing-custom-blocks) for how to download it locally. + +```python +class GeminiPromptExpander(ModularPipelineBlocks): + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "prompt", + type_hint=str, + required=True, + description="Prompt to use", + metadata={"mellon": "textbox"}, # Text input + ) + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt", + type_hint=str, + description="Expanded prompt by the LLM", + metadata={"mellon": "text"}, # Text output + ), + OutputParam( + "old_prompt", + type_hint=str, + description="Old prompt provided by the user", + # No metadata - we don't want to render this in UI + ) + ] +``` + +For full control over UI configuration, pass a `MellonParam` instance directly: +```python +from diffusers.modular_pipelines.mellon_node_utils import MellonParam + +InputParam( + "mode", + type_hint=str, + default="balanced", + metadata={"mellon": MellonParam.Input.dropdown("mode", options=["fast", "balanced", "quality"])}, +) +``` + +### Using `input_types` and `output_types` when Generating Config + +If you're working with an existing pipeline or prefer to keep your block definitions clean, specify types when generating the config using the `input_types/output_types` argument: +```python +from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig + +mellon_config = MellonPipelineConfig.from_custom_block( + blocks, + input_types={"prompt": "textbox"}, + output_types={"prompt": "text"} +) +``` + +> [!NOTE] +> When both `metadata` and `input_types`/`output_types` are specified, the arguments overrides `metadata`. + +## Generate and push the Mellon config + +After adding metadata to your block, generate the default Mellon configuration template and push it to the Hub: + +```python +from diffusers import ModularPipelineBlocks +from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig + +# load your custom blocks from your local dir +blocks = ModularPipelineBlocks.from_pretrained("/path/local/folder", trust_remote_code=True) + +# Generate the default config template +mellon_config = MellonPipelineConfig.from_custom_block(blocks) +# push the default template to `repo_id`, you will need to pass the same local folder path so that it will save the config locally first +mellon_config.save( + local_dir="/path/local/folder", + repo_id= repo_id, + push_to_hub=True +) +``` + +This creates a `mellon_pipeline_config.json` file in your repository. + +## Review and adjust the config + +The generated template is a starting point - you may want to adjust it for your needs. Let's walk through the generated config for the Gemini Prompt Expander: + +```json +{ + "label": "Gemini Prompt Expander", + "default_repo": "", + "default_dtype": "", + "node_params": { + "custom": { + "params": { + "prompt": { + "label": "Prompt", + "type": "string", + "display": "textarea", + "default": "" + }, + "out_prompt": { + "label": "Prompt", + "type": "string", + "display": "output" + }, + "old_prompt": { + "label": "Old Prompt", + "type": "custom", + "display": "output" + }, + "doc": { + "label": "Doc", + "type": "string", + "display": "output" + } + }, + "input_names": ["prompt"], + "model_input_names": [], + "output_names": ["out_prompt", "old_prompt", "doc"], + "block_name": "custom", + "node_type": "custom" + } + } +} +``` + +### Understanding the Structure + +The `params` dict defines how each UI element renders. The `input_names`, `model_input_names`, and `output_names` lists map these UI elements to the underlying [`ModularPipelineBlocks`]'s I/O interface: + +| Mellon Config | ModularPipelineBlocks | +|---------------|----------------------| +| `input_names` | `inputs` property | +| `model_input_names` | `expected_components` property | +| `output_names` | `intermediate_outputs` property | + +In this example: `prompt` is the only input. There are no model components, and outputs include `out_prompt`, `old_prompt`, and `doc`. + +Now let's look at the `params` dict: + +- **`prompt`**: An input parameter with `display: "textarea"` (renders as a text input box), `label: "Prompt"` (shown in the UI), and `default: ""` (starts empty). The `type: "string"` field is important in Mellon because it determines which nodes can connect together - only matching types can be linked with "noodles". + +- **`out_prompt`**: The expanded prompt output. The `out_` prefix was automatically added because the input and output share the same name (`prompt`), avoiding naming conflicts in the config. It has `display: "output"` which renders as an output socket. + +- **`old_prompt`**: Has `type: "custom"` because we didn't specify metadata. This renders as a simple dot in the UI. Since we don't actually want to expose this in the UI, we can remove it. + +- **`doc`**: The documentation output, automatically added to all custom blocks. + +### Making Adjustments + +Remove `old_prompt` from both `params` and `output_names` because you won't need to use it. + +```json +{ + "label": "Gemini Prompt Expander", + "default_repo": "", + "default_dtype": "", + "node_params": { + "custom": { + "params": { + "prompt": { + "label": "Prompt", + "type": "string", + "display": "textarea", + "default": "" + }, + "out_prompt": { + "label": "Prompt", + "type": "string", + "display": "output" + }, + "doc": { + "label": "Doc", + "type": "string", + "display": "output" + } + }, + "input_names": ["prompt"], + "model_input_names": [], + "output_names": ["out_prompt", "doc"], + "block_name": "custom", + "node_type": "custom" + } + } +} +``` + +See the final config at [diffusers/gemini-prompt-expander-mellon](https://huggingface.co/diffusers/gemini-prompt-expander-mellon). \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md index 8e27cad6eb91..83975200d664 100644 --- a/docs/source/en/modular_diffusers/overview.md +++ b/docs/source/en/modular_diffusers/overview.md @@ -33,9 +33,14 @@ The Modular Diffusers docs are organized as shown below. - [SequentialPipelineBlocks](./sequential_pipeline_blocks) is a type of block that chains multiple blocks so they run one after another, passing data along the chain. This guide shows you how to create [`~modular_pipelines.SequentialPipelineBlocks`] and how they connect and work together. - [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks) is a type of block that runs a series of blocks in a loop. This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`]. - [AutoPipelineBlocks](./auto_pipeline_blocks) is a type of block that automatically chooses which blocks to run based on the input. This guide shows you how to create [`~modular_pipelines.AutoPipelineBlocks`]. +- [Building Custom Blocks](./custom_blocks) shows you how to create your own custom blocks and share them on the Hub. ## ModularPipeline - [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`]. - [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines. -- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline. \ No newline at end of file +- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline. + +## Mellon Integration + +- [Using Custom Blocks with Mellon](./mellon) shows you how to make your custom blocks work with [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows. \ No newline at end of file diff --git a/docs/source/en/optimization/cache.md b/docs/source/en/optimization/cache.md index 3854ecd469f8..4eccd70cb304 100644 --- a/docs/source/en/optimization/cache.md +++ b/docs/source/en/optimization/cache.md @@ -111,3 +111,57 @@ config = TaylorSeerCacheConfig( ) pipe.transformer.enable_cache(config) ``` + +## MagCache + +[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual. + +MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler. + +### Usage + +To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**. + +1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console. +2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration. + +```python +import torch +from diffusers import FluxPipeline, MagCacheConfig + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16 +).to("cuda") + +# 1. Calibration Step +# Run full inference to measure model behavior. +calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4) +pipe.transformer.enable_cache(calib_config) + +# Run a prompt to trigger calibration +pipe("A cat playing chess", num_inference_steps=4) +# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]" + +# 2. Inference Step +# Apply the specific ratios obtained from calibration for optimized speed. +# Note: For Flux models, you can also import defaults: +# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS +mag_config = MagCacheConfig( + mag_ratios=[1.0, 1.37, 0.97, 0.87], + num_inference_steps=4 +) + +pipe.transformer.enable_cache(mag_config) + +image = pipe("A cat playing chess", num_inference_steps=4).images[0] +``` + +> [!NOTE] +> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps. + +> [!TIP] +> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional). + +> [!TIP] +> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification. diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index aa415dbc36e2..de90c3006e8f 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -66,7 +66,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf from torchao.quantization import Int4WeightOnlyConfig pipeline_quant_config = PipelineQuantizationConfig( - quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))} + quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))} ) pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", diff --git a/examples/dreambooth/README_z_image.md b/examples/dreambooth/README_z_image.md new file mode 100644 index 000000000000..cded38f3f11f --- /dev/null +++ b/examples/dreambooth/README_z_image.md @@ -0,0 +1,347 @@ +# DreamBooth training example for Z-Image + +[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept. +[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. + +The `train_dreambooth_lora_z_image.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image). + +> [!NOTE] +> **About Z-Image** +> +> Z-Image is a high-quality text-to-image generation model from Alibaba's Tongyi Lab. It uses a DiT (Diffusion Transformer) architecture with Qwen3 as the text encoder. The model excels at generating images with accurate text rendering, especially for Chinese characters. + +> [!NOTE] +> **Memory consumption** +> +> Z-Image is relatively memory efficient compared to other large-scale diffusion models. Below we provide some tips and tricks to further reduce memory consumption during training. + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/dreambooth` folder and run +```bash +pip install -r requirements_z_image.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + + +### Dog toy example + +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. + +## Memory Optimizations + +> [!NOTE] +> Many of these techniques complement each other and can be used together to further reduce memory consumption. However some techniques may be mutually exclusive so be sure to check before launching a training run. + +### CPU Offloading +To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the VAE and text encoder to CPU memory and only move them to GPU when needed. + +### Latent Caching +Pre-encode the training images with the VAE, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`. + +### QLoRA: Low Precision Training with Quantization +Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags: + +- **FP8 training** with `torchao`: +Enable FP8 training by passing `--do_fp8_training`. +> [!IMPORTANT] +> Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers. + +- **NF4 training** with `bitsandbytes`: +Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing `--bnb_quantization_config_path` to enable 4-bit NF4 quantization. + +### Gradient Checkpointing and Accumulation +* `--gradient_accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. By passing a value > 1 you can reduce the amount of backward/update passes and hence also memory requirements. +* With `--gradient_checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expense of a slower backward pass. + +### 8-bit-Adam Optimizer +When training with `AdamW` (doesn't apply to `prodigy`) you can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. + +### Image Resolution +An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. +Note that by default, images are resized to resolution of 1024, but it's good to keep in mind in case you're training on higher resolutions. + +### Precision of saved LoRA layers +By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well. +This reduces memory requirements significantly without a significant quality loss. Note that if you do wish to save the final layers in float32 at the expense of more memory usage, you can do so by passing `--upcast_before_saving`. + +## Training Examples + +### Z-Image Training + +To perform DreamBooth with LoRA on Z-Image, run: + +```bash +export MODEL_NAME="Tongyi-MAI/Z-Image" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-z-image-lora" + +accelerate launch train_dreambooth_lora_z_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="bf16" \ + --gradient_checkpointing \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=5.0 \ + --use_8bit_adam \ + --gradient_accumulation_steps=4 \ + --optimizer="adamW" \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb"` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +> [!NOTE] +> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. The default is 512. Note that this will use more resources and may slow down the training in some cases. + +### Training with FP8 Quantization + +For reduced memory usage with FP8 training: + +```bash +export MODEL_NAME="Tongyi-MAI/Z-Image" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-z-image-lora-fp8" + +accelerate launch train_dreambooth_lora_z_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --do_fp8_training \ + --gradient_checkpointing \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=5.0 \ + --use_8bit_adam \ + --gradient_accumulation_steps=4 \ + --optimizer="adamW" \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +### FSDP on the transformer + +By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to: + +```yaml +distributed_type: FSDP +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_sharding_strategy: HYBRID_SHARD + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: ZImageTransformerBlock + fsdp_forward_prefetch: true + fsdp_sync_module_states: false + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_use_orig_params: false + fsdp_activation_checkpointing: true + fsdp_reshard_after_forward: true + fsdp_cpu_ram_efficient_loading: false +``` + +### Prodigy Optimizer + +Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. +By using prodigy we can "eliminate" the need for manual learning rate tuning. Read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers). + +To use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify: +```bash +--optimizer="prodigy" +``` + +> [!TIP] +> When using prodigy it's generally good practice to set `--learning_rate=1.0` + +```bash +export MODEL_NAME="Tongyi-MAI/Z-Image" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-z-image-lora-prodigy" + +accelerate launch train_dreambooth_lora_z_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="bf16" \ + --gradient_checkpointing \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=5.0 \ + --gradient_accumulation_steps=4 \ + --optimizer="prodigy" \ + --learning_rate=1.0 \ + --report_to="wandb" \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +### LoRA Rank and Alpha + +Two key LoRA hyperparameters are LoRA rank and LoRA alpha: + +- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters). +- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by `lora_alpha / lora_rank`. + +**lora_alpha vs. rank:** + +This ratio dictates the LoRA's effective strength: +- `lora_alpha == rank`: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16) +- `lora_alpha < rank`: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16) +- `lora_alpha > rank`: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16) + +> [!TIP] +> A common starting point is to set `lora_alpha` equal to `rank`. +> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) +> to give the LoRA updates more influence without increasing parameter count. +> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank` +> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case. + +### Target Modules + +When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the UNet that relate the image representations with the prompts that describe them. +More recently, SOTA text-to-image diffusion models replaced the UNet with a diffusion Transformer (DiT). With this change, we may also want to explore applying LoRA training onto different types of layers and blocks. + +To allow more flexibility and control over the targeted modules we added `--lora_layers`, in which you can specify in a comma separated string the exact modules for LoRA training. Here are some examples of target modules you can provide: + +- For attention only layers: `--lora_layers="to_k,to_q,to_v,to_out.0"` +- For attention and feed-forward layers: `--lora_layers="to_k,to_q,to_v,to_out.0,ff.net.0.proj,ff.net.2"` + +> [!NOTE] +> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string. + +> [!NOTE] +> Keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights. + +### Aspect Ratio Bucketing + +We've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency. + +To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as: + +```bash +--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672" +``` + +### Bilingual Prompts + +Z-Image has strong support for both Chinese and English prompts. When training with Chinese prompts, ensure your dataset captions are properly encoded in UTF-8: + +```bash +--instance_prompt="一只sks狗的照片" +--validation_prompt="一只sks狗在桶里的照片" +``` + +> [!TIP] +> Z-Image excels at text rendering in generated images, especially for Chinese characters. If your use case involves generating images with text, consider including text-related examples in your training data. + +## Inference + +Once you have trained a LoRA, you can load it for inference: + +```python +import torch +from diffusers import ZImagePipeline + +pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Load your trained LoRA +pipe.load_lora_weights("path/to/your/trained-z-image-lora") + +# Generate an image +image = pipe( + prompt="A photo of sks dog in a bucket", + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), +).images[0] + +image.save("output.png") +``` + +--- + +Since Z-Image finetuning is still in an experimental phase, we encourage you to explore different settings and share your insights! 🤗 \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py new file mode 100644 index 000000000000..c77953f16410 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -0,0 +1,1912 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import copy +import itertools +import json +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm.auto import tqdm +from transformers import Qwen2Tokenizer, Qwen3Model + +import diffusers +from diffusers import ( + AutoencoderKL, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + ZImagePipeline, + ZImageTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _collate_lora_metadata, + _to_cpu_contiguous, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + get_fsdp_kwargs_from_accelerator, + offload_models, + parse_buckets_string, + wrap_with_fsdp, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.37.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + quant_training=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Z Image DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Z Image diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_z_image.md). + +Quant training? {quant_training} + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained("Tongyi-MAI/Z-Image", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Apace License 2.0 +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="apache-2.0", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "z-image", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(dtype=torch_dtype) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt=args.validation_prompt, + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the output module + if fqn == "proj_out": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="z-image-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + if args.do_fp8_training and args.bnb_quantization_config_path: + raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + buckets=None, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + for i, image in enumerate(self.instance_images): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + image = self.train_transform( + image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, + ) + self.pixel_values.append((image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + + return image + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + + pipeline = ZImagePipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer = Qwen2Tokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = ZImageTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + text_encoder = Qwen3Model.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, + ) + text_encoder.requires_grad_(False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + vae.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + + if args.do_fp8_training: + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + text_encoder.to(**to_kwargs) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = ZImagePipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=None, + revision=args.revision, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None + if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + ZImagePipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not is_fsdp: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = ZImageTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = ZImagePipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + buckets=buckets, + ) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + with torch.no_grad(): + prompt_embeds, _ = text_encoding_pipeline.encode_prompt( + prompt=prompt, + max_sequence_length=args.max_sequence_length, + ) + return prompt_embeds + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_prompt_hidden_states = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + class_prompt_hidden_states = compute_text_embeddings(args.class_prompt, text_encoding_pipeline) + validation_embeddings = {} + if args.validation_prompt is not None: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + validation_embeddings["prompt_embeds"] = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) + + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + prompt_embeds_cache = [] + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + if args.fsdp_text_encoder: + prompt_embeds = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + prompt_embeds = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds_cache.append(prompt_embeds) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-z-image-lora" + args_cp = vars(args).copy() + accelerator.init_trackers(tracker_name, config=args_cp) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + if train_dataset.custom_instance_prompts: + prompt_embeds = prompt_embeds_cache[step] + else: + num_repeat_elements = len(prompts) + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_repeat_elements)] + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].mode() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() + + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + timestep_normalized = (1000 - timesteps) / 1000 + + noisy_model_input_5d = noisy_model_input.unsqueeze(2) # (B, C, H, W) -> (B, C, 1, H, W) + noisy_model_input_list = list(noisy_model_input_5d.unbind(dim=0)) # List of (C, 1, H, W) + + model_pred_list = transformer( + noisy_model_input_list, + timestep_normalized, + prompt_embeds, # This is a List[torch.Tensor] for Z-Image + return_dict=False, + )[0] + model_pred = torch.stack(model_pred_list, dim=0) # (B, C, 1, H, W) + model_pred = model_pred.squeeze(2) # (B, C, H, W) + model_pred = -model_pred # z-Image negates the prediction + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or is_fsdp: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = ZImagePipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + torch_dtype=weight_dtype, + ) + + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) + if accelerator.is_main_process: + modules_to_save = {} + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + modules_to_save["transformer"] = transformer + + ZImagePipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + images = [] + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + pipeline = ZImagePipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + images = None + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + quant_training = None + if args.do_fp8_training: + quant_training = "FP8 TorchAO" + elif args.bnb_quantization_config_path: + quant_training = "BitsandBytes" + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + quant_training=quant_training, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py index bc6014068e87..ae66c9b8197c 100644 --- a/scripts/convert_cosmos_to_diffusers.py +++ b/scripts/convert_cosmos_to_diffusers.py @@ -78,12 +78,67 @@ --save_pipeline ``` +# Cosmos 2.5 Transfer + +Download checkpoint +```bash +hf download nvidia/Cosmos-Transfer2.5-2B +``` + +Convert checkpoint +```bash +# depth +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/depth/626e6618-bfcd-4d9a-a077-1409e2ce353f_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Transfer-General-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/transfer/2b/general/depth \ + --save_pipeline + +# edge +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Transfer-General-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/transfer/2b/general/edge/pipeline \ + --save_pipeline + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Transfer-General-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/transfer/2b/general/edge/models + +# blur +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/blur/ba2f44f2-c726-4fe7-949f-597069d9b91c_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Transfer-General-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/transfer/2b/general/blur \ + --save_pipeline + +# seg +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Transfer-General-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/transfer/2b/general/seg \ + --save_pipeline +``` """ import argparse import pathlib import sys -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from accelerate import init_empty_weights @@ -95,6 +150,7 @@ AutoencoderKLWan, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, + CosmosControlNetModel, CosmosTextToWorldPipeline, CosmosTransformer3DModel, CosmosVideoToWorldPipeline, @@ -103,6 +159,7 @@ UniPCMultistepScheduler, ) from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline +from diffusers.pipelines.cosmos.pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline def remove_keys_(key: str, state_dict: Dict[str, Any]): @@ -356,8 +413,62 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]): "crossattn_proj_in_channels": 100352, "encoder_hidden_states_channels": 1024, }, + "Cosmos-2.5-Transfer-General-2B": { + "in_channels": 16 + 1, + "out_channels": 16, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 28, + "mlp_ratio": 4.0, + "text_embed_dim": 1024, + "adaln_lora_dim": 256, + "max_size": (128, 240, 240), + "patch_size": (1, 2, 2), + "rope_scale": (1.0, 3.0, 3.0), + "concat_padding_mask": True, + "extra_pos_embed_type": None, + "use_crossattn_projection": True, + "crossattn_proj_in_channels": 100352, + "encoder_hidden_states_channels": 1024, + "controlnet_block_every_n": 7, + "img_context_dim_in": 1152, + "img_context_dim_out": 2048, + "img_context_num_tokens": 256, + }, +} + +CONTROLNET_CONFIGS = { + "Cosmos-2.5-Transfer-General-2B": { + "n_controlnet_blocks": 4, + "model_channels": 2048, + "in_channels": 130, + "latent_channels": 18, # (16 latent + 1 condition_mask) + 1 padding_mask = 18 + "num_attention_heads": 16, + "attention_head_dim": 128, + "mlp_ratio": 4.0, + "text_embed_dim": 1024, + "adaln_lora_dim": 256, + "patch_size": (1, 2, 2), + "max_size": (128, 240, 240), + "rope_scale": (1.0, 3.0, 3.0), + "extra_pos_embed_type": None, + "img_context_dim_in": 1152, + "img_context_dim_out": 2048, + "use_crossattn_projection": True, + "crossattn_proj_in_channels": 100352, + "encoder_hidden_states_channels": 1024, + }, } +CONTROLNET_KEYS_RENAME_DICT = { + **TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0, + "blocks": "blocks", + "control_embedder.proj.1": "patch_embed.proj", +} + + +CONTROLNET_SPECIAL_KEYS_REMAP = {**TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0} + VAE_KEYS_RENAME_DICT = { "down.0": "down_blocks.0", "down.1": "down_blocks.1", @@ -447,9 +558,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: return state_dict -def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True): +def convert_transformer( + transformer_type: str, + state_dict: Optional[Dict[str, Any]] = None, + weights_only: bool = True, +): PREFIX_KEY = "net." - original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only)) if "Cosmos-1.0" in transformer_type: TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 @@ -467,23 +581,29 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo config = TRANSFORMER_CONFIGS[transformer_type] transformer = CosmosTransformer3DModel(**config) - for key in list(original_state_dict.keys()): + old2new = {} + new2old = {} + for key in list(state_dict.keys()): new_key = key[:] if new_key.startswith(PREFIX_KEY): new_key = new_key.removeprefix(PREFIX_KEY) for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) print(key, "->", new_key, flush=True) - update_state_dict_(original_state_dict, key, new_key) + assert new_key not in new2old, f"new key {new_key} already mapped" + assert key not in old2new, f"old key {key} already mapped" + old2new[key] = new_key + new2old[new_key] = key + update_state_dict_(state_dict, key, new_key) - for key in list(original_state_dict.keys()): + for key in list(state_dict.keys()): for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): if special_key not in key: continue - handler_fn_inplace(key, original_state_dict) + handler_fn_inplace(key, state_dict) expected_keys = set(transformer.state_dict().keys()) - mapped_keys = set(original_state_dict.keys()) + mapped_keys = set(state_dict.keys()) missing_keys = expected_keys - mapped_keys unexpected_keys = mapped_keys - expected_keys if missing_keys: @@ -497,10 +617,86 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo print(k) sys.exit(2) - transformer.load_state_dict(original_state_dict, strict=True, assign=True) + transformer.load_state_dict(state_dict, strict=True, assign=True) return transformer +def convert_controlnet( + transformer_type: str, + control_state_dict: Dict[str, Any], + base_state_dict: Dict[str, Any], + weights_only: bool = True, +): + """ + Convert controlnet weights. + + Args: + transformer_type: The type of transformer/controlnet + control_state_dict: State dict containing controlnet-specific weights + base_state_dict: State dict containing base transformer weights (for shared modules) + weights_only: Whether to use weights_only loading + """ + if transformer_type not in CONTROLNET_CONFIGS: + raise AssertionError(f"{transformer_type} does not define a ControlNet config") + + PREFIX_KEY = "net." + + # Process control-specific keys + for key in list(control_state_dict.keys()): + new_key = key[:] + if new_key.startswith(PREFIX_KEY): + new_key = new_key.removeprefix(PREFIX_KEY) + for replace_key, rename_key in CONTROLNET_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(control_state_dict, key, new_key) + + for key in list(control_state_dict.keys()): + for special_key, handler_fn_inplace in CONTROLNET_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, control_state_dict) + + # Copy shared weights from base transformer to controlnet + # These are the duplicated modules: patch_embed_base, time_embed, learnable_pos_embed, img_context_proj, crossattn_proj + shared_module_mappings = { + # transformer key prefix -> controlnet key prefix + "patch_embed.": "patch_embed_base.", + "time_embed.": "time_embed.", + "learnable_pos_embed.": "learnable_pos_embed.", + "img_context_proj.": "img_context_proj.", + "crossattn_proj.": "crossattn_proj.", + } + + for key in list(base_state_dict.keys()): + for transformer_prefix, controlnet_prefix in shared_module_mappings.items(): + if key.startswith(transformer_prefix): + controlnet_key = controlnet_prefix + key[len(transformer_prefix) :] + control_state_dict[controlnet_key] = base_state_dict[key].clone() + print(f"Copied shared weight: {key} -> {controlnet_key}", flush=True) + break + + cfg = CONTROLNET_CONFIGS[transformer_type] + controlnet = CosmosControlNetModel(**cfg) + + expected_keys = set(controlnet.state_dict().keys()) + mapped_keys = set(control_state_dict.keys()) + missing_keys = expected_keys - mapped_keys + unexpected_keys = mapped_keys - expected_keys + if missing_keys: + print(f"WARNING: missing controlnet keys ({len(missing_keys)}):", file=sys.stderr, flush=True) + for k in sorted(missing_keys): + print(k, file=sys.stderr) + sys.exit(3) + if unexpected_keys: + print(f"WARNING: unexpected controlnet keys ({len(unexpected_keys)}):", file=sys.stderr, flush=True) + for k in sorted(unexpected_keys): + print(k, file=sys.stderr) + sys.exit(4) + + controlnet.load_state_dict(control_state_dict, strict=True, assign=True) + return controlnet + + def convert_vae(vae_type: str): model_name = VAE_CONFIGS[vae_type]["name"] snapshot_directory = snapshot_download(model_name, repo_type="model") @@ -586,7 +782,7 @@ def save_pipeline_cosmos_2_0(args, transformer, vae): pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") -def save_pipeline_cosmos2_5(args, transformer, vae): +def save_pipeline_cosmos2_5_predict(args, transformer, vae): text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B" tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct" @@ -614,6 +810,35 @@ def save_pipeline_cosmos2_5(args, transformer, vae): pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") +def save_pipeline_cosmos2_5_transfer(args, transformer, controlnet, vae): + text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B" + tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct" + + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + text_encoder_path, torch_dtype="auto", device_map="cpu" + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + scheduler = UniPCMultistepScheduler( + use_karras_sigmas=True, + use_flow_sigmas=True, + prediction_type="flow_prediction", + sigma_max=200.0, + sigma_min=0.01, + ) + + pipe = Cosmos2_5_TransferPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + controlnet=controlnet, + vae=vae, + scheduler=scheduler, + safety_checker=lambda *args, **kwargs: None, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys())) @@ -642,18 +867,61 @@ def get_args(): args = get_args() transformer = None + controlnet = None dtype = DTYPE_MAPPING[args.dtype] if args.save_pipeline: assert args.transformer_ckpt_path is not None assert args.vae_type is not None + raw_state_dict = None if args.transformer_ckpt_path is not None: weights_only = "Cosmos-1.0" in args.transformer_type - transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only) - transformer = transformer.to(dtype=dtype) - if not args.save_pipeline: - transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + raw_state_dict = get_state_dict( + torch.load(args.transformer_ckpt_path, map_location="cpu", weights_only=weights_only) + ) + + if raw_state_dict is not None: + if "Transfer" in args.transformer_type: + base_state_dict = {} + control_state_dict = {} + for k, v in raw_state_dict.items(): + plain_key = k.removeprefix("net.") if k.startswith("net.") else k + if "control" in plain_key.lower(): + control_state_dict[k] = v + else: + base_state_dict[k] = v + assert len(base_state_dict.keys() & control_state_dict.keys()) == 0 + + # Convert transformer first to get the processed base state dict + transformer = convert_transformer( + args.transformer_type, state_dict=base_state_dict, weights_only=weights_only + ) + transformer = transformer.to(dtype=dtype) + + # Get converted transformer state dict to copy shared weights to controlnet + converted_base_state_dict = transformer.state_dict() + + # Convert controlnet with both control-specific and shared weights from transformer + controlnet = convert_controlnet( + args.transformer_type, control_state_dict, converted_base_state_dict, weights_only=weights_only + ) + controlnet = controlnet.to(dtype=dtype) + + if not args.save_pipeline: + transformer.save_pretrained( + pathlib.Path(args.output_path) / "transformer", safe_serialization=True, max_shard_size="5GB" + ) + controlnet.save_pretrained( + pathlib.Path(args.output_path) / "controlnet", safe_serialization=True, max_shard_size="5GB" + ) + else: + transformer = convert_transformer( + args.transformer_type, state_dict=raw_state_dict, weights_only=weights_only + ) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") if args.vae_type is not None: if "Cosmos-1.0" in args.transformer_type: @@ -667,6 +935,8 @@ def get_args(): if not args.save_pipeline: vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + else: + vae = None if args.save_pipeline: if "Cosmos-1.0" in args.transformer_type: @@ -678,6 +948,11 @@ def get_args(): assert args.tokenizer_path is not None save_pipeline_cosmos_2_0(args, transformer, vae) elif "Cosmos-2.5" in args.transformer_type: - save_pipeline_cosmos2_5(args, transformer, vae) + if "Predict" in args.transformer_type: + save_pipeline_cosmos2_5_predict(args, transformer, vae) + elif "Transfer" in args.transformer_type: + save_pipeline_cosmos2_5_transfer(args, transformer, None, vae) + else: + raise AssertionError(f"{args.transformer_type} not supported") else: raise AssertionError(f"{args.transformer_type} not supported") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 52ec30c536bd..61ccfd85c192 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -168,12 +168,14 @@ "FirstBlockCacheConfig", "HookRegistry", "LayerSkipConfig", + "MagCacheConfig", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", "TaylorSeerCacheConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", + "apply_mag_cache", "apply_pyramid_attention_broadcast", "apply_taylorseer_cache", ] @@ -219,6 +221,7 @@ "ControlNetModel", "ControlNetUnionModel", "ControlNetXSAdapter", + "CosmosControlNetModel", "CosmosTransformer3DModel", "DiTTransformer2DModel", "EasyAnimateTransformer3DModel", @@ -415,6 +418,7 @@ "Flux2AutoBlocks", "Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks", + "Flux2KleinBaseModularPipeline", "Flux2KleinModularPipeline", "Flux2ModularPipeline", "FluxAutoBlocks", @@ -431,8 +435,13 @@ "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", - "Wan22AutoBlocks", - "WanAutoBlocks", + "Wan22Blocks", + "Wan22Image2VideoBlocks", + "Wan22Image2VideoModularPipeline", + "Wan22ModularPipeline", + "WanBlocks", + "WanImage2VideoAutoBlocks", + "WanImage2VideoModularPipeline", "WanModularPipeline", "ZImageAutoBlocks", "ZImageModularPipeline", @@ -477,6 +486,7 @@ "CogView4Pipeline", "ConsisIDPipeline", "Cosmos2_5_PredictBasePipeline", + "Cosmos2_5_TransferPipeline", "Cosmos2TextToImagePipeline", "Cosmos2VideoToWorldPipeline", "CosmosTextToWorldPipeline", @@ -694,6 +704,7 @@ "ZImageControlNetInpaintPipeline", "ZImageControlNetPipeline", "ZImageImg2ImgPipeline", + "ZImageInpaintPipeline", "ZImageOmniPipeline", "ZImagePipeline", ] @@ -932,12 +943,14 @@ FirstBlockCacheConfig, HookRegistry, LayerSkipConfig, + MagCacheConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, + apply_mag_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, ) @@ -981,6 +994,7 @@ ControlNetModel, ControlNetUnionModel, ControlNetXSAdapter, + CosmosControlNetModel, CosmosTransformer3DModel, DiTTransformer2DModel, EasyAnimateTransformer3DModel, @@ -1151,6 +1165,7 @@ Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, + Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline, FluxAutoBlocks, @@ -1167,8 +1182,13 @@ QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, - Wan22AutoBlocks, - WanAutoBlocks, + Wan22Blocks, + Wan22Image2VideoBlocks, + Wan22Image2VideoModularPipeline, + Wan22ModularPipeline, + WanBlocks, + WanImage2VideoAutoBlocks, + WanImage2VideoModularPipeline, WanModularPipeline, ZImageAutoBlocks, ZImageModularPipeline, @@ -1209,6 +1229,7 @@ CogView4Pipeline, ConsisIDPipeline, Cosmos2_5_PredictBasePipeline, + Cosmos2_5_TransferPipeline, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, CosmosTextToWorldPipeline, @@ -1424,6 +1445,7 @@ ZImageControlNetInpaintPipeline, ZImageControlNetPipeline, ZImageImg2ImgPipeline, + ZImageInpaintPipeline, ZImageOmniPipeline, ZImagePipeline, ) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index eb12b8a52a1e..23c8bc92b2f1 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -23,6 +23,7 @@ from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook + from .mag_cache import MagCacheConfig, apply_mag_cache from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index ca7934e5c313..f5dd1f8c7c4d 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -23,7 +23,13 @@ _ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin) _FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) -_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ( + "blocks", + "transformer_blocks", + "single_transformer_blocks", + "layers", + "visual_transformer_blocks", +) _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index da7313cb4737..1cbc3a35d5b9 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -26,6 +26,7 @@ class AttentionProcessorMetadata: class TransformerBlockMetadata: return_hidden_states_index: int = None return_encoder_hidden_states_index: int = None + hidden_states_argument_name: str = "hidden_states" _cls: Type = None _cached_parameter_indices: Dict[str, int] = None @@ -169,7 +170,7 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): - from ..models.attention import BasicTransformerBlock + from ..models.attention import BasicTransformerBlock, JointTransformerBlock from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock from ..models.transformers.transformer_bria import BriaTransformerBlock from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock @@ -184,6 +185,7 @@ def _register_transformer_blocks_metadata(): HunyuanImageSingleTransformerBlock, HunyuanImageTransformerBlock, ) + from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock @@ -331,6 +333,24 @@ def _register_transformer_blocks_metadata(): ), ) + TransformerBlockRegistry.register( + model_class=JointTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock) + TransformerBlockRegistry.register( + model_class=Kandinsky5TransformerDecoderBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + hidden_states_argument_name="visual_embed", + ), + ) + # fmt: off def _skip_attention___ret___hidden_states(self, *args, **kwargs): diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py new file mode 100644 index 000000000000..d28cd2d793b6 --- /dev/null +++ b/src/diffusers/hooks/mag_cache.py @@ -0,0 +1,468 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook" +_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" + +# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience. +# Users must explicitly pass these to the config if using Flux. +# Reference: https://github.com/Zehong-Ma/MagCache +FLUX_MAG_RATIOS = torch.tensor( + [1.0] + + [ + 1.21094, + 1.11719, + 1.07812, + 1.0625, + 1.03906, + 1.03125, + 1.03906, + 1.02344, + 1.03125, + 1.02344, + 0.98047, + 1.01562, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.0, + 0.99609, + 0.99609, + 0.98047, + 0.98828, + 0.96484, + 0.95703, + 0.93359, + 0.89062, + ] +) + + +def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: + """ + Interpolate the source array to the target length using nearest neighbor interpolation. + """ + src_length = len(src_array) + if target_length == 1: + return src_array[-1:] + + scale = (src_length - 1) / (target_length - 1) + grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) + mapped_indices = torch.round(grid * scale).long() + return src_array[mapped_indices] + + +@dataclass +class MagCacheConfig: + r""" + Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache). + + Args: + threshold (`float`, defaults to `0.06`): + The threshold for the accumulated error. If the accumulated error is below this threshold, the block + computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade + quality. + max_skip_steps (`int`, defaults to `3`): + The maximum number of consecutive steps that can be skipped (K in the paper). + retention_ratio (`float`, defaults to `0.2`): + The fraction of initial steps during which skipping is disabled to ensure stability. For example, if + `num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped. + num_inference_steps (`int`, defaults to `28`): + The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly. + mag_ratios (`torch.Tensor`, *optional*): + The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must + set `calibrate=True` to calculate them for your specific model. For Flux models, you can use + `diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`. + calibrate (`bool`, defaults to `False`): + If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the + magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new + models or schedulers. + """ + + threshold: float = 0.06 + max_skip_steps: int = 3 + retention_ratio: float = 0.2 + num_inference_steps: int = 28 + mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None + calibrate: bool = False + + def __post_init__(self): + # User MUST provide ratios OR enable calibration. + if self.mag_ratios is None and not self.calibrate: + raise ValueError( + " `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n" + "To get them for your model:\n" + "1. Initialize `MagCacheConfig(calibrate=True, ...)`\n" + "2. Run inference on your model once.\n" + "3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n" + "For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`." + ) + + if not self.calibrate and self.mag_ratios is not None: + if not torch.is_tensor(self.mag_ratios): + self.mag_ratios = torch.tensor(self.mag_ratios) + + if len(self.mag_ratios) != self.num_inference_steps: + logger.debug( + f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}" + ) + self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps) + + +class MagCacheState(BaseState): + def __init__(self) -> None: + super().__init__() + # Cache for the residual (output - input) from the *previous* timestep + self.previous_residual: torch.Tensor = None + + # State inputs/outputs for the current forward pass + self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + # MagCache accumulators + self.accumulated_ratio: float = 1.0 + self.accumulated_err: float = 0.0 + self.accumulated_steps: int = 0 + + # Current step counter (timestep index) + self.step_index: int = 0 + + # Calibration storage + self.calibration_ratios: List[float] = [] + + def reset(self): + self.previous_residual = None + self.should_compute = True + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + self.step_index = 0 + self.calibration_ratios = [] + + +class MagCacheHeadHook(ModelHook): + _is_stateful = True + + def __init__(self, state_manager: StateManager, config: MagCacheConfig): + self.state_manager = state_manager + self.config = config + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + @torch.compiler.disable + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("inference") + + arg_name = self._metadata.hidden_states_argument_name + hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) + + state: MagCacheState = self.state_manager.get_state() + state.head_block_input = hidden_states + + should_compute = True + + if self.config.calibrate: + # Never skip during calibration + should_compute = True + else: + # MagCache Logic + current_step = state.step_index + if current_step >= len(self.config.mag_ratios): + current_scale = 1.0 + else: + current_scale = self.config.mag_ratios[current_step] + + retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) + + if current_step >= retention_step: + state.accumulated_ratio *= current_scale + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) + + if ( + state.previous_residual is not None + and state.accumulated_err <= self.config.threshold + and state.accumulated_steps <= self.config.max_skip_steps + ): + should_compute = False + else: + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + + state.should_compute = should_compute + + if not should_compute: + logger.debug(f"MagCache: Skipping step {state.step_index}") + # Apply MagCache: Output = Input + Previous Residual + + output = hidden_states + res = state.previous_residual + + if res.device != output.device: + res = res.to(output.device) + + # Attempt to apply residual handling shape mismatches (e.g., text+image vs image only) + if res.shape == output.shape: + output = output + res + elif ( + output.ndim == 3 + and res.ndim == 3 + and output.shape[0] == res.shape[0] + and output.shape[2] == res.shape[2] + ): + # Assuming concatenation where image part is at the end (standard in Flux/SD3) + diff = output.shape[1] - res.shape[1] + if diff > 0: + output = output.clone() + output[:, diff:, :] = output[:, diff:, :] + res + else: + logger.warning( + f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " + "Cannot apply residual safely. Returning input without residual." + ) + else: + logger.warning( + f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " + "Cannot apply residual safely. Returning input without residual." + ) + + if self._metadata.return_encoder_hidden_states_index is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = output + ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return tuple(ret_list) + else: + return output + + else: + # Compute original forward + output = self.fn_ref.original_forward(*args, **kwargs) + return output + + def reset_state(self, module): + self.state_manager.reset() + return module + + +class MagCacheBlockHook(ModelHook): + def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None): + super().__init__() + self.state_manager = state_manager + self.is_tail = is_tail + self.config = config + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + @torch.compiler.disable + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("inference") + state: MagCacheState = self.state_manager.get_state() + + if not state.should_compute: + arg_name = self._metadata.hidden_states_argument_name + hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) + + if self.is_tail: + # Still need to advance step index even if we skip + self._advance_step(state) + + if self._metadata.return_encoder_hidden_states_index is not None: + encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = hidden_states + ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states + return tuple(ret_list) + + return hidden_states + + output = self.fn_ref.original_forward(*args, **kwargs) + + if self.is_tail: + # Calculate residual for next steps + if isinstance(output, tuple): + out_hidden = output[self._metadata.return_hidden_states_index] + else: + out_hidden = output + + in_hidden = state.head_block_input + + if in_hidden is None: + return output + + # Determine residual + if out_hidden.shape == in_hidden.shape: + residual = out_hidden - in_hidden + elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: + diff = in_hidden.shape[1] - out_hidden.shape[1] + if diff == 0: + residual = out_hidden - in_hidden + else: + residual = out_hidden - in_hidden # Fallback to matching tail + else: + # Fallback for completely mismatched shapes + residual = out_hidden + + if self.config.calibrate: + self._perform_calibration_step(state, residual) + + state.previous_residual = residual + self._advance_step(state) + + return output + + def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor): + if state.previous_residual is None: + # First step has no previous residual to compare against. + # log 1.0 as a neutral starting point. + ratio = 1.0 + else: + # MagCache Calibration Formula: mean(norm(curr) / norm(prev)) + # norm(dim=-1) gives magnitude of each token vector + curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) + prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1) + + # Avoid division by zero + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + + state.calibration_ratios.append(ratio) + + def _advance_step(self, state: MagCacheState): + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + # End of inference loop + if self.config.calibrate: + print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):") + print(f"{state.calibration_ratios}\n") + logger.info(f"MagCache Calibration Results: {state.calibration_ratios}") + + # Reset state + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + state.calibration_ratios = [] + + +def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: + """ + Applies MagCache to a given module (typically a Transformer). + + Args: + module (`torch.nn.Module`): + The module to apply MagCache to. + config (`MagCacheConfig`): + The configuration for MagCache. + """ + # Initialize registry on the root module so the Pipeline can set context. + HookRegistry.check_if_exists_or_initialize(module) + + state_manager = StateManager(MagCacheState, (), {}) + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + if not remaining_blocks: + logger.warning("MagCache: No transformer blocks found to apply hooks.") + return + + # Handle single-block models + if len(remaining_blocks) == 1: + name, block = remaining_blocks[0] + logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'") + _apply_mag_cache_block_hook(block, state_manager, config, is_tail=True) + _apply_mag_cache_head_hook(block, state_manager, config) + return + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.info(f"MagCache: Applying Head Hook to {head_block_name}") + _apply_mag_cache_head_hook(head_block, state_manager, config) + + for name, block in remaining_blocks: + _apply_mag_cache_block_hook(block, state_manager, config) + + logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}") + _apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True) + + +def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + # Automatically remove existing hook to allow re-application (e.g. switching modes) + if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) + + hook = MagCacheHeadHook(state_manager, config) + registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK) + + +def _apply_mag_cache_block_hook( + block: torch.nn.Module, + state_manager: StateManager, + config: MagCacheConfig, + is_tail: bool = False, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + # Automatically remove existing hook to allow re-application + if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK) + + hook = MagCacheBlockHook(state_manager, is_tail, config) + registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4d1db36a7352..96953afa4f4a 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -54,6 +54,7 @@ _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] + _import_structure["controlnets.controlnet_cosmos"] = ["CosmosControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_hunyuan"] = [ "HunyuanDiT2DControlNetModel", @@ -175,6 +176,7 @@ ControlNetModel, ControlNetUnionModel, ControlNetXSAdapter, + CosmosControlNetModel, FluxControlNetModel, FluxMultiControlNetModel, HunyuanDiT2DControlNetModel, diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index c96b4fa88c49..0a5b7fff1c66 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -18,7 +18,7 @@ from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin -from ..utils import logging +from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code @@ -220,4 +220,11 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") kwargs = {**load_config_kwargs, **kwargs} - return model_cls.from_pretrained(pretrained_model_or_path, **kwargs) + model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs) + + load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs} + parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS] + load_id = "|".join("null" if p is None else p for p in parts) + model._diffusers_load_id = load_id + + return model diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 153608bb2bf8..04c90668a1db 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -68,10 +68,12 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, FirstBlockCacheConfig, + MagCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, + apply_mag_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, ) @@ -85,6 +87,8 @@ def enable_cache(self, config) -> None: apply_faster_cache(self, config) elif isinstance(config, FirstBlockCacheConfig): apply_first_block_cache(self, config) + elif isinstance(config, MagCacheConfig): + apply_mag_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) elif isinstance(config, TaylorSeerCacheConfig): @@ -99,11 +103,13 @@ def disable_cache(self) -> None: FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, + MagCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK + from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK @@ -118,6 +124,9 @@ def disable_cache(self) -> None: elif isinstance(self._cache_config, FirstBlockCacheConfig): registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True) registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, MagCacheConfig): + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True) + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) elif isinstance(self._cache_config, TaylorSeerCacheConfig): diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index fee7f231e899..853a2207f903 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -3,6 +3,7 @@ if is_torch_available(): from .controlnet import ControlNetModel, ControlNetOutput + from .controlnet_cosmos import CosmosControlNetModel from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel from .controlnet_hunyuan import ( HunyuanControlNetOutput, diff --git a/src/diffusers/models/controlnets/controlnet_cosmos.py b/src/diffusers/models/controlnets/controlnet_cosmos.py new file mode 100644 index 000000000000..6ea7d629b816 --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_cosmos.py @@ -0,0 +1,312 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import BaseOutput, is_torchvision_available, logging +from ..modeling_utils import ModelMixin +from ..transformers.transformer_cosmos import ( + CosmosEmbedding, + CosmosLearnablePositionalEmbed, + CosmosPatchEmbed, + CosmosRotaryPosEmbed, + CosmosTransformerBlock, +) + + +if is_torchvision_available(): + from torchvision import transforms + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class CosmosControlNetOutput(BaseOutput): + """ + Output of [`CosmosControlNetModel`]. + + Args: + control_block_samples (`list[torch.Tensor]`): + List of control block activations to be injected into transformer blocks. + """ + + control_block_samples: List[torch.Tensor] + + +class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + ControlNet for Cosmos Transfer2.5. + + This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed, + learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything + internally from raw inputs. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"] + _no_split_modules = ["CosmosTransformerBlock"] + _keep_in_fp32_modules = ["learnable_pos_embed"] + + @register_to_config + def __init__( + self, + n_controlnet_blocks: int = 4, + in_channels: int = 130, + latent_channels: int = 18, # base latent channels (latents + condition_mask) + padding_mask + model_channels: int = 2048, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + mlp_ratio: float = 4.0, + text_embed_dim: int = 1024, + adaln_lora_dim: int = 256, + patch_size: Tuple[int, int, int] = (1, 2, 2), + max_size: Tuple[int, int, int] = (128, 240, 240), + rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0), + extra_pos_embed_type: Optional[str] = None, + img_context_dim_in: Optional[int] = None, + img_context_dim_out: int = 2048, + use_crossattn_projection: bool = False, + crossattn_proj_in_channels: int = 1024, + encoder_hidden_states_channels: int = 1024, + ): + super().__init__() + + self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False) + + self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False) + self.time_embed = CosmosEmbedding(model_channels, model_channels) + + self.learnable_pos_embed = None + if extra_pos_embed_type == "learnable": + self.learnable_pos_embed = CosmosLearnablePositionalEmbed( + hidden_size=model_channels, + max_size=max_size, + patch_size=patch_size, + ) + + self.img_context_proj = None + if img_context_dim_in is not None and img_context_dim_in > 0: + self.img_context_proj = nn.Sequential( + nn.Linear(img_context_dim_in, img_context_dim_out, bias=True), + nn.GELU(), + ) + + # Cross-attention projection for text embeddings (same as transformer) + self.crossattn_proj = None + if use_crossattn_projection: + self.crossattn_proj = nn.Sequential( + nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True), + nn.GELU(), + ) + + # RoPE for both control and base latents + self.rope = CosmosRotaryPosEmbed( + hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale + ) + + self.control_blocks = nn.ModuleList( + [ + CosmosTransformerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=text_embed_dim, + mlp_ratio=mlp_ratio, + adaln_lora_dim=adaln_lora_dim, + qk_norm="rms_norm", + out_bias=False, + img_context=img_context_dim_in is not None and img_context_dim_in > 0, + before_proj=(block_idx == 0), + after_proj=True, + ) + for block_idx in range(n_controlnet_blocks) + ] + ) + + self.gradient_checkpointing = False + + def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float]]) -> List[float]: + if isinstance(conditioning_scale, list): + scales = conditioning_scale + else: + scales = [conditioning_scale] * len(self.control_blocks) + + if len(scales) < len(self.control_blocks): + logger.warning( + "Received %d control scales, but control network defines %d blocks. " + "Scales will be trimmed or repeated to match.", + len(scales), + len(self.control_blocks), + ) + scales = (scales * len(self.control_blocks))[: len(self.control_blocks)] + return scales + + def forward( + self, + controls_latents: torch.Tensor, + latents: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], + condition_mask: torch.Tensor, + conditioning_scale: Union[float, List[float]] = 1.0, + padding_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + fps: Optional[int] = None, + return_dict: bool = True, + ) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]: + """ + Forward pass for the ControlNet. + + Args: + controls_latents: Control signal latents [B, C, T, H, W] + latents: Base latents from the noising process [B, C, T, H, W] + timestep: Diffusion timestep tensor + encoder_hidden_states: Tuple of (text_context, img_context) or text_context + condition_mask: Conditioning mask [B, 1, T, H, W] + conditioning_scale: Scale factor(s) for control outputs + padding_mask: Padding mask [B, 1, H, W] or None + attention_mask: Optional attention mask or None + fps: Frames per second for RoPE or None + return_dict: Whether to return a CosmosControlNetOutput or a tuple + + Returns: + CosmosControlNetOutput or tuple of control tensors + """ + B, C, T, H, W = controls_latents.shape + + # 1. Prepare control latents + control_hidden_states = controls_latents + vace_in_channels = self.config.in_channels - 1 + if control_hidden_states.shape[1] < vace_in_channels - 1: + pad_C = vace_in_channels - 1 - control_hidden_states.shape[1] + control_hidden_states = torch.cat( + [ + control_hidden_states, + torch.zeros( + (B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device + ), + ], + dim=1, + ) + + control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1) + + padding_mask_resized = transforms.functional.resize( + padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + control_hidden_states = torch.cat( + [control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1 + ) + + # 2. Prepare base latents (same processing as transformer.forward) + base_hidden_states = latents + if condition_mask is not None: + base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1) + + base_padding_mask = transforms.functional.resize( + padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + base_hidden_states = torch.cat( + [base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1 + ) + + # 3. Generate positional embeddings (shared for both) + image_rotary_emb = self.rope(control_hidden_states, fps=fps) + extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None + + # 4. Patchify control latents + control_hidden_states = self.patch_embed(control_hidden_states) + control_hidden_states = control_hidden_states.flatten(1, 3) + + # 5. Patchify base latents + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = T // p_t + post_patch_height = H // p_h + post_patch_width = W // p_w + + base_hidden_states = self.patch_embed_base(base_hidden_states) + base_hidden_states = base_hidden_states.flatten(1, 3) + + # 6. Time embeddings + if timestep.ndim == 1: + temb, embedded_timestep = self.time_embed(base_hidden_states, timestep) + elif timestep.ndim == 5: + batch_size, _, num_frames, _, _ = latents.shape + assert timestep.shape == (batch_size, 1, num_frames, 1, 1), ( + f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}" + ) + timestep_flat = timestep.flatten() + temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat) + temb, embedded_timestep = ( + x.view(batch_size, post_patch_num_frames, 1, 1, -1) + .expand(-1, -1, post_patch_height, post_patch_width, -1) + .flatten(1, 3) + for x in (temb, embedded_timestep) + ) + else: + raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}") + + # 7. Process encoder hidden states + if isinstance(encoder_hidden_states, tuple): + text_context, img_context = encoder_hidden_states + else: + text_context = encoder_hidden_states + img_context = None + + # Apply cross-attention projection to text context + if self.crossattn_proj is not None: + text_context = self.crossattn_proj(text_context) + + # Apply cross-attention projection to image context (if provided) + if img_context is not None and self.img_context_proj is not None: + img_context = self.img_context_proj(img_context) + + # Combine text and image context into a single tuple + if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0: + processed_encoder_hidden_states = (text_context, img_context) + else: + processed_encoder_hidden_states = text_context + + # 8. Prepare attention mask + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S] + + # 9. Run control blocks + scales = self._expand_conditioning_scale(conditioning_scale) + result = [] + for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)): + if torch.is_grad_enabled() and self.gradient_checkpointing: + control_hidden_states, control_proj = self._gradient_checkpointing_func( + block, + control_hidden_states, + processed_encoder_hidden_states, + embedded_timestep, + temb, + image_rotary_emb, + extra_pos_emb, + attention_mask, + None, # controlnet_residual + base_hidden_states, + block_idx, + ) + else: + control_hidden_states, control_proj = block( + hidden_states=control_hidden_states, + encoder_hidden_states=processed_encoder_hidden_states, + embedded_timestep=embedded_timestep, + temb=temb, + image_rotary_emb=image_rotary_emb, + extra_pos_emb=extra_pos_emb, + attention_mask=attention_mask, + controlnet_residual=None, + latents=base_hidden_states, + block_idx=block_idx, + ) + result.append(control_proj * scale) + + if not return_dict: + return (result,) + + return CosmosControlNetOutput(control_block_samples=result) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 09f79619320d..e10bfddcbc86 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -125,9 +125,9 @@ def __call__( encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) - hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[0](hidden_states.contiguous()) hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous()) return hidden_states, encoder_hidden_states else: diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 2b0c2667072b..0f1a5f295c34 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin from ...utils import is_torchvision_available from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention from ..embeddings import Timesteps from ..modeling_outputs import Transformer2DModelOutput @@ -152,7 +152,7 @@ def forward( class CosmosAttnProcessor2_0: def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") def __call__( @@ -191,7 +191,6 @@ def __call__( query_idx = torch.tensor(query.size(3), device=query.device) key_idx = torch.tensor(key.size(3), device=key.device) value_idx = torch.tensor(value.size(3), device=value.device) - else: query_idx = query.size(3) key_idx = key.size(3) @@ -200,18 +199,148 @@ def __call__( value = value.repeat_interleave(query_idx // value_idx, dim=3) # 5. Attention - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query) - - # 6. Output projection + hidden_states = hidden_states.flatten(2, 3).type_as(query) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states +class CosmosAttnProcessor2_5: + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], + attention_mask: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], + image_rotary_emb=None, + ) -> torch.Tensor: + if not isinstance(encoder_hidden_states, tuple): + raise ValueError("Expected encoder_hidden_states as (text_context, img_context) tuple.") + + text_context, img_context = encoder_hidden_states if encoder_hidden_states else (None, None) + text_mask, img_mask = attention_mask if attention_mask else (None, None) + + if text_context is None: + text_context = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(text_context) + value = attn.to_v(text_context) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) + key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) + + if torch.onnx.is_in_onnx_export(): + query_idx = torch.tensor(query.size(3), device=query.device) + key_idx = torch.tensor(key.size(3), device=key.device) + value_idx = torch.tensor(value.size(3), device=value.device) + else: + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) + key = key.repeat_interleave(query_idx // key_idx, dim=3) + value = value.repeat_interleave(query_idx // value_idx, dim=3) + + attn_out = dispatch_attention_fn( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=text_mask, + dropout_p=0.0, + is_causal=False, + ) + attn_out = attn_out.flatten(2, 3).type_as(query) + + if img_context is not None: + q_img = attn.q_img(hidden_states) + k_img = attn.k_img(img_context) + v_img = attn.v_img(img_context) + + batch_size = hidden_states.shape[0] + dim_head = attn.out_dim // attn.heads + + q_img = q_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2) + k_img = k_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2) + v_img = v_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2) + + q_img = attn.q_img_norm(q_img) + k_img = attn.k_img_norm(k_img) + + q_img_idx = q_img.size(3) + k_img_idx = k_img.size(3) + v_img_idx = v_img.size(3) + k_img = k_img.repeat_interleave(q_img_idx // k_img_idx, dim=3) + v_img = v_img.repeat_interleave(q_img_idx // v_img_idx, dim=3) + + img_out = dispatch_attention_fn( + q_img.transpose(1, 2), + k_img.transpose(1, 2), + v_img.transpose(1, 2), + attn_mask=img_mask, + dropout_p=0.0, + is_causal=False, + ) + img_out = img_out.flatten(2, 3).type_as(q_img) + hidden_states = attn_out + img_out + else: + hidden_states = attn_out + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class CosmosAttention(Attention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # add parameters for image q/k/v + inner_dim = self.heads * self.to_q.out_features // self.heads + self.q_img = nn.Linear(self.query_dim, inner_dim, bias=False) + self.k_img = nn.Linear(self.query_dim, inner_dim, bias=False) + self.v_img = nn.Linear(self.query_dim, inner_dim, bias=False) + self.q_img_norm = RMSNorm(self.to_q.out_features // self.heads, eps=1e-6, elementwise_affine=True) + self.k_img_norm = RMSNorm(self.to_k.out_features // self.heads, eps=1e-6, elementwise_affine=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + return super().forward( + hidden_states=hidden_states, + # NOTE: type-hint in base class can be ignored + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + class CosmosTransformerBlock(nn.Module): def __init__( self, @@ -222,12 +351,16 @@ def __init__( adaln_lora_dim: int = 256, qk_norm: str = "rms_norm", out_bias: bool = False, + img_context: bool = False, + before_proj: bool = False, + after_proj: bool = False, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim) + self.img_context = img_context self.attn1 = Attention( query_dim=hidden_size, cross_attention_dim=None, @@ -240,30 +373,58 @@ def __init__( ) self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim) - self.attn2 = Attention( - query_dim=hidden_size, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - qk_norm=qk_norm, - elementwise_affine=True, - out_bias=out_bias, - processor=CosmosAttnProcessor2_0(), - ) + if img_context: + self.attn2 = CosmosAttention( + query_dim=hidden_size, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + qk_norm=qk_norm, + elementwise_affine=True, + out_bias=out_bias, + processor=CosmosAttnProcessor2_5(), + ) + else: + self.attn2 = Attention( + query_dim=hidden_size, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + qk_norm=qk_norm, + elementwise_affine=True, + out_bias=out_bias, + processor=CosmosAttnProcessor2_0(), + ) self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim) self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias) + # NOTE: zero conv for CosmosControlNet + self.before_proj = None + self.after_proj = None + if before_proj: + self.before_proj = nn.Linear(hidden_size, hidden_size) + if after_proj: + self.after_proj = nn.Linear(hidden_size, hidden_size) + def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, + encoder_hidden_states: Union[ + Optional[torch.Tensor], Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]] + ], embedded_timestep: torch.Tensor, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, extra_pos_emb: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + controlnet_residual: Optional[torch.Tensor] = None, + latents: Optional[torch.Tensor] = None, + block_idx: Optional[int] = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if self.before_proj is not None: + hidden_states = self.before_proj(hidden_states) + latents + if extra_pos_emb is not None: hidden_states = hidden_states + extra_pos_emb @@ -284,6 +445,16 @@ def forward( ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate * ff_output + if controlnet_residual is not None: + assert self.after_proj is None + # NOTE: this is assumed to be scaled by the controlnet + hidden_states += controlnet_residual + + if self.after_proj is not None: + assert controlnet_residual is None + hs_proj = self.after_proj(hidden_states) + return hidden_states, hs_proj + return hidden_states @@ -416,6 +587,17 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): Whether to concatenate the padding mask to the input latent tensors. extra_pos_embed_type (`str`, *optional*, defaults to `learnable`): The type of extra positional embeddings to use. Can be one of `None` or `learnable`. + controlnet_block_every_n (`int`, *optional*): + Interval between transformer blocks that should receive control residuals (for example, `7` to inject after + every seventh block). Required for Cosmos Transfer2.5. + img_context_dim_in (`int`, *optional*): + The dimension of the input image context feature vector, i.e. it is the D in [B, N, D]. + img_context_num_tokens (`int`): + The number of tokens in the image context feature vector, i.e. it is the N in [B, N, D]. If + `img_context_dim_in` is not provided, then this parameter is ignored. + img_context_dim_out (`int`): + The output dimension of the image context projection layer. If `img_context_dim_in` is not provided, then + this parameter is ignored. """ _supports_gradient_checkpointing = True @@ -442,6 +624,10 @@ def __init__( use_crossattn_projection: bool = False, crossattn_proj_in_channels: int = 1024, encoder_hidden_states_channels: int = 1024, + controlnet_block_every_n: Optional[int] = None, + img_context_dim_in: Optional[int] = None, + img_context_num_tokens: int = 256, + img_context_dim_out: int = 2048, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim @@ -477,6 +663,7 @@ def __init__( adaln_lora_dim=adaln_lora_dim, qk_norm="rms_norm", out_bias=False, + img_context=self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0, ) for _ in range(num_layers) ] @@ -496,17 +683,24 @@ def __init__( self.gradient_checkpointing = False + if self.config.img_context_dim_in: + self.img_context_proj = nn.Sequential( + nn.Linear(self.config.img_context_dim_in, self.config.img_context_dim_out, bias=True), + nn.GELU(), + ) + def forward( self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, + block_controlnet_hidden_states: Optional[List[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, fps: Optional[int] = None, condition_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, return_dict: bool = True, - ) -> torch.Tensor: + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: batch_size, num_channels, num_frames, height, width = hidden_states.shape # 1. Concatenate padding mask if needed & prepare attention mask @@ -514,11 +708,11 @@ def forward( hidden_states = torch.cat([hidden_states, condition_mask], dim=1) if self.config.concat_padding_mask: - padding_mask = transforms.functional.resize( + padding_mask_resized = transforms.functional.resize( padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST ) hidden_states = torch.cat( - [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1 + [hidden_states, padding_mask_resized.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1 ) if attention_mask is not None: @@ -554,36 +748,59 @@ def forward( for x in (temb, embedded_timestep) ) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C] else: - assert False + raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}") + # 5. Process encoder hidden states + text_context, img_context = ( + encoder_hidden_states if isinstance(encoder_hidden_states, tuple) else (encoder_hidden_states, None) + ) if self.config.use_crossattn_projection: - encoder_hidden_states = self.crossattn_proj(encoder_hidden_states) + text_context = self.crossattn_proj(text_context) + + if img_context is not None and self.config.img_context_dim_in: + img_context = self.img_context_proj(img_context) - # 5. Transformer blocks - for block in self.transformer_blocks: + processed_encoder_hidden_states = ( + (text_context, img_context) if isinstance(encoder_hidden_states, tuple) else text_context + ) + + # 6. Build controlnet block index map + controlnet_block_index_map = {} + if block_controlnet_hidden_states is not None: + n_blocks = len(self.transformer_blocks) + controlnet_block_index_map = { + block_idx: block_controlnet_hidden_states[idx] + for idx, block_idx in list(enumerate(range(0, n_blocks, self.config.controlnet_block_every_n))) + } + + # 7. Transformer blocks + for block_idx, block in enumerate(self.transformer_blocks): + controlnet_residual = controlnet_block_index_map.get(block_idx) if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( block, hidden_states, - encoder_hidden_states, + processed_encoder_hidden_states, embedded_timestep, temb, image_rotary_emb, extra_pos_emb, attention_mask, + controlnet_residual, ) else: hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - embedded_timestep=embedded_timestep, - temb=temb, - image_rotary_emb=image_rotary_emb, - extra_pos_emb=extra_pos_emb, - attention_mask=attention_mask, + hidden_states, + processed_encoder_hidden_states, + embedded_timestep, + temb, + image_rotary_emb, + extra_pos_emb, + attention_mask, + controlnet_residual, ) - # 6. Output norm & projection & unpatchify + # 8. Output norm & projection & unpatchify hidden_states = self.norm_out(hidden_states, embedded_timestep, temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1)) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 1a4464432425..37b995867fd0 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -130,9 +130,9 @@ def __call__( encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) - hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[0](hidden_states.contiguous()) hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous()) return hidden_states, encoder_hidden_states else: diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..74461a07667e 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -561,11 +561,11 @@ def __call__( img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part # Apply output projections - img_attn_output = attn.to_out[0](img_attn_output) + img_attn_output = attn.to_out[0](img_attn_output.contiguous()) if len(attn.to_out) > 1: img_attn_output = attn.to_out[1](img_attn_output) # dropout - txt_attn_output = attn.to_add_out(txt_attn_output) + txt_attn_output = attn.to_add_out(txt_attn_output.contiguous()) return img_attn_output, txt_attn_output diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 823a3d263ea9..94b87c61c234 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -45,7 +45,16 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] - _import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"] + _import_structure["wan"] = [ + "WanBlocks", + "Wan22Blocks", + "WanImage2VideoAutoBlocks", + "Wan22Image2VideoBlocks", + "WanModularPipeline", + "Wan22ModularPipeline", + "WanImage2VideoModularPipeline", + "Wan22Image2VideoModularPipeline", + ] _import_structure["flux"] = [ "FluxAutoBlocks", "FluxModularPipeline", @@ -58,6 +67,7 @@ "Flux2KleinBaseAutoBlocks", "Flux2ModularPipeline", "Flux2KleinModularPipeline", + "Flux2KleinBaseModularPipeline", ] _import_structure["qwenimage"] = [ "QwenImageAutoBlocks", @@ -88,6 +98,7 @@ Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, + Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline, ) @@ -112,7 +123,16 @@ QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline - from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline + from .wan import ( + Wan22Blocks, + Wan22Image2VideoBlocks, + Wan22Image2VideoModularPipeline, + Wan22ModularPipeline, + WanBlocks, + WanImage2VideoAutoBlocks, + WanImage2VideoModularPipeline, + WanModularPipeline, + ) from .z_image import ZImageAutoBlocks, ZImageModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index e16abb382313..4a7ea8502c86 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -324,6 +324,7 @@ class ComponentsManager: "has_hook", "execution_device", "ip_adapter", + "quantization", ] def __init__(self): @@ -356,7 +357,9 @@ def _lookup_ids( ids_by_name.add(component_id) else: ids_by_name = set(components.keys()) - if collection: + if collection and collection not in self.collections: + return set() + elif collection and collection in self.collections: ids_by_collection = set() for component_id, component in components.items(): if component_id in self.collections[collection]: @@ -423,7 +426,8 @@ def add(self, name: str, component: Any, collection: Optional[str] = None): # add component to components manager self.components[component_id] = component - self.added_time[component_id] = time.time() + if is_new_component: + self.added_time[component_id] = time.time() if collection: if collection not in self.collections: @@ -760,7 +764,6 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False - # YiYi TODO: (1) add quantization info def get_model_info( self, component_id: str, @@ -836,6 +839,17 @@ def get_model_info( if scales: info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) + # Check for quantization + hf_quantizer = getattr(component, "hf_quantizer", None) + if hf_quantizer is not None: + quant_config = hf_quantizer.quantization_config + if hasattr(quant_config, "to_diff_dict"): + info["quantization"] = quant_config.to_diff_dict() + else: + info["quantization"] = quant_config.to_dict() + else: + info["quantization"] = None + # If fields specified, filter info if fields is not None: return {k: v for k, v in info.items() if k in fields} @@ -966,12 +980,16 @@ def format_device(component, info): output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" for name in self.components: info = self.get_model_info(name) - if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): + if info is not None and ( + info.get("adapters") is not None or info.get("ip_adapter") or info.get("quantization") + ): output += f"\n{name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): output += " IP-Adapter: Enabled\n" + if info.get("quantization"): + output += f" Quantization: {info['quantization']}\n" return output diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py index 220ec0c4ab65..74907a9af806 100644 --- a/src/diffusers/modular_pipelines/flux2/__init__.py +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -55,7 +55,11 @@ "Flux2VaeEncoderSequentialStep", ] _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"] - _import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"] + _import_structure["modular_pipeline"] = [ + "Flux2ModularPipeline", + "Flux2KleinModularPipeline", + "Flux2KleinBaseModularPipeline", + ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -101,7 +105,7 @@ Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, ) - from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline + from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py index 29fbeba07c24..31ba5aec7cfb 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -13,8 +13,6 @@ # limitations under the License. -from typing import Any, Dict, Optional - from ...loaders import Flux2LoraLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline @@ -59,46 +57,35 @@ def num_channels_latents(self): return num_channels_latents -class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): +class Flux2KleinModularPipeline(Flux2ModularPipeline): """ - A ModularPipeline for Flux2-Klein. + A ModularPipeline for Flux2-Klein (distilled model). > [!WARNING] > This is an experimental feature and is likely to change in the future. """ - default_blocks_name = "Flux2KleinBaseAutoBlocks" - - def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: - if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]: - return "Flux2KleinAutoBlocks" - else: - return "Flux2KleinBaseAutoBlocks" + default_blocks_name = "Flux2KleinAutoBlocks" @property - def default_height(self): - return self.default_sample_size * self.vae_scale_factor + def requires_unconditional_embeds(self): + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False - @property - def default_width(self): - return self.default_sample_size * self.vae_scale_factor + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 - @property - def default_sample_size(self): - return 128 + return requires_unconditional_embeds - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if getattr(self, "vae", None) is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - @property - def num_channels_latents(self): - num_channels_latents = 32 - if getattr(self, "transformer", None): - num_channels_latents = self.transformer.config.in_channels // 4 - return num_channels_latents +class Flux2KleinBaseModularPipeline(Flux2ModularPipeline): + """ + A ModularPipeline for Flux2-Klein (base model). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinBaseAutoBlocks" @property def requires_unconditional_embeds(self): diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py index f848afe9a3ae..5eb6319a6b7a 100644 --- a/src/diffusers/modular_pipelines/mellon_node_utils.py +++ b/src/diffusers/modular_pipelines/mellon_node_utils.py @@ -1,3 +1,4 @@ +import copy import json import logging import os @@ -6,7 +7,7 @@ from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Union -from huggingface_hub import create_repo, hf_hub_download, upload_folder +from huggingface_hub import create_repo, hf_hub_download, upload_file from huggingface_hub.utils import ( EntryNotFoundError, HfHubHTTPError, @@ -15,25 +16,262 @@ ) from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT +from .modular_pipeline_utils import InputParam, OutputParam logger = logging.getLogger(__name__) +def _name_to_label(name: str) -> str: + """Convert snake_case name to Title Case label.""" + return name.replace("_", " ").title() + + +# Template definitions for standard diffuser pipeline parameters +MELLON_PARAM_TEMPLATES = { + # Image I/O + "image": {"label": "Image", "type": "image", "display": "input", "required_block_params": ["image"]}, + "images": {"label": "Images", "type": "image", "display": "output", "required_block_params": ["images"]}, + "control_image": { + "label": "Control Image", + "type": "image", + "display": "input", + "required_block_params": ["control_image"], + }, + # Latents + "latents": {"label": "Latents", "type": "latents", "display": "input", "required_block_params": ["latents"]}, + "image_latents": { + "label": "Image Latents", + "type": "latents", + "display": "input", + "required_block_params": ["image_latents"], + }, + "first_frame_latents": { + "label": "First Frame Latents", + "type": "latents", + "display": "input", + "required_block_params": ["first_frame_latents"], + }, + "latents_preview": {"label": "Latents Preview", "type": "latent", "display": "output"}, + # Image Latents with Strength + "image_latents_with_strength": { + "name": "image_latents", # name is not same as template key + "label": "Image Latents", + "type": "latents", + "display": "input", + "onChange": {"false": ["height", "width"], "true": ["strength"]}, + "required_block_params": ["image_latents", "strength"], + }, + # Embeddings + "embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "output"}, + "image_embeds": { + "label": "Image Embeddings", + "type": "image_embeds", + "display": "output", + "required_block_params": ["image_embeds"], + }, + # Text inputs + "prompt": { + "label": "Prompt", + "type": "string", + "display": "textarea", + "default": "", + "required_block_params": ["prompt"], + }, + "negative_prompt": { + "label": "Negative Prompt", + "type": "string", + "display": "textarea", + "default": "", + "required_block_params": ["negative_prompt"], + }, + # Numeric params + "guidance_scale": { + "label": "Guidance Scale", + "type": "float", + "display": "slider", + "default": 5.0, + "min": 1.0, + "max": 30.0, + "step": 0.1, + }, + "strength": { + "label": "Strength", + "type": "float", + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "required_block_params": ["strength"], + }, + "height": { + "label": "Height", + "type": "int", + "default": 1024, + "min": 64, + "step": 8, + "required_block_params": ["height"], + }, + "width": { + "label": "Width", + "type": "int", + "default": 1024, + "min": 64, + "step": 8, + "required_block_params": ["width"], + }, + "seed": { + "label": "Seed", + "type": "int", + "default": 0, + "min": 0, + "max": 4294967295, + "display": "random", + "required_block_params": ["generator"], + }, + "num_inference_steps": { + "label": "Steps", + "type": "int", + "default": 25, + "min": 1, + "max": 100, + "display": "slider", + "required_block_params": ["num_inference_steps"], + }, + "num_frames": { + "label": "Frames", + "type": "int", + "default": 81, + "min": 1, + "max": 480, + "display": "slider", + "required_block_params": ["num_frames"], + }, + "layers": { + "label": "Layers", + "type": "int", + "default": 4, + "min": 1, + "max": 10, + "display": "slider", + "required_block_params": ["layers"], + }, + "output_type": { + "label": "Output Type", + "type": "dropdown", + "default": "np", + "options": ["np", "pil", "pt"], + }, + # ControlNet + "controlnet_conditioning_scale": { + "label": "Controlnet Conditioning Scale", + "type": "float", + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "required_block_params": ["controlnet_conditioning_scale"], + }, + "control_guidance_start": { + "label": "Control Guidance Start", + "type": "float", + "default": 0.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "required_block_params": ["control_guidance_start"], + }, + "control_guidance_end": { + "label": "Control Guidance End", + "type": "float", + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "required_block_params": ["control_guidance_end"], + }, + # Video + "videos": {"label": "Videos", "type": "video", "display": "output", "required_block_params": ["videos"]}, + # Models + "vae": {"label": "VAE", "type": "diffusers_auto_model", "display": "input", "required_block_params": ["vae"]}, + "image_encoder": { + "label": "Image Encoder", + "type": "diffusers_auto_model", + "display": "input", + "required_block_params": ["image_encoder"], + }, + "unet": {"label": "Denoise Model", "type": "diffusers_auto_model", "display": "input"}, + "scheduler": {"label": "Scheduler", "type": "diffusers_auto_model", "display": "input"}, + "controlnet": { + "label": "ControlNet Model", + "type": "diffusers_auto_model", + "display": "input", + "required_block_params": ["controlnet"], + }, + "text_encoders": { + "label": "Text Encoders", + "type": "diffusers_auto_models", + "display": "input", + "required_block_params": ["text_encoder"], + }, + # Bundles/Custom + "controlnet_bundle": { + "label": "ControlNet", + "type": "custom_controlnet", + "display": "input", + "required_block_params": "controlnet_image", + }, + "ip_adapter": {"label": "IP Adapter", "type": "custom_ip_adapter", "display": "input"}, + "guider": { + "label": "Guider", + "type": "custom_guider", + "display": "input", + "onChange": {False: ["guidance_scale"], True: []}, + }, + "doc": {"label": "Doc", "type": "string", "display": "output"}, +} + + +class MellonParamMeta(type): + """Metaclass that enables MellonParam.template_name(**overrides) syntax.""" + + def __getattr__(cls, name: str): + if name in MELLON_PARAM_TEMPLATES: + + def factory(default=None, **overrides): + template = MELLON_PARAM_TEMPLATES[name] + # Use template's name if specified, otherwise use the key + params = {"name": template.get("name", name), **template, **overrides} + if default is not None: + params["default"] = default + return cls(**params) + + return factory + + raise AttributeError(f"type object 'MellonParam' has no attribute '{name}'") + + @dataclass(frozen=True) -class MellonParam: +class MellonParam(metaclass=MellonParamMeta): """ Parameter definition for Mellon nodes. - Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with - MellonParam(name="...", label="...", type="..."). - - Example: + Usage: ```python - # Custom param - MellonParam(name="my_param", label="My Param", type="float", default=0.5) - # Output in Mellon node definition: - # "my_param": {"label": "My Param", "type": "float", "default": 0.5} + # From template (standard diffuser params) + MellonParam.seed() + MellonParam.prompt(default="a cat") + MellonParam.latents(display="output") + + # Generic inputs (for custom blocks) + MellonParam.Input.slider("my_scale", default=1.0, min=0.0, max=2.0) + MellonParam.Input.dropdown("mode", options=["fast", "slow"]) + + # Generic outputs + MellonParam.Output.image("result_images") + + # Fully custom + MellonParam(name="custom", label="Custom", type="float", default=0.5) ``` """ @@ -53,577 +291,204 @@ class MellonParam: required_block_params: Optional[Union[str, List[str]]] = None def to_dict(self) -> Dict[str, Any]: - """Convert to dict for Mellon schema, excluding None values and name.""" + """Convert to dict for Mellon schema, excluding None values and internal fields.""" data = asdict(self) return {k: v for k, v in data.items() if v is not None and k not in ("name", "required_block_params")} - @classmethod - def image(cls) -> "MellonParam": - """ - Image input parameter. - - Mellon node definition: - "image": {"label": "Image", "type": "image", "display": "input"} - """ - return cls(name="image", label="Image", type="image", display="input", required_block_params=["image"]) - - @classmethod - def images(cls) -> "MellonParam": - """ - Images output parameter. - - Mellon node definition: - "images": {"label": "Images", "type": "image", "display": "output"} - """ - return cls(name="images", label="Images", type="image", display="output", required_block_params=["images"]) - - @classmethod - def control_image(cls, display: str = "input") -> "MellonParam": - """ - Control image parameter for ControlNet. - - Mellon node definition (display="input"): - "control_image": {"label": "Control Image", "type": "image", "display": "input"} - """ - return cls( - name="control_image", - label="Control Image", - type="image", - display=display, - required_block_params=["control_image"], - ) - - @classmethod - def latents(cls, display: str = "input") -> "MellonParam": - """ - Latents parameter. - - Mellon node definition (display="input"): - "latents": {"label": "Latents", "type": "latents", "display": "input"} - - Mellon node definition (display="output"): - "latents": {"label": "Latents", "type": "latents", "display": "output"} - """ - return cls(name="latents", label="Latents", type="latents", display=display, required_block_params=["latents"]) - - @classmethod - def image_latents(cls, display: str = "input") -> "MellonParam": - """ - Image latents parameter for img2img workflows. - - Mellon node definition (display="input"): - "image_latents": {"label": "Image Latents", "type": "latents", "display": "input"} - """ - return cls( - name="image_latents", - label="Image Latents", - type="latents", - display=display, - required_block_params=["image_latents"], - ) - - @classmethod - def first_frame_latents(cls, display: str = "input") -> "MellonParam": - """ - First frame latents for video generation. - - Mellon node definition (display="input"): - "first_frame_latents": {"label": "First Frame Latents", "type": "latents", "display": "input"} - """ - return cls( - name="first_frame_latents", - label="First Frame Latents", - type="latents", - display=display, - required_block_params=["first_frame_latents"], - ) - - @classmethod - def image_latents_with_strength(cls) -> "MellonParam": - """ - Image latents with strength-based onChange behavior. When connected, shows strength slider; when disconnected, - shows height/width. - - Mellon node definition: - "image_latents": { - "label": "Image Latents", "type": "latents", "display": "input", "onChange": {"false": ["height", - "width"], "true": ["strength"]} - } - """ - return cls( - name="image_latents", - label="Image Latents", - type="latents", - display="input", - onChange={"false": ["height", "width"], "true": ["strength"]}, - required_block_params=["image_latents", "strength"], - ) - - @classmethod - def latents_preview(cls) -> "MellonParam": - """ - Latents preview output for visualizing latents in the UI. - - Mellon node definition: - "latents_preview": {"label": "Latents Preview", "type": "latent", "display": "output"} - """ - return cls(name="latents_preview", label="Latents Preview", type="latent", display="output") - - @classmethod - def embeddings(cls, display: str = "output") -> "MellonParam": - """ - Text embeddings parameter. - - Mellon node definition (display="output"): - "embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "output"} - - Mellon node definition (display="input"): - "embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "input"} - """ - return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display) - - @classmethod - def image_embeds(cls, display: str = "output") -> "MellonParam": - """ - Image embeddings parameter for IP-Adapter workflows. - - Mellon node definition (display="output"): - "image_embeds": {"label": "Image Embeddings", "type": "image_embeds", "display": "output"} - """ - return cls( - name="image_embeds", - label="Image Embeddings", - type="image_embeds", - display=display, - required_block_params=["image_embeds"], - ) - - @classmethod - def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam": - """ - ControlNet conditioning scale slider. - - Mellon node definition (default=0.5): - "controlnet_conditioning_scale": { - "label": "Controlnet Conditioning Scale", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0, - "step": 0.01 - } - """ - return cls( - name="controlnet_conditioning_scale", - label="Controlnet Conditioning Scale", - type="float", - default=default, - min=0.0, - max=1.0, - step=0.01, - required_block_params=["controlnet_conditioning_scale"], - ) - - @classmethod - def control_guidance_start(cls, default: float = 0.0) -> "MellonParam": - """ - Control guidance start timestep. - - Mellon node definition (default=0.0): - "control_guidance_start": { - "label": "Control Guidance Start", "type": "float", "default": 0.0, "min": 0.0, "max": 1.0, "step": - 0.01 - } - """ - return cls( - name="control_guidance_start", - label="Control Guidance Start", - type="float", - default=default, - min=0.0, - max=1.0, - step=0.01, - required_block_params=["control_guidance_start"], - ) - - @classmethod - def control_guidance_end(cls, default: float = 1.0) -> "MellonParam": - """ - Control guidance end timestep. - - Mellon node definition (default=1.0): - "control_guidance_end": { - "label": "Control Guidance End", "type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01 - } - """ - return cls( - name="control_guidance_end", - label="Control Guidance End", - type="float", - default=default, - min=0.0, - max=1.0, - step=0.01, - required_block_params=["control_guidance_end"], - ) - - @classmethod - def prompt(cls, default: str = "") -> "MellonParam": - """ - Text prompt input as textarea. - - Mellon node definition (default=""): - "prompt": {"label": "Prompt", "type": "string", "default": "", "display": "textarea"} - """ - return cls( - name="prompt", - label="Prompt", - type="string", - default=default, - display="textarea", - required_block_params=["prompt"], - ) - - @classmethod - def negative_prompt(cls, default: str = "") -> "MellonParam": - """ - Negative prompt input as textarea. - - Mellon node definition (default=""): - "negative_prompt": {"label": "Negative Prompt", "type": "string", "default": "", "display": "textarea"} - """ - return cls( - name="negative_prompt", - label="Negative Prompt", - type="string", - default=default, - display="textarea", - required_block_params=["negative_prompt"], - ) - - @classmethod - def strength(cls, default: float = 0.5) -> "MellonParam": - """ - Denoising strength for img2img. - - Mellon node definition (default=0.5): - "strength": {"label": "Strength", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01} - """ - return cls( - name="strength", - label="Strength", - type="float", - default=default, - min=0.0, - max=1.0, - step=0.01, - required_block_params=["strength"], - ) - - @classmethod - def guidance_scale(cls, default: float = 5.0) -> "MellonParam": - """ - CFG guidance scale slider. - - Mellon node definition (default=5.0): - "guidance_scale": { - "label": "Guidance Scale", "type": "float", "display": "slider", "default": 5.0, "min": 1.0, "max": - 30.0, "step": 0.1 - } - """ - return cls( - name="guidance_scale", - label="Guidance Scale", - type="float", - display="slider", - default=default, - min=1.0, - max=30.0, - step=0.1, - ) - - @classmethod - def height(cls, default: int = 1024) -> "MellonParam": - """ - Image height in pixels. - - Mellon node definition (default=1024): - "height": {"label": "Height", "type": "int", "default": 1024, "min": 64, "step": 8} - """ - return cls( - name="height", - label="Height", - type="int", - default=default, - min=64, - step=8, - required_block_params=["height"], - ) - - @classmethod - def width(cls, default: int = 1024) -> "MellonParam": - """ - Image width in pixels. - - Mellon node definition (default=1024): - "width": {"label": "Width", "type": "int", "default": 1024, "min": 64, "step": 8} - """ - return cls( - name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"] - ) - - @classmethod - def seed(cls, default: int = 0) -> "MellonParam": - """ - Random seed with randomize button. - - Mellon node definition (default=0): - "seed": { - "label": "Seed", "type": "int", "default": 0, "min": 0, "max": 4294967295, "display": "random" - } - """ - return cls( - name="seed", - label="Seed", - type="int", - default=default, - min=0, - max=4294967295, - display="random", - required_block_params=["generator"], - ) - - @classmethod - def num_inference_steps(cls, default: int = 25) -> "MellonParam": - """ - Number of denoising steps slider. - - Mellon node definition (default=25): - "num_inference_steps": { - "label": "Steps", "type": "int", "default": 25, "min": 1, "max": 100, "display": "slider" - } - """ - return cls( - name="num_inference_steps", - label="Steps", - type="int", - default=default, - min=1, - max=100, - display="slider", - required_block_params=["num_inference_steps"], - ) - - @classmethod - def num_frames(cls, default: int = 81) -> "MellonParam": - """ - Number of video frames slider. - - Mellon node definition (default=81): - "num_frames": {"label": "Frames", "type": "int", "default": 81, "min": 1, "max": 480, "display": "slider"} - """ - return cls( - name="num_frames", - label="Frames", - type="int", - default=default, - min=1, - max=480, - display="slider", - required_block_params=["num_frames"], - ) - - @classmethod - def layers(cls, default: int = 4) -> "MellonParam": - """ - Number of layers slider (for layered diffusion). - - Mellon node definition (default=4): - "layers": {"label": "Layers", "type": "int", "default": 4, "min": 1, "max": 10, "display": "slider"} - """ - return cls( - name="layers", - label="Layers", - type="int", - default=default, - min=1, - max=10, - display="slider", - required_block_params=["layers"], - ) - - @classmethod - def videos(cls) -> "MellonParam": - """ - Video output parameter. - - Mellon node definition: - "videos": {"label": "Videos", "type": "video", "display": "output"} - """ - return cls(name="videos", label="Videos", type="video", display="output", required_block_params=["videos"]) - - @classmethod - def vae(cls) -> "MellonParam": - """ - VAE model input. - - Mellon node definition: - "vae": {"label": "VAE", "type": "diffusers_auto_model", "display": "input"} - - Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use - components.get_one(model_id) to retrieve the actual model. - """ - return cls( - name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"] - ) - - @classmethod - def image_encoder(cls) -> "MellonParam": - """ - Image encoder model input. - - Mellon node definition: - "image_encoder": {"label": "Image Encoder", "type": "diffusers_auto_model", "display": "input"} - - Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use - components.get_one(model_id) to retrieve the actual model. - """ - return cls( - name="image_encoder", - label="Image Encoder", - type="diffusers_auto_model", - display="input", - required_block_params=["image_encoder"], - ) - - @classmethod - def unet(cls) -> "MellonParam": - """ - Denoising model (UNet/Transformer) input. - - Mellon node definition: - "unet": {"label": "Denoise Model", "type": "diffusers_auto_model", "display": "input"} - - Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use - components.get_one(model_id) to retrieve the actual model. - """ - return cls(name="unet", label="Denoise Model", type="diffusers_auto_model", display="input") - - @classmethod - def scheduler(cls) -> "MellonParam": - """ - Scheduler model input. - - Mellon node definition: - "scheduler": {"label": "Scheduler", "type": "diffusers_auto_model", "display": "input"} - - Note: The value received is a model info dict with keys like 'model_id', 'repo_id'. Use - components.get_one(model_id) to retrieve the actual scheduler. - """ - return cls(name="scheduler", label="Scheduler", type="diffusers_auto_model", display="input") - - @classmethod - def controlnet(cls) -> "MellonParam": - """ - ControlNet model input. - - Mellon node definition: - "controlnet": {"label": "ControlNet Model", "type": "diffusers_auto_model", "display": "input"} - - Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use - components.get_one(model_id) to retrieve the actual model. - """ - return cls( - name="controlnet", - label="ControlNet Model", - type="diffusers_auto_model", - display="input", - required_block_params=["controlnet"], - ) - - @classmethod - def text_encoders(cls) -> "MellonParam": - """ - Text encoders dict input (multiple encoders). - - Mellon node definition: - "text_encoders": {"label": "Text Encoders", "type": "diffusers_auto_models", "display": "input"} - - Note: The value received is a dict of model info dicts: - { - 'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...}, - 'repo_id': '...' - } - Use components.get_one(model_id) to retrieve each model. - """ - return cls( - name="text_encoders", - label="Text Encoders", - type="diffusers_auto_models", - display="input", - required_block_params=["text_encoder"], - ) - - @classmethod - def controlnet_bundle(cls, display: str = "input") -> "MellonParam": - """ - ControlNet bundle containing model and processed control inputs. Output from ControlNet node, input to Denoise - node. - - Mellon node definition (display="input"): - "controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "input"} + # ========================================================================= + # Input: Generic input parameter factories (for custom blocks) + # ========================================================================= + class Input: + """input UI elements for custom blocks.""" + + @classmethod + def image(cls, name: str) -> "MellonParam": + """image input.""" + return MellonParam(name=name, label=_name_to_label(name), type="image", display="input") + + @classmethod + def textbox(cls, name: str, default: str = "") -> "MellonParam": + """text input as textarea.""" + return MellonParam( + name=name, label=_name_to_label(name), type="string", display="textarea", default=default + ) - Mellon node definition (display="output"): - "controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "output"} + @classmethod + def dropdown(cls, name: str, options: List[str] = None, default: str = None) -> "MellonParam": + """dropdown selection.""" + if options and not default: + default = options[0] + if not default: + default = "" + if not options: + options = [default] + return MellonParam(name=name, label=_name_to_label(name), type="string", options=options, value=default) + + @classmethod + def slider( + cls, name: str, default: float = 0, min: float = None, max: float = None, step: float = None + ) -> "MellonParam": + """slider input.""" + is_float = isinstance(default, float) or (step is not None and isinstance(step, float)) + param_type = "float" if is_float else "int" + if min is None: + min = default + if max is None: + max = default + if step is None: + step = 0.01 if is_float else 1 + return MellonParam( + name=name, + label=_name_to_label(name), + type=param_type, + display="slider", + default=default, + min=min, + max=max, + step=step, + ) - Note: The value is a dict containing: - { - 'controlnet': {'model_id': ..., ...}, # controlnet model info 'control_image': ..., # processed control - image/embeddings 'controlnet_conditioning_scale': ..., # and other denoise block inputs - } - """ - return cls( - name="controlnet_bundle", - label="ControlNet", - type="custom_controlnet", - display=display, - required_block_params="controlnet_image", - ) + @classmethod + def number( + cls, name: str, default: float = 0, min: float = None, max: float = None, step: float = None + ) -> "MellonParam": + """number input (no slider).""" + is_float = isinstance(default, float) or (step is not None and isinstance(step, float)) + param_type = "float" if is_float else "int" + return MellonParam( + name=name, label=_name_to_label(name), type=param_type, default=default, min=min, max=max, step=step + ) - @classmethod - def ip_adapter(cls) -> "MellonParam": - """ - IP-Adapter input. + @classmethod + def seed(cls, name: str = "seed", default: int = 0) -> "MellonParam": + """seed input with randomize button.""" + return MellonParam( + name=name, + label=_name_to_label(name), + type="int", + display="random", + default=default, + min=0, + max=4294967295, + ) - Mellon node definition: - "ip_adapter": {"label": "IP Adapter", "type": "custom_ip_adapter", "display": "input"} - """ - return cls(name="ip_adapter", label="IP Adapter", type="custom_ip_adapter", display="input") + @classmethod + def checkbox(cls, name: str, default: bool = False) -> "MellonParam": + """boolean checkbox.""" + return MellonParam(name=name, label=_name_to_label(name), type="boolean", value=default) + + @classmethod + def custom_type(cls, name: str, type: str) -> "MellonParam": + """custom type input for node connections.""" + return MellonParam(name=name, label=_name_to_label(name), type=type, display="input") + + @classmethod + def model(cls, name: str) -> "MellonParam": + """model input for diffusers components.""" + return MellonParam(name=name, label=_name_to_label(name), type="diffusers_auto_model", display="input") + + # ========================================================================= + # Output: Generic output parameter factories (for custom blocks) + # ========================================================================= + class Output: + """output UI elements for custom blocks.""" + + @classmethod + def image(cls, name: str) -> "MellonParam": + """image output.""" + return MellonParam(name=name, label=_name_to_label(name), type="image", display="output") + + @classmethod + def video(cls, name: str) -> "MellonParam": + """video output.""" + return MellonParam(name=name, label=_name_to_label(name), type="video", display="output") + + @classmethod + def text(cls, name: str) -> "MellonParam": + """text output.""" + return MellonParam(name=name, label=_name_to_label(name), type="string", display="output") + + @classmethod + def custom_type(cls, name: str, type: str) -> "MellonParam": + """custom type output for node connections.""" + return MellonParam(name=name, label=_name_to_label(name), type=type, display="output") + + @classmethod + def model(cls, name: str) -> "MellonParam": + """model output for diffusers components.""" + return MellonParam(name=name, label=_name_to_label(name), type="diffusers_auto_model", display="output") + + +def input_param_to_mellon_param(input_param: "InputParam") -> MellonParam: + """ + Convert an InputParam to a MellonParam using metadata. - @classmethod - def guider(cls) -> "MellonParam": - """ - Custom guider input. When connected, hides the guidance_scale slider. + Args: + input_param: An InputParam with optional metadata containing either: + - {"mellon": ""} for simple types (image, textbox, slider, etc.) + - {"mellon": MellonParam(...)} for full control over UI configuration - Mellon node definition: - "guider": { - "label": "Guider", "type": "custom_guider", "display": "input", "onChange": {false: ["guidance_scale"], - true: []} - } - """ - return cls( - name="guider", - label="Guider", - type="custom_guider", - display="input", - onChange={False: ["guidance_scale"], True: []}, - ) + Returns: + MellonParam instance + """ + name = input_param.name + metadata = input_param.metadata + mellon_value = metadata.get("mellon") if metadata else None + default = input_param.default + + # If it's already a MellonParam, return it directly + if isinstance(mellon_value, MellonParam): + return mellon_value + + mellon_type = mellon_value + + if mellon_type == "image": + return MellonParam.Input.image(name) + elif mellon_type == "textbox": + return MellonParam.Input.textbox(name, default=default or "") + elif mellon_type == "dropdown": + return MellonParam.Input.dropdown(name, default=default or "") + elif mellon_type == "slider": + return MellonParam.Input.slider(name, default=default or 0) + elif mellon_type == "number": + return MellonParam.Input.number(name, default=default or 0) + elif mellon_type == "seed": + return MellonParam.Input.seed(name, default=default or 0) + elif mellon_type == "checkbox": + return MellonParam.Input.checkbox(name, default=default or False) + elif mellon_type == "model": + return MellonParam.Input.model(name) + else: + # None or unknown -> custom + return MellonParam.Input.custom_type(name, type="custom") + + +def output_param_to_mellon_param(output_param: "OutputParam") -> MellonParam: + """ + Convert an OutputParam to a MellonParam using metadata. - @classmethod - def doc(cls) -> "MellonParam": - """ - Documentation output for inspecting the underlying modular pipeline. + Args: + output_param: An OutputParam with optional metadata={"mellon": ""} where type is one of: + image, video, text, model. If metadata is None or unknown, maps to "custom". - Mellon node definition: - "doc": {"label": "Doc", "type": "string", "display": "output"} - """ - return cls(name="doc", label="Doc", type="string", display="output") + Returns: + MellonParam instance + """ + name = output_param.name + metadata = output_param.metadata + mellon_type = metadata.get("mellon") if metadata else None + + if mellon_type == "image": + return MellonParam.Output.image(name) + elif mellon_type == "video": + return MellonParam.Output.video(name) + elif mellon_type == "text": + return MellonParam.Output.text(name) + elif mellon_type == "model": + return MellonParam.Output.model(name) + else: + # None or unknown -> custom + return MellonParam.Output.custom_type(name, type="custom") DEFAULT_NODE_SPECS = { @@ -804,10 +669,15 @@ def node_spec_to_mellon_dict(node_spec: Dict[str, Any], node_type: str) -> Dict[ params[p.name] = param_dict model_input_names.append(p.name) - # Process outputs + # Process outputs: add a prefix to the output name if it already exists as an input for p in node_spec.get("outputs", []): - params[p.name] = p.to_dict() - output_names.append(p.name) + if p.name in input_names: + # rename to out_ + output_name = f"out_{p.name}" + else: + output_name = p.name + params[output_name] = p.to_dict() + output_names.append(output_name) return { "params": params, @@ -959,7 +829,7 @@ def from_json_file(cls, json_file_path: Union[str, os.PathLike]) -> "MellonPipel return cls.from_dict(data) def save(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): - """Save the pipeline config to a directory.""" + """Save the mellon pipeline config to a directory.""" if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") @@ -975,15 +845,14 @@ def save(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = Fals token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id - subfolder = kwargs.pop("subfolder", None) - upload_folder( + upload_file( + path_or_fileobj=output_path, + path_in_repo=self.config_name, repo_id=repo_id, - folder_path=save_directory, token=token, commit_message=commit_message or "Upload MellonPipelineConfig", create_pr=create_pr, - path_in_repo=subfolder, ) logger.info(f"Pipeline config pushed to hub: {repo_id}") @@ -1150,3 +1019,83 @@ def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict default_repo=default_repo, default_dtype=default_dtype, ) + + @classmethod + def from_custom_block( + cls, + block, + node_label: str = None, + input_types: Optional[Dict[str, str]] = None, + output_types: Optional[Dict[str, str]] = None, + ) -> "MellonPipelineConfig": + """ + Create a MellonPipelineConfig from a custom block. + + Args: + block: A block instance with `inputs`, `outputs`, and `expected_components`/`component_names` properties. + Each InputParam/OutputParam should have metadata={"mellon": ""} where type is one of: image, + video, text, checkbox, number, slider, dropdown, model. If metadata is None, maps to "custom". + node_label: The display label for the node. Defaults to block class name with spaces. + input_types: + Optional dict mapping input param names to mellon types. Overrides the block's metadata if provided. + Example: {"prompt": "textbox", "image": "image"} + output_types: + Optional dict mapping output param names to mellon types. Overrides the block's metadata if provided. + Example: {"prompt": "text", "images": "image"} + + Returns: + MellonPipelineConfig instance + """ + if node_label is None: + class_name = block.__class__.__name__ + node_label = "".join([" " + c if c.isupper() else c for c in class_name]).strip() + + if input_types is None: + input_types = {} + if output_types is None: + output_types = {} + + inputs = [] + model_inputs = [] + outputs = [] + + # Process block inputs + for input_param in block.inputs: + if input_param.name is None: + continue + if input_param.name in input_types: + input_param = copy.copy(input_param) + input_param.metadata = {"mellon": input_types[input_param.name]} + print(f" processing input: {input_param.name}, metadata: {input_param.metadata}") + inputs.append(input_param_to_mellon_param(input_param)) + + # Process block outputs + for output_param in block.outputs: + if output_param.name is None: + continue + if output_param.name in output_types: + output_param = copy.copy(output_param) + output_param.metadata = {"mellon": output_types[output_param.name]} + outputs.append(output_param_to_mellon_param(output_param)) + + # Process expected components (all map to model inputs) + component_names = block.component_names + for component_name in component_names: + model_inputs.append(MellonParam.Input.model(component_name)) + + # Always add doc output + outputs.append(MellonParam.doc()) + + node_spec = { + "inputs": inputs, + "model_inputs": model_inputs, + "outputs": outputs, + "required_inputs": [], + "required_model_inputs": [], + "block_name": "custom", + } + + return cls( + node_specs={"custom": node_spec}, + label=node_label, + ) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 98ede73c21fe..f0722ffe99f8 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -34,6 +34,7 @@ from ..utils.hub_utils import load_or_create_model_card, populate_model_card from .components_manager import ComponentsManager from .modular_pipeline_utils import ( + MODULAR_MODEL_CARD_TEMPLATE, ComponentSpec, ConfigSpec, InputParam, @@ -41,6 +42,7 @@ OutputParam, format_components, format_configs, + generate_modular_model_card_content, make_doc_string, ) @@ -52,19 +54,61 @@ # map regular pipeline to modular pipeline class name + + +def _create_default_map_fn(pipeline_class_name: str): + """Create a mapping function that always returns the same pipeline class.""" + + def _map_fn(config_dict=None): + return pipeline_class_name + + return _map_fn + + +def _flux2_klein_map_fn(config_dict=None): + if config_dict is None: + return "Flux2KleinModularPipeline" + + if "is_distilled" in config_dict and config_dict["is_distilled"]: + return "Flux2KleinModularPipeline" + else: + return "Flux2KleinBaseModularPipeline" + + +def _wan_map_fn(config_dict=None): + if config_dict is None: + return "WanModularPipeline" + + if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: + return "Wan22ModularPipeline" + else: + return "WanModularPipeline" + + +def _wan_i2v_map_fn(config_dict=None): + if config_dict is None: + return "WanImage2VideoModularPipeline" + + if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: + return "Wan22Image2VideoModularPipeline" + else: + return "WanImage2VideoModularPipeline" + + MODULAR_PIPELINE_MAPPING = OrderedDict( [ - ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), - ("wan", "WanModularPipeline"), - ("flux", "FluxModularPipeline"), - ("flux-kontext", "FluxKontextModularPipeline"), - ("flux2", "Flux2ModularPipeline"), - ("flux2-klein", "Flux2KleinModularPipeline"), - ("qwenimage", "QwenImageModularPipeline"), - ("qwenimage-edit", "QwenImageEditModularPipeline"), - ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), - ("qwenimage-layered", "QwenImageLayeredModularPipeline"), - ("z-image", "ZImageModularPipeline"), + ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), + ("wan", _wan_map_fn), + ("wan-i2v", _wan_i2v_map_fn), + ("flux", _create_default_map_fn("FluxModularPipeline")), + ("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")), + ("flux2", _create_default_map_fn("Flux2ModularPipeline")), + ("flux2-klein", _flux2_klein_map_fn), + ("qwenimage", _create_default_map_fn("QwenImageModularPipeline")), + ("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")), + ("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")), + ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), + ("z-image", _create_default_map_fn("ZImageModularPipeline")), ] ) @@ -366,7 +410,8 @@ def init_pipeline( """ create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub. """ - pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__) + map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline")) + pipeline_class_name = map_fn() diffusers_module = importlib.import_module("diffusers") pipeline_class = getattr(diffusers_module, pipeline_class_name) @@ -1545,7 +1590,7 @@ def __init__( if modular_config_dict is not None: blocks_class_name = modular_config_dict.get("_blocks_class_name") else: - blocks_class_name = self.get_default_blocks_name(config_dict) + blocks_class_name = self.default_blocks_name if blocks_class_name is not None: diffusers_module = importlib.import_module("diffusers") blocks_class = getattr(diffusers_module, blocks_class_name) @@ -1553,11 +1598,11 @@ def __init__( else: logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") - self.blocks = blocks + self._blocks = blocks self._components_manager = components_manager self._collection = collection - self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} - self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} + self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components} + self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs} # update component_specs and config_specs based on modular_model_index.json if modular_config_dict is not None: @@ -1604,7 +1649,9 @@ def __init__( for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default self.register_to_config(**default_configs) - self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None) + self.register_to_config( + _blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None + ) @property def default_call_parameters(self) -> Dict[str, Any]: @@ -1613,13 +1660,10 @@ def default_call_parameters(self) -> Dict[str, Any]: - Dictionary mapping input names to their default values """ params = {} - for input_param in self.blocks.inputs: + for input_param in self._blocks.inputs: params[input_param.name] = input_param.default return params - def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: - return self.default_blocks_name - @classmethod def _load_pipeline_config( cls, @@ -1715,7 +1759,8 @@ def from_pretrained( logger.debug(" try to determine the modular pipeline class from model_index.json") standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) model_name = _get_model(standard_pipeline_class.__name__) - pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__) + map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline")) + pipeline_class_name = map_fn(config_dict) diffusers_module = importlib.import_module("diffusers") pipeline_class = getattr(diffusers_module, pipeline_class_name) else: @@ -1753,9 +1798,19 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + # Generate modular pipeline card content + card_content = generate_modular_model_card_content(self.blocks) + # Create a new empty model card and eventually tag it - model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) - model_card = populate_model_card(model_card) + model_card = load_or_create_model_card( + repo_id, + token=token, + is_pipeline=True, + model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content), + is_modular=True, + ) + model_card = populate_model_card(model_card, tags=card_content["tags"]) + model_card.save(os.path.join(save_directory, "README.md")) # YiYi TODO: maybe order the json file to make it more readable: configs first, then components @@ -1776,7 +1831,15 @@ def doc(self): Returns: - The docstring of the pipeline blocks """ - return self.blocks.doc + return self._blocks.doc + + @property + def blocks(self) -> ModularPipelineBlocks: + """ + Returns: + - A copy of the pipeline blocks + """ + return deepcopy(self._blocks) def register_components(self, **kwargs): """ @@ -2143,6 +2206,8 @@ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwarg name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained" + and self._component_specs[name].pretrained_model_name_or_path is not None + and getattr(self, name, None) is None ] elif isinstance(names, str): names = [names] @@ -2510,7 +2575,7 @@ def _dict_to_component_spec( ) def set_progress_bar_config(self, **kwargs): - for sub_block_name, sub_block in self.blocks.sub_blocks.items(): + for sub_block_name, sub_block in self._blocks.sub_blocks.items(): if hasattr(sub_block, "set_progress_bar_config"): sub_block.set_progress_bar_config(**kwargs) @@ -2564,7 +2629,7 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs - for expected_input_param in self.blocks.inputs: + for expected_input_param in self._blocks.inputs: name = expected_input_param.name default = expected_input_param.default kwargs_type = expected_input_param.kwargs_type @@ -2583,9 +2648,9 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = # Run the pipeline with torch.no_grad(): try: - _, state = self.blocks(self, state) + _, state = self._blocks(self, state) except Exception: - error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" + error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n" logger.error(error_msg) raise diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index f3b12d716160..9e11fb7ef79b 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -15,7 +15,7 @@ import inspect import re from collections import OrderedDict -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from typing import Any, Dict, List, Literal, Optional, Type, Union import PIL.Image @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, FrozenDict from ..loaders.single_file_utils import _is_single_file_path_or_url -from ..utils import is_torch_available, logging +from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging if is_torch_available(): @@ -31,6 +31,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Template for modular pipeline model card description with placeholders +MODULAR_MODEL_CARD_TEMPLATE = """{model_description} + +## Example Usage + +[TODO] + +## Pipeline Architecture + +This modular pipeline is composed of the following blocks: + +{blocks_description} {trigger_inputs_section} + +## Model Components + +{components_description} {configs_section} + +## Input/Output Specification + +### Inputs {inputs_description} + +### Outputs {outputs_description} +""" + class InsertableDict(OrderedDict): def insert(self, key, value, index): @@ -186,7 +210,7 @@ def loading_fields(cls) -> List[str]: """ Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True). """ - return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + return DIFFUSERS_LOAD_ID_FIELDS.copy() @property def load_id(self) -> str: @@ -198,7 +222,7 @@ def load_id(self) -> str: return "null" parts = [getattr(self, k) for k in self.loading_fields()] parts = ["null" if p is None else p for p in parts] - return "|".join(p for p in parts if p) + return "|".join(parts) @classmethod def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: @@ -520,6 +544,7 @@ class InputParam: required: bool = False description: str = "" kwargs_type: str = None + metadata: Dict[str, Any] = None def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -553,6 +578,7 @@ class OutputParam: type_hint: Any = None description: str = "" kwargs_type: str = None + metadata: Dict[str, Any] = None def __repr__(self): return ( @@ -914,3 +940,178 @@ def make_doc_string( output += format_output_params(outputs, indent_level=2) return output + + +def generate_modular_model_card_content(blocks) -> Dict[str, Any]: + """ + Generate model card content for a modular pipeline. + + This function creates a comprehensive model card with descriptions of the pipeline's architecture, components, + configurations, inputs, and outputs. + + Args: + blocks: The pipeline's blocks object containing all pipeline specifications + + Returns: + Dict[str, Any]: A dictionary containing formatted content sections: + - pipeline_name: Name of the pipeline + - model_description: Overall description with pipeline type + - blocks_description: Detailed architecture of blocks + - components_description: List of required components + - configs_section: Configuration parameters section + - inputs_description: Input parameters specification + - outputs_description: Output parameters specification + - trigger_inputs_section: Conditional execution information + - tags: List of relevant tags for the model card + """ + blocks_class_name = blocks.__class__.__name__ + pipeline_name = blocks_class_name.replace("Blocks", " Pipeline") + description = getattr(blocks, "description", "A modular diffusion pipeline.") + + # generate blocks architecture description + blocks_desc_parts = [] + sub_blocks = getattr(blocks, "sub_blocks", None) or {} + if sub_blocks: + for i, (name, block) in enumerate(sub_blocks.items()): + block_class = block.__class__.__name__ + block_desc = block.description.split("\n")[0] if getattr(block, "description", "") else "" + blocks_desc_parts.append(f"{i + 1}. **{name}** (`{block_class}`)") + if block_desc: + blocks_desc_parts.append(f" - {block_desc}") + + # add sub-blocks if any + if hasattr(block, "sub_blocks") and block.sub_blocks: + for sub_name, sub_block in block.sub_blocks.items(): + sub_class = sub_block.__class__.__name__ + sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else "" + blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`") + if sub_desc: + blocks_desc_parts.append(f" - {sub_desc}") + + blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined." + + components = getattr(blocks, "expected_components", []) + if components: + components_str = format_components(components, indent_level=0, add_empty_lines=False) + # remove the "Components:" header since template has its own + components_description = components_str.replace("Components:\n", "").strip() + if components_description: + # Convert to enumerated list + lines = [line.strip() for line in components_description.split("\n") if line.strip()] + enumerated_lines = [f"{i + 1}. {line}" for i, line in enumerate(lines)] + components_description = "\n".join(enumerated_lines) + else: + components_description = "No specific components required." + else: + components_description = "No specific components required. Components can be loaded dynamically." + + configs = getattr(blocks, "expected_configs", []) + configs_section = "" + if configs: + configs_str = format_configs(configs, indent_level=0, add_empty_lines=False) + configs_description = configs_str.replace("Configs:\n", "").strip() + if configs_description: + configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}" + + inputs = blocks.inputs + outputs = blocks.outputs + + # format inputs as markdown list + inputs_parts = [] + required_inputs = [inp for inp in inputs if inp.required] + optional_inputs = [inp for inp in inputs if not inp.required] + + if required_inputs: + inputs_parts.append("**Required:**\n") + for inp in required_inputs: + if hasattr(inp.type_hint, "__name__"): + type_str = inp.type_hint.__name__ + elif inp.type_hint is not None: + type_str = str(inp.type_hint).replace("typing.", "") + else: + type_str = "Any" + desc = inp.description or "No description provided" + inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}") + + if optional_inputs: + if required_inputs: + inputs_parts.append("") + inputs_parts.append("**Optional:**\n") + for inp in optional_inputs: + if hasattr(inp.type_hint, "__name__"): + type_str = inp.type_hint.__name__ + elif inp.type_hint is not None: + type_str = str(inp.type_hint).replace("typing.", "") + else: + type_str = "Any" + desc = inp.description or "No description provided" + default_str = f", default: `{inp.default}`" if inp.default is not None else "" + inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}") + + inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined." + + # format outputs as markdown list + outputs_parts = [] + for out in outputs: + if hasattr(out.type_hint, "__name__"): + type_str = out.type_hint.__name__ + elif out.type_hint is not None: + type_str = str(out.type_hint).replace("typing.", "") + else: + type_str = "Any" + desc = out.description or "No description provided" + outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}") + + outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs." + + trigger_inputs_section = "" + if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs: + trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None]) + if trigger_inputs_list: + trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list) + trigger_inputs_section = f""" +### Conditional Execution + +This pipeline contains blocks that are selected at runtime based on inputs: +- **Trigger Inputs**: {trigger_inputs_str} +""" + + # generate tags based on pipeline characteristics + tags = ["modular-diffusers", "diffusers"] + + if hasattr(blocks, "model_name") and blocks.model_name: + tags.append(blocks.model_name) + + if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs: + triggers = blocks.trigger_inputs + if any(t in triggers for t in ["mask", "mask_image"]): + tags.append("inpainting") + if any(t in triggers for t in ["image", "image_latents"]): + tags.append("image-to-image") + if any(t in triggers for t in ["control_image", "controlnet_cond"]): + tags.append("controlnet") + if not any(t in triggers for t in ["image", "mask", "image_latents", "mask_image"]): + tags.append("text-to-image") + else: + tags.append("text-to-image") + + block_count = len(blocks.sub_blocks) + model_description = f"""This is a modular diffusion pipeline built with 🧨 Diffusers' modular pipeline framework. + +**Pipeline Type**: {blocks_class_name} + +**Description**: {description} + +This pipeline uses a {block_count}-block architecture that can be customized and extended.""" + + return { + "pipeline_name": pipeline_name, + "model_description": model_description, + "blocks_description": blocks_description, + "components_description": components_description, + "configs_section": configs_section, + "inputs_description": inputs_description, + "outputs_description": outputs_description, + "trigger_inputs_section": trigger_inputs_section, + "tags": tags, + } diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py index 73f67c9afed2..284b6c9fa436 100644 --- a/src/diffusers/modular_pipelines/wan/__init__.py +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -21,16 +21,16 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["decoders"] = ["WanImageVaeDecoderStep"] - _import_structure["encoders"] = ["WanTextEncoderStep"] - _import_structure["modular_blocks"] = [ - "ALL_BLOCKS", - "Wan22AutoBlocks", - "WanAutoBlocks", - "WanAutoImageEncoderStep", - "WanAutoVaeImageEncoderStep", + _import_structure["modular_blocks_wan"] = ["WanBlocks"] + _import_structure["modular_blocks_wan22"] = ["Wan22Blocks"] + _import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"] + _import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "Wan22Image2VideoModularPipeline", + "Wan22ModularPipeline", + "WanImage2VideoModularPipeline", + "WanModularPipeline", ] - _import_structure["modular_pipeline"] = ["WanModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -39,16 +39,16 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .decoders import WanImageVaeDecoderStep - from .encoders import WanTextEncoderStep - from .modular_blocks import ( - ALL_BLOCKS, - Wan22AutoBlocks, - WanAutoBlocks, - WanAutoImageEncoderStep, - WanAutoVaeImageEncoderStep, + from .modular_blocks_wan import WanBlocks + from .modular_blocks_wan22 import Wan22Blocks + from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks + from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks + from .modular_pipeline import ( + Wan22Image2VideoModularPipeline, + Wan22ModularPipeline, + WanImage2VideoModularPipeline, + WanModularPipeline, ) - from .modular_pipeline import WanModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index e2f8d3e7d88b..719ba4c21148 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -280,7 +280,7 @@ class WanAdditionalInputsStep(ModularPipelineBlocks): def __init__( self, - image_latent_inputs: List[str] = ["first_frame_latents"], + image_latent_inputs: List[str] = ["image_condition_latents"], additional_batch_inputs: List[str] = [], ): """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" @@ -294,20 +294,16 @@ def __init__( Args: image_latent_inputs (List[str], optional): Names of image latent tensors to process. In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be - a single string or list of strings. Defaults to ["first_frame_latents"]. + a single string or list of strings. Defaults to ["image_condition_latents"]. additional_batch_inputs (List[str], optional): Names of additional conditional input tensors to expand batch size. These tensors will only have their batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. Defaults to []. Examples: - # Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep() - - # Configure to process multiple image latent inputs - WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"]) - - # Configure to process image latents and additional batch inputs WanAdditionalInputsStep( - image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"] + # Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to + process image latents and additional batch inputs WanAdditionalInputsStep( + image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"] ) """ if not isinstance(image_latent_inputs, list): @@ -557,81 +553,3 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe self.set_block_state(state, block_state) return components, state - - -class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): - model_name = "wan" - - @property - def description(self) -> str: - return "step that prepares the masked first frame latents and add it to the latent condition" - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), - InputParam("num_frames", type_hint=int), - ] - - def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape - - mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) - mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 - - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave( - first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal - ) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view( - batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width - ) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) - block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) - - self.set_block_state(state, block_state) - return components, state - - -class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): - model_name = "wan" - - @property - def description(self) -> str: - return "step that prepares the masked latents with first and last frames and add it to the latent condition" - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]), - InputParam("num_frames", type_hint=int), - ] - - def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape - - mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) - mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0 - - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave( - first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal - ) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view( - batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width - ) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) - block_state.first_last_frame_latents = torch.concat( - [mask_lat_size, block_state.first_last_frame_latents], dim=1 - ) - - self.set_block_state(state, block_state) - return components, state diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py index 7cec318c1706..c26a8b11ba5c 100644 --- a/src/diffusers/modular_pipelines/wan/decoders.py +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -29,7 +29,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class WanImageVaeDecoderStep(ModularPipelineBlocks): +class WanVaeDecoderStep(ModularPipelineBlocks): model_name = "wan" @property @@ -56,7 +56,10 @@ def inputs(self) -> List[Tuple[str, Any]]: required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step", - ) + ), + InputParam( + "output_type", default="np", type_hint=str, description="The output type of the decoded videos" + ), ] @property @@ -87,7 +90,8 @@ def __call__(self, components, state: PipelineState) -> PipelineState: latents = latents.to(vae_dtype) block_state.videos = components.vae.decode(latents, return_dict=False)[0] - block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np") + output_type = getattr(block_state, "output_type", "np") + block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type=output_type) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 2da36f52da87..7f44b0230d78 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -89,52 +89,10 @@ def inputs(self) -> List[InputParam]: description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( - "first_frame_latents", + "image_condition_latents", required=True, type_hint=torch.Tensor, - description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.", - ), - InputParam( - "dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs. Can be generated in input step.", - ), - ] - - @torch.no_grad() - def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to( - block_state.dtype - ) - return components, block_state - - -class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks): - model_name = "wan" - - @property - def description(self) -> str: - return ( - "step within the denoising loop that prepares the latent input for the denoiser. " - "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " - "object (e.g. `WanDenoiseLoopWrapper`)" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", - ), - InputParam( - "first_last_frame_latents", - required=True, - type_hint=torch.Tensor, - description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.", + description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.", ), InputParam( "dtype", @@ -147,7 +105,7 @@ def inputs(self) -> List[InputParam]: @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): block_state.latent_model_input = torch.cat( - [block_state.latents, block_state.first_last_frame_latents], dim=1 + [block_state.latents, block_state.image_condition_latents], dim=1 ).to(block_state.dtype) return components, block_state @@ -584,29 +542,3 @@ def description(self) -> str: " - `WanLoopAfterDenoiser`\n" "This block supports image-to-video tasks for Wan2.2." ) - - -class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper): - block_classes = [ - WanFLF2VLoopBeforeDenoiser, - WanLoopDenoiser( - guider_input_fields={ - "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), - "encoder_hidden_states_image": "image_embeds", - } - ), - WanLoopAfterDenoiser, - ] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents. \n" - "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" - "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" - " - `WanFLF2VLoopBeforeDenoiser`\n" - " - `WanLoopDenoiser`\n" - " - `WanLoopAfterDenoiser`\n" - "This block supports FLF2V tasks for wan2.1." - ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index 4fd69c6ca6ab..22b62a601d34 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -468,7 +468,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -class WanVaeImageEncoderStep(ModularPipelineBlocks): +class WanVaeEncoderStep(ModularPipelineBlocks): model_name = "wan" @property @@ -493,7 +493,7 @@ def inputs(self) -> List[InputParam]: InputParam("resized_image", type_hint=PIL.Image.Image, required=True), InputParam("height"), InputParam("width"), - InputParam("num_frames"), + InputParam("num_frames", type_hint=int, default=81), InputParam("generator"), ] @@ -564,7 +564,51 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks): +class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked first frame latents and add it to the latent condition" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("num_frames", required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) + block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) + + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks): model_name = "wan" @property @@ -590,7 +634,7 @@ def inputs(self) -> List[InputParam]: InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), InputParam("height"), InputParam("width"), - InputParam("num_frames"), + InputParam("num_frames", type_hint=int, default=81), InputParam("generator"), ] @@ -667,3 +711,49 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe self.set_block_state(state, block_state) return components, state + + +class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked latents with first and last frames and add it to the latent condition" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("num_frames", type_hint=int, required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) + block_state.image_condition_latents = torch.concat( + [mask_lat_size, block_state.first_last_frame_latents], dim=1 + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py deleted file mode 100644 index 905111bcf42d..000000000000 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ /dev/null @@ -1,474 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ...utils import logging -from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict -from .before_denoise import ( - WanAdditionalInputsStep, - WanPrepareFirstFrameLatentsStep, - WanPrepareFirstLastFrameLatentsStep, - WanPrepareLatentsStep, - WanSetTimestepsStep, - WanTextInputStep, -) -from .decoders import WanImageVaeDecoderStep -from .denoise import ( - Wan22DenoiseStep, - Wan22Image2VideoDenoiseStep, - WanDenoiseStep, - WanFLF2VDenoiseStep, - WanImage2VideoDenoiseStep, -) -from .encoders import ( - WanFirstLastFrameImageEncoderStep, - WanFirstLastFrameVaeImageEncoderStep, - WanImageCropResizeStep, - WanImageEncoderStep, - WanImageResizeStep, - WanTextEncoderStep, - WanVaeImageEncoderStep, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# wan2.1 -# wan2.1: text2vid -class WanCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanSetTimestepsStep, - WanPrepareLatentsStep, - WanDenoiseStep, - ] - block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] - - @property - def description(self): - return ( - "denoise block that takes encoded conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanDenoiseStep` is used to denoise the latents\n" - ) - - -# wan2.1: image2video -## image encoder -class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): - model_name = "wan" - block_classes = [WanImageResizeStep, WanImageEncoderStep] - block_names = ["image_resize", "image_encoder"] - - @property - def description(self): - return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" - - -## vae encoder -class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): - model_name = "wan" - block_classes = [WanImageResizeStep, WanVaeImageEncoderStep] - block_names = ["image_resize", "vae_encoder"] - - @property - def description(self): - return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" - - -## denoise -class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]), - WanSetTimestepsStep, - WanPrepareLatentsStep, - WanPrepareFirstFrameLatentsStep, - WanImage2VideoDenoiseStep, - ] - block_names = [ - "input", - "additional_inputs", - "set_timesteps", - "prepare_latents", - "prepare_first_frame_latents", - "denoise", - ] - - @property - def description(self): - return ( - "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n" - + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" - ) - - -# wan2.1: FLF2v - - -## image encoder -class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): - model_name = "wan" - block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] - block_names = ["image_resize", "last_image_resize", "image_encoder"] - - @property - def description(self): - return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" - - -## vae encoder -class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks): - model_name = "wan" - block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep] - block_names = ["image_resize", "last_image_resize", "vae_encoder"] - - @property - def description(self): - return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" - - -## denoise -class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]), - WanSetTimestepsStep, - WanPrepareLatentsStep, - WanPrepareFirstLastFrameLatentsStep, - WanFLF2VDenoiseStep, - ] - block_names = [ - "input", - "additional_inputs", - "set_timesteps", - "prepare_latents", - "prepare_first_last_frame_latents", - "denoise", - ] - - @property - def description(self): - return ( - "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n" - + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" - ) - - -# wan2.1: auto blocks -## image encoder -class WanAutoImageEncoderStep(AutoPipelineBlocks): - block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] - block_names = ["flf2v_image_encoder", "image2video_image_encoder"] - block_trigger_inputs = ["last_image", "image"] - - @property - def description(self): - return ( - "Image Encoder step that encode the image to generate the image embeddings" - + "This is an auto pipeline block that works for image2video tasks." - + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." - + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." - + " - if `last_image` or `image` is not provided, step will be skipped." - ) - - -## vae encoder -class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): - block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep] - block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"] - block_trigger_inputs = ["last_image", "image"] - - @property - def description(self): - return ( - "Vae Image Encoder step that encode the image to generate the image latents" - + "This is an auto pipeline block that works for image2video tasks." - + " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided." - + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided." - + " - if `last_image` or `image` is not provided, step will be skipped." - ) - - -## denoise -class WanAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [ - WanFLF2VCoreDenoiseStep, - WanImage2VideoCoreDenoiseStep, - WanCoreDenoiseStep, - ] - block_names = ["flf2v", "image2video", "text2video"] - block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2video and image2video tasks." - " - `WanCoreDenoiseStep` (text2video) for text2vid tasks." - " - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks." - + " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n" - + " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n" - ) - - -# auto pipeline blocks -class WanAutoBlocks(SequentialPipelineBlocks): - block_classes = [ - WanTextEncoderStep, - WanAutoImageEncoderStep, - WanAutoVaeImageEncoderStep, - WanAutoDenoiseStep, - WanImageVaeDecoderStep, - ] - block_names = [ - "text_encoder", - "image_encoder", - "vae_encoder", - "denoise", - "decode", - ] - - @property - def description(self): - return ( - "Auto Modular pipeline for text-to-video using Wan.\n" - + "- for text-to-video generation, all you need to provide is `prompt`" - ) - - -# wan22 -# wan2.2: text2vid - - -## denoise -class Wan22CoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanSetTimestepsStep, - WanPrepareLatentsStep, - Wan22DenoiseStep, - ] - block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] - - @property - def description(self): - return ( - "denoise block that takes encoded conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n" - ) - - -# wan2.2: image2video -## denoise -class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]), - WanSetTimestepsStep, - WanPrepareLatentsStep, - WanPrepareFirstFrameLatentsStep, - Wan22Image2VideoDenoiseStep, - ] - block_names = [ - "input", - "additional_inputs", - "set_timesteps", - "prepare_latents", - "prepare_first_frame_latents", - "denoise", - ] - - @property - def description(self): - return ( - "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n" - + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n" - ) - - -class Wan22AutoDenoiseStep(AutoPipelineBlocks): - block_classes = [ - Wan22Image2VideoCoreDenoiseStep, - Wan22CoreDenoiseStep, - ] - block_names = ["image2video", "text2video"] - block_trigger_inputs = ["first_frame_latents", None] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2video and image2video tasks." - " - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks." - " - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks." - + " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n" - + " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n" - ) - - -class Wan22AutoBlocks(SequentialPipelineBlocks): - block_classes = [ - WanTextEncoderStep, - WanAutoVaeImageEncoderStep, - Wan22AutoDenoiseStep, - WanImageVaeDecoderStep, - ] - block_names = [ - "text_encoder", - "vae_encoder", - "denoise", - "decode", - ] - - @property - def description(self): - return ( - "Auto Modular pipeline for text-to-video using Wan2.2.\n" - + "- for text-to-video generation, all you need to provide is `prompt`" - ) - - -# presets for wan2.1 and wan2.2 -# YiYi Notes: should we move these to doc? -# wan2.1 -TEXT2VIDEO_BLOCKS = InsertableDict( - [ - ("text_encoder", WanTextEncoderStep), - ("input", WanTextInputStep), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("denoise", WanDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -IMAGE2VIDEO_BLOCKS = InsertableDict( - [ - ("image_resize", WanImageResizeStep), - ("image_encoder", WanImage2VideoImageEncoderStep), - ("vae_encoder", WanImage2VideoVaeImageEncoderStep), - ("input", WanTextInputStep), - ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep), - ("denoise", WanImage2VideoDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - - -FLF2V_BLOCKS = InsertableDict( - [ - ("image_resize", WanImageResizeStep), - ("last_image_resize", WanImageCropResizeStep), - ("image_encoder", WanFLF2VImageEncoderStep), - ("vae_encoder", WanFLF2VVaeImageEncoderStep), - ("input", WanTextInputStep), - ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep), - ("denoise", WanFLF2VDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", WanTextEncoderStep), - ("image_encoder", WanAutoImageEncoderStep), - ("vae_encoder", WanAutoVaeImageEncoderStep), - ("denoise", WanAutoDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -# wan2.2 presets - -TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict( - [ - ("text_encoder", WanTextEncoderStep), - ("input", WanTextInputStep), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("denoise", Wan22DenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict( - [ - ("image_resize", WanImageResizeStep), - ("vae_encoder", WanImage2VideoVaeImageEncoderStep), - ("input", WanTextInputStep), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("denoise", Wan22DenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -AUTO_BLOCKS_WAN22 = InsertableDict( - [ - ("text_encoder", WanTextEncoderStep), - ("vae_encoder", WanAutoVaeImageEncoderStep), - ("denoise", Wan22AutoDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -# presets all blocks (wan and wan22) - - -ALL_BLOCKS = { - "wan2.1": { - "text2video": TEXT2VIDEO_BLOCKS, - "image2video": IMAGE2VIDEO_BLOCKS, - "flf2v": FLF2V_BLOCKS, - "auto": AUTO_BLOCKS, - }, - "wan2.2": { - "text2video": TEXT2VIDEO_BLOCKS_WAN22, - "image2video": IMAGE2VIDEO_BLOCKS_WAN22, - "auto": AUTO_BLOCKS_WAN22, - }, -} diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks_wan.py b/src/diffusers/modular_pipelines/wan/modular_blocks_wan.py new file mode 100644 index 000000000000..d01a86ca09b5 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks_wan.py @@ -0,0 +1,83 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from .before_denoise import ( + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + WanDenoiseStep, +) +from .encoders import ( + WanTextEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + + +# inputs(text) -> set_timesteps -> prepare_latents -> denoise +class WanCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanDenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanDenoiseStep` is used to denoise the latents\n" + ) + + +# ==================== +# 2. BLOCKS (Wan2.1 text2video) +# ==================== + + +class WanBlocks(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [ + WanTextEncoderStep, + WanCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Modular pipeline blocks for Wan2.1.\n" + + "- `WanTextEncoderStep` is used to encode the text\n" + + "- `WanCoreDenoiseStep` is used to denoise the latents\n" + + "- `WanVaeDecoderStep` is used to decode the latents to images" + ) diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks_wan22.py b/src/diffusers/modular_pipelines/wan/modular_blocks_wan22.py new file mode 100644 index 000000000000..21164422f3d9 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks_wan22.py @@ -0,0 +1,88 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from .before_denoise import ( + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + Wan22DenoiseStep, +) +from .encoders import ( + WanTextEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + +# inputs(text) -> set_timesteps -> prepare_latents -> denoise + + +class Wan22CoreDenoiseStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22DenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n" + ) + + +# ==================== +# 2. BLOCKS (Wan2.2 text2video) +# ==================== + + +class Wan22Blocks(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [ + WanTextEncoderStep, + Wan22CoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return ( + "Modular pipeline for text-to-video using Wan2.2.\n" + + " - `WanTextEncoderStep` encodes the text\n" + + " - `Wan22CoreDenoiseStep` denoes the latents\n" + + " - `WanVaeDecoderStep` decodes the latents to video frames\n" + ) diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py b/src/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py new file mode 100644 index 000000000000..3db1c8fa837b --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py @@ -0,0 +1,117 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from .before_denoise import ( + WanAdditionalInputsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + Wan22Image2VideoDenoiseStep, +) +from .encoders import ( + WanImageResizeStep, + WanPrepareFirstFrameLatentsStep, + WanTextEncoderStep, + WanVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. VAE ENCODER +# ==================== + + +class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep] + block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +# ==================== +# 2. DENOISE +# ==================== + + +# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents) +class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22Image2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n" + ) + + +# ==================== +# 3. BLOCKS (Wan2.2 Image2Video) +# ==================== + + +class Wan22Image2VideoBlocks(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanTextEncoderStep, + WanImage2VideoVaeEncoderStep, + Wan22Image2VideoCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "vae_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return ( + "Modular pipeline for image-to-video using Wan2.2.\n" + + " - `WanTextEncoderStep` encodes the text\n" + + " - `WanImage2VideoVaeEncoderStep` encodes the image\n" + + " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n" + + " - `WanVaeDecoderStep` decodes the latents to video frames\n" + ) diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py b/src/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py new file mode 100644 index 000000000000..d07ab8ecf473 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py @@ -0,0 +1,203 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from .before_denoise import ( + WanAdditionalInputsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + WanImage2VideoDenoiseStep, +) +from .encoders import ( + WanFirstLastFrameImageEncoderStep, + WanFirstLastFrameVaeEncoderStep, + WanImageCropResizeStep, + WanImageEncoderStep, + WanImageResizeStep, + WanPrepareFirstFrameLatentsStep, + WanPrepareFirstLastFrameLatentsStep, + WanTextEncoderStep, + WanVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# ==================== +# 1. IMAGE ENCODER +# ==================== + + +# wan2.1 I2V (first frame only) +class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanImageEncoderStep] + block_names = ["image_resize", "image_encoder"] + + @property + def description(self): + return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" + + +# wan2.1 FLF2V (first and last frame) +class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] + block_names = ["image_resize", "last_image_resize", "image_encoder"] + + @property + def description(self): + return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" + + +# wan2.1 Auto Image Encoder +class WanAutoImageEncoderStep(AutoPipelineBlocks): + block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] + block_names = ["flf2v_image_encoder", "image2video_image_encoder"] + block_trigger_inputs = ["last_image", "image"] + model_name = "wan-i2v" + + @property + def description(self): + return ( + "Image Encoder step that encode the image to generate the image embeddings" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# wan2.1 I2V (first frame only) +class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep] + block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +# wan2.1 FLF2V (first and last frame) +class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanImageResizeStep, + WanImageCropResizeStep, + WanFirstLastFrameVaeEncoderStep, + WanPrepareFirstLastFrameLatentsStep, + ] + block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"] + + @property + def description(self): + return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" + + +# wan2.1 Auto Vae Encoder +class WanAutoVaeEncoderStep(AutoPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep] + block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"] + block_trigger_inputs = ["last_image", "image"] + + @property + def description(self): + return ( + "Vae Image Encoder step that encode the image to generate the image latents" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise) +# ==================== + + +# wan2.1 I2V core denoise (support both I2V and FLF2V) +# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents) +class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanImage2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" + ) + + +# ==================== +# 4. BLOCKS (Wan2.1 Image2Video) +# ==================== + + +# wan2.1 Image2Video Auto Blocks +class WanImage2VideoAutoBlocks(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanTextEncoderStep, + WanAutoImageEncoderStep, + WanAutoVaeEncoderStep, + WanImage2VideoCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "image_encoder", + "vae_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for image-to-video using Wan.\n" + + "- for I2V workflow, all you need to provide is `image`" + + "- for FLF2V workflow, all you need to provide is `last_image` and `image`" + ) diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py index 930b25e4b905..0e52026a51bf 100644 --- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -13,8 +13,6 @@ # limitations under the License. -from typing import Any, Dict, Optional - from ...loaders import WanLoraLoaderMixin from ...pipelines.pipeline_utils import StableDiffusionMixin from ...utils import logging @@ -30,19 +28,12 @@ class WanModularPipeline( WanLoraLoaderMixin, ): """ - A ModularPipeline for Wan. + A ModularPipeline for Wan2.1 text2video. > [!WARNING] > This is an experimental feature and is likely to change in the future. """ - default_blocks_name = "WanAutoBlocks" - - # override the default_blocks_name in base class, which is just return self.default_blocks_name - def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: - if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: - return "Wan22AutoBlocks" - else: - return "WanAutoBlocks" + default_blocks_name = "WanBlocks" @property def default_height(self): @@ -118,3 +109,33 @@ def num_train_timesteps(self): if hasattr(self, "scheduler") and self.scheduler is not None: num_train_timesteps = self.scheduler.config.num_train_timesteps return num_train_timesteps + + +class WanImage2VideoModularPipeline(WanModularPipeline): + """ + A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "WanImage2VideoAutoBlocks" + + +class Wan22ModularPipeline(WanModularPipeline): + """ + A ModularPipeline for Wan2.2 text2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Wan22Blocks" + + +class Wan22Image2VideoModularPipeline(Wan22ModularPipeline): + """ + A ModularPipeline for Wan2.2 image2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Wan22Image2VideoBlocks" diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 65378631a172..cfa1f8d92558 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -167,6 +167,7 @@ _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["cosmos"] = [ "Cosmos2_5_PredictBasePipeline", + "Cosmos2_5_TransferPipeline", "Cosmos2TextToImagePipeline", "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", @@ -410,11 +411,12 @@ "Kandinsky5I2IPipeline", ] _import_structure["z_image"] = [ - "ZImageImg2ImgPipeline", - "ZImagePipeline", - "ZImageControlNetPipeline", "ZImageControlNetInpaintPipeline", + "ZImageControlNetPipeline", + "ZImageImg2ImgPipeline", + "ZImageInpaintPipeline", "ZImageOmniPipeline", + "ZImagePipeline", ] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", @@ -630,6 +632,7 @@ ) from .cosmos import ( Cosmos2_5_PredictBasePipeline, + Cosmos2_5_TransferPipeline, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, CosmosTextToWorldPipeline, @@ -870,6 +873,7 @@ ZImageControlNetInpaintPipeline, ZImageControlNetPipeline, ZImageImg2ImgPipeline, + ZImageInpaintPipeline, ZImageOmniPipeline, ZImagePipeline, ) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 5ee44190e23b..be1d7ea5a54e 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -127,6 +127,7 @@ ZImageControlNetInpaintPipeline, ZImageControlNetPipeline, ZImageImg2ImgPipeline, + ZImageInpaintPipeline, ZImageOmniPipeline, ZImagePipeline, ) @@ -235,6 +236,7 @@ ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline), ("qwenimage", QwenImageInpaintPipeline), ("qwenimage-edit", QwenImageEditInpaintPipeline), + ("z-image", ZImageInpaintPipeline), ] ) @@ -246,7 +248,7 @@ AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict( [ - ("wan", WanImageToVideoPipeline), + ("wan-i2v", WanImageToVideoPipeline), ] ) diff --git a/src/diffusers/pipelines/cosmos/__init__.py b/src/diffusers/pipelines/cosmos/__init__.py index 944f16553173..5fc66cdf84b6 100644 --- a/src/diffusers/pipelines/cosmos/__init__.py +++ b/src/diffusers/pipelines/cosmos/__init__.py @@ -25,6 +25,9 @@ _import_structure["pipeline_cosmos2_5_predict"] = [ "Cosmos2_5_PredictBasePipeline", ] + _import_structure["pipeline_cosmos2_5_transfer"] = [ + "Cosmos2_5_TransferPipeline", + ] _import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"] _import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"] _import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"] @@ -41,6 +44,7 @@ from .pipeline_cosmos2_5_predict import ( Cosmos2_5_PredictBasePipeline, ) + from .pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py new file mode 100644 index 000000000000..13f583e8df8a --- /dev/null +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py @@ -0,0 +1,923 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torchvision +import torchvision.transforms +import torchvision.transforms.functional +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosControlNetModel, CosmosTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _maybe_pad_video(video: torch.Tensor, num_frames: int): + n_pad_frames = num_frames - video.shape[2] + if n_pad_frames > 0: + last_frame = video[:, :, -1:, :, :] + video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2) + return video + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import cv2 + >>> import numpy as np + >>> import torch + >>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel + >>> from diffusers.utils import export_to_video, load_video + + >>> model_id = "nvidia/Cosmos-Transfer2.5-2B" + >>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur) + >>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge") + >>> pipe = Cosmos2_5_TransferPipeline.from_pretrained( + ... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Video2World with edge control: Generate video guided by edge maps extracted from input video. + >>> prompt = ( + ... "The video is a demonstration of robotic manipulation, likely in a laboratory or testing environment. It" + ... "features two robotic arms interacting with a piece of blue fabric. The setting is a room with a beige" + ... "couch in the background, providing a neutral backdrop for the robotic activity. The robotic arms are" + ... "positioned on either side of the fabric, which is placed on a yellow cushion. The left robotic arm is" + ... "white with a black gripper, while the right arm is black with a more complex, articulated gripper. At the" + ... "beginning, the fabric is laid out on the cushion. The left robotic arm approaches the fabric, its gripper" + ... "opening and closing as it positions itself. The right arm remains stationary initially, poised to assist." + ... "As the video progresses, the left arm grips the fabric, lifting it slightly off the cushion. The right arm" + ... "then moves in, its gripper adjusting to grasp the opposite side of the fabric. Both arms work in" + ... "coordination, lifting and holding the fabric between them. The fabric is manipulated with precision," + ... "showcasing the dexterity and control of the robotic arms. The camera remains static throughout, focusing" + ... "on the interaction between the robotic arms and the fabric, allowing viewers to observe the detailed" + ... "movements and coordination involved in the task." + ... ) + >>> negative_prompt = ( + ... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + ... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + ... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " + ... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + ... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + ... "Overall, the video is of poor quality." + ... ) + >>> input_video = load_video( + ... "https://github.com/nvidia-cosmos/cosmos-transfer2.5/raw/refs/heads/main/assets/robot_example/robot_input.mp4" + ... ) + >>> num_frames = 93 + + >>> # Extract edge maps from the input video using Canny edge detection + >>> edge_maps = [ + ... cv2.Canny(cv2.cvtColor(np.array(frame.convert("RGB")), cv2.COLOR_RGB2BGR), 100, 200) + ... for frame in input_video[:num_frames] + ... ] + >>> edge_maps = np.stack(edge_maps)[None] # (T, H, W) -> (1, T, H, W) + >>> controls = torch.from_numpy(edge_maps).expand(3, -1, -1, -1) # (1, T, H, W) -> (3, T, H, W) + >>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)] + >>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30) + + >>> video = pipe( + ... video=input_video[:num_frames], + ... controls=controls, + ... controls_conditioning_scale=1.0, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=num_frames, + ... ).frames[0] + >>> export_to_video(video, "edge_controlled_video.mp4", fps=30) + ``` +""" + + +class Cosmos2_5_TransferPipeline(DiffusionPipeline): + r""" + Pipeline for Cosmos Transfer2.5 base model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder. Cosmos Transfer2.5 uses the [Qwen2.5 + VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder. + tokenizer (`AutoTokenizer`): + Tokenizer associated with the Qwen2.5 VL encoder. + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker", "controlnet"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: AutoTokenizer, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + controlnet: Optional[CosmosControlNetModel], + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_mean", None) is not None + else None + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_std", None) is not None + else None + ) + self.latents_mean = latents_mean + self.latents_std = latents_std + + if self.latents_mean is None or self.latents_std is None: + raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + + def _get_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + input_ids_batch = [] + + for sample_idx in range(len(prompt)): + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[sample_idx], + } + ], + }, + ] + input_ids = self.tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=False, + add_vision_id=False, + max_length=max_sequence_length, + truncation=True, + padding="max_length", + ) + input_ids = torch.LongTensor(input_ids) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + outputs = self.text_encoder( + input_ids_batch.to(device), + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states + + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / ( + hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8 + ) + normalized_hidden_states.append(normalized_state) + + prompt_embeds = torch.cat(normalized_hidden_states, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and + # diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents + def prepare_latents( + self, + video: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int = 16, + height: int = 704, + width: int = 1280, + num_frames_in: int = 93, + num_frames_out: int = 93, + do_classifier_free_guidance: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + B = batch_size + C = num_channels_latents + T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + H = height // self.vae_scale_factor_spatial + W = width // self.vae_scale_factor_spatial + shape = (B, C, T, H, W) + + if num_frames_in == 0: + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) + cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) + + cond_latents = torch.zeros_like(latents) + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + else: + if video is None: + raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.") + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + cond_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + cond_latents = torch.cat(cond_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + cond_latents = (cond_latents - latents_mean) / latents_std + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + padding_shape = (B, 1, T, H, W) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + + def _encode_controls( + self, + controls: Optional[torch.Tensor], + height: int, + width: int, + num_frames: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> Optional[torch.Tensor]: + if controls is None: + return None + + control_video = self.video_processor.preprocess_video(controls, height, width) + control_video = _maybe_pad_video(control_video, num_frames) + + control_video = control_video.to(device=device, dtype=self.vae.dtype) + control_latents = [ + retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video + ] + control_latents = torch.cat(control_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + control_latents = (control_latents - latents_mean) / latents_std + return control_latents + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + video: List[PipelineImageInput] | None = None, + prompt: Union[str, List[str]] | None = None, + negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT, + height: int = 704, + width: Optional[int] = None, + num_frames: int = 93, + num_inference_steps: int = 36, + guidance_scale: float = 3.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None, + controls_conditioning_scale: Union[float, List[float]] = 1.0, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + conditional_frame_timestep: float = 0.1, + ): + r""" + The call function to the pipeline for generation. Supports three modes: + + - **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip. + - **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame. + - **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip. + + Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the + above in "*2Image mode"). + + Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt). + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional single image for Image2World conditioning. Must be `None` when `video` is provided. + video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional input video for Video2World conditioning. Must be `None` when `image` is provided. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. If not provided, this will be determined based on the + aspect ratio of the input and the provided height. + num_frames (`int`, defaults to `93`): + Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `3.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*): + Control image or video input used by the ControlNet. If `None`, ControlNet is skipped. + controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`): + The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if width is None: + frame = image or video[0] if image or video else None + if frame is None and controls is not None: + frame = controls[0] if isinstance(controls, list) else controls + if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4: + frame = controls[0] + + if frame is None: + width = int((height + 16) * (1280 / 720)) + elif isinstance(frame, PIL.Image.Image): + width = int((height + 16) * (frame.width / frame.height)) + else: + width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + img_context = torch.zeros( + batch_size, + self.transformer.config.img_context_num_tokens, + self.transformer.config.img_context_dim_in, + device=prompt_embeds.device, + dtype=transformer_dtype, + ) + encoder_hidden_states = (prompt_embeds, img_context) + neg_encoder_hidden_states = (negative_prompt_embeds, img_context) + + num_frames_in = None + if image is not None: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") + + image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) + video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0) + video = video.unsqueeze(0) + num_frames_in = 1 + elif video is None: + video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8) + num_frames_in = 0 + else: + num_frames_in = len(video) + + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for video input (given {batch_size})") + + assert video is not None + video = self.video_processor.preprocess_video(video, height, width) + + # pad with last frame (for video2world) + num_frames_out = num_frames + video = _maybe_pad_video(video, num_frames_out) + assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" + + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames_in=num_frames_in, + num_frames_out=num_frames, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + cond_mask = cond_mask.to(transformer_dtype) + + controls_latents = None + if controls is not None: + controls_latents = self._encode_controls( + controls, + height=height, + width=width, + num_frames=num_frames, + dtype=transformer_dtype, + device=device, + generator=generator, + ) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + gt_velocity = (latents - cond_latent) * cond_mask + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) + ) + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + control_blocks = None + if controls_latents is not None and self.controlnet is not None: + control_output = self.controlnet( + controls_latents=controls_latents, + latents=in_latents, + timestep=in_timestep, + encoder_hidden_states=encoder_hidden_states, + condition_mask=cond_mask, + conditioning_scale=controls_conditioning_scale, + padding_mask=padding_mask, + return_dict=False, + ) + control_blocks = control_output[0] + + noise_pred = self.transformer( + hidden_states=in_latents, + timestep=in_timestep, + encoder_hidden_states=encoder_hidden_states, + block_controlnet_hidden_states=control_blocks, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + control_blocks = None + if controls_latents is not None and self.controlnet is not None: + control_output = self.controlnet( + controls_latents=controls_latents, + latents=in_latents, + timestep=in_timestep, + encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + condition_mask=cond_mask, + conditioning_scale=controls_conditioning_scale, + padding_mask=padding_mask, + return_dict=False, + ) + control_blocks = control_output[0] + + noise_pred_neg = self.transformer( + hidden_states=in_latents, + timestep=in_timestep, + encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + block_controlnet_hidden_states=control_blocks, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = self.latents_mean.to(latents.device, latents.dtype) + latents_std = self.latents_std.to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + video = self._match_num_frames(video, num_frames) + + assert self.safety_checker is not None + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + if vid is None: + video_batch.append(np.zeros_like(video[0])) + else: + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) + + def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor: + if target_num_frames <= 0 or video.shape[2] == target_num_frames: + return video + + frames_per_latent = max(self.vae_scale_factor_temporal, 1) + video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2) + + current_frames = video.shape[2] + if current_frames < target_num_frames: + pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1) + video = torch.cat([video, pad], dim=2) + elif current_frames > target_num_frames: + video = video[:, :, :target_num_frames] + + return video diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py index 0bc7a59db228..347601422c83 100644 --- a/src/diffusers/pipelines/ltx2/export_utils.py +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -13,12 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator from fractions import Fraction -from typing import Optional +from itertools import chain +from typing import List, Optional, Union +import numpy as np +import PIL.Image import torch +from tqdm import tqdm -from ...utils import is_av_available +from ...utils import get_logger, is_av_available + + +logger = get_logger(__name__) # pylint: disable=invalid-name _CAN_USE_AV = is_av_available() @@ -101,11 +109,59 @@ def _write_audio( def encode_video( - video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str + video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]], + fps: int, + audio: Optional[torch.Tensor], + audio_sample_rate: Optional[int], + output_path: str, + video_chunks_number: int = 1, ) -> None: - video_np = video.cpu().numpy() - - _, height, width, _ = video_np.shape + """ + Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo: + https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182 + + Args: + video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the + input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines + usually return with `output_type="np"`). + fps (`int`) + The frames per second (FPS) of the encoded video. + audio (`torch.Tensor`, *optional*): + An audio waveform of shape [audio_channels, samples]. + audio_sample_rate: (`int`, *optional*): + The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz). + output_path (`str`): + The path to save the encoded video to. + video_chunks_number (`int`, *optional*, defaults to `1`): + The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The + number of chunks to use often depends on the tiling config for the video VAE. + """ + if isinstance(video, list) and isinstance(video[0], PIL.Image.Image): + # Pipeline output_type="pil"; assumes each image is in "RGB" mode + video_frames = [np.array(frame) for frame in video] + video = np.stack(video_frames, axis=0) + video = torch.from_numpy(video) + elif isinstance(video, np.ndarray): + # Pipeline output_type="np" + is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video)) + if np.all(is_denormalized): + video = (video * 255).round().astype("uint8") + else: + logger.warning( + "Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel " + "values in [0, ..., 255] and will be used as is." + ) + video = torch.from_numpy(video) + + if isinstance(video, torch.Tensor): + # Split into video_chunks_number along the frame dimension + video = torch.tensor_split(video, video_chunks_number, dim=0) + video = iter(video) + + first_chunk = next(video) + + _, height, width, _ = first_chunk.shape container = av.open(output_path, mode="w") stream = container.add_stream("libx264", rate=int(fps)) @@ -119,10 +175,12 @@ def encode_video( audio_stream = _prepare_audio_stream(container, audio_sample_rate) - for frame_array in video_np: - frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") - for packet in stream.encode(frame): - container.mux(packet) + for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"): + video_chunk_cpu = video_chunk.to("cpu").numpy() + for frame_array in video_chunk_cpu: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) # Flush encoder for packet in stream.encode(): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index a92a7a2c8869..cb01159a81a7 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -69,8 +69,6 @@ ... output_type="np", ... return_dict=False, ... ) - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0], diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 04d7ee89c52a..c120e1f010e9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -75,8 +75,6 @@ ... output_type="np", ... return_dict=False, ... ) - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0], diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 340efd10f24f..b0db1bdee317 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -76,8 +76,6 @@ ... output_type="np", ... return_dict=False, ... )[0] - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0], diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index 78bd3bfacbec..14629a6e2160 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -26,6 +26,7 @@ _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] _import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"] _import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"] + _import_structure["pipeline_z_image_inpaint"] = ["ZImageInpaintPipeline"] _import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"] @@ -42,6 +43,7 @@ from .pipeline_z_image_controlnet import ZImageControlNetPipeline from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline from .pipeline_z_image_img2img import ZImageImg2ImgPipeline + from .pipeline_z_image_inpaint import ZImageInpaintPipeline from .pipeline_z_image_omni import ZImageOmniPipeline else: import sys diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 08fc4da0e7ba..3c8db4a0f748 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -635,10 +635,12 @@ def __call__( latent_model_input = latents_typed.repeat(2, 1, 1, 1) prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds timestep_model_input = timestep.repeat(2) + control_image_input = control_image.repeat(2, 1, 1, 1, 1) else: latent_model_input = latents.to(self.transformer.dtype) prompt_embeds_model_input = prompt_embeds timestep_model_input = timestep + control_image_input = control_image latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) @@ -647,7 +649,7 @@ def __call__( latent_model_input_list, timestep_model_input, prompt_embeds_model_input, - control_image, + control_image_input, conditioning_scale=controlnet_conditioning_scale, ) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py index 3b0f8dc288d3..cdc60eaf4dd3 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py @@ -657,10 +657,12 @@ def __call__( latent_model_input = latents_typed.repeat(2, 1, 1, 1) prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds timestep_model_input = timestep.repeat(2) + control_image_input = control_image.repeat(2, 1, 1, 1, 1) else: latent_model_input = latents.to(self.transformer.dtype) prompt_embeds_model_input = prompt_embeds timestep_model_input = timestep + control_image_input = control_image latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) @@ -669,7 +671,7 @@ def __call__( latent_model_input_list, timestep_model_input, prompt_embeds_model_input, - control_image, + control_image_input, conditioning_scale=controlnet_conditioning_scale, ) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py new file mode 100644 index 000000000000..de9fb3c82f8c --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py @@ -0,0 +1,932 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> # Create a mask (white = inpaint, black = preserve) + >>> import numpy as np + >>> from PIL import Image + + >>> mask = np.zeros((1024, 1024), dtype=np.uint8) + >>> mask[256:768, 256:768] = 255 # Inpaint center region + >>> mask_image = Image.fromarray(mask) + + >>> prompt = "A beautiful lake with mountains in the background" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask_image, + ... strength=1.0, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage_inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + r""" + The ZImage pipeline for inpainting. + + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`PreTrainedModel`]): + A text encoder model to encode text prompts. + tokenizer ([`AutoTokenizer`]): + A tokenizer to tokenize text prompts. + transformer ([`ZImageTransformer2DModel`]): + A ZImage transformer model to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "mask", "masked_image_latents"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + ): + """Prepare mask and masked image latents for inpainting. + + Args: + mask: Binary mask tensor where 1 = inpaint region, 0 = preserve region. + masked_image: Original image with masked regions zeroed out. + batch_size: Number of images to generate. + height: Output image height. + width: Output image width. + dtype: Data type for the tensors. + device: Device to place tensors on. + generator: Random generator for reproducibility. + + Returns: + Tuple of (mask, masked_image_latents) prepared for the denoising loop. + """ + # Calculate latent dimensions + latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + # Resize mask to latent dimensions + mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width), mode="nearest") + mask = mask.to(device=device, dtype=dtype) + + # Encode masked image to latents + masked_image = masked_image.to(device=device, dtype=dtype) + if isinstance(generator, list): + masked_image_latents = [ + retrieve_latents(self.vae.encode(masked_image[i : i + 1]), generator=generator[i]) + for i in range(masked_image.shape[0]) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + # Apply VAE scaling + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # Expand for batch size + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + return mask, masked_image_latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + """Prepare latents for inpainting, returning noise and image_latents for blending. + + Returns: + Tuple of (latents, noise, image_latents) where: + - latents: Noised image latents for denoising + - noise: The noise tensor used for blending + - image_latents: Clean image latents for blending + """ + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + # Generate noise for blending even if latents are provided + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Encode image for blending + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + image_latents = torch.cat([image_latents] * (batch_size // image_latents.shape[0]), dim=0) + return latents.to(device=device, dtype=dtype), noise, image_latents + + # Encode the input image + image = image.to(device=device, dtype=dtype) + if image.shape[1] != num_channels_latents: + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor) + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + image_latents = image + + # Handle batch size expansion + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + + # Generate noise for both initial noising and later blending + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # Add noise using flow matching scale_noise + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + + return latents, noise, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if image is None: + raise ValueError("`image` input cannot be undefined for inpainting.") + + if mask_image is None: + raise ValueError("`mask_image` input cannot be undefined for inpainting.") + + if output_type not in ["latent", "pil", "np", "pt"]: + raise ValueError(f"`output_type` must be one of 'latent', 'pil', 'np', or 'pt', but got {output_type}") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: Optional[torch.FloatTensor] = None, + strength: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for inpainting. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing a mask image for inpainting. White pixels (value 1) in the + mask will be inpainted, black pixels (value 0) will be preserved from the original image. + masked_image_latents (`torch.FloatTensor`, *optional*): + Pre-encoded masked image latents. If provided, the masked image encoding step will be skipped. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image` in the masked region. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. If not provided, uses the input image height. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. If not provided, uses the input image width. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + # 1. Check inputs + self.check_inputs( + prompt=prompt, + image=image, + mask_image=mask_image, + strength=strength, + height=height, + width=width, + output_type=output_type, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Preprocess image and mask + init_image = self.image_processor.preprocess(image) + init_image = init_image.to(dtype=torch.float32) + + # Get dimensions from the preprocessed image if not specified + if height is None: + height = init_image.shape[-2] + if width is None: + width = init_image.shape[-1] + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + # Preprocess mask + mask = self.mask_processor.preprocess(mask_image, height=height, width=width) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + + # Calculate latent dimensions for image_seq_len + latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + + # 6. Adjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline " + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(actual_batch_size) + + # 7. Prepare latents from image (returns noise and image_latents for blending) + latents, noise, image_latents = self.prepare_latents( + init_image, + latent_timestep, + actual_batch_size, + num_channels_latents, + height, + width, + prompt_embeds[0].dtype, + device, + generator, + latents, + ) + + # 8. Prepare mask and masked image latents + # Create masked image: preserve only unmasked regions (mask=0) + if masked_image_latents is None: + masked_image = init_image * (mask < 0.5) + else: + masked_image = None # Will use provided masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image if masked_image is not None else init_image, + actual_batch_size, + height, + width, + prompt_embeds[0].dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 9. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + # Inpainting blend: combine denoised latents with original image latents + init_latents_proper = image_latents + + # Re-scale original latents to current noise level for proper blending + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + # Blend: mask=1 for inpaint region (use denoised), mask=0 for preserve region (use original) + latents = (1 - mask) * init_latents_proper + mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 2fba9986e825..adb429688723 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -79,7 +79,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: # there is no need to call any kernel for fp16/bf16 if qweight_type in UNQUANTIZED_TYPES: - return x @ qweight.T + weight = dequantize_gguf_tensor(qweight) + return x @ weight.T # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for # contiguous batching and inefficient with diffusers' batching, diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py index 9c6b0fcf69b6..796d62e5b6a0 100644 --- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -545,7 +545,9 @@ def multistep_dpm_solver_second_order_update( # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep( - self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + self, + timestep: Union[int, torch.Tensor], + schedule_timesteps: Optional[torch.Tensor] = None, ) -> int: """ Find the index for a given timestep in the schedule. diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 7c2dfd8e503f..58086c13b5dc 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -867,7 +867,9 @@ def ind_fn(t, b, c, d): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep( - self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + self, + timestep: Union[int, torch.Tensor], + schedule_timesteps: Optional[torch.Tensor] = None, ) -> int: """ Find the index for a given timestep in the schedule. diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 07cb64f32b58..a1b118b2adcc 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -245,13 +245,26 @@ def __init__( ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): raise ValueError( "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." ) if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + deprecate( + "algorithm_types dpmsolver and sde-dpmsolver", + "1.0.0", + deprecation_message, + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -259,7 +272,15 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -287,7 +308,12 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type not in [ + "dpmsolver", + "dpmsolver++", + "sde-dpmsolver", + "sde-dpmsolver++", + ]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: @@ -724,7 +750,7 @@ def convert_model_output( self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -738,7 +764,7 @@ def convert_model_output( Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. - sample (`torch.Tensor`): + sample (`torch.Tensor`, *optional*): A current instance of a sample created by the diffusion process. Returns: @@ -822,7 +848,7 @@ def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -832,8 +858,10 @@ def dpm_solver_first_order_update( Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. - sample (`torch.Tensor`): + sample (`torch.Tensor`, *optional*): A current instance of a sample created by the diffusion process. + noise (`torch.Tensor`, *optional*): + The noise tensor. Returns: `torch.Tensor`: @@ -860,7 +888,10 @@ def dpm_solver_first_order_update( "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) @@ -891,7 +922,7 @@ def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -901,7 +932,7 @@ def multistep_dpm_solver_second_order_update( Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.Tensor`): + sample (`torch.Tensor`, *optional*): A current instance of a sample created by the diffusion process. Returns: @@ -1014,7 +1045,7 @@ def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -1024,8 +1055,10 @@ def multistep_dpm_solver_third_order_update( Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.Tensor`): + sample (`torch.Tensor`, *optional*): A current instance of a sample created by diffusion process. + noise (`torch.Tensor`, *optional*): + The noise tensor. Returns: `torch.Tensor`: @@ -1106,7 +1139,9 @@ def multistep_dpm_solver_third_order_update( return x_t def index_for_timestep( - self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + self, + timestep: Union[int, torch.Tensor], + schedule_timesteps: Optional[torch.Tensor] = None, ) -> int: """ Find the index for a given timestep in the schedule. @@ -1216,7 +1251,10 @@ def step( sample = sample.to(torch.float32) if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32, ) elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = variance_noise.to(device=model_output.device, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 2da90d287cf8..e7f0d1a13a9c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -141,6 +141,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): use_beta_sigmas (`bool`, *optional*, defaults to `False`): Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + flow_shift (`float`, *optional*, defaults to 1.0): + The flow shift factor. Valid only when `use_flow_sigmas=True`. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -163,15 +167,15 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", + algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++", + solver_type: Literal["midpoint", "heun"] = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, @@ -180,19 +184,32 @@ def __init__( use_flow_sigmas: Optional[bool] = False, flow_shift: Optional[float] = 1.0, lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, - timestep_spacing: str = "linspace", + variance_type: Optional[Literal["learned", "learned_range"]] = None, + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", steps_offset: int = 0, ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): raise ValueError( "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." ) if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + deprecate( + "algorithm_types dpmsolver and sde-dpmsolver", + "1.0.0", + deprecation_message, + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -200,7 +217,15 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -219,7 +244,12 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type not in [ + "dpmsolver", + "dpmsolver++", + "sde-dpmsolver", + "sde-dpmsolver++", + ]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: @@ -250,7 +280,11 @@ def step_index(self): """ return self._step_index - def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -382,7 +416,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -419,7 +453,7 @@ def _sigma_to_t(self, sigma, log_sigmas): return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert sigma values to alpha_t and sigma_t values. @@ -441,7 +475,7 @@ def _sigma_to_alpha_sigma_t(self, sigma): return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: """ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). @@ -567,7 +601,7 @@ def convert_model_output( self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -581,7 +615,7 @@ def convert_model_output( Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. - sample (`torch.Tensor`): + sample (`torch.Tensor`, *optional*): A current instance of a sample created by the diffusion process. Returns: @@ -666,7 +700,7 @@ def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -676,8 +710,10 @@ def dpm_solver_first_order_update( Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. - sample (`torch.Tensor`): + sample (`torch.Tensor`, *optional*): A current instance of a sample created by the diffusion process. + noise (`torch.Tensor`, *optional*): + The noise tensor. Returns: `torch.Tensor`: @@ -704,7 +740,10 @@ def dpm_solver_first_order_update( "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) @@ -736,7 +775,7 @@ def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -746,7 +785,7 @@ def multistep_dpm_solver_second_order_update( Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.Tensor`): + sample (`torch.Tensor`, *optional*): A current instance of a sample created by the diffusion process. Returns: @@ -860,7 +899,7 @@ def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -870,8 +909,10 @@ def multistep_dpm_solver_third_order_update( Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.Tensor`): + sample (`torch.Tensor`, *optional*): A current instance of a sample created by diffusion process. + noise (`torch.Tensor`, *optional*): + The noise tensor. Returns: `torch.Tensor`: @@ -951,7 +992,7 @@ def multistep_dpm_solver_third_order_update( ) return x_t - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]): if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -975,7 +1016,7 @@ def step( model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: @@ -1027,7 +1068,10 @@ def step( if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, ) elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = variance_noise @@ -1074,6 +1118,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the clean `original_samples` using the scheduler's equivalent function. + + Args: + original_samples (`torch.Tensor`): + The original samples to add noise to. + noise (`torch.Tensor`): + The noise tensor. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): @@ -1103,5 +1162,5 @@ def add_noise( noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 6f905a623d70..a1d4a997d126 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Literal, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -51,7 +51,14 @@ class DPMSolverSDESchedulerOutput(BaseOutput): class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" - def __init__(self, x, t0, t1, seed=None, **kwargs): + def __init__( + self, + x: torch.Tensor, + t0: float, + t1: float, + seed: Optional[Union[int, List[int]]] = None, + **kwargs, + ): t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: @@ -79,10 +86,23 @@ def __init__(self, x, t0, t1, seed=None, **kwargs): ] @staticmethod - def sort(a, b): - return (a, b, 1) if a < b else (b, a, -1) + def sort(a: float, b: float) -> Tuple[float, float, float]: + """ + Sorts two float values and returns them along with a sign indicating if they were swapped. + + Args: + a (`float`): + The first value. + b (`float`): + The second value. - def __call__(self, t0, t1): + Returns: + `Tuple[float, float, float]`: + A tuple containing the sorted values (min, max) and a sign (1.0 if a < b, -1.0 otherwise). + """ + return (a, b, 1.0) if a < b else (b, a, -1.0) + + def __call__(self, t0: float, t1: float) -> torch.Tensor: t0, t1, sign = self.sort(t0, t1) w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) return w if self.batched else w[0] @@ -92,23 +112,29 @@ class BrownianTreeNoiseSampler: """A noise sampler backed by a torchsde.BrownianTree. Args: - x (Tensor): The tensor whose shape, device and dtype to use to generate - random samples. - sigma_min (float): The low end of the valid interval. - sigma_max (float): The high end of the valid interval. - seed (int or List[int]): The random seed. If a list of seeds is + x (`torch.Tensor`): The tensor whose shape, device and dtype is used to generate random samples. + sigma_min (`float`): The low end of the valid interval. + sigma_max (`float`): The high end of the valid interval. + seed (`int` or `List[int]`): The random seed. If a list of seeds is supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each with its own seed. - transform (callable): A function that maps sigma to the sampler's + transform (`callable`): A function that maps sigma to the sampler's internal timestep. """ - def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + def __init__( + self, + x: torch.Tensor, + sigma_min: float, + sigma_max: float, + seed: Optional[Union[int, List[int]]] = None, + transform: Callable[[float], float] = lambda x: x, + ): self.transform = transform t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) self.tree = BatchedBrownianTree(x, t0, t1, seed) - def __call__(self, sigma, sigma_next): + def __call__(self, sigma: float, sigma_next: float) -> torch.Tensor: t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) return self.tree(t0, t1) / (t1 - t0).abs().sqrt() @@ -216,19 +242,28 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, - timestep_spacing: str = "linspace", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", steps_offset: int = 0, ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): raise ValueError( "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." ) @@ -238,7 +273,15 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -305,7 +348,7 @@ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: self._step_index = self._begin_index @property - def init_noise_sigma(self): + def init_noise_sigma(self) -> torch.Tensor: # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() @@ -313,21 +356,21 @@ def init_noise_sigma(self): return (self.sigmas.max() ** 2 + 1) ** 0.5 @property - def step_index(self): + def step_index(self) -> Union[int, None]: """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Union[int, None]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -369,7 +412,7 @@ def set_timesteps( num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -378,6 +421,8 @@ def set_timesteps( The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + num_train_timesteps (`int`, *optional*): + The number of train timesteps. If `None`, uses `self.config.num_train_timesteps`. """ self.num_inference_steps = num_inference_steps @@ -443,7 +488,7 @@ def set_timesteps( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.noise_sampler = None - def _second_order_timesteps(self, sigmas, log_sigmas): + def _second_order_timesteps(self, sigmas: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: def sigma_fn(_t): return np.exp(-_t) @@ -459,7 +504,7 @@ def t_fn(_sigma): return timesteps # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -604,14 +649,14 @@ def _convert_to_beta( return sigmas @property - def state_in_first_order(self): + def state_in_first_order(self) -> bool: return self.sample is None def step( self, - model_output: Union[torch.Tensor, np.ndarray], + model_output: torch.Tensor, timestep: Union[float, torch.Tensor], - sample: Union[torch.Tensor, np.ndarray], + sample: torch.Tensor, return_dict: bool = True, s_noise: float = 1.0, ) -> Union[DPMSolverSDESchedulerOutput, Tuple]: @@ -620,11 +665,11 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.Tensor` or `np.ndarray`): + model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float` or `torch.Tensor`): The current discrete timestep in the diffusion chain. - sample (`torch.Tensor` or `np.ndarray`): + sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or @@ -643,7 +688,9 @@ def step( # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() - self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) + self.noise_sampler = BrownianTreeNoiseSampler( + sample, min_sigma.item(), max_sigma.item(), self.noise_sampler_seed + ) # Define functions to compute sigma and t from each other def sigma_fn(_t: torch.Tensor) -> torch.Tensor: @@ -694,7 +741,10 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor: sigma_from = sigma_fn(t) sigma_to = sigma_fn(t_next) - sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_up = min( + sigma_to, + (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 ancestral_t = t_fn(sigma_down) prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( @@ -771,5 +821,5 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index e9bf815aba86..9ab362da3006 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -1120,7 +1120,9 @@ def singlestep_dpm_solver_update( # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep( - self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + self, + timestep: Union[int, torch.Tensor], + schedule_timesteps: Optional[torch.Tensor] = None, ) -> int: """ Find the index for a given timestep in the schedule. diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index a573f032cad8..282540f5a23c 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -662,7 +662,9 @@ def multistep_dpm_solver_third_order_update( # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep( - self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + self, + timestep: Union[int, torch.Tensor], + schedule_timesteps: Optional[torch.Tensor] = None, ) -> int: """ Find the index for a given timestep in the schedule. diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 565fae1c0d76..df8fbd7b1a47 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -1122,7 +1122,9 @@ def stochastic_adams_moulton_update( # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep( - self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + self, + timestep: Union[int, torch.Tensor], + schedule_timesteps: Optional[torch.Tensor] = None, ) -> int: """ Find the index for a given timestep in the schedule. diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index d8e24d196418..de04e854d60c 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -1083,7 +1083,9 @@ def multistep_uni_c_bh_update( # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep( - self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + self, + timestep: Union[int, torch.Tensor], + schedule_timesteps: Optional[torch.Tensor] = None, ) -> int: """ Find the index for a given timestep in the schedule. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index e726bbb46913..3f736e2ee39b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -23,6 +23,7 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS, DEPRECATED_REVISION_ARGS, DIFFUSERS_DYNAMIC_MODULE_NAME, + DIFFUSERS_LOAD_ID_FIELDS, FLAX_WEIGHTS_NAME, GGUF_FILE_EXTENSION, HF_ENABLE_PARALLEL_LOADING, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index c46fa4363483..4f94df656a65 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -73,3 +73,11 @@ ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" + + +DIFFUSERS_LOAD_ID_FIELDS = [ + "pretrained_model_name_or_path", + "subfolder", + "variant", + "revision", +] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 7120ff1f6257..d75be9c4714f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -227,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class MagCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PyramidAttentionBroadcastConfig(metaclass=DummyObject): _backends = ["torch"] @@ -284,6 +299,10 @@ def apply_layer_skip(*args, **kwargs): requires_backends(apply_layer_skip, ["torch"]) +def apply_mag_cache(*args, **kwargs): + requires_backends(apply_mag_cache, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) @@ -877,6 +896,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CosmosControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CosmosTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a23f852616c0..8758c549ca77 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2KleinBaseModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2KleinModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -287,7 +302,52 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class Wan22AutoBlocks(metaclass=DummyObject): +class Wan22Blocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Wan22Image2VideoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Wan22Image2VideoModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Wan22ModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -302,7 +362,37 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class WanAutoBlocks(metaclass=DummyObject): +class WanBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WanImage2VideoAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WanImage2VideoModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -887,6 +977,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Cosmos2_5_TransferPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Cosmos2TextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -4112,6 +4217,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ZImageInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ZImageOmniPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index d0b05c7d9541..58695bae1e9d 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -107,6 +107,7 @@ def load_or_create_model_card( license: Optional[str] = None, widget: Optional[List[dict]] = None, inference: Optional[bool] = None, + is_modular: bool = False, ) -> ModelCard: """ Loads or creates a model card. @@ -131,6 +132,8 @@ def load_or_create_model_card( widget (`List[dict]`, *optional*): Widget to accompany a gallery template. inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using `load_or_create_model_card` from a training script. + is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline. + When True, uses model_description as-is without additional template formatting. """ if not is_jinja_available(): raise ValueError( @@ -159,10 +162,14 @@ def load_or_create_model_card( ) else: card_data = ModelCardData() - component = "pipeline" if is_pipeline else "model" - if model_description is None: - model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." - model_card = ModelCard.from_template(card_data, model_description=model_description) + if is_modular and model_description is not None: + model_card = ModelCard(model_description) + model_card.data = card_data + else: + component = "pipeline" if is_pipeline else "model" + if model_description is None: + model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." + model_card = ModelCard.from_template(card_data, model_description=model_description) return model_card diff --git a/tests/hooks/test_mag_cache.py b/tests/hooks/test_mag_cache.py new file mode 100644 index 000000000000..a7e1b52d3b69 --- /dev/null +++ b/tests/hooks/test_mag_cache.py @@ -0,0 +1,244 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import MagCacheConfig, apply_mag_cache +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from diffusers.models import ModelMixin +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) + + +class DummyBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + # Output is double input + # This ensures Residual = 2*Input - Input = Input + return hidden_states * 2.0 + + +class DummyTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states + + +class TupleOutputBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + # Returns a tuple + return hidden_states * 2.0, encoder_hidden_states + + +class TupleTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + # Emulate Flux-like behavior + output = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = output[0] + encoder_hidden_states = output[1] + return hidden_states, encoder_hidden_states + + +class MagCacheTests(unittest.TestCase): + def setUp(self): + # Register standard dummy block + TransformerBlockRegistry.register( + DummyBlock, + TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None), + ) + # Register tuple block (Flux style) + TransformerBlockRegistry.register( + TupleOutputBlock, + TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1), + ) + + def _set_context(self, model, context_name): + """Helper to set context on all hooks in the model.""" + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(context_name) + + def _get_calibration_data(self, model): + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("mag_cache_block_hook") + if hook: + return hook.state_manager.get_state().calibration_ratios + return [] + + def test_mag_cache_validation(self): + """Test that missing mag_ratios raises ValueError.""" + with self.assertRaises(ValueError): + MagCacheConfig(num_inference_steps=10, calibrate=False) + + def test_mag_cache_skipping_logic(self): + """ + Tests that MagCache correctly calculates residuals and skips blocks when conditions are met. + """ + model = DummyTransformer() + + # Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=0.0, # Enable immediate skipping + max_skip_steps=5, + mag_ratios=ratios, + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each) + # HeadInput=10. Output=40. Residual=30. + input_t0 = torch.tensor([[[10.0]]]) + output_t0 = model(input_t0) + self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed") + + # Step 1: Input 11.0. + # If Skipped: Output = Input(11) + Residual(30) = 41.0 + # If Computed: Output = 11 * 4 = 44.0 + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}" + ) + + def test_mag_cache_retention(self): + """Test that retention_ratio prevents skipping even if error is low.""" + model = DummyTransformer() + # Ratios that imply 0 error, so it *would* skip if retention allowed it + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=1.0, # Force retention for ALL steps + mag_ratios=ratios, + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0 + model(torch.tensor([[[10.0]]])) + + # Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[44.0]]])), + f"Expected Compute (44.0) due to retention, got {output_t1.item()}", + ) + + def test_mag_cache_tuple_outputs(self): + """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" + model = TupleTransformer() + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x) + # Residual = 10.0 + input_t0 = torch.tensor([[[10.0]]]) + enc_t0 = torch.tensor([[[1.0]]]) + out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) + self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]]))) + + # Step 1: Skip. Input 11.0. + # Skipped Output = 11 + 10 = 21.0 + input_t1 = torch.tensor([[[11.0]]]) + out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) + + self.assertTrue( + torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}" + ) + + def test_mag_cache_reset(self): + """Test that state resets correctly after num_inference_steps.""" + model = DummyTransformer() + config = MagCacheConfig( + threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0]) + ) + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + input_t = torch.ones(1, 1, 1) + + model(input_t) # Step 0 + model(input_t) # Step 1 (Skipped) + + # Step 2 (Reset -> Step 0) -> Should Compute + # Input 2.0 -> Output 8.0 + input_t2 = torch.tensor([[[2.0]]]) + output_t2 = model(input_t2) + + self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly") + + def test_mag_cache_calibration(self): + """Test that calibration mode records ratios.""" + model = DummyTransformer() + config = MagCacheConfig(num_inference_steps=2, calibrate=True) + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0 + # HeadInput = 10. Output = 40. Residual = 30. + # Ratio 0 is placeholder 1.0 + model(torch.tensor([[[10.0]]])) + + # Check intermediate state + ratios = self._get_calibration_data(model) + self.assertEqual(len(ratios), 1) + self.assertEqual(ratios[0], 1.0) + + # Step 1 + # HeadInput = 10. Output = 40. Residual = 30. + # PrevResidual = 30. CurrResidual = 30. + # Ratio = 30/30 = 1.0 + model(torch.tensor([[[10.0]]])) + + # Verify it computes fully (no skip) + # If it skipped, output would be 41.0. It should be 40.0 + # Actually in test setup, input is same (10.0) so output 40.0. + # Let's ensure list is empty after reset (end of step 1) + ratios_after = self._get_calibration_data(model) + self.assertEqual(ratios_after, []) diff --git a/tests/models/controlnets/__init__.py b/tests/models/controlnets/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/controlnets/test_models_controlnet_cosmos.py b/tests/models/controlnets/test_models_controlnet_cosmos.py new file mode 100644 index 000000000000..bf879b11663b --- /dev/null +++ b/tests/models/controlnets/test_models_controlnet_cosmos.py @@ -0,0 +1,255 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import CosmosControlNetModel +from diffusers.models.controlnets.controlnet_cosmos import CosmosControlNetOutput + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = CosmosControlNetModel + main_input_name = "controls_latents" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 16 + num_frames = 1 + height = 16 + width = 16 + text_embed_dim = 32 + sequence_length = 12 + img_context_dim_in = 32 + img_context_num_tokens = 4 + + # Raw latents (not patchified) - the controlnet computes embeddings internally + controls_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.tensor([0.5]).to(torch_device) # Diffusion timestep + condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device) + padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device) + + # Text embeddings + text_context = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device) + # Image context for Cosmos 2.5 + img_context = torch.randn((batch_size, img_context_num_tokens, img_context_dim_in)).to(torch_device) + encoder_hidden_states = (text_context, img_context) + + return { + "controls_latents": controls_latents, + "latents": latents, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "condition_mask": condition_mask, + "conditioning_scale": 1.0, + "padding_mask": padding_mask, + } + + @property + def input_shape(self): + return (16, 1, 16, 16) + + @property + def output_shape(self): + # Output is tuple of n_controlnet_blocks tensors, each with shape (batch, num_patches, model_channels) + # After stacking by normalize_output: (n_blocks, batch, num_patches, model_channels) + # For test config: n_blocks=2, num_patches=64 (1*8*8), model_channels=32 + # output_shape is used as (batch_size,) + output_shape, so: (2, 64, 32) + return (2, 64, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "n_controlnet_blocks": 2, + "in_channels": 16 + 1 + 1, # control_latent_channels + condition_mask + padding_mask + "latent_channels": 16 + 1 + 1, # base_latent_channels (16) + condition_mask (1) + padding_mask (1) = 18 + "model_channels": 32, + "num_attention_heads": 2, + "attention_head_dim": 16, + "mlp_ratio": 2, + "text_embed_dim": 32, + "adaln_lora_dim": 4, + "patch_size": (1, 2, 2), + "max_size": (4, 32, 32), + "rope_scale": (2.0, 1.0, 1.0), + "extra_pos_embed_type": None, + "img_context_dim_in": 32, + "img_context_dim_out": 32, + "use_crossattn_projection": False, # Test doesn't need this projection + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output_format(self): + """Test that the model outputs CosmosControlNetOutput with correct structure.""" + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + self.assertIsInstance(output, CosmosControlNetOutput) + self.assertIsInstance(output.control_block_samples, list) + self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"]) + for tensor in output.control_block_samples: + self.assertIsInstance(tensor, torch.Tensor) + + def test_output_list_format(self): + """Test that return_dict=False returns a tuple containing a list.""" + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict, return_dict=False) + + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 1) + self.assertIsInstance(output[0], list) + self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"]) + + def test_conditioning_scale_single(self): + """Test that a single conditioning scale is broadcast to all blocks.""" + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + inputs_dict["conditioning_scale"] = 0.5 + + with torch.no_grad(): + output = model(**inputs_dict) + + self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"]) + + def test_conditioning_scale_list(self): + """Test that a list of conditioning scales is applied per block.""" + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # Provide a scale for each block + inputs_dict["conditioning_scale"] = [0.5, 1.0] + + with torch.no_grad(): + output = model(**inputs_dict) + + self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"]) + + def test_forward_with_none_img_context(self): + """Test forward pass when img_context is None.""" + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # Set encoder_hidden_states to (text_context, None) + text_context = inputs_dict["encoder_hidden_states"][0] + inputs_dict["encoder_hidden_states"] = (text_context, None) + + with torch.no_grad(): + output = model(**inputs_dict) + + self.assertIsInstance(output, CosmosControlNetOutput) + self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"]) + + def test_forward_without_img_context_proj(self): + """Test forward pass when img_context_proj is not configured.""" + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + # Disable img_context_proj + init_dict["img_context_dim_in"] = None + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # When img_context is disabled, pass only text context (not a tuple) + text_context = inputs_dict["encoder_hidden_states"][0] + inputs_dict["encoder_hidden_states"] = text_context + + with torch.no_grad(): + output = model(**inputs_dict) + + self.assertIsInstance(output, CosmosControlNetOutput) + self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"]) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"CosmosControlNetModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + # Note: test_set_attn_processor_for_determinism already handles uses_custom_attn_processor=True + # so no explicit skip needed for it + # Note: test_forward_signature and test_set_default_attn_processor don't exist in base class + + # Skip tests that don't apply to this architecture + @unittest.skip("CosmosControlNetModel doesn't use norm groups.") + def test_forward_with_norm_groups(self): + pass + + # Skip tests that expect .sample attribute - ControlNets don't have this + @unittest.skip("ControlNet output doesn't have .sample attribute") + def test_effective_gradient_checkpointing(self): + pass + + # Skip tests that compute MSE loss against single tensor output + @unittest.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss") + def test_ema_training(self): + pass + + @unittest.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss") + def test_training(self): + pass + + # Skip tests where output shape comparison doesn't apply to ControlNets + @unittest.skip("ControlNet output shape doesn't match input shape by design") + def test_output(self): + pass + + # Skip outputs_equivalence - dict/list comparison logic not compatible (recursive_check expects dict.values()) + @unittest.skip("ControlNet output structure not compatible with recursive dict check") + def test_outputs_equivalence(self): + pass + + # Skip model parallelism - base test uses torch.allclose(base_output[0], new_output[0]) which fails + # because output[0] is the list of control_block_samples, not a tensor + @unittest.skip("test_model_parallelism uses torch.allclose on output[0] which is a list, not a tensor") + def test_model_parallelism(self): + pass + + # Skip layerwise casting tests - these have two issues: + # 1. _inference and _memory: dtype compatibility issues with learnable_pos_embed and float8/bfloat16 + # 2. _training: same as test_training - mse_loss expects tensor, not list + @unittest.skip("Layerwise casting has dtype issues with learnable_pos_embed") + def test_layerwise_casting_inference(self): + pass + + @unittest.skip("Layerwise casting has dtype issues with learnable_pos_embed") + def test_layerwise_casting_memory(self): + pass + + @unittest.skip("test_layerwise_casting_training computes mse_loss on list output") + def test_layerwise_casting_training(self): + pass diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 661fcc253795..9ee5c6c2ac80 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -8,6 +8,13 @@ import diffusers from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks from diffusers.guiders import ClassifierFreeGuidance +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + generate_modular_model_card_content, +) from diffusers.utils import logging from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device @@ -30,6 +37,9 @@ class ModularPipelineTesterMixin: optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"]) # this is modular specific: generator needs to be a intermediate input because it's mutable intermediate_params = frozenset(["generator"]) + # Output type for the pipeline (e.g., "images" for image pipelines, "videos" for video pipelines) + # Subclasses can override this to change the expected output type + output_name = "images" def get_generator(self, seed=0): generator = torch.Generator("cpu").manual_seed(seed) @@ -156,7 +166,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True) logger.setLevel(level=diffusers.logging.WARNING) for batch_size, batched_input in zip(batch_sizes, batched_inputs): - output = pipe(**batched_input, output="images") + output = pipe(**batched_input, output=self.output_name) assert len(output) == batch_size, "Output is different from expected batch size" def test_inference_batch_single_identical( @@ -190,12 +200,16 @@ def test_inference_batch_single_identical( if "batch_size" in inputs: batched_inputs["batch_size"] = batch_size - output = pipe(**inputs, output="images") - output_batch = pipe(**batched_inputs, output="images") + output = pipe(**inputs, output=self.output_name) + output_batch = pipe(**batched_inputs, output=self.output_name) assert output_batch.shape[0] == batch_size - max_diff = torch.abs(output_batch[0] - output[0]).max() + # For batch comparison, we only need to compare the first item + if output_batch.shape[0] == batch_size and output.shape[0] == 1: + output_batch = output_batch[0:1] + + max_diff = torch.abs(output_batch - output).max() assert max_diff < expected_max_diff, "Batch inference results different from single inference results" @require_accelerator @@ -210,19 +224,32 @@ def test_float16_inference(self, expected_max_diff=5e-2): # Reset generator in case it is used inside dummy inputs if "generator" in inputs: inputs["generator"] = self.get_generator(0) - output = pipe(**inputs, output="images") + + output = pipe(**inputs, output=self.output_name) fp16_inputs = self.get_dummy_inputs() # Reset generator in case it is used inside dummy inputs if "generator" in fp16_inputs: fp16_inputs["generator"] = self.get_generator(0) - output_fp16 = pipe_fp16(**fp16_inputs, output="images") - output = output.cpu() - output_fp16 = output_fp16.cpu() + output_fp16 = pipe_fp16(**fp16_inputs, output=self.output_name) + + output_tensor = output.float().cpu() + output_fp16_tensor = output_fp16.float().cpu() + + # Check for NaNs in outputs (can happen with tiny models in FP16) + if torch.isnan(output_tensor).any() or torch.isnan(output_fp16_tensor).any(): + pytest.skip("FP16 inference produces NaN values - this is a known issue with tiny models") - max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) - assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference" + max_diff = numpy_cosine_similarity_distance( + output_tensor.flatten().numpy(), output_fp16_tensor.flatten().numpy() + ) + + # Check if cosine similarity is NaN (which can happen if vectors are zero or very small) + if torch.isnan(torch.tensor(max_diff)): + pytest.skip("Cosine similarity is NaN - outputs may be too small for reliable comparison") + + assert max_diff < expected_max_diff, f"FP16 inference is different from FP32 inference (max_diff: {max_diff})" @require_accelerator def test_to_device(self): @@ -244,14 +271,16 @@ def test_to_device(self): def test_inference_is_not_nan_cpu(self): pipe = self.get_pipeline().to("cpu") - output = pipe(**self.get_dummy_inputs(), output="images") + inputs = self.get_dummy_inputs() + output = pipe(**inputs, output=self.output_name) assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN" @require_accelerator def test_inference_is_not_nan(self): pipe = self.get_pipeline().to(torch_device) - output = pipe(**self.get_dummy_inputs(), output="images") + inputs = self.get_dummy_inputs() + output = pipe(**inputs, output=self.output_name) assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN" def test_num_images_per_prompt(self): @@ -271,7 +300,7 @@ def test_num_images_per_prompt(self): if key in self.batch_params: inputs[key] = batch_size * [inputs[key]] - images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images") + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output=self.output_name) assert images.shape[0] == batch_size * num_images_per_prompt @@ -286,8 +315,7 @@ def test_components_auto_cpu_offload_inference_consistent(self): image_slices = [] for pipe in [base_pipe, offload_pipe]: inputs = self.get_dummy_inputs() - image = pipe(**inputs, output="images") - + image = pipe(**inputs, output=self.output_name) image_slices.append(image[0, -3:, -3:, -1].flatten()) assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 @@ -308,8 +336,7 @@ def test_save_from_pretrained(self): image_slices = [] for pipe in pipes: inputs = self.get_dummy_inputs() - image = pipe(**inputs, output="images") - + image = pipe(**inputs, output=self.output_name) image_slices.append(image[0, -3:, -3:, -1].flatten()) assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 @@ -324,14 +351,250 @@ def test_guider_cfg(self, expected_max_diff=1e-2): pipe.update_components(guider=guider) inputs = self.get_dummy_inputs() - out_no_cfg = pipe(**inputs, output="images") + out_no_cfg = pipe(**inputs, output=self.output_name) # forward pass with CFG applied guider = ClassifierFreeGuidance(guidance_scale=7.5) pipe.update_components(guider=guider) inputs = self.get_dummy_inputs() - out_cfg = pipe(**inputs, output="images") + out_cfg = pipe(**inputs, output=self.output_name) assert out_cfg.shape == out_no_cfg.shape max_diff = torch.abs(out_cfg - out_no_cfg).max() assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference" + + +class TestModularModelCardContent: + def create_mock_block(self, name="TestBlock", description="Test block description"): + class MockBlock: + def __init__(self, name, description): + self.__class__.__name__ = name + self.description = description + self.sub_blocks = {} + + return MockBlock(name, description) + + def create_mock_blocks( + self, + class_name="TestBlocks", + description="Test pipeline description", + num_blocks=2, + components=None, + configs=None, + inputs=None, + outputs=None, + trigger_inputs=None, + model_name=None, + ): + class MockBlocks: + def __init__(self): + self.__class__.__name__ = class_name + self.description = description + self.sub_blocks = {} + self.expected_components = components or [] + self.expected_configs = configs or [] + self.inputs = inputs or [] + self.outputs = outputs or [] + self.trigger_inputs = trigger_inputs + self.model_name = model_name + + blocks = MockBlocks() + + # Add mock sub-blocks + for i in range(num_blocks): + block_name = f"block_{i}" + blocks.sub_blocks[block_name] = self.create_mock_block(f"Block{i}", f"Description for block {i}") + + return blocks + + def test_basic_model_card_content_structure(self): + """Test that all expected keys are present in the output.""" + blocks = self.create_mock_blocks() + content = generate_modular_model_card_content(blocks) + + expected_keys = [ + "pipeline_name", + "model_description", + "blocks_description", + "components_description", + "configs_section", + "inputs_description", + "outputs_description", + "trigger_inputs_section", + "tags", + ] + + for key in expected_keys: + assert key in content, f"Expected key '{key}' not found in model card content" + + assert isinstance(content["tags"], list), "Tags should be a list" + + def test_pipeline_name_generation(self): + """Test that pipeline name is correctly generated from blocks class name.""" + blocks = self.create_mock_blocks(class_name="StableDiffusionBlocks") + content = generate_modular_model_card_content(blocks) + + assert content["pipeline_name"] == "StableDiffusion Pipeline" + + def test_tags_generation_text_to_image(self): + """Test that text-to-image tags are correctly generated.""" + blocks = self.create_mock_blocks(trigger_inputs=None) + content = generate_modular_model_card_content(blocks) + + assert "modular-diffusers" in content["tags"] + assert "diffusers" in content["tags"] + assert "text-to-image" in content["tags"] + + def test_tags_generation_with_trigger_inputs(self): + """Test that tags are correctly generated based on trigger inputs.""" + # Test inpainting + blocks = self.create_mock_blocks(trigger_inputs=["mask", "prompt"]) + content = generate_modular_model_card_content(blocks) + assert "inpainting" in content["tags"] + + # Test image-to-image + blocks = self.create_mock_blocks(trigger_inputs=["image", "prompt"]) + content = generate_modular_model_card_content(blocks) + assert "image-to-image" in content["tags"] + + # Test controlnet + blocks = self.create_mock_blocks(trigger_inputs=["control_image", "prompt"]) + content = generate_modular_model_card_content(blocks) + assert "controlnet" in content["tags"] + + def test_tags_with_model_name(self): + """Test that model name is included in tags when present.""" + blocks = self.create_mock_blocks(model_name="stable-diffusion-xl") + content = generate_modular_model_card_content(blocks) + + assert "stable-diffusion-xl" in content["tags"] + + def test_components_description_formatting(self): + """Test that components are correctly formatted.""" + components = [ + ComponentSpec(name="vae", description="VAE component"), + ComponentSpec(name="text_encoder", description="Text encoder component"), + ] + blocks = self.create_mock_blocks(components=components) + content = generate_modular_model_card_content(blocks) + + assert "vae" in content["components_description"] + assert "text_encoder" in content["components_description"] + # Should be enumerated + assert "1." in content["components_description"] + + def test_components_description_empty(self): + """Test handling of pipelines without components.""" + blocks = self.create_mock_blocks(components=None) + content = generate_modular_model_card_content(blocks) + + assert "No specific components required" in content["components_description"] + + def test_configs_section_with_configs(self): + """Test that configs section is generated when configs are present.""" + configs = [ + ConfigSpec(name="num_train_timesteps", default=1000, description="Number of training timesteps"), + ] + blocks = self.create_mock_blocks(configs=configs) + content = generate_modular_model_card_content(blocks) + + assert "## Configuration Parameters" in content["configs_section"] + + def test_configs_section_empty(self): + """Test that configs section is empty when no configs are present.""" + blocks = self.create_mock_blocks(configs=None) + content = generate_modular_model_card_content(blocks) + + assert content["configs_section"] == "" + + def test_inputs_description_required_and_optional(self): + """Test that required and optional inputs are correctly formatted.""" + inputs = [ + InputParam(name="prompt", type_hint=str, required=True, description="The input prompt"), + InputParam(name="num_steps", type_hint=int, required=False, default=50, description="Number of steps"), + ] + blocks = self.create_mock_blocks(inputs=inputs) + content = generate_modular_model_card_content(blocks) + + assert "**Required:**" in content["inputs_description"] + assert "**Optional:**" in content["inputs_description"] + assert "prompt" in content["inputs_description"] + assert "num_steps" in content["inputs_description"] + assert "default: `50`" in content["inputs_description"] + + def test_inputs_description_empty(self): + """Test handling of pipelines without specific inputs.""" + blocks = self.create_mock_blocks(inputs=[]) + content = generate_modular_model_card_content(blocks) + + assert "No specific inputs defined" in content["inputs_description"] + + def test_outputs_description_formatting(self): + """Test that outputs are correctly formatted.""" + outputs = [ + OutputParam(name="images", type_hint=torch.Tensor, description="Generated images"), + ] + blocks = self.create_mock_blocks(outputs=outputs) + content = generate_modular_model_card_content(blocks) + + assert "images" in content["outputs_description"] + assert "Generated images" in content["outputs_description"] + + def test_outputs_description_empty(self): + """Test handling of pipelines without specific outputs.""" + blocks = self.create_mock_blocks(outputs=[]) + content = generate_modular_model_card_content(blocks) + + assert "Standard pipeline outputs" in content["outputs_description"] + + def test_trigger_inputs_section_with_triggers(self): + """Test that trigger inputs section is generated when present.""" + blocks = self.create_mock_blocks(trigger_inputs=["mask", "image"]) + content = generate_modular_model_card_content(blocks) + + assert "### Conditional Execution" in content["trigger_inputs_section"] + assert "`mask`" in content["trigger_inputs_section"] + assert "`image`" in content["trigger_inputs_section"] + + def test_trigger_inputs_section_empty(self): + """Test that trigger inputs section is empty when not present.""" + blocks = self.create_mock_blocks(trigger_inputs=None) + content = generate_modular_model_card_content(blocks) + + assert content["trigger_inputs_section"] == "" + + def test_blocks_description_with_sub_blocks(self): + """Test that blocks with sub-blocks are correctly described.""" + + class MockBlockWithSubBlocks: + def __init__(self): + self.__class__.__name__ = "ParentBlock" + self.description = "Parent block" + self.sub_blocks = { + "child1": self.create_child_block("ChildBlock1", "Child 1 description"), + "child2": self.create_child_block("ChildBlock2", "Child 2 description"), + } + + def create_child_block(self, name, desc): + class ChildBlock: + def __init__(self): + self.__class__.__name__ = name + self.description = desc + + return ChildBlock() + + blocks = self.create_mock_blocks() + blocks.sub_blocks["parent"] = MockBlockWithSubBlocks() + + content = generate_modular_model_card_content(blocks) + + assert "parent" in content["blocks_description"] + assert "child1" in content["blocks_description"] + assert "child2" in content["blocks_description"] + + def test_model_description_includes_block_count(self): + """Test that model description includes the number of blocks.""" + blocks = self.create_mock_blocks(num_blocks=5) + content = generate_modular_model_card_content(blocks) + + assert "5-block architecture" in content["model_description"] diff --git a/tests/modular_pipelines/wan/__init__.py b/tests/modular_pipelines/wan/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/wan/test_modular_pipeline_wan.py b/tests/modular_pipelines/wan/test_modular_pipeline_wan.py new file mode 100644 index 000000000000..c5ed9613e40f --- /dev/null +++ b/tests/modular_pipelines/wan/test_modular_pipeline_wan.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from diffusers.modular_pipelines import WanBlocks, WanModularPipeline + +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestWanModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = WanModularPipeline + pipeline_blocks_class = WanBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-wan-modular-pipe" + + params = frozenset(["prompt", "height", "width", "num_frames"]) + batch_params = frozenset(["prompt"]) + optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"]) + output_name = "videos" + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + @pytest.mark.skip(reason="num_videos_per_prompt") + def test_num_images_per_prompt(self): + pass diff --git a/tests/modular_pipelines/z_image/__init__.py b/tests/modular_pipelines/z_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/z_image/test_modular_pipeline_z_image.py b/tests/modular_pipelines/z_image/test_modular_pipeline_z_image.py new file mode 100644 index 000000000000..29da18fce61b --- /dev/null +++ b/tests/modular_pipelines/z_image/test_modular_pipeline_z_image.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from diffusers.modular_pipelines import ZImageAutoBlocks, ZImageModularPipeline + +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestZImageModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = ZImageModularPipeline + pipeline_blocks_class = ZImageAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-zimage-modular-pipe" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=5e-3) diff --git a/tests/pipelines/cosmos/test_cosmos2_5_transfer.py b/tests/pipelines/cosmos/test_cosmos2_5_transfer.py new file mode 100644 index 000000000000..932443bceea2 --- /dev/null +++ b/tests/pipelines/cosmos/test_cosmos2_5_transfer.py @@ -0,0 +1,386 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +import tempfile +import unittest + +import numpy as np +import torch +from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from diffusers import ( + AutoencoderKLWan, + Cosmos2_5_TransferPipeline, + CosmosControlNetModel, + CosmosTransformer3DModel, + UniPCMultistepScheduler, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np +from .cosmos_guardrail import DummyCosmosSafetyChecker + + +enable_full_determinism() + + +class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline): + @staticmethod + def from_pretrained(*args, **kwargs): + if "safety_checker" not in kwargs or kwargs["safety_checker"] is None: + safety_checker = DummyCosmosSafetyChecker() + device_map = kwargs.get("device_map", "cpu") + torch_dtype = kwargs.get("torch_dtype") + if device_map is not None or torch_dtype is not None: + safety_checker = safety_checker.to(device_map, dtype=torch_dtype) + kwargs["safety_checker"] = safety_checker + return Cosmos2_5_TransferPipeline.from_pretrained(*args, **kwargs) + + +class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Cosmos2_5_TransferWrapper + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + # Transformer with img_context support for Transfer2.5 + transformer = CosmosTransformer3DModel( + in_channels=16 + 1, + out_channels=16, + num_attention_heads=2, + attention_head_dim=16, + num_layers=2, + mlp_ratio=2, + text_embed_dim=32, + adaln_lora_dim=4, + max_size=(4, 32, 32), + patch_size=(1, 2, 2), + rope_scale=(2.0, 1.0, 1.0), + concat_padding_mask=True, + extra_pos_embed_type="learnable", + controlnet_block_every_n=1, + img_context_dim_in=32, + img_context_num_tokens=4, + img_context_dim_out=32, + ) + + torch.manual_seed(0) + controlnet = CosmosControlNetModel( + n_controlnet_blocks=2, + in_channels=16 + 1 + 1, # control latent channels + condition_mask + padding_mask + latent_channels=16 + 1 + 1, # base latent channels (16) + condition_mask (1) + padding_mask (1) = 18 + model_channels=32, + num_attention_heads=2, + attention_head_dim=16, + mlp_ratio=2, + text_embed_dim=32, + adaln_lora_dim=4, + patch_size=(1, 2, 2), + max_size=(4, 32, 32), + rope_scale=(2.0, 1.0, 1.0), + extra_pos_embed_type="learnable", # Match transformer's config + img_context_dim_in=32, + img_context_dim_out=32, + use_crossattn_projection=False, # Test doesn't need this projection + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler() + + torch.manual_seed(0) + config = Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = Qwen2_5_VLForConditionalGeneration(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "controlnet": controlnet, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": DummyCosmosSafetyChecker(), + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + "num_frames": 3, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_components_function(self): + init_components = self.get_dummy_components() + init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))} + pipe = self.pipeline_class(**init_components) + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (3, 3, 32, 32)) + self.assertTrue(torch.isfinite(generated_video).all()) + + def test_inference_with_controls(self): + """Test inference with control inputs (ControlNet).""" + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + # Add control video input - should be a video tensor + inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width + inputs["controls_conditioning_scale"] = 1.0 + + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (3, 3, 32, 32)) + self.assertTrue(torch.isfinite(generated_video).all()) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + _ = pipe(**inputs)[0] + + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not getattr(self, "test_attention_slicing", True): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_serialization_with_variants(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + model_components = [ + component_name + for component_name, component in pipe.components.items() + if isinstance(component, torch.nn.Module) + ] + # Remove components that aren't saved as standard diffusers models + if "safety_checker" in model_components: + model_components.remove("safety_checker") + variant = "fp16" + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) + + with open(f"{tmpdir}/model_index.json", "r") as f: + config = json.load(f) + + for subfolder in os.listdir(tmpdir): + if not os.path.isfile(subfolder) and subfolder in model_components: + folder_path = os.path.join(tmpdir, subfolder) + is_folder = os.path.isdir(folder_path) and subfolder in config + assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) + + def test_torch_dtype_dict(self): + components = self.get_dummy_components() + if not components: + self.skipTest("No dummy components defined.") + + pipe = self.pipeline_class(**components) + + specified_key = next(iter(components.keys())) + + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: + pipe.save_pretrained(tmpdirname, safe_serialization=False) + torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16} + loaded_pipe = self.pipeline_class.from_pretrained( + tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict + ) + + for name, component in loaded_pipe.components.items(): + # Skip components that are not loaded from disk or have special handling + if name == "safety_checker": + continue + if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"): + expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32)) + self.assertEqual( + component.dtype, + expected_dtype, + f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", + ) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + self.pipeline_class._optional_components.remove("safety_checker") + super().test_save_load_optional_components(expected_max_difference=expected_max_difference) + self.pipeline_class._optional_components.append("safety_checker") + + @unittest.skip( + "The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in " + "a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is " + "too large and slow to run on CI." + ) + def test_encode_prompt_works_in_isolation(self): + pass diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 74499bfa607a..f7476a21de57 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -27,6 +27,7 @@ FasterCacheTesterMixin, FirstBlockCacheTesterMixin, FluxIPAdapterTesterMixin, + MagCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, TaylorSeerCacheTesterMixin, @@ -41,6 +42,7 @@ class FluxPipelineFastTests( FasterCacheTesterMixin, FirstBlockCacheTesterMixin, TaylorSeerCacheTesterMixin, + MagCacheTesterMixin, unittest.TestCase, ): pipeline_class = FluxPipeline diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 7db5f4da89ca..f0eba0026b70 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -35,6 +35,7 @@ from diffusers.hooks import apply_group_offloading from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.hooks.first_block_cache import FirstBlockCacheConfig +from diffusers.hooks.mag_cache import MagCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig from diffusers.image_processor import VaeImageProcessor @@ -2405,7 +2406,11 @@ def test_pipeline_level_group_offloading_sanity_checks(self): if name not in [exclude_module_name] and isinstance(component, torch.nn.Module): # `component.device` prints the `onload_device` type. We should probably override the # `device` property in `ModelMixin`. - component_device = next(component.parameters())[0].device + # Skip modules with no parameters (e.g., dummy safety checkers with only buffers) + params = list(component.parameters()) + if not params: + continue + component_device = params[0].device self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type) @require_torch_accelerator @@ -2976,6 +2981,59 @@ def run_forward(pipe): ) +class MagCacheTesterMixin: + mag_cache_config = MagCacheConfig( + threshold=0.06, + max_skip_steps=3, + retention_ratio=0.2, + num_inference_steps=50, + mag_ratios=torch.ones(50), + ) + + def test_mag_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + # Match the config steps + inputs["num_inference_steps"] = 50 + return pipe(**inputs)[0] + + # 1. Run inference without MagCache (Baseline) + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # 2. Run inference with MagCache ENABLED + pipe = create_pipe() + pipe.transformer.enable_cache(self.mag_cache_config) + output = run_forward(pipe).flatten() + image_slice_enabled = np.concatenate((output[:8], output[-8:])) + + # 3. Run inference with MagCache DISABLED + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose(original_image_slice, image_slice_enabled, atol=expected_atol), ( + "MagCache outputs should not differ too much from baseline." + ) + + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-4), ( + "Outputs after disabling cache should match original inference exactly." + ) + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. diff --git a/tests/pipelines/z_image/test_z_image_inpaint.py b/tests/pipelines/z_image/test_z_image_inpaint.py new file mode 100644 index 000000000000..e904a4e44bd7 --- /dev/null +++ b/tests/pipelines/z_image/test_z_image_inpaint.py @@ -0,0 +1,396 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import unittest + +import numpy as np +import torch +from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + ZImageInpaintPipeline, + ZImageTransformer2DModel, +) +from diffusers.utils.testing_utils import floats_tensor + +from ...testing_utils import torch_device +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations +# Cannot use enable_full_determinism() which sets it to True +# Note: Z-Image does not support FP16 inference due to complex64 RoPE embeddings +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(False) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +if hasattr(torch.backends, "cuda"): + torch.backends.cuda.matmul.allow_tf32 = False + + +class ZImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ZImageInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset(["image", "mask_image"]) + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "strength", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def tearDown(self): + super().tearDown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = ZImageTransformer2DModel( + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=32, + n_layers=2, + n_refiner_layers=1, + n_heads=2, + n_kv_heads=2, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=16, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[8, 4, 4], + axes_lens=[256, 32, 32], + ) + # `x_pad_token` and `cap_pad_token` are initialized with `torch.empty` which contains + # uninitialized memory. Set them to known values for deterministic test behavior. + with torch.no_grad(): + transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data)) + transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data)) + + torch.manual_seed(0) + vae = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + block_out_channels=[32, 64], + layers_per_block=1, + latent_channels=16, + norm_num_groups=32, + sample_size=32, + scaling_factor=0.3611, + shift_factor=0.1159, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen3Config( + hidden_size=16, + intermediate_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=151936, + max_position_embeddings=512, + ) + text_encoder = Qwen3Model(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + import random + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + # Create mask: 1 = inpaint region, 0 = preserve region + mask_image = torch.zeros((1, 1, 32, 32), device=device) + mask_image[:, :, 8:24, 8:24] = 1.0 # Inpaint center region + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "image": image, + "mask_image": mask_image, + "strength": 1.0, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "cfg_normalization": False, + "cfg_truncation": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "np", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (32, 32, 3)) + + def test_inference_batch_single_identical(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_num_images_per_prompt(self): + import inspect + + sig = inspect.signature(self.pipeline_class.__call__) + + if "num_images_per_prompt" not in sig.parameters: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1, 2] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] + + assert images.shape[0] == batch_size * num_images_per_prompt + + del pipe + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.7): + import random + + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + # Generate a larger image for the input + inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu") + # Generate a larger mask for the input + mask = torch.zeros((1, 1, 128, 128), device="cpu") + mask[:, :, 32:96, 32:96] = 1.0 + inputs["mask_image"] = mask + output_without_tiling = pipe(**inputs)[0] + + # With tiling (standard AutoencoderKL doesn't accept parameters) + pipe.vae.enable_tiling() + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu") + inputs["mask_image"] = mask + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-3): + # Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance + # Inpainting mask blending adds additional numerical variance + super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference) + + def test_group_offloading_inference(self): + # Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine. + self.skipTest("Using test_pipeline_level_group_offloading_inference instead") + + def test_save_load_float16(self, expected_max_diff=1e-2): + # Z-Image does not support FP16 due to complex64 RoPE embeddings + self.skipTest("Z-Image does not support FP16 inference") + + def test_float16_inference(self, expected_max_diff=5e-2): + # Z-Image does not support FP16 due to complex64 RoPE embeddings + self.skipTest("Z-Image does not support FP16 inference") + + def test_strength_parameter(self): + """Test that strength parameter affects the output correctly.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + # Test with different strength values + inputs_low_strength = self.get_dummy_inputs(device) + inputs_low_strength["strength"] = 0.2 + + inputs_high_strength = self.get_dummy_inputs(device) + inputs_high_strength["strength"] = 0.8 + + # Both should complete without errors + output_low = pipe(**inputs_low_strength).images[0] + output_high = pipe(**inputs_high_strength).images[0] + + # Outputs should be different (different amount of transformation) + self.assertFalse(np.allclose(output_low, output_high, atol=1e-3)) + + def test_invalid_strength(self): + """Test that invalid strength values raise appropriate errors.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + + inputs = self.get_dummy_inputs(device) + + # Test strength < 0 + inputs["strength"] = -0.1 + with self.assertRaises(ValueError): + pipe(**inputs) + + # Test strength > 1 + inputs["strength"] = 1.5 + with self.assertRaises(ValueError): + pipe(**inputs) + + def test_mask_inpainting(self): + """Test that the mask properly controls which regions are inpainted.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + # Generate with full mask (inpaint everything) + inputs_full = self.get_dummy_inputs(device) + inputs_full["mask_image"] = torch.ones((1, 1, 32, 32), device=device) + + # Generate with no mask (preserve everything) + inputs_none = self.get_dummy_inputs(device) + inputs_none["mask_image"] = torch.zeros((1, 1, 32, 32), device=device) + + # Both should complete without errors + output_full = pipe(**inputs_full).images[0] + output_none = pipe(**inputs_none).images[0] + + # Outputs should be different (full inpaint vs preserve) + self.assertFalse(np.allclose(output_full, output_none, atol=1e-3))