Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions code/evaluation/logical_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@
from typing import Optional


def _is_compiled(model: nn.Module) -> bool:
"""True if *model* is already wrapped by torch.compile (OptimizedModule)."""
try:
from torch._dynamo.eval_frame import OptimizedModule
return isinstance(model, OptimizedModule)
except ImportError:
return False


def _decode_batch(matcher, detectors, enable_correlated):
return matcher.decode_batch(detectors, enable_correlations=enable_correlated)

Expand Down Expand Up @@ -917,9 +926,9 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
if dist.rank == 0:
print(f"[LER] channels_last_3d not applied: {e}")

_applied_compile = False
_applied_compile = _is_compiled(model)
_compile_enabled = _get_env_bool("PREDECODER_TORCH_COMPILE", True)
if not _will_export_onnx and _compile_enabled:
if not _will_export_onnx and _compile_enabled and not _applied_compile:
try:
model = torch.compile(model, mode="default")
_applied_compile = True
Expand Down Expand Up @@ -1686,10 +1695,11 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict:
res_ones_Z = torch.tensor(0, dtype=torch.int64, device=device)

model.eval()
try:
model = torch.compile(model, mode="default")
except Exception:
pass
if not _is_compiled(model):
try:
model = torch.compile(model, mode="default")
except Exception:
pass
try:
model = model.to(memory_format=torch.channels_last_3d)
except Exception:
Expand Down Expand Up @@ -1882,9 +1892,13 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict:
res_ones_Z += R_Z[mask_Z].sum(dtype=torch.int64)

t_end = time.perf_counter()
# Disable detailed printing for now
if False:
print(f"Time taken for batch_count={batch_count}: {t_end - t_start:.3f}s")

# Eagerly tear down DataLoader workers to release /dev/shm semaphores
# before LER spins up its own DataLoader.
del data_iter
del test_dataloader
import gc
gc.collect()

# ----- All-reduce across ranks -----
if torch.distributed.is_available() and torch.distributed.is_initialized():
Expand Down
Loading