Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/openpi/models/pi0_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,20 @@ class Pi0Config(_model.BaseModelConfig):
# This config option is not used directly by the model, but it is read by the ModelTransformFactory.
discrete_state_input: bool = None # type: ignore

pytorch_compile_mode: str | None = "max-autotune"

def __post_init__(self):
if self.max_token_len is None:
object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48)
if self.discrete_state_input is None:
object.__setattr__(self, "discrete_state_input", self.pi05)
if self.pytorch_compile_mode is not None:
assert self.pytorch_compile_mode in [
"default",
"reduce-overhead",
"max-autotune",
"max-autotune-no-cudagraphs",
]

@property
@override
Expand Down
3 changes: 2 additions & 1 deletion src/openpi/models_pytorch/pi0_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def __init__(self, config):
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)

torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
if config.pytorch_compile_mode is not None:
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)

# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
Expand Down
Loading