diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 8aa30ee082ff..ba401e7fdef1 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -777,7 +777,8 @@ def _prepare_sequence( # Pad token feats_cat = torch.cat(feats, dim=0) - feats_cat[torch.cat(inner_pad_mask)] = pad_token + mask = torch.cat(inner_pad_mask).unsqueeze(-1) + feats_cat = torch.where(mask, pad_token, feats_cat) feats = list(feats_cat.split(item_seqlens, dim=0)) # RoPE diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image.py b/src/diffusers/pipelines/z_image/pipeline_z_image.py index 959368ec1cd1..46403a0719cd 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image.py @@ -486,6 +486,15 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + + if self.do_classifier_free_guidance and self._cfg_truncation is not None and float(self._cfg_truncation) <= 1: + _precomputed_t_norms = ((1000 - timesteps.float()) / 1000).tolist() + else: + _precomputed_t_norms = None + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -495,17 +504,9 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) timestep = (1000 - timestep) / 1000 - # Normalized time for time-aware config (0 at start, 1 at end) - t_norm = timestep[0].item() - - # Handle cfg truncation current_guidance_scale = self.guidance_scale - if ( - self.do_classifier_free_guidance - and self._cfg_truncation is not None - and float(self._cfg_truncation) <= 1 - ): - if t_norm > self._cfg_truncation: + if _precomputed_t_norms is not None: + if _precomputed_t_norms[i] > self._cfg_truncation: current_guidance_scale = 0.0 # Run CFG only if configured AND scale is non-zero