diff --git a/src/openpi/models/pi0_config.py b/src/openpi/models/pi0_config.py index b0f6b662ac..584d83f199 100644 --- a/src/openpi/models/pi0_config.py +++ b/src/openpi/models/pi0_config.py @@ -32,11 +32,20 @@ class Pi0Config(_model.BaseModelConfig): # This config option is not used directly by the model, but it is read by the ModelTransformFactory. discrete_state_input: bool = None # type: ignore + pytorch_compile_mode: str | None = "max-autotune" + def __post_init__(self): if self.max_token_len is None: object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48) if self.discrete_state_input is None: object.__setattr__(self, "discrete_state_input", self.pi05) + if self.pytorch_compile_mode is not None: + assert self.pytorch_compile_mode in [ + "default", + "reduce-overhead", + "max-autotune", + "max-autotune-no-cudagraphs", + ] @property @override diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 92ff95028a..e68ddb7cc0 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -109,7 +109,8 @@ def __init__(self, config): self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) torch.set_float32_matmul_precision("high") - self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") + if config.pytorch_compile_mode is not None: + self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode) # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False