Skip to content
32 changes: 21 additions & 11 deletions code/qec/dem_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _custab_available() -> bool:
_cached_HT: "torch.Tensor | None" = None
_cached_max_shots: int = 0
_cached_device_id: int | None = None
_cached_seed: "int | None" = None

_DEM_TIMINGS_S: deque[float] = deque(maxlen=200)
_custab_path_logged: bool = False
Expand All @@ -75,19 +76,21 @@ def get_dem_sampling_avg_ms() -> float:

def _reset_sampler_cache() -> None:
"""Reset the module-level sampler cache."""
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots, _cached_device_id
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots, _cached_device_id, _cached_seed
_cached_sampler = None
_cached_H = None
_cached_HT = None
_cached_max_shots = 0
_cached_device_id = None
_cached_seed = None


def dem_sampling(
H: torch.Tensor,
p: torch.Tensor,
batch_size: int,
device_id: int | None = None
device_id: int | None = None,
seed: int | None = None,
) -> torch.Tensor:
"""
Sample errors from a detector error model (DEM) via cuST BitMatrixSampler.
Expand All @@ -98,6 +101,10 @@ def dem_sampling(
batch_size: int - Number of samples to generate
device_id: Optional int - Device ID for cuST. If omitted, infer from
H.device when H is on CUDA.
seed: Optional int - RNG seed passed directly to BitMatrixSampler.
When provided, a fresh sampler is created with that seed so repeated
calls with the same seed produce identical outputs.
When None (default), the cached sampler is reused across calls.

Returns:
frames_xz: (batch_size, 2*num_detectors) uint8 - Detector outcomes
Expand All @@ -106,7 +113,7 @@ def dem_sampling(
from cuquantum.stabilizer.simulator import Options

global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots
global _cached_device_id, _custab_path_logged
global _cached_device_id, _cached_seed, _custab_path_logged

if H.ndim != 2:
raise ValueError(f"H must be 2-D, got ndim={H.ndim}")
Expand All @@ -129,9 +136,14 @@ def dem_sampling(
_cached_H = H
_cached_sampler = None
_cached_device_id = None
_cached_seed = None

# When a seed is given always create a fresh sampler so its internal RNG is
# reset to that seed, giving identical outputs on repeated calls with the
# same seed value (per the BitMatrixSampler constructor contract).
need_new = (
_cached_sampler is None or batch_size > _cached_max_shots or _cached_device_id != device_id
_cached_sampler is None or batch_size > _cached_max_shots or
_cached_device_id != device_id or seed is not None
)

if need_new:
Expand All @@ -146,15 +158,13 @@ def dem_sampling(
H_in = _cached_HT.detach().cpu().numpy().astype(np.uint8)
p_in = p.detach().cpu().numpy().astype(np.float64)
pkg = "numpy"
_cached_sampler = BitMatrixSampler(
H_in,
p_in,
max_shots,
package=pkg,
options=Options(device_id=device_id),
)
bms_kwargs: dict = {"package": pkg, "options": Options(device_id=device_id)}
if seed is not None:
bms_kwargs["seed"] = seed
_cached_sampler = BitMatrixSampler(H_in, p_in, max_shots, **bms_kwargs)
_cached_max_shots = max_shots
_cached_device_id = device_id
_cached_seed = seed

t0 = time.perf_counter()
if gpu_native:
Expand Down
4 changes: 4 additions & 0 deletions code/qec/surface_code/memory_circuit_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def generate_batch(
batch_size: int,
return_aux: bool = False,
collect_timing: bool = False,
seed: int | None = None,
) -> Union[
tuple[torch.Tensor, torch.Tensor],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
Expand All @@ -248,6 +249,8 @@ def generate_batch(
build Stim dets_and_obs from circuit-order measurements.
- If collect_timing=True, also return timing breakdown in milliseconds:
data generation, HE, format, and total.
- If seed is given, the BitMatrixSampler is re-created with that seed so
repeated calls with the same seed produce identical outputs.
"""
if self._compile_thread is not None:
# torch.compile warmup can be slow; 20 min cap prevents silent hangs.
Expand All @@ -269,6 +272,7 @@ def generate_batch(
self.p,
int(batch_size),
device_id=device_id,
seed=seed,
) # (B, 2*num_detectors)
meas_old = measure_from_stacked_frames(
frames_xz, self.meas_qubits, self.meas_bases, nq=self.nq
Expand Down
6 changes: 2 additions & 4 deletions code/tests/mid/test_homological_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,10 +1250,8 @@ def test_he_reduces_error_weight(self):
gen0 = self._make_generator(num_he_cycles=0)
gen1 = self._make_generator(num_he_cycles=2)

torch.manual_seed(123)
trainX0, trainY0 = gen0.generate_batch(batch_size=self.batch_size)
torch.manual_seed(123)
trainX1, trainY1 = gen1.generate_batch(batch_size=self.batch_size)
trainX0, trainY0 = gen0.generate_batch(batch_size=self.batch_size, seed=123)
trainX1, trainY1 = gen1.generate_batch(batch_size=self.batch_size, seed=123)

# trainX is derived from meas_old and should be identical for same sample.
self.assertTrue(torch.equal(trainX0, trainX1))
Expand Down
6 changes: 2 additions & 4 deletions code/tests/mid/test_w2_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def test_full_pipeline_w2_reproducible(self):
A=None,
)
B = 128
torch.manual_seed(100)
tX_a, tY_a = gen_w2.generate_batch(batch_size=B)
torch.manual_seed(100)
tX_b, tY_b = gen_w2.generate_batch(batch_size=B)
tX_a, tY_a = gen_w2.generate_batch(batch_size=B, seed=100)
tX_b, tY_b = gen_w2.generate_batch(batch_size=B, seed=100)
self.assertTrue(torch.allclose(tX_a, tX_b), "trainX should match for same seed")
self.assertTrue(torch.allclose(tY_a, tY_b), "trainY should match for same seed")
73 changes: 73 additions & 0 deletions code/tests/test_dem_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,5 +711,78 @@ def test_generate_batch_output_distribution(self) -> None:
)


# ============================================================================
# Seeded reproducibility
# ============================================================================


class TestDEMSamplingReproducibility(unittest.TestCase):
"""
Verify that dem_sampling(seed=N) produces identical outputs on repeated
calls with the same seed, using stochastic (non-trivial) probabilities.

This is the regression test for the bug where torch.manual_seed() was
assumed to control BitMatrixSampler's internal RNG (it does not).
The fix exposes an explicit seed= parameter that re-creates the sampler,
resetting cuST's RNG to a known state.
"""

H = torch.tensor(
[[1, 0, 1, 0, 0], [0, 1, 1, 0, 0], [0, 0, 0, 1, 1]],
dtype=torch.uint8,
)
p = torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9], dtype=torch.float32)

def setUp(self):
_reset_sampler_cache()

def test_same_seed_cpu_reproducible(self) -> None:
"""Two calls with the same seed on CPU must produce identical frames."""
out_a = dem_sampling(self.H, self.p, 64, seed=42)
_reset_sampler_cache()
out_b = dem_sampling(self.H, self.p, 64, seed=42)
self.assertTrue(
torch.equal(out_a, out_b),
"dem_sampling with the same seed must be bit-exact reproducible",
)

def test_different_seeds_cpu_differ(self) -> None:
"""Two calls with different seeds should (with overwhelming probability) differ."""
out_a = dem_sampling(self.H, self.p, 256, seed=1)
_reset_sampler_cache()
out_b = dem_sampling(self.H, self.p, 256, seed=2)
self.assertFalse(
torch.equal(out_a, out_b),
"dem_sampling with different seeds should produce different frames",
)

def test_no_seed_does_not_break_cache(self) -> None:
"""Unseeded calls still reuse the cached sampler (no regression to perf path)."""
import qec.dem_sampling as _mod
_reset_sampler_cache()
dem_sampling(self.H, self.p, 64)
sampler_after_first = _mod._cached_sampler
dem_sampling(self.H, self.p, 64)
sampler_after_second = _mod._cached_sampler
self.assertIs(
sampler_after_first,
sampler_after_second,
"Unseeded calls must reuse the cached BitMatrixSampler",
)

@unittest.skipUnless(torch.cuda.is_available(), "requires CUDA GPU")
def test_same_seed_gpu_reproducible(self) -> None:
"""Two calls with the same seed on GPU must produce identical frames."""
H = self.H.cuda()
p = self.p.cuda()
out_a = dem_sampling(H, p, 64, seed=42)
_reset_sampler_cache()
out_b = dem_sampling(H, p, 64, seed=42)
self.assertTrue(
torch.equal(out_a, out_b),
"GPU: dem_sampling with the same seed must be bit-exact reproducible",
)


if __name__ == "__main__":
unittest.main()
Loading