diff --git a/examples/convert_jax_model_to_pytorch.py b/examples/convert_jax_model_to_pytorch.py index a605d0dc0f..75c5cee2c1 100644 --- a/examples/convert_jax_model_to_pytorch.py +++ b/examples/convert_jax_model_to_pytorch.py @@ -533,7 +533,7 @@ def __init__(self): safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors")) # Copy assets folder if it exists - assets_source = pathlib.Path(checkpoint_dir).parent / "assets" + assets_source = pathlib.Path(checkpoint_dir) / "assets" if assets_source.exists(): assets_dest = pathlib.Path(output_path) / "assets" if assets_dest.exists():