From b04b66f36d9e2dc906667ec1606f0842cbcfa7fa Mon Sep 17 00:00:00 2001 From: hitchhiker3010 Date: Tue, 14 Apr 2026 01:11:04 +0530 Subject: [PATCH 1/2] [core] Remove DtoH syncs from ZImage pipeline denoising loop --- .../pipelines/z_image/pipeline_z_image.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image.py b/src/diffusers/pipelines/z_image/pipeline_z_image.py index 959368ec1cd1..46403a0719cd 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image.py @@ -486,6 +486,15 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + + if self.do_classifier_free_guidance and self._cfg_truncation is not None and float(self._cfg_truncation) <= 1: + _precomputed_t_norms = ((1000 - timesteps.float()) / 1000).tolist() + else: + _precomputed_t_norms = None + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -495,17 +504,9 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) timestep = (1000 - timestep) / 1000 - # Normalized time for time-aware config (0 at start, 1 at end) - t_norm = timestep[0].item() - - # Handle cfg truncation current_guidance_scale = self.guidance_scale - if ( - self.do_classifier_free_guidance - and self._cfg_truncation is not None - and float(self._cfg_truncation) <= 1 - ): - if t_norm > self._cfg_truncation: + if _precomputed_t_norms is not None: + if _precomputed_t_norms[i] > self._cfg_truncation: current_guidance_scale = 0.0 # Run CFG only if configured AND scale is non-zero From 2da67248e0d4c2cda112ed9e962dbcee90dffa08 Mon Sep 17 00:00:00 2001 From: hitchhiker3010 Date: Tue, 14 Apr 2026 13:02:56 +0530 Subject: [PATCH 2/2] [core] Replace boolean mask indexing with torch.where in ZImage transformer Boolean mask indexing (tensor[mask] = val) implicitly calls nonzero(), which triggers a DtoH sync that stalls the CPU while the GPU queue drains. Replacing it with torch.where eliminates these syncs from the transformer's pad-token assignment. Profiling (4-step turbo, fix_2 vs fix_1): - Eager: nonzero CPU time drops from ~2091 ms to <1 ms; index_put eliminated - Compile: nonzero CPU time drops from ~3057 ms to <1 ms; index_put eliminated --- src/diffusers/models/transformers/transformer_z_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 8aa30ee082ff..ba401e7fdef1 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -777,7 +777,8 @@ def _prepare_sequence( # Pad token feats_cat = torch.cat(feats, dim=0) - feats_cat[torch.cat(inner_pad_mask)] = pad_token + mask = torch.cat(inner_pad_mask).unsqueeze(-1) + feats_cat = torch.where(mask, pad_token, feats_cat) feats = list(feats_cat.split(item_seqlens, dim=0)) # RoPE