Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions code/evaluation/logical_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,10 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
pass
except Exception:
pass
# torch.compile + DataLoader spawn workers causes a segfault; fall back to
# single-process loading when compile was applied.
if _applied_compile and int(test_loader_kwargs.get("num_workers", 0)) > 0:
test_loader_kwargs["num_workers"] = 0
# Handle prefetch_factor when num_workers=0
if test_loader_kwargs.get('num_workers', 0) == 0:
test_loader_kwargs.pop('prefetch_factor', None)
Expand Down Expand Up @@ -1653,6 +1657,11 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict:
pass
except Exception:
pass
# torch.compile + DataLoader spawn workers causes a segfault; fall back to
# single-process loading when compile will be applied.
if _get_env_bool("PREDECODER_TORCH_COMPILE",
True) and int(test_loader_kwargs.get("num_workers", 0)) > 0:
test_loader_kwargs["num_workers"] = 0
if int(test_loader_kwargs.get("num_workers", 0)) == 0:
test_loader_kwargs.pop("prefetch_factor", None)
if test_loader_kwargs.get("persistent_workers", False):
Expand Down
58 changes: 58 additions & 0 deletions code/tests/test_v3_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,63 @@ def test_exact_threshold_qualifies(self):
)


# ---------------------------------------------------------------------------
# torch.compile + DataLoader spawn-worker guard
# ---------------------------------------------------------------------------
class TestCompileSpawnGuard(unittest.TestCase):
"""Guard that forces num_workers=0 when torch.compile is active.

torch.compile + DataLoader multiprocessing_context=spawn segfaults (leaked
semaphores, core dump). Both LER inference and SDR functions apply this
guard before constructing the DataLoader.
"""

def _apply_ler_guard(self, applied_compile: bool, num_workers: int) -> dict:
"""Mirror the guard in run_inference_and_decode_pre_decoder_memory."""
kwargs = {"num_workers": num_workers, "prefetch_factor": 2}
if applied_compile and int(kwargs.get("num_workers", 0)) > 0:
kwargs["num_workers"] = 0
if kwargs.get("num_workers", 0) == 0:
kwargs.pop("prefetch_factor", None)
return kwargs

def _apply_sdr_guard(self, compile_env: str, num_workers: int) -> dict:
"""Mirror the guard in compute_syndrome_density_reduction."""
kwargs = {"num_workers": num_workers}
with patch.dict(os.environ, {"PREDECODER_TORCH_COMPILE": compile_env}):
if _get_env_bool("PREDECODER_TORCH_COMPILE",
True) and int(kwargs.get("num_workers", 0)) > 0:
kwargs["num_workers"] = 0
return kwargs

# --- LER inference guard ---

def test_ler_guard_forces_zero_workers_when_compiled(self):
kwargs = self._apply_ler_guard(applied_compile=True, num_workers=16)
self.assertEqual(kwargs["num_workers"], 0)

def test_ler_guard_keeps_workers_when_not_compiled(self):
kwargs = self._apply_ler_guard(applied_compile=False, num_workers=16)
self.assertEqual(kwargs["num_workers"], 16)

def test_ler_guard_noop_when_workers_already_zero(self):
kwargs = self._apply_ler_guard(applied_compile=True, num_workers=0)
self.assertEqual(kwargs["num_workers"], 0)

def test_ler_guard_drops_prefetch_factor_when_workers_zeroed(self):
kwargs = self._apply_ler_guard(applied_compile=True, num_workers=16)
self.assertNotIn("prefetch_factor", kwargs)

# --- SDR guard (env-var based) ---

def test_sdr_guard_forces_zero_workers_when_compile_env_on(self):
kwargs = self._apply_sdr_guard(compile_env="1", num_workers=16)
self.assertEqual(kwargs["num_workers"], 0)

def test_sdr_guard_keeps_workers_when_compile_env_off(self):
kwargs = self._apply_sdr_guard(compile_env="0", num_workers=16)
self.assertEqual(kwargs["num_workers"], 16)


if __name__ == "__main__":
unittest.main()
Loading