diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 5c0c885..54b65e9 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -58,6 +58,15 @@ from typing import Optional +def _is_compiled(model: nn.Module) -> bool: + """True if *model* is already wrapped by torch.compile (OptimizedModule).""" + try: + from torch._dynamo.eval_frame import OptimizedModule + return isinstance(model, OptimizedModule) + except ImportError: + return False + + def _decode_batch(matcher, detectors, enable_correlated): return matcher.decode_batch(detectors, enable_correlations=enable_correlated) @@ -917,9 +926,9 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic if dist.rank == 0: print(f"[LER] channels_last_3d not applied: {e}") - _applied_compile = False + _applied_compile = _is_compiled(model) _compile_enabled = _get_env_bool("PREDECODER_TORCH_COMPILE", True) - if not _will_export_onnx and _compile_enabled: + if not _will_export_onnx and _compile_enabled and not _applied_compile: try: model = torch.compile(model, mode="default") _applied_compile = True @@ -1686,10 +1695,11 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict: res_ones_Z = torch.tensor(0, dtype=torch.int64, device=device) model.eval() - try: - model = torch.compile(model, mode="default") - except Exception: - pass + if not _is_compiled(model): + try: + model = torch.compile(model, mode="default") + except Exception: + pass try: model = model.to(memory_format=torch.channels_last_3d) except Exception: @@ -1882,9 +1892,13 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict: res_ones_Z += R_Z[mask_Z].sum(dtype=torch.int64) t_end = time.perf_counter() - # Disable detailed printing for now - if False: - print(f"Time taken for batch_count={batch_count}: {t_end - t_start:.3f}s") + + # Eagerly tear down DataLoader workers to release /dev/shm semaphores + # before LER spins up its own DataLoader. + del data_iter + del test_dataloader + import gc + gc.collect() # ----- All-reduce across ranks ----- if torch.distributed.is_available() and torch.distributed.is_initialized():