Skip to content

Commit ea133c5

Browse files
authored
fix: guard against double torch.compile in SDR→LER path (#39)
When SDR runs before LER, the same model object gets torch.compile'd twice, producing a nested OptimizedModule that segfaults during the first forward pass. Skip compilation when the model is already compiled. Also eagerly tear down SDR's DataLoader workers before LER starts to prevent leaked /dev/shm semaphores. Signed-off-by: kvmto <kmato@nvidia.com>
1 parent d36b7da commit ea133c5

1 file changed

Lines changed: 23 additions & 9 deletions

File tree

code/evaluation/logical_error_rate.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@
5858
from typing import Optional
5959

6060

61+
def _is_compiled(model: nn.Module) -> bool:
62+
"""True if *model* is already wrapped by torch.compile (OptimizedModule)."""
63+
try:
64+
from torch._dynamo.eval_frame import OptimizedModule
65+
return isinstance(model, OptimizedModule)
66+
except ImportError:
67+
return False
68+
69+
6170
def _decode_batch(matcher, detectors, enable_correlated):
6271
return matcher.decode_batch(detectors, enable_correlations=enable_correlated)
6372

@@ -917,9 +926,9 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
917926
if dist.rank == 0:
918927
print(f"[LER] channels_last_3d not applied: {e}")
919928

920-
_applied_compile = False
929+
_applied_compile = _is_compiled(model)
921930
_compile_enabled = _get_env_bool("PREDECODER_TORCH_COMPILE", True)
922-
if not _will_export_onnx and _compile_enabled:
931+
if not _will_export_onnx and _compile_enabled and not _applied_compile:
923932
try:
924933
model = torch.compile(model, mode="default")
925934
_applied_compile = True
@@ -1686,10 +1695,11 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict:
16861695
res_ones_Z = torch.tensor(0, dtype=torch.int64, device=device)
16871696

16881697
model.eval()
1689-
try:
1690-
model = torch.compile(model, mode="default")
1691-
except Exception:
1692-
pass
1698+
if not _is_compiled(model):
1699+
try:
1700+
model = torch.compile(model, mode="default")
1701+
except Exception:
1702+
pass
16931703
try:
16941704
model = model.to(memory_format=torch.channels_last_3d)
16951705
except Exception:
@@ -1882,9 +1892,13 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict:
18821892
res_ones_Z += R_Z[mask_Z].sum(dtype=torch.int64)
18831893

18841894
t_end = time.perf_counter()
1885-
# Disable detailed printing for now
1886-
if False:
1887-
print(f"Time taken for batch_count={batch_count}: {t_end - t_start:.3f}s")
1895+
1896+
# Eagerly tear down DataLoader workers to release /dev/shm semaphores
1897+
# before LER spins up its own DataLoader.
1898+
del data_iter
1899+
del test_dataloader
1900+
import gc
1901+
gc.collect()
18881902

18891903
# ----- All-reduce across ranks -----
18901904
if torch.distributed.is_available() and torch.distributed.is_initialized():

0 commit comments

Comments
 (0)