Skip to content

Commit 06ba22d

Browse files
committed
Transfer env states all at once for rendering
1 parent f5cb4a9 commit 06ba22d

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/lerax/callback/logging/callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def render_frame(env_state) -> np.ndarray:
308308

309309
init_key, policy_key, rollout_key = jr.split(key, 3)
310310
env_states = run_rollout(env, policy, init_key, policy_key, rollout_key)
311-
jax.block_until_ready(env_states)
311+
env_states = jax.device_get(env_states)
312312

313313
frames = [
314314
render_frame(jax.tree.map(lambda x: x[i], env_states))

0 commit comments

Comments
 (0)