Skip to content

Commit 26bb7fa

Browse files
[ptxla] fix pytorch xla inference on TPUs. (#13463)
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
1 parent 5063aa5 commit 26bb7fa

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -877,10 +877,7 @@ def __call__(
877877
self.scheduler.config.get("max_shift", 1.15),
878878
)
879879

880-
if XLA_AVAILABLE:
881-
timestep_device = "cpu"
882-
else:
883-
timestep_device = device
880+
timestep_device = device
884881
timesteps, num_inference_steps = retrieve_timesteps(
885882
self.scheduler,
886883
num_inference_steps,

0 commit comments

Comments
 (0)