diff --git a/README.md b/README.md index 4747f16..76222a2 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ The public release exposes a **single user-facing config** and a **single runner - [Converting .pt checkpoints to SafeTensors](#converting-pt-checkpoints-to-safetensors-optional-post-training) - [ONNX export and quantization](#onnx-export-and-quantization-optional-post-training) - [Generating data for CUDA-Q QEC](#generating-data-for-cuda-q-qec-realtime-predecoder-test-application) + - [Offline decoding from Stim detector samples](#offline-decoding-from-stim-detector-samples) - [Decoder ablation study with cudaq-qec](#decoder-ablation-study-with-cudaq-qec-optional) - [Configuration and advanced usage](#configuration-and-advanced-usage) - [GPU selection](#gpu-selection) @@ -359,6 +360,205 @@ Done. pymatching_predictions.bin 40,008 bytes ``` +### Offline decoding from Stim detector samples + +This is the file-based path for decoding detector samples produced outside the +in-memory simulator. It exists for two distinct audiences: + +1. **You already have detector samples** (from a QPU, a third-party simulator, + or a previously cached run) and want to feed them to the same decoders we + ship. Jump to [Bring your own detector samples](#bring-your-own-detector-samples). +2. **You want a reproducible end-to-end smoke test.** Use the local + generator below, then run the same decode commands. + +#### File contract + +Each basis is exactly two files: + +```text +/ + samples_X.dets # Stim sparse detector-sample format + metadata_X.json # circuit + noise fingerprint + samples_Z.dets + metadata_Z.json +``` + +`samples_*.dets` uses Stim's sparse format with logical observables appended, +so a line `shot D3 D8 L0` says detectors 3 and 8 fired and logical observable +0 flipped on that shot. Stim does not encode the memory basis in the sample +format, so X and Z always live in separate files; the LER loop iterates over +both when `cfg.test.meas_basis_test=both`. The resolver +(`resolve_stim_sample_paths`) also accepts the alternate layouts +`//samples.dets` + `metadata.json` and the flat +`/samples.dets` + `metadata.json`. + +The metadata JSON has the shape that +`qec.surface_code.stim_sample_io.build_stim_sample_metadata` writes: + +```json +{ + "schema_version": 2, + "artifact": "stim_detector_samples", + "format": "dets", + "append_observables": true, + "distance": 7, + "n_rounds": 7, + "basis": "X", + "code_rotation": "XV", + "num_detectors": 168, + "num_observables": 1, + "num_shots": 262144, + "p_error": 0.003, + "noise_model": "25-param", + "noise_model_sha256": "abcd…", + "noise_model_params": { "p_prep_X": 0.002, "...": 0.0 } +} +``` + +`p_error`, `noise_model`, `noise_model_sha256`, and `noise_model_params` are +optional but recommended; when present, the decoder cross-checks its active +noise model against the recorded fingerprint and raises by default if the two +disagree. Files written before this schema (no noise fields) keep loading +unchanged. `code_rotation` accepts both the canonical names (`XV`, `XH`, `ZV`, +`ZH`) and the public aliases (`O1`..`O4`). + +#### Bring your own detector samples + +If you have `.dets` data from elsewhere (a QPU, an external simulator), the +contract is exactly the three things above: + +1. Write `samples_{basis}.dets` in Stim's sparse format with observables + appended. +2. Write `metadata_{basis}.json` matching the schema above. The easiest way is + to call `build_stim_sample_metadata(...)` and `write_metadata_json(...)` + from `qec.surface_code.stim_sample_io`; you can also hand-author it. +3. Make sure `conf/config_public.yaml` reflects the experiment your samples + came from: `distance`, `n_rounds`, `data.code_rotation`, and + `data.noise_model` must match exactly. The decoder rebuilds a Stim memory + circuit from these and validates the file against it before decoding. + +Then point the launcher at the directory: + +```bash +PREDECODER_STIM_SAMPLES_DIR=/path/to/your/dets \ +PREDECODER_DECODE_MODE=pymatching_only \ +WORKFLOW=inference bash code/scripts/local_run.sh +``` + +Validation is strict by default: mismatches in distance, rounds, basis, +orientation, detector count, observable presence, `p_error`, or +`noise_model_sha256` raise with one explicit error per mismatch before any +decoding happens. To downgrade only the **noise** mismatches to warnings (for +example when sweeping `p_error` for a calibration study), set +`PREDECODER_STIM_STRICT_NOISE=0`. Structural mismatches are always fatal. + +#### Generate local reference files + +```bash +WORKFLOW=generate_stim_data \ +EXPERIMENT_NAME=offline_stim_run \ +bash code/scripts/local_run.sh +``` + +The generator reads from `conf/config_public.yaml`: + +| config field | role | +| --- | --- | +| `distance` | surface-code distance | +| `n_rounds` | number of measurement rounds | +| `data.code_rotation` | code orientation (`XV`/`XH`/`ZV`/`ZH` or `O1`..`O4`) | +| `data.noise_model` | 25-parameter noise model dict (optional) | +| `test.meas_basis_test` | `X`, `Z`, or `both` (default `both`) | +| `test.p_error` | scalar noise level (default `0.003`) | +| `test.num_samples` | shots per basis (default `262144`, ~20 MB per file) | + +The default sample count is large because the smoke run targets LER stable to +~3 significant digits; override `++test.num_samples=N` (or set the field in a +local override config) to shrink it for a faster iteration. Output goes to: + +```text +outputs/offline_stim_run/stim_samples/samples_X.dets +outputs/offline_stim_run/stim_samples/metadata_X.json +outputs/offline_stim_run/stim_samples/samples_Z.dets +outputs/offline_stim_run/stim_samples/metadata_Z.json +``` + +The `generate_stim_data` workflow writes only the Stim sample artifacts. The +CUDA-Q `.bin` artifacts (`detectors.bin`, `H_csr.bin`, etc.) live in a +separate output dir and are produced by `python code/export/generate_test_data.py` +directly; see [the CUDA-Q section](#generating-data-for-cuda-q-qec-realtime-predecoder-test-application). + +#### Decode the files + +PyMatching only — useful as the apples-to-apples baseline to compare against +the Ising pre-decoder. In this mode the launcher replaces the neural model +with `torch.nn.Identity()` and **no checkpoint is required**: + +```bash +PREDECODER_STIM_SAMPLES_DIR=outputs/offline_stim_run/stim_samples \ +PREDECODER_DECODE_MODE=pymatching_only \ +WORKFLOW=inference bash code/scripts/local_run.sh +``` + +Ising pre-decoder followed by PyMatching — **requires a model checkpoint.** +Point `PREDECODER_MODEL_CHECKPOINT_FILE` (or `model_checkpoint_file` in the +config) at one of the released models, or run training under the same +`EXPERIMENT_NAME` first: + +```bash +PREDECODER_STIM_SAMPLES_DIR=outputs/offline_stim_run/stim_samples \ +PREDECODER_DECODE_MODE=ising_decoding_pymatching \ +EXTRA_PARAMS="++model_checkpoint_file=models/Ising-Decoder-SurfaceCode-1-Fast.pt" \ +WORKFLOW=inference bash code/scripts/local_run.sh +``` + +No changes to `conf/config_public.yaml` are required for either command; the +existing config controls the model, distance, rounds, orientation, and noise +model, and the Stim file metadata is checked against the rebuilt circuit +before decoding. + +To persist the per-shot comparison arrays, also set: + +```bash +PREDECODER_DECODE_OUTPUT_DIR=offline_decode_outputs +``` + +With that set, `pymatching_only` writes: + +* `{basis}_observables.npy` +* `{basis}_pymatching_predictions.npy` + +…and `ising_decoding_pymatching` writes those plus: + +* `{basis}_predecoder_residual_detectors.npy` +* `{basis}_ising_decoding_pymatching_predictions.npy` + +The directory is created lazily on the first write, so it is safe to point at +a path that does not yet exist. + +#### Smoke script + +```bash +code/scripts/offline_smoketest.sh +``` + +The script defaults `EXPERIMENT_NAME=offline_stim_run` (matching the example +paths above), generates Stim files, decodes with `pymatching_only`, and (if +`models/Ising-Decoder-SurfaceCode-1-Fast.pt` is on disk) decodes again with +`ising_decoding_pymatching`. It then parses a structured +`[Inference Summary]` JSON marker that the inference loop emits on the last +line of its summary block. The marker is **off by default** to keep +interactive and notebook runs clean; the smoketest opts in by exporting +`PREDECODER_EMIT_INFERENCE_SUMMARY=1` before each inference call. Set the same +env var yourself if you want to pipe these results into other tooling. + +Example output from one `d=7`, `n_rounds=7`, `O1`, `262,144` shots per basis +run is shown below. Treat timing/speedup as a smoke signal, not a benchmark: + +```text +[offline_smoketest.sh] Avg LER 0.002678 (no pre-decoder) -> 0.002285 (after); PyMatching speedup 1.815x +``` + ### Decoder ablation study with cudaq-qec (optional) The `decoder_ablation` workflow compares multiple global decoders on the residual syndromes left diff --git a/code/data/datapipe_stim.py b/code/data/datapipe_stim.py index 25cb73d..ea3cd3d 100644 --- a/code/data/datapipe_stim.py +++ b/code/data/datapipe_stim.py @@ -25,12 +25,14 @@ from torch.utils.data import Dataset from qec.surface_code.memory_circuit import MemoryCircuit +from qec.surface_code.stim_sample_io import read_stim_detector_samples, resolve_stim_sample_paths from qec.surface_code.data_mapping import ( normalized_weight_mapping_Xstab_memory, normalized_weight_mapping_Zstab_memory, compute_stabX_to_data_index_map, compute_stabZ_to_data_index_map, ) +from data.predecoder_transform import dets_to_predecoder_inputs class QCDataPipePreDecoder_Memory_inference(Dataset): @@ -94,7 +96,7 @@ def __init__( self._presence_x_Z[:, 0] = 0 self._presence_x_Z[:, -1] = 0 - # If using explicit noise model, use a conservative scalar placeholder for MemoryCircuit's legacy slots. + # If using explicit noise model, use a conservative scalar placeholder for MemoryCircuit's scalar-rate slots. if noise_model is not None: p_placeholder = float(noise_model.get_max_probability()) else: @@ -380,4 +382,141 @@ def __getitem__(self, idx): } -__all__ = ['QCDataPipePreDecoder_Memory_inference'] +class QCDataPipePreDecoder_Memory_from_stim_file(Dataset): + """ + Datapipe for offline inference from Stim detector-sample files. + + The file stores detector events plus appended observables. Metadata is + validated against a freshly rebuilt MemoryCircuit before data is exposed. + + Noise-model validation: when ``noise_model`` is provided (the typical + inference path), the datapipe computes a deterministic fingerprint of its + 25-parameter dict and asks :func:`read_stim_detector_samples` to compare it + against the value recorded in the JSON metadata. Mismatches raise unless + ``strict_noise`` is ``False`` (in which case a warning is emitted). When + ``noise_model`` is ``None``, only the scalar ``p_error`` is checked. + + Args: + distance, n_rounds, num_samples, error_mode, measure_basis, + code_rotation: Standard circuit parameters; ``num_samples`` may + truncate the loaded file to the first N shots when positive. + stim_samples_dir: Directory containing ``samples_{basis}.dets`` and + ``metadata_{basis}.json``. + p_error: Scalar physical error rate used by the active config. Compared + against ``metadata['p_error']`` when present. + noise_model: Optional explicit :class:`NoiseModel`. When set, its + ``sha256()`` is compared against ``metadata['noise_model_sha256']``. + strict_noise: ``True`` (default) raises on noise-fingerprint drift; + ``False`` downgrades the failure to a :class:`UserWarning`. + """ + + def __init__( + self, + distance, + n_rounds, + num_samples, + error_mode, + stim_samples_dir, + p_error=0.005, + measure_basis='X', + code_rotation='XV', + noise_model=None, + strict_noise: bool = True, + ): + self.distance = int(distance) + self.n_rounds = max(int(n_rounds), 1) + self.measure_basis = str(measure_basis).upper() + self.code_rotation = code_rotation.upper() if code_rotation else 'XV' + self.requested_num_samples = int(num_samples) if num_samples is not None else 0 + + if self.measure_basis not in ("X", "Z"): + raise ValueError( + "Stim file datapipe expects one basis at a time. " + f"Got measure_basis={measure_basis!r}." + ) + if error_mode != "circuit_level_surface_custom": + raise ValueError("error_mode not supported") + + D = self.distance + if noise_model is not None: + p_placeholder = float(noise_model.get_max_probability()) + noise_sha = noise_model.sha256() + noise_label = "25-param" + else: + p_placeholder = float(p_error) + noise_sha = None + noise_label = "simple" + + self.circ = MemoryCircuit( + distance=D, + idle_error=p_placeholder, + sqgate_error=p_placeholder, + tqgate_error=p_placeholder, + spam_error=(2.0 / 3.0) * p_placeholder, + n_rounds=self.n_rounds, + basis=self.measure_basis, + code_rotation=self.code_rotation, + noise_model=noise_model, + add_boundary_detectors=True, + ) + self.circ.set_error_rates() + + samples_path, metadata_path = resolve_stim_sample_paths( + stim_samples_dir, self.measure_basis + ) + dets_and_obs, metadata = read_stim_detector_samples( + samples_path=samples_path, + metadata_path=metadata_path, + distance=self.distance, + n_rounds=self.n_rounds, + basis=self.measure_basis, + code_rotation=self.code_rotation, + num_detectors=self.circ.stim_circuit.num_detectors, + num_observables=self.circ.stim_circuit.num_observables, + p_error=float(p_error), + noise_model_sha256=noise_sha, + noise_model_label=noise_label, + strict_noise=bool(strict_noise), + ) + if self.requested_num_samples > 0: + dets_and_obs = dets_and_obs[:self.requested_num_samples] + + self.samples_path = samples_path + self.metadata_path = metadata_path + self.metadata = metadata + self.dets_and_obs = torch.from_numpy(dets_and_obs).to(torch.uint8).contiguous() + self.num_samples = int(self.dets_and_obs.shape[0]) + self._half = (D * D - 1) // 2 + + self._precompute_transformations_from_dets() + + def _precompute_transformations_from_dets(self): + num_obs = self.circ.stim_circuit.num_observables + dets = self.dets_and_obs[:, :-num_obs].contiguous() + train_x, x_syn_diff, z_syn_diff = dets_to_predecoder_inputs( + dets, + distance=self.distance, + n_rounds=self.n_rounds, + basis=self.measure_basis, + code_rotation=self.code_rotation, + ) + self.x_syn_diff_all = x_syn_diff + self.z_syn_diff_all = z_syn_diff + self.trainX_all = train_x + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "x_syn_diff": self.x_syn_diff_all[idx], + "z_syn_diff": self.z_syn_diff_all[idx], + "trainX": self.trainX_all[idx], + "dets_and_obs": self.dets_and_obs[idx], + } + + +__all__ = [ + 'QCDataPipePreDecoder_Memory_inference', + 'QCDataPipePreDecoder_Memory_from_stim_file', +] diff --git a/code/data/factory.py b/code/data/factory.py index 8b876ec..ec3749f 100644 --- a/code/data/factory.py +++ b/code/data/factory.py @@ -17,6 +17,7 @@ Provides DatapipeFactory for instantiating data generators/datapipes from config. """ +import os import torch @@ -67,14 +68,17 @@ def _create_surface_datapipe_inference(cfg): Datapipe for inference using Stim. """ if cfg.datapipe == "memory": - from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + from data.datapipe_stim import ( + QCDataPipePreDecoder_Memory_from_stim_file, + QCDataPipePreDecoder_Memory_inference, + ) from qec.noise_model import NoiseModel error_mode_value = getattr(cfg.data, 'error_mode', 'circuit_level_surface_custom') code_rotation = getattr(cfg.data, 'code_rotation', 'XV') # Test-time noise model selection: # - cfg.test.noise_model='train': use cfg.data.noise_model (if present) - # - cfg.test.noise_model='none': ignore cfg.data.noise_model, use cfg.test.p_error (legacy single-p) + # - cfg.test.noise_model='none': ignore cfg.data.noise_model, use cfg.test.p_error (single-p) # Takes priority over cfg.test.p_error. test_nm_mode = getattr(getattr(cfg, "test", None), "noise_model", None) if test_nm_mode is None: @@ -100,24 +104,48 @@ def _create_surface_datapipe_inference(cfg): ) # Fail fast: if the user provided an explicit 25p noise model and asked to use it, - # do not silently fall back to legacy p_error-based generation. + # do not silently fall back to p_error-based generation. if test_nm_mode == "train" and getattr( cfg.data, "noise_model", None ) is not None and noise_model_obj is None: raise ValueError( "cfg.test.noise_model='train' but failed to construct NoiseModel from cfg.data.noise_model. " - "Refusing to fall back to legacy cfg.test.p_error." + "Refusing to fall back to cfg.test.p_error." ) - test_dataset = QCDataPipePreDecoder_Memory_inference( + stim_samples_dir = os.environ.get("PREDECODER_STIM_SAMPLES_DIR", "").strip() + if not stim_samples_dir: + stim_samples_dir = str(getattr(cfg.test, "stim_samples_dir", "") or "").strip() + measure_basis = cfg.test.meas_basis_test + if stim_samples_dir and str(measure_basis).upper() in ("BOTH", "MIXED"): + # The file pipe holds one basis per instance; callers that need + # both bases re-instantiate per basis (the LER loop does this). + # Pick X for shape probing here; choice does not affect tensor + # shapes because they only depend on (distance, n_rounds). + measure_basis = "X" + + dataset_kwargs: dict = {} + if stim_samples_dir: + dataset_cls = QCDataPipePreDecoder_Memory_from_stim_file + dataset_kwargs["stim_samples_dir"] = stim_samples_dir + # PREDECODER_STIM_STRICT_NOISE=0 downgrades the noise-fingerprint + # mismatch from an error to a UserWarning. Default is strict. + strict_env = os.environ.get("PREDECODER_STIM_STRICT_NOISE", "").strip().lower() + if strict_env in ("0", "false", "no", "off"): + dataset_kwargs["strict_noise"] = False + else: + dataset_cls = QCDataPipePreDecoder_Memory_inference + + test_dataset = dataset_cls( distance=cfg.distance, n_rounds=cfg.n_rounds, num_samples=cfg.test.num_samples, error_mode=error_mode_value, p_error=cfg.test.p_error, - measure_basis=cfg.test.meas_basis_test, + measure_basis=measure_basis, code_rotation=code_rotation, noise_model=noise_model_obj, + **dataset_kwargs, ) return test_dataset else: diff --git a/code/data/predecoder_transform.py b/code/data/predecoder_transform.py new file mode 100644 index 0000000..8dac63e --- /dev/null +++ b/code/data/predecoder_transform.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Canonical transformation from Stim detector samples to pre-decoder inputs. + +Single source of truth for converting Stim detector bits into the +``(trainX, x_syn_diff, z_syn_diff)`` tuple that the pre-decoder consumes. +:func:`dets_to_predecoder_inputs` (file-based datapipe) and +:meth:`evaluation.logical_error_rate.PreDecoderMemoryEvalModule._batch_to_trainx_and_syndromes` +(GPU/ONNX/TensorRT export path) both delegate to +:func:`_predecoder_transform_core`; the latter pre-registers the same buffers +the helper rebuilds ad-hoc per call. + +Input contract +-------------- +- ``dets``: ``(B, 2 * T * half)`` uint8/int tensor where ``half = (D*D - 1)//2`` + and detectors are stored in Stim's emission order, two ``half``-sized groups + per round (X-stab group followed by Z-stab group). +- ``basis``: ``'X'`` or ``'Z'`` — the memory basis being decoded. +- ``code_rotation``: one of ``'XV'``, ``'XH'``, ``'ZV'``, ``'ZH'``. + +Output +------ +- ``trainX``: ``(B, 4, T, D, D)`` float32 — channels are + ``[x_syn_grid, z_syn_grid, x_present, z_present]``. +- ``x_syn_diff``, ``z_syn_diff``: ``(B, half, T)`` int32 — boundary-masked + syndrome differences, suitable for residual-syndrome arithmetic. +""" + +from __future__ import annotations + +from typing import Tuple + +import torch + +from qec.surface_code.data_mapping import ( + compute_stabX_to_data_index_map, + compute_stabZ_to_data_index_map, + normalized_weight_mapping_Xstab_memory, + normalized_weight_mapping_Zstab_memory, +) + + +def _build_scatter_perm(idx_map: torch.Tensor, D2: int, half: int) -> torch.Tensor: + """Invert ``idx_map`` into a length-``D2`` permutation whose missing entries + point at ``half`` (a sentinel column that callers keep all-zero).""" + perm = torch.full((D2,), half, dtype=torch.long, device=idx_map.device) + perm[idx_map] = torch.arange(idx_map.shape[0], dtype=torch.long, device=idx_map.device) + return perm + + +def _predecoder_transform_core( + dets: torch.Tensor, + *, + D: int, + T: int, + half: int, + basis: str, + scatter_perm_x: torch.Tensor, + scatter_perm_z: torch.Tensor, + w_mapXgrid: torch.Tensor, + w_mapZgrid: torch.Tensor, + zero_pad_row: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Buffer-fused dets → ``(trainX, x_syn_diff, z_syn_diff)`` transform. + + Returns ``x_syn_diff`` and ``z_syn_diff`` as float32; callers that need + int32 cast at the boundary. The signature mirrors the buffers that + :class:`PreDecoderMemoryEvalModule` pre-registers so the eval/export path + can call this directly without per-call allocations. + + Args: + dets: ``(B, 2 * T * half)`` detector bits. + D, T, half, basis: Code / batch geometry. ``basis`` is ``'X'`` or ``'Z'``. + scatter_perm_x, scatter_perm_z: length-``D*D`` long tensors mapping + each grid position to a syndrome row (or to the sentinel column + at index ``half``); see :func:`_build_scatter_perm`. + w_mapXgrid, w_mapZgrid: ``(1, 1, D, D)`` float32 presence grids. + zero_pad_row: ``(1, 1, 1)`` float32 — broadcast to fabricate + sentinel rows/columns for scatter-via-gather. + """ + B = dets.shape[0] + timeline_len = 2 * T + dev = dets.device + + # ── trt_L1: preprocessor (cast, deinterleave, index_select, boundary handling) ── + # (B, 2*T*half) -> (B, half, 2*T) float32, then pad a sentinel column for + # boundary rounds that have no corresponding detector. + dets_timeline = dets.to(torch.float32).view(B, timeline_len, half).permute(0, 2, 1).contiguous() + zero_col = zero_pad_row.expand(B, half, 1) + padded = torch.cat([dets_timeline, zero_col], dim=2) # (B, half, 2*T+1) + sentinel_idx = timeline_len + + x_bulk_idx = torch.arange(1, timeline_len - 1, 2, dtype=torch.long, device=dev) + z_bulk_idx = torch.arange(2, timeline_len, 2, dtype=torch.long, device=dev) + zero_idx = torch.zeros(1, dtype=torch.long, device=dev) + sentinel = torch.full((1,), sentinel_idx, dtype=torch.long, device=dev) + + if T == 1: + if basis == "X": + idx_x, idx_z = zero_idx, sentinel + else: + idx_z, idx_x = zero_idx, sentinel + else: + if basis == "X": + idx_x = torch.cat([zero_idx, x_bulk_idx]) + idx_z = torch.cat([sentinel, z_bulk_idx[:-1], sentinel]) + else: + idx_z = torch.cat([zero_idx, z_bulk_idx]) + idx_x = torch.cat([sentinel, x_bulk_idx[:-1], sentinel]) + + x_syn_diff = torch.index_select(padded, 2, idx_x) # (B, half, T) float32 + z_syn_diff = torch.index_select(padded, 2, idx_z) # (B, half, T) float32 + + # Presence: broadcast-multiply by round mask to zero boundary rounds (no clone/in-place) + if T == 1: + boundary_mask = torch.zeros(1, device=dev, dtype=torch.float32) + else: + boundary_mask = torch.cat( + [ + torch.zeros(1, device=dev, dtype=torch.float32), + torch.ones(T - 2, device=dev, dtype=torch.float32), + torch.zeros(1, device=dev, dtype=torch.float32), + ] + ) + boundary_mask = boundary_mask.view(1, T, 1, 1) + if basis == "X": + x_present = w_mapXgrid.expand(B, T, D, D) + z_present = (w_mapZgrid * boundary_mask).expand(B, T, D, D) + else: + x_present = (w_mapXgrid * boundary_mask).expand(B, T, D, D) + z_present = w_mapZgrid.expand(B, T, D, D) + + # ── trt_L2: trainX assembly (scatter-via-gather → grid reshape → cat) ── + zero_pad = zero_pad_row.expand(B, 1, T) + x_grid = torch.index_select( + torch.cat([x_syn_diff, zero_pad], dim=1), 1, scatter_perm_x + ) # (B, D², T) + z_grid = torch.index_select(torch.cat([z_syn_diff, zero_pad], dim=1), 1, scatter_perm_z) + x_type = x_grid.reshape(B, D, D, T).permute(0, 3, 1, 2).contiguous() # (B, T, D, D) + z_type = z_grid.reshape(B, D, D, T).permute(0, 3, 1, 2).contiguous() + trainX = torch.cat( + [ + x_type.unsqueeze(1), + z_type.unsqueeze(1), + x_present.unsqueeze(1), + z_present.unsqueeze(1), + ], + dim=1, + ).contiguous() + + return trainX, x_syn_diff, z_syn_diff + + +def dets_to_predecoder_inputs( + dets: torch.Tensor, + *, + distance: int, + n_rounds: int, + basis: str, + code_rotation: str = "XV", +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert detector bits to the pre-decoder's ``(trainX, x_syn_diff, z_syn_diff)``. + + Args: + dets: ``(B, 2 * T * half)`` tensor of detector bits. + distance: Surface-code distance ``D`` (must be odd). + n_rounds: Number of measurement rounds ``T`` (must be >= 1). + basis: Memory basis being decoded; ``'X'`` or ``'Z'``. + code_rotation: Code orientation; ``'XV'``, ``'XH'``, ``'ZV'``, ``'ZH'``. + + Returns: + ``(trainX, x_syn_diff, z_syn_diff)``. See module docstring for shapes. + + Raises: + ValueError: If ``dets`` does not have ``2 * T * half`` columns, or if + ``basis`` / ``code_rotation`` are not in the supported set. + """ + if dets.ndim != 2: + raise ValueError(f"dets must be 2-D (B, num_detectors); got shape {tuple(dets.shape)}") + + D = int(distance) + T = int(n_rounds) + if T < 1: + raise ValueError(f"n_rounds must be >= 1, got {T}") + if D < 3 or (D % 2) == 0: + raise ValueError(f"distance must be an odd integer >= 3, got {D}") + half = (D * D - 1) // 2 + + basis_upper = str(basis).strip().upper() + if basis_upper not in ("X", "Z"): + raise ValueError(f"basis must be 'X' or 'Z', got {basis!r}") + + rotation = str(code_rotation).strip().upper() if code_rotation else "XV" + if rotation not in ("XV", "XH", "ZV", "ZH"): + raise ValueError(f"code_rotation must be one of XV/XH/ZV/ZH, got {code_rotation!r}") + + num_dets = dets.shape[1] + expected = 2 * T * half + if int(num_dets) != expected: + raise ValueError( + f"detector count mismatch: dets has {int(num_dets)} columns, " + f"expected 2 * n_rounds * (D^2 - 1)/2 = {expected} " + f"(distance={D}, n_rounds={T})." + ) + + dev = dets.device + idx_map_x = torch.as_tensor( + compute_stabX_to_data_index_map(D, rotation), dtype=torch.long, device=dev + ) + idx_map_z = torch.as_tensor( + compute_stabZ_to_data_index_map(D, rotation), dtype=torch.long, device=dev + ) + scatter_perm_x = _build_scatter_perm(idx_map_x, D * D, half) + scatter_perm_z = _build_scatter_perm(idx_map_z, D * D, half) + w_mapX = normalized_weight_mapping_Xstab_memory(D, rotation).reshape(D, D) + w_mapZ = normalized_weight_mapping_Zstab_memory(D, rotation).reshape(D, D) + w_mapXgrid = w_mapX.to(dtype=torch.float32, device=dev).unsqueeze(0).unsqueeze(0) + w_mapZgrid = w_mapZ.to(dtype=torch.float32, device=dev).unsqueeze(0).unsqueeze(0) + zero_pad_row = torch.zeros(1, 1, 1, dtype=torch.float32, device=dev) + + trainX, x_syn_diff, z_syn_diff = _predecoder_transform_core( + dets, + D=D, + T=T, + half=half, + basis=basis_upper, + scatter_perm_x=scatter_perm_x, + scatter_perm_z=scatter_perm_z, + w_mapXgrid=w_mapXgrid, + w_mapZgrid=w_mapZgrid, + zero_pad_row=zero_pad_row, + ) + return ( + trainX, + x_syn_diff.to(torch.int32).contiguous(), + z_syn_diff.to(torch.int32).contiguous(), + ) + + +__all__ = ["dets_to_predecoder_inputs"] diff --git a/code/evaluation/inference.py b/code/evaluation/inference.py index 89d3ad6..99ff6df 100644 --- a/code/evaluation/inference.py +++ b/code/evaluation/inference.py @@ -38,6 +38,17 @@ def _detect_shm_bytes() -> Optional[int]: return None +def _emit_inference_summary_marker() -> bool: + """Whether to print the machine-readable ``[Inference Summary]`` JSON line. + + Off by default so interactive runs see only the pretty-printed table. + Set ``PREDECODER_EMIT_INFERENCE_SUMMARY=1`` to opt in; the offline + smoketest sets this automatically when it invokes inference. + """ + raw = os.environ.get("PREDECODER_EMIT_INFERENCE_SUMMARY", "").strip().lower() + return raw in ("1", "true", "yes", "on") + + def _safe_float(x: Any) -> float: try: return float(x) @@ -199,3 +210,37 @@ def _avg(a: float, b: float) -> float: f" {'LER - Avg:':<{label_w}}{_avg(x_base, z_base):>15.6f} {_avg(x_after, z_after):>17.6f}" ) print(f" {'PyMatching speedup (Avg X/Z):':<{label_w}}{avg_speedup:>15.3f}x") + + # Opt-in machine-readable summary. Off by default so interactive and + # notebook runs see only the pretty-printed table above. Tools that need a + # stable contract (e.g. ``code/scripts/offline_smoketest.sh``) set + # ``PREDECODER_EMIT_INFERENCE_SUMMARY=1`` and parse the JSON marker. The + # smoketest reads the LAST marker in the log to handle the case where a + # single run invokes inference more than once. + if _emit_inference_summary_marker(): + import json as _json + + summary = { + "schema_version": 1, + "marker": "inference_summary", + "ler": + { + "x_basis_no_predecoder": float(x_base), + "x_basis_after_predecoder": float(x_after), + "z_basis_no_predecoder": float(z_base), + "z_basis_after_predecoder": float(z_after), + "avg_no_predecoder": float(_avg(x_base, z_base)), + "avg_after_predecoder": float(_avg(x_after, z_after)), + }, + "pymatching_latency_us_per_round": + { + "x_basis_no_predecoder": float(x_lat_base), + "x_basis_after_predecoder": float(x_lat_after), + "z_basis_no_predecoder": float(z_lat_base), + "z_basis_after_predecoder": float(z_lat_after), + "avg_no_predecoder": float(avg_lat_base), + "avg_after_predecoder": float(avg_lat_after), + }, + "pymatching_speedup_avg_xz": float(avg_speedup), + } + print("[Inference Summary] " + _json.dumps(summary, sort_keys=True)) diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 3b705e0..897f2ed 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -81,6 +81,7 @@ class OnnxWorkflow(IntEnum): from data.factory import DatapipeFactory +from data.predecoder_transform import _predecoder_transform_core from torch.utils.data.distributed import DistributedSampler from torch.utils.data import DataLoader from training.utils import * @@ -115,6 +116,40 @@ def _get_env_bool(name: str, default: bool = False) -> bool: return val not in ("0", "false", "no", "off", "") +def _get_decode_mode(cfg) -> str: + mode = os.environ.get("PREDECODER_DECODE_MODE", "").strip() + if not mode: + mode = str(getattr(getattr(cfg, "test", None), "decode_mode", "ising_decoding_pymatching")) + mode = mode.strip().lower() + valid = {"ising_decoding_pymatching", "pymatching_only"} + if mode not in valid: + raise ValueError(f"Invalid decode mode {mode!r}. Supported modes: {sorted(valid)}") + return mode + + +def _get_decode_output_dir(cfg) -> Optional[Path]: + """Return the configured decode-output directory without creating it. + + Resolution order: ``PREDECODER_DECODE_OUTPUT_DIR`` environment variable, then + ``cfg.test.decode_output_dir``. Returns ``None`` when neither is set. The + directory is materialised lazily by :func:`_save_decode_array` so that a + pure read does not leave empty directories on disk. + """ + raw = os.environ.get("PREDECODER_DECODE_OUTPUT_DIR", "").strip() + if not raw: + raw = str(getattr(getattr(cfg, "test", None), "decode_output_dir", "") or "").strip() + if not raw: + return None + return Path(raw) + + +def _save_decode_array(output_dir: Optional[Path], basis: str, name: str, arr) -> None: + if output_dir is None: + return + output_dir.mkdir(parents=True, exist_ok=True) + np.save(output_dir / f"{basis}_{name}.npy", np.asarray(arr)) + + def _parse_quant_format(rank: int = 0) -> str: """Read and validate the QUANT_FORMAT environment variable. @@ -625,75 +660,30 @@ def _batch_to_trainx_and_syndromes(self, dets: torch.Tensor): if not getattr(torch.onnx, "is_in_onnx_export", lambda: False)(): assert T >= 2, f"T={T} is too small for DEM (need T>=2)." - timeline_len = 2 * T + trainX, x_syn_diff, z_syn_diff = _predecoder_transform_core( + dets, + D=self.D, + T=T, + half=half, + basis=self.basis, + scatter_perm_x=self.scatter_perm_x, + scatter_perm_z=self.scatter_perm_z, + w_mapXgrid=self.w_mapXgrid, + w_mapZgrid=self.w_mapZgrid, + zero_pad_row=self.zero_pad_row, + ) - # ── trt_L1: preprocessor (cast, deinterleave, index_select, boundary handling) ── - # (B, 2*T*half) -> (B, half, 2*T) float32. + # baseline_detectors_batch: dets reshape pass-through, preserved + # explicitly to keep the original ONNX graph for the post-processing + # GEMM input. + timeline_len = 2 * T dets_timeline = dets.to(torch.float32).view(B, timeline_len, half).permute(0, 2, 1).contiguous() - zero_col = self.zero_pad_row.expand(B, half, 1) # (B, half, 1) - dets_timeline_padded = torch.cat([dets_timeline, zero_col], dim=2) # (B, half, 2*T+1) - - # Build deinterleave indices dynamically for this T. - sentinel_idx = timeline_len # points to appended all-zero column - dev = dets.device - x_bulk_idx = torch.arange(1, timeline_len - 1, 2, dtype=torch.long, device=dev) # T-1 - z_bulk_idx = torch.arange(2, timeline_len, 2, dtype=torch.long, device=dev) # T-1 - - _zero = torch.zeros(1, dtype=torch.long, device=dev) - _sentinel = torch.full((1,), sentinel_idx, dtype=torch.long, device=dev) - - if self.basis == "X": - idx_x = torch.cat([_zero, x_bulk_idx]) - idx_z = torch.cat([_sentinel, z_bulk_idx[:-1], _sentinel]) - x_syn_diff = torch.index_select(dets_timeline_padded, 2, idx_x) # (B, half, T) - z_syn_diff = torch.index_select(dets_timeline_padded, 2, idx_z) # (B, half, T) - else: - idx_z = torch.cat([_zero, z_bulk_idx]) - idx_x = torch.cat([_sentinel, x_bulk_idx[:-1], _sentinel]) - z_syn_diff = torch.index_select(dets_timeline_padded, 2, idx_z) # (B, half, T) - x_syn_diff = torch.index_select(dets_timeline_padded, 2, idx_x) # (B, half, T) - - # Presence: broadcast-multiply by round mask to zero boundary rounds (no clone/in-place) - boundary_mask = torch.cat( - [ - torch.zeros(1, device=dev, dtype=torch.float32), - torch.ones(T - 2, device=dev, dtype=torch.float32), - torch.zeros(1, device=dev, dtype=torch.float32), - ] - ).view(1, T, 1, 1) - if self.basis == "X": - x_present = self.w_mapXgrid.expand(B, T, self.D, self.D) - z_present = (self.w_mapZgrid * boundary_mask).expand(B, T, self.D, self.D) - else: - x_present = (self.w_mapXgrid * boundary_mask).expand(B, T, self.D, self.D) - z_present = self.w_mapZgrid.expand(B, T, self.D, self.D) - - # ── trt_L2: trainX assembly (scatter-via-gather → grid reshape → cat) ── - zero_pad = self.zero_pad_row.expand(B, 1, T) - x_grid = torch.index_select( - torch.cat([x_syn_diff, zero_pad], dim=1), 1, self.scatter_perm_x - ) # (B, D², T) - z_grid = torch.index_select( - torch.cat([z_syn_diff, zero_pad], dim=1), 1, self.scatter_perm_z - ) # (B, D², T) - x_type = x_grid.reshape(B, self.D, self.D, T).permute(0, 3, 1, - 2).contiguous() # (B, T, D, D) - z_type = z_grid.reshape(B, self.D, self.D, T).permute(0, 3, 1, 2).contiguous() - trainX = torch.cat( - [ - x_type.unsqueeze(1), - z_type.unsqueeze(1), - x_present.unsqueeze(1), - z_present.unsqueeze(1), - ], - dim=1 - ).contiguous() + baseline_detectors_batch = dets_timeline.permute(0, 2, 1).contiguous().view(B, -1) n_x = half n_z = z_syn_diff.shape[1] num_boundary_dets = n_x if self.basis == "X" else n_z - baseline_detectors_batch = dets_timeline.permute(0, 2, 1).contiguous().view(B, -1) return trainX, x_syn_diff, z_syn_diff, baseline_detectors_batch, num_boundary_dets, B, T, n_x, n_z @@ -903,6 +893,11 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic verbose = bool(getattr(cfg.test, "verbose_inference", False) ) or bool(getattr(cfg.test, "verbose", False)) + decode_mode = _get_decode_mode(cfg) + decode_output_dir = _get_decode_output_dir(cfg) + basis = str(getattr(cfg.test, "meas_basis_test", "X")).upper() + if basis not in ("X", "Z"): + raise AssertionError(f"Invalid meas_basis_test='{basis}'. Use 'X' or 'Z'.") # Log distributed and sampling configuration (only on rank 0, verbose only) if verbose and dist.rank == 0: if dist.world_size > 1: @@ -915,6 +910,7 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic ) else: print(f"[LER] Sampling mode: threshold (th_data={th_data}, th_syn={th_syn})") + print(f"[LER] Decode mode: {decode_mode}") enable_timing_instrumentation = bool(getattr(cfg.test, "enable_timing_instrumentation", False)) enable_timing_instrumentation = _get_env_bool( @@ -1036,6 +1032,26 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic baseline_pred = np.asarray(baseline_pred, dtype=np.uint8).reshape(-1, circuit.num_observables) num_pymatch_errors = int((baseline_pred != stim_obs).sum()) + if dist.rank == 0: + _save_decode_array(decode_output_dir, basis, "observables", stim_obs) + _save_decode_array(decode_output_dir, basis, "pymatching_predictions", baseline_pred) + + if decode_mode == "pymatching_only": + if dist.rank == 0: + LAST_DEM_TIMING.update( + { + "dem_build_s": float(dem_build_s), + "dem_decode_s": float(dem_decode_s), + "basis": basis, + "num_batches": 0, + } + ) + print( + f"[DEM Timing] build={dem_build_s:.2f}s decode={dem_decode_s:.2f}s " + f"(basis={basis}, decode_mode=pymatching_only)" + ) + return num_pymatch_errors, int(stim_obs.shape[0] + ), num_pymatch_errors, float("nan"), float("nan") # --- DataLoader: NO DistributedSampler - each GPU processes ALL of its own samples --- test_loader_kwargs = dict(cfg.test.dataloader) @@ -1086,10 +1102,6 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic code_rotation = getattr(cfg.data, 'code_rotation', 'XV') maps = _build_stab_maps(D, code_rotation) - basis = str(getattr(cfg.test, "meas_basis_test", "X")).upper() - if basis not in ("X", "Z"): - raise AssertionError(f"Invalid meas_basis_test='{basis}'. Use 'X' or 'Z'.") - batch_size_original = test_loader_kwargs.get("batch_size", 1) T_original = cfg.test.n_rounds @@ -1342,6 +1354,8 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic ) / stim_dets.size if stim_dets.size > 0 else 0.0 floor_time_per_round = None detector_shape = None + final_prediction_rows = [] if dist.rank == 0 and decode_output_dir is not None else None + residual_rows = [] if dist.rank == 0 and decode_output_dir is not None else None t_start = time.perf_counter() t_model_time = 0.0 @@ -1426,6 +1440,10 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic pre_L_cpu = pre_L.cpu() if pre_L.is_cuda else pre_L pred_obs_t = pred_obs_t.view(-1).contiguous() # always (B,) final_L = (pre_L_cpu + pred_obs_t).remainder_(2) # (B,) + if final_prediction_rows is not None: + final_prediction_rows.append(final_L.numpy().reshape(-1, 1).copy()) + if residual_rows is not None: + residual_rows.append(residual_np.copy()) # Ground truth (same for X or Z; DEM has 1 observable) gt_obs = dets_and_obs[:, -num_obs:] @@ -1531,6 +1549,22 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic total_samples = int(t_n.item()) num_pymatch_errors = int(t_pymatch.item()) + if dist.rank == 0: + if final_prediction_rows: + _save_decode_array( + decode_output_dir, + basis, + "ising_decoding_pymatching_predictions", + np.concatenate(final_prediction_rows, axis=0), + ) + if residual_rows: + _save_decode_array( + decode_output_dir, + basis, + "predecoder_residual_detectors", + np.concatenate(residual_rows, axis=0), + ) + # Latency: single-shot (batch_size=1, matcher.decode) on a small subset on rank 0 only, # timed after the main loop for a clean CPU state. baseline_us_per_round = float("nan") diff --git a/code/export/generate_test_data.py b/code/export/generate_test_data.py index 47c773c..0cd1b69 100644 --- a/code/export/generate_test_data.py +++ b/code/export/generate_test_data.py @@ -95,6 +95,12 @@ sys.modules.setdefault("qec.surface_code", _sc_pkg) from qec.surface_code.memory_circuit import MemoryCircuit +from qec.surface_code.stim_sample_io import ( + build_stim_sample_metadata, + normalize_code_rotation, + write_metadata_json, + write_stim_detector_samples, +) from qec.noise_model import NoiseModel # Default 25-parameter noise model matching config_public.yaml at p=0.003 @@ -126,8 +132,6 @@ "p_cnot_ZZ": 0.0002, } -_ROTATION_ALIASES = {"O1": "XV", "O2": "XH", "O3": "ZV", "O4": "ZH"} - # --------------------------------------------------------------------------- # Binary I/O helpers # --------------------------------------------------------------------------- @@ -185,8 +189,10 @@ def generate_test_data( num_samples: int = 1000, onnx_model: str | None = None, output_dir: str = "test_data", + write_stim_artifacts: bool = False, + write_cudaq_artifacts: bool = True, ): - code_rotation = _ROTATION_ALIASES.get(code_rotation.upper(), code_rotation.upper()) + code_rotation = normalize_code_rotation(code_rotation) out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) @@ -230,17 +236,19 @@ def generate_test_data( print(f" DEM + matcher built in {time.perf_counter() - t0:.3f}s") print(f" Detectors: {det_model.num_detectors}, Observables: {det_model.num_observables}") - # ---- Extract H, O, priors via beliefmatching ---- - print("Extracting check matrices (beliefmatching)...") - matrices = detector_error_model_to_check_matrices(det_model) - H = matrices.edge_check_matrix - O = matrices.edge_observables_matrix - edge_probs = matrices.hyperedge_to_edge_matrix @ matrices.priors - eps = 1e-14 - edge_probs[edge_probs > 1 - eps] = 1 - eps - edge_probs[edge_probs < eps] = eps - priors = edge_probs - print(f" H shape: {H.shape}, O shape: {O.shape}, priors shape: {priors.shape}") + H = O = priors = None + if write_cudaq_artifacts: + # ---- Extract H, O, priors via beliefmatching ---- + print("Extracting check matrices (beliefmatching)...") + matrices = detector_error_model_to_check_matrices(det_model) + H = matrices.edge_check_matrix + O = matrices.edge_observables_matrix + edge_probs = matrices.hyperedge_to_edge_matrix @ matrices.priors + eps = 1e-14 + edge_probs[edge_probs > 1 - eps] = 1 - eps + edge_probs[edge_probs < eps] = eps + priors = edge_probs + print(f" H shape: {H.shape}, O shape: {O.shape}, priors shape: {priors.shape}") # ---- Sample syndrome data ---- print(f"Sampling {num_samples} shots...") @@ -291,33 +299,70 @@ def generate_test_data( # ---- Save everything ---- print(f"Writing outputs to {out}/") - save_dense_bin(str(out / "detectors.bin"), stim_dets) - save_dense_bin(str(out / "observables.bin"), stim_obs) - save_dense_bin(str(out / "pymatching_predictions.bin"), predictions) - save_csr_bin(str(out / "H_csr.bin"), H) - save_csr_bin(str(out / "O_csr.bin"), O) - save_priors_bin(str(out / "priors.bin"), priors) + if write_cudaq_artifacts: + save_dense_bin(str(out / "detectors.bin"), stim_dets) + save_dense_bin(str(out / "observables.bin"), stim_obs) + save_dense_bin(str(out / "pymatching_predictions.bin"), predictions) + save_csr_bin(str(out / "H_csr.bin"), H) + save_csr_bin(str(out / "O_csr.bin"), O) + save_priors_bin(str(out / "priors.bin"), priors) - if predecoder_outputs is not None: + if write_cudaq_artifacts and predecoder_outputs is not None: save_dense_bin(str(out / "predecoder_outputs.bin"), predecoder_outputs) noise_label = "25-param" if noise_model_params else "simple" - save_metadata( - str(out / "metadata.txt"), - distance=distance, - n_rounds=n_rounds, - basis=basis.upper(), - code_rotation=code_rotation, - p_error=p_error, - num_samples=num_samples, - num_detectors=det_model.num_detectors, - num_observables=det_model.num_observables, - H_shape=H.shape, - noise_model=noise_label, - **({ - "onnx_model": onnx_model - } if onnx_model else {}), - ) + if write_cudaq_artifacts: + save_metadata( + str(out / "metadata.txt"), + distance=distance, + n_rounds=n_rounds, + basis=basis.upper(), + code_rotation=code_rotation, + p_error=p_error, + num_samples=num_samples, + num_detectors=det_model.num_detectors, + num_observables=det_model.num_observables, + H_shape=H.shape, + noise_model=noise_label, + **({ + "onnx_model": onnx_model + } if onnx_model else {}), + ) + + if write_stim_artifacts: + basis_label = basis.upper() + samples_path = out / f"samples_{basis_label}.dets" + metadata_path = out / f"metadata_{basis_label}.json" + write_stim_detector_samples( + path=samples_path, + dets_and_obs=dets_and_obs, + num_detectors=det_model.num_detectors, + num_observables=det_model.num_observables, + sample_format="dets", + ) + if noise_model is not None: + noise_params_for_meta = noise_model.canonical_parameters() + noise_sha_for_meta = noise_model.sha256() + else: + noise_params_for_meta = None + noise_sha_for_meta = None + metadata = build_stim_sample_metadata( + distance=distance, + n_rounds=n_rounds, + basis=basis_label, + code_rotation=code_rotation, + num_detectors=det_model.num_detectors, + num_observables=det_model.num_observables, + num_shots=num_samples, + sample_format="dets", + append_observables=True, + p_error=float(p_error), + noise_model_label=noise_label, + noise_model_params=noise_params_for_meta, + noise_model_sha256=noise_sha_for_meta, + extra={"onnx_model": onnx_model} if onnx_model else None, + ) + write_metadata_json(metadata_path, metadata) print("Done.") for f in sorted(out.iterdir()): @@ -330,7 +375,7 @@ def generate_test_data( ) parser.add_argument("--distance", type=int, default=13) parser.add_argument("--n-rounds", type=int, default=104) - parser.add_argument("--basis", type=str, default="X", choices=["X", "Z"]) + parser.add_argument("--basis", type=str, default="X", choices=["X", "Z", "both"]) parser.add_argument( "--code-rotation", type=str, default="XV", help="XV, XH, ZV, ZH or public aliases O1-O4" ) @@ -350,6 +395,16 @@ def generate_test_data( action="store_true", help="Use simple p_error instead of 25-parameter noise model" ) + parser.add_argument( + "--stim-artifacts", + action="store_true", + help="Write Stim-standard samples_{basis}.dets and metadata_{basis}.json files" + ) + parser.add_argument( + "--no-cudaq-artifacts", + action="store_true", + help="Do not write the CUDA-Q/export .bin/.txt artifacts" + ) args = parser.parse_args() if args.output_dir is None: @@ -357,14 +412,24 @@ def generate_test_data( noise_params = None if args.simple_noise else DEFAULT_NOISE_PARAMS - generate_test_data( - distance=args.distance, - n_rounds=args.n_rounds, - basis=args.basis, - p_error=args.p_error, - code_rotation=args.code_rotation, - noise_model_params=noise_params, - num_samples=args.num_samples, - onnx_model=args.onnx_model, - output_dir=args.output_dir, - ) + if args.basis == "both" and not args.no_cudaq_artifacts: + raise ValueError( + "--basis both writes X and Z into one directory; use --no-cudaq-artifacts " + "to avoid collisions in detectors.bin/observables.bin outputs." + ) + + bases = ["X", "Z"] if args.basis == "both" else [args.basis] + for one_basis in bases: + generate_test_data( + distance=args.distance, + n_rounds=args.n_rounds, + basis=one_basis, + p_error=args.p_error, + code_rotation=args.code_rotation, + noise_model_params=noise_params, + num_samples=args.num_samples, + onnx_model=args.onnx_model, + output_dir=args.output_dir, + write_stim_artifacts=args.stim_artifacts, + write_cudaq_artifacts=not args.no_cudaq_artifacts, + ) diff --git a/code/qec/surface_code/stim_sample_io.py b/code/qec/surface_code/stim_sample_io.py new file mode 100644 index 0000000..05ff8d2 --- /dev/null +++ b/code/qec/surface_code/stim_sample_io.py @@ -0,0 +1,419 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helpers for Stim detector-sample files used by offline decoding. + +The pair of files written by :func:`write_stim_detector_samples` / +:func:`write_metadata_json` and consumed by :func:`read_stim_detector_samples` +forms the on-disk *Stim sample contract* between data generators (the local +simulator, a QPU wrapper, or a third-party producer) and the offline decoder. + +The contract has two layers: + +1. **Structural fields** — distance, rounds, basis, orientation, detector and + observable counts, and the on-disk format. These are always validated and + any mismatch is a hard error. +2. **Noise-model fingerprint** — ``p_error`` and the 25-parameter ``NoiseModel`` + parameters (via :func:`qec.noise_model.NoiseModel.sha256`). When the decoder + passes in an active noise model and ``strict=True``, mismatches are a hard + error; when ``strict=False`` they emit a warning. Older files that predate + this field bypass the check, so legacy artifacts keep loading. +""" + +from __future__ import annotations + +import json +import warnings +from pathlib import Path +from typing import Any, Mapping, Optional + +import numpy as np +import stim + +_ROTATION_ALIASES = {"O1": "XV", "O2": "XH", "O3": "ZV", "O4": "ZH"} +_SUPPORTED_FORMATS = {"dets"} + +#: Current Stim sample metadata schema version. +SCHEMA_VERSION = 2 + + +def normalize_code_rotation(value: Any) -> str: + rotation = str(value).strip().upper() + return _ROTATION_ALIASES.get(rotation, rotation) + + +def build_stim_sample_metadata( + *, + distance: int, + n_rounds: int, + basis: str, + code_rotation: str, + num_detectors: int, + num_observables: int, + num_shots: int, + sample_format: str = "dets", + append_observables: bool = True, + p_error: Optional[float] = None, + noise_model_label: Optional[str] = None, + noise_model_params: Optional[Mapping[str, float]] = None, + noise_model_sha256: Optional[str] = None, + extra: Optional[Mapping[str, Any]] = None, +) -> dict[str, Any]: + """Build a metadata dict matching the current Stim sample schema. + + Args: + distance, n_rounds, basis, code_rotation, num_detectors, num_observables, + num_shots: Structural fields the offline decoder must agree on. + sample_format: Only ``"dets"`` is supported today; the argument exists + so that future formats can be added without an API break. + append_observables: ``True`` if logical observables are appended to each + shot. The offline decoder requires this. + p_error: Scalar physical error rate used when generating samples (or + ``None`` if the generator uses an explicit ``NoiseModel``). + noise_model_label: Human-readable label, e.g. ``"25-param"`` or + ``"simple"``. Used for warnings and never for strict checks. + noise_model_params: The full 25-parameter dict (sorted) used to build + the ``NoiseModel`` instance, or ``None`` for the simple-noise case. + noise_model_sha256: Deterministic fingerprint of ``noise_model_params`` + (typically ``NoiseModel.sha256()``). Used for strict comparison. + extra: Optional additional fields to record alongside the contract. + + Returns: + A JSON-serializable dict describing one ``samples_{basis}.dets`` file. + """ + meta: dict[str, Any] = { + "schema_version": SCHEMA_VERSION, + "artifact": "stim_detector_samples", + "format": sample_format, + "append_observables": bool(append_observables), + "distance": int(distance), + "n_rounds": int(n_rounds), + "basis": str(basis).strip().upper(), + "code_rotation": str(code_rotation).strip().upper(), + "num_detectors": int(num_detectors), + "num_observables": int(num_observables), + "num_shots": int(num_shots), + } + if p_error is not None: + meta["p_error"] = float(p_error) + if noise_model_label is not None: + meta["noise_model"] = str(noise_model_label) + if noise_model_params is not None: + meta["noise_model_params"] = { + str(k): float(v) for k, v in sorted(dict(noise_model_params).items()) + } + if noise_model_sha256 is not None: + meta["noise_model_sha256"] = str(noise_model_sha256) + if extra: + for k, v in dict(extra).items(): + meta.setdefault(str(k), v) + return meta + + +def write_metadata_json(path: str | Path, metadata: Mapping[str, Any]) -> None: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(dict(metadata), f, indent=2, sort_keys=True) + f.write("\n") + + +def read_metadata_json(path: str | Path) -> dict[str, Any]: + with Path(path).open("r", encoding="utf-8") as f: + metadata = json.load(f) + if not isinstance(metadata, dict): + raise ValueError(f"Stim metadata must be a JSON object: {path}") + return metadata + + +def validate_stim_sample_metadata( + metadata: Mapping[str, Any], + *, + distance: int, + n_rounds: int, + basis: str, + code_rotation: str, + num_detectors: int, + num_observables: int, + p_error: Optional[float] = None, + noise_model_sha256: Optional[str] = None, + noise_model_label: Optional[str] = None, + p_error_atol: float = 0.0, + strict_noise: bool = True, +) -> None: + """Validate metadata against the circuit/config used for decoding. + + Structural checks (distance, rounds, basis, orientation, detector and + observable counts, format, observable appending) are always strict. The + noise-model checks are opt-in: pass ``p_error`` and/or + ``noise_model_sha256`` to compare against the recorded values. When the + decoder does not provide a noise model (``noise_model_sha256=None``), older + files that never recorded one are accepted as-is. + + Args: + metadata: Loaded metadata dict. + distance, n_rounds, basis, code_rotation, num_detectors, num_observables: + Structural fields the decoder expects. + p_error: Decoder's active scalar error rate, or ``None`` to skip. + noise_model_sha256: Decoder's active noise-model fingerprint, or + ``None`` to skip. + noise_model_label: Decoder's active noise-model label + (e.g. ``"25-param"``), used only for clarifying messages. + p_error_atol: Absolute tolerance when comparing scalar ``p_error``. + strict_noise: If ``True``, mismatches in ``p_error`` or + ``noise_model_sha256`` raise. If ``False``, they emit a + :class:`UserWarning` and the call continues. + + Raises: + ValueError: With one line per structural mismatch (and per noise + mismatch when ``strict_noise=True``). The messages are intentionally + explicit because these files are a cross-team contract. + """ + errors: list[str] = [] + # Legacy files (no schema_version key) predate the noise fingerprint; treat them as v1. + sv = metadata.get("schema_version", 1) + if not isinstance(sv, int) or isinstance(sv, bool) or sv > SCHEMA_VERSION: + errors.append(f"unsupported schema_version: {sv!r} (max {SCHEMA_VERSION})") + sample_format = str(metadata.get("format", "")).strip().lower() + if sample_format not in _SUPPORTED_FORMATS: + errors.append( + f"metadata format mismatch: file has {metadata.get('format')!r}, " + f"supported formats are {sorted(_SUPPORTED_FORMATS)}" + ) + + if metadata.get("append_observables") is not True: + errors.append( + "metadata append_observables mismatch: expected true because LER requires logical labels" + ) + + checks = ( + ("distance", int(distance), lambda v: int(v)), + ("n_rounds", int(n_rounds), lambda v: int(v)), + ("basis", str(basis).strip().upper(), lambda v: str(v).strip().upper()), + ("num_detectors", int(num_detectors), lambda v: int(v)), + ("num_observables", int(num_observables), lambda v: int(v)), + ) + for key, expected, cast in checks: + if key not in metadata: + errors.append(f"metadata missing required field: {key}") + continue + try: + actual = cast(metadata[key]) + except Exception: + errors.append( + f"metadata {key} mismatch: file has {metadata[key]!r}, expected {expected!r}" + ) + continue + if actual != expected: + errors.append(f"metadata {key} mismatch: file has {actual!r}, expected {expected!r}") + + if "code_rotation" not in metadata: + errors.append("metadata missing required field: code_rotation") + else: + actual_rotation = normalize_code_rotation(metadata["code_rotation"]) + expected_rotation = normalize_code_rotation(code_rotation) + if actual_rotation != expected_rotation: + errors.append( + "metadata code_rotation mismatch: " + f"file has {metadata['code_rotation']!r}/{actual_rotation}, " + f"decode requested {code_rotation!r}/{expected_rotation}" + ) + + if int(num_observables) <= 0: + errors.append("missing observables: rebuilt circuit has num_observables=0") + else: + try: + file_num_obs = int(metadata.get("num_observables", 0)) + except Exception: + file_num_obs = 0 + if file_num_obs <= 0: + errors.append("missing observables: metadata num_observables must be positive") + + noise_messages: list[str] = [] + if p_error is not None and "p_error" in metadata: + try: + file_p_error = float(metadata["p_error"]) + except Exception: + file_p_error = None + if file_p_error is None: + noise_messages.append( + f"metadata p_error mismatch: file has {metadata['p_error']!r}, " + f"decoder uses {float(p_error)!r}" + ) + elif abs(file_p_error - float(p_error)) > float(p_error_atol): + noise_messages.append( + f"metadata p_error mismatch: file has {file_p_error!r}, " + f"decoder uses {float(p_error)!r} (atol={p_error_atol})" + ) + + if noise_model_sha256 is not None and "noise_model_sha256" in metadata: + file_sha = str(metadata.get("noise_model_sha256", "")).strip() + if file_sha != str(noise_model_sha256).strip(): + file_label = metadata.get("noise_model", "?") + local_label = noise_model_label or "?" + noise_messages.append( + "metadata noise_model_sha256 mismatch: " + f"file has {file_sha!r} (label={file_label!r}), " + f"decoder uses {noise_model_sha256!r} (label={local_label!r})" + ) + + if noise_messages and strict_noise: + errors.extend(noise_messages) + elif noise_messages: + warnings.warn( + "Stim sample noise-model drift (continuing because strict_noise=False):\n- " + + "\n- ".join(noise_messages), + UserWarning, + stacklevel=2, + ) + + if errors: + raise ValueError("Invalid Stim sample metadata:\n- " + "\n- ".join(errors)) + + +def write_stim_detector_samples( + *, + path: str | Path, + dets_and_obs: np.ndarray, + num_detectors: int, + num_observables: int, + sample_format: str = "dets", +) -> None: + """Write a ``samples_*.dets`` file using Stim's sparse format. + + Args: + path: Destination path (parent directories are created). + dets_and_obs: ``(num_shots, num_detectors + num_observables)`` array of + detector bits with logical observables appended. The width must + match exactly; mismatches raise rather than producing a malformed + file that would fail validation on read. + num_detectors, num_observables: Counts used to direct Stim's writer. + sample_format: Currently only ``"dets"`` is supported. + """ + sample_format = str(sample_format).lower() + if sample_format not in _SUPPORTED_FORMATS: + raise ValueError(f"Unsupported Stim sample format {sample_format!r}") + data = np.asarray(dets_and_obs, dtype=np.bool_) + if data.ndim != 2: + raise ValueError( + f"dets_and_obs must be 2-D (num_shots, num_detectors + num_observables); " + f"got shape {data.shape}" + ) + expected_width = int(num_detectors) + int(num_observables) + if int(data.shape[1]) != expected_width: + raise ValueError( + f"dets_and_obs width mismatch: array has {int(data.shape[1])} columns, " + f"expected num_detectors + num_observables = {expected_width} " + f"(num_detectors={int(num_detectors)}, num_observables={int(num_observables)})." + ) + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + stim.write_shot_data_file( + data=data, + path=str(path), + format=sample_format, + num_detectors=int(num_detectors), + num_observables=int(num_observables), + ) + + +def read_stim_detector_samples( + *, + samples_path: str | Path, + metadata_path: str | Path, + distance: int, + n_rounds: int, + basis: str, + code_rotation: str, + num_detectors: int, + num_observables: int, + p_error: Optional[float] = None, + noise_model_sha256: Optional[str] = None, + noise_model_label: Optional[str] = None, + strict_noise: bool = True, +) -> tuple[np.ndarray, dict[str, Any]]: + """Read a ``samples_*.dets`` file after validating its metadata. + + Args: + samples_path, metadata_path: File paths for the sample data and JSON + metadata, typically produced by :func:`resolve_stim_sample_paths`. + distance, n_rounds, basis, code_rotation, num_detectors, + num_observables: Structural parameters the decoder expects. + p_error, noise_model_sha256, noise_model_label: Optional noise-model + fingerprint the decoder is using. See + :func:`validate_stim_sample_metadata` for semantics. + strict_noise: When ``True`` (the default), noise-fingerprint + mismatches raise. When ``False``, they emit a warning. + + Returns: + ``(dets_and_obs, metadata)`` where ``dets_and_obs`` has shape + ``(num_shots, num_detectors + num_observables)`` and dtype ``uint8``. + + Raises: + ValueError: If the metadata is inconsistent with the structural + parameters (always strict) or, when ``strict_noise=True``, with + the noise fingerprint. + """ + metadata = read_metadata_json(metadata_path) + validate_stim_sample_metadata( + metadata, + distance=distance, + n_rounds=n_rounds, + basis=basis, + code_rotation=code_rotation, + num_detectors=num_detectors, + num_observables=num_observables, + p_error=p_error, + noise_model_sha256=noise_model_sha256, + noise_model_label=noise_model_label, + strict_noise=strict_noise, + ) + data = stim.read_shot_data_file( + path=str(samples_path), + format=str(metadata["format"]).lower(), + num_detectors=int(num_detectors), + num_observables=int(num_observables), + ) + arr = np.asarray(data, dtype=np.uint8) + expected_width = int(num_detectors) + int(num_observables) + if arr.ndim != 2 or arr.shape[1] != expected_width: + raise ValueError( + f"Stim sample shape mismatch: file produced shape {arr.shape}, " + f"expected (*, {expected_width})" + ) + expected_shots = int(metadata["num_shots"]) + if arr.shape[0] != expected_shots: + raise ValueError( + f"metadata num_shots mismatch: file has {arr.shape[0]} shots, " + f"metadata has {expected_shots}" + ) + return arr, dict(metadata) + + +def resolve_stim_sample_paths(root: str | Path, basis: str) -> tuple[Path, Path]: + """Resolve either flat or per-basis Stim artifact layouts.""" + root = Path(root) + basis = str(basis).strip().upper() + candidates = ( + (root / f"samples_{basis}.dets", root / f"metadata_{basis}.json"), + (root / basis / "samples.dets", root / basis / "metadata.json"), + (root / "samples.dets", root / "metadata.json"), + ) + for samples_path, metadata_path in candidates: + if samples_path.exists() and metadata_path.exists(): + return samples_path, metadata_path + expected = ", ".join(f"{s} + {m}" for s, m in candidates) + raise FileNotFoundError( + f"No Stim sample artifact found for basis {basis}. Expected one of: {expected}" + ) diff --git a/code/scripts/local_run.sh b/code/scripts/local_run.sh index ba117a4..3eab234 100644 --- a/code/scripts/local_run.sh +++ b/code/scripts/local_run.sh @@ -21,6 +21,7 @@ set -euo pipefail # Examples: # bash code/scripts/local_run.sh # WORKFLOW=inference bash code/scripts/local_run.sh +# WORKFLOW=generate_stim_data bash code/scripts/local_run.sh # GPUS=4 bash code/scripts/local_run.sh # CUDA_VISIBLE_DEVICES=1 bash code/scripts/local_run.sh # use only GPU 1 # @@ -85,30 +86,34 @@ else RESUME_FLAG="++load_checkpoint=True" fi -# GPU-only runs: require a visible GPU and nvidia-smi. -if ! command -v nvidia-smi >/dev/null 2>&1; then - echo "[local_run.sh] Error: GPU-only mode requires nvidia-smi on PATH." >&2 - echo "[local_run.sh] Hint: run on a GPU host or pass CUDA_VISIBLE_DEVICES." >&2 - exit 1 -fi +if [ "${WORKFLOW}" = "generate_stim_data" ]; then + GPUS=1 +else + # GPU-only runs: require a visible GPU and nvidia-smi. + if ! command -v nvidia-smi >/dev/null 2>&1; then + echo "[local_run.sh] Error: GPU-only mode requires nvidia-smi on PATH." >&2 + echo "[local_run.sh] Hint: run on a GPU host or pass CUDA_VISIBLE_DEVICES." >&2 + exit 1 + fi -# Respect CUDA_VISIBLE_DEVICES if set; otherwise auto-detect via nvidia-smi. -if [ -z "${GPUS}" ]; then - if [ -n "${CUDA_VISIBLE_DEVICES:-}" ]; then - GPUS="$(python3 - <<'PY' + # Respect CUDA_VISIBLE_DEVICES if set; otherwise auto-detect via nvidia-smi. + if [ -z "${GPUS}" ]; then + if [ -n "${CUDA_VISIBLE_DEVICES:-}" ]; then + GPUS="$(python3 - <<'PY' import os v=os.environ.get('CUDA_VISIBLE_DEVICES','').strip() print(len([x for x in v.split(',') if x.strip()]) or 1) PY )" - else - GPUS="$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l | tr -d ' ')" + else + GPUS="$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l | tr -d ' ')" + fi fi -fi -if [ "${GPUS}" -le 0 ]; then - echo "[local_run.sh] Error: no GPUs detected. GPU-only mode requires CUDA." >&2 - exit 1 + if [ "${GPUS}" -le 0 ]; then + echo "[local_run.sh] Error: no GPUs detected. GPU-only mode requires CUDA." >&2 + exit 1 + fi fi if [ -z "${MASTER_PORT:-}" ]; then @@ -204,8 +209,8 @@ if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then fi fi -# Ensure CUDA is usable before launching the workflow. -if ! "${PYTHON_BIN}" - <<'PY' +# Ensure CUDA is usable before launching GPU workflows. +if [ "${WORKFLOW}" != "generate_stim_data" ] && ! "${PYTHON_BIN}" - <<'PY' import sys try: import torch diff --git a/code/scripts/offline_smoketest.sh b/code/scripts/offline_smoketest.sh new file mode 100755 index 0000000..8261973 --- /dev/null +++ b/code/scripts/offline_smoketest.sh @@ -0,0 +1,119 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../.." && pwd)" + +EXPERIMENT_NAME="${EXPERIMENT_NAME:-offline_stim_run}" +MODEL_CHECKPOINT="${MODEL_CHECKPOINT:-${REPO_ROOT}/models/Ising-Decoder-SurfaceCode-1-Fast.pt}" +PYTHON_BIN="${PREDECODER_PYTHON:-}" +if [ -z "${PYTHON_BIN}" ]; then + if [ -x "${REPO_ROOT}/.venv_gpu/bin/python" ]; then + PYTHON_BIN="${REPO_ROOT}/.venv_gpu/bin/python" + else + PYTHON_BIN="python3" + fi +fi + +SAMPLES_DIR="${REPO_ROOT}/outputs/${EXPERIMENT_NAME}/stim_samples" +LOG_PATH="${REPO_ROOT}/outputs/${EXPERIMENT_NAME}/run.log" + +echo "==========================================" +echo "Offline Stim smoke test" +echo "==========================================" +echo "experiment: ${EXPERIMENT_NAME}" +echo "python: ${PYTHON_BIN}" +echo "samples: ${SAMPLES_DIR}" +echo "model: ${MODEL_CHECKPOINT}" +echo "==========================================" + +cd "${REPO_ROOT}" + +PREDECODER_PYTHON="${PYTHON_BIN}" \ +WORKFLOW=generate_stim_data \ +EXPERIMENT_NAME="${EXPERIMENT_NAME}" \ +bash code/scripts/local_run.sh + +PREDECODER_PYTHON="${PYTHON_BIN}" \ +PREDECODER_STIM_SAMPLES_DIR="${SAMPLES_DIR}" \ +PREDECODER_DECODE_MODE=pymatching_only \ +PREDECODER_EMIT_INFERENCE_SUMMARY=1 \ +EXPERIMENT_NAME="${EXPERIMENT_NAME}" \ +GPUS=1 \ +WORKFLOW=inference \ +bash code/scripts/local_run.sh + +if [ -f "${MODEL_CHECKPOINT}" ]; then + PREDECODER_PYTHON="${PYTHON_BIN}" \ + PREDECODER_STIM_SAMPLES_DIR="${SAMPLES_DIR}" \ + PREDECODER_DECODE_MODE=ising_decoding_pymatching \ + PREDECODER_EMIT_INFERENCE_SUMMARY=1 \ + EXPERIMENT_NAME="${EXPERIMENT_NAME}" \ + GPUS=1 \ + WORKFLOW=inference \ + EXTRA_PARAMS="++model_checkpoint_file=${MODEL_CHECKPOINT}${MODEL_ID:+ ++model_id=${MODEL_ID}}" \ + bash code/scripts/local_run.sh +else + echo "[offline_smoketest.sh] Model not found; skipped ising_decoding_pymatching:" + echo " ${MODEL_CHECKPOINT}" +fi + +if [ -f "${LOG_PATH}" ]; then + "${PYTHON_BIN}" - "${LOG_PATH}" <<'PY' +import json +import sys +from pathlib import Path + +# Inference prints a single-line JSON marker: +# [Inference Summary] {"marker": "inference_summary", ...} +# Parse the LAST such marker (so we pick up the most recent run when the log +# file accumulates multiple inference passes, e.g. pymatching_only followed by +# ising_decoding_pymatching). +text = Path(sys.argv[1]).read_text(encoding="utf-8") +marker_prefix = "[Inference Summary] " +records = [] +for line in text.splitlines(): + idx = line.find(marker_prefix) + if idx < 0: + continue + payload = line[idx + len(marker_prefix):].strip() + if not payload: + continue + try: + record = json.loads(payload) + except json.JSONDecodeError: + continue + if record.get("marker") == "inference_summary": + records.append(record) + +if not records: + raise SystemExit(f"No [Inference Summary] JSON marker found in {sys.argv[1]}") + +summary = records[-1] +ler = summary.get("ler", {}) +speedup = summary.get("pymatching_speedup_avg_xz", float("nan")) + +# Full per-basis latency/LER/speedup table is already printed by +# code/evaluation/inference.py; just emit one headline line here. +print( + f"\n[offline_smoketest.sh] Avg LER {ler.get('avg_no_predecoder')} " + f"(no pre-decoder) -> {ler.get('avg_after_predecoder')} (after); " + f"PyMatching speedup {speedup}" +) +PY +fi diff --git a/code/tests/test_offline_stim_decoding.py b/code/tests/test_offline_stim_decoding.py new file mode 100644 index 0000000..d6c4a3c --- /dev/null +++ b/code/tests/test_offline_stim_decoding.py @@ -0,0 +1,852 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for offline decoding from Stim detector-sample files.""" + +import os +import sys +import tempfile +import unittest +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pymatching +import torch +import torch.nn as nn + +_repo_code = Path(__file__).resolve().parent.parent +if str(_repo_code) not in sys.path: + sys.path.insert(0, str(_repo_code)) + +from data.datapipe_stim import QCDataPipePreDecoder_Memory_from_stim_file +from data.predecoder_transform import dets_to_predecoder_inputs +from evaluation.logical_error_rate import count_logical_errors_with_errorbar +from qec.noise_model import NoiseModel +from qec.surface_code.data_mapping import ( + compute_stabX_to_data_index_map, + compute_stabZ_to_data_index_map, + normalized_weight_mapping_Xstab_memory, + normalized_weight_mapping_Zstab_memory, +) +from qec.surface_code.memory_circuit import MemoryCircuit +from qec.surface_code.stim_sample_io import ( + build_stim_sample_metadata, + read_metadata_json, + read_stim_detector_samples, + resolve_stim_sample_paths, + write_metadata_json, + write_stim_detector_samples, +) + + +class _DummyDist: + rank = 0 + world_size = 1 + device = torch.device("cpu") + + +class _UnusedModel(nn.Module): + + def forward(self, x): + raise AssertionError("pymatching_only mode must not run the neural predecoder") + + +class _ZeroCorrectionModel(nn.Module): + + def __init__(self): + super().__init__() + self.anchor = nn.Parameter(torch.zeros(())) + + def forward(self, x): + B, _, T, D, _ = x.shape + return torch.full((B, 4, T, D, D), -1.0, device=x.device) + self.anchor * 0.0 + + +def _build_memory_circuit(distance=3, n_rounds=3, basis="X", rotation="XV", p_error=0.003): + noise_model = NoiseModel.from_single_p(p_error) + p_placeholder = float(noise_model.get_max_probability()) + circuit = MemoryCircuit( + distance=distance, + idle_error=p_placeholder, + sqgate_error=p_placeholder, + tqgate_error=p_placeholder, + spam_error=(2.0 / 3.0) * p_placeholder, + n_rounds=n_rounds, + basis=basis, + code_rotation=rotation, + noise_model=noise_model, + add_boundary_detectors=True, + ) + circuit.set_error_rates() + return circuit + + +def _write_artifact( + root: Path, + *, + basis: str, + shots: int = 128, + rotation: str = "XV", + distance: int = 3, + n_rounds: int = 3, + p_error: float = 0.003, + record_noise_fingerprint: bool = False, + samples_basename: str | None = None, + metadata_basename: str | None = None, +): + mem = _build_memory_circuit( + distance=distance, + n_rounds=n_rounds, + basis=basis, + rotation=rotation, + p_error=p_error, + ) + stim_circuit = mem.stim_circuit + dets_and_obs = stim_circuit.compile_detector_sampler(seed=1234).sample( + shots=shots, + append_observables=True, + ) + dets_and_obs = np.asarray(dets_and_obs, dtype=np.uint8) + samples_path = root / (samples_basename or f"samples_{basis}.dets") + metadata_path = root / (metadata_basename or f"metadata_{basis}.json") + write_stim_detector_samples( + path=samples_path, + dets_and_obs=dets_and_obs, + num_detectors=stim_circuit.num_detectors, + num_observables=stim_circuit.num_observables, + ) + noise_kwargs: dict = {} + if record_noise_fingerprint: + noise_model = NoiseModel.from_single_p(p_error) + noise_kwargs.update( + p_error=float(p_error), + noise_model_label="25-param", + noise_model_params=noise_model.canonical_parameters(), + noise_model_sha256=noise_model.sha256(), + ) + metadata = build_stim_sample_metadata( + distance=distance, + n_rounds=n_rounds, + basis=basis, + code_rotation=rotation, + num_detectors=stim_circuit.num_detectors, + num_observables=stim_circuit.num_observables, + num_shots=shots, + **noise_kwargs, + ) + write_metadata_json(metadata_path, metadata) + + dem = stim_circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True) + matcher = pymatching.Matching.from_detector_error_model(dem) + dets = dets_and_obs[:, :-stim_circuit.num_observables] + obs = dets_and_obs[:, -stim_circuit.num_observables:] + predictions = np.asarray(matcher.decode_batch(dets), dtype=np.uint8).reshape(obs.shape) + errors = int((predictions != obs).sum()) + return SimpleNamespace( + mem=mem, + dets_and_obs=dets_and_obs, + predictions=predictions, + errors=errors, + ler=errors / float(shots), + samples_path=samples_path, + metadata_path=metadata_path, + ) + + +def _write_measurement_artifact( + root: Path, + *, + basis: str, + distance: int, + n_rounds: int, + shots: int, + rotation: str = "XV", + p_error: float = 0.02, +): + mem = _build_memory_circuit( + distance=distance, + n_rounds=n_rounds, + basis=basis, + rotation=rotation, + p_error=p_error, + ) + stim_circuit = mem.stim_circuit + measurements = stim_circuit.compile_sampler(seed=1234).sample(shots=shots) + dets_and_obs = stim_circuit.compile_m2d_converter().convert( + measurements=measurements, + append_observables=True, + ) + dets_and_obs = np.asarray(dets_and_obs, dtype=np.uint8) + + samples_path = root / f"samples_{basis}.dets" + metadata_path = root / f"metadata_{basis}.json" + write_stim_detector_samples( + path=samples_path, + dets_and_obs=dets_and_obs, + num_detectors=stim_circuit.num_detectors, + num_observables=stim_circuit.num_observables, + ) + write_metadata_json( + metadata_path, + build_stim_sample_metadata( + distance=distance, + n_rounds=n_rounds, + basis=basis, + code_rotation=rotation, + num_detectors=stim_circuit.num_detectors, + num_observables=stim_circuit.num_observables, + num_shots=shots, + ), + ) + return SimpleNamespace( + mem=mem, + measurements=np.asarray(measurements, dtype=np.uint8), + dets_and_obs=dets_and_obs, + samples_path=samples_path, + metadata_path=metadata_path, + ) + + +def _reference_tensors_from_measurements( + *, + measurements: np.ndarray, + distance: int, + n_rounds: int, + basis: str, + rotation: str, +): + """Independent oracle: build (x_syn_diff, z_syn_diff, trainX) from raw + Stim measurements via XOR-differencing. + + This is the *second* implementation by design: it is the slowest, simplest + one written purely against the surface-code memory experiment convention, + so it cross-checks both the production dets-based helper and Stim's m2d + converter at once. Any drift in the production helper or Stim's detector + emission ordering will surface here. + """ + D = distance + T = n_rounds + shots = measurements.shape[0] + half = (D * D - 1) // 2 + frames = torch.from_numpy(measurements[..., :-(D * D)]).to(torch.uint8) + frames = frames.view(shots, T, D * D - 1).contiguous() + + x_raw = frames[:, :, :half].permute(0, 2, 1).contiguous() + z_raw = frames[:, :, half:].permute(0, 2, 1).contiguous() + zero_batch = torch.zeros((shots, half, 1), dtype=torch.uint8) + x_aug = torch.cat([zero_batch, x_raw], dim=2) + z_aug = torch.cat([zero_batch, z_raw], dim=2) + x_syn_diff = (x_aug[:, :, 1:] ^ x_aug[:, :, :-1]).to(torch.int32).contiguous() + z_syn_diff = (z_aug[:, :, 1:] ^ z_aug[:, :, :-1]).to(torch.int32).contiguous() + + w_map_x = normalized_weight_mapping_Xstab_memory(D, rotation).reshape(D, D) + w_map_z = normalized_weight_mapping_Zstab_memory(D, rotation).reshape(D, D) + x_present = w_map_x.unsqueeze(0).expand(shots, T, D, D).to(torch.float32) + z_present = w_map_z.unsqueeze(0).expand(shots, T, D, D).to(torch.float32) + if basis == "X": + z_syn_diff[:, :, 0] = 0 + z_syn_diff[:, :, -1] = 0 + z_present = z_present.clone() + z_present[:, 0] = 0 + z_present[:, -1] = 0 + else: + x_syn_diff[:, :, 0] = 0 + x_syn_diff[:, :, -1] = 0 + x_present = x_present.clone() + x_present[:, 0] = 0 + x_present[:, -1] = 0 + + idx_map_x = torch.as_tensor(compute_stabX_to_data_index_map(D, rotation), dtype=torch.long) + idx_map_z = torch.as_tensor(compute_stabZ_to_data_index_map(D, rotation), dtype=torch.long) + x_grid = torch.zeros(shots, D * D, T, dtype=torch.float32) + z_grid = torch.zeros(shots, D * D, T, dtype=torch.float32) + x_grid[:, idx_map_x, :] = x_syn_diff[:, :len(idx_map_x), :].to(torch.float32) + z_grid[:, idx_map_z, :] = z_syn_diff[:, :len(idx_map_z), :].to(torch.float32) + + x_type = x_grid.reshape(shots, D, D, T).permute(0, 3, 1, 2).contiguous() + z_type = z_grid.reshape(shots, D, D, T).permute(0, 3, 1, 2).contiguous() + train_x = torch.cat( + [ + x_type.unsqueeze(1), + z_type.unsqueeze(1), + x_present.unsqueeze(1), + z_present.unsqueeze(1), + ], + dim=1, + ).contiguous() + return x_syn_diff, z_syn_diff, train_x + + +class TestStimSampleFileContract(unittest.TestCase): + + def test_write_read_round_trip_for_x_and_z(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + for basis in ("X", "Z"): + artifact = _write_artifact(root, basis=basis, shots=64) + read, metadata = read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis=basis, + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + ) + self.assertTrue(np.array_equal(read, artifact.dets_and_obs)) + self.assertEqual(metadata["basis"], basis) + + def test_file_datapipe_exposes_valid_predecoder_inputs_for_x_and_z(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + for basis in ("X", "Z"): + artifact = _write_artifact(root, basis=basis, shots=32) + pipe = QCDataPipePreDecoder_Memory_from_stim_file( + distance=3, + n_rounds=3, + num_samples=32, + error_mode="circuit_level_surface_custom", + stim_samples_dir=root, + p_error=0.003, + measure_basis=basis, + code_rotation="XV", + noise_model=NoiseModel.from_single_p(0.003), + ) + self.assertEqual(len(pipe), 32) + self.assertTrue( + torch.equal(pipe.dets_and_obs, torch.from_numpy(artifact.dets_and_obs)) + ) + sample = pipe[0] + self.assertEqual(tuple(sample["trainX"].shape), (4, 3, 3, 3)) + self.assertEqual( + sample["dets_and_obs"].numel(), + artifact.mem.stim_circuit.num_detectors + + artifact.mem.stim_circuit.num_observables, + ) + + def test_file_datapipe_matches_measurement_xor_reference(self): + """File datapipe (consumes Stim dets) must agree with the independent + measurement-XOR oracle for several (D, T, basis) triples. + + This is the strongest cross-check: it pins three things at once — + Stim's detector emission ordering, the canonical helper, and the + measurement-based reference implementation. + """ + cases = ( + (3, 3, 32), + (5, 5, 24), + (7, 7, 8), + ) + for distance, n_rounds, shots in cases: + for basis in ("X", "Z"): + with self.subTest(distance=distance, n_rounds=n_rounds, basis=basis): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_measurement_artifact( + root, + basis=basis, + distance=distance, + n_rounds=n_rounds, + shots=shots, + ) + pipe = QCDataPipePreDecoder_Memory_from_stim_file( + distance=distance, + n_rounds=n_rounds, + num_samples=shots, + error_mode="circuit_level_surface_custom", + stim_samples_dir=root, + p_error=0.02, + measure_basis=basis, + code_rotation="XV", + noise_model=NoiseModel.from_single_p(0.02), + ) + expected_x, expected_z, expected_train_x = _reference_tensors_from_measurements( + measurements=artifact.measurements, + distance=distance, + n_rounds=n_rounds, + basis=basis, + rotation="XV", + ) + + self.assertTrue( + torch.equal(pipe.dets_and_obs, torch.from_numpy(artifact.dets_and_obs)) + ) + self.assertTrue(torch.equal(pipe.x_syn_diff_all, expected_x)) + self.assertTrue(torch.equal(pipe.z_syn_diff_all, expected_z)) + self.assertTrue(torch.equal(pipe.trainX_all, expected_train_x)) + + def test_canonical_helper_matches_measurement_xor_reference(self): + """`dets_to_predecoder_inputs` (the single source of truth for the + production datapipes) must match the independent measurement-XOR + oracle directly, without going through file I/O. + """ + for distance, n_rounds in ((3, 3), (5, 4)): + for basis in ("X", "Z"): + with self.subTest(distance=distance, n_rounds=n_rounds, basis=basis): + with tempfile.TemporaryDirectory() as tmp: + artifact = _write_measurement_artifact( + Path(tmp), + basis=basis, + distance=distance, + n_rounds=n_rounds, + shots=12, + ) + num_obs = artifact.mem.stim_circuit.num_observables + dets = torch.from_numpy(artifact.dets_and_obs[:, :-num_obs]).to(torch.uint8) + train_x, x_syn, z_syn = dets_to_predecoder_inputs( + dets, + distance=distance, + n_rounds=n_rounds, + basis=basis, + code_rotation="XV", + ) + expected_x, expected_z, expected_train_x = _reference_tensors_from_measurements( + measurements=artifact.measurements, + distance=distance, + n_rounds=n_rounds, + basis=basis, + rotation="XV", + ) + self.assertTrue(torch.equal(x_syn, expected_x)) + self.assertTrue(torch.equal(z_syn, expected_z)) + self.assertTrue(torch.equal(train_x, expected_train_x)) + + def test_metadata_mismatches_are_explicit(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_artifact(root, basis="X", shots=16) + metadata = read_metadata_json(artifact.metadata_path) + metadata["basis"] = "Z" + write_metadata_json(artifact.metadata_path, metadata) + + with self.assertRaisesRegex(ValueError, "basis mismatch.*file has 'Z'.*expected 'X'"): + read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + ) + + def test_wrong_orientation_detector_count_and_missing_observable_failures(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_artifact(root, basis="X", shots=16) + + metadata = read_metadata_json(artifact.metadata_path) + metadata["code_rotation"] = "XH" + write_metadata_json(artifact.metadata_path, metadata) + with self.assertRaisesRegex(ValueError, "code_rotation mismatch"): + read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + ) + + metadata["code_rotation"] = "XV" + metadata["num_detectors"] = metadata["num_detectors"] + 1 + write_metadata_json(artifact.metadata_path, metadata) + with self.assertRaisesRegex(ValueError, "num_detectors mismatch"): + read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + ) + + metadata["num_detectors"] = artifact.mem.stim_circuit.num_detectors + metadata["num_observables"] = 0 + write_metadata_json(artifact.metadata_path, metadata) + with self.assertRaisesRegex(ValueError, "missing observables"): + read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + ) + + def test_p_error_drift_raises_under_strict_noise(self): + """When metadata records ``p_error`` and the decoder passes a different + ``p_error``, strict validation must surface the drift instead of + silently building a wrong-weight matcher. + """ + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_artifact( + root, + basis="X", + shots=16, + p_error=0.003, + record_noise_fingerprint=True, + ) + with self.assertRaisesRegex(ValueError, "p_error mismatch"): + read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + p_error=0.005, + ) + + def test_noise_model_sha_drift_raises_under_strict_noise(self): + """A different :class:`NoiseModel` (even with the same scalar + ``p_error``) yields a different fingerprint, which must be caught.""" + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_artifact( + root, + basis="X", + shots=16, + p_error=0.003, + record_noise_fingerprint=True, + ) + different = NoiseModel.from_single_p(0.003, spam_factor=0.5) + with self.assertRaisesRegex(ValueError, "noise_model_sha256 mismatch"): + read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + p_error=0.003, + noise_model_sha256=different.sha256(), + noise_model_label="25-param", + ) + + def test_noise_drift_downgraded_to_warning_when_strict_disabled(self): + """``strict_noise=False`` must downgrade the failure to a warning so + users can override after eyeballing the metadata.""" + import warnings as _warnings + + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_artifact( + root, + basis="X", + shots=16, + p_error=0.003, + record_noise_fingerprint=True, + ) + with _warnings.catch_warnings(record=True) as caught: + _warnings.simplefilter("always") + read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + p_error=0.005, + strict_noise=False, + ) + warning_messages = [str(w.message) for w in caught] + self.assertTrue( + any("p_error mismatch" in msg for msg in warning_messages), + f"expected p_error warning, got: {warning_messages}", + ) + + def test_legacy_metadata_without_noise_fields_still_loads(self): + """Older files that predate the noise fingerprint (no ``p_error`` or + ``noise_model_sha256`` in metadata) must keep loading even when the + decoder passes its active fingerprint.""" + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_artifact( + root, + basis="X", + shots=8, + p_error=0.003, + record_noise_fingerprint=False, + ) + metadata = read_metadata_json(artifact.metadata_path) + self.assertNotIn("p_error", metadata) + self.assertNotIn("noise_model_sha256", metadata) + data, _ = read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + p_error=0.005, + noise_model_sha256=NoiseModel.from_single_p(0.001).sha256(), + ) + self.assertEqual(data.shape[0], 8) + + def test_per_basis_subdirectory_layout_resolves(self): + """``{root}/{basis}/samples.dets`` is a valid layout; the resolver + must find it, and the file datapipe must consume it transparently.""" + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + basis = "X" + (root / basis).mkdir(parents=True, exist_ok=True) + artifact = _write_artifact( + root, + basis=basis, + shots=8, + samples_basename=f"{basis}/samples.dets", + metadata_basename=f"{basis}/metadata.json", + record_noise_fingerprint=True, + ) + resolved_samples, resolved_metadata = resolve_stim_sample_paths(root, basis) + self.assertEqual(resolved_samples, root / basis / "samples.dets") + self.assertEqual(resolved_metadata, root / basis / "metadata.json") + + pipe = QCDataPipePreDecoder_Memory_from_stim_file( + distance=3, + n_rounds=3, + num_samples=8, + error_mode="circuit_level_surface_custom", + stim_samples_dir=root, + p_error=0.003, + measure_basis=basis, + code_rotation="XV", + noise_model=NoiseModel.from_single_p(0.003), + ) + self.assertEqual(len(pipe), 8) + self.assertTrue(torch.equal(pipe.dets_and_obs, torch.from_numpy(artifact.dets_and_obs))) + + def test_truncated_dets_file_raises_at_load_time(self): + """Truncating the .dets file must trip the num_shots/shape check in + :func:`read_stim_detector_samples` rather than silently producing + a partial array (which would skew LER computations).""" + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_artifact( + root, + basis="X", + shots=16, + record_noise_fingerprint=True, + ) + raw_bytes = artifact.samples_path.read_bytes() + # Drop the trailing 25% of the file mid-shot. Stim's dets format is + # newline-delimited, so we deliberately drop bytes after a newline + # to keep the parser from raising on its own; the shot-count check + # in read_stim_detector_samples should then notice. + lines = raw_bytes.splitlines(keepends=True) + truncated = b"".join(lines[:max(1, len(lines) // 2)]) + artifact.samples_path.write_bytes(truncated) + with self.assertRaisesRegex(ValueError, "num_shots mismatch"): + read_stim_detector_samples( + samples_path=artifact.samples_path, + metadata_path=artifact.metadata_path, + distance=3, + n_rounds=3, + basis="X", + code_rotation="XV", + num_detectors=artifact.mem.stim_circuit.num_detectors, + num_observables=artifact.mem.stim_circuit.num_observables, + p_error=0.003, + noise_model_sha256=NoiseModel.from_single_p(0.003).sha256(), + ) + + def test_write_rejects_shape_mismatch(self): + """``write_stim_detector_samples`` must refuse arrays whose width does + not match ``num_detectors + num_observables``; otherwise the file + would round-trip into a malformed shape check on read.""" + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + bogus = np.zeros((4, 5), dtype=np.uint8) + with self.assertRaisesRegex(ValueError, "width mismatch"): + write_stim_detector_samples( + path=root / "samples_X.dets", + dets_and_obs=bogus, + num_detectors=10, + num_observables=1, + ) + + def test_helper_handles_t_equals_one(self): + """The canonical helper must produce the right tensor shapes for + ``T = 1`` even though the ONNX-bound eval module asserts ``T >= 2``. + The file/in-memory datapipes use the helper, so ``T = 1`` flowing + from a one-round QPU dump must not silently corrupt the residual + masking on the cross-basis row.""" + distance, n_rounds, shots = 3, 1, 4 + half = (distance * distance - 1) // 2 + torch.manual_seed(13) + dets = torch.randint(0, 2, (shots, 2 * n_rounds * half), dtype=torch.uint8) + for basis in ("X", "Z"): + with self.subTest(basis=basis): + train_x, x_syn, z_syn = dets_to_predecoder_inputs( + dets, + distance=distance, + n_rounds=n_rounds, + basis=basis, + code_rotation="XV", + ) + self.assertEqual(tuple(train_x.shape), (shots, 4, n_rounds, distance, distance)) + self.assertEqual(tuple(x_syn.shape), (shots, half, n_rounds)) + self.assertEqual(tuple(z_syn.shape), (shots, half, n_rounds)) + # Cross-basis row must be all zero because there is no interior + # round to populate when T = 1. + if basis == "X": + self.assertTrue(torch.all(z_syn == 0)) + self.assertTrue(torch.all(train_x[:, 3] == 0)) + else: + self.assertTrue(torch.all(x_syn == 0)) + self.assertTrue(torch.all(train_x[:, 2] == 0)) + + +class TestOfflineStimLER(unittest.TestCase): + + def _base_cfg(self, num_samples: int): + return SimpleNamespace( + code="surface", + datapipe="memory", + distance=3, + n_rounds=3, + data=SimpleNamespace( + error_mode="circuit_level_surface_custom", + code_rotation="XV", + ), + test=SimpleNamespace( + num_samples=num_samples, + n_rounds=3, + p_error=0.003, + meas_basis_test="both", + noise_model="none", + latency_num_samples=0, + th_data=0.0, + th_syn=0.0, + sampling_mode="threshold", + temperature=1.0, + dataloader={ + "batch_size": 16, + "num_workers": 0, + "persistent_workers": False, + }, + ), + ) + + def _run_file_ler(self, root: Path, *, decode_mode: str, model: nn.Module, cfg): + old_env = { + "PREDECODER_STIM_SAMPLES_DIR": os.environ.get("PREDECODER_STIM_SAMPLES_DIR"), + "PREDECODER_DECODE_MODE": os.environ.get("PREDECODER_DECODE_MODE"), + "PREDECODER_TORCH_COMPILE": os.environ.get("PREDECODER_TORCH_COMPILE"), + "PREDECODER_INFERENCE_NUM_WORKERS": os.environ.get("PREDECODER_INFERENCE_NUM_WORKERS"), + } + os.environ["PREDECODER_STIM_SAMPLES_DIR"] = str(root) + os.environ["PREDECODER_DECODE_MODE"] = decode_mode + os.environ["PREDECODER_TORCH_COMPILE"] = "0" + os.environ["PREDECODER_INFERENCE_NUM_WORKERS"] = "0" + try: + return count_logical_errors_with_errorbar(model, torch.device("cpu"), _DummyDist(), cfg) + finally: + for key, value in old_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def test_pymatching_only_ler_matches_golden_x_and_z(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + expected = { + basis: _write_artifact(root, basis=basis, shots=96).ler for basis in ("X", "Z") + } + + result = self._run_file_ler( + root, + decode_mode="pymatching_only", + model=_UnusedModel(), + cfg=self._base_cfg(96), + ) + + for basis in ("X", "Z"): + self.assertIn(basis, result) + observed = float(result[basis]["logical error ratio (pymatch mean)"]) + self.assertEqual(observed, expected[basis]) + self.assertEqual( + float(result[basis]["logical error ratio (mean)"]), expected[basis] + ) + self.assertGreaterEqual(observed, 0.0) + self.assertLessEqual(observed, 1.0) + + def test_ising_decoding_pymatching_smoke_matches_baseline_for_zero_correction_model(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + expected = { + basis: _write_artifact(root, basis=basis, shots=32).ler for basis in ("X", "Z") + } + + result = self._run_file_ler( + root, + decode_mode="ising_decoding_pymatching", + model=_ZeroCorrectionModel(), + cfg=self._base_cfg(32), + ) + + for basis in ("X", "Z"): + baseline = float(result[basis]["logical error ratio (pymatch mean)"]) + after = float(result[basis]["logical error ratio (mean)"]) + self.assertEqual(baseline, expected[basis]) + self.assertEqual(after, expected[basis]) + self.assertGreaterEqual(after, 0.0) + self.assertLessEqual(after, 1.0) + + def test_pymatching_only_single_basis_x_branch(self): + """Cover the ``meas_basis_test='X'`` branch of + ``count_logical_errors_with_errorbar`` — distinct from the ``both`` + branch we exercise above, and previously untested for file mode.""" + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + artifact = _write_artifact(root, basis="X", shots=64) + + cfg = self._base_cfg(64) + cfg.test.meas_basis_test = "X" + + result = self._run_file_ler( + root, + decode_mode="pymatching_only", + model=_UnusedModel(), + cfg=cfg, + ) + + self.assertIn("X", result) + self.assertNotIn("Z", result) + observed = float(result["X"]["logical error ratio (pymatch mean)"]) + self.assertEqual(observed, artifact.ler) + + +if __name__ == "__main__": + unittest.main() diff --git a/code/workflows/run.py b/code/workflows/run.py index 4c8b443..486ae9a 100644 --- a/code/workflows/run.py +++ b/code/workflows/run.py @@ -78,8 +78,53 @@ def run_surface(cfg: DictConfig): from evaluation.inference import run_inference DistributedManager.initialize() dist = DistributedManager() - model = _load_model(cfg, dist) + decode_mode = os.environ.get("PREDECODER_DECODE_MODE", "").strip().lower() + if decode_mode == "pymatching_only": + model = torch.nn.Identity() + else: + model = _load_model(cfg, dist) run_inference(model, dist.device, dist, cfg) + elif cfg.workflow.task == "generate_stim_data": + from export.generate_test_data import generate_test_data + from hydra.core.hydra_config import HydraConfig + from omegaconf import OmegaConf + + basis_cfg = str(getattr(cfg.test, "meas_basis_test", "both")).upper() + bases = ["X", "Z"] if basis_cfg in ("BOTH", "MIXED") else [basis_cfg] + if any(b not in ("X", "Z") for b in bases): + raise ValueError(f"Invalid test.meas_basis_test={basis_cfg!r}; expected X, Z, or both.") + + num_samples = int(getattr(cfg.test, "num_samples", 1000)) + output_dir = os.path.join(HydraConfig.get().runtime.output_dir, "stim_samples") + noise_model_cfg = getattr(cfg.data, "noise_model", None) + noise_model_params = None + if noise_model_cfg is not None: + noise_model_params = OmegaConf.to_container(noise_model_cfg, resolve=True) + + # The generate_stim_data workflow ONLY writes Stim sample artifacts + # (samples_{basis}.dets + metadata_{basis}.json). The CUDA-Q .bin + # artifacts are produced by a separate workflow (see generate_test_data + # CLI with --stim-artifacts/--no-cudaq-artifacts) to keep the offline + # decoding output dir narrowly scoped. + write_cudaq_artifacts = False + + print( + "[generate_stim_data] Writing Stim detector samples " + f"to {output_dir} for basis={bases}, shots={num_samples}" + ) + for basis in bases: + generate_test_data( + distance=int(cfg.distance), + n_rounds=int(cfg.n_rounds), + basis=basis, + p_error=float(getattr(cfg.test, "p_error", 0.003)), + code_rotation=str(getattr(cfg.data, "code_rotation", "XV")), + noise_model_params=noise_model_params, + num_samples=num_samples, + output_dir=output_dir, + write_stim_artifacts=True, + write_cudaq_artifacts=write_cudaq_artifacts, + ) elif cfg.workflow.task == "data": DistributedManager.initialize() dist = DistributedManager() diff --git a/cookbook/predecoder.ipynb b/cookbook/predecoder.ipynb index 3351186..2c0e60b 100644 --- a/cookbook/predecoder.ipynb +++ b/cookbook/predecoder.ipynb @@ -1,1151 +1,1304 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# NVIDIA Ising pre-decoder tutorial\n", - "\n", - "On World Quantum Day April 14, 2026, the [NVIDIA Ising models](https://developer.nvidia.com/ising) were released: a family of open quantum models consisting of the [Ising pre-decoder model](https://github.com/NVIDIA/Ising-Decoding) and the [Ising calibration agent](https://github.com/NVIDIA/ising-calibration). This tutorial will focus on the former, by explaining how a pre-decoder works, showcasing data from the [NVIDIA Ising pre-decoder whitepaper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes), and providing code examples so you can run the Ising decoder yourself. The full documentation and model parameters can be found on [GitHub](https://github.com/NVIDIA/Ising-Decoding) and [Hugging Face](https://huggingface.co/nvidia/ising_decoder_surface_code_1_fast). \n", - "\n", - "\n", - "## What is a pre-decoder?\n", - "\n", - "Fault tolerant quantum computing depends on scalable quantum error correction (QEC) workflows which encode logical qubits as many noisy physical qubits and perform logical operations using the logical qubits. This can only be achieved with continual measurement and decoding of syndrome data from the QPU to identify and correct errors that occur due to the noisy physical qubits.\n", - "\n", - "Currently, the decoding step (performed on a classical coprocessor) is likely to be the primary bottleneck of QEC, since errors generally need to be reconciled before each non-Clifford gate is applied. If a decoder throughput is too low it cannot process syndromes faster than they are measured from the QPU causing an exponential backlog to build up and the entire process to grind to a halt. Even if the throughput problem is solved, the speed of the decoder can determine the wall clock time of the QPU and present a practical limitation on quantum applications. Complicating this further, any optimizations or approximations made for decoder speed must maintain a low enough logical error rate for the QEC procedure to enable the end application. For more detail on these concepts, see the CUDA-Q Academic lessons found here in the [QEC 101 series](https://github.com/NVIDIA/cuda-q-academic/tree/main/qec101).\n", - "\n", - "Finding faster, more accurate, and more scalable decoders is at the heart of developing useful QEC approaches. On World Quantum Day 2026, NVIDIA released the Ising AI pre-decoder model that improves the performance of surface code error decoding compared to using the state-of-the-art algorithmic decoder PyMatching by itself. \n", - "\n", - "A pre-decoder is an additional step added before the default decoder (which is PyMatching in the case of the surface code.) The job of the pre-decoder is to quickly resolve easily identified local errors. The remaining reduced set of syndromes is sent to PyMatching which decodes these more challenging syndromes. Both decoder results are then combined to produce the final set of corrections. The workflow is depicted qualitatively below.\n", - "\n", - "\n", - "\n", - "\n", - "In the above image, the large blocks are the syndrome volumes where the height and depth are the spatial dimensions of the surface code and the width is the temporal dimension. Small colored blocks are flagged syndromes.\n", - "\n", - "At first it may seem that adding a step will slow things down, but in fact, this allows faster overall decoding compared to PyMatching alone. In many cases, the decoding time *and* the accuracy are improved using the predecoder approach! \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "In the figure above taken from the [NVIDIA Ising pre-decoder whitepaper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes), with the model computed using a batch size of 1 on a GB300 with FP8 precision and PyMatching computed on a Grace Neoverse-V2 CPU, the grey and yellow lines correspond to PyMatching with increasing code distances. The blue and red lines are different combinations of the NVIDIA Ising pre-decoder and PyMatching. Notice, that many of the data points (with fixed distance) are below and left of the PyMatching points, meaning that we can produce lower logical error rates with less latency compared to PyMatching. Generally, PyMatching is more accurate at high distances, but the pre-decoder workflow can run much faster at the cost of a slightly elevated logical error rate (LER). Even in this case, the pre-decoder workflow provides additional freedom when balancing tradeoffs of accuracy and speed.\n", - "\n", - "### What you will learn in this tutorial\n", - "\n", - "The rest of this tutorial walks through the key workflows for using the Ising decoder:\n", - "\n", - "1. **Quick Start** \u2014 Run the full NVIDIA Ising pre-decoder pipeline end-to-end and compare to PyMatching.\n", - "2. **Training** \u2014 Train your own NVIDIA Ising pre-decoder from scratch with on-the-fly data generation and learn how to set key hyperparameters.\n", - "3. **Optimizing Inference** \u2014 Improve inference performance using ONNX export, TensorRT FP16, and FP8 quantization for real-time deployment.\n", - "\n", - "\n", - "The full documentation for the Ising decoder can be found [here](https://github.com/NVIDIA/Ising-Decoding), along with the [whitepaper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes) that explains the models in more detail." - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NVIDIA Ising pre-decoder tutorial\n", + "\n", + "On World Quantum Day April 14, 2026, the [NVIDIA Ising models](https://developer.nvidia.com/ising) were released: a family of open quantum models consisting of the [Ising pre-decoder model](https://github.com/NVIDIA/Ising-Decoding) and the [Ising calibration agent](https://github.com/NVIDIA/ising-calibration). This tutorial will focus on the former, by explaining how a pre-decoder works, showcasing data from the [NVIDIA Ising pre-decoder whitepaper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes), and providing code examples so you can run the Ising decoder yourself. The full documentation and model parameters can be found on [GitHub](https://github.com/NVIDIA/Ising-Decoding) and [Hugging Face](https://huggingface.co/nvidia/ising_decoder_surface_code_1_fast). \n", + "\n", + "\n", + "## What is a pre-decoder?\n", + "\n", + "Fault tolerant quantum computing depends on scalable quantum error correction (QEC) workflows which encode logical qubits as many noisy physical qubits and perform logical operations using the logical qubits. This can only be achieved with continual measurement and decoding of syndrome data from the QPU to identify and correct errors that occur due to the noisy physical qubits.\n", + "\n", + "Currently, the decoding step (performed on a classical coprocessor) is likely to be the primary bottleneck of QEC, since errors generally need to be reconciled before each non-Clifford gate is applied. If a decoder throughput is too low it cannot process syndromes faster than they are measured from the QPU causing an exponential backlog to build up and the entire process to grind to a halt. Even if the throughput problem is solved, the speed of the decoder can determine the wall clock time of the QPU and present a practical limitation on quantum applications. Complicating this further, any optimizations or approximations made for decoder speed must maintain a low enough logical error rate for the QEC procedure to enable the end application. For more detail on these concepts, see the CUDA-Q Academic lessons found here in the [QEC 101 series](https://github.com/NVIDIA/cuda-q-academic/tree/main/qec101).\n", + "\n", + "Finding faster, more accurate, and more scalable decoders is at the heart of developing useful QEC approaches. On World Quantum Day 2026, NVIDIA released the Ising AI pre-decoder model that improves the performance of surface code error decoding compared to using the state-of-the-art algorithmic decoder PyMatching by itself. \n", + "\n", + "A pre-decoder is an additional step added before the default decoder (which is PyMatching in the case of the surface code.) The job of the pre-decoder is to quickly resolve easily identified local errors. The remaining reduced set of syndromes is sent to PyMatching which decodes these more challenging syndromes. Both decoder results are then combined to produce the final set of corrections. The workflow is depicted qualitatively below.\n", + "\n", + "\n", + "\n", + "\n", + "In the above image, the large blocks are the syndrome volumes where the height and depth are the spatial dimensions of the surface code and the width is the temporal dimension. Small colored blocks are flagged syndromes.\n", + "\n", + "At first it may seem that adding a step will slow things down, but in fact, this allows faster overall decoding compared to PyMatching alone. In many cases, the decoding time *and* the accuracy are improved using the predecoder approach! \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "In the figure above taken from the [NVIDIA Ising pre-decoder whitepaper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes), with the model computed using a batch size of 1 on a GB300 with FP8 precision and PyMatching computed on a Grace Neoverse-V2 CPU, the grey and yellow lines correspond to PyMatching with increasing code distances. The blue and red lines are different combinations of the NVIDIA Ising pre-decoder and PyMatching. Notice, that many of the data points (with fixed distance) are below and left of the PyMatching points, meaning that we can produce lower logical error rates with less latency compared to PyMatching. Generally, PyMatching is more accurate at high distances, but the pre-decoder workflow can run much faster at the cost of a slightly elevated logical error rate (LER). Even in this case, the pre-decoder workflow provides additional freedom when balancing tradeoffs of accuracy and speed.\n", + "\n", + "### What you will learn in this tutorial\n", + "\n", + "The rest of this tutorial walks through the key workflows for using the Ising decoder:\n", + "\n", + "1. **Quick Start** — Run the full NVIDIA Ising pre-decoder pipeline end-to-end and compare to PyMatching.\n", + "2. **Training** — Train your own NVIDIA Ising pre-decoder from scratch with on-the-fly data generation and learn how to set key hyperparameters.\n", + "3. **Optimizing Inference** — Improve inference performance using ONNX export, TensorRT FP16, and FP8 quantization for real-time deployment.\n", + "\n", + "\n", + "The full documentation for the Ising decoder can be found [here](https://github.com/NVIDIA/Ising-Decoding), along with the [whitepaper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes) that explains the models in more detail." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "This tutorial can be found in the `cookbook/` directory of the `Ising-Decoding` repository.\n", + "\n", + "**Prerequisites:**\n", + "- **NVIDIA GPU** with CUDA drivers installed (`nvidia-smi` must be on your PATH)\n", + "- **Python 3.11, 3.12, or 3.13**\n", + "\n", + "The cell below will:\n", + "1. Locate the repository root and add the predecoder source code to the Python path\n", + "2. Detect the CUDA version from your GPU driver and install the matching PyTorch build\n", + "3. Install all training dependencies (which are a superset of inference dependencies), including cuQuantum for GPU-accelerated data generation, ONNX tooling for the inference optimization sections, and all other required packages\n", + "4. Authenticate with Hugging Face to download pre-trained model weights (the model is access-controlled). **Paste your token into the `HF_TOKEN` variable** in the cell below (get one at https://huggingface.co/settings/tokens)\n", + "\n", + "**TensorRT** (used in the FP16 engine and FP8 quantization sections) is installed automatically with the correct CUDA version.\n", + "\n", + "**Note:** If you have already run the setup cell once and all packages are installed, you can skip it on subsequent runs — only the imports cell that follows is needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess, sys, os, re, shutil\n", + "\n", + "NOTEBOOK_DIR = os.path.abspath('')\n", + "PREDECODER_ROOT = os.path.abspath(os.path.join(NOTEBOOK_DIR,'..'))\n", + "sys.path.insert(0, os.path.join(PREDECODER_ROOT, 'code'))\n", + "\n", + "print(f'PREDECODER_ROOT: {PREDECODER_ROOT}')\n", + "assert os.path.isdir(os.path.join(PREDECODER_ROOT, 'code')), (\n", + " f\"Cannot find predecoder source code at {PREDECODER_ROOT}/code. \"\n", + " f\"This notebook must live at /tutorials/predecoder_tutorial/.\"\n", + ")\n", + "\n", + "assert shutil.which('nvidia-smi'), 'nvidia-smi not found — this tutorial requires an NVIDIA GPU.'\n", + "\n", + "nvsmi_output = subprocess.check_output(['nvidia-smi'], text=True)\n", + "cuda_match = re.search(r'CUDA Version:\\s+([\\d.]+)', nvsmi_output)\n", + "assert cuda_match, 'Could not detect CUDA version from nvidia-smi output.'\n", + "cuda_ver = cuda_match.group(1)\n", + "cuda_major = cuda_ver.split('.')[0]\n", + "print(f'CUDA {cuda_ver} detected (major: {cuda_major})')\n", + "\n", + "def _pip(*args):\n", + " subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', *args])\n", + "\n", + "gpu_info = subprocess.check_output(\n", + " ['nvidia-smi', '--query-gpu=name,compute_cap', '--format=csv,noheader'], text=True\n", + ").strip().split('\\n')[0]\n", + "gpu_name, gpu_cc = [s.strip() for s in gpu_info.split(',')]\n", + "gpu_cc_major = int(gpu_cc.split('.')[0])\n", + "gpu_cc_minor = int(gpu_cc.split('.')[1])\n", + "gpu_supports_fp8 = (gpu_cc_major > 8) or (gpu_cc_major == 8 and gpu_cc_minor >= 9)\n", + "print(f'GPU: {gpu_name} (compute capability {gpu_cc}, FP8: {\"yes\" if gpu_supports_fp8 else \"no\"})')\n", + "\n", + "print('Environment OK.')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_pip('--upgrade', 'pip', 'setuptools<82', 'wheel')\n", + "\n", + "torch_cuda_tag = {'12': 'cu128', '13': 'cu130'}[cuda_major]\n", + "print(f'Installing PyTorch (wheel index: {torch_cuda_tag})...')\n", + "_pip('torch', '--index-url', f'https://download.pytorch.org/whl/{torch_cuda_tag}',\n", + " '--extra-index-url', 'https://pypi.org/simple')\n", + "\n", + "import torch\n", + "assert torch.cuda.is_available(), 'PyTorch installed but CUDA is not available.'\n", + "print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}, '\n", + " f'GPU: {torch.cuda.get_device_name(0)}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_req = os.path.join(PREDECODER_ROOT, 'code', f'requirements_public_train-cu{cuda_major}.txt')\n", + "assert os.path.exists(train_req), (\n", + " f\"No training requirements for CUDA {cuda_major}: {train_req}\\n\"\n", + " f\"Available: requirements_public_train-cu12.txt, requirements_public_train-cu13.txt\"\n", + ")\n", + "print(f'Installing predecoder dependencies from: {os.path.basename(train_req)}')\n", + "print(f' (this includes: stim, pymatching, cuquantum, onnx tooling, and more)')\n", + "_pip('-r', train_req, 'huggingface_hub')\n", + "\n", + "trt_pkg = f'tensorrt-cu{cuda_major}'\n", + "print(f'Installing TensorRT ({trt_pkg}) for ONNX optimization sections...')\n", + "_pip(trt_pkg)\n", + "\n", + "print('All dependencies installed.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Paste your Hugging Face token below (get one at https://huggingface.co/settings/tokens).\n", + "HF_TOKEN = \"\" # <-- paste your token here, e.g. \"hf_abc123...\"\n", + "\n", + "import huggingface_hub\n", + "if HF_TOKEN:\n", + " huggingface_hub.login(token=HF_TOKEN)\n", + " print('Logged in to Hugging Face.')\n", + "elif huggingface_hub.get_token():\n", + " print('Using previously cached Hugging Face token.')\n", + "else:\n", + " raise RuntimeError(\n", + " 'No Hugging Face token found. Paste your token into the HF_TOKEN variable above.\\n'\n", + " 'Get one at: https://huggingface.co/settings/tokens'\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import pymatching\n", + "import time\n", + "from types import SimpleNamespace\n", + "\n", + "from qec.noise_model import NoiseModel\n", + "from qec.surface_code.memory_circuit import MemoryCircuit\n", + "from model.registry import get_model_spec\n", + "from model.factory import ModelFactory\n", + "from evaluation.logical_error_rate import PreDecoderMemoryEvalModule, _build_stab_maps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Quick Start\n", + "\n", + "The fastest path to running the pre-decoder is shown below. This code cell generates $d$=9 surface code syndrome data directly using the repo's `MemoryCircuit` and STIM, loads the pre-trained model from [Hugging Face](https://huggingface.co/nvidia/ising_decoder_surface_code_1_fast), and runs the full decode pipeline. Note, we are using the model optimized for speed, but a second pretrained model targeting accuracy is available [here](https://huggingface.co/nvidia/ising_decoder_surface_code_1_accurate).\n", + "\n", + "The workflow has four primary steps:\n", + "\n", + "- **Syndrome data generation** — A noisy surface code circuit is simulated using STIM. Each shot produces a set of detector bits (syndrome differences across rounds) and an observable bit (whether a logical error occurred). This is what a real QPU would send to the classical decoder.\n", + "- **Pre-trained model** — A 3D convolutional neural network trained to recognize local, easily identifiable error patterns directly from syndrome data. It processes thousands of shots in parallel on GPU and returns a set of corrections along with a reduced residual syndrome.\n", + "- **Pre-decoder inference pipeline** — The model is wrapped with preprocessing (mapping flat syndrome bits onto a spatial grid) and postprocessing (applying predicted corrections to reduce the syndrome). The output is a partial logical correction and a much sparser syndrome for PyMatching to handle.\n", + "- **PyMatching on residuals** — The industry-standard MWPM decoder finishes the job on the reduced syndrome. Its result is XOR'd with the pre-decoder's partial correction to produce the final logical error prediction.\n", + "\n", + "**Data format:** The pre-decoder takes a flat array of detector bits, obtained from a QPU or via simulation in the case of this example. Internally, these bits are reshaped into a 4-channel 3D grid of shape `(batch, 4, rounds, distance, distance)` before being passed to the CNN. The four channels encode X and Z syndrome values and the surface code boundary geometry. For full details on the data format and model architecture, see [the paper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes).\n", + "\n", + "\n", + "Run the code below to see how the performance of PyMatching, the Ising decoder, and the combination of both compare.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── 1. Generate syndrome data ────────────────────────────────────────────────\n", + "# These are the key QEC parameters for a surface code memory experiment.\n", + "# distance=9 means a 9x9 grid of data qubits (81 qubits + 80 ancilla).\n", + "# n_rounds=9 is the number of syndrome measurement cycles (standard: match distance).\n", + "# basis=\"X\" means we protect the X-type logical qubit (alternatively \"Z\").\n", + "# code_rotation=\"XV\" is the internal label for surface code orientation O1.\n", + "distance, n_rounds, basis, code_rotation, num_samples = 9, 9, \"X\", \"XV\", 5000\n", + "\n", + "# Build a 25-parameter depolarizing noise model from a single error rate p=0.005.\n", + "# from_single_p() distributes p across all error channels using standard ratios:\n", + "# prep/meas errors: 2p/3 (~0.0033)\n", + "# idle during CNOT: p/3 (~0.0017)\n", + "# idle during SPAM: 2p/3 (~0.0033)\n", + "# two-qubit CNOT: p/15 (~0.00033) per Pauli pair\n", + "\n", + "noise = NoiseModel.from_single_p(0.005)\n", + "p_max = noise.get_max_probability() # largest value across all 25 params (~2p/3)\n", + "\n", + "# Build the full noisy surface code circuit using Stim.\n", + "# The scalar error args (idle_error, sqgate_error, etc.) are placeholder values\n", + "# required by the constructor — the actual error channels are driven by noise_model.\n", + "# add_boundary_detectors=True adds virtual detectors at the start/end of the circuit,\n", + "# which improves decoding accuracy at the temporal boundaries of the experiment.\n", + "\n", + "circ = MemoryCircuit(\n", + " distance=distance, n_rounds=n_rounds, basis=basis, code_rotation=code_rotation,\n", + " idle_error=p_max, sqgate_error=p_max, tqgate_error=p_max,\n", + " spam_error=(2/3)*p_max, noise_model=noise, add_boundary_detectors=True,\n", + ")\n", + "circ.set_error_rates()\n", + "\n", + "# Sample the circuit: each row of `shots` is one full experiment run (all measurements).\n", + "shots = circ.stim_circuit.compile_sampler().sample(num_samples)\n", + "\n", + "# Convert raw measurement bits into detector outcomes and logical observable flips.\n", + "# Detectors are syndrome parities across two consecutive rounds — they fire (=1)\n", + "# when consecutive stabilizer measurements disagree, signaling an error.\n", + "# Observables are the true logical error bits — what we're trying to predict.\n", + "\n", + "conv = circ.stim_circuit.compile_m2d_converter()\n", + "dets_obs = conv.convert(measurements=shots, append_observables=True)\n", + "num_obs = circ.stim_circuit.num_observables\n", + "detectors = dets_obs[:, :-num_obs].astype(np.uint8) # shape: (shots, n_detectors)\n", + "observables = dets_obs[:, -num_obs:].astype(np.uint8) # shape: (shots, n_observables)\n", + "\n", + "# Build the Detector Error Model (DEM): a graph where nodes are detectors and\n", + "# edges are the most likely single-error mechanisms connecting them.\n", + "# PyMatching uses this graph to find minimum-weight correction paths.\n", + "\n", + "dem = circ.stim_circuit.detector_error_model(\n", + " decompose_errors=True, # split multi-detector errors into pairs for graph compatibility\n", + " approximate_disjoint_errors=True, # merge nearly-identical error mechanisms\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the syndrome data generated, we now have `detectors` (the parity-check outcomes the decoder will work with) and `observables` (the ground-truth logical errors we want to predict). We also built a Detector Error Model (`dem`) that PyMatching will use later.\n", + "\n", + "Next, we load the pre-trained Ising pre-decoder model from Hugging Face. Model 1 has a receptive field of R=9, which matches our $d$=9 experiment exactly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── 2. Load pre-trained model ────────────────────────────────────────────────\n", + "# Model 1 has receptive field R=9, matching our d=9 experiment.\n", + "# The receptive field sets the maximum space-time window the model can see.\n", + "\n", + "model_id = 1\n", + "spec = get_model_spec(model_id)\n", + "\n", + "# Build the model architecture from the registry spec.\n", + "# num_filters/kernel_size define the 3D convolutional layers that process the\n", + "# syndrome tensor (space x space x time).\n", + "\n", + "model_cfg = SimpleNamespace(\n", + " code=\"surface\", distance=distance, n_rounds=n_rounds,\n", + " model=SimpleNamespace(\n", + " version=\"predecoder_memory_v1\",\n", + " num_filters=list(spec.num_filters), kernel_size=list(spec.kernel_size),\n", + " dropout_p=0.0, activation=\"gelu\", input_channels=4, out_channels=4,\n", + " ),\n", + ")\n", + "model = ModelFactory.create_model(model_cfg)\n", + "\n", + "# Download weights from HuggingFace (cached after first run).\n", + "import logging\n", + "logging.getLogger(\"httpx\").setLevel(logging.WARNING)\n", + "from huggingface_hub import hf_hub_download\n", + "from safetensors.torch import load_file\n", + "\n", + "model_path = hf_hub_download(\n", + " repo_id=\"nvidia/ising_decoder_surface_code_1_fast\",\n", + " filename=\"ising_decoder_surface_code_1_fast_r9_v1.0.77_fp16.safetensors\",\n", + " token=HF_TOKEN or True,\n", + ")\n", + "sd = load_file(model_path, device=\"cpu\")\n", + "sd = {(k[7:] if k.startswith(\"module.\") else k): v.float() for k, v in sd.items()}\n", + "model.load_state_dict(sd)\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model printed above is just the neural network — it knows how to process data, but it expects a very specific input format (a 4-channel 3D tensor shaped to the surface code lattice), not the flat array of detector bits we generated earlier. The next cell wraps the model inside `PreDecoderMemoryEvalModule`, which handles the format conversion automatically: flat detector bits go in, and a partial logical correction plus a simplified syndrome for PyMatching come out." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── 3. Build inference pipeline ──────────────────────────────────────────────\n", + "# Move model to GPU if available; CPU fallback works but will be slow.\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "model = model.to(device)\n", + "\n", + "# Build stabilizer maps: spatial lookup tables that tell the model which detectors\n", + "# correspond to which physical qubit locations on the surface code lattice.\n", + "maps = _build_stab_maps(distance, code_rotation)\n", + "\n", + "# PreDecoderMemoryEvalModule wraps the model with the pre/post-processing needed\n", + "# for inference: it accepts raw detector arrays and returns:\n", + "# out[:, 0] — predicted logical correction (0 or 1 per shot)\n", + "# out[:, 1:] — residual detectors after pre-decoding (passed to PyMatching)\n", + "eval_cfg = SimpleNamespace(\n", + " distance=distance,\n", + " test=SimpleNamespace(meas_basis_test=basis, th_data=0.0, th_syn=0.0,\n", + " sampling_mode=\"threshold\", temperature=1.0,\n", + " temperature_data=None, temperature_syn=None, n_rounds=n_rounds),\n", + " data=SimpleNamespace(code_rotation=code_rotation),\n", + " enable_fp16=False,\n", + ")\n", + "pipeline = PreDecoderMemoryEvalModule(model, eval_cfg, maps, device).to(device)\n", + "pipeline.eval();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `pipeline` object now encapsulates the entire pre-decoder workflow in a single callable: give it a batch of raw detector bits and it returns a two-part result — a predicted logical correction (column 0) and a set of residual syndromes (remaining columns) that have been simplified by the neural network. Those residual syndromes are what we will pass to PyMatching as a second decoding step." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Everything is ready. We now compare four decoding strategies side by side:\n", + "\n", + "1. **Undecoded** — no correction at all (the raw logical error rate)\n", + "2. **PyMatching only** — the standard MWPM decoder applied to the full syndrome\n", + "3. **Pre-decoder alone** — just the neural network's partial correction, no follow-up\n", + "4. **Pre-decoder + PyMatching** — the full pipeline: the model reduces the syndrome, then PyMatching finishes the job on the sparser residual. The final correction is the XOR of both predictions.\n", + "\n", + "A single warmup pass runs first so that `torch.compile` JIT-compiles the model for this batch shape before we start timing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ── 4. Run all four decoding strategies ─────────────────────────────────────\n", + "# Build the PyMatching decoder from the DEM graph — this is the classical baseline.\n", + "\n", + "matcher = pymatching.Matching.from_detector_error_model(dem)\n", + "obs_flat = observables.flatten()\n", + "\n", + "# Warmup pass: torch.compile JIT-compiles the model on first call for this batch\n", + "# shape/dtype. We discard these timings — subsequent calls reflect true latency.\n", + "\n", + "with torch.no_grad():\n", + " _ = pipeline(torch.from_numpy(detectors).to(torch.uint8).to(device))\n", + "_ = matcher.decode_batch(detectors)\n", + "\n", + "\n", + "\n", + "# Strategy 1 — Undecoded (baseline: what LER looks like with no correction at all)\n", + "\n", + "undecoded_ler = obs_flat.mean()\n", + "\n", + "\n", + "\n", + "# Strategy 2 — PyMatching only (standard MWPM decoder, the classical state of the art)\n", + "\n", + "t0 = time.perf_counter()\n", + "pm_pred = np.asarray(matcher.decode_batch(detectors), dtype=np.uint8).reshape(-1, num_obs).flatten()\n", + "t_pm = time.perf_counter() - t0\n", + "pm_ler = (pm_pred != obs_flat).mean()\n", + "\n", + "\n", + "\n", + "# Strategy 3 — Pre-decoder alone (neural network output, no follow-up matching)\n", + "# This shows what the model achieves without a global decoder — useful for\n", + "# understanding the pre-decoder's contribution in isolation.\n", + "\n", + "dets_t = torch.from_numpy(detectors).to(torch.uint8).to(device)\n", + "t0 = time.perf_counter()\n", + "with torch.no_grad():\n", + " out = pipeline(dets_t)\n", + "t_pd = time.perf_counter() - t0\n", + "pre_L = out[:, 0].cpu().numpy() # model's logical correction prediction\n", + "res_dets = out[:, 1:].cpu().numpy().astype(np.uint8) # residual syndromes for PyMatching\n", + "pd_ler = (pre_L.astype(int) != obs_flat.astype(int)).mean()\n", + "\n", + "\n", + "\n", + "# Strategy 4 — Pre-decoder + PyMatching (the full pipeline)\n", + "# The model first simplifies the syndrome pattern (reducing syndrome density),\n", + "# then PyMatching decodes the cleaner residual syndromes. The final logical\n", + "# correction is XOR of the two predictions (mod 2 addition of Pauli corrections).\n", + "\n", + "t0 = time.perf_counter()\n", + "pm_res = np.asarray(matcher.decode_batch(res_dets), dtype=np.uint8).reshape(-1, num_obs).flatten()\n", + "t_pm_res = time.perf_counter() - t0\n", + "combo_L = (pre_L.astype(int) + pm_res.astype(int)) % 2 # combine both corrections\n", + "combo_ler = (combo_L != obs_flat).mean()\n", + "t_combo = t_pd + t_pm_res # total wall time for the combined pipeline\n", + "\n", + "\n", + "\n", + "# ── 5. Results ───────────────────────────────────────────────────────────────\n", + "print(f\" {'Strategy':<30} {'LER':>8} {'Time (ms)':>10} {'µs/shot':>8}\")\n", + "print(f\" {'-'*30} {'-'*8} {'-'*10} {'-'*8}\")\n", + "print(f\" {'Undecoded':<30} {undecoded_ler:>8.5f} {'—':>10} {'—':>8}\")\n", + "print(f\" {'PyMatching only':<30} {pm_ler:>8.5f} {t_pm*1e3:>9.1f}ms {t_pm*1e6/num_samples:>7.1f}µs\")\n", + "print(f\" {'Pre-decoder alone':<30} {pd_ler:>8.5f} {t_pd*1e3:>9.1f}ms {t_pd*1e6/num_samples:>7.1f}µs\")\n", + "print(f\" {'Pre-decoder + PyMatching':<30} {combo_ler:>8.5f} {t_combo*1e3:>9.1f}ms {t_combo*1e6/num_samples:>7.1f}µs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the LER of the NVIDIA Ising pre-decoder + PyMatching combination is nearly identical to PyMatching alone in this case, but the overall and per-shot decoding time is much smaller for the pre-decoder + PyMatching approach.\n", + "\n", + "An important caveat: the workflow above shows decode time for a batch of 5,000 syndromes. If you run the same code with a batch size of 1, as you would in a real-time QEC application, the pre-decoder + PyMatching approach is much slower than PyMatching alone, due to the significant overhead of launching a GPU kernel. Batched decoding amortizes that overhead across all 5,000 syndrome rounds at once. \n", + "\n", + "This behavior is expected at small scale and is consistent with the results from the [paper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes). Compare the d=9 point for the light-blue line (pre-decoder + PyMatching) and the grey line (PyMatching only).\n", + "\n", + "\n", + "\n", + "To reach a regime where the kernel-launch overhead is worth the tradeoff, we need a combination of:\n", + "\n", + "1. A higher-distance surface code\n", + "2. A more complex noise model (i.e., harder-to-decode syndrome patterns)\n", + "3. Optimizations for low-latency model inference\n", + "\n", + "In the next section, you will learn how to train your own model on higher-distance surface codes with much more complex noise patterns, using on-the-fly GPU-accelerated syndrome generation with cuStabilizer. After that, you will learn how to optimize model inference for the best possible real-time deployment performance on your specific GPU. Together, these sections will equip you to study regimes where the NVIDIA Ising pre-decoder + PyMatching approach significantly outperforms PyMatching alone." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Training your own Pre-decoder Model\n", + "\n", + "The pre-trained models ship with this repo, but there are several reasons you might want to train your own:\n", + "\n", + "- **Custom noise model**: The pre-trained models were trained on a specific noise profile. Real quantum hardware has different error characteristics and you may want to see how the model handles a custom noise profile.\n", + "- **Different code distances**: Pre-trained models are available for specific distances. Training lets you target any distance within the model's receptive field.\n", + "- **Different noise regimes**: Physical error rates vary by platform. A model trained at p ≈ 0.005 may not be optimal for p ≈ 0.001 or p ≈ 0.006.\n", + "- **Architecture exploration**: The 5 public models vary in depth and width. Training lets you experiment with the tradeoff between model size (latency) and correction quality (LER improvement).\n", + "\n", + "The following sections will walk you through the process for training your own model." + ] + }, + { + "cell_type": "markdown", + "id": "inserted", + "metadata": {}, + "source": [ + "### Model Variants\n", + "\n", + "Five model variants are available, trading off model size against correction quality. The **receptive field (RF)** determines the spatial and temporal window the model can see: a model with RF = 9 requires `distance ≥ 9` and `n_rounds ≥ 9`. Larger RF models can correct more complex, spatially spread errors but cost more compute per shot.\n", + "\n", + "Run the code below to see the model specs for the 5 included model architectures. The length of the `num_filters` list is the number of layers in the model. Generally, more filters allow the model to detect more complex error patterns. The kernel size prescribes the spatial and temporal window over which features are computed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "inserted", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "print(f\"{'Model':<8} {'num_filters':<30} {'kernel_size':<24} {'RF size':<10} {'num_params':>12}\")\n", + "print(\"-\" * 86)\n", + "for mid in range(1, 6):\n", + " spec = get_model_spec(mid)\n", + " cfg = SimpleNamespace(\n", + " code=\"surface\", distance=9, n_rounds=9,\n", + " model=SimpleNamespace(\n", + " version=\"predecoder_memory_v1\",\n", + " num_filters=list(spec.num_filters),\n", + " kernel_size=list(spec.kernel_size),\n", + " dropout_p=0.0, activation=\"gelu\",\n", + " input_channels=4, out_channels=4,\n", + " ),\n", + " )\n", + " m = ModelFactory.create_model(cfg)\n", + " nparams = sum(p.numel() for p in m.parameters())\n", + " print(f\" {mid:<6} {str(spec.num_filters):<30} {str(spec.kernel_size):<24} {spec.receptive_field:<10} {nparams:>12,}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### On-the-Fly Training Data Generation\n", + "\n", + "A key benefit of the Ising pre-decoder training pipeline is on-the-fly generation of syndrome data during each training epoch. The simulation uses a **25-parameter circuit-level noise model** that independently controls:\n", + "\n", + "- **State preparation errors** (2 params): `p_prep_X`, `p_prep_Z`\n", + "- **Measurement errors** (2 params): `p_meas_X`, `p_meas_Z`\n", + "- **Idle errors during CNOT layers** (3 params): `p_idle_cnot_X/Y/Z`\n", + "- **Idle errors during state-prep/measurement windows** (3 params): `p_idle_spam_X/Y/Z`\n", + "- **Two-qubit CNOT gate errors** (15 params): all Pauli pairs `IX, IY, IZ, XI, XX, ... ZZ` (excluding `II`)\n", + "\n", + "This noise model is defined in `conf/config_public.yaml` and can be customized to match real hardware calibration data.\n", + "\n", + "For training, the simulation runs inside `QCDataGeneratorTorch`, which uses GPU accelerated DEM sampling that leverages the [NVIDIA cuStabilizer library](https://docs.nvidia.com/cuda/cuquantum/latest/custabilizer/index.html) within the [NVIDIA cuQuantum library](https://docs.nvidia.com/cuda/cuquantum/latest/index.html) to generate fresh batches of `(trainX, trainY)` tensors on every training step.\n", + "\n", + "GPU-accelerated sampling requires the `cuquantum-python` package (installed automatically by the setup cell above, matched to your CUDA version). Training uses cuStabilizer for on-the-fly data generation and will fail if cuquantum is not installed. Inference uses STIM instead and does not require cuquantum.\n", + "\n", + "For large scale training runs, cuStabilizer provides massive speedups in data generation. \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "For example, in the above figure, cuStabilizer provides a 1060x speedup over STIM for sampling 1 Million shots from a distance 30 surface code on an NVIDIA DGX B200 GPU.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training Configuration\n", + "\n", + "Training is controlled by a single user-facing file: `conf/config_public.yaml`. It exposes the QEC problem variables.\n", + "\n", + "Here are the main settings for your `config_public.yaml`:\n", + "\n", + "| Parameter | What it controls |\n", + "|---|---|\n", + "| `model_id` | Which model architecture to use (1–5). Each has a fixed **receptive field R** — the maximum distance/rounds it can handle. Models 1 & 2: R=9, Model 3: R=17, Models 4 & 5: R=13 |\n", + "| `distance` | Surface code distance for **evaluation** (Training always runs at distance=R) |\n", + "| `n_rounds` | Number of syndrome measurement rounds for **evaluation** |\n", + "| `workflow.task` | `train` to train a new model, `inference` to evaluate a saved model |\n", + "| `data.code_rotation` | Surface code orientation: `O1`, `O2`, `O3`, or `O4` (different logical operator placements) |\n", + "| `data.noise_model` | The 25-parameter circuit-level noise model: state prep errors (`p_prep_X/Z`), measurement errors (`p_meas_X/Z`), idle errors during CNOT layers (`p_idle_cnot_X/Y/Z`) and SPAM windows (`p_idle_spam_X/Y/Z`), and all 15 two-qubit CNOT Pauli error rates (`p_cnot_AB`) |\n", + "\n", + "\n", + "The cell below prints the full default `config_public.yaml` file.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# === conf/config_public.yaml (what users edit) ===\n", + "config_path = os.path.join(PREDECODER_ROOT, \"conf\", \"config_public.yaml\")\n", + "with open(config_path) as f:\n", + " print(f.read())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`config_validator.py` holds the hidden training hyperparameters: optimizer settings, batch schedule, EMA, early stopping, and more. These are intentionally not in the public config to prevent misconfiguration.\n", + "\n", + "You do not need to touch this file for a standard run. However, some parameters can be overridden at runtime via **environment variables**, without editing any file:\n", + "\n", + "| Environment Variable | Config Parameter | Default | Description |\n", + "|---|---|---|---|\n", + "| `PREDECODER_TRAIN_SAMPLES` | `train.num_samples` | 67,108,864 | Total training shots generated per epoch (divided into batches based on batch size and GPU count). Default is the production base for 8 GPUs, auto-scaled proportionally to GPU count (e.g. 8,388,608 on 1 GPU). Env var override bypasses scaling |\n", + "| `PREDECODER_TRAIN_EPOCHS` | `train.epochs` | 100 | Total number of training epochs |\n", + "| `PREDECODER_VAL_SAMPLES` | `val.num_samples` | 65,536 | Total validation shots evaluated at the end of each epoch. Split evenly across GPUs |\n", + "| `PREDECODER_TEST_SAMPLES` | `test.num_samples` | 262,144 | Total shots for final LER/SDR evaluation after training. Split evenly across GPUs |\n", + "| `PREDECODER_LR_MILESTONES` | `lr_scheduler.milestones` | `0.25,0.5,1.0` | Comma-separated fractions of total training steps. At each milestone the learning rate is multiplied by gamma=0.7. With the defaults, LR drops at 25%, 50%, and 100% of training |\n", + "| `PREDECODER_SKIP_NOISE_UPSCALING` | *(env var only)* | `0` | When `0` (default), training noise is automatically scaled up toward the surface code threshold (~0.006) to improve generalization. Set `1` to train with your exact noise model values |\n", + "| `PREDECODER_TORCH_COMPILE` | `torch_compile` | `1` | Enables `torch.compile` for the model. Set `0` to disable (useful if training segfaults on startup) |\n", + "| `PREDECODER_TORCH_COMPILE_MODE` | `torch_compile_mode` | `default` | `torch.compile` backend mode. Options: `default`, `reduce-overhead`, `max-autotune` |\n", + "| `PREDECODER_BASE_OUTPUT_DIR` | *(env var only)* | `outputs/` | Root directory for all run output: checkpoints, TensorBoard logs, Hydra config snapshots |\n", + "\n", + "Parameters **not** accessible via environment variable — such as optimizer internals (`weight_decay`, `beta2`), EMA decay, batch schedule shape, and model architecture can only be changed by editing `apply_public_defaults_and_model()` in `config_validator.py` directly.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Running Training from the CLI\n", + "\n", + "All training is launched through a single script that reads `conf/config_public.yaml` and applies the hidden defaults from `config_validator.py` automatically.\n", + "\n", + "#### Default training run\n", + "\n", + "All CLI commands below assume you are in the repository root directory (`Ising-Decoding/`).\n", + "\n", + "```bash\n", + "bash code/scripts/local_run.sh\n", + "```\n", + "\n", + "This trains Model 1 (R=9) at d=9, n_rounds=9 using the default noise model (p≈0.003), with all hidden defaults: Lion optimizer, LR warmup/decay, EMA, and noise upscaling. Output is written to `outputs/test1/`.\n", + "\n", + "---\n", + "\n", + "#### Full reference: everything you can change at the command line\n", + "\n", + "**Shell variables** — control the run itself (prepend before `bash code/scripts/local_run.sh`):\n", + "\n", + "| Variable | What it does | Example |\n", + "|---|---|---|\n", + "| `EXPERIMENT_NAME` | Names the output folder under `outputs/` — change this when you change config so runs don't mix | `EXPERIMENT_NAME=my_d9_run` |\n", + "| `WORKFLOW` | `train` to train, `inference` to evaluate a saved model | `WORKFLOW=inference` |\n", + "| `GPUS` | Number of GPUs (uses PyTorch DDP; each GPU gets its own data generator) | `GPUS=4` |\n", + "| `CUDA_VISIBLE_DEVICES` | Which physical GPUs to use | `CUDA_VISIBLE_DEVICES=0,1` |\n", + "| `FRESH_START` | Set `1` to ignore any existing checkpoint and restart from scratch | `FRESH_START=1` |\n", + "| `CONFIG_NAME` | Which config file to load from `conf/` (omit `.yaml`) | `CONFIG_NAME=config_local_test` |\n", + "\n", + "**Environment variables** — override `config_validator.py` hidden defaults at runtime (see the table in the previous section for defaults and detailed descriptions). Prepend these before `bash code/scripts/local_run.sh`:\n", + "\n", + "**Hydra overrides** — override any `config_public.yaml` value inline without editing the file (appended after `local_run.sh`):\n", + "\n", + "| Override | What it does | Example |\n", + "|---|---|---|\n", + "| `model_id=N` | Select model architecture 1–5 | `model_id=4` |\n", + "| `distance=N` | Evaluation distance (training always uses R) | `distance=13` |\n", + "| `n_rounds=N` | Evaluation rounds | `n_rounds=13` |\n", + "| `data.code_rotation=ON` | Code orientation O1–O4 | `data.code_rotation=O2` |\n", + "| `optimizer.lr=X` | Override the model-default learning rate | `optimizer.lr=0.0002` |\n", + "| `data.skip_noise_upscaling=true` | Disable noise auto-scaling inline | `data.skip_noise_upscaling=true` |\n", + "| `data.noise_model.p_prep_X=X` | Override any individual noise parameter | `data.noise_model.p_prep_X=0.004` |\n", + "\n", + "---\n", + "\n", + "Below are a few example runs you might do and the associated variables to set.\n", + "\n", + "#### Example 1: quick smoke test on 1 GPU\n", + "\n", + "```bash\n", + "# Separate output folder, small dataset, 3 epochs, fewer eval shots\n", + "EXPERIMENT_NAME=smoke_test \\\n", + "FRESH_START=1 \\\n", + "PREDECODER_TRAIN_SAMPLES=65536 \\\n", + "PREDECODER_TRAIN_EPOCHS=3 \\\n", + "PREDECODER_TEST_SAMPLES=8192 \\\n", + "GPUS=1 \\\n", + "bash code/scripts/local_run.sh\n", + "```\n", + "\n", + "#### Example 2: full d=13 run on 8 GPUs with hardware noise\n", + "\n", + "```bash\n", + "# 8 GPUs, Model 4 (R=13), hardware-calibrated noise (no auto-scaling)\n", + "EXPERIMENT_NAME=d13_hardware_noise \\\n", + "GPUS=8 \\\n", + "PREDECODER_SKIP_NOISE_UPSCALING=1 \\\n", + "bash code/scripts/local_run.sh \\\n", + " model_id=4 \\\n", + " distance=13 \\\n", + " n_rounds=13 \\\n", + " data.code_rotation=O1 \\\n", + " data.noise_model.p_prep_X=0.001 \\\n", + " data.noise_model.p_prep_Z=0.001 \\\n", + " data.noise_model.p_meas_X=0.003 \\\n", + " data.noise_model.p_meas_Z=0.003\n", + "```\n", + "\n", + "#### Example 3: inference on a saved model\n", + "\n", + "```bash\n", + "# Must use the same EXPERIMENT_NAME as the training run\n", + "EXPERIMENT_NAME=d13_hardware_noise \\\n", + "WORKFLOW=inference \\\n", + "bash code/scripts/local_run.sh \\\n", + " model_id=4 \\\n", + " distance=13 \\\n", + " n_rounds=13\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Optimizing Model Inference\n", + "\n", + "When using the predecoder in a realtime setting, it becomes crucial to minimize the inference time for the model to ensure the decoder has sufficient throughput and the reaction time is as short as possible. There are optimizations which can be used to improve the inference performance of the predecoder model from the AI side of the equation. Each is walked through below.\n", + "\n", + "1. **ONNX Export**: convert the PyTorch pipeline to a portable, framework-independent graph\n", + "2. **TensorRT Engine (FP16)**: compile the ONNX graph into a GPU-optimized engine with FP16 precision\n", + "3. **FP8 Quantization + TensorRT**: further reduce precision to FP8 for maximum throughput on supported GPUs\n", + "\n", + "The figure below demonstrates why such optimizations are so important for realtime deployment. It shows the speedups achieved for model 1 at three different code distances relative to a FP16 PyTorch run after optimization with TensorRT(FP16) and then quantization with TensorRT(FP8) on a NVIDIA GB300.\n", + "\n", + "\n", + "\n", + "In this case, speedups of around 10x are possible for a batch size of 1. This is just a representative example and not a formal benchmark as model size, code distance, GPU selection, and initial PyTorch optimizations all play a role in the expected speedups. The main point is that these optimizations are easy to perform and always worth doing before your final inference runs. In general, quantization is worth the minimal cost in LER, but you will need to run benchmarks to assess the proper tradeoffs for your experiment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 1: ONNX Export\n", + "\n", + "[ONNX](https://onnx.ai/) (Open Neural Network Exchange) is a standard format for representing ML models that allows a model trained with a specific framework (e.g. PyTorch) to be converted to a framework-agnostic graph which can be optimized to improve inference performance. Exporting to ONNX gives us:\n", + "\n", + "- **Portability**: the model can be loaded by any ONNX-compatible runtime (TensorRT, ONNX Runtime, etc.)\n", + "- **Graph optimization**: constant folding and operator fusion happen at export time\n", + "- **Full pipeline in one graph**: we export the entire `PreDecoderMemoryEvalModule`, not just the CNN. This means detector preprocessing, the model forward pass, and residual syndrome computation are all fused into a single ONNX graph.\n", + "\n", + "\n", + "The input and output signatures of the ONNX graph are:\n", + "- **Input**: `dets` — `(B, num_dets)` uint8\n", + "- **Output**: `L_and_residual_dets` — `(B, 1 + num_residual_dets)` uint8\n", + "\n", + "B is the batch size and is dynamic, meaning the same ONNX graph can be used to process any batch size without regeneration.\n", + "\n", + "The code below converts our Ising model into an ONNX file. \n", + "\n", + "**Note:** If you are running CUDA 13 you will see a warning from onnxruntime-gpu about missing CUDA 13 support. You can safely ignore this warning. ONNX Runtime is used here only as a quick sanity check that the exported graph runs correctly — it automatically falls back to CPU execution. The actual GPU-accelerated inference in this tutorial is handled by TensorRT, which has full CUDA 13 support.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging, warnings\n", + "for _name in [\"onnxscript\", \"onnx_ir\", \"torch.onnx\", \"torch\"]:\n", + " logging.getLogger(_name).setLevel(logging.ERROR)\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "# Export the full pipeline to ONNX\n", + "onnx_path = \"predecoder_memory_d9_T9_X.onnx\"\n", + "\n", + "# half = number of X-type (or Z-type) detectors per round for a distance-d surface code\n", + "half = (distance**2 - 1) // 2\n", + "\n", + "# Create a small dummy detector batch — torch.onnx.export traces the model\n", + "# using this example input to record the full computation graph\n", + "\n", + "example_dets = torch.randint(0, 2, (2, 2 * n_rounds * half), dtype=torch.uint8, device=device)\n", + "print(f\"Input shape: {example_dets.shape}\")\n", + "\n", + "try:\n", + " import onnx\n", + "\n", + " # Trace the pipeline and serialize the computation graph to an ONNX file.\n", + " # dynamic_axes lets the batch dimension vary at runtime (single shot or batched inference).\n", + " # do_constant_folding=True pre-computes constant sub-expressions at export time.\n", + " \n", + " torch.onnx.export(\n", + " pipeline,\n", + " example_dets,\n", + " onnx_path,\n", + " opset_version=18,\n", + " input_names=[\"dets\"],\n", + " output_names=[\"L_and_residual_dets\"],\n", + " dynamic_axes={\n", + " \"dets\": {0: \"batch\"},\n", + " \"L_and_residual_dets\": {0: \"batch\"},\n", + " },\n", + " do_constant_folding=True,\n", + " )\n", + "\n", + " # Verify the exported graph is valid\n", + " model_onnx = onnx.load(onnx_path)\n", + " onnx.checker.check_model(model_onnx)\n", + " onnx_size_mb = os.path.getsize(onnx_path) / (1024 ** 2)\n", + " print(f\"Exported and verified: {onnx_path} ({onnx_size_mb:.1f} MB)\")\n", + "\n", + " # Quick inference check with ONNX Runtime\n", + " # Tries CUDA first, falls back to CPU — confirms the graph runs end-to-end\n", + " try:\n", + " import onnxruntime as ort\n", + " providers = [\"CUDAExecutionProvider\", \"CPUExecutionProvider\"]\n", + " sess = ort.InferenceSession(onnx_path, providers=providers)\n", + " test_input = example_dets.cpu().numpy()\n", + " ort_out = sess.run(None, {\"dets\": test_input})\n", + " print(f\"ONNX Runtime inference OK — output shape: {ort_out[0].shape}\")\n", + " active_provider = sess.get_providers()[0]\n", + " print(f\" Running on: {active_provider}\")\n", + " except ImportError:\n", + " print(\"onnxruntime not installed — skipping runtime check.\")\n", + " print(\"Install with: pip install onnxruntime\")\n", + "\n", + "except ImportError:\n", + " print(\"onnx not installed — skipping export.\")\n", + " print(\"Install with: pip install onnx\")\n", + " onnx_path = None\n", + "except Exception as e:\n", + " print(f\"ONNX export failed: {e}\")\n", + " onnx_path = None\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2: TensorRT Engine (FP16)\n", + "\n", + "[TensorRT](https://developer.nvidia.com/tensorrt) is NVIDIA's inference compiler which takes an ONNX graph and compiles it into a GPU-architecture-specific engine that accounts for your target data type and batch size range in order to get the most performance out of your GPU. \n", + "\n", + "Under the hood, TensorRT benchmarks dozens of candidate GPU kernel implementations for each operation in the graph and selects the fastest one for your exact hardware. It then fuses adjacent operations (convolution, normalization, activation) into single GPU kernels, reorganizes tensor memory layouts to maximize cache efficiency, and eliminates the Python interpreter overhead that PyTorch incurs on every call. The result is a binary engine tuned to extract maximum throughput from your specific GPU — the same model weights, the same outputs, but with latency approaching the hardware limit.\n", + "\n", + "The build process:\n", + "1. Parse the ONNX model into a TensorRT network\n", + "2. Set an **optimization profile** defining the min/optimal/max batch sizes\n", + "3. Enable the FP16 flag (tells TensorRT it can use half-precision where beneficial)\n", + "4. Build and serialize the engine to a `.engine` file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build a TensorRT engine from the ONNX model (FP16)\n", + "engine_path = \"predecoder_memory_d9_T9_X_fp16.engine\"\n", + "batch_size = 1024\n", + "in_cols = 2 * n_rounds * half\n", + "\n", + "if onnx_path is None or not os.path.exists(onnx_path):\n", + " print(\"ONNX model not available — run the ONNX export cell first.\")\n", + "else:\n", + " try:\n", + " import tensorrt as trt\n", + " print(f\"TensorRT version: {trt.__version__}\")\n", + "\n", + " # Create TensorRT logger and builder — entry points for all TRT operations.\n", + " # Store as trt_logger so the FP8 cell can reuse it (avoids a TRT duplicate-logger warning).\n", + " trt_logger = trt.Logger(trt.Logger.WARNING)\n", + " builder = trt.Builder(trt_logger)\n", + "\n", + " # TRT 10+ removed the EXPLICIT_BATCH flag (always on); TRT 8 needs it.\n", + " trt_major = int(trt.__version__.split(\".\")[0])\n", + " if trt_major >= 10:\n", + " network = builder.create_network()\n", + " else:\n", + " network = builder.create_network(\n", + " 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)\n", + " )\n", + "\n", + " # Parse the ONNX graph into the TensorRT network representation\n", + " parser = trt.OnnxParser(network, trt_logger)\n", + " with open(onnx_path, \"rb\") as f:\n", + " if not parser.parse(f.read()):\n", + " for err_i in range(parser.num_errors):\n", + " print(f\" Parse error: {parser.get_error(err_i)}\")\n", + " raise RuntimeError(\"ONNX parse failed\")\n", + "\n", + " # Enable FP16 mode — TensorRT will use half-precision kernels wherever beneficial\n", + " config = builder.create_builder_config()\n", + " config.set_flag(trt.BuilderFlag.FP16)\n", + "\n", + " # Define the supported batch size range: (min, optimal, max)\n", + " # TensorRT selects the best kernels for the optimal size during compilation\n", + " profile = builder.create_optimization_profile()\n", + " profile.set_shape(\"dets\", (1, in_cols), (batch_size, in_cols), (batch_size, in_cols))\n", + " config.add_optimization_profile(profile)\n", + "\n", + " # Compile the network into a hardware-specific binary engine and save it\n", + " print(f\"Building TensorRT FP16 engine (batch_size={batch_size}) — this may take a minute...\")\n", + " serialized = builder.build_serialized_network(network, config)\n", + " with open(engine_path, \"wb\") as f:\n", + " f.write(serialized)\n", + "\n", + " engine_size_mb = os.path.getsize(engine_path) / (1024 ** 2)\n", + " print(f\"Built: {engine_path} ({engine_size_mb:.1f} MB)\")\n", + "\n", + " except ImportError:\n", + " print(\"TensorRT not installed — skipping engine build.\")\n", + " print(\"Install with: pip install tensorrt\")\n", + " except Exception as e:\n", + " print(f\"Engine build failed: {e}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3: FP8 Quantization + TensorRT\n", + "\n", + "AI models can be quantized such that their weights are converted to a lower precision. Often, this provides additional speedup with minimal loss in accuracy. The Ising models are FP16 by default, but TensorRT can quantize to FP8 and achieve further speedups which can run on any GPU with FP8 tensor cores. \n", + "\n", + "The quantization process is:\n", + "1. **Collect calibration data**: representative detector patterns from the test distribution\n", + "2. **Quantize the ONNX model**: `nvidia-modelopt` analyzes the calibration data to determine optimal quantization ranges, then inserts QuantizeLinear/DequantizeLinear (QDQ) nodes into the ONNX graph\n", + "3. **Build a TensorRT engine**: TensorRT recognizes the QDQ nodes and fuses them into native FP8 operations\n", + "\n", + "Only the convolution layers are quantized to FP8; the remaining operations (preprocessing, residual computation) stay in FP16 for accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import contextlib, io\n", + "\n", + "# FP8 quantization of the ONNX model + TensorRT engine build\n", + "fp8_onnx_path = \"predecoder_memory_d9_T9_X_fp8.onnx\"\n", + "fp8_engine_path = \"predecoder_memory_d9_T9_X_fp8.engine\"\n", + "\n", + "if onnx_path is None or not os.path.exists(onnx_path):\n", + " print(\"ONNX model not available — run the ONNX export cell first.\")\n", + "else:\n", + " calib_dets = detectors[:256]\n", + " print(f\"Calibration data: {calib_dets.shape} ({calib_dets.shape[0]} samples, dtype: {calib_dets.dtype})\")\n", + "\n", + " try:\n", + " import modelopt.onnx.quantization as mq\n", + " import tensorrt as trt\n", + " print(f\"TensorRT version: {trt.__version__}\")\n", + "\n", + " # Quantize convolution layers to FP8 using calibration data to determine\n", + " # optimal scale factors; non-conv ops stay at FP16 for numerical accuracy\n", + " print(f\"Quantizing {onnx_path} to FP8 (this may take a moment)...\")\n", + " _prev_level = logging.root.manager.disable\n", + " logging.disable(logging.CRITICAL)\n", + " _devnull_fd = os.open(os.devnull, os.O_WRONLY)\n", + " _old_stdout_fd = os.dup(1)\n", + " _old_stderr_fd = os.dup(2)\n", + " _devnull_py = open(os.devnull, \"w\")\n", + " _old_sys_stdout = sys.stdout\n", + " _old_sys_stderr = sys.stderr\n", + " os.dup2(_devnull_fd, 1)\n", + " os.dup2(_devnull_fd, 2)\n", + " sys.stdout = _devnull_py\n", + " sys.stderr = _devnull_py\n", + " try:\n", + " mq.quantize(\n", + " onnx_path=onnx_path,\n", + " quantize_mode=\"fp8\",\n", + " calibration_data={\"dets\": calib_dets},\n", + " output_path=fp8_onnx_path,\n", + " op_types_to_quantize=[\"Conv\"],\n", + " high_precision_dtype=\"fp16\",\n", + " )\n", + " finally:\n", + " sys.stdout = _old_sys_stdout\n", + " sys.stderr = _old_sys_stderr\n", + " os.dup2(_old_stdout_fd, 1)\n", + " os.dup2(_old_stderr_fd, 2)\n", + " os.close(_devnull_fd)\n", + " os.close(_old_stdout_fd)\n", + " os.close(_old_stderr_fd)\n", + " _devnull_py.close()\n", + " logging.disable(_prev_level)\n", + " fp8_size_mb = os.path.getsize(fp8_onnx_path) / (1024 ** 2)\n", + " print(f\" Quantized: {fp8_onnx_path} ({fp8_size_mb:.1f} MB)\")\n", + "\n", + " if not gpu_supports_fp8:\n", + " print(f\"\\n Skipping FP8 engine build: {gpu_name} (compute capability {gpu_cc}) \"\n", + " f\"does not have FP8 tensor cores.\")\n", + " print(f\" FP8 requires Ada Lovelace (RTX 4000+, cc 8.9) or Hopper (H100+, cc 9.0).\")\n", + " print(f\" The quantized ONNX model ({fp8_onnx_path}) can be built on supported hardware.\")\n", + " else:\n", + " _trt_logger = trt_logger if 'trt_logger' in dir() else trt.Logger(trt.Logger.WARNING)\n", + " builder = trt.Builder(_trt_logger)\n", + "\n", + " trt_major = int(trt.__version__.split(\".\")[0])\n", + " if trt_major >= 10:\n", + " network = builder.create_network(\n", + " 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)\n", + " )\n", + " else:\n", + " network = builder.create_network(\n", + " 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)\n", + " | 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)\n", + " )\n", + "\n", + " parser = trt.OnnxParser(network, _trt_logger)\n", + " with open(fp8_onnx_path, \"rb\") as f:\n", + " if not parser.parse(f.read()):\n", + " raise RuntimeError(\"ONNX parse failed for FP8 model\")\n", + "\n", + " config = builder.create_builder_config()\n", + " profile = builder.create_optimization_profile()\n", + " profile.set_shape(\"dets\", (1, in_cols), (batch_size, in_cols), (batch_size, in_cols))\n", + " config.add_optimization_profile(profile)\n", + "\n", + " print(f\"Building TensorRT FP8 engine (batch_size={batch_size}) — this may take a minute...\")\n", + " serialized = builder.build_serialized_network(network, config)\n", + " with open(fp8_engine_path, \"wb\") as f:\n", + " f.write(serialized)\n", + "\n", + " engine_size_mb = os.path.getsize(fp8_engine_path) / (1024 ** 2)\n", + " print(f\"Built: {fp8_engine_path} ({engine_size_mb:.1f} MB)\")\n", + "\n", + " except ImportError as e:\n", + " print(f\"Required package not available: {e}\")\n", + " print(\"Install with: pip install nvidia-modelopt[onnx] tensorrt\")\n", + " except Exception as e:\n", + " print(f\"FP8 pipeline failed: {e}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Running Optimized Inference from CLI\n", + "\n", + "The full optimization pipeline can also be driven via environment variables from the command line (run from the repository root):\n", + "\n", + "```bash\n", + "# PyTorch only (default):\n", + "WORKFLOW=inference bash code/scripts/local_run.sh\n", + "\n", + "# ONNX export + TensorRT engine build (FP16):\n", + "ONNX_WORKFLOW=2 WORKFLOW=inference bash code/scripts/local_run.sh\n", + "\n", + "# ONNX + FP8 quantization + TensorRT:\n", + "ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 WORKFLOW=inference bash code/scripts/local_run.sh\n", + "\n", + "# Load a pre-built engine file:\n", + "ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh\n", + "```\n", + "\n", + "### Generating data for CUDA-Q QEC realtime predecoder test application\n", + "\n", + "When evaluating the neural pre-decoder in an end-to-end downstream system like\n", + "CUDA-Q Realtime, you will need a test harness with valid inputs—both the\n", + "exported neural network model and the corresponding syndrome data.\n", + "\n", + "The utility script `code/export/generate_test_data.py` is provided to generate\n", + "this exact data (both an `.onnx` file and several `.bin` files) so you can\n", + "easily consume it in the CUDA-Q QEC realtime AI decoder.\n", + "\n", + "> **Important:** The `--distance` and `--n-rounds` arguments provided to this\n", + "script **must match** the values used in the preceding section when running the\n", + "ONNX export (e.g. `ONNX_WORKFLOW=2`).\n", + "\n", + "For a detailed walkthrough on how to ingest these files into the CUDA-Q Realtime\n", + "C++ pipeline, see the downstream documentation here: [Realtime AI Predecoder\n", + "Pipeline](https://nvidia.github.io/cudaqx/examples_rst/qec/realtime_predecoder_pymatching.html).\n", + "\n", + "```text\n", + "python3 code/export/generate_test_data.py --distance 13 --n-rounds 104 --num-samples 10000 --basis X --p-error=0.003 --simple-noise\n", + "```\n", + "\n", + "**Example output:**\n", + "\n", + "```text\n", + "Building circuit: D=13, T=104, basis=X, rotation=XV, p=0.003\n", + " Circuit built in 0.007s\n", + "Building detector error model and PyMatching matcher...\n", + " DEM + matcher built in 0.083s\n", + " Detectors: 17472, Observables: 1\n", + "Extracting check matrices (beliefmatching)...\n", + " H shape: (17472, 93864), O shape: (1, 93864), priors shape: (93864,)\n", + "Sampling 10000 shots...\n", + " Sampled in 1.006s\n", + "Decoding with PyMatching (baseline)...\n", + " Errors: 30/10000, LER: 0.0030\n", + " Decode time: 5.439s (543.9 µs/shot)\n", + "Writing outputs to test_data/d13_T104_X/\n", + "Done.\n", + " H_csr.bin 808,944 bytes\n", + " O_csr.bin 2,932 bytes\n", + " detectors.bin 698,880,008 bytes\n", + " metadata.txt 162 bytes\n", + " observables.bin 40,008 bytes\n", + " priors.bin 750,916 bytes\n", + " pymatching_predictions.bin 40,008 bytes\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Offline Decoding From Stim Files\n", + "\n", + "The notebook so far has driven everything from in-memory Stim sampling, which is the right path when you control the experiment. There is a second, file-based path for **detector samples produced elsewhere** — typically a QPU dump or a third-party simulator — that you want to push through the same pre-decoder + PyMatching pipeline.\n", + "\n", + "The contract is two files per basis:\n", + "\n", + "```text\n", + "/\n", + " samples_X.dets # Stim sparse detector-sample format\n", + " metadata_X.json # circuit + noise fingerprint\n", + " samples_Z.dets\n", + " metadata_Z.json\n", + "```\n", + "\n", + "`samples_*.dets` is Stim's sparse format with logical observables appended (`shot D3 D8 L0` ⇒ detectors 3 and 8 fired, observable 0 flipped). The metadata JSON is what `qec.surface_code.stim_sample_io.build_stim_sample_metadata` writes and contains the circuit parameters (`distance`, `n_rounds`, `basis`, `code_rotation`, detector and observable counts) plus an optional noise fingerprint (`p_error`, `noise_model_sha256`). The decoder validates every recorded field against the rebuilt circuit before any decoding happens — distance, rounds, basis, orientation, detector count, and observable presence are always strict; the noise fingerprint is strict by default and can be downgraded to a warning with `PREDECODER_STIM_STRICT_NOISE=0`.\n", + "\n", + "The cell below is a self-contained, in-process demo of that round-trip: write a small artifact, read it back, push it through the canonical `dets_to_predecoder_inputs` helper, and confirm the result equals the file datapipe's output. No shell, no GPU.\n", + "\n", + "For the **full** CLI workflow (generate large reference files, decode with PyMatching only, decode with the Ising pre-decoder, print latency/LER) see the *Offline decoding from Stim detector samples* section of the repository [README](../README.md#offline-decoding-from-stim-detector-samples). The README is the canonical reference; the launcher commands, output schema, and smoke script live there." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from data.datapipe_stim import QCDataPipePreDecoder_Memory_from_stim_file\n", + "from data.predecoder_transform import dets_to_predecoder_inputs\n", + "from qec.noise_model import NoiseModel\n", + "from qec.surface_code.memory_circuit import MemoryCircuit\n", + "from qec.surface_code.stim_sample_io import (\n", + " build_stim_sample_metadata,\n", + " read_stim_detector_samples,\n", + " write_metadata_json,\n", + " write_stim_detector_samples,\n", + ")\n", + "\n", + "# 1) Build a tiny memory circuit (D=3, T=3, X basis) so the demo runs in <1s.\n", + "DISTANCE, N_ROUNDS, SHOTS, BASIS, ROTATION = 3, 3, 32, \"X\", \"XV\"\n", + "P_ERROR = 0.003\n", + "noise = NoiseModel.from_single_p(P_ERROR)\n", + "circ = MemoryCircuit(\n", + " distance=DISTANCE,\n", + " idle_error=float(noise.get_max_probability()),\n", + " sqgate_error=float(noise.get_max_probability()),\n", + " tqgate_error=float(noise.get_max_probability()),\n", + " spam_error=(2.0 / 3.0) * float(noise.get_max_probability()),\n", + " n_rounds=N_ROUNDS,\n", + " basis=BASIS,\n", + " code_rotation=ROTATION,\n", + " noise_model=noise,\n", + " add_boundary_detectors=True,\n", + ")\n", + "circ.set_error_rates()\n", + "\n", + "# 2) Sample detector events + appended observables from Stim.\n", + "dets_and_obs = np.asarray(\n", + " circ.stim_circuit.compile_detector_sampler(seed=42).sample(\n", + " shots=SHOTS, append_observables=True\n", + " ),\n", + " dtype=np.uint8,\n", + ")\n", + "\n", + "with tempfile.TemporaryDirectory() as tmp:\n", + " root = Path(tmp)\n", + "\n", + " # 3) Write the two files that make up the offline contract.\n", + " write_stim_detector_samples(\n", + " path=root / f\"samples_{BASIS}.dets\",\n", + " dets_and_obs=dets_and_obs,\n", + " num_detectors=circ.stim_circuit.num_detectors,\n", + " num_observables=circ.stim_circuit.num_observables,\n", + " )\n", + " write_metadata_json(\n", + " root / f\"metadata_{BASIS}.json\",\n", + " build_stim_sample_metadata(\n", + " distance=DISTANCE,\n", + " n_rounds=N_ROUNDS,\n", + " basis=BASIS,\n", + " code_rotation=ROTATION,\n", + " num_detectors=circ.stim_circuit.num_detectors,\n", + " num_observables=circ.stim_circuit.num_observables,\n", + " num_shots=SHOTS,\n", + " p_error=P_ERROR,\n", + " noise_model_label=\"25-param\",\n", + " noise_model_params=noise.canonical_parameters(),\n", + " noise_model_sha256=noise.sha256(),\n", + " ),\n", + " )\n", + "\n", + " # 4) Read the files back through the contract-enforcing reader.\n", + " read_dets_and_obs, read_meta = read_stim_detector_samples(\n", + " samples_path=root / f\"samples_{BASIS}.dets\",\n", + " metadata_path=root / f\"metadata_{BASIS}.json\",\n", + " distance=DISTANCE,\n", + " n_rounds=N_ROUNDS,\n", + " basis=BASIS,\n", + " code_rotation=ROTATION,\n", + " num_detectors=circ.stim_circuit.num_detectors,\n", + " num_observables=circ.stim_circuit.num_observables,\n", + " p_error=P_ERROR,\n", + " noise_model_sha256=noise.sha256(),\n", + " noise_model_label=\"25-param\",\n", + " )\n", + " assert np.array_equal(read_dets_and_obs, dets_and_obs), \"round-trip mismatch\"\n", + "\n", + " # 5) Push detector bits through the canonical pre-decoder transform.\n", + " num_obs = circ.stim_circuit.num_observables\n", + " dets_tensor = torch.from_numpy(read_dets_and_obs[:, :-num_obs]).to(torch.uint8)\n", + " train_x_helper, x_syn_helper, z_syn_helper = dets_to_predecoder_inputs(\n", + " dets_tensor,\n", + " distance=DISTANCE,\n", + " n_rounds=N_ROUNDS,\n", + " basis=BASIS,\n", + " code_rotation=ROTATION,\n", + " )\n", + "\n", + " # 6) The file datapipe wraps the same helper; its tensors must match\n", + " # exactly. This is the same check the offline tests run in CI.\n", + " pipe = QCDataPipePreDecoder_Memory_from_stim_file(\n", + " distance=DISTANCE,\n", + " n_rounds=N_ROUNDS,\n", + " num_samples=SHOTS,\n", + " error_mode=\"circuit_level_surface_custom\",\n", + " stim_samples_dir=root,\n", + " p_error=P_ERROR,\n", + " measure_basis=BASIS,\n", + " code_rotation=ROTATION,\n", + " noise_model=noise,\n", + " )\n", + " assert torch.equal(pipe.trainX_all, train_x_helper)\n", + " assert torch.equal(pipe.x_syn_diff_all, x_syn_helper)\n", + " assert torch.equal(pipe.z_syn_diff_all, z_syn_helper)\n", + "\n", + " print(f\"Wrote and read {SHOTS} shots from {root}\")\n", + " print(f\" dets_and_obs shape : {read_dets_and_obs.shape}\")\n", + " print(f\" trainX shape : {tuple(train_x_helper.shape)}\")\n", + " print(f\" metadata schema_ver: {read_meta['schema_version']}\")\n", + " print(f\" metadata p_error : {read_meta['p_error']}\")\n", + " print(f\" metadata sha256[:8]: {read_meta['noise_model_sha256'][:8]}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Learn More\n", + "\n", + "After completing this tutorial, you should have a foundational understanding of how the NVIDIA Ising pre-decoder works and how to get started training your own models. For more information on the code, explore the [GitHub repo](https://github.com/NVIDIA/Ising-Decoding). For more details on the model itself and its performance, read the [NVIDIA Ising pre-decoder whitepaper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes).\n", + "\n", + "Visit the NVIDIA Ising webpage to learn more about other models in the [NVIDIA Ising family](https://developer.nvidia.com/ising) of open quantum models." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ising2", + "language": "python", + "name": "ising2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setup\n", - "\n", - "This tutorial can be found in the `cookbook/` directory of the `Ising-Decoding` repository.\n", - "\n", - "**Prerequisites:**\n", - "- **NVIDIA GPU** with CUDA drivers installed (`nvidia-smi` must be on your PATH)\n", - "- **Python 3.11, 3.12, or 3.13**\n", - "\n", - "The cell below will:\n", - "1. Locate the repository root and add the predecoder source code to the Python path\n", - "2. Detect the CUDA version from your GPU driver and install the matching PyTorch build\n", - "3. Install all training dependencies (which are a superset of inference dependencies), including cuQuantum for GPU-accelerated data generation, ONNX tooling for the inference optimization sections, and all other required packages\n", - "4. Authenticate with Hugging Face to download pre-trained model weights (the model is access-controlled). **Paste your token into the `HF_TOKEN` variable** in the cell below (get one at https://huggingface.co/settings/tokens)\n", - "\n", - "**TensorRT** (used in the FP16 engine and FP8 quantization sections) is installed automatically with the correct CUDA version.\n", - "\n", - "**Note:** If you have already run the setup cell once and all packages are installed, you can skip it on subsequent runs \u2014 only the imports cell that follows is needed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import subprocess, sys, os, re, shutil\n", - "\n", - "NOTEBOOK_DIR = os.path.abspath('')\n", - "PREDECODER_ROOT = os.path.abspath(os.path.join(NOTEBOOK_DIR,'..'))\n", - "sys.path.insert(0, os.path.join(PREDECODER_ROOT, 'code'))\n", - "\n", - "print(f'PREDECODER_ROOT: {PREDECODER_ROOT}')\n", - "assert os.path.isdir(os.path.join(PREDECODER_ROOT, 'code')), (\n", - " f\"Cannot find predecoder source code at {PREDECODER_ROOT}/code. \"\n", - " f\"This notebook must live at /tutorials/predecoder_tutorial/.\"\n", - ")\n", - "\n", - "assert shutil.which('nvidia-smi'), 'nvidia-smi not found \u2014 this tutorial requires an NVIDIA GPU.'\n", - "\n", - "nvsmi_output = subprocess.check_output(['nvidia-smi'], text=True)\n", - "cuda_match = re.search(r'CUDA Version:\\s+([\\d.]+)', nvsmi_output)\n", - "assert cuda_match, 'Could not detect CUDA version from nvidia-smi output.'\n", - "cuda_ver = cuda_match.group(1)\n", - "cuda_major = cuda_ver.split('.')[0]\n", - "print(f'CUDA {cuda_ver} detected (major: {cuda_major})')\n", - "\n", - "def _pip(*args):\n", - " subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', *args])\n", - "\n", - "gpu_info = subprocess.check_output(\n", - " ['nvidia-smi', '--query-gpu=name,compute_cap', '--format=csv,noheader'], text=True\n", - ").strip().split('\\n')[0]\n", - "gpu_name, gpu_cc = [s.strip() for s in gpu_info.split(',')]\n", - "gpu_cc_major = int(gpu_cc.split('.')[0])\n", - "gpu_cc_minor = int(gpu_cc.split('.')[1])\n", - "gpu_supports_fp8 = (gpu_cc_major > 8) or (gpu_cc_major == 8 and gpu_cc_minor >= 9)\n", - "print(f'GPU: {gpu_name} (compute capability {gpu_cc}, FP8: {\"yes\" if gpu_supports_fp8 else \"no\"})')\n", - "\n", - "print('Environment OK.')\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "_pip('--upgrade', 'pip', 'setuptools<82', 'wheel')\n", - "\n", - "torch_cuda_tag = {'12': 'cu128', '13': 'cu130'}[cuda_major]\n", - "print(f'Installing PyTorch (wheel index: {torch_cuda_tag})...')\n", - "_pip('torch', '--index-url', f'https://download.pytorch.org/whl/{torch_cuda_tag}',\n", - " '--extra-index-url', 'https://pypi.org/simple')\n", - "\n", - "import torch\n", - "assert torch.cuda.is_available(), 'PyTorch installed but CUDA is not available.'\n", - "print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}, '\n", - " f'GPU: {torch.cuda.get_device_name(0)}')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_req = os.path.join(PREDECODER_ROOT, 'code', f'requirements_public_train-cu{cuda_major}.txt')\n", - "assert os.path.exists(train_req), (\n", - " f\"No training requirements for CUDA {cuda_major}: {train_req}\\n\"\n", - " f\"Available: requirements_public_train-cu12.txt, requirements_public_train-cu13.txt\"\n", - ")\n", - "print(f'Installing predecoder dependencies from: {os.path.basename(train_req)}')\n", - "print(f' (this includes: stim, pymatching, cuquantum, onnx tooling, and more)')\n", - "_pip('-r', train_req, 'huggingface_hub')\n", - "\n", - "trt_pkg = f'tensorrt-cu{cuda_major}'\n", - "print(f'Installing TensorRT ({trt_pkg}) for ONNX optimization sections...')\n", - "_pip(trt_pkg)\n", - "\n", - "print('All dependencies installed.')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Paste your Hugging Face token below (get one at https://huggingface.co/settings/tokens).\n", - "HF_TOKEN = \"\" # <-- paste your token here, e.g. \"hf_abc123...\"\n", - "\n", - "import huggingface_hub\n", - "if HF_TOKEN:\n", - " huggingface_hub.login(token=HF_TOKEN)\n", - " print('Logged in to Hugging Face.')\n", - "elif huggingface_hub.get_token():\n", - " print('Using previously cached Hugging Face token.')\n", - "else:\n", - " raise RuntimeError(\n", - " 'No Hugging Face token found. Paste your token into the HF_TOKEN variable above.\\n'\n", - " 'Get one at: https://huggingface.co/settings/tokens'\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import numpy as np\n", - "import pymatching\n", - "import time\n", - "from types import SimpleNamespace\n", - "\n", - "from qec.noise_model import NoiseModel\n", - "from qec.surface_code.memory_circuit import MemoryCircuit\n", - "from model.registry import get_model_spec\n", - "from model.factory import ModelFactory\n", - "from evaluation.logical_error_rate import PreDecoderMemoryEvalModule, _build_stab_maps" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "## Quick Start\n", - "\n", - "The fastest path to running the pre-decoder is shown below. This code cell generates $d$=9 surface code syndrome data directly using the repo's `MemoryCircuit` and STIM, loads the pre-trained model from [Hugging Face](https://huggingface.co/nvidia/ising_decoder_surface_code_1_fast), and runs the full decode pipeline. Note, we are using the model optimized for speed, but a second pretrained model targeting accuracy is available [here](https://huggingface.co/nvidia/ising_decoder_surface_code_1_accurate).\n", - "\n", - "The workflow has four primary steps:\n", - "\n", - "- **Syndrome data generation** \u2014 A noisy surface code circuit is simulated using STIM. Each shot produces a set of detector bits (syndrome differences across rounds) and an observable bit (whether a logical error occurred). This is what a real QPU would send to the classical decoder.\n", - "- **Pre-trained model** \u2014 A 3D convolutional neural network trained to recognize local, easily identifiable error patterns directly from syndrome data. It processes thousands of shots in parallel on GPU and returns a set of corrections along with a reduced residual syndrome.\n", - "- **Pre-decoder inference pipeline** \u2014 The model is wrapped with preprocessing (mapping flat syndrome bits onto a spatial grid) and postprocessing (applying predicted corrections to reduce the syndrome). The output is a partial logical correction and a much sparser syndrome for PyMatching to handle.\n", - "- **PyMatching on residuals** \u2014 The industry-standard MWPM decoder finishes the job on the reduced syndrome. Its result is XOR'd with the pre-decoder's partial correction to produce the final logical error prediction.\n", - "\n", - "**Data format:** The pre-decoder takes a flat array of detector bits, obtained from a QPU or via simulation in the case of this example. Internally, these bits are reshaped into a 4-channel 3D grid of shape `(batch, 4, rounds, distance, distance)` before being passed to the CNN. The four channels encode X and Z syndrome values and the surface code boundary geometry. For full details on the data format and model architecture, see [the paper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes).\n", - "\n", - "\n", - "Run the code below to see how the performance of PyMatching, the Ising decoder, and the combination of both compare.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# \u2500\u2500 1. Generate syndrome data \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", - "# These are the key QEC parameters for a surface code memory experiment.\n", - "# distance=9 means a 9x9 grid of data qubits (81 qubits + 80 ancilla).\n", - "# n_rounds=9 is the number of syndrome measurement cycles (standard: match distance).\n", - "# basis=\"X\" means we protect the X-type logical qubit (alternatively \"Z\").\n", - "# code_rotation=\"XV\" is the internal label for surface code orientation O1.\n", - "distance, n_rounds, basis, code_rotation, num_samples = 9, 9, \"X\", \"XV\", 5000\n", - "\n", - "# Build a 25-parameter depolarizing noise model from a single error rate p=0.005.\n", - "# from_single_p() distributes p across all error channels using standard ratios:\n", - "# prep/meas errors: 2p/3 (~0.0033)\n", - "# idle during CNOT: p/3 (~0.0017)\n", - "# idle during SPAM: 2p/3 (~0.0033)\n", - "# two-qubit CNOT: p/15 (~0.00033) per Pauli pair\n", - "\n", - "noise = NoiseModel.from_single_p(0.005)\n", - "p_max = noise.get_max_probability() # largest value across all 25 params (~2p/3)\n", - "\n", - "# Build the full noisy surface code circuit using Stim.\n", - "# The scalar error args (idle_error, sqgate_error, etc.) are placeholder values\n", - "# required by the constructor \u2014 the actual error channels are driven by noise_model.\n", - "# add_boundary_detectors=True adds virtual detectors at the start/end of the circuit,\n", - "# which improves decoding accuracy at the temporal boundaries of the experiment.\n", - "\n", - "circ = MemoryCircuit(\n", - " distance=distance, n_rounds=n_rounds, basis=basis, code_rotation=code_rotation,\n", - " idle_error=p_max, sqgate_error=p_max, tqgate_error=p_max,\n", - " spam_error=(2/3)*p_max, noise_model=noise, add_boundary_detectors=True,\n", - ")\n", - "circ.set_error_rates()\n", - "\n", - "# Sample the circuit: each row of `shots` is one full experiment run (all measurements).\n", - "shots = circ.stim_circuit.compile_sampler().sample(num_samples)\n", - "\n", - "# Convert raw measurement bits into detector outcomes and logical observable flips.\n", - "# Detectors are syndrome parities across two consecutive rounds \u2014 they fire (=1)\n", - "# when consecutive stabilizer measurements disagree, signaling an error.\n", - "# Observables are the true logical error bits \u2014 what we're trying to predict.\n", - "\n", - "conv = circ.stim_circuit.compile_m2d_converter()\n", - "dets_obs = conv.convert(measurements=shots, append_observables=True)\n", - "num_obs = circ.stim_circuit.num_observables\n", - "detectors = dets_obs[:, :-num_obs].astype(np.uint8) # shape: (shots, n_detectors)\n", - "observables = dets_obs[:, -num_obs:].astype(np.uint8) # shape: (shots, n_observables)\n", - "\n", - "# Build the Detector Error Model (DEM): a graph where nodes are detectors and\n", - "# edges are the most likely single-error mechanisms connecting them.\n", - "# PyMatching uses this graph to find minimum-weight correction paths.\n", - "\n", - "dem = circ.stim_circuit.detector_error_model(\n", - " decompose_errors=True, # split multi-detector errors into pairs for graph compatibility\n", - " approximate_disjoint_errors=True, # merge nearly-identical error mechanisms\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With the syndrome data generated, we now have `detectors` (the parity-check outcomes the decoder will work with) and `observables` (the ground-truth logical errors we want to predict). We also built a Detector Error Model (`dem`) that PyMatching will use later.\n", - "\n", - "Next, we load the pre-trained Ising pre-decoder model from Hugging Face. Model 1 has a receptive field of R=9, which matches our $d$=9 experiment exactly." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# \u2500\u2500 2. Load pre-trained model \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", - "# Model 1 has receptive field R=9, matching our d=9 experiment.\n", - "# The receptive field sets the maximum space-time window the model can see.\n", - "\n", - "model_id = 1\n", - "spec = get_model_spec(model_id)\n", - "\n", - "# Build the model architecture from the registry spec.\n", - "# num_filters/kernel_size define the 3D convolutional layers that process the\n", - "# syndrome tensor (space x space x time).\n", - "\n", - "model_cfg = SimpleNamespace(\n", - " code=\"surface\", distance=distance, n_rounds=n_rounds,\n", - " model=SimpleNamespace(\n", - " version=\"predecoder_memory_v1\",\n", - " num_filters=list(spec.num_filters), kernel_size=list(spec.kernel_size),\n", - " dropout_p=0.0, activation=\"gelu\", input_channels=4, out_channels=4,\n", - " ),\n", - ")\n", - "model = ModelFactory.create_model(model_cfg)\n", - "\n", - "# Download weights from HuggingFace (cached after first run).\n", - "import logging\n", - "logging.getLogger(\"httpx\").setLevel(logging.WARNING)\n", - "from huggingface_hub import hf_hub_download\n", - "from safetensors.torch import load_file\n", - "\n", - "model_path = hf_hub_download(\n", - " repo_id=\"nvidia/ising_decoder_surface_code_1_fast\",\n", - " filename=\"ising_decoder_surface_code_1_fast_r9_v1.0.77_fp16.safetensors\",\n", - " token=HF_TOKEN or True,\n", - ")\n", - "sd = load_file(model_path, device=\"cpu\")\n", - "sd = {(k[7:] if k.startswith(\"module.\") else k): v.float() for k, v in sd.items()}\n", - "model.load_state_dict(sd)\n", - "model.eval()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The model printed above is just the neural network \u2014 it knows how to process data, but it expects a very specific input format (a 4-channel 3D tensor shaped to the surface code lattice), not the flat array of detector bits we generated earlier. The next cell wraps the model inside `PreDecoderMemoryEvalModule`, which handles the format conversion automatically: flat detector bits go in, and a partial logical correction plus a simplified syndrome for PyMatching come out." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# \u2500\u2500 3. Build inference pipeline \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", - "# Move model to GPU if available; CPU fallback works but will be slow.\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "model = model.to(device)\n", - "\n", - "# Build stabilizer maps: spatial lookup tables that tell the model which detectors\n", - "# correspond to which physical qubit locations on the surface code lattice.\n", - "maps = _build_stab_maps(distance, code_rotation)\n", - "\n", - "# PreDecoderMemoryEvalModule wraps the model with the pre/post-processing needed\n", - "# for inference: it accepts raw detector arrays and returns:\n", - "# out[:, 0] \u2014 predicted logical correction (0 or 1 per shot)\n", - "# out[:, 1:] \u2014 residual detectors after pre-decoding (passed to PyMatching)\n", - "eval_cfg = SimpleNamespace(\n", - " distance=distance,\n", - " test=SimpleNamespace(meas_basis_test=basis, th_data=0.0, th_syn=0.0,\n", - " sampling_mode=\"threshold\", temperature=1.0,\n", - " temperature_data=None, temperature_syn=None, n_rounds=n_rounds),\n", - " data=SimpleNamespace(code_rotation=code_rotation),\n", - " enable_fp16=False,\n", - ")\n", - "pipeline = PreDecoderMemoryEvalModule(model, eval_cfg, maps, device).to(device)\n", - "pipeline.eval();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `pipeline` object now encapsulates the entire pre-decoder workflow in a single callable: give it a batch of raw detector bits and it returns a two-part result \u2014 a predicted logical correction (column 0) and a set of residual syndromes (remaining columns) that have been simplified by the neural network. Those residual syndromes are what we will pass to PyMatching as a second decoding step." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Everything is ready. We now compare four decoding strategies side by side:\n", - "\n", - "1. **Undecoded** \u2014 no correction at all (the raw logical error rate)\n", - "2. **PyMatching only** \u2014 the standard MWPM decoder applied to the full syndrome\n", - "3. **Pre-decoder alone** \u2014 just the neural network's partial correction, no follow-up\n", - "4. **Pre-decoder + PyMatching** \u2014 the full pipeline: the model reduces the syndrome, then PyMatching finishes the job on the sparser residual. The final correction is the XOR of both predictions.\n", - "\n", - "A single warmup pass runs first so that `torch.compile` JIT-compiles the model for this batch shape before we start timing." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# \u2500\u2500 4. Run all four decoding strategies \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", - "# Build the PyMatching decoder from the DEM graph \u2014 this is the classical baseline.\n", - "\n", - "matcher = pymatching.Matching.from_detector_error_model(dem)\n", - "obs_flat = observables.flatten()\n", - "\n", - "# Warmup pass: torch.compile JIT-compiles the model on first call for this batch\n", - "# shape/dtype. We discard these timings \u2014 subsequent calls reflect true latency.\n", - "\n", - "with torch.no_grad():\n", - " _ = pipeline(torch.from_numpy(detectors).to(torch.uint8).to(device))\n", - "_ = matcher.decode_batch(detectors)\n", - "\n", - "\n", - "\n", - "# Strategy 1 \u2014 Undecoded (baseline: what LER looks like with no correction at all)\n", - "\n", - "undecoded_ler = obs_flat.mean()\n", - "\n", - "\n", - "\n", - "# Strategy 2 \u2014 PyMatching only (standard MWPM decoder, the classical state of the art)\n", - "\n", - "t0 = time.perf_counter()\n", - "pm_pred = np.asarray(matcher.decode_batch(detectors), dtype=np.uint8).reshape(-1, num_obs).flatten()\n", - "t_pm = time.perf_counter() - t0\n", - "pm_ler = (pm_pred != obs_flat).mean()\n", - "\n", - "\n", - "\n", - "# Strategy 3 \u2014 Pre-decoder alone (neural network output, no follow-up matching)\n", - "# This shows what the model achieves without a global decoder \u2014 useful for\n", - "# understanding the pre-decoder's contribution in isolation.\n", - "\n", - "dets_t = torch.from_numpy(detectors).to(torch.uint8).to(device)\n", - "t0 = time.perf_counter()\n", - "with torch.no_grad():\n", - " out = pipeline(dets_t)\n", - "t_pd = time.perf_counter() - t0\n", - "pre_L = out[:, 0].cpu().numpy() # model's logical correction prediction\n", - "res_dets = out[:, 1:].cpu().numpy().astype(np.uint8) # residual syndromes for PyMatching\n", - "pd_ler = (pre_L.astype(int) != obs_flat.astype(int)).mean()\n", - "\n", - "\n", - "\n", - "# Strategy 4 \u2014 Pre-decoder + PyMatching (the full pipeline)\n", - "# The model first simplifies the syndrome pattern (reducing syndrome density),\n", - "# then PyMatching decodes the cleaner residual syndromes. The final logical\n", - "# correction is XOR of the two predictions (mod 2 addition of Pauli corrections).\n", - "\n", - "t0 = time.perf_counter()\n", - "pm_res = np.asarray(matcher.decode_batch(res_dets), dtype=np.uint8).reshape(-1, num_obs).flatten()\n", - "t_pm_res = time.perf_counter() - t0\n", - "combo_L = (pre_L.astype(int) + pm_res.astype(int)) % 2 # combine both corrections\n", - "combo_ler = (combo_L != obs_flat).mean()\n", - "t_combo = t_pd + t_pm_res # total wall time for the combined pipeline\n", - "\n", - "\n", - "\n", - "# \u2500\u2500 5. Results \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", - "print(f\" {'Strategy':<30} {'LER':>8} {'Time (ms)':>10} {'\u00b5s/shot':>8}\")\n", - "print(f\" {'-'*30} {'-'*8} {'-'*10} {'-'*8}\")\n", - "print(f\" {'Undecoded':<30} {undecoded_ler:>8.5f} {'\u2014':>10} {'\u2014':>8}\")\n", - "print(f\" {'PyMatching only':<30} {pm_ler:>8.5f} {t_pm*1e3:>9.1f}ms {t_pm*1e6/num_samples:>7.1f}\u00b5s\")\n", - "print(f\" {'Pre-decoder alone':<30} {pd_ler:>8.5f} {t_pd*1e3:>9.1f}ms {t_pd*1e6/num_samples:>7.1f}\u00b5s\")\n", - "print(f\" {'Pre-decoder + PyMatching':<30} {combo_ler:>8.5f} {t_combo*1e3:>9.1f}ms {t_combo*1e6/num_samples:>7.1f}\u00b5s\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Notice that the LER of the NVIDIA Ising pre-decoder + PyMatching combination is nearly identical to PyMatching alone in this case, but the overall and per-shot decoding time is much smaller for the pre-decoder + PyMatching approach.\n", - "\n", - "An important caveat: the workflow above shows decode time for a batch of 5,000 syndromes. If you run the same code with a batch size of 1, as you would in a real-time QEC application, the pre-decoder + PyMatching approach is much slower than PyMatching alone, due to the significant overhead of launching a GPU kernel. Batched decoding amortizes that overhead across all 5,000 syndrome rounds at once. \n", - "\n", - "This behavior is expected at small scale and is consistent with the results from the [paper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes). Compare the d=9 point for the light-blue line (pre-decoder + PyMatching) and the grey line (PyMatching only).\n", - "\n", - "\n", - "\n", - "To reach a regime where the kernel-launch overhead is worth the tradeoff, we need a combination of:\n", - "\n", - "1. A higher-distance surface code\n", - "2. A more complex noise model (i.e., harder-to-decode syndrome patterns)\n", - "3. Optimizations for low-latency model inference\n", - "\n", - "In the next section, you will learn how to train your own model on higher-distance surface codes with much more complex noise patterns, using on-the-fly GPU-accelerated syndrome generation with cuStabilizer. After that, you will learn how to optimize model inference for the best possible real-time deployment performance on your specific GPU. Together, these sections will equip you to study regimes where the NVIDIA Ising pre-decoder + PyMatching approach significantly outperforms PyMatching alone." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "## Training your own Pre-decoder Model\n", - "\n", - "The pre-trained models ship with this repo, but there are several reasons you might want to train your own:\n", - "\n", - "- **Custom noise model**: The pre-trained models were trained on a specific noise profile. Real quantum hardware has different error characteristics and you may want to see how the model handles a custom noise profile.\n", - "- **Different code distances**: Pre-trained models are available for specific distances. Training lets you target any distance within the model's receptive field.\n", - "- **Different noise regimes**: Physical error rates vary by platform. A model trained at p \u2248 0.005 may not be optimal for p \u2248 0.001 or p \u2248 0.006.\n", - "- **Architecture exploration**: The 5 public models vary in depth and width. Training lets you experiment with the tradeoff between model size (latency) and correction quality (LER improvement).\n", - "\n", - "The following sections will walk you through the process for training your own model." - ] - }, - { - "cell_type": "markdown", - "id": "inserted", - "metadata": {}, - "source": [ - "### Model Variants\n", - "\n", - "Five model variants are available, trading off model size against correction quality. The **receptive field (RF)** determines the spatial and temporal window the model can see: a model with RF = 9 requires `distance \u2265 9` and `n_rounds \u2265 9`. Larger RF models can correct more complex, spatially spread errors but cost more compute per shot.\n", - "\n", - "Run the code below to see the model specs for the 5 included model architectures. The length of the `num_filters` list is the number of layers in the model. Generally, more filters allow the model to detect more complex error patterns. The kernel size prescribes the spatial and temporal window over which features are computed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "inserted", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "print(f\"{'Model':<8} {'num_filters':<30} {'kernel_size':<24} {'RF size':<10} {'num_params':>12}\")\n", - "print(\"-\" * 86)\n", - "for mid in range(1, 6):\n", - " spec = get_model_spec(mid)\n", - " cfg = SimpleNamespace(\n", - " code=\"surface\", distance=9, n_rounds=9,\n", - " model=SimpleNamespace(\n", - " version=\"predecoder_memory_v1\",\n", - " num_filters=list(spec.num_filters),\n", - " kernel_size=list(spec.kernel_size),\n", - " dropout_p=0.0, activation=\"gelu\",\n", - " input_channels=4, out_channels=4,\n", - " ),\n", - " )\n", - " m = ModelFactory.create_model(cfg)\n", - " nparams = sum(p.numel() for p in m.parameters())\n", - " print(f\" {mid:<6} {str(spec.num_filters):<30} {str(spec.kernel_size):<24} {spec.receptive_field:<10} {nparams:>12,}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### On-the-Fly Training Data Generation\n", - "\n", - "A key benefit of the Ising pre-decoder training pipeline is on-the-fly generation of syndrome data during each training epoch. The simulation uses a **25-parameter circuit-level noise model** that independently controls:\n", - "\n", - "- **State preparation errors** (2 params): `p_prep_X`, `p_prep_Z`\n", - "- **Measurement errors** (2 params): `p_meas_X`, `p_meas_Z`\n", - "- **Idle errors during CNOT layers** (3 params): `p_idle_cnot_X/Y/Z`\n", - "- **Idle errors during state-prep/measurement windows** (3 params): `p_idle_spam_X/Y/Z`\n", - "- **Two-qubit CNOT gate errors** (15 params): all Pauli pairs `IX, IY, IZ, XI, XX, ... ZZ` (excluding `II`)\n", - "\n", - "This noise model is defined in `conf/config_public.yaml` and can be customized to match real hardware calibration data.\n", - "\n", - "For training, the simulation runs inside `QCDataGeneratorTorch`, which uses GPU accelerated DEM sampling that leverages the [NVIDIA cuStabilizer library](https://docs.nvidia.com/cuda/cuquantum/latest/custabilizer/index.html) within the [NVIDIA cuQuantum library](https://docs.nvidia.com/cuda/cuquantum/latest/index.html) to generate fresh batches of `(trainX, trainY)` tensors on every training step.\n", - "\n", - "GPU-accelerated sampling requires the `cuquantum-python` package (installed automatically by the setup cell above, matched to your CUDA version). Training uses cuStabilizer for on-the-fly data generation and will fail if cuquantum is not installed. Inference uses STIM instead and does not require cuquantum.\n", - "\n", - "For large scale training runs, cuStabilizer provides massive speedups in data generation. \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "For example, in the above figure, cuStabilizer provides a 1060x speedup over STIM for sampling 1 Million shots from a distance 30 surface code on an NVIDIA DGX B200 GPU.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training Configuration\n", - "\n", - "Training is controlled by a single user-facing file: `conf/config_public.yaml`. It exposes the QEC problem variables.\n", - "\n", - "Here are the main settings for your `config_public.yaml`:\n", - "\n", - "| Parameter | What it controls |\n", - "|---|---|\n", - "| `model_id` | Which model architecture to use (1\u20135). Each has a fixed **receptive field R** \u2014 the maximum distance/rounds it can handle. Models 1 & 2: R=9, Model 3: R=17, Models 4 & 5: R=13 |\n", - "| `distance` | Surface code distance for **evaluation** (Training always runs at distance=R) |\n", - "| `n_rounds` | Number of syndrome measurement rounds for **evaluation** |\n", - "| `workflow.task` | `train` to train a new model, `inference` to evaluate a saved model |\n", - "| `data.code_rotation` | Surface code orientation: `O1`, `O2`, `O3`, or `O4` (different logical operator placements) |\n", - "| `data.noise_model` | The 25-parameter circuit-level noise model: state prep errors (`p_prep_X/Z`), measurement errors (`p_meas_X/Z`), idle errors during CNOT layers (`p_idle_cnot_X/Y/Z`) and SPAM windows (`p_idle_spam_X/Y/Z`), and all 15 two-qubit CNOT Pauli error rates (`p_cnot_AB`) |\n", - "\n", - "\n", - "The cell below prints the full default `config_public.yaml` file.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# === conf/config_public.yaml (what users edit) ===\n", - "config_path = os.path.join(PREDECODER_ROOT, \"conf\", \"config_public.yaml\")\n", - "with open(config_path) as f:\n", - " print(f.read())\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`config_validator.py` holds the hidden training hyperparameters: optimizer settings, batch schedule, EMA, early stopping, and more. These are intentionally not in the public config to prevent misconfiguration.\n", - "\n", - "You do not need to touch this file for a standard run. However, some parameters can be overridden at runtime via **environment variables**, without editing any file:\n", - "\n", - "| Environment Variable | Config Parameter | Default | Description |\n", - "|---|---|---|---|\n", - "| `PREDECODER_TRAIN_SAMPLES` | `train.num_samples` | 67,108,864 | Total training shots generated per epoch (divided into batches based on batch size and GPU count). Default is the production base for 8 GPUs, auto-scaled proportionally to GPU count (e.g. 8,388,608 on 1 GPU). Env var override bypasses scaling |\n", - "| `PREDECODER_TRAIN_EPOCHS` | `train.epochs` | 100 | Total number of training epochs |\n", - "| `PREDECODER_VAL_SAMPLES` | `val.num_samples` | 65,536 | Total validation shots evaluated at the end of each epoch. Split evenly across GPUs |\n", - "| `PREDECODER_TEST_SAMPLES` | `test.num_samples` | 262,144 | Total shots for final LER/SDR evaluation after training. Split evenly across GPUs |\n", - "| `PREDECODER_LR_MILESTONES` | `lr_scheduler.milestones` | `0.25,0.5,1.0` | Comma-separated fractions of total training steps. At each milestone the learning rate is multiplied by gamma=0.7. With the defaults, LR drops at 25%, 50%, and 100% of training |\n", - "| `PREDECODER_SKIP_NOISE_UPSCALING` | *(env var only)* | `0` | When `0` (default), training noise is automatically scaled up toward the surface code threshold (~0.006) to improve generalization. Set `1` to train with your exact noise model values |\n", - "| `PREDECODER_TORCH_COMPILE` | `torch_compile` | `1` | Enables `torch.compile` for the model. Set `0` to disable (useful if training segfaults on startup) |\n", - "| `PREDECODER_TORCH_COMPILE_MODE` | `torch_compile_mode` | `default` | `torch.compile` backend mode. Options: `default`, `reduce-overhead`, `max-autotune` |\n", - "| `PREDECODER_BASE_OUTPUT_DIR` | *(env var only)* | `outputs/` | Root directory for all run output: checkpoints, TensorBoard logs, Hydra config snapshots |\n", - "\n", - "Parameters **not** accessible via environment variable \u2014 such as optimizer internals (`weight_decay`, `beta2`), EMA decay, batch schedule shape, and model architecture can only be changed by editing `apply_public_defaults_and_model()` in `config_validator.py` directly.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Running Training from the CLI\n", - "\n", - "All training is launched through a single script that reads `conf/config_public.yaml` and applies the hidden defaults from `config_validator.py` automatically.\n", - "\n", - "#### Default training run\n", - "\n", - "All CLI commands below assume you are in the repository root directory (`Ising-Decoding/`).\n", - "\n", - "```bash\n", - "bash code/scripts/local_run.sh\n", - "```\n", - "\n", - "This trains Model 1 (R=9) at d=9, n_rounds=9 using the default noise model (p\u22480.003), with all hidden defaults: Lion optimizer, LR warmup/decay, EMA, and noise upscaling. Output is written to `outputs/test1/`.\n", - "\n", - "---\n", - "\n", - "#### Full reference: everything you can change at the command line\n", - "\n", - "**Shell variables** \u2014 control the run itself (prepend before `bash code/scripts/local_run.sh`):\n", - "\n", - "| Variable | What it does | Example |\n", - "|---|---|---|\n", - "| `EXPERIMENT_NAME` | Names the output folder under `outputs/` \u2014 change this when you change config so runs don't mix | `EXPERIMENT_NAME=my_d9_run` |\n", - "| `WORKFLOW` | `train` to train, `inference` to evaluate a saved model | `WORKFLOW=inference` |\n", - "| `GPUS` | Number of GPUs (uses PyTorch DDP; each GPU gets its own data generator) | `GPUS=4` |\n", - "| `CUDA_VISIBLE_DEVICES` | Which physical GPUs to use | `CUDA_VISIBLE_DEVICES=0,1` |\n", - "| `FRESH_START` | Set `1` to ignore any existing checkpoint and restart from scratch | `FRESH_START=1` |\n", - "| `CONFIG_NAME` | Which config file to load from `conf/` (omit `.yaml`) | `CONFIG_NAME=config_local_test` |\n", - "\n", - "**Environment variables** \u2014 override `config_validator.py` hidden defaults at runtime (see the table in the previous section for defaults and detailed descriptions). Prepend these before `bash code/scripts/local_run.sh`:\n", - "\n", - "**Hydra overrides** \u2014 override any `config_public.yaml` value inline without editing the file (appended after `local_run.sh`):\n", - "\n", - "| Override | What it does | Example |\n", - "|---|---|---|\n", - "| `model_id=N` | Select model architecture 1\u20135 | `model_id=4` |\n", - "| `distance=N` | Evaluation distance (training always uses R) | `distance=13` |\n", - "| `n_rounds=N` | Evaluation rounds | `n_rounds=13` |\n", - "| `data.code_rotation=ON` | Code orientation O1\u2013O4 | `data.code_rotation=O2` |\n", - "| `optimizer.lr=X` | Override the model-default learning rate | `optimizer.lr=0.0002` |\n", - "| `data.skip_noise_upscaling=true` | Disable noise auto-scaling inline | `data.skip_noise_upscaling=true` |\n", - "| `data.noise_model.p_prep_X=X` | Override any individual noise parameter | `data.noise_model.p_prep_X=0.004` |\n", - "\n", - "---\n", - "\n", - "Below are a few example runs you might do and the associated variables to set.\n", - "\n", - "#### Example 1: quick smoke test on 1 GPU\n", - "\n", - "```bash\n", - "# Separate output folder, small dataset, 3 epochs, fewer eval shots\n", - "EXPERIMENT_NAME=smoke_test \\\n", - "FRESH_START=1 \\\n", - "PREDECODER_TRAIN_SAMPLES=65536 \\\n", - "PREDECODER_TRAIN_EPOCHS=3 \\\n", - "PREDECODER_TEST_SAMPLES=8192 \\\n", - "GPUS=1 \\\n", - "bash code/scripts/local_run.sh\n", - "```\n", - "\n", - "#### Example 2: full d=13 run on 8 GPUs with hardware noise\n", - "\n", - "```bash\n", - "# 8 GPUs, Model 4 (R=13), hardware-calibrated noise (no auto-scaling)\n", - "EXPERIMENT_NAME=d13_hardware_noise \\\n", - "GPUS=8 \\\n", - "PREDECODER_SKIP_NOISE_UPSCALING=1 \\\n", - "bash code/scripts/local_run.sh \\\n", - " model_id=4 \\\n", - " distance=13 \\\n", - " n_rounds=13 \\\n", - " data.code_rotation=O1 \\\n", - " data.noise_model.p_prep_X=0.001 \\\n", - " data.noise_model.p_prep_Z=0.001 \\\n", - " data.noise_model.p_meas_X=0.003 \\\n", - " data.noise_model.p_meas_Z=0.003\n", - "```\n", - "\n", - "#### Example 3: inference on a saved model\n", - "\n", - "```bash\n", - "# Must use the same EXPERIMENT_NAME as the training run\n", - "EXPERIMENT_NAME=d13_hardware_noise \\\n", - "WORKFLOW=inference \\\n", - "bash code/scripts/local_run.sh \\\n", - " model_id=4 \\\n", - " distance=13 \\\n", - " n_rounds=13\n", - "```\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "## Optimizing Model Inference\n", - "\n", - "When using the predecoder in a realtime setting, it becomes crucial to minimize the inference time for the model to ensure the decoder has sufficient throughput and the reaction time is as short as possible. There are optimizations which can be used to improve the inference performance of the predecoder model from the AI side of the equation. Each is walked through below.\n", - "\n", - "1. **ONNX Export**: convert the PyTorch pipeline to a portable, framework-independent graph\n", - "2. **TensorRT Engine (FP16)**: compile the ONNX graph into a GPU-optimized engine with FP16 precision\n", - "3. **FP8 Quantization + TensorRT**: further reduce precision to FP8 for maximum throughput on supported GPUs\n", - "\n", - "The figure below demonstrates why such optimizations are so important for realtime deployment. It shows the speedups achieved for model 1 at three different code distances relative to a FP16 PyTorch run after optimization with TensorRT(FP16) and then quantization with TensorRT(FP8) on a NVIDIA GB300.\n", - "\n", - "\n", - "\n", - "In this case, speedups of around 10x are possible for a batch size of 1. This is just a representative example and not a formal benchmark as model size, code distance, GPU selection, and initial PyTorch optimizations all play a role in the expected speedups. The main point is that these optimizations are easy to perform and always worth doing before your final inference runs. In general, quantization is worth the minimal cost in LER, but you will need to run benchmarks to assess the proper tradeoffs for your experiment." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Step 1: ONNX Export\n", - "\n", - "[ONNX](https://onnx.ai/) (Open Neural Network Exchange) is a standard format for representing ML models that allows a model trained with a specific framework (e.g. PyTorch) to be converted to a framework-agnostic graph which can be optimized to improve inference performance. Exporting to ONNX gives us:\n", - "\n", - "- **Portability**: the model can be loaded by any ONNX-compatible runtime (TensorRT, ONNX Runtime, etc.)\n", - "- **Graph optimization**: constant folding and operator fusion happen at export time\n", - "- **Full pipeline in one graph**: we export the entire `PreDecoderMemoryEvalModule`, not just the CNN. This means detector preprocessing, the model forward pass, and residual syndrome computation are all fused into a single ONNX graph.\n", - "\n", - "\n", - "The input and output signatures of the ONNX graph are:\n", - "- **Input**: `dets` \u2014 `(B, num_dets)` uint8\n", - "- **Output**: `L_and_residual_dets` \u2014 `(B, 1 + num_residual_dets)` uint8\n", - "\n", - "B is the batch size and is dynamic, meaning the same ONNX graph can be used to process any batch size without regeneration.\n", - "\n", - "The code below converts our Ising model into an ONNX file. \n", - "\n", - "**Note:** If you are running CUDA 13 you will see a warning from onnxruntime-gpu about missing CUDA 13 support. You can safely ignore this warning. ONNX Runtime is used here only as a quick sanity check that the exported graph runs correctly \u2014 it automatically falls back to CPU execution. The actual GPU-accelerated inference in this tutorial is handled by TensorRT, which has full CUDA 13 support.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import logging, warnings\n", - "for _name in [\"onnxscript\", \"onnx_ir\", \"torch.onnx\", \"torch\"]:\n", - " logging.getLogger(_name).setLevel(logging.ERROR)\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "# Export the full pipeline to ONNX\n", - "onnx_path = \"predecoder_memory_d9_T9_X.onnx\"\n", - "\n", - "# half = number of X-type (or Z-type) detectors per round for a distance-d surface code\n", - "half = (distance**2 - 1) // 2\n", - "\n", - "# Create a small dummy detector batch \u2014 torch.onnx.export traces the model\n", - "# using this example input to record the full computation graph\n", - "\n", - "example_dets = torch.randint(0, 2, (2, 2 * n_rounds * half), dtype=torch.uint8, device=device)\n", - "print(f\"Input shape: {example_dets.shape}\")\n", - "\n", - "try:\n", - " import onnx\n", - "\n", - " # Trace the pipeline and serialize the computation graph to an ONNX file.\n", - " # dynamic_axes lets the batch dimension vary at runtime (single shot or batched inference).\n", - " # do_constant_folding=True pre-computes constant sub-expressions at export time.\n", - " \n", - " torch.onnx.export(\n", - " pipeline,\n", - " example_dets,\n", - " onnx_path,\n", - " opset_version=18,\n", - " input_names=[\"dets\"],\n", - " output_names=[\"L_and_residual_dets\"],\n", - " dynamic_axes={\n", - " \"dets\": {0: \"batch\"},\n", - " \"L_and_residual_dets\": {0: \"batch\"},\n", - " },\n", - " do_constant_folding=True,\n", - " )\n", - "\n", - " # Verify the exported graph is valid\n", - " model_onnx = onnx.load(onnx_path)\n", - " onnx.checker.check_model(model_onnx)\n", - " onnx_size_mb = os.path.getsize(onnx_path) / (1024 ** 2)\n", - " print(f\"Exported and verified: {onnx_path} ({onnx_size_mb:.1f} MB)\")\n", - "\n", - " # Quick inference check with ONNX Runtime\n", - " # Tries CUDA first, falls back to CPU \u2014 confirms the graph runs end-to-end\n", - " try:\n", - " import onnxruntime as ort\n", - " providers = [\"CUDAExecutionProvider\", \"CPUExecutionProvider\"]\n", - " sess = ort.InferenceSession(onnx_path, providers=providers)\n", - " test_input = example_dets.cpu().numpy()\n", - " ort_out = sess.run(None, {\"dets\": test_input})\n", - " print(f\"ONNX Runtime inference OK \u2014 output shape: {ort_out[0].shape}\")\n", - " active_provider = sess.get_providers()[0]\n", - " print(f\" Running on: {active_provider}\")\n", - " except ImportError:\n", - " print(\"onnxruntime not installed \u2014 skipping runtime check.\")\n", - " print(\"Install with: pip install onnxruntime\")\n", - "\n", - "except ImportError:\n", - " print(\"onnx not installed \u2014 skipping export.\")\n", - " print(\"Install with: pip install onnx\")\n", - " onnx_path = None\n", - "except Exception as e:\n", - " print(f\"ONNX export failed: {e}\")\n", - " onnx_path = None\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Step 2: TensorRT Engine (FP16)\n", - "\n", - "[TensorRT](https://developer.nvidia.com/tensorrt) is NVIDIA's inference compiler which takes an ONNX graph and compiles it into a GPU-architecture-specific engine that accounts for your target data type and batch size range in order to get the most performance out of your GPU. \n", - "\n", - "Under the hood, TensorRT benchmarks dozens of candidate GPU kernel implementations for each operation in the graph and selects the fastest one for your exact hardware. It then fuses adjacent operations (convolution, normalization, activation) into single GPU kernels, reorganizes tensor memory layouts to maximize cache efficiency, and eliminates the Python interpreter overhead that PyTorch incurs on every call. The result is a binary engine tuned to extract maximum throughput from your specific GPU \u2014 the same model weights, the same outputs, but with latency approaching the hardware limit.\n", - "\n", - "The build process:\n", - "1. Parse the ONNX model into a TensorRT network\n", - "2. Set an **optimization profile** defining the min/optimal/max batch sizes\n", - "3. Enable the FP16 flag (tells TensorRT it can use half-precision where beneficial)\n", - "4. Build and serialize the engine to a `.engine` file" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Build a TensorRT engine from the ONNX model (FP16)\n", - "engine_path = \"predecoder_memory_d9_T9_X_fp16.engine\"\n", - "batch_size = 1024\n", - "in_cols = 2 * n_rounds * half\n", - "\n", - "if onnx_path is None or not os.path.exists(onnx_path):\n", - " print(\"ONNX model not available \u2014 run the ONNX export cell first.\")\n", - "else:\n", - " try:\n", - " import tensorrt as trt\n", - " print(f\"TensorRT version: {trt.__version__}\")\n", - "\n", - " # Create TensorRT logger and builder \u2014 entry points for all TRT operations.\n", - " # Store as trt_logger so the FP8 cell can reuse it (avoids a TRT duplicate-logger warning).\n", - " trt_logger = trt.Logger(trt.Logger.WARNING)\n", - " builder = trt.Builder(trt_logger)\n", - "\n", - " # TRT 10+ removed the EXPLICIT_BATCH flag (always on); TRT 8 needs it.\n", - " trt_major = int(trt.__version__.split(\".\")[0])\n", - " if trt_major >= 10:\n", - " network = builder.create_network()\n", - " else:\n", - " network = builder.create_network(\n", - " 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)\n", - " )\n", - "\n", - " # Parse the ONNX graph into the TensorRT network representation\n", - " parser = trt.OnnxParser(network, trt_logger)\n", - " with open(onnx_path, \"rb\") as f:\n", - " if not parser.parse(f.read()):\n", - " for err_i in range(parser.num_errors):\n", - " print(f\" Parse error: {parser.get_error(err_i)}\")\n", - " raise RuntimeError(\"ONNX parse failed\")\n", - "\n", - " # Enable FP16 mode \u2014 TensorRT will use half-precision kernels wherever beneficial\n", - " config = builder.create_builder_config()\n", - " config.set_flag(trt.BuilderFlag.FP16)\n", - "\n", - " # Define the supported batch size range: (min, optimal, max)\n", - " # TensorRT selects the best kernels for the optimal size during compilation\n", - " profile = builder.create_optimization_profile()\n", - " profile.set_shape(\"dets\", (1, in_cols), (batch_size, in_cols), (batch_size, in_cols))\n", - " config.add_optimization_profile(profile)\n", - "\n", - " # Compile the network into a hardware-specific binary engine and save it\n", - " print(f\"Building TensorRT FP16 engine (batch_size={batch_size}) \u2014 this may take a minute...\")\n", - " serialized = builder.build_serialized_network(network, config)\n", - " with open(engine_path, \"wb\") as f:\n", - " f.write(serialized)\n", - "\n", - " engine_size_mb = os.path.getsize(engine_path) / (1024 ** 2)\n", - " print(f\"Built: {engine_path} ({engine_size_mb:.1f} MB)\")\n", - "\n", - " except ImportError:\n", - " print(\"TensorRT not installed \u2014 skipping engine build.\")\n", - " print(\"Install with: pip install tensorrt\")\n", - " except Exception as e:\n", - " print(f\"Engine build failed: {e}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Step 3: FP8 Quantization + TensorRT\n", - "\n", - "AI models can be quantized such that their weights are converted to a lower precision. Often, this provides additional speedup with minimal loss in accuracy. The Ising models are FP16 by default, but TensorRT can quantize to FP8 and achieve further speedups which can run on any GPU with FP8 tensor cores. \n", - "\n", - "The quantization process is:\n", - "1. **Collect calibration data**: representative detector patterns from the test distribution\n", - "2. **Quantize the ONNX model**: `nvidia-modelopt` analyzes the calibration data to determine optimal quantization ranges, then inserts QuantizeLinear/DequantizeLinear (QDQ) nodes into the ONNX graph\n", - "3. **Build a TensorRT engine**: TensorRT recognizes the QDQ nodes and fuses them into native FP8 operations\n", - "\n", - "Only the convolution layers are quantized to FP8; the remaining operations (preprocessing, residual computation) stay in FP16 for accuracy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import contextlib, io\n", - "\n", - "# FP8 quantization of the ONNX model + TensorRT engine build\n", - "fp8_onnx_path = \"predecoder_memory_d9_T9_X_fp8.onnx\"\n", - "fp8_engine_path = \"predecoder_memory_d9_T9_X_fp8.engine\"\n", - "\n", - "if onnx_path is None or not os.path.exists(onnx_path):\n", - " print(\"ONNX model not available \u2014 run the ONNX export cell first.\")\n", - "else:\n", - " calib_dets = detectors[:256]\n", - " print(f\"Calibration data: {calib_dets.shape} ({calib_dets.shape[0]} samples, dtype: {calib_dets.dtype})\")\n", - "\n", - " try:\n", - " import modelopt.onnx.quantization as mq\n", - " import tensorrt as trt\n", - " print(f\"TensorRT version: {trt.__version__}\")\n", - "\n", - " # Quantize convolution layers to FP8 using calibration data to determine\n", - " # optimal scale factors; non-conv ops stay at FP16 for numerical accuracy\n", - " print(f\"Quantizing {onnx_path} to FP8 (this may take a moment)...\")\n", - " _prev_level = logging.root.manager.disable\n", - " logging.disable(logging.CRITICAL)\n", - " _devnull_fd = os.open(os.devnull, os.O_WRONLY)\n", - " _old_stdout_fd = os.dup(1)\n", - " _old_stderr_fd = os.dup(2)\n", - " _devnull_py = open(os.devnull, \"w\")\n", - " _old_sys_stdout = sys.stdout\n", - " _old_sys_stderr = sys.stderr\n", - " os.dup2(_devnull_fd, 1)\n", - " os.dup2(_devnull_fd, 2)\n", - " sys.stdout = _devnull_py\n", - " sys.stderr = _devnull_py\n", - " try:\n", - " mq.quantize(\n", - " onnx_path=onnx_path,\n", - " quantize_mode=\"fp8\",\n", - " calibration_data={\"dets\": calib_dets},\n", - " output_path=fp8_onnx_path,\n", - " op_types_to_quantize=[\"Conv\"],\n", - " high_precision_dtype=\"fp16\",\n", - " )\n", - " finally:\n", - " sys.stdout = _old_sys_stdout\n", - " sys.stderr = _old_sys_stderr\n", - " os.dup2(_old_stdout_fd, 1)\n", - " os.dup2(_old_stderr_fd, 2)\n", - " os.close(_devnull_fd)\n", - " os.close(_old_stdout_fd)\n", - " os.close(_old_stderr_fd)\n", - " _devnull_py.close()\n", - " logging.disable(_prev_level)\n", - " fp8_size_mb = os.path.getsize(fp8_onnx_path) / (1024 ** 2)\n", - " print(f\" Quantized: {fp8_onnx_path} ({fp8_size_mb:.1f} MB)\")\n", - "\n", - " if not gpu_supports_fp8:\n", - " print(f\"\\n Skipping FP8 engine build: {gpu_name} (compute capability {gpu_cc}) \"\n", - " f\"does not have FP8 tensor cores.\")\n", - " print(f\" FP8 requires Ada Lovelace (RTX 4000+, cc 8.9) or Hopper (H100+, cc 9.0).\")\n", - " print(f\" The quantized ONNX model ({fp8_onnx_path}) can be built on supported hardware.\")\n", - " else:\n", - " _trt_logger = trt_logger if 'trt_logger' in dir() else trt.Logger(trt.Logger.WARNING)\n", - " builder = trt.Builder(_trt_logger)\n", - "\n", - " trt_major = int(trt.__version__.split(\".\")[0])\n", - " if trt_major >= 10:\n", - " network = builder.create_network(\n", - " 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)\n", - " )\n", - " else:\n", - " network = builder.create_network(\n", - " 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)\n", - " | 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)\n", - " )\n", - "\n", - " parser = trt.OnnxParser(network, _trt_logger)\n", - " with open(fp8_onnx_path, \"rb\") as f:\n", - " if not parser.parse(f.read()):\n", - " raise RuntimeError(\"ONNX parse failed for FP8 model\")\n", - "\n", - " config = builder.create_builder_config()\n", - " profile = builder.create_optimization_profile()\n", - " profile.set_shape(\"dets\", (1, in_cols), (batch_size, in_cols), (batch_size, in_cols))\n", - " config.add_optimization_profile(profile)\n", - "\n", - " print(f\"Building TensorRT FP8 engine (batch_size={batch_size}) \u2014 this may take a minute...\")\n", - " serialized = builder.build_serialized_network(network, config)\n", - " with open(fp8_engine_path, \"wb\") as f:\n", - " f.write(serialized)\n", - "\n", - " engine_size_mb = os.path.getsize(fp8_engine_path) / (1024 ** 2)\n", - " print(f\"Built: {fp8_engine_path} ({engine_size_mb:.1f} MB)\")\n", - "\n", - " except ImportError as e:\n", - " print(f\"Required package not available: {e}\")\n", - " print(\"Install with: pip install nvidia-modelopt[onnx] tensorrt\")\n", - " except Exception as e:\n", - " print(f\"FP8 pipeline failed: {e}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Running Optimized Inference from CLI\n", - "\n", - "The full optimization pipeline can also be driven via environment variables from the command line (run from the repository root):\n", - "\n", - "```bash\n", - "# PyTorch only (default):\n", - "WORKFLOW=inference bash code/scripts/local_run.sh\n", - "\n", - "# ONNX export + TensorRT engine build (FP16):\n", - "ONNX_WORKFLOW=2 WORKFLOW=inference bash code/scripts/local_run.sh\n", - "\n", - "# ONNX + FP8 quantization + TensorRT:\n", - "ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 WORKFLOW=inference bash code/scripts/local_run.sh\n", - "\n", - "# Load a pre-built engine file:\n", - "ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh\n", - "```\n", - "\n", - "### Generating data for CUDA-Q QEC realtime predecoder test application\n", - "\n", - "When evaluating the neural pre-decoder in an end-to-end downstream system like\n", - "CUDA-Q Realtime, you will need a test harness with valid inputs—both the\n", - "exported neural network model and the corresponding syndrome data.\n", - "\n", - "The utility script `code/export/generate_test_data.py` is provided to generate\n", - "this exact data (both an `.onnx` file and several `.bin` files) so you can\n", - "easily consume it in the CUDA-Q QEC realtime AI decoder.\n", - "\n", - "> **Important:** The `--distance` and `--n-rounds` arguments provided to this\n", - "script **must match** the values used in the preceding section when running the\n", - "ONNX export (e.g. `ONNX_WORKFLOW=2`).\n", - "\n", - "For a detailed walkthrough on how to ingest these files into the CUDA-Q Realtime\n", - "C++ pipeline, see the downstream documentation here: [Realtime AI Predecoder\n", - "Pipeline](https://nvidia.github.io/cudaqx/examples_rst/qec/realtime_predecoder_pymatching.html).\n", - "\n", - "```text\n", - "python3 code/export/generate_test_data.py --distance 13 --n-rounds 104 --num-samples 10000 --basis X --p-error=0.003 --simple-noise\n", - "```\n", - "\n", - "**Example output:**\n", - "\n", - "```text\n", - "Building circuit: D=13, T=104, basis=X, rotation=XV, p=0.003\n", - " Circuit built in 0.007s\n", - "Building detector error model and PyMatching matcher...\n", - " DEM + matcher built in 0.083s\n", - " Detectors: 17472, Observables: 1\n", - "Extracting check matrices (beliefmatching)...\n", - " H shape: (17472, 93864), O shape: (1, 93864), priors shape: (93864,)\n", - "Sampling 10000 shots...\n", - " Sampled in 1.006s\n", - "Decoding with PyMatching (baseline)...\n", - " Errors: 30/10000, LER: 0.0030\n", - " Decode time: 5.439s (543.9 µs/shot)\n", - "Writing outputs to test_data/d13_T104_X/\n", - "Done.\n", - " H_csr.bin 808,944 bytes\n", - " O_csr.bin 2,932 bytes\n", - " detectors.bin 698,880,008 bytes\n", - " metadata.txt 162 bytes\n", - " observables.bin 40,008 bytes\n", - " priors.bin 750,916 bytes\n", - " pymatching_predictions.bin 40,008 bytes\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "## Learn More\n", - "\n", - "After completing this tutorial, you should have a foundational understanding of how the NVIDIA Ising pre-decoder works and how to get started training your own models. For more information on the code, explore the [GitHub repo](https://github.com/NVIDIA/Ising-Decoding). For more details on the model itself and its performance, read the [NVIDIA Ising pre-decoder whitepaper](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes).\n", - "\n", - "Visit the NVIDIA Ising webpage to learn more about other models in the [NVIDIA Ising family](https://developer.nvidia.com/ising) of open quantum models." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "ising2", - "language": "python", - "name": "ising2" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 }