|
58 | 58 | from typing import Optional |
59 | 59 |
|
60 | 60 |
|
| 61 | +def _is_compiled(model: nn.Module) -> bool: |
| 62 | + """True if *model* is already wrapped by torch.compile (OptimizedModule).""" |
| 63 | + try: |
| 64 | + from torch._dynamo.eval_frame import OptimizedModule |
| 65 | + return isinstance(model, OptimizedModule) |
| 66 | + except ImportError: |
| 67 | + return False |
| 68 | + |
| 69 | + |
61 | 70 | def _decode_batch(matcher, detectors, enable_correlated): |
62 | 71 | return matcher.decode_batch(detectors, enable_correlations=enable_correlated) |
63 | 72 |
|
@@ -917,9 +926,9 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic |
917 | 926 | if dist.rank == 0: |
918 | 927 | print(f"[LER] channels_last_3d not applied: {e}") |
919 | 928 |
|
920 | | - _applied_compile = False |
| 929 | + _applied_compile = _is_compiled(model) |
921 | 930 | _compile_enabled = _get_env_bool("PREDECODER_TORCH_COMPILE", True) |
922 | | - if not _will_export_onnx and _compile_enabled: |
| 931 | + if not _will_export_onnx and _compile_enabled and not _applied_compile: |
923 | 932 | try: |
924 | 933 | model = torch.compile(model, mode="default") |
925 | 934 | _applied_compile = True |
@@ -1686,10 +1695,11 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict: |
1686 | 1695 | res_ones_Z = torch.tensor(0, dtype=torch.int64, device=device) |
1687 | 1696 |
|
1688 | 1697 | model.eval() |
1689 | | - try: |
1690 | | - model = torch.compile(model, mode="default") |
1691 | | - except Exception: |
1692 | | - pass |
| 1698 | + if not _is_compiled(model): |
| 1699 | + try: |
| 1700 | + model = torch.compile(model, mode="default") |
| 1701 | + except Exception: |
| 1702 | + pass |
1693 | 1703 | try: |
1694 | 1704 | model = model.to(memory_format=torch.channels_last_3d) |
1695 | 1705 | except Exception: |
@@ -1882,9 +1892,13 @@ def compute_syndrome_density_reduction(model, device, dist, cfg) -> dict: |
1882 | 1892 | res_ones_Z += R_Z[mask_Z].sum(dtype=torch.int64) |
1883 | 1893 |
|
1884 | 1894 | t_end = time.perf_counter() |
1885 | | - # Disable detailed printing for now |
1886 | | - if False: |
1887 | | - print(f"Time taken for batch_count={batch_count}: {t_end - t_start:.3f}s") |
| 1895 | + |
| 1896 | + # Eagerly tear down DataLoader workers to release /dev/shm semaphores |
| 1897 | + # before LER spins up its own DataLoader. |
| 1898 | + del data_iter |
| 1899 | + del test_dataloader |
| 1900 | + import gc |
| 1901 | + gc.collect() |
1888 | 1902 |
|
1889 | 1903 | # ----- All-reduce across ranks ----- |
1890 | 1904 | if torch.distributed.is_available() and torch.distributed.is_initialized(): |
|
0 commit comments