From 7b23daaf3bb57ac0f23648b2532001e745ecac28 Mon Sep 17 00:00:00 2001 From: Abhay D Date: Thu, 12 Feb 2026 11:47:51 -0800 Subject: [PATCH 1/2] Allow configuring pytorch compilation mode --- src/openpi/models/pi0_config.py | 4 ++++ src/openpi/models_pytorch/pi0_pytorch.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/openpi/models/pi0_config.py b/src/openpi/models/pi0_config.py index b0f6b662ac..8d92ee4484 100644 --- a/src/openpi/models/pi0_config.py +++ b/src/openpi/models/pi0_config.py @@ -32,11 +32,15 @@ class Pi0Config(_model.BaseModelConfig): # This config option is not used directly by the model, but it is read by the ModelTransformFactory. discrete_state_input: bool = None # type: ignore + pytorch_compile_mode: str | None = "max-autotune" + def __post_init__(self): if self.max_token_len is None: object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48) if self.discrete_state_input is None: object.__setattr__(self, "discrete_state_input", self.pi05) + if self.pytorch_compile_mode is not None: + assert self.pytorch_compile_mode in ["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"] @property @override diff --git a/src/openpi/models_pytorch/pi0_pytorch.py b/src/openpi/models_pytorch/pi0_pytorch.py index 92ff95028a..e68ddb7cc0 100644 --- a/src/openpi/models_pytorch/pi0_pytorch.py +++ b/src/openpi/models_pytorch/pi0_pytorch.py @@ -109,7 +109,8 @@ def __init__(self, config): self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) torch.set_float32_matmul_precision("high") - self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune") + if config.pytorch_compile_mode is not None: + self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode) # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False From 417042a927159619b3fb0c4605334e21fc9ba9cc Mon Sep 17 00:00:00 2001 From: Abhay D Date: Thu, 19 Mar 2026 12:54:39 -0700 Subject: [PATCH 2/2] ruff --- src/openpi/models/pi0_config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/openpi/models/pi0_config.py b/src/openpi/models/pi0_config.py index 8d92ee4484..584d83f199 100644 --- a/src/openpi/models/pi0_config.py +++ b/src/openpi/models/pi0_config.py @@ -40,7 +40,12 @@ def __post_init__(self): if self.discrete_state_input is None: object.__setattr__(self, "discrete_state_input", self.pi05) if self.pytorch_compile_mode is not None: - assert self.pytorch_compile_mode in ["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"] + assert self.pytorch_compile_mode in [ + "default", + "reduce-overhead", + "max-autotune", + "max-autotune-no-cudagraphs", + ] @property @override