From 38007b441cdd258fa52423cad517ecaeab25d3c0 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Fri, 29 May 2026 18:15:09 +0000 Subject: [PATCH 1/3] feat: add torchtpu --- docs/source/en/optimization/tpu.md | 205 ++++++++++++++++++ .../models/unets/unet_2d_condition.py | 15 +- src/diffusers/pipelines/flux/pipeline_flux.py | 6 +- .../flux/pipeline_flux_controlnet.py | 6 +- .../pipelines/flux/pipeline_flux_inpaint.py | 6 +- .../pipelines/flux/pipeline_flux_kontext.py | 6 +- .../flux/pipeline_flux_kontext_inpaint.py | 6 +- src/diffusers/pipelines/pipeline_utils.py | 97 ++++++++- .../pipeline_stable_diffusion_xl.py | 4 +- .../pipeline_stable_diffusion_xl_img2img.py | 4 +- .../pipeline_stable_diffusion_xl_inpaint.py | 4 +- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 11 + src/diffusers/utils/torch_utils.py | 16 +- 14 files changed, 366 insertions(+), 21 deletions(-) create mode 100644 docs/source/en/optimization/tpu.md diff --git a/docs/source/en/optimization/tpu.md b/docs/source/en/optimization/tpu.md new file mode 100644 index 000000000000..cea7dd0e388c --- /dev/null +++ b/docs/source/en/optimization/tpu.md @@ -0,0 +1,205 @@ + + +# TorchTPU + +[TorchTPU](https://github.com/pytorch/tpu) registers a `"tpu"` device type with PyTorch, enabling you to run +diffusers pipelines on Google Cloud TPUs (v4, v5p, v5e, …) with minimal code changes. + +Three execution modes are available: + +| Mode | How to activate | Speed | Notes | +|---|---|---|---| +| **Lazy** (default) | just `import torch_tpu` | baseline | XLA traces the graph lazily | +| **Eager** | `set_eager_mode(EagerMode.DEFER_NEVER)` | medium | dispatch ops eagerly | +| **Compile** | `pipe.enable_tpu_compile()` | fastest (~4–6×) | static compilation with `TpuBackend` | + +## Installation + +Follow the [TorchTPU installation guide](https://github.com/pytorch/tpu). After installation, +`import torch_tpu` registers the `"tpu"` device automatically. + +## Text encoders always stay on CPU + +XLA's static graph compiler does not support certain dynamic ops used in text encoders (notably +`index_select` on large embedding tables). Text encoders must therefore remain on CPU. Their +output embeddings are moved to the TPU after encoding. + +Diffusers handles this transparently: +- `_execution_device` detects any component on TPU and returns that device. +- `encode_prompt` runs the text encoder on its own device (`cpu`) and moves the resulting + embeddings to the execution device (TPU). +- `randn_tensor` generates initial noise on CPU and moves it to TPU, avoiding a TPU RNG + unaligned DUS (dynamic-update-slice) bug. + +## Basic usage (lazy mode) + +```python +import torch +import torch_tpu # noqa: F401 — registers torch.tpu + +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, +) + +# Move only the denoising components to TPU; text encoders stay on CPU. +pipe.transformer.to("tpu") +pipe.vae.to("tpu") + +# _execution_device is now "tpu" automatically. +image = pipe( + prompt="a golden retriever surfing a wave, photorealistic", + height=1024, + width=1024, + num_inference_steps=4, + guidance_scale=0.0, +).images[0] + +image.save("output.png") +``` + +## Compiled mode (recommended for production) + +`torch.compile` with `TpuBackend` traces the transformer statically and gives the largest +speedup. The first call (warmup) is slow because it triggers compilation; subsequent calls +reuse the compiled graph. + +```python +import torch +import torch_tpu # noqa: F401 + +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, +) +pipe.transformer.to("tpu") +pipe.vae.to("tpu") + +# Compile TPU components with TpuBackend. +# Also applies AttnProcessor to replace SDP-based attention (required for XLA). +pipe.enable_tpu_compile() + +# Warmup — triggers static graph compilation. +pipe.tpu_warmup( + prompt="warmup", + height=1024, + width=1024, + num_inference_steps=4, + guidance_scale=0.0, +) + +# Timed inference reuses the compiled graph. +image = pipe( + prompt="a golden retriever surfing a wave, photorealistic", + height=1024, + width=1024, + num_inference_steps=4, + guidance_scale=0.0, +).images[0] + +image.save("output.png") +``` + +## SDXL + +SDXL uses a UNet instead of a transformer. The same approach applies. + +```python +import torch +import torch_tpu # noqa: F401 + +from diffusers import StableDiffusionXLPipeline + +pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.bfloat16, + use_safetensors=True, +) + +pipe.unet.to("tpu") +pipe.vae.to("tpu") + +pipe.enable_tpu_compile() +pipe.tpu_warmup( + prompt="warmup", + height=1024, + width=1024, + num_inference_steps=20, + guidance_scale=7.5, +) + +image = pipe( + prompt="a golden retriever surfing a wave, photorealistic", + height=1024, + width=1024, + num_inference_steps=20, + guidance_scale=7.5, +).images[0] + +image.save("output.png") +``` + +> [!NOTE] +> In SDXL **lazy/eager mode** (without `enable_tpu_compile`), `time_proj` inside the UNet +> runs on CPU automatically to avoid an XLA unaligned DUS crash. `enable_tpu_compile` uses +> `TpuBackend` which handles the layout internally, so no wrapper is needed in compiled mode. + +## Eager mode + +Eager mode dispatches ops immediately instead of accumulating a lazy graph. Enter it +**before loading or moving models** to TPU: + +```python +import torch +import torch_tpu # noqa: F401 +from torch_tpu._internal.execution_mode import EagerMode, set_eager_mode + +eager_ctx = set_eager_mode(EagerMode.DEFER_NEVER) +eager_ctx.__enter__() + +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) +pipe.transformer.to("tpu") +pipe.vae.to("tpu") + +image = pipe(prompt="a cat", height=1024, width=1024, num_inference_steps=4, guidance_scale=0.0).images[0] +image.save("output.png") + +eager_ctx.__exit__(None, None, None) +``` + +## Performance benchmarks (v5p, BF16) + +| Model | Mode | Steps | Resolution | Time/iter | +|---|---|---|---|---| +| FLUX.2-klein-9B | Lazy | 4 | 1024×1024 | 7.82 s | +| FLUX.2-klein-9B | Compile | 4 | 1024×1024 | 1.94 s | +| ERNIE-Image-Turbo | Lazy | 8 | 1024×1024 | 5.97 s | +| ERNIE-Image-Turbo | Compile | 8 | 1024×1024 | 2.24 s | +| Wan2.2-TI2V (video) | Eager | 50 | 480×832 | 82.2 s | +| Wan2.2-TI2V (video) | Compile | 50 | 480×832 | 14.2 s | + +## API reference + +### `enable_tpu_compile` + +[[autodoc]] diffusers.DiffusionPipeline.enable_tpu_compile + +### `tpu_warmup` + +[[autodoc]] diffusers.DiffusionPipeline.tpu_warmup diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index deae25899475..22e1f9ac2a61 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -26,6 +26,7 @@ deprecate, logging, ) +from ...utils.torch_utils import is_compiled_module from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import ( @@ -855,10 +856,11 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" is_npu = sample.device.type == "npu" + is_tpu = sample.device.type == "tpu" if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = torch.float32 if (is_mps or is_npu or is_tpu) else torch.float64 else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = torch.int32 if (is_mps or is_npu or is_tpu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) @@ -866,7 +868,14 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - t_emb = self.time_proj(timesteps) + # On TPU in eager/lazy mode, torch.cat([sin, cos], dim=-1) inside time_proj + # lands at an unaligned offset in the XLA DUS fusion emitter → crash. + # torch.compile with TpuBackend handles this internally, so only wrap for + # non-compiled modules. + if sample.device.type == "tpu" and not is_compiled_module(self): + t_emb = self.time_proj(timesteps.cpu()).to(sample.device) + else: + t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index e125924adf7f..c88ca0735ba9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -251,7 +251,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -296,7 +297,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index da81563e4a66..b1c9f848f2c8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -282,7 +282,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -327,7 +328,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 914274397944..f8d81faa7e00 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -275,7 +275,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -321,7 +322,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index efddc6cea139..d6a0df89ac32 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -297,7 +297,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -343,7 +344,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py index c85299eedcd3..8bf645cab04c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -330,7 +330,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -376,7 +377,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1fa4db90d995..b046b7436b01 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -64,10 +64,12 @@ is_bitsandbytes_version, is_hpu_available, is_torch_npu_available, + is_torch_tpu_available, is_torch_version, is_transformers_version, logging, numpy_to_pil, + requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -1162,6 +1164,13 @@ def _execution_device(self): except ValueError: pass + # For TPU pipelines, text encoders stay on CPU while denoising components + # (transformer, unet, vae) live on TPU. The standard self.device check below + # would return CPU (first component). Detect any TPU component and prefer it. + for name, model in self.components.items(): + if isinstance(model, torch.nn.Module) and model.device.type == "tpu": + return model.device + for name, model in self.components.items(): if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: continue @@ -2387,3 +2396,89 @@ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): else: self.vae.unfuse_qkv_projections() self.fusing_vae = False + + def enable_tpu_compile( + self, + model_names: Optional[List[str]] = None, + **compile_kwargs, + ) -> None: + """Compile pipeline components that are on TPU using ``torch.compile`` with the ``TpuBackend``. + + Before compiling, each component that exposes ``set_attn_processor`` has ``AttnProcessor`` + applied. This replaces ``AttnProcessor2_0`` (SDP-based) which triggers XLA fusion-emitter + crashes in eager/lazy mode. ``TpuBackend`` handles the resulting ``torch.cat`` layout + internally during static tracing, so no additional wrapper is needed at compile time. + + Args: + model_names (`list[str]`, *optional*): + Names of pipeline components to compile. Defaults to all ``torch.nn.Module`` + components currently resident on a TPU device. + **compile_kwargs: + Extra keyword arguments forwarded to ``torch.compile``. ``backend`` defaults to + ``TpuBackend()`` and ``dynamic`` defaults to ``False`` (required for static tracing). + + Example: + ```python + import torch + import torch_tpu # noqa: F401 + + pipe.transformer.to("tpu") + pipe.vae.to("tpu") + pipe.enable_tpu_compile() + ``` + """ + requires_backends(self, "torch_tpu") + from torch_tpu._internal.compile import TpuBackend + + from ..models.attention_processor import AttnProcessor + + if model_names is None: + model_names = [ + name + for name, comp in self.components.items() + if isinstance(comp, torch.nn.Module) and comp.device.type == "tpu" + ] + + for name in model_names: + component = getattr(self, name, None) + if not isinstance(component, torch.nn.Module): + logger.warning(f"`enable_tpu_compile`: component '{name}' is not a nn.Module, skipping.") + continue + if is_compiled_module(component): + logger.warning(f"`enable_tpu_compile`: component '{name}' is already compiled, skipping.") + continue + if hasattr(component, "set_attn_processor"): + component.set_attn_processor(AttnProcessor()) + compile_kwargs.setdefault("backend", TpuBackend()) + compile_kwargs.setdefault("dynamic", False) + logger.info(f"Compiling '{name}' with TpuBackend.") + setattr(self, name, torch.compile(component, **compile_kwargs)) + + def tpu_warmup(self, *args, **kwargs) -> None: + """Run a single forward pass to trigger XLA / ``TpuBackend`` compilation. + + Call this after ``enable_tpu_compile`` and before timed inference. The warmup + pass compiles the static computation graphs; subsequent calls reuse the compiled + graphs and run at full speed. + + Args: + *args: Positional arguments forwarded to the pipeline ``__call__``. + **kwargs: Keyword arguments forwarded to the pipeline ``__call__``. + + Example: + ```python + pipe.tpu_warmup( + prompt="warmup", + height=1024, + width=1024, + num_inference_steps=4, + guidance_scale=0.0, + ) + ``` + """ + logger.info("Running TPU warmup pass to trigger XLA compilation...") + with torch.no_grad(): + self(*args, **kwargs) + if hasattr(torch, "tpu") and hasattr(torch.tpu, "synchronize"): + torch.tpu.synchronize() + logger.info("TPU warmup complete.") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 8148fac123e0..856dc27a7125 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -402,7 +402,7 @@ def encode_prompt( f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: @@ -463,7 +463,7 @@ def encode_prompt( ) negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), + uncond_input.input_ids.to(text_encoder.device), output_hidden_states=True, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 19ccfab3de0a..eadaf543b9d0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -420,7 +420,7 @@ def encode_prompt( f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: @@ -481,7 +481,7 @@ def encode_prompt( ) negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), + uncond_input.input_ids.to(text_encoder.device), output_hidden_states=True, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 7382d597102c..f03d01c90ec1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -524,7 +524,7 @@ def encode_prompt( f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: @@ -585,7 +585,7 @@ def encode_prompt( ) negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), + uncond_input.input_ids.to(text_encoder.device), output_hidden_states=True, ) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 4b41622b2a4a..da402b602802 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -114,6 +114,7 @@ is_torch_available, is_torch_mlu_available, is_torch_npu_available, + is_torch_tpu_available, is_torch_version, is_torch_xla_available, is_torch_xla_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 5323dfe5ec82..92799bee6de2 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") +_torch_tpu_available, _torch_tpu_version = _is_package_available("torch_tpu") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _kernels_available, _kernels_version = _is_package_available("kernels") @@ -250,6 +251,10 @@ def is_torch_mlu_available(): return _torch_mlu_available +def is_torch_tpu_available(): + return _torch_tpu_available + + def is_flax_available(): return _flax_available @@ -579,6 +584,11 @@ def is_av_available(): torchao` """ +TORCH_TPU_IMPORT_ERROR = """ +{0} requires the torch_tpu library but it was not found in your environment. Please follow the installation +instructions at https://github.com/pytorch/tpu +""" + QUANTO_IMPORT_ERROR = """ {0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip install optimum-quanto` @@ -630,6 +640,7 @@ def is_av_available(): ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)), ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)), ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("torch_tpu", (is_torch_tpu_available, TORCH_TPU_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index c314a8609bec..b66b8fbc9913 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -33,12 +33,13 @@ import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "tpu": False, "default": True} BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, + "tpu": getattr(getattr(torch, "tpu", None), "empty_cache", None), "default": None, } BACKEND_DEVICE_COUNT = { @@ -46,6 +47,7 @@ "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, + "tpu": lambda: getattr(getattr(torch, "tpu", None), "device_count", lambda: 0)(), "default": 0, } BACKEND_MANUAL_SEED = { @@ -53,6 +55,9 @@ "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, + # TPU latents are always generated on CPU (TPU RNG has unaligned DUS bug), + # so CPU seeding is the correct behaviour here. + "tpu": torch.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { @@ -60,6 +65,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "tpu": None, "default": None, } BACKEND_RESET_MAX_MEMORY_ALLOCATED = { @@ -67,6 +73,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "tpu": None, "default": None, } BACKEND_MAX_MEMORY_ALLOCATED = { @@ -74,6 +81,7 @@ "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, + "tpu": 0, "default": 0, } BACKEND_SYNCHRONIZE = { @@ -81,6 +89,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, + "tpu": getattr(getattr(torch, "tpu", None), "synchronize", None), "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -166,6 +175,11 @@ def randn_tensor( rand_device = device batch_size = shape[0] + # TPU RNG has an unaligned DUS (dynamic-update-slice) bug — generate on CPU + # and move to TPU via the existing .to(device) call at the end. + if device is not None and device.type == "tpu": + rand_device = torch.device("cpu") + layout = layout or torch.strided device = device or torch.device("cpu") From 8ed15f91cd128748618add135ab4815d4e31ff22 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 2 Jun 2026 13:51:03 +0000 Subject: [PATCH 2/3] feat:draft TorchTPU support --- .../pipelines/ernie_image/pipeline_ernie_image.py | 4 ++-- src/diffusers/pipelines/flux2/pipeline_flux2_klein.py | 5 +++-- src/diffusers/pipelines/pipeline_utils.py | 10 ++++++---- src/diffusers/pipelines/wan/pipeline_wan.py | 3 ++- src/diffusers/pipelines/wan/pipeline_wan_animate.py | 3 ++- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 3 ++- src/diffusers/pipelines/wan/pipeline_wan_vace.py | 3 ++- .../pipelines/wan/pipeline_wan_video2video.py | 3 ++- 8 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 11fce6a204bf..df5b27ace653 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -114,7 +114,7 @@ def _enhance_prompt_with_pe( tokenize=False, add_generation_prompt=False, # "Output:" is already in the user block ) - inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(self.pe.device) output_ids = self.pe.generate( **inputs, max_new_tokens=self.pe_tokenizer.model_max_length, @@ -155,7 +155,7 @@ def encode_prompt( else: ids = [0] - input_ids = torch.tensor([ids], device=device) + input_ids = torch.tensor([ids], device=self.text_encoder.device) with torch.no_grad(): outputs = self.text_encoder( input_ids=input_ids, diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index 9a3468525c0c..69788479dc86 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -241,8 +241,9 @@ def _get_qwen3_prompt_embeds( all_input_ids.append(inputs["input_ids"]) all_attention_masks.append(inputs["attention_mask"]) - input_ids = torch.cat(all_input_ids, dim=0).to(device) - attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + model_device = text_encoder.device + input_ids = torch.cat(all_input_ids, dim=0).to(model_device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(model_device) # Forward pass through the model output = text_encoder( diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index b046b7436b01..b55d3fed29ad 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1164,11 +1164,13 @@ def _execution_device(self): except ValueError: pass - # For TPU pipelines, text encoders stay on CPU while denoising components - # (transformer, unet, vae) live on TPU. The standard self.device check below - # would return CPU (first component). Detect any TPU component and prefer it. + # When text encoders are offloaded to CPU while the denoising backbone + # (unet, transformer, vae) runs on an accelerator, self.device returns CPU + # (first component). Prefer any non-CPU, non-meta component device so that + # scheduler and latent tensors land on the accelerator. This covers TPU, + # NPU (npu), Intel GPU (xpu), Habana (hpu), and any other backend. for name, model in self.components.items(): - if isinstance(model, torch.nn.Module) and model.device.type == "tpu": + if isinstance(model, torch.nn.Module) and model.device.type not in ("cpu", "meta"): return model.device for name, model in self.components.items(): diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index be2d53f17932..08c4a2b23c66 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -182,7 +182,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_animate.py b/src/diffusers/pipelines/wan/pipeline_wan_animate.py index 5806032c0142..91960fad7d36 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_animate.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_animate.py @@ -259,7 +259,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 8061f67ab6b9..59a844e088ae 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -223,7 +223,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index b0896d382d67..8c72adf09d6a 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -228,7 +228,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index 8993475a2851..9064000ab35b 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -246,7 +246,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( From e343c0aef0522ab8e673935f1cbcb098624ecb41 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 9 Jun 2026 15:38:59 +0000 Subject: [PATCH 3/3] fix: wan overflow issue + compile mode error on sdxl --- src/diffusers/models/unets/unet_2d_condition.py | 6 +++--- src/diffusers/pipelines/wan/pipeline_wan.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 22e1f9ac2a61..94a079dca027 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -870,9 +870,9 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | # On TPU in eager/lazy mode, torch.cat([sin, cos], dim=-1) inside time_proj # lands at an unaligned offset in the XLA DUS fusion emitter → crash. - # torch.compile with TpuBackend handles this internally, so only wrap for - # non-compiled modules. - if sample.device.type == "tpu" and not is_compiled_module(self): + # torch.compile with TpuBackend handles this internally, so skip the CPU + # workaround when we're inside a compiled graph. + if sample.device.type == "tpu" and not torch.compiler.is_compiling(): t_emb = self.time_proj(timesteps.cpu()).to(sample.device) else: t_emb = self.time_proj(timesteps) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 08c4a2b23c66..b7e8a71cd50c 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -655,7 +655,7 @@ def __call__( self._current_timestep = None if not output_type == "latent": - latents = latents.to(self.vae.dtype) + latents = latents.to(self.vae.device, dtype=self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1)