diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 5c0c885..0e9fecf 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -1046,6 +1046,10 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic pass except Exception: pass + # torch.compile + DataLoader spawn workers causes a segfault; fall back to + # single-process loading when compile was applied. + if _applied_compile and int(test_loader_kwargs.get("num_workers", 0)) > 0: + test_loader_kwargs["num_workers"] = 0 # Handle prefetch_factor when num_workers=0 if test_loader_kwargs.get('num_workers', 0) == 0: test_loader_kwargs.pop('prefetch_factor', None) @@ -1653,6 +1657,11 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict: pass except Exception: pass + # torch.compile + DataLoader spawn workers causes a segfault; fall back to + # single-process loading when compile will be applied. + if _get_env_bool("PREDECODER_TORCH_COMPILE", + True) and int(test_loader_kwargs.get("num_workers", 0)) > 0: + test_loader_kwargs["num_workers"] = 0 if int(test_loader_kwargs.get("num_workers", 0)) == 0: test_loader_kwargs.pop("prefetch_factor", None) if test_loader_kwargs.get("persistent_workers", False): diff --git a/code/tests/test_v3_optimizations.py b/code/tests/test_v3_optimizations.py index 8827ba3..cfd9af4 100644 --- a/code/tests/test_v3_optimizations.py +++ b/code/tests/test_v3_optimizations.py @@ -211,5 +211,63 @@ def test_exact_threshold_qualifies(self): ) +# --------------------------------------------------------------------------- +# torch.compile + DataLoader spawn-worker guard +# --------------------------------------------------------------------------- +class TestCompileSpawnGuard(unittest.TestCase): + """Guard that forces num_workers=0 when torch.compile is active. + + torch.compile + DataLoader multiprocessing_context=spawn segfaults (leaked + semaphores, core dump). Both LER inference and SDR functions apply this + guard before constructing the DataLoader. + """ + + def _apply_ler_guard(self, applied_compile: bool, num_workers: int) -> dict: + """Mirror the guard in run_inference_and_decode_pre_decoder_memory.""" + kwargs = {"num_workers": num_workers, "prefetch_factor": 2} + if applied_compile and int(kwargs.get("num_workers", 0)) > 0: + kwargs["num_workers"] = 0 + if kwargs.get("num_workers", 0) == 0: + kwargs.pop("prefetch_factor", None) + return kwargs + + def _apply_sdr_guard(self, compile_env: str, num_workers: int) -> dict: + """Mirror the guard in compute_syndrome_density_reduction.""" + kwargs = {"num_workers": num_workers} + with patch.dict(os.environ, {"PREDECODER_TORCH_COMPILE": compile_env}): + if _get_env_bool("PREDECODER_TORCH_COMPILE", + True) and int(kwargs.get("num_workers", 0)) > 0: + kwargs["num_workers"] = 0 + return kwargs + + # --- LER inference guard --- + + def test_ler_guard_forces_zero_workers_when_compiled(self): + kwargs = self._apply_ler_guard(applied_compile=True, num_workers=16) + self.assertEqual(kwargs["num_workers"], 0) + + def test_ler_guard_keeps_workers_when_not_compiled(self): + kwargs = self._apply_ler_guard(applied_compile=False, num_workers=16) + self.assertEqual(kwargs["num_workers"], 16) + + def test_ler_guard_noop_when_workers_already_zero(self): + kwargs = self._apply_ler_guard(applied_compile=True, num_workers=0) + self.assertEqual(kwargs["num_workers"], 0) + + def test_ler_guard_drops_prefetch_factor_when_workers_zeroed(self): + kwargs = self._apply_ler_guard(applied_compile=True, num_workers=16) + self.assertNotIn("prefetch_factor", kwargs) + + # --- SDR guard (env-var based) --- + + def test_sdr_guard_forces_zero_workers_when_compile_env_on(self): + kwargs = self._apply_sdr_guard(compile_env="1", num_workers=16) + self.assertEqual(kwargs["num_workers"], 0) + + def test_sdr_guard_keeps_workers_when_compile_env_off(self): + kwargs = self._apply_sdr_guard(compile_env="0", num_workers=16) + self.assertEqual(kwargs["num_workers"], 16) + + if __name__ == "__main__": unittest.main()