Skip to content

Commit 4fbe92c

Browse files
ivanbasovclaude
andauthored
fix(mid): seed BitMatrixSampler explicitly to restore test reproducibility (#43)
* fix(ci): disable torch.compile in orientation training to prevent segfault torch.compile=on combined with DataLoader spawn workers during LER validation causes a segfault (20 leaked semaphores, core dumped). Set PREDECODER_TORCH_COMPILE=0 for the Train all orientations step. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Revert "fix(ci): disable torch.compile in orientation training to prevent segfault" This reverts commit 7f0f6c8. * fix(mid): seed BitMatrixSampler explicitly to restore test reproducibility torch.manual_seed() does not control cuQuantum's BitMatrixSampler internal RNG, so the two mid-GPU tests that relied on it for reproducibility were non-deterministic and intermittently failing. Add an optional `seed` parameter to `dem_sampling()` and `MemoryCircuitTorch.generate_batch()`. When a seed is provided a fresh BitMatrixSampler is always created with `Options(seed=N)`, resetting its internal RNG and guaranteeing identical outputs on every call with the same seed. Production paths (seed=None) are unaffected — the cached sampler is reused as before. Update the two failing tests to use the explicit seed kwarg instead of torch.manual_seed(): - test_he_reduces_error_weight: seed=123 - test_full_pipeline_w2_reproducible: seed=100 Fixes: NVIDIA/Ising-Decoding CI run 23963347042 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style: fix yapf line-break position in need_new condition Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test: add dem_sampling reproducibility tests for seed= parameter Add TestDEMSamplingReproducibility to test_dem_sampling.py with four cases: - same seed on CPU produces bit-exact identical frames - different seeds produce different frames - unseeded calls still reuse the cached sampler (perf regression guard) - same seed on GPU produces bit-exact identical frames (GPU-only) These tests use stochastic p values (0.1–0.9) so they would have caught the original regression: before the seed= fix, BitMatrixSampler's internal RNG was not reset between calls, making "same seed" reproducibility impossible regardless of torch.manual_seed(). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: use torch.Generator for seeded path; BitMatrixSampler RNG is not seedable Options.__init__() does not accept a 'seed' keyword — the cuST BitMatrixSampler's internal RNG is not exposed via the public API. Replace the attempted Options(seed=N) approach with a small pure-torch fallback (_torch_dem_sampling) that uses a local torch.Generator seeded to the requested value. This path is only taken when seed= is explicitly passed (tests); the production BitMatrixSampler cache path is unchanged. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: pass seed directly to BitMatrixSampler constructor BitMatrixSampler accepts seed as a constructor kwarg (not via Options). Replace the torch fallback workaround with the correct cuST API: pass seed= directly to BitMatrixSampler(..., seed=seed). A fresh sampler is created on every seeded call so its internal RNG is reset to the requested seed, guaranteeing identical outputs on repeated calls with the same value. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ef2b160 commit 4fbe92c

5 files changed

Lines changed: 102 additions & 19 deletions

File tree

code/qec/dem_sampling.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def _custab_available() -> bool:
5959
_cached_HT: "torch.Tensor | None" = None
6060
_cached_max_shots: int = 0
6161
_cached_device_id: int | None = None
62+
_cached_seed: "int | None" = None
6263

6364
_DEM_TIMINGS_S: deque[float] = deque(maxlen=200)
6465
_custab_path_logged: bool = False
@@ -75,19 +76,21 @@ def get_dem_sampling_avg_ms() -> float:
7576

7677
def _reset_sampler_cache() -> None:
7778
"""Reset the module-level sampler cache."""
78-
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots, _cached_device_id
79+
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots, _cached_device_id, _cached_seed
7980
_cached_sampler = None
8081
_cached_H = None
8182
_cached_HT = None
8283
_cached_max_shots = 0
8384
_cached_device_id = None
85+
_cached_seed = None
8486

8587

8688
def dem_sampling(
8789
H: torch.Tensor,
8890
p: torch.Tensor,
8991
batch_size: int,
90-
device_id: int | None = None
92+
device_id: int | None = None,
93+
seed: int | None = None,
9194
) -> torch.Tensor:
9295
"""
9396
Sample errors from a detector error model (DEM) via cuST BitMatrixSampler.
@@ -98,6 +101,10 @@ def dem_sampling(
98101
batch_size: int - Number of samples to generate
99102
device_id: Optional int - Device ID for cuST. If omitted, infer from
100103
H.device when H is on CUDA.
104+
seed: Optional int - RNG seed passed directly to BitMatrixSampler.
105+
When provided, a fresh sampler is created with that seed so repeated
106+
calls with the same seed produce identical outputs.
107+
When None (default), the cached sampler is reused across calls.
101108
102109
Returns:
103110
frames_xz: (batch_size, 2*num_detectors) uint8 - Detector outcomes
@@ -106,7 +113,7 @@ def dem_sampling(
106113
from cuquantum.stabilizer.simulator import Options
107114

108115
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots
109-
global _cached_device_id, _custab_path_logged
116+
global _cached_device_id, _cached_seed, _custab_path_logged
110117

111118
if H.ndim != 2:
112119
raise ValueError(f"H must be 2-D, got ndim={H.ndim}")
@@ -129,9 +136,14 @@ def dem_sampling(
129136
_cached_H = H
130137
_cached_sampler = None
131138
_cached_device_id = None
139+
_cached_seed = None
132140

141+
# When a seed is given always create a fresh sampler so its internal RNG is
142+
# reset to that seed, giving identical outputs on repeated calls with the
143+
# same seed value (per the BitMatrixSampler constructor contract).
133144
need_new = (
134-
_cached_sampler is None or batch_size > _cached_max_shots or _cached_device_id != device_id
145+
_cached_sampler is None or batch_size > _cached_max_shots or
146+
_cached_device_id != device_id or seed is not None
135147
)
136148

137149
if need_new:
@@ -146,15 +158,13 @@ def dem_sampling(
146158
H_in = _cached_HT.detach().cpu().numpy().astype(np.uint8)
147159
p_in = p.detach().cpu().numpy().astype(np.float64)
148160
pkg = "numpy"
149-
_cached_sampler = BitMatrixSampler(
150-
H_in,
151-
p_in,
152-
max_shots,
153-
package=pkg,
154-
options=Options(device_id=device_id),
155-
)
161+
bms_kwargs: dict = {"package": pkg, "options": Options(device_id=device_id)}
162+
if seed is not None:
163+
bms_kwargs["seed"] = seed
164+
_cached_sampler = BitMatrixSampler(H_in, p_in, max_shots, **bms_kwargs)
156165
_cached_max_shots = max_shots
157166
_cached_device_id = device_id
167+
_cached_seed = seed
158168

159169
t0 = time.perf_counter()
160170
if gpu_native:

code/qec/surface_code/memory_circuit_torch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def generate_batch(
234234
batch_size: int,
235235
return_aux: bool = False,
236236
collect_timing: bool = False,
237+
seed: int | None = None,
237238
) -> Union[
238239
tuple[torch.Tensor, torch.Tensor],
239240
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
@@ -248,6 +249,8 @@ def generate_batch(
248249
build Stim dets_and_obs from circuit-order measurements.
249250
- If collect_timing=True, also return timing breakdown in milliseconds:
250251
data generation, HE, format, and total.
252+
- If seed is given, the BitMatrixSampler is re-created with that seed so
253+
repeated calls with the same seed produce identical outputs.
251254
"""
252255
if self._compile_thread is not None:
253256
# torch.compile warmup can be slow; 20 min cap prevents silent hangs.
@@ -269,6 +272,7 @@ def generate_batch(
269272
self.p,
270273
int(batch_size),
271274
device_id=device_id,
275+
seed=seed,
272276
) # (B, 2*num_detectors)
273277
meas_old = measure_from_stacked_frames(
274278
frames_xz, self.meas_qubits, self.meas_bases, nq=self.nq

code/tests/mid/test_homological_equivalence.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,10 +1250,8 @@ def test_he_reduces_error_weight(self):
12501250
gen0 = self._make_generator(num_he_cycles=0)
12511251
gen1 = self._make_generator(num_he_cycles=2)
12521252

1253-
torch.manual_seed(123)
1254-
trainX0, trainY0 = gen0.generate_batch(batch_size=self.batch_size)
1255-
torch.manual_seed(123)
1256-
trainX1, trainY1 = gen1.generate_batch(batch_size=self.batch_size)
1253+
trainX0, trainY0 = gen0.generate_batch(batch_size=self.batch_size, seed=123)
1254+
trainX1, trainY1 = gen1.generate_batch(batch_size=self.batch_size, seed=123)
12571255

12581256
# trainX is derived from meas_old and should be identical for same sample.
12591257
self.assertTrue(torch.equal(trainX0, trainX1))

code/tests/mid/test_w2_verify.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def test_full_pipeline_w2_reproducible(self):
8888
A=None,
8989
)
9090
B = 128
91-
torch.manual_seed(100)
92-
tX_a, tY_a = gen_w2.generate_batch(batch_size=B)
93-
torch.manual_seed(100)
94-
tX_b, tY_b = gen_w2.generate_batch(batch_size=B)
91+
tX_a, tY_a = gen_w2.generate_batch(batch_size=B, seed=100)
92+
tX_b, tY_b = gen_w2.generate_batch(batch_size=B, seed=100)
9593
self.assertTrue(torch.allclose(tX_a, tX_b), "trainX should match for same seed")
9694
self.assertTrue(torch.allclose(tY_a, tY_b), "trainY should match for same seed")

code/tests/test_dem_sampling.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,5 +711,78 @@ def test_generate_batch_output_distribution(self) -> None:
711711
)
712712

713713

714+
# ============================================================================
715+
# Seeded reproducibility
716+
# ============================================================================
717+
718+
719+
class TestDEMSamplingReproducibility(unittest.TestCase):
720+
"""
721+
Verify that dem_sampling(seed=N) produces identical outputs on repeated
722+
calls with the same seed, using stochastic (non-trivial) probabilities.
723+
724+
This is the regression test for the bug where torch.manual_seed() was
725+
assumed to control BitMatrixSampler's internal RNG (it does not).
726+
The fix exposes an explicit seed= parameter that re-creates the sampler,
727+
resetting cuST's RNG to a known state.
728+
"""
729+
730+
H = torch.tensor(
731+
[[1, 0, 1, 0, 0], [0, 1, 1, 0, 0], [0, 0, 0, 1, 1]],
732+
dtype=torch.uint8,
733+
)
734+
p = torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9], dtype=torch.float32)
735+
736+
def setUp(self):
737+
_reset_sampler_cache()
738+
739+
def test_same_seed_cpu_reproducible(self) -> None:
740+
"""Two calls with the same seed on CPU must produce identical frames."""
741+
out_a = dem_sampling(self.H, self.p, 64, seed=42)
742+
_reset_sampler_cache()
743+
out_b = dem_sampling(self.H, self.p, 64, seed=42)
744+
self.assertTrue(
745+
torch.equal(out_a, out_b),
746+
"dem_sampling with the same seed must be bit-exact reproducible",
747+
)
748+
749+
def test_different_seeds_cpu_differ(self) -> None:
750+
"""Two calls with different seeds should (with overwhelming probability) differ."""
751+
out_a = dem_sampling(self.H, self.p, 256, seed=1)
752+
_reset_sampler_cache()
753+
out_b = dem_sampling(self.H, self.p, 256, seed=2)
754+
self.assertFalse(
755+
torch.equal(out_a, out_b),
756+
"dem_sampling with different seeds should produce different frames",
757+
)
758+
759+
def test_no_seed_does_not_break_cache(self) -> None:
760+
"""Unseeded calls still reuse the cached sampler (no regression to perf path)."""
761+
import qec.dem_sampling as _mod
762+
_reset_sampler_cache()
763+
dem_sampling(self.H, self.p, 64)
764+
sampler_after_first = _mod._cached_sampler
765+
dem_sampling(self.H, self.p, 64)
766+
sampler_after_second = _mod._cached_sampler
767+
self.assertIs(
768+
sampler_after_first,
769+
sampler_after_second,
770+
"Unseeded calls must reuse the cached BitMatrixSampler",
771+
)
772+
773+
@unittest.skipUnless(torch.cuda.is_available(), "requires CUDA GPU")
774+
def test_same_seed_gpu_reproducible(self) -> None:
775+
"""Two calls with the same seed on GPU must produce identical frames."""
776+
H = self.H.cuda()
777+
p = self.p.cuda()
778+
out_a = dem_sampling(H, p, 64, seed=42)
779+
_reset_sampler_cache()
780+
out_b = dem_sampling(H, p, 64, seed=42)
781+
self.assertTrue(
782+
torch.equal(out_a, out_b),
783+
"GPU: dem_sampling with the same seed must be bit-exact reproducible",
784+
)
785+
786+
714787
if __name__ == "__main__":
715788
unittest.main()

0 commit comments

Comments
 (0)