Skip to content

Commit ec4e8bf

Browse files
ivanbasovclaude
andcommitted
fix(mid): seed BitMatrixSampler explicitly to restore test reproducibility
torch.manual_seed() does not control cuQuantum's BitMatrixSampler internal RNG, so the two mid-GPU tests that relied on it for reproducibility were non-deterministic and intermittently failing. Add an optional `seed` parameter to `dem_sampling()` and `MemoryCircuitTorch.generate_batch()`. When a seed is provided a fresh BitMatrixSampler is always created with `Options(seed=N)`, resetting its internal RNG and guaranteeing identical outputs on every call with the same seed. Production paths (seed=None) are unaffected — the cached sampler is reused as before. Update the two failing tests to use the explicit seed kwarg instead of torch.manual_seed(): - test_he_reduces_error_weight: seed=123 - test_full_pipeline_w2_reproducible: seed=100 Fixes: NVIDIA/Ising-Decoding CI run 23963347042 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 838d14f commit ec4e8bf

4 files changed

Lines changed: 29 additions & 13 deletions

File tree

code/qec/dem_sampling.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def _custab_available() -> bool:
5959
_cached_HT: "torch.Tensor | None" = None
6060
_cached_max_shots: int = 0
6161
_cached_device_id: int | None = None
62+
_cached_seed: "int | None" = None
6263

6364
_DEM_TIMINGS_S: deque[float] = deque(maxlen=200)
6465
_custab_path_logged: bool = False
@@ -75,19 +76,21 @@ def get_dem_sampling_avg_ms() -> float:
7576

7677
def _reset_sampler_cache() -> None:
7778
"""Reset the module-level sampler cache."""
78-
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots, _cached_device_id
79+
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots, _cached_device_id, _cached_seed
7980
_cached_sampler = None
8081
_cached_H = None
8182
_cached_HT = None
8283
_cached_max_shots = 0
8384
_cached_device_id = None
85+
_cached_seed = None
8486

8587

8688
def dem_sampling(
8789
H: torch.Tensor,
8890
p: torch.Tensor,
8991
batch_size: int,
90-
device_id: int | None = None
92+
device_id: int | None = None,
93+
seed: int | None = None,
9194
) -> torch.Tensor:
9295
"""
9396
Sample errors from a detector error model (DEM) via cuST BitMatrixSampler.
@@ -98,6 +101,10 @@ def dem_sampling(
98101
batch_size: int - Number of samples to generate
99102
device_id: Optional int - Device ID for cuST. If omitted, infer from
100103
H.device when H is on CUDA.
104+
seed: Optional int - RNG seed for the BitMatrixSampler. When provided,
105+
a fresh sampler is always created so that the same seed produces
106+
identical outputs on repeated calls (useful for reproducible tests).
107+
When None (default), the cached sampler is reused across calls.
101108
102109
Returns:
103110
frames_xz: (batch_size, 2*num_detectors) uint8 - Detector outcomes
@@ -106,7 +113,7 @@ def dem_sampling(
106113
from cuquantum.stabilizer.simulator import Options
107114

108115
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots
109-
global _cached_device_id, _custab_path_logged
116+
global _cached_device_id, _cached_seed, _custab_path_logged
110117

111118
if H.ndim != 2:
112119
raise ValueError(f"H must be 2-D, got ndim={H.ndim}")
@@ -129,9 +136,14 @@ def dem_sampling(
129136
_cached_H = H
130137
_cached_sampler = None
131138
_cached_device_id = None
139+
_cached_seed = None
132140

141+
# When a seed is requested we always create a fresh sampler so that the
142+
# BitMatrixSampler's internal RNG is reset to that seed, giving bit-for-bit
143+
# reproducibility across repeated calls with the same seed value.
133144
need_new = (
134-
_cached_sampler is None or batch_size > _cached_max_shots or _cached_device_id != device_id
145+
_cached_sampler is None or batch_size > _cached_max_shots
146+
or _cached_device_id != device_id or seed is not None
135147
)
136148

137149
if need_new:
@@ -146,15 +158,19 @@ def dem_sampling(
146158
H_in = _cached_HT.detach().cpu().numpy().astype(np.uint8)
147159
p_in = p.detach().cpu().numpy().astype(np.float64)
148160
pkg = "numpy"
161+
opt_kwargs: dict = {"device_id": device_id}
162+
if seed is not None:
163+
opt_kwargs["seed"] = seed
149164
_cached_sampler = BitMatrixSampler(
150165
H_in,
151166
p_in,
152167
max_shots,
153168
package=pkg,
154-
options=Options(device_id=device_id),
169+
options=Options(**opt_kwargs),
155170
)
156171
_cached_max_shots = max_shots
157172
_cached_device_id = device_id
173+
_cached_seed = seed
158174

159175
t0 = time.perf_counter()
160176
if gpu_native:

code/qec/surface_code/memory_circuit_torch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def generate_batch(
234234
batch_size: int,
235235
return_aux: bool = False,
236236
collect_timing: bool = False,
237+
seed: int | None = None,
237238
) -> Union[
238239
tuple[torch.Tensor, torch.Tensor],
239240
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
@@ -248,6 +249,8 @@ def generate_batch(
248249
build Stim dets_and_obs from circuit-order measurements.
249250
- If collect_timing=True, also return timing breakdown in milliseconds:
250251
data generation, HE, format, and total.
252+
- If seed is given, the BitMatrixSampler is re-created with that seed so
253+
repeated calls with the same seed produce identical outputs.
251254
"""
252255
if self._compile_thread is not None:
253256
# torch.compile warmup can be slow; 20 min cap prevents silent hangs.
@@ -269,6 +272,7 @@ def generate_batch(
269272
self.p,
270273
int(batch_size),
271274
device_id=device_id,
275+
seed=seed,
272276
) # (B, 2*num_detectors)
273277
meas_old = measure_from_stacked_frames(
274278
frames_xz, self.meas_qubits, self.meas_bases, nq=self.nq

code/tests/mid/test_homological_equivalence.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,10 +1250,8 @@ def test_he_reduces_error_weight(self):
12501250
gen0 = self._make_generator(num_he_cycles=0)
12511251
gen1 = self._make_generator(num_he_cycles=2)
12521252

1253-
torch.manual_seed(123)
1254-
trainX0, trainY0 = gen0.generate_batch(batch_size=self.batch_size)
1255-
torch.manual_seed(123)
1256-
trainX1, trainY1 = gen1.generate_batch(batch_size=self.batch_size)
1253+
trainX0, trainY0 = gen0.generate_batch(batch_size=self.batch_size, seed=123)
1254+
trainX1, trainY1 = gen1.generate_batch(batch_size=self.batch_size, seed=123)
12571255

12581256
# trainX is derived from meas_old and should be identical for same sample.
12591257
self.assertTrue(torch.equal(trainX0, trainX1))

code/tests/mid/test_w2_verify.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def test_full_pipeline_w2_reproducible(self):
8888
A=None,
8989
)
9090
B = 128
91-
torch.manual_seed(100)
92-
tX_a, tY_a = gen_w2.generate_batch(batch_size=B)
93-
torch.manual_seed(100)
94-
tX_b, tY_b = gen_w2.generate_batch(batch_size=B)
91+
tX_a, tY_a = gen_w2.generate_batch(batch_size=B, seed=100)
92+
tX_b, tY_b = gen_w2.generate_batch(batch_size=B, seed=100)
9593
self.assertTrue(torch.allclose(tX_a, tX_b), "trainX should match for same seed")
9694
self.assertTrue(torch.allclose(tY_a, tY_b), "trainY should match for same seed")

0 commit comments

Comments
 (0)