Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/long-running-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ jobs:
PREDECODER_TRAIN_SAMPLES: "32768"
PREDECODER_VAL_SAMPLES: "4096"
PREDECODER_TEST_SAMPLES: "4096"
PREDECODER_TRAIN_EPOCHS: "1"
PREDECODER_TRAIN_EPOCHS: "30"
PREDECODER_DISABLE_SDR: "1"

- name: Multi-orientation inference (O1–O4) with LER output check
shell: bash
Expand Down
5 changes: 5 additions & 0 deletions code/evaluation/logical_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,11 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
pass
except Exception:
pass
# torch.compile + spawn workers causes a segfault (CUDA context conflict in
# spawned subprocesses after the model is compiled). Fall back to in-process
# loading when torch.compile has been applied.
if _applied_compile and int(test_loader_kwargs.get("num_workers", 0)) > 0:
test_loader_kwargs["num_workers"] = 0
# Handle prefetch_factor when num_workers=0
if test_loader_kwargs.get('num_workers', 0) == 0:
test_loader_kwargs.pop('prefetch_factor', None)
Expand Down
22 changes: 21 additions & 1 deletion code/tests/test_metrics_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""Tests for evaluation.metrics (configure_metrics, _extract_reduction_factor)."""

import inspect
import sys
import unittest
from pathlib import Path
Expand All @@ -22,7 +23,7 @@
if str(_repo_code) not in sys.path:
sys.path.insert(0, str(_repo_code))

from evaluation.metrics import configure_metrics, _extract_reduction_factor
from evaluation.metrics import configure_metrics, _extract_reduction_factor, compute_syndrome_density


class TestConfigureMetrics(unittest.TestCase):
Expand Down Expand Up @@ -63,3 +64,22 @@ def test_extract_from_empty_dict_returns_none(self):
def test_extract_from_nested_stim(self):
result = {"other": 1, "stim": {"reduction factor (X/Z)": 2.5}}
self.assertEqual(_extract_reduction_factor(result), 2.5)


class TestComputeSyndromeDensitySignature(unittest.TestCase):
"""Regression guard: sdr_as_percent must not appear in compute_syndrome_density().

This kwarg is a display-only flag owned by train.py (controls "%" vs "x" in log
output). It has been accidentally passed to compute_syndrome_density() twice,
causing a TypeError that is only caught by long-running GPU tests. This test
keeps the contract cheap to verify on every short CI run.
"""

def test_sdr_as_percent_not_a_parameter(self):
sig = inspect.signature(compute_syndrome_density)
self.assertNotIn(
"sdr_as_percent",
sig.parameters,
"sdr_as_percent is a display-only flag in train.py and must not be added "
"to compute_syndrome_density(); passing it causes TypeError at runtime.",
)
65 changes: 33 additions & 32 deletions code/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,38 +687,39 @@ def init_process_group_with_timeout(*args, **kwargs):
if getattr(cfg.train, "epochs", None) is None:
cfg.train.epochs = 100

# Optional timing-mode overrides (env-based) for short measurement runs.
if os.environ.get("PREDECODER_TIMING_RUN", "0") == "1":
train_samples_env = os.environ.get("PREDECODER_TRAIN_SAMPLES")
val_samples_env = os.environ.get("PREDECODER_VAL_SAMPLES")
test_samples_env = os.environ.get("PREDECODER_TEST_SAMPLES")
epochs_env = os.environ.get("PREDECODER_TRAIN_EPOCHS")
try:
if train_samples_env:
cfg.train.num_samples = int(train_samples_env)
except Exception:
pass
try:
if val_samples_env:
cfg.val.num_samples = int(val_samples_env)
except Exception:
pass
try:
if test_samples_env:
cfg.test.num_samples = int(test_samples_env)
except Exception:
pass
try:
if epochs_env:
cfg.train.epochs = int(epochs_env)
except Exception:
pass
milestones_env = os.environ.get("PREDECODER_LR_MILESTONES")
try:
if milestones_env:
cfg.lr_scheduler.milestones = [float(x) for x in milestones_env.split(",")]
except Exception:
pass
# Env-based overrides for samples, epochs, and LR milestones.
# These apply unconditionally so that CI jobs and quick local runs can
# override config values without needing PREDECODER_TIMING_RUN=1.
train_samples_env = os.environ.get("PREDECODER_TRAIN_SAMPLES")
val_samples_env = os.environ.get("PREDECODER_VAL_SAMPLES")
test_samples_env = os.environ.get("PREDECODER_TEST_SAMPLES")
epochs_env = os.environ.get("PREDECODER_TRAIN_EPOCHS")
try:
if train_samples_env:
cfg.train.num_samples = int(train_samples_env)
except Exception:
pass
try:
if val_samples_env:
cfg.val.num_samples = int(val_samples_env)
except Exception:
pass
try:
if test_samples_env:
cfg.test.num_samples = int(test_samples_env)
except Exception:
pass
try:
if epochs_env:
cfg.train.epochs = int(epochs_env)
except Exception:
pass
milestones_env = os.environ.get("PREDECODER_LR_MILESTONES")
try:
if milestones_env:
cfg.lr_scheduler.milestones = [float(x) for x in milestones_env.split(",")]
except Exception:
pass

if dist.rank == 0:
print(f"Effective workflow.task: {cfg.workflow.task}")
Expand Down
Loading