Skip to content

Commit 5aeebdf

Browse files
kvmtobmhowe23
andauthored
Improvement cuStabilizer DEM sampling backend and consolidate tests (#32)
* Make cuStabilizer the sole DEM sampling backend and consolidate tests Remove the torch fallback path from dem_sampling.py — cuquantum's BitMatrixSampler is now the only sampling backend, simplifying the module and eliminating the USE_CUSTAB toggle. The sampler cache uses identity-based comparison with a pre-cached transpose to avoid redundant reallocation. Merge test_dem_sampling_custab.py and test_dem_sampling_integration.py into test_dem_sampling.py for a single, comprehensive test suite. Also: - Add cuquantum>=26.3.0 to requirements_public_train.txt - Fix CI to install train (not inference) requirements for GPU tests - Apply yapf formatting (Google style, 100-col limit) Signed-off-by: kvmto <kmato@nvidia.com> * fixed license Signed-off-by: kvmto <kmato@nvidia.com> * fix: use cuquantum-python-cu12 wheel to avoid pkg_resources build failure The cuquantum meta-package fails to build in environments where pkg_resources is unavailable. Pin the CUDA-12 specific wheel directly to bypass the broken auto-detection setup.py. Signed-off-by: kvmto <kmato@nvidia.com> * lazy imports for safe separation between training and inference Signed-off-by: kvmto <kmato@nvidia.com> * quick fix to CI Signed-off-by: kvmto <kmato@nvidia.com> * route cuQuantum dem_sampling tests to GPU CI Signed-off-by: kvmto <kmato@nvidia.com> * left behind change Signed-off-by: kvmto <kmato@nvidia.com> * missing bash session Signed-off-by: kvmto <kmato@nvidia.com> * Make CUDA major version specific requirements files and use custabilizer-cuXY * Revert some changes to test files that are hopefully no longer needed * Revert REQUIRE_CUQUANTUM changes * Change custabilizer version to 0.3.0 * Change custabilizer back to cuquantum-python * Skip test_dem_sampling.py if required deps are not present * Try again * Skip a few more tests if cuquantum-python not installed * Revert CUDA major version specific requirements files Since custabilizer-cuXY is not a viable standalone package, there is no need to try to make CUDA major version specific files. Rather, we just rely on the auto detection logic in cuquantum-python. * Revert "Revert CUDA major version specific requirements files" This reverts commit ce055a2. * small torch device object bug fix for nccl Signed-off-by: kvmto <kmato@nvidia.com> * overcome custab device id limitation Signed-off-by: kvmto <kmato@nvidia.com> * added tiny logging Signed-off-by: kvmto <kmato@nvidia.com> * linted Signed-off-by: kvmto <kmato@nvidia.com> * Revert "Revert "Revert CUDA major version specific requirements files"" This reverts commit 05e92f8. * Revert "Revert "Revert "Revert CUDA major version specific requirements files""" This reverts commit 84c814b. --------- Signed-off-by: kvmto <kmato@nvidia.com> Co-authored-by: Ben Howe <bhowe@nvidia.com>
1 parent 1b49b69 commit 5aeebdf

20 files changed

Lines changed: 821 additions & 410 deletions

.github/workflows/ci-gpu.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ jobs:
154154
python3.13 -m venv .venv_mid
155155
. .venv_mid/bin/activate
156156
python -m pip install --upgrade pip setuptools wheel
157-
pip install -r code/requirements_public_train.txt
157+
# TODO: matrix by CUDA major version [cu12, cu13]
158+
pip install -r code/requirements_public_train-cu12.txt
158159
159160
- name: Mid-tier training + inference with LER check (32k train, 2 epochs)
160161
shell: bash
@@ -212,7 +213,8 @@ jobs:
212213
python3 -m venv .venv_gpu_cov
213214
. .venv_gpu_cov/bin/activate
214215
python -m pip install --upgrade pip setuptools wheel
215-
pip install -r code/requirements_public_inference.txt
216+
# TODO: matrix by CUDA major version [cu12, cu13]
217+
pip install -r code/requirements_public_train-cu12.txt
216218
pip install -r code/requirements_ci.txt
217219
218220
- name: Run tests with GPU coverage

.github/workflows/long-running-tests.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ jobs:
166166
python${{ env.PYTHON_VERSION }} -m venv .venv
167167
. .venv/bin/activate
168168
python -m pip install --upgrade pip setuptools wheel
169-
pip install -r code/requirements_public_train.txt
169+
# TODO: matrix by CUDA major version [cu12, cu13]
170+
pip install -r code/requirements_public_train-cu12.txt
170171
171172
- name: Verify GPU
172173
run: |
@@ -302,7 +303,8 @@ jobs:
302303
python${{ env.PYTHON_VERSION }} -m venv .venv
303304
. .venv/bin/activate
304305
python -m pip install --upgrade pip setuptools wheel
305-
pip install -r code/requirements_public_train.txt
306+
# TODO: matrix by CUDA major version [cu12, cu13]
307+
pip install -r code/requirements_public_train-cu12.txt
306308
307309
- name: Verify GPU
308310
run: |

Dockerfile

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,16 @@ RUN python${PYTHON_VERSION} -m venv /opt/venv
3535
ENV PATH="/opt/venv/bin:$PATH"
3636

3737
COPY code/requirements_public_inference.txt /tmp/requirements_public_inference.txt
38-
COPY code/requirements_public_train.txt /tmp/requirements_public_train.txt
39-
40-
RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
38+
COPY code/requirements_public_train-cu*.txt /tmp/
39+
40+
# Derive the CUDA major version from the base image's $CUDA_VERSION env var
41+
# (e.g. "12.1.0" -> "12") and install the matching requirements file.
42+
RUN CUDA_MAJOR_VERSION=$(echo "${CUDA_VERSION}" | cut -d. -f1) && \
43+
echo "Detected CUDA major version: ${CUDA_MAJOR_VERSION}" && \
44+
echo "export CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}" >> /etc/bash.bashrc && \
45+
pip install --no-cache-dir --upgrade pip setuptools wheel && \
4146
pip install --no-cache-dir \
42-
-r /tmp/requirements_public_train.txt \
47+
-r /tmp/requirements_public_train-cu${CUDA_MAJOR_VERSION}.txt \
4348
--index-url "https://download.pytorch.org/whl/${TORCH_CUDA}" \
4449
--extra-index-url https://pypi.org/simple && \
4550
python -c "import torch; print('PyTorch', torch.__version__, '(CUDA build:', torch.version.cuda, ')')"

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Target Python versions: **3.11, 3.12, 3.13**.
2525
Two minimal requirements files are provided:
2626

2727
- `code/requirements_public_inference.txt` (Stim + PyTorch path)
28-
- `code/requirements_public_train.txt` (training path)
28+
- `code/requirements_public_train-cuXY.txt` (training path, where XY = 12 or 13)
2929

3030
Install examples (virtual environment is optional but recommended):
3131

@@ -41,8 +41,8 @@ export TORCH_CUDA=cu130
4141
# Inference-only (training install is a superset)
4242
pip install -r code/requirements_public_inference.txt
4343

44-
# Training (includes inference deps)
45-
pip install -r code/requirements_public_train.txt
44+
# Training (includes inference deps, adjust to cu13 as appropriate)
45+
pip install -r code/requirements_public_train-cu12.txt
4646

4747
bash code/scripts/check_python_compat.sh
4848
```

code/qec/dem_sampling.py

Lines changed: 73 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,15 @@
1515
"""
1616
DEM sampling utilities for training data generation.
1717
18-
When cuQuantum's cuStabilizer (BitMatrixSampler) is installed the sampling
19-
runs on the GPU via the cuST sparse sampler with optional CuPy zero-copy
20-
DLPack transfers. When cuST is absent or disabled (USE_CUSTAB=0) the module
21-
falls back to a pure-torch implementation.
18+
Sampling runs on the GPU via cuQuantum's cuStabilizer BitMatrixSampler with
19+
optional CuPy zero-copy DLPack transfers. cuquantum>=26.3.0 is required.
2220
2321
This module provides the sampling functions needed by MemoryCircuitTorch
2422
to generate training batches from precomputed DEM matrices (H, p, A).
2523
"""
2624

2725
from __future__ import annotations
2826

29-
import os
3027
import time
3128
from collections import deque
3229

@@ -38,6 +35,9 @@
3835
from cuquantum.stabilizer.simulator import Options
3936
_CUSTAB_AVAILABLE = True
4037
except ImportError:
38+
# This should only happen if cuquantum is not installed. That is expected
39+
# for some test environments that don't need DEM sampling, so handle that
40+
# gracefully here.
4141
BitMatrixSampler = None # type: ignore[misc, assignment]
4242
Options = None # type: ignore[misc, assignment]
4343
_CUSTAB_AVAILABLE = False
@@ -54,14 +54,14 @@ def _custab_available() -> bool:
5454
return _CUSTAB_AVAILABLE
5555

5656

57-
_cached_sampler: "BitMatrixSampler | None" = None
58-
_cached_H_id: int | None = None
57+
_cached_sampler = None
58+
_cached_H: "torch.Tensor | None" = None
59+
_cached_HT: "torch.Tensor | None" = None
5960
_cached_max_shots: int = 0
61+
_cached_device_id: int | None = None
6062

6163
_DEM_TIMINGS_S: deque[float] = deque(maxlen=200)
62-
_use_custab_cached: bool | None = None
6364
_custab_path_logged: bool = False
64-
_fallback_path_logged: bool = False
6565

6666
_MIN_MAX_SHOTS = 1024
6767

@@ -75,42 +75,75 @@ def get_dem_sampling_avg_ms() -> float:
7575

7676
def _reset_sampler_cache() -> None:
7777
"""Reset the module-level sampler cache."""
78-
global _cached_sampler, _cached_H_id, _cached_max_shots
78+
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots, _cached_device_id
7979
_cached_sampler = None
80-
_cached_H_id = None
80+
_cached_H = None
81+
_cached_HT = None
8182
_cached_max_shots = 0
83+
_cached_device_id = None
8284

8385

84-
def custab_matrix_sampling(
85-
H: torch.Tensor, p: torch.Tensor, batch_size: int, device_id: int = 0
86+
def dem_sampling(
87+
H: torch.Tensor,
88+
p: torch.Tensor,
89+
batch_size: int,
90+
device_id: int | None = None
8691
) -> torch.Tensor:
8792
"""
88-
Sample from a DEM using cuST BitMatrixSampler. H must be (errors, result) layout.
93+
Sample errors from a detector error model (DEM) via cuST BitMatrixSampler.
94+
95+
Args:
96+
H: (2*num_detectors, num_errors) uint8 - Detector-error incidence matrix
97+
p: (num_errors,) float32 - Per-error probabilities
98+
batch_size: int - Number of samples to generate
99+
device_id: Optional int - Device ID for cuST. If omitted, infer from
100+
H.device when H is on CUDA.
89101
90-
When CuPy is available the entire pipeline stays on GPU:
91-
torch CUDA -> CuPy (zero-copy DLPack) -> cuStabilizer -> CuPy -> torch (zero-copy DLPack)
102+
Returns:
103+
frames_xz: (batch_size, 2*num_detectors) uint8 - Detector outcomes
92104
"""
93-
if not _CUSTAB_AVAILABLE or BitMatrixSampler is None or Options is None:
94-
raise RuntimeError("custab_matrix_sampling requires cuquantum.stabilizer")
105+
from cuquantum.stabilizer.dem_sampling import BitMatrixSampler
106+
from cuquantum.stabilizer.simulator import Options
107+
108+
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots
109+
global _cached_device_id, _custab_path_logged
110+
111+
if H.ndim != 2:
112+
raise ValueError(f"H must be 2-D, got ndim={H.ndim}")
113+
if p.ndim != 1:
114+
raise ValueError(f"p must be 1-D, got ndim={p.ndim}")
115+
if H.shape[1] != p.shape[0]:
116+
raise ValueError(f"H has {H.shape[1]} columns but p has {p.shape[0]} entries")
117+
118+
if device_id is None:
119+
if H.is_cuda:
120+
device_index = H.device.index
121+
device_id = int(torch.cuda.current_device() if device_index is None else device_index)
122+
else:
123+
device_id = 0
124+
125+
gpu_native = _CUPY_AVAILABLE and H.is_cuda
95126

96-
global _cached_sampler, _cached_H_id, _cached_max_shots, _custab_path_logged
127+
if _cached_H is not H:
128+
_cached_HT = H.T
129+
_cached_H = H
130+
_cached_sampler = None
131+
_cached_device_id = None
97132

98-
# id(H) is the tensor's memory address — fast but not content-based.
99-
# Safe in training loops where H is a long-lived tensor; a content hash
100-
# (like cuda-qx-g uses) would be more robust but slower on every call.
101-
H_id = id(H)
102-
need_new = (_cached_sampler is None or _cached_H_id != H_id or batch_size > _cached_max_shots)
133+
need_new = (
134+
_cached_sampler is None or batch_size > _cached_max_shots or _cached_device_id != device_id
135+
)
103136

104137
if need_new:
105138
max_shots = max(batch_size, _MIN_MAX_SHOTS)
106-
gpu_native = _CUPY_AVAILABLE and H.is_cuda
107139
if gpu_native:
108140
import cupy as cp
109-
H_in = cp.from_dlpack(H.detach())
110-
p_in = cp.from_dlpack(p.detach().to(torch.float64))
141+
with cp.cuda.Device(device_id):
142+
H_in = cp.from_dlpack(_cached_HT.detach())
143+
p_in = cp.from_dlpack(p.detach().to(torch.float64))
111144
pkg = "cupy"
112145
else:
113-
H_in = H.detach().cpu().numpy().astype(np.uint8)
146+
H_in = _cached_HT.detach().cpu().numpy().astype(np.uint8)
114147
p_in = p.detach().cpu().numpy().astype(np.float64)
115148
pkg = "numpy"
116149
_cached_sampler = BitMatrixSampler(
@@ -120,98 +153,34 @@ def custab_matrix_sampling(
120153
package=pkg,
121154
options=Options(device_id=device_id),
122155
)
123-
_cached_H_id = H_id
124156
_cached_max_shots = max_shots
157+
_cached_device_id = device_id
125158

126-
_cached_sampler.sample(batch_size)
127-
128-
out = _cached_sampler.get_outcomes(bit_packed=False)
159+
t0 = time.perf_counter()
160+
if gpu_native:
161+
import cupy as cp
162+
with cp.cuda.Device(device_id):
163+
_cached_sampler.sample(batch_size)
164+
out = _cached_sampler.get_outcomes(bit_packed=False)
165+
else:
166+
_cached_sampler.sample(batch_size)
167+
out = _cached_sampler.get_outcomes(bit_packed=False)
129168
if isinstance(out, np.ndarray):
130169
out = torch.as_tensor(out, device=H.device).to(dtype=torch.uint8)
131170
else:
132171
out = torch.from_dlpack(out).to(dtype=torch.uint8)
172+
_DEM_TIMINGS_S.append(time.perf_counter() - t0)
133173

134174
if not _custab_path_logged:
135175
print(
136176
f"---- [dem_sampling] Using cuST BitMatrixSampler path "
137-
f"(max_shots={_cached_max_shots}, gpu_native={_CUPY_AVAILABLE})"
177+
f"(max_shots={_cached_max_shots}, gpu_native={gpu_native}, device_id={device_id})"
138178
)
139179
_custab_path_logged = True
140180

141181
return out
142182

143183

144-
def _use_custab() -> bool:
145-
"""Use cuST if available and not disabled by USE_CUSTAB=0. Cached after first call."""
146-
global _use_custab_cached
147-
if _use_custab_cached is None:
148-
if not _CUSTAB_AVAILABLE:
149-
_use_custab_cached = False
150-
else:
151-
v = os.environ.get("USE_CUSTAB", "1").strip().lower()
152-
_use_custab_cached = v not in ("0", "false", "no", "off")
153-
return _use_custab_cached
154-
155-
156-
def _reset_use_custab_cache() -> None:
157-
"""Reset the _use_custab cache (e.g. after changing USE_CUSTAB in tests)."""
158-
global _use_custab_cached
159-
_use_custab_cached = None
160-
161-
162-
def dem_sampling(
163-
H: torch.Tensor, p: torch.Tensor, batch_size: int, device_id: int = 0
164-
) -> torch.Tensor:
165-
"""
166-
Sample errors from a detector error model (DEM) using precomputed H and p matrices.
167-
Uses cuST BitMatrixSampler when available; if cuST is not present or USE_CUSTAB=0,
168-
uses the torch fallback.
169-
170-
Args:
171-
H: (2*num_detectors, num_errors) uint8 - Detector-error incidence matrix
172-
p: (num_errors,) float32 - Per-error probabilities
173-
batch_size: int - Number of samples to generate
174-
device_id: int - Device ID for cuST (ignored by torch path).
175-
176-
Returns:
177-
frames_xz: (batch_size, 2*num_detectors) uint8 - Detector outcomes
178-
"""
179-
if H.ndim != 2:
180-
raise ValueError(f"H must be 2-D, got ndim={H.ndim}")
181-
if p.ndim != 1:
182-
raise ValueError(f"p must be 1-D, got ndim={p.ndim}")
183-
if H.shape[1] != p.shape[0]:
184-
raise ValueError(f"H has {H.shape[1]} columns but p has {p.shape[0]} entries")
185-
186-
global _fallback_path_logged
187-
t0 = time.perf_counter()
188-
189-
if _use_custab():
190-
# cuST expects (errors, result); dem_sampling H is (result, errors) -> pass H.T
191-
out = custab_matrix_sampling(H.T, p, batch_size, device_id)
192-
_DEM_TIMINGS_S.append(time.perf_counter() - t0)
193-
return out
194-
195-
num_errors = int(H.shape[1])
196-
device = H.device
197-
198-
# Sample errors according to their probabilities (independent Bernoulli)
199-
rand_vals = torch.rand(batch_size, num_errors, device=device, dtype=torch.float32)
200-
errors = (rand_vals < p[None, :]).to(torch.uint8) # (batch_size, num_errors)
201-
202-
# Matrix multiply H @ errors^T to get detector outcomes
203-
# H is (2*num_detectors, num_errors), errors is (batch_size, num_errors)
204-
frames_xz = torch.matmul(errors.to(torch.float32), H.T.to(torch.float32))
205-
frames_xz = frames_xz.to(torch.uint8) % 2 # Binary GF(2) arithmetic
206-
207-
_DEM_TIMINGS_S.append(time.perf_counter() - t0)
208-
if not _fallback_path_logged:
209-
print("Used fallback torch path for dem_sampling")
210-
_fallback_path_logged = True
211-
212-
return frames_xz
213-
214-
215184
def measure_from_stacked_frames(
216185
frames_xz: torch.Tensor,
217186
meas_qubits: torch.Tensor,

code/qec/surface_code/homological_equivalence_torch.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,10 @@ def _simplify_time_w1_step_nobreak(
334334
return err_out, syn_out
335335

336336

337+
_INT8_GEMM_OK: dict[str, bool] = {}
338+
_INT8_GEMM_WARNED: set[str] = set()
339+
340+
337341
def _weight_reduction(cfg: torch.Tensor, cache: SpacelikeHECache) -> torch.Tensor:
338342
"""
339343
Weight reduction (parallel within disjoint stabilizer layers).
@@ -355,7 +359,11 @@ def _weight_reduction(cfg: torch.Tensor, cache: SpacelikeHECache) -> torch.Tenso
355359
# at most 4 (stabilizer support size) and act1/act2 are bool→int8 with at
356360
# most L ones, so intermediate sums stay well within int8 range as long as
357361
# L < 128 (true for practical surface code distances).
358-
_use_int8 = True
362+
#
363+
# _INT8_GEMM_OK caches per-device results so after one failure on a given
364+
# device we skip int8 permanently (no repeated exceptions / warnings).
365+
dev_key = str(cfg.device)
366+
_use_int8 = _INT8_GEMM_OK.get(dev_key, True)
359367

360368
for layer_idx in cache.layers:
361369
if layer_idx.numel() == 0:
@@ -372,12 +380,16 @@ def _weight_reduction(cfg: torch.Tensor, cache: SpacelikeHECache) -> torch.Tenso
372380
flip_mask = ((act2.to(torch.int8) @ masks_i8).to(torch.int32)
373381
> 0) & (~set_to_zero_mask)
374382
except RuntimeError as exc:
375-
warnings.warn(
376-
f"Int8 GEMM failed, falling back to float32 for weight reduction: {exc}",
377-
RuntimeWarning,
378-
stacklevel=2,
379-
)
380383
_use_int8 = False
384+
_INT8_GEMM_OK[dev_key] = False
385+
if dev_key not in _INT8_GEMM_WARNED:
386+
_INT8_GEMM_WARNED.add(dev_key)
387+
warnings.warn(
388+
f"Int8 GEMM failed on {dev_key}, permanently falling back to "
389+
f"float32 for weight reduction: {exc}",
390+
RuntimeWarning,
391+
stacklevel=2,
392+
)
381393
masks_f = cache.support_masks.to(torch.float32).index_select(0, layer_idx)
382394
error_counts = (cfg.to(torch.float32) @ masks_f.t()).to(torch.int32)
383395
act1 = (error_counts == 4) | ((error_counts == 2) & (sizes.unsqueeze(0) == 2))
@@ -396,6 +408,9 @@ def _weight_reduction(cfg: torch.Tensor, cache: SpacelikeHECache) -> torch.Tenso
396408
cfg = cfg ^ flip_mask.to(cfg.dtype)
397409
cfg_i8 = cfg.to(torch.int8)
398410

411+
if _use_int8 and dev_key not in _INT8_GEMM_OK:
412+
_INT8_GEMM_OK[dev_key] = True
413+
399414
return cfg
400415

401416

0 commit comments

Comments
 (0)