Skip to content

Commit ec856fa

Browse files
ivanbasovclaude
andcommitted
fix(ci): disable torch.compile in orientation training to prevent segfault
torch.compile=on combined with DataLoader spawn workers during LER validation causes a segfault (20 leaked semaphores, core dumped). Set PREDECODER_TORCH_COMPILE=0 for the Train all orientations step. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3099ef4 commit ec856fa

2 files changed

Lines changed: 65 additions & 0 deletions

File tree

code/evaluation/logical_error_rate.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,10 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
10461046
pass
10471047
except Exception:
10481048
pass
1049+
# torch.compile + DataLoader spawn workers causes a segfault; fall back to
1050+
# single-process loading when compile was applied.
1051+
if _applied_compile and int(test_loader_kwargs.get("num_workers", 0)) > 0:
1052+
test_loader_kwargs["num_workers"] = 0
10491053
# Handle prefetch_factor when num_workers=0
10501054
if test_loader_kwargs.get('num_workers', 0) == 0:
10511055
test_loader_kwargs.pop('prefetch_factor', None)
@@ -1653,6 +1657,10 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict:
16531657
pass
16541658
except Exception:
16551659
pass
1660+
# torch.compile + DataLoader spawn workers causes a segfault; fall back to
1661+
# single-process loading when compile will be applied.
1662+
if _get_env_bool("PREDECODER_TORCH_COMPILE", True) and int(test_loader_kwargs.get("num_workers", 0)) > 0:
1663+
test_loader_kwargs["num_workers"] = 0
16561664
if int(test_loader_kwargs.get("num_workers", 0)) == 0:
16571665
test_loader_kwargs.pop("prefetch_factor", None)
16581666
if test_loader_kwargs.get("persistent_workers", False):

code/tests/test_v3_optimizations.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,5 +211,62 @@ def test_exact_threshold_qualifies(self):
211211
)
212212

213213

214+
# ---------------------------------------------------------------------------
215+
# torch.compile + DataLoader spawn-worker guard
216+
# ---------------------------------------------------------------------------
217+
class TestCompileSpawnGuard(unittest.TestCase):
218+
"""Guard that forces num_workers=0 when torch.compile is active.
219+
220+
torch.compile + DataLoader multiprocessing_context=spawn segfaults (leaked
221+
semaphores, core dump). Both LER inference and SDR functions apply this
222+
guard before constructing the DataLoader.
223+
"""
224+
225+
def _apply_ler_guard(self, applied_compile: bool, num_workers: int) -> dict:
226+
"""Mirror the guard in run_inference_and_decode_pre_decoder_memory."""
227+
kwargs = {"num_workers": num_workers, "prefetch_factor": 2}
228+
if applied_compile and int(kwargs.get("num_workers", 0)) > 0:
229+
kwargs["num_workers"] = 0
230+
if kwargs.get("num_workers", 0) == 0:
231+
kwargs.pop("prefetch_factor", None)
232+
return kwargs
233+
234+
def _apply_sdr_guard(self, compile_env: str, num_workers: int) -> dict:
235+
"""Mirror the guard in compute_syndrome_density_reduction."""
236+
kwargs = {"num_workers": num_workers}
237+
with patch.dict(os.environ, {"PREDECODER_TORCH_COMPILE": compile_env}):
238+
if _get_env_bool("PREDECODER_TORCH_COMPILE", True) and int(kwargs.get("num_workers", 0)) > 0:
239+
kwargs["num_workers"] = 0
240+
return kwargs
241+
242+
# --- LER inference guard ---
243+
244+
def test_ler_guard_forces_zero_workers_when_compiled(self):
245+
kwargs = self._apply_ler_guard(applied_compile=True, num_workers=16)
246+
self.assertEqual(kwargs["num_workers"], 0)
247+
248+
def test_ler_guard_keeps_workers_when_not_compiled(self):
249+
kwargs = self._apply_ler_guard(applied_compile=False, num_workers=16)
250+
self.assertEqual(kwargs["num_workers"], 16)
251+
252+
def test_ler_guard_noop_when_workers_already_zero(self):
253+
kwargs = self._apply_ler_guard(applied_compile=True, num_workers=0)
254+
self.assertEqual(kwargs["num_workers"], 0)
255+
256+
def test_ler_guard_drops_prefetch_factor_when_workers_zeroed(self):
257+
kwargs = self._apply_ler_guard(applied_compile=True, num_workers=16)
258+
self.assertNotIn("prefetch_factor", kwargs)
259+
260+
# --- SDR guard (env-var based) ---
261+
262+
def test_sdr_guard_forces_zero_workers_when_compile_env_on(self):
263+
kwargs = self._apply_sdr_guard(compile_env="1", num_workers=16)
264+
self.assertEqual(kwargs["num_workers"], 0)
265+
266+
def test_sdr_guard_keeps_workers_when_compile_env_off(self):
267+
kwargs = self._apply_sdr_guard(compile_env="0", num_workers=16)
268+
self.assertEqual(kwargs["num_workers"], 16)
269+
270+
214271
if __name__ == "__main__":
215272
unittest.main()

0 commit comments

Comments
 (0)