Skip to content

Remove compile bottlenecks from ZImage pipeline#13461

Open
hitchhiker3010 wants to merge 3 commits intohuggingface:mainfrom
hitchhiker3010:main
Open

Remove compile bottlenecks from ZImage pipeline#13461
hitchhiker3010 wants to merge 3 commits intohuggingface:mainfrom
hitchhiker3010:main

Conversation

@hitchhiker3010
Copy link
Copy Markdown

What does this PR do?

Fixes performance issues identified by profiling ZImagePipeline with torch.profiler as part of #13401 .

What does this PR do?

Profiled ZImagePipeline (using Tongyi-MAI/Z-Image-Turbo) in both eager and torch.compile modes following the profiling guide. The Chrome traces revealed two device-to-host (DtoH) synchronization points that break asynchronous GPU execution and prevent torch.compile from yielding its full speedup.

Pipeline denoising loop: t_norm = timestep[0].item() DtoH sync

  1. Inside the denoising loop, timestep[0].item() triggers a GPU→CPU sync every step to read t_norm for CFG truncation logic. Since the full timestep schedule is known before the loop begins, we precompute all t_norm values into a plain Python list before entering the loop and index into it with i.
  2. This also lets us set scheduler.set_begin_index(0) upfront to avoid the DtoH sync in _init_step_index (same pattern as Avoid DtoH sync from access of nonzero() item in scheduler #11696 )

Profiling ZImagePipeline
GPU - L4
num_inference_steps - 4,
guidance_scale - 0.0 ( Guidance should be 0 for the Turbo models)

Before
image
The first scheduler_step took 657.8µs
Number of cudaStreamSynchronize blocks - 19

After
image
The first scheduler_step took 15.49 µs after this fix
Number of cudaStreamSynchronize blocks - 13
Part of #13401 .

Before submitting

Who can review?

@sayakpaul @dg845

@github-actions github-actions bot added pipelines size/S PR with diff < 50 LOC labels Apr 13, 2026
@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Apr 14, 2026
@sayakpaul
Copy link
Copy Markdown
Member

Thanks for your PR! Can we eliminate all the cudaStreamSynchronize calls?

…former

Boolean mask indexing (tensor[mask] = val) implicitly calls nonzero(),
which triggers a DtoH sync that stalls the CPU while the GPU queue drains.
Replacing it with torch.where eliminates these syncs from the transformer's
pad-token assignment.

Profiling (4-step turbo, fix_2 vs fix_1):
- Eager: nonzero CPU time drops from ~2091 ms to <1 ms; index_put eliminated
- Compile: nonzero CPU time drops from ~3057 ms to <1 ms; index_put eliminated
@github-actions github-actions bot added models size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 14, 2026
@hitchhiker3010
Copy link
Copy Markdown
Author

hitchhiker3010 commented Apr 14, 2026

Here are some comparison stats between commit_1 and commit_2

Metric commit_1 eager commit_2 eager commit_1 compile commit_2 compile
nonzero calls 28 4 28 4
nonzero CPU time 2091 ms 0.72 ms 3057 ms 0.49 ms
index_put calls 20 0 36 0
index_put total 4183 ms 0 ms 9172 ms 0 ms
cudaStreamSynchronize calls 13 5 13 5
cudaStreamSynchronize total 2089 ms 0.47 ms 3055 ms 0.32 ms

@hitchhiker3010
Copy link
Copy Markdown
Author

all the trace files can be accessed here.

The cudaStreamSynchronize traces from the Denoising phase are eliminated now, the remaining 5 cudaStreamSynchronize seem to be from the text encoding phase, should we fix them too?

cc: @sayakpaul

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models performance Anything related to performance improvements, profiling and benchmarking pipelines size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants